1pub mod fs_provider;
27
28pub use fs_provider::{FileSystemContextConfig, FileSystemContextProvider};
29
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
35pub enum ContextType {
36 Memory,
38 #[default]
40 Resource,
41 Skill,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
47pub enum ContextDepth {
48 Abstract,
50 #[default]
52 Overview,
53 Full,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct ContextQuery {
60 pub query: String,
62
63 #[serde(default)]
65 pub context_types: Vec<ContextType>,
66
67 #[serde(default)]
69 pub depth: ContextDepth,
70
71 #[serde(default = "default_max_results")]
73 pub max_results: usize,
74
75 #[serde(default = "default_max_tokens")]
77 pub max_tokens: usize,
78
79 #[serde(default)]
81 pub session_id: Option<String>,
82
83 #[serde(default)]
85 pub params: HashMap<String, serde_json::Value>,
86}
87
88fn default_max_results() -> usize {
89 10
90}
91
92fn default_max_tokens() -> usize {
93 4000
94}
95
96impl ContextQuery {
97 pub fn new(query: impl Into<String>) -> Self {
99 Self {
100 query: query.into(),
101 context_types: vec![ContextType::Resource],
102 depth: ContextDepth::default(),
103 max_results: default_max_results(),
104 max_tokens: default_max_tokens(),
105 session_id: None,
106 params: HashMap::new(),
107 }
108 }
109
110 pub fn with_types(mut self, types: impl IntoIterator<Item = ContextType>) -> Self {
112 self.context_types = types.into_iter().collect();
113 self
114 }
115
116 pub fn with_depth(mut self, depth: ContextDepth) -> Self {
118 self.depth = depth;
119 self
120 }
121
122 pub fn with_max_results(mut self, max: usize) -> Self {
124 self.max_results = max;
125 self
126 }
127
128 pub fn with_max_tokens(mut self, max: usize) -> Self {
130 self.max_tokens = max;
131 self
132 }
133
134 pub fn with_session_id(mut self, id: impl Into<String>) -> Self {
136 self.session_id = Some(id.into());
137 self
138 }
139
140 pub fn with_param(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
142 self.params.insert(key.into(), value);
143 self
144 }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct ContextItem {
150 pub id: String,
152
153 pub context_type: ContextType,
155
156 pub content: String,
158
159 #[serde(default)]
161 pub token_count: usize,
162
163 #[serde(default)]
165 pub relevance: f32,
166
167 #[serde(default)]
169 pub source: Option<String>,
170
171 #[serde(default)]
173 pub metadata: HashMap<String, serde_json::Value>,
174}
175
176impl ContextItem {
177 pub fn new(
179 id: impl Into<String>,
180 context_type: ContextType,
181 content: impl Into<String>,
182 ) -> Self {
183 Self {
184 id: id.into(),
185 context_type,
186 content: content.into(),
187 token_count: 0,
188 relevance: 0.0,
189 source: None,
190 metadata: HashMap::new(),
191 }
192 }
193
194 pub fn with_token_count(mut self, count: usize) -> Self {
196 self.token_count = count;
197 self
198 }
199
200 pub fn with_relevance(mut self, score: f32) -> Self {
202 self.relevance = score.clamp(0.0, 1.0);
203 self
204 }
205
206 pub fn with_source(mut self, source: impl Into<String>) -> Self {
208 self.source = Some(source.into());
209 self
210 }
211
212 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
214 self.metadata.insert(key.into(), value);
215 self
216 }
217
218 pub fn to_xml(&self) -> String {
220 let source_attr = self
221 .source
222 .as_ref()
223 .map(|s| format!(" source=\"{}\"", s))
224 .unwrap_or_default();
225 let type_str = match self.context_type {
226 ContextType::Memory => "Memory",
227 ContextType::Resource => "Resource",
228 ContextType::Skill => "Skill",
229 };
230 format!(
231 "<context{} type=\"{}\">\n{}\n</context>",
232 source_attr, type_str, self.content
233 )
234 }
235}
236
237#[derive(Debug, Clone, Default, Serialize, Deserialize)]
239pub struct ContextResult {
240 pub items: Vec<ContextItem>,
242
243 pub total_tokens: usize,
245
246 pub provider: String,
248
249 pub truncated: bool,
251}
252
253impl ContextResult {
254 pub fn new(provider: impl Into<String>) -> Self {
256 Self {
257 items: Vec::new(),
258 total_tokens: 0,
259 provider: provider.into(),
260 truncated: false,
261 }
262 }
263
264 pub fn add_item(&mut self, item: ContextItem) {
266 self.total_tokens += item.token_count;
267 self.items.push(item);
268 }
269
270 pub fn is_empty(&self) -> bool {
272 self.items.is_empty()
273 }
274
275 pub fn to_xml(&self) -> String {
277 self.items
278 .iter()
279 .map(|item| item.to_xml())
280 .collect::<Vec<_>>()
281 .join("\n\n")
282 }
283}
284
285#[async_trait::async_trait]
287pub trait ContextProvider: Send + Sync {
288 fn name(&self) -> &str;
290
291 async fn query(&self, query: &ContextQuery) -> anyhow::Result<ContextResult>;
293
294 async fn on_turn_complete(
299 &self,
300 _session_id: &str,
301 _prompt: &str,
302 _response: &str,
303 ) -> anyhow::Result<()> {
304 Ok(())
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
317 fn test_context_type_default() {
318 let ct: ContextType = Default::default();
319 assert_eq!(ct, ContextType::Resource);
320 }
321
322 #[test]
323 fn test_context_type_serialization() {
324 let ct = ContextType::Memory;
325 let json = serde_json::to_string(&ct).unwrap();
326 assert_eq!(json, "\"Memory\"");
327
328 let parsed: ContextType = serde_json::from_str(&json).unwrap();
329 assert_eq!(parsed, ContextType::Memory);
330 }
331
332 #[test]
333 fn test_context_type_all_variants() {
334 let types = vec![
335 ContextType::Memory,
336 ContextType::Resource,
337 ContextType::Skill,
338 ];
339 for ct in types {
340 let json = serde_json::to_string(&ct).unwrap();
341 let parsed: ContextType = serde_json::from_str(&json).unwrap();
342 assert_eq!(parsed, ct);
343 }
344 }
345
346 #[test]
351 fn test_context_depth_default() {
352 let cd: ContextDepth = Default::default();
353 assert_eq!(cd, ContextDepth::Overview);
354 }
355
356 #[test]
357 fn test_context_depth_serialization() {
358 let cd = ContextDepth::Full;
359 let json = serde_json::to_string(&cd).unwrap();
360 assert_eq!(json, "\"Full\"");
361
362 let parsed: ContextDepth = serde_json::from_str(&json).unwrap();
363 assert_eq!(parsed, ContextDepth::Full);
364 }
365
366 #[test]
367 fn test_context_depth_all_variants() {
368 let depths = vec![
369 ContextDepth::Abstract,
370 ContextDepth::Overview,
371 ContextDepth::Full,
372 ];
373 for cd in depths {
374 let json = serde_json::to_string(&cd).unwrap();
375 let parsed: ContextDepth = serde_json::from_str(&json).unwrap();
376 assert_eq!(parsed, cd);
377 }
378 }
379
380 #[test]
385 fn test_context_query_new() {
386 let query = ContextQuery::new("test query");
387 assert_eq!(query.query, "test query");
388 assert_eq!(query.context_types, vec![ContextType::Resource]);
389 assert_eq!(query.depth, ContextDepth::Overview);
390 assert_eq!(query.max_results, 10);
391 assert_eq!(query.max_tokens, 4000);
392 assert!(query.session_id.is_none());
393 assert!(query.params.is_empty());
394 }
395
396 #[test]
397 fn test_context_query_builder() {
398 let query = ContextQuery::new("test")
399 .with_types([ContextType::Memory, ContextType::Skill])
400 .with_depth(ContextDepth::Full)
401 .with_max_results(5)
402 .with_max_tokens(2000)
403 .with_session_id("sess-123")
404 .with_param("custom", serde_json::json!("value"));
405
406 assert_eq!(query.context_types.len(), 2);
407 assert!(query.context_types.contains(&ContextType::Memory));
408 assert!(query.context_types.contains(&ContextType::Skill));
409 assert_eq!(query.depth, ContextDepth::Full);
410 assert_eq!(query.max_results, 5);
411 assert_eq!(query.max_tokens, 2000);
412 assert_eq!(query.session_id, Some("sess-123".to_string()));
413 assert_eq!(
414 query.params.get("custom"),
415 Some(&serde_json::json!("value"))
416 );
417 }
418
419 #[test]
420 fn test_context_query_serialization() {
421 let query = ContextQuery::new("search term")
422 .with_types([ContextType::Resource])
423 .with_session_id("sess-456");
424
425 let json = serde_json::to_string(&query).unwrap();
426 let parsed: ContextQuery = serde_json::from_str(&json).unwrap();
427
428 assert_eq!(parsed.query, "search term");
429 assert_eq!(parsed.session_id, Some("sess-456".to_string()));
430 }
431
432 #[test]
433 fn test_context_query_deserialization_with_defaults() {
434 let json = r#"{"query": "minimal query"}"#;
435 let query: ContextQuery = serde_json::from_str(json).unwrap();
436
437 assert_eq!(query.query, "minimal query");
438 assert!(query.context_types.is_empty()); assert_eq!(query.depth, ContextDepth::Overview);
440 assert_eq!(query.max_results, 10);
441 assert_eq!(query.max_tokens, 4000);
442 }
443
444 #[test]
449 fn test_context_item_new() {
450 let item = ContextItem::new("item-1", ContextType::Resource, "Some content");
451 assert_eq!(item.id, "item-1");
452 assert_eq!(item.context_type, ContextType::Resource);
453 assert_eq!(item.content, "Some content");
454 assert_eq!(item.token_count, 0);
455 assert_eq!(item.relevance, 0.0);
456 assert!(item.source.is_none());
457 assert!(item.metadata.is_empty());
458 }
459
460 #[test]
461 fn test_context_item_builder() {
462 let item = ContextItem::new("item-2", ContextType::Memory, "Memory content")
463 .with_token_count(150)
464 .with_relevance(0.85)
465 .with_source("viking://memory/session-123")
466 .with_metadata("key", serde_json::json!("value"));
467
468 assert_eq!(item.token_count, 150);
469 assert!((item.relevance - 0.85).abs() < f32::EPSILON);
470 assert_eq!(item.source, Some("viking://memory/session-123".to_string()));
471 assert_eq!(item.metadata.get("key"), Some(&serde_json::json!("value")));
472 }
473
474 #[test]
475 fn test_context_item_relevance_clamping() {
476 let item1 = ContextItem::new("id", ContextType::Resource, "").with_relevance(1.5);
477 assert!((item1.relevance - 1.0).abs() < f32::EPSILON);
478
479 let item2 = ContextItem::new("id", ContextType::Resource, "").with_relevance(-0.5);
480 assert!(item2.relevance.abs() < f32::EPSILON);
481 }
482
483 #[test]
484 fn test_context_item_to_xml_without_source() {
485 let item = ContextItem::new("id", ContextType::Resource, "Content here");
486 let xml = item.to_xml();
487 assert_eq!(xml, "<context type=\"Resource\">\nContent here\n</context>");
488 }
489
490 #[test]
491 fn test_context_item_to_xml_with_source() {
492 let item = ContextItem::new("id", ContextType::Memory, "Memory content")
493 .with_source("viking://docs/auth");
494 let xml = item.to_xml();
495 assert_eq!(
496 xml,
497 "<context source=\"viking://docs/auth\" type=\"Memory\">\nMemory content\n</context>"
498 );
499 }
500
501 #[test]
502 fn test_context_item_to_xml_all_types() {
503 let memory = ContextItem::new("m", ContextType::Memory, "m").to_xml();
504 assert!(memory.contains("type=\"Memory\""));
505
506 let resource = ContextItem::new("r", ContextType::Resource, "r").to_xml();
507 assert!(resource.contains("type=\"Resource\""));
508
509 let skill = ContextItem::new("s", ContextType::Skill, "s").to_xml();
510 assert!(skill.contains("type=\"Skill\""));
511 }
512
513 #[test]
514 fn test_context_item_serialization() {
515 let item = ContextItem::new("item-3", ContextType::Skill, "Skill instructions")
516 .with_token_count(200)
517 .with_relevance(0.9)
518 .with_source("viking://skills/code-review");
519
520 let json = serde_json::to_string(&item).unwrap();
521 let parsed: ContextItem = serde_json::from_str(&json).unwrap();
522
523 assert_eq!(parsed.id, "item-3");
524 assert_eq!(parsed.context_type, ContextType::Skill);
525 assert_eq!(parsed.content, "Skill instructions");
526 assert_eq!(parsed.token_count, 200);
527 }
528
529 #[test]
534 fn test_context_result_new() {
535 let result = ContextResult::new("test-provider");
536 assert!(result.items.is_empty());
537 assert_eq!(result.total_tokens, 0);
538 assert_eq!(result.provider, "test-provider");
539 assert!(!result.truncated);
540 }
541
542 #[test]
543 fn test_context_result_add_item() {
544 let mut result = ContextResult::new("provider");
545 let item = ContextItem::new("id", ContextType::Resource, "content").with_token_count(100);
546 result.add_item(item);
547
548 assert_eq!(result.items.len(), 1);
549 assert_eq!(result.total_tokens, 100);
550 }
551
552 #[test]
553 fn test_context_result_add_multiple_items() {
554 let mut result = ContextResult::new("provider");
555 result.add_item(ContextItem::new("1", ContextType::Resource, "a").with_token_count(50));
556 result.add_item(ContextItem::new("2", ContextType::Memory, "b").with_token_count(75));
557 result.add_item(ContextItem::new("3", ContextType::Skill, "c").with_token_count(25));
558
559 assert_eq!(result.items.len(), 3);
560 assert_eq!(result.total_tokens, 150);
561 }
562
563 #[test]
564 fn test_context_result_is_empty() {
565 let empty = ContextResult::new("provider");
566 assert!(empty.is_empty());
567
568 let mut non_empty = ContextResult::new("provider");
569 non_empty.add_item(ContextItem::new("id", ContextType::Resource, "content"));
570 assert!(!non_empty.is_empty());
571 }
572
573 #[test]
574 fn test_context_result_to_xml() {
575 let mut result = ContextResult::new("provider");
576 result.add_item(
577 ContextItem::new("1", ContextType::Resource, "First content").with_source("source://1"),
578 );
579 result.add_item(ContextItem::new("2", ContextType::Memory, "Second content"));
580
581 let xml = result.to_xml();
582 assert!(xml.contains("<context source=\"source://1\" type=\"Resource\">"));
583 assert!(xml.contains("First content"));
584 assert!(xml.contains("<context type=\"Memory\">"));
585 assert!(xml.contains("Second content"));
586 }
587
588 #[test]
589 fn test_context_result_to_xml_empty() {
590 let result = ContextResult::new("provider");
591 let xml = result.to_xml();
592 assert!(xml.is_empty());
593 }
594
595 #[test]
596 fn test_context_result_serialization() {
597 let mut result = ContextResult::new("test-provider");
598 result.truncated = true;
599 result.add_item(ContextItem::new("id", ContextType::Resource, "content"));
600
601 let json = serde_json::to_string(&result).unwrap();
602 let parsed: ContextResult = serde_json::from_str(&json).unwrap();
603
604 assert_eq!(parsed.provider, "test-provider");
605 assert!(parsed.truncated);
606 assert_eq!(parsed.items.len(), 1);
607 }
608
609 #[test]
610 fn test_context_result_default() {
611 let result: ContextResult = Default::default();
612 assert!(result.items.is_empty());
613 assert_eq!(result.total_tokens, 0);
614 assert!(result.provider.is_empty());
615 assert!(!result.truncated);
616 }
617
618 struct MockContextProvider {
623 name: String,
624 items: Vec<ContextItem>,
625 }
626
627 impl MockContextProvider {
628 fn new(name: &str) -> Self {
629 Self {
630 name: name.to_string(),
631 items: Vec::new(),
632 }
633 }
634
635 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
636 self.items = items;
637 self
638 }
639 }
640
641 #[async_trait::async_trait]
642 impl ContextProvider for MockContextProvider {
643 fn name(&self) -> &str {
644 &self.name
645 }
646
647 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
648 let mut result = ContextResult::new(&self.name);
649 for item in &self.items {
650 result.add_item(item.clone());
651 }
652 Ok(result)
653 }
654 }
655
656 #[tokio::test]
657 async fn test_mock_context_provider() {
658 let provider = MockContextProvider::new("mock").with_items(vec![ContextItem::new(
659 "1",
660 ContextType::Resource,
661 "content",
662 )]);
663
664 assert_eq!(provider.name(), "mock");
665
666 let query = ContextQuery::new("test");
667 let result = provider.query(&query).await.unwrap();
668
669 assert_eq!(result.provider, "mock");
670 assert_eq!(result.items.len(), 1);
671 }
672
673 #[tokio::test]
674 async fn test_context_provider_on_turn_complete_default() {
675 let provider = MockContextProvider::new("mock");
676
677 let result = provider
679 .on_turn_complete("session-1", "prompt", "response")
680 .await;
681 assert!(result.is_ok());
682 }
683
684 struct MockMemoryProvider {
685 memories: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
686 }
687
688 impl MockMemoryProvider {
689 fn new() -> Self {
690 Self {
691 memories: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
692 }
693 }
694 }
695
696 #[async_trait::async_trait]
697 impl ContextProvider for MockMemoryProvider {
698 fn name(&self) -> &str {
699 "memory-provider"
700 }
701
702 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
703 Ok(ContextResult::new("memory-provider"))
704 }
705
706 async fn on_turn_complete(
707 &self,
708 session_id: &str,
709 prompt: &str,
710 response: &str,
711 ) -> anyhow::Result<()> {
712 let mut memories = self.memories.write().await;
713 memories.push((
714 session_id.to_string(),
715 prompt.to_string(),
716 response.to_string(),
717 ));
718 Ok(())
719 }
720 }
721
722 #[tokio::test]
723 async fn test_context_provider_on_turn_complete_custom() {
724 let provider = MockMemoryProvider::new();
725
726 provider
727 .on_turn_complete("sess-1", "What is Rust?", "Rust is a systems language.")
728 .await
729 .unwrap();
730
731 let memories = provider.memories.read().await;
732 assert_eq!(memories.len(), 1);
733 assert_eq!(memories[0].0, "sess-1");
734 assert_eq!(memories[0].1, "What is Rust?");
735 assert_eq!(memories[0].2, "Rust is a systems language.");
736 }
737
738 #[tokio::test]
743 async fn test_multiple_providers_query() {
744 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
745 "p1-1",
746 ContextType::Resource,
747 "Resource from P1",
748 )]);
749
750 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
751 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2"),
752 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2"),
753 ]);
754
755 let providers: Vec<&dyn ContextProvider> = vec![&provider1, &provider2];
756 let query = ContextQuery::new("test");
757
758 let mut all_items = Vec::new();
759 for provider in providers {
760 let result = provider.query(&query).await.unwrap();
761 all_items.extend(result.items);
762 }
763
764 assert_eq!(all_items.len(), 3);
765 assert!(all_items.iter().any(|i| i.id == "p1-1"));
766 assert!(all_items.iter().any(|i| i.id == "p2-1"));
767 assert!(all_items.iter().any(|i| i.id == "p2-2"));
768 }
769
770 #[test]
771 fn test_context_result_xml_formatting_complex() {
772 let mut result = ContextResult::new("openviking");
773 result.add_item(
774 ContextItem::new(
775 "doc-1",
776 ContextType::Resource,
777 "Authentication uses JWT tokens stored in httpOnly cookies.",
778 )
779 .with_source("viking://docs/auth")
780 .with_token_count(50),
781 );
782 result.add_item(
783 ContextItem::new(
784 "mem-1",
785 ContextType::Memory,
786 "User prefers TypeScript over JavaScript.",
787 )
788 .with_token_count(30),
789 );
790
791 let xml = result.to_xml();
792
793 assert!(xml.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
795 assert!(xml.contains("Authentication uses JWT tokens"));
796 assert!(xml.contains("<context type=\"Memory\">"));
797 assert!(xml.contains("User prefers TypeScript"));
798
799 assert!(xml.contains("</context>\n\n<context"));
801 }
802}