Skip to main content

codetether_agent/rlm/
context_trace.rs

1//! Context tracing for RLM iterations.
2//!
3//! Tracks token budget and context events per RLM iteration to enable
4//! analysis of context window usage and optimization opportunities.
5//!
6//! # Events Traced
7//!
8//! - SystemPrompt: Initial system message with context summary
9//! - GrepResult: Results from grep operations
10//! - LlmQueryResult: Results from sub-LLM calls
11//! - AssistantCode: Code generated by the assistant
12//! - Final: Final answer returned
13//!
14//! # Usage
15//!
16//! ```ignore
17//! use codetether_agent::rlm::context_trace::{ContextTrace, ContextEvent};
18//!
19//! let mut trace = ContextTrace::new(max_tokens);
20//!
21//! // Log events as they occur
22//! trace.log_event(ContextEvent::SystemPrompt {
23//!     content: system_prompt.clone(),
24//!     tokens: count_tokens(&system_prompt),
25//! });
26//!
27//! trace.log_event(ContextEvent::GrepResult {
28//!     pattern: "async fn".to_string(),
29//!     matches: 5,
30//!     tokens: 150,
31//! });
32//!
33//! // Get summary statistics
34//! let stats = trace.summary();
35//! println!("Total tokens: {}", stats.total_tokens);
36//! println!("Budget used: {:.1}%", stats.budget_used_percent);
37//! ```
38
39use serde::{Deserialize, Serialize};
40use std::collections::VecDeque;
41
42/// Maximum number of events to keep in the trace buffer.
43const MAX_EVENTS: usize = 1000;
44
45/// A context event logged during RLM iteration.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum ContextEvent {
48    /// System prompt with context summary
49    SystemPrompt { content: String, tokens: usize },
50    /// Result from a grep operation
51    GrepResult {
52        pattern: String,
53        matches: usize,
54        tokens: usize,
55    },
56    /// Result from a sub-LLM query
57    LlmQueryResult {
58        query: String,
59        response_preview: String,
60        tokens: usize,
61    },
62    /// Code generated by the assistant
63    AssistantCode { code: String, tokens: usize },
64    /// Output from code execution
65    ExecutionOutput { output: String, tokens: usize },
66    /// Final answer
67    Final { answer: String, tokens: usize },
68    /// Tool call
69    ToolCall {
70        name: String,
71        arguments_preview: String,
72        tokens: usize,
73    },
74    /// Tool result
75    ToolResult {
76        tool_call_id: String,
77        result_preview: String,
78        tokens: usize,
79    },
80}
81
82impl ContextEvent {
83    /// Get the token count for this event.
84    pub fn tokens(&self) -> usize {
85        match self {
86            Self::SystemPrompt { tokens, .. } => *tokens,
87            Self::GrepResult { tokens, .. } => *tokens,
88            Self::LlmQueryResult { tokens, .. } => *tokens,
89            Self::AssistantCode { tokens, .. } => *tokens,
90            Self::ExecutionOutput { tokens, .. } => *tokens,
91            Self::Final { tokens, .. } => *tokens,
92            Self::ToolCall { tokens, .. } => *tokens,
93            Self::ToolResult { tokens, .. } => *tokens,
94        }
95    }
96
97    /// Get a human-readable label for this event type.
98    pub fn label(&self) -> &'static str {
99        match self {
100            Self::SystemPrompt { .. } => "system_prompt",
101            Self::GrepResult { .. } => "grep_result",
102            Self::LlmQueryResult { .. } => "llm_query_result",
103            Self::AssistantCode { .. } => "assistant_code",
104            Self::ExecutionOutput { .. } => "execution_output",
105            Self::Final { .. } => "final",
106            Self::ToolCall { .. } => "tool_call",
107            Self::ToolResult { .. } => "tool_result",
108        }
109    }
110}
111
112/// Context trace for a single RLM analysis run.
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct ContextTrace {
115    /// Maximum token budget
116    max_tokens: usize,
117    /// Events logged during the run
118    events: VecDeque<ContextEvent>,
119    /// Total tokens accumulated
120    total_tokens: usize,
121    /// Iteration number
122    iteration: usize,
123}
124
125impl ContextTrace {
126    /// Create a new context trace with the given token budget.
127    pub fn new(max_tokens: usize) -> Self {
128        Self {
129            max_tokens,
130            events: VecDeque::with_capacity(64),
131            total_tokens: 0,
132            iteration: 0,
133        }
134    }
135
136    /// Log a context event.
137    pub fn log_event(&mut self, event: ContextEvent) {
138        self.total_tokens += event.tokens();
139
140        // Evict oldest events if buffer is full
141        while self.events.len() >= MAX_EVENTS {
142            if let Some(evicted) = self.events.pop_front() {
143                self.total_tokens = self.total_tokens.saturating_sub(evicted.tokens());
144            }
145        }
146
147        self.events.push_back(event);
148    }
149
150    /// Increment the iteration counter.
151    pub fn next_iteration(&mut self) {
152        self.iteration += 1;
153    }
154
155    /// Get the current iteration number.
156    pub fn iteration(&self) -> usize {
157        self.iteration
158    }
159
160    /// Get the total tokens used.
161    pub fn total_tokens(&self) -> usize {
162        self.total_tokens
163    }
164
165    /// Get the remaining token budget.
166    pub fn remaining_tokens(&self) -> usize {
167        self.max_tokens.saturating_sub(self.total_tokens)
168    }
169
170    /// Get the percentage of budget used.
171    pub fn budget_used_percent(&self) -> f32 {
172        if self.max_tokens == 0 {
173            0.0
174        } else {
175            // Use integer math to avoid floating point noise (e.g. 15.000001).
176            // We keep 1 decimal place of precision.
177            let total = self.total_tokens as u64;
178            let max = self.max_tokens as u64;
179            let tenths = (total.saturating_mul(1000) + (max / 2)).saturating_div(max);
180            (tenths as f32) / 10.0
181        }
182    }
183
184    /// Check if the budget is exceeded.
185    pub fn is_over_budget(&self) -> bool {
186        self.total_tokens > self.max_tokens
187    }
188
189    /// Get all events.
190    pub fn events(&self) -> &VecDeque<ContextEvent> {
191        &self.events
192    }
193
194    /// Get events by type.
195    pub fn events_of_type(&self, label: &str) -> Vec<&ContextEvent> {
196        self.events.iter().filter(|e| e.label() == label).collect()
197    }
198
199    /// Get a summary of the trace.
200    pub fn summary(&self) -> ContextTraceSummary {
201        let mut event_counts = std::collections::HashMap::new();
202        let mut event_tokens = std::collections::HashMap::new();
203
204        for event in &self.events {
205            let label = event.label().to_string();
206            *event_counts.entry(label.clone()).or_insert(0) += 1;
207            *event_tokens.entry(label).or_insert(0) += event.tokens();
208        }
209
210        ContextTraceSummary {
211            total_tokens: self.total_tokens,
212            max_tokens: self.max_tokens,
213            budget_used_percent: self.budget_used_percent(),
214            iteration: self.iteration,
215            event_counts,
216            event_tokens,
217            events_len: self.events.len(),
218        }
219    }
220
221    /// Estimate token count from text (rough approximation).
222    ///
223    /// Uses ~4 characters per token as a rough estimate.
224    pub fn estimate_tokens(text: &str) -> usize {
225        (text.chars().count() / 4).max(1)
226    }
227
228    /// Create an event from text with automatic token estimation.
229    pub fn event_from_text(event: ContextEvent, text: &str) -> ContextEvent {
230        let tokens = Self::estimate_tokens(text);
231        match event {
232            ContextEvent::SystemPrompt { content, .. } => {
233                ContextEvent::SystemPrompt { content, tokens }
234            }
235            ContextEvent::GrepResult {
236                pattern, matches, ..
237            } => ContextEvent::GrepResult {
238                pattern,
239                matches,
240                tokens,
241            },
242            ContextEvent::LlmQueryResult {
243                query,
244                response_preview,
245                ..
246            } => ContextEvent::LlmQueryResult {
247                query,
248                response_preview,
249                tokens,
250            },
251            ContextEvent::AssistantCode { code, .. } => {
252                ContextEvent::AssistantCode { code, tokens }
253            }
254            ContextEvent::ExecutionOutput { output, .. } => {
255                ContextEvent::ExecutionOutput { output, tokens }
256            }
257            ContextEvent::Final { answer, .. } => ContextEvent::Final { answer, tokens },
258            ContextEvent::ToolCall {
259                name,
260                arguments_preview,
261                ..
262            } => ContextEvent::ToolCall {
263                name,
264                arguments_preview,
265                tokens,
266            },
267            ContextEvent::ToolResult {
268                tool_call_id,
269                result_preview,
270                ..
271            } => ContextEvent::ToolResult {
272                tool_call_id,
273                result_preview,
274                tokens,
275            },
276        }
277    }
278}
279
280/// Summary statistics for a context trace.
281#[derive(Debug, Clone, Serialize, Deserialize)]
282pub struct ContextTraceSummary {
283    /// Total tokens used
284    pub total_tokens: usize,
285    /// Maximum token budget
286    pub max_tokens: usize,
287    /// Percentage of budget used
288    pub budget_used_percent: f32,
289    /// Current iteration number
290    pub iteration: usize,
291    /// Count of events by type
292    pub event_counts: std::collections::HashMap<String, usize>,
293    /// Tokens by event type
294    pub event_tokens: std::collections::HashMap<String, usize>,
295    /// Total number of events
296    pub events_len: usize,
297}
298
299impl ContextTraceSummary {
300    /// Format as a human-readable string.
301    pub fn format(&self) -> String {
302        let mut lines = vec![
303            format!("Context Trace Summary (iteration {})", self.iteration),
304            format!(
305                "  Budget: {}/{} tokens ({:.1}%)",
306                self.total_tokens, self.max_tokens, self.budget_used_percent
307            ),
308            format!("  Events: {}", self.events_len),
309        ];
310
311        if !self.event_counts.is_empty() {
312            lines.push("  By type:".to_string());
313            for (label, count) in &self.event_counts {
314                let tokens = self.event_tokens.get(label).copied().unwrap_or(0);
315                lines.push(format!(
316                    "    {}: {} events, {} tokens",
317                    label, count, tokens
318                ));
319            }
320        }
321
322        lines.join("\n")
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn new_trace_has_zero_tokens() {
332        let trace = ContextTrace::new(1000);
333        assert_eq!(trace.total_tokens(), 0);
334        assert_eq!(trace.remaining_tokens(), 1000);
335    }
336
337    #[test]
338    fn log_event_adds_tokens() {
339        let mut trace = ContextTrace::new(1000);
340        trace.log_event(ContextEvent::SystemPrompt {
341            content: "test".to_string(),
342            tokens: 100,
343        });
344        assert_eq!(trace.total_tokens(), 100);
345        assert_eq!(trace.remaining_tokens(), 900);
346    }
347
348    #[test]
349    fn budget_exceeded_check() {
350        let mut trace = ContextTrace::new(100);
351        trace.log_event(ContextEvent::SystemPrompt {
352            content: "test".to_string(),
353            tokens: 150,
354        });
355        assert!(trace.is_over_budget());
356    }
357
358    #[test]
359    fn event_type_filtering() {
360        let mut trace = ContextTrace::new(1000);
361        trace.log_event(ContextEvent::SystemPrompt {
362            content: "test".to_string(),
363            tokens: 100,
364        });
365        trace.log_event(ContextEvent::GrepResult {
366            pattern: "async".to_string(),
367            matches: 5,
368            tokens: 50,
369        });
370        trace.log_event(ContextEvent::SystemPrompt {
371            content: "test2".to_string(),
372            tokens: 75,
373        });
374
375        let system_events = trace.events_of_type("system_prompt");
376        assert_eq!(system_events.len(), 2);
377
378        let grep_events = trace.events_of_type("grep_result");
379        assert_eq!(grep_events.len(), 1);
380    }
381
382    #[test]
383    fn summary_statistics() {
384        let mut trace = ContextTrace::new(1000);
385        trace.log_event(ContextEvent::SystemPrompt {
386            content: "test".to_string(),
387            tokens: 100,
388        });
389        trace.log_event(ContextEvent::GrepResult {
390            pattern: "async".to_string(),
391            matches: 5,
392            tokens: 50,
393        });
394        trace.next_iteration();
395
396        let summary = trace.summary();
397        assert_eq!(summary.total_tokens, 150);
398        assert_eq!(summary.iteration, 1);
399        assert_eq!(summary.budget_used_percent, 15.0);
400    }
401
402    #[test]
403    fn estimate_tokens_approximation() {
404        // ~4 chars per token
405        assert_eq!(ContextTrace::estimate_tokens("test"), 1);
406        assert_eq!(ContextTrace::estimate_tokens("test test test test"), 4);
407        assert_eq!(ContextTrace::estimate_tokens("12345678"), 2);
408    }
409
410    #[test]
411    fn evict_old_events_when_full() {
412        let mut trace = ContextTrace::new(10000);
413
414        // Add more than MAX_EVENTS
415        for i in 0..(MAX_EVENTS + 100) {
416            trace.log_event(ContextEvent::SystemPrompt {
417                content: format!("event {}", i),
418                tokens: 1,
419            });
420        }
421
422        // Should be capped at MAX_EVENTS
423        assert!(trace.events.len() <= MAX_EVENTS);
424
425        // Total tokens should be bounded
426        assert!(trace.total_tokens <= MAX_EVENTS);
427    }
428}