opencode_orchestrator_mcp/
token_tracker.rs1use opencode_rs::types::event::Event;
7use opencode_rs::types::message::Part;
8use opencode_rs::types::message::TokenUsage;
9
10fn sat_u32(value: u64) -> (u32, bool) {
11 if value > u64::from(u32::MAX) {
12 (u32::MAX, true)
13 } else {
14 (value as u32, false)
15 }
16}
17
18#[derive(Debug, Clone)]
20pub struct TokenTracker {
21 pub provider_id: Option<String>,
23 pub model_id: Option<String>,
25 pub context_limit: Option<u64>,
27 pub latest_input_tokens: Option<u64>,
29 pub latest_tokens: Option<TokenUsage>,
31 pub compaction_needed: bool,
33 threshold: f64,
35}
36
37impl Default for TokenTracker {
38 fn default() -> Self {
39 Self::with_threshold(0.80)
40 }
41}
42
43impl TokenTracker {
44 pub fn with_threshold(threshold: f64) -> Self {
48 Self {
49 provider_id: None,
50 model_id: None,
51 context_limit: None,
52 latest_input_tokens: None,
53 latest_tokens: None,
54 compaction_needed: false,
55 threshold,
56 }
57 }
58
59 #[cfg(test)]
61 pub fn new() -> Self {
62 Self::default()
63 }
64
65 pub fn observe_event<F>(&mut self, ev: &Event, context_limit_lookup: F)
70 where
71 F: Fn(&str, &str) -> Option<u64>,
72 {
73 match ev {
74 Event::MessageUpdated { properties } => {
75 if let Some(pid) = properties.info.provider_id.as_ref()
77 && let Some(mid) = properties.info.model_id.as_ref()
78 {
79 self.provider_id = Some(pid.clone());
80 self.model_id = Some(mid.clone());
81 self.context_limit = context_limit_lookup(pid, mid);
82 if properties.info.tokens.is_none() {
84 self.recompute_flag();
85 }
86 }
87
88 if let Some(tokens) = &properties.info.tokens {
90 self.observe_tokens(tokens);
91 }
92 }
93 Event::MessagePartUpdated { properties } => {
94 if let Some(part) = properties.part.as_ref()
96 && let Part::StepFinish {
97 tokens: Some(tokens),
98 ..
99 } = part
100 {
101 self.observe_tokens(tokens);
102 }
103 }
104 _ => {}
105 }
106 }
107
108 pub fn observe_tokens(&mut self, tokens: &TokenUsage) {
110 self.latest_input_tokens = Some(tokens.input);
111 self.latest_tokens = Some(tokens.clone());
112 self.recompute_flag();
113 }
114
115 pub fn to_log_token_usage(&self) -> (Option<agentic_logging::TokenUsage>, bool) {
116 let Some(tokens) = &self.latest_tokens else {
117 return (None, false);
118 };
119
120 let (prompt, prompt_saturated) = sat_u32(tokens.input);
121 let (completion, completion_saturated) = sat_u32(tokens.output);
122 let total_raw = tokens
123 .total
124 .unwrap_or_else(|| tokens.input.saturating_add(tokens.output));
125 let (total, total_saturated) = sat_u32(total_raw);
126 let (reasoning, reasoning_saturated) = sat_u32(tokens.reasoning);
127 let saturated =
128 prompt_saturated || completion_saturated || total_saturated || reasoning_saturated;
129
130 (
131 Some(agentic_logging::TokenUsage {
132 prompt,
133 completion,
134 total,
135 reasoning_tokens: (tokens.reasoning > 0).then_some(reasoning),
136 }),
137 saturated,
138 )
139 }
140
141 fn recompute_flag(&mut self) {
143 if let (Some(input), Some(limit)) = (self.latest_input_tokens, self.context_limit)
144 && limit > 0
145 {
146 let ratio = input as f64 / limit as f64;
147 if ratio >= self.threshold {
148 self.compaction_needed = true;
149 tracing::info!(
150 "Context limit threshold reached: {}/{} ({:.1}%)",
151 input,
152 limit,
153 ratio * 100.0
154 );
155 }
156 }
157 }
158}
159
160#[cfg(test)]
162impl TokenTracker {
163 pub fn reset_after_compaction(&mut self) {
165 self.compaction_needed = false;
166 self.latest_input_tokens = None;
167 self.latest_tokens = None;
168 }
169
170 pub fn usage_ratio(&self) -> Option<f64> {
172 match (self.latest_input_tokens, self.context_limit) {
173 (Some(input), Some(limit)) if limit > 0 => Some(input as f64 / limit as f64),
174 _ => None,
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use opencode_rs::types::event::MessagePartEventProps;
183 use opencode_rs::types::event::MessageUpdatedProps;
184 use opencode_rs::types::message::MessageInfo;
185 use opencode_rs::types::message::MessageTime;
186
187 fn mk_token_usage(input: u64) -> TokenUsage {
188 TokenUsage {
189 total: None,
190 input,
191 output: 0,
192 reasoning: 0,
193 cache: None,
194 extra: serde_json::Value::Null,
195 }
196 }
197
198 fn mk_message_updated(
199 provider_id: Option<&str>,
200 model_id: Option<&str>,
201 tokens: Option<TokenUsage>,
202 ) -> Event {
203 Event::MessageUpdated {
204 properties: Box::new(MessageUpdatedProps {
205 info: MessageInfo {
206 id: "msg-1".to_string(),
207 session_id: None,
208 role: "assistant".to_string(),
209 time: MessageTime {
210 created: 0,
211 completed: None,
212 },
213 agent: None,
214 variant: None,
215 format: None,
216 model: None,
217 system: None,
218 tools: std::collections::HashMap::new(),
219 parent_id: None,
220 model_id: model_id.map(str::to_string),
221 provider_id: provider_id.map(str::to_string),
222 path: None,
223 cost: None,
224 tokens,
225 structured: None,
226 finish: None,
227 extra: serde_json::Value::Null,
228 },
229 extra: serde_json::Value::Null,
230 }),
231 }
232 }
233
234 fn mk_message_part_step_finish(tokens: Option<TokenUsage>) -> Event {
235 Event::MessagePartUpdated {
236 properties: Box::new(MessagePartEventProps {
237 session_id: None,
238 message_id: None,
239 index: None,
240 part: Some(Part::StepFinish {
241 id: None,
242 reason: "done".to_string(),
243 snapshot: None,
244 cost: 0.0,
245 tokens,
246 }),
247 delta: None,
248 extra: serde_json::Value::Null,
249 }),
250 }
251 }
252
253 #[test]
254 fn triggers_compaction_at_80_percent() {
255 let mut tracker = TokenTracker::new();
256 tracker.context_limit = Some(1000);
257
258 tracker.latest_input_tokens = Some(799);
260 tracker.recompute_flag();
261 assert!(!tracker.compaction_needed);
262
263 tracker.latest_input_tokens = Some(800);
265 tracker.recompute_flag();
266 assert!(tracker.compaction_needed);
267 }
268
269 #[test]
270 fn does_not_trigger_without_limit() {
271 let mut tracker = TokenTracker::new();
272 tracker.latest_input_tokens = Some(10000);
273 tracker.recompute_flag();
274 assert!(!tracker.compaction_needed);
275 }
276
277 #[test]
278 fn reset_clears_flag() {
279 let mut tracker = TokenTracker::new();
280 tracker.context_limit = Some(100);
281 tracker.latest_input_tokens = Some(90);
282 tracker.recompute_flag();
283 assert!(tracker.compaction_needed);
284
285 tracker.reset_after_compaction();
286 assert!(!tracker.compaction_needed);
287 assert!(tracker.latest_input_tokens.is_none());
288 }
289
290 #[test]
291 fn usage_ratio_calculation() {
292 let mut tracker = TokenTracker::new();
293 tracker.context_limit = Some(1000);
294 tracker.latest_input_tokens = Some(500);
295
296 assert_eq!(tracker.usage_ratio(), Some(0.5));
297 }
298
299 #[test]
300 fn observe_event_tokens_first_limit_later_triggers_compaction() {
301 let lookup = |_: &str, _: &str| Some(1000);
302 let mut tracker = TokenTracker::new();
303
304 let ev_tokens = mk_message_part_step_finish(Some(mk_token_usage(800)));
306 tracker.observe_event(&ev_tokens, lookup);
307 assert!(!tracker.compaction_needed); let ev_limit = mk_message_updated(Some("provider-1"), Some("model-1"), None);
311 tracker.observe_event(&ev_limit, lookup);
312
313 assert!(tracker.compaction_needed);
315 }
316
317 #[test]
318 fn observe_event_limit_first_tokens_later_triggers_compaction() {
319 let lookup = |_: &str, _: &str| Some(1000);
320 let mut tracker = TokenTracker::new();
321
322 let ev_limit = mk_message_updated(Some("provider-1"), Some("model-1"), None);
324 tracker.observe_event(&ev_limit, lookup);
325 assert!(!tracker.compaction_needed); let ev_tokens = mk_message_part_step_finish(Some(mk_token_usage(800)));
329 tracker.observe_event(&ev_tokens, lookup);
330
331 assert!(tracker.compaction_needed);
333 }
334
335 #[test]
336 fn observe_event_combined_message_updated_event_triggers_compaction() {
337 let lookup = |_: &str, _: &str| Some(1000);
338 let mut tracker = TokenTracker::new();
339
340 let ev = mk_message_updated(
342 Some("provider-1"),
343 Some("model-1"),
344 Some(mk_token_usage(800)),
345 );
346 tracker.observe_event(&ev, lookup);
347
348 assert!(tracker.compaction_needed);
350 }
351
352 #[test]
353 fn observe_event_tokens_without_any_limit_does_not_trigger_compaction() {
354 let lookup = |_: &str, _: &str| Some(1000);
356 let mut tracker = TokenTracker::new();
357
358 let ev_tokens = mk_message_part_step_finish(Some(mk_token_usage(10_000)));
360 tracker.observe_event(&ev_tokens, lookup);
361
362 assert!(!tracker.compaction_needed);
364 assert_eq!(tracker.context_limit, None);
365 }
366
367 #[test]
368 fn to_log_token_usage_preserves_values_without_saturation() {
369 let mut tracker = TokenTracker::new();
370 tracker.observe_tokens(&TokenUsage {
371 total: Some(30),
372 input: 10,
373 output: 20,
374 reasoning: 5,
375 cache: None,
376 extra: serde_json::Value::Null,
377 });
378
379 let (usage, saturated) = tracker.to_log_token_usage();
380 let usage = usage.expect("usage should be present");
381 assert!(!saturated);
382 assert_eq!(usage.prompt, 10);
383 assert_eq!(usage.completion, 20);
384 assert_eq!(usage.total, 30);
385 assert_eq!(usage.reasoning_tokens, Some(5));
386 }
387
388 #[test]
389 fn to_log_token_usage_saturates_large_values() {
390 let mut tracker = TokenTracker::new();
391 tracker.observe_tokens(&TokenUsage {
392 total: Some(u64::MAX),
393 input: u64::MAX,
394 output: u64::MAX,
395 reasoning: u64::MAX,
396 cache: None,
397 extra: serde_json::Value::Null,
398 });
399
400 let (usage, saturated) = tracker.to_log_token_usage();
401 let usage = usage.expect("usage should be present");
402 assert!(saturated);
403 assert_eq!(usage.prompt, u32::MAX);
404 assert_eq!(usage.completion, u32::MAX);
405 assert_eq!(usage.total, u32::MAX);
406 assert_eq!(usage.reasoning_tokens, Some(u32::MAX));
407 }
408}