Skip to main content

codetether_agent/rlm/
context_trace.rs

1//! Context tracing for RLM iterations.
2//!
3//! Tracks token budget and context events per RLM iteration to enable
4//! analysis of context window usage and optimization opportunities.
5//!
6//! # Events Traced
7//!
8//! - SystemPrompt: Initial system message with context summary
9//! - GrepResult: Results from grep operations
10//! - LlmQueryResult: Results from sub-LLM calls
11//! - AssistantCode: Code generated by the assistant
12//! - Final: Final answer returned
13//!
14//! # Usage
15//!
16//! ```ignore
17//! use codetether_agent::rlm::context_trace::{ContextTrace, ContextEvent};
18//!
19//! let mut trace = ContextTrace::new(max_tokens);
20//!
21//! // Log events as they occur
22//! trace.log_event(ContextEvent::SystemPrompt {
23//!     content: system_prompt.clone(),
24//!     tokens: count_tokens(&system_prompt),
25//! });
26//!
27//! trace.log_event(ContextEvent::GrepResult {
28//!     pattern: "async fn".to_string(),
29//!     matches: 5,
30//!     tokens: 150,
31//! });
32//!
33//! // Get summary statistics
34//! let stats = trace.summary();
35//! println!("Total tokens: {}", stats.total_tokens);
36//! println!("Budget used: {:.1}%", stats.budget_used_percent);
37//! ```
38
39use serde::{Deserialize, Serialize};
40use std::collections::VecDeque;
41
42/// Maximum number of events to keep in the trace buffer.
43const MAX_EVENTS: usize = 1000;
44
45/// A context event logged during RLM iteration.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum ContextEvent {
48    /// System prompt with context summary
49    SystemPrompt {
50        content: String,
51        tokens: usize,
52    },
53    /// Result from a grep operation
54    GrepResult {
55        pattern: String,
56        matches: usize,
57        tokens: usize,
58    },
59    /// Result from a sub-LLM query
60    LlmQueryResult {
61        query: String,
62        response_preview: String,
63        tokens: usize,
64    },
65    /// Code generated by the assistant
66    AssistantCode {
67        code: String,
68        tokens: usize,
69    },
70    /// Output from code execution
71    ExecutionOutput {
72        output: String,
73        tokens: usize,
74    },
75    /// Final answer
76    Final {
77        answer: String,
78        tokens: usize,
79    },
80    /// Tool call
81    ToolCall {
82        name: String,
83        arguments_preview: String,
84        tokens: usize,
85    },
86    /// Tool result
87    ToolResult {
88        tool_call_id: String,
89        result_preview: String,
90        tokens: usize,
91    },
92}
93
94impl ContextEvent {
95    /// Get the token count for this event.
96    pub fn tokens(&self) -> usize {
97        match self {
98            Self::SystemPrompt { tokens, .. } => *tokens,
99            Self::GrepResult { tokens, .. } => *tokens,
100            Self::LlmQueryResult { tokens, .. } => *tokens,
101            Self::AssistantCode { tokens, .. } => *tokens,
102            Self::ExecutionOutput { tokens, .. } => *tokens,
103            Self::Final { tokens, .. } => *tokens,
104            Self::ToolCall { tokens, .. } => *tokens,
105            Self::ToolResult { tokens, .. } => *tokens,
106        }
107    }
108    
109    /// Get a human-readable label for this event type.
110    pub fn label(&self) -> &'static str {
111        match self {
112            Self::SystemPrompt { .. } => "system_prompt",
113            Self::GrepResult { .. } => "grep_result",
114            Self::LlmQueryResult { .. } => "llm_query_result",
115            Self::AssistantCode { .. } => "assistant_code",
116            Self::ExecutionOutput { .. } => "execution_output",
117            Self::Final { .. } => "final",
118            Self::ToolCall { .. } => "tool_call",
119            Self::ToolResult { .. } => "tool_result",
120        }
121    }
122}
123
124/// Context trace for a single RLM analysis run.
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ContextTrace {
127    /// Maximum token budget
128    max_tokens: usize,
129    /// Events logged during the run
130    events: VecDeque<ContextEvent>,
131    /// Total tokens accumulated
132    total_tokens: usize,
133    /// Iteration number
134    iteration: usize,
135}
136
137impl ContextTrace {
138    /// Create a new context trace with the given token budget.
139    pub fn new(max_tokens: usize) -> Self {
140        Self {
141            max_tokens,
142            events: VecDeque::with_capacity(64),
143            total_tokens: 0,
144            iteration: 0,
145        }
146    }
147
148    /// Log a context event.
149    pub fn log_event(&mut self, event: ContextEvent) {
150        self.total_tokens += event.tokens();
151        
152        // Evict oldest events if buffer is full
153        while self.events.len() >= MAX_EVENTS {
154            if let Some(evicted) = self.events.pop_front() {
155                self.total_tokens = self.total_tokens.saturating_sub(evicted.tokens());
156            }
157        }
158        
159        self.events.push_back(event);
160    }
161
162    /// Increment the iteration counter.
163    pub fn next_iteration(&mut self) {
164        self.iteration += 1;
165    }
166
167    /// Get the current iteration number.
168    pub fn iteration(&self) -> usize {
169        self.iteration
170    }
171
172    /// Get the total tokens used.
173    pub fn total_tokens(&self) -> usize {
174        self.total_tokens
175    }
176
177    /// Get the remaining token budget.
178    pub fn remaining_tokens(&self) -> usize {
179        self.max_tokens.saturating_sub(self.total_tokens)
180    }
181
182    /// Get the percentage of budget used.
183    pub fn budget_used_percent(&self) -> f32 {
184        if self.max_tokens == 0 {
185            0.0
186        } else {
187            (self.total_tokens as f32 / self.max_tokens as f32) * 100.0
188        }
189    }
190
191    /// Check if the budget is exceeded.
192    pub fn is_over_budget(&self) -> bool {
193        self.total_tokens > self.max_tokens
194    }
195
196    /// Get all events.
197    pub fn events(&self) -> &VecDeque<ContextEvent> {
198        &self.events
199    }
200
201    /// Get events by type.
202    pub fn events_of_type(&self, label: &str) -> Vec<&ContextEvent> {
203        self.events
204            .iter()
205            .filter(|e| e.label() == label)
206            .collect()
207    }
208
209    /// Get a summary of the trace.
210    pub fn summary(&self) -> ContextTraceSummary {
211        let mut event_counts = std::collections::HashMap::new();
212        let mut event_tokens = std::collections::HashMap::new();
213        
214        for event in &self.events {
215            let label = event.label().to_string();
216            *event_counts.entry(label.clone()).or_insert(0) += 1;
217            *event_tokens.entry(label).or_insert(0) += event.tokens();
218        }
219        
220        ContextTraceSummary {
221            total_tokens: self.total_tokens,
222            max_tokens: self.max_tokens,
223            budget_used_percent: self.budget_used_percent(),
224            iteration: self.iteration,
225            event_counts,
226            event_tokens,
227            events_len: self.events.len(),
228        }
229    }
230
231    /// Estimate token count from text (rough approximation).
232    ///
233    /// Uses ~4 characters per token as a rough estimate.
234    pub fn estimate_tokens(text: &str) -> usize {
235        (text.chars().count() / 4).max(1)
236    }
237
238    /// Create an event from text with automatic token estimation.
239    pub fn event_from_text(event: ContextEvent, text: &str) -> ContextEvent {
240        let tokens = Self::estimate_tokens(text);
241        match event {
242            ContextEvent::SystemPrompt { content, .. } => {
243                ContextEvent::SystemPrompt { content, tokens }
244            }
245            ContextEvent::GrepResult { pattern, matches, .. } => {
246                ContextEvent::GrepResult { pattern, matches, tokens }
247            }
248            ContextEvent::LlmQueryResult { query, response_preview, .. } => {
249                ContextEvent::LlmQueryResult { query, response_preview, tokens }
250            }
251            ContextEvent::AssistantCode { code, .. } => {
252                ContextEvent::AssistantCode { code, tokens }
253            }
254            ContextEvent::ExecutionOutput { output, .. } => {
255                ContextEvent::ExecutionOutput { output, tokens }
256            }
257            ContextEvent::Final { answer, .. } => {
258                ContextEvent::Final { answer, tokens }
259            }
260            ContextEvent::ToolCall { name, arguments_preview, .. } => {
261                ContextEvent::ToolCall { name, arguments_preview, tokens }
262            }
263            ContextEvent::ToolResult { tool_call_id, result_preview, .. } => {
264                ContextEvent::ToolResult { tool_call_id, result_preview, tokens }
265            }
266        }
267    }
268}
269
270/// Summary statistics for a context trace.
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct ContextTraceSummary {
273    /// Total tokens used
274    pub total_tokens: usize,
275    /// Maximum token budget
276    pub max_tokens: usize,
277    /// Percentage of budget used
278    pub budget_used_percent: f32,
279    /// Current iteration number
280    pub iteration: usize,
281    /// Count of events by type
282    pub event_counts: std::collections::HashMap<String, usize>,
283    /// Tokens by event type
284    pub event_tokens: std::collections::HashMap<String, usize>,
285    /// Total number of events
286    pub events_len: usize,
287}
288
289impl ContextTraceSummary {
290    /// Format as a human-readable string.
291    pub fn format(&self) -> String {
292        let mut lines = vec![
293            format!("Context Trace Summary (iteration {})", self.iteration),
294            format!("  Budget: {}/{} tokens ({:.1}%)", 
295                self.total_tokens, self.max_tokens, self.budget_used_percent),
296            format!("  Events: {}", self.events_len),
297        ];
298        
299        if !self.event_counts.is_empty() {
300            lines.push("  By type:".to_string());
301            for (label, count) in &self.event_counts {
302                let tokens = self.event_tokens.get(label).copied().unwrap_or(0);
303                lines.push(format!("    {}: {} events, {} tokens", label, count, tokens));
304            }
305        }
306        
307        lines.join("\n")
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    #[test]
316    fn new_trace_has_zero_tokens() {
317        let trace = ContextTrace::new(1000);
318        assert_eq!(trace.total_tokens(), 0);
319        assert_eq!(trace.remaining_tokens(), 1000);
320    }
321
322    #[test]
323    fn log_event_adds_tokens() {
324        let mut trace = ContextTrace::new(1000);
325        trace.log_event(ContextEvent::SystemPrompt {
326            content: "test".to_string(),
327            tokens: 100,
328        });
329        assert_eq!(trace.total_tokens(), 100);
330        assert_eq!(trace.remaining_tokens(), 900);
331    }
332
333    #[test]
334    fn budget_exceeded_check() {
335        let mut trace = ContextTrace::new(100);
336        trace.log_event(ContextEvent::SystemPrompt {
337            content: "test".to_string(),
338            tokens: 150,
339        });
340        assert!(trace.is_over_budget());
341    }
342
343    #[test]
344    fn event_type_filtering() {
345        let mut trace = ContextTrace::new(1000);
346        trace.log_event(ContextEvent::SystemPrompt {
347            content: "test".to_string(),
348            tokens: 100,
349        });
350        trace.log_event(ContextEvent::GrepResult {
351            pattern: "async".to_string(),
352            matches: 5,
353            tokens: 50,
354        });
355        trace.log_event(ContextEvent::SystemPrompt {
356            content: "test2".to_string(),
357            tokens: 75,
358        });
359        
360        let system_events = trace.events_of_type("system_prompt");
361        assert_eq!(system_events.len(), 2);
362        
363        let grep_events = trace.events_of_type("grep_result");
364        assert_eq!(grep_events.len(), 1);
365    }
366
367    #[test]
368    fn summary_statistics() {
369        let mut trace = ContextTrace::new(1000);
370        trace.log_event(ContextEvent::SystemPrompt {
371            content: "test".to_string(),
372            tokens: 100,
373        });
374        trace.log_event(ContextEvent::GrepResult {
375            pattern: "async".to_string(),
376            matches: 5,
377            tokens: 50,
378        });
379        trace.next_iteration();
380        
381        let summary = trace.summary();
382        assert_eq!(summary.total_tokens, 150);
383        assert_eq!(summary.iteration, 1);
384        assert_eq!(summary.budget_used_percent, 15.0);
385    }
386
387    #[test]
388    fn estimate_tokens_approximation() {
389        // ~4 chars per token
390        assert_eq!(ContextTrace::estimate_tokens("test"), 1);
391        assert_eq!(ContextTrace::estimate_tokens("test test test test"), 4);
392        assert_eq!(ContextTrace::estimate_tokens("12345678"), 2);
393    }
394
395    #[test]
396    fn evict_old_events_when_full() {
397        let mut trace = ContextTrace::new(10000);
398        
399        // Add more than MAX_EVENTS
400        for i in 0..(MAX_EVENTS + 100) {
401            trace.log_event(ContextEvent::SystemPrompt {
402                content: format!("event {}", i),
403                tokens: 1,
404            });
405        }
406        
407        // Should be capped at MAX_EVENTS
408        assert!(trace.events.len() <= MAX_EVENTS);
409        
410        // Total tokens should be bounded
411        assert!(trace.total_tokens <= MAX_EVENTS);
412    }
413}