Skip to main content

mcp_postgres/actions/
query.rs

1use crate::errors::Result as MCPResult;
2use serde_json::{Value, json};
3use tokio_postgres::types::Type;
4use tokio_postgres::{Client, Row};
5
6const MAX_SQL_LEN: usize = 10_000;
7
8/// Decode a single result cell to JSON based on its PostgreSQL column type.
9///
10/// Numeric and boolean types map to native JSON numbers/bools; temporal,
11/// numeric-decimal, uuid and text types map to strings; json/jsonb pass
12/// through as structured JSON; bytea becomes a hex string. Unknown types fall
13/// back to their text representation, and only truly undecodable values (e.g.
14/// arrays) become null.
15fn decode_cell(row: &Row, i: usize) -> Value {
16    let ty = row.columns()[i].type_().clone();
17    match ty {
18        Type::BOOL => match row.try_get::<_, Option<bool>>(i) {
19            Ok(Some(v)) => json!(v),
20            _ => Value::Null,
21        },
22        Type::INT2 => match row.try_get::<_, Option<i16>>(i) {
23            Ok(Some(v)) => json!(v),
24            _ => Value::Null,
25        },
26        Type::INT4 => match row.try_get::<_, Option<i32>>(i) {
27            Ok(Some(v)) => json!(v),
28            _ => Value::Null,
29        },
30        Type::INT8 => match row.try_get::<_, Option<i64>>(i) {
31            Ok(Some(v)) => json!(v),
32            _ => Value::Null,
33        },
34        Type::OID => match row.try_get::<_, Option<u32>>(i) {
35            Ok(Some(v)) => json!(v),
36            _ => Value::Null,
37        },
38        Type::FLOAT4 => match row.try_get::<_, Option<f32>>(i) {
39            Ok(Some(v)) => json!(v),
40            _ => Value::Null,
41        },
42        Type::FLOAT8 => match row.try_get::<_, Option<f64>>(i) {
43            Ok(Some(v)) => json!(v),
44            _ => Value::Null,
45        },
46        // Decimal as a string to preserve full precision.
47        Type::NUMERIC => str_cell::<rust_decimal::Decimal>(row, i),
48        Type::UUID => str_cell::<uuid::Uuid>(row, i),
49        Type::TIMESTAMP => str_cell::<chrono::NaiveDateTime>(row, i),
50        Type::TIMESTAMPTZ => str_cell::<chrono::DateTime<chrono::Utc>>(row, i),
51        Type::DATE => str_cell::<chrono::NaiveDate>(row, i),
52        Type::TIME => str_cell::<chrono::NaiveTime>(row, i),
53        Type::JSON | Type::JSONB => match row.try_get::<_, Option<Value>>(i) {
54            Ok(Some(v)) => v,
55            _ => Value::Null,
56        },
57        Type::BYTEA => match row.try_get::<_, Option<Vec<u8>>>(i) {
58            Ok(Some(b)) => Value::String(to_hex(&b)),
59            _ => Value::Null,
60        },
61        Type::TEXT | Type::VARCHAR | Type::BPCHAR | Type::NAME => {
62            match row.try_get::<_, Option<String>>(i) {
63                Ok(Some(v)) => Value::String(v),
64                _ => Value::Null,
65            }
66        }
67        // Fallback: enums, citext, and other text-output types decode as String.
68        _ => match row.try_get::<_, Option<String>>(i) {
69            Ok(Some(v)) => Value::String(v),
70            _ => Value::Null,
71        },
72    }
73}
74
75/// Decode an optional value whose Rust type implements `Display`, rendering it
76/// as a JSON string (or null when SQL NULL / undecodable).
77fn str_cell<T>(row: &Row, i: usize) -> Value
78where
79    T: std::fmt::Display + for<'a> tokio_postgres::types::FromSql<'a>,
80{
81    match row.try_get::<_, Option<T>>(i) {
82        Ok(Some(v)) => Value::String(v.to_string()),
83        _ => Value::Null,
84    }
85}
86
87fn to_hex(bytes: &[u8]) -> String {
88    use std::fmt::Write;
89    let mut s = String::with_capacity(2 + bytes.len() * 2);
90    s.push_str("\\x");
91    for b in bytes {
92        let _ = write!(s, "{b:02x}");
93    }
94    s
95}
96
97pub(crate) fn validate_sql(
98    sql: &str,
99    allowed_prefix: &str,
100    label: &str,
101) -> std::result::Result<(), crate::errors::MCPError> {
102    if sql.is_empty() {
103        return Err(crate::errors::MCPError::InvalidParams(
104            "'sql' parameter must not be empty".into(),
105        ));
106    }
107    if sql.len() > MAX_SQL_LEN {
108        return Err(crate::errors::MCPError::InvalidParams(format!(
109            "SQL exceeds maximum length of {MAX_SQL_LEN} characters (got {})",
110            sql.len()
111        )));
112    }
113    let trimmed = sql.trim();
114    let first_word = trimmed.split_whitespace().next().unwrap_or("");
115    if !first_word.eq_ignore_ascii_case(allowed_prefix) {
116        return Err(crate::errors::MCPError::InvalidParams(format!(
117            "Invalid {label} query: expected '{allowed_prefix}'"
118        )));
119    }
120    // Reject multi-statement: find the first statement-terminating ';' that is
121    // not inside a string literal, quoted identifier, dollar-quoted string, or
122    // comment. A single trailing ';' is allowed.
123    let body = trimmed.strip_suffix(';').unwrap_or(trimmed);
124    if let Some(i) = first_unquoted_semicolon(body) {
125        let ctx_end = (i + 20).min(body.len());
126        let ctx = body.get(i..ctx_end).unwrap_or("");
127        return Err(crate::errors::MCPError::InvalidParams(format!(
128            "Multi-statement queries are not allowed: {label} contained ';' at position {i} (context: ...{ctx}...)"
129        )));
130    }
131    Ok(())
132}
133
134/// Byte index of the first `;` in `sql` that lies outside any string literal,
135/// quoted identifier, dollar-quoted string, or comment. Returns `None` if there
136/// is no such terminator.
137fn first_unquoted_semicolon(sql: &str) -> Option<usize> {
138    let b = sql.as_bytes();
139    let n = b.len();
140    let mut i = 0;
141    while i < n {
142        match b[i] {
143            b'\'' => {
144                // single-quoted string literal; '' is an escaped quote
145                i += 1;
146                while i < n {
147                    if b[i] == b'\'' {
148                        if i + 1 < n && b[i + 1] == b'\'' {
149                            i += 2;
150                            continue;
151                        }
152                        i += 1;
153                        break;
154                    }
155                    i += 1;
156                }
157            }
158            b'"' => {
159                // double-quoted identifier; "" is an escaped quote
160                i += 1;
161                while i < n {
162                    if b[i] == b'"' {
163                        if i + 1 < n && b[i + 1] == b'"' {
164                            i += 2;
165                            continue;
166                        }
167                        i += 1;
168                        break;
169                    }
170                    i += 1;
171                }
172            }
173            b'-' if i + 1 < n && b[i + 1] == b'-' => {
174                // line comment to end of line
175                i += 2;
176                while i < n && b[i] != b'\n' {
177                    i += 1;
178                }
179            }
180            b'/' if i + 1 < n && b[i + 1] == b'*' => {
181                // block comment (PostgreSQL allows nesting)
182                i += 2;
183                let mut depth = 1usize;
184                while i < n && depth > 0 {
185                    if i + 1 < n && b[i] == b'/' && b[i + 1] == b'*' {
186                        depth += 1;
187                        i += 2;
188                    } else if i + 1 < n && b[i] == b'*' && b[i + 1] == b'/' {
189                        depth -= 1;
190                        i += 2;
191                    } else {
192                        i += 1;
193                    }
194                }
195            }
196            b'$' => {
197                // dollar-quoted string: $tag$ ... $tag$ (tag may be empty)
198                let mut j = i + 1;
199                while j < n && (b[j].is_ascii_alphanumeric() || b[j] == b'_') {
200                    j += 1;
201                }
202                if j < n && b[j] == b'$' {
203                    let tag = &sql[i..=j]; // includes both $ delimiters
204                    match sql[j + 1..].find(tag) {
205                        Some(off) => i = j + 1 + off + tag.len(),
206                        None => i = n, // unterminated — consume the rest
207                    }
208                } else {
209                    i += 1;
210                }
211            }
212            b';' => return Some(i),
213            _ => i += 1,
214        }
215    }
216    None
217}
218
219/// 6. Execute query
220pub async fn execute_query(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
221    let sql = params
222        .as_ref()
223        .and_then(|p| p.get("sql"))
224        .and_then(|v| v.as_str())
225        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
226
227    validate_sql(sql, "SELECT", "SELECT")?;
228
229    let rows = client.query(sql, &[]).await?;
230
231    let results: Vec<Value> = rows
232        .iter()
233        .map(|row| {
234            let values: Vec<Value> = (0..row.len()).map(|i| decode_cell(row, i)).collect();
235            Value::Array(values)
236        })
237        .collect();
238
239    Ok(json!({ "rows": results }))
240}
241
242/// 7. Execute insert
243pub async fn execute_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
244    let sql = params
245        .as_ref()
246        .and_then(|p| p.get("sql"))
247        .and_then(|v| v.as_str())
248        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
249
250    validate_sql(sql, "INSERT", "INSERT")?;
251
252    let rows_affected = client.execute(sql, &[]).await?;
253
254    Ok(json!({ "rows_affected": rows_affected }))
255}
256
257/// 8. Execute update
258pub async fn execute_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
259    let sql = params
260        .as_ref()
261        .and_then(|p| p.get("sql"))
262        .and_then(|v| v.as_str())
263        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
264
265    validate_sql(sql, "UPDATE", "UPDATE")?;
266
267    let rows_affected = client.execute(sql, &[]).await?;
268
269    Ok(json!({ "rows_affected": rows_affected }))
270}
271
272/// 9. Execute delete
273pub async fn execute_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
274    let sql = params
275        .as_ref()
276        .and_then(|p| p.get("sql"))
277        .and_then(|v| v.as_str())
278        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
279
280    validate_sql(sql, "DELETE", "DELETE")?;
281
282    let rows_affected = client.execute(sql, &[]).await?;
283
284    Ok(json!({ "rows_affected": rows_affected }))
285}
286
287/// 10. Explain query
288///
289/// Supports EXPLAIN with optional ANALYZE, BUFFERS, and FORMAT options.
290/// Only SELECT queries can be explained.
291pub async fn explain_query(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
292    let sql = params
293        .as_ref()
294        .and_then(|p| p.get("sql"))
295        .and_then(|v| v.as_str())
296        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
297
298    validate_sql(sql, "SELECT", "SELECT")?;
299
300    let analyze = params
301        .as_ref()
302        .and_then(|p| p.get("analyze"))
303        .and_then(|v| v.as_bool())
304        .unwrap_or(false);
305
306    let buffers = params
307        .as_ref()
308        .and_then(|p| p.get("buffers"))
309        .and_then(|v| v.as_bool())
310        .unwrap_or(false);
311
312    let format = params
313        .as_ref()
314        .and_then(|p| p.get("format"))
315        .and_then(|v| v.as_str())
316        .unwrap_or("json");
317
318    if format.eq_ignore_ascii_case("xml") {
319        return Err(crate::errors::MCPError::InvalidParams(
320            "XML format is not supported — use TEXT, YAML, or JSON".into(),
321        ));
322    }
323
324    let mut explain_sql = String::with_capacity(sql.len() + 80);
325    explain_sql.push_str("EXPLAIN (FORMAT ");
326    explain_sql.push_str(&format.to_uppercase());
327    if analyze {
328        explain_sql.push_str(", ANALYZE");
329    }
330    if buffers {
331        explain_sql.push_str(", BUFFERS");
332    }
333    explain_sql.push_str(") ");
334    explain_sql.push_str(sql);
335
336    let rows = client.query(&explain_sql, &[]).await?;
337
338    if rows.is_empty() {
339        return Ok(json!({ "plan": null }));
340    }
341
342    if format.eq_ignore_ascii_case("json") {
343        let plan: serde_json::Value = rows[0].get(0);
344        Ok(json!({
345            "plan": plan,
346            "options": { "analyze": analyze, "buffers": buffers, "format": format }
347        }))
348    } else {
349        let mut plan = String::new();
350        for (i, row) in rows.iter().enumerate() {
351            if i > 0 {
352                plan.push('\n');
353            }
354            plan.push_str(&row.get::<_, String>(0));
355        }
356        Ok(json!({
357            "plan": plan,
358            "options": { "analyze": analyze, "buffers": buffers, "format": format }
359        }))
360    }
361}
362
363/// 26. Async execute insert (with synchronous_commit=off for high-volume operations)
364///
365/// High-performance insert for WHERE predicate affecting more than 100 rows.
366/// Disables synchronous_commit temporarily for maximum throughput.
367/// Significant performance benefit when WHERE condition matches > 100 rows.
368/// Returns rows affected count.
369pub async fn async_execute_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
370    let sql = params
371        .as_ref()
372        .and_then(|p| p.get("sql"))
373        .and_then(|v| v.as_str())
374        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
375
376    validate_sql(sql, "INSERT", "INSERT")?;
377
378    async_sync_commit_execute(client, sql).await
379}
380
381/// 27. Async execute update (with synchronous_commit=off for high-volume operations)
382///
383/// High-performance update for WHERE predicate affecting more than 100 rows.
384/// Disables synchronous_commit temporarily for maximum throughput.
385/// Significant performance benefit when WHERE condition matches > 100 rows.
386/// Always include WHERE clause to prevent accidental updates.
387/// Returns rows affected count.
388pub async fn async_execute_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
389    let sql = params
390        .as_ref()
391        .and_then(|p| p.get("sql"))
392        .and_then(|v| v.as_str())
393        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
394
395    validate_sql(sql, "UPDATE", "UPDATE")?;
396
397    async_sync_commit_execute(client, sql).await
398}
399
400/// 28. Async execute delete (with synchronous_commit=off for high-volume operations)
401///
402/// High-performance delete for WHERE predicate affecting more than 100 rows.
403/// Disables synchronous_commit temporarily for maximum throughput.
404/// Significant performance benefit when WHERE condition matches > 100 rows.
405/// Always include WHERE clause - deleting without one removes all rows.
406/// Returns rows affected count.
407pub async fn async_execute_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
408    let sql = params
409        .as_ref()
410        .and_then(|p| p.get("sql"))
411        .and_then(|v| v.as_str())
412        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
413
414    validate_sql(sql, "DELETE", "DELETE")?;
415
416    async_sync_commit_execute(client, sql).await
417}
418
419/// Execute a DML statement inside a transaction with SET LOCAL synchronous_commit = OFF.
420/// The SET LOCAL is scoped to the transaction, so it auto-resets on COMMIT/ROLLBACK,
421/// preventing session-state leakage when the connection returns to the pool.
422async fn async_sync_commit_execute(client: &Client, sql: &str) -> MCPResult<Value> {
423    client.execute("BEGIN", &[]).await?;
424    client
425        .execute("SET LOCAL synchronous_commit = OFF", &[])
426        .await?;
427    match client.execute(sql, &[]).await {
428        Ok(rows_affected) => {
429            client.execute("COMMIT", &[]).await?;
430            Ok(json!({ "rows_affected": rows_affected }))
431        }
432        Err(e) => {
433            client.execute("ROLLBACK", &[]).await.ok();
434            Err(crate::errors::MCPError::DatabaseError(e))
435        }
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    #[test]
444    fn test_unquoted_semicolon_detected() {
445        assert_eq!(first_unquoted_semicolon("SELECT 1; DROP TABLE x"), Some(8));
446    }
447
448    #[test]
449    fn test_semicolon_in_string_ignored() {
450        assert_eq!(first_unquoted_semicolon("SELECT ';not a stmt'"), None);
451        assert_eq!(first_unquoted_semicolon("SELECT 'a''b; c'"), None);
452    }
453
454    #[test]
455    fn test_semicolon_in_identifier_ignored() {
456        assert_eq!(
457            first_unquoted_semicolon("SELECT \"weird;col\" FROM t"),
458            None
459        );
460    }
461
462    #[test]
463    fn test_semicolon_in_comments_ignored() {
464        assert_eq!(first_unquoted_semicolon("SELECT 1 -- a; b\n"), None);
465        assert_eq!(first_unquoted_semicolon("SELECT 1 /* a; b */"), None);
466    }
467
468    #[test]
469    fn test_semicolon_in_dollar_quote_ignored() {
470        assert_eq!(first_unquoted_semicolon("SELECT $$a; b$$"), None);
471        assert_eq!(first_unquoted_semicolon("SELECT $tag$a; b$tag$"), None);
472    }
473
474    #[test]
475    fn test_validate_sql_allows_trailing_semicolon() {
476        assert!(validate_sql("SELECT 1;", "SELECT", "SELECT").is_ok());
477        assert!(validate_sql("SELECT ';'", "SELECT", "SELECT").is_ok());
478    }
479
480    #[test]
481    fn test_validate_sql_rejects_stacked() {
482        assert!(validate_sql("SELECT 1; DROP TABLE x", "SELECT", "SELECT").is_err());
483    }
484
485    #[test]
486    fn test_validate_sql_prefix() {
487        assert!(validate_sql("DELETE FROM x WHERE id=1", "DELETE", "DELETE").is_ok());
488        assert!(validate_sql("SELECT 1", "DELETE", "DELETE").is_err());
489    }
490
491    #[test]
492    fn test_to_hex() {
493        assert_eq!(to_hex(&[0xde, 0xad, 0xbe, 0xef]), "\\xdeadbeef");
494        assert_eq!(to_hex(&[]), "\\x");
495        assert_eq!(to_hex(&[0x00, 0x0f]), "\\x000f");
496    }
497}