1pub mod assembler;
27pub mod fs_provider;
28pub mod ripgrep_provider;
29pub mod static_provider;
30
31pub use assembler::{
32 ContextAssembler, ContextAssembly, ContextAssemblyPolicy, ContextBudget, ContextSourcePolicy,
33};
34pub use fs_provider::{FileSystemContextConfig, FileSystemContextProvider};
35pub use ripgrep_provider::{RipgrepContextConfig, RipgrepContextProvider};
36pub use static_provider::StaticContextProvider;
37
38use serde::{Deserialize, Serialize};
39use std::collections::HashMap;
40
41pub const CONTEXT_PROVENANCE_METADATA: &str = "a3s.context.provenance";
42pub const CONTEXT_PRIORITY_METADATA: &str = "a3s.context.priority";
43pub const CONTEXT_TRUST_METADATA: &str = "a3s.context.trust";
44pub const CONTEXT_FRESHNESS_METADATA: &str = "a3s.context.freshness";
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
48pub enum ContextType {
49 Memory,
51 #[default]
53 Resource,
54 Skill,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
60pub enum ContextDepth {
61 Abstract,
63 #[default]
65 Overview,
66 Full,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ContextQuery {
73 pub query: String,
75
76 #[serde(default)]
78 pub context_types: Vec<ContextType>,
79
80 #[serde(default)]
82 pub depth: ContextDepth,
83
84 #[serde(default = "default_max_results")]
86 pub max_results: usize,
87
88 #[serde(default = "default_max_tokens")]
90 pub max_tokens: usize,
91
92 #[serde(default)]
94 pub session_id: Option<String>,
95
96 #[serde(default)]
98 pub params: HashMap<String, serde_json::Value>,
99}
100
101fn default_max_results() -> usize {
102 10
103}
104
105fn default_max_tokens() -> usize {
106 4000
107}
108
109impl ContextQuery {
110 pub fn new(query: impl Into<String>) -> Self {
112 Self {
113 query: query.into(),
114 context_types: vec![ContextType::Resource],
115 depth: ContextDepth::default(),
116 max_results: default_max_results(),
117 max_tokens: default_max_tokens(),
118 session_id: None,
119 params: HashMap::new(),
120 }
121 }
122
123 pub fn with_types(mut self, types: impl IntoIterator<Item = ContextType>) -> Self {
125 self.context_types = types.into_iter().collect();
126 self
127 }
128
129 pub fn with_depth(mut self, depth: ContextDepth) -> Self {
131 self.depth = depth;
132 self
133 }
134
135 pub fn with_max_results(mut self, max: usize) -> Self {
137 self.max_results = max;
138 self
139 }
140
141 pub fn with_max_tokens(mut self, max: usize) -> Self {
143 self.max_tokens = max;
144 self
145 }
146
147 pub fn with_session_id(mut self, id: impl Into<String>) -> Self {
149 self.session_id = Some(id.into());
150 self
151 }
152
153 pub fn with_param(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
155 self.params.insert(key.into(), value);
156 self
157 }
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct ContextItem {
163 pub id: String,
165
166 pub context_type: ContextType,
168
169 pub content: String,
171
172 #[serde(default)]
174 pub token_count: usize,
175
176 #[serde(default)]
178 pub relevance: f32,
179
180 #[serde(default)]
182 pub source: Option<String>,
183
184 #[serde(default)]
186 pub metadata: HashMap<String, serde_json::Value>,
187}
188
189impl ContextItem {
190 pub fn new(
192 id: impl Into<String>,
193 context_type: ContextType,
194 content: impl Into<String>,
195 ) -> Self {
196 Self {
197 id: id.into(),
198 context_type,
199 content: content.into(),
200 token_count: 0,
201 relevance: 0.0,
202 source: None,
203 metadata: HashMap::new(),
204 }
205 }
206
207 pub fn with_token_count(mut self, count: usize) -> Self {
209 self.token_count = count;
210 self
211 }
212
213 pub fn with_relevance(mut self, score: f32) -> Self {
215 self.relevance = score.clamp(0.0, 1.0);
216 self
217 }
218
219 pub fn with_source(mut self, source: impl Into<String>) -> Self {
221 self.source = Some(source.into());
222 self
223 }
224
225 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
227 self.metadata.insert(key.into(), value);
228 self
229 }
230
231 pub fn with_provenance(mut self, provenance: impl Into<String>) -> Self {
233 self.metadata.insert(
234 CONTEXT_PROVENANCE_METADATA.to_string(),
235 serde_json::Value::String(provenance.into()),
236 );
237 self
238 }
239
240 pub fn with_priority(mut self, priority: f32) -> Self {
242 self.metadata.insert(
243 CONTEXT_PRIORITY_METADATA.to_string(),
244 serde_json::json!(priority.clamp(0.0, 1.0)),
245 );
246 self
247 }
248
249 pub fn with_trust(mut self, trust: f32) -> Self {
251 self.metadata.insert(
252 CONTEXT_TRUST_METADATA.to_string(),
253 serde_json::json!(trust.clamp(0.0, 1.0)),
254 );
255 self
256 }
257
258 pub fn with_freshness(mut self, freshness: f32) -> Self {
260 self.metadata.insert(
261 CONTEXT_FRESHNESS_METADATA.to_string(),
262 serde_json::json!(freshness.clamp(0.0, 1.0)),
263 );
264 self
265 }
266
267 pub fn provenance(&self) -> Option<&str> {
268 self.metadata
269 .get(CONTEXT_PROVENANCE_METADATA)
270 .and_then(serde_json::Value::as_str)
271 }
272
273 pub fn priority(&self) -> f32 {
274 metadata_score(self.metadata.get(CONTEXT_PRIORITY_METADATA))
275 }
276
277 pub fn trust(&self) -> f32 {
278 metadata_score(self.metadata.get(CONTEXT_TRUST_METADATA))
279 }
280
281 pub fn freshness(&self) -> f32 {
282 metadata_score(self.metadata.get(CONTEXT_FRESHNESS_METADATA))
283 }
284
285 pub fn to_xml(&self) -> String {
287 let source_attr = self
288 .source
289 .as_ref()
290 .map(|s| format!(" source=\"{}\"", s))
291 .unwrap_or_default();
292 let type_str = match self.context_type {
293 ContextType::Memory => "Memory",
294 ContextType::Resource => "Resource",
295 ContextType::Skill => "Skill",
296 };
297 format!(
298 "<context{} type=\"{}\">\n{}\n</context>",
299 source_attr, type_str, self.content
300 )
301 }
302}
303
304fn metadata_score(value: Option<&serde_json::Value>) -> f32 {
305 value
306 .and_then(serde_json::Value::as_f64)
307 .map(|score| (score as f32).clamp(0.0, 1.0))
308 .unwrap_or(0.0)
309}
310
311#[derive(Debug, Clone, Default, Serialize, Deserialize)]
313pub struct ContextResult {
314 pub items: Vec<ContextItem>,
316
317 pub total_tokens: usize,
319
320 pub provider: String,
322
323 pub truncated: bool,
325}
326
327impl ContextResult {
328 pub fn new(provider: impl Into<String>) -> Self {
330 Self {
331 items: Vec::new(),
332 total_tokens: 0,
333 provider: provider.into(),
334 truncated: false,
335 }
336 }
337
338 pub fn add_item(&mut self, item: ContextItem) {
340 self.total_tokens += item.token_count;
341 self.items.push(item);
342 }
343
344 pub fn is_empty(&self) -> bool {
346 self.items.is_empty()
347 }
348
349 pub fn to_xml(&self) -> String {
351 self.items
352 .iter()
353 .map(|item| item.to_xml())
354 .collect::<Vec<_>>()
355 .join("\n\n")
356 }
357}
358
359#[async_trait::async_trait]
361pub trait ContextProvider: Send + Sync {
362 fn name(&self) -> &str;
364
365 async fn query(&self, query: &ContextQuery) -> anyhow::Result<ContextResult>;
367
368 async fn on_turn_complete(
373 &self,
374 _session_id: &str,
375 _prompt: &str,
376 _response: &str,
377 ) -> anyhow::Result<()> {
378 Ok(())
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
391 fn test_context_type_default() {
392 let ct: ContextType = Default::default();
393 assert_eq!(ct, ContextType::Resource);
394 }
395
396 #[test]
397 fn test_context_type_serialization() {
398 let ct = ContextType::Memory;
399 let json = serde_json::to_string(&ct).unwrap();
400 assert_eq!(json, "\"Memory\"");
401
402 let parsed: ContextType = serde_json::from_str(&json).unwrap();
403 assert_eq!(parsed, ContextType::Memory);
404 }
405
406 #[test]
407 fn test_context_type_all_variants() {
408 let types = vec![
409 ContextType::Memory,
410 ContextType::Resource,
411 ContextType::Skill,
412 ];
413 for ct in types {
414 let json = serde_json::to_string(&ct).unwrap();
415 let parsed: ContextType = serde_json::from_str(&json).unwrap();
416 assert_eq!(parsed, ct);
417 }
418 }
419
420 #[test]
425 fn test_context_depth_default() {
426 let cd: ContextDepth = Default::default();
427 assert_eq!(cd, ContextDepth::Overview);
428 }
429
430 #[test]
431 fn test_context_depth_serialization() {
432 let cd = ContextDepth::Full;
433 let json = serde_json::to_string(&cd).unwrap();
434 assert_eq!(json, "\"Full\"");
435
436 let parsed: ContextDepth = serde_json::from_str(&json).unwrap();
437 assert_eq!(parsed, ContextDepth::Full);
438 }
439
440 #[test]
441 fn test_context_depth_all_variants() {
442 let depths = vec![
443 ContextDepth::Abstract,
444 ContextDepth::Overview,
445 ContextDepth::Full,
446 ];
447 for cd in depths {
448 let json = serde_json::to_string(&cd).unwrap();
449 let parsed: ContextDepth = serde_json::from_str(&json).unwrap();
450 assert_eq!(parsed, cd);
451 }
452 }
453
454 #[test]
459 fn test_context_query_new() {
460 let query = ContextQuery::new("test query");
461 assert_eq!(query.query, "test query");
462 assert_eq!(query.context_types, vec![ContextType::Resource]);
463 assert_eq!(query.depth, ContextDepth::Overview);
464 assert_eq!(query.max_results, 10);
465 assert_eq!(query.max_tokens, 4000);
466 assert!(query.session_id.is_none());
467 assert!(query.params.is_empty());
468 }
469
470 #[test]
471 fn test_context_query_builder() {
472 let query = ContextQuery::new("test")
473 .with_types([ContextType::Memory, ContextType::Skill])
474 .with_depth(ContextDepth::Full)
475 .with_max_results(5)
476 .with_max_tokens(2000)
477 .with_session_id("sess-123")
478 .with_param("custom", serde_json::json!("value"));
479
480 assert_eq!(query.context_types.len(), 2);
481 assert!(query.context_types.contains(&ContextType::Memory));
482 assert!(query.context_types.contains(&ContextType::Skill));
483 assert_eq!(query.depth, ContextDepth::Full);
484 assert_eq!(query.max_results, 5);
485 assert_eq!(query.max_tokens, 2000);
486 assert_eq!(query.session_id, Some("sess-123".to_string()));
487 assert_eq!(
488 query.params.get("custom"),
489 Some(&serde_json::json!("value"))
490 );
491 }
492
493 #[test]
494 fn test_context_query_serialization() {
495 let query = ContextQuery::new("search term")
496 .with_types([ContextType::Resource])
497 .with_session_id("sess-456");
498
499 let json = serde_json::to_string(&query).unwrap();
500 let parsed: ContextQuery = serde_json::from_str(&json).unwrap();
501
502 assert_eq!(parsed.query, "search term");
503 assert_eq!(parsed.session_id, Some("sess-456".to_string()));
504 }
505
506 #[test]
507 fn test_context_query_deserialization_with_defaults() {
508 let json = r#"{"query": "minimal query"}"#;
509 let query: ContextQuery = serde_json::from_str(json).unwrap();
510
511 assert_eq!(query.query, "minimal query");
512 assert!(query.context_types.is_empty()); assert_eq!(query.depth, ContextDepth::Overview);
514 assert_eq!(query.max_results, 10);
515 assert_eq!(query.max_tokens, 4000);
516 }
517
518 #[test]
523 fn test_context_item_new() {
524 let item = ContextItem::new("item-1", ContextType::Resource, "Some content");
525 assert_eq!(item.id, "item-1");
526 assert_eq!(item.context_type, ContextType::Resource);
527 assert_eq!(item.content, "Some content");
528 assert_eq!(item.token_count, 0);
529 assert_eq!(item.relevance, 0.0);
530 assert!(item.source.is_none());
531 assert!(item.metadata.is_empty());
532 }
533
534 #[test]
535 fn test_context_item_builder() {
536 let item = ContextItem::new("item-2", ContextType::Memory, "Memory content")
537 .with_token_count(150)
538 .with_relevance(0.85)
539 .with_source("viking://memory/session-123")
540 .with_provenance("memory")
541 .with_priority(0.7)
542 .with_trust(1.2)
543 .with_freshness(-1.0)
544 .with_metadata("key", serde_json::json!("value"));
545
546 assert_eq!(item.token_count, 150);
547 assert!((item.relevance - 0.85).abs() < f32::EPSILON);
548 assert_eq!(item.source, Some("viking://memory/session-123".to_string()));
549 assert_eq!(item.provenance(), Some("memory"));
550 assert!((item.priority() - 0.7).abs() < f32::EPSILON);
551 assert!((item.trust() - 1.0).abs() < f32::EPSILON);
552 assert!(item.freshness().abs() < f32::EPSILON);
553 assert_eq!(item.metadata.get("key"), Some(&serde_json::json!("value")));
554 }
555
556 #[test]
557 fn test_context_item_relevance_clamping() {
558 let item1 = ContextItem::new("id", ContextType::Resource, "").with_relevance(1.5);
559 assert!((item1.relevance - 1.0).abs() < f32::EPSILON);
560
561 let item2 = ContextItem::new("id", ContextType::Resource, "").with_relevance(-0.5);
562 assert!(item2.relevance.abs() < f32::EPSILON);
563 }
564
565 #[test]
566 fn test_context_item_to_xml_without_source() {
567 let item = ContextItem::new("id", ContextType::Resource, "Content here");
568 let xml = item.to_xml();
569 assert_eq!(xml, "<context type=\"Resource\">\nContent here\n</context>");
570 }
571
572 #[test]
573 fn test_context_item_to_xml_with_source() {
574 let item = ContextItem::new("id", ContextType::Memory, "Memory content")
575 .with_source("viking://docs/auth");
576 let xml = item.to_xml();
577 assert_eq!(
578 xml,
579 "<context source=\"viking://docs/auth\" type=\"Memory\">\nMemory content\n</context>"
580 );
581 }
582
583 #[test]
584 fn test_context_item_to_xml_all_types() {
585 let memory = ContextItem::new("m", ContextType::Memory, "m").to_xml();
586 assert!(memory.contains("type=\"Memory\""));
587
588 let resource = ContextItem::new("r", ContextType::Resource, "r").to_xml();
589 assert!(resource.contains("type=\"Resource\""));
590
591 let skill = ContextItem::new("s", ContextType::Skill, "s").to_xml();
592 assert!(skill.contains("type=\"Skill\""));
593 }
594
595 #[test]
596 fn test_context_item_serialization() {
597 let item = ContextItem::new("item-3", ContextType::Skill, "Skill instructions")
598 .with_token_count(200)
599 .with_relevance(0.9)
600 .with_source("viking://skills/code-review");
601
602 let json = serde_json::to_string(&item).unwrap();
603 let parsed: ContextItem = serde_json::from_str(&json).unwrap();
604
605 assert_eq!(parsed.id, "item-3");
606 assert_eq!(parsed.context_type, ContextType::Skill);
607 assert_eq!(parsed.content, "Skill instructions");
608 assert_eq!(parsed.token_count, 200);
609 }
610
611 #[test]
616 fn test_context_result_new() {
617 let result = ContextResult::new("test-provider");
618 assert!(result.items.is_empty());
619 assert_eq!(result.total_tokens, 0);
620 assert_eq!(result.provider, "test-provider");
621 assert!(!result.truncated);
622 }
623
624 #[test]
625 fn test_context_result_add_item() {
626 let mut result = ContextResult::new("provider");
627 let item = ContextItem::new("id", ContextType::Resource, "content").with_token_count(100);
628 result.add_item(item);
629
630 assert_eq!(result.items.len(), 1);
631 assert_eq!(result.total_tokens, 100);
632 }
633
634 #[test]
635 fn test_context_result_add_multiple_items() {
636 let mut result = ContextResult::new("provider");
637 result.add_item(ContextItem::new("1", ContextType::Resource, "a").with_token_count(50));
638 result.add_item(ContextItem::new("2", ContextType::Memory, "b").with_token_count(75));
639 result.add_item(ContextItem::new("3", ContextType::Skill, "c").with_token_count(25));
640
641 assert_eq!(result.items.len(), 3);
642 assert_eq!(result.total_tokens, 150);
643 }
644
645 #[test]
646 fn test_context_result_is_empty() {
647 let empty = ContextResult::new("provider");
648 assert!(empty.is_empty());
649
650 let mut non_empty = ContextResult::new("provider");
651 non_empty.add_item(ContextItem::new("id", ContextType::Resource, "content"));
652 assert!(!non_empty.is_empty());
653 }
654
655 #[test]
656 fn test_context_result_to_xml() {
657 let mut result = ContextResult::new("provider");
658 result.add_item(
659 ContextItem::new("1", ContextType::Resource, "First content").with_source("source://1"),
660 );
661 result.add_item(ContextItem::new("2", ContextType::Memory, "Second content"));
662
663 let xml = result.to_xml();
664 assert!(xml.contains("<context source=\"source://1\" type=\"Resource\">"));
665 assert!(xml.contains("First content"));
666 assert!(xml.contains("<context type=\"Memory\">"));
667 assert!(xml.contains("Second content"));
668 }
669
670 #[test]
671 fn test_context_result_to_xml_empty() {
672 let result = ContextResult::new("provider");
673 let xml = result.to_xml();
674 assert!(xml.is_empty());
675 }
676
677 #[test]
678 fn test_context_result_serialization() {
679 let mut result = ContextResult::new("test-provider");
680 result.truncated = true;
681 result.add_item(ContextItem::new("id", ContextType::Resource, "content"));
682
683 let json = serde_json::to_string(&result).unwrap();
684 let parsed: ContextResult = serde_json::from_str(&json).unwrap();
685
686 assert_eq!(parsed.provider, "test-provider");
687 assert!(parsed.truncated);
688 assert_eq!(parsed.items.len(), 1);
689 }
690
691 #[test]
692 fn test_context_result_default() {
693 let result: ContextResult = Default::default();
694 assert!(result.items.is_empty());
695 assert_eq!(result.total_tokens, 0);
696 assert!(result.provider.is_empty());
697 assert!(!result.truncated);
698 }
699
700 struct MockContextProvider {
705 name: String,
706 items: Vec<ContextItem>,
707 }
708
709 impl MockContextProvider {
710 fn new(name: &str) -> Self {
711 Self {
712 name: name.to_string(),
713 items: Vec::new(),
714 }
715 }
716
717 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
718 self.items = items;
719 self
720 }
721 }
722
723 #[async_trait::async_trait]
724 impl ContextProvider for MockContextProvider {
725 fn name(&self) -> &str {
726 &self.name
727 }
728
729 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
730 let mut result = ContextResult::new(&self.name);
731 for item in &self.items {
732 result.add_item(item.clone());
733 }
734 Ok(result)
735 }
736 }
737
738 #[tokio::test]
739 async fn test_mock_context_provider() {
740 let provider = MockContextProvider::new("mock").with_items(vec![ContextItem::new(
741 "1",
742 ContextType::Resource,
743 "content",
744 )]);
745
746 assert_eq!(provider.name(), "mock");
747
748 let query = ContextQuery::new("test");
749 let result = provider.query(&query).await.unwrap();
750
751 assert_eq!(result.provider, "mock");
752 assert_eq!(result.items.len(), 1);
753 }
754
755 #[tokio::test]
756 async fn test_context_provider_on_turn_complete_default() {
757 let provider = MockContextProvider::new("mock");
758
759 let result = provider
761 .on_turn_complete("session-1", "prompt", "response")
762 .await;
763 assert!(result.is_ok());
764 }
765
766 struct MockMemoryProvider {
767 memories: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
768 }
769
770 impl MockMemoryProvider {
771 fn new() -> Self {
772 Self {
773 memories: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
774 }
775 }
776 }
777
778 #[async_trait::async_trait]
779 impl ContextProvider for MockMemoryProvider {
780 fn name(&self) -> &str {
781 "memory-provider"
782 }
783
784 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
785 Ok(ContextResult::new("memory-provider"))
786 }
787
788 async fn on_turn_complete(
789 &self,
790 session_id: &str,
791 prompt: &str,
792 response: &str,
793 ) -> anyhow::Result<()> {
794 let mut memories = self.memories.write().await;
795 memories.push((
796 session_id.to_string(),
797 prompt.to_string(),
798 response.to_string(),
799 ));
800 Ok(())
801 }
802 }
803
804 #[tokio::test]
805 async fn test_context_provider_on_turn_complete_custom() {
806 let provider = MockMemoryProvider::new();
807
808 provider
809 .on_turn_complete("sess-1", "What is Rust?", "Rust is a systems language.")
810 .await
811 .unwrap();
812
813 let memories = provider.memories.read().await;
814 assert_eq!(memories.len(), 1);
815 assert_eq!(memories[0].0, "sess-1");
816 assert_eq!(memories[0].1, "What is Rust?");
817 assert_eq!(memories[0].2, "Rust is a systems language.");
818 }
819
820 #[tokio::test]
825 async fn test_multiple_providers_query() {
826 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
827 "p1-1",
828 ContextType::Resource,
829 "Resource from P1",
830 )]);
831
832 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
833 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2"),
834 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2"),
835 ]);
836
837 let providers: Vec<&dyn ContextProvider> = vec![&provider1, &provider2];
838 let query = ContextQuery::new("test");
839
840 let mut all_items = Vec::new();
841 for provider in providers {
842 let result = provider.query(&query).await.unwrap();
843 all_items.extend(result.items);
844 }
845
846 assert_eq!(all_items.len(), 3);
847 assert!(all_items.iter().any(|i| i.id == "p1-1"));
848 assert!(all_items.iter().any(|i| i.id == "p2-1"));
849 assert!(all_items.iter().any(|i| i.id == "p2-2"));
850 }
851
852 #[test]
853 fn test_context_result_xml_formatting_complex() {
854 let mut result = ContextResult::new("openviking");
855 result.add_item(
856 ContextItem::new(
857 "doc-1",
858 ContextType::Resource,
859 "Authentication uses JWT tokens stored in httpOnly cookies.",
860 )
861 .with_source("viking://docs/auth")
862 .with_token_count(50),
863 );
864 result.add_item(
865 ContextItem::new(
866 "mem-1",
867 ContextType::Memory,
868 "User prefers TypeScript over JavaScript.",
869 )
870 .with_token_count(30),
871 );
872
873 let xml = result.to_xml();
874
875 assert!(xml.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
877 assert!(xml.contains("Authentication uses JWT tokens"));
878 assert!(xml.contains("<context type=\"Memory\">"));
879 assert!(xml.contains("User prefers TypeScript"));
880
881 assert!(xml.contains("</context>\n\n<context"));
883 }
884}