Skip to main content

opencode_orchestrator_mcp/
token_tracker.rs

1//! Token tracking for context limit safeguard.
2//!
3//! Monitors token usage during session runs and detects when the 80% threshold
4//! is reached to trigger server-side summarization.
5
6use opencode_rs::types::event::Event;
7use opencode_rs::types::message::Part;
8use opencode_rs::types::message::TokenUsage;
9
10fn sat_u32(value: u64) -> (u32, bool) {
11    if value > u64::from(u32::MAX) {
12        (u32::MAX, true)
13    } else {
14        (value as u32, false)
15    }
16}
17
18/// Tracks token usage during a session run to detect context limit threshold.
19#[derive(Debug, Clone)]
20pub struct TokenTracker {
21    /// Provider ID from message events
22    pub provider_id: Option<String>,
23    /// Model ID from message events
24    pub model_id: Option<String>,
25    /// Context limit for the current model (from cached limits)
26    pub context_limit: Option<u64>,
27    /// Latest observed input token count
28    pub latest_input_tokens: Option<u64>,
29    /// Latest observed full token usage
30    pub latest_tokens: Option<TokenUsage>,
31    /// Flag indicating compaction/summarization is needed
32    pub compaction_needed: bool,
33    /// Threshold at which to trigger summarization (0.0 - 1.0)
34    threshold: f64,
35}
36
37impl Default for TokenTracker {
38    fn default() -> Self {
39        Self::with_threshold(0.80)
40    }
41}
42
43impl TokenTracker {
44    /// Create a new token tracker with a custom compaction threshold.
45    ///
46    /// The threshold should be between 0.0 and 1.0 (e.g., 0.80 for 80%).
47    pub fn with_threshold(threshold: f64) -> Self {
48        Self {
49            provider_id: None,
50            model_id: None,
51            context_limit: None,
52            latest_input_tokens: None,
53            latest_tokens: None,
54            compaction_needed: false,
55            threshold,
56        }
57    }
58
59    /// Create a new token tracker with default threshold (80%).
60    #[cfg(test)]
61    pub fn new() -> Self {
62        Self::default()
63    }
64
65    /// Observe an SSE event and update token tracking.
66    ///
67    /// The `context_limit_lookup` function is called to look up the context limit
68    /// for a given (`provider_id`, `model_id`) pair from the cached limits.
69    pub fn observe_event<F>(&mut self, ev: &Event, context_limit_lookup: F)
70    where
71        F: Fn(&str, &str) -> Option<u64>,
72    {
73        match ev {
74            Event::MessageUpdated { properties } => {
75                // Extract provider/model info
76                if let Some(pid) = properties.info.provider_id.as_ref()
77                    && let Some(mid) = properties.info.model_id.as_ref()
78                {
79                    self.provider_id = Some(pid.clone());
80                    self.model_id = Some(mid.clone());
81                    self.context_limit = context_limit_lookup(pid, mid);
82                    // Recompute threshold if this event didn't carry tokens
83                    if properties.info.tokens.is_none() {
84                        self.recompute_flag();
85                    }
86                }
87
88                // Extract token usage
89                if let Some(tokens) = &properties.info.tokens {
90                    self.observe_tokens(tokens);
91                }
92            }
93            Event::MessagePartUpdated { properties } => {
94                // Check for StepFinish with token info
95                if let Some(part) = properties.part.as_ref()
96                    && let Part::StepFinish {
97                        tokens: Some(tokens),
98                        ..
99                    } = part
100                {
101                    self.observe_tokens(tokens);
102                }
103            }
104            _ => {}
105        }
106    }
107
108    /// Observe token usage and update threshold flag.
109    pub fn observe_tokens(&mut self, tokens: &TokenUsage) {
110        self.latest_input_tokens = Some(tokens.input);
111        self.latest_tokens = Some(tokens.clone());
112        self.recompute_flag();
113    }
114
115    pub fn to_log_token_usage(&self) -> (Option<agentic_logging::TokenUsage>, bool) {
116        let Some(tokens) = &self.latest_tokens else {
117            return (None, false);
118        };
119
120        let (prompt, prompt_saturated) = sat_u32(tokens.input);
121        let (completion, completion_saturated) = sat_u32(tokens.output);
122        let total_raw = tokens
123            .total
124            .unwrap_or_else(|| tokens.input.saturating_add(tokens.output));
125        let (total, total_saturated) = sat_u32(total_raw);
126        let (reasoning, reasoning_saturated) = sat_u32(tokens.reasoning);
127        let saturated =
128            prompt_saturated || completion_saturated || total_saturated || reasoning_saturated;
129
130        (
131            Some(agentic_logging::TokenUsage {
132                prompt,
133                completion,
134                total,
135                reasoning_tokens: (tokens.reasoning > 0).then_some(reasoning),
136            }),
137            saturated,
138        )
139    }
140
141    /// Recompute the `compaction_needed` flag based on current state.
142    fn recompute_flag(&mut self) {
143        if let (Some(input), Some(limit)) = (self.latest_input_tokens, self.context_limit)
144            && limit > 0
145        {
146            let ratio = input as f64 / limit as f64;
147            if ratio >= self.threshold {
148                self.compaction_needed = true;
149                tracing::info!(
150                    "Context limit threshold reached: {}/{} ({:.1}%)",
151                    input,
152                    limit,
153                    ratio * 100.0
154                );
155            }
156        }
157    }
158}
159
160// Test-only helper methods
161#[cfg(test)]
162impl TokenTracker {
163    /// Reset after compaction/summarization.
164    pub fn reset_after_compaction(&mut self) {
165        self.compaction_needed = false;
166        self.latest_input_tokens = None;
167        self.latest_tokens = None;
168    }
169
170    /// Get the current usage ratio (0.0 to 1.0+).
171    pub fn usage_ratio(&self) -> Option<f64> {
172        match (self.latest_input_tokens, self.context_limit) {
173            (Some(input), Some(limit)) if limit > 0 => Some(input as f64 / limit as f64),
174            _ => None,
175        }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use opencode_rs::types::event::MessagePartEventProps;
183    use opencode_rs::types::event::MessageUpdatedProps;
184    use opencode_rs::types::message::MessageInfo;
185    use opencode_rs::types::message::MessageTime;
186
187    fn mk_token_usage(input: u64) -> TokenUsage {
188        TokenUsage {
189            total: None,
190            input,
191            output: 0,
192            reasoning: 0,
193            cache: None,
194            extra: serde_json::Value::Null,
195        }
196    }
197
198    fn mk_message_updated(
199        provider_id: Option<&str>,
200        model_id: Option<&str>,
201        tokens: Option<TokenUsage>,
202    ) -> Event {
203        Event::MessageUpdated {
204            properties: Box::new(MessageUpdatedProps {
205                info: MessageInfo {
206                    id: "msg-1".to_string(),
207                    session_id: None,
208                    role: "assistant".to_string(),
209                    time: MessageTime {
210                        created: 0,
211                        completed: None,
212                    },
213                    agent: None,
214                    variant: None,
215                    format: None,
216                    model: None,
217                    system: None,
218                    tools: std::collections::HashMap::new(),
219                    parent_id: None,
220                    model_id: model_id.map(str::to_string),
221                    provider_id: provider_id.map(str::to_string),
222                    path: None,
223                    cost: None,
224                    tokens,
225                    structured: None,
226                    finish: None,
227                    extra: serde_json::Value::Null,
228                },
229                extra: serde_json::Value::Null,
230            }),
231        }
232    }
233
234    fn mk_message_part_step_finish(tokens: Option<TokenUsage>) -> Event {
235        Event::MessagePartUpdated {
236            properties: Box::new(MessagePartEventProps {
237                session_id: None,
238                message_id: None,
239                index: None,
240                part: Some(Part::StepFinish {
241                    id: None,
242                    reason: "done".to_string(),
243                    snapshot: None,
244                    cost: 0.0,
245                    tokens,
246                }),
247                delta: None,
248                extra: serde_json::Value::Null,
249            }),
250        }
251    }
252
253    #[test]
254    fn triggers_compaction_at_80_percent() {
255        let mut tracker = TokenTracker::new();
256        tracker.context_limit = Some(1000);
257
258        // 79.9% - should not trigger
259        tracker.latest_input_tokens = Some(799);
260        tracker.recompute_flag();
261        assert!(!tracker.compaction_needed);
262
263        // 80.0% - should trigger
264        tracker.latest_input_tokens = Some(800);
265        tracker.recompute_flag();
266        assert!(tracker.compaction_needed);
267    }
268
269    #[test]
270    fn does_not_trigger_without_limit() {
271        let mut tracker = TokenTracker::new();
272        tracker.latest_input_tokens = Some(10000);
273        tracker.recompute_flag();
274        assert!(!tracker.compaction_needed);
275    }
276
277    #[test]
278    fn reset_clears_flag() {
279        let mut tracker = TokenTracker::new();
280        tracker.context_limit = Some(100);
281        tracker.latest_input_tokens = Some(90);
282        tracker.recompute_flag();
283        assert!(tracker.compaction_needed);
284
285        tracker.reset_after_compaction();
286        assert!(!tracker.compaction_needed);
287        assert!(tracker.latest_input_tokens.is_none());
288    }
289
290    #[test]
291    fn usage_ratio_calculation() {
292        let mut tracker = TokenTracker::new();
293        tracker.context_limit = Some(1000);
294        tracker.latest_input_tokens = Some(500);
295
296        assert_eq!(tracker.usage_ratio(), Some(0.5));
297    }
298
299    #[test]
300    fn observe_event_tokens_first_limit_later_triggers_compaction() {
301        let lookup = |_: &str, _: &str| Some(1000);
302        let mut tracker = TokenTracker::new();
303
304        // Tokens arrive first via StepFinish, but no context_limit yet
305        let ev_tokens = mk_message_part_step_finish(Some(mk_token_usage(800)));
306        tracker.observe_event(&ev_tokens, lookup);
307        assert!(!tracker.compaction_needed); // Can't trigger without limit
308
309        // Model info arrives later without tokens
310        let ev_limit = mk_message_updated(Some("provider-1"), Some("model-1"), None);
311        tracker.observe_event(&ev_limit, lookup);
312
313        // Should now trigger because 800/1000 = 80%
314        assert!(tracker.compaction_needed);
315    }
316
317    #[test]
318    fn observe_event_limit_first_tokens_later_triggers_compaction() {
319        let lookup = |_: &str, _: &str| Some(1000);
320        let mut tracker = TokenTracker::new();
321
322        // Model info arrives first
323        let ev_limit = mk_message_updated(Some("provider-1"), Some("model-1"), None);
324        tracker.observe_event(&ev_limit, lookup);
325        assert!(!tracker.compaction_needed); // No tokens yet
326
327        // Tokens arrive later
328        let ev_tokens = mk_message_part_step_finish(Some(mk_token_usage(800)));
329        tracker.observe_event(&ev_tokens, lookup);
330
331        // Should trigger because 800/1000 = 80%
332        assert!(tracker.compaction_needed);
333    }
334
335    #[test]
336    fn observe_event_combined_message_updated_event_triggers_compaction() {
337        let lookup = |_: &str, _: &str| Some(1000);
338        let mut tracker = TokenTracker::new();
339
340        // Single event with both model info and tokens
341        let ev = mk_message_updated(
342            Some("provider-1"),
343            Some("model-1"),
344            Some(mk_token_usage(800)),
345        );
346        tracker.observe_event(&ev, lookup);
347
348        // Should trigger because 800/1000 = 80%
349        assert!(tracker.compaction_needed);
350    }
351
352    #[test]
353    fn observe_event_tokens_without_any_limit_does_not_trigger_compaction() {
354        // Lookup won't be called since no model info event arrives
355        let lookup = |_: &str, _: &str| Some(1000);
356        let mut tracker = TokenTracker::new();
357
358        // Tokens arrive but no model info ever comes
359        let ev_tokens = mk_message_part_step_finish(Some(mk_token_usage(10_000)));
360        tracker.observe_event(&ev_tokens, lookup);
361
362        // Should NOT trigger because context_limit is None
363        assert!(!tracker.compaction_needed);
364        assert_eq!(tracker.context_limit, None);
365    }
366
367    #[test]
368    fn to_log_token_usage_preserves_values_without_saturation() {
369        let mut tracker = TokenTracker::new();
370        tracker.observe_tokens(&TokenUsage {
371            total: Some(30),
372            input: 10,
373            output: 20,
374            reasoning: 5,
375            cache: None,
376            extra: serde_json::Value::Null,
377        });
378
379        let (usage, saturated) = tracker.to_log_token_usage();
380        let usage = usage.expect("usage should be present");
381        assert!(!saturated);
382        assert_eq!(usage.prompt, 10);
383        assert_eq!(usage.completion, 20);
384        assert_eq!(usage.total, 30);
385        assert_eq!(usage.reasoning_tokens, Some(5));
386    }
387
388    #[test]
389    fn to_log_token_usage_saturates_large_values() {
390        let mut tracker = TokenTracker::new();
391        tracker.observe_tokens(&TokenUsage {
392            total: Some(u64::MAX),
393            input: u64::MAX,
394            output: u64::MAX,
395            reasoning: u64::MAX,
396            cache: None,
397            extra: serde_json::Value::Null,
398        });
399
400        let (usage, saturated) = tracker.to_log_token_usage();
401        let usage = usage.expect("usage should be present");
402        assert!(saturated);
403        assert_eq!(usage.prompt, u32::MAX);
404        assert_eq!(usage.completion, u32::MAX);
405        assert_eq!(usage.total, u32::MAX);
406        assert_eq!(usage.reasoning_tokens, Some(u32::MAX));
407    }
408}