use rmcp::model::{Content, IntoContents};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use super::ToolError;
use crate::{mcp::McpServerSqlite, traits::SqliteServerTool};
#[derive(
Clone,
Copy,
Debug,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Default,
Serialize,
Deserialize,
JsonSchema,
)]
pub struct ListTriggersTool;
impl SqliteServerTool for ListTriggersTool {
const NAME: &str = "list_triggers";
type Context = McpServerSqlite;
type Error = ToolError<ListTriggersError>;
type Input = ListTriggersInput;
type Output = ListTriggersOutput;
fn handle(
ctx: &Self::Context,
input: Self::Input,
) -> Result<Self::Output, Self::Error> {
let conn = ctx
.connection()
.map_err(|source| ToolError::Connection { source })?;
let triggers = match &input.table_name {
Some(table) => {
let mut stmt = conn
.prepare(
"SELECT name, tbl_name, sql FROM sqlite_master \
WHERE type = 'trigger' AND tbl_name = ?1 \
ORDER BY tbl_name, name",
)
.map_err(|source| {
ToolError::Tool(ListTriggersError::Query { source })
})?;
stmt.query_map([table], |row| Ok(trigger_info_from_row(row)))
.map_err(|source| {
ToolError::Tool(ListTriggersError::Query { source })
})?
.collect::<Result<Vec<_>, _>>()
.map_err(|source| {
ToolError::Tool(ListTriggersError::Query { source })
})?
}
None => {
let mut stmt = conn
.prepare(
"SELECT name, tbl_name, sql FROM sqlite_master \
WHERE type = 'trigger' \
ORDER BY tbl_name, name",
)
.map_err(|source| {
ToolError::Tool(ListTriggersError::Query { source })
})?;
stmt.query_map([], |row| Ok(trigger_info_from_row(row)))
.map_err(|source| {
ToolError::Tool(ListTriggersError::Query { source })
})?
.collect::<Result<Vec<_>, _>>()
.map_err(|source| {
ToolError::Tool(ListTriggersError::Query { source })
})?
}
};
Ok(ListTriggersOutput { triggers })
}
}
fn trigger_info_from_row(row: &rusqlite::Row<'_>) -> TriggerInfo {
let name: String = row.get(0).unwrap_or_default();
let table_name: String = row.get(1).unwrap_or_default();
let sql: Option<String> = row.get(2).unwrap_or_default();
let sql_text = sql.unwrap_or_default();
let (timing, event) = parse_timing_and_event(&name, &sql_text);
TriggerInfo {
name,
table_name,
event,
timing,
sql: sql_text,
}
}
fn parse_timing_and_event(trigger_name: &str, sql: &str) -> (String, String) {
let upper = sql.to_uppercase();
let search_name = trigger_name.to_uppercase();
let after_name_pos =
upper.find(&search_name).map(|pos| pos + search_name.len());
let remainder = match after_name_pos {
Some(pos) => &upper[pos..],
None => return (unknown(), unknown()),
};
let trimmed = remainder.trim_start();
let timing = if trimmed.starts_with("BEFORE") {
"BEFORE"
} else if trimmed.starts_with("AFTER") {
"AFTER"
} else if trimmed.starts_with("INSTEAD OF") {
"INSTEAD OF"
} else {
return (unknown(), unknown());
};
let after_timing = trimmed[timing.len()..].trim_start();
let event = if after_timing.starts_with("INSERT") {
"INSERT"
} else if after_timing.starts_with("UPDATE") {
"UPDATE"
} else if after_timing.starts_with("DELETE") {
"DELETE"
} else {
return (timing.to_owned(), unknown());
};
(timing.to_owned(), event.to_owned())
}
fn unknown() -> String {
"UNKNOWN".to_owned()
}
#[derive(
Clone,
Debug,
Default,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Serialize,
Deserialize,
schemars::JsonSchema,
)]
pub struct ListTriggersInput {
#[schemars(description = "Optional table name to filter triggers by. \
Omit to list all triggers.")]
pub table_name: Option<String>,
}
#[derive(
Clone,
Debug,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Serialize,
Deserialize,
schemars::JsonSchema,
)]
pub struct ListTriggersOutput {
pub triggers: Vec<TriggerInfo>,
}
#[derive(
Clone,
Debug,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Serialize,
Deserialize,
schemars::JsonSchema,
)]
pub struct TriggerInfo {
pub name: String,
pub table_name: String,
pub event: String,
pub timing: String,
pub sql: String,
}
#[derive(Debug, thiserror::Error)]
pub enum ListTriggersError {
#[error("failed to list triggers: {source}")]
Query {
source: rusqlite::Error,
},
}
impl IntoContents for ListTriggersError {
fn into_contents(self) -> Vec<Content> {
vec![Content::text(self.to_string())]
}
}