Skip to main content

punch_runtime/
context_budget.rs

1//! Context window budget management.
2//!
3//! Tracks estimated token usage and enforces limits to prevent context overflow.
4//! Uses a chars/4 heuristic for token estimation (conservative but fast).
5
6use tracing::{debug, info, warn};
7
8use punch_types::{Message, Role, ToolDefinition};
9
10/// Default context window size in tokens.
11const DEFAULT_WINDOW_SIZE: usize = 200_000;
12
13/// Threshold (as fraction of window) for moderate trimming.
14const MODERATE_TRIM_THRESHOLD: f64 = 0.70;
15
16/// Threshold (as fraction of window) for aggressive trimming.
17const AGGRESSIVE_TRIM_THRESHOLD: f64 = 0.90;
18
19/// Messages to keep during moderate trim.
20const MODERATE_KEEP_LAST: usize = 10;
21
22/// Messages to keep during aggressive trim.
23const AGGRESSIVE_KEEP_LAST: usize = 4;
24
25/// Fraction of window allowed per individual tool result.
26const PER_RESULT_CAP_FRACTION: f64 = 0.30;
27
28/// Absolute max fraction for a single tool result.
29const SINGLE_RESULT_MAX_FRACTION: f64 = 0.50;
30
31/// Total fraction of window available for all tool results combined.
32const TOTAL_TOOL_HEADROOM_FRACTION: f64 = 0.75;
33
34/// Context budget configuration and enforcement.
35#[derive(Debug, Clone)]
36pub struct ContextBudget {
37    /// Maximum tokens in the context window.
38    pub window_size: usize,
39}
40
41impl Default for ContextBudget {
42    fn default() -> Self {
43        Self {
44            window_size: DEFAULT_WINDOW_SIZE,
45        }
46    }
47}
48
49impl ContextBudget {
50    /// Create a new context budget with the given window size.
51    pub fn new(window_size: usize) -> Self {
52        Self { window_size }
53    }
54
55    /// Estimate the token count of a set of messages and tool definitions.
56    ///
57    /// Uses the chars/4 heuristic: each character is roughly 0.25 tokens.
58    /// This is conservative (overestimates) which is safer than underestimating.
59    pub fn estimate_tokens(&self, messages: &[Message], tools: &[ToolDefinition]) -> usize {
60        let mut total_chars: usize = 0;
61
62        for msg in messages {
63            total_chars += msg.content.len();
64            for tc in &msg.tool_calls {
65                total_chars += tc.name.len();
66                total_chars += tc.input.to_string().len();
67                total_chars += tc.id.len();
68            }
69            for tr in &msg.tool_results {
70                total_chars += tr.content.len();
71                total_chars += tr.id.len();
72            }
73        }
74
75        for tool in tools {
76            total_chars += tool.name.len();
77            total_chars += tool.description.len();
78            total_chars += tool.input_schema.to_string().len();
79        }
80
81        // chars / 4 heuristic
82        total_chars / 4
83    }
84
85    /// Estimate tokens for messages only (no tool definitions).
86    pub fn estimate_message_tokens(&self, messages: &[Message]) -> usize {
87        self.estimate_tokens(messages, &[])
88    }
89
90    /// Maximum chars allowed per individual tool result.
91    pub fn per_result_cap(&self) -> usize {
92        // Convert from tokens back to chars (* 4)
93        ((self.window_size as f64) * PER_RESULT_CAP_FRACTION * 4.0) as usize
94    }
95
96    /// Absolute maximum chars for a single tool result.
97    pub fn single_result_max(&self) -> usize {
98        ((self.window_size as f64) * SINGLE_RESULT_MAX_FRACTION * 4.0) as usize
99    }
100
101    /// Total chars available for all tool results combined.
102    pub fn total_tool_headroom(&self) -> usize {
103        ((self.window_size as f64) * TOTAL_TOOL_HEADROOM_FRACTION * 4.0) as usize
104    }
105
106    /// Truncate a tool result string to fit within max_chars.
107    ///
108    /// If truncation occurs, appends a `[truncated]` marker.
109    pub fn truncate_result(text: &str, max_chars: usize) -> String {
110        if text.len() <= max_chars {
111            return text.to_string();
112        }
113
114        // Leave room for the truncation marker
115        let marker = "\n\n[truncated — result exceeded context budget]";
116        let keep = max_chars.saturating_sub(marker.len());
117
118        // Find a safe char boundary
119        let boundary = find_char_boundary(text, keep);
120
121        let mut result = text[..boundary].to_string();
122        result.push_str(marker);
123        result
124    }
125
126    /// Apply the context guard to messages: trim oldest tool results when
127    /// total tool result content exceeds headroom.
128    ///
129    /// Returns the (possibly modified) messages and whether trimming occurred.
130    pub fn apply_context_guard(&self, messages: &mut [Message]) -> bool {
131        let headroom = self.total_tool_headroom();
132        let per_cap = self.per_result_cap();
133        let single_max = self.single_result_max();
134        let mut trimmed = false;
135
136        // First pass: truncate individual oversized tool results.
137        for msg in messages.iter_mut() {
138            if msg.role == Role::Tool {
139                for tr in msg.tool_results.iter_mut() {
140                    let cap = per_cap.min(single_max);
141                    if tr.content.len() > cap {
142                        debug!(
143                            tool_result_id = %tr.id,
144                            original_len = tr.content.len(),
145                            cap = cap,
146                            "truncating oversized tool result"
147                        );
148                        tr.content = Self::truncate_result(&tr.content, cap);
149                        trimmed = true;
150                    }
151                }
152            }
153        }
154
155        // Second pass: if total tool result content exceeds headroom,
156        // truncate oldest tool results first.
157        let total_tool_chars: usize = messages
158            .iter()
159            .filter(|m| m.role == Role::Tool)
160            .flat_map(|m| &m.tool_results)
161            .map(|tr| tr.content.len())
162            .sum();
163
164        if total_tool_chars > headroom {
165            debug!(
166                total_tool_chars = total_tool_chars,
167                headroom = headroom,
168                "tool results exceed headroom, trimming oldest"
169            );
170
171            // Collect indices of tool messages, oldest first (they're in chronological order).
172            let tool_indices: Vec<usize> = messages
173                .iter()
174                .enumerate()
175                .filter(|(_, m)| m.role == Role::Tool)
176                .map(|(i, _)| i)
177                .collect();
178
179            let mut current_total = total_tool_chars;
180
181            // Trim from oldest tool messages until we're under headroom.
182            for &idx in &tool_indices {
183                if current_total <= headroom {
184                    break;
185                }
186                let msg = &mut messages[idx];
187                for tr in msg.tool_results.iter_mut() {
188                    if current_total <= headroom {
189                        break;
190                    }
191                    let old_len = tr.content.len();
192                    // Aggressively truncate old results to 200 chars.
193                    if old_len > 200 {
194                        tr.content = Self::truncate_result(&tr.content, 200);
195                        current_total -= old_len - tr.content.len();
196                        trimmed = true;
197                    }
198                }
199            }
200        }
201
202        trimmed
203    }
204
205    /// Determine the trim action needed based on current token estimate.
206    ///
207    /// Returns `None` if no trimming needed, or the trim action to take.
208    pub fn check_trim_needed(
209        &self,
210        messages: &[Message],
211        tools: &[ToolDefinition],
212    ) -> Option<TrimAction> {
213        let tokens = self.estimate_tokens(messages, tools);
214        let ratio = tokens as f64 / self.window_size as f64;
215
216        if ratio > AGGRESSIVE_TRIM_THRESHOLD {
217            warn!(
218                tokens = tokens,
219                window = self.window_size,
220                ratio = format!("{:.1}%", ratio * 100.0),
221                "context usage critical — aggressive trim needed"
222            );
223            Some(TrimAction::Aggressive)
224        } else if ratio > MODERATE_TRIM_THRESHOLD {
225            info!(
226                tokens = tokens,
227                window = self.window_size,
228                ratio = format!("{:.1}%", ratio * 100.0),
229                "context usage high — moderate trim needed"
230            );
231            Some(TrimAction::Moderate)
232        } else {
233            None
234        }
235    }
236
237    /// Apply a trim action to messages. Returns the trimmed messages.
238    ///
239    /// Preserves the first message (usually the user's initial prompt) and
240    /// system markers, then keeps the last N messages.
241    pub fn apply_trim(&self, messages: &mut Vec<Message>, action: TrimAction) {
242        let keep = match action {
243            TrimAction::Moderate => MODERATE_KEEP_LAST,
244            TrimAction::Aggressive => AGGRESSIVE_KEEP_LAST,
245        };
246
247        if messages.len() <= keep {
248            return;
249        }
250
251        let original_len = messages.len();
252
253        // Always keep the first message (user's initial prompt).
254        let first = messages[0].clone();
255        let tail: Vec<Message> = messages
256            .iter()
257            .rev()
258            .take(keep)
259            .cloned()
260            .collect::<Vec<_>>()
261            .into_iter()
262            .rev()
263            .collect();
264
265        messages.clear();
266        messages.push(first);
267
268        // For aggressive trim, insert a summary marker.
269        if matches!(action, TrimAction::Aggressive) {
270            messages.push(Message::new(
271                Role::System,
272                format!(
273                    "[Context trimmed: {} earlier messages removed to stay within context window. \
274                     Conversation may reference prior context that is no longer visible.]",
275                    original_len - 1 - tail.len()
276                ),
277            ));
278        }
279
280        messages.extend(tail);
281
282        info!(
283            original = original_len,
284            trimmed_to = messages.len(),
285            action = ?action,
286            "context window trimmed"
287        );
288    }
289}
290
291/// What kind of trim to apply.
292#[derive(Debug, Clone, Copy, PartialEq, Eq)]
293pub enum TrimAction {
294    /// Keep last 10 messages (70-90% usage).
295    Moderate,
296    /// Keep last 4 messages + insert summary marker (>90% usage).
297    Aggressive,
298}
299
300/// Find a valid UTF-8 char boundary at or before `pos`.
301fn find_char_boundary(s: &str, pos: usize) -> usize {
302    if pos >= s.len() {
303        return s.len();
304    }
305    let mut boundary = pos;
306    while boundary > 0 && !s.is_char_boundary(boundary) {
307        boundary -= 1;
308    }
309    boundary
310}
311
312// ---------------------------------------------------------------------------
313// Tests
314// ---------------------------------------------------------------------------
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use punch_types::{Message, Role, ToolCallResult, ToolCategory, ToolDefinition};
320
321    fn make_message(role: Role, content: &str) -> Message {
322        Message::new(role, content)
323    }
324
325    fn make_tool_message(results: Vec<ToolCallResult>) -> Message {
326        Message {
327            role: Role::Tool,
328            content: String::new(),
329            tool_calls: Vec::new(),
330            tool_results: results,
331            timestamp: chrono::Utc::now(),
332            content_parts: Vec::new(),
333        }
334    }
335
336    fn make_tool_def(name: &str) -> ToolDefinition {
337        ToolDefinition {
338            name: name.to_string(),
339            description: "A test tool".to_string(),
340            input_schema: serde_json::json!({"type": "object"}),
341            category: ToolCategory::FileSystem,
342        }
343    }
344
345    #[test]
346    fn test_estimate_tokens_basic() {
347        let budget = ContextBudget::new(200_000);
348        // 400 chars / 4 = 100 tokens
349        let msg = make_message(Role::User, &"x".repeat(400));
350        let tokens = budget.estimate_tokens(&[msg], &[]);
351        assert_eq!(tokens, 100);
352    }
353
354    #[test]
355    fn test_estimate_tokens_with_tools() {
356        let budget = ContextBudget::new(200_000);
357        let msgs = vec![make_message(Role::User, "hello")];
358        let tools = vec![make_tool_def("file_read")];
359        let tokens_with = budget.estimate_tokens(&msgs, &tools);
360        let tokens_without = budget.estimate_tokens(&msgs, &[]);
361        assert!(tokens_with > tokens_without);
362    }
363
364    #[test]
365    fn test_truncate_result_no_truncation() {
366        let text = "short text";
367        let result = ContextBudget::truncate_result(text, 100);
368        assert_eq!(result, text);
369    }
370
371    #[test]
372    fn test_truncate_result_with_truncation() {
373        let text = "a".repeat(1000);
374        let result = ContextBudget::truncate_result(&text, 200);
375        assert!(result.len() <= 200 + 50); // some slack for marker
376        assert!(result.contains("[truncated"));
377    }
378
379    #[test]
380    fn test_per_result_cap() {
381        let budget = ContextBudget::new(200_000);
382        // 30% of 200K tokens * 4 chars/token = 240K chars
383        assert_eq!(budget.per_result_cap(), 240_000);
384    }
385
386    #[test]
387    fn test_single_result_max() {
388        let budget = ContextBudget::new(200_000);
389        // 50% of 200K tokens * 4 chars/token = 400K chars
390        assert_eq!(budget.single_result_max(), 400_000);
391    }
392
393    #[test]
394    fn test_total_tool_headroom() {
395        let budget = ContextBudget::new(200_000);
396        // 75% of 200K tokens * 4 chars/token = 600K chars
397        assert_eq!(budget.total_tool_headroom(), 600_000);
398    }
399
400    #[test]
401    fn test_check_trim_not_needed() {
402        let budget = ContextBudget::new(200_000);
403        // Small message, well under 70%
404        let msgs = vec![make_message(Role::User, "hello")];
405        assert!(budget.check_trim_needed(&msgs, &[]).is_none());
406    }
407
408    #[test]
409    fn test_check_trim_moderate() {
410        let budget = ContextBudget::new(1_000); // 1K token window
411        // 750 tokens * 4 chars = 3000 chars -> 75% of window
412        let msgs = vec![make_message(Role::User, &"x".repeat(3000))];
413        let action = budget.check_trim_needed(&msgs, &[]);
414        assert_eq!(action, Some(TrimAction::Moderate));
415    }
416
417    #[test]
418    fn test_check_trim_aggressive() {
419        let budget = ContextBudget::new(1_000); // 1K token window
420        // 950 tokens * 4 chars = 3800 chars -> 95% of window
421        let msgs = vec![make_message(Role::User, &"x".repeat(3800))];
422        let action = budget.check_trim_needed(&msgs, &[]);
423        assert_eq!(action, Some(TrimAction::Aggressive));
424    }
425
426    #[test]
427    fn test_apply_trim_moderate() {
428        let budget = ContextBudget::new(200_000);
429        let mut msgs: Vec<Message> = (0..20)
430            .map(|i| make_message(Role::User, &format!("message {}", i)))
431            .collect();
432
433        budget.apply_trim(&mut msgs, TrimAction::Moderate);
434
435        // First message + last 10 = 11
436        assert_eq!(msgs.len(), 11);
437        assert!(msgs[0].content.contains("message 0"));
438        assert!(msgs.last().unwrap().content.contains("message 19"));
439    }
440
441    #[test]
442    fn test_apply_trim_aggressive() {
443        let budget = ContextBudget::new(200_000);
444        let mut msgs: Vec<Message> = (0..20)
445            .map(|i| make_message(Role::User, &format!("message {}", i)))
446            .collect();
447
448        budget.apply_trim(&mut msgs, TrimAction::Aggressive);
449
450        // First message + summary marker + last 4 = 6
451        assert_eq!(msgs.len(), 6);
452        assert!(msgs[0].content.contains("message 0"));
453        assert!(msgs[1].role == Role::System);
454        assert!(msgs[1].content.contains("Context trimmed"));
455        assert!(msgs.last().unwrap().content.contains("message 19"));
456    }
457
458    #[test]
459    fn test_apply_context_guard_truncates_oversized() {
460        // Use a small window so the cap is small
461        let budget = ContextBudget::new(100); // 100 tokens
462        // per_result_cap = 0.30 * 100 * 4 = 120 chars
463        let big_result = "x".repeat(500);
464        let mut msgs = vec![make_tool_message(vec![ToolCallResult {
465            id: "call_1".into(),
466            content: big_result,
467            is_error: false,
468            image: None,
469        }])];
470
471        let trimmed = budget.apply_context_guard(&mut msgs);
472        assert!(trimmed);
473        assert!(msgs[0].tool_results[0].content.len() < 500);
474    }
475
476    #[test]
477    fn test_apply_context_guard_no_change_when_small() {
478        let budget = ContextBudget::new(200_000);
479        let mut msgs = vec![make_tool_message(vec![ToolCallResult {
480            id: "call_1".into(),
481            content: "small result".into(),
482            is_error: false,
483            image: None,
484        }])];
485
486        let trimmed = budget.apply_context_guard(&mut msgs);
487        assert!(!trimmed);
488        assert_eq!(msgs[0].tool_results[0].content, "small result");
489    }
490
491    #[test]
492    fn test_find_char_boundary_ascii() {
493        let s = "hello world";
494        assert_eq!(find_char_boundary(s, 5), 5);
495    }
496
497    #[test]
498    fn test_find_char_boundary_multibyte() {
499        let s = "hello 世界";
500        // '世' starts at byte 6, is 3 bytes. Asking for boundary at 7 should back up to 6.
501        let boundary = find_char_boundary(s, 7);
502        assert!(s.is_char_boundary(boundary));
503        assert!(boundary <= 7);
504    }
505
506    // -----------------------------------------------------------------------
507    // Additional context budget tests
508    // -----------------------------------------------------------------------
509
510    #[test]
511    fn test_default_context_budget() {
512        let budget = ContextBudget::default();
513        assert_eq!(budget.window_size, 200_000);
514    }
515
516    #[test]
517    fn test_estimate_tokens_empty() {
518        let budget = ContextBudget::new(200_000);
519        let tokens = budget.estimate_tokens(&[], &[]);
520        assert_eq!(tokens, 0);
521    }
522
523    #[test]
524    fn test_estimate_tokens_with_tool_calls() {
525        let budget = ContextBudget::new(200_000);
526        let msg = Message {
527            role: Role::Assistant,
528            content: "thinking".into(),
529            tool_calls: vec![punch_types::ToolCall {
530                id: "call_1".into(),
531                name: "file_read".into(),
532                input: serde_json::json!({"path": "/tmp/test.txt"}),
533            }],
534            tool_results: Vec::new(),
535            timestamp: chrono::Utc::now(),
536            content_parts: Vec::new(),
537        };
538        let tokens = budget.estimate_tokens(&[msg], &[]);
539        assert!(tokens > 0);
540    }
541
542    #[test]
543    fn test_estimate_tokens_with_tool_results() {
544        let budget = ContextBudget::new(200_000);
545        let msg = Message {
546            role: Role::Tool,
547            content: String::new(),
548            tool_calls: Vec::new(),
549            tool_results: vec![punch_types::ToolCallResult {
550                id: "call_1".into(),
551                content: "x".repeat(400),
552                is_error: false,
553                image: None,
554            }],
555            timestamp: chrono::Utc::now(),
556            content_parts: Vec::new(),
557        };
558        let tokens = budget.estimate_tokens(&[msg], &[]);
559        assert!(tokens >= 100); // 400+ chars / 4
560    }
561
562    #[test]
563    fn test_estimate_message_tokens() {
564        let budget = ContextBudget::new(200_000);
565        let msgs = vec![make_message(Role::User, &"x".repeat(800))];
566        let tokens = budget.estimate_message_tokens(&msgs);
567        assert_eq!(tokens, 200); // 800 / 4
568    }
569
570    #[test]
571    fn test_per_result_cap_custom_window() {
572        let budget = ContextBudget::new(100_000);
573        // 30% of 100K * 4 = 120K chars
574        assert_eq!(budget.per_result_cap(), 120_000);
575    }
576
577    #[test]
578    fn test_single_result_max_custom_window() {
579        let budget = ContextBudget::new(100_000);
580        // 50% of 100K * 4 = 200K chars
581        assert_eq!(budget.single_result_max(), 200_000);
582    }
583
584    #[test]
585    fn test_truncate_result_exact_boundary() {
586        let text = "a".repeat(100);
587        let result = ContextBudget::truncate_result(&text, 100);
588        // Should not truncate when exactly at boundary
589        assert_eq!(result, text);
590    }
591
592    #[test]
593    fn test_truncate_result_one_over() {
594        let text = "a".repeat(101);
595        let result = ContextBudget::truncate_result(&text, 100);
596        assert!(result.len() <= 150); // some slack for marker
597        assert!(result.contains("[truncated"));
598    }
599
600    #[test]
601    fn test_apply_trim_fewer_than_keep() {
602        let budget = ContextBudget::new(200_000);
603        let mut msgs: Vec<Message> = (0..3)
604            .map(|i| make_message(Role::User, &format!("msg {}", i)))
605            .collect();
606
607        budget.apply_trim(&mut msgs, TrimAction::Moderate);
608        // Should not trim if fewer messages than keep count
609        assert_eq!(msgs.len(), 3);
610    }
611
612    #[test]
613    fn test_apply_trim_preserves_first_message() {
614        let budget = ContextBudget::new(200_000);
615        let mut msgs: Vec<Message> = (0..30)
616            .map(|i| make_message(Role::User, &format!("msg {}", i)))
617            .collect();
618
619        budget.apply_trim(&mut msgs, TrimAction::Moderate);
620        assert!(msgs[0].content.contains("msg 0"));
621    }
622
623    #[test]
624    fn test_apply_trim_aggressive_inserts_marker() {
625        let budget = ContextBudget::new(200_000);
626        let mut msgs: Vec<Message> = (0..15)
627            .map(|i| make_message(Role::User, &format!("msg {}", i)))
628            .collect();
629
630        budget.apply_trim(&mut msgs, TrimAction::Aggressive);
631        // Should have: first + marker + last 4
632        assert_eq!(msgs.len(), 6);
633        assert_eq!(msgs[1].role, Role::System);
634        assert!(msgs[1].content.contains("Context trimmed"));
635    }
636
637    #[test]
638    fn test_check_trim_below_moderate() {
639        let budget = ContextBudget::new(10_000);
640        // 6000 tokens * 4 chars = 24000 chars -> 60% of window, below 70%
641        let msgs = vec![make_message(Role::User, &"x".repeat(24_000))];
642        assert!(budget.check_trim_needed(&msgs, &[]).is_none());
643    }
644
645    #[test]
646    fn test_apply_context_guard_total_headroom_exceeded() {
647        // Very small window to trigger total headroom exceeded path
648        let budget = ContextBudget::new(10);
649        let big_result = "y".repeat(500);
650        let mut msgs = vec![
651            make_tool_message(vec![ToolCallResult {
652                id: "c1".into(),
653                content: big_result.clone(),
654                is_error: false,
655                image: None,
656            }]),
657            make_tool_message(vec![ToolCallResult {
658                id: "c2".into(),
659                content: big_result,
660                is_error: false,
661                image: None,
662            }]),
663        ];
664
665        let trimmed = budget.apply_context_guard(&mut msgs);
666        assert!(trimmed);
667    }
668
669    #[test]
670    fn test_find_char_boundary_at_end() {
671        let s = "hello";
672        assert_eq!(find_char_boundary(s, 100), s.len());
673    }
674
675    #[test]
676    fn test_find_char_boundary_at_zero() {
677        let s = "hello";
678        assert_eq!(find_char_boundary(s, 0), 0);
679    }
680
681    #[test]
682    fn test_trim_action_equality() {
683        assert_eq!(TrimAction::Moderate, TrimAction::Moderate);
684        assert_eq!(TrimAction::Aggressive, TrimAction::Aggressive);
685        assert_ne!(TrimAction::Moderate, TrimAction::Aggressive);
686    }
687}