opencode_orchestrator_mcp/
token_tracker.rs1use opencode_rs::types::event::Event;
7use opencode_rs::types::message::Part;
8use opencode_rs::types::message::TokenUsage;
9
10#[derive(Debug, Clone)]
12pub struct TokenTracker {
13 pub provider_id: Option<String>,
15 pub model_id: Option<String>,
17 pub context_limit: Option<u64>,
19 pub latest_input_tokens: Option<u64>,
21 pub compaction_needed: bool,
23 threshold: f64,
25}
26
27impl Default for TokenTracker {
28 fn default() -> Self {
29 Self::with_threshold(0.80)
30 }
31}
32
33impl TokenTracker {
34 pub fn with_threshold(threshold: f64) -> Self {
38 Self {
39 provider_id: None,
40 model_id: None,
41 context_limit: None,
42 latest_input_tokens: None,
43 compaction_needed: false,
44 threshold,
45 }
46 }
47
48 #[cfg(test)]
50 pub fn new() -> Self {
51 Self::default()
52 }
53
54 pub fn observe_event<F>(&mut self, ev: &Event, context_limit_lookup: F)
59 where
60 F: Fn(&str, &str) -> Option<u64>,
61 {
62 match ev {
63 Event::MessageUpdated { properties } => {
64 if let Some(pid) = properties.info.provider_id.as_ref()
66 && let Some(mid) = properties.info.model_id.as_ref()
67 {
68 self.provider_id = Some(pid.clone());
69 self.model_id = Some(mid.clone());
70 self.context_limit = context_limit_lookup(pid, mid);
71 if properties.info.tokens.is_none() {
73 self.recompute_flag();
74 }
75 }
76
77 if let Some(tokens) = &properties.info.tokens {
79 self.observe_tokens(tokens);
80 }
81 }
82 Event::MessagePartUpdated { properties } => {
83 if let Some(part) = properties.part.as_ref()
85 && let Part::StepFinish {
86 tokens: Some(tokens),
87 ..
88 } = part
89 {
90 self.observe_tokens(tokens);
91 }
92 }
93 _ => {}
94 }
95 }
96
97 pub fn observe_tokens(&mut self, tokens: &TokenUsage) {
99 self.latest_input_tokens = Some(tokens.input);
100 self.recompute_flag();
101 }
102
103 fn recompute_flag(&mut self) {
105 if let (Some(input), Some(limit)) = (self.latest_input_tokens, self.context_limit)
106 && limit > 0
107 {
108 let ratio = input as f64 / limit as f64;
109 if ratio >= self.threshold {
110 self.compaction_needed = true;
111 tracing::info!(
112 "Context limit threshold reached: {}/{} ({:.1}%)",
113 input,
114 limit,
115 ratio * 100.0
116 );
117 }
118 }
119 }
120}
121
122#[cfg(test)]
124impl TokenTracker {
125 pub fn reset_after_compaction(&mut self) {
127 self.compaction_needed = false;
128 self.latest_input_tokens = None;
129 }
130
131 pub fn usage_ratio(&self) -> Option<f64> {
133 match (self.latest_input_tokens, self.context_limit) {
134 (Some(input), Some(limit)) if limit > 0 => Some(input as f64 / limit as f64),
135 _ => None,
136 }
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use opencode_rs::types::event::MessagePartEventProps;
144 use opencode_rs::types::event::MessageUpdatedProps;
145 use opencode_rs::types::message::MessageInfo;
146 use opencode_rs::types::message::MessageTime;
147
148 fn mk_token_usage(input: u64) -> TokenUsage {
149 TokenUsage {
150 total: None,
151 input,
152 output: 0,
153 reasoning: 0,
154 cache: None,
155 extra: serde_json::Value::Null,
156 }
157 }
158
159 fn mk_message_updated(
160 provider_id: Option<&str>,
161 model_id: Option<&str>,
162 tokens: Option<TokenUsage>,
163 ) -> Event {
164 Event::MessageUpdated {
165 properties: MessageUpdatedProps {
166 info: MessageInfo {
167 id: "msg-1".to_string(),
168 session_id: None,
169 role: "assistant".to_string(),
170 time: MessageTime {
171 created: 0,
172 completed: None,
173 },
174 agent: None,
175 variant: None,
176 format: None,
177 model: None,
178 system: None,
179 tools: vec![],
180 parent_id: None,
181 model_id: model_id.map(str::to_string),
182 provider_id: provider_id.map(str::to_string),
183 path: None,
184 cost: None,
185 tokens,
186 structured: None,
187 finish: None,
188 extra: serde_json::Value::Null,
189 },
190 extra: serde_json::Value::Null,
191 },
192 }
193 }
194
195 fn mk_message_part_step_finish(tokens: Option<TokenUsage>) -> Event {
196 Event::MessagePartUpdated {
197 properties: Box::new(MessagePartEventProps {
198 session_id: None,
199 message_id: None,
200 index: None,
201 part: Some(Part::StepFinish {
202 id: None,
203 reason: "done".to_string(),
204 snapshot: None,
205 cost: 0.0,
206 tokens,
207 }),
208 delta: None,
209 extra: serde_json::Value::Null,
210 }),
211 }
212 }
213
214 #[test]
215 fn triggers_compaction_at_80_percent() {
216 let mut tracker = TokenTracker::new();
217 tracker.context_limit = Some(1000);
218
219 tracker.latest_input_tokens = Some(799);
221 tracker.recompute_flag();
222 assert!(!tracker.compaction_needed);
223
224 tracker.latest_input_tokens = Some(800);
226 tracker.recompute_flag();
227 assert!(tracker.compaction_needed);
228 }
229
230 #[test]
231 fn does_not_trigger_without_limit() {
232 let mut tracker = TokenTracker::new();
233 tracker.latest_input_tokens = Some(10000);
234 tracker.recompute_flag();
235 assert!(!tracker.compaction_needed);
236 }
237
238 #[test]
239 fn reset_clears_flag() {
240 let mut tracker = TokenTracker::new();
241 tracker.context_limit = Some(100);
242 tracker.latest_input_tokens = Some(90);
243 tracker.recompute_flag();
244 assert!(tracker.compaction_needed);
245
246 tracker.reset_after_compaction();
247 assert!(!tracker.compaction_needed);
248 assert!(tracker.latest_input_tokens.is_none());
249 }
250
251 #[test]
252 fn usage_ratio_calculation() {
253 let mut tracker = TokenTracker::new();
254 tracker.context_limit = Some(1000);
255 tracker.latest_input_tokens = Some(500);
256
257 assert_eq!(tracker.usage_ratio(), Some(0.5));
258 }
259
260 #[test]
261 fn observe_event_tokens_first_limit_later_triggers_compaction() {
262 let lookup = |_: &str, _: &str| Some(1000);
263 let mut tracker = TokenTracker::new();
264
265 let ev_tokens = mk_message_part_step_finish(Some(mk_token_usage(800)));
267 tracker.observe_event(&ev_tokens, lookup);
268 assert!(!tracker.compaction_needed); let ev_limit = mk_message_updated(Some("provider-1"), Some("model-1"), None);
272 tracker.observe_event(&ev_limit, lookup);
273
274 assert!(tracker.compaction_needed);
276 }
277
278 #[test]
279 fn observe_event_limit_first_tokens_later_triggers_compaction() {
280 let lookup = |_: &str, _: &str| Some(1000);
281 let mut tracker = TokenTracker::new();
282
283 let ev_limit = mk_message_updated(Some("provider-1"), Some("model-1"), None);
285 tracker.observe_event(&ev_limit, lookup);
286 assert!(!tracker.compaction_needed); let ev_tokens = mk_message_part_step_finish(Some(mk_token_usage(800)));
290 tracker.observe_event(&ev_tokens, lookup);
291
292 assert!(tracker.compaction_needed);
294 }
295
296 #[test]
297 fn observe_event_combined_message_updated_event_triggers_compaction() {
298 let lookup = |_: &str, _: &str| Some(1000);
299 let mut tracker = TokenTracker::new();
300
301 let ev = mk_message_updated(
303 Some("provider-1"),
304 Some("model-1"),
305 Some(mk_token_usage(800)),
306 );
307 tracker.observe_event(&ev, lookup);
308
309 assert!(tracker.compaction_needed);
311 }
312
313 #[test]
314 fn observe_event_tokens_without_any_limit_does_not_trigger_compaction() {
315 let lookup = |_: &str, _: &str| Some(1000);
317 let mut tracker = TokenTracker::new();
318
319 let ev_tokens = mk_message_part_step_finish(Some(mk_token_usage(10_000)));
321 tracker.observe_event(&ev_tokens, lookup);
322
323 assert!(!tracker.compaction_needed);
325 assert_eq!(tracker.context_limit, None);
326 }
327}