1use serde::{Deserialize, Serialize};
27use std::collections::HashMap;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
31pub enum ContextType {
32 Memory,
34 #[default]
36 Resource,
37 Skill,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
43pub enum ContextDepth {
44 Abstract,
46 #[default]
48 Overview,
49 Full,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ContextQuery {
56 pub query: String,
58
59 #[serde(default)]
61 pub context_types: Vec<ContextType>,
62
63 #[serde(default)]
65 pub depth: ContextDepth,
66
67 #[serde(default = "default_max_results")]
69 pub max_results: usize,
70
71 #[serde(default = "default_max_tokens")]
73 pub max_tokens: usize,
74
75 #[serde(default)]
77 pub session_id: Option<String>,
78
79 #[serde(default)]
81 pub params: HashMap<String, serde_json::Value>,
82}
83
84fn default_max_results() -> usize {
85 10
86}
87
88fn default_max_tokens() -> usize {
89 4000
90}
91
92impl ContextQuery {
93 pub fn new(query: impl Into<String>) -> Self {
95 Self {
96 query: query.into(),
97 context_types: vec![ContextType::Resource],
98 depth: ContextDepth::default(),
99 max_results: default_max_results(),
100 max_tokens: default_max_tokens(),
101 session_id: None,
102 params: HashMap::new(),
103 }
104 }
105
106 pub fn with_types(mut self, types: impl IntoIterator<Item = ContextType>) -> Self {
108 self.context_types = types.into_iter().collect();
109 self
110 }
111
112 pub fn with_depth(mut self, depth: ContextDepth) -> Self {
114 self.depth = depth;
115 self
116 }
117
118 pub fn with_max_results(mut self, max: usize) -> Self {
120 self.max_results = max;
121 self
122 }
123
124 pub fn with_max_tokens(mut self, max: usize) -> Self {
126 self.max_tokens = max;
127 self
128 }
129
130 pub fn with_session_id(mut self, id: impl Into<String>) -> Self {
132 self.session_id = Some(id.into());
133 self
134 }
135
136 pub fn with_param(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
138 self.params.insert(key.into(), value);
139 self
140 }
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct ContextItem {
146 pub id: String,
148
149 pub context_type: ContextType,
151
152 pub content: String,
154
155 #[serde(default)]
157 pub token_count: usize,
158
159 #[serde(default)]
161 pub relevance: f32,
162
163 #[serde(default)]
165 pub source: Option<String>,
166
167 #[serde(default)]
169 pub metadata: HashMap<String, serde_json::Value>,
170}
171
172impl ContextItem {
173 pub fn new(
175 id: impl Into<String>,
176 context_type: ContextType,
177 content: impl Into<String>,
178 ) -> Self {
179 Self {
180 id: id.into(),
181 context_type,
182 content: content.into(),
183 token_count: 0,
184 relevance: 0.0,
185 source: None,
186 metadata: HashMap::new(),
187 }
188 }
189
190 pub fn with_token_count(mut self, count: usize) -> Self {
192 self.token_count = count;
193 self
194 }
195
196 pub fn with_relevance(mut self, score: f32) -> Self {
198 self.relevance = score.clamp(0.0, 1.0);
199 self
200 }
201
202 pub fn with_source(mut self, source: impl Into<String>) -> Self {
204 self.source = Some(source.into());
205 self
206 }
207
208 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
210 self.metadata.insert(key.into(), value);
211 self
212 }
213
214 pub fn to_xml(&self) -> String {
216 let source_attr = self
217 .source
218 .as_ref()
219 .map(|s| format!(" source=\"{}\"", s))
220 .unwrap_or_default();
221 let type_str = match self.context_type {
222 ContextType::Memory => "Memory",
223 ContextType::Resource => "Resource",
224 ContextType::Skill => "Skill",
225 };
226 format!(
227 "<context{} type=\"{}\">\n{}\n</context>",
228 source_attr, type_str, self.content
229 )
230 }
231}
232
233#[derive(Debug, Clone, Default, Serialize, Deserialize)]
235pub struct ContextResult {
236 pub items: Vec<ContextItem>,
238
239 pub total_tokens: usize,
241
242 pub provider: String,
244
245 pub truncated: bool,
247}
248
249impl ContextResult {
250 pub fn new(provider: impl Into<String>) -> Self {
252 Self {
253 items: Vec::new(),
254 total_tokens: 0,
255 provider: provider.into(),
256 truncated: false,
257 }
258 }
259
260 pub fn add_item(&mut self, item: ContextItem) {
262 self.total_tokens += item.token_count;
263 self.items.push(item);
264 }
265
266 pub fn is_empty(&self) -> bool {
268 self.items.is_empty()
269 }
270
271 pub fn to_xml(&self) -> String {
273 self.items
274 .iter()
275 .map(|item| item.to_xml())
276 .collect::<Vec<_>>()
277 .join("\n\n")
278 }
279}
280
281#[async_trait::async_trait]
283pub trait ContextProvider: Send + Sync {
284 fn name(&self) -> &str;
286
287 async fn query(&self, query: &ContextQuery) -> anyhow::Result<ContextResult>;
289
290 async fn on_turn_complete(
295 &self,
296 _session_id: &str,
297 _prompt: &str,
298 _response: &str,
299 ) -> anyhow::Result<()> {
300 Ok(())
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
313 fn test_context_type_default() {
314 let ct: ContextType = Default::default();
315 assert_eq!(ct, ContextType::Resource);
316 }
317
318 #[test]
319 fn test_context_type_serialization() {
320 let ct = ContextType::Memory;
321 let json = serde_json::to_string(&ct).unwrap();
322 assert_eq!(json, "\"Memory\"");
323
324 let parsed: ContextType = serde_json::from_str(&json).unwrap();
325 assert_eq!(parsed, ContextType::Memory);
326 }
327
328 #[test]
329 fn test_context_type_all_variants() {
330 let types = vec![
331 ContextType::Memory,
332 ContextType::Resource,
333 ContextType::Skill,
334 ];
335 for ct in types {
336 let json = serde_json::to_string(&ct).unwrap();
337 let parsed: ContextType = serde_json::from_str(&json).unwrap();
338 assert_eq!(parsed, ct);
339 }
340 }
341
342 #[test]
347 fn test_context_depth_default() {
348 let cd: ContextDepth = Default::default();
349 assert_eq!(cd, ContextDepth::Overview);
350 }
351
352 #[test]
353 fn test_context_depth_serialization() {
354 let cd = ContextDepth::Full;
355 let json = serde_json::to_string(&cd).unwrap();
356 assert_eq!(json, "\"Full\"");
357
358 let parsed: ContextDepth = serde_json::from_str(&json).unwrap();
359 assert_eq!(parsed, ContextDepth::Full);
360 }
361
362 #[test]
363 fn test_context_depth_all_variants() {
364 let depths = vec![
365 ContextDepth::Abstract,
366 ContextDepth::Overview,
367 ContextDepth::Full,
368 ];
369 for cd in depths {
370 let json = serde_json::to_string(&cd).unwrap();
371 let parsed: ContextDepth = serde_json::from_str(&json).unwrap();
372 assert_eq!(parsed, cd);
373 }
374 }
375
376 #[test]
381 fn test_context_query_new() {
382 let query = ContextQuery::new("test query");
383 assert_eq!(query.query, "test query");
384 assert_eq!(query.context_types, vec![ContextType::Resource]);
385 assert_eq!(query.depth, ContextDepth::Overview);
386 assert_eq!(query.max_results, 10);
387 assert_eq!(query.max_tokens, 4000);
388 assert!(query.session_id.is_none());
389 assert!(query.params.is_empty());
390 }
391
392 #[test]
393 fn test_context_query_builder() {
394 let query = ContextQuery::new("test")
395 .with_types([ContextType::Memory, ContextType::Skill])
396 .with_depth(ContextDepth::Full)
397 .with_max_results(5)
398 .with_max_tokens(2000)
399 .with_session_id("sess-123")
400 .with_param("custom", serde_json::json!("value"));
401
402 assert_eq!(query.context_types.len(), 2);
403 assert!(query.context_types.contains(&ContextType::Memory));
404 assert!(query.context_types.contains(&ContextType::Skill));
405 assert_eq!(query.depth, ContextDepth::Full);
406 assert_eq!(query.max_results, 5);
407 assert_eq!(query.max_tokens, 2000);
408 assert_eq!(query.session_id, Some("sess-123".to_string()));
409 assert_eq!(
410 query.params.get("custom"),
411 Some(&serde_json::json!("value"))
412 );
413 }
414
415 #[test]
416 fn test_context_query_serialization() {
417 let query = ContextQuery::new("search term")
418 .with_types([ContextType::Resource])
419 .with_session_id("sess-456");
420
421 let json = serde_json::to_string(&query).unwrap();
422 let parsed: ContextQuery = serde_json::from_str(&json).unwrap();
423
424 assert_eq!(parsed.query, "search term");
425 assert_eq!(parsed.session_id, Some("sess-456".to_string()));
426 }
427
428 #[test]
429 fn test_context_query_deserialization_with_defaults() {
430 let json = r#"{"query": "minimal query"}"#;
431 let query: ContextQuery = serde_json::from_str(json).unwrap();
432
433 assert_eq!(query.query, "minimal query");
434 assert!(query.context_types.is_empty()); assert_eq!(query.depth, ContextDepth::Overview);
436 assert_eq!(query.max_results, 10);
437 assert_eq!(query.max_tokens, 4000);
438 }
439
440 #[test]
445 fn test_context_item_new() {
446 let item = ContextItem::new("item-1", ContextType::Resource, "Some content");
447 assert_eq!(item.id, "item-1");
448 assert_eq!(item.context_type, ContextType::Resource);
449 assert_eq!(item.content, "Some content");
450 assert_eq!(item.token_count, 0);
451 assert_eq!(item.relevance, 0.0);
452 assert!(item.source.is_none());
453 assert!(item.metadata.is_empty());
454 }
455
456 #[test]
457 fn test_context_item_builder() {
458 let item = ContextItem::new("item-2", ContextType::Memory, "Memory content")
459 .with_token_count(150)
460 .with_relevance(0.85)
461 .with_source("viking://memory/session-123")
462 .with_metadata("key", serde_json::json!("value"));
463
464 assert_eq!(item.token_count, 150);
465 assert!((item.relevance - 0.85).abs() < f32::EPSILON);
466 assert_eq!(item.source, Some("viking://memory/session-123".to_string()));
467 assert_eq!(item.metadata.get("key"), Some(&serde_json::json!("value")));
468 }
469
470 #[test]
471 fn test_context_item_relevance_clamping() {
472 let item1 = ContextItem::new("id", ContextType::Resource, "").with_relevance(1.5);
473 assert!((item1.relevance - 1.0).abs() < f32::EPSILON);
474
475 let item2 = ContextItem::new("id", ContextType::Resource, "").with_relevance(-0.5);
476 assert!(item2.relevance.abs() < f32::EPSILON);
477 }
478
479 #[test]
480 fn test_context_item_to_xml_without_source() {
481 let item = ContextItem::new("id", ContextType::Resource, "Content here");
482 let xml = item.to_xml();
483 assert_eq!(xml, "<context type=\"Resource\">\nContent here\n</context>");
484 }
485
486 #[test]
487 fn test_context_item_to_xml_with_source() {
488 let item = ContextItem::new("id", ContextType::Memory, "Memory content")
489 .with_source("viking://docs/auth");
490 let xml = item.to_xml();
491 assert_eq!(
492 xml,
493 "<context source=\"viking://docs/auth\" type=\"Memory\">\nMemory content\n</context>"
494 );
495 }
496
497 #[test]
498 fn test_context_item_to_xml_all_types() {
499 let memory = ContextItem::new("m", ContextType::Memory, "m").to_xml();
500 assert!(memory.contains("type=\"Memory\""));
501
502 let resource = ContextItem::new("r", ContextType::Resource, "r").to_xml();
503 assert!(resource.contains("type=\"Resource\""));
504
505 let skill = ContextItem::new("s", ContextType::Skill, "s").to_xml();
506 assert!(skill.contains("type=\"Skill\""));
507 }
508
509 #[test]
510 fn test_context_item_serialization() {
511 let item = ContextItem::new("item-3", ContextType::Skill, "Skill instructions")
512 .with_token_count(200)
513 .with_relevance(0.9)
514 .with_source("viking://skills/code-review");
515
516 let json = serde_json::to_string(&item).unwrap();
517 let parsed: ContextItem = serde_json::from_str(&json).unwrap();
518
519 assert_eq!(parsed.id, "item-3");
520 assert_eq!(parsed.context_type, ContextType::Skill);
521 assert_eq!(parsed.content, "Skill instructions");
522 assert_eq!(parsed.token_count, 200);
523 }
524
525 #[test]
530 fn test_context_result_new() {
531 let result = ContextResult::new("test-provider");
532 assert!(result.items.is_empty());
533 assert_eq!(result.total_tokens, 0);
534 assert_eq!(result.provider, "test-provider");
535 assert!(!result.truncated);
536 }
537
538 #[test]
539 fn test_context_result_add_item() {
540 let mut result = ContextResult::new("provider");
541 let item = ContextItem::new("id", ContextType::Resource, "content").with_token_count(100);
542 result.add_item(item);
543
544 assert_eq!(result.items.len(), 1);
545 assert_eq!(result.total_tokens, 100);
546 }
547
548 #[test]
549 fn test_context_result_add_multiple_items() {
550 let mut result = ContextResult::new("provider");
551 result.add_item(ContextItem::new("1", ContextType::Resource, "a").with_token_count(50));
552 result.add_item(ContextItem::new("2", ContextType::Memory, "b").with_token_count(75));
553 result.add_item(ContextItem::new("3", ContextType::Skill, "c").with_token_count(25));
554
555 assert_eq!(result.items.len(), 3);
556 assert_eq!(result.total_tokens, 150);
557 }
558
559 #[test]
560 fn test_context_result_is_empty() {
561 let empty = ContextResult::new("provider");
562 assert!(empty.is_empty());
563
564 let mut non_empty = ContextResult::new("provider");
565 non_empty.add_item(ContextItem::new("id", ContextType::Resource, "content"));
566 assert!(!non_empty.is_empty());
567 }
568
569 #[test]
570 fn test_context_result_to_xml() {
571 let mut result = ContextResult::new("provider");
572 result.add_item(
573 ContextItem::new("1", ContextType::Resource, "First content").with_source("source://1"),
574 );
575 result.add_item(ContextItem::new("2", ContextType::Memory, "Second content"));
576
577 let xml = result.to_xml();
578 assert!(xml.contains("<context source=\"source://1\" type=\"Resource\">"));
579 assert!(xml.contains("First content"));
580 assert!(xml.contains("<context type=\"Memory\">"));
581 assert!(xml.contains("Second content"));
582 }
583
584 #[test]
585 fn test_context_result_to_xml_empty() {
586 let result = ContextResult::new("provider");
587 let xml = result.to_xml();
588 assert!(xml.is_empty());
589 }
590
591 #[test]
592 fn test_context_result_serialization() {
593 let mut result = ContextResult::new("test-provider");
594 result.truncated = true;
595 result.add_item(ContextItem::new("id", ContextType::Resource, "content"));
596
597 let json = serde_json::to_string(&result).unwrap();
598 let parsed: ContextResult = serde_json::from_str(&json).unwrap();
599
600 assert_eq!(parsed.provider, "test-provider");
601 assert!(parsed.truncated);
602 assert_eq!(parsed.items.len(), 1);
603 }
604
605 #[test]
606 fn test_context_result_default() {
607 let result: ContextResult = Default::default();
608 assert!(result.items.is_empty());
609 assert_eq!(result.total_tokens, 0);
610 assert!(result.provider.is_empty());
611 assert!(!result.truncated);
612 }
613
614 struct MockContextProvider {
619 name: String,
620 items: Vec<ContextItem>,
621 }
622
623 impl MockContextProvider {
624 fn new(name: &str) -> Self {
625 Self {
626 name: name.to_string(),
627 items: Vec::new(),
628 }
629 }
630
631 fn with_items(mut self, items: Vec<ContextItem>) -> Self {
632 self.items = items;
633 self
634 }
635 }
636
637 #[async_trait::async_trait]
638 impl ContextProvider for MockContextProvider {
639 fn name(&self) -> &str {
640 &self.name
641 }
642
643 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
644 let mut result = ContextResult::new(&self.name);
645 for item in &self.items {
646 result.add_item(item.clone());
647 }
648 Ok(result)
649 }
650 }
651
652 #[tokio::test]
653 async fn test_mock_context_provider() {
654 let provider = MockContextProvider::new("mock").with_items(vec![ContextItem::new(
655 "1",
656 ContextType::Resource,
657 "content",
658 )]);
659
660 assert_eq!(provider.name(), "mock");
661
662 let query = ContextQuery::new("test");
663 let result = provider.query(&query).await.unwrap();
664
665 assert_eq!(result.provider, "mock");
666 assert_eq!(result.items.len(), 1);
667 }
668
669 #[tokio::test]
670 async fn test_context_provider_on_turn_complete_default() {
671 let provider = MockContextProvider::new("mock");
672
673 let result = provider
675 .on_turn_complete("session-1", "prompt", "response")
676 .await;
677 assert!(result.is_ok());
678 }
679
680 struct MockMemoryProvider {
681 memories: std::sync::Arc<tokio::sync::RwLock<Vec<(String, String, String)>>>,
682 }
683
684 impl MockMemoryProvider {
685 fn new() -> Self {
686 Self {
687 memories: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
688 }
689 }
690 }
691
692 #[async_trait::async_trait]
693 impl ContextProvider for MockMemoryProvider {
694 fn name(&self) -> &str {
695 "memory-provider"
696 }
697
698 async fn query(&self, _query: &ContextQuery) -> anyhow::Result<ContextResult> {
699 Ok(ContextResult::new("memory-provider"))
700 }
701
702 async fn on_turn_complete(
703 &self,
704 session_id: &str,
705 prompt: &str,
706 response: &str,
707 ) -> anyhow::Result<()> {
708 let mut memories = self.memories.write().await;
709 memories.push((
710 session_id.to_string(),
711 prompt.to_string(),
712 response.to_string(),
713 ));
714 Ok(())
715 }
716 }
717
718 #[tokio::test]
719 async fn test_context_provider_on_turn_complete_custom() {
720 let provider = MockMemoryProvider::new();
721
722 provider
723 .on_turn_complete("sess-1", "What is Rust?", "Rust is a systems language.")
724 .await
725 .unwrap();
726
727 let memories = provider.memories.read().await;
728 assert_eq!(memories.len(), 1);
729 assert_eq!(memories[0].0, "sess-1");
730 assert_eq!(memories[0].1, "What is Rust?");
731 assert_eq!(memories[0].2, "Rust is a systems language.");
732 }
733
734 #[tokio::test]
739 async fn test_multiple_providers_query() {
740 let provider1 = MockContextProvider::new("provider-1").with_items(vec![ContextItem::new(
741 "p1-1",
742 ContextType::Resource,
743 "Resource from P1",
744 )]);
745
746 let provider2 = MockContextProvider::new("provider-2").with_items(vec![
747 ContextItem::new("p2-1", ContextType::Memory, "Memory from P2"),
748 ContextItem::new("p2-2", ContextType::Skill, "Skill from P2"),
749 ]);
750
751 let providers: Vec<&dyn ContextProvider> = vec![&provider1, &provider2];
752 let query = ContextQuery::new("test");
753
754 let mut all_items = Vec::new();
755 for provider in providers {
756 let result = provider.query(&query).await.unwrap();
757 all_items.extend(result.items);
758 }
759
760 assert_eq!(all_items.len(), 3);
761 assert!(all_items.iter().any(|i| i.id == "p1-1"));
762 assert!(all_items.iter().any(|i| i.id == "p2-1"));
763 assert!(all_items.iter().any(|i| i.id == "p2-2"));
764 }
765
766 #[test]
767 fn test_context_result_xml_formatting_complex() {
768 let mut result = ContextResult::new("openviking");
769 result.add_item(
770 ContextItem::new(
771 "doc-1",
772 ContextType::Resource,
773 "Authentication uses JWT tokens stored in httpOnly cookies.",
774 )
775 .with_source("viking://docs/auth")
776 .with_token_count(50),
777 );
778 result.add_item(
779 ContextItem::new(
780 "mem-1",
781 ContextType::Memory,
782 "User prefers TypeScript over JavaScript.",
783 )
784 .with_token_count(30),
785 );
786
787 let xml = result.to_xml();
788
789 assert!(xml.contains("<context source=\"viking://docs/auth\" type=\"Resource\">"));
791 assert!(xml.contains("Authentication uses JWT tokens"));
792 assert!(xml.contains("<context type=\"Memory\">"));
793 assert!(xml.contains("User prefers TypeScript"));
794
795 assert!(xml.contains("</context>\n\n<context"));
797 }
798}