use rmcp::model::{Content, IntoContents};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use super::ToolError;
use crate::{mcp::McpServerSqlite, traits::SqliteServerTool};
#[derive(
Clone,
Copy,
Debug,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Default,
Serialize,
Deserialize,
JsonSchema,
)]
pub struct SearchFtsTool;
impl SqliteServerTool for SearchFtsTool {
const NAME: &str = "search_fts";
type Context = McpServerSqlite;
type Error = ToolError<SearchFtsError>;
type Input = SearchFtsInput;
type Output = SearchFtsOutput;
fn handle(
ctx: &Self::Context,
input: Self::Input,
) -> Result<Self::Output, Self::Error> {
let conn = ctx
.connection()
.map_err(|source| ToolError::Connection { source })?;
let limit = input.limit.unwrap_or(10);
let snippet_tokens = input.snippet_tokens.unwrap_or(32);
let hl_start = input.highlight_start.as_deref().unwrap_or("<b>");
let hl_end = input.highlight_end.as_deref().unwrap_or("</b>");
let column_count =
fts_column_count(&conn, &input.fts_table).map_err(|source| {
ToolError::Tool(SearchFtsError::Query { source })
})?;
let snippet_exprs = (0..column_count)
.map(|i| {
format!(
"snippet({tbl}, {i}, '{hl_start}', '{hl_end}', '...', {tokens})",
tbl = input.fts_table,
tokens = snippet_tokens,
)
})
.collect::<Vec<_>>()
.join(", ");
let sql = format!(
"SELECT rowid, rank, {snippets} \
FROM [{tbl}] \
WHERE [{tbl}] MATCH ?1 \
ORDER BY rank \
LIMIT ?2",
tbl = input.fts_table,
snippets = snippet_exprs,
);
let mut stmt = conn.prepare(&sql).map_err(|source| {
ToolError::Tool(SearchFtsError::Query { source })
})?;
let results = stmt
.query_map(rusqlite::params![input.query, limit], |row| {
let rowid = row.get::<_, i64>(0)?;
let rank = row.get::<_, f64>(1)?;
let snippets = (0..column_count)
.map(|i| row.get::<_, String>(2 + i))
.collect::<Result<Vec<_>, _>>()?;
Ok(FtsMatch {
rowid,
rank,
snippets,
})
})
.map_err(|source| {
ToolError::Tool(SearchFtsError::Query { source })
})?
.collect::<Result<Vec<_>, _>>()
.map_err(|source| {
ToolError::Tool(SearchFtsError::Query { source })
})?;
Ok(SearchFtsOutput { results })
}
}
fn fts_column_count(
conn: &rusqlite::Connection,
fts_table: &str,
) -> Result<usize, rusqlite::Error> {
let mut stmt =
conn.prepare(&format!("PRAGMA table_info([{}])", fts_table))?;
let count = stmt.query_map([], |_| Ok(()))?.count();
Ok(count)
}
#[derive(
Clone,
Debug,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Serialize,
Deserialize,
schemars::JsonSchema,
)]
pub struct SearchFtsInput {
#[schemars(description = "The FTS5 virtual table to search")]
pub fts_table: String,
#[schemars(description = "The FTS5 MATCH query string")]
pub query: String,
#[schemars(description = "Maximum number of results (default 10)")]
pub limit: Option<i64>,
#[schemars(description = "Max tokens per snippet (default 32)")]
pub snippet_tokens: Option<i32>,
#[schemars(
description = "String before matched terms in snippets (default <b>)"
)]
pub highlight_start: Option<String>,
#[schemars(
description = "String after matched terms in snippets (default </b>)"
)]
pub highlight_end: Option<String>,
}
#[derive(
Clone,
Debug,
PartialEq,
PartialOrd,
Serialize,
Deserialize,
schemars::JsonSchema,
)]
pub struct SearchFtsOutput {
pub results: Vec<FtsMatch>,
}
#[derive(
Clone,
Debug,
PartialEq,
PartialOrd,
Serialize,
Deserialize,
schemars::JsonSchema,
)]
pub struct FtsMatch {
pub rowid: i64,
pub rank: f64,
pub snippets: Vec<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum SearchFtsError {
#[error("FTS search failed: {source}")]
Query {
source: rusqlite::Error,
},
}
impl IntoContents for SearchFtsError {
fn into_contents(self) -> Vec<Content> {
vec![Content::text(self.to_string())]
}
}