1use tracing::{debug, info, warn};
7
8use punch_types::{Message, Role, ToolDefinition};
9
10const DEFAULT_WINDOW_SIZE: usize = 200_000;
12
13const MODERATE_TRIM_THRESHOLD: f64 = 0.70;
15
16const AGGRESSIVE_TRIM_THRESHOLD: f64 = 0.90;
18
19const MODERATE_KEEP_LAST: usize = 10;
21
22const AGGRESSIVE_KEEP_LAST: usize = 4;
24
25const PER_RESULT_CAP_FRACTION: f64 = 0.30;
27
28const SINGLE_RESULT_MAX_FRACTION: f64 = 0.50;
30
31const TOTAL_TOOL_HEADROOM_FRACTION: f64 = 0.75;
33
34#[derive(Debug, Clone)]
36pub struct ContextBudget {
37 pub window_size: usize,
39}
40
41impl Default for ContextBudget {
42 fn default() -> Self {
43 Self {
44 window_size: DEFAULT_WINDOW_SIZE,
45 }
46 }
47}
48
49impl ContextBudget {
50 pub fn new(window_size: usize) -> Self {
52 Self { window_size }
53 }
54
55 pub fn estimate_tokens(&self, messages: &[Message], tools: &[ToolDefinition]) -> usize {
60 let mut total_chars: usize = 0;
61
62 for msg in messages {
63 total_chars += msg.content.len();
64 for tc in &msg.tool_calls {
65 total_chars += tc.name.len();
66 total_chars += tc.input.to_string().len();
67 total_chars += tc.id.len();
68 }
69 for tr in &msg.tool_results {
70 total_chars += tr.content.len();
71 total_chars += tr.id.len();
72 }
73 }
74
75 for tool in tools {
76 total_chars += tool.name.len();
77 total_chars += tool.description.len();
78 total_chars += tool.input_schema.to_string().len();
79 }
80
81 total_chars / 4
83 }
84
85 pub fn estimate_message_tokens(&self, messages: &[Message]) -> usize {
87 self.estimate_tokens(messages, &[])
88 }
89
90 pub fn per_result_cap(&self) -> usize {
92 ((self.window_size as f64) * PER_RESULT_CAP_FRACTION * 4.0) as usize
94 }
95
96 pub fn single_result_max(&self) -> usize {
98 ((self.window_size as f64) * SINGLE_RESULT_MAX_FRACTION * 4.0) as usize
99 }
100
101 pub fn total_tool_headroom(&self) -> usize {
103 ((self.window_size as f64) * TOTAL_TOOL_HEADROOM_FRACTION * 4.0) as usize
104 }
105
106 pub fn truncate_result(text: &str, max_chars: usize) -> String {
110 if text.len() <= max_chars {
111 return text.to_string();
112 }
113
114 let marker = "\n\n[truncated — result exceeded context budget]";
116 let keep = max_chars.saturating_sub(marker.len());
117
118 let boundary = find_char_boundary(text, keep);
120
121 let mut result = text[..boundary].to_string();
122 result.push_str(marker);
123 result
124 }
125
126 pub fn apply_context_guard(&self, messages: &mut [Message]) -> bool {
131 let headroom = self.total_tool_headroom();
132 let per_cap = self.per_result_cap();
133 let single_max = self.single_result_max();
134 let mut trimmed = false;
135
136 for msg in messages.iter_mut() {
138 if msg.role == Role::Tool {
139 for tr in msg.tool_results.iter_mut() {
140 let cap = per_cap.min(single_max);
141 if tr.content.len() > cap {
142 debug!(
143 tool_result_id = %tr.id,
144 original_len = tr.content.len(),
145 cap = cap,
146 "truncating oversized tool result"
147 );
148 tr.content = Self::truncate_result(&tr.content, cap);
149 trimmed = true;
150 }
151 }
152 }
153 }
154
155 let total_tool_chars: usize = messages
158 .iter()
159 .filter(|m| m.role == Role::Tool)
160 .flat_map(|m| &m.tool_results)
161 .map(|tr| tr.content.len())
162 .sum();
163
164 if total_tool_chars > headroom {
165 debug!(
166 total_tool_chars = total_tool_chars,
167 headroom = headroom,
168 "tool results exceed headroom, trimming oldest"
169 );
170
171 let tool_indices: Vec<usize> = messages
173 .iter()
174 .enumerate()
175 .filter(|(_, m)| m.role == Role::Tool)
176 .map(|(i, _)| i)
177 .collect();
178
179 let mut current_total = total_tool_chars;
180
181 for &idx in &tool_indices {
183 if current_total <= headroom {
184 break;
185 }
186 let msg = &mut messages[idx];
187 for tr in msg.tool_results.iter_mut() {
188 if current_total <= headroom {
189 break;
190 }
191 let old_len = tr.content.len();
192 if old_len > 200 {
194 tr.content = Self::truncate_result(&tr.content, 200);
195 current_total -= old_len - tr.content.len();
196 trimmed = true;
197 }
198 }
199 }
200 }
201
202 trimmed
203 }
204
205 pub fn check_trim_needed(
209 &self,
210 messages: &[Message],
211 tools: &[ToolDefinition],
212 ) -> Option<TrimAction> {
213 let tokens = self.estimate_tokens(messages, tools);
214 let ratio = tokens as f64 / self.window_size as f64;
215
216 if ratio > AGGRESSIVE_TRIM_THRESHOLD {
217 warn!(
218 tokens = tokens,
219 window = self.window_size,
220 ratio = format!("{:.1}%", ratio * 100.0),
221 "context usage critical — aggressive trim needed"
222 );
223 Some(TrimAction::Aggressive)
224 } else if ratio > MODERATE_TRIM_THRESHOLD {
225 info!(
226 tokens = tokens,
227 window = self.window_size,
228 ratio = format!("{:.1}%", ratio * 100.0),
229 "context usage high — moderate trim needed"
230 );
231 Some(TrimAction::Moderate)
232 } else {
233 None
234 }
235 }
236
237 pub fn apply_trim(&self, messages: &mut Vec<Message>, action: TrimAction) {
242 let keep = match action {
243 TrimAction::Moderate => MODERATE_KEEP_LAST,
244 TrimAction::Aggressive => AGGRESSIVE_KEEP_LAST,
245 };
246
247 if messages.len() <= keep {
248 return;
249 }
250
251 let original_len = messages.len();
252
253 let first = messages[0].clone();
255 let tail: Vec<Message> = messages
256 .iter()
257 .rev()
258 .take(keep)
259 .cloned()
260 .collect::<Vec<_>>()
261 .into_iter()
262 .rev()
263 .collect();
264
265 messages.clear();
266 messages.push(first);
267
268 if matches!(action, TrimAction::Aggressive) {
270 messages.push(Message::new(
271 Role::System,
272 format!(
273 "[Context trimmed: {} earlier messages removed to stay within context window. \
274 Conversation may reference prior context that is no longer visible.]",
275 original_len - 1 - tail.len()
276 ),
277 ));
278 }
279
280 messages.extend(tail);
281
282 info!(
283 original = original_len,
284 trimmed_to = messages.len(),
285 action = ?action,
286 "context window trimmed"
287 );
288 }
289}
290
291#[derive(Debug, Clone, Copy, PartialEq, Eq)]
293pub enum TrimAction {
294 Moderate,
296 Aggressive,
298}
299
300fn find_char_boundary(s: &str, pos: usize) -> usize {
302 if pos >= s.len() {
303 return s.len();
304 }
305 let mut boundary = pos;
306 while boundary > 0 && !s.is_char_boundary(boundary) {
307 boundary -= 1;
308 }
309 boundary
310}
311
312#[cfg(test)]
317mod tests {
318 use super::*;
319 use punch_types::{Message, Role, ToolCallResult, ToolCategory, ToolDefinition};
320
321 fn make_message(role: Role, content: &str) -> Message {
322 Message::new(role, content)
323 }
324
325 fn make_tool_message(results: Vec<ToolCallResult>) -> Message {
326 Message {
327 role: Role::Tool,
328 content: String::new(),
329 tool_calls: Vec::new(),
330 tool_results: results,
331 timestamp: chrono::Utc::now(),
332 content_parts: Vec::new(),
333 }
334 }
335
336 fn make_tool_def(name: &str) -> ToolDefinition {
337 ToolDefinition {
338 name: name.to_string(),
339 description: "A test tool".to_string(),
340 input_schema: serde_json::json!({"type": "object"}),
341 category: ToolCategory::FileSystem,
342 }
343 }
344
345 #[test]
346 fn test_estimate_tokens_basic() {
347 let budget = ContextBudget::new(200_000);
348 let msg = make_message(Role::User, &"x".repeat(400));
350 let tokens = budget.estimate_tokens(&[msg], &[]);
351 assert_eq!(tokens, 100);
352 }
353
354 #[test]
355 fn test_estimate_tokens_with_tools() {
356 let budget = ContextBudget::new(200_000);
357 let msgs = vec![make_message(Role::User, "hello")];
358 let tools = vec![make_tool_def("file_read")];
359 let tokens_with = budget.estimate_tokens(&msgs, &tools);
360 let tokens_without = budget.estimate_tokens(&msgs, &[]);
361 assert!(tokens_with > tokens_without);
362 }
363
364 #[test]
365 fn test_truncate_result_no_truncation() {
366 let text = "short text";
367 let result = ContextBudget::truncate_result(text, 100);
368 assert_eq!(result, text);
369 }
370
371 #[test]
372 fn test_truncate_result_with_truncation() {
373 let text = "a".repeat(1000);
374 let result = ContextBudget::truncate_result(&text, 200);
375 assert!(result.len() <= 200 + 50); assert!(result.contains("[truncated"));
377 }
378
379 #[test]
380 fn test_per_result_cap() {
381 let budget = ContextBudget::new(200_000);
382 assert_eq!(budget.per_result_cap(), 240_000);
384 }
385
386 #[test]
387 fn test_single_result_max() {
388 let budget = ContextBudget::new(200_000);
389 assert_eq!(budget.single_result_max(), 400_000);
391 }
392
393 #[test]
394 fn test_total_tool_headroom() {
395 let budget = ContextBudget::new(200_000);
396 assert_eq!(budget.total_tool_headroom(), 600_000);
398 }
399
400 #[test]
401 fn test_check_trim_not_needed() {
402 let budget = ContextBudget::new(200_000);
403 let msgs = vec![make_message(Role::User, "hello")];
405 assert!(budget.check_trim_needed(&msgs, &[]).is_none());
406 }
407
408 #[test]
409 fn test_check_trim_moderate() {
410 let budget = ContextBudget::new(1_000); let msgs = vec![make_message(Role::User, &"x".repeat(3000))];
413 let action = budget.check_trim_needed(&msgs, &[]);
414 assert_eq!(action, Some(TrimAction::Moderate));
415 }
416
417 #[test]
418 fn test_check_trim_aggressive() {
419 let budget = ContextBudget::new(1_000); let msgs = vec![make_message(Role::User, &"x".repeat(3800))];
422 let action = budget.check_trim_needed(&msgs, &[]);
423 assert_eq!(action, Some(TrimAction::Aggressive));
424 }
425
426 #[test]
427 fn test_apply_trim_moderate() {
428 let budget = ContextBudget::new(200_000);
429 let mut msgs: Vec<Message> = (0..20)
430 .map(|i| make_message(Role::User, &format!("message {}", i)))
431 .collect();
432
433 budget.apply_trim(&mut msgs, TrimAction::Moderate);
434
435 assert_eq!(msgs.len(), 11);
437 assert!(msgs[0].content.contains("message 0"));
438 assert!(msgs.last().unwrap().content.contains("message 19"));
439 }
440
441 #[test]
442 fn test_apply_trim_aggressive() {
443 let budget = ContextBudget::new(200_000);
444 let mut msgs: Vec<Message> = (0..20)
445 .map(|i| make_message(Role::User, &format!("message {}", i)))
446 .collect();
447
448 budget.apply_trim(&mut msgs, TrimAction::Aggressive);
449
450 assert_eq!(msgs.len(), 6);
452 assert!(msgs[0].content.contains("message 0"));
453 assert!(msgs[1].role == Role::System);
454 assert!(msgs[1].content.contains("Context trimmed"));
455 assert!(msgs.last().unwrap().content.contains("message 19"));
456 }
457
458 #[test]
459 fn test_apply_context_guard_truncates_oversized() {
460 let budget = ContextBudget::new(100); let big_result = "x".repeat(500);
464 let mut msgs = vec![make_tool_message(vec![ToolCallResult {
465 id: "call_1".into(),
466 content: big_result,
467 is_error: false,
468 image: None,
469 }])];
470
471 let trimmed = budget.apply_context_guard(&mut msgs);
472 assert!(trimmed);
473 assert!(msgs[0].tool_results[0].content.len() < 500);
474 }
475
476 #[test]
477 fn test_apply_context_guard_no_change_when_small() {
478 let budget = ContextBudget::new(200_000);
479 let mut msgs = vec![make_tool_message(vec![ToolCallResult {
480 id: "call_1".into(),
481 content: "small result".into(),
482 is_error: false,
483 image: None,
484 }])];
485
486 let trimmed = budget.apply_context_guard(&mut msgs);
487 assert!(!trimmed);
488 assert_eq!(msgs[0].tool_results[0].content, "small result");
489 }
490
491 #[test]
492 fn test_find_char_boundary_ascii() {
493 let s = "hello world";
494 assert_eq!(find_char_boundary(s, 5), 5);
495 }
496
497 #[test]
498 fn test_find_char_boundary_multibyte() {
499 let s = "hello 世界";
500 let boundary = find_char_boundary(s, 7);
502 assert!(s.is_char_boundary(boundary));
503 assert!(boundary <= 7);
504 }
505
506 #[test]
511 fn test_default_context_budget() {
512 let budget = ContextBudget::default();
513 assert_eq!(budget.window_size, 200_000);
514 }
515
516 #[test]
517 fn test_estimate_tokens_empty() {
518 let budget = ContextBudget::new(200_000);
519 let tokens = budget.estimate_tokens(&[], &[]);
520 assert_eq!(tokens, 0);
521 }
522
523 #[test]
524 fn test_estimate_tokens_with_tool_calls() {
525 let budget = ContextBudget::new(200_000);
526 let msg = Message {
527 role: Role::Assistant,
528 content: "thinking".into(),
529 tool_calls: vec![punch_types::ToolCall {
530 id: "call_1".into(),
531 name: "file_read".into(),
532 input: serde_json::json!({"path": "/tmp/test.txt"}),
533 }],
534 tool_results: Vec::new(),
535 timestamp: chrono::Utc::now(),
536 content_parts: Vec::new(),
537 };
538 let tokens = budget.estimate_tokens(&[msg], &[]);
539 assert!(tokens > 0);
540 }
541
542 #[test]
543 fn test_estimate_tokens_with_tool_results() {
544 let budget = ContextBudget::new(200_000);
545 let msg = Message {
546 role: Role::Tool,
547 content: String::new(),
548 tool_calls: Vec::new(),
549 tool_results: vec![punch_types::ToolCallResult {
550 id: "call_1".into(),
551 content: "x".repeat(400),
552 is_error: false,
553 image: None,
554 }],
555 timestamp: chrono::Utc::now(),
556 content_parts: Vec::new(),
557 };
558 let tokens = budget.estimate_tokens(&[msg], &[]);
559 assert!(tokens >= 100); }
561
562 #[test]
563 fn test_estimate_message_tokens() {
564 let budget = ContextBudget::new(200_000);
565 let msgs = vec![make_message(Role::User, &"x".repeat(800))];
566 let tokens = budget.estimate_message_tokens(&msgs);
567 assert_eq!(tokens, 200); }
569
570 #[test]
571 fn test_per_result_cap_custom_window() {
572 let budget = ContextBudget::new(100_000);
573 assert_eq!(budget.per_result_cap(), 120_000);
575 }
576
577 #[test]
578 fn test_single_result_max_custom_window() {
579 let budget = ContextBudget::new(100_000);
580 assert_eq!(budget.single_result_max(), 200_000);
582 }
583
584 #[test]
585 fn test_truncate_result_exact_boundary() {
586 let text = "a".repeat(100);
587 let result = ContextBudget::truncate_result(&text, 100);
588 assert_eq!(result, text);
590 }
591
592 #[test]
593 fn test_truncate_result_one_over() {
594 let text = "a".repeat(101);
595 let result = ContextBudget::truncate_result(&text, 100);
596 assert!(result.len() <= 150); assert!(result.contains("[truncated"));
598 }
599
600 #[test]
601 fn test_apply_trim_fewer_than_keep() {
602 let budget = ContextBudget::new(200_000);
603 let mut msgs: Vec<Message> = (0..3)
604 .map(|i| make_message(Role::User, &format!("msg {}", i)))
605 .collect();
606
607 budget.apply_trim(&mut msgs, TrimAction::Moderate);
608 assert_eq!(msgs.len(), 3);
610 }
611
612 #[test]
613 fn test_apply_trim_preserves_first_message() {
614 let budget = ContextBudget::new(200_000);
615 let mut msgs: Vec<Message> = (0..30)
616 .map(|i| make_message(Role::User, &format!("msg {}", i)))
617 .collect();
618
619 budget.apply_trim(&mut msgs, TrimAction::Moderate);
620 assert!(msgs[0].content.contains("msg 0"));
621 }
622
623 #[test]
624 fn test_apply_trim_aggressive_inserts_marker() {
625 let budget = ContextBudget::new(200_000);
626 let mut msgs: Vec<Message> = (0..15)
627 .map(|i| make_message(Role::User, &format!("msg {}", i)))
628 .collect();
629
630 budget.apply_trim(&mut msgs, TrimAction::Aggressive);
631 assert_eq!(msgs.len(), 6);
633 assert_eq!(msgs[1].role, Role::System);
634 assert!(msgs[1].content.contains("Context trimmed"));
635 }
636
637 #[test]
638 fn test_check_trim_below_moderate() {
639 let budget = ContextBudget::new(10_000);
640 let msgs = vec![make_message(Role::User, &"x".repeat(24_000))];
642 assert!(budget.check_trim_needed(&msgs, &[]).is_none());
643 }
644
645 #[test]
646 fn test_apply_context_guard_total_headroom_exceeded() {
647 let budget = ContextBudget::new(10);
649 let big_result = "y".repeat(500);
650 let mut msgs = vec![
651 make_tool_message(vec![ToolCallResult {
652 id: "c1".into(),
653 content: big_result.clone(),
654 is_error: false,
655 image: None,
656 }]),
657 make_tool_message(vec![ToolCallResult {
658 id: "c2".into(),
659 content: big_result,
660 is_error: false,
661 image: None,
662 }]),
663 ];
664
665 let trimmed = budget.apply_context_guard(&mut msgs);
666 assert!(trimmed);
667 }
668
669 #[test]
670 fn test_find_char_boundary_at_end() {
671 let s = "hello";
672 assert_eq!(find_char_boundary(s, 100), s.len());
673 }
674
675 #[test]
676 fn test_find_char_boundary_at_zero() {
677 let s = "hello";
678 assert_eq!(find_char_boundary(s, 0), 0);
679 }
680
681 #[test]
682 fn test_trim_action_equality() {
683 assert_eq!(TrimAction::Moderate, TrimAction::Moderate);
684 assert_eq!(TrimAction::Aggressive, TrimAction::Aggressive);
685 assert_ne!(TrimAction::Moderate, TrimAction::Aggressive);
686 }
687}