use rmcp::model::{Content, IntoContents};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use super::ToolError;
use crate::{mcp::McpServerSqlite, traits::SqliteServerTool};
#[derive(
Clone,
Copy,
Debug,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Default,
Serialize,
Deserialize,
JsonSchema,
)]
pub struct ExecuteTool;
impl SqliteServerTool for ExecuteTool {
const NAME: &str = "execute";
type Context = McpServerSqlite;
type Error = ToolError<ExecuteError>;
type Input = ExecuteInput;
type Output = ExecuteOutput;
fn handle(
ctx: &Self::Context,
input: Self::Input,
) -> Result<Self::Output, Self::Error> {
let conn = ctx
.connection()
.map_err(|source| ToolError::Connection { source })?;
let mut stmt = conn.prepare(&input.query).map_err(|source| {
if matches!(
source,
rusqlite::Error::SqliteFailure(
rusqlite::ffi::Error {
code: rusqlite::ffi::ErrorCode::AuthorizationForStatementDenied,
..
},
_,
)
) {
ToolError::AccessDenied {
message: format!(
"the configured access control policy denied this \
statement: {}",
input.query,
),
}
} else {
ToolError::Tool(ExecuteError::Prepare { source })
}
})?;
let column_count = stmt.column_count();
let column_names = (0..column_count)
.map(|i| stmt.column_name(i).unwrap_or("?").to_owned())
.collect::<Vec<String>>();
let rows = stmt
.query_map([], |row: &rusqlite::Row<'_>| {
let columns = column_names
.iter()
.enumerate()
.map(|(i, name)| {
let value = row
.get::<_, rusqlite::types::Value>(i)
.map(Value::from)
.unwrap_or(Value::Null);
(name.clone(), value)
})
.collect();
Ok(Row { columns })
})
.map_err(|source| ToolError::Tool(ExecuteError::Query { source }))?
.collect::<Result<Vec<_>, _>>()
.map_err(|source| {
ToolError::Tool(ExecuteError::Query { source })
})?;
let rows_changed = conn.changes();
Ok(ExecuteOutput { rows, rows_changed })
}
}
#[derive(
Clone,
Debug,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Serialize,
Deserialize,
schemars::JsonSchema,
)]
pub struct ExecuteInput {
#[schemars(description = "The SQL query to execute")]
pub query: String,
}
#[derive(
Clone,
Debug,
PartialEq,
PartialOrd,
Serialize,
Deserialize,
schemars::JsonSchema,
)]
pub struct ExecuteOutput {
pub rows: Vec<Row>,
pub rows_changed: u64,
}
#[derive(
Clone,
Debug,
PartialEq,
PartialOrd,
Serialize,
Deserialize,
schemars::JsonSchema,
)]
pub struct Row {
pub columns: std::collections::BTreeMap<String, Value>,
}
#[serde_with::serde_as]
#[derive(
Clone,
Debug,
PartialEq,
PartialOrd,
Serialize,
Deserialize,
schemars::JsonSchema,
)]
#[serde(tag = "kind", content = "value")]
pub enum Value {
Null,
Integer(i64),
Real(f64),
Text(String),
Blob(
#[serde_as(as = "serde_with::hex::Hex")]
#[schemars(with = "String")]
Vec<u8>,
),
}
impl From<rusqlite::types::Value> for Value {
fn from(value: rusqlite::types::Value) -> Self {
match value {
rusqlite::types::Value::Null => Self::Null,
rusqlite::types::Value::Integer(n) => Self::Integer(n),
rusqlite::types::Value::Real(f) => Self::Real(f),
rusqlite::types::Value::Text(s) => Self::Text(s),
rusqlite::types::Value::Blob(b) => Self::Blob(b),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ExecuteError {
#[error("failed to prepare statement: {source}")]
Prepare {
source: rusqlite::Error,
},
#[error("failed to read query results: {source}")]
Query {
source: rusqlite::Error,
},
}
impl IntoContents for ExecuteError {
fn into_contents(self) -> Vec<Content> {
vec![Content::text(self.to_string())]
}
}