Skip to main content

mcp_sql/
server.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use rmcp::handler::server::router::tool::ToolRouter;
5use rmcp::handler::server::wrapper::Parameters;
6use rmcp::model::*;
7use rmcp::{schemars, tool, tool_handler, tool_router, ServerHandler};
8use serde::Deserialize;
9
10use crate::db::convert::row_to_json;
11use crate::db::dialect;
12use crate::db::{DatabaseManager, DbBackend};
13use crate::error::McpSqlError;
14
15#[derive(Clone)]
16pub struct McpSqlServer {
17    db: Arc<DatabaseManager>,
18    allow_write: bool,
19    row_limit: u32,
20    query_timeout: Duration,
21    tool_router: ToolRouter<Self>,
22}
23
24// -- Tool parameter types --
25
26#[derive(Debug, Deserialize, schemars::JsonSchema)]
27pub struct DatabaseParam {
28    #[schemars(description = "Database name (optional if only one database is connected)")]
29    #[serde(default)]
30    pub database: Option<String>,
31}
32
33#[derive(Debug, Deserialize, schemars::JsonSchema)]
34pub struct DescribeTableParams {
35    #[schemars(description = "Table name to describe (use schema.table for PostgreSQL)")]
36    pub table: String,
37
38    #[schemars(description = "Database name (optional if only one database is connected)")]
39    #[serde(default)]
40    pub database: Option<String>,
41}
42
43#[derive(Debug, Deserialize, schemars::JsonSchema)]
44pub struct SampleDataParams {
45    #[schemars(description = "Table name to sample rows from")]
46    pub table: String,
47
48    #[schemars(description = "Database name (optional if only one database is connected)")]
49    #[serde(default)]
50    pub database: Option<String>,
51
52    #[schemars(description = "Number of sample rows to return (default: 5)")]
53    #[serde(default)]
54    pub limit: Option<u32>,
55}
56
57#[derive(Debug, Deserialize, schemars::JsonSchema)]
58pub struct QueryParams {
59    #[schemars(description = "SQL query to execute")]
60    pub sql: String,
61
62    #[schemars(description = "Database name (optional if only one database is connected)")]
63    #[serde(default)]
64    pub database: Option<String>,
65}
66
67impl McpSqlServer {
68    pub fn new(db: DatabaseManager, allow_write: bool, row_limit: u32, query_timeout_secs: u64) -> Self {
69        Self {
70            db: Arc::new(db),
71            allow_write,
72            row_limit,
73            query_timeout: Duration::from_secs(query_timeout_secs),
74            tool_router: Self::tool_router(),
75        }
76    }
77
78    fn err(&self, e: McpSqlError) -> ErrorData {
79        e.to_mcp_error()
80    }
81}
82
83#[tool_router]
84impl McpSqlServer {
85    #[tool(
86        name = "list_databases",
87        description = "List all connected databases with their names and types (postgres/sqlite/mysql)"
88    )]
89    async fn list_databases(&self) -> Result<CallToolResult, ErrorData> {
90        let databases: Vec<serde_json::Value> = self
91            .db
92            .databases
93            .iter()
94            .map(|d| {
95                serde_json::json!({
96                    "name": d.name,
97                    "type": d.backend.name(),
98                    "url": d.url_redacted,
99                })
100            })
101            .collect();
102
103        let text = serde_json::to_string_pretty(&databases)
104            .unwrap_or_else(|_| "[]".to_string());
105        Ok(CallToolResult::success(vec![Content::text(text)]))
106    }
107
108    #[tool(
109        name = "list_tables",
110        description = "List all tables in a database with approximate row counts"
111    )]
112    async fn list_tables(
113        &self,
114        Parameters(params): Parameters<DatabaseParam>,
115    ) -> Result<CallToolResult, ErrorData> {
116        let entry = self.db.resolve(params.database.as_deref()).map_err(|e| self.err(e))?;
117        let tables = dialect::list_tables(&entry.pool, entry.backend)
118            .await
119            .map_err(|e| self.err(e))?;
120
121        let text = serde_json::to_string_pretty(&tables)
122            .unwrap_or_else(|_| "[]".to_string());
123        Ok(CallToolResult::success(vec![Content::text(text)]))
124    }
125
126    #[tool(
127        name = "describe_table",
128        description = "Describe a table's columns with name, type, nullable, default, and primary key info"
129    )]
130    async fn describe_table(
131        &self,
132        Parameters(params): Parameters<DescribeTableParams>,
133    ) -> Result<CallToolResult, ErrorData> {
134        let entry = self.db.resolve(params.database.as_deref()).map_err(|e| self.err(e))?;
135        let columns = dialect::describe_table(&entry.pool, entry.backend, &params.table)
136            .await
137            .map_err(|e| self.err(e))?;
138
139        let text = serde_json::to_string_pretty(&columns)
140            .unwrap_or_else(|_| "[]".to_string());
141        Ok(CallToolResult::success(vec![Content::text(text)]))
142    }
143
144    #[tool(
145        name = "query",
146        description = "Execute a SQL query and return results as JSON. Read-only by default (SELECT/WITH/SHOW/PRAGMA only). Use --allow-write flag to enable write operations."
147    )]
148    async fn query(
149        &self,
150        Parameters(params): Parameters<QueryParams>,
151    ) -> Result<CallToolResult, ErrorData> {
152        let entry = self.db.resolve(params.database.as_deref()).map_err(|e| self.err(e))?;
153        let sql = params.sql.trim();
154
155        // Read-only guard
156        if !self.allow_write {
157            check_read_only(sql).map_err(|e| self.err(e))?;
158        }
159
160        // Set transaction read only for backends that support it
161        if !self.allow_write && entry.backend != DbBackend::Sqlite {
162            let read_only_sql = match entry.backend {
163                DbBackend::Postgres => "SET TRANSACTION READ ONLY",
164                DbBackend::Mysql => "SET TRANSACTION READ ONLY",
165                DbBackend::Sqlite => unreachable!(),
166            };
167            // Best effort — some connection states may not support this
168            let _ = sqlx::query(read_only_sql).execute(&entry.pool).await;
169        }
170
171        // Inject LIMIT if not present
172        let limited_sql = inject_limit(sql, self.row_limit);
173
174        let rows = tokio::time::timeout(
175            self.query_timeout,
176            sqlx::query(&limited_sql).fetch_all(&entry.pool),
177        )
178        .await
179        .map_err(|_| self.err(McpSqlError::QueryTimeout(self.query_timeout.as_secs())))?
180        .map_err(|e| self.err(McpSqlError::Database(e)))?;
181
182        let results: Vec<serde_json::Value> = rows.iter().map(row_to_json).collect();
183        let text = serde_json::to_string_pretty(&serde_json::json!({
184            "rows": results,
185            "count": results.len(),
186        }))
187        .unwrap_or_else(|_| "{}".to_string());
188
189        Ok(CallToolResult::success(vec![Content::text(text)]))
190    }
191
192    #[tool(
193        name = "explain",
194        description = "Show the query execution plan for a SQL statement. Uses the appropriate EXPLAIN syntax for the database type."
195    )]
196    async fn explain(
197        &self,
198        Parameters(params): Parameters<QueryParams>,
199    ) -> Result<CallToolResult, ErrorData> {
200        let entry = self.db.resolve(params.database.as_deref()).map_err(|e| self.err(e))?;
201        let prefix = dialect::explain_prefix(entry.backend);
202        let explain_sql = format!("{}{}", prefix, params.sql.trim());
203
204        let rows = tokio::time::timeout(
205            self.query_timeout,
206            sqlx::query(&explain_sql).fetch_all(&entry.pool),
207        )
208        .await
209        .map_err(|_| self.err(McpSqlError::QueryTimeout(self.query_timeout.as_secs())))?
210        .map_err(|e| self.err(McpSqlError::Database(e)))?;
211
212        let results: Vec<serde_json::Value> = rows.iter().map(row_to_json).collect();
213        let text = serde_json::to_string_pretty(&results)
214            .unwrap_or_else(|_| "[]".to_string());
215
216        Ok(CallToolResult::success(vec![Content::text(text)]))
217    }
218
219    #[tool(
220        name = "sample_data",
221        description = "Return sample rows from a table as JSON. Useful for previewing table contents without writing SQL."
222    )]
223    async fn sample_data(
224        &self,
225        Parameters(params): Parameters<SampleDataParams>,
226    ) -> Result<CallToolResult, ErrorData> {
227        let entry = self.db.resolve(params.database.as_deref()).map_err(|e| self.err(e))?;
228        let limit = params.limit.unwrap_or(5);
229
230        let rows = tokio::time::timeout(
231            self.query_timeout,
232            dialect::sample_data(&entry.pool, entry.backend, &params.table, limit),
233        )
234        .await
235        .map_err(|_| self.err(McpSqlError::QueryTimeout(self.query_timeout.as_secs())))?
236        .map_err(|e| self.err(e))?;
237
238        let text = serde_json::to_string_pretty(&serde_json::json!({
239            "table": params.table,
240            "rows": rows,
241            "count": rows.len(),
242        }))
243        .unwrap_or_else(|_| "{}".to_string());
244
245        Ok(CallToolResult::success(vec![Content::text(text)]))
246    }
247}
248
249#[tool_handler]
250impl ServerHandler for McpSqlServer {
251    fn get_info(&self) -> ServerInfo {
252        ServerInfo {
253            protocol_version: ProtocolVersion::V_2024_11_05,
254            capabilities: ServerCapabilities::builder().enable_tools().build(),
255            server_info: Implementation {
256                name: "mcp-sql".to_string(),
257                version: env!("CARGO_PKG_VERSION").to_string(),
258                ..Default::default()
259            },
260            instructions: Some(
261                "SQL database server. Use list_databases to see connected databases, \
262                 list_tables to see tables, describe_table for schema details (includes foreign keys), \
263                 sample_data to preview table contents, query to run SQL, and explain for query plans."
264                    .to_string(),
265            ),
266        }
267    }
268}
269
270/// Heuristic check: only allow SELECT, WITH, SHOW, PRAGMA, EXPLAIN.
271fn check_read_only(sql: &str) -> Result<(), McpSqlError> {
272    let upper = sql.trim_start().to_uppercase();
273    let allowed_prefixes = ["SELECT", "WITH", "SHOW", "PRAGMA", "EXPLAIN"];
274    if allowed_prefixes.iter().any(|p| upper.starts_with(p)) {
275        Ok(())
276    } else {
277        Err(McpSqlError::ReadOnly(
278            "Only SELECT/WITH/SHOW/PRAGMA/EXPLAIN queries are allowed in read-only mode. \
279             Start the server with --allow-write to enable write operations."
280                .to_string(),
281        ))
282    }
283}
284
285/// Inject a LIMIT clause if the query doesn't already have one.
286fn inject_limit(sql: &str, limit: u32) -> String {
287    let upper = sql.to_uppercase();
288    // Don't inject LIMIT for non-SELECT statements or if LIMIT already present
289    if !upper.trim_start().starts_with("SELECT") && !upper.trim_start().starts_with("WITH") {
290        return sql.to_string();
291    }
292    if upper.contains(" LIMIT ") {
293        return sql.to_string();
294    }
295    // Strip trailing semicolon if present
296    let trimmed = sql.trim_end().trim_end_matches(';');
297    format!("{trimmed} LIMIT {limit}")
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_check_read_only() {
306        assert!(check_read_only("SELECT * FROM users").is_ok());
307        assert!(check_read_only("  select * from users").is_ok());
308        assert!(check_read_only("WITH cte AS (SELECT 1) SELECT * FROM cte").is_ok());
309        assert!(check_read_only("SHOW TABLES").is_ok());
310        assert!(check_read_only("PRAGMA table_info(users)").is_ok());
311        assert!(check_read_only("EXPLAIN SELECT * FROM users").is_ok());
312
313        assert!(check_read_only("INSERT INTO users VALUES (1)").is_err());
314        assert!(check_read_only("UPDATE users SET name = 'x'").is_err());
315        assert!(check_read_only("DELETE FROM users").is_err());
316        assert!(check_read_only("DROP TABLE users").is_err());
317        assert!(check_read_only("CREATE TABLE t (id INT)").is_err());
318    }
319
320    #[test]
321    fn test_inject_limit() {
322        assert_eq!(
323            inject_limit("SELECT * FROM users", 100),
324            "SELECT * FROM users LIMIT 100"
325        );
326        assert_eq!(
327            inject_limit("SELECT * FROM users;", 100),
328            "SELECT * FROM users LIMIT 100"
329        );
330        assert_eq!(
331            inject_limit("SELECT * FROM users LIMIT 10", 100),
332            "SELECT * FROM users LIMIT 10"
333        );
334        assert_eq!(
335            inject_limit("INSERT INTO users VALUES (1)", 100),
336            "INSERT INTO users VALUES (1)"
337        );
338        assert_eq!(
339            inject_limit("WITH cte AS (SELECT 1) SELECT * FROM cte", 50),
340            "WITH cte AS (SELECT 1) SELECT * FROM cte LIMIT 50"
341        );
342    }
343}