1use crate::context::compressor::MessageCompressor;
30use crate::context::summarizer::{Summarizer, SummarizerClient, DEFAULT_SUMMARY_BUDGET};
31use crate::context::token_estimator::TokenEstimator;
32use crate::context::types::{
33 CompressionConfig, CompressionDetails, CompressionResult, ContextConfig, ContextError,
34 ContextExport, ContextStats, ContextUsage, ConversationTurn, TokenUsage,
35};
36use crate::conversation::message::{Message, MessageContent};
37use std::sync::Arc;
38
39const TOOL_REFERENCE_PLACEHOLDER: &str = "[Tool reference collapsed]";
45
46const SUMMARY_PREFIX: &str = "[Previous conversation summary]\n";
48
49pub struct EnhancedContextManager {
59 config: ContextConfig,
61
62 turns: Vec<ConversationTurn>,
64
65 system_prompt: String,
67
68 compression_count: usize,
70
71 saved_tokens: usize,
73
74 summarizer_client: Option<Arc<dyn SummarizerClient>>,
76}
77
78impl EnhancedContextManager {
79 pub fn new(config: ContextConfig) -> Self {
93 Self {
94 config,
95 turns: Vec::new(),
96 system_prompt: String::new(),
97 compression_count: 0,
98 saved_tokens: 0,
99 summarizer_client: None,
100 }
101 }
102
103 pub fn with_default_config() -> Self {
105 Self::new(ContextConfig::default())
106 }
107
108 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
114 self.system_prompt = prompt.into();
115 }
116
117 pub fn system_prompt(&self) -> &str {
119 &self.system_prompt
120 }
121
122 pub fn set_summarizer_client(&mut self, client: Arc<dyn SummarizerClient>) {
128 self.summarizer_client = Some(client);
129 }
130
131 pub fn has_summarizer_client(&self) -> bool {
133 self.summarizer_client.is_some() && self.config.enable_ai_summary
134 }
135
136 pub fn add_turn(&mut self, user: Message, assistant: Message, api_usage: Option<TokenUsage>) {
151 let user_tokens = TokenEstimator::estimate_message_tokens(&user);
153 let assistant_tokens = TokenEstimator::estimate_message_tokens(&assistant);
154 let total_tokens = user_tokens + assistant_tokens;
155
156 let (final_user, final_assistant, final_tokens) = if self
158 .config
159 .enable_incremental_compression
160 {
161 let compression_config = CompressionConfig {
162 code_block_max_lines: self.config.code_block_max_lines,
163 tool_output_max_chars: self.config.tool_output_max_chars,
164 ..Default::default()
165 };
166
167 let compressed_user = MessageCompressor::compress_message(&user, &compression_config);
168 let compressed_assistant =
169 MessageCompressor::compress_message(&assistant, &compression_config);
170
171 let compressed_user_tokens = TokenEstimator::estimate_message_tokens(&compressed_user);
172 let compressed_assistant_tokens =
173 TokenEstimator::estimate_message_tokens(&compressed_assistant);
174 let compressed_total = compressed_user_tokens + compressed_assistant_tokens;
175
176 (compressed_user, compressed_assistant, compressed_total)
177 } else {
178 (user, assistant, total_tokens)
179 };
180
181 let mut turn = ConversationTurn::new(final_user, final_assistant, final_tokens);
183 turn.original_tokens = total_tokens;
184
185 if final_tokens < total_tokens {
187 turn.compressed = true;
188 self.saved_tokens += total_tokens - final_tokens;
189 }
190
191 if let Some(usage) = api_usage {
193 turn.api_usage = Some(usage);
194 }
195
196 self.turns.push(turn);
197 }
198
199 pub fn turn_count(&self) -> usize {
201 self.turns.len()
202 }
203
204 pub fn turns(&self) -> &[ConversationTurn] {
206 &self.turns
207 }
208
209 pub fn turns_mut(&mut self) -> &mut Vec<ConversationTurn> {
211 &mut self.turns
212 }
213
214 pub fn get_messages(&self) -> Vec<Message> {
229 let mut messages: Vec<Message> = Vec::new();
230
231 if !self.system_prompt.is_empty() {
233 messages.push(Message::user().with_text(&self.system_prompt));
234 }
235
236 let summarized_turns: Vec<&ConversationTurn> =
238 self.turns.iter().filter(|t| t.summarized).collect();
239
240 if !summarized_turns.is_empty() {
241 let combined_summary = summarized_turns
243 .iter()
244 .filter_map(|t| t.summary.as_ref())
245 .cloned()
246 .collect::<Vec<_>>()
247 .join("\n\n");
248
249 if !combined_summary.is_empty() {
250 let summary_text = format!("{}{}", SUMMARY_PREFIX, combined_summary);
251 messages.push(Message::user().with_text(summary_text));
252 }
253 }
254
255 for turn in &self.turns {
257 if !turn.summarized {
258 messages.push(turn.user.clone());
259 messages.push(turn.assistant.clone());
260 }
261 }
262
263 messages
264 }
265
266 pub fn get_messages_collapsed(&self) -> Vec<Message> {
271 self.get_messages()
272 .into_iter()
273 .map(|msg| Self::collapse_tool_references(&msg))
274 .collect()
275 }
276
277 pub fn get_used_tokens(&self) -> usize {
285 let system_tokens = TokenEstimator::estimate_tokens(&self.system_prompt);
286 let turn_tokens: usize = self.turns.iter().map(|t| t.token_estimate).sum();
287 system_tokens + turn_tokens
288 }
289
290 pub fn get_available_tokens(&self) -> usize {
292 let available = self.config.available_tokens();
293 let used = self.get_used_tokens();
294 available.saturating_sub(used)
295 }
296
297 fn should_compress(&self) -> bool {
299 let used = self.get_used_tokens();
300 let threshold = self.config.summarize_token_threshold();
301 used > threshold
302 }
303
304 pub async fn maybe_compress(&mut self) -> Result<(), ContextError> {
317 if self.should_compress() {
318 self.compact().await?;
319 }
320 Ok(())
321 }
322
323 pub async fn compact(&mut self) -> Result<(), ContextError> {
333 let total_turns = self.turns.len();
334 if total_turns == 0 {
335 return Ok(());
336 }
337
338 let keep_recent = self.config.keep_recent_messages.min(total_turns);
340 let turns_to_summarize = total_turns.saturating_sub(keep_recent);
341
342 if turns_to_summarize == 0 {
343 return Ok(());
344 }
345
346 let unsummarized_indices: Vec<usize> = self
348 .turns
349 .iter()
350 .enumerate()
351 .take(turns_to_summarize)
352 .filter(|(_, t)| !t.summarized)
353 .map(|(i, _)| i)
354 .collect();
355
356 if unsummarized_indices.is_empty() {
357 return Ok(());
358 }
359
360 let turns_for_summary: Vec<ConversationTurn> = unsummarized_indices
362 .iter()
363 .map(|&i| self.turns[i].clone())
364 .collect();
365
366 let summary = if self.has_summarizer_client() {
368 let client = self.summarizer_client.as_ref().unwrap();
369 Summarizer::generate_ai_summary(
370 &turns_for_summary,
371 client.as_ref(),
372 DEFAULT_SUMMARY_BUDGET,
373 )
374 .await?
375 } else {
376 Summarizer::create_simple_summary(&turns_for_summary)
377 };
378
379 let original_tokens: usize = turns_for_summary.iter().map(|t| t.token_estimate).sum();
381 let summary_tokens = TokenEstimator::estimate_tokens(&summary);
382
383 for &idx in &unsummarized_indices {
385 let turn = &mut self.turns[idx];
386 turn.mark_summarized(summary.clone(), summary_tokens / unsummarized_indices.len());
387 }
388
389 self.compression_count += 1;
391 self.saved_tokens += original_tokens.saturating_sub(summary_tokens);
392
393 Ok(())
394 }
395
396 pub fn export(&self) -> ContextExport {
406 ContextExport::new(
407 self.system_prompt.clone(),
408 self.turns.clone(),
409 self.config.clone(),
410 self.compression_count,
411 self.saved_tokens,
412 )
413 }
414
415 pub fn import(&mut self, data: ContextExport) {
423 self.system_prompt = data.system_prompt;
424 self.turns = data.turns;
425 self.config = data.config;
426 self.compression_count = data.compression_count;
427 self.saved_tokens = data.saved_tokens;
428 }
429
430 pub fn clear(&mut self) {
435 self.turns.clear();
436 self.compression_count = 0;
437 self.saved_tokens = 0;
438 }
439
440 pub fn reset(&mut self) {
442 self.clear();
443 self.system_prompt.clear();
444 }
445
446 pub fn get_stats(&self) -> ContextStats {
452 let total_messages = self.turns.len() * 2; let estimated_tokens = self.get_used_tokens();
454 let summarized_messages = self.turns.iter().filter(|t| t.summarized).count() * 2;
455
456 let original_tokens: usize = self.turns.iter().map(|t| t.original_tokens).sum();
457 let current_tokens: usize = self.turns.iter().map(|t| t.token_estimate).sum();
458
459 let compression_ratio = if original_tokens > 0 {
460 current_tokens as f64 / original_tokens as f64
461 } else {
462 1.0
463 };
464
465 ContextStats {
466 total_messages,
467 estimated_tokens,
468 summarized_messages,
469 compression_ratio,
470 saved_tokens: self.saved_tokens,
471 compression_count: self.compression_count,
472 }
473 }
474
475 pub fn get_compression_details(&self) -> CompressionDetails {
477 let total_turns = self.turns.len();
478 let summarized_turns = self.turns.iter().filter(|t| t.summarized).count();
479 let compressed_turns = self.turns.iter().filter(|t| t.compressed).count();
480 let recent_turns = total_turns.saturating_sub(summarized_turns);
481
482 let original_tokens: usize = self.turns.iter().map(|t| t.original_tokens).sum();
483 let current_tokens: usize = self.turns.iter().map(|t| t.token_estimate).sum();
484
485 let compression_ratio = if original_tokens > 0 {
486 current_tokens as f64 / original_tokens as f64
487 } else {
488 1.0
489 };
490
491 CompressionDetails {
492 total_turns,
493 summarized_turns,
494 compressed_turns,
495 recent_turns,
496 compression_ratio,
497 saved_tokens: self.saved_tokens,
498 }
499 }
500
501 pub fn get_context_usage(&self) -> ContextUsage {
503 let used = self.get_used_tokens();
504 let total = self.config.max_tokens;
505 ContextUsage::new(used, total)
506 }
507
508 pub fn is_near_limit(&self) -> bool {
512 let usage = self.get_context_usage();
513 usage.percentage > (self.config.summarize_threshold * 100.0)
514 }
515
516 pub fn get_formatted_report(&self) -> String {
518 let stats = self.get_stats();
519 let usage = self.get_context_usage();
520 let details = self.get_compression_details();
521
522 format!(
523 "Context Statistics:\n\
524 - Total messages: {}\n\
525 - Estimated tokens: {} / {} ({:.1}%)\n\
526 - Available tokens: {}\n\
527 - Summarized messages: {}\n\
528 - Compression ratio: {:.2}\n\
529 - Tokens saved: {}\n\
530 - Compression operations: {}\n\
531 \n\
532 Compression Details:\n\
533 - Total turns: {}\n\
534 - Summarized turns: {}\n\
535 - Compressed turns: {}\n\
536 - Recent turns: {}",
537 stats.total_messages,
538 usage.used,
539 usage.total,
540 usage.percentage,
541 usage.available,
542 stats.summarized_messages,
543 stats.compression_ratio,
544 stats.saved_tokens,
545 stats.compression_count,
546 details.total_turns,
547 details.summarized_turns,
548 details.compressed_turns,
549 details.recent_turns,
550 )
551 }
552
553 pub fn analyze_compression(&self) -> CompressionResult {
555 let original_tokens: usize = self.turns.iter().map(|t| t.original_tokens).sum();
556 let current_tokens: usize = self.turns.iter().map(|t| t.token_estimate).sum();
557
558 CompressionResult::new(original_tokens, current_tokens, "context_compression")
559 }
560
561 pub fn collapse_tool_references(message: &Message) -> Message {
578 let mut has_non_reference = false;
579 let mut has_reference = false;
580
581 for content in &message.content {
583 match content {
584 MessageContent::ToolResponse(resp) => {
585 if Self::is_tool_reference_response(resp) {
587 has_reference = true;
588 } else {
589 has_non_reference = true;
590 }
591 }
592 _ => {
593 has_non_reference = true;
594 }
595 }
596 }
597
598 if !has_reference {
600 return message.clone();
601 }
602
603 let mut new_content: Vec<MessageContent> = Vec::new();
605 let mut reference_collapsed = false;
606
607 for content in &message.content {
608 match content {
609 MessageContent::ToolResponse(resp) => {
610 if Self::is_tool_reference_response(resp) {
611 if !reference_collapsed {
613 new_content.push(MessageContent::text(TOOL_REFERENCE_PLACEHOLDER));
614 reference_collapsed = true;
615 }
616 } else {
617 new_content.push(content.clone());
618 }
619 }
620 _ => {
621 new_content.push(content.clone());
622 }
623 }
624 }
625
626 if (new_content.is_empty() || (!has_non_reference && reference_collapsed))
628 && new_content.is_empty()
629 {
630 new_content.push(MessageContent::text(TOOL_REFERENCE_PLACEHOLDER));
631 }
632
633 Message {
634 id: message.id.clone(),
635 role: message.role.clone(),
636 created: message.created,
637 content: new_content,
638 metadata: message.metadata,
639 }
640 }
641
642 fn is_tool_reference_response(resp: &crate::conversation::message::ToolResponse) -> bool {
647 if let Ok(result) = &resp.tool_result {
648 for content in &result.content {
649 if let Some(text) = content.as_text() {
650 if text.text.contains("tool_reference")
652 || text.text.contains("[Reference to tool")
653 || text.text.starts_with("ref:")
654 {
655 return true;
656 }
657 }
658 }
659 }
660 false
661 }
662
663 pub fn config(&self) -> &ContextConfig {
669 &self.config
670 }
671
672 pub fn config_mut(&mut self) -> &mut ContextConfig {
674 &mut self.config
675 }
676
677 pub fn set_config(&mut self, config: ContextConfig) {
679 self.config = config;
680 }
681}
682
683impl Default for EnhancedContextManager {
684 fn default() -> Self {
685 Self::with_default_config()
686 }
687}
688
689#[cfg(test)]
694mod tests {
695 use super::*;
696
697 fn create_test_message(text: &str, is_user: bool) -> Message {
698 if is_user {
699 Message::user().with_text(text)
700 } else {
701 Message::assistant().with_text(text)
702 }
703 }
704
705 #[test]
706 fn test_new_manager() {
707 let config = ContextConfig::default();
708 let manager = EnhancedContextManager::new(config);
709
710 assert_eq!(manager.turn_count(), 0);
711 assert!(manager.system_prompt().is_empty());
712 assert!(!manager.has_summarizer_client());
713 }
714
715 #[test]
716 fn test_set_system_prompt() {
717 let mut manager = EnhancedContextManager::default();
718 manager.set_system_prompt("You are a helpful assistant.");
719
720 assert_eq!(manager.system_prompt(), "You are a helpful assistant.");
721 }
722
723 #[test]
724 fn test_add_turn() {
725 let mut manager = EnhancedContextManager::default();
726
727 let user = create_test_message("Hello", true);
728 let assistant = create_test_message("Hi there!", false);
729
730 manager.add_turn(user, assistant, None);
731
732 assert_eq!(manager.turn_count(), 1);
733 assert!(manager.get_used_tokens() > 0);
734 }
735
736 #[test]
737 fn test_add_turn_with_usage() {
738 let mut manager = EnhancedContextManager::default();
739
740 let user = create_test_message("Hello", true);
741 let assistant = create_test_message("Hi there!", false);
742 let usage = TokenUsage::new(10, 20);
743
744 manager.add_turn(user, assistant, Some(usage));
745
746 assert_eq!(manager.turn_count(), 1);
747 let turn = &manager.turns()[0];
748 assert!(turn.api_usage.is_some());
749 assert_eq!(turn.api_usage.as_ref().unwrap().input_tokens, 10);
750 }
751
752 #[test]
753 fn test_get_messages_empty() {
754 let manager = EnhancedContextManager::default();
755 let messages = manager.get_messages();
756 assert!(messages.is_empty());
757 }
758
759 #[test]
760 fn test_get_messages_with_system_prompt() {
761 let mut manager = EnhancedContextManager::default();
762 manager.set_system_prompt("System prompt");
763
764 let messages = manager.get_messages();
765 assert_eq!(messages.len(), 1);
766 }
767
768 #[test]
769 fn test_get_messages_with_turns() {
770 let mut manager = EnhancedContextManager::default();
771 manager.set_system_prompt("System prompt");
772
773 let user = create_test_message("Hello", true);
774 let assistant = create_test_message("Hi!", false);
775 manager.add_turn(user, assistant, None);
776
777 let messages = manager.get_messages();
778 assert_eq!(messages.len(), 3);
780 }
781
782 #[test]
783 fn test_get_used_tokens() {
784 let mut manager = EnhancedContextManager::default();
785 manager.set_system_prompt("Short prompt");
786
787 let initial_tokens = manager.get_used_tokens();
788 assert!(initial_tokens > 0);
789
790 let user = create_test_message("Hello world", true);
791 let assistant = create_test_message("Hi there!", false);
792 manager.add_turn(user, assistant, None);
793
794 let after_turn_tokens = manager.get_used_tokens();
795 assert!(after_turn_tokens > initial_tokens);
796 }
797
798 #[test]
799 fn test_get_available_tokens() {
800 let config = ContextConfig {
801 max_tokens: 1000,
802 reserve_tokens: 200,
803 ..Default::default()
804 };
805 let manager = EnhancedContextManager::new(config);
806
807 assert_eq!(manager.get_available_tokens(), 800);
810 }
811
812 #[test]
813 fn test_export_import() {
814 let mut manager = EnhancedContextManager::default();
815 manager.set_system_prompt("Test prompt");
816
817 let user = create_test_message("Hello", true);
818 let assistant = create_test_message("Hi!", false);
819 manager.add_turn(user, assistant, None);
820
821 let export = manager.export();
823 assert_eq!(export.system_prompt, "Test prompt");
824 assert_eq!(export.turns.len(), 1);
825
826 let mut new_manager = EnhancedContextManager::default();
828 new_manager.import(export);
829
830 assert_eq!(new_manager.system_prompt(), "Test prompt");
831 assert_eq!(new_manager.turn_count(), 1);
832 }
833
834 #[test]
835 fn test_clear() {
836 let mut manager = EnhancedContextManager::default();
837 manager.set_system_prompt("Test prompt");
838
839 let user = create_test_message("Hello", true);
840 let assistant = create_test_message("Hi!", false);
841 manager.add_turn(user, assistant, None);
842
843 manager.clear();
844
845 assert_eq!(manager.turn_count(), 0);
846 assert_eq!(manager.system_prompt(), "Test prompt"); }
848
849 #[test]
850 fn test_reset() {
851 let mut manager = EnhancedContextManager::default();
852 manager.set_system_prompt("Test prompt");
853
854 let user = create_test_message("Hello", true);
855 let assistant = create_test_message("Hi!", false);
856 manager.add_turn(user, assistant, None);
857
858 manager.reset();
859
860 assert_eq!(manager.turn_count(), 0);
861 assert!(manager.system_prompt().is_empty()); }
863
864 #[test]
865 fn test_get_stats() {
866 let mut manager = EnhancedContextManager::default();
867
868 let user = create_test_message("Hello", true);
869 let assistant = create_test_message("Hi!", false);
870 manager.add_turn(user, assistant, None);
871
872 let stats = manager.get_stats();
873 assert_eq!(stats.total_messages, 2); assert!(stats.estimated_tokens > 0);
875 assert_eq!(stats.summarized_messages, 0);
876 }
877
878 #[test]
879 fn test_get_compression_details() {
880 let mut manager = EnhancedContextManager::default();
881
882 let user = create_test_message("Hello", true);
883 let assistant = create_test_message("Hi!", false);
884 manager.add_turn(user, assistant, None);
885
886 let details = manager.get_compression_details();
887 assert_eq!(details.total_turns, 1);
888 assert_eq!(details.summarized_turns, 0);
889 assert_eq!(details.recent_turns, 1);
890 }
891
892 #[test]
893 fn test_get_context_usage() {
894 let config = ContextConfig {
895 max_tokens: 1000,
896 ..Default::default()
897 };
898 let mut manager = EnhancedContextManager::new(config);
899
900 let user = create_test_message("Hello", true);
901 let assistant = create_test_message("Hi!", false);
902 manager.add_turn(user, assistant, None);
903
904 let usage = manager.get_context_usage();
905 assert!(usage.used > 0);
906 assert_eq!(usage.total, 1000);
907 assert!(usage.percentage > 0.0);
908 }
909
910 #[test]
911 fn test_is_near_limit() {
912 let config = ContextConfig {
913 max_tokens: 100,
914 summarize_threshold: 0.5, ..Default::default()
916 };
917 let mut manager = EnhancedContextManager::new(config);
918
919 assert!(!manager.is_near_limit());
921
922 let long_text = "A".repeat(200);
924 let user = create_test_message(&long_text, true);
925 let assistant = create_test_message(&long_text, false);
926 manager.add_turn(user, assistant, None);
927
928 assert!(manager.is_near_limit());
930 }
931
932 #[test]
933 fn test_get_formatted_report() {
934 let mut manager = EnhancedContextManager::default();
935
936 let user = create_test_message("Hello", true);
937 let assistant = create_test_message("Hi!", false);
938 manager.add_turn(user, assistant, None);
939
940 let report = manager.get_formatted_report();
941 assert!(report.contains("Context Statistics"));
942 assert!(report.contains("Total messages"));
943 assert!(report.contains("Compression Details"));
944 }
945
946 #[test]
947 fn test_analyze_compression() {
948 let mut manager = EnhancedContextManager::default();
949
950 let user = create_test_message("Hello", true);
951 let assistant = create_test_message("Hi!", false);
952 manager.add_turn(user, assistant, None);
953
954 let result = manager.analyze_compression();
955 assert!(result.original_tokens > 0);
956 assert!(result.compressed_tokens > 0);
957 }
958
959 #[test]
960 fn test_collapse_tool_references_no_references() {
961 let message = Message::user().with_text("Hello world");
962 let collapsed = EnhancedContextManager::collapse_tool_references(&message);
963
964 assert_eq!(collapsed.content.len(), 1);
966 }
967
968 #[test]
969 fn test_should_compress() {
970 let config = ContextConfig {
971 max_tokens: 100,
972 summarize_threshold: 0.5,
973 ..Default::default()
974 };
975 let mut manager = EnhancedContextManager::new(config);
976
977 assert!(!manager.should_compress());
979
980 let long_text = "A".repeat(200);
982 let user = create_test_message(&long_text, true);
983 let assistant = create_test_message(&long_text, false);
984 manager.add_turn(user, assistant, None);
985
986 assert!(manager.should_compress());
988 }
989
990 #[tokio::test]
991 async fn test_compact_empty() {
992 let mut manager = EnhancedContextManager::default();
993 let result = manager.compact().await;
994 assert!(result.is_ok());
995 }
996
997 #[tokio::test]
998 async fn test_compact_with_turns() {
999 let config = ContextConfig {
1000 keep_recent_messages: 1,
1001 ..Default::default()
1002 };
1003 let mut manager = EnhancedContextManager::new(config);
1004
1005 for i in 0..5 {
1007 let user = create_test_message(&format!("Message {}", i), true);
1008 let assistant = create_test_message(&format!("Response {}", i), false);
1009 manager.add_turn(user, assistant, None);
1010 }
1011
1012 let result = manager.compact().await;
1013 assert!(result.is_ok());
1014
1015 let summarized_count = manager.turns().iter().filter(|t| t.summarized).count();
1017 assert!(summarized_count > 0);
1018 }
1019
1020 #[tokio::test]
1021 async fn test_maybe_compress_below_threshold() {
1022 let config = ContextConfig {
1023 max_tokens: 100000,
1024 summarize_threshold: 0.9,
1025 ..Default::default()
1026 };
1027 let mut manager = EnhancedContextManager::new(config);
1028
1029 let user = create_test_message("Hello", true);
1030 let assistant = create_test_message("Hi!", false);
1031 manager.add_turn(user, assistant, None);
1032
1033 let result = manager.maybe_compress().await;
1034 assert!(result.is_ok());
1035
1036 assert_eq!(manager.compression_count, 0);
1038 }
1039}