Skip to main content

mcp_postgres/actions/
query.rs

1use crate::errors::Result as MCPResult;
2use serde_json::{Value, json};
3use tokio_postgres::Client;
4
5const MAX_SQL_LEN: usize = 10_000;
6
7fn validate_sql(
8    sql: &str,
9    allowed_prefix: &str,
10    label: &str,
11) -> std::result::Result<(), crate::errors::MCPError> {
12    if sql.is_empty() {
13        return Err(crate::errors::MCPError::InvalidParams(
14            "'sql' parameter must not be empty".into(),
15        ));
16    }
17    if sql.len() > MAX_SQL_LEN {
18        return Err(crate::errors::MCPError::InvalidParams(format!(
19            "SQL exceeds maximum length of {MAX_SQL_LEN} characters (got {})",
20            sql.len()
21        )));
22    }
23    let trimmed = sql.trim();
24    let first_word = trimmed.split_whitespace().next().unwrap_or("");
25    if !first_word.eq_ignore_ascii_case(allowed_prefix) {
26        return Err(crate::errors::MCPError::InvalidParams(format!(
27            "Invalid {label} query: expected '{allowed_prefix}'"
28        )));
29    }
30    // Reject multi-statement: find the first unquoted ';' that is not trailing
31    let body = trimmed.strip_suffix(';').unwrap_or(trimmed);
32    let mut in_string = false;
33    for (i, ch) in body.char_indices() {
34        if ch == '\'' {
35            in_string = !in_string;
36        }
37        if !in_string && ch == ';' {
38            let ctx_end = (i + 20).min(sql.len());
39            return Err(crate::errors::MCPError::InvalidParams(format!(
40                "Multi-statement queries are not allowed: {label} contained ';' at position {i} (context: ...{}...)",
41                &sql[i..ctx_end]
42            )));
43        }
44    }
45    Ok(())
46}
47
48/// 6. Execute query
49pub async fn execute_query(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
50    let sql = params
51        .as_ref()
52        .and_then(|p| p.get("sql"))
53        .and_then(|v| v.as_str())
54        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
55
56    validate_sql(sql, "SELECT", "SELECT")?;
57
58    let rows = client.query(sql, &[]).await?;
59
60    let results: Vec<Value> = rows
61        .iter()
62        .map(|row| {
63            let values: Vec<Value> = (0..row.len())
64                .map(|i| {
65                    // Try type inference: prefer native JSON types over raw strings
66                    row.try_get::<_, bool>(i)
67                        .map(|v| json!(v))
68                        .or_else(|_| row.try_get::<_, i32>(i).map(|v| json!(v)))
69                        .or_else(|_| row.try_get::<_, i64>(i).map(|v| json!(v)))
70                        .or_else(|_| row.try_get::<_, f32>(i).map(|v| json!(v)))
71                        .or_else(|_| row.try_get::<_, f64>(i).map(|v| json!(v)))
72                        .or_else(|_| row.try_get::<_, String>(i).map(Value::String))
73                        .or_else(|_| {
74                            row.try_get::<_, Option<String>>(i)
75                                .map(|v| v.map(Value::String).unwrap_or(Value::Null))
76                        })
77                        .unwrap_or(Value::Null)
78                })
79                .collect();
80            Value::Array(values)
81        })
82        .collect();
83
84    Ok(json!({ "rows": results }))
85}
86
87/// 7. Execute insert
88pub async fn execute_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
89    let sql = params
90        .as_ref()
91        .and_then(|p| p.get("sql"))
92        .and_then(|v| v.as_str())
93        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
94
95    validate_sql(sql, "INSERT", "INSERT")?;
96
97    let rows_affected = client.execute(sql, &[]).await?;
98
99    Ok(json!({ "rows_affected": rows_affected }))
100}
101
102/// 8. Execute update
103pub async fn execute_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
104    let sql = params
105        .as_ref()
106        .and_then(|p| p.get("sql"))
107        .and_then(|v| v.as_str())
108        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
109
110    validate_sql(sql, "UPDATE", "UPDATE")?;
111
112    let rows_affected = client.execute(sql, &[]).await?;
113
114    Ok(json!({ "rows_affected": rows_affected }))
115}
116
117/// 9. Execute delete
118pub async fn execute_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
119    let sql = params
120        .as_ref()
121        .and_then(|p| p.get("sql"))
122        .and_then(|v| v.as_str())
123        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
124
125    validate_sql(sql, "DELETE", "DELETE")?;
126
127    let rows_affected = client.execute(sql, &[]).await?;
128
129    Ok(json!({ "rows_affected": rows_affected }))
130}
131
132/// 10. Explain query
133///
134/// Supports EXPLAIN with optional ANALYZE, BUFFERS, and FORMAT options.
135/// Only SELECT queries can be explained.
136pub async fn explain_query(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
137    let sql = params
138        .as_ref()
139        .and_then(|p| p.get("sql"))
140        .and_then(|v| v.as_str())
141        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
142
143    validate_sql(sql, "SELECT", "SELECT")?;
144
145    let analyze = params
146        .as_ref()
147        .and_then(|p| p.get("analyze"))
148        .and_then(|v| v.as_bool())
149        .unwrap_or(false);
150
151    let buffers = params
152        .as_ref()
153        .and_then(|p| p.get("buffers"))
154        .and_then(|v| v.as_bool())
155        .unwrap_or(false);
156
157    let format = params
158        .as_ref()
159        .and_then(|p| p.get("format"))
160        .and_then(|v| v.as_str())
161        .unwrap_or("json");
162
163    if format.eq_ignore_ascii_case("xml") {
164        return Err(crate::errors::MCPError::InvalidParams(
165            "XML format is not supported — use TEXT, YAML, or JSON".into(),
166        ));
167    }
168
169    let mut explain_sql = String::with_capacity(sql.len() + 80);
170    explain_sql.push_str("EXPLAIN (FORMAT ");
171    explain_sql.push_str(&format.to_uppercase());
172    if analyze {
173        explain_sql.push_str(", ANALYZE");
174    }
175    if buffers {
176        explain_sql.push_str(", BUFFERS");
177    }
178    explain_sql.push_str(") ");
179    explain_sql.push_str(sql);
180
181    let rows = client.query(&explain_sql, &[]).await?;
182
183    if rows.is_empty() {
184        return Ok(json!({ "plan": null }));
185    }
186
187    if format.eq_ignore_ascii_case("json") {
188        let plan: serde_json::Value = rows[0].get(0);
189        Ok(json!({
190            "plan": plan,
191            "options": { "analyze": analyze, "buffers": buffers, "format": format }
192        }))
193    } else {
194        let mut plan = String::new();
195        for (i, row) in rows.iter().enumerate() {
196            if i > 0 {
197                plan.push('\n');
198            }
199            plan.push_str(&row.get::<_, String>(0));
200        }
201        Ok(json!({
202            "plan": plan,
203            "options": { "analyze": analyze, "buffers": buffers, "format": format }
204        }))
205    }
206}
207
208/// 26. Async execute insert (with synchronous_commit=off for high-volume operations)
209///
210/// High-performance insert for WHERE predicate affecting more than 100 rows.
211/// Disables synchronous_commit temporarily for maximum throughput.
212/// Significant performance benefit when WHERE condition matches > 100 rows.
213/// Returns rows affected count.
214pub async fn async_execute_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
215    let sql = params
216        .as_ref()
217        .and_then(|p| p.get("sql"))
218        .and_then(|v| v.as_str())
219        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
220
221    validate_sql(sql, "INSERT", "INSERT")?;
222
223    async_sync_commit_execute(client, sql).await
224}
225
226/// 27. Async execute update (with synchronous_commit=off for high-volume operations)
227///
228/// High-performance update for WHERE predicate affecting more than 100 rows.
229/// Disables synchronous_commit temporarily for maximum throughput.
230/// Significant performance benefit when WHERE condition matches > 100 rows.
231/// Always include WHERE clause to prevent accidental updates.
232/// Returns rows affected count.
233pub async fn async_execute_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
234    let sql = params
235        .as_ref()
236        .and_then(|p| p.get("sql"))
237        .and_then(|v| v.as_str())
238        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
239
240    validate_sql(sql, "UPDATE", "UPDATE")?;
241
242    async_sync_commit_execute(client, sql).await
243}
244
245/// 28. Async execute delete (with synchronous_commit=off for high-volume operations)
246///
247/// High-performance delete for WHERE predicate affecting more than 100 rows.
248/// Disables synchronous_commit temporarily for maximum throughput.
249/// Significant performance benefit when WHERE condition matches > 100 rows.
250/// Always include WHERE clause - deleting without one removes all rows.
251/// Returns rows affected count.
252pub async fn async_execute_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
253    let sql = params
254        .as_ref()
255        .and_then(|p| p.get("sql"))
256        .and_then(|v| v.as_str())
257        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
258
259    validate_sql(sql, "DELETE", "DELETE")?;
260
261    async_sync_commit_execute(client, sql).await
262}
263
264/// Execute a DML statement inside a transaction with SET LOCAL synchronous_commit = OFF.
265/// The SET LOCAL is scoped to the transaction, so it auto-resets on COMMIT/ROLLBACK,
266/// preventing session-state leakage when the connection returns to the pool.
267async fn async_sync_commit_execute(client: &Client, sql: &str) -> MCPResult<Value> {
268    client.execute("BEGIN", &[]).await?;
269    client
270        .execute("SET LOCAL synchronous_commit = OFF", &[])
271        .await?;
272    match client.execute(sql, &[]).await {
273        Ok(rows_affected) => {
274            client.execute("COMMIT", &[]).await?;
275            Ok(json!({ "rows_affected": rows_affected }))
276        }
277        Err(e) => {
278            client.execute("ROLLBACK", &[]).await.ok();
279            Err(crate::errors::MCPError::DatabaseError(e))
280        }
281    }
282}