1use std::collections::{HashMap, HashSet};
7use std::path::PathBuf;
8
9use regex::Regex;
10use tokio::fs;
11use tokio::io::AsyncWriteExt;
12
13use super::types::{
14 AgentContext, AgentContextError, AgentContextResult, CompressionResult, ContextFilter,
15 ContextInheritanceConfig, ContextInheritanceType, ContextUpdate,
16};
17use crate::conversation::message::Message;
18
19#[derive(Debug)]
29pub struct AgentContextManager {
30 contexts: HashMap<String, AgentContext>,
32
33 storage_dir: PathBuf,
35}
36
37impl Default for AgentContextManager {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl AgentContextManager {
44 pub fn new() -> Self {
46 let storage_dir = dirs::data_local_dir()
47 .unwrap_or_else(|| PathBuf::from("."))
48 .join("aster")
49 .join("contexts");
50
51 Self {
52 contexts: HashMap::new(),
53 storage_dir,
54 }
55 }
56
57 pub fn with_storage_dir(storage_dir: impl Into<PathBuf>) -> Self {
59 Self {
60 contexts: HashMap::new(),
61 storage_dir: storage_dir.into(),
62 }
63 }
64
65 pub fn create_context(
67 &mut self,
68 parent: Option<&AgentContext>,
69 config: Option<ContextInheritanceConfig>,
70 ) -> AgentContext {
71 let config = config.unwrap_or_default();
72
73 let context = match parent {
74 Some(parent_ctx) => self.inherit(parent_ctx, &config),
75 None => AgentContext::new(),
76 };
77
78 self.contexts
80 .insert(context.context_id.clone(), context.clone());
81
82 context
83 }
84
85 pub fn inherit(
87 &self,
88 parent: &AgentContext,
89 config: &ContextInheritanceConfig,
90 ) -> AgentContext {
91 let mut context = AgentContext::new();
92 context.parent_context_id = Some(parent.context_id.clone());
93
94 match config.inheritance_type {
95 ContextInheritanceType::None => {
96 return context;
98 }
99 ContextInheritanceType::Full => {
100 let history = &parent.conversation_history;
102 context.conversation_history = match config.max_history_length {
103 Some(max) if history.len() > max => {
104 history.iter().rev().take(max).cloned().rev().collect()
105 }
106 _ => history.clone(),
107 };
108
109 let files = &parent.file_context;
110 context.file_context = match config.max_file_contexts {
111 Some(max) if files.len() > max => {
112 files.iter().rev().take(max).cloned().rev().collect()
113 }
114 _ => files.clone(),
115 };
116
117 let results = &parent.tool_results;
118 context.tool_results = match config.max_tool_results {
119 Some(max) if results.len() > max => {
120 results.iter().rev().take(max).cloned().rev().collect()
121 }
122 _ => results.clone(),
123 };
124
125 context.environment = parent.environment.clone();
126 context.system_prompt = parent.system_prompt.clone();
127 context.working_directory = parent.working_directory.clone();
128 }
129 ContextInheritanceType::Shallow | ContextInheritanceType::Selective => {
130 if config.inherit_conversation {
132 let history = &parent.conversation_history;
133 context.conversation_history = match config.max_history_length {
134 Some(max) if history.len() > max => {
135 history.iter().rev().take(max).cloned().rev().collect()
136 }
137 _ => history.clone(),
138 };
139 }
140
141 if config.inherit_files {
142 let files = &parent.file_context;
143 context.file_context = match config.max_file_contexts {
144 Some(max) if files.len() > max => {
145 files.iter().rev().take(max).cloned().rev().collect()
146 }
147 _ => files.clone(),
148 };
149 }
150
151 if config.inherit_tool_results {
152 let results = &parent.tool_results;
153 context.tool_results = match config.max_tool_results {
154 Some(max) if results.len() > max => {
155 results.iter().rev().take(max).cloned().rev().collect()
156 }
157 _ => results.clone(),
158 };
159 }
160
161 if config.inherit_environment {
162 context.environment = parent.environment.clone();
163 }
164
165 context.system_prompt = parent.system_prompt.clone();
166 context.working_directory = parent.working_directory.clone();
167 }
168 }
169
170 if config.filter_sensitive {
172 let filter = ContextFilter::with_defaults();
173 context = self.filter(&context, &filter);
174 }
175
176 if config.compress_context {
178 if let Some(target_tokens) = config.target_tokens {
179 let _ = self.compress(&mut context, target_tokens);
180 }
181 }
182
183 context
184 }
185
186 pub fn compress(
193 &self,
194 context: &mut AgentContext,
195 target_tokens: usize,
196 ) -> AgentContextResult<CompressionResult> {
197 let original_tokens = self.estimate_token_count(context);
198
199 if original_tokens <= target_tokens {
200 return Ok(CompressionResult {
201 original_tokens,
202 compressed_tokens: original_tokens,
203 ratio: 1.0,
204 messages_summarized: 0,
205 files_removed: 0,
206 tool_results_removed: 0,
207 });
208 }
209
210 let mut messages_summarized = 0;
211 let mut files_removed = 0;
212 let mut tool_results_removed = 0;
213
214 if context.tool_results.len() > 5 {
216 let removed = context.tool_results.len() - 5;
217 context.tool_results = context.tool_results.split_off(removed);
218 tool_results_removed = removed;
219 }
220
221 let current_tokens = self.estimate_token_count(context);
223 if current_tokens <= target_tokens {
224 return Ok(CompressionResult {
225 original_tokens,
226 compressed_tokens: current_tokens,
227 ratio: original_tokens as f64 / current_tokens as f64,
228 messages_summarized,
229 files_removed,
230 tool_results_removed,
231 });
232 }
233
234 if context.file_context.len() > 3 {
236 let removed = context.file_context.len() - 3;
237 context.file_context = context.file_context.split_off(removed);
238 files_removed = removed;
239 }
240
241 let current_tokens = self.estimate_token_count(context);
243 if current_tokens <= target_tokens {
244 return Ok(CompressionResult {
245 original_tokens,
246 compressed_tokens: current_tokens,
247 ratio: original_tokens as f64 / current_tokens as f64,
248 messages_summarized,
249 files_removed,
250 tool_results_removed,
251 });
252 }
253
254 if context.conversation_history.len() > 10 {
256 let to_summarize = context.conversation_history.len() - 10;
257 let older_messages: Vec<_> =
258 context.conversation_history.drain(..to_summarize).collect();
259
260 let summary = self.create_message_summary(&older_messages);
262 context.conversation_summary = Some(summary);
263 messages_summarized = to_summarize;
264 }
265
266 let compressed_tokens = self.estimate_token_count(context);
267 context.metadata.is_compressed = true;
268 context.metadata.compression_ratio =
269 Some(original_tokens as f64 / compressed_tokens as f64);
270 context.metadata.touch();
271
272 Ok(CompressionResult {
273 original_tokens,
274 compressed_tokens,
275 ratio: original_tokens as f64 / compressed_tokens as f64,
276 messages_summarized,
277 files_removed,
278 tool_results_removed,
279 })
280 }
281
282 pub fn filter(&self, context: &AgentContext, filter: &ContextFilter) -> AgentContext {
284 let mut filtered = context.clone();
285
286 let excluded_keys: HashSet<_> = filter
288 .excluded_env_keys
289 .iter()
290 .map(|k| k.to_uppercase())
291 .collect();
292
293 filtered
294 .environment
295 .retain(|key, _| !excluded_keys.contains(&key.to_uppercase()));
296
297 if !filter.excluded_file_patterns.is_empty() {
299 filtered.file_context.retain(|fc| {
300 let path_str = fc.path.to_string_lossy();
301 !filter
302 .excluded_file_patterns
303 .iter()
304 .any(|pattern| glob_match(pattern, &path_str))
305 });
306 }
307
308 if !filter.excluded_tools.is_empty() {
310 filtered
311 .tool_results
312 .retain(|tr| !filter.excluded_tools.contains(&tr.tool_name));
313 }
314
315 let patterns: Vec<Regex> = filter
317 .sensitive_patterns
318 .iter()
319 .filter_map(|p| Regex::new(p).ok())
320 .collect();
321
322 for fc in &mut filtered.file_context {
324 fc.content = mask_sensitive_content(&fc.content, &patterns);
325 }
326
327 for tr in &mut filtered.tool_results {
329 tr.content = mask_sensitive_content(&tr.content, &patterns);
330 }
331
332 filtered.metadata.touch();
333 filtered
334 }
335
336 pub fn merge(&self, contexts: Vec<&AgentContext>) -> AgentContext {
338 let mut merged = AgentContext::new();
339
340 for ctx in contexts {
341 merged
343 .conversation_history
344 .extend(ctx.conversation_history.clone());
345
346 for fc in &ctx.file_context {
348 if !merged.file_context.iter().any(|f| f.path == fc.path) {
349 merged.file_context.push(fc.clone());
350 }
351 }
352
353 merged.tool_results.extend(ctx.tool_results.clone());
355
356 merged.environment.extend(ctx.environment.clone());
358
359 if ctx.system_prompt.is_some() {
361 merged.system_prompt = ctx.system_prompt.clone();
362 }
363
364 if ctx.working_directory.as_os_str() != "." {
366 merged.working_directory = ctx.working_directory.clone();
367 }
368 }
369
370 merged.metadata.token_count = self.estimate_token_count(&merged);
372 merged.metadata.touch();
373
374 merged
375 }
376
377 pub fn get_context(&self, context_id: &str) -> Option<&AgentContext> {
379 self.contexts.get(context_id)
380 }
381
382 pub fn get_context_mut(&mut self, context_id: &str) -> Option<&mut AgentContext> {
384 self.contexts.get_mut(context_id)
385 }
386
387 pub fn update_context(
389 &mut self,
390 context_id: &str,
391 updates: ContextUpdate,
392 ) -> AgentContextResult<()> {
393 if !self.contexts.contains_key(context_id) {
395 return Err(AgentContextError::NotFound(context_id.to_string()));
396 }
397
398 {
400 let context = self.contexts.get_mut(context_id).unwrap();
401
402 if let Some(messages) = updates.add_messages {
403 context.conversation_history.extend(messages);
404 }
405
406 if let Some(files) = updates.add_files {
407 context.file_context.extend(files);
408 }
409
410 if let Some(results) = updates.add_tool_results {
411 context.tool_results.extend(results);
412 }
413
414 if let Some(env) = updates.set_environment {
415 context.environment.extend(env);
416 }
417
418 if let Some(prompt) = updates.set_system_prompt {
419 context.system_prompt = Some(prompt);
420 }
421
422 if let Some(dir) = updates.set_working_directory {
423 context.working_directory = dir;
424 }
425
426 if let Some(tags) = updates.add_tags {
427 for tag in tags {
428 context.metadata.add_tag(tag);
429 }
430 }
431
432 if let Some(custom) = updates.set_custom_metadata {
433 for (key, value) in custom {
434 context.metadata.set_custom(key, value);
435 }
436 }
437
438 context.metadata.touch();
439 }
440
441 let token_count = {
443 let ctx = self.contexts.get(context_id).unwrap();
444 self.estimate_token_count(ctx)
445 };
446
447 if let Some(ctx_mut) = self.contexts.get_mut(context_id) {
448 ctx_mut.metadata.token_count = token_count;
449 }
450
451 Ok(())
452 }
453
454 pub fn delete_context(&mut self, context_id: &str) -> bool {
456 self.contexts.remove(context_id).is_some()
457 }
458
459 pub async fn persist_context(&self, context: &AgentContext) -> AgentContextResult<()> {
461 fs::create_dir_all(&self.storage_dir).await?;
463
464 let file_path = self
465 .storage_dir
466 .join(format!("{}.json", context.context_id));
467
468 let json = serde_json::to_string_pretty(context)
469 .map_err(|e| AgentContextError::SerializationError(e.to_string()))?;
470
471 let mut file = fs::File::create(&file_path).await?;
472 file.write_all(json.as_bytes()).await?;
473 file.flush().await?;
474
475 Ok(())
476 }
477
478 pub async fn load_context(
480 &mut self,
481 context_id: &str,
482 ) -> AgentContextResult<Option<AgentContext>> {
483 let file_path = self.storage_dir.join(format!("{}.json", context_id));
484
485 if !file_path.exists() {
486 return Ok(None);
487 }
488
489 let json = fs::read_to_string(&file_path).await?;
490
491 let context: AgentContext = serde_json::from_str(&json)
492 .map_err(|e| AgentContextError::SerializationError(e.to_string()))?;
493
494 self.contexts
496 .insert(context_id.to_string(), context.clone());
497
498 Ok(Some(context))
499 }
500
501 pub fn estimate_token_count(&self, context: &AgentContext) -> usize {
506 let mut total_chars = 0;
507
508 for msg in &context.conversation_history {
510 for content in &msg.content {
511 total_chars += content.to_string().len();
512 }
513 }
514
515 if let Some(summary) = &context.conversation_summary {
517 total_chars += summary.len();
518 }
519
520 for fc in &context.file_context {
522 total_chars += fc.content.len();
523 }
524
525 for tr in &context.tool_results {
527 total_chars += tr.content.len();
528 }
529
530 if let Some(prompt) = &context.system_prompt {
532 total_chars += prompt.len();
533 }
534
535 total_chars / 4
537 }
538
539 pub fn update_token_count(&self, context: &mut AgentContext) {
541 context.metadata.token_count = self.estimate_token_count(context);
542 context.metadata.touch();
543 }
544
545 fn create_message_summary(&self, messages: &[Message]) -> String {
547 let mut summary = String::from("Previous conversation summary:\n");
548
549 for msg in messages {
550 let role = format!("{:?}", msg.role);
551 let content_preview: String = msg
552 .content
553 .iter()
554 .map(|c| c.to_string())
555 .collect::<Vec<_>>()
556 .join(" ");
557
558 let preview = if content_preview.chars().count() > 100 {
559 format!(
560 "{}...",
561 content_preview.chars().take(100).collect::<String>()
562 )
563 } else {
564 content_preview
565 };
566
567 summary.push_str(&format!("- {}: {}\n", role, preview));
568 }
569
570 summary
571 }
572
573 pub fn list_context_ids(&self) -> Vec<String> {
575 self.contexts.keys().cloned().collect()
576 }
577
578 pub fn storage_dir(&self) -> &PathBuf {
580 &self.storage_dir
581 }
582}
583
584fn glob_match(pattern: &str, text: &str) -> bool {
586 let pattern = pattern.replace('.', r"\.");
587 let pattern = pattern.replace('*', ".*");
588 let pattern = format!("^{}$", pattern);
589
590 Regex::new(&pattern)
591 .map(|re| re.is_match(text))
592 .unwrap_or(false)
593}
594
595fn mask_sensitive_content(content: &str, patterns: &[Regex]) -> String {
597 let mut result = content.to_string();
598
599 for pattern in patterns {
600 result = pattern.replace_all(&result, "[REDACTED]").to_string();
601 }
602
603 result
604}
605
606#[cfg(test)]
607mod tests {
608 use super::*;
609 use crate::agents::context::types::{ContextInheritanceType, FileContext, ToolExecutionResult};
610
611 #[test]
612 fn test_create_context_without_parent() {
613 let mut manager = AgentContextManager::new();
614 let context = manager.create_context(None, None);
615
616 assert!(!context.context_id.is_empty());
617 assert!(context.parent_context_id.is_none());
618 assert!(context.is_empty());
619 }
620
621 #[test]
622 fn test_create_context_with_parent() {
623 let mut manager = AgentContextManager::new();
624
625 let mut parent = AgentContext::new();
627 parent.add_message(Message::user().with_text("Hello"));
628 parent.set_env("TEST_VAR", "test_value");
629
630 let config = ContextInheritanceConfig::default();
632 let child = manager.create_context(Some(&parent), Some(config));
633
634 assert!(child.parent_context_id.is_some());
635 assert_eq!(
636 child.parent_context_id.as_ref().unwrap(),
637 &parent.context_id
638 );
639 assert_eq!(child.conversation_history.len(), 1);
640 assert_eq!(child.get_env("TEST_VAR"), Some(&"test_value".to_string()));
641 }
642
643 #[test]
644 fn test_inherit_none() {
645 let manager = AgentContextManager::new();
646
647 let mut parent = AgentContext::new();
648 parent.add_message(Message::user().with_text("Hello"));
649 parent.set_env("TEST_VAR", "test_value");
650
651 let config = ContextInheritanceConfig::none();
652 let child = manager.inherit(&parent, &config);
653
654 assert!(child.conversation_history.is_empty());
655 assert!(child.environment.is_empty());
656 }
657
658 #[test]
659 fn test_inherit_selective() {
660 let manager = AgentContextManager::new();
661
662 let mut parent = AgentContext::new();
663 parent.add_message(Message::user().with_text("Hello"));
664 parent.add_file_context(FileContext::new("/test.rs", "fn main() {}"));
665 parent.set_env("TEST_VAR", "test_value");
666
667 let config = ContextInheritanceConfig {
668 inherit_conversation: true,
669 inherit_files: false,
670 inherit_tool_results: false,
671 inherit_environment: true,
672 inheritance_type: ContextInheritanceType::Selective,
673 ..Default::default()
674 };
675
676 let child = manager.inherit(&parent, &config);
677
678 assert_eq!(child.conversation_history.len(), 1);
679 assert!(child.file_context.is_empty());
680 assert_eq!(child.get_env("TEST_VAR"), Some(&"test_value".to_string()));
681 }
682
683 #[test]
684 fn test_inherit_with_max_history() {
685 let manager = AgentContextManager::new();
686
687 let mut parent = AgentContext::new();
688 for i in 0..20 {
689 parent.add_message(Message::user().with_text(format!("Message {}", i)));
690 }
691
692 let config = ContextInheritanceConfig {
693 inherit_conversation: true,
694 max_history_length: Some(5),
695 inheritance_type: ContextInheritanceType::Selective,
696 ..Default::default()
697 };
698
699 let child = manager.inherit(&parent, &config);
700
701 assert_eq!(child.conversation_history.len(), 5);
702 }
703
704 #[test]
705 fn test_get_context() {
706 let mut manager = AgentContextManager::new();
707 let context = manager.create_context(None, None);
708 let context_id = context.context_id.clone();
709
710 let retrieved = manager.get_context(&context_id);
711 assert!(retrieved.is_some());
712 assert_eq!(retrieved.unwrap().context_id, context_id);
713 }
714
715 #[test]
716 fn test_update_context() {
717 let mut manager = AgentContextManager::new();
718 let context = manager.create_context(None, None);
719 let context_id = context.context_id.clone();
720
721 let updates = ContextUpdate {
722 add_messages: Some(vec![Message::user().with_text("New message")]),
723 set_environment: Some(HashMap::from([("KEY".to_string(), "value".to_string())])),
724 ..Default::default()
725 };
726
727 manager.update_context(&context_id, updates).unwrap();
728
729 let updated = manager.get_context(&context_id).unwrap();
730 assert_eq!(updated.conversation_history.len(), 1);
731 assert_eq!(updated.get_env("KEY"), Some(&"value".to_string()));
732 }
733
734 #[test]
735 fn test_delete_context() {
736 let mut manager = AgentContextManager::new();
737 let context = manager.create_context(None, None);
738 let context_id = context.context_id.clone();
739
740 assert!(manager.get_context(&context_id).is_some());
741 assert!(manager.delete_context(&context_id));
742 assert!(manager.get_context(&context_id).is_none());
743 }
744
745 #[test]
746 fn test_filter_sensitive_env() {
747 let manager = AgentContextManager::new();
748
749 let mut context = AgentContext::new();
750 context.set_env("API_KEY", "secret123");
751 context.set_env("NORMAL_VAR", "normal_value");
752
753 let filter = ContextFilter::with_defaults();
754 let filtered = manager.filter(&context, &filter);
755
756 assert!(filtered.get_env("API_KEY").is_none());
757 assert_eq!(
758 filtered.get_env("NORMAL_VAR"),
759 Some(&"normal_value".to_string())
760 );
761 }
762
763 #[test]
764 fn test_filter_sensitive_content() {
765 let manager = AgentContextManager::new();
766
767 let mut context = AgentContext::new();
768 context.add_file_context(FileContext::new(
769 "/config.rs",
770 "let api_key = \"sk-12345\";",
771 ));
772
773 let filter = ContextFilter::with_defaults();
774 let filtered = manager.filter(&context, &filter);
775
776 assert!(filtered.file_context[0].content.contains("[REDACTED]"));
777 }
778
779 #[test]
780 fn test_merge_contexts() {
781 let manager = AgentContextManager::new();
782
783 let mut ctx1 = AgentContext::new();
784 ctx1.add_message(Message::user().with_text("Message 1"));
785 ctx1.set_env("VAR1", "value1");
786
787 let mut ctx2 = AgentContext::new();
788 ctx2.add_message(Message::user().with_text("Message 2"));
789 ctx2.set_env("VAR2", "value2");
790
791 let merged = manager.merge(vec![&ctx1, &ctx2]);
792
793 assert_eq!(merged.conversation_history.len(), 2);
794 assert_eq!(merged.get_env("VAR1"), Some(&"value1".to_string()));
795 assert_eq!(merged.get_env("VAR2"), Some(&"value2".to_string()));
796 }
797
798 #[test]
799 fn test_merge_deduplicates_files() {
800 let manager = AgentContextManager::new();
801
802 let mut ctx1 = AgentContext::new();
803 ctx1.add_file_context(FileContext::new("/test.rs", "content1"));
804
805 let mut ctx2 = AgentContext::new();
806 ctx2.add_file_context(FileContext::new("/test.rs", "content2"));
807 ctx2.add_file_context(FileContext::new("/other.rs", "other"));
808
809 let merged = manager.merge(vec![&ctx1, &ctx2]);
810
811 assert_eq!(merged.file_context.len(), 2);
813 }
814
815 #[test]
816 fn test_compress_already_small() {
817 let manager = AgentContextManager::new();
818
819 let mut context = AgentContext::new();
820 context.add_message(Message::user().with_text("Small message"));
821
822 let result = manager.compress(&mut context, 10000).unwrap();
823
824 assert_eq!(result.messages_summarized, 0);
825 assert_eq!(result.files_removed, 0);
826 assert_eq!(result.tool_results_removed, 0);
827 }
828
829 #[test]
830 fn test_compress_removes_old_tool_results() {
831 let manager = AgentContextManager::new();
832
833 let mut context = AgentContext::new();
834 for i in 0..10 {
835 context.add_tool_result(ToolExecutionResult::success(
836 "bash",
837 format!("call-{}", i),
838 "x".repeat(1000),
839 100,
840 ));
841 }
842
843 let result = manager.compress(&mut context, 100).unwrap();
845
846 assert!(result.tool_results_removed > 0);
847 assert!(context.tool_results.len() <= 5);
848 }
849
850 #[test]
851 fn test_estimate_token_count() {
852 let manager = AgentContextManager::new();
853
854 let mut context = AgentContext::new();
855 context.add_message(Message::user().with_text("Hello world")); context.system_prompt = Some("You are helpful".to_string()); let tokens = manager.estimate_token_count(&context);
859
860 assert!(tokens > 0);
862 assert!(tokens < 100);
863 }
864
865 #[test]
866 fn test_list_context_ids() {
867 let mut manager = AgentContextManager::new();
868
869 let ctx1 = manager.create_context(None, None);
870 let ctx2 = manager.create_context(None, None);
871
872 let ids = manager.list_context_ids();
873
874 assert_eq!(ids.len(), 2);
875 assert!(ids.contains(&ctx1.context_id));
876 assert!(ids.contains(&ctx2.context_id));
877 }
878
879 #[test]
880 fn test_unique_context_ids() {
881 let mut manager = AgentContextManager::new();
882 let mut ids = std::collections::HashSet::new();
883
884 for _ in 0..100 {
885 let context = manager.create_context(None, None);
886 assert!(
887 ids.insert(context.context_id.clone()),
888 "Duplicate ID generated"
889 );
890 }
891 }
892
893 #[tokio::test]
894 async fn test_persist_and_load_context() {
895 let temp_dir = tempfile::tempdir().unwrap();
896 let mut manager = AgentContextManager::with_storage_dir(temp_dir.path());
897
898 let mut context = AgentContext::new();
899 context.add_message(Message::user().with_text("Test message"));
900 context.set_env("TEST", "value");
901
902 let context_id = context.context_id.clone();
903
904 manager.persist_context(&context).await.unwrap();
906
907 manager.contexts.clear();
909
910 let loaded = manager.load_context(&context_id).await.unwrap();
912
913 assert!(loaded.is_some());
914 let loaded = loaded.unwrap();
915 assert_eq!(loaded.context_id, context_id);
916 assert_eq!(loaded.conversation_history.len(), 1);
917 assert_eq!(loaded.get_env("TEST"), Some(&"value".to_string()));
918 }
919
920 #[tokio::test]
921 async fn test_load_nonexistent_context() {
922 let temp_dir = tempfile::tempdir().unwrap();
923 let mut manager = AgentContextManager::with_storage_dir(temp_dir.path());
924
925 let result = manager.load_context("nonexistent-id").await.unwrap();
926 assert!(result.is_none());
927 }
928}
929
930#[cfg(test)]
931mod property_tests {
932 use super::*;
933 use crate::agents::context::types::{ContextInheritanceType, FileContext, ToolExecutionResult};
934 use proptest::prelude::*;
935 use std::collections::HashSet;
936
937 fn arb_message() -> impl Strategy<Value = Message> {
940 prop::string::string_regex("[a-zA-Z0-9 ]{1,100}")
941 .unwrap()
942 .prop_map(|text| Message::user().with_text(text))
943 }
944
945 fn arb_file_context() -> impl Strategy<Value = FileContext> {
946 (
947 prop::string::string_regex("/[a-z]+/[a-z]+\\.[a-z]+").unwrap(),
948 prop::string::string_regex("[a-zA-Z0-9\\s]{1,500}").unwrap(),
949 )
950 .prop_map(|(path, content)| FileContext::new(path, content))
951 }
952
953 fn arb_tool_result() -> impl Strategy<Value = ToolExecutionResult> {
954 (
955 prop::string::string_regex("[a-z_]+").unwrap(),
956 prop::string::string_regex("[a-zA-Z0-9]{1,100}").unwrap(),
957 prop::bool::ANY,
958 )
959 .prop_map(|(tool_name, content, success)| {
960 if success {
961 ToolExecutionResult::success(
962 &tool_name,
963 uuid::Uuid::new_v4().to_string(),
964 content,
965 100,
966 )
967 } else {
968 ToolExecutionResult::failure(
969 &tool_name,
970 uuid::Uuid::new_v4().to_string(),
971 "error",
972 100,
973 )
974 }
975 })
976 }
977
978 fn arb_env_var() -> impl Strategy<Value = (String, String)> {
979 (
980 prop::string::string_regex("[A-Z_]{1,20}").unwrap(),
981 prop::string::string_regex("[a-zA-Z0-9]{1,50}").unwrap(),
982 )
983 }
984
985 fn arb_agent_context() -> impl Strategy<Value = AgentContext> {
986 (
987 prop::collection::vec(arb_message(), 0..10),
988 prop::collection::vec(arb_file_context(), 0..5),
989 prop::collection::vec(arb_tool_result(), 0..5),
990 prop::collection::vec(arb_env_var(), 0..5),
991 )
992 .prop_map(|(messages, files, tool_results, env_vars)| {
993 let mut ctx = AgentContext::new();
994 for msg in messages {
995 ctx.add_message(msg);
996 }
997 for file in files {
998 ctx.add_file_context(file);
999 }
1000 for result in tool_results {
1001 ctx.add_tool_result(result);
1002 }
1003 for (key, value) in env_vars {
1004 ctx.set_env(key, value);
1005 }
1006 ctx
1007 })
1008 }
1009
1010 fn arb_inheritance_config() -> impl Strategy<Value = ContextInheritanceConfig> {
1011 (
1012 prop::bool::ANY,
1013 prop::bool::ANY,
1014 prop::bool::ANY,
1015 prop::bool::ANY,
1016 prop::option::of(1usize..20),
1017 prop::option::of(1usize..10),
1018 prop::option::of(1usize..10),
1019 prop::sample::select(vec![
1020 ContextInheritanceType::Full,
1021 ContextInheritanceType::Shallow,
1022 ContextInheritanceType::Selective,
1023 ContextInheritanceType::None,
1024 ]),
1025 )
1026 .prop_map(
1027 |(
1028 inherit_conversation,
1029 inherit_files,
1030 inherit_tool_results,
1031 inherit_environment,
1032 max_history_length,
1033 max_file_contexts,
1034 max_tool_results,
1035 inheritance_type,
1036 )| {
1037 ContextInheritanceConfig {
1038 inherit_conversation,
1039 inherit_files,
1040 inherit_tool_results,
1041 inherit_environment,
1042 max_history_length,
1043 max_file_contexts,
1044 max_tool_results,
1045 filter_sensitive: false, compress_context: false,
1047 target_tokens: None,
1048 inheritance_type,
1049 }
1050 },
1051 )
1052 }
1053
1054 proptest! {
1055 #![proptest_config(ProptestConfig::with_cases(100))]
1056
1057 #[test]
1064 fn prop_context_unique_id_generation(count in 1usize..200) {
1065 let mut manager = AgentContextManager::new();
1066 let mut ids = HashSet::new();
1067
1068 for _ in 0..count {
1069 let context = manager.create_context(None, None);
1070 prop_assert!(
1072 ids.insert(context.context_id.clone()),
1073 "Duplicate context ID generated: {}",
1074 context.context_id
1075 );
1076 prop_assert!(!context.context_id.is_empty(), "Empty context ID generated");
1078 }
1079
1080 prop_assert_eq!(manager.list_context_ids().len(), count);
1082 }
1083
1084 #[test]
1092 fn prop_context_inheritance_consistency(
1093 parent in arb_agent_context(),
1094 config in arb_inheritance_config()
1095 ) {
1096 let manager = AgentContextManager::new();
1097 let child = manager.inherit(&parent, &config);
1098
1099 prop_assert_eq!(child.parent_context_id.as_ref(), Some(&parent.context_id));
1101
1102 match config.inheritance_type {
1103 ContextInheritanceType::None => {
1104 prop_assert!(child.conversation_history.is_empty());
1106 prop_assert!(child.file_context.is_empty());
1107 prop_assert!(child.tool_results.is_empty());
1108 prop_assert!(child.environment.is_empty());
1109 }
1110 ContextInheritanceType::Full => {
1111 let expected_history_len = match config.max_history_length {
1113 Some(max) => parent.conversation_history.len().min(max),
1114 None => parent.conversation_history.len(),
1115 };
1116 prop_assert_eq!(child.conversation_history.len(), expected_history_len);
1117
1118 let expected_files_len = match config.max_file_contexts {
1119 Some(max) => parent.file_context.len().min(max),
1120 None => parent.file_context.len(),
1121 };
1122 prop_assert_eq!(child.file_context.len(), expected_files_len);
1123
1124 let expected_results_len = match config.max_tool_results {
1125 Some(max) => parent.tool_results.len().min(max),
1126 None => parent.tool_results.len(),
1127 };
1128 prop_assert_eq!(child.tool_results.len(), expected_results_len);
1129
1130 prop_assert_eq!(child.environment.len(), parent.environment.len());
1132 }
1133 ContextInheritanceType::Shallow | ContextInheritanceType::Selective => {
1134 if config.inherit_conversation {
1136 let expected_len = match config.max_history_length {
1137 Some(max) => parent.conversation_history.len().min(max),
1138 None => parent.conversation_history.len(),
1139 };
1140 prop_assert_eq!(child.conversation_history.len(), expected_len);
1141 } else {
1142 prop_assert!(child.conversation_history.is_empty());
1143 }
1144
1145 if config.inherit_files {
1146 let expected_len = match config.max_file_contexts {
1147 Some(max) => parent.file_context.len().min(max),
1148 None => parent.file_context.len(),
1149 };
1150 prop_assert_eq!(child.file_context.len(), expected_len);
1151 } else {
1152 prop_assert!(child.file_context.is_empty());
1153 }
1154
1155 if config.inherit_tool_results {
1156 let expected_len = match config.max_tool_results {
1157 Some(max) => parent.tool_results.len().min(max),
1158 None => parent.tool_results.len(),
1159 };
1160 prop_assert_eq!(child.tool_results.len(), expected_len);
1161 } else {
1162 prop_assert!(child.tool_results.is_empty());
1163 }
1164
1165 if config.inherit_environment {
1166 prop_assert_eq!(child.environment.len(), parent.environment.len());
1167 } else {
1168 prop_assert!(child.environment.is_empty());
1169 }
1170 }
1171 }
1172 }
1173
1174 #[test]
1181 fn prop_context_compression_effectiveness(
1182 messages in prop::collection::vec(arb_message(), 15..30),
1183 files in prop::collection::vec(arb_file_context(), 5..10),
1184 tool_results in prop::collection::vec(arb_tool_result(), 8..15),
1185 target_tokens in 500usize..2000 ) {
1187 let manager = AgentContextManager::new();
1188
1189 let mut context = AgentContext::new();
1190 for msg in messages.clone() {
1191 context.add_message(msg);
1192 }
1193 for file in files.clone() {
1194 context.add_file_context(file);
1195 }
1196 for result in tool_results.clone() {
1197 context.add_tool_result(result);
1198 }
1199
1200 let original_tokens = manager.estimate_token_count(&context);
1201 let original_message_count = context.conversation_history.len();
1202 let original_file_count = context.file_context.len();
1203 let original_tool_count = context.tool_results.len();
1204
1205 if original_tokens > target_tokens {
1207 let result = manager.compress(&mut context, target_tokens).unwrap();
1208
1209 if result.tool_results_removed > 0 {
1215 prop_assert!(
1216 context.tool_results.len() <= 5,
1217 "Tool results should be limited to 5 after compression removed some"
1218 );
1219 }
1220
1221 if result.files_removed > 0 {
1223 prop_assert!(
1224 context.file_context.len() <= 3,
1225 "File contexts should be limited to 3 after compression removed some"
1226 );
1227 }
1228
1229 if result.messages_summarized > 0 {
1231 let remaining_count = context.conversation_history.len();
1233 prop_assert!(
1234 remaining_count <= original_message_count,
1235 "Message count should not increase after compression"
1236 );
1237
1238 prop_assert!(context.metadata.is_compressed);
1240 }
1241
1242 let something_removed = result.tool_results_removed > 0
1244 || result.files_removed > 0
1245 || result.messages_summarized > 0;
1246
1247 if original_tool_count > 5 || original_file_count > 3 || original_message_count > 10 {
1250 prop_assert!(
1251 something_removed,
1252 "Compression should remove content when over limits"
1253 );
1254 }
1255
1256 prop_assert!(
1258 result.ratio > 0.0,
1259 "Compression ratio should be positive"
1260 );
1261 }
1262 }
1263
1264 #[test]
1271 fn prop_sensitive_data_filtering(
1272 normal_env_vars in prop::collection::vec(
1273 (
1274 prop::string::string_regex("[A-Z]{3,10}_VAR").unwrap(),
1275 prop::string::string_regex("[a-z0-9]{5,20}").unwrap()
1276 ),
1277 1..5
1278 ),
1279 sensitive_env_keys in prop::sample::subsequence(
1280 vec!["API_KEY", "SECRET", "PASSWORD", "TOKEN", "PRIVATE_KEY"],
1281 1..4
1282 ),
1283 file_with_sensitive in prop::bool::ANY
1284 ) {
1285 let manager = AgentContextManager::new();
1286 let filter = ContextFilter::with_defaults();
1287
1288 let mut context = AgentContext::new();
1289
1290 for (key, value) in &normal_env_vars {
1292 context.set_env(key, value);
1293 }
1294
1295 for key in &sensitive_env_keys {
1297 context.set_env(*key, "sensitive_value_12345");
1298 }
1299
1300 if file_with_sensitive {
1302 context.add_file_context(FileContext::new(
1303 "/config.rs",
1304 "let api_key = \"sk-secret123\"; let password = \"hunter2\";",
1305 ));
1306 }
1307
1308 let filtered = manager.filter(&context, &filter);
1309
1310 for key in &sensitive_env_keys {
1312 prop_assert!(
1313 filtered.get_env(key).is_none(),
1314 "Sensitive env var {} should be filtered",
1315 key
1316 );
1317 }
1318
1319 for (key, value) in &normal_env_vars {
1321 let key_upper = key.to_uppercase();
1323 if !key_upper.contains("API") && !key_upper.contains("SECRET")
1324 && !key_upper.contains("PASSWORD") && !key_upper.contains("TOKEN")
1325 && !key_upper.contains("KEY")
1326 {
1327 prop_assert_eq!(
1328 filtered.get_env(key),
1329 Some(value),
1330 "Normal env var {} should be preserved",
1331 key
1332 );
1333 }
1334 }
1335
1336 if file_with_sensitive && !filtered.file_context.is_empty() {
1338 let content = &filtered.file_context[0].content;
1339 prop_assert!(
1340 content.contains("[REDACTED]") || !content.contains("api_key"),
1341 "Sensitive content in files should be redacted"
1342 );
1343 }
1344 }
1345
1346 #[test]
1353 fn prop_context_merge_completeness(
1354 contexts in prop::collection::vec(arb_agent_context(), 2..5)
1355 ) {
1356 let manager = AgentContextManager::new();
1357
1358 let total_messages: usize = contexts.iter()
1360 .map(|c| c.conversation_history.len())
1361 .sum();
1362
1363 let total_tool_results: usize = contexts.iter()
1364 .map(|c| c.tool_results.len())
1365 .sum();
1366
1367 let mut unique_file_paths = HashSet::new();
1369 for ctx in &contexts {
1370 for fc in &ctx.file_context {
1371 unique_file_paths.insert(fc.path.clone());
1372 }
1373 }
1374
1375 let mut all_env_keys = HashSet::new();
1377 for ctx in &contexts {
1378 for key in ctx.environment.keys() {
1379 all_env_keys.insert(key.clone());
1380 }
1381 }
1382
1383 let context_refs: Vec<&AgentContext> = contexts.iter().collect();
1384 let merged = manager.merge(context_refs);
1385
1386 prop_assert_eq!(
1388 merged.conversation_history.len(),
1389 total_messages,
1390 "All messages should be merged"
1391 );
1392
1393 prop_assert_eq!(
1395 merged.tool_results.len(),
1396 total_tool_results,
1397 "All tool results should be merged"
1398 );
1399
1400 prop_assert_eq!(
1402 merged.file_context.len(),
1403 unique_file_paths.len(),
1404 "Files should be deduplicated by path"
1405 );
1406
1407 for key in &all_env_keys {
1409 prop_assert!(
1410 merged.environment.contains_key(key),
1411 "Environment key {} should be present in merged context",
1412 key
1413 );
1414 }
1415
1416 let total_content_len: usize = merged.conversation_history.iter()
1419 .flat_map(|m| m.content.iter())
1420 .map(|c| c.to_string().len())
1421 .sum::<usize>()
1422 + merged.file_context.iter().map(|f| f.content.len()).sum::<usize>()
1423 + merged.tool_results.iter().map(|t| t.content.len()).sum::<usize>();
1424 if total_content_len >= 4 {
1425 prop_assert!(
1426 merged.metadata.token_count > 0,
1427 "Token count should be > 0 when content is substantial"
1428 );
1429 }
1430 }
1431 }
1432
1433 mod async_property_tests {
1435 use super::*;
1436 use tokio::runtime::Runtime;
1437
1438 #[test]
1445 fn prop_context_persistence_round_trip() {
1446 let rt = Runtime::new().unwrap();
1447
1448 proptest!(ProptestConfig::with_cases(50), |(context in arb_agent_context())| {
1449 rt.block_on(async {
1450 let temp_dir = tempfile::tempdir().unwrap();
1451 let mut manager = AgentContextManager::with_storage_dir(temp_dir.path());
1452
1453 let context_id = context.context_id.clone();
1454
1455 manager.persist_context(&context).await.unwrap();
1457
1458 let loaded = manager.load_context(&context_id).await.unwrap();
1460
1461 prop_assert!(loaded.is_some(), "Context should be loadable after persistence");
1462 let loaded = loaded.unwrap();
1463
1464 prop_assert_eq!(loaded.context_id, context.context_id);
1466 prop_assert_eq!(loaded.agent_id, context.agent_id);
1467 prop_assert_eq!(loaded.parent_context_id, context.parent_context_id);
1468 prop_assert_eq!(loaded.conversation_history.len(), context.conversation_history.len());
1469 prop_assert_eq!(loaded.file_context.len(), context.file_context.len());
1470 prop_assert_eq!(loaded.tool_results.len(), context.tool_results.len());
1471 prop_assert_eq!(loaded.environment.len(), context.environment.len());
1472 prop_assert_eq!(loaded.system_prompt, context.system_prompt);
1473 prop_assert_eq!(loaded.working_directory, context.working_directory);
1474
1475 Ok(())
1476 })?;
1477 });
1478 }
1479
1480 #[test]
1487 fn prop_token_count_accuracy() {
1488 proptest!(ProptestConfig::with_cases(100), |(context in arb_agent_context())| {
1489 let manager = AgentContextManager::new();
1490
1491 let estimated_tokens = manager.estimate_token_count(&context);
1492
1493 let mut total_chars = 0;
1495
1496 for msg in &context.conversation_history {
1497 for content in &msg.content {
1498 total_chars += content.to_string().len();
1499 }
1500 }
1501
1502 if let Some(summary) = &context.conversation_summary {
1503 total_chars += summary.len();
1504 }
1505
1506 for fc in &context.file_context {
1507 total_chars += fc.content.len();
1508 }
1509
1510 for tr in &context.tool_results {
1511 total_chars += tr.content.len();
1512 }
1513
1514 if let Some(prompt) = &context.system_prompt {
1515 total_chars += prompt.len();
1516 }
1517
1518 let expected_tokens = total_chars / 4;
1520
1521 if expected_tokens > 10 {
1524 let diff = (estimated_tokens as i64 - expected_tokens as i64).abs();
1525 let tolerance = (expected_tokens as f64 * 0.2).max(10.0) as i64;
1526 prop_assert!(
1527 diff <= tolerance,
1528 "Token count {} should be within 20% of expected {} (diff: {})",
1529 estimated_tokens,
1530 expected_tokens,
1531 diff
1532 );
1533 } else {
1534 prop_assert!(estimated_tokens <= expected_tokens + 10);
1536 }
1537 });
1538 }
1539 }
1540}