1use crate::token_budget::{BudgetSection, TokenBudgetConfig, TokenBudgetEnforcer, TokenEstimator};
52use crate::soch_ql::{ComparisonOp, Condition, LogicalOp, SochValue, WhereClause};
53use std::collections::HashMap;
54
55#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
61pub struct ContextSelectQuery {
62 pub output_name: String,
64 pub session: SessionReference,
66 pub options: ContextQueryOptions,
68 pub sections: Vec<ContextSection>,
70}
71
72#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
74pub enum SessionReference {
75 Session(String),
77 Agent(String),
79 None,
81}
82
83#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
85pub struct ContextQueryOptions {
86 pub token_limit: usize,
88 pub include_schema: bool,
90 pub format: OutputFormat,
92 pub truncation: TruncationStrategy,
94 pub include_headers: bool,
96}
97
98impl Default for ContextQueryOptions {
99 fn default() -> Self {
100 Self {
101 token_limit: 4096,
102 include_schema: true,
103 format: OutputFormat::Soch,
104 truncation: TruncationStrategy::TailDrop,
105 include_headers: true,
106 }
107 }
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
112pub enum OutputFormat {
113 Soch,
115 Json,
117 Markdown,
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
123pub enum TruncationStrategy {
124 TailDrop,
126 HeadDrop,
128 Proportional,
130 Fail,
132}
133
134#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
140pub struct ContextSection {
141 pub name: String,
143 pub priority: i32,
145 pub content: SectionContent,
147 pub transform: Option<SectionTransform>,
149}
150
151#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
153pub enum SectionContent {
154 Get { path: PathExpression },
157
158 Last {
161 count: usize,
162 table: String,
163 where_clause: Option<WhereClause>,
164 },
165
166 Search {
169 collection: String,
170 query: SimilarityQuery,
171 top_k: usize,
172 min_score: Option<f32>,
173 },
174
175 Select {
177 columns: Vec<String>,
178 table: String,
179 where_clause: Option<WhereClause>,
180 limit: Option<usize>,
181 },
182
183 Literal { value: String },
185
186 Variable { name: String },
188
189 ToolRegistry {
192 include: Vec<String>,
194 exclude: Vec<String>,
196 include_schema: bool,
198 },
199
200 ToolCalls {
203 count: usize,
205 tool_filter: Option<String>,
207 status_filter: Option<String>,
209 include_outputs: bool,
211 },
212}
213
214#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
216pub struct PathExpression {
217 pub segments: Vec<String>,
219 pub fields: Vec<String>,
221 pub all_fields: bool,
223}
224
225impl PathExpression {
226 pub fn parse(input: &str) -> Result<Self, ContextParseError> {
229 let input = input.trim();
230
231 if let Some(brace_start) = input.find('{') {
233 if !input.ends_with('}') {
234 return Err(ContextParseError::InvalidPath(
235 "unclosed field projection".to_string(),
236 ));
237 }
238
239 let path_part = &input[..brace_start].trim_end_matches('.');
240 let fields_part = &input[brace_start + 1..input.len() - 1];
241
242 let segments: Vec<String> = path_part.split('.').map(|s| s.to_string()).collect();
243 let fields: Vec<String> = fields_part
244 .split(',')
245 .map(|s| s.trim().to_string())
246 .filter(|s| !s.is_empty())
247 .collect();
248
249 Ok(PathExpression {
250 segments,
251 fields,
252 all_fields: false,
253 })
254 } else if let Some(path_part) = input.strip_suffix(".**") {
255 let segments: Vec<String> = path_part.split('.').map(|s| s.to_string()).collect();
257
258 Ok(PathExpression {
259 segments,
260 fields: vec![],
261 all_fields: true,
262 })
263 } else {
264 let segments: Vec<String> = input.split('.').map(|s| s.to_string()).collect();
266
267 Ok(PathExpression {
268 segments,
269 fields: vec![],
270 all_fields: true,
271 })
272 }
273 }
274
275 pub fn to_path_string(&self) -> String {
277 let base = self.segments.join(".");
278 if self.all_fields {
279 format!("{}.**", base)
280 } else if !self.fields.is_empty() {
281 format!("{}.{{{}}}", base, self.fields.join(", "))
282 } else {
283 base
284 }
285 }
286}
287
288#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
290pub enum SimilarityQuery {
291 Variable(String),
293 Embedding(Vec<f32>),
295 Text(String),
297}
298
299#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
301pub enum SectionTransform {
302 Summarize { max_tokens: usize },
304 Project { fields: Vec<String> },
306 Template { template: String },
308 Custom { function: String },
310}
311
312#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
324pub struct ContextRecipe {
325 pub id: String,
327 pub name: String,
329 pub description: String,
331 pub version: String,
333 pub query: ContextSelectQuery,
335 pub metadata: RecipeMetadata,
337 pub session_binding: Option<SessionBinding>,
339}
340
341#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
343pub struct RecipeMetadata {
344 pub author: Option<String>,
346 pub created_at: Option<String>,
348 pub updated_at: Option<String>,
350 pub tags: Vec<String>,
352 pub usage_count: u64,
354 pub avg_tokens: Option<f32>,
356}
357
358#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
360pub enum SessionBinding {
361 Session(String),
363 Agent(String),
365 Pattern(String),
367 None,
369}
370
371pub struct ContextRecipeStore {
373 recipes: std::sync::RwLock<HashMap<String, ContextRecipe>>,
375 versions: std::sync::RwLock<HashMap<String, Vec<String>>>,
377}
378
379impl ContextRecipeStore {
380 pub fn new() -> Self {
382 Self {
383 recipes: std::sync::RwLock::new(HashMap::new()),
384 versions: std::sync::RwLock::new(HashMap::new()),
385 }
386 }
387
388 pub fn save(&self, recipe: ContextRecipe) -> Result<(), String> {
390 let mut recipes = self.recipes.write().map_err(|e| e.to_string())?;
391 let mut versions = self.versions.write().map_err(|e| e.to_string())?;
392
393 let key = format!("{}:{}", recipe.id, recipe.version);
394 recipes.insert(key.clone(), recipe.clone());
395
396 versions
397 .entry(recipe.id.clone())
398 .or_default()
399 .push(recipe.version.clone());
400
401 Ok(())
402 }
403
404 pub fn get_latest(&self, recipe_id: &str) -> Option<ContextRecipe> {
406 let versions = self.versions.read().ok()?;
407 let latest_version = versions.get(recipe_id)?.last()?;
408
409 let recipes = self.recipes.read().ok()?;
410 let key = format!("{}:{}", recipe_id, latest_version);
411 recipes.get(&key).cloned()
412 }
413
414 pub fn get_version(&self, recipe_id: &str, version: &str) -> Option<ContextRecipe> {
416 let recipes = self.recipes.read().ok()?;
417 let key = format!("{}:{}", recipe_id, version);
418 recipes.get(&key).cloned()
419 }
420
421 pub fn list_versions(&self, recipe_id: &str) -> Vec<String> {
423 self.versions
424 .read()
425 .ok()
426 .and_then(|v| v.get(recipe_id).cloned())
427 .unwrap_or_default()
428 }
429
430 pub fn find_by_session(&self, session_id: &str) -> Vec<ContextRecipe> {
432 let recipes = match self.recipes.read() {
433 Ok(r) => r,
434 Err(_) => return Vec::new(),
435 };
436
437 recipes
438 .values()
439 .filter(|r| match &r.session_binding {
440 Some(SessionBinding::Session(sid)) => sid == session_id,
441 Some(SessionBinding::Pattern(pattern)) => {
442 glob_match(pattern, session_id)
443 }
444 _ => false,
445 })
446 .cloned()
447 .collect()
448 }
449
450 pub fn find_by_agent(&self, agent_id: &str) -> Vec<ContextRecipe> {
452 let recipes = match self.recipes.read() {
453 Ok(r) => r,
454 Err(_) => return Vec::new(),
455 };
456
457 recipes
458 .values()
459 .filter(|r| matches!(&r.session_binding, Some(SessionBinding::Agent(aid)) if aid == agent_id))
460 .cloned()
461 .collect()
462 }
463}
464
465impl Default for ContextRecipeStore {
466 fn default() -> Self {
467 Self::new()
468 }
469}
470
471fn glob_match(pattern: &str, input: &str) -> bool {
473 if pattern == "*" {
475 return true;
476 }
477 if pattern.contains('*') {
478 let parts: Vec<&str> = pattern.split('*').collect();
479 if parts.len() == 2 {
480 return input.starts_with(parts[0]) && input.ends_with(parts[1]);
481 }
482 }
483 pattern == input
484}
485
486#[derive(Debug, Clone)]
492pub struct VectorSearchResult {
493 pub id: String,
495 pub score: f32,
497 pub content: String,
499 pub metadata: HashMap<String, SochValue>,
501}
502
503pub trait VectorIndex: Send + Sync {
510 fn search_by_embedding(
512 &self,
513 collection: &str,
514 embedding: &[f32],
515 k: usize,
516 min_score: Option<f32>,
517 ) -> Result<Vec<VectorSearchResult>, String>;
518
519 fn search_by_text(
521 &self,
522 collection: &str,
523 text: &str,
524 k: usize,
525 min_score: Option<f32>,
526 ) -> Result<Vec<VectorSearchResult>, String>;
527
528 fn stats(&self, collection: &str) -> Option<VectorIndexStats>;
530}
531
532#[derive(Debug, Clone)]
534pub struct VectorIndexStats {
535 pub vector_count: usize,
537 pub dimension: usize,
539 pub metric: String,
541}
542
543pub struct SimpleVectorIndex {
548 collections: std::sync::RwLock<HashMap<String, VectorCollection>>,
550}
551
552struct VectorCollection {
554 #[allow(clippy::type_complexity)]
556 vectors: Vec<(String, Vec<f32>, String, HashMap<String, SochValue>)>,
557 dimension: usize,
559}
560
561impl SimpleVectorIndex {
562 pub fn new() -> Self {
564 Self {
565 collections: std::sync::RwLock::new(HashMap::new()),
566 }
567 }
568
569 pub fn create_collection(&self, name: &str, dimension: usize) {
571 let mut collections = self.collections.write().unwrap();
572 collections
573 .entry(name.to_string())
574 .or_insert_with(|| VectorCollection {
575 vectors: Vec::new(),
576 dimension,
577 });
578 }
579
580 pub fn insert(
582 &self,
583 collection: &str,
584 id: String,
585 vector: Vec<f32>,
586 content: String,
587 metadata: HashMap<String, SochValue>,
588 ) -> Result<(), String> {
589 let mut collections = self.collections.write().unwrap();
590 let coll = collections
591 .get_mut(collection)
592 .ok_or_else(|| format!("Collection '{}' not found", collection))?;
593
594 if vector.len() != coll.dimension {
595 return Err(format!(
596 "Vector dimension mismatch: expected {}, got {}",
597 coll.dimension,
598 vector.len()
599 ));
600 }
601
602 coll.vectors.push((id, vector, content, metadata));
603 Ok(())
604 }
605
606 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
608 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
609 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
610 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
611 if norm_a == 0.0 || norm_b == 0.0 {
612 0.0
613 } else {
614 dot / (norm_a * norm_b)
615 }
616 }
617}
618
619impl Default for SimpleVectorIndex {
620 fn default() -> Self {
621 Self::new()
622 }
623}
624
625impl VectorIndex for SimpleVectorIndex {
626 fn search_by_embedding(
627 &self,
628 collection: &str,
629 embedding: &[f32],
630 k: usize,
631 min_score: Option<f32>,
632 ) -> Result<Vec<VectorSearchResult>, String> {
633 let collections = self.collections.read().unwrap();
634 let coll = collections
635 .get(collection)
636 .ok_or_else(|| format!("Collection '{}' not found", collection))?;
637
638 let mut scored: Vec<_> = coll
640 .vectors
641 .iter()
642 .map(|(id, vec, content, meta)| {
643 let score = Self::cosine_similarity(embedding, vec);
644 (id, score, content, meta)
645 })
646 .filter(|(_, score, _, _)| min_score.map(|min| *score >= min).unwrap_or(true))
647 .collect();
648
649 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
651
652 Ok(scored
654 .into_iter()
655 .take(k)
656 .map(|(id, score, content, meta)| VectorSearchResult {
657 id: id.clone(),
658 score,
659 content: content.clone(),
660 metadata: meta.clone(),
661 })
662 .collect())
663 }
664
665 fn search_by_text(
666 &self,
667 _collection: &str,
668 _text: &str,
669 _k: usize,
670 _min_score: Option<f32>,
671 ) -> Result<Vec<VectorSearchResult>, String> {
672 Err(
674 "Text-based search requires an embedding model. Use search_by_embedding instead."
675 .to_string(),
676 )
677 }
678
679 fn stats(&self, collection: &str) -> Option<VectorIndexStats> {
680 let collections = self.collections.read().unwrap();
681 collections.get(collection).map(|coll| VectorIndexStats {
682 vector_count: coll.vectors.len(),
683 dimension: coll.dimension,
684 metric: "cosine".to_string(),
685 })
686 }
687}
688
689pub struct HnswVectorIndex {
698 collections: std::sync::RwLock<HashMap<String, HnswCollection>>,
700}
701
702struct HnswCollection {
704 index: sochdb_index::vector::VectorIndex,
706 #[allow(clippy::type_complexity)]
708 metadata: HashMap<u128, (String, String, HashMap<String, SochValue>)>,
709 next_edge_id: u128,
711 dimension: usize,
713}
714
715impl HnswVectorIndex {
716 pub fn new() -> Self {
718 Self {
719 collections: std::sync::RwLock::new(HashMap::new()),
720 }
721 }
722
723 pub fn create_collection(&self, name: &str, dimension: usize) {
725 let mut collections = self.collections.write().unwrap();
726 collections.entry(name.to_string()).or_insert_with(|| {
727 let index = sochdb_index::vector::VectorIndex::with_dimension(
728 sochdb_index::vector::DistanceMetric::Cosine,
729 dimension,
730 );
731 HnswCollection {
732 index,
733 metadata: HashMap::new(),
734 next_edge_id: 0,
735 dimension,
736 }
737 });
738 }
739
740 pub fn insert(
742 &self,
743 collection: &str,
744 id: String,
745 vector: Vec<f32>,
746 content: String,
747 metadata: HashMap<String, SochValue>,
748 ) -> Result<(), String> {
749 let mut collections = self.collections.write().unwrap();
750 let coll = collections
751 .get_mut(collection)
752 .ok_or_else(|| format!("Collection '{}' not found", collection))?;
753
754 if vector.len() != coll.dimension {
755 return Err(format!(
756 "Vector dimension mismatch: expected {}, got {}",
757 coll.dimension,
758 vector.len()
759 ));
760 }
761
762 let edge_id = coll.next_edge_id;
764 coll.next_edge_id += 1;
765 coll.metadata.insert(edge_id, (id, content, metadata));
766
767 let embedding = ndarray::Array1::from_vec(vector);
769
770 coll.index.add(edge_id, embedding)?;
772
773 Ok(())
774 }
775
776 pub fn vector_count(&self, collection: &str) -> Option<usize> {
778 let collections = self.collections.read().unwrap();
779 collections.get(collection).map(|c| c.metadata.len())
780 }
781}
782
783impl Default for HnswVectorIndex {
784 fn default() -> Self {
785 Self::new()
786 }
787}
788
789impl VectorIndex for HnswVectorIndex {
790 fn search_by_embedding(
791 &self,
792 collection: &str,
793 embedding: &[f32],
794 k: usize,
795 min_score: Option<f32>,
796 ) -> Result<Vec<VectorSearchResult>, String> {
797 let collections = self.collections.read().unwrap();
798 let coll = collections
799 .get(collection)
800 .ok_or_else(|| format!("Collection '{}' not found", collection))?;
801
802 let query = ndarray::Array1::from_vec(embedding.to_vec());
804
805 let results = coll.index.search(&query, k)?;
807
808 let mut search_results = Vec::with_capacity(results.len());
811 for (edge_id, distance) in results {
812 let score = 1.0 - distance;
814
815 if let Some(min) = min_score {
817 if score < min {
818 continue;
819 }
820 }
821
822 if let Some((id, content, meta)) = coll.metadata.get(&edge_id) {
824 search_results.push(VectorSearchResult {
825 id: id.clone(),
826 score,
827 content: content.clone(),
828 metadata: meta.clone(),
829 });
830 }
831 }
832
833 Ok(search_results)
834 }
835
836 fn search_by_text(
837 &self,
838 _collection: &str,
839 _text: &str,
840 _k: usize,
841 _min_score: Option<f32>,
842 ) -> Result<Vec<VectorSearchResult>, String> {
843 Err(
845 "Text-based search requires an embedding model. Use search_by_embedding instead."
846 .to_string(),
847 )
848 }
849
850 fn stats(&self, collection: &str) -> Option<VectorIndexStats> {
851 let collections = self.collections.read().unwrap();
852 collections.get(collection).map(|coll| VectorIndexStats {
853 vector_count: coll.metadata.len(),
854 dimension: coll.dimension,
855 metric: "cosine".to_string(),
856 })
857 }
858}
859
860#[derive(Debug, Clone)]
866pub struct ContextResult {
867 pub context: String,
869 pub token_count: usize,
871 pub token_budget: usize,
873 pub sections_included: Vec<SectionResult>,
875 pub sections_truncated: Vec<String>,
877 pub sections_dropped: Vec<String>,
879}
880
881#[derive(Debug, Clone)]
883pub struct SectionResult {
884 pub name: String,
886 pub priority: i32,
888 pub content: String,
890 pub tokens: usize,
892 pub tokens_used: usize,
894 pub truncated: bool,
896 pub row_count: usize,
898}
899
900#[derive(Debug, Clone)]
906pub enum ContextQueryError {
907 SessionMismatch { expected: String, actual: String },
909 VariableNotFound(String),
911 InvalidVariableType { variable: String, expected: String },
913 BudgetExceeded {
915 section: String,
916 requested: usize,
917 available: usize,
918 },
919 BudgetExhausted(String),
921 PermissionDenied(String),
923 InvalidPath(String),
925 Parse(ContextParseError),
927 FormatError(String),
929 InvalidQuery(String),
931 VectorSearchError(String),
933}
934
935impl std::fmt::Display for ContextQueryError {
936 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
937 match self {
938 Self::SessionMismatch { expected, actual } => {
939 write!(f, "session mismatch: expected {}, got {}", expected, actual)
940 }
941 Self::VariableNotFound(name) => write!(f, "variable not found: {}", name),
942 Self::InvalidVariableType { variable, expected } => {
943 write!(
944 f,
945 "variable {} has invalid type, expected {}",
946 variable, expected
947 )
948 }
949 Self::BudgetExceeded {
950 section,
951 requested,
952 available,
953 } => {
954 write!(
955 f,
956 "section {} exceeds budget: {} > {}",
957 section, requested, available
958 )
959 }
960 Self::BudgetExhausted(msg) => write!(f, "budget exhausted: {}", msg),
961 Self::PermissionDenied(msg) => write!(f, "permission denied: {}", msg),
962 Self::InvalidPath(path) => write!(f, "invalid path: {}", path),
963 Self::Parse(e) => write!(f, "parse error: {}", e),
964 Self::FormatError(e) => write!(f, "format error: {}", e),
965 Self::InvalidQuery(msg) => write!(f, "invalid query: {}", msg),
966 Self::VectorSearchError(e) => write!(f, "vector search error: {}", e),
967 }
968 }
969}
970
971impl std::error::Error for ContextQueryError {}
972
973#[derive(Debug, Clone)]
975pub enum ContextParseError {
976 UnexpectedToken { expected: String, found: String },
978 MissingClause(String),
980 InvalidOption(String),
982 InvalidPath(String),
984 InvalidSection(String),
986 SyntaxError(String),
988}
989
990impl std::fmt::Display for ContextParseError {
991 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
992 match self {
993 Self::UnexpectedToken { expected, found } => {
994 write!(f, "expected {}, found '{}'", expected, found)
995 }
996 Self::MissingClause(clause) => write!(f, "missing {} clause", clause),
997 Self::InvalidOption(opt) => write!(f, "invalid option: {}", opt),
998 Self::InvalidPath(path) => write!(f, "invalid path: {}", path),
999 Self::InvalidSection(sec) => write!(f, "invalid section: {}", sec),
1000 Self::SyntaxError(msg) => write!(f, "syntax error: {}", msg),
1001 }
1002 }
1003}
1004
1005impl std::error::Error for ContextParseError {}
1006
1007pub struct ContextQueryParser {
1009 pos: usize,
1011 tokens: Vec<Token>,
1013}
1014
1015#[derive(Debug, Clone, PartialEq)]
1017enum Token {
1018 Keyword(String),
1020 Ident(String),
1022 Number(f64),
1024 String(String),
1026 Punct(char),
1028 Variable(String),
1030 Eof,
1032}
1033
1034impl ContextQueryParser {
1035 pub fn new(input: &str) -> Self {
1037 let tokens = Self::tokenize(input);
1038 Self { pos: 0, tokens }
1039 }
1040
1041 pub fn parse(&mut self) -> Result<ContextSelectQuery, ContextParseError> {
1043 self.expect_keyword("CONTEXT")?;
1045 self.expect_keyword("SELECT")?;
1046 let output_name = self.expect_ident()?;
1047
1048 let session = if self.match_keyword("FROM") {
1050 self.parse_session_reference()?
1051 } else {
1052 SessionReference::None
1053 };
1054
1055 let options = if self.match_keyword("WITH") {
1057 self.parse_options()?
1058 } else {
1059 ContextQueryOptions::default()
1060 };
1061
1062 self.expect_keyword("SECTIONS")?;
1064 let sections = self.parse_sections()?;
1065
1066 Ok(ContextSelectQuery {
1067 output_name,
1068 session,
1069 options,
1070 sections,
1071 })
1072 }
1073
1074 fn parse_session_reference(&mut self) -> Result<SessionReference, ContextParseError> {
1076 if self.match_keyword("session") {
1077 self.expect_punct('(')?;
1078 let var = self.expect_variable()?;
1079 self.expect_punct(')')?;
1080 Ok(SessionReference::Session(var))
1081 } else if self.match_keyword("agent") {
1082 self.expect_punct('(')?;
1083 let var = self.expect_variable()?;
1084 self.expect_punct(')')?;
1085 Ok(SessionReference::Agent(var))
1086 } else {
1087 Err(ContextParseError::SyntaxError(
1088 "expected 'session' or 'agent'".to_string(),
1089 ))
1090 }
1091 }
1092
1093 fn parse_options(&mut self) -> Result<ContextQueryOptions, ContextParseError> {
1095 self.expect_punct('(')?;
1096 let mut options = ContextQueryOptions::default();
1097
1098 loop {
1099 let key = self.expect_ident()?;
1100 self.expect_punct('=')?;
1101
1102 match key.as_str() {
1103 "token_limit" => {
1104 if let Token::Number(n) = self.current().clone() {
1105 options.token_limit = n as usize;
1106 self.advance();
1107 }
1108 }
1109 "include_schema" => {
1110 options.include_schema = self.parse_bool()?;
1111 }
1112 "format" => {
1113 let format = self.expect_ident()?;
1114 options.format = match format.to_lowercase().as_str() {
1115 "toon" => OutputFormat::Soch,
1116 "json" => OutputFormat::Json,
1117 "markdown" => OutputFormat::Markdown,
1118 _ => return Err(ContextParseError::InvalidOption(format)),
1119 };
1120 }
1121 "truncation" => {
1122 let strategy = self.expect_ident()?;
1123 options.truncation = match strategy.to_lowercase().as_str() {
1124 "tail_drop" | "taildrop" => TruncationStrategy::TailDrop,
1125 "head_drop" | "headdrop" => TruncationStrategy::HeadDrop,
1126 "proportional" => TruncationStrategy::Proportional,
1127 "fail" => TruncationStrategy::Fail,
1128 _ => return Err(ContextParseError::InvalidOption(strategy)),
1129 };
1130 }
1131 "include_headers" => {
1132 options.include_headers = self.parse_bool()?;
1133 }
1134 _ => return Err(ContextParseError::InvalidOption(key)),
1135 }
1136
1137 if !self.match_punct(',') {
1138 break;
1139 }
1140 }
1141
1142 self.expect_punct(')')?;
1143 Ok(options)
1144 }
1145
1146 fn parse_sections(&mut self) -> Result<Vec<ContextSection>, ContextParseError> {
1148 self.expect_punct('(')?;
1149 let mut sections = Vec::new();
1150
1151 loop {
1152 if self.check_punct(')') {
1153 break;
1154 }
1155
1156 let section = self.parse_section()?;
1157 sections.push(section);
1158
1159 if !self.match_punct(',') {
1160 break;
1161 }
1162 }
1163
1164 self.expect_punct(')')?;
1165 Ok(sections)
1166 }
1167
1168 fn parse_section(&mut self) -> Result<ContextSection, ContextParseError> {
1170 let name = self.expect_ident()?;
1172
1173 self.expect_keyword("PRIORITY")?;
1174 let priority = if let Token::Number(n) = self.current().clone() {
1175 let val = n as i32;
1176 self.advance();
1177 val
1178 } else {
1179 0
1180 };
1181
1182 self.expect_punct(':')?;
1183
1184 let content = self.parse_section_content()?;
1185
1186 Ok(ContextSection {
1187 name,
1188 priority,
1189 content,
1190 transform: None,
1191 })
1192 }
1193
1194 fn parse_section_content(&mut self) -> Result<SectionContent, ContextParseError> {
1196 if self.match_keyword("GET") {
1197 let path_str = self.collect_until(&[',', ')']);
1199 let path = PathExpression::parse(&path_str)?;
1200 Ok(SectionContent::Get { path })
1201 } else if self.match_keyword("LAST") {
1202 let count = if let Token::Number(n) = self.current().clone() {
1204 let val = n as usize;
1205 self.advance();
1206 val
1207 } else {
1208 10 };
1210
1211 self.expect_keyword("FROM")?;
1212 let table = self.expect_ident()?;
1213
1214 let where_clause = if self.match_keyword("WHERE") {
1215 Some(self.parse_where_clause()?)
1216 } else {
1217 None
1218 };
1219
1220 Ok(SectionContent::Last {
1221 count,
1222 table,
1223 where_clause,
1224 })
1225 } else if self.match_keyword("SEARCH") {
1226 let collection = self.expect_ident()?;
1228 self.expect_keyword("BY")?;
1229 self.expect_keyword("SIMILARITY")?;
1230
1231 self.expect_punct('(')?;
1232 let query = if let Token::Variable(v) = self.current().clone() {
1233 self.advance();
1234 SimilarityQuery::Variable(v)
1235 } else if let Token::String(s) = self.current().clone() {
1236 self.advance();
1237 SimilarityQuery::Text(s)
1238 } else {
1239 return Err(ContextParseError::SyntaxError(
1240 "expected variable or string for similarity query".to_string(),
1241 ));
1242 };
1243 self.expect_punct(')')?;
1244
1245 self.expect_keyword("TOP")?;
1246 let top_k = if let Token::Number(n) = self.current().clone() {
1247 let val = n as usize;
1248 self.advance();
1249 val
1250 } else {
1251 5 };
1253
1254 Ok(SectionContent::Search {
1255 collection,
1256 query,
1257 top_k,
1258 min_score: None,
1259 })
1260 } else if self.match_keyword("SELECT") {
1261 let columns = self.parse_column_list()?;
1263 self.expect_keyword("FROM")?;
1264 let table = self.expect_ident()?;
1265
1266 let where_clause = if self.match_keyword("WHERE") {
1267 Some(self.parse_where_clause()?)
1268 } else {
1269 None
1270 };
1271
1272 let limit = if self.match_keyword("LIMIT") {
1273 if let Token::Number(n) = self.current().clone() {
1274 let val = n as usize;
1275 self.advance();
1276 Some(val)
1277 } else {
1278 None
1279 }
1280 } else {
1281 None
1282 };
1283
1284 Ok(SectionContent::Select {
1285 columns,
1286 table,
1287 where_clause,
1288 limit,
1289 })
1290 } else if let Token::Variable(v) = self.current().clone() {
1291 self.advance();
1292 Ok(SectionContent::Variable { name: v })
1293 } else if let Token::String(s) = self.current().clone() {
1294 self.advance();
1295 Ok(SectionContent::Literal { value: s })
1296 } else {
1297 Err(ContextParseError::InvalidSection(
1298 "expected GET, LAST, SEARCH, SELECT, or literal".to_string(),
1299 ))
1300 }
1301 }
1302
1303 fn parse_where_clause(&mut self) -> Result<WhereClause, ContextParseError> {
1305 let mut conditions = Vec::new();
1306
1307 loop {
1308 let column = self.expect_ident()?;
1309 let operator = self.parse_comparison_op()?;
1310 let value = self.parse_value()?;
1311
1312 conditions.push(Condition {
1313 column,
1314 operator,
1315 value,
1316 });
1317
1318 if !self.match_keyword("AND") && !self.match_keyword("OR") {
1319 break;
1320 }
1321 }
1322
1323 Ok(WhereClause {
1324 conditions,
1325 operator: LogicalOp::And,
1326 })
1327 }
1328
1329 fn parse_comparison_op(&mut self) -> Result<ComparisonOp, ContextParseError> {
1331 match self.current() {
1332 Token::Punct('=') => {
1333 self.advance();
1334 Ok(ComparisonOp::Eq)
1335 }
1336 Token::Punct('>') => {
1337 self.advance();
1338 if self.check_punct('=') {
1339 self.advance();
1340 Ok(ComparisonOp::Ge)
1341 } else {
1342 Ok(ComparisonOp::Gt)
1343 }
1344 }
1345 Token::Punct('<') => {
1346 self.advance();
1347 if self.check_punct('=') {
1348 self.advance();
1349 Ok(ComparisonOp::Le)
1350 } else {
1351 Ok(ComparisonOp::Lt)
1352 }
1353 }
1354 _ => {
1355 if self.match_keyword("LIKE") {
1356 Ok(ComparisonOp::Like)
1357 } else if self.match_keyword("IN") {
1358 Ok(ComparisonOp::In)
1359 } else {
1360 Err(ContextParseError::SyntaxError(
1361 "expected comparison operator".to_string(),
1362 ))
1363 }
1364 }
1365 }
1366 }
1367
1368 fn parse_value(&mut self) -> Result<SochValue, ContextParseError> {
1370 match self.current().clone() {
1371 Token::Number(n) => {
1372 self.advance();
1373 if n.fract() == 0.0 {
1374 Ok(SochValue::Int(n as i64))
1375 } else {
1376 Ok(SochValue::Float(n))
1377 }
1378 }
1379 Token::String(s) => {
1380 self.advance();
1381 Ok(SochValue::Text(s))
1382 }
1383 Token::Keyword(k) if k.eq_ignore_ascii_case("null") => {
1384 self.advance();
1385 Ok(SochValue::Null)
1386 }
1387 Token::Keyword(k) if k.eq_ignore_ascii_case("true") => {
1388 self.advance();
1389 Ok(SochValue::Bool(true))
1390 }
1391 Token::Keyword(k) if k.eq_ignore_ascii_case("false") => {
1392 self.advance();
1393 Ok(SochValue::Bool(false))
1394 }
1395 Token::Variable(v) => {
1396 self.advance();
1397 Ok(SochValue::Text(format!("${}", v)))
1399 }
1400 _ => Err(ContextParseError::SyntaxError("expected value".to_string())),
1401 }
1402 }
1403
1404 fn parse_column_list(&mut self) -> Result<Vec<String>, ContextParseError> {
1406 let mut columns = Vec::new();
1407
1408 if self.check_punct('*') {
1409 self.advance();
1410 columns.push("*".to_string());
1411 } else {
1412 loop {
1413 columns.push(self.expect_ident()?);
1414 if !self.match_punct(',') {
1415 break;
1416 }
1417 }
1418 }
1419
1420 Ok(columns)
1421 }
1422
1423 fn parse_bool(&mut self) -> Result<bool, ContextParseError> {
1425 match self.current() {
1426 Token::Keyword(k) if k.eq_ignore_ascii_case("true") => {
1427 self.advance();
1428 Ok(true)
1429 }
1430 Token::Keyword(k) if k.eq_ignore_ascii_case("false") => {
1431 self.advance();
1432 Ok(false)
1433 }
1434 _ => Err(ContextParseError::SyntaxError(
1435 "expected boolean".to_string(),
1436 )),
1437 }
1438 }
1439
1440 fn tokenize(input: &str) -> Vec<Token> {
1442 let mut tokens = Vec::new();
1443 let mut chars = input.chars().peekable();
1444
1445 while let Some(&ch) = chars.peek() {
1446 match ch {
1447 ' ' | '\t' | '\n' | '\r' => {
1449 chars.next();
1450 }
1451
1452 '(' | ')' | ',' | ':' | '=' | '<' | '>' | '*' | '{' | '}' | '.' => {
1454 tokens.push(Token::Punct(ch));
1455 chars.next();
1456 }
1457
1458 '$' => {
1460 chars.next();
1461 let mut name = String::new();
1462 while let Some(&c) = chars.peek() {
1463 if c.is_alphanumeric() || c == '_' {
1464 name.push(c);
1465 chars.next();
1466 } else {
1467 break;
1468 }
1469 }
1470 tokens.push(Token::Variable(name));
1471 }
1472
1473 '\'' | '"' => {
1475 let quote = ch;
1476 chars.next();
1477 let mut s = String::new();
1478 while let Some(&c) = chars.peek() {
1479 if c == quote {
1480 chars.next(); break;
1482 }
1483 s.push(c);
1484 chars.next();
1485 }
1486 tokens.push(Token::String(s));
1487 }
1488
1489 '0'..='9' | '-' => {
1491 let mut num_str = String::new();
1492 if ch == '-' {
1493 num_str.push(ch);
1494 chars.next();
1495 }
1496 while let Some(&c) = chars.peek() {
1497 if c.is_ascii_digit() || c == '.' {
1498 num_str.push(c);
1499 chars.next();
1500 } else {
1501 break;
1502 }
1503 }
1504 if let Ok(n) = num_str.parse::<f64>() {
1505 tokens.push(Token::Number(n));
1506 }
1507 }
1508
1509 'a'..='z' | 'A'..='Z' | '_' => {
1511 let mut ident = String::new();
1512 while let Some(&c) = chars.peek() {
1513 if c.is_alphanumeric() || c == '_' {
1514 ident.push(c);
1515 chars.next();
1516 } else {
1517 break;
1518 }
1519 }
1520
1521 let keywords = [
1523 "CONTEXT",
1524 "SELECT",
1525 "FROM",
1526 "WITH",
1527 "SECTIONS",
1528 "PRIORITY",
1529 "GET",
1530 "LAST",
1531 "SEARCH",
1532 "BY",
1533 "SIMILARITY",
1534 "TOP",
1535 "WHERE",
1536 "AND",
1537 "OR",
1538 "LIKE",
1539 "IN",
1540 "LIMIT",
1541 "session",
1542 "agent",
1543 "true",
1544 "false",
1545 "null",
1546 ];
1547
1548 if keywords.iter().any(|k| k.eq_ignore_ascii_case(&ident)) {
1549 tokens.push(Token::Keyword(ident.to_uppercase()));
1550 } else {
1551 tokens.push(Token::Ident(ident));
1552 }
1553 }
1554
1555 _ => {
1557 chars.next();
1558 }
1559 }
1560 }
1561
1562 tokens.push(Token::Eof);
1563 tokens
1564 }
1565
1566 fn current(&self) -> &Token {
1568 self.tokens.get(self.pos).unwrap_or(&Token::Eof)
1569 }
1570
1571 fn advance(&mut self) {
1572 if self.pos < self.tokens.len() {
1573 self.pos += 1;
1574 }
1575 }
1576
1577 fn expect_keyword(&mut self, kw: &str) -> Result<(), ContextParseError> {
1578 match self.current() {
1579 Token::Keyword(k) if k.eq_ignore_ascii_case(kw) => {
1580 self.advance();
1581 Ok(())
1582 }
1583 other => Err(ContextParseError::UnexpectedToken {
1584 expected: kw.to_string(),
1585 found: format!("{:?}", other),
1586 }),
1587 }
1588 }
1589
1590 fn match_keyword(&mut self, kw: &str) -> bool {
1591 match self.current() {
1592 Token::Keyword(k) if k.eq_ignore_ascii_case(kw) => {
1593 self.advance();
1594 true
1595 }
1596 _ => false,
1597 }
1598 }
1599
1600 fn expect_ident(&mut self) -> Result<String, ContextParseError> {
1601 match self.current().clone() {
1602 Token::Ident(s) => {
1603 self.advance();
1604 Ok(s)
1605 }
1606 Token::Keyword(s) => {
1607 self.advance();
1609 Ok(s)
1610 }
1611 other => Err(ContextParseError::UnexpectedToken {
1612 expected: "identifier".to_string(),
1613 found: format!("{:?}", other),
1614 }),
1615 }
1616 }
1617
1618 fn expect_variable(&mut self) -> Result<String, ContextParseError> {
1619 match self.current().clone() {
1620 Token::Variable(v) => {
1621 self.advance();
1622 Ok(v)
1623 }
1624 other => Err(ContextParseError::UnexpectedToken {
1625 expected: "variable ($name)".to_string(),
1626 found: format!("{:?}", other),
1627 }),
1628 }
1629 }
1630
1631 fn expect_punct(&mut self, p: char) -> Result<(), ContextParseError> {
1632 match self.current() {
1633 Token::Punct(c) if *c == p => {
1634 self.advance();
1635 Ok(())
1636 }
1637 other => Err(ContextParseError::UnexpectedToken {
1638 expected: p.to_string(),
1639 found: format!("{:?}", other),
1640 }),
1641 }
1642 }
1643
1644 fn match_punct(&mut self, p: char) -> bool {
1645 match self.current() {
1646 Token::Punct(c) if *c == p => {
1647 self.advance();
1648 true
1649 }
1650 _ => false,
1651 }
1652 }
1653
1654 fn check_punct(&self, p: char) -> bool {
1655 matches!(self.current(), Token::Punct(c) if *c == p)
1656 }
1657
1658 fn collect_until(&mut self, terminators: &[char]) -> String {
1659 let mut result = String::new();
1660 let mut depth = 0;
1661
1662 loop {
1663 match self.current() {
1664 Token::Punct('{') => {
1665 depth += 1;
1666 result.push('{');
1667 self.advance();
1668 }
1669 Token::Punct('}') => {
1670 depth -= 1;
1671 result.push('}');
1672 self.advance();
1673 }
1674 Token::Punct(c) if depth == 0 && terminators.contains(c) => {
1675 break;
1676 }
1677 Token::Punct(c) => {
1678 result.push(*c);
1679 self.advance();
1680 }
1681 Token::Ident(s) | Token::Keyword(s) => {
1682 if !result.is_empty() && !result.ends_with(['.', '{']) {
1683 result.push(' ');
1684 }
1685 result.push_str(s);
1686 self.advance();
1687 }
1688 Token::Eof => break,
1689 _ => {
1690 self.advance();
1691 }
1692 }
1693 }
1694
1695 result.trim().to_string()
1696 }
1697}
1698
1699use crate::agent_context::{AgentContext, AuditOperation, ContextValue};
1704
1705pub struct AgentContextIntegration<'a> {
1715 context: &'a mut AgentContext,
1717 budget_enforcer: TokenBudgetEnforcer,
1719 estimator: TokenEstimator,
1721 vector_index: Option<std::sync::Arc<dyn VectorIndex>>,
1723 embedding_provider: Option<std::sync::Arc<dyn EmbeddingProvider>>,
1725}
1726
1727pub trait EmbeddingProvider: Send + Sync {
1731 fn embed_text(&self, text: &str) -> Result<Vec<f32>, String>;
1733
1734 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, String> {
1736 texts.iter().map(|t| self.embed_text(t)).collect()
1737 }
1738
1739 fn dimension(&self) -> usize;
1741
1742 fn model_name(&self) -> &str;
1744}
1745
1746impl<'a> AgentContextIntegration<'a> {
1747 pub fn new(context: &'a mut AgentContext) -> Self {
1749 let config = TokenBudgetConfig {
1750 total_budget: context.budget.max_tokens.unwrap_or(4096) as usize,
1751 ..Default::default()
1752 };
1753
1754 Self {
1755 context,
1756 budget_enforcer: TokenBudgetEnforcer::new(config),
1757 estimator: TokenEstimator::default(),
1758 vector_index: None,
1759 embedding_provider: None,
1760 }
1761 }
1762
1763 pub fn with_vector_index(
1765 context: &'a mut AgentContext,
1766 vector_index: std::sync::Arc<dyn VectorIndex>,
1767 ) -> Self {
1768 let config = TokenBudgetConfig {
1769 total_budget: context.budget.max_tokens.unwrap_or(4096) as usize,
1770 ..Default::default()
1771 };
1772
1773 Self {
1774 context,
1775 budget_enforcer: TokenBudgetEnforcer::new(config),
1776 estimator: TokenEstimator::default(),
1777 vector_index: Some(vector_index),
1778 embedding_provider: None,
1779 }
1780 }
1781
1782 pub fn with_vector_and_embedding(
1784 context: &'a mut AgentContext,
1785 vector_index: std::sync::Arc<dyn VectorIndex>,
1786 embedding_provider: std::sync::Arc<dyn EmbeddingProvider>,
1787 ) -> Self {
1788 let config = TokenBudgetConfig {
1789 total_budget: context.budget.max_tokens.unwrap_or(4096) as usize,
1790 ..Default::default()
1791 };
1792
1793 Self {
1794 context,
1795 budget_enforcer: TokenBudgetEnforcer::new(config),
1796 estimator: TokenEstimator::default(),
1797 vector_index: Some(vector_index),
1798 embedding_provider: Some(embedding_provider),
1799 }
1800 }
1801
1802 pub fn set_embedding_provider(&mut self, provider: std::sync::Arc<dyn EmbeddingProvider>) {
1804 self.embedding_provider = Some(provider);
1805 }
1806
1807 pub fn set_vector_index(&mut self, index: std::sync::Arc<dyn VectorIndex>) {
1809 self.vector_index = Some(index);
1810 }
1811
1812 pub fn execute(
1814 &mut self,
1815 query: &ContextSelectQuery,
1816 ) -> Result<ContextQueryResult, ContextQueryError> {
1817 self.validate_session(&query.session)?;
1819
1820 self.context.audit.push(crate::agent_context::AuditEntry {
1822 timestamp: std::time::SystemTime::now(),
1823 operation: AuditOperation::DbQuery,
1824 resource: format!("CONTEXT SELECT {}", query.output_name),
1825 result: crate::agent_context::AuditResult::Success,
1826 metadata: std::collections::HashMap::new(),
1827 });
1828
1829 let resolved_sections = self.resolve_sections(&query.sections)?;
1831
1832 for section in &resolved_sections {
1834 self.check_section_permissions(section)?;
1835 }
1836
1837 let mut section_contents: Vec<(ContextSection, String)> = Vec::new();
1839 for section in &resolved_sections {
1840 let content = self.execute_section_content(section, query.options.token_limit)?;
1841 section_contents.push((section.clone(), content));
1842 }
1843
1844 let budget_sections: Vec<BudgetSection> = section_contents
1846 .iter()
1847 .map(|(section, content)| {
1848 let estimated = self.estimator.estimate_text(content);
1849 let minimum = if query.options.truncation == TruncationStrategy::Fail {
1851 None
1852 } else {
1853 Some(estimated.min(100).max(estimated / 10))
1854 };
1855 BudgetSection {
1856 name: section.name.clone(),
1857 estimated_tokens: estimated,
1858 minimum_tokens: minimum,
1859 priority: section.priority,
1860 required: section.priority == 0, weight: 1.0,
1862 }
1863 })
1864 .collect();
1865
1866 let allocation = self.budget_enforcer.allocate_sections(&budget_sections);
1868
1869 let mut result = ContextQueryResult::new(query.output_name.clone());
1871 result.format = query.options.format;
1872 result.allocation_explain = Some(allocation.explain.clone());
1873
1874 for (section, content) in section_contents.iter() {
1876 if allocation.full_sections.contains(§ion.name) {
1877 let tokens = self.estimator.estimate_text(content);
1878 result.sections.push(SectionResult {
1879 name: section.name.clone(),
1880 priority: section.priority,
1881 content: content.clone(),
1882 tokens,
1883 tokens_used: tokens,
1884 truncated: false,
1885 row_count: 0,
1886 });
1887 }
1888 }
1889
1890 for (section_name, _original, truncated_to) in &allocation.truncated_sections {
1892 if let Some((section, content)) = section_contents
1893 .iter()
1894 .find(|(s, _)| &s.name == section_name)
1895 {
1896 let truncated = self.estimator.truncate_to_tokens(content, *truncated_to);
1898 let actual_tokens = self.estimator.estimate_text(&truncated);
1899 result.sections.push(SectionResult {
1900 name: section.name.clone(),
1901 priority: section.priority,
1902 content: truncated,
1903 tokens: actual_tokens,
1904 tokens_used: actual_tokens,
1905 truncated: true,
1906 row_count: 0,
1907 });
1908 }
1909 }
1910
1911 result.sections.sort_by_key(|s| s.priority);
1913
1914 result.total_tokens = allocation.tokens_allocated;
1915 result.token_limit = query.options.token_limit;
1916
1917 self.context
1919 .consume_budget(result.total_tokens as u64, 0)
1920 .map_err(|e| ContextQueryError::BudgetExhausted(e.to_string()))?;
1921
1922 Ok(result)
1923 }
1924
1925 pub fn execute_explain(
1927 &mut self,
1928 query: &ContextSelectQuery,
1929 ) -> Result<(ContextQueryResult, String), ContextQueryError> {
1930 let result = self.execute(query)?;
1931 let explain = result
1932 .allocation_explain
1933 .as_ref()
1934 .map(|decisions| {
1935 use crate::token_budget::BudgetAllocation;
1936 let allocation = BudgetAllocation {
1937 full_sections: result
1938 .sections
1939 .iter()
1940 .filter(|s| !s.truncated)
1941 .map(|s| s.name.clone())
1942 .collect(),
1943 truncated_sections: result
1944 .sections
1945 .iter()
1946 .filter(|s| s.truncated)
1947 .map(|s| (s.name.clone(), s.tokens, s.tokens_used))
1948 .collect(),
1949 dropped_sections: Vec::new(),
1950 tokens_allocated: result.total_tokens,
1951 tokens_remaining: result.token_limit.saturating_sub(result.total_tokens),
1952 explain: decisions.clone(),
1953 };
1954 allocation.explain_text()
1955 })
1956 .unwrap_or_else(|| "No allocation explain available".to_string());
1957 Ok((result, explain))
1958 }
1959
1960 fn validate_session(&self, session_ref: &SessionReference) -> Result<(), ContextQueryError> {
1962 match session_ref {
1963 SessionReference::Session(sid) => {
1964 if sid.starts_with('$') {
1966 return Ok(());
1967 }
1968 if sid != &self.context.session_id && sid != "*" {
1970 return Err(ContextQueryError::SessionMismatch {
1971 expected: sid.clone(),
1972 actual: self.context.session_id.clone(),
1973 });
1974 }
1975 }
1976 SessionReference::Agent(aid) => {
1977 if let Some(ContextValue::String(agent_id)) = self.context.peek_var("agent_id")
1979 && aid != agent_id
1980 && aid != "*"
1981 {
1982 return Err(ContextQueryError::SessionMismatch {
1983 expected: aid.clone(),
1984 actual: agent_id.clone(),
1985 });
1986 }
1987 }
1988 SessionReference::None => {}
1989 }
1990 Ok(())
1991 }
1992
1993 fn resolve_sections(
1995 &self,
1996 sections: &[ContextSection],
1997 ) -> Result<Vec<ContextSection>, ContextQueryError> {
1998 let mut resolved = Vec::new();
1999
2000 for section in sections {
2001 let mut resolved_section = section.clone();
2002
2003 resolved_section.content = match §ion.content {
2005 SectionContent::Literal { value } => {
2006 let resolved_value = self.resolve_variables(value);
2007 SectionContent::Literal {
2008 value: resolved_value,
2009 }
2010 }
2011 SectionContent::Variable { name } => {
2012 if let Some(value) = self.context.peek_var(name) {
2013 SectionContent::Literal {
2014 value: value.to_string(),
2015 }
2016 } else {
2017 return Err(ContextQueryError::VariableNotFound(name.clone()));
2018 }
2019 }
2020 SectionContent::Search {
2021 collection,
2022 query,
2023 top_k,
2024 min_score,
2025 } => {
2026 let resolved_query = match query {
2027 SimilarityQuery::Variable(var) => {
2028 if let Some(value) = self.context.peek_var(var) {
2029 match value {
2030 ContextValue::String(s) => SimilarityQuery::Text(s.clone()),
2031 ContextValue::List(l) => {
2032 let vec: Vec<f32> = l
2033 .iter()
2034 .filter_map(|v| match v {
2035 ContextValue::Number(n) => Some(*n as f32),
2036 _ => None,
2037 })
2038 .collect();
2039 SimilarityQuery::Embedding(vec)
2040 }
2041 _ => {
2042 return Err(ContextQueryError::InvalidVariableType {
2043 variable: var.clone(),
2044 expected: "string or vector".to_string(),
2045 });
2046 }
2047 }
2048 } else {
2049 return Err(ContextQueryError::VariableNotFound(var.clone()));
2050 }
2051 }
2052 other => other.clone(),
2053 };
2054 SectionContent::Search {
2055 collection: collection.clone(),
2056 query: resolved_query,
2057 top_k: *top_k,
2058 min_score: *min_score,
2059 }
2060 }
2061 other => other.clone(),
2062 };
2063
2064 resolved.push(resolved_section);
2065 }
2066
2067 Ok(resolved)
2068 }
2069
2070 fn resolve_variables(&self, input: &str) -> String {
2072 self.context.substitute_vars(input)
2073 }
2074
2075 fn check_section_permissions(&self, section: &ContextSection) -> Result<(), ContextQueryError> {
2077 match §ion.content {
2078 SectionContent::Get { path } => {
2079 let path_str = path.to_path_string();
2081 if path_str.starts_with('/') {
2082 self.context
2083 .check_fs_permission(&path_str, AuditOperation::FsRead)
2084 .map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
2085 } else {
2086 let table = path
2088 .segments
2089 .first()
2090 .ok_or_else(|| ContextQueryError::InvalidPath("empty path".to_string()))?;
2091 self.context
2092 .check_db_permission(table, AuditOperation::DbQuery)
2093 .map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
2094 }
2095 }
2096 SectionContent::Last { table, .. } | SectionContent::Select { table, .. } => {
2097 self.context
2098 .check_db_permission(table, AuditOperation::DbQuery)
2099 .map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
2100 }
2101 SectionContent::Search { collection, .. } => {
2102 self.context
2103 .check_db_permission(collection, AuditOperation::DbQuery)
2104 .map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
2105 }
2106 SectionContent::Literal { .. } | SectionContent::Variable { .. } => {
2107 }
2109 SectionContent::ToolRegistry { .. } | SectionContent::ToolCalls { .. } => {
2110 }
2112 }
2113 Ok(())
2114 }
2115
2116 fn execute_section_content(
2118 &self,
2119 section: &ContextSection,
2120 _budget: usize,
2121 ) -> Result<String, ContextQueryError> {
2122 match §ion.content {
2125 SectionContent::Literal { value } => Ok(value.clone()),
2126 SectionContent::Variable { name } => self
2127 .context
2128 .peek_var(name)
2129 .map(|v| v.to_string())
2130 .ok_or_else(|| ContextQueryError::VariableNotFound(name.clone())),
2131 SectionContent::Get { path } => {
2132 Ok(format!(
2134 "[{}: path={}]",
2135 section.name,
2136 path.to_path_string()
2137 ))
2138 }
2139 SectionContent::Last { count, table, .. } => {
2140 Ok(format!("[{}: last {} from {}]", section.name, count, table))
2142 }
2143 SectionContent::Search {
2144 collection,
2145 query: similarity_query,
2146 top_k,
2147 min_score,
2148 } => {
2149 match &self.vector_index {
2151 Some(index) => {
2152 let results = match similarity_query {
2154 SimilarityQuery::Embedding(emb) => {
2155 index.search_by_embedding(collection, emb, *top_k, *min_score)
2156 }
2157 SimilarityQuery::Text(text) => {
2158 self.search_by_text_with_embedding(
2160 index, collection, text, *top_k, *min_score,
2161 )
2162 }
2163 SimilarityQuery::Variable(var_name) => {
2164 match self.context.peek_var(var_name) {
2166 Some(ContextValue::String(text)) => {
2167 self.search_by_text_with_embedding(
2168 index, collection, text, *top_k, *min_score,
2169 )
2170 }
2171 Some(ContextValue::List(list)) => {
2172 let embedding: Result<Vec<f32>, _> = list
2174 .iter()
2175 .map(|v| match v {
2176 ContextValue::Number(n) => Ok(*n as f32),
2177 ContextValue::String(s) => {
2178 s.parse::<f32>().map_err(|_| "not a number")
2179 }
2180 _ => Err("not a number"),
2181 })
2182 .collect();
2183
2184 match embedding {
2185 Ok(emb) => index.search_by_embedding(
2186 collection, &emb, *top_k, *min_score,
2187 ),
2188 Err(_) => {
2189 Err("Variable is not a valid embedding vector"
2190 .to_string())
2191 }
2192 }
2193 }
2194 _ => Err(format!(
2195 "Variable '{}' not found or has wrong type",
2196 var_name
2197 )),
2198 }
2199 }
2200 };
2201
2202 match results {
2203 Ok(search_results) => {
2204 self.format_search_results(§ion.name, &search_results)
2206 }
2207 Err(e) => {
2208 Ok(format!("[{}: search error: {}]", section.name, e))
2210 }
2211 }
2212 }
2213 None => {
2214 Ok(format!(
2216 "[{}: search {} top {}]",
2217 section.name, collection, top_k
2218 ))
2219 }
2220 }
2221 }
2222 SectionContent::Select { table, limit, .. } => {
2223 let limit_str = limit.map(|l| format!(" limit {}", l)).unwrap_or_default();
2225 Ok(format!(
2226 "[{}: select from {}{}]",
2227 section.name, table, limit_str
2228 ))
2229 }
2230 SectionContent::ToolRegistry {
2231 include,
2232 exclude,
2233 include_schema,
2234 } => {
2235 self.format_tool_registry(include, exclude, *include_schema)
2237 }
2238 SectionContent::ToolCalls {
2239 count,
2240 tool_filter,
2241 status_filter,
2242 include_outputs,
2243 } => {
2244 self.format_tool_calls(*count, tool_filter.as_deref(), status_filter.as_deref(), *include_outputs)
2246 }
2247 }
2248 }
2249
2250 fn format_tool_registry(
2252 &self,
2253 include: &[String],
2254 exclude: &[String],
2255 include_schema: bool,
2256 ) -> Result<String, ContextQueryError> {
2257 use std::fmt::Write;
2258
2259 let tools = &self.context.tool_registry;
2261 let mut output = String::new();
2262
2263 writeln!(output, "[tool_registry ({} tools)]", tools.len())
2264 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2265
2266 for tool in tools {
2267 if !include.is_empty() && !include.contains(&tool.name) {
2269 continue;
2270 }
2271 if exclude.contains(&tool.name) {
2272 continue;
2273 }
2274
2275 writeln!(output, " [{}]", tool.name)
2276 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2277 writeln!(output, " description = {:?}", tool.description)
2278 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2279
2280 if include_schema {
2281 if let Some(schema) = &tool.parameters_schema {
2282 writeln!(output, " parameters = {}", schema)
2283 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2284 }
2285 }
2286 }
2287
2288 Ok(output)
2289 }
2290
2291 fn format_tool_calls(
2293 &self,
2294 count: usize,
2295 tool_filter: Option<&str>,
2296 status_filter: Option<&str>,
2297 include_outputs: bool,
2298 ) -> Result<String, ContextQueryError> {
2299 use std::fmt::Write;
2300
2301 let calls = &self.context.tool_calls;
2303 let mut output = String::new();
2304
2305 let filtered: Vec<_> = calls
2307 .iter()
2308 .filter(|call| {
2309 tool_filter.map(|f| call.tool_name == f).unwrap_or(true)
2310 && status_filter
2311 .map(|s| {
2312 match s {
2313 "success" => call.result.is_some() && call.error.is_none(),
2314 "error" => call.error.is_some(),
2315 "pending" => call.result.is_none() && call.error.is_none(),
2316 _ => true,
2317 }
2318 })
2319 .unwrap_or(true)
2320 })
2321 .rev() .take(count)
2323 .collect();
2324
2325 writeln!(output, "[tool_calls ({} calls)]", filtered.len())
2326 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2327
2328 for call in filtered {
2329 writeln!(output, " [call {}]", call.call_id)
2330 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2331 writeln!(output, " tool = {:?}", call.tool_name)
2332 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2333 writeln!(output, " arguments = {:?}", call.arguments)
2334 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2335
2336 if include_outputs {
2337 if let Some(result) = &call.result {
2338 writeln!(output, " result = {:?}", result)
2339 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2340 }
2341 if let Some(error) = &call.error {
2342 writeln!(output, " error = {:?}", error)
2343 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2344 }
2345 }
2346 }
2347
2348 Ok(output)
2349 }
2350
2351 fn search_by_text_with_embedding(
2357 &self,
2358 index: &std::sync::Arc<dyn VectorIndex>,
2359 collection: &str,
2360 text: &str,
2361 k: usize,
2362 min_score: Option<f32>,
2363 ) -> Result<Vec<VectorSearchResult>, String> {
2364 match &self.embedding_provider {
2365 Some(provider) => {
2366 let embedding = provider.embed_text(text)?;
2368 index.search_by_embedding(collection, &embedding, k, min_score)
2370 }
2371 None => {
2372 index.search_by_text(collection, text, k, min_score)
2374 }
2375 }
2376 }
2377
2378 fn format_search_results(
2380 &self,
2381 section_name: &str,
2382 results: &[VectorSearchResult],
2383 ) -> Result<String, ContextQueryError> {
2384 use std::fmt::Write;
2385
2386 let mut output = String::new();
2387 writeln!(output, "[{} ({} results)]", section_name, results.len())
2388 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2389
2390 for (i, result) in results.iter().enumerate() {
2391 writeln!(output, " [result {} score={:.4}]", i + 1, result.score)
2392 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2393 writeln!(output, " id = {}", result.id)
2394 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2395
2396 for line in result.content.lines() {
2398 writeln!(output, " {}", line)
2399 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2400 }
2401
2402 if !result.metadata.is_empty() {
2404 writeln!(output, " [metadata]")
2405 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2406 for (key, value) in &result.metadata {
2407 writeln!(output, " {} = {:?}", key, value)
2408 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2409 }
2410 }
2411 }
2412
2413 Ok(output)
2414 }
2415
2416 #[allow(dead_code)]
2418 fn truncate_content(
2419 &self,
2420 content: &str,
2421 max_tokens: usize,
2422 strategy: TruncationStrategy,
2423 ) -> String {
2424 let max_chars = max_tokens * 4;
2426
2427 if content.len() <= max_chars {
2428 return content.to_string();
2429 }
2430
2431 match strategy {
2432 TruncationStrategy::TailDrop => {
2433 let mut result: String = content.chars().take(max_chars - 3).collect();
2434 result.push_str("...");
2435 result
2436 }
2437 TruncationStrategy::HeadDrop => {
2438 let skip = content.len() - max_chars + 3;
2439 let mut result = "...".to_string();
2440 result.extend(content.chars().skip(skip));
2441 result
2442 }
2443 TruncationStrategy::Proportional => {
2444 let quarter = max_chars / 4;
2446 let first: String = content.chars().take(quarter).collect();
2447 let last: String = content
2448 .chars()
2449 .skip(content.len().saturating_sub(quarter))
2450 .collect();
2451 format!("{}...{}...", first, last)
2452 }
2453 TruncationStrategy::Fail => {
2454 content.to_string() }
2456 }
2457 }
2458
2459 pub fn get_session_context(&self) -> HashMap<String, String> {
2461 self.context
2462 .variables
2463 .iter()
2464 .map(|(k, v)| (k.clone(), v.to_string()))
2465 .collect()
2466 }
2467
2468 pub fn set_variable(&mut self, name: &str, value: ContextValue) {
2470 self.context.set_var(name, value);
2471 }
2472
2473 pub fn remaining_budget(&self) -> u64 {
2475 self.context
2476 .budget
2477 .max_tokens
2478 .map(|max| max.saturating_sub(self.context.budget.tokens_used))
2479 .unwrap_or(u64::MAX)
2480 }
2481}
2482
2483#[derive(Debug, Clone)]
2485pub struct ContextQueryResult {
2486 pub output_name: String,
2488 pub sections: Vec<SectionResult>,
2490 pub total_tokens: usize,
2492 pub token_limit: usize,
2494 pub format: OutputFormat,
2496 pub allocation_explain: Option<Vec<crate::token_budget::AllocationDecision>>,
2498}
2499
2500impl ContextQueryResult {
2501 fn new(output_name: String) -> Self {
2502 Self {
2503 output_name,
2504 sections: Vec::new(),
2505 total_tokens: 0,
2506 token_limit: 0,
2507 format: OutputFormat::Soch,
2508 allocation_explain: None,
2509 }
2510 }
2511
2512 pub fn render(&self) -> String {
2514 let mut output = String::new();
2515
2516 match self.format {
2517 OutputFormat::Soch => {
2518 output.push_str(&format!("{}[{}]:\n", self.output_name, self.sections.len()));
2520 for section in &self.sections {
2521 output.push_str(&format!(
2522 " {}[{}{}]:\n",
2523 section.name,
2524 section.tokens_used,
2525 if section.truncated { "T" } else { "" }
2526 ));
2527 for line in section.content.lines() {
2528 output.push_str(&format!(" {}\n", line));
2529 }
2530 }
2531 }
2532 OutputFormat::Json => {
2533 output.push_str("{\n");
2534 output.push_str(&format!(" \"name\": \"{}\",\n", self.output_name));
2535 output.push_str(&format!(" \"total_tokens\": {},\n", self.total_tokens));
2536 output.push_str(" \"sections\": [\n");
2537 for (i, section) in self.sections.iter().enumerate() {
2538 output.push_str(&format!(" {{\"name\": \"{}\", \"tokens\": {}, \"truncated\": {}, \"content\": \"{}\"}}",
2539 section.name,
2540 section.tokens_used,
2541 section.truncated,
2542 section.content.replace('"', "\\\"").replace('\n', "\\n")
2543 ));
2544 if i < self.sections.len() - 1 {
2545 output.push(',');
2546 }
2547 output.push('\n');
2548 }
2549 output.push_str(" ]\n}");
2550 }
2551 OutputFormat::Markdown => {
2552 output.push_str(&format!("# {}\n\n", self.output_name));
2553 output.push_str(&format!(
2554 "*Tokens: {}/{}*\n\n",
2555 self.total_tokens, self.token_limit
2556 ));
2557 for section in &self.sections {
2558 output.push_str(&format!("## {}", section.name));
2559 if section.truncated {
2560 output.push_str(" *(truncated)*");
2561 }
2562 output.push_str("\n\n");
2563 output.push_str(§ion.content);
2564 output.push_str("\n\n");
2565 }
2566 }
2567 }
2568
2569 output
2570 }
2571
2572 pub fn utilization(&self) -> f64 {
2574 if self.token_limit == 0 {
2575 return 0.0;
2576 }
2577 (self.total_tokens as f64 / self.token_limit as f64) * 100.0
2578 }
2579
2580 pub fn has_truncation(&self) -> bool {
2582 self.sections.iter().any(|s| s.truncated)
2583 }
2584}
2585
2586#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
2588pub struct SectionPriority(pub i32);
2589
2590impl SectionPriority {
2591 pub const CRITICAL: SectionPriority = SectionPriority(-100);
2592 pub const SYSTEM: SectionPriority = SectionPriority(-1);
2593 pub const USER: SectionPriority = SectionPriority(0);
2594 pub const HISTORY: SectionPriority = SectionPriority(1);
2595 pub const KNOWLEDGE: SectionPriority = SectionPriority(2);
2596 pub const SUPPLEMENTARY: SectionPriority = SectionPriority(10);
2597}
2598
2599pub struct ContextQueryBuilder {
2605 output_name: String,
2606 session: SessionReference,
2607 options: ContextQueryOptions,
2608 sections: Vec<ContextSection>,
2609}
2610
2611impl ContextQueryBuilder {
2612 pub fn new(output_name: &str) -> Self {
2614 Self {
2615 output_name: output_name.to_string(),
2616 session: SessionReference::None,
2617 options: ContextQueryOptions::default(),
2618 sections: Vec::new(),
2619 }
2620 }
2621
2622 pub fn from_session(mut self, session_id: &str) -> Self {
2624 self.session = SessionReference::Session(session_id.to_string());
2625 self
2626 }
2627
2628 pub fn from_agent(mut self, agent_id: &str) -> Self {
2630 self.session = SessionReference::Agent(agent_id.to_string());
2631 self
2632 }
2633
2634 pub fn with_token_limit(mut self, limit: usize) -> Self {
2636 self.options.token_limit = limit;
2637 self
2638 }
2639
2640 pub fn include_schema(mut self, include: bool) -> Self {
2642 self.options.include_schema = include;
2643 self
2644 }
2645
2646 pub fn format(mut self, format: OutputFormat) -> Self {
2648 self.options.format = format;
2649 self
2650 }
2651
2652 pub fn truncation(mut self, strategy: TruncationStrategy) -> Self {
2654 self.options.truncation = strategy;
2655 self
2656 }
2657
2658 pub fn get(mut self, name: &str, priority: i32, path: &str) -> Self {
2660 let path_expr = PathExpression::parse(path).unwrap_or(PathExpression {
2661 segments: vec![path.to_string()],
2662 fields: vec![],
2663 all_fields: true,
2664 });
2665
2666 self.sections.push(ContextSection {
2667 name: name.to_string(),
2668 priority,
2669 content: SectionContent::Get { path: path_expr },
2670 transform: None,
2671 });
2672 self
2673 }
2674
2675 pub fn last(mut self, name: &str, priority: i32, count: usize, table: &str) -> Self {
2677 self.sections.push(ContextSection {
2678 name: name.to_string(),
2679 priority,
2680 content: SectionContent::Last {
2681 count,
2682 table: table.to_string(),
2683 where_clause: None,
2684 },
2685 transform: None,
2686 });
2687 self
2688 }
2689
2690 pub fn search(
2692 mut self,
2693 name: &str,
2694 priority: i32,
2695 collection: &str,
2696 query_var: &str,
2697 top_k: usize,
2698 ) -> Self {
2699 self.sections.push(ContextSection {
2700 name: name.to_string(),
2701 priority,
2702 content: SectionContent::Search {
2703 collection: collection.to_string(),
2704 query: SimilarityQuery::Variable(query_var.to_string()),
2705 top_k,
2706 min_score: None,
2707 },
2708 transform: None,
2709 });
2710 self
2711 }
2712
2713 pub fn literal(mut self, name: &str, priority: i32, value: &str) -> Self {
2715 self.sections.push(ContextSection {
2716 name: name.to_string(),
2717 priority,
2718 content: SectionContent::Literal {
2719 value: value.to_string(),
2720 },
2721 transform: None,
2722 });
2723 self
2724 }
2725
2726 pub fn build(self) -> ContextSelectQuery {
2728 ContextSelectQuery {
2729 output_name: self.output_name,
2730 session: self.session,
2731 options: self.options,
2732 sections: self.sections,
2733 }
2734 }
2735}
2736
2737#[cfg(test)]
2742mod tests {
2743 use super::*;
2744
2745 #[test]
2746 fn test_path_expression_simple() {
2747 let path = PathExpression::parse("user.profile").unwrap();
2748 assert_eq!(path.segments, vec!["user", "profile"]);
2749 assert!(path.all_fields);
2750 }
2751
2752 #[test]
2753 fn test_path_expression_with_fields() {
2754 let path = PathExpression::parse("user.profile.{name, email}").unwrap();
2755 assert_eq!(path.segments, vec!["user", "profile"]);
2756 assert_eq!(path.fields, vec!["name", "email"]);
2757 assert!(!path.all_fields);
2758 }
2759
2760 #[test]
2761 fn test_path_expression_glob() {
2762 let path = PathExpression::parse("user.**").unwrap();
2763 assert_eq!(path.segments, vec!["user"]);
2764 assert!(path.all_fields);
2765 }
2766
2767 #[test]
2768 fn test_parse_simple_query() {
2769 let query = r#"
2770 CONTEXT SELECT prompt_context
2771 FROM session($SESSION_ID)
2772 WITH (token_limit = 2048, include_schema = true)
2773 SECTIONS (
2774 USER PRIORITY 0: GET user.profile.{name, preferences}
2775 )
2776 "#;
2777
2778 let mut parser = ContextQueryParser::new(query);
2779 let result = parser.parse().unwrap();
2780
2781 assert_eq!(result.output_name, "prompt_context");
2782 assert!(matches!(result.session, SessionReference::Session(s) if s == "SESSION_ID"));
2783 assert_eq!(result.options.token_limit, 2048);
2784 assert!(result.options.include_schema);
2785 assert_eq!(result.sections.len(), 1);
2786 assert_eq!(result.sections[0].name, "USER");
2787 assert_eq!(result.sections[0].priority, 0);
2788 }
2789
2790 #[test]
2791 fn test_parse_multiple_sections() {
2792 let query = r#"
2793 CONTEXT SELECT context
2794 SECTIONS (
2795 A PRIORITY 0: "literal value",
2796 B PRIORITY 1: LAST 10 FROM logs,
2797 C PRIORITY 2: SEARCH docs BY SIMILARITY($query) TOP 5
2798 )
2799 "#;
2800
2801 let mut parser = ContextQueryParser::new(query);
2802 let result = parser.parse().unwrap();
2803
2804 assert_eq!(result.sections.len(), 3);
2805
2806 assert_eq!(result.sections[0].name, "A");
2808 assert!(
2809 matches!(&result.sections[0].content, SectionContent::Literal { value } if value == "literal value")
2810 );
2811
2812 assert_eq!(result.sections[1].name, "B");
2814 assert!(
2815 matches!(&result.sections[1].content, SectionContent::Last { count: 10, table, .. } if table == "logs")
2816 );
2817
2818 assert_eq!(result.sections[2].name, "C");
2820 assert!(
2821 matches!(&result.sections[2].content, SectionContent::Search { collection, top_k: 5, .. } if collection == "docs")
2822 );
2823 }
2824
2825 #[test]
2826 fn test_builder() {
2827 let query = ContextQueryBuilder::new("prompt")
2828 .from_session("sess123")
2829 .with_token_limit(4096)
2830 .include_schema(false)
2831 .get("USER", 0, "user.profile.{name, email}")
2832 .last("HISTORY", 1, 20, "events")
2833 .search("DOCS", 2, "knowledge_base", "query_embedding", 10)
2834 .literal("SYSTEM", -1, "You are a helpful assistant")
2835 .build();
2836
2837 assert_eq!(query.output_name, "prompt");
2838 assert_eq!(query.options.token_limit, 4096);
2839 assert!(!query.options.include_schema);
2840 assert_eq!(query.sections.len(), 4);
2841
2842 let system = query.sections.iter().find(|s| s.name == "SYSTEM").unwrap();
2844 assert_eq!(system.priority, -1);
2845 }
2846
2847 #[test]
2848 fn test_output_format() {
2849 let query = r#"
2850 CONTEXT SELECT ctx
2851 WITH (format = markdown)
2852 SECTIONS ()
2853 "#;
2854
2855 let mut parser = ContextQueryParser::new(query);
2856 let result = parser.parse().unwrap();
2857
2858 assert_eq!(result.options.format, OutputFormat::Markdown);
2859 }
2860
2861 #[test]
2862 fn test_truncation_strategy() {
2863 let query = r#"
2864 CONTEXT SELECT ctx
2865 WITH (truncation = proportional)
2866 SECTIONS ()
2867 "#;
2868
2869 let mut parser = ContextQueryParser::new(query);
2870 let result = parser.parse().unwrap();
2871
2872 assert_eq!(result.options.truncation, TruncationStrategy::Proportional);
2873 }
2874
2875 #[test]
2880 fn test_simple_vector_index_creation() {
2881 let index = SimpleVectorIndex::new();
2882 index.create_collection("test", 3);
2883
2884 let stats = index.stats("test");
2885 assert!(stats.is_some());
2886 let stats = stats.unwrap();
2887 assert_eq!(stats.dimension, 3);
2888 assert_eq!(stats.vector_count, 0);
2889 assert_eq!(stats.metric, "cosine");
2890 }
2891
2892 #[test]
2893 fn test_simple_vector_index_insert_and_search() {
2894 let index = SimpleVectorIndex::new();
2895 index.create_collection("docs", 3);
2896
2897 index
2899 .insert(
2900 "docs",
2901 "doc1".to_string(),
2902 vec![1.0, 0.0, 0.0],
2903 "Document about cats".to_string(),
2904 HashMap::new(),
2905 )
2906 .unwrap();
2907
2908 index
2909 .insert(
2910 "docs",
2911 "doc2".to_string(),
2912 vec![0.9, 0.1, 0.0],
2913 "Document about dogs".to_string(),
2914 HashMap::new(),
2915 )
2916 .unwrap();
2917
2918 index
2919 .insert(
2920 "docs",
2921 "doc3".to_string(),
2922 vec![0.0, 0.0, 1.0],
2923 "Document about cars".to_string(),
2924 HashMap::new(),
2925 )
2926 .unwrap();
2927
2928 let results = index
2930 .search_by_embedding("docs", &[1.0, 0.0, 0.0], 2, None)
2931 .unwrap();
2932
2933 assert_eq!(results.len(), 2);
2934 assert_eq!(results[0].id, "doc1"); assert!((results[0].score - 1.0).abs() < 0.001);
2936 assert_eq!(results[1].id, "doc2"); assert!(results[1].score > 0.9); }
2939
2940 #[test]
2941 fn test_simple_vector_index_min_score_filter() {
2942 let index = SimpleVectorIndex::new();
2943 index.create_collection("docs", 3);
2944
2945 index
2946 .insert(
2947 "docs",
2948 "a".to_string(),
2949 vec![1.0, 0.0, 0.0],
2950 "A".to_string(),
2951 HashMap::new(),
2952 )
2953 .unwrap();
2954 index
2955 .insert(
2956 "docs",
2957 "b".to_string(),
2958 vec![0.0, 1.0, 0.0],
2959 "B".to_string(),
2960 HashMap::new(),
2961 )
2962 .unwrap();
2963 index
2964 .insert(
2965 "docs",
2966 "c".to_string(),
2967 vec![0.0, 0.0, 1.0],
2968 "C".to_string(),
2969 HashMap::new(),
2970 )
2971 .unwrap();
2972
2973 let results = index
2975 .search_by_embedding("docs", &[1.0, 0.0, 0.0], 10, Some(0.9))
2976 .unwrap();
2977
2978 assert_eq!(results.len(), 1);
2979 assert_eq!(results[0].id, "a");
2980 }
2981
2982 #[test]
2983 fn test_simple_vector_index_dimension_mismatch() {
2984 let index = SimpleVectorIndex::new();
2985 index.create_collection("docs", 3);
2986
2987 let result = index.insert(
2988 "docs",
2989 "bad".to_string(),
2990 vec![1.0, 0.0], "Content".to_string(),
2992 HashMap::new(),
2993 );
2994
2995 assert!(result.is_err());
2996 assert!(result.unwrap_err().contains("dimension mismatch"));
2997 }
2998
2999 #[test]
3000 fn test_simple_vector_index_nonexistent_collection() {
3001 let index = SimpleVectorIndex::new();
3002
3003 let result = index.search_by_embedding("nonexistent", &[1.0], 1, None);
3004 assert!(result.is_err());
3005 assert!(result.unwrap_err().contains("not found"));
3006 }
3007
3008 #[test]
3009 fn test_vector_index_with_metadata() {
3010 let index = SimpleVectorIndex::new();
3011 index.create_collection("docs", 2);
3012
3013 let mut metadata = HashMap::new();
3014 metadata.insert("author".to_string(), SochValue::Text("Alice".to_string()));
3015 metadata.insert("year".to_string(), SochValue::Int(2024));
3016
3017 index
3018 .insert(
3019 "docs",
3020 "doc1".to_string(),
3021 vec![1.0, 0.0],
3022 "Document content".to_string(),
3023 metadata,
3024 )
3025 .unwrap();
3026
3027 let results = index
3028 .search_by_embedding("docs", &[1.0, 0.0], 1, None)
3029 .unwrap();
3030
3031 assert_eq!(results.len(), 1);
3032 assert!(results[0].metadata.contains_key("author"));
3033 assert!(results[0].metadata.contains_key("year"));
3034 }
3035
3036 #[test]
3037 fn test_vector_index_text_search_unsupported() {
3038 let index = SimpleVectorIndex::new();
3039 index.create_collection("docs", 2);
3040
3041 let result = index.search_by_text("docs", "hello", 5, None);
3043 assert!(result.is_err());
3044 assert!(result.unwrap_err().contains("embedding model"));
3045 }
3046}