claude_code_acp/session/
usage.rs

1//! Token usage tracking for sessions
2//!
3//! Tracks cumulative token usage across a session's lifetime.
4
5use std::sync::atomic::{AtomicU64, Ordering};
6
7use crate::types::TokenUsage;
8
9/// Tracks token usage across a session
10///
11/// Thread-safe usage tracking using atomic operations.
12#[derive(Debug, Default)]
13pub struct UsageTracker {
14    /// Total input tokens consumed
15    input_tokens: AtomicU64,
16    /// Total output tokens generated
17    output_tokens: AtomicU64,
18    /// Total cache read tokens
19    cache_read_input_tokens: AtomicU64,
20    /// Total cache creation tokens
21    cache_creation_input_tokens: AtomicU64,
22}
23
24impl UsageTracker {
25    /// Create a new usage tracker
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    /// Add usage from a completed request
31    pub fn add(&self, usage: &TokenUsage) {
32        self.input_tokens
33            .fetch_add(usage.input_tokens, Ordering::Relaxed);
34        self.output_tokens
35            .fetch_add(usage.output_tokens, Ordering::Relaxed);
36        if let Some(v) = usage.cache_read_input_tokens {
37            self.cache_read_input_tokens.fetch_add(v, Ordering::Relaxed);
38        }
39        if let Some(v) = usage.cache_creation_input_tokens {
40            self.cache_creation_input_tokens
41                .fetch_add(v, Ordering::Relaxed);
42        }
43    }
44
45    /// Get current cumulative usage
46    pub fn get(&self) -> TokenUsage {
47        TokenUsage {
48            input_tokens: self.input_tokens.load(Ordering::Relaxed),
49            output_tokens: self.output_tokens.load(Ordering::Relaxed),
50            cache_read_input_tokens: Some(self.cache_read_input_tokens.load(Ordering::Relaxed)),
51            cache_creation_input_tokens: Some(
52                self.cache_creation_input_tokens.load(Ordering::Relaxed),
53            ),
54        }
55    }
56
57    /// Reset usage counters
58    pub fn reset(&self) {
59        self.input_tokens.store(0, Ordering::Relaxed);
60        self.output_tokens.store(0, Ordering::Relaxed);
61        self.cache_read_input_tokens.store(0, Ordering::Relaxed);
62        self.cache_creation_input_tokens.store(0, Ordering::Relaxed);
63    }
64
65    /// Get input tokens
66    pub fn input_tokens(&self) -> u64 {
67        self.input_tokens.load(Ordering::Relaxed)
68    }
69
70    /// Get output tokens
71    pub fn output_tokens(&self) -> u64 {
72        self.output_tokens.load(Ordering::Relaxed)
73    }
74
75    /// Get total tokens (input + output)
76    pub fn total_tokens(&self) -> u64 {
77        self.input_tokens() + self.output_tokens()
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84
85    #[test]
86    fn test_usage_tracker_new() {
87        let tracker = UsageTracker::new();
88        let usage = tracker.get();
89        assert_eq!(usage.input_tokens, 0);
90        assert_eq!(usage.output_tokens, 0);
91    }
92
93    #[test]
94    fn test_usage_tracker_add() {
95        let tracker = UsageTracker::new();
96        let usage = TokenUsage {
97            input_tokens: 100,
98            output_tokens: 50,
99            cache_read_input_tokens: Some(10),
100            cache_creation_input_tokens: Some(5),
101        };
102
103        tracker.add(&usage);
104        let total = tracker.get();
105        assert_eq!(total.input_tokens, 100);
106        assert_eq!(total.output_tokens, 50);
107    }
108
109    #[test]
110    fn test_usage_tracker_cumulative() {
111        let tracker = UsageTracker::new();
112
113        tracker.add(&TokenUsage {
114            input_tokens: 100,
115            output_tokens: 50,
116            cache_read_input_tokens: None,
117            cache_creation_input_tokens: None,
118        });
119
120        tracker.add(&TokenUsage {
121            input_tokens: 200,
122            output_tokens: 100,
123            cache_read_input_tokens: None,
124            cache_creation_input_tokens: None,
125        });
126
127        let total = tracker.get();
128        assert_eq!(total.input_tokens, 300);
129        assert_eq!(total.output_tokens, 150);
130        assert_eq!(tracker.total_tokens(), 450);
131    }
132
133    #[test]
134    fn test_usage_tracker_reset() {
135        let tracker = UsageTracker::new();
136        tracker.add(&TokenUsage {
137            input_tokens: 100,
138            output_tokens: 50,
139            cache_read_input_tokens: Some(10),
140            cache_creation_input_tokens: Some(5),
141        });
142
143        tracker.reset();
144        let total = tracker.get();
145        assert_eq!(total.input_tokens, 0);
146        assert_eq!(total.output_tokens, 0);
147    }
148}