use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
const MAX_EVENTS: usize = 1000;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ContextEvent {
SystemPrompt {
content: String,
tokens: usize,
},
GrepResult {
pattern: String,
matches: usize,
tokens: usize,
},
LlmQueryResult {
query: String,
response_preview: String,
tokens: usize,
},
AssistantCode {
code: String,
tokens: usize,
},
ExecutionOutput {
output: String,
tokens: usize,
},
Final {
answer: String,
tokens: usize,
},
ToolCall {
name: String,
arguments_preview: String,
tokens: usize,
},
ToolResult {
tool_call_id: String,
result_preview: String,
tokens: usize,
},
}
impl ContextEvent {
pub fn tokens(&self) -> usize {
match self {
Self::SystemPrompt { tokens, .. } => *tokens,
Self::GrepResult { tokens, .. } => *tokens,
Self::LlmQueryResult { tokens, .. } => *tokens,
Self::AssistantCode { tokens, .. } => *tokens,
Self::ExecutionOutput { tokens, .. } => *tokens,
Self::Final { tokens, .. } => *tokens,
Self::ToolCall { tokens, .. } => *tokens,
Self::ToolResult { tokens, .. } => *tokens,
}
}
pub fn label(&self) -> &'static str {
match self {
Self::SystemPrompt { .. } => "system_prompt",
Self::GrepResult { .. } => "grep_result",
Self::LlmQueryResult { .. } => "llm_query_result",
Self::AssistantCode { .. } => "assistant_code",
Self::ExecutionOutput { .. } => "execution_output",
Self::Final { .. } => "final",
Self::ToolCall { .. } => "tool_call",
Self::ToolResult { .. } => "tool_result",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextTrace {
max_tokens: usize,
events: VecDeque<ContextEvent>,
total_tokens: usize,
iteration: usize,
}
impl ContextTrace {
pub fn new(max_tokens: usize) -> Self {
Self {
max_tokens,
events: VecDeque::with_capacity(64),
total_tokens: 0,
iteration: 0,
}
}
pub fn log_event(&mut self, event: ContextEvent) {
self.total_tokens += event.tokens();
while self.events.len() >= MAX_EVENTS {
if let Some(evicted) = self.events.pop_front() {
self.total_tokens = self.total_tokens.saturating_sub(evicted.tokens());
}
}
self.events.push_back(event);
}
pub fn next_iteration(&mut self) {
self.iteration += 1;
}
pub fn iteration(&self) -> usize {
self.iteration
}
pub fn total_tokens(&self) -> usize {
self.total_tokens
}
pub fn remaining_tokens(&self) -> usize {
self.max_tokens.saturating_sub(self.total_tokens)
}
pub fn budget_used_percent(&self) -> f32 {
if self.max_tokens == 0 {
0.0
} else {
(self.total_tokens as f32 / self.max_tokens as f32) * 100.0
}
}
pub fn is_over_budget(&self) -> bool {
self.total_tokens > self.max_tokens
}
pub fn events(&self) -> &VecDeque<ContextEvent> {
&self.events
}
pub fn events_of_type(&self, label: &str) -> Vec<&ContextEvent> {
self.events
.iter()
.filter(|e| e.label() == label)
.collect()
}
pub fn summary(&self) -> ContextTraceSummary {
let mut event_counts = std::collections::HashMap::new();
let mut event_tokens = std::collections::HashMap::new();
for event in &self.events {
let label = event.label().to_string();
*event_counts.entry(label.clone()).or_insert(0) += 1;
*event_tokens.entry(label).or_insert(0) += event.tokens();
}
ContextTraceSummary {
total_tokens: self.total_tokens,
max_tokens: self.max_tokens,
budget_used_percent: self.budget_used_percent(),
iteration: self.iteration,
event_counts,
event_tokens,
events_len: self.events.len(),
}
}
pub fn estimate_tokens(text: &str) -> usize {
(text.chars().count() / 4).max(1)
}
pub fn event_from_text(event: ContextEvent, text: &str) -> ContextEvent {
let tokens = Self::estimate_tokens(text);
match event {
ContextEvent::SystemPrompt { content, .. } => {
ContextEvent::SystemPrompt { content, tokens }
}
ContextEvent::GrepResult { pattern, matches, .. } => {
ContextEvent::GrepResult { pattern, matches, tokens }
}
ContextEvent::LlmQueryResult { query, response_preview, .. } => {
ContextEvent::LlmQueryResult { query, response_preview, tokens }
}
ContextEvent::AssistantCode { code, .. } => {
ContextEvent::AssistantCode { code, tokens }
}
ContextEvent::ExecutionOutput { output, .. } => {
ContextEvent::ExecutionOutput { output, tokens }
}
ContextEvent::Final { answer, .. } => {
ContextEvent::Final { answer, tokens }
}
ContextEvent::ToolCall { name, arguments_preview, .. } => {
ContextEvent::ToolCall { name, arguments_preview, tokens }
}
ContextEvent::ToolResult { tool_call_id, result_preview, .. } => {
ContextEvent::ToolResult { tool_call_id, result_preview, tokens }
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextTraceSummary {
pub total_tokens: usize,
pub max_tokens: usize,
pub budget_used_percent: f32,
pub iteration: usize,
pub event_counts: std::collections::HashMap<String, usize>,
pub event_tokens: std::collections::HashMap<String, usize>,
pub events_len: usize,
}
impl ContextTraceSummary {
pub fn format(&self) -> String {
let mut lines = vec![
format!("Context Trace Summary (iteration {})", self.iteration),
format!(" Budget: {}/{} tokens ({:.1}%)",
self.total_tokens, self.max_tokens, self.budget_used_percent),
format!(" Events: {}", self.events_len),
];
if !self.event_counts.is_empty() {
lines.push(" By type:".to_string());
for (label, count) in &self.event_counts {
let tokens = self.event_tokens.get(label).copied().unwrap_or(0);
lines.push(format!(" {}: {} events, {} tokens", label, count, tokens));
}
}
lines.join("\n")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_trace_has_zero_tokens() {
let trace = ContextTrace::new(1000);
assert_eq!(trace.total_tokens(), 0);
assert_eq!(trace.remaining_tokens(), 1000);
}
#[test]
fn log_event_adds_tokens() {
let mut trace = ContextTrace::new(1000);
trace.log_event(ContextEvent::SystemPrompt {
content: "test".to_string(),
tokens: 100,
});
assert_eq!(trace.total_tokens(), 100);
assert_eq!(trace.remaining_tokens(), 900);
}
#[test]
fn budget_exceeded_check() {
let mut trace = ContextTrace::new(100);
trace.log_event(ContextEvent::SystemPrompt {
content: "test".to_string(),
tokens: 150,
});
assert!(trace.is_over_budget());
}
#[test]
fn event_type_filtering() {
let mut trace = ContextTrace::new(1000);
trace.log_event(ContextEvent::SystemPrompt {
content: "test".to_string(),
tokens: 100,
});
trace.log_event(ContextEvent::GrepResult {
pattern: "async".to_string(),
matches: 5,
tokens: 50,
});
trace.log_event(ContextEvent::SystemPrompt {
content: "test2".to_string(),
tokens: 75,
});
let system_events = trace.events_of_type("system_prompt");
assert_eq!(system_events.len(), 2);
let grep_events = trace.events_of_type("grep_result");
assert_eq!(grep_events.len(), 1);
}
#[test]
fn summary_statistics() {
let mut trace = ContextTrace::new(1000);
trace.log_event(ContextEvent::SystemPrompt {
content: "test".to_string(),
tokens: 100,
});
trace.log_event(ContextEvent::GrepResult {
pattern: "async".to_string(),
matches: 5,
tokens: 50,
});
trace.next_iteration();
let summary = trace.summary();
assert_eq!(summary.total_tokens, 150);
assert_eq!(summary.iteration, 1);
assert_eq!(summary.budget_used_percent, 15.0);
}
#[test]
fn estimate_tokens_approximation() {
assert_eq!(ContextTrace::estimate_tokens("test"), 1);
assert_eq!(ContextTrace::estimate_tokens("test test test test"), 4);
assert_eq!(ContextTrace::estimate_tokens("12345678"), 2);
}
#[test]
fn evict_old_events_when_full() {
let mut trace = ContextTrace::new(10000);
for i in 0..(MAX_EVENTS + 100) {
trace.log_event(ContextEvent::SystemPrompt {
content: format!("event {}", i),
tokens: 1,
});
}
assert!(trace.events.len() <= MAX_EVENTS);
assert!(trace.total_tokens <= MAX_EVENTS);
}
}