1use crate::chat::ChatMessage;
45
46#[derive(Debug)]
51pub struct ContextWindow {
52 max_tokens: u32,
54 reserved_for_output: u32,
56 messages: Vec<TrackedMessage>,
58}
59
60#[derive(Debug, Clone)]
62struct TrackedMessage {
63 message: ChatMessage,
65 token_count: u32,
67 compactable: bool,
69}
70
71impl ContextWindow {
72 pub fn new(max_tokens: u32, reserved_for_output: u32) -> Self {
83 assert!(
84 reserved_for_output < max_tokens,
85 "reserved_for_output ({reserved_for_output}) must be less than max_tokens ({max_tokens})"
86 );
87 Self {
88 max_tokens,
89 reserved_for_output,
90 messages: Vec::new(),
91 }
92 }
93
94 pub fn push(&mut self, message: ChatMessage, tokens: u32) {
104 self.messages.push(TrackedMessage {
105 message,
106 token_count: tokens,
107 compactable: true,
108 });
109 }
110
111 pub fn available(&self) -> u32 {
115 let input_budget = self.max_tokens.saturating_sub(self.reserved_for_output);
116 input_budget.saturating_sub(self.total_tokens())
117 }
118
119 pub fn iter(&self) -> impl Iterator<Item = &ChatMessage> {
123 self.messages.iter().map(|t| &t.message)
124 }
125
126 pub fn messages(&self) -> Vec<&ChatMessage> {
130 self.messages.iter().map(|t| &t.message).collect()
131 }
132
133 pub fn messages_owned(&self) -> Vec<ChatMessage> {
137 self.messages.iter().map(|t| t.message.clone()).collect()
138 }
139
140 pub fn total_tokens(&self) -> u32 {
142 self.messages
143 .iter()
144 .map(|t| t.token_count)
145 .fold(0, u32::saturating_add)
146 }
147
148 pub fn len(&self) -> usize {
150 self.messages.len()
151 }
152
153 pub fn is_empty(&self) -> bool {
155 self.messages.is_empty()
156 }
157
158 #[allow(clippy::cast_precision_loss)]
180 pub fn needs_compaction(&self, threshold: f32) -> bool {
181 let input_budget = self.max_tokens.saturating_sub(self.reserved_for_output);
182 if input_budget == 0 {
183 return false;
184 }
185 let usage_ratio = self.total_tokens() as f32 / input_budget as f32;
188 usage_ratio > threshold
189 }
190
191 pub fn compact(&mut self) -> Vec<ChatMessage> {
201 let mut removed = Vec::new();
202 let mut retained = Vec::new();
203
204 for tracked in self.messages.drain(..) {
205 if tracked.compactable {
206 removed.push(tracked.message);
207 } else {
208 retained.push(tracked);
209 }
210 }
211
212 self.messages = retained;
213 removed
214 }
215
216 pub fn protect_recent(&mut self, n: usize) {
226 let len = self.messages.len();
227 let start = len.saturating_sub(n);
228 for msg in &mut self.messages[start..] {
229 msg.compactable = false;
230 }
231 }
232
233 pub fn protect(&mut self, index: usize) {
241 self.messages[index].compactable = false;
242 }
243
244 pub fn unprotect(&mut self, index: usize) {
252 self.messages[index].compactable = true;
253 }
254
255 pub fn is_protected(&self, index: usize) -> bool {
261 !self.messages[index].compactable
262 }
263
264 pub fn input_budget(&self) -> u32 {
266 self.max_tokens.saturating_sub(self.reserved_for_output)
267 }
268
269 pub fn max_tokens(&self) -> u32 {
271 self.max_tokens
272 }
273
274 pub fn reserved_for_output(&self) -> u32 {
276 self.reserved_for_output
277 }
278
279 pub fn clear(&mut self) {
281 self.messages.clear();
282 }
283
284 pub fn token_count(&self, index: usize) -> u32 {
290 self.messages[index].token_count
291 }
292
293 pub fn update_token_count(&mut self, index: usize, tokens: u32) {
302 self.messages[index].token_count = tokens;
303 }
304}
305
306#[allow(clippy::cast_possible_truncation)]
321pub fn estimate_tokens(text: &str) -> u32 {
322 if text.is_empty() {
323 return 0;
324 }
325 let len = text.len().min(u32::MAX as usize) as u32;
328 len.div_ceil(4).max(1)
329}
330
331pub fn estimate_message_tokens(message: &ChatMessage) -> u32 {
335 use crate::chat::ContentBlock;
336
337 let content_tokens: u32 = message
338 .content
339 .iter()
340 .map(|block| match block {
341 ContentBlock::Text(text) => estimate_tokens(text),
342 ContentBlock::Image { .. } => 85,
344 ContentBlock::ToolCall(tc) => {
345 estimate_tokens(&tc.name) + estimate_tokens(&tc.arguments.to_string())
346 }
347 ContentBlock::ToolResult(tr) => estimate_tokens(&tr.content) + 10,
349 ContentBlock::Reasoning { content } => estimate_tokens(content),
350 })
351 .sum();
352
353 content_tokens + 4
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use crate::chat::ChatRole;
361
362 fn user_msg(text: &str) -> ChatMessage {
363 ChatMessage::user(text)
364 }
365
366 fn assistant_msg(text: &str) -> ChatMessage {
367 ChatMessage::assistant(text)
368 }
369
370 fn system_msg(text: &str) -> ChatMessage {
371 ChatMessage::system(text)
372 }
373
374 #[test]
377 fn test_new_context_window() {
378 let window = ContextWindow::new(8000, 1000);
379 assert_eq!(window.max_tokens(), 8000);
380 assert_eq!(window.reserved_for_output(), 1000);
381 assert_eq!(window.input_budget(), 7000);
382 assert!(window.is_empty());
383 assert_eq!(window.len(), 0);
384 }
385
386 #[test]
387 #[should_panic(expected = "reserved_for_output")]
388 fn test_new_invalid_reserved() {
389 ContextWindow::new(1000, 1000);
390 }
391
392 #[test]
393 #[should_panic(expected = "reserved_for_output")]
394 fn test_new_reserved_exceeds_max() {
395 ContextWindow::new(1000, 2000);
396 }
397
398 #[test]
399 fn test_push_and_len() {
400 let mut window = ContextWindow::new(8000, 1000);
401 window.push(user_msg("Hello"), 10);
402 window.push(assistant_msg("Hi"), 8);
403
404 assert_eq!(window.len(), 2);
405 assert!(!window.is_empty());
406 }
407
408 #[test]
409 fn test_total_tokens() {
410 let mut window = ContextWindow::new(8000, 1000);
411 window.push(user_msg("Hello"), 10);
412 window.push(assistant_msg("Hi"), 8);
413 window.push(user_msg("How are you?"), 15);
414
415 assert_eq!(window.total_tokens(), 33);
416 }
417
418 #[test]
419 fn test_available_tokens() {
420 let mut window = ContextWindow::new(8000, 1000);
421 assert_eq!(window.available(), 7000);
423
424 window.push(user_msg("Hello"), 100);
425 assert_eq!(window.available(), 6900);
426
427 window.push(assistant_msg("Hi"), 50);
428 assert_eq!(window.available(), 6850);
429 }
430
431 #[test]
432 fn test_available_saturates() {
433 let mut window = ContextWindow::new(1000, 100);
434 window.push(user_msg("Large message"), 1000);
436 assert_eq!(window.available(), 0);
438 }
439
440 #[test]
441 fn test_messages() {
442 let mut window = ContextWindow::new(8000, 1000);
443 window.push(user_msg("Hello"), 10);
444 window.push(assistant_msg("Hi"), 8);
445
446 let messages = window.messages();
447 assert_eq!(messages.len(), 2);
448 assert_eq!(messages[0].role, ChatRole::User);
449 assert_eq!(messages[1].role, ChatRole::Assistant);
450 }
451
452 #[test]
453 fn test_messages_owned() {
454 let mut window = ContextWindow::new(8000, 1000);
455 window.push(user_msg("Hello"), 10);
456
457 let messages = window.messages_owned();
458 assert_eq!(messages.len(), 1);
459 assert_eq!(messages[0].role, ChatRole::User);
460 }
461
462 #[test]
465 fn test_needs_compaction_below_threshold() {
466 let mut window = ContextWindow::new(1000, 200);
467 window.push(user_msg("Hello"), 400);
469 assert!(!window.needs_compaction(0.8));
471 }
472
473 #[test]
474 fn test_needs_compaction_above_threshold() {
475 let mut window = ContextWindow::new(1000, 200);
476 window.push(user_msg("Hello"), 700);
478 assert!(window.needs_compaction(0.8));
480 }
481
482 #[test]
483 fn test_needs_compaction_at_threshold() {
484 let mut window = ContextWindow::new(1000, 200);
485 window.push(user_msg("Hello"), 640);
487 assert!(!window.needs_compaction(0.8));
489 }
490
491 #[test]
492 fn test_needs_compaction_zero_budget() {
493 let window = ContextWindow::new(100, 99);
494 assert!(!window.needs_compaction(0.8));
497 }
498
499 #[test]
500 fn test_compact_all_compactable() {
501 let mut window = ContextWindow::new(8000, 1000);
502 window.push(user_msg("Hello"), 10);
503 window.push(assistant_msg("Hi"), 8);
504 window.push(user_msg("Bye"), 5);
505
506 let removed = window.compact();
507
508 assert_eq!(removed.len(), 3);
509 assert!(window.is_empty());
510 assert_eq!(window.total_tokens(), 0);
511 }
512
513 #[test]
514 fn test_compact_with_protected() {
515 let mut window = ContextWindow::new(8000, 1000);
516 window.push(system_msg("System"), 20);
517 window.push(user_msg("Hello"), 10);
518 window.push(assistant_msg("Hi"), 8);
519 window.push(user_msg("Question"), 15);
520
521 window.protect(0);
523 window.protect_recent(2);
524
525 let removed = window.compact();
526
527 assert_eq!(removed.len(), 1);
529 assert_eq!(window.len(), 3);
530 assert_eq!(window.total_tokens(), 20 + 8 + 15);
531 }
532
533 #[test]
534 fn test_compact_none_compactable() {
535 let mut window = ContextWindow::new(8000, 1000);
536 window.push(system_msg("System"), 20);
537 window.push(user_msg("Hello"), 10);
538
539 window.protect_recent(2);
541
542 let removed = window.compact();
543
544 assert!(removed.is_empty());
545 assert_eq!(window.len(), 2);
546 }
547
548 #[test]
549 fn test_protect_recent() {
550 let mut window = ContextWindow::new(8000, 1000);
551 window.push(user_msg("1"), 10);
552 window.push(user_msg("2"), 10);
553 window.push(user_msg("3"), 10);
554 window.push(user_msg("4"), 10);
555
556 window.protect_recent(2);
557
558 let removed = window.compact();
559
560 assert_eq!(removed.len(), 2);
562 assert_eq!(window.len(), 2);
563 }
564
565 #[test]
566 fn test_protect_recent_more_than_len() {
567 let mut window = ContextWindow::new(8000, 1000);
568 window.push(user_msg("1"), 10);
569 window.push(user_msg("2"), 10);
570
571 window.protect_recent(10); let removed = window.compact();
574
575 assert!(removed.is_empty());
576 assert_eq!(window.len(), 2);
577 }
578
579 #[test]
580 fn test_protect_index() {
581 let mut window = ContextWindow::new(8000, 1000);
582 window.push(user_msg("1"), 10);
583 window.push(user_msg("2"), 10);
584 window.push(user_msg("3"), 10);
585
586 window.protect(1); let removed = window.compact();
589
590 assert_eq!(removed.len(), 2);
591 assert_eq!(window.len(), 1);
592 }
593
594 #[test]
595 fn test_unprotect() {
596 let mut window = ContextWindow::new(8000, 1000);
597 window.push(user_msg("1"), 10);
598 window.push(user_msg("2"), 10);
599
600 window.protect(0);
601 assert!(window.is_protected(0));
602
603 window.unprotect(0);
604 assert!(!window.is_protected(0));
605
606 let removed = window.compact();
607 assert_eq!(removed.len(), 2);
608 }
609
610 #[test]
611 fn test_is_protected() {
612 let mut window = ContextWindow::new(8000, 1000);
613 window.push(user_msg("1"), 10);
614 window.push(user_msg("2"), 10);
615
616 assert!(!window.is_protected(0));
618 assert!(!window.is_protected(1));
619
620 window.protect(0);
621 assert!(window.is_protected(0));
622 assert!(!window.is_protected(1));
623 }
624
625 #[test]
626 fn test_iter() {
627 let mut window = ContextWindow::new(8000, 1000);
628 window.push(user_msg("Hello"), 10);
629 window.push(assistant_msg("Hi"), 8);
630
631 let collected: Vec<_> = window.iter().collect();
632 assert_eq!(collected.len(), 2);
633 assert_eq!(collected[0].role, ChatRole::User);
634 assert_eq!(collected[1].role, ChatRole::Assistant);
635 }
636
637 #[test]
640 fn test_token_count() {
641 let mut window = ContextWindow::new(8000, 1000);
642 window.push(user_msg("Hello"), 42);
643
644 assert_eq!(window.token_count(0), 42);
645 }
646
647 #[test]
648 fn test_update_token_count() {
649 let mut window = ContextWindow::new(8000, 1000);
650 window.push(user_msg("Hello"), 10);
651
652 assert_eq!(window.total_tokens(), 10);
653
654 window.update_token_count(0, 15);
655
656 assert_eq!(window.token_count(0), 15);
657 assert_eq!(window.total_tokens(), 15);
658 }
659
660 #[test]
661 fn test_clear() {
662 let mut window = ContextWindow::new(8000, 1000);
663 window.push(user_msg("Hello"), 10);
664 window.push(assistant_msg("Hi"), 8);
665
666 window.clear();
667
668 assert!(window.is_empty());
669 assert_eq!(window.total_tokens(), 0);
670 assert_eq!(window.available(), 7000);
671 }
672
673 #[test]
676 fn test_estimate_tokens_empty() {
677 assert_eq!(estimate_tokens(""), 0);
678 }
679
680 #[test]
681 fn test_estimate_tokens_short() {
682 assert_eq!(estimate_tokens("Hi"), 1);
684 }
685
686 #[test]
687 fn test_estimate_tokens_medium() {
688 assert_eq!(estimate_tokens("Hello world"), 3);
690 }
691
692 #[test]
693 fn test_estimate_tokens_exact_multiple() {
694 assert_eq!(estimate_tokens("1234567890123456"), 4);
696 }
697
698 #[test]
699 fn test_estimate_tokens_minimum() {
700 assert_eq!(estimate_tokens("a"), 1);
702 }
703
704 #[test]
705 fn test_estimate_message_tokens() {
706 let msg = user_msg("Hello world");
707 let estimate = estimate_message_tokens(&msg);
708 assert_eq!(estimate, 7);
710 }
711
712 #[test]
713 fn test_estimate_message_tokens_empty() {
714 let msg = ChatMessage {
715 role: ChatRole::User,
716 content: vec![],
717 };
718 let estimate = estimate_message_tokens(&msg);
719 assert_eq!(estimate, 4);
721 }
722
723 #[test]
726 fn test_context_window_debug() {
727 let window = ContextWindow::new(8000, 1000);
728 let debug = format!("{window:?}");
729 assert!(debug.contains("ContextWindow"));
730 assert!(debug.contains("8000"));
731 }
732
733 #[test]
734 fn test_context_window_is_send_sync() {
735 fn assert_send_sync<T: Send + Sync>() {}
736 assert_send_sync::<ContextWindow>();
737 }
738
739 #[test]
742 fn test_typical_conversation_flow() {
743 let mut window = ContextWindow::new(4000, 500);
744 window.push(system_msg("You are a helpful assistant."), 15);
748 window.protect(0);
749
750 window.push(user_msg("What is 2+2?"), 20);
752 window.push(assistant_msg("2+2 equals 4."), 25);
753 window.push(user_msg("What about 3+3?"), 22);
754 window.push(assistant_msg("3+3 equals 6."), 25);
755
756 assert_eq!(window.len(), 5);
757 assert_eq!(window.total_tokens(), 107);
758 assert_eq!(window.available(), 3500 - 107);
759
760 assert!(!window.needs_compaction(0.8));
762
763 for i in 0..50 {
765 window.push(user_msg(&format!("Question {i}")), 30);
766 window.push(assistant_msg(&format!("Answer {i}")), 30);
767 }
768
769 assert!(window.needs_compaction(0.8));
771
772 window.protect_recent(4);
774
775 let removed = window.compact();
777
778 assert!(!removed.is_empty());
780 assert!(window.len() <= 5); assert!(window.messages()[0].role == ChatRole::System);
782 }
783
784 #[test]
785 fn test_compact_then_add_summary() {
786 let mut window = ContextWindow::new(1000, 100);
787 window.push(system_msg("System"), 20);
790 window.protect(0);
791
792 for _ in 0..10 {
794 window.push(user_msg("Message"), 80);
795 }
796
797 let removed = window.compact();
799 assert_eq!(removed.len(), 10);
800 assert_eq!(window.len(), 1); window.push(
804 ChatMessage::system("Summary of previous conversation..."),
805 50,
806 );
807
808 assert_eq!(window.len(), 2);
809 assert_eq!(window.total_tokens(), 70);
810 }
811}