Skip to main content

a3s_code_core/context/
mod.rs

1//! Context Provider Extension Point
2//!
3//! This module provides the extension point for integrating context databases
4//! like OpenViking into the agent loop. Context providers can supply memory,
5//! resources, and skills to augment the LLM's context.
6//!
7//! ## Usage
8//!
9//! Implement the `ContextProvider` trait and register it with a session:
10//!
11//! ```ignore
12//! use a3s_code::context::{ContextProvider, ContextQuery, ContextResult};
13//!
14//! struct MyProvider { /* ... */ }
15//!
16//! #[async_trait::async_trait]
17//! impl ContextProvider for MyProvider {
18//!     fn name(&self) -> &str { "my-provider" }
19//!
20//!     async fn query(&self, query: &ContextQuery) -> anyhow::Result<ContextResult> {
21//!         // Retrieve relevant context...
22//!     }
23//! }
24//! ```
25
26pub mod assembler;
27pub mod fs_provider;
28pub mod ripgrep_provider;
29pub mod static_provider;
30
31pub use assembler::{
32    ContextAssembler, ContextAssembly, ContextAssemblyPolicy, ContextBudget, ContextSourcePolicy,
33};
34pub use fs_provider::{FileSystemContextConfig, FileSystemContextProvider};
35pub use ripgrep_provider::{RipgrepContextConfig, RipgrepContextProvider};
36pub use static_provider::StaticContextProvider;
37
38use serde::{Deserialize, Serialize};
39use std::collections::HashMap;
40
41pub const CONTEXT_PROVENANCE_METADATA: &str = "a3s.context.provenance";
42pub const CONTEXT_PRIORITY_METADATA: &str = "a3s.context.priority";
43pub const CONTEXT_TRUST_METADATA: &str = "a3s.context.trust";
44pub const CONTEXT_FRESHNESS_METADATA: &str = "a3s.context.freshness";
45
46/// Type of context being queried
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
48pub enum ContextType {
49    /// Session/user history, extracted insights
50    Memory,
51    /// Documentation, code, knowledge base
52    #[default]
53    Resource,
54    /// Agent capabilities, behavior instructions
55    Skill,
56}
57
58/// Retrieval depth for tiered context (L0/L1/L2 pattern)
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
60pub enum ContextDepth {
61    /// ~100 tokens - high-level summary
62    Abstract,
63    /// ~2k tokens - key details (default)
64    #[default]
65    Overview,
66    /// Variable - complete content
67    Full,
68}
69
70/// Query to a context provider
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ContextQuery {
73    /// The query string to search for relevant context
74    pub query: String,
75
76    /// Types of context to retrieve
77    #[serde(default)]
78    pub context_types: Vec<ContextType>,
79
80    /// Desired retrieval depth
81    #[serde(default)]
82    pub depth: ContextDepth,
83
84    /// Maximum number of results to return
85    #[serde(default = "default_max_results")]
86    pub max_results: usize,
87
88    /// Maximum total tokens across all results
89    #[serde(default = "default_max_tokens")]
90    pub max_tokens: usize,
91
92    /// Optional session ID for session-specific context
93    #[serde(default)]
94    pub session_id: Option<String>,
95
96    /// Additional provider-specific parameters
97    #[serde(default)]
98    pub params: HashMap<String, serde_json::Value>,
99}
100
101fn default_max_results() -> usize {
102    10
103}
104
105fn default_max_tokens() -> usize {
106    4000
107}
108
109impl ContextQuery {
110    /// Create a new context query with defaults
111    pub fn new(query: impl Into<String>) -> Self {
112        Self {
113            query: query.into(),
114            context_types: vec![ContextType::Resource],
115            depth: ContextDepth::default(),
116            max_results: default_max_results(),
117            max_tokens: default_max_tokens(),
118            session_id: None,
119            params: HashMap::new(),
120        }
121    }
122
123    /// Set the context types to retrieve
124    pub fn with_types(mut self, types: impl IntoIterator<Item = ContextType>) -> Self {
125        self.context_types = types.into_iter().collect();
126        self
127    }
128
129    /// Set the retrieval depth
130    pub fn with_depth(mut self, depth: ContextDepth) -> Self {
131        self.depth = depth;
132        self
133    }
134
135    /// Set the maximum number of results
136    pub fn with_max_results(mut self, max: usize) -> Self {
137        self.max_results = max;
138        self
139    }
140
141    /// Set the maximum total tokens
142    pub fn with_max_tokens(mut self, max: usize) -> Self {
143        self.max_tokens = max;
144        self
145    }
146
147    /// Set the session ID
148    pub fn with_session_id(mut self, id: impl Into<String>) -> Self {
149        self.session_id = Some(id.into());
150        self
151    }
152
153    /// Add a custom parameter
154    pub fn with_param(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
155        self.params.insert(key.into(), value);
156        self
157    }
158}
159
160/// A single piece of retrieved context
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct ContextItem {
163    /// Unique identifier for this context item
164    pub id: String,
165
166    /// Type of context
167    pub context_type: ContextType,
168
169    /// The actual content
170    pub content: String,
171
172    /// Estimated token count (informational)
173    #[serde(default)]
174    pub token_count: usize,
175
176    /// Relevance score (0.0 to 1.0)
177    #[serde(default)]
178    pub relevance: f32,
179
180    /// Optional source URI (e.g., "viking://docs/auth")
181    #[serde(default)]
182    pub source: Option<String>,
183
184    /// Additional metadata
185    #[serde(default)]
186    pub metadata: HashMap<String, serde_json::Value>,
187}
188
189impl ContextItem {
190    /// Create a new context item
191    pub fn new(
192        id: impl Into<String>,
193        context_type: ContextType,
194        content: impl Into<String>,
195    ) -> Self {
196        Self {
197            id: id.into(),
198            context_type,
199            content: content.into(),
200            token_count: 0,
201            relevance: 0.0,
202            source: None,
203            metadata: HashMap::new(),
204        }
205    }
206
207    /// Set the token count
208    pub fn with_token_count(mut self, count: usize) -> Self {
209        self.token_count = count;
210        self
211    }
212
213    /// Set the relevance score
214    pub fn with_relevance(mut self, score: f32) -> Self {
215        self.relevance = score.clamp(0.0, 1.0);
216        self
217    }
218
219    /// Set the source URI
220    pub fn with_source(mut self, source: impl Into<String>) -> Self {
221        self.source = Some(source.into());
222        self
223    }
224
225    /// Add metadata
226    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
227        self.metadata.insert(key.into(), value);
228        self
229    }
230
231    /// Set a human-readable provenance label for diagnostics and ranking.
232    pub fn with_provenance(mut self, provenance: impl Into<String>) -> Self {
233        self.metadata.insert(
234            CONTEXT_PROVENANCE_METADATA.to_string(),
235            serde_json::Value::String(provenance.into()),
236        );
237        self
238    }
239
240    /// Set priority score (0.0 to 1.0) for harness-controlled ranking.
241    pub fn with_priority(mut self, priority: f32) -> Self {
242        self.metadata.insert(
243            CONTEXT_PRIORITY_METADATA.to_string(),
244            serde_json::json!(priority.clamp(0.0, 1.0)),
245        );
246        self
247    }
248
249    /// Set trust score (0.0 to 1.0) for harness-controlled ranking.
250    pub fn with_trust(mut self, trust: f32) -> Self {
251        self.metadata.insert(
252            CONTEXT_TRUST_METADATA.to_string(),
253            serde_json::json!(trust.clamp(0.0, 1.0)),
254        );
255        self
256    }
257
258    /// Set freshness score (0.0 to 1.0) for harness-controlled ranking.
259    pub fn with_freshness(mut self, freshness: f32) -> Self {
260        self.metadata.insert(
261            CONTEXT_FRESHNESS_METADATA.to_string(),
262            serde_json::json!(freshness.clamp(0.0, 1.0)),
263        );
264        self
265    }
266
267    pub fn provenance(&self) -> Option<&str> {
268        self.metadata
269            .get(CONTEXT_PROVENANCE_METADATA)
270            .and_then(serde_json::Value::as_str)
271    }
272
273    pub fn priority(&self) -> f32 {
274        metadata_score(self.metadata.get(CONTEXT_PRIORITY_METADATA))
275    }
276
277    pub fn trust(&self) -> f32 {
278        metadata_score(self.metadata.get(CONTEXT_TRUST_METADATA))
279    }
280
281    pub fn freshness(&self) -> f32 {
282        metadata_score(self.metadata.get(CONTEXT_FRESHNESS_METADATA))
283    }
284
285    /// Format as XML tag for system prompt injection
286    pub fn to_xml(&self) -> String {
287        let source_attr = self
288            .source
289            .as_ref()
290            .map(|s| format!(" source=\"{}\"", s))
291            .unwrap_or_default();
292        let type_str = match self.context_type {
293            ContextType::Memory => "Memory",
294            ContextType::Resource => "Resource",
295            ContextType::Skill => "Skill",
296        };
297        format!(
298            "<context{} type=\"{}\">\n{}\n</context>",
299            source_attr, type_str, self.content
300        )
301    }
302}
303
304fn metadata_score(value: Option<&serde_json::Value>) -> f32 {
305    value
306        .and_then(serde_json::Value::as_f64)
307        .map(|score| (score as f32).clamp(0.0, 1.0))
308        .unwrap_or(0.0)
309}
310
311/// Result from a context provider query
312#[derive(Debug, Clone, Default, Serialize, Deserialize)]
313pub struct ContextResult {
314    /// Retrieved context items
315    pub items: Vec<ContextItem>,
316
317    /// Total tokens across all items
318    pub total_tokens: usize,
319
320    /// Name of the provider that returned these results
321    pub provider: String,
322
323    /// Whether results were truncated due to limits
324    pub truncated: bool,
325}
326
327impl ContextResult {
328    /// Create a new empty result
329    pub fn new(provider: impl Into<String>) -> Self {
330        Self {
331            items: Vec::new(),
332            total_tokens: 0,
333            provider: provider.into(),
334            truncated: false,
335        }
336    }
337
338    /// Add an item to the result
339    pub fn add_item(&mut self, item: ContextItem) {
340        self.total_tokens += item.token_count;
341        self.items.push(item);
342    }
343
344    /// Check if the result is empty
345    pub fn is_empty(&self) -> bool {
346        self.items.is_empty()
347    }
348
349    /// Format all items as XML for system prompt injection
350    pub fn to_xml(&self) -> String {
351        self.items
352            .iter()
353            .map(|item| item.to_xml())
354            .collect::<Vec<_>>()
355            .join("\n\n")
356    }
357}
358
359/// Context provider trait - implement this for OpenViking, RAG systems, etc.
360#[async_trait::async_trait]
361pub trait ContextProvider: Send + Sync {
362    /// Provider name (used for identification and logging)
363    fn name(&self) -> &str;
364
365    /// Query the provider for relevant context
366    async fn query(&self, query: &ContextQuery) -> anyhow::Result<ContextResult>;
367
368    /// Called after each turn for memory extraction (optional)
369    ///
370    /// Providers can implement this to extract and store memories from
371    /// the conversation.
372    async fn on_turn_complete(
373        &self,
374        _session_id: &str,
375        _prompt: &str,
376        _response: &str,
377    ) -> anyhow::Result<()> {
378        Ok(())
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    // ========================================================================
387    // ContextType Tests
388    // ========================================================================
389
390    #[test]
391    fn test_context_type_default() {
392        let ct: ContextType = Default::default();
393        assert_eq!(ct, ContextType::Resource);
394    }
395
396    #[test]
397    fn test_context_type_serialization() {
398        let ct = ContextType::Memory;
399        let json = serde_json::to_string(&ct).unwrap();
400        assert_eq!(json, "\"Memory\"");
401
402        let parsed: ContextType = serde_json::from_str(&json).unwrap();
403        assert_eq!(parsed, ContextType::Memory);
404    }
405
406    #[test]
407    fn test_context_type_all_variants() {
408        let types = vec![
409            ContextType::Memory,
410            ContextType::Resource,
411            ContextType::Skill,
412        ];
413        for ct in types {
414            let json = serde_json::to_string(&ct).unwrap();
415            let parsed: ContextType = serde_json::from_str(&json).unwrap();
416            assert_eq!(parsed, ct);
417        }
418    }
419
420    // ========================================================================
421    // ContextDepth Tests
422    // ========================================================================
423
424    #[test]
425    fn test_context_depth_default() {
426        let cd: ContextDepth = Default::default();
427        assert_eq!(cd, ContextDepth::Overview);
428    }
429
430    #[test]
431    fn test_context_depth_serialization() {
432        let cd = ContextDepth::Full;
433        let json = serde_json::to_string(&cd).unwrap();
434        assert_eq!(json, "\"Full\"");
435
436        let parsed: ContextDepth = serde_json::from_str(&json).unwrap();
437        assert_eq!(parsed, ContextDepth::Full);
438    }
439
440    #[test]
441    fn test_context_depth_all_variants() {
442        let depths = vec![
443            ContextDepth::Abstract,
444            ContextDepth::Overview,
445            ContextDepth::Full,
446        ];
447        for cd in depths {
448            let json = serde_json::to_string(&cd).unwrap();
449            let parsed: ContextDepth = serde_json::from_str(&json).unwrap();
450            assert_eq!(parsed, cd);
451        }
452    }
453
454    // ========================================================================
455    // ContextQuery Tests
456    // ========================================================================
457
458    #[test]
459    fn test_context_query_new() {
460        let query = ContextQuery::new("test query");
461        assert_eq!(query.query, "test query");
462        assert_eq!(query.context_types, vec![ContextType::Resource]);
463        assert_eq!(query.depth, ContextDepth::Overview);
464        assert_eq!(query.max_results, 10);
465        assert_eq!(query.max_tokens, 4000);
466        assert!(query.session_id.is_none());
467        assert!(query.params.is_empty());
468    }
469
470    #[test]
471    fn test_context_query_builder() {
472        let query = ContextQuery::new("test")
473            .with_types([ContextType::Memory, ContextType::Skill])
474            .with_depth(ContextDepth::Full)
475            .with_max_results(5)
476            .with_max_tokens(2000)
477            .with_session_id("sess-123")
478            .with_param("custom", serde_json::json!("value"));
479
480        assert_eq!(query.context_types.len(), 2);
481        assert!(query.context_types.contains(&ContextType::Memory));
482        assert!(query.context_types.contains(&ContextType::Skill));
483        assert_eq!(query.depth, ContextDepth::Full);
484        assert_eq!(query.max_results, 5);
485        assert_eq!(query.max_tokens, 2000);
486        assert_eq!(query.session_id, Some("sess-123".to_string()));
487        assert_eq!(
488            query.params.get("custom"),
489            Some(&serde_json::json!("value"))
490        );
491    }
492
493    #[test]
494    fn test_context_query_serialization() {
495        let query = ContextQuery::new("search term")
496            .with_types([ContextType::Resource])
497            .with_session_id("sess-456");
498
499        let json = serde_json::to_string(&query).unwrap();
500        let parsed: ContextQuery = serde_json::from_str(&json).unwrap();
501
502        assert_eq!(parsed.query, "search term");
503        assert_eq!(parsed.session_id, Some("sess-456".to_string()));
504    }
505
506    #[test]
507    fn test_context_query_deserialization_with_defaults() {
508        let json = r#"{"query": "minimal query"}"#;
509        let query: ContextQuery = serde_json::from_str(json).unwrap();
510
511        assert_eq!(query.query, "minimal query");
512        assert!(query.context_types.is_empty()); // Default from serde is empty vec
513        assert_eq!(query.depth, ContextDepth::Overview);
514        assert_eq!(query.max_results, 10);
515        assert_eq!(query.max_tokens, 4000);
516    }
517
518    // ========================================================================
519    // ContextItem Tests
520    // ========================================================================
521
522    #[test]
523    fn test_context_item_new() {
524        let item = ContextItem::new("item-1", ContextType::Resource, "Some content");
525        assert_eq!(item.id, "item-1");
526        assert_eq!(item.context_type, ContextType::Resource);
527        assert_eq!(item.content, "Some content");
528        assert_eq!(item.token_count, 0);
529        assert_eq!(item.relevance, 0.0);
530        assert!(item.source.is_none());
531        assert!(item.metadata.is_empty());
532    }
533
534    #[test]
535    fn test_context_item_builder() {
536        let item = ContextItem::new("item-2", ContextType::Memory, "Memory content")
537            .with_token_count(150)
538            .with_relevance(0.85)
539            .with_source("viking://memory/session-123")
540            .with_provenance("memory")
541            .with_priority(0.7)
542            .with_trust(1.2)
543            .with_freshness(-1.0)
544            .with_metadata("key", serde_json::json!("value"));
545
546        assert_eq!(item.token_count, 150);
547        assert!((item.relevance - 0.85).abs() < f32::EPSILON);
548        assert_eq!(item.source, Some("viking://memory/session-123".to_string()));
549        assert_eq!(item.provenance(), Some("memory"));
550        assert!((item.priority() - 0.7).abs() < f32::EPSILON);
551        assert!((item.trust() - 1.0).abs() < f32::EPSILON);
552        assert!(item.freshness().abs() < f32::EPSILON);
553        assert_eq!(item.metadata.get("key"), Some(&serde_json::json!("value")));
554    }
555
556    #[test]
557    fn test_context_item_relevance_clamping() {
558        let item1 = ContextItem::new("id", ContextType::Resource, "").with_relevance(1.5);
559        assert!((item1.relevance - 1.0).abs() < f32::EPSILON);
560
561        let item2 = ContextItem::new("id", ContextType::Resource, "").with_relevance(-0.5);
562        assert!(item2.relevance.abs() < f32::EPSILON);
563    }
564
565    #[test]
566    fn test_context_item_to_xml_without_source() {
567        let item = ContextItem::new("id", ContextType::Resource, "Content here");
568        let xml = item.to_xml();
569        assert_eq!(xml, "<context type=\"Resource\">\nContent here\n</context>");
570    }
571
572    #[test]
573    fn test_context_item_to_xml_with_source() {
574        let item = ContextItem::new("id", ContextType::Memory, "Memory content")
575            .with_source("viking://docs/auth");
576        let xml = item.to_xml();
577        assert_eq!(
578            xml,
579            "<context source=\"viking://docs/auth\" type=\"Memory\">\nMemory content\n</context>"
580        );
581    }
582
583    #[test]
584    fn test_context_item_to_xml_all_types() {
585        let memory = ContextItem::new("m", ContextType::Memory, "m").to_xml();
586        assert!(memory.contains("type=\"Memory\""));
587
588        let resource = ContextItem::new("r", ContextType::Resource, "r").to_xml();
589        assert!(resource.contains("type=\"Resource\""));
590
591        let skill = ContextItem::new("s", ContextType::Skill, "s").to_xml();
592        assert!(skill.contains("type=\"Skill\""));
593    }
594
595    #[test]
596    fn test_context_item_serialization() {
597        let item = ContextItem::new("item-3", ContextType::Skill, "Skill instructions")
598            .with_token_count(200)
599            .with_relevance(0.9)
600            .with_source("viking://skills/code-review");
601
602        let json = serde_json::to_string(&item).unwrap();
603        let parsed: ContextItem = serde_json::from_str(&json).unwrap();
604
605        assert_eq!(parsed.id, "item-3");
606        assert_eq!(parsed.context_type, ContextType::Skill);
607        assert_eq!(parsed.content, "Skill instructions");
608        assert_eq!(parsed.token_count, 200);
609    }
610
611    // ========================================================================
612    // ContextResult Tests
613    // ========================================================================
614
615    #[test]
616    fn test_context_result_new() {
617        let result = ContextResult::new("test-provider");
618        assert!(result.items.is_empty());
619        assert_eq!(result.total_tokens, 0);
620        assert_eq!(result.provider, "test-provider");
621        assert!(!result.truncated);
622    }
623
624    #[test]
625    fn test_context_result_add_item() {
626        let mut result = ContextResult::new("provider");
627        let item = ContextItem::new("id", ContextType::Resource, "content").with_token_count(100);
628        result.add_item(item);
629
630        assert_eq!(result.items.len(), 1);
631        assert_eq!(result.total_tokens, 100);
632    }
633
634    #[test]
635    fn test_context_result_add_multiple_items() {
636        let mut result = ContextResult::new("provider");
637        result.add_item(ContextItem::new("1", ContextType::Resource, "a").with_token_count(50));
638        result.add_item(ContextItem::new("2", ContextType::Memory, "b").with_token_count(75));
639        result.add_item(ContextItem::new("3", ContextType::Skill, "c").with_token_count(25));
640
641        assert_eq!(result.items.len(), 3);
642        assert_eq!(result.total_tokens, 150);
643    }
644
645    #[test]
646    fn test_context_result_is_empty() {
647        let empty = ContextResult::new("provider");
648        assert!(empty.is_empty());
649
650        let mut non_empty = ContextResult::new("provider");
651        non_empty.add_item(ContextItem::new("id", ContextType::Resource, "content"));
652        assert!(!non_empty.is_empty());
653    }
654
655    #[test]
656    fn test_context_result_to_xml() {
657        let mut result = ContextResult::new("provider");
658        result.add_item(
659            ContextItem::new("1", ContextType::Resource, "First content").with_source("source://1"),
660        );
661        result.add_item(ContextItem::new("2", ContextType::Memory, "Second content"));
662
663        let xml = result.to_xml();
664        assert!(xml.contains("<context source=\"source://1\" type=\"Resource\">"));
665        assert!(xml.contains("First content"));
666        assert!(xml.contains("<context type=\"Memory\">"));
667        assert!(xml.contains("Second content"));
668    }
669
670    #[test]
671    fn test_context_result_to_xml_empty() {
672        let result = ContextResult::new("provider");
673        let xml = result.to_xml();
674        assert!(xml.is_empty());
675    }
676
677    #[test]
678    fn test_context_result_serialization() {
679        let mut result = ContextResult::new("test-provider");
680        result.truncated = true;
681        result.add_item(ContextItem::new("id", ContextType::Resource, "content"));
682
683        let json = serde_json::to_string(&result).unwrap();
684        let parsed: ContextResult = serde_json::from_str(&json).unwrap();
685
686        assert_eq!(parsed.provider, "test-provider");
687        assert!(parsed.truncated);
688        assert_eq!(parsed.items.len(), 1);
689    }
690
691    #[test]
692    fn test_context_result_default() {
693        let result: ContextResult = Default::default();
694        assert!(result.items.is_empty());
695        assert_eq!(result.total_tokens, 0);
696        assert!(result.provider.is_empty());
697        assert!(!result.truncated);
698    }
699
700    // ========================================================================
701    // ContextProvider Trait Tests (with Mock)
702    // ========================================================================
703
704    struct MockContextProvider {
705        name: String,
706        items: Vec<ContextItem>,
707    }
708
709    impl MockContextProvider {
710        fn new(name: &str) -> Self {
711            Self {
712                name: name.to_string(),
713                items: Vec::new(),
714            }
715        }
716
717        fn with_items(mut self, items: Vec<ContextItem>) -> Self {
718            self.items = items;
719            self
720        }
721    }
722
723    #[async_trait::async_trait]
724    impl ContextProvider for MockContextProvider {
725        fn name(&self) -> &str {
726            &self.name
727        }
728
729        async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
730            let mut result = ContextResult::new(&self.name);
731            for item in &self.items {
732                result.add_item(item.clone());
733            }
734            Ok(result)
735        }
736    }
737
738    #[tokio::test]
739    async fn test_mock_context_provider() {
740        let provider = MockContextProvider::new("mock").with_items(vec![ContextItem::new(
741            "1",
742            ContextType::Resource,
743            "content",
744        )]);
745
746        assert_eq!(provider.name(), "mock");
747
748        let query = ContextQuery::new("test");
749        let result = provider.query(&query).await.unwrap();
750
751        assert_eq!(result.provider, "mock");
752        assert_eq!(result.items.len(), 1);
753    }
754
755    #[tokio::test]
756    async fn test_context_provider_on_turn_complete_default() {
757        let provider = MockContextProvider::new("mock");
758
759        // Default implementation should succeed
760        let result = provider
761            .on_turn_complete("session-1", "prompt", "response")
762            .await;
763        assert!(result.is_ok());
764    }
765
766    struct MockMemoryProvider {
767        memories: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
768    }
769
770    impl MockMemoryProvider {
771        fn new() -> Self {
772            Self {
773                memories: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
774            }
775        }
776    }
777
778    #[async_trait::async_trait]
779    impl ContextProvider for MockMemoryProvider {
780        fn name(&self) -> &str {
781            "memory-provider"
782        }
783
784        async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
785            Ok(ContextResult::new("memory-provider"))
786        }
787
788        async fn on_turn_complete(
789            &self,
790            session_id: &str,
791            prompt: &str,
792            response: &str,
793        ) -> anyhow::Result<()> {
794            let mut memories = self.memories.write().await;
795            memories.push((
796                session_id.to_string(),
797                prompt.to_string(),
798                response.to_string(),
799            ));
800            Ok(())
801        }
802    }
803
804    #[tokio::test]
805    async fn test_context_provider_on_turn_complete_custom() {
806        let provider = MockMemoryProvider::new();
807
808        provider
809            .on_turn_complete("sess-1", "What is Rust?", "Rust is a systems language.")
810            .await
811            .unwrap();
812
813        let memories = provider.memories.read().await;
814        assert_eq!(memories.len(), 1);
815        assert_eq!(memories[0].0, "sess-1");
816        assert_eq!(memories[0].1, "What is Rust?");
817        assert_eq!(memories[0].2, "Rust is a systems language.");
818    }
819
820    // ========================================================================
821    // Integration-style Tests
822    // ========================================================================
823
824    #[tokio::test]
825    async fn test_multiple_providers_query() {
826        let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
827            "p1-1",
828            ContextType::Resource,
829            "Resource from P1",
830        )]);
831
832        let provider2 = MockContextProvider::new("provider-2").with_items(vec![
833            ContextItem::new("p2-1", ContextType::Memory, "Memory from P2"),
834            ContextItem::new("p2-2", ContextType::Skill, "Skill from P2"),
835        ]);
836
837        let providers: Vec<&dyn ContextProvider> = vec![&provider1, &provider2];
838        let query = ContextQuery::new("test");
839
840        let mut all_items = Vec::new();
841        for provider in providers {
842            let result = provider.query(&query).await.unwrap();
843            all_items.extend(result.items);
844        }
845
846        assert_eq!(all_items.len(), 3);
847        assert!(all_items.iter().any(|i| i.id == "p1-1"));
848        assert!(all_items.iter().any(|i| i.id == "p2-1"));
849        assert!(all_items.iter().any(|i| i.id == "p2-2"));
850    }
851
852    #[test]
853    fn test_context_result_xml_formatting_complex() {
854        let mut result = ContextResult::new("openviking");
855        result.add_item(
856            ContextItem::new(
857                "doc-1",
858                ContextType::Resource,
859                "Authentication uses JWT tokens stored in httpOnly cookies.",
860            )
861            .with_source("viking://docs/auth")
862            .with_token_count(50),
863        );
864        result.add_item(
865            ContextItem::new(
866                "mem-1",
867                ContextType::Memory,
868                "User prefers TypeScript over JavaScript.",
869            )
870            .with_token_count(30),
871        );
872
873        let xml = result.to_xml();
874
875        // Verify structure
876        assert!(xml.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
877        assert!(xml.contains("Authentication uses JWT tokens"));
878        assert!(xml.contains("<context type=\"Memory\">"));
879        assert!(xml.contains("User prefers TypeScript"));
880
881        // Verify items are separated
882        assert!(xml.contains("</context>\n\n<context"));
883    }
884}