Skip to main content

aether_core/context/
token_tracker.rs

1use llm::TokenUsage;
2
3/// Default threshold for triggering context compaction (85%)
4pub const DEFAULT_COMPACTION_THRESHOLD: f64 = 0.85;
5
6/// Tracks token usage from LLM API responses.
7/// Uses real usage data from API, not estimation.
8///
9/// Cumulative totals are stored for the dimensions consumers care about today
10/// (input/output, cache read/creation, reasoning). The `last_usage` field
11/// preserves the full `TokenUsage` from the most recent API call, so audio /
12/// video / prediction dimensions are still accessible without growing
13/// dedicated accumulators until a consumer asks for them.
14#[derive(Debug, Clone, Default)]
15pub struct TokenTracker {
16    total_input_tokens: u64,
17    total_output_tokens: u64,
18    total_cache_read_tokens: u64,
19    total_cache_creation_tokens: u64,
20    total_reasoning_tokens: u64,
21    last_usage: TokenUsage,
22    context_limit: Option<u32>,
23}
24
25impl TokenTracker {
26    pub fn new(context_limit: Option<u32>) -> Self {
27        Self { context_limit, ..Self::default() }
28    }
29
30    /// Record usage from an LLM API response.
31    pub fn record_usage(&mut self, sample: TokenUsage) {
32        self.total_input_tokens += u64::from(sample.input_tokens);
33        self.total_output_tokens += u64::from(sample.output_tokens);
34        self.total_cache_read_tokens += u64::from(sample.cache_read_tokens.unwrap_or(0));
35        self.total_cache_creation_tokens += u64::from(sample.cache_creation_tokens.unwrap_or(0));
36        self.total_reasoning_tokens += u64::from(sample.reasoning_tokens.unwrap_or(0));
37        self.last_usage = sample;
38    }
39
40    /// Current context usage as a ratio (0.0 - 1.0)
41    pub fn usage_ratio(&self) -> Option<f64> {
42        let context_limit = self.context_limit?;
43        if context_limit == 0 {
44            return None;
45        }
46        Some(f64::from(self.last_usage.input_tokens) / f64::from(context_limit))
47    }
48
49    /// Whether current usage exceeds the given threshold
50    pub fn exceeds_threshold(&self, threshold: f64) -> bool {
51        self.usage_ratio().is_some_and(|ratio| ratio >= threshold)
52    }
53
54    /// Check if context should be compacted based on the given threshold.
55    /// This is a convenience method that combines usage ratio check with
56    /// a minimum context size requirement to avoid unnecessary compaction
57    /// on small conversations.
58    pub fn should_compact(&self, threshold: f64) -> bool {
59        let Some(context_limit) = self.context_limit else {
60            return false;
61        };
62        let min_tokens = std::cmp::max(context_limit / 10, 1000);
63        self.last_usage.input_tokens >= min_tokens && self.exceeds_threshold(threshold)
64    }
65
66    /// Tokens remaining before hitting limit
67    pub fn tokens_remaining(&self) -> Option<u32> {
68        self.context_limit.map(|context_limit| context_limit.saturating_sub(self.last_usage.input_tokens))
69    }
70
71    /// Update the context limit (e.g. when switching models)
72    pub fn set_context_limit(&mut self, limit: Option<u32>) {
73        self.context_limit = limit;
74    }
75
76    /// Get the context limit
77    pub fn context_limit(&self) -> Option<u32> {
78        self.context_limit
79    }
80
81    /// Get last recorded input tokens (current context size)
82    pub fn last_input_tokens(&self) -> u32 {
83        self.last_usage.input_tokens
84    }
85
86    /// Get the full `TokenUsage` from the most recent API call. Returns the
87    /// default (all zeros / `None`) before any call has been recorded.
88    pub fn last_usage(&self) -> &TokenUsage {
89        &self.last_usage
90    }
91
92    /// Get total input tokens across all calls
93    pub fn total_input_tokens(&self) -> u64 {
94        self.total_input_tokens
95    }
96
97    /// Get total output tokens across all calls
98    pub fn total_output_tokens(&self) -> u64 {
99        self.total_output_tokens
100    }
101
102    /// Get total cache-read tokens across all calls
103    pub fn total_cache_read_tokens(&self) -> u64 {
104        self.total_cache_read_tokens
105    }
106
107    /// Get total cache-creation tokens across all calls
108    pub fn total_cache_creation_tokens(&self) -> u64 {
109        self.total_cache_creation_tokens
110    }
111
112    /// Get total reasoning tokens across all calls
113    pub fn total_reasoning_tokens(&self) -> u64 {
114        self.total_reasoning_tokens
115    }
116
117    /// Reset current usage tracking after context compaction.
118    /// Preserves cumulative totals for metrics while clearing `last_usage` to
119    /// prevent immediate re-triggering of compaction.
120    pub fn reset_current_usage(&mut self) {
121        self.last_usage = TokenUsage::default();
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_usage_tracking() {
131        let mut tracker = TokenTracker::new(Some(1000));
132
133        tracker.record_usage(TokenUsage::new(500, 100));
134        assert_eq!(tracker.usage_ratio(), Some(0.5));
135        assert!(!tracker.exceeds_threshold(0.85));
136
137        tracker.record_usage(TokenUsage::new(900, 50));
138        assert_eq!(tracker.usage_ratio(), Some(0.9));
139        assert!(tracker.exceeds_threshold(0.85));
140    }
141
142    #[test]
143    fn test_tokens_remaining() {
144        let mut tracker = TokenTracker::new(Some(1000));
145        tracker.record_usage(TokenUsage::new(700, 50));
146        assert_eq!(tracker.tokens_remaining(), Some(300));
147    }
148
149    #[test]
150    fn test_cumulative_totals() {
151        let mut tracker = TokenTracker::new(Some(1000));
152        tracker.record_usage(TokenUsage::new(100, 50));
153        tracker.record_usage(TokenUsage::new(200, 60));
154
155        assert_eq!(tracker.total_input_tokens(), 300);
156        assert_eq!(tracker.total_output_tokens(), 110);
157        assert_eq!(tracker.last_input_tokens(), 200);
158    }
159
160    #[test]
161    fn test_unknown_context_limit() {
162        let tracker = TokenTracker::new(None);
163        assert_eq!(tracker.usage_ratio(), None);
164        assert_eq!(tracker.tokens_remaining(), None);
165        assert!(!tracker.should_compact(0.85));
166    }
167
168    #[test]
169    fn test_exceeds_threshold() {
170        let mut tracker = TokenTracker::new(Some(1000));
171
172        tracker.record_usage(TokenUsage::new(500, 100));
173        assert!(!tracker.exceeds_threshold(0.6));
174        assert!(tracker.exceeds_threshold(0.5));
175
176        tracker.record_usage(TokenUsage::new(850, 50));
177        assert!(tracker.exceeds_threshold(0.8));
178        assert!(tracker.exceeds_threshold(0.85));
179    }
180
181    #[test]
182    fn test_should_compact() {
183        let mut tracker = TokenTracker::new(Some(10000));
184
185        tracker.record_usage(TokenUsage::new(500, 100));
186        assert!(!tracker.should_compact(0.04));
187
188        tracker.record_usage(TokenUsage::new(9000, 100));
189        assert!(tracker.should_compact(0.85));
190
191        tracker.record_usage(TokenUsage::new(7000, 100));
192        assert!(!tracker.should_compact(0.85));
193    }
194
195    #[test]
196    fn test_default_compaction_threshold() {
197        use super::DEFAULT_COMPACTION_THRESHOLD;
198        assert!((DEFAULT_COMPACTION_THRESHOLD - 0.85).abs() < 0.001);
199    }
200
201    #[test]
202    fn test_set_context_limit() {
203        let mut tracker = TokenTracker::new(Some(200_000));
204        assert_eq!(tracker.context_limit(), Some(200_000));
205
206        tracker.set_context_limit(Some(128_000));
207        assert_eq!(tracker.context_limit(), Some(128_000));
208
209        // Verify usage ratio recalculates against new limit
210        tracker.record_usage(TokenUsage::new(100_000, 50));
211        let expected_ratio = 100_000.0 / 128_000.0;
212        assert!((tracker.usage_ratio().unwrap_or_default() - expected_ratio).abs() < 0.001);
213    }
214
215    #[test]
216    fn test_reset_current_usage() {
217        let mut tracker = TokenTracker::new(Some(10000));
218        tracker.record_usage(TokenUsage::new(9000, 100));
219
220        assert!(tracker.should_compact(0.85));
221
222        tracker.reset_current_usage();
223
224        assert_eq!(tracker.last_input_tokens(), 0);
225        assert!(!tracker.should_compact(0.85));
226        assert_eq!(tracker.total_input_tokens(), 9000);
227        assert_eq!(tracker.total_output_tokens(), 100);
228    }
229
230    #[test]
231    fn test_cache_and_reasoning_totals_accumulate() {
232        let mut tracker = TokenTracker::new(Some(10000));
233
234        tracker.record_usage(TokenUsage {
235            input_tokens: 500,
236            output_tokens: 100,
237            cache_read_tokens: Some(200),
238            cache_creation_tokens: Some(50),
239            reasoning_tokens: Some(30),
240            ..TokenUsage::default()
241        });
242        tracker.record_usage(TokenUsage {
243            input_tokens: 600,
244            output_tokens: 80,
245            cache_read_tokens: Some(300),
246            cache_creation_tokens: None,
247            reasoning_tokens: Some(20),
248            ..TokenUsage::default()
249        });
250
251        assert_eq!(tracker.total_cache_read_tokens(), 500);
252        assert_eq!(tracker.total_cache_creation_tokens(), 50);
253        assert_eq!(tracker.total_reasoning_tokens(), 50);
254    }
255
256    #[test]
257    fn test_last_usage_exposes_full_token_usage() {
258        let mut tracker = TokenTracker::new(Some(10000));
259        let sample = TokenUsage {
260            input_tokens: 500,
261            output_tokens: 100,
262            cache_read_tokens: Some(200),
263            cache_creation_tokens: Some(50),
264            reasoning_tokens: Some(30),
265            input_audio_tokens: Some(5),
266            ..TokenUsage::default()
267        };
268
269        tracker.record_usage(sample);
270
271        assert_eq!(*tracker.last_usage(), sample);
272    }
273
274    #[test]
275    fn test_reset_clears_last_usage_but_keeps_cache_totals() {
276        let mut tracker = TokenTracker::new(Some(10000));
277        tracker.record_usage(TokenUsage {
278            input_tokens: 500,
279            output_tokens: 100,
280            cache_read_tokens: Some(200),
281            cache_creation_tokens: Some(50),
282            reasoning_tokens: Some(30),
283            ..TokenUsage::default()
284        });
285
286        tracker.reset_current_usage();
287
288        assert_eq!(*tracker.last_usage(), TokenUsage::default());
289        assert_eq!(tracker.total_cache_read_tokens(), 200);
290        assert_eq!(tracker.total_cache_creation_tokens(), 50);
291        assert_eq!(tracker.total_reasoning_tokens(), 30);
292    }
293}