Skip to main content

aether_core/context/
token_tracker.rs

1/// Default threshold for triggering context compaction (85%)
2pub const DEFAULT_COMPACTION_THRESHOLD: f64 = 0.85;
3
4/// Tracks token usage from LLM API responses.
5/// Uses real usage data from API, not estimation.
6#[derive(Debug, Clone, Default)]
7pub struct TokenTracker {
8    /// Total input tokens across all API calls
9    total_input_tokens: u64,
10    /// Total output tokens across all API calls
11    total_output_tokens: u64,
12    /// Total cached input tokens across all API calls
13    total_cached_input_tokens: u64,
14    /// Input tokens from the most recent API call (current context size)
15    last_input_tokens: u32,
16    /// Cached input tokens from the most recent API call
17    last_cached_input_tokens: Option<u32>,
18    /// Configured context limit for the current provider
19    context_limit: Option<u32>,
20}
21
22impl TokenTracker {
23    pub fn new(context_limit: Option<u32>) -> Self {
24        Self {
25            total_input_tokens: 0,
26            total_output_tokens: 0,
27            total_cached_input_tokens: 0,
28            last_input_tokens: 0,
29            last_cached_input_tokens: None,
30            context_limit,
31        }
32    }
33
34    /// Record usage from an LLM API response
35    pub fn record_usage(
36        &mut self,
37        input_tokens: u32,
38        output_tokens: u32,
39        cached_input_tokens: Option<u32>,
40    ) {
41        self.total_input_tokens += u64::from(input_tokens);
42        self.total_output_tokens += u64::from(output_tokens);
43        if let Some(cached) = cached_input_tokens {
44            self.total_cached_input_tokens += u64::from(cached);
45        }
46        self.last_input_tokens = input_tokens;
47        self.last_cached_input_tokens = cached_input_tokens;
48    }
49
50    /// Current context usage as a ratio (0.0 - 1.0)
51    pub fn usage_ratio(&self) -> Option<f64> {
52        let context_limit = self.context_limit?;
53        if context_limit == 0 {
54            return None;
55        }
56        Some(f64::from(self.last_input_tokens) / f64::from(context_limit))
57    }
58
59    /// Whether current usage exceeds the given threshold
60    pub fn exceeds_threshold(&self, threshold: f64) -> bool {
61        self.usage_ratio().is_some_and(|ratio| ratio >= threshold)
62    }
63
64    /// Check if context should be compacted based on the given threshold.
65    /// This is a convenience method that combines usage ratio check with
66    /// a minimum context size requirement to avoid unnecessary compaction
67    /// on small conversations.
68    pub fn should_compact(&self, threshold: f64) -> bool {
69        let Some(context_limit) = self.context_limit else {
70            return false;
71        };
72        let min_tokens = std::cmp::max(context_limit / 10, 1000);
73        self.last_input_tokens >= min_tokens && self.exceeds_threshold(threshold)
74    }
75
76    /// Tokens remaining before hitting limit
77    pub fn tokens_remaining(&self) -> Option<u32> {
78        self.context_limit
79            .map(|context_limit| context_limit.saturating_sub(self.last_input_tokens))
80    }
81
82    /// Update the context limit (e.g. when switching models)
83    pub fn set_context_limit(&mut self, limit: Option<u32>) {
84        self.context_limit = limit;
85    }
86
87    /// Get the context limit
88    pub fn context_limit(&self) -> Option<u32> {
89        self.context_limit
90    }
91
92    /// Get last recorded input tokens (current context size)
93    pub fn last_input_tokens(&self) -> u32 {
94        self.last_input_tokens
95    }
96
97    /// Get total input tokens across all calls
98    pub fn total_input_tokens(&self) -> u64 {
99        self.total_input_tokens
100    }
101
102    /// Get total output tokens across all calls
103    pub fn total_output_tokens(&self) -> u64 {
104        self.total_output_tokens
105    }
106
107    /// Get total cached input tokens across all calls
108    pub fn total_cached_input_tokens(&self) -> u64 {
109        self.total_cached_input_tokens
110    }
111
112    /// Get last recorded cached input tokens
113    pub fn last_cached_input_tokens(&self) -> Option<u32> {
114        self.last_cached_input_tokens
115    }
116
117    /// Reset current usage tracking after context compaction.
118    /// Preserves cumulative totals for metrics while clearing the
119    /// `last_input_tokens` to prevent immediate re-triggering of compaction.
120    pub fn reset_current_usage(&mut self) {
121        self.last_input_tokens = 0;
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_usage_tracking() {
131        let mut tracker = TokenTracker::new(Some(1000));
132
133        tracker.record_usage(500, 100, None);
134        assert_eq!(tracker.usage_ratio(), Some(0.5));
135        assert!(!tracker.exceeds_threshold(0.85));
136
137        tracker.record_usage(900, 50, None);
138        assert_eq!(tracker.usage_ratio(), Some(0.9));
139        assert!(tracker.exceeds_threshold(0.85));
140    }
141
142    #[test]
143    fn test_tokens_remaining() {
144        let mut tracker = TokenTracker::new(Some(1000));
145        tracker.record_usage(700, 50, None);
146        assert_eq!(tracker.tokens_remaining(), Some(300));
147    }
148
149    #[test]
150    fn test_cumulative_totals() {
151        let mut tracker = TokenTracker::new(Some(1000));
152        tracker.record_usage(100, 50, None);
153        tracker.record_usage(200, 60, None);
154
155        assert_eq!(tracker.total_input_tokens(), 300);
156        assert_eq!(tracker.total_output_tokens(), 110);
157        assert_eq!(tracker.last_input_tokens(), 200); // Only last call
158    }
159
160    #[test]
161    fn test_unknown_context_limit() {
162        let tracker = TokenTracker::new(None);
163        assert_eq!(tracker.usage_ratio(), None);
164        assert_eq!(tracker.tokens_remaining(), None);
165        assert!(!tracker.should_compact(0.85));
166    }
167
168    #[test]
169    fn test_exceeds_threshold() {
170        let mut tracker = TokenTracker::new(Some(1000));
171
172        tracker.record_usage(500, 100, None);
173        assert!(!tracker.exceeds_threshold(0.6));
174        assert!(tracker.exceeds_threshold(0.5));
175
176        tracker.record_usage(850, 50, None);
177        assert!(tracker.exceeds_threshold(0.8));
178        assert!(tracker.exceeds_threshold(0.85));
179    }
180
181    #[test]
182    fn test_should_compact() {
183        let mut tracker = TokenTracker::new(Some(10000));
184
185        tracker.record_usage(500, 100, None);
186        assert!(!tracker.should_compact(0.04));
187
188        tracker.record_usage(9000, 100, None);
189        assert!(tracker.should_compact(0.85));
190
191        tracker.record_usage(7000, 100, None);
192        assert!(!tracker.should_compact(0.85));
193    }
194
195    #[test]
196    fn test_default_compaction_threshold() {
197        use super::DEFAULT_COMPACTION_THRESHOLD;
198        assert!((DEFAULT_COMPACTION_THRESHOLD - 0.85).abs() < 0.001);
199    }
200
201    #[test]
202    fn test_set_context_limit() {
203        let mut tracker = TokenTracker::new(Some(200_000));
204        assert_eq!(tracker.context_limit(), Some(200_000));
205
206        tracker.set_context_limit(Some(128_000));
207        assert_eq!(tracker.context_limit(), Some(128_000));
208
209        // Verify usage ratio recalculates against new limit
210        tracker.record_usage(100_000, 50, None);
211        let expected_ratio = 100_000.0 / 128_000.0;
212        assert!((tracker.usage_ratio().unwrap_or_default() - expected_ratio).abs() < 0.001);
213    }
214
215    #[test]
216    fn test_reset_current_usage() {
217        let mut tracker = TokenTracker::new(Some(10000));
218        tracker.record_usage(9000, 100, None);
219
220        assert!(tracker.should_compact(0.85));
221
222        tracker.reset_current_usage();
223
224        assert_eq!(tracker.last_input_tokens(), 0);
225        assert!(!tracker.should_compact(0.85));
226        assert_eq!(tracker.total_input_tokens(), 9000);
227        assert_eq!(tracker.total_output_tokens(), 100);
228    }
229
230    #[test]
231    fn test_cached_token_tracking() {
232        let mut tracker = TokenTracker::new(Some(1000));
233
234        tracker.record_usage(500, 100, Some(200));
235        assert_eq!(tracker.last_cached_input_tokens(), Some(200));
236        assert_eq!(tracker.total_cached_input_tokens(), 200);
237
238        tracker.record_usage(600, 50, Some(400));
239        assert_eq!(tracker.last_cached_input_tokens(), Some(400));
240        assert_eq!(tracker.total_cached_input_tokens(), 600);
241
242        tracker.record_usage(300, 30, None);
243        assert_eq!(tracker.last_cached_input_tokens(), None);
244        assert_eq!(tracker.total_cached_input_tokens(), 600);
245    }
246}