use serde::Deserialize;
use serde_json::{Value, json};
use crate::error::ToolError;
use crate::protocol::ServerState;
use crate::tools::{TOOL_OUTPUT_CAP_BYTES, is_safe_identifier, value_to_json};
const DEFAULT_K: usize = 10;
const HARD_CAP_K: usize = 100;
pub fn metadata() -> Value {
json!({
"name": "vector_search",
"description": "Find the k nearest rows to a query embedding in a VECTOR column. \
Uses an HNSW index automatically if one is built on the column \
(CREATE INDEX … USING hnsw); otherwise falls back to a brute-force \
scan. Returns the table's columns for the k closest rows, in \
ascending distance order. Supported metrics: `l2` (Euclidean, \
default), `cosine`, `dot`. (The numeric distance value is not \
included in the response — the engine doesn't yet support \
function calls in SELECT projections.)",
"inputSchema": {
"type": "object",
"properties": {
"table": {
"type": "string",
"description": "Table name. Must match `[A-Za-z_][A-Za-z0-9_]*`.",
},
"column": {
"type": "string",
"description": "VECTOR column on the table. Must match `[A-Za-z_][A-Za-z0-9_]*`.",
},
"embedding": {
"type": "array",
"items": { "type": "number" },
"description": "Query vector. Length must match the column's declared dimension.",
},
"k": {
"type": "integer",
"description": "Number of nearest neighbors to return (1..=100, default 10).",
"minimum": 1,
"maximum": 100,
},
"metric": {
"type": "string",
"enum": ["l2", "cosine", "dot"],
"description": "Distance metric (default `l2`).",
},
},
"required": ["table", "column", "embedding"],
"additionalProperties": false,
}
})
}
#[derive(Deserialize)]
struct Args {
table: String,
column: String,
embedding: Vec<f64>,
#[serde(default)]
k: Option<usize>,
#[serde(default)]
metric: Option<String>,
}
pub fn handle(args: Value, state: &mut ServerState) -> Result<String, ToolError> {
let args: Args = serde_json::from_value(args)
.map_err(|e| ToolError::new(format!("invalid arguments: {e}")))?;
if !is_safe_identifier(&args.table) {
return Err(ToolError::new(format!(
"invalid table name `{}`",
args.table
)));
}
if !is_safe_identifier(&args.column) {
return Err(ToolError::new(format!(
"invalid column name `{}`",
args.column
)));
}
if args.embedding.is_empty() {
return Err(ToolError::new(
"embedding must be a non-empty array of numbers".to_string(),
));
}
let metric_fn = match args.metric.as_deref().unwrap_or("l2") {
"l2" => "vec_distance_l2",
"cosine" => "vec_distance_cosine",
"dot" => "vec_distance_dot",
other => {
return Err(ToolError::new(format!(
"unsupported metric `{other}`. Use `l2`, `cosine`, or `dot`."
)));
}
};
let k = args.k.unwrap_or(DEFAULT_K).clamp(1, HARD_CAP_K);
let embedding_lit = format_vector_literal(&args.embedding);
{
let db = state.conn.database();
let table = db
.get_table(args.table.clone())
.map_err(|e| ToolError::new(format!("table `{}` not found: {e}", args.table)))?;
let target = table
.columns
.iter()
.find(|c| c.column_name == args.column)
.ok_or_else(|| {
ToolError::new(format!(
"column `{}` not found on table `{}`",
args.column, args.table,
))
})?;
if let sqlrite::sql::db::table::DataType::Vector(dim) = target.datatype {
if args.embedding.len() != dim as usize {
return Err(ToolError::new(format!(
"embedding has {} dimensions but column `{}` is VECTOR({}) — \
dimension mismatch",
args.embedding.len(),
args.column,
dim,
)));
}
} else {
return Err(ToolError::new(format!(
"column `{}` on table `{}` is `{}`, not a VECTOR column",
args.column, args.table, target.datatype,
)));
}
}
let sql = format!(
"SELECT * FROM {tbl} ORDER BY {fn_}({col}, {emb}) ASC LIMIT {k}",
fn_ = metric_fn,
col = args.column,
emb = embedding_lit,
tbl = args.table,
k = k,
);
let stmt = state.conn.prepare(&sql)?;
let mut rows = stmt.query()?;
let columns = rows.columns().to_vec();
let mut out: Vec<Value> = Vec::with_capacity(k);
let mut size_estimate = 0;
let mut byte_truncated = false;
while let Some(row) = rows.next()? {
let mut obj = serde_json::Map::with_capacity(columns.len());
for (i, col) in columns.iter().enumerate() {
let v: sqlrite::Value = row.get(i)?;
let json_val = value_to_json(&v);
size_estimate += col.len() + 8 + json_val.to_string().len();
obj.insert(col.clone(), json_val);
}
if size_estimate > TOOL_OUTPUT_CAP_BYTES {
byte_truncated = true;
break;
}
out.push(Value::Object(obj));
}
let mut result = json!({
"table": args.table,
"column": args.column,
"metric": args.metric.unwrap_or_else(|| "l2".to_string()),
"k_requested": k,
"rows": out,
});
if byte_truncated {
result["truncated"] = json!(true);
result["truncation_reason"] = json!(format!(
"response truncated at {} bytes ({} of {} rows shown)",
TOOL_OUTPUT_CAP_BYTES,
out.len(),
k,
));
}
serde_json::to_string_pretty(&result)
.map_err(|e| ToolError::new(format!("internal: failed to serialize results: {e}")))
}
fn format_vector_literal(v: &[f64]) -> String {
let mut s = String::with_capacity(v.len() * 6 + 2);
s.push('[');
for (i, x) in v.iter().enumerate() {
if i > 0 {
s.push_str(", ");
}
if !x.is_finite() {
s.push_str("0.0");
} else {
let formatted = format!("{x}");
if formatted.contains('.') || formatted.contains('e') || formatted.contains('E') {
s.push_str(&formatted);
} else {
s.push_str(&formatted);
s.push_str(".0");
}
}
}
s.push(']');
s
}