1use crate::errors::Result as MCPResult;
2use serde_json::{Value, json};
3use tokio_postgres::types::Type;
4use tokio_postgres::{Client, Row};
5
6const MAX_SQL_LEN: usize = 10_000;
7
8fn decode_cell(row: &Row, i: usize) -> Value {
16 let ty = row.columns()[i].type_().clone();
17 match ty {
18 Type::BOOL => match row.try_get::<_, Option<bool>>(i) {
19 Ok(Some(v)) => json!(v),
20 _ => Value::Null,
21 },
22 Type::INT2 => match row.try_get::<_, Option<i16>>(i) {
23 Ok(Some(v)) => json!(v),
24 _ => Value::Null,
25 },
26 Type::INT4 => match row.try_get::<_, Option<i32>>(i) {
27 Ok(Some(v)) => json!(v),
28 _ => Value::Null,
29 },
30 Type::INT8 => match row.try_get::<_, Option<i64>>(i) {
31 Ok(Some(v)) => json!(v),
32 _ => Value::Null,
33 },
34 Type::OID => match row.try_get::<_, Option<u32>>(i) {
35 Ok(Some(v)) => json!(v),
36 _ => Value::Null,
37 },
38 Type::FLOAT4 => match row.try_get::<_, Option<f32>>(i) {
39 Ok(Some(v)) => json!(v),
40 _ => Value::Null,
41 },
42 Type::FLOAT8 => match row.try_get::<_, Option<f64>>(i) {
43 Ok(Some(v)) => json!(v),
44 _ => Value::Null,
45 },
46 Type::NUMERIC => str_cell::<rust_decimal::Decimal>(row, i),
48 Type::UUID => str_cell::<uuid::Uuid>(row, i),
49 Type::TIMESTAMP => str_cell::<chrono::NaiveDateTime>(row, i),
50 Type::TIMESTAMPTZ => str_cell::<chrono::DateTime<chrono::Utc>>(row, i),
51 Type::DATE => str_cell::<chrono::NaiveDate>(row, i),
52 Type::TIME => str_cell::<chrono::NaiveTime>(row, i),
53 Type::JSON | Type::JSONB => match row.try_get::<_, Option<Value>>(i) {
54 Ok(Some(v)) => v,
55 _ => Value::Null,
56 },
57 Type::BYTEA => match row.try_get::<_, Option<Vec<u8>>>(i) {
58 Ok(Some(b)) => Value::String(to_hex(&b)),
59 _ => Value::Null,
60 },
61 Type::TEXT | Type::VARCHAR | Type::BPCHAR | Type::NAME => {
62 match row.try_get::<_, Option<String>>(i) {
63 Ok(Some(v)) => Value::String(v),
64 _ => Value::Null,
65 }
66 }
67 _ => match row.try_get::<_, Option<String>>(i) {
69 Ok(Some(v)) => Value::String(v),
70 _ => Value::Null,
71 },
72 }
73}
74
75fn str_cell<T>(row: &Row, i: usize) -> Value
78where
79 T: std::fmt::Display + for<'a> tokio_postgres::types::FromSql<'a>,
80{
81 match row.try_get::<_, Option<T>>(i) {
82 Ok(Some(v)) => Value::String(v.to_string()),
83 _ => Value::Null,
84 }
85}
86
87fn to_hex(bytes: &[u8]) -> String {
88 use std::fmt::Write;
89 let mut s = String::with_capacity(2 + bytes.len() * 2);
90 s.push_str("\\x");
91 for b in bytes {
92 let _ = write!(s, "{b:02x}");
93 }
94 s
95}
96
97pub(crate) fn validate_sql(
98 sql: &str,
99 allowed_prefix: &str,
100 label: &str,
101) -> std::result::Result<(), crate::errors::MCPError> {
102 if sql.is_empty() {
103 return Err(crate::errors::MCPError::InvalidParams(
104 "'sql' parameter must not be empty".into(),
105 ));
106 }
107 if sql.len() > MAX_SQL_LEN {
108 return Err(crate::errors::MCPError::InvalidParams(format!(
109 "SQL exceeds maximum length of {MAX_SQL_LEN} characters (got {})",
110 sql.len()
111 )));
112 }
113 let trimmed = sql.trim();
114 let first_word = trimmed.split_whitespace().next().unwrap_or("");
115 if !first_word.eq_ignore_ascii_case(allowed_prefix) {
116 return Err(crate::errors::MCPError::InvalidParams(format!(
117 "Invalid {label} query: expected '{allowed_prefix}'"
118 )));
119 }
120 let body = trimmed.strip_suffix(';').unwrap_or(trimmed);
124 if let Some(i) = first_unquoted_semicolon(body) {
125 let ctx_end = (i + 20).min(body.len());
126 let ctx = body.get(i..ctx_end).unwrap_or("");
127 return Err(crate::errors::MCPError::InvalidParams(format!(
128 "Multi-statement queries are not allowed: {label} contained ';' at position {i} (context: ...{ctx}...)"
129 )));
130 }
131 Ok(())
132}
133
134fn first_unquoted_semicolon(sql: &str) -> Option<usize> {
138 let b = sql.as_bytes();
139 let n = b.len();
140 let mut i = 0;
141 while i < n {
142 match b[i] {
143 b'\'' => {
144 i += 1;
146 while i < n {
147 if b[i] == b'\'' {
148 if i + 1 < n && b[i + 1] == b'\'' {
149 i += 2;
150 continue;
151 }
152 i += 1;
153 break;
154 }
155 i += 1;
156 }
157 }
158 b'"' => {
159 i += 1;
161 while i < n {
162 if b[i] == b'"' {
163 if i + 1 < n && b[i + 1] == b'"' {
164 i += 2;
165 continue;
166 }
167 i += 1;
168 break;
169 }
170 i += 1;
171 }
172 }
173 b'-' if i + 1 < n && b[i + 1] == b'-' => {
174 i += 2;
176 while i < n && b[i] != b'\n' {
177 i += 1;
178 }
179 }
180 b'/' if i + 1 < n && b[i + 1] == b'*' => {
181 i += 2;
183 let mut depth = 1usize;
184 while i < n && depth > 0 {
185 if i + 1 < n && b[i] == b'/' && b[i + 1] == b'*' {
186 depth += 1;
187 i += 2;
188 } else if i + 1 < n && b[i] == b'*' && b[i + 1] == b'/' {
189 depth -= 1;
190 i += 2;
191 } else {
192 i += 1;
193 }
194 }
195 }
196 b'$' => {
197 let mut j = i + 1;
199 while j < n && (b[j].is_ascii_alphanumeric() || b[j] == b'_') {
200 j += 1;
201 }
202 if j < n && b[j] == b'$' {
203 let tag = &sql[i..=j]; match sql[j + 1..].find(tag) {
205 Some(off) => i = j + 1 + off + tag.len(),
206 None => i = n, }
208 } else {
209 i += 1;
210 }
211 }
212 b';' => return Some(i),
213 _ => i += 1,
214 }
215 }
216 None
217}
218
219pub async fn execute_query(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
221 let sql = params
222 .as_ref()
223 .and_then(|p| p.get("sql"))
224 .and_then(|v| v.as_str())
225 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
226
227 validate_sql(sql, "SELECT", "SELECT")?;
228
229 let rows = client.query(sql, &[]).await?;
230
231 let results: Vec<Value> = rows
232 .iter()
233 .map(|row| {
234 let values: Vec<Value> = (0..row.len()).map(|i| decode_cell(row, i)).collect();
235 Value::Array(values)
236 })
237 .collect();
238
239 Ok(json!({ "rows": results }))
240}
241
242pub async fn execute_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
244 let sql = params
245 .as_ref()
246 .and_then(|p| p.get("sql"))
247 .and_then(|v| v.as_str())
248 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
249
250 validate_sql(sql, "INSERT", "INSERT")?;
251
252 let rows_affected = client.execute(sql, &[]).await?;
253
254 Ok(json!({ "rows_affected": rows_affected }))
255}
256
257pub async fn execute_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
259 let sql = params
260 .as_ref()
261 .and_then(|p| p.get("sql"))
262 .and_then(|v| v.as_str())
263 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
264
265 validate_sql(sql, "UPDATE", "UPDATE")?;
266
267 let rows_affected = client.execute(sql, &[]).await?;
268
269 Ok(json!({ "rows_affected": rows_affected }))
270}
271
272pub async fn execute_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
274 let sql = params
275 .as_ref()
276 .and_then(|p| p.get("sql"))
277 .and_then(|v| v.as_str())
278 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
279
280 validate_sql(sql, "DELETE", "DELETE")?;
281
282 let rows_affected = client.execute(sql, &[]).await?;
283
284 Ok(json!({ "rows_affected": rows_affected }))
285}
286
287pub async fn explain_query(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
292 let sql = params
293 .as_ref()
294 .and_then(|p| p.get("sql"))
295 .and_then(|v| v.as_str())
296 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
297
298 validate_sql(sql, "SELECT", "SELECT")?;
299
300 let analyze = params
301 .as_ref()
302 .and_then(|p| p.get("analyze"))
303 .and_then(|v| v.as_bool())
304 .unwrap_or(false);
305
306 let buffers = params
307 .as_ref()
308 .and_then(|p| p.get("buffers"))
309 .and_then(|v| v.as_bool())
310 .unwrap_or(false);
311
312 let format = params
313 .as_ref()
314 .and_then(|p| p.get("format"))
315 .and_then(|v| v.as_str())
316 .unwrap_or("json");
317
318 if format.eq_ignore_ascii_case("xml") {
319 return Err(crate::errors::MCPError::InvalidParams(
320 "XML format is not supported — use TEXT, YAML, or JSON".into(),
321 ));
322 }
323
324 let mut explain_sql = String::with_capacity(sql.len() + 80);
325 explain_sql.push_str("EXPLAIN (FORMAT ");
326 explain_sql.push_str(&format.to_uppercase());
327 if analyze {
328 explain_sql.push_str(", ANALYZE");
329 }
330 if buffers {
331 explain_sql.push_str(", BUFFERS");
332 }
333 explain_sql.push_str(") ");
334 explain_sql.push_str(sql);
335
336 let rows = client.query(&explain_sql, &[]).await?;
337
338 if rows.is_empty() {
339 return Ok(json!({ "plan": null }));
340 }
341
342 if format.eq_ignore_ascii_case("json") {
343 let plan: serde_json::Value = rows[0].get(0);
344 Ok(json!({
345 "plan": plan,
346 "options": { "analyze": analyze, "buffers": buffers, "format": format }
347 }))
348 } else {
349 let mut plan = String::new();
350 for (i, row) in rows.iter().enumerate() {
351 if i > 0 {
352 plan.push('\n');
353 }
354 plan.push_str(&row.get::<_, String>(0));
355 }
356 Ok(json!({
357 "plan": plan,
358 "options": { "analyze": analyze, "buffers": buffers, "format": format }
359 }))
360 }
361}
362
363pub async fn async_execute_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
370 let sql = params
371 .as_ref()
372 .and_then(|p| p.get("sql"))
373 .and_then(|v| v.as_str())
374 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
375
376 validate_sql(sql, "INSERT", "INSERT")?;
377
378 async_sync_commit_execute(client, sql).await
379}
380
381pub async fn async_execute_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
389 let sql = params
390 .as_ref()
391 .and_then(|p| p.get("sql"))
392 .and_then(|v| v.as_str())
393 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
394
395 validate_sql(sql, "UPDATE", "UPDATE")?;
396
397 async_sync_commit_execute(client, sql).await
398}
399
400pub async fn async_execute_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
408 let sql = params
409 .as_ref()
410 .and_then(|p| p.get("sql"))
411 .and_then(|v| v.as_str())
412 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
413
414 validate_sql(sql, "DELETE", "DELETE")?;
415
416 async_sync_commit_execute(client, sql).await
417}
418
419async fn async_sync_commit_execute(client: &Client, sql: &str) -> MCPResult<Value> {
423 client.execute("BEGIN", &[]).await?;
424 client
425 .execute("SET LOCAL synchronous_commit = OFF", &[])
426 .await?;
427 match client.execute(sql, &[]).await {
428 Ok(rows_affected) => {
429 client.execute("COMMIT", &[]).await?;
430 Ok(json!({ "rows_affected": rows_affected }))
431 }
432 Err(e) => {
433 client.execute("ROLLBACK", &[]).await.ok();
434 Err(crate::errors::MCPError::DatabaseError(e))
435 }
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[test]
444 fn test_unquoted_semicolon_detected() {
445 assert_eq!(first_unquoted_semicolon("SELECT 1; DROP TABLE x"), Some(8));
446 }
447
448 #[test]
449 fn test_semicolon_in_string_ignored() {
450 assert_eq!(first_unquoted_semicolon("SELECT ';not a stmt'"), None);
451 assert_eq!(first_unquoted_semicolon("SELECT 'a''b; c'"), None);
452 }
453
454 #[test]
455 fn test_semicolon_in_identifier_ignored() {
456 assert_eq!(
457 first_unquoted_semicolon("SELECT \"weird;col\" FROM t"),
458 None
459 );
460 }
461
462 #[test]
463 fn test_semicolon_in_comments_ignored() {
464 assert_eq!(first_unquoted_semicolon("SELECT 1 -- a; b\n"), None);
465 assert_eq!(first_unquoted_semicolon("SELECT 1 /* a; b */"), None);
466 }
467
468 #[test]
469 fn test_semicolon_in_dollar_quote_ignored() {
470 assert_eq!(first_unquoted_semicolon("SELECT $$a; b$$"), None);
471 assert_eq!(first_unquoted_semicolon("SELECT $tag$a; b$tag$"), None);
472 }
473
474 #[test]
475 fn test_validate_sql_allows_trailing_semicolon() {
476 assert!(validate_sql("SELECT 1;", "SELECT", "SELECT").is_ok());
477 assert!(validate_sql("SELECT ';'", "SELECT", "SELECT").is_ok());
478 }
479
480 #[test]
481 fn test_validate_sql_rejects_stacked() {
482 assert!(validate_sql("SELECT 1; DROP TABLE x", "SELECT", "SELECT").is_err());
483 }
484
485 #[test]
486 fn test_validate_sql_prefix() {
487 assert!(validate_sql("DELETE FROM x WHERE id=1", "DELETE", "DELETE").is_ok());
488 assert!(validate_sql("SELECT 1", "DELETE", "DELETE").is_err());
489 }
490
491 #[test]
492 fn test_to_hex() {
493 assert_eq!(to_hex(&[0xde, 0xad, 0xbe, 0xef]), "\\xdeadbeef");
494 assert_eq!(to_hex(&[]), "\\x");
495 assert_eq!(to_hex(&[0x00, 0x0f]), "\\x000f");
496 }
497}