1use std::sync::Arc;
2use std::time::Duration;
3
4use rmcp::handler::server::router::tool::ToolRouter;
5use rmcp::handler::server::wrapper::Parameters;
6use rmcp::model::*;
7use rmcp::{schemars, tool, tool_handler, tool_router, ServerHandler};
8use serde::Deserialize;
9
10use crate::db::convert::row_to_json;
11use crate::db::dialect;
12use crate::db::{DatabaseManager, DbBackend};
13use crate::error::McpSqlError;
14
15#[derive(Clone)]
16pub struct McpSqlServer {
17 db: Arc<DatabaseManager>,
18 allow_write: bool,
19 row_limit: u32,
20 query_timeout: Duration,
21 tool_router: ToolRouter<Self>,
22}
23
24#[derive(Debug, Deserialize, schemars::JsonSchema)]
27pub struct DatabaseParam {
28 #[schemars(description = "Database name (optional if only one database is connected)")]
29 #[serde(default)]
30 pub database: Option<String>,
31}
32
33#[derive(Debug, Deserialize, schemars::JsonSchema)]
34pub struct DescribeTableParams {
35 #[schemars(description = "Table name to describe (use schema.table for PostgreSQL)")]
36 pub table: String,
37
38 #[schemars(description = "Database name (optional if only one database is connected)")]
39 #[serde(default)]
40 pub database: Option<String>,
41}
42
43#[derive(Debug, Deserialize, schemars::JsonSchema)]
44pub struct SampleDataParams {
45 #[schemars(description = "Table name to sample rows from")]
46 pub table: String,
47
48 #[schemars(description = "Database name (optional if only one database is connected)")]
49 #[serde(default)]
50 pub database: Option<String>,
51
52 #[schemars(description = "Number of sample rows to return (default: 5)")]
53 #[serde(default)]
54 pub limit: Option<u32>,
55}
56
57#[derive(Debug, Deserialize, schemars::JsonSchema)]
58pub struct QueryParams {
59 #[schemars(description = "SQL query to execute")]
60 pub sql: String,
61
62 #[schemars(description = "Database name (optional if only one database is connected)")]
63 #[serde(default)]
64 pub database: Option<String>,
65}
66
67impl McpSqlServer {
68 pub fn new(db: DatabaseManager, allow_write: bool, row_limit: u32, query_timeout_secs: u64) -> Self {
69 Self {
70 db: Arc::new(db),
71 allow_write,
72 row_limit,
73 query_timeout: Duration::from_secs(query_timeout_secs),
74 tool_router: Self::tool_router(),
75 }
76 }
77
78 fn err(&self, e: McpSqlError) -> ErrorData {
79 e.to_mcp_error()
80 }
81}
82
83#[tool_router]
84impl McpSqlServer {
85 #[tool(
86 name = "list_databases",
87 description = "List all connected databases with their names and types (postgres/sqlite/mysql)"
88 )]
89 async fn list_databases(&self) -> Result<CallToolResult, ErrorData> {
90 let databases: Vec<serde_json::Value> = self
91 .db
92 .databases
93 .iter()
94 .map(|d| {
95 serde_json::json!({
96 "name": d.name,
97 "type": d.backend.name(),
98 "url": d.url_redacted,
99 })
100 })
101 .collect();
102
103 let text = serde_json::to_string_pretty(&databases)
104 .unwrap_or_else(|_| "[]".to_string());
105 Ok(CallToolResult::success(vec![Content::text(text)]))
106 }
107
108 #[tool(
109 name = "list_tables",
110 description = "List all tables in a database with approximate row counts"
111 )]
112 async fn list_tables(
113 &self,
114 Parameters(params): Parameters<DatabaseParam>,
115 ) -> Result<CallToolResult, ErrorData> {
116 let entry = self.db.resolve(params.database.as_deref()).map_err(|e| self.err(e))?;
117 let tables = dialect::list_tables(&entry.pool, entry.backend)
118 .await
119 .map_err(|e| self.err(e))?;
120
121 let text = serde_json::to_string_pretty(&tables)
122 .unwrap_or_else(|_| "[]".to_string());
123 Ok(CallToolResult::success(vec![Content::text(text)]))
124 }
125
126 #[tool(
127 name = "describe_table",
128 description = "Describe a table's columns with name, type, nullable, default, and primary key info"
129 )]
130 async fn describe_table(
131 &self,
132 Parameters(params): Parameters<DescribeTableParams>,
133 ) -> Result<CallToolResult, ErrorData> {
134 let entry = self.db.resolve(params.database.as_deref()).map_err(|e| self.err(e))?;
135 let columns = dialect::describe_table(&entry.pool, entry.backend, ¶ms.table)
136 .await
137 .map_err(|e| self.err(e))?;
138
139 let text = serde_json::to_string_pretty(&columns)
140 .unwrap_or_else(|_| "[]".to_string());
141 Ok(CallToolResult::success(vec![Content::text(text)]))
142 }
143
144 #[tool(
145 name = "query",
146 description = "Execute a SQL query and return results as JSON. Read-only by default (SELECT/WITH/SHOW/PRAGMA only). Use --allow-write flag to enable write operations."
147 )]
148 async fn query(
149 &self,
150 Parameters(params): Parameters<QueryParams>,
151 ) -> Result<CallToolResult, ErrorData> {
152 let entry = self.db.resolve(params.database.as_deref()).map_err(|e| self.err(e))?;
153 let sql = params.sql.trim();
154
155 if !self.allow_write {
157 check_read_only(sql).map_err(|e| self.err(e))?;
158 }
159
160 if !self.allow_write && entry.backend != DbBackend::Sqlite {
162 let read_only_sql = match entry.backend {
163 DbBackend::Postgres => "SET TRANSACTION READ ONLY",
164 DbBackend::Mysql => "SET TRANSACTION READ ONLY",
165 DbBackend::Sqlite => unreachable!(),
166 };
167 let _ = sqlx::query(read_only_sql).execute(&entry.pool).await;
169 }
170
171 let limited_sql = inject_limit(sql, self.row_limit);
173
174 let rows = tokio::time::timeout(
175 self.query_timeout,
176 sqlx::query(&limited_sql).fetch_all(&entry.pool),
177 )
178 .await
179 .map_err(|_| self.err(McpSqlError::QueryTimeout(self.query_timeout.as_secs())))?
180 .map_err(|e| self.err(McpSqlError::Database(e)))?;
181
182 let results: Vec<serde_json::Value> = rows.iter().map(row_to_json).collect();
183 let text = serde_json::to_string_pretty(&serde_json::json!({
184 "rows": results,
185 "count": results.len(),
186 }))
187 .unwrap_or_else(|_| "{}".to_string());
188
189 Ok(CallToolResult::success(vec![Content::text(text)]))
190 }
191
192 #[tool(
193 name = "explain",
194 description = "Show the query execution plan for a SQL statement. Uses the appropriate EXPLAIN syntax for the database type."
195 )]
196 async fn explain(
197 &self,
198 Parameters(params): Parameters<QueryParams>,
199 ) -> Result<CallToolResult, ErrorData> {
200 let entry = self.db.resolve(params.database.as_deref()).map_err(|e| self.err(e))?;
201 let prefix = dialect::explain_prefix(entry.backend);
202 let explain_sql = format!("{}{}", prefix, params.sql.trim());
203
204 let rows = tokio::time::timeout(
205 self.query_timeout,
206 sqlx::query(&explain_sql).fetch_all(&entry.pool),
207 )
208 .await
209 .map_err(|_| self.err(McpSqlError::QueryTimeout(self.query_timeout.as_secs())))?
210 .map_err(|e| self.err(McpSqlError::Database(e)))?;
211
212 let results: Vec<serde_json::Value> = rows.iter().map(row_to_json).collect();
213 let text = serde_json::to_string_pretty(&results)
214 .unwrap_or_else(|_| "[]".to_string());
215
216 Ok(CallToolResult::success(vec![Content::text(text)]))
217 }
218
219 #[tool(
220 name = "sample_data",
221 description = "Return sample rows from a table as JSON. Useful for previewing table contents without writing SQL."
222 )]
223 async fn sample_data(
224 &self,
225 Parameters(params): Parameters<SampleDataParams>,
226 ) -> Result<CallToolResult, ErrorData> {
227 let entry = self.db.resolve(params.database.as_deref()).map_err(|e| self.err(e))?;
228 let limit = params.limit.unwrap_or(5);
229
230 let rows = tokio::time::timeout(
231 self.query_timeout,
232 dialect::sample_data(&entry.pool, entry.backend, ¶ms.table, limit),
233 )
234 .await
235 .map_err(|_| self.err(McpSqlError::QueryTimeout(self.query_timeout.as_secs())))?
236 .map_err(|e| self.err(e))?;
237
238 let text = serde_json::to_string_pretty(&serde_json::json!({
239 "table": params.table,
240 "rows": rows,
241 "count": rows.len(),
242 }))
243 .unwrap_or_else(|_| "{}".to_string());
244
245 Ok(CallToolResult::success(vec![Content::text(text)]))
246 }
247}
248
249#[tool_handler]
250impl ServerHandler for McpSqlServer {
251 fn get_info(&self) -> ServerInfo {
252 ServerInfo {
253 protocol_version: ProtocolVersion::V_2024_11_05,
254 capabilities: ServerCapabilities::builder().enable_tools().build(),
255 server_info: Implementation {
256 name: "mcp-sql".to_string(),
257 version: env!("CARGO_PKG_VERSION").to_string(),
258 ..Default::default()
259 },
260 instructions: Some(
261 "SQL database server. Use list_databases to see connected databases, \
262 list_tables to see tables, describe_table for schema details (includes foreign keys), \
263 sample_data to preview table contents, query to run SQL, and explain for query plans."
264 .to_string(),
265 ),
266 }
267 }
268}
269
270fn check_read_only(sql: &str) -> Result<(), McpSqlError> {
272 let upper = sql.trim_start().to_uppercase();
273 let allowed_prefixes = ["SELECT", "WITH", "SHOW", "PRAGMA", "EXPLAIN"];
274 if allowed_prefixes.iter().any(|p| upper.starts_with(p)) {
275 Ok(())
276 } else {
277 Err(McpSqlError::ReadOnly(
278 "Only SELECT/WITH/SHOW/PRAGMA/EXPLAIN queries are allowed in read-only mode. \
279 Start the server with --allow-write to enable write operations."
280 .to_string(),
281 ))
282 }
283}
284
285fn inject_limit(sql: &str, limit: u32) -> String {
287 let upper = sql.to_uppercase();
288 if !upper.trim_start().starts_with("SELECT") && !upper.trim_start().starts_with("WITH") {
290 return sql.to_string();
291 }
292 if upper.contains(" LIMIT ") {
293 return sql.to_string();
294 }
295 let trimmed = sql.trim_end().trim_end_matches(';');
297 format!("{trimmed} LIMIT {limit}")
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_check_read_only() {
306 assert!(check_read_only("SELECT * FROM users").is_ok());
307 assert!(check_read_only(" select * from users").is_ok());
308 assert!(check_read_only("WITH cte AS (SELECT 1) SELECT * FROM cte").is_ok());
309 assert!(check_read_only("SHOW TABLES").is_ok());
310 assert!(check_read_only("PRAGMA table_info(users)").is_ok());
311 assert!(check_read_only("EXPLAIN SELECT * FROM users").is_ok());
312
313 assert!(check_read_only("INSERT INTO users VALUES (1)").is_err());
314 assert!(check_read_only("UPDATE users SET name = 'x'").is_err());
315 assert!(check_read_only("DELETE FROM users").is_err());
316 assert!(check_read_only("DROP TABLE users").is_err());
317 assert!(check_read_only("CREATE TABLE t (id INT)").is_err());
318 }
319
320 #[test]
321 fn test_inject_limit() {
322 assert_eq!(
323 inject_limit("SELECT * FROM users", 100),
324 "SELECT * FROM users LIMIT 100"
325 );
326 assert_eq!(
327 inject_limit("SELECT * FROM users;", 100),
328 "SELECT * FROM users LIMIT 100"
329 );
330 assert_eq!(
331 inject_limit("SELECT * FROM users LIMIT 10", 100),
332 "SELECT * FROM users LIMIT 10"
333 );
334 assert_eq!(
335 inject_limit("INSERT INTO users VALUES (1)", 100),
336 "INSERT INTO users VALUES (1)"
337 );
338 assert_eq!(
339 inject_limit("WITH cte AS (SELECT 1) SELECT * FROM cte", 50),
340 "WITH cte AS (SELECT 1) SELECT * FROM cte LIMIT 50"
341 );
342 }
343}