1use serde::{Deserialize, Serialize};
40use std::collections::VecDeque;
41
42const MAX_EVENTS: usize = 1000;
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum ContextEvent {
48 SystemPrompt { content: String, tokens: usize },
50 GrepResult {
52 pattern: String,
53 matches: usize,
54 tokens: usize,
55 },
56 LlmQueryResult {
58 query: String,
59 response_preview: String,
60 tokens: usize,
61 },
62 AssistantCode { code: String, tokens: usize },
64 ExecutionOutput { output: String, tokens: usize },
66 Final { answer: String, tokens: usize },
68 ToolCall {
70 name: String,
71 arguments_preview: String,
72 tokens: usize,
73 },
74 ToolResult {
76 tool_call_id: String,
77 result_preview: String,
78 tokens: usize,
79 },
80}
81
82impl ContextEvent {
83 pub fn tokens(&self) -> usize {
85 match self {
86 Self::SystemPrompt { tokens, .. } => *tokens,
87 Self::GrepResult { tokens, .. } => *tokens,
88 Self::LlmQueryResult { tokens, .. } => *tokens,
89 Self::AssistantCode { tokens, .. } => *tokens,
90 Self::ExecutionOutput { tokens, .. } => *tokens,
91 Self::Final { tokens, .. } => *tokens,
92 Self::ToolCall { tokens, .. } => *tokens,
93 Self::ToolResult { tokens, .. } => *tokens,
94 }
95 }
96
97 pub fn label(&self) -> &'static str {
99 match self {
100 Self::SystemPrompt { .. } => "system_prompt",
101 Self::GrepResult { .. } => "grep_result",
102 Self::LlmQueryResult { .. } => "llm_query_result",
103 Self::AssistantCode { .. } => "assistant_code",
104 Self::ExecutionOutput { .. } => "execution_output",
105 Self::Final { .. } => "final",
106 Self::ToolCall { .. } => "tool_call",
107 Self::ToolResult { .. } => "tool_result",
108 }
109 }
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct ContextTrace {
115 max_tokens: usize,
117 events: VecDeque<ContextEvent>,
119 total_tokens: usize,
121 iteration: usize,
123}
124
125impl ContextTrace {
126 pub fn new(max_tokens: usize) -> Self {
128 Self {
129 max_tokens,
130 events: VecDeque::with_capacity(64),
131 total_tokens: 0,
132 iteration: 0,
133 }
134 }
135
136 pub fn log_event(&mut self, event: ContextEvent) {
138 self.total_tokens += event.tokens();
139
140 while self.events.len() >= MAX_EVENTS {
142 if let Some(evicted) = self.events.pop_front() {
143 self.total_tokens = self.total_tokens.saturating_sub(evicted.tokens());
144 }
145 }
146
147 self.events.push_back(event);
148 }
149
150 pub fn next_iteration(&mut self) {
152 self.iteration += 1;
153 }
154
155 pub fn iteration(&self) -> usize {
157 self.iteration
158 }
159
160 pub fn total_tokens(&self) -> usize {
162 self.total_tokens
163 }
164
165 pub fn remaining_tokens(&self) -> usize {
167 self.max_tokens.saturating_sub(self.total_tokens)
168 }
169
170 pub fn budget_used_percent(&self) -> f32 {
172 if self.max_tokens == 0 {
173 0.0
174 } else {
175 let total = self.total_tokens as u64;
178 let max = self.max_tokens as u64;
179 let tenths = (total.saturating_mul(1000) + (max / 2)).saturating_div(max);
180 (tenths as f32) / 10.0
181 }
182 }
183
184 pub fn is_over_budget(&self) -> bool {
186 self.total_tokens > self.max_tokens
187 }
188
189 pub fn events(&self) -> &VecDeque<ContextEvent> {
191 &self.events
192 }
193
194 pub fn events_of_type(&self, label: &str) -> Vec<&ContextEvent> {
196 self.events.iter().filter(|e| e.label() == label).collect()
197 }
198
199 pub fn summary(&self) -> ContextTraceSummary {
201 let mut event_counts = std::collections::HashMap::new();
202 let mut event_tokens = std::collections::HashMap::new();
203
204 for event in &self.events {
205 let label = event.label().to_string();
206 *event_counts.entry(label.clone()).or_insert(0) += 1;
207 *event_tokens.entry(label).or_insert(0) += event.tokens();
208 }
209
210 ContextTraceSummary {
211 total_tokens: self.total_tokens,
212 max_tokens: self.max_tokens,
213 budget_used_percent: self.budget_used_percent(),
214 iteration: self.iteration,
215 event_counts,
216 event_tokens,
217 events_len: self.events.len(),
218 }
219 }
220
221 pub fn estimate_tokens(text: &str) -> usize {
225 (text.chars().count() / 4).max(1)
226 }
227
228 pub fn event_from_text(event: ContextEvent, text: &str) -> ContextEvent {
230 let tokens = Self::estimate_tokens(text);
231 match event {
232 ContextEvent::SystemPrompt { content, .. } => {
233 ContextEvent::SystemPrompt { content, tokens }
234 }
235 ContextEvent::GrepResult {
236 pattern, matches, ..
237 } => ContextEvent::GrepResult {
238 pattern,
239 matches,
240 tokens,
241 },
242 ContextEvent::LlmQueryResult {
243 query,
244 response_preview,
245 ..
246 } => ContextEvent::LlmQueryResult {
247 query,
248 response_preview,
249 tokens,
250 },
251 ContextEvent::AssistantCode { code, .. } => {
252 ContextEvent::AssistantCode { code, tokens }
253 }
254 ContextEvent::ExecutionOutput { output, .. } => {
255 ContextEvent::ExecutionOutput { output, tokens }
256 }
257 ContextEvent::Final { answer, .. } => ContextEvent::Final { answer, tokens },
258 ContextEvent::ToolCall {
259 name,
260 arguments_preview,
261 ..
262 } => ContextEvent::ToolCall {
263 name,
264 arguments_preview,
265 tokens,
266 },
267 ContextEvent::ToolResult {
268 tool_call_id,
269 result_preview,
270 ..
271 } => ContextEvent::ToolResult {
272 tool_call_id,
273 result_preview,
274 tokens,
275 },
276 }
277 }
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
282pub struct ContextTraceSummary {
283 pub total_tokens: usize,
285 pub max_tokens: usize,
287 pub budget_used_percent: f32,
289 pub iteration: usize,
291 pub event_counts: std::collections::HashMap<String, usize>,
293 pub event_tokens: std::collections::HashMap<String, usize>,
295 pub events_len: usize,
297}
298
299impl ContextTraceSummary {
300 pub fn format(&self) -> String {
302 let mut lines = vec![
303 format!("Context Trace Summary (iteration {})", self.iteration),
304 format!(
305 " Budget: {}/{} tokens ({:.1}%)",
306 self.total_tokens, self.max_tokens, self.budget_used_percent
307 ),
308 format!(" Events: {}", self.events_len),
309 ];
310
311 if !self.event_counts.is_empty() {
312 lines.push(" By type:".to_string());
313 for (label, count) in &self.event_counts {
314 let tokens = self.event_tokens.get(label).copied().unwrap_or(0);
315 lines.push(format!(
316 " {}: {} events, {} tokens",
317 label, count, tokens
318 ));
319 }
320 }
321
322 lines.join("\n")
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn new_trace_has_zero_tokens() {
332 let trace = ContextTrace::new(1000);
333 assert_eq!(trace.total_tokens(), 0);
334 assert_eq!(trace.remaining_tokens(), 1000);
335 }
336
337 #[test]
338 fn log_event_adds_tokens() {
339 let mut trace = ContextTrace::new(1000);
340 trace.log_event(ContextEvent::SystemPrompt {
341 content: "test".to_string(),
342 tokens: 100,
343 });
344 assert_eq!(trace.total_tokens(), 100);
345 assert_eq!(trace.remaining_tokens(), 900);
346 }
347
348 #[test]
349 fn budget_exceeded_check() {
350 let mut trace = ContextTrace::new(100);
351 trace.log_event(ContextEvent::SystemPrompt {
352 content: "test".to_string(),
353 tokens: 150,
354 });
355 assert!(trace.is_over_budget());
356 }
357
358 #[test]
359 fn event_type_filtering() {
360 let mut trace = ContextTrace::new(1000);
361 trace.log_event(ContextEvent::SystemPrompt {
362 content: "test".to_string(),
363 tokens: 100,
364 });
365 trace.log_event(ContextEvent::GrepResult {
366 pattern: "async".to_string(),
367 matches: 5,
368 tokens: 50,
369 });
370 trace.log_event(ContextEvent::SystemPrompt {
371 content: "test2".to_string(),
372 tokens: 75,
373 });
374
375 let system_events = trace.events_of_type("system_prompt");
376 assert_eq!(system_events.len(), 2);
377
378 let grep_events = trace.events_of_type("grep_result");
379 assert_eq!(grep_events.len(), 1);
380 }
381
382 #[test]
383 fn summary_statistics() {
384 let mut trace = ContextTrace::new(1000);
385 trace.log_event(ContextEvent::SystemPrompt {
386 content: "test".to_string(),
387 tokens: 100,
388 });
389 trace.log_event(ContextEvent::GrepResult {
390 pattern: "async".to_string(),
391 matches: 5,
392 tokens: 50,
393 });
394 trace.next_iteration();
395
396 let summary = trace.summary();
397 assert_eq!(summary.total_tokens, 150);
398 assert_eq!(summary.iteration, 1);
399 assert_eq!(summary.budget_used_percent, 15.0);
400 }
401
402 #[test]
403 fn estimate_tokens_approximation() {
404 assert_eq!(ContextTrace::estimate_tokens("test"), 1);
406 assert_eq!(ContextTrace::estimate_tokens("test test test test"), 4);
407 assert_eq!(ContextTrace::estimate_tokens("12345678"), 2);
408 }
409
410 #[test]
411 fn evict_old_events_when_full() {
412 let mut trace = ContextTrace::new(10000);
413
414 for i in 0..(MAX_EVENTS + 100) {
416 trace.log_event(ContextEvent::SystemPrompt {
417 content: format!("event {}", i),
418 tokens: 1,
419 });
420 }
421
422 assert!(trace.events.len() <= MAX_EVENTS);
424
425 assert!(trace.total_tokens <= MAX_EVENTS);
427 }
428}