opencode_orchestrator_mcp/
token_tracker.rs1use opencode_rs::types::event::Event;
7use opencode_rs::types::message::Part;
8use opencode_rs::types::message::TokenUsage;
9
10fn sat_u32(value: u64) -> (u32, bool) {
11 if value > u64::from(u32::MAX) {
12 (u32::MAX, true)
13 } else {
14 (value as u32, false)
15 }
16}
17
18#[derive(Debug, Clone)]
20pub struct TokenTracker {
21 pub provider_id: Option<String>,
23 pub model_id: Option<String>,
25 pub context_limit: Option<u64>,
27 pub latest_input_tokens: Option<u64>,
29 pub latest_tokens: Option<TokenUsage>,
31 pub compaction_needed: bool,
33 threshold: f64,
35}
36
37impl Default for TokenTracker {
38 fn default() -> Self {
39 Self::with_threshold(0.80)
40 }
41}
42
43impl TokenTracker {
44 pub fn with_threshold(threshold: f64) -> Self {
48 Self {
49 provider_id: None,
50 model_id: None,
51 context_limit: None,
52 latest_input_tokens: None,
53 latest_tokens: None,
54 compaction_needed: false,
55 threshold,
56 }
57 }
58
59 #[cfg(test)]
61 pub fn new() -> Self {
62 Self::default()
63 }
64
65 pub fn observe_event<F>(&mut self, ev: &Event, context_limit_lookup: F)
70 where
71 F: Fn(&str, &str) -> Option<u64>,
72 {
73 match ev {
74 Event::MessageUpdated { properties } => {
75 if let Some(pid) = properties.info.provider_id.as_ref()
77 && let Some(mid) = properties.info.model_id.as_ref()
78 {
79 self.provider_id = Some(pid.clone());
80 self.model_id = Some(mid.clone());
81 self.context_limit = context_limit_lookup(pid, mid);
82 if properties.info.tokens.is_none() {
84 self.recompute_flag();
85 }
86 }
87
88 if let Some(tokens) = &properties.info.tokens {
90 self.observe_tokens(tokens);
91 }
92 }
93 Event::MessagePartUpdated { properties } => {
94 if let Some(part) = properties.part.as_ref()
96 && let Part::StepFinish {
97 tokens: Some(tokens),
98 ..
99 } = part
100 {
101 self.observe_tokens(tokens);
102 }
103 }
104 _ => {}
105 }
106 }
107
108 pub fn observe_tokens(&mut self, tokens: &TokenUsage) {
110 self.latest_input_tokens = Some(tokens.input);
111 self.latest_tokens = Some(tokens.clone());
112 self.recompute_flag();
113 }
114
115 pub fn to_log_token_usage(&self) -> (Option<agentic_logging::TokenUsage>, bool) {
116 let Some(tokens) = &self.latest_tokens else {
117 return (None, false);
118 };
119
120 let (prompt, prompt_saturated) = sat_u32(tokens.input);
121 let (completion, completion_saturated) = sat_u32(tokens.output);
122 let total_raw = tokens
123 .total
124 .unwrap_or_else(|| tokens.input.saturating_add(tokens.output));
125 let (total, total_saturated) = sat_u32(total_raw);
126 let (reasoning, reasoning_saturated) = sat_u32(tokens.reasoning);
127 let saturated =
128 prompt_saturated || completion_saturated || total_saturated || reasoning_saturated;
129
130 (
131 Some(agentic_logging::TokenUsage {
132 prompt,
133 completion,
134 total,
135 reasoning_tokens: (tokens.reasoning > 0).then_some(reasoning),
136 }),
137 saturated,
138 )
139 }
140
141 fn recompute_flag(&mut self) {
143 if let (Some(input), Some(limit)) = (self.latest_input_tokens, self.context_limit)
144 && limit > 0
145 {
146 let ratio = input as f64 / limit as f64;
147 if ratio >= self.threshold {
148 self.compaction_needed = true;
149 tracing::info!(
150 "Context limit threshold reached: {}/{} ({:.1}%)",
151 input,
152 limit,
153 ratio * 100.0
154 );
155 }
156 }
157 }
158}
159
160#[cfg(test)]
162impl TokenTracker {
163 pub fn reset_after_compaction(&mut self) {
165 self.compaction_needed = false;
166 self.latest_input_tokens = None;
167 self.latest_tokens = None;
168 }
169
170 pub fn usage_ratio(&self) -> Option<f64> {
172 match (self.latest_input_tokens, self.context_limit) {
173 (Some(input), Some(limit)) if limit > 0 => Some(input as f64 / limit as f64),
174 _ => None,
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use opencode_rs::types::event::MessagePartEventProps;
183 use opencode_rs::types::event::MessageUpdatedProps;
184 use opencode_rs::types::message::MessageInfo;
185 use opencode_rs::types::message::MessageTime;
186
187 fn mk_token_usage(input: u64) -> TokenUsage {
188 TokenUsage {
189 total: None,
190 input,
191 output: 0,
192 reasoning: 0,
193 cache: None,
194 extra: serde_json::Value::Null,
195 }
196 }
197
198 fn mk_message_updated(
199 provider_id: Option<&str>,
200 model_id: Option<&str>,
201 tokens: Option<TokenUsage>,
202 ) -> Event {
203 Event::MessageUpdated {
204 properties: Box::new(MessageUpdatedProps {
205 info: MessageInfo {
206 id: "msg-1".to_string(),
207 session_id: None,
208 role: "assistant".to_string(),
209 time: MessageTime {
210 created: 0,
211 completed: None,
212 },
213 agent: None,
214 format: None,
215 model: None,
216 system: None,
217 tools: std::collections::HashMap::new(),
218 parent_id: None,
219 model_id: model_id.map(str::to_string),
220 provider_id: provider_id.map(str::to_string),
221 path: None,
222 cost: None,
223 tokens,
224 structured: None,
225 finish: None,
226 extra: serde_json::Value::Null,
227 },
228 extra: serde_json::Value::Null,
229 }),
230 }
231 }
232
233 fn mk_message_part_step_finish(tokens: Option<TokenUsage>) -> Event {
234 Event::MessagePartUpdated {
235 properties: Box::new(MessagePartEventProps {
236 session_id: None,
237 message_id: None,
238 index: None,
239 part: Some(Part::StepFinish {
240 id: None,
241 reason: "done".to_string(),
242 snapshot: None,
243 cost: 0.0,
244 tokens,
245 }),
246 delta: None,
247 extra: serde_json::Value::Null,
248 }),
249 }
250 }
251
252 #[test]
253 fn triggers_compaction_at_80_percent() {
254 let mut tracker = TokenTracker::new();
255 tracker.context_limit = Some(1000);
256
257 tracker.latest_input_tokens = Some(799);
259 tracker.recompute_flag();
260 assert!(!tracker.compaction_needed);
261
262 tracker.latest_input_tokens = Some(800);
264 tracker.recompute_flag();
265 assert!(tracker.compaction_needed);
266 }
267
268 #[test]
269 fn does_not_trigger_without_limit() {
270 let mut tracker = TokenTracker::new();
271 tracker.latest_input_tokens = Some(10000);
272 tracker.recompute_flag();
273 assert!(!tracker.compaction_needed);
274 }
275
276 #[test]
277 fn reset_clears_flag() {
278 let mut tracker = TokenTracker::new();
279 tracker.context_limit = Some(100);
280 tracker.latest_input_tokens = Some(90);
281 tracker.recompute_flag();
282 assert!(tracker.compaction_needed);
283
284 tracker.reset_after_compaction();
285 assert!(!tracker.compaction_needed);
286 assert!(tracker.latest_input_tokens.is_none());
287 }
288
289 #[test]
290 fn usage_ratio_calculation() {
291 let mut tracker = TokenTracker::new();
292 tracker.context_limit = Some(1000);
293 tracker.latest_input_tokens = Some(500);
294
295 assert_eq!(tracker.usage_ratio(), Some(0.5));
296 }
297
298 #[test]
299 fn observe_event_tokens_first_limit_later_triggers_compaction() {
300 let lookup = |_: &str, _: &str| Some(1000);
301 let mut tracker = TokenTracker::new();
302
303 let ev_tokens = mk_message_part_step_finish(Some(mk_token_usage(800)));
305 tracker.observe_event(&ev_tokens, lookup);
306 assert!(!tracker.compaction_needed); let ev_limit = mk_message_updated(Some("provider-1"), Some("model-1"), None);
310 tracker.observe_event(&ev_limit, lookup);
311
312 assert!(tracker.compaction_needed);
314 }
315
316 #[test]
317 fn observe_event_limit_first_tokens_later_triggers_compaction() {
318 let lookup = |_: &str, _: &str| Some(1000);
319 let mut tracker = TokenTracker::new();
320
321 let ev_limit = mk_message_updated(Some("provider-1"), Some("model-1"), None);
323 tracker.observe_event(&ev_limit, lookup);
324 assert!(!tracker.compaction_needed); let ev_tokens = mk_message_part_step_finish(Some(mk_token_usage(800)));
328 tracker.observe_event(&ev_tokens, lookup);
329
330 assert!(tracker.compaction_needed);
332 }
333
334 #[test]
335 fn observe_event_combined_message_updated_event_triggers_compaction() {
336 let lookup = |_: &str, _: &str| Some(1000);
337 let mut tracker = TokenTracker::new();
338
339 let ev = mk_message_updated(
341 Some("provider-1"),
342 Some("model-1"),
343 Some(mk_token_usage(800)),
344 );
345 tracker.observe_event(&ev, lookup);
346
347 assert!(tracker.compaction_needed);
349 }
350
351 #[test]
352 fn observe_event_tokens_without_any_limit_does_not_trigger_compaction() {
353 let lookup = |_: &str, _: &str| Some(1000);
355 let mut tracker = TokenTracker::new();
356
357 let ev_tokens = mk_message_part_step_finish(Some(mk_token_usage(10_000)));
359 tracker.observe_event(&ev_tokens, lookup);
360
361 assert!(!tracker.compaction_needed);
363 assert_eq!(tracker.context_limit, None);
364 }
365
366 #[test]
367 fn to_log_token_usage_preserves_values_without_saturation() {
368 let mut tracker = TokenTracker::new();
369 tracker.observe_tokens(&TokenUsage {
370 total: Some(30),
371 input: 10,
372 output: 20,
373 reasoning: 5,
374 cache: None,
375 extra: serde_json::Value::Null,
376 });
377
378 let (usage, saturated) = tracker.to_log_token_usage();
379 let usage = usage.expect("usage should be present");
380 assert!(!saturated);
381 assert_eq!(usage.prompt, 10);
382 assert_eq!(usage.completion, 20);
383 assert_eq!(usage.total, 30);
384 assert_eq!(usage.reasoning_tokens, Some(5));
385 }
386
387 #[test]
388 fn to_log_token_usage_saturates_large_values() {
389 let mut tracker = TokenTracker::new();
390 tracker.observe_tokens(&TokenUsage {
391 total: Some(u64::MAX),
392 input: u64::MAX,
393 output: u64::MAX,
394 reasoning: u64::MAX,
395 cache: None,
396 extra: serde_json::Value::Null,
397 });
398
399 let (usage, saturated) = tracker.to_log_token_usage();
400 let usage = usage.expect("usage should be present");
401 assert!(saturated);
402 assert_eq!(usage.prompt, u32::MAX);
403 assert_eq!(usage.completion, u32::MAX);
404 assert_eq!(usage.total, u32::MAX);
405 assert_eq!(usage.reasoning_tokens, Some(u32::MAX));
406 }
407}