use crate::soch_ql::{ComparisonOp, Condition, LogicalOp, SochValue, WhereClause};
use crate::token_budget::{BudgetSection, TokenBudgetConfig, TokenBudgetEnforcer, TokenEstimator};
use std::collections::HashMap;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ContextSelectQuery {
pub output_name: String,
pub session: SessionReference,
pub options: ContextQueryOptions,
pub sections: Vec<ContextSection>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum SessionReference {
Session(String),
Agent(String),
None,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ContextQueryOptions {
pub token_limit: usize,
pub include_schema: bool,
pub format: OutputFormat,
pub truncation: TruncationStrategy,
pub include_headers: bool,
}
impl Default for ContextQueryOptions {
fn default() -> Self {
Self {
token_limit: 4096,
include_schema: true,
format: OutputFormat::Soch,
truncation: TruncationStrategy::TailDrop,
include_headers: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum OutputFormat {
Soch,
Json,
Markdown,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum TruncationStrategy {
TailDrop,
HeadDrop,
Proportional,
Fail,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ContextSection {
pub name: String,
pub priority: i32,
pub content: SectionContent,
pub transform: Option<SectionTransform>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum SectionContent {
Get { path: PathExpression },
Last {
count: usize,
table: String,
where_clause: Option<WhereClause>,
},
Search {
collection: String,
query: SimilarityQuery,
top_k: usize,
min_score: Option<f32>,
},
Select {
columns: Vec<String>,
table: String,
where_clause: Option<WhereClause>,
limit: Option<usize>,
},
Literal { value: String },
Variable { name: String },
ToolRegistry {
include: Vec<String>,
exclude: Vec<String>,
include_schema: bool,
},
ToolCalls {
count: usize,
tool_filter: Option<String>,
status_filter: Option<String>,
include_outputs: bool,
},
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PathExpression {
pub segments: Vec<String>,
pub fields: Vec<String>,
pub all_fields: bool,
}
impl PathExpression {
pub fn parse(input: &str) -> Result<Self, ContextParseError> {
let input = input.trim();
if let Some(brace_start) = input.find('{') {
if !input.ends_with('}') {
return Err(ContextParseError::InvalidPath(
"unclosed field projection".to_string(),
));
}
let path_part = &input[..brace_start].trim_end_matches('.');
let fields_part = &input[brace_start + 1..input.len() - 1];
let segments: Vec<String> = path_part.split('.').map(|s| s.to_string()).collect();
let fields: Vec<String> = fields_part
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
Ok(PathExpression {
segments,
fields,
all_fields: false,
})
} else if let Some(path_part) = input.strip_suffix(".**") {
let segments: Vec<String> = path_part.split('.').map(|s| s.to_string()).collect();
Ok(PathExpression {
segments,
fields: vec![],
all_fields: true,
})
} else {
let segments: Vec<String> = input.split('.').map(|s| s.to_string()).collect();
Ok(PathExpression {
segments,
fields: vec![],
all_fields: true,
})
}
}
pub fn to_path_string(&self) -> String {
let base = self.segments.join(".");
if self.all_fields {
format!("{}.**", base)
} else if !self.fields.is_empty() {
format!("{}.{{{}}}", base, self.fields.join(", "))
} else {
base
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum SimilarityQuery {
Variable(String),
Embedding(Vec<f32>),
Text(String),
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum SectionTransform {
Summarize { max_tokens: usize },
Project { fields: Vec<String> },
Template { template: String },
Custom { function: String },
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ContextRecipe {
pub id: String,
pub name: String,
pub description: String,
pub version: String,
pub query: ContextSelectQuery,
pub metadata: RecipeMetadata,
pub session_binding: Option<SessionBinding>,
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct RecipeMetadata {
pub author: Option<String>,
pub created_at: Option<String>,
pub updated_at: Option<String>,
pub tags: Vec<String>,
pub usage_count: u64,
pub avg_tokens: Option<f32>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum SessionBinding {
Session(String),
Agent(String),
Pattern(String),
None,
}
pub struct ContextRecipeStore {
recipes: std::sync::RwLock<HashMap<String, ContextRecipe>>,
versions: std::sync::RwLock<HashMap<String, Vec<String>>>,
}
impl ContextRecipeStore {
pub fn new() -> Self {
Self {
recipes: std::sync::RwLock::new(HashMap::new()),
versions: std::sync::RwLock::new(HashMap::new()),
}
}
pub fn save(&self, recipe: ContextRecipe) -> Result<(), String> {
let mut recipes = self.recipes.write().map_err(|e| e.to_string())?;
let mut versions = self.versions.write().map_err(|e| e.to_string())?;
let key = format!("{}:{}", recipe.id, recipe.version);
recipes.insert(key.clone(), recipe.clone());
versions
.entry(recipe.id.clone())
.or_default()
.push(recipe.version.clone());
Ok(())
}
pub fn get_latest(&self, recipe_id: &str) -> Option<ContextRecipe> {
let versions = self.versions.read().ok()?;
let latest_version = versions.get(recipe_id)?.last()?;
let recipes = self.recipes.read().ok()?;
let key = format!("{}:{}", recipe_id, latest_version);
recipes.get(&key).cloned()
}
pub fn get_version(&self, recipe_id: &str, version: &str) -> Option<ContextRecipe> {
let recipes = self.recipes.read().ok()?;
let key = format!("{}:{}", recipe_id, version);
recipes.get(&key).cloned()
}
pub fn list_versions(&self, recipe_id: &str) -> Vec<String> {
self.versions
.read()
.ok()
.and_then(|v| v.get(recipe_id).cloned())
.unwrap_or_default()
}
pub fn find_by_session(&self, session_id: &str) -> Vec<ContextRecipe> {
let recipes = match self.recipes.read() {
Ok(r) => r,
Err(_) => return Vec::new(),
};
recipes
.values()
.filter(|r| match &r.session_binding {
Some(SessionBinding::Session(sid)) => sid == session_id,
Some(SessionBinding::Pattern(pattern)) => glob_match(pattern, session_id),
_ => false,
})
.cloned()
.collect()
}
pub fn find_by_agent(&self, agent_id: &str) -> Vec<ContextRecipe> {
let recipes = match self.recipes.read() {
Ok(r) => r,
Err(_) => return Vec::new(),
};
recipes
.values()
.filter(|r| matches!(&r.session_binding, Some(SessionBinding::Agent(aid)) if aid == agent_id))
.cloned()
.collect()
}
}
impl Default for ContextRecipeStore {
fn default() -> Self {
Self::new()
}
}
fn glob_match(pattern: &str, input: &str) -> bool {
if pattern == "*" {
return true;
}
if pattern.contains('*') {
let parts: Vec<&str> = pattern.split('*').collect();
if parts.len() == 2 {
return input.starts_with(parts[0]) && input.ends_with(parts[1]);
}
}
pattern == input
}
#[derive(Debug, Clone)]
pub struct VectorSearchResult {
pub id: String,
pub score: f32,
pub content: String,
pub metadata: HashMap<String, SochValue>,
}
pub trait VectorIndex: Send + Sync {
fn search_by_embedding(
&self,
collection: &str,
embedding: &[f32],
k: usize,
min_score: Option<f32>,
) -> Result<Vec<VectorSearchResult>, String>;
fn search_by_text(
&self,
collection: &str,
text: &str,
k: usize,
min_score: Option<f32>,
) -> Result<Vec<VectorSearchResult>, String>;
fn stats(&self, collection: &str) -> Option<VectorIndexStats>;
}
#[derive(Debug, Clone)]
pub struct VectorIndexStats {
pub vector_count: usize,
pub dimension: usize,
pub metric: String,
}
pub struct SimpleVectorIndex {
collections: std::sync::RwLock<HashMap<String, VectorCollection>>,
}
struct VectorCollection {
#[allow(clippy::type_complexity)]
vectors: Vec<(String, Vec<f32>, String, HashMap<String, SochValue>)>,
dimension: usize,
}
impl SimpleVectorIndex {
pub fn new() -> Self {
Self {
collections: std::sync::RwLock::new(HashMap::new()),
}
}
pub fn create_collection(&self, name: &str, dimension: usize) {
let mut collections = self.collections.write().unwrap();
collections
.entry(name.to_string())
.or_insert_with(|| VectorCollection {
vectors: Vec::new(),
dimension,
});
}
pub fn insert(
&self,
collection: &str,
id: String,
vector: Vec<f32>,
content: String,
metadata: HashMap<String, SochValue>,
) -> Result<(), String> {
let mut collections = self.collections.write().unwrap();
let coll = collections
.get_mut(collection)
.ok_or_else(|| format!("Collection '{}' not found", collection))?;
if vector.len() != coll.dimension {
return Err(format!(
"Vector dimension mismatch: expected {}, got {}",
coll.dimension,
vector.len()
));
}
coll.vectors.push((id, vector, content, metadata));
Ok(())
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
}
impl Default for SimpleVectorIndex {
fn default() -> Self {
Self::new()
}
}
impl VectorIndex for SimpleVectorIndex {
fn search_by_embedding(
&self,
collection: &str,
embedding: &[f32],
k: usize,
min_score: Option<f32>,
) -> Result<Vec<VectorSearchResult>, String> {
let collections = self.collections.read().unwrap();
let coll = collections
.get(collection)
.ok_or_else(|| format!("Collection '{}' not found", collection))?;
let mut scored: Vec<_> = coll
.vectors
.iter()
.map(|(id, vec, content, meta)| {
let score = Self::cosine_similarity(embedding, vec);
(id, score, content, meta)
})
.filter(|(_, score, _, _)| min_score.map(|min| *score >= min).unwrap_or(true))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(scored
.into_iter()
.take(k)
.map(|(id, score, content, meta)| VectorSearchResult {
id: id.clone(),
score,
content: content.clone(),
metadata: meta.clone(),
})
.collect())
}
fn search_by_text(
&self,
_collection: &str,
_text: &str,
_k: usize,
_min_score: Option<f32>,
) -> Result<Vec<VectorSearchResult>, String> {
Err(
"Text-based search requires an embedding model. Use search_by_embedding instead."
.to_string(),
)
}
fn stats(&self, collection: &str) -> Option<VectorIndexStats> {
let collections = self.collections.read().unwrap();
collections.get(collection).map(|coll| VectorIndexStats {
vector_count: coll.vectors.len(),
dimension: coll.dimension,
metric: "cosine".to_string(),
})
}
}
pub struct HnswVectorIndex {
collections: std::sync::RwLock<HashMap<String, HnswCollection>>,
}
struct HnswCollection {
index: sochdb_index::vector::VectorIndex,
#[allow(clippy::type_complexity)]
metadata: HashMap<u128, (String, String, HashMap<String, SochValue>)>,
next_edge_id: u128,
dimension: usize,
}
impl HnswVectorIndex {
pub fn new() -> Self {
Self {
collections: std::sync::RwLock::new(HashMap::new()),
}
}
pub fn create_collection(&self, name: &str, dimension: usize) {
let mut collections = self.collections.write().unwrap();
collections.entry(name.to_string()).or_insert_with(|| {
let index = sochdb_index::vector::VectorIndex::with_dimension(
sochdb_index::vector::DistanceMetric::Cosine,
dimension,
);
HnswCollection {
index,
metadata: HashMap::new(),
next_edge_id: 0,
dimension,
}
});
}
pub fn insert(
&self,
collection: &str,
id: String,
vector: Vec<f32>,
content: String,
metadata: HashMap<String, SochValue>,
) -> Result<(), String> {
let mut collections = self.collections.write().unwrap();
let coll = collections
.get_mut(collection)
.ok_or_else(|| format!("Collection '{}' not found", collection))?;
if vector.len() != coll.dimension {
return Err(format!(
"Vector dimension mismatch: expected {}, got {}",
coll.dimension,
vector.len()
));
}
let edge_id = coll.next_edge_id;
coll.next_edge_id += 1;
coll.metadata.insert(edge_id, (id, content, metadata));
let embedding = ndarray::Array1::from_vec(vector);
coll.index.add(edge_id, embedding)?;
Ok(())
}
pub fn vector_count(&self, collection: &str) -> Option<usize> {
let collections = self.collections.read().unwrap();
collections.get(collection).map(|c| c.metadata.len())
}
}
impl Default for HnswVectorIndex {
fn default() -> Self {
Self::new()
}
}
impl VectorIndex for HnswVectorIndex {
fn search_by_embedding(
&self,
collection: &str,
embedding: &[f32],
k: usize,
min_score: Option<f32>,
) -> Result<Vec<VectorSearchResult>, String> {
let collections = self.collections.read().unwrap();
let coll = collections
.get(collection)
.ok_or_else(|| format!("Collection '{}' not found", collection))?;
let query = ndarray::Array1::from_vec(embedding.to_vec());
let results = coll.index.search(&query, k)?;
let mut search_results = Vec::with_capacity(results.len());
for (edge_id, distance) in results {
let score = 1.0 - distance;
if let Some(min) = min_score {
if score < min {
continue;
}
}
if let Some((id, content, meta)) = coll.metadata.get(&edge_id) {
search_results.push(VectorSearchResult {
id: id.clone(),
score,
content: content.clone(),
metadata: meta.clone(),
});
}
}
Ok(search_results)
}
fn search_by_text(
&self,
_collection: &str,
_text: &str,
_k: usize,
_min_score: Option<f32>,
) -> Result<Vec<VectorSearchResult>, String> {
Err(
"Text-based search requires an embedding model. Use search_by_embedding instead."
.to_string(),
)
}
fn stats(&self, collection: &str) -> Option<VectorIndexStats> {
let collections = self.collections.read().unwrap();
collections.get(collection).map(|coll| VectorIndexStats {
vector_count: coll.metadata.len(),
dimension: coll.dimension,
metric: "cosine".to_string(),
})
}
}
#[derive(Debug, Clone)]
pub struct ContextResult {
pub context: String,
pub token_count: usize,
pub token_budget: usize,
pub sections_included: Vec<SectionResult>,
pub sections_truncated: Vec<String>,
pub sections_dropped: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct SectionResult {
pub name: String,
pub priority: i32,
pub content: String,
pub tokens: usize,
pub tokens_used: usize,
pub truncated: bool,
pub row_count: usize,
}
#[derive(Debug, Clone)]
pub enum ContextQueryError {
SessionMismatch { expected: String, actual: String },
VariableNotFound(String),
InvalidVariableType { variable: String, expected: String },
BudgetExceeded {
section: String,
requested: usize,
available: usize,
},
BudgetExhausted(String),
PermissionDenied(String),
InvalidPath(String),
Parse(ContextParseError),
FormatError(String),
InvalidQuery(String),
VectorSearchError(String),
}
impl std::fmt::Display for ContextQueryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SessionMismatch { expected, actual } => {
write!(f, "session mismatch: expected {}, got {}", expected, actual)
}
Self::VariableNotFound(name) => write!(f, "variable not found: {}", name),
Self::InvalidVariableType { variable, expected } => {
write!(
f,
"variable {} has invalid type, expected {}",
variable, expected
)
}
Self::BudgetExceeded {
section,
requested,
available,
} => {
write!(
f,
"section {} exceeds budget: {} > {}",
section, requested, available
)
}
Self::BudgetExhausted(msg) => write!(f, "budget exhausted: {}", msg),
Self::PermissionDenied(msg) => write!(f, "permission denied: {}", msg),
Self::InvalidPath(path) => write!(f, "invalid path: {}", path),
Self::Parse(e) => write!(f, "parse error: {}", e),
Self::FormatError(e) => write!(f, "format error: {}", e),
Self::InvalidQuery(msg) => write!(f, "invalid query: {}", msg),
Self::VectorSearchError(e) => write!(f, "vector search error: {}", e),
}
}
}
impl std::error::Error for ContextQueryError {}
#[derive(Debug, Clone)]
pub enum ContextParseError {
UnexpectedToken { expected: String, found: String },
MissingClause(String),
InvalidOption(String),
InvalidPath(String),
InvalidSection(String),
SyntaxError(String),
}
impl std::fmt::Display for ContextParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnexpectedToken { expected, found } => {
write!(f, "expected {}, found '{}'", expected, found)
}
Self::MissingClause(clause) => write!(f, "missing {} clause", clause),
Self::InvalidOption(opt) => write!(f, "invalid option: {}", opt),
Self::InvalidPath(path) => write!(f, "invalid path: {}", path),
Self::InvalidSection(sec) => write!(f, "invalid section: {}", sec),
Self::SyntaxError(msg) => write!(f, "syntax error: {}", msg),
}
}
}
impl std::error::Error for ContextParseError {}
pub struct ContextQueryParser {
pos: usize,
tokens: Vec<Token>,
}
#[derive(Debug, Clone, PartialEq)]
enum Token {
Keyword(String),
Ident(String),
Number(f64),
String(String),
Punct(char),
Variable(String),
Eof,
}
impl ContextQueryParser {
pub fn new(input: &str) -> Self {
let tokens = Self::tokenize(input);
Self { pos: 0, tokens }
}
pub fn parse(&mut self) -> Result<ContextSelectQuery, ContextParseError> {
self.expect_keyword("CONTEXT")?;
self.expect_keyword("SELECT")?;
let output_name = self.expect_ident()?;
let session = if self.match_keyword("FROM") {
self.parse_session_reference()?
} else {
SessionReference::None
};
let options = if self.match_keyword("WITH") {
self.parse_options()?
} else {
ContextQueryOptions::default()
};
self.expect_keyword("SECTIONS")?;
let sections = self.parse_sections()?;
Ok(ContextSelectQuery {
output_name,
session,
options,
sections,
})
}
fn parse_session_reference(&mut self) -> Result<SessionReference, ContextParseError> {
if self.match_keyword("session") {
self.expect_punct('(')?;
let var = self.expect_variable()?;
self.expect_punct(')')?;
Ok(SessionReference::Session(var))
} else if self.match_keyword("agent") {
self.expect_punct('(')?;
let var = self.expect_variable()?;
self.expect_punct(')')?;
Ok(SessionReference::Agent(var))
} else {
Err(ContextParseError::SyntaxError(
"expected 'session' or 'agent'".to_string(),
))
}
}
fn parse_options(&mut self) -> Result<ContextQueryOptions, ContextParseError> {
self.expect_punct('(')?;
let mut options = ContextQueryOptions::default();
loop {
let key = self.expect_ident()?;
self.expect_punct('=')?;
match key.as_str() {
"token_limit" => {
if let Token::Number(n) = self.current().clone() {
options.token_limit = n as usize;
self.advance();
}
}
"include_schema" => {
options.include_schema = self.parse_bool()?;
}
"format" => {
let format = self.expect_ident()?;
options.format = match format.to_lowercase().as_str() {
"toon" => OutputFormat::Soch,
"json" => OutputFormat::Json,
"markdown" => OutputFormat::Markdown,
_ => return Err(ContextParseError::InvalidOption(format)),
};
}
"truncation" => {
let strategy = self.expect_ident()?;
options.truncation = match strategy.to_lowercase().as_str() {
"tail_drop" | "taildrop" => TruncationStrategy::TailDrop,
"head_drop" | "headdrop" => TruncationStrategy::HeadDrop,
"proportional" => TruncationStrategy::Proportional,
"fail" => TruncationStrategy::Fail,
_ => return Err(ContextParseError::InvalidOption(strategy)),
};
}
"include_headers" => {
options.include_headers = self.parse_bool()?;
}
_ => return Err(ContextParseError::InvalidOption(key)),
}
if !self.match_punct(',') {
break;
}
}
self.expect_punct(')')?;
Ok(options)
}
fn parse_sections(&mut self) -> Result<Vec<ContextSection>, ContextParseError> {
self.expect_punct('(')?;
let mut sections = Vec::new();
loop {
if self.check_punct(')') {
break;
}
let section = self.parse_section()?;
sections.push(section);
if !self.match_punct(',') {
break;
}
}
self.expect_punct(')')?;
Ok(sections)
}
fn parse_section(&mut self) -> Result<ContextSection, ContextParseError> {
let name = self.expect_ident()?;
self.expect_keyword("PRIORITY")?;
let priority = if let Token::Number(n) = self.current().clone() {
let val = n as i32;
self.advance();
val
} else {
0
};
self.expect_punct(':')?;
let content = self.parse_section_content()?;
Ok(ContextSection {
name,
priority,
content,
transform: None,
})
}
fn parse_section_content(&mut self) -> Result<SectionContent, ContextParseError> {
if self.match_keyword("GET") {
let path_str = self.collect_until(&[',', ')']);
let path = PathExpression::parse(&path_str)?;
Ok(SectionContent::Get { path })
} else if self.match_keyword("LAST") {
let count = if let Token::Number(n) = self.current().clone() {
let val = n as usize;
self.advance();
val
} else {
10 };
self.expect_keyword("FROM")?;
let table = self.expect_ident()?;
let where_clause = if self.match_keyword("WHERE") {
Some(self.parse_where_clause()?)
} else {
None
};
Ok(SectionContent::Last {
count,
table,
where_clause,
})
} else if self.match_keyword("SEARCH") {
let collection = self.expect_ident()?;
self.expect_keyword("BY")?;
self.expect_keyword("SIMILARITY")?;
self.expect_punct('(')?;
let query = if let Token::Variable(v) = self.current().clone() {
self.advance();
SimilarityQuery::Variable(v)
} else if let Token::String(s) = self.current().clone() {
self.advance();
SimilarityQuery::Text(s)
} else {
return Err(ContextParseError::SyntaxError(
"expected variable or string for similarity query".to_string(),
));
};
self.expect_punct(')')?;
self.expect_keyword("TOP")?;
let top_k = if let Token::Number(n) = self.current().clone() {
let val = n as usize;
self.advance();
val
} else {
5 };
Ok(SectionContent::Search {
collection,
query,
top_k,
min_score: None,
})
} else if self.match_keyword("SELECT") {
let columns = self.parse_column_list()?;
self.expect_keyword("FROM")?;
let table = self.expect_ident()?;
let where_clause = if self.match_keyword("WHERE") {
Some(self.parse_where_clause()?)
} else {
None
};
let limit = if self.match_keyword("LIMIT") {
if let Token::Number(n) = self.current().clone() {
let val = n as usize;
self.advance();
Some(val)
} else {
None
}
} else {
None
};
Ok(SectionContent::Select {
columns,
table,
where_clause,
limit,
})
} else if let Token::Variable(v) = self.current().clone() {
self.advance();
Ok(SectionContent::Variable { name: v })
} else if let Token::String(s) = self.current().clone() {
self.advance();
Ok(SectionContent::Literal { value: s })
} else {
Err(ContextParseError::InvalidSection(
"expected GET, LAST, SEARCH, SELECT, or literal".to_string(),
))
}
}
fn parse_where_clause(&mut self) -> Result<WhereClause, ContextParseError> {
let mut conditions = Vec::new();
loop {
let column = self.expect_ident()?;
let operator = self.parse_comparison_op()?;
let value = self.parse_value()?;
conditions.push(Condition {
column,
operator,
value,
});
if !self.match_keyword("AND") && !self.match_keyword("OR") {
break;
}
}
Ok(WhereClause {
conditions,
operator: LogicalOp::And,
})
}
fn parse_comparison_op(&mut self) -> Result<ComparisonOp, ContextParseError> {
match self.current() {
Token::Punct('=') => {
self.advance();
Ok(ComparisonOp::Eq)
}
Token::Punct('>') => {
self.advance();
if self.check_punct('=') {
self.advance();
Ok(ComparisonOp::Ge)
} else {
Ok(ComparisonOp::Gt)
}
}
Token::Punct('<') => {
self.advance();
if self.check_punct('=') {
self.advance();
Ok(ComparisonOp::Le)
} else {
Ok(ComparisonOp::Lt)
}
}
_ => {
if self.match_keyword("LIKE") {
Ok(ComparisonOp::Like)
} else if self.match_keyword("IN") {
Ok(ComparisonOp::In)
} else {
Err(ContextParseError::SyntaxError(
"expected comparison operator".to_string(),
))
}
}
}
}
fn parse_value(&mut self) -> Result<SochValue, ContextParseError> {
match self.current().clone() {
Token::Number(n) => {
self.advance();
if n.fract() == 0.0 {
Ok(SochValue::Int(n as i64))
} else {
Ok(SochValue::Float(n))
}
}
Token::String(s) => {
self.advance();
Ok(SochValue::Text(s))
}
Token::Keyword(k) if k.eq_ignore_ascii_case("null") => {
self.advance();
Ok(SochValue::Null)
}
Token::Keyword(k) if k.eq_ignore_ascii_case("true") => {
self.advance();
Ok(SochValue::Bool(true))
}
Token::Keyword(k) if k.eq_ignore_ascii_case("false") => {
self.advance();
Ok(SochValue::Bool(false))
}
Token::Variable(v) => {
self.advance();
Ok(SochValue::Text(format!("${}", v)))
}
_ => Err(ContextParseError::SyntaxError("expected value".to_string())),
}
}
fn parse_column_list(&mut self) -> Result<Vec<String>, ContextParseError> {
let mut columns = Vec::new();
if self.check_punct('*') {
self.advance();
columns.push("*".to_string());
} else {
loop {
columns.push(self.expect_ident()?);
if !self.match_punct(',') {
break;
}
}
}
Ok(columns)
}
fn parse_bool(&mut self) -> Result<bool, ContextParseError> {
match self.current() {
Token::Keyword(k) if k.eq_ignore_ascii_case("true") => {
self.advance();
Ok(true)
}
Token::Keyword(k) if k.eq_ignore_ascii_case("false") => {
self.advance();
Ok(false)
}
_ => Err(ContextParseError::SyntaxError(
"expected boolean".to_string(),
)),
}
}
fn tokenize(input: &str) -> Vec<Token> {
let mut tokens = Vec::new();
let mut chars = input.chars().peekable();
while let Some(&ch) = chars.peek() {
match ch {
' ' | '\t' | '\n' | '\r' => {
chars.next();
}
'(' | ')' | ',' | ':' | '=' | '<' | '>' | '*' | '{' | '}' | '.' => {
tokens.push(Token::Punct(ch));
chars.next();
}
'$' => {
chars.next();
let mut name = String::new();
while let Some(&c) = chars.peek() {
if c.is_alphanumeric() || c == '_' {
name.push(c);
chars.next();
} else {
break;
}
}
tokens.push(Token::Variable(name));
}
'\'' | '"' => {
let quote = ch;
chars.next();
let mut s = String::new();
while let Some(&c) = chars.peek() {
if c == quote {
chars.next(); break;
}
s.push(c);
chars.next();
}
tokens.push(Token::String(s));
}
'0'..='9' | '-' => {
let mut num_str = String::new();
if ch == '-' {
num_str.push(ch);
chars.next();
}
while let Some(&c) = chars.peek() {
if c.is_ascii_digit() || c == '.' {
num_str.push(c);
chars.next();
} else {
break;
}
}
if let Ok(n) = num_str.parse::<f64>() {
tokens.push(Token::Number(n));
}
}
'a'..='z' | 'A'..='Z' | '_' => {
let mut ident = String::new();
while let Some(&c) = chars.peek() {
if c.is_alphanumeric() || c == '_' {
ident.push(c);
chars.next();
} else {
break;
}
}
let keywords = [
"CONTEXT",
"SELECT",
"FROM",
"WITH",
"SECTIONS",
"PRIORITY",
"GET",
"LAST",
"SEARCH",
"BY",
"SIMILARITY",
"TOP",
"WHERE",
"AND",
"OR",
"LIKE",
"IN",
"LIMIT",
"session",
"agent",
"true",
"false",
"null",
];
if keywords.iter().any(|k| k.eq_ignore_ascii_case(&ident)) {
tokens.push(Token::Keyword(ident.to_uppercase()));
} else {
tokens.push(Token::Ident(ident));
}
}
_ => {
chars.next();
}
}
}
tokens.push(Token::Eof);
tokens
}
fn current(&self) -> &Token {
self.tokens.get(self.pos).unwrap_or(&Token::Eof)
}
fn advance(&mut self) {
if self.pos < self.tokens.len() {
self.pos += 1;
}
}
fn expect_keyword(&mut self, kw: &str) -> Result<(), ContextParseError> {
match self.current() {
Token::Keyword(k) if k.eq_ignore_ascii_case(kw) => {
self.advance();
Ok(())
}
other => Err(ContextParseError::UnexpectedToken {
expected: kw.to_string(),
found: format!("{:?}", other),
}),
}
}
fn match_keyword(&mut self, kw: &str) -> bool {
match self.current() {
Token::Keyword(k) if k.eq_ignore_ascii_case(kw) => {
self.advance();
true
}
_ => false,
}
}
fn expect_ident(&mut self) -> Result<String, ContextParseError> {
match self.current().clone() {
Token::Ident(s) => {
self.advance();
Ok(s)
}
Token::Keyword(s) => {
self.advance();
Ok(s)
}
other => Err(ContextParseError::UnexpectedToken {
expected: "identifier".to_string(),
found: format!("{:?}", other),
}),
}
}
fn expect_variable(&mut self) -> Result<String, ContextParseError> {
match self.current().clone() {
Token::Variable(v) => {
self.advance();
Ok(v)
}
other => Err(ContextParseError::UnexpectedToken {
expected: "variable ($name)".to_string(),
found: format!("{:?}", other),
}),
}
}
fn expect_punct(&mut self, p: char) -> Result<(), ContextParseError> {
match self.current() {
Token::Punct(c) if *c == p => {
self.advance();
Ok(())
}
other => Err(ContextParseError::UnexpectedToken {
expected: p.to_string(),
found: format!("{:?}", other),
}),
}
}
fn match_punct(&mut self, p: char) -> bool {
match self.current() {
Token::Punct(c) if *c == p => {
self.advance();
true
}
_ => false,
}
}
fn check_punct(&self, p: char) -> bool {
matches!(self.current(), Token::Punct(c) if *c == p)
}
fn collect_until(&mut self, terminators: &[char]) -> String {
let mut result = String::new();
let mut depth = 0;
loop {
match self.current() {
Token::Punct('{') => {
depth += 1;
result.push('{');
self.advance();
}
Token::Punct('}') => {
depth -= 1;
result.push('}');
self.advance();
}
Token::Punct(c) if depth == 0 && terminators.contains(c) => {
break;
}
Token::Punct(c) => {
result.push(*c);
self.advance();
}
Token::Ident(s) | Token::Keyword(s) => {
if !result.is_empty() && !result.ends_with(['.', '{']) {
result.push(' ');
}
result.push_str(s);
self.advance();
}
Token::Eof => break,
_ => {
self.advance();
}
}
}
result.trim().to_string()
}
}
use crate::agent_context::{AgentContext, AuditOperation, ContextValue};
pub struct AgentContextIntegration<'a> {
context: &'a mut AgentContext,
budget_enforcer: TokenBudgetEnforcer,
estimator: TokenEstimator,
vector_index: Option<std::sync::Arc<dyn VectorIndex>>,
embedding_provider: Option<std::sync::Arc<dyn EmbeddingProvider>>,
}
pub trait EmbeddingProvider: Send + Sync {
fn embed_text(&self, text: &str) -> Result<Vec<f32>, String>;
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, String> {
texts.iter().map(|t| self.embed_text(t)).collect()
}
fn dimension(&self) -> usize;
fn model_name(&self) -> &str;
}
impl<'a> AgentContextIntegration<'a> {
pub fn new(context: &'a mut AgentContext) -> Self {
let config = TokenBudgetConfig {
total_budget: context.budget.max_tokens.unwrap_or(4096) as usize,
..Default::default()
};
Self {
context,
budget_enforcer: TokenBudgetEnforcer::new(config),
estimator: TokenEstimator::default(),
vector_index: None,
embedding_provider: None,
}
}
pub fn with_vector_index(
context: &'a mut AgentContext,
vector_index: std::sync::Arc<dyn VectorIndex>,
) -> Self {
let config = TokenBudgetConfig {
total_budget: context.budget.max_tokens.unwrap_or(4096) as usize,
..Default::default()
};
Self {
context,
budget_enforcer: TokenBudgetEnforcer::new(config),
estimator: TokenEstimator::default(),
vector_index: Some(vector_index),
embedding_provider: None,
}
}
pub fn with_vector_and_embedding(
context: &'a mut AgentContext,
vector_index: std::sync::Arc<dyn VectorIndex>,
embedding_provider: std::sync::Arc<dyn EmbeddingProvider>,
) -> Self {
let config = TokenBudgetConfig {
total_budget: context.budget.max_tokens.unwrap_or(4096) as usize,
..Default::default()
};
Self {
context,
budget_enforcer: TokenBudgetEnforcer::new(config),
estimator: TokenEstimator::default(),
vector_index: Some(vector_index),
embedding_provider: Some(embedding_provider),
}
}
pub fn set_embedding_provider(&mut self, provider: std::sync::Arc<dyn EmbeddingProvider>) {
self.embedding_provider = Some(provider);
}
pub fn set_vector_index(&mut self, index: std::sync::Arc<dyn VectorIndex>) {
self.vector_index = Some(index);
}
pub fn execute(
&mut self,
query: &ContextSelectQuery,
) -> Result<ContextQueryResult, ContextQueryError> {
self.validate_session(&query.session)?;
self.context.audit.push(crate::agent_context::AuditEntry {
timestamp: std::time::SystemTime::now(),
operation: AuditOperation::DbQuery,
resource: format!("CONTEXT SELECT {}", query.output_name),
result: crate::agent_context::AuditResult::Success,
metadata: std::collections::HashMap::new(),
});
let resolved_sections = self.resolve_sections(&query.sections)?;
for section in &resolved_sections {
self.check_section_permissions(section)?;
}
let mut section_contents: Vec<(ContextSection, String)> = Vec::new();
for section in &resolved_sections {
let content = self.execute_section_content(section, query.options.token_limit)?;
section_contents.push((section.clone(), content));
}
let budget_sections: Vec<BudgetSection> = section_contents
.iter()
.map(|(section, content)| {
let estimated = self.estimator.estimate_text(content);
let minimum = if query.options.truncation == TruncationStrategy::Fail {
None
} else {
Some(estimated.min(100).max(estimated / 10))
};
BudgetSection {
name: section.name.clone(),
estimated_tokens: estimated,
minimum_tokens: minimum,
priority: section.priority,
required: section.priority == 0, weight: 1.0,
}
})
.collect();
let allocation = self.budget_enforcer.allocate_sections(&budget_sections);
let mut result = ContextQueryResult::new(query.output_name.clone());
result.format = query.options.format;
result.allocation_explain = Some(allocation.explain.clone());
for (section, content) in section_contents.iter() {
if allocation.full_sections.contains(§ion.name) {
let tokens = self.estimator.estimate_text(content);
result.sections.push(SectionResult {
name: section.name.clone(),
priority: section.priority,
content: content.clone(),
tokens,
tokens_used: tokens,
truncated: false,
row_count: 0,
});
}
}
for (section_name, _original, truncated_to) in &allocation.truncated_sections {
if let Some((section, content)) = section_contents
.iter()
.find(|(s, _)| &s.name == section_name)
{
let truncated = self.estimator.truncate_to_tokens(content, *truncated_to);
let actual_tokens = self.estimator.estimate_text(&truncated);
result.sections.push(SectionResult {
name: section.name.clone(),
priority: section.priority,
content: truncated,
tokens: actual_tokens,
tokens_used: actual_tokens,
truncated: true,
row_count: 0,
});
}
}
result.sections.sort_by_key(|s| s.priority);
result.total_tokens = allocation.tokens_allocated;
result.token_limit = query.options.token_limit;
self.context
.consume_budget(result.total_tokens as u64, 0)
.map_err(|e| ContextQueryError::BudgetExhausted(e.to_string()))?;
Ok(result)
}
pub fn execute_explain(
&mut self,
query: &ContextSelectQuery,
) -> Result<(ContextQueryResult, String), ContextQueryError> {
let result = self.execute(query)?;
let explain = result
.allocation_explain
.as_ref()
.map(|decisions| {
use crate::token_budget::BudgetAllocation;
let allocation = BudgetAllocation {
full_sections: result
.sections
.iter()
.filter(|s| !s.truncated)
.map(|s| s.name.clone())
.collect(),
truncated_sections: result
.sections
.iter()
.filter(|s| s.truncated)
.map(|s| (s.name.clone(), s.tokens, s.tokens_used))
.collect(),
dropped_sections: Vec::new(),
tokens_allocated: result.total_tokens,
tokens_remaining: result.token_limit.saturating_sub(result.total_tokens),
explain: decisions.clone(),
};
allocation.explain_text()
})
.unwrap_or_else(|| "No allocation explain available".to_string());
Ok((result, explain))
}
fn validate_session(&self, session_ref: &SessionReference) -> Result<(), ContextQueryError> {
match session_ref {
SessionReference::Session(sid) => {
if sid.starts_with('$') {
return Ok(());
}
if sid != &self.context.session_id && sid != "*" {
return Err(ContextQueryError::SessionMismatch {
expected: sid.clone(),
actual: self.context.session_id.clone(),
});
}
}
SessionReference::Agent(aid) => {
if let Some(ContextValue::String(agent_id)) = self.context.peek_var("agent_id")
&& aid != agent_id
&& aid != "*"
{
return Err(ContextQueryError::SessionMismatch {
expected: aid.clone(),
actual: agent_id.clone(),
});
}
}
SessionReference::None => {}
}
Ok(())
}
fn resolve_sections(
&self,
sections: &[ContextSection],
) -> Result<Vec<ContextSection>, ContextQueryError> {
let mut resolved = Vec::new();
for section in sections {
let mut resolved_section = section.clone();
resolved_section.content = match §ion.content {
SectionContent::Literal { value } => {
let resolved_value = self.resolve_variables(value);
SectionContent::Literal {
value: resolved_value,
}
}
SectionContent::Variable { name } => {
if let Some(value) = self.context.peek_var(name) {
SectionContent::Literal {
value: value.to_string(),
}
} else {
return Err(ContextQueryError::VariableNotFound(name.clone()));
}
}
SectionContent::Search {
collection,
query,
top_k,
min_score,
} => {
let resolved_query = match query {
SimilarityQuery::Variable(var) => {
if let Some(value) = self.context.peek_var(var) {
match value {
ContextValue::String(s) => SimilarityQuery::Text(s.clone()),
ContextValue::List(l) => {
let vec: Vec<f32> = l
.iter()
.filter_map(|v| match v {
ContextValue::Number(n) => Some(*n as f32),
_ => None,
})
.collect();
SimilarityQuery::Embedding(vec)
}
_ => {
return Err(ContextQueryError::InvalidVariableType {
variable: var.clone(),
expected: "string or vector".to_string(),
});
}
}
} else {
return Err(ContextQueryError::VariableNotFound(var.clone()));
}
}
other => other.clone(),
};
SectionContent::Search {
collection: collection.clone(),
query: resolved_query,
top_k: *top_k,
min_score: *min_score,
}
}
other => other.clone(),
};
resolved.push(resolved_section);
}
Ok(resolved)
}
fn resolve_variables(&self, input: &str) -> String {
self.context.substitute_vars(input)
}
fn check_section_permissions(&self, section: &ContextSection) -> Result<(), ContextQueryError> {
match §ion.content {
SectionContent::Get { path } => {
let path_str = path.to_path_string();
if path_str.starts_with('/') {
self.context
.check_fs_permission(&path_str, AuditOperation::FsRead)
.map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
} else {
let table = path
.segments
.first()
.ok_or_else(|| ContextQueryError::InvalidPath("empty path".to_string()))?;
self.context
.check_db_permission(table, AuditOperation::DbQuery)
.map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
}
}
SectionContent::Last { table, .. } | SectionContent::Select { table, .. } => {
self.context
.check_db_permission(table, AuditOperation::DbQuery)
.map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
}
SectionContent::Search { collection, .. } => {
self.context
.check_db_permission(collection, AuditOperation::DbQuery)
.map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
}
SectionContent::Literal { .. } | SectionContent::Variable { .. } => {
}
SectionContent::ToolRegistry { .. } | SectionContent::ToolCalls { .. } => {
}
}
Ok(())
}
fn execute_section_content(
&self,
section: &ContextSection,
_budget: usize,
) -> Result<String, ContextQueryError> {
match §ion.content {
SectionContent::Literal { value } => Ok(value.clone()),
SectionContent::Variable { name } => self
.context
.peek_var(name)
.map(|v| v.to_string())
.ok_or_else(|| ContextQueryError::VariableNotFound(name.clone())),
SectionContent::Get { path } => {
Ok(format!(
"[{}: path={}]",
section.name,
path.to_path_string()
))
}
SectionContent::Last { count, table, .. } => {
Ok(format!("[{}: last {} from {}]", section.name, count, table))
}
SectionContent::Search {
collection,
query: similarity_query,
top_k,
min_score,
} => {
match &self.vector_index {
Some(index) => {
let results = match similarity_query {
SimilarityQuery::Embedding(emb) => {
index.search_by_embedding(collection, emb, *top_k, *min_score)
}
SimilarityQuery::Text(text) => {
self.search_by_text_with_embedding(
index, collection, text, *top_k, *min_score,
)
}
SimilarityQuery::Variable(var_name) => {
match self.context.peek_var(var_name) {
Some(ContextValue::String(text)) => self
.search_by_text_with_embedding(
index, collection, text, *top_k, *min_score,
),
Some(ContextValue::List(list)) => {
let embedding: Result<Vec<f32>, _> = list
.iter()
.map(|v| match v {
ContextValue::Number(n) => Ok(*n as f32),
ContextValue::String(s) => {
s.parse::<f32>().map_err(|_| "not a number")
}
_ => Err("not a number"),
})
.collect();
match embedding {
Ok(emb) => index.search_by_embedding(
collection, &emb, *top_k, *min_score,
),
Err(_) => {
Err("Variable is not a valid embedding vector"
.to_string())
}
}
}
_ => Err(format!(
"Variable '{}' not found or has wrong type",
var_name
)),
}
}
};
match results {
Ok(search_results) => {
self.format_search_results(§ion.name, &search_results)
}
Err(e) => {
Ok(format!("[{}: search error: {}]", section.name, e))
}
}
}
None => {
Ok(format!(
"[{}: search {} top {}]",
section.name, collection, top_k
))
}
}
}
SectionContent::Select { table, limit, .. } => {
let limit_str = limit.map(|l| format!(" limit {}", l)).unwrap_or_default();
Ok(format!(
"[{}: select from {}{}]",
section.name, table, limit_str
))
}
SectionContent::ToolRegistry {
include,
exclude,
include_schema,
} => {
self.format_tool_registry(include, exclude, *include_schema)
}
SectionContent::ToolCalls {
count,
tool_filter,
status_filter,
include_outputs,
} => {
self.format_tool_calls(
*count,
tool_filter.as_deref(),
status_filter.as_deref(),
*include_outputs,
)
}
}
}
fn format_tool_registry(
&self,
include: &[String],
exclude: &[String],
include_schema: bool,
) -> Result<String, ContextQueryError> {
use std::fmt::Write;
let tools = &self.context.tool_registry;
let mut output = String::new();
writeln!(output, "[tool_registry ({} tools)]", tools.len())
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
for tool in tools {
if !include.is_empty() && !include.contains(&tool.name) {
continue;
}
if exclude.contains(&tool.name) {
continue;
}
writeln!(output, " [{}]", tool.name)
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
writeln!(output, " description = {:?}", tool.description)
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
if include_schema {
if let Some(schema) = &tool.parameters_schema {
writeln!(output, " parameters = {}", schema)
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
}
}
}
Ok(output)
}
fn format_tool_calls(
&self,
count: usize,
tool_filter: Option<&str>,
status_filter: Option<&str>,
include_outputs: bool,
) -> Result<String, ContextQueryError> {
use std::fmt::Write;
let calls = &self.context.tool_calls;
let mut output = String::new();
let filtered: Vec<_> = calls
.iter()
.filter(|call| {
tool_filter.map(|f| call.tool_name == f).unwrap_or(true)
&& status_filter
.map(|s| match s {
"success" => call.result.is_some() && call.error.is_none(),
"error" => call.error.is_some(),
"pending" => call.result.is_none() && call.error.is_none(),
_ => true,
})
.unwrap_or(true)
})
.rev() .take(count)
.collect();
writeln!(output, "[tool_calls ({} calls)]", filtered.len())
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
for call in filtered {
writeln!(output, " [call {}]", call.call_id)
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
writeln!(output, " tool = {:?}", call.tool_name)
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
writeln!(output, " arguments = {:?}", call.arguments)
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
if include_outputs {
if let Some(result) = &call.result {
writeln!(output, " result = {:?}", result)
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
}
if let Some(error) = &call.error {
writeln!(output, " error = {:?}", error)
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
}
}
}
Ok(output)
}
fn search_by_text_with_embedding(
&self,
index: &std::sync::Arc<dyn VectorIndex>,
collection: &str,
text: &str,
k: usize,
min_score: Option<f32>,
) -> Result<Vec<VectorSearchResult>, String> {
match &self.embedding_provider {
Some(provider) => {
let embedding = provider.embed_text(text)?;
index.search_by_embedding(collection, &embedding, k, min_score)
}
None => {
index.search_by_text(collection, text, k, min_score)
}
}
}
fn format_search_results(
&self,
section_name: &str,
results: &[VectorSearchResult],
) -> Result<String, ContextQueryError> {
use std::fmt::Write;
let mut output = String::new();
writeln!(output, "[{} ({} results)]", section_name, results.len())
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
for (i, result) in results.iter().enumerate() {
writeln!(output, " [result {} score={:.4}]", i + 1, result.score)
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
writeln!(output, " id = {}", result.id)
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
for line in result.content.lines() {
writeln!(output, " {}", line)
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
}
if !result.metadata.is_empty() {
writeln!(output, " [metadata]")
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
for (key, value) in &result.metadata {
writeln!(output, " {} = {:?}", key, value)
.map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
}
}
}
Ok(output)
}
#[allow(dead_code)]
fn truncate_content(
&self,
content: &str,
max_tokens: usize,
strategy: TruncationStrategy,
) -> String {
let max_chars = max_tokens * 4;
if content.len() <= max_chars {
return content.to_string();
}
match strategy {
TruncationStrategy::TailDrop => {
let mut result: String = content.chars().take(max_chars - 3).collect();
result.push_str("...");
result
}
TruncationStrategy::HeadDrop => {
let skip = content.len() - max_chars + 3;
let mut result = "...".to_string();
result.extend(content.chars().skip(skip));
result
}
TruncationStrategy::Proportional => {
let quarter = max_chars / 4;
let first: String = content.chars().take(quarter).collect();
let last: String = content
.chars()
.skip(content.len().saturating_sub(quarter))
.collect();
format!("{}...{}...", first, last)
}
TruncationStrategy::Fail => {
content.to_string() }
}
}
pub fn get_session_context(&self) -> HashMap<String, String> {
self.context
.variables
.iter()
.map(|(k, v)| (k.clone(), v.to_string()))
.collect()
}
pub fn set_variable(&mut self, name: &str, value: ContextValue) {
self.context.set_var(name, value);
}
pub fn remaining_budget(&self) -> u64 {
self.context
.budget
.max_tokens
.map(|max| max.saturating_sub(self.context.budget.tokens_used))
.unwrap_or(u64::MAX)
}
}
#[derive(Debug, Clone)]
pub struct ContextQueryResult {
pub output_name: String,
pub sections: Vec<SectionResult>,
pub total_tokens: usize,
pub token_limit: usize,
pub format: OutputFormat,
pub allocation_explain: Option<Vec<crate::token_budget::AllocationDecision>>,
}
impl ContextQueryResult {
fn new(output_name: String) -> Self {
Self {
output_name,
sections: Vec::new(),
total_tokens: 0,
token_limit: 0,
format: OutputFormat::Soch,
allocation_explain: None,
}
}
pub fn render(&self) -> String {
let mut output = String::new();
match self.format {
OutputFormat::Soch => {
output.push_str(&format!("{}[{}]:\n", self.output_name, self.sections.len()));
for section in &self.sections {
output.push_str(&format!(
" {}[{}{}]:\n",
section.name,
section.tokens_used,
if section.truncated { "T" } else { "" }
));
for line in section.content.lines() {
output.push_str(&format!(" {}\n", line));
}
}
}
OutputFormat::Json => {
output.push_str("{\n");
output.push_str(&format!(" \"name\": \"{}\",\n", self.output_name));
output.push_str(&format!(" \"total_tokens\": {},\n", self.total_tokens));
output.push_str(" \"sections\": [\n");
for (i, section) in self.sections.iter().enumerate() {
output.push_str(&format!(" {{\"name\": \"{}\", \"tokens\": {}, \"truncated\": {}, \"content\": \"{}\"}}",
section.name,
section.tokens_used,
section.truncated,
section.content.replace('"', "\\\"").replace('\n', "\\n")
));
if i < self.sections.len() - 1 {
output.push(',');
}
output.push('\n');
}
output.push_str(" ]\n}");
}
OutputFormat::Markdown => {
output.push_str(&format!("# {}\n\n", self.output_name));
output.push_str(&format!(
"*Tokens: {}/{}*\n\n",
self.total_tokens, self.token_limit
));
for section in &self.sections {
output.push_str(&format!("## {}", section.name));
if section.truncated {
output.push_str(" *(truncated)*");
}
output.push_str("\n\n");
output.push_str(§ion.content);
output.push_str("\n\n");
}
}
}
output
}
pub fn utilization(&self) -> f64 {
if self.token_limit == 0 {
return 0.0;
}
(self.total_tokens as f64 / self.token_limit as f64) * 100.0
}
pub fn has_truncation(&self) -> bool {
self.sections.iter().any(|s| s.truncated)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct SectionPriority(pub i32);
impl SectionPriority {
pub const CRITICAL: SectionPriority = SectionPriority(-100);
pub const SYSTEM: SectionPriority = SectionPriority(-1);
pub const USER: SectionPriority = SectionPriority(0);
pub const HISTORY: SectionPriority = SectionPriority(1);
pub const KNOWLEDGE: SectionPriority = SectionPriority(2);
pub const SUPPLEMENTARY: SectionPriority = SectionPriority(10);
}
pub struct ContextQueryBuilder {
output_name: String,
session: SessionReference,
options: ContextQueryOptions,
sections: Vec<ContextSection>,
}
impl ContextQueryBuilder {
pub fn new(output_name: &str) -> Self {
Self {
output_name: output_name.to_string(),
session: SessionReference::None,
options: ContextQueryOptions::default(),
sections: Vec::new(),
}
}
pub fn from_session(mut self, session_id: &str) -> Self {
self.session = SessionReference::Session(session_id.to_string());
self
}
pub fn from_agent(mut self, agent_id: &str) -> Self {
self.session = SessionReference::Agent(agent_id.to_string());
self
}
pub fn with_token_limit(mut self, limit: usize) -> Self {
self.options.token_limit = limit;
self
}
pub fn include_schema(mut self, include: bool) -> Self {
self.options.include_schema = include;
self
}
pub fn format(mut self, format: OutputFormat) -> Self {
self.options.format = format;
self
}
pub fn truncation(mut self, strategy: TruncationStrategy) -> Self {
self.options.truncation = strategy;
self
}
pub fn get(mut self, name: &str, priority: i32, path: &str) -> Self {
let path_expr = PathExpression::parse(path).unwrap_or(PathExpression {
segments: vec![path.to_string()],
fields: vec![],
all_fields: true,
});
self.sections.push(ContextSection {
name: name.to_string(),
priority,
content: SectionContent::Get { path: path_expr },
transform: None,
});
self
}
pub fn last(mut self, name: &str, priority: i32, count: usize, table: &str) -> Self {
self.sections.push(ContextSection {
name: name.to_string(),
priority,
content: SectionContent::Last {
count,
table: table.to_string(),
where_clause: None,
},
transform: None,
});
self
}
pub fn search(
mut self,
name: &str,
priority: i32,
collection: &str,
query_var: &str,
top_k: usize,
) -> Self {
self.sections.push(ContextSection {
name: name.to_string(),
priority,
content: SectionContent::Search {
collection: collection.to_string(),
query: SimilarityQuery::Variable(query_var.to_string()),
top_k,
min_score: None,
},
transform: None,
});
self
}
pub fn literal(mut self, name: &str, priority: i32, value: &str) -> Self {
self.sections.push(ContextSection {
name: name.to_string(),
priority,
content: SectionContent::Literal {
value: value.to_string(),
},
transform: None,
});
self
}
pub fn build(self) -> ContextSelectQuery {
ContextSelectQuery {
output_name: self.output_name,
session: self.session,
options: self.options,
sections: self.sections,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_path_expression_simple() {
let path = PathExpression::parse("user.profile").unwrap();
assert_eq!(path.segments, vec!["user", "profile"]);
assert!(path.all_fields);
}
#[test]
fn test_path_expression_with_fields() {
let path = PathExpression::parse("user.profile.{name, email}").unwrap();
assert_eq!(path.segments, vec!["user", "profile"]);
assert_eq!(path.fields, vec!["name", "email"]);
assert!(!path.all_fields);
}
#[test]
fn test_path_expression_glob() {
let path = PathExpression::parse("user.**").unwrap();
assert_eq!(path.segments, vec!["user"]);
assert!(path.all_fields);
}
#[test]
fn test_parse_simple_query() {
let query = r#"
CONTEXT SELECT prompt_context
FROM session($SESSION_ID)
WITH (token_limit = 2048, include_schema = true)
SECTIONS (
USER PRIORITY 0: GET user.profile.{name, preferences}
)
"#;
let mut parser = ContextQueryParser::new(query);
let result = parser.parse().unwrap();
assert_eq!(result.output_name, "prompt_context");
assert!(matches!(result.session, SessionReference::Session(s) if s == "SESSION_ID"));
assert_eq!(result.options.token_limit, 2048);
assert!(result.options.include_schema);
assert_eq!(result.sections.len(), 1);
assert_eq!(result.sections[0].name, "USER");
assert_eq!(result.sections[0].priority, 0);
}
#[test]
fn test_parse_multiple_sections() {
let query = r#"
CONTEXT SELECT context
SECTIONS (
A PRIORITY 0: "literal value",
B PRIORITY 1: LAST 10 FROM logs,
C PRIORITY 2: SEARCH docs BY SIMILARITY($query) TOP 5
)
"#;
let mut parser = ContextQueryParser::new(query);
let result = parser.parse().unwrap();
assert_eq!(result.sections.len(), 3);
assert_eq!(result.sections[0].name, "A");
assert!(
matches!(&result.sections[0].content, SectionContent::Literal { value } if value == "literal value")
);
assert_eq!(result.sections[1].name, "B");
assert!(
matches!(&result.sections[1].content, SectionContent::Last { count: 10, table, .. } if table == "logs")
);
assert_eq!(result.sections[2].name, "C");
assert!(
matches!(&result.sections[2].content, SectionContent::Search { collection, top_k: 5, .. } if collection == "docs")
);
}
#[test]
fn test_builder() {
let query = ContextQueryBuilder::new("prompt")
.from_session("sess123")
.with_token_limit(4096)
.include_schema(false)
.get("USER", 0, "user.profile.{name, email}")
.last("HISTORY", 1, 20, "events")
.search("DOCS", 2, "knowledge_base", "query_embedding", 10)
.literal("SYSTEM", -1, "You are a helpful assistant")
.build();
assert_eq!(query.output_name, "prompt");
assert_eq!(query.options.token_limit, 4096);
assert!(!query.options.include_schema);
assert_eq!(query.sections.len(), 4);
let system = query.sections.iter().find(|s| s.name == "SYSTEM").unwrap();
assert_eq!(system.priority, -1);
}
#[test]
fn test_output_format() {
let query = r#"
CONTEXT SELECT ctx
WITH (format = markdown)
SECTIONS ()
"#;
let mut parser = ContextQueryParser::new(query);
let result = parser.parse().unwrap();
assert_eq!(result.options.format, OutputFormat::Markdown);
}
#[test]
fn test_truncation_strategy() {
let query = r#"
CONTEXT SELECT ctx
WITH (truncation = proportional)
SECTIONS ()
"#;
let mut parser = ContextQueryParser::new(query);
let result = parser.parse().unwrap();
assert_eq!(result.options.truncation, TruncationStrategy::Proportional);
}
#[test]
fn test_simple_vector_index_creation() {
let index = SimpleVectorIndex::new();
index.create_collection("test", 3);
let stats = index.stats("test");
assert!(stats.is_some());
let stats = stats.unwrap();
assert_eq!(stats.dimension, 3);
assert_eq!(stats.vector_count, 0);
assert_eq!(stats.metric, "cosine");
}
#[test]
fn test_simple_vector_index_insert_and_search() {
let index = SimpleVectorIndex::new();
index.create_collection("docs", 3);
index
.insert(
"docs",
"doc1".to_string(),
vec![1.0, 0.0, 0.0],
"Document about cats".to_string(),
HashMap::new(),
)
.unwrap();
index
.insert(
"docs",
"doc2".to_string(),
vec![0.9, 0.1, 0.0],
"Document about dogs".to_string(),
HashMap::new(),
)
.unwrap();
index
.insert(
"docs",
"doc3".to_string(),
vec![0.0, 0.0, 1.0],
"Document about cars".to_string(),
HashMap::new(),
)
.unwrap();
let results = index
.search_by_embedding("docs", &[1.0, 0.0, 0.0], 2, None)
.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "doc1"); assert!((results[0].score - 1.0).abs() < 0.001);
assert_eq!(results[1].id, "doc2"); assert!(results[1].score > 0.9); }
#[test]
fn test_simple_vector_index_min_score_filter() {
let index = SimpleVectorIndex::new();
index.create_collection("docs", 3);
index
.insert(
"docs",
"a".to_string(),
vec![1.0, 0.0, 0.0],
"A".to_string(),
HashMap::new(),
)
.unwrap();
index
.insert(
"docs",
"b".to_string(),
vec![0.0, 1.0, 0.0],
"B".to_string(),
HashMap::new(),
)
.unwrap();
index
.insert(
"docs",
"c".to_string(),
vec![0.0, 0.0, 1.0],
"C".to_string(),
HashMap::new(),
)
.unwrap();
let results = index
.search_by_embedding("docs", &[1.0, 0.0, 0.0], 10, Some(0.9))
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "a");
}
#[test]
fn test_simple_vector_index_dimension_mismatch() {
let index = SimpleVectorIndex::new();
index.create_collection("docs", 3);
let result = index.insert(
"docs",
"bad".to_string(),
vec![1.0, 0.0], "Content".to_string(),
HashMap::new(),
);
assert!(result.is_err());
assert!(result.unwrap_err().contains("dimension mismatch"));
}
#[test]
fn test_simple_vector_index_nonexistent_collection() {
let index = SimpleVectorIndex::new();
let result = index.search_by_embedding("nonexistent", &[1.0], 1, None);
assert!(result.is_err());
assert!(result.unwrap_err().contains("not found"));
}
#[test]
fn test_vector_index_with_metadata() {
let index = SimpleVectorIndex::new();
index.create_collection("docs", 2);
let mut metadata = HashMap::new();
metadata.insert("author".to_string(), SochValue::Text("Alice".to_string()));
metadata.insert("year".to_string(), SochValue::Int(2024));
index
.insert(
"docs",
"doc1".to_string(),
vec![1.0, 0.0],
"Document content".to_string(),
metadata,
)
.unwrap();
let results = index
.search_by_embedding("docs", &[1.0, 0.0], 1, None)
.unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].metadata.contains_key("author"));
assert!(results[0].metadata.contains_key("year"));
}
#[test]
fn test_vector_index_text_search_unsupported() {
let index = SimpleVectorIndex::new();
index.create_collection("docs", 2);
let result = index.search_by_text("docs", "hello", 5, None);
assert!(result.is_err());
assert!(result.unwrap_err().contains("embedding model"));
}
}