use std::collections::HashMap;
use std::sync::Arc;
use crate::connection::{CompareOp, SochConnection, WhereClause};
use sochdb_core::soch::SochValue;
pub struct ContextQueryBuilder {
session_id: Option<String>,
agent_id: Option<String>,
token_budget: usize,
include_schema: bool,
format: ContextFormat,
truncation: TruncationStrategy,
sections: Vec<ContextSection>,
variables: HashMap<String, ContextValue>,
connection: Option<Arc<SochConnection>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ContextFormat {
Soch,
Json,
Markdown,
Text,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TruncationStrategy {
TailDrop,
HeadDrop,
Proportional,
Strict,
}
#[derive(Debug, Clone)]
pub enum ContextValue {
String(String),
Int(i64),
Float(f64),
Bool(bool),
Embedding(Vec<f32>),
Binary(Vec<u8>),
}
impl ContextQueryBuilder {
pub fn new() -> Self {
Self {
session_id: None,
agent_id: None,
token_budget: 4096,
include_schema: true,
format: ContextFormat::Soch,
truncation: TruncationStrategy::TailDrop,
sections: Vec::new(),
variables: HashMap::new(),
connection: None,
}
}
pub fn with_connection(conn: Arc<SochConnection>) -> Self {
Self {
session_id: None,
agent_id: None,
token_budget: 4096,
include_schema: true,
format: ContextFormat::Soch,
truncation: TruncationStrategy::TailDrop,
sections: Vec::new(),
variables: HashMap::new(),
connection: Some(conn),
}
}
pub fn connection(mut self, conn: Arc<SochConnection>) -> Self {
self.connection = Some(conn);
self
}
pub fn for_session(mut self, session_id: &str) -> Self {
self.session_id = Some(session_id.to_string());
self
}
pub fn for_agent(mut self, agent_id: &str) -> Self {
self.agent_id = Some(agent_id.to_string());
self
}
pub fn with_budget(mut self, budget: usize) -> Self {
self.token_budget = budget;
self
}
pub fn include_schema(mut self, include: bool) -> Self {
self.include_schema = include;
self
}
pub fn format(mut self, format: ContextFormat) -> Self {
self.format = format;
self
}
pub fn truncation(mut self, strategy: TruncationStrategy) -> Self {
self.truncation = strategy;
self
}
pub fn set_var(mut self, name: &str, value: ContextValue) -> Self {
self.variables.insert(name.to_string(), value);
self
}
pub fn section(self, name: &str, priority: i32) -> SectionBuilder {
SectionBuilder {
parent: self,
name: name.to_string(),
priority,
content: None,
filter: None,
transform: None,
}
}
pub fn literal(mut self, name: &str, priority: i32, text: &str) -> Self {
self.sections.push(ContextSection {
name: name.to_string(),
priority,
content: SectionContent::Literal(text.to_string()),
filter: None,
transform: None,
});
self
}
pub fn variable(mut self, name: &str, priority: i32, var_name: &str) -> Self {
self.sections.push(ContextSection {
name: name.to_string(),
priority,
content: SectionContent::Variable(var_name.to_string()),
filter: None,
transform: None,
});
self
}
pub fn build(self) -> ContextQuery {
ContextQuery {
session_id: self.session_id,
agent_id: self.agent_id,
token_budget: self.token_budget,
include_schema: self.include_schema,
format: self.format,
truncation: self.truncation,
sections: self.sections,
variables: self.variables,
connection: self.connection,
}
}
pub fn execute(self) -> Result<ContextQueryResult, ContextQueryError> {
let query = self.build();
query.execute()
}
}
impl Default for ContextQueryBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct SectionBuilder {
parent: ContextQueryBuilder,
name: String,
priority: i32,
content: Option<SectionContent>,
filter: Option<FilterExpr>,
transform: Option<TransformExpr>,
}
impl SectionBuilder {
pub fn get(mut self, path: &str) -> Self {
self.content = Some(SectionContent::Get(path.to_string()));
self
}
pub fn last(mut self, count: usize, table: &str) -> Self {
self.content = Some(SectionContent::Last {
count,
table: table.to_string(),
});
self
}
pub fn search(mut self, collection: &str, query_var: &str, top_k: usize) -> Self {
self.content = Some(SectionContent::Search {
collection: collection.to_string(),
query_var: query_var.to_string(),
top_k,
min_score: None,
});
self
}
pub fn select(mut self, columns: &[&str], table: &str) -> Self {
self.content = Some(SectionContent::Select {
columns: columns.iter().map(|s| s.to_string()).collect(),
table: table.to_string(),
limit: None,
});
self
}
pub fn where_eq(mut self, column: &str, value: &str) -> Self {
let filter = FilterExpr::Eq(column.to_string(), value.to_string());
self.filter = match self.filter {
None => Some(filter),
Some(existing) => Some(FilterExpr::And(vec![existing, filter])),
};
self
}
pub fn where_gt(mut self, column: &str, value: i64) -> Self {
let filter = FilterExpr::Gt(column.to_string(), value);
self.filter = match self.filter {
None => Some(filter),
Some(existing) => Some(FilterExpr::And(vec![existing, filter])),
};
self
}
pub fn where_lt(mut self, column: &str, value: i64) -> Self {
let filter = FilterExpr::Lt(column.to_string(), value);
self.filter = match self.filter {
None => Some(filter),
Some(existing) => Some(FilterExpr::And(vec![existing, filter])),
};
self
}
pub fn where_like(mut self, column: &str, pattern: &str) -> Self {
let filter = FilterExpr::Like(column.to_string(), pattern.to_string());
self.filter = match self.filter {
None => Some(filter),
Some(existing) => Some(FilterExpr::And(vec![existing, filter])),
};
self
}
pub fn limit(mut self, limit_val: usize) -> Self {
if let Some(SectionContent::Select { ref mut limit, .. }) = self.content {
*limit = Some(limit_val);
}
self
}
pub fn min_score(mut self, score: f32) -> Self {
if let Some(SectionContent::Search {
ref mut min_score, ..
}) = self.content
{
*min_score = Some(score);
}
self
}
pub fn summarize(mut self, max_tokens: usize) -> Self {
self.transform = Some(TransformExpr::Summarize(max_tokens));
self
}
pub fn project(mut self, fields: &[&str]) -> Self {
self.transform = Some(TransformExpr::Project(
fields.iter().map(|s| s.to_string()).collect(),
));
self
}
pub fn done(mut self) -> ContextQueryBuilder {
let section = ContextSection {
name: self.name,
priority: self.priority,
content: self
.content
.unwrap_or(SectionContent::Literal(String::new())),
filter: self.filter,
transform: self.transform,
};
self.parent.sections.push(section);
self.parent
}
}
#[derive(Debug, Clone)]
pub enum SectionContent {
Get(String),
Last { count: usize, table: String },
Search {
collection: String,
query_var: String,
top_k: usize,
min_score: Option<f32>,
},
Select {
columns: Vec<String>,
table: String,
limit: Option<usize>,
},
Literal(String),
Variable(String),
}
#[derive(Debug, Clone)]
pub enum FilterExpr {
Eq(String, String),
Gt(String, i64),
Lt(String, i64),
Ge(String, i64),
Le(String, i64),
Like(String, String),
In(String, Vec<String>),
And(Vec<FilterExpr>),
Or(Vec<FilterExpr>),
}
#[derive(Debug, Clone)]
pub enum TransformExpr {
Summarize(usize),
Project(Vec<String>),
Template(String),
}
#[derive(Debug, Clone)]
pub struct ContextSection {
pub name: String,
pub priority: i32,
pub content: SectionContent,
pub filter: Option<FilterExpr>,
pub transform: Option<TransformExpr>,
}
pub struct ContextQuery {
pub session_id: Option<String>,
pub agent_id: Option<String>,
pub token_budget: usize,
pub include_schema: bool,
pub format: ContextFormat,
pub truncation: TruncationStrategy,
pub sections: Vec<ContextSection>,
pub variables: HashMap<String, ContextValue>,
connection: Option<Arc<SochConnection>>,
}
impl ContextQuery {
pub fn execute(&self) -> Result<ContextQueryResult, ContextQueryError> {
let mut sections = self.sections.clone();
sections.sort_by_key(|s| s.priority);
let mut results = Vec::new();
let mut total_tokens = 0;
for section in §ions {
let content = self.execute_section(section)?;
let tokens = estimate_tokens(&content);
if total_tokens + tokens > self.token_budget {
match self.truncation {
TruncationStrategy::Strict => {
return Err(ContextQueryError::BudgetExceeded {
budget: self.token_budget,
required: total_tokens + tokens,
});
}
TruncationStrategy::TailDrop => {
let remaining = self.token_budget - total_tokens;
let truncated = truncate_to_tokens(&content, remaining);
results.push(SectionResult {
name: section.name.clone(),
content: truncated.clone(),
tokens: estimate_tokens(&truncated),
truncated: true,
dropped: false,
});
break;
}
TruncationStrategy::HeadDrop => {
results.push(SectionResult {
name: section.name.clone(),
content: String::new(),
tokens: 0,
truncated: false,
dropped: true,
});
continue;
}
TruncationStrategy::Proportional => {
let remaining = self.token_budget - total_tokens;
let truncated = truncate_to_tokens(&content, remaining);
results.push(SectionResult {
name: section.name.clone(),
content: truncated.clone(),
tokens: estimate_tokens(&truncated),
truncated: true,
dropped: false,
});
break;
}
}
} else {
results.push(SectionResult {
name: section.name.clone(),
content: content.clone(),
tokens,
truncated: false,
dropped: false,
});
total_tokens += tokens;
}
}
let context = self.assemble_context(&results);
Ok(ContextQueryResult {
context,
token_count: total_tokens,
token_budget: self.token_budget,
sections: results,
session_id: self.session_id.clone(),
})
}
pub fn to_canonical(&self) -> sochdb_query::context_query::ContextSelectQuery {
use sochdb_query::context_query as cq;
let session = match (&self.session_id, &self.agent_id) {
(Some(sid), _) => cq::SessionReference::Session(sid.clone()),
(_, Some(aid)) => cq::SessionReference::Agent(aid.clone()),
(None, None) => cq::SessionReference::None,
};
let truncation = match self.truncation {
TruncationStrategy::Strict => cq::TruncationStrategy::Fail,
TruncationStrategy::TailDrop => cq::TruncationStrategy::TailDrop,
TruncationStrategy::HeadDrop => cq::TruncationStrategy::HeadDrop,
TruncationStrategy::Proportional => cq::TruncationStrategy::Proportional,
};
let format = match self.format {
ContextFormat::Soch => cq::OutputFormat::Soch,
ContextFormat::Json => cq::OutputFormat::Json,
ContextFormat::Markdown => cq::OutputFormat::Markdown,
ContextFormat::Text => cq::OutputFormat::Soch, };
let options = cq::ContextQueryOptions {
token_limit: self.token_budget,
include_schema: self.include_schema,
format,
truncation,
include_headers: true,
};
let sections = self
.sections
.iter()
.map(|s| self.convert_section(s))
.collect();
cq::ContextSelectQuery {
output_name: "context".to_string(),
session,
options,
sections,
}
}
fn convert_section(&self, section: &ContextSection) -> sochdb_query::context_query::ContextSection {
use sochdb_query::context_query as cq;
let content = match §ion.content {
SectionContent::Get(path) => cq::SectionContent::Get {
path: cq::PathExpression::parse(path).unwrap_or_else(|_| cq::PathExpression {
segments: path.split('.').map(|s| s.to_string()).collect(),
fields: vec![],
all_fields: true,
}),
},
SectionContent::Last { count, table } => cq::SectionContent::Last {
count: *count,
table: table.clone(),
where_clause: section.filter.as_ref().map(|f| self.convert_filter(f)),
},
SectionContent::Search {
collection,
query_var,
top_k,
min_score,
} => cq::SectionContent::Search {
collection: collection.clone(),
query: cq::SimilarityQuery::Variable(query_var.clone()),
top_k: *top_k,
min_score: *min_score,
},
SectionContent::Select {
columns,
table,
limit,
} => cq::SectionContent::Select {
columns: columns.clone(),
table: table.clone(),
where_clause: section.filter.as_ref().map(|f| self.convert_filter(f)),
limit: *limit,
},
SectionContent::Literal(text) => cq::SectionContent::Literal { value: text.clone() },
SectionContent::Variable(name) => cq::SectionContent::Variable { name: name.clone() },
};
let transform = section.transform.as_ref().map(|t| match t {
TransformExpr::Summarize(tokens) => cq::SectionTransform::Summarize { max_tokens: *tokens },
TransformExpr::Project(fields) => cq::SectionTransform::Project { fields: fields.clone() },
TransformExpr::Template(tpl) => cq::SectionTransform::Template { template: tpl.clone() },
});
cq::ContextSection {
name: section.name.clone(),
priority: section.priority,
content,
transform,
}
}
fn convert_filter(&self, filter: &FilterExpr) -> sochdb_query::soch_ql::WhereClause {
use sochdb_query::soch_ql as tq;
let (conditions, operator) = match filter {
FilterExpr::Eq(col, val) => (
vec![tq::Condition {
column: col.clone(),
operator: tq::ComparisonOp::Eq,
value: tq::SochValue::Text(val.clone()),
}],
tq::LogicalOp::And,
),
FilterExpr::Gt(col, val) => (
vec![tq::Condition {
column: col.clone(),
operator: tq::ComparisonOp::Gt,
value: tq::SochValue::Int(*val),
}],
tq::LogicalOp::And,
),
FilterExpr::Lt(col, val) => (
vec![tq::Condition {
column: col.clone(),
operator: tq::ComparisonOp::Lt,
value: tq::SochValue::Int(*val),
}],
tq::LogicalOp::And,
),
FilterExpr::Ge(col, val) => (
vec![tq::Condition {
column: col.clone(),
operator: tq::ComparisonOp::Ge,
value: tq::SochValue::Int(*val),
}],
tq::LogicalOp::And,
),
FilterExpr::Le(col, val) => (
vec![tq::Condition {
column: col.clone(),
operator: tq::ComparisonOp::Le,
value: tq::SochValue::Int(*val),
}],
tq::LogicalOp::And,
),
FilterExpr::Like(col, pat) => (
vec![tq::Condition {
column: col.clone(),
operator: tq::ComparisonOp::Like,
value: tq::SochValue::Text(pat.clone()),
}],
tq::LogicalOp::And,
),
FilterExpr::In(col, vals) => (
vec![tq::Condition {
column: col.clone(),
operator: tq::ComparisonOp::In,
value: tq::SochValue::Array(vals.iter().map(|v| tq::SochValue::Text(v.clone())).collect()),
}],
tq::LogicalOp::And,
),
FilterExpr::And(exprs) => {
let conditions: Vec<_> = exprs
.iter()
.flat_map(|e| self.convert_filter(e).conditions)
.collect();
(conditions, tq::LogicalOp::And)
}
FilterExpr::Or(exprs) => {
let conditions: Vec<_> = exprs
.iter()
.flat_map(|e| self.convert_filter(e).conditions)
.collect();
(conditions, tq::LogicalOp::Or)
}
};
tq::WhereClause { conditions, operator }
}
fn execute_section(&self, section: &ContextSection) -> Result<String, ContextQueryError> {
match §ion.content {
SectionContent::Literal(text) => Ok(text.clone()),
SectionContent::Variable(name) => self
.variables
.get(name)
.map(|v| match v {
ContextValue::String(s) => s.clone(),
ContextValue::Int(i) => i.to_string(),
ContextValue::Float(f) => format!("{:.2}", f),
ContextValue::Bool(b) => b.to_string(),
_ => format!("<{}>", name),
})
.ok_or_else(|| ContextQueryError::VariableNotFound(name.clone())),
SectionContent::Get(path) => self.execute_get(section, path),
SectionContent::Last { count, table } => self.execute_last(section, *count, table),
SectionContent::Search {
collection,
query_var,
top_k,
min_score,
} => self.execute_search(section, collection, query_var, *top_k, *min_score),
SectionContent::Select {
columns,
table,
limit,
} => self.execute_select(section, columns, table, *limit),
}
}
fn execute_get(
&self,
section: &ContextSection,
path: &str,
) -> Result<String, ContextQueryError> {
let conn = match &self.connection {
Some(c) => c,
None => {
return Ok(format!(
"# {}\n{}.data: <no connection - path: {}>\n",
section.name, path, path
));
}
};
let tch = conn.tch.read();
let resolution = tch.resolve(path);
match resolution {
crate::connection::PathResolution::Value(col_ref) => {
Ok(format!("# {}\n{}: {}\n", section.name, path, col_ref.name))
}
crate::connection::PathResolution::Array { schema, columns } => {
let mut output = format!("# {}\n", section.name);
output.push_str(&format!(
"{}[{}]{{{}}}\n",
schema.name,
0, columns
.iter()
.map(|c| c.name.as_str())
.collect::<Vec<_>>()
.join(", ")
));
Ok(output)
}
crate::connection::PathResolution::Partial { remaining } => Ok(format!(
"# {}\n{}: <partial match, remaining: {}>\n",
section.name, path, remaining
)),
crate::connection::PathResolution::NotFound => {
Ok(format!("# {}\n{}: <not found>\n", section.name, path))
}
}
}
fn execute_last(
&self,
section: &ContextSection,
count: usize,
table: &str,
) -> Result<String, ContextQueryError> {
let conn = match &self.connection {
Some(c) => c,
None => {
let filter_str = section
.filter
.as_ref()
.map(|f| format!(" WHERE {:?}", f))
.unwrap_or_default();
return Ok(format!(
"# {}\n{}[{}]{{...}}: (last {} rows{})\n",
section.name, table, count, count, filter_str
));
}
};
let where_clause = section.filter.as_ref().map(|f| self.filter_to_where(f));
let tch = conn.tch.read();
let cursor = tch.select(
table,
&[], where_clause.as_ref(),
Some(&("created_at".to_string(), false)), Some(count),
None,
);
self.format_cursor_as_toon(§ion.name, table, cursor)
}
fn execute_select(
&self,
section: &ContextSection,
columns: &[String],
table: &str,
limit: Option<usize>,
) -> Result<String, ContextQueryError> {
let conn = match &self.connection {
Some(c) => c,
None => {
let cols = columns.join(", ");
let limit_str = limit.map(|l| format!(" LIMIT {}", l)).unwrap_or_default();
return Ok(format!(
"# {}\nSELECT {} FROM {}{}\n",
section.name, cols, table, limit_str
));
}
};
let where_clause = section.filter.as_ref().map(|f| self.filter_to_where(f));
let tch = conn.tch.read();
let cursor = tch.select(
table,
columns,
where_clause.as_ref(),
None, limit,
None,
);
self.format_cursor_as_toon(§ion.name, table, cursor)
}
fn execute_search(
&self,
section: &ContextSection,
collection: &str,
query_var: &str,
top_k: usize,
min_score: Option<f32>,
) -> Result<String, ContextQueryError> {
let embedding = match self.variables.get(query_var) {
Some(ContextValue::Embedding(v)) => v.clone(),
Some(_) => {
return Err(ContextQueryError::TypeMismatch {
expected: "embedding".to_string(),
found: "other".to_string(),
});
}
None => {
return Ok(format!(
"# {}\n{}[{}]{{...}}: (top {} by similarity to ${})\n",
section.name, collection, top_k, top_k, query_var
));
}
};
let _conn = match &self.connection {
Some(c) => c,
None => {
return Ok(format!(
"# {}\n{}[{}]{{...}}: (top {} by similarity - no connection)\n",
section.name, collection, top_k, top_k
));
}
};
let _min_score = min_score.unwrap_or(0.0);
Ok(format!(
"# {}\n{}[{}]{{...}}: (top {} by similarity, embedding dim={})\n",
section.name,
collection,
top_k,
top_k,
embedding.len()
))
}
fn filter_to_where(&self, filter: &FilterExpr) -> WhereClause {
match filter {
FilterExpr::Eq(col, val) => WhereClause::Simple {
field: col.clone(),
op: CompareOp::Eq,
value: SochValue::Text(val.clone()),
},
FilterExpr::Gt(col, val) => WhereClause::Simple {
field: col.clone(),
op: CompareOp::Gt,
value: SochValue::Int(*val),
},
FilterExpr::Lt(col, val) => WhereClause::Simple {
field: col.clone(),
op: CompareOp::Lt,
value: SochValue::Int(*val),
},
FilterExpr::Ge(col, val) => WhereClause::Simple {
field: col.clone(),
op: CompareOp::Ge,
value: SochValue::Int(*val),
},
FilterExpr::Le(col, val) => WhereClause::Simple {
field: col.clone(),
op: CompareOp::Le,
value: SochValue::Int(*val),
},
FilterExpr::Like(col, val) => WhereClause::Simple {
field: col.clone(),
op: CompareOp::Like,
value: SochValue::Text(val.clone()),
},
FilterExpr::In(col, vals) => {
WhereClause::In {
field: col.clone(),
values: vals.iter().map(|v| SochValue::Text(v.clone())).collect(),
negated: false,
}
}
FilterExpr::And(filters) => {
let clauses: Vec<WhereClause> = filters
.iter()
.map(|f| self.filter_to_where(f))
.collect();
if clauses.is_empty() {
WhereClause::Simple {
field: String::new(),
op: CompareOp::Eq,
value: SochValue::Bool(true),
}
} else if clauses.len() == 1 {
clauses.into_iter().next().unwrap()
} else {
WhereClause::And(clauses)
}
}
FilterExpr::Or(filters) => {
let clauses: Vec<WhereClause> = filters
.iter()
.map(|f| self.filter_to_where(f))
.collect();
if clauses.is_empty() {
WhereClause::Simple {
field: String::new(),
op: CompareOp::Ne,
value: SochValue::Bool(true),
}
} else if clauses.len() == 1 {
clauses.into_iter().next().unwrap()
} else {
WhereClause::Or(clauses)
}
}
}
}
fn format_cursor_as_toon(
&self,
section_name: &str,
table: &str,
mut cursor: crate::connection::SochCursor,
) -> Result<String, ContextQueryError> {
let mut rows = Vec::new();
while let Some(row) = cursor.next() {
rows.push(row);
}
if rows.is_empty() {
return Ok(format!("# {}\n{}[0]{{}}\n", section_name, table));
}
let mut output = format!("# {}\n", section_name);
let fields: Vec<&String> = if let Some(first) = rows.first() {
first.keys().collect()
} else {
vec![]
};
match self.format {
ContextFormat::Soch => {
output.push_str(&format!(
"{}[{}]{{{}}}\n",
table,
rows.len(),
fields
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(", ")
));
for row in &rows {
let values: Vec<String> = fields
.iter()
.filter_map(|f| row.get(*f).map(format_soch_value))
.collect();
output.push_str(&format!(" [{}]\n", values.join(", ")));
}
}
ContextFormat::Json => {
output.push_str("[\n");
for (i, row) in rows.iter().enumerate() {
output.push_str(" {");
let pairs: Vec<String> = row
.iter()
.map(|(k, v)| format!("\"{}\": {}", k, format_json_value(v)))
.collect();
output.push_str(&pairs.join(", "));
output.push('}');
if i < rows.len() - 1 {
output.push(',');
}
output.push('\n');
}
output.push_str("]\n");
}
ContextFormat::Markdown => {
output.push_str("| ");
output.push_str(
&fields
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(" | "),
);
output.push_str(" |\n");
output.push_str("| ");
output.push_str(&fields.iter().map(|_| "---").collect::<Vec<_>>().join(" | "));
output.push_str(" |\n");
for row in &rows {
output.push_str("| ");
let values: Vec<String> = fields
.iter()
.filter_map(|f| row.get(*f).map(format_soch_value))
.collect();
output.push_str(&values.join(" | "));
output.push_str(" |\n");
}
}
ContextFormat::Text => {
for row in &rows {
for (k, v) in row {
output.push_str(&format!("{}: {}\n", k, format_soch_value(v)));
}
output.push('\n');
}
}
}
Ok(output)
}
fn assemble_context(&self, results: &[SectionResult]) -> String {
let mut context = String::new();
if self.include_schema {
context.push_str("# Context\n");
context.push_str(&format!(
"session: {}\n",
self.session_id.as_deref().unwrap_or("none")
));
context.push_str(&format!("budget: {} tokens\n\n", self.token_budget));
}
for result in results {
if !result.dropped && !result.content.is_empty() {
context.push_str(&result.content);
context.push('\n');
}
}
context
}
}
#[derive(Debug, Clone)]
pub struct ContextQueryResult {
pub context: String,
pub token_count: usize,
pub token_budget: usize,
pub sections: Vec<SectionResult>,
pub session_id: Option<String>,
}
#[derive(Debug, Clone)]
pub struct SectionResult {
pub name: String,
pub content: String,
pub tokens: usize,
pub truncated: bool,
pub dropped: bool,
}
impl ContextQueryResult {
pub fn utilization(&self) -> f64 {
(self.token_count as f64 / self.token_budget as f64) * 100.0
}
pub fn included_sections(&self) -> Vec<&str> {
self.sections
.iter()
.filter(|s| !s.dropped)
.map(|s| s.name.as_str())
.collect()
}
pub fn dropped_sections(&self) -> Vec<&str> {
self.sections
.iter()
.filter(|s| s.dropped)
.map(|s| s.name.as_str())
.collect()
}
pub fn truncated_sections(&self) -> Vec<&str> {
self.sections
.iter()
.filter(|s| s.truncated)
.map(|s| s.name.as_str())
.collect()
}
}
#[derive(Debug, Clone)]
pub enum ContextQueryError {
BudgetExceeded { budget: usize, required: usize },
VariableNotFound(String),
SectionFailed { section: String, error: String },
InvalidQuery(String),
DatabaseError(String),
TypeMismatch { expected: String, found: String },
}
impl std::fmt::Display for ContextQueryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::BudgetExceeded { budget, required } => {
write!(
f,
"token budget exceeded: {} required, {} available",
required, budget
)
}
Self::VariableNotFound(name) => write!(f, "variable not found: ${}", name),
Self::SectionFailed { section, error } => {
write!(f, "section '{}' failed: {}", section, error)
}
Self::InvalidQuery(msg) => write!(f, "invalid query: {}", msg),
Self::DatabaseError(msg) => write!(f, "database error: {}", msg),
Self::TypeMismatch { expected, found } => {
write!(f, "type mismatch: expected {}, found {}", expected, found)
}
}
}
}
impl std::error::Error for ContextQueryError {}
use sochdb_query::token_budget::TokenEstimator;
fn get_estimator() -> TokenEstimator {
TokenEstimator::default()
}
fn estimate_tokens(text: &str) -> usize {
get_estimator().estimate_text(text)
}
fn truncate_to_tokens(text: &str, max_tokens: usize) -> String {
get_estimator().truncate_to_tokens(text, max_tokens)
}
fn format_soch_value(v: &SochValue) -> String {
match v {
SochValue::Null => "∅".to_string(),
SochValue::Int(i) => i.to_string(),
SochValue::UInt(u) => u.to_string(),
SochValue::Float(f) => format!("{:.6}", f),
SochValue::Text(s) => {
if s.contains(',') || s.contains(';') || s.contains('\n') {
format!("\"{}\"", s.replace('"', "\\\""))
} else {
s.clone()
}
}
SochValue::Bool(b) => if *b { "T" } else { "F" }.to_string(),
SochValue::Binary(b) => format!("b64:<{}bytes>", b.len()),
SochValue::Array(arr) => {
let items: Vec<String> = arr.iter().map(format_soch_value).collect();
format!("[{}]", items.join(","))
}
SochValue::Object(map) => {
let items: Vec<String> = map
.iter()
.map(|(k, v)| format!("{}:{}", k, format_soch_value(v)))
.collect();
format!("{{{}}}", items.join(","))
}
SochValue::Ref { table, id } => format!("ref({},{})", table, id),
}
}
fn format_json_value(v: &SochValue) -> String {
match v {
SochValue::Null => "null".to_string(),
SochValue::Int(i) => i.to_string(),
SochValue::UInt(u) => u.to_string(),
SochValue::Float(f) => format!("{}", f),
SochValue::Text(s) => format!("\"{}\"", s.replace('\\', "\\\\").replace('"', "\\\"")),
SochValue::Bool(b) => if *b { "true" } else { "false" }.to_string(),
SochValue::Binary(b) => format!("\"<binary:{}>\"", b.len()),
SochValue::Array(arr) => {
let items: Vec<String> = arr.iter().map(format_json_value).collect();
format!("[{}]", items.join(","))
}
SochValue::Object(map) => {
let items: Vec<String> = map
.iter()
.map(|(k, v)| format!("\"{}\":{}", k, format_json_value(v)))
.collect();
format!("{{{}}}", items.join(","))
}
SochValue::Ref { table, id } => format!("{{\"$ref\":\"{}\",\"id\":{}}}", table, id),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_basic() {
let query = ContextQueryBuilder::new()
.for_session("sess_123")
.with_budget(4096)
.literal("SYSTEM", -1, "You are a helpful assistant")
.build();
assert_eq!(query.session_id, Some("sess_123".to_string()));
assert_eq!(query.token_budget, 4096);
assert_eq!(query.sections.len(), 1);
}
#[test]
fn test_section_builder() {
let query = ContextQueryBuilder::new()
.section("USER", 0)
.get("user.profile.{name, email}")
.done()
.section("HISTORY", 1)
.last(10, "events")
.where_eq("type", "tool_call")
.done()
.section("DOCS", 2)
.search("knowledge_base", "query_embedding", 5)
.min_score(0.7)
.done()
.build();
assert_eq!(query.sections.len(), 3);
assert_eq!(query.sections[0].priority, 0);
assert_eq!(query.sections[1].priority, 1);
assert_eq!(query.sections[2].priority, 2);
}
#[test]
fn test_execute_with_literals() {
let result = ContextQueryBuilder::new()
.with_budget(1000)
.literal("SYSTEM", 0, "You are a helpful assistant")
.literal("USER", 1, "Hello, how are you?")
.execute()
.unwrap();
assert!(result.token_count < 1000);
assert!(result.context.contains("You are a helpful assistant"));
assert!(result.context.contains("Hello, how are you?"));
}
#[test]
fn test_variable_resolution() {
let result = ContextQueryBuilder::new()
.set_var("user_name", ContextValue::String("Alice".to_string()))
.variable("GREETING", 0, "user_name")
.execute()
.unwrap();
assert!(result.context.contains("Alice"));
}
#[test]
fn test_budget_exceeded_strict() {
let result = ContextQueryBuilder::new()
.with_budget(10) .truncation(TruncationStrategy::Strict)
.literal("LONG", 0, &"x".repeat(1000))
.execute();
assert!(matches!(
result,
Err(ContextQueryError::BudgetExceeded { .. })
));
}
#[test]
fn test_budget_truncation() {
let long_text = "x".repeat(1000);
let result = ContextQueryBuilder::new()
.with_budget(100)
.truncation(TruncationStrategy::TailDrop)
.literal("LONG", 0, &long_text)
.execute()
.unwrap();
assert!(result.token_count <= 100);
assert!(result.sections[0].truncated);
}
#[test]
fn test_format_options() {
let query = ContextQueryBuilder::new()
.format(ContextFormat::Markdown)
.include_schema(false)
.build();
assert_eq!(query.format, ContextFormat::Markdown);
assert!(!query.include_schema);
}
#[test]
fn test_complex_filters() {
let query = ContextQueryBuilder::new()
.section("DATA", 0)
.select(&["id", "name", "score"], "users")
.where_gt("score", 80)
.where_like("name", "A%")
.limit(10)
.done()
.build();
let section = &query.sections[0];
assert!(matches!(§ion.filter, Some(FilterExpr::And(_))));
}
#[test]
fn test_result_methods() {
let result = ContextQueryBuilder::new()
.with_budget(1000)
.literal("A", 0, "content a")
.literal("B", 1, "content b")
.execute()
.unwrap();
let included = result.included_sections();
assert_eq!(included.len(), 2);
assert!(included.contains(&"A"));
assert!(included.contains(&"B"));
assert!(result.dropped_sections().is_empty());
assert!(result.truncated_sections().is_empty());
}
#[test]
fn test_estimate_tokens() {
assert_eq!(estimate_tokens(""), 0);
assert_eq!(estimate_tokens("test"), 1);
assert_eq!(estimate_tokens("hello world!"), 3);
}
#[test]
fn test_truncate_to_tokens() {
let text = "This is a long text that needs truncation";
let truncated = truncate_to_tokens(text, 5);
assert!(truncated.len() < text.len());
assert!(truncated.ends_with("..."));
}
}