use serde::Deserialize;
use serde_json::{Value, json};
use crate::error::ToolError;
use crate::protocol::ServerState;
use crate::tools::{TOOL_OUTPUT_CAP_BYTES, is_safe_identifier, value_to_json};
const DEFAULT_K: usize = 10;
const HARD_CAP_K: usize = 100;
pub fn metadata() -> Value {
json!({
"name": "bm25_search",
"description": "Find the top-k rows ranked by BM25 keyword relevance against \
an FTS-indexed TEXT column. Requires a `CREATE INDEX … USING fts \
(column)` to exist on the column; errors otherwise. Uses any-term \
OR semantics (a row matches if it contains ANY of the query \
terms). Returns the table's columns for the k highest-scoring \
rows, in descending BM25 order. Pairs naturally with `vector_search` \
for hybrid retrieval — the LLM can call both and fuse results \
client-side, or compose them in a single SQL via the `query` tool.",
"inputSchema": {
"type": "object",
"properties": {
"table": {
"type": "string",
"description": "Table name. Must match `[A-Za-z_][A-Za-z0-9_]*`.",
},
"column": {
"type": "string",
"description": "FTS-indexed TEXT column on the table. Must match \
`[A-Za-z_][A-Za-z0-9_]*`.",
},
"query": {
"type": "string",
"description": "Free-text query. Tokenized the same way the index \
was built (ASCII split + lowercase for the MVP).",
},
"k": {
"type": "integer",
"description": "Number of top-ranked rows to return (1..=100, default 10).",
"minimum": 1,
"maximum": 100,
},
},
"required": ["table", "column", "query"],
"additionalProperties": false,
}
})
}
#[derive(Deserialize)]
struct Args {
table: String,
column: String,
query: String,
#[serde(default)]
k: Option<usize>,
}
pub fn handle(args: Value, state: &mut ServerState) -> Result<String, ToolError> {
let args: Args = serde_json::from_value(args)
.map_err(|e| ToolError::new(format!("invalid arguments: {e}")))?;
if !is_safe_identifier(&args.table) {
return Err(ToolError::new(format!(
"invalid table name `{}`",
args.table
)));
}
if !is_safe_identifier(&args.column) {
return Err(ToolError::new(format!(
"invalid column name `{}`",
args.column
)));
}
if args.query.is_empty() {
return Err(ToolError::new(
"query must be a non-empty string".to_string(),
));
}
let k = args.k.unwrap_or(DEFAULT_K).clamp(1, HARD_CAP_K);
{
let db = state.conn.database();
let table = db
.get_table(args.table.clone())
.map_err(|e| ToolError::new(format!("table `{}` not found: {e}", args.table)))?;
let target = table
.columns
.iter()
.find(|c| c.column_name == args.column)
.ok_or_else(|| {
ToolError::new(format!(
"column `{}` not found on table `{}`",
args.column, args.table,
))
})?;
if !matches!(target.datatype, sqlrite::sql::db::table::DataType::Text) {
return Err(ToolError::new(format!(
"column `{}` on table `{}` is `{}`, not a TEXT column",
args.column, args.table, target.datatype,
)));
}
if !table
.fts_indexes
.iter()
.any(|i| i.column_name == args.column)
{
return Err(ToolError::new(format!(
"column `{}` on table `{}` has no FTS index — \
run `CREATE INDEX <name> ON {} USING fts ({})` first",
args.column, args.table, args.table, args.column,
)));
}
}
let query_lit = sql_string_literal(&args.query);
let sql = format!(
"SELECT * FROM {tbl} WHERE fts_match({col}, {q}) \
ORDER BY bm25_score({col}, {q}) DESC LIMIT {k}",
tbl = args.table,
col = args.column,
q = query_lit,
k = k,
);
let stmt = state.conn.prepare(&sql)?;
let mut rows = stmt.query()?;
let columns = rows.columns().to_vec();
let mut out: Vec<Value> = Vec::with_capacity(k);
let mut size_estimate = 0;
let mut byte_truncated = false;
while let Some(row) = rows.next()? {
let mut obj = serde_json::Map::with_capacity(columns.len());
for (i, col) in columns.iter().enumerate() {
let v: sqlrite::Value = row.get(i)?;
let json_val = value_to_json(&v);
size_estimate += col.len() + 8 + json_val.to_string().len();
obj.insert(col.clone(), json_val);
}
if size_estimate > TOOL_OUTPUT_CAP_BYTES {
byte_truncated = true;
break;
}
out.push(Value::Object(obj));
}
let mut result = json!({
"table": args.table,
"column": args.column,
"query": args.query,
"k_requested": k,
"rows": out,
});
if byte_truncated {
result["truncated"] = json!(true);
result["truncation_reason"] = json!(format!(
"response truncated at {} bytes ({} of {} rows shown)",
TOOL_OUTPUT_CAP_BYTES,
out.len(),
k,
));
}
serde_json::to_string_pretty(&result)
.map_err(|e| ToolError::new(format!("internal: failed to serialize results: {e}")))
}
fn sql_string_literal(s: &str) -> String {
let mut out = String::with_capacity(s.len() + 2);
out.push('\'');
for c in s.chars() {
if c == '\'' {
out.push_str("''");
} else {
out.push(c);
}
}
out.push('\'');
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sql_string_literal_doubles_quotes() {
assert_eq!(sql_string_literal("rust"), "'rust'");
assert_eq!(sql_string_literal("it's fast"), "'it''s fast'");
assert_eq!(sql_string_literal(""), "''");
}
}