aether_core/context/
token_tracker.rs1pub const DEFAULT_COMPACTION_THRESHOLD: f64 = 0.85;
3
4#[derive(Debug, Clone, Default)]
7pub struct TokenTracker {
8 total_input_tokens: u64,
10 total_output_tokens: u64,
12 total_cached_input_tokens: u64,
14 last_input_tokens: u32,
16 last_cached_input_tokens: Option<u32>,
18 context_limit: Option<u32>,
20}
21
22impl TokenTracker {
23 pub fn new(context_limit: Option<u32>) -> Self {
24 Self {
25 total_input_tokens: 0,
26 total_output_tokens: 0,
27 total_cached_input_tokens: 0,
28 last_input_tokens: 0,
29 last_cached_input_tokens: None,
30 context_limit,
31 }
32 }
33
34 pub fn record_usage(
36 &mut self,
37 input_tokens: u32,
38 output_tokens: u32,
39 cached_input_tokens: Option<u32>,
40 ) {
41 self.total_input_tokens += u64::from(input_tokens);
42 self.total_output_tokens += u64::from(output_tokens);
43 if let Some(cached) = cached_input_tokens {
44 self.total_cached_input_tokens += u64::from(cached);
45 }
46 self.last_input_tokens = input_tokens;
47 self.last_cached_input_tokens = cached_input_tokens;
48 }
49
50 pub fn usage_ratio(&self) -> Option<f64> {
52 let context_limit = self.context_limit?;
53 if context_limit == 0 {
54 return None;
55 }
56 Some(f64::from(self.last_input_tokens) / f64::from(context_limit))
57 }
58
59 pub fn exceeds_threshold(&self, threshold: f64) -> bool {
61 self.usage_ratio().is_some_and(|ratio| ratio >= threshold)
62 }
63
64 pub fn should_compact(&self, threshold: f64) -> bool {
69 let Some(context_limit) = self.context_limit else {
70 return false;
71 };
72 let min_tokens = std::cmp::max(context_limit / 10, 1000);
73 self.last_input_tokens >= min_tokens && self.exceeds_threshold(threshold)
74 }
75
76 pub fn tokens_remaining(&self) -> Option<u32> {
78 self.context_limit
79 .map(|context_limit| context_limit.saturating_sub(self.last_input_tokens))
80 }
81
82 pub fn set_context_limit(&mut self, limit: Option<u32>) {
84 self.context_limit = limit;
85 }
86
87 pub fn context_limit(&self) -> Option<u32> {
89 self.context_limit
90 }
91
92 pub fn last_input_tokens(&self) -> u32 {
94 self.last_input_tokens
95 }
96
97 pub fn total_input_tokens(&self) -> u64 {
99 self.total_input_tokens
100 }
101
102 pub fn total_output_tokens(&self) -> u64 {
104 self.total_output_tokens
105 }
106
107 pub fn total_cached_input_tokens(&self) -> u64 {
109 self.total_cached_input_tokens
110 }
111
112 pub fn last_cached_input_tokens(&self) -> Option<u32> {
114 self.last_cached_input_tokens
115 }
116
117 pub fn reset_current_usage(&mut self) {
121 self.last_input_tokens = 0;
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn test_usage_tracking() {
131 let mut tracker = TokenTracker::new(Some(1000));
132
133 tracker.record_usage(500, 100, None);
134 assert_eq!(tracker.usage_ratio(), Some(0.5));
135 assert!(!tracker.exceeds_threshold(0.85));
136
137 tracker.record_usage(900, 50, None);
138 assert_eq!(tracker.usage_ratio(), Some(0.9));
139 assert!(tracker.exceeds_threshold(0.85));
140 }
141
142 #[test]
143 fn test_tokens_remaining() {
144 let mut tracker = TokenTracker::new(Some(1000));
145 tracker.record_usage(700, 50, None);
146 assert_eq!(tracker.tokens_remaining(), Some(300));
147 }
148
149 #[test]
150 fn test_cumulative_totals() {
151 let mut tracker = TokenTracker::new(Some(1000));
152 tracker.record_usage(100, 50, None);
153 tracker.record_usage(200, 60, None);
154
155 assert_eq!(tracker.total_input_tokens(), 300);
156 assert_eq!(tracker.total_output_tokens(), 110);
157 assert_eq!(tracker.last_input_tokens(), 200); }
159
160 #[test]
161 fn test_unknown_context_limit() {
162 let tracker = TokenTracker::new(None);
163 assert_eq!(tracker.usage_ratio(), None);
164 assert_eq!(tracker.tokens_remaining(), None);
165 assert!(!tracker.should_compact(0.85));
166 }
167
168 #[test]
169 fn test_exceeds_threshold() {
170 let mut tracker = TokenTracker::new(Some(1000));
171
172 tracker.record_usage(500, 100, None);
173 assert!(!tracker.exceeds_threshold(0.6));
174 assert!(tracker.exceeds_threshold(0.5));
175
176 tracker.record_usage(850, 50, None);
177 assert!(tracker.exceeds_threshold(0.8));
178 assert!(tracker.exceeds_threshold(0.85));
179 }
180
181 #[test]
182 fn test_should_compact() {
183 let mut tracker = TokenTracker::new(Some(10000));
184
185 tracker.record_usage(500, 100, None);
186 assert!(!tracker.should_compact(0.04));
187
188 tracker.record_usage(9000, 100, None);
189 assert!(tracker.should_compact(0.85));
190
191 tracker.record_usage(7000, 100, None);
192 assert!(!tracker.should_compact(0.85));
193 }
194
195 #[test]
196 fn test_default_compaction_threshold() {
197 use super::DEFAULT_COMPACTION_THRESHOLD;
198 assert!((DEFAULT_COMPACTION_THRESHOLD - 0.85).abs() < 0.001);
199 }
200
201 #[test]
202 fn test_set_context_limit() {
203 let mut tracker = TokenTracker::new(Some(200_000));
204 assert_eq!(tracker.context_limit(), Some(200_000));
205
206 tracker.set_context_limit(Some(128_000));
207 assert_eq!(tracker.context_limit(), Some(128_000));
208
209 tracker.record_usage(100_000, 50, None);
211 let expected_ratio = 100_000.0 / 128_000.0;
212 assert!((tracker.usage_ratio().unwrap_or_default() - expected_ratio).abs() < 0.001);
213 }
214
215 #[test]
216 fn test_reset_current_usage() {
217 let mut tracker = TokenTracker::new(Some(10000));
218 tracker.record_usage(9000, 100, None);
219
220 assert!(tracker.should_compact(0.85));
221
222 tracker.reset_current_usage();
223
224 assert_eq!(tracker.last_input_tokens(), 0);
225 assert!(!tracker.should_compact(0.85));
226 assert_eq!(tracker.total_input_tokens(), 9000);
227 assert_eq!(tracker.total_output_tokens(), 100);
228 }
229
230 #[test]
231 fn test_cached_token_tracking() {
232 let mut tracker = TokenTracker::new(Some(1000));
233
234 tracker.record_usage(500, 100, Some(200));
235 assert_eq!(tracker.last_cached_input_tokens(), Some(200));
236 assert_eq!(tracker.total_cached_input_tokens(), 200);
237
238 tracker.record_usage(600, 50, Some(400));
239 assert_eq!(tracker.last_cached_input_tokens(), Some(400));
240 assert_eq!(tracker.total_cached_input_tokens(), 600);
241
242 tracker.record_usage(300, 30, None);
243 assert_eq!(tracker.last_cached_input_tokens(), None);
244 assert_eq!(tracker.total_cached_input_tokens(), 600);
245 }
246}