1pub mod fs_provider;
27pub mod ripgrep_provider;
28
29pub use fs_provider::{FileSystemContextConfig, FileSystemContextProvider};
30pub use ripgrep_provider::{RipgrepContextConfig, RipgrepContextProvider};
31
32use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
37pub enum ContextType {
38 Memory,
40 #[default]
42 Resource,
43 Skill,
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
49pub enum ContextDepth {
50 Abstract,
52 #[default]
54 Overview,
55 Full,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ContextQuery {
62 pub query: String,
64
65 #[serde(default)]
67 pub context_types: Vec<ContextType>,
68
69 #[serde(default)]
71 pub depth: ContextDepth,
72
73 #[serde(default = "default_max_results")]
75 pub max_results: usize,
76
77 #[serde(default = "default_max_tokens")]
79 pub max_tokens: usize,
80
81 #[serde(default)]
83 pub session_id: Option<String>,
84
85 #[serde(default)]
87 pub params: HashMap<String, serde_json::Value>,
88}
89
90fn default_max_results() -> usize {
91 10
92}
93
94fn default_max_tokens() -> usize {
95 4000
96}
97
98impl ContextQuery {
99 pub fn new(query: impl Into<String>) -> Self {
101 Self {
102 query: query.into(),
103 context_types: vec![ContextType::Resource],
104 depth: ContextDepth::default(),
105 max_results: default_max_results(),
106 max_tokens: default_max_tokens(),
107 session_id: None,
108 params: HashMap::new(),
109 }
110 }
111
112 pub fn with_types(mut self, types: impl IntoIterator<Item = ContextType>) -> Self {
114 self.context_types = types.into_iter().collect();
115 self
116 }
117
118 pub fn with_depth(mut self, depth: ContextDepth) -> Self {
120 self.depth = depth;
121 self
122 }
123
124 pub fn with_max_results(mut self, max: usize) -> Self {
126 self.max_results = max;
127 self
128 }
129
130 pub fn with_max_tokens(mut self, max: usize) -> Self {
132 self.max_tokens = max;
133 self
134 }
135
136 pub fn with_session_id(mut self, id: impl Into<String>) -> Self {
138 self.session_id = Some(id.into());
139 self
140 }
141
142 pub fn with_param(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
144 self.params.insert(key.into(), value);
145 self
146 }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ContextItem {
152 pub id: String,
154
155 pub context_type: ContextType,
157
158 pub content: String,
160
161 #[serde(default)]
163 pub token_count: usize,
164
165 #[serde(default)]
167 pub relevance: f32,
168
169 #[serde(default)]
171 pub source: Option<String>,
172
173 #[serde(default)]
175 pub metadata: HashMap<String, serde_json::Value>,
176}
177
178impl ContextItem {
179 pub fn new(
181 id: impl Into<String>,
182 context_type: ContextType,
183 content: impl Into<String>,
184 ) -> Self {
185 Self {
186 id: id.into(),
187 context_type,
188 content: content.into(),
189 token_count: 0,
190 relevance: 0.0,
191 source: None,
192 metadata: HashMap::new(),
193 }
194 }
195
196 pub fn with_token_count(mut self, count: usize) -> Self {
198 self.token_count = count;
199 self
200 }
201
202 pub fn with_relevance(mut self, score: f32) -> Self {
204 self.relevance = score.clamp(0.0, 1.0);
205 self
206 }
207
208 pub fn with_source(mut self, source: impl Into<String>) -> Self {
210 self.source = Some(source.into());
211 self
212 }
213
214 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
216 self.metadata.insert(key.into(), value);
217 self
218 }
219
220 pub fn to_xml(&self) -> String {
222 let source_attr = self
223 .source
224 .as_ref()
225 .map(|s| format!(" source=\"{}\"", s))
226 .unwrap_or_default();
227 let type_str = match self.context_type {
228 ContextType::Memory => "Memory",
229 ContextType::Resource => "Resource",
230 ContextType::Skill => "Skill",
231 };
232 format!(
233 "<context{} type=\"{}\">\n{}\n</context>",
234 source_attr, type_str, self.content
235 )
236 }
237}
238
239#[derive(Debug, Clone, Default, Serialize, Deserialize)]
241pub struct ContextResult {
242 pub items: Vec<ContextItem>,
244
245 pub total_tokens: usize,
247
248 pub provider: String,
250
251 pub truncated: bool,
253}
254
255impl ContextResult {
256 pub fn new(provider: impl Into<String>) -> Self {
258 Self {
259 items: Vec::new(),
260 total_tokens: 0,
261 provider: provider.into(),
262 truncated: false,
263 }
264 }
265
266 pub fn add_item(&mut self, item: ContextItem) {
268 self.total_tokens += item.token_count;
269 self.items.push(item);
270 }
271
272 pub fn is_empty(&self) -> bool {
274 self.items.is_empty()
275 }
276
277 pub fn to_xml(&self) -> String {
279 self.items
280 .iter()
281 .map(|item| item.to_xml())
282 .collect::<Vec<_>>()
283 .join("\n\n")
284 }
285}
286
287#[async_trait::async_trait]
289pub trait ContextProvider: Send + Sync {
290 fn name(&self) -> &str;
292
293 async fn query(&self, query: &ContextQuery) -> anyhow::Result<ContextResult>;
295
296 async fn on_turn_complete(
301 &self,
302 _session_id: &str,
303 _prompt: &str,
304 _response: &str,
305 ) -> anyhow::Result<()> {
306 Ok(())
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
319 fn test_context_type_default() {
320 let ct: ContextType = Default::default();
321 assert_eq!(ct, ContextType::Resource);
322 }
323
324 #[test]
325 fn test_context_type_serialization() {
326 let ct = ContextType::Memory;
327 let json = serde_json::to_string(&ct).unwrap();
328 assert_eq!(json, "\"Memory\"");
329
330 let parsed: ContextType = serde_json::from_str(&json).unwrap();
331 assert_eq!(parsed, ContextType::Memory);
332 }
333
334 #[test]
335 fn test_context_type_all_variants() {
336 let types = vec![
337 ContextType::Memory,
338 ContextType::Resource,
339 ContextType::Skill,
340 ];
341 for ct in types {
342 let json = serde_json::to_string(&ct).unwrap();
343 let parsed: ContextType = serde_json::from_str(&json).unwrap();
344 assert_eq!(parsed, ct);
345 }
346 }
347
348 #[test]
353 fn test_context_depth_default() {
354 let cd: ContextDepth = Default::default();
355 assert_eq!(cd, ContextDepth::Overview);
356 }
357
358 #[test]
359 fn test_context_depth_serialization() {
360 let cd = ContextDepth::Full;
361 let json = serde_json::to_string(&cd).unwrap();
362 assert_eq!(json, "\"Full\"");
363
364 let parsed: ContextDepth = serde_json::from_str(&json).unwrap();
365 assert_eq!(parsed, ContextDepth::Full);
366 }
367
368 #[test]
369 fn test_context_depth_all_variants() {
370 let depths = vec![
371 ContextDepth::Abstract,
372 ContextDepth::Overview,
373 ContextDepth::Full,
374 ];
375 for cd in depths {
376 let json = serde_json::to_string(&cd).unwrap();
377 let parsed: ContextDepth = serde_json::from_str(&json).unwrap();
378 assert_eq!(parsed, cd);
379 }
380 }
381
382 #[test]
387 fn test_context_query_new() {
388 let query = ContextQuery::new("test query");
389 assert_eq!(query.query, "test query");
390 assert_eq!(query.context_types, vec![ContextType::Resource]);
391 assert_eq!(query.depth, ContextDepth::Overview);
392 assert_eq!(query.max_results, 10);
393 assert_eq!(query.max_tokens, 4000);
394 assert!(query.session_id.is_none());
395 assert!(query.params.is_empty());
396 }
397
398 #[test]
399 fn test_context_query_builder() {
400 let query = ContextQuery::new("test")
401 .with_types([ContextType::Memory, ContextType::Skill])
402 .with_depth(ContextDepth::Full)
403 .with_max_results(5)
404 .with_max_tokens(2000)
405 .with_session_id("sess-123")
406 .with_param("custom", serde_json::json!("value"));
407
408 assert_eq!(query.context_types.len(), 2);
409 assert!(query.context_types.contains(&ContextType::Memory));
410 assert!(query.context_types.contains(&ContextType::Skill));
411 assert_eq!(query.depth, ContextDepth::Full);
412 assert_eq!(query.max_results, 5);
413 assert_eq!(query.max_tokens, 2000);
414 assert_eq!(query.session_id, Some("sess-123".to_string()));
415 assert_eq!(
416 query.params.get("custom"),
417 Some(&serde_json::json!("value"))
418 );
419 }
420
421 #[test]
422 fn test_context_query_serialization() {
423 let query = ContextQuery::new("search term")
424 .with_types([ContextType::Resource])
425 .with_session_id("sess-456");
426
427 let json = serde_json::to_string(&query).unwrap();
428 let parsed: ContextQuery = serde_json::from_str(&json).unwrap();
429
430 assert_eq!(parsed.query, "search term");
431 assert_eq!(parsed.session_id, Some("sess-456".to_string()));
432 }
433
434 #[test]
435 fn test_context_query_deserialization_with_defaults() {
436 let json = r#"{"query": "minimal query"}"#;
437 let query: ContextQuery = serde_json::from_str(json).unwrap();
438
439 assert_eq!(query.query, "minimal query");
440 assert!(query.context_types.is_empty()); assert_eq!(query.depth, ContextDepth::Overview);
442 assert_eq!(query.max_results, 10);
443 assert_eq!(query.max_tokens, 4000);
444 }
445
446 #[test]
451 fn test_context_item_new() {
452 let item = ContextItem::new("item-1", ContextType::Resource, "Some content");
453 assert_eq!(item.id, "item-1");
454 assert_eq!(item.context_type, ContextType::Resource);
455 assert_eq!(item.content, "Some content");
456 assert_eq!(item.token_count, 0);
457 assert_eq!(item.relevance, 0.0);
458 assert!(item.source.is_none());
459 assert!(item.metadata.is_empty());
460 }
461
462 #[test]
463 fn test_context_item_builder() {
464 let item = ContextItem::new("item-2", ContextType::Memory, "Memory content")
465 .with_token_count(150)
466 .with_relevance(0.85)
467 .with_source("viking://memory/session-123")
468 .with_metadata("key", serde_json::json!("value"));
469
470 assert_eq!(item.token_count, 150);
471 assert!((item.relevance - 0.85).abs() < f32::EPSILON);
472 assert_eq!(item.source, Some("viking://memory/session-123".to_string()));
473 assert_eq!(item.metadata.get("key"), Some(&serde_json::json!("value")));
474 }
475
476 #[test]
477 fn test_context_item_relevance_clamping() {
478 let item1 = ContextItem::new("id", ContextType::Resource, "").with_relevance(1.5);
479 assert!((item1.relevance - 1.0).abs() < f32::EPSILON);
480
481 let item2 = ContextItem::new("id", ContextType::Resource, "").with_relevance(-0.5);
482 assert!(item2.relevance.abs() < f32::EPSILON);
483 }
484
485 #[test]
486 fn test_context_item_to_xml_without_source() {
487 let item = ContextItem::new("id", ContextType::Resource, "Content here");
488 let xml = item.to_xml();
489 assert_eq!(xml, "<context type=\"Resource\">\nContent here\n</context>");
490 }
491
492 #[test]
493 fn test_context_item_to_xml_with_source() {
494 let item = ContextItem::new("id", ContextType::Memory, "Memory content")
495 .with_source("viking://docs/auth");
496 let xml = item.to_xml();
497 assert_eq!(
498 xml,
499 "<context source=\"viking://docs/auth\" type=\"Memory\">\nMemory content\n</context>"
500 );
501 }
502
503 #[test]
504 fn test_context_item_to_xml_all_types() {
505 let memory = ContextItem::new("m", ContextType::Memory, "m").to_xml();
506 assert!(memory.contains("type=\"Memory\""));
507
508 let resource = ContextItem::new("r", ContextType::Resource, "r").to_xml();
509 assert!(resource.contains("type=\"Resource\""));
510
511 let skill = ContextItem::new("s", ContextType::Skill, "s").to_xml();
512 assert!(skill.contains("type=\"Skill\""));
513 }
514
515 #[test]
516 fn test_context_item_serialization() {
517 let item = ContextItem::new("item-3", ContextType::Skill, "Skill instructions")
518 .with_token_count(200)
519 .with_relevance(0.9)
520 .with_source("viking://skills/code-review");
521
522 let json = serde_json::to_string(&item).unwrap();
523 let parsed: ContextItem = serde_json::from_str(&json).unwrap();
524
525 assert_eq!(parsed.id, "item-3");
526 assert_eq!(parsed.context_type, ContextType::Skill);
527 assert_eq!(parsed.content, "Skill instructions");
528 assert_eq!(parsed.token_count, 200);
529 }
530
531 #[test]
536 fn test_context_result_new() {
537 let result = ContextResult::new("test-provider");
538 assert!(result.items.is_empty());
539 assert_eq!(result.total_tokens, 0);
540 assert_eq!(result.provider, "test-provider");
541 assert!(!result.truncated);
542 }
543
544 #[test]
545 fn test_context_result_add_item() {
546 let mut result = ContextResult::new("provider");
547 let item = ContextItem::new("id", ContextType::Resource, "content").with_token_count(100);
548 result.add_item(item);
549
550 assert_eq!(result.items.len(), 1);
551 assert_eq!(result.total_tokens, 100);
552 }
553
554 #[test]
555 fn test_context_result_add_multiple_items() {
556 let mut result = ContextResult::new("provider");
557 result.add_item(ContextItem::new("1", ContextType::Resource, "a").with_token_count(50));
558 result.add_item(ContextItem::new("2", ContextType::Memory, "b").with_token_count(75));
559 result.add_item(ContextItem::new("3", ContextType::Skill, "c").with_token_count(25));
560
561 assert_eq!(result.items.len(), 3);
562 assert_eq!(result.total_tokens, 150);
563 }
564
565 #[test]
566 fn test_context_result_is_empty() {
567 let empty = ContextResult::new("provider");
568 assert!(empty.is_empty());
569
570 let mut non_empty = ContextResult::new("provider");
571 non_empty.add_item(ContextItem::new("id", ContextType::Resource, "content"));
572 assert!(!non_empty.is_empty());
573 }
574
575 #[test]
576 fn test_context_result_to_xml() {
577 let mut result = ContextResult::new("provider");
578 result.add_item(
579 ContextItem::new("1", ContextType::Resource, "First content").with_source("source://1"),
580 );
581 result.add_item(ContextItem::new("2", ContextType::Memory, "Second content"));
582
583 let xml = result.to_xml();
584 assert!(xml.contains("<context source=\"source://1\" type=\"Resource\">"));
585 assert!(xml.contains("First content"));
586 assert!(xml.contains("<context type=\"Memory\">"));
587 assert!(xml.contains("Second content"));
588 }
589
590 #[test]
591 fn test_context_result_to_xml_empty() {
592 let result = ContextResult::new("provider");
593 let xml = result.to_xml();
594 assert!(xml.is_empty());
595 }
596
597 #[test]
598 fn test_context_result_serialization() {
599 let mut result = ContextResult::new("test-provider");
600 result.truncated = true;
601 result.add_item(ContextItem::new("id", ContextType::Resource, "content"));
602
603 let json = serde_json::to_string(&result).unwrap();
604 let parsed: ContextResult = serde_json::from_str(&json).unwrap();
605
606 assert_eq!(parsed.provider, "test-provider");
607 assert!(parsed.truncated);
608 assert_eq!(parsed.items.len(), 1);
609 }
610
611 #[test]
612 fn test_context_result_default() {
613 let result: ContextResult = Default::default();
614 assert!(result.items.is_empty());
615 assert_eq!(result.total_tokens, 0);
616 assert!(result.provider.is_empty());
617 assert!(!result.truncated);
618 }
619
620 struct MockContextProvider {
625 name: String,
626 items: Vec<ContextItem>,
627 }
628
629 impl MockContextProvider {
630 fn new(name: &str) -> Self {
631 Self {
632 name: name.to_string(),
633 items: Vec::new(),
634 }
635 }
636
637 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
638 self.items = items;
639 self
640 }
641 }
642
643 #[async_trait::async_trait]
644 impl ContextProvider for MockContextProvider {
645 fn name(&self) -> &str {
646 &self.name
647 }
648
649 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
650 let mut result = ContextResult::new(&self.name);
651 for item in &self.items {
652 result.add_item(item.clone());
653 }
654 Ok(result)
655 }
656 }
657
658 #[tokio::test]
659 async fn test_mock_context_provider() {
660 let provider = MockContextProvider::new("mock").with_items(vec![ContextItem::new(
661 "1",
662 ContextType::Resource,
663 "content",
664 )]);
665
666 assert_eq!(provider.name(), "mock");
667
668 let query = ContextQuery::new("test");
669 let result = provider.query(&query).await.unwrap();
670
671 assert_eq!(result.provider, "mock");
672 assert_eq!(result.items.len(), 1);
673 }
674
675 #[tokio::test]
676 async fn test_context_provider_on_turn_complete_default() {
677 let provider = MockContextProvider::new("mock");
678
679 let result = provider
681 .on_turn_complete("session-1", "prompt", "response")
682 .await;
683 assert!(result.is_ok());
684 }
685
686 struct MockMemoryProvider {
687 memories: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
688 }
689
690 impl MockMemoryProvider {
691 fn new() -> Self {
692 Self {
693 memories: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
694 }
695 }
696 }
697
698 #[async_trait::async_trait]
699 impl ContextProvider for MockMemoryProvider {
700 fn name(&self) -> &str {
701 "memory-provider"
702 }
703
704 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
705 Ok(ContextResult::new("memory-provider"))
706 }
707
708 async fn on_turn_complete(
709 &self,
710 session_id: &str,
711 prompt: &str,
712 response: &str,
713 ) -> anyhow::Result<()> {
714 let mut memories = self.memories.write().await;
715 memories.push((
716 session_id.to_string(),
717 prompt.to_string(),
718 response.to_string(),
719 ));
720 Ok(())
721 }
722 }
723
724 #[tokio::test]
725 async fn test_context_provider_on_turn_complete_custom() {
726 let provider = MockMemoryProvider::new();
727
728 provider
729 .on_turn_complete("sess-1", "What is Rust?", "Rust is a systems language.")
730 .await
731 .unwrap();
732
733 let memories = provider.memories.read().await;
734 assert_eq!(memories.len(), 1);
735 assert_eq!(memories[0].0, "sess-1");
736 assert_eq!(memories[0].1, "What is Rust?");
737 assert_eq!(memories[0].2, "Rust is a systems language.");
738 }
739
740 #[tokio::test]
745 async fn test_multiple_providers_query() {
746 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
747 "p1-1",
748 ContextType::Resource,
749 "Resource from P1",
750 )]);
751
752 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
753 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2"),
754 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2"),
755 ]);
756
757 let providers: Vec<&dyn ContextProvider> = vec![&provider1, &provider2];
758 let query = ContextQuery::new("test");
759
760 let mut all_items = Vec::new();
761 for provider in providers {
762 let result = provider.query(&query).await.unwrap();
763 all_items.extend(result.items);
764 }
765
766 assert_eq!(all_items.len(), 3);
767 assert!(all_items.iter().any(|i| i.id == "p1-1"));
768 assert!(all_items.iter().any(|i| i.id == "p2-1"));
769 assert!(all_items.iter().any(|i| i.id == "p2-2"));
770 }
771
772 #[test]
773 fn test_context_result_xml_formatting_complex() {
774 let mut result = ContextResult::new("openviking");
775 result.add_item(
776 ContextItem::new(
777 "doc-1",
778 ContextType::Resource,
779 "Authentication uses JWT tokens stored in httpOnly cookies.",
780 )
781 .with_source("viking://docs/auth")
782 .with_token_count(50),
783 );
784 result.add_item(
785 ContextItem::new(
786 "mem-1",
787 ContextType::Memory,
788 "User prefers TypeScript over JavaScript.",
789 )
790 .with_token_count(30),
791 );
792
793 let xml = result.to_xml();
794
795 assert!(xml.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
797 assert!(xml.contains("Authentication uses JWT tokens"));
798 assert!(xml.contains("<context type=\"Memory\">"));
799 assert!(xml.contains("User prefers TypeScript"));
800
801 assert!(xml.contains("</context>\n\n<context"));
803 }
804}