1use serde::{Deserialize, Serialize};
40use std::collections::VecDeque;
41
42const MAX_EVENTS: usize = 1000;
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum ContextEvent {
48 SystemPrompt {
50 content: String,
51 tokens: usize,
52 },
53 GrepResult {
55 pattern: String,
56 matches: usize,
57 tokens: usize,
58 },
59 LlmQueryResult {
61 query: String,
62 response_preview: String,
63 tokens: usize,
64 },
65 AssistantCode {
67 code: String,
68 tokens: usize,
69 },
70 ExecutionOutput {
72 output: String,
73 tokens: usize,
74 },
75 Final {
77 answer: String,
78 tokens: usize,
79 },
80 ToolCall {
82 name: String,
83 arguments_preview: String,
84 tokens: usize,
85 },
86 ToolResult {
88 tool_call_id: String,
89 result_preview: String,
90 tokens: usize,
91 },
92}
93
94impl ContextEvent {
95 pub fn tokens(&self) -> usize {
97 match self {
98 Self::SystemPrompt { tokens, .. } => *tokens,
99 Self::GrepResult { tokens, .. } => *tokens,
100 Self::LlmQueryResult { tokens, .. } => *tokens,
101 Self::AssistantCode { tokens, .. } => *tokens,
102 Self::ExecutionOutput { tokens, .. } => *tokens,
103 Self::Final { tokens, .. } => *tokens,
104 Self::ToolCall { tokens, .. } => *tokens,
105 Self::ToolResult { tokens, .. } => *tokens,
106 }
107 }
108
109 pub fn label(&self) -> &'static str {
111 match self {
112 Self::SystemPrompt { .. } => "system_prompt",
113 Self::GrepResult { .. } => "grep_result",
114 Self::LlmQueryResult { .. } => "llm_query_result",
115 Self::AssistantCode { .. } => "assistant_code",
116 Self::ExecutionOutput { .. } => "execution_output",
117 Self::Final { .. } => "final",
118 Self::ToolCall { .. } => "tool_call",
119 Self::ToolResult { .. } => "tool_result",
120 }
121 }
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ContextTrace {
127 max_tokens: usize,
129 events: VecDeque<ContextEvent>,
131 total_tokens: usize,
133 iteration: usize,
135}
136
137impl ContextTrace {
138 pub fn new(max_tokens: usize) -> Self {
140 Self {
141 max_tokens,
142 events: VecDeque::with_capacity(64),
143 total_tokens: 0,
144 iteration: 0,
145 }
146 }
147
148 pub fn log_event(&mut self, event: ContextEvent) {
150 self.total_tokens += event.tokens();
151
152 while self.events.len() >= MAX_EVENTS {
154 if let Some(evicted) = self.events.pop_front() {
155 self.total_tokens = self.total_tokens.saturating_sub(evicted.tokens());
156 }
157 }
158
159 self.events.push_back(event);
160 }
161
162 pub fn next_iteration(&mut self) {
164 self.iteration += 1;
165 }
166
167 pub fn iteration(&self) -> usize {
169 self.iteration
170 }
171
172 pub fn total_tokens(&self) -> usize {
174 self.total_tokens
175 }
176
177 pub fn remaining_tokens(&self) -> usize {
179 self.max_tokens.saturating_sub(self.total_tokens)
180 }
181
182 pub fn budget_used_percent(&self) -> f32 {
184 if self.max_tokens == 0 {
185 0.0
186 } else {
187 (self.total_tokens as f32 / self.max_tokens as f32) * 100.0
188 }
189 }
190
191 pub fn is_over_budget(&self) -> bool {
193 self.total_tokens > self.max_tokens
194 }
195
196 pub fn events(&self) -> &VecDeque<ContextEvent> {
198 &self.events
199 }
200
201 pub fn events_of_type(&self, label: &str) -> Vec<&ContextEvent> {
203 self.events
204 .iter()
205 .filter(|e| e.label() == label)
206 .collect()
207 }
208
209 pub fn summary(&self) -> ContextTraceSummary {
211 let mut event_counts = std::collections::HashMap::new();
212 let mut event_tokens = std::collections::HashMap::new();
213
214 for event in &self.events {
215 let label = event.label().to_string();
216 *event_counts.entry(label.clone()).or_insert(0) += 1;
217 *event_tokens.entry(label).or_insert(0) += event.tokens();
218 }
219
220 ContextTraceSummary {
221 total_tokens: self.total_tokens,
222 max_tokens: self.max_tokens,
223 budget_used_percent: self.budget_used_percent(),
224 iteration: self.iteration,
225 event_counts,
226 event_tokens,
227 events_len: self.events.len(),
228 }
229 }
230
231 pub fn estimate_tokens(text: &str) -> usize {
235 (text.chars().count() / 4).max(1)
236 }
237
238 pub fn event_from_text(event: ContextEvent, text: &str) -> ContextEvent {
240 let tokens = Self::estimate_tokens(text);
241 match event {
242 ContextEvent::SystemPrompt { content, .. } => {
243 ContextEvent::SystemPrompt { content, tokens }
244 }
245 ContextEvent::GrepResult { pattern, matches, .. } => {
246 ContextEvent::GrepResult { pattern, matches, tokens }
247 }
248 ContextEvent::LlmQueryResult { query, response_preview, .. } => {
249 ContextEvent::LlmQueryResult { query, response_preview, tokens }
250 }
251 ContextEvent::AssistantCode { code, .. } => {
252 ContextEvent::AssistantCode { code, tokens }
253 }
254 ContextEvent::ExecutionOutput { output, .. } => {
255 ContextEvent::ExecutionOutput { output, tokens }
256 }
257 ContextEvent::Final { answer, .. } => {
258 ContextEvent::Final { answer, tokens }
259 }
260 ContextEvent::ToolCall { name, arguments_preview, .. } => {
261 ContextEvent::ToolCall { name, arguments_preview, tokens }
262 }
263 ContextEvent::ToolResult { tool_call_id, result_preview, .. } => {
264 ContextEvent::ToolResult { tool_call_id, result_preview, tokens }
265 }
266 }
267 }
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct ContextTraceSummary {
273 pub total_tokens: usize,
275 pub max_tokens: usize,
277 pub budget_used_percent: f32,
279 pub iteration: usize,
281 pub event_counts: std::collections::HashMap<String, usize>,
283 pub event_tokens: std::collections::HashMap<String, usize>,
285 pub events_len: usize,
287}
288
289impl ContextTraceSummary {
290 pub fn format(&self) -> String {
292 let mut lines = vec![
293 format!("Context Trace Summary (iteration {})", self.iteration),
294 format!(" Budget: {}/{} tokens ({:.1}%)",
295 self.total_tokens, self.max_tokens, self.budget_used_percent),
296 format!(" Events: {}", self.events_len),
297 ];
298
299 if !self.event_counts.is_empty() {
300 lines.push(" By type:".to_string());
301 for (label, count) in &self.event_counts {
302 let tokens = self.event_tokens.get(label).copied().unwrap_or(0);
303 lines.push(format!(" {}: {} events, {} tokens", label, count, tokens));
304 }
305 }
306
307 lines.join("\n")
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 #[test]
316 fn new_trace_has_zero_tokens() {
317 let trace = ContextTrace::new(1000);
318 assert_eq!(trace.total_tokens(), 0);
319 assert_eq!(trace.remaining_tokens(), 1000);
320 }
321
322 #[test]
323 fn log_event_adds_tokens() {
324 let mut trace = ContextTrace::new(1000);
325 trace.log_event(ContextEvent::SystemPrompt {
326 content: "test".to_string(),
327 tokens: 100,
328 });
329 assert_eq!(trace.total_tokens(), 100);
330 assert_eq!(trace.remaining_tokens(), 900);
331 }
332
333 #[test]
334 fn budget_exceeded_check() {
335 let mut trace = ContextTrace::new(100);
336 trace.log_event(ContextEvent::SystemPrompt {
337 content: "test".to_string(),
338 tokens: 150,
339 });
340 assert!(trace.is_over_budget());
341 }
342
343 #[test]
344 fn event_type_filtering() {
345 let mut trace = ContextTrace::new(1000);
346 trace.log_event(ContextEvent::SystemPrompt {
347 content: "test".to_string(),
348 tokens: 100,
349 });
350 trace.log_event(ContextEvent::GrepResult {
351 pattern: "async".to_string(),
352 matches: 5,
353 tokens: 50,
354 });
355 trace.log_event(ContextEvent::SystemPrompt {
356 content: "test2".to_string(),
357 tokens: 75,
358 });
359
360 let system_events = trace.events_of_type("system_prompt");
361 assert_eq!(system_events.len(), 2);
362
363 let grep_events = trace.events_of_type("grep_result");
364 assert_eq!(grep_events.len(), 1);
365 }
366
367 #[test]
368 fn summary_statistics() {
369 let mut trace = ContextTrace::new(1000);
370 trace.log_event(ContextEvent::SystemPrompt {
371 content: "test".to_string(),
372 tokens: 100,
373 });
374 trace.log_event(ContextEvent::GrepResult {
375 pattern: "async".to_string(),
376 matches: 5,
377 tokens: 50,
378 });
379 trace.next_iteration();
380
381 let summary = trace.summary();
382 assert_eq!(summary.total_tokens, 150);
383 assert_eq!(summary.iteration, 1);
384 assert_eq!(summary.budget_used_percent, 15.0);
385 }
386
387 #[test]
388 fn estimate_tokens_approximation() {
389 assert_eq!(ContextTrace::estimate_tokens("test"), 1);
391 assert_eq!(ContextTrace::estimate_tokens("test test test test"), 4);
392 assert_eq!(ContextTrace::estimate_tokens("12345678"), 2);
393 }
394
395 #[test]
396 fn evict_old_events_when_full() {
397 let mut trace = ContextTrace::new(10000);
398
399 for i in 0..(MAX_EVENTS + 100) {
401 trace.log_event(ContextEvent::SystemPrompt {
402 content: format!("event {}", i),
403 tokens: 1,
404 });
405 }
406
407 assert!(trace.events.len() <= MAX_EVENTS);
409
410 assert!(trace.total_tokens <= MAX_EVENTS);
412 }
413}