aether_core/context/
token_tracker.rs1use llm::TokenUsage;
2
3pub const DEFAULT_COMPACTION_THRESHOLD: f64 = 0.85;
5
6#[derive(Debug, Clone, Default)]
15pub struct TokenTracker {
16 total_input_tokens: u64,
17 total_output_tokens: u64,
18 total_cache_read_tokens: u64,
19 total_cache_creation_tokens: u64,
20 total_reasoning_tokens: u64,
21 last_usage: TokenUsage,
22 context_limit: Option<u32>,
23}
24
25impl TokenTracker {
26 pub fn new(context_limit: Option<u32>) -> Self {
27 Self { context_limit, ..Self::default() }
28 }
29
30 pub fn record_usage(&mut self, sample: TokenUsage) {
32 self.total_input_tokens += u64::from(sample.input_tokens);
33 self.total_output_tokens += u64::from(sample.output_tokens);
34 self.total_cache_read_tokens += u64::from(sample.cache_read_tokens.unwrap_or(0));
35 self.total_cache_creation_tokens += u64::from(sample.cache_creation_tokens.unwrap_or(0));
36 self.total_reasoning_tokens += u64::from(sample.reasoning_tokens.unwrap_or(0));
37 self.last_usage = sample;
38 }
39
40 pub fn usage_ratio(&self) -> Option<f64> {
42 let context_limit = self.context_limit?;
43 if context_limit == 0 {
44 return None;
45 }
46 Some(f64::from(self.last_usage.input_tokens) / f64::from(context_limit))
47 }
48
49 pub fn exceeds_threshold(&self, threshold: f64) -> bool {
51 self.usage_ratio().is_some_and(|ratio| ratio >= threshold)
52 }
53
54 pub fn should_compact(&self, threshold: f64) -> bool {
59 let Some(context_limit) = self.context_limit else {
60 return false;
61 };
62 let min_tokens = std::cmp::max(context_limit / 10, 1000);
63 self.last_usage.input_tokens >= min_tokens && self.exceeds_threshold(threshold)
64 }
65
66 pub fn tokens_remaining(&self) -> Option<u32> {
68 self.context_limit.map(|context_limit| context_limit.saturating_sub(self.last_usage.input_tokens))
69 }
70
71 pub fn set_context_limit(&mut self, limit: Option<u32>) {
73 self.context_limit = limit;
74 }
75
76 pub fn context_limit(&self) -> Option<u32> {
78 self.context_limit
79 }
80
81 pub fn last_input_tokens(&self) -> u32 {
83 self.last_usage.input_tokens
84 }
85
86 pub fn last_usage(&self) -> &TokenUsage {
89 &self.last_usage
90 }
91
92 pub fn total_input_tokens(&self) -> u64 {
94 self.total_input_tokens
95 }
96
97 pub fn total_output_tokens(&self) -> u64 {
99 self.total_output_tokens
100 }
101
102 pub fn total_cache_read_tokens(&self) -> u64 {
104 self.total_cache_read_tokens
105 }
106
107 pub fn total_cache_creation_tokens(&self) -> u64 {
109 self.total_cache_creation_tokens
110 }
111
112 pub fn total_reasoning_tokens(&self) -> u64 {
114 self.total_reasoning_tokens
115 }
116
117 pub fn reset_current_usage(&mut self) {
121 self.last_usage = TokenUsage::default();
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn test_usage_tracking() {
131 let mut tracker = TokenTracker::new(Some(1000));
132
133 tracker.record_usage(TokenUsage::new(500, 100));
134 assert_eq!(tracker.usage_ratio(), Some(0.5));
135 assert!(!tracker.exceeds_threshold(0.85));
136
137 tracker.record_usage(TokenUsage::new(900, 50));
138 assert_eq!(tracker.usage_ratio(), Some(0.9));
139 assert!(tracker.exceeds_threshold(0.85));
140 }
141
142 #[test]
143 fn test_tokens_remaining() {
144 let mut tracker = TokenTracker::new(Some(1000));
145 tracker.record_usage(TokenUsage::new(700, 50));
146 assert_eq!(tracker.tokens_remaining(), Some(300));
147 }
148
149 #[test]
150 fn test_cumulative_totals() {
151 let mut tracker = TokenTracker::new(Some(1000));
152 tracker.record_usage(TokenUsage::new(100, 50));
153 tracker.record_usage(TokenUsage::new(200, 60));
154
155 assert_eq!(tracker.total_input_tokens(), 300);
156 assert_eq!(tracker.total_output_tokens(), 110);
157 assert_eq!(tracker.last_input_tokens(), 200);
158 }
159
160 #[test]
161 fn test_unknown_context_limit() {
162 let tracker = TokenTracker::new(None);
163 assert_eq!(tracker.usage_ratio(), None);
164 assert_eq!(tracker.tokens_remaining(), None);
165 assert!(!tracker.should_compact(0.85));
166 }
167
168 #[test]
169 fn test_exceeds_threshold() {
170 let mut tracker = TokenTracker::new(Some(1000));
171
172 tracker.record_usage(TokenUsage::new(500, 100));
173 assert!(!tracker.exceeds_threshold(0.6));
174 assert!(tracker.exceeds_threshold(0.5));
175
176 tracker.record_usage(TokenUsage::new(850, 50));
177 assert!(tracker.exceeds_threshold(0.8));
178 assert!(tracker.exceeds_threshold(0.85));
179 }
180
181 #[test]
182 fn test_should_compact() {
183 let mut tracker = TokenTracker::new(Some(10000));
184
185 tracker.record_usage(TokenUsage::new(500, 100));
186 assert!(!tracker.should_compact(0.04));
187
188 tracker.record_usage(TokenUsage::new(9000, 100));
189 assert!(tracker.should_compact(0.85));
190
191 tracker.record_usage(TokenUsage::new(7000, 100));
192 assert!(!tracker.should_compact(0.85));
193 }
194
195 #[test]
196 fn test_default_compaction_threshold() {
197 use super::DEFAULT_COMPACTION_THRESHOLD;
198 assert!((DEFAULT_COMPACTION_THRESHOLD - 0.85).abs() < 0.001);
199 }
200
201 #[test]
202 fn test_set_context_limit() {
203 let mut tracker = TokenTracker::new(Some(200_000));
204 assert_eq!(tracker.context_limit(), Some(200_000));
205
206 tracker.set_context_limit(Some(128_000));
207 assert_eq!(tracker.context_limit(), Some(128_000));
208
209 tracker.record_usage(TokenUsage::new(100_000, 50));
211 let expected_ratio = 100_000.0 / 128_000.0;
212 assert!((tracker.usage_ratio().unwrap_or_default() - expected_ratio).abs() < 0.001);
213 }
214
215 #[test]
216 fn test_reset_current_usage() {
217 let mut tracker = TokenTracker::new(Some(10000));
218 tracker.record_usage(TokenUsage::new(9000, 100));
219
220 assert!(tracker.should_compact(0.85));
221
222 tracker.reset_current_usage();
223
224 assert_eq!(tracker.last_input_tokens(), 0);
225 assert!(!tracker.should_compact(0.85));
226 assert_eq!(tracker.total_input_tokens(), 9000);
227 assert_eq!(tracker.total_output_tokens(), 100);
228 }
229
230 #[test]
231 fn test_cache_and_reasoning_totals_accumulate() {
232 let mut tracker = TokenTracker::new(Some(10000));
233
234 tracker.record_usage(TokenUsage {
235 input_tokens: 500,
236 output_tokens: 100,
237 cache_read_tokens: Some(200),
238 cache_creation_tokens: Some(50),
239 reasoning_tokens: Some(30),
240 ..TokenUsage::default()
241 });
242 tracker.record_usage(TokenUsage {
243 input_tokens: 600,
244 output_tokens: 80,
245 cache_read_tokens: Some(300),
246 cache_creation_tokens: None,
247 reasoning_tokens: Some(20),
248 ..TokenUsage::default()
249 });
250
251 assert_eq!(tracker.total_cache_read_tokens(), 500);
252 assert_eq!(tracker.total_cache_creation_tokens(), 50);
253 assert_eq!(tracker.total_reasoning_tokens(), 50);
254 }
255
256 #[test]
257 fn test_last_usage_exposes_full_token_usage() {
258 let mut tracker = TokenTracker::new(Some(10000));
259 let sample = TokenUsage {
260 input_tokens: 500,
261 output_tokens: 100,
262 cache_read_tokens: Some(200),
263 cache_creation_tokens: Some(50),
264 reasoning_tokens: Some(30),
265 input_audio_tokens: Some(5),
266 ..TokenUsage::default()
267 };
268
269 tracker.record_usage(sample);
270
271 assert_eq!(*tracker.last_usage(), sample);
272 }
273
274 #[test]
275 fn test_reset_clears_last_usage_but_keeps_cache_totals() {
276 let mut tracker = TokenTracker::new(Some(10000));
277 tracker.record_usage(TokenUsage {
278 input_tokens: 500,
279 output_tokens: 100,
280 cache_read_tokens: Some(200),
281 cache_creation_tokens: Some(50),
282 reasoning_tokens: Some(30),
283 ..TokenUsage::default()
284 });
285
286 tracker.reset_current_usage();
287
288 assert_eq!(*tracker.last_usage(), TokenUsage::default());
289 assert_eq!(tracker.total_cache_read_tokens(), 200);
290 assert_eq!(tracker.total_cache_creation_tokens(), 50);
291 assert_eq!(tracker.total_reasoning_tokens(), 30);
292 }
293}