1pub mod embedding;
27pub mod fs_provider;
28pub mod vector_provider;
29pub mod vector_store;
30
31pub use fs_provider::{FileSystemContextConfig, FileSystemContextProvider};
32pub use vector_provider::{VectorContextConfig, VectorContextProvider};
33pub use vector_store::{InMemoryVectorStore, VectorStore};
34
35use serde::{Deserialize, Serialize};
36use std::collections::HashMap;
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
40pub enum ContextType {
41 Memory,
43 #[default]
45 Resource,
46 Skill,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
52pub enum ContextDepth {
53 Abstract,
55 #[default]
57 Overview,
58 Full,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ContextQuery {
65 pub query: String,
67
68 #[serde(default)]
70 pub context_types: Vec<ContextType>,
71
72 #[serde(default)]
74 pub depth: ContextDepth,
75
76 #[serde(default = "default_max_results")]
78 pub max_results: usize,
79
80 #[serde(default = "default_max_tokens")]
82 pub max_tokens: usize,
83
84 #[serde(default)]
86 pub session_id: Option<String>,
87
88 #[serde(default)]
90 pub params: HashMap<String, serde_json::Value>,
91}
92
93fn default_max_results() -> usize {
94 10
95}
96
97fn default_max_tokens() -> usize {
98 4000
99}
100
101impl ContextQuery {
102 pub fn new(query: impl Into<String>) -> Self {
104 Self {
105 query: query.into(),
106 context_types: vec![ContextType::Resource],
107 depth: ContextDepth::default(),
108 max_results: default_max_results(),
109 max_tokens: default_max_tokens(),
110 session_id: None,
111 params: HashMap::new(),
112 }
113 }
114
115 pub fn with_types(mut self, types: impl IntoIterator<Item = ContextType>) -> Self {
117 self.context_types = types.into_iter().collect();
118 self
119 }
120
121 pub fn with_depth(mut self, depth: ContextDepth) -> Self {
123 self.depth = depth;
124 self
125 }
126
127 pub fn with_max_results(mut self, max: usize) -> Self {
129 self.max_results = max;
130 self
131 }
132
133 pub fn with_max_tokens(mut self, max: usize) -> Self {
135 self.max_tokens = max;
136 self
137 }
138
139 pub fn with_session_id(mut self, id: impl Into<String>) -> Self {
141 self.session_id = Some(id.into());
142 self
143 }
144
145 pub fn with_param(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
147 self.params.insert(key.into(), value);
148 self
149 }
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct ContextItem {
155 pub id: String,
157
158 pub context_type: ContextType,
160
161 pub content: String,
163
164 #[serde(default)]
166 pub token_count: usize,
167
168 #[serde(default)]
170 pub relevance: f32,
171
172 #[serde(default)]
174 pub source: Option<String>,
175
176 #[serde(default)]
178 pub metadata: HashMap<String, serde_json::Value>,
179}
180
181impl ContextItem {
182 pub fn new(
184 id: impl Into<String>,
185 context_type: ContextType,
186 content: impl Into<String>,
187 ) -> Self {
188 Self {
189 id: id.into(),
190 context_type,
191 content: content.into(),
192 token_count: 0,
193 relevance: 0.0,
194 source: None,
195 metadata: HashMap::new(),
196 }
197 }
198
199 pub fn with_token_count(mut self, count: usize) -> Self {
201 self.token_count = count;
202 self
203 }
204
205 pub fn with_relevance(mut self, score: f32) -> Self {
207 self.relevance = score.clamp(0.0, 1.0);
208 self
209 }
210
211 pub fn with_source(mut self, source: impl Into<String>) -> Self {
213 self.source = Some(source.into());
214 self
215 }
216
217 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
219 self.metadata.insert(key.into(), value);
220 self
221 }
222
223 pub fn to_xml(&self) -> String {
225 let source_attr = self
226 .source
227 .as_ref()
228 .map(|s| format!(" source=\"{}\"", s))
229 .unwrap_or_default();
230 let type_str = match self.context_type {
231 ContextType::Memory => "Memory",
232 ContextType::Resource => "Resource",
233 ContextType::Skill => "Skill",
234 };
235 format!(
236 "<context{} type=\"{}\">\n{}\n</context>",
237 source_attr, type_str, self.content
238 )
239 }
240}
241
242#[derive(Debug, Clone, Default, Serialize, Deserialize)]
244pub struct ContextResult {
245 pub items: Vec<ContextItem>,
247
248 pub total_tokens: usize,
250
251 pub provider: String,
253
254 pub truncated: bool,
256}
257
258impl ContextResult {
259 pub fn new(provider: impl Into<String>) -> Self {
261 Self {
262 items: Vec::new(),
263 total_tokens: 0,
264 provider: provider.into(),
265 truncated: false,
266 }
267 }
268
269 pub fn add_item(&mut self, item: ContextItem) {
271 self.total_tokens += item.token_count;
272 self.items.push(item);
273 }
274
275 pub fn is_empty(&self) -> bool {
277 self.items.is_empty()
278 }
279
280 pub fn to_xml(&self) -> String {
282 self.items
283 .iter()
284 .map(|item| item.to_xml())
285 .collect::<Vec<_>>()
286 .join("\n\n")
287 }
288}
289
290#[async_trait::async_trait]
292pub trait ContextProvider: Send + Sync {
293 fn name(&self) -> &str;
295
296 async fn query(&self, query: &ContextQuery) -> anyhow::Result<ContextResult>;
298
299 async fn on_turn_complete(
304 &self,
305 _session_id: &str,
306 _prompt: &str,
307 _response: &str,
308 ) -> anyhow::Result<()> {
309 Ok(())
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
322 fn test_context_type_default() {
323 let ct: ContextType = Default::default();
324 assert_eq!(ct, ContextType::Resource);
325 }
326
327 #[test]
328 fn test_context_type_serialization() {
329 let ct = ContextType::Memory;
330 let json = serde_json::to_string(&ct).unwrap();
331 assert_eq!(json, "\"Memory\"");
332
333 let parsed: ContextType = serde_json::from_str(&json).unwrap();
334 assert_eq!(parsed, ContextType::Memory);
335 }
336
337 #[test]
338 fn test_context_type_all_variants() {
339 let types = vec![
340 ContextType::Memory,
341 ContextType::Resource,
342 ContextType::Skill,
343 ];
344 for ct in types {
345 let json = serde_json::to_string(&ct).unwrap();
346 let parsed: ContextType = serde_json::from_str(&json).unwrap();
347 assert_eq!(parsed, ct);
348 }
349 }
350
351 #[test]
356 fn test_context_depth_default() {
357 let cd: ContextDepth = Default::default();
358 assert_eq!(cd, ContextDepth::Overview);
359 }
360
361 #[test]
362 fn test_context_depth_serialization() {
363 let cd = ContextDepth::Full;
364 let json = serde_json::to_string(&cd).unwrap();
365 assert_eq!(json, "\"Full\"");
366
367 let parsed: ContextDepth = serde_json::from_str(&json).unwrap();
368 assert_eq!(parsed, ContextDepth::Full);
369 }
370
371 #[test]
372 fn test_context_depth_all_variants() {
373 let depths = vec![
374 ContextDepth::Abstract,
375 ContextDepth::Overview,
376 ContextDepth::Full,
377 ];
378 for cd in depths {
379 let json = serde_json::to_string(&cd).unwrap();
380 let parsed: ContextDepth = serde_json::from_str(&json).unwrap();
381 assert_eq!(parsed, cd);
382 }
383 }
384
385 #[test]
390 fn test_context_query_new() {
391 let query = ContextQuery::new("test query");
392 assert_eq!(query.query, "test query");
393 assert_eq!(query.context_types, vec![ContextType::Resource]);
394 assert_eq!(query.depth, ContextDepth::Overview);
395 assert_eq!(query.max_results, 10);
396 assert_eq!(query.max_tokens, 4000);
397 assert!(query.session_id.is_none());
398 assert!(query.params.is_empty());
399 }
400
401 #[test]
402 fn test_context_query_builder() {
403 let query = ContextQuery::new("test")
404 .with_types([ContextType::Memory, ContextType::Skill])
405 .with_depth(ContextDepth::Full)
406 .with_max_results(5)
407 .with_max_tokens(2000)
408 .with_session_id("sess-123")
409 .with_param("custom", serde_json::json!("value"));
410
411 assert_eq!(query.context_types.len(), 2);
412 assert!(query.context_types.contains(&ContextType::Memory));
413 assert!(query.context_types.contains(&ContextType::Skill));
414 assert_eq!(query.depth, ContextDepth::Full);
415 assert_eq!(query.max_results, 5);
416 assert_eq!(query.max_tokens, 2000);
417 assert_eq!(query.session_id, Some("sess-123".to_string()));
418 assert_eq!(
419 query.params.get("custom"),
420 Some(&serde_json::json!("value"))
421 );
422 }
423
424 #[test]
425 fn test_context_query_serialization() {
426 let query = ContextQuery::new("search term")
427 .with_types([ContextType::Resource])
428 .with_session_id("sess-456");
429
430 let json = serde_json::to_string(&query).unwrap();
431 let parsed: ContextQuery = serde_json::from_str(&json).unwrap();
432
433 assert_eq!(parsed.query, "search term");
434 assert_eq!(parsed.session_id, Some("sess-456".to_string()));
435 }
436
437 #[test]
438 fn test_context_query_deserialization_with_defaults() {
439 let json = r#"{"query": "minimal query"}"#;
440 let query: ContextQuery = serde_json::from_str(json).unwrap();
441
442 assert_eq!(query.query, "minimal query");
443 assert!(query.context_types.is_empty()); assert_eq!(query.depth, ContextDepth::Overview);
445 assert_eq!(query.max_results, 10);
446 assert_eq!(query.max_tokens, 4000);
447 }
448
449 #[test]
454 fn test_context_item_new() {
455 let item = ContextItem::new("item-1", ContextType::Resource, "Some content");
456 assert_eq!(item.id, "item-1");
457 assert_eq!(item.context_type, ContextType::Resource);
458 assert_eq!(item.content, "Some content");
459 assert_eq!(item.token_count, 0);
460 assert_eq!(item.relevance, 0.0);
461 assert!(item.source.is_none());
462 assert!(item.metadata.is_empty());
463 }
464
465 #[test]
466 fn test_context_item_builder() {
467 let item = ContextItem::new("item-2", ContextType::Memory, "Memory content")
468 .with_token_count(150)
469 .with_relevance(0.85)
470 .with_source("viking://memory/session-123")
471 .with_metadata("key", serde_json::json!("value"));
472
473 assert_eq!(item.token_count, 150);
474 assert!((item.relevance - 0.85).abs() < f32::EPSILON);
475 assert_eq!(item.source, Some("viking://memory/session-123".to_string()));
476 assert_eq!(item.metadata.get("key"), Some(&serde_json::json!("value")));
477 }
478
479 #[test]
480 fn test_context_item_relevance_clamping() {
481 let item1 = ContextItem::new("id", ContextType::Resource, "").with_relevance(1.5);
482 assert!((item1.relevance - 1.0).abs() < f32::EPSILON);
483
484 let item2 = ContextItem::new("id", ContextType::Resource, "").with_relevance(-0.5);
485 assert!(item2.relevance.abs() < f32::EPSILON);
486 }
487
488 #[test]
489 fn test_context_item_to_xml_without_source() {
490 let item = ContextItem::new("id", ContextType::Resource, "Content here");
491 let xml = item.to_xml();
492 assert_eq!(xml, "<context type=\"Resource\">\nContent here\n</context>");
493 }
494
495 #[test]
496 fn test_context_item_to_xml_with_source() {
497 let item = ContextItem::new("id", ContextType::Memory, "Memory content")
498 .with_source("viking://docs/auth");
499 let xml = item.to_xml();
500 assert_eq!(
501 xml,
502 "<context source=\"viking://docs/auth\" type=\"Memory\">\nMemory content\n</context>"
503 );
504 }
505
506 #[test]
507 fn test_context_item_to_xml_all_types() {
508 let memory = ContextItem::new("m", ContextType::Memory, "m").to_xml();
509 assert!(memory.contains("type=\"Memory\""));
510
511 let resource = ContextItem::new("r", ContextType::Resource, "r").to_xml();
512 assert!(resource.contains("type=\"Resource\""));
513
514 let skill = ContextItem::new("s", ContextType::Skill, "s").to_xml();
515 assert!(skill.contains("type=\"Skill\""));
516 }
517
518 #[test]
519 fn test_context_item_serialization() {
520 let item = ContextItem::new("item-3", ContextType::Skill, "Skill instructions")
521 .with_token_count(200)
522 .with_relevance(0.9)
523 .with_source("viking://skills/code-review");
524
525 let json = serde_json::to_string(&item).unwrap();
526 let parsed: ContextItem = serde_json::from_str(&json).unwrap();
527
528 assert_eq!(parsed.id, "item-3");
529 assert_eq!(parsed.context_type, ContextType::Skill);
530 assert_eq!(parsed.content, "Skill instructions");
531 assert_eq!(parsed.token_count, 200);
532 }
533
534 #[test]
539 fn test_context_result_new() {
540 let result = ContextResult::new("test-provider");
541 assert!(result.items.is_empty());
542 assert_eq!(result.total_tokens, 0);
543 assert_eq!(result.provider, "test-provider");
544 assert!(!result.truncated);
545 }
546
547 #[test]
548 fn test_context_result_add_item() {
549 let mut result = ContextResult::new("provider");
550 let item = ContextItem::new("id", ContextType::Resource, "content").with_token_count(100);
551 result.add_item(item);
552
553 assert_eq!(result.items.len(), 1);
554 assert_eq!(result.total_tokens, 100);
555 }
556
557 #[test]
558 fn test_context_result_add_multiple_items() {
559 let mut result = ContextResult::new("provider");
560 result.add_item(ContextItem::new("1", ContextType::Resource, "a").with_token_count(50));
561 result.add_item(ContextItem::new("2", ContextType::Memory, "b").with_token_count(75));
562 result.add_item(ContextItem::new("3", ContextType::Skill, "c").with_token_count(25));
563
564 assert_eq!(result.items.len(), 3);
565 assert_eq!(result.total_tokens, 150);
566 }
567
568 #[test]
569 fn test_context_result_is_empty() {
570 let empty = ContextResult::new("provider");
571 assert!(empty.is_empty());
572
573 let mut non_empty = ContextResult::new("provider");
574 non_empty.add_item(ContextItem::new("id", ContextType::Resource, "content"));
575 assert!(!non_empty.is_empty());
576 }
577
578 #[test]
579 fn test_context_result_to_xml() {
580 let mut result = ContextResult::new("provider");
581 result.add_item(
582 ContextItem::new("1", ContextType::Resource, "First content").with_source("source://1"),
583 );
584 result.add_item(ContextItem::new("2", ContextType::Memory, "Second content"));
585
586 let xml = result.to_xml();
587 assert!(xml.contains("<context source=\"source://1\" type=\"Resource\">"));
588 assert!(xml.contains("First content"));
589 assert!(xml.contains("<context type=\"Memory\">"));
590 assert!(xml.contains("Second content"));
591 }
592
593 #[test]
594 fn test_context_result_to_xml_empty() {
595 let result = ContextResult::new("provider");
596 let xml = result.to_xml();
597 assert!(xml.is_empty());
598 }
599
600 #[test]
601 fn test_context_result_serialization() {
602 let mut result = ContextResult::new("test-provider");
603 result.truncated = true;
604 result.add_item(ContextItem::new("id", ContextType::Resource, "content"));
605
606 let json = serde_json::to_string(&result).unwrap();
607 let parsed: ContextResult = serde_json::from_str(&json).unwrap();
608
609 assert_eq!(parsed.provider, "test-provider");
610 assert!(parsed.truncated);
611 assert_eq!(parsed.items.len(), 1);
612 }
613
614 #[test]
615 fn test_context_result_default() {
616 let result: ContextResult = Default::default();
617 assert!(result.items.is_empty());
618 assert_eq!(result.total_tokens, 0);
619 assert!(result.provider.is_empty());
620 assert!(!result.truncated);
621 }
622
623 struct MockContextProvider {
628 name: String,
629 items: Vec<ContextItem>,
630 }
631
632 impl MockContextProvider {
633 fn new(name: &str) -> Self {
634 Self {
635 name: name.to_string(),
636 items: Vec::new(),
637 }
638 }
639
640 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
641 self.items = items;
642 self
643 }
644 }
645
646 #[async_trait::async_trait]
647 impl ContextProvider for MockContextProvider {
648 fn name(&self) -> &str {
649 &self.name
650 }
651
652 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
653 let mut result = ContextResult::new(&self.name);
654 for item in &self.items {
655 result.add_item(item.clone());
656 }
657 Ok(result)
658 }
659 }
660
661 #[tokio::test]
662 async fn test_mock_context_provider() {
663 let provider = MockContextProvider::new("mock").with_items(vec![ContextItem::new(
664 "1",
665 ContextType::Resource,
666 "content",
667 )]);
668
669 assert_eq!(provider.name(), "mock");
670
671 let query = ContextQuery::new("test");
672 let result = provider.query(&query).await.unwrap();
673
674 assert_eq!(result.provider, "mock");
675 assert_eq!(result.items.len(), 1);
676 }
677
678 #[tokio::test]
679 async fn test_context_provider_on_turn_complete_default() {
680 let provider = MockContextProvider::new("mock");
681
682 let result = provider
684 .on_turn_complete("session-1", "prompt", "response")
685 .await;
686 assert!(result.is_ok());
687 }
688
689 struct MockMemoryProvider {
690 memories: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
691 }
692
693 impl MockMemoryProvider {
694 fn new() -> Self {
695 Self {
696 memories: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
697 }
698 }
699 }
700
701 #[async_trait::async_trait]
702 impl ContextProvider for MockMemoryProvider {
703 fn name(&self) -> &str {
704 "memory-provider"
705 }
706
707 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
708 Ok(ContextResult::new("memory-provider"))
709 }
710
711 async fn on_turn_complete(
712 &self,
713 session_id: &str,
714 prompt: &str,
715 response: &str,
716 ) -> anyhow::Result<()> {
717 let mut memories = self.memories.write().await;
718 memories.push((
719 session_id.to_string(),
720 prompt.to_string(),
721 response.to_string(),
722 ));
723 Ok(())
724 }
725 }
726
727 #[tokio::test]
728 async fn test_context_provider_on_turn_complete_custom() {
729 let provider = MockMemoryProvider::new();
730
731 provider
732 .on_turn_complete("sess-1", "What is Rust?", "Rust is a systems language.")
733 .await
734 .unwrap();
735
736 let memories = provider.memories.read().await;
737 assert_eq!(memories.len(), 1);
738 assert_eq!(memories[0].0, "sess-1");
739 assert_eq!(memories[0].1, "What is Rust?");
740 assert_eq!(memories[0].2, "Rust is a systems language.");
741 }
742
743 #[tokio::test]
748 async fn test_multiple_providers_query() {
749 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
750 "p1-1",
751 ContextType::Resource,
752 "Resource from P1",
753 )]);
754
755 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
756 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2"),
757 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2"),
758 ]);
759
760 let providers: Vec<&dyn ContextProvider> = vec![&provider1, &provider2];
761 let query = ContextQuery::new("test");
762
763 let mut all_items = Vec::new();
764 for provider in providers {
765 let result = provider.query(&query).await.unwrap();
766 all_items.extend(result.items);
767 }
768
769 assert_eq!(all_items.len(), 3);
770 assert!(all_items.iter().any(|i| i.id == "p1-1"));
771 assert!(all_items.iter().any(|i| i.id == "p2-1"));
772 assert!(all_items.iter().any(|i| i.id == "p2-2"));
773 }
774
775 #[test]
776 fn test_context_result_xml_formatting_complex() {
777 let mut result = ContextResult::new("openviking");
778 result.add_item(
779 ContextItem::new(
780 "doc-1",
781 ContextType::Resource,
782 "Authentication uses JWT tokens stored in httpOnly cookies.",
783 )
784 .with_source("viking://docs/auth")
785 .with_token_count(50),
786 );
787 result.add_item(
788 ContextItem::new(
789 "mem-1",
790 ContextType::Memory,
791 "User prefers TypeScript over JavaScript.",
792 )
793 .with_token_count(30),
794 );
795
796 let xml = result.to_xml();
797
798 assert!(xml.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
800 assert!(xml.contains("Authentication uses JWT tokens"));
801 assert!(xml.contains("<context type=\"Memory\">"));
802 assert!(xml.contains("User prefers TypeScript"));
803
804 assert!(xml.contains("</context>\n\n<context"));
806 }
807}