use std::collections::HashMap;
use std::sync::Arc;
use elicitation::{ColumnEntry, ColumnValue, RowData};
use futures::future::BoxFuture;
use rmcp::{
ErrorData,
model::{CallToolRequestParams, CallToolResult, Content, Tool},
service::RequestContext,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use sqlx::Column as _;
use sqlx::Row as _;
use sqlx::TypeInfo as _;
use tokio::sync::Mutex;
use tracing::instrument;
use uuid::Uuid;
use crate::QueryResultData;
#[derive(Debug, Deserialize, JsonSchema)]
pub struct ConnectParams {
pub database_url: String,
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct PoolIdParams {
pub pool_id: Uuid,
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct PoolSqlParams {
pub pool_id: Uuid,
pub sql: String,
#[serde(default)]
pub args: Vec<serde_json::Value>,
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct TxIdParams {
pub tx_id: Uuid,
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct TxSqlParams {
pub tx_id: Uuid,
pub sql: String,
#[serde(default)]
pub args: Vec<serde_json::Value>,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct ConnectResult {
pub pool_id: Uuid,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct BeginResult {
pub tx_id: Uuid,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct PoolStatsResult {
pub size: u32,
pub num_idle: usize,
pub is_closed: bool,
}
fn parse_args<T: for<'de> Deserialize<'de>>(
params: &CallToolRequestParams,
) -> Result<T, ErrorData> {
let value = serde_json::Value::Object(params.arguments.clone().unwrap_or_default());
serde_json::from_value(value).map_err(|e| ErrorData::invalid_params(e.to_string(), None))
}
fn json_result<T: Serialize>(value: &T) -> CallToolResult {
match serde_json::to_string(value) {
Ok(s) => CallToolResult::success(vec![Content::text(s)]),
Err(e) => CallToolResult::error(vec![Content::text(format!("serialize error: {e}"))]),
}
}
pub fn decode_pg_row(row: &sqlx::postgres::PgRow) -> RowData {
let columns = row
.columns()
.iter()
.enumerate()
.map(|(i, col)| {
let name = col.name().to_string();
let value = match col.type_info().name() {
"BOOL" => row
.try_get::<bool, _>(i)
.map(ColumnValue::Bool)
.unwrap_or(ColumnValue::Null),
"INT2" | "SMALLINT" | "SMALLSERIAL" => row
.try_get::<i16, _>(i)
.map(ColumnValue::SmallInt)
.unwrap_or(ColumnValue::Null),
"INT4" | "INT" | "INTEGER" | "SERIAL" => row
.try_get::<i32, _>(i)
.map(ColumnValue::Integer)
.unwrap_or(ColumnValue::Null),
"INT8" | "BIGINT" | "BIGSERIAL" => row
.try_get::<i64, _>(i)
.map(ColumnValue::BigInt)
.unwrap_or(ColumnValue::Null),
"FLOAT4" | "REAL" => row
.try_get::<f32, _>(i)
.map(ColumnValue::Real)
.unwrap_or(ColumnValue::Null),
"FLOAT8" | "DOUBLE PRECISION" => row
.try_get::<f64, _>(i)
.map(ColumnValue::Double)
.unwrap_or(ColumnValue::Null),
"BYTEA" => row
.try_get::<Vec<u8>, _>(i)
.map(ColumnValue::Blob)
.unwrap_or(ColumnValue::Null),
_ => row
.try_get::<String, _>(i)
.map(ColumnValue::Text)
.unwrap_or(ColumnValue::Null),
};
ColumnEntry::new(name, value)
})
.collect();
RowData::new(columns)
}
pub fn decode_sqlite_row(row: &sqlx::sqlite::SqliteRow) -> RowData {
let columns = row
.columns()
.iter()
.enumerate()
.map(|(i, col)| {
let name = col.name().to_string();
let value = match col.type_info().name().to_uppercase().as_str() {
"NULL" => ColumnValue::Null,
"BOOLEAN" => row
.try_get::<bool, _>(i)
.map(ColumnValue::Bool)
.unwrap_or(ColumnValue::Null),
"INTEGER" | "INT" | "TINYINT" | "SMALLINT" | "MEDIUMINT" | "BIGINT" | "INT2"
| "INT8" => row
.try_get::<i64, _>(i)
.map(ColumnValue::BigInt)
.unwrap_or(ColumnValue::Null),
"REAL" | "FLOAT" | "DOUBLE" | "NUMERIC" | "DECIMAL" => row
.try_get::<f64, _>(i)
.map(ColumnValue::Double)
.unwrap_or(ColumnValue::Null),
"BLOB" => row
.try_get::<Vec<u8>, _>(i)
.map(ColumnValue::Blob)
.unwrap_or(ColumnValue::Null),
_ => row
.try_get::<String, _>(i)
.map(ColumnValue::Text)
.unwrap_or(ColumnValue::Null),
};
ColumnEntry::new(name, value)
})
.collect();
RowData::new(columns)
}
pub fn decode_mysql_row(row: &sqlx::mysql::MySqlRow) -> RowData {
let columns = row
.columns()
.iter()
.enumerate()
.map(|(i, col)| {
let name = col.name().to_string();
let value = match col.type_info().name().to_uppercase().as_str() {
"NULL" => ColumnValue::Null,
"BOOLEAN" | "BOOL" | "TINYINT(1)" => row
.try_get::<bool, _>(i)
.map(ColumnValue::Bool)
.unwrap_or(ColumnValue::Null),
"SMALLINT" | "SMALLINT UNSIGNED" | "YEAR" => row
.try_get::<i16, _>(i)
.map(ColumnValue::SmallInt)
.unwrap_or(ColumnValue::Null),
"INT" | "INTEGER" | "MEDIUMINT" | "INT UNSIGNED" => row
.try_get::<i32, _>(i)
.map(ColumnValue::Integer)
.unwrap_or(ColumnValue::Null),
"BIGINT" | "BIGINT UNSIGNED" => row
.try_get::<i64, _>(i)
.map(ColumnValue::BigInt)
.unwrap_or(ColumnValue::Null),
"FLOAT" => row
.try_get::<f32, _>(i)
.map(ColumnValue::Real)
.unwrap_or(ColumnValue::Null),
"DOUBLE" | "REAL" | "DECIMAL" | "NUMERIC" => row
.try_get::<f64, _>(i)
.map(ColumnValue::Double)
.unwrap_or(ColumnValue::Null),
"BLOB" | "TINYBLOB" | "MEDIUMBLOB" | "LONGBLOB" | "BINARY" | "VARBINARY" => row
.try_get::<Vec<u8>, _>(i)
.map(ColumnValue::Blob)
.unwrap_or(ColumnValue::Null),
_ => row
.try_get::<String, _>(i)
.map(ColumnValue::Text)
.unwrap_or(ColumnValue::Null),
};
ColumnEntry::new(name, value)
})
.collect();
RowData::new(columns)
}
pub fn pg_query_result(r: sqlx::postgres::PgQueryResult) -> QueryResultData {
QueryResultData {
rows_affected: r.rows_affected(),
last_insert_id: None,
}
}
pub fn sqlite_query_result(r: sqlx::sqlite::SqliteQueryResult) -> QueryResultData {
QueryResultData {
rows_affected: r.rows_affected(),
last_insert_id: Some(r.last_insert_rowid()),
}
}
pub fn mysql_query_result(r: sqlx::mysql::MySqlQueryResult) -> QueryResultData {
QueryResultData {
rows_affected: r.rows_affected(),
last_insert_id: Some(r.last_insert_id() as i64),
}
}
pub fn pg_args_from_json(args: &[serde_json::Value]) -> sqlx::postgres::PgArguments {
use sqlx::Arguments as _;
let mut out = sqlx::postgres::PgArguments::default();
for v in args {
match v {
serde_json::Value::Bool(b) => out.add(*b).expect("pg bind bool"),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
out.add(i).expect("pg bind i64");
} else if let Some(f) = n.as_f64() {
out.add(f).expect("pg bind f64");
} else {
out.add(Option::<String>::None).expect("pg bind null");
}
}
serde_json::Value::String(s) => out.add(s.clone()).expect("pg bind string"),
_ => out.add(Option::<String>::None).expect("pg bind null"),
}
}
out
}
pub fn sqlite_args_from_json(args: &[serde_json::Value]) -> sqlx::sqlite::SqliteArguments<'static> {
use sqlx::Arguments as _;
let mut out = sqlx::sqlite::SqliteArguments::default();
for v in args {
match v {
serde_json::Value::Bool(b) => out.add(*b).expect("sqlite bind bool"),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
out.add(i).expect("sqlite bind i64");
} else if let Some(f) = n.as_f64() {
out.add(f).expect("sqlite bind f64");
} else {
out.add(Option::<String>::None).expect("sqlite bind null");
}
}
serde_json::Value::String(s) => out.add(s.clone()).expect("sqlite bind string"),
_ => out.add(Option::<String>::None).expect("sqlite bind null"),
}
}
out
}
pub fn mysql_args_from_json(args: &[serde_json::Value]) -> sqlx::mysql::MySqlArguments {
use sqlx::Arguments as _;
let mut out = sqlx::mysql::MySqlArguments::default();
for v in args {
match v {
serde_json::Value::Bool(b) => out.add(*b).expect("mysql bind bool"),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
out.add(i).expect("mysql bind i64");
} else if let Some(f) = n.as_f64() {
out.add(f).expect("mysql bind f64");
} else {
out.add(Option::<String>::None).expect("mysql bind null");
}
}
serde_json::Value::String(s) => out.add(s.clone()).expect("mysql bind string"),
_ => out.add(Option::<String>::None).expect("mysql bind null"),
}
}
out
}
macro_rules! impl_driver_plugin {
(
struct_name = $Plugin:ident,
plugin_name = $name:literal,
pool_type = $Pool:ty,
db_type = $Db:ty,
row_decoder = $decode_row:ident,
result_converter = $build_result:ident,
args_converter = $args_from_json:ident,
driver_label = $driver_label:literal
) => {
#[doc = concat!("Stateful MCP plugin for `", $driver_label, "` database operations.")]
#[allow(dead_code)]
pub struct $Plugin {
pools: Arc<Mutex<HashMap<Uuid, $Pool>>>,
transactions: Arc<Mutex<HashMap<Uuid, sqlx::Transaction<'static, $Db>>>>,
}
impl $Plugin {
pub fn new() -> Self {
Self {
pools: Arc::new(Mutex::new(HashMap::new())),
transactions: Arc::new(Mutex::new(HashMap::new())),
}
}
}
impl Default for $Plugin {
fn default() -> Self {
Self::new()
}
}
impl elicitation::ElicitPlugin for $Plugin {
fn name(&self) -> &'static str {
$name
}
fn list_tools(&self) -> Vec<Tool> {
fn tool(name: &'static str, desc: &'static str) -> Tool {
Tool::new(name, desc, Arc::new(Default::default()))
}
vec![
tool(
concat!($name, "__connect"),
concat!(
"Connect to a ",
$driver_label,
" database. Returns `{ pool_id }` for subsequent calls."
),
),
tool(
concat!($name, "__disconnect"),
concat!(
"Close and remove a named ",
$driver_label,
" pool from the registry."
),
),
tool(
concat!($name, "__pool_stats"),
concat!(
"Return `{ size, num_idle, is_closed }` for a named ",
$driver_label,
" pool."
),
),
tool(
concat!($name, "__execute"),
concat!(
"Execute a non-returning SQL statement against a ",
$driver_label,
" pool. Returns `{ rows_affected, last_insert_id }`."
),
),
tool(
concat!($name, "__fetch_all"),
concat!(
"Execute a SELECT against a ",
$driver_label,
" pool. Returns all rows as a JSON array of RowData."
),
),
tool(
concat!($name, "__fetch_one"),
concat!(
"Execute a SELECT against a ",
$driver_label,
" pool. Returns the first row or errors if none found."
),
),
tool(
concat!($name, "__fetch_optional"),
concat!(
"Execute a SELECT against a ",
$driver_label,
" pool. Returns the first row or null if none found."
),
),
tool(
concat!($name, "__begin"),
concat!(
"Begin a transaction on a ",
$driver_label,
" pool. Returns `{ tx_id }` for subsequent tx_ calls."
),
),
tool(
concat!($name, "__commit"),
concat!(
"Commit a ",
$driver_label,
" transaction and remove it from the registry."
),
),
tool(
concat!($name, "__rollback"),
concat!(
"Roll back a ",
$driver_label,
" transaction and remove it from the registry."
),
),
tool(
concat!($name, "__tx_execute"),
concat!(
"Execute a non-returning SQL statement within a ",
$driver_label,
" transaction."
),
),
tool(
concat!($name, "__tx_fetch_all"),
concat!(
"Execute a SELECT within a ",
$driver_label,
" transaction. Returns all rows as a JSON array of RowData."
),
),
tool(
concat!($name, "__tx_fetch_one"),
concat!(
"Execute a SELECT within a ",
$driver_label,
" transaction. Returns the first row or errors if none found."
),
),
tool(
concat!($name, "__tx_fetch_optional"),
concat!(
"Execute a SELECT within a ",
$driver_label,
" transaction. Returns the first row or null if none found."
),
),
]
}
#[instrument(skip(self, _ctx), fields(tool = %params.name))]
fn call_tool<'a>(
&'a self,
params: CallToolRequestParams,
_ctx: RequestContext<rmcp::RoleServer>,
) -> BoxFuture<'a, Result<CallToolResult, ErrorData>> {
Box::pin(async move {
let verb = params
.name
.strip_prefix(concat!($name, "__"))
.ok_or_else(|| {
ErrorData::invalid_params(
format!("unknown tool: {}", params.name),
None,
)
})?;
match verb {
"connect" => {
let p: ConnectParams = parse_args(¶ms)?;
sqlx::any::install_default_drivers();
let pool = <$Pool>::connect(&p.database_url).await.map_err(|e| {
ErrorData::invalid_params(format!("connect failed: {e}"), None)
})?;
let id = Uuid::new_v4();
self.pools.lock().await.insert(id, pool);
Ok(json_result(&ConnectResult { pool_id: id }))
}
"disconnect" => {
let p: PoolIdParams = parse_args(¶ms)?;
let pool =
self.pools.lock().await.remove(&p.pool_id).ok_or_else(|| {
ErrorData::invalid_params("pool_id not found", None)
})?;
pool.close().await;
Ok(CallToolResult::success(vec![Content::text(
r#"{"ok":true}"#,
)]))
}
"pool_stats" => {
let p: PoolIdParams = parse_args(¶ms)?;
let stats = {
let pools = self.pools.lock().await;
let pool = pools.get(&p.pool_id).ok_or_else(|| {
ErrorData::invalid_params("pool_id not found", None)
})?;
PoolStatsResult {
size: pool.size(),
num_idle: pool.num_idle(),
is_closed: pool.is_closed(),
}
};
Ok(json_result(&stats))
}
"execute" => {
let p: PoolSqlParams = parse_args(¶ms)?;
let pool = {
let pools = self.pools.lock().await;
pools.get(&p.pool_id).cloned().ok_or_else(|| {
ErrorData::invalid_params("pool_id not found", None)
})?
};
let result = if p.args.is_empty() {
sqlx::query(&p.sql).execute(&pool).await
} else {
sqlx::query_with(&p.sql, $args_from_json(&p.args))
.execute(&pool)
.await
};
match result {
Ok(r) => Ok(json_result(&$build_result(r))),
Err(e) => {
Ok(CallToolResult::error(vec![Content::text(e.to_string())]))
}
}
}
"fetch_all" => {
let p: PoolSqlParams = parse_args(¶ms)?;
let pool = {
let pools = self.pools.lock().await;
pools.get(&p.pool_id).cloned().ok_or_else(|| {
ErrorData::invalid_params("pool_id not found", None)
})?
};
let result = if p.args.is_empty() {
sqlx::query(&p.sql).fetch_all(&pool).await
} else {
sqlx::query_with(&p.sql, $args_from_json(&p.args))
.fetch_all(&pool)
.await
};
match result {
Ok(rows) => {
let data: Vec<RowData> = rows.iter().map($decode_row).collect();
Ok(json_result(&data))
}
Err(e) => {
Ok(CallToolResult::error(vec![Content::text(e.to_string())]))
}
}
}
"fetch_one" => {
let p: PoolSqlParams = parse_args(¶ms)?;
let pool = {
let pools = self.pools.lock().await;
pools.get(&p.pool_id).cloned().ok_or_else(|| {
ErrorData::invalid_params("pool_id not found", None)
})?
};
let result = if p.args.is_empty() {
sqlx::query(&p.sql).fetch_one(&pool).await
} else {
sqlx::query_with(&p.sql, $args_from_json(&p.args))
.fetch_one(&pool)
.await
};
match result {
Ok(row) => Ok(json_result(&$decode_row(&row))),
Err(e) => {
Ok(CallToolResult::error(vec![Content::text(e.to_string())]))
}
}
}
"fetch_optional" => {
let p: PoolSqlParams = parse_args(¶ms)?;
let pool = {
let pools = self.pools.lock().await;
pools.get(&p.pool_id).cloned().ok_or_else(|| {
ErrorData::invalid_params("pool_id not found", None)
})?
};
let result = if p.args.is_empty() {
sqlx::query(&p.sql).fetch_optional(&pool).await
} else {
sqlx::query_with(&p.sql, $args_from_json(&p.args))
.fetch_optional(&pool)
.await
};
match result {
Ok(maybe) => {
let data = maybe.as_ref().map($decode_row);
Ok(json_result(&data))
}
Err(e) => {
Ok(CallToolResult::error(vec![Content::text(e.to_string())]))
}
}
}
"begin" => {
let p: PoolIdParams = parse_args(¶ms)?;
let pool = {
let pools = self.pools.lock().await;
pools.get(&p.pool_id).cloned().ok_or_else(|| {
ErrorData::invalid_params("pool_id not found", None)
})?
};
match pool.begin().await {
Ok(tx) => {
let id = Uuid::new_v4();
self.transactions.lock().await.insert(id, tx);
Ok(json_result(&BeginResult { tx_id: id }))
}
Err(e) => {
Ok(CallToolResult::error(vec![Content::text(e.to_string())]))
}
}
}
"commit" => {
let p: TxIdParams = parse_args(¶ms)?;
let tx = self.transactions.lock().await.remove(&p.tx_id).ok_or_else(
|| ErrorData::invalid_params("tx_id not found", None),
)?;
match tx.commit().await {
Ok(()) => Ok(CallToolResult::success(vec![Content::text(
r#"{"ok":true}"#,
)])),
Err(e) => {
Ok(CallToolResult::error(vec![Content::text(e.to_string())]))
}
}
}
"rollback" => {
let p: TxIdParams = parse_args(¶ms)?;
let tx = self.transactions.lock().await.remove(&p.tx_id).ok_or_else(
|| ErrorData::invalid_params("tx_id not found", None),
)?;
match tx.rollback().await {
Ok(()) => Ok(CallToolResult::success(vec![Content::text(
r#"{"ok":true}"#,
)])),
Err(e) => {
Ok(CallToolResult::error(vec![Content::text(e.to_string())]))
}
}
}
"tx_execute" => {
let p: TxSqlParams = parse_args(¶ms)?;
let mut tx =
self.transactions.lock().await.remove(&p.tx_id).ok_or_else(
|| ErrorData::invalid_params("tx_id not found", None),
)?;
let result = if p.args.is_empty() {
sqlx::query(&p.sql).execute(&mut *tx).await
} else {
sqlx::query_with(&p.sql, $args_from_json(&p.args))
.execute(&mut *tx)
.await
};
self.transactions.lock().await.insert(p.tx_id, tx);
match result {
Ok(r) => Ok(json_result(&$build_result(r))),
Err(e) => {
Ok(CallToolResult::error(vec![Content::text(e.to_string())]))
}
}
}
"tx_fetch_all" => {
let p: TxSqlParams = parse_args(¶ms)?;
let mut tx =
self.transactions.lock().await.remove(&p.tx_id).ok_or_else(
|| ErrorData::invalid_params("tx_id not found", None),
)?;
let result = if p.args.is_empty() {
sqlx::query(&p.sql).fetch_all(&mut *tx).await
} else {
sqlx::query_with(&p.sql, $args_from_json(&p.args))
.fetch_all(&mut *tx)
.await
};
self.transactions.lock().await.insert(p.tx_id, tx);
match result {
Ok(rows) => {
let data: Vec<RowData> = rows.iter().map($decode_row).collect();
Ok(json_result(&data))
}
Err(e) => {
Ok(CallToolResult::error(vec![Content::text(e.to_string())]))
}
}
}
"tx_fetch_one" => {
let p: TxSqlParams = parse_args(¶ms)?;
let mut tx =
self.transactions.lock().await.remove(&p.tx_id).ok_or_else(
|| ErrorData::invalid_params("tx_id not found", None),
)?;
let result = if p.args.is_empty() {
sqlx::query(&p.sql).fetch_one(&mut *tx).await
} else {
sqlx::query_with(&p.sql, $args_from_json(&p.args))
.fetch_one(&mut *tx)
.await
};
self.transactions.lock().await.insert(p.tx_id, tx);
match result {
Ok(row) => Ok(json_result(&$decode_row(&row))),
Err(e) => {
Ok(CallToolResult::error(vec![Content::text(e.to_string())]))
}
}
}
"tx_fetch_optional" => {
let p: TxSqlParams = parse_args(¶ms)?;
let mut tx =
self.transactions.lock().await.remove(&p.tx_id).ok_or_else(
|| ErrorData::invalid_params("tx_id not found", None),
)?;
let result = if p.args.is_empty() {
sqlx::query(&p.sql).fetch_optional(&mut *tx).await
} else {
sqlx::query_with(&p.sql, $args_from_json(&p.args))
.fetch_optional(&mut *tx)
.await
};
self.transactions.lock().await.insert(p.tx_id, tx);
match result {
Ok(maybe) => {
let data = maybe.as_ref().map($decode_row);
Ok(json_result(&data))
}
Err(e) => {
Ok(CallToolResult::error(vec![Content::text(e.to_string())]))
}
}
}
other => Err(ErrorData::invalid_params(
format!("unknown {} tool: {other}", $name),
None,
)),
}
})
}
}
};
}
impl_driver_plugin!(
struct_name = SqlxPgPlugin,
plugin_name = "pg",
pool_type = sqlx::PgPool,
db_type = sqlx::Postgres,
row_decoder = decode_pg_row,
result_converter = pg_query_result,
args_converter = pg_args_from_json,
driver_label = "PostgreSQL"
);
impl_driver_plugin!(
struct_name = SqlxSqlitePlugin,
plugin_name = "sqlite",
pool_type = sqlx::SqlitePool,
db_type = sqlx::Sqlite,
row_decoder = decode_sqlite_row,
result_converter = sqlite_query_result,
args_converter = sqlite_args_from_json,
driver_label = "SQLite"
);
impl_driver_plugin!(
struct_name = SqlxMySqlPlugin,
plugin_name = "mysql",
pool_type = sqlx::MySqlPool,
db_type = sqlx::MySql,
row_decoder = decode_mysql_row,
result_converter = mysql_query_result,
args_converter = mysql_args_from_json,
driver_label = "MySQL"
);