Skip to main content

opencode_orchestrator_mcp/
token_tracker.rs

1//! Token tracking for context limit safeguard.
2//!
3//! Monitors token usage during session runs and detects when the 80% threshold
4//! is reached to trigger server-side summarization.
5
6use opencode_rs::types::event::Event;
7use opencode_rs::types::message::Part;
8use opencode_rs::types::message::TokenUsage;
9
10/// Tracks token usage during a session run to detect context limit threshold.
11#[derive(Debug, Clone)]
12pub struct TokenTracker {
13    /// Provider ID from message events
14    pub provider_id: Option<String>,
15    /// Model ID from message events
16    pub model_id: Option<String>,
17    /// Context limit for the current model (from cached limits)
18    pub context_limit: Option<u64>,
19    /// Latest observed input token count
20    pub latest_input_tokens: Option<u64>,
21    /// Flag indicating compaction/summarization is needed
22    pub compaction_needed: bool,
23    /// Threshold at which to trigger summarization (0.0 - 1.0)
24    threshold: f64,
25}
26
27impl Default for TokenTracker {
28    fn default() -> Self {
29        Self::with_threshold(0.80)
30    }
31}
32
33impl TokenTracker {
34    /// Create a new token tracker with a custom compaction threshold.
35    ///
36    /// The threshold should be between 0.0 and 1.0 (e.g., 0.80 for 80%).
37    pub fn with_threshold(threshold: f64) -> Self {
38        Self {
39            provider_id: None,
40            model_id: None,
41            context_limit: None,
42            latest_input_tokens: None,
43            compaction_needed: false,
44            threshold,
45        }
46    }
47
48    /// Create a new token tracker with default threshold (80%).
49    #[cfg(test)]
50    pub fn new() -> Self {
51        Self::default()
52    }
53
54    /// Observe an SSE event and update token tracking.
55    ///
56    /// The `context_limit_lookup` function is called to look up the context limit
57    /// for a given (`provider_id`, `model_id`) pair from the cached limits.
58    pub fn observe_event<F>(&mut self, ev: &Event, context_limit_lookup: F)
59    where
60        F: Fn(&str, &str) -> Option<u64>,
61    {
62        match ev {
63            Event::MessageUpdated { properties } => {
64                // Extract provider/model info
65                if let Some(pid) = properties.info.provider_id.as_ref()
66                    && let Some(mid) = properties.info.model_id.as_ref()
67                {
68                    self.provider_id = Some(pid.clone());
69                    self.model_id = Some(mid.clone());
70                    self.context_limit = context_limit_lookup(pid, mid);
71                    // Recompute threshold if this event didn't carry tokens
72                    if properties.info.tokens.is_none() {
73                        self.recompute_flag();
74                    }
75                }
76
77                // Extract token usage
78                if let Some(tokens) = &properties.info.tokens {
79                    self.observe_tokens(tokens);
80                }
81            }
82            Event::MessagePartUpdated { properties } => {
83                // Check for StepFinish with token info
84                if let Some(part) = properties.part.as_ref()
85                    && let Part::StepFinish {
86                        tokens: Some(tokens),
87                        ..
88                    } = part
89                {
90                    self.observe_tokens(tokens);
91                }
92            }
93            _ => {}
94        }
95    }
96
97    /// Observe token usage and update threshold flag.
98    pub fn observe_tokens(&mut self, tokens: &TokenUsage) {
99        self.latest_input_tokens = Some(tokens.input);
100        self.recompute_flag();
101    }
102
103    /// Recompute the `compaction_needed` flag based on current state.
104    fn recompute_flag(&mut self) {
105        if let (Some(input), Some(limit)) = (self.latest_input_tokens, self.context_limit)
106            && limit > 0
107        {
108            let ratio = input as f64 / limit as f64;
109            if ratio >= self.threshold {
110                self.compaction_needed = true;
111                tracing::info!(
112                    "Context limit threshold reached: {}/{} ({:.1}%)",
113                    input,
114                    limit,
115                    ratio * 100.0
116                );
117            }
118        }
119    }
120}
121
122// Test-only helper methods
123#[cfg(test)]
124impl TokenTracker {
125    /// Reset after compaction/summarization.
126    pub fn reset_after_compaction(&mut self) {
127        self.compaction_needed = false;
128        self.latest_input_tokens = None;
129    }
130
131    /// Get the current usage ratio (0.0 to 1.0+).
132    pub fn usage_ratio(&self) -> Option<f64> {
133        match (self.latest_input_tokens, self.context_limit) {
134            (Some(input), Some(limit)) if limit > 0 => Some(input as f64 / limit as f64),
135            _ => None,
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use opencode_rs::types::event::MessagePartEventProps;
144    use opencode_rs::types::event::MessageUpdatedProps;
145    use opencode_rs::types::message::MessageInfo;
146    use opencode_rs::types::message::MessageTime;
147
148    fn mk_token_usage(input: u64) -> TokenUsage {
149        TokenUsage {
150            total: None,
151            input,
152            output: 0,
153            reasoning: 0,
154            cache: None,
155            extra: serde_json::Value::Null,
156        }
157    }
158
159    fn mk_message_updated(
160        provider_id: Option<&str>,
161        model_id: Option<&str>,
162        tokens: Option<TokenUsage>,
163    ) -> Event {
164        Event::MessageUpdated {
165            properties: MessageUpdatedProps {
166                info: MessageInfo {
167                    id: "msg-1".to_string(),
168                    session_id: None,
169                    role: "assistant".to_string(),
170                    time: MessageTime {
171                        created: 0,
172                        completed: None,
173                    },
174                    agent: None,
175                    variant: None,
176                    format: None,
177                    model: None,
178                    system: None,
179                    tools: vec![],
180                    parent_id: None,
181                    model_id: model_id.map(str::to_string),
182                    provider_id: provider_id.map(str::to_string),
183                    path: None,
184                    cost: None,
185                    tokens,
186                    structured: None,
187                    finish: None,
188                    extra: serde_json::Value::Null,
189                },
190                extra: serde_json::Value::Null,
191            },
192        }
193    }
194
195    fn mk_message_part_step_finish(tokens: Option<TokenUsage>) -> Event {
196        Event::MessagePartUpdated {
197            properties: Box::new(MessagePartEventProps {
198                session_id: None,
199                message_id: None,
200                index: None,
201                part: Some(Part::StepFinish {
202                    id: None,
203                    reason: "done".to_string(),
204                    snapshot: None,
205                    cost: 0.0,
206                    tokens,
207                }),
208                delta: None,
209                extra: serde_json::Value::Null,
210            }),
211        }
212    }
213
214    #[test]
215    fn triggers_compaction_at_80_percent() {
216        let mut tracker = TokenTracker::new();
217        tracker.context_limit = Some(1000);
218
219        // 79.9% - should not trigger
220        tracker.latest_input_tokens = Some(799);
221        tracker.recompute_flag();
222        assert!(!tracker.compaction_needed);
223
224        // 80.0% - should trigger
225        tracker.latest_input_tokens = Some(800);
226        tracker.recompute_flag();
227        assert!(tracker.compaction_needed);
228    }
229
230    #[test]
231    fn does_not_trigger_without_limit() {
232        let mut tracker = TokenTracker::new();
233        tracker.latest_input_tokens = Some(10000);
234        tracker.recompute_flag();
235        assert!(!tracker.compaction_needed);
236    }
237
238    #[test]
239    fn reset_clears_flag() {
240        let mut tracker = TokenTracker::new();
241        tracker.context_limit = Some(100);
242        tracker.latest_input_tokens = Some(90);
243        tracker.recompute_flag();
244        assert!(tracker.compaction_needed);
245
246        tracker.reset_after_compaction();
247        assert!(!tracker.compaction_needed);
248        assert!(tracker.latest_input_tokens.is_none());
249    }
250
251    #[test]
252    fn usage_ratio_calculation() {
253        let mut tracker = TokenTracker::new();
254        tracker.context_limit = Some(1000);
255        tracker.latest_input_tokens = Some(500);
256
257        assert_eq!(tracker.usage_ratio(), Some(0.5));
258    }
259
260    #[test]
261    fn observe_event_tokens_first_limit_later_triggers_compaction() {
262        let lookup = |_: &str, _: &str| Some(1000);
263        let mut tracker = TokenTracker::new();
264
265        // Tokens arrive first via StepFinish, but no context_limit yet
266        let ev_tokens = mk_message_part_step_finish(Some(mk_token_usage(800)));
267        tracker.observe_event(&ev_tokens, lookup);
268        assert!(!tracker.compaction_needed); // Can't trigger without limit
269
270        // Model info arrives later without tokens
271        let ev_limit = mk_message_updated(Some("provider-1"), Some("model-1"), None);
272        tracker.observe_event(&ev_limit, lookup);
273
274        // Should now trigger because 800/1000 = 80%
275        assert!(tracker.compaction_needed);
276    }
277
278    #[test]
279    fn observe_event_limit_first_tokens_later_triggers_compaction() {
280        let lookup = |_: &str, _: &str| Some(1000);
281        let mut tracker = TokenTracker::new();
282
283        // Model info arrives first
284        let ev_limit = mk_message_updated(Some("provider-1"), Some("model-1"), None);
285        tracker.observe_event(&ev_limit, lookup);
286        assert!(!tracker.compaction_needed); // No tokens yet
287
288        // Tokens arrive later
289        let ev_tokens = mk_message_part_step_finish(Some(mk_token_usage(800)));
290        tracker.observe_event(&ev_tokens, lookup);
291
292        // Should trigger because 800/1000 = 80%
293        assert!(tracker.compaction_needed);
294    }
295
296    #[test]
297    fn observe_event_combined_message_updated_event_triggers_compaction() {
298        let lookup = |_: &str, _: &str| Some(1000);
299        let mut tracker = TokenTracker::new();
300
301        // Single event with both model info and tokens
302        let ev = mk_message_updated(
303            Some("provider-1"),
304            Some("model-1"),
305            Some(mk_token_usage(800)),
306        );
307        tracker.observe_event(&ev, lookup);
308
309        // Should trigger because 800/1000 = 80%
310        assert!(tracker.compaction_needed);
311    }
312
313    #[test]
314    fn observe_event_tokens_without_any_limit_does_not_trigger_compaction() {
315        // Lookup won't be called since no model info event arrives
316        let lookup = |_: &str, _: &str| Some(1000);
317        let mut tracker = TokenTracker::new();
318
319        // Tokens arrive but no model info ever comes
320        let ev_tokens = mk_message_part_step_finish(Some(mk_token_usage(10_000)));
321        tracker.observe_event(&ev_tokens, lookup);
322
323        // Should NOT trigger because context_limit is None
324        assert!(!tracker.compaction_needed);
325        assert_eq!(tracker.context_limit, None);
326    }
327}