1use crate::chat::ChatMessage;
45
46#[derive(Debug)]
51pub struct ContextWindow {
52 max_tokens: u32,
54 reserved_for_output: u32,
56 messages: Vec<TrackedMessage>,
58}
59
60#[derive(Debug, Clone)]
62struct TrackedMessage {
63 message: ChatMessage,
65 token_count: u32,
67 compactable: bool,
69}
70
71impl ContextWindow {
72 pub fn new(max_tokens: u32, reserved_for_output: u32) -> Self {
83 assert!(
84 reserved_for_output < max_tokens,
85 "reserved_for_output ({reserved_for_output}) must be less than max_tokens ({max_tokens})"
86 );
87 Self {
88 max_tokens,
89 reserved_for_output,
90 messages: Vec::new(),
91 }
92 }
93
94 pub fn push(&mut self, message: ChatMessage, tokens: u32) {
104 self.messages.push(TrackedMessage {
105 message,
106 token_count: tokens,
107 compactable: true,
108 });
109 }
110
111 pub fn available(&self) -> u32 {
115 let input_budget = self.max_tokens.saturating_sub(self.reserved_for_output);
116 input_budget.saturating_sub(self.total_tokens())
117 }
118
119 pub fn iter(&self) -> impl Iterator<Item = &ChatMessage> {
123 self.messages.iter().map(|t| &t.message)
124 }
125
126 pub fn messages(&self) -> Vec<&ChatMessage> {
130 self.messages.iter().map(|t| &t.message).collect()
131 }
132
133 pub fn messages_owned(&self) -> Vec<ChatMessage> {
137 self.messages.iter().map(|t| t.message.clone()).collect()
138 }
139
140 pub fn total_tokens(&self) -> u32 {
142 self.messages
143 .iter()
144 .map(|t| t.token_count)
145 .fold(0, u32::saturating_add)
146 }
147
148 pub fn len(&self) -> usize {
150 self.messages.len()
151 }
152
153 pub fn is_empty(&self) -> bool {
155 self.messages.is_empty()
156 }
157
158 #[allow(clippy::cast_precision_loss)]
180 pub fn needs_compaction(&self, threshold: f32) -> bool {
181 let input_budget = self.max_tokens.saturating_sub(self.reserved_for_output);
182 if input_budget == 0 {
183 return false;
184 }
185 let usage_ratio = self.total_tokens() as f32 / input_budget as f32;
188 usage_ratio > threshold
189 }
190
191 pub fn compact(&mut self) -> Vec<ChatMessage> {
201 let mut removed = Vec::new();
202 let mut retained = Vec::new();
203
204 for tracked in self.messages.drain(..) {
205 if tracked.compactable {
206 removed.push(tracked.message);
207 } else {
208 retained.push(tracked);
209 }
210 }
211
212 self.messages = retained;
213 removed
214 }
215
216 pub fn protect_recent(&mut self, n: usize) {
226 let len = self.messages.len();
227 let start = len.saturating_sub(n);
228 for msg in &mut self.messages[start..] {
229 msg.compactable = false;
230 }
231 }
232
233 pub fn protect(&mut self, index: usize) {
241 self.messages[index].compactable = false;
242 }
243
244 pub fn unprotect(&mut self, index: usize) {
252 self.messages[index].compactable = true;
253 }
254
255 pub fn is_protected(&self, index: usize) -> bool {
261 !self.messages[index].compactable
262 }
263
264 pub fn input_budget(&self) -> u32 {
266 self.max_tokens.saturating_sub(self.reserved_for_output)
267 }
268
269 pub fn max_tokens(&self) -> u32 {
271 self.max_tokens
272 }
273
274 pub fn reserved_for_output(&self) -> u32 {
276 self.reserved_for_output
277 }
278
279 pub fn clear(&mut self) {
281 self.messages.clear();
282 }
283
284 pub fn token_count(&self, index: usize) -> u32 {
290 self.messages[index].token_count
291 }
292
293 pub fn update_token_count(&mut self, index: usize, tokens: u32) {
302 self.messages[index].token_count = tokens;
303 }
304
305 pub fn force_fit(&mut self) -> Vec<ChatMessage> {
316 let mut removed = Vec::new();
317
318 while self.needs_compaction(1.0) {
320 let idx = self.messages.iter().position(|m| m.compactable);
321 match idx {
322 Some(i) => removed.push(self.messages.remove(i).message),
323 None => break, }
325 }
326
327 removed
328 }
329}
330
331#[allow(clippy::cast_possible_truncation)]
346pub fn estimate_tokens(text: &str) -> u32 {
347 if text.is_empty() {
348 return 0;
349 }
350 let len = text.len().min(u32::MAX as usize) as u32;
353 len.div_ceil(4).max(1)
354}
355
356pub fn estimate_message_tokens(message: &ChatMessage) -> u32 {
360 use crate::chat::ContentBlock;
361
362 let content_tokens: u32 = message
363 .content
364 .iter()
365 .map(|block| match block {
366 ContentBlock::Text(text) => estimate_tokens(text),
367 ContentBlock::Image { .. } => 85,
369 ContentBlock::ToolCall(tc) => {
370 estimate_tokens(&tc.name) + estimate_tokens(&tc.arguments.to_string())
371 }
372 ContentBlock::ToolResult(tr) => estimate_tokens(&tr.content) + 10,
374 ContentBlock::Reasoning { content } => estimate_tokens(content),
375 })
376 .sum();
377
378 content_tokens + 4
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use crate::chat::ChatRole;
386
387 fn user_msg(text: &str) -> ChatMessage {
388 ChatMessage::user(text)
389 }
390
391 fn assistant_msg(text: &str) -> ChatMessage {
392 ChatMessage::assistant(text)
393 }
394
395 fn system_msg(text: &str) -> ChatMessage {
396 ChatMessage::system(text)
397 }
398
399 #[test]
402 fn test_new_context_window() {
403 let window = ContextWindow::new(8000, 1000);
404 assert_eq!(window.max_tokens(), 8000);
405 assert_eq!(window.reserved_for_output(), 1000);
406 assert_eq!(window.input_budget(), 7000);
407 assert!(window.is_empty());
408 assert_eq!(window.len(), 0);
409 }
410
411 #[test]
412 #[should_panic(expected = "reserved_for_output")]
413 fn test_new_invalid_reserved() {
414 ContextWindow::new(1000, 1000);
415 }
416
417 #[test]
418 #[should_panic(expected = "reserved_for_output")]
419 fn test_new_reserved_exceeds_max() {
420 ContextWindow::new(1000, 2000);
421 }
422
423 #[test]
424 fn test_push_and_len() {
425 let mut window = ContextWindow::new(8000, 1000);
426 window.push(user_msg("Hello"), 10);
427 window.push(assistant_msg("Hi"), 8);
428
429 assert_eq!(window.len(), 2);
430 assert!(!window.is_empty());
431 }
432
433 #[test]
434 fn test_total_tokens() {
435 let mut window = ContextWindow::new(8000, 1000);
436 window.push(user_msg("Hello"), 10);
437 window.push(assistant_msg("Hi"), 8);
438 window.push(user_msg("How are you?"), 15);
439
440 assert_eq!(window.total_tokens(), 33);
441 }
442
443 #[test]
444 fn test_available_tokens() {
445 let mut window = ContextWindow::new(8000, 1000);
446 assert_eq!(window.available(), 7000);
448
449 window.push(user_msg("Hello"), 100);
450 assert_eq!(window.available(), 6900);
451
452 window.push(assistant_msg("Hi"), 50);
453 assert_eq!(window.available(), 6850);
454 }
455
456 #[test]
457 fn test_available_saturates() {
458 let mut window = ContextWindow::new(1000, 100);
459 window.push(user_msg("Large message"), 1000);
461 assert_eq!(window.available(), 0);
463 }
464
465 #[test]
466 fn test_messages() {
467 let mut window = ContextWindow::new(8000, 1000);
468 window.push(user_msg("Hello"), 10);
469 window.push(assistant_msg("Hi"), 8);
470
471 let messages = window.messages();
472 assert_eq!(messages.len(), 2);
473 assert_eq!(messages[0].role, ChatRole::User);
474 assert_eq!(messages[1].role, ChatRole::Assistant);
475 }
476
477 #[test]
478 fn test_messages_owned() {
479 let mut window = ContextWindow::new(8000, 1000);
480 window.push(user_msg("Hello"), 10);
481
482 let messages = window.messages_owned();
483 assert_eq!(messages.len(), 1);
484 assert_eq!(messages[0].role, ChatRole::User);
485 }
486
487 #[test]
490 fn test_needs_compaction_below_threshold() {
491 let mut window = ContextWindow::new(1000, 200);
492 window.push(user_msg("Hello"), 400);
494 assert!(!window.needs_compaction(0.8));
496 }
497
498 #[test]
499 fn test_needs_compaction_above_threshold() {
500 let mut window = ContextWindow::new(1000, 200);
501 window.push(user_msg("Hello"), 700);
503 assert!(window.needs_compaction(0.8));
505 }
506
507 #[test]
508 fn test_needs_compaction_at_threshold() {
509 let mut window = ContextWindow::new(1000, 200);
510 window.push(user_msg("Hello"), 640);
512 assert!(!window.needs_compaction(0.8));
514 }
515
516 #[test]
517 fn test_needs_compaction_zero_budget() {
518 let window = ContextWindow::new(100, 99);
519 assert!(!window.needs_compaction(0.8));
522 }
523
524 #[test]
525 fn test_compact_all_compactable() {
526 let mut window = ContextWindow::new(8000, 1000);
527 window.push(user_msg("Hello"), 10);
528 window.push(assistant_msg("Hi"), 8);
529 window.push(user_msg("Bye"), 5);
530
531 let removed = window.compact();
532
533 assert_eq!(removed.len(), 3);
534 assert!(window.is_empty());
535 assert_eq!(window.total_tokens(), 0);
536 }
537
538 #[test]
539 fn test_compact_with_protected() {
540 let mut window = ContextWindow::new(8000, 1000);
541 window.push(system_msg("System"), 20);
542 window.push(user_msg("Hello"), 10);
543 window.push(assistant_msg("Hi"), 8);
544 window.push(user_msg("Question"), 15);
545
546 window.protect(0);
548 window.protect_recent(2);
549
550 let removed = window.compact();
551
552 assert_eq!(removed.len(), 1);
554 assert_eq!(window.len(), 3);
555 assert_eq!(window.total_tokens(), 20 + 8 + 15);
556 }
557
558 #[test]
559 fn test_compact_none_compactable() {
560 let mut window = ContextWindow::new(8000, 1000);
561 window.push(system_msg("System"), 20);
562 window.push(user_msg("Hello"), 10);
563
564 window.protect_recent(2);
566
567 let removed = window.compact();
568
569 assert!(removed.is_empty());
570 assert_eq!(window.len(), 2);
571 }
572
573 #[test]
574 fn test_protect_recent() {
575 let mut window = ContextWindow::new(8000, 1000);
576 window.push(user_msg("1"), 10);
577 window.push(user_msg("2"), 10);
578 window.push(user_msg("3"), 10);
579 window.push(user_msg("4"), 10);
580
581 window.protect_recent(2);
582
583 let removed = window.compact();
584
585 assert_eq!(removed.len(), 2);
587 assert_eq!(window.len(), 2);
588 }
589
590 #[test]
591 fn test_protect_recent_more_than_len() {
592 let mut window = ContextWindow::new(8000, 1000);
593 window.push(user_msg("1"), 10);
594 window.push(user_msg("2"), 10);
595
596 window.protect_recent(10); let removed = window.compact();
599
600 assert!(removed.is_empty());
601 assert_eq!(window.len(), 2);
602 }
603
604 #[test]
605 fn test_protect_index() {
606 let mut window = ContextWindow::new(8000, 1000);
607 window.push(user_msg("1"), 10);
608 window.push(user_msg("2"), 10);
609 window.push(user_msg("3"), 10);
610
611 window.protect(1); let removed = window.compact();
614
615 assert_eq!(removed.len(), 2);
616 assert_eq!(window.len(), 1);
617 }
618
619 #[test]
620 fn test_unprotect() {
621 let mut window = ContextWindow::new(8000, 1000);
622 window.push(user_msg("1"), 10);
623 window.push(user_msg("2"), 10);
624
625 window.protect(0);
626 assert!(window.is_protected(0));
627
628 window.unprotect(0);
629 assert!(!window.is_protected(0));
630
631 let removed = window.compact();
632 assert_eq!(removed.len(), 2);
633 }
634
635 #[test]
636 fn test_is_protected() {
637 let mut window = ContextWindow::new(8000, 1000);
638 window.push(user_msg("1"), 10);
639 window.push(user_msg("2"), 10);
640
641 assert!(!window.is_protected(0));
643 assert!(!window.is_protected(1));
644
645 window.protect(0);
646 assert!(window.is_protected(0));
647 assert!(!window.is_protected(1));
648 }
649
650 #[test]
651 fn test_iter() {
652 let mut window = ContextWindow::new(8000, 1000);
653 window.push(user_msg("Hello"), 10);
654 window.push(assistant_msg("Hi"), 8);
655
656 let collected: Vec<_> = window.iter().collect();
657 assert_eq!(collected.len(), 2);
658 assert_eq!(collected[0].role, ChatRole::User);
659 assert_eq!(collected[1].role, ChatRole::Assistant);
660 }
661
662 #[test]
665 fn test_token_count() {
666 let mut window = ContextWindow::new(8000, 1000);
667 window.push(user_msg("Hello"), 42);
668
669 assert_eq!(window.token_count(0), 42);
670 }
671
672 #[test]
673 fn test_update_token_count() {
674 let mut window = ContextWindow::new(8000, 1000);
675 window.push(user_msg("Hello"), 10);
676
677 assert_eq!(window.total_tokens(), 10);
678
679 window.update_token_count(0, 15);
680
681 assert_eq!(window.token_count(0), 15);
682 assert_eq!(window.total_tokens(), 15);
683 }
684
685 #[test]
686 fn test_clear() {
687 let mut window = ContextWindow::new(8000, 1000);
688 window.push(user_msg("Hello"), 10);
689 window.push(assistant_msg("Hi"), 8);
690
691 window.clear();
692
693 assert!(window.is_empty());
694 assert_eq!(window.total_tokens(), 0);
695 assert_eq!(window.available(), 7000);
696 }
697
698 #[test]
701 fn test_estimate_tokens_empty() {
702 assert_eq!(estimate_tokens(""), 0);
703 }
704
705 #[test]
706 fn test_estimate_tokens_short() {
707 assert_eq!(estimate_tokens("Hi"), 1);
709 }
710
711 #[test]
712 fn test_estimate_tokens_medium() {
713 assert_eq!(estimate_tokens("Hello world"), 3);
715 }
716
717 #[test]
718 fn test_estimate_tokens_exact_multiple() {
719 assert_eq!(estimate_tokens("1234567890123456"), 4);
721 }
722
723 #[test]
724 fn test_estimate_tokens_minimum() {
725 assert_eq!(estimate_tokens("a"), 1);
727 }
728
729 #[test]
730 fn test_estimate_message_tokens() {
731 let msg = user_msg("Hello world");
732 let estimate = estimate_message_tokens(&msg);
733 assert_eq!(estimate, 7);
735 }
736
737 #[test]
738 fn test_estimate_message_tokens_empty() {
739 let msg = ChatMessage {
740 role: ChatRole::User,
741 content: vec![],
742 };
743 let estimate = estimate_message_tokens(&msg);
744 assert_eq!(estimate, 4);
746 }
747
748 #[test]
751 fn test_context_window_debug() {
752 let window = ContextWindow::new(8000, 1000);
753 let debug = format!("{window:?}");
754 assert!(debug.contains("ContextWindow"));
755 assert!(debug.contains("8000"));
756 }
757
758 #[test]
759 fn test_context_window_is_send_sync() {
760 fn assert_send_sync<T: Send + Sync>() {}
761 assert_send_sync::<ContextWindow>();
762 }
763
764 #[test]
767 fn test_typical_conversation_flow() {
768 let mut window = ContextWindow::new(4000, 500);
769 window.push(system_msg("You are a helpful assistant."), 15);
773 window.protect(0);
774
775 window.push(user_msg("What is 2+2?"), 20);
777 window.push(assistant_msg("2+2 equals 4."), 25);
778 window.push(user_msg("What about 3+3?"), 22);
779 window.push(assistant_msg("3+3 equals 6."), 25);
780
781 assert_eq!(window.len(), 5);
782 assert_eq!(window.total_tokens(), 107);
783 assert_eq!(window.available(), 3500 - 107);
784
785 assert!(!window.needs_compaction(0.8));
787
788 for i in 0..50 {
790 window.push(user_msg(&format!("Question {i}")), 30);
791 window.push(assistant_msg(&format!("Answer {i}")), 30);
792 }
793
794 assert!(window.needs_compaction(0.8));
796
797 window.protect_recent(4);
799
800 let removed = window.compact();
802
803 assert!(!removed.is_empty());
805 assert!(window.len() <= 5); assert!(window.messages()[0].role == ChatRole::System);
807 }
808
809 #[test]
810 fn test_compact_then_add_summary() {
811 let mut window = ContextWindow::new(1000, 100);
812 window.push(system_msg("System"), 20);
815 window.protect(0);
816
817 for _ in 0..10 {
819 window.push(user_msg("Message"), 80);
820 }
821
822 let removed = window.compact();
824 assert_eq!(removed.len(), 10);
825 assert_eq!(window.len(), 1); window.push(
829 ChatMessage::system("Summary of previous conversation..."),
830 50,
831 );
832
833 assert_eq!(window.len(), 2);
834 assert_eq!(window.total_tokens(), 70);
835 }
836
837 #[test]
840 fn test_force_fit_drops_oldest_first() {
841 let mut window = ContextWindow::new(1000, 100);
842 window.push(system_msg("System"), 20);
845 window.protect(0);
846
847 window.push(user_msg("Old"), 500);
848 window.push(user_msg("Newer"), 500);
849
850 assert!(window.needs_compaction(1.0));
852
853 let removed = window.force_fit();
854
855 assert_eq!(removed.len(), 1);
858 assert_eq!(window.len(), 2); assert_eq!(window.total_tokens(), 520);
860 assert!(!window.needs_compaction(1.0));
861 }
862
863 #[test]
864 fn test_force_fit_stops_when_under_budget() {
865 let mut window = ContextWindow::new(1000, 100);
866 window.push(user_msg("A"), 300);
869 window.push(user_msg("B"), 300);
870 window.push(user_msg("C"), 300);
871 window.push(user_msg("D"), 200);
872
873 assert!(window.needs_compaction(1.0));
875
876 let removed = window.force_fit();
877
878 assert_eq!(removed.len(), 1);
880 assert_eq!(window.total_tokens(), 800);
881 }
882
883 #[test]
884 fn test_force_fit_skips_protected() {
885 let mut window = ContextWindow::new(1000, 100);
886 window.push(system_msg("System"), 400);
889 window.protect(0);
890
891 window.push(user_msg("Old 1"), 300);
892 window.push(user_msg("Old 2"), 300);
893
894 let removed = window.force_fit();
896
897 assert_eq!(removed.len(), 1);
899 assert_eq!(window.len(), 2); assert_eq!(window.total_tokens(), 700);
901 }
902
903 #[test]
904 fn test_force_fit_noop_when_under_budget() {
905 let mut window = ContextWindow::new(1000, 100);
906 window.push(user_msg("Small"), 50);
907
908 let removed = window.force_fit();
909
910 assert!(removed.is_empty());
911 assert_eq!(window.len(), 1);
912 }
913
914 #[test]
915 fn test_force_fit_stops_when_only_protected_remain() {
916 let mut window = ContextWindow::new(1000, 100);
917 window.push(system_msg("Big system"), 600);
920 window.protect(0);
921
922 window.push(user_msg("Big user"), 400);
923 window.protect(1);
924
925 let removed = window.force_fit();
927
928 assert!(removed.is_empty());
929 assert_eq!(window.len(), 2);
930 assert!(window.needs_compaction(1.0));
932 }
933}