Skip to main content

mcp_postgres/actions/
query.rs

1use serde_json::{json, Value};
2use tokio_postgres::Client;
3use crate::errors::Result as MCPResult;
4
5const MAX_SQL_LEN: usize = 10_000;
6
7fn validate_sql(sql: &str, allowed_prefix: &str, label: &str) -> std::result::Result<(), crate::errors::MCPError> {
8    if sql.is_empty() {
9        return Err(crate::errors::MCPError::InvalidParams("'sql' parameter must not be empty".into()));
10    }
11    if sql.len() > MAX_SQL_LEN {
12        return Err(crate::errors::MCPError::InvalidParams(
13            format!("SQL exceeds maximum length of {MAX_SQL_LEN} characters (got {})", sql.len())
14        ));
15    }
16    let trimmed = sql.trim();
17    let first_word = trimmed.split_whitespace().next().unwrap_or("").to_uppercase();
18    if first_word != allowed_prefix {
19        return Err(crate::errors::MCPError::InvalidParams(
20            format!("Invalid {label} query: expected '{allowed_prefix}'")
21        ));
22    }
23    // Reject multi-statement: find the first unquoted ';' that is not trailing
24    let body = trimmed.strip_suffix(';').unwrap_or(trimmed);
25    let mut in_string = false;
26    for (i, ch) in body.char_indices() {
27        if ch == '\'' {
28            in_string = !in_string;
29        }
30        if !in_string && ch == ';' {
31            let ctx_end = (i + 20).min(sql.len());
32            return Err(crate::errors::MCPError::InvalidParams(
33                format!("Multi-statement queries are not allowed: {label} contained ';' at position {i} (context: ...{}...)", &sql[i..ctx_end])
34            ));
35        }
36    }
37    Ok(())
38}
39
40/// 6. Execute query
41pub async fn execute_query(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
42    let sql = params
43        .as_ref()
44        .and_then(|p| p.get("sql"))
45        .and_then(|v| v.as_str())
46        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
47
48    validate_sql(sql, "SELECT", "SELECT")?;
49
50    let rows = client.query(sql, &[]).await?;
51
52    let results: Vec<Value> = rows
53        .iter()
54        .map(|row| {
55            let values: Vec<Value> = (0..row.len())
56                .map(|i| {
57                    // Try type inference: prefer native JSON types over raw strings
58                    if let Ok(v) = row.try_get::<_, bool>(i) {
59                        json!(v)
60                    } else if let Ok(v) = row.try_get::<_, i32>(i) {
61                        json!(v)
62                    } else if let Ok(v) = row.try_get::<_, i64>(i) {
63                        json!(v)
64                    } else if let Ok(v) = row.try_get::<_, f32>(i) {
65                        json!(v)
66                    } else if let Ok(v) = row.try_get::<_, f64>(i) {
67                        json!(v)
68                    } else if let Ok(v) = row.try_get::<_, String>(i) {
69                        Value::String(v)
70                    } else if let Ok(v) = row.try_get::<_, Option<String>>(i) {
71                        v.map(Value::String).unwrap_or(Value::Null)
72                    } else {
73                        Value::Null
74                    }
75                })
76                .collect();
77            Value::Array(values)
78        })
79        .collect();
80
81    Ok(json!({ "rows": results }))
82}
83
84/// 7. Execute insert
85pub async fn execute_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
86    let sql = params
87        .as_ref()
88        .and_then(|p| p.get("sql"))
89        .and_then(|v| v.as_str())
90        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
91
92    validate_sql(sql, "INSERT", "INSERT")?;
93
94    let rows_affected = client.execute(sql, &[]).await?;
95
96    Ok(json!({ "rows_affected": rows_affected }))
97}
98
99/// 8. Execute update
100pub async fn execute_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
101    let sql = params
102        .as_ref()
103        .and_then(|p| p.get("sql"))
104        .and_then(|v| v.as_str())
105        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
106
107    validate_sql(sql, "UPDATE", "UPDATE")?;
108
109    let rows_affected = client.execute(sql, &[]).await?;
110
111    Ok(json!({ "rows_affected": rows_affected }))
112}
113
114/// 9. Execute delete
115pub async fn execute_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
116    let sql = params
117        .as_ref()
118        .and_then(|p| p.get("sql"))
119        .and_then(|v| v.as_str())
120        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
121
122    validate_sql(sql, "DELETE", "DELETE")?;
123
124    let rows_affected = client.execute(sql, &[]).await?;
125
126    Ok(json!({ "rows_affected": rows_affected }))
127}
128
129/// 10. Explain query
130///
131/// Supports EXPLAIN with optional ANALYZE, BUFFERS, and FORMAT options.
132/// Only SELECT queries can be explained.
133pub async fn explain_query(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
134    let sql = params
135        .as_ref()
136        .and_then(|p| p.get("sql"))
137        .and_then(|v| v.as_str())
138        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
139
140    validate_sql(sql, "SELECT", "SELECT")?;
141
142    let analyze = params
143        .as_ref()
144        .and_then(|p| p.get("analyze"))
145        .and_then(|v| v.as_bool())
146        .unwrap_or(false);
147
148    let buffers = params
149        .as_ref()
150        .and_then(|p| p.get("buffers"))
151        .and_then(|v| v.as_bool())
152        .unwrap_or(false);
153
154    let format = params
155        .as_ref()
156        .and_then(|p| p.get("format"))
157        .and_then(|v| v.as_str())
158        .unwrap_or("json");
159
160    let format_upper = format.to_uppercase();
161    if format_upper == "XML" {
162        return Err(crate::errors::MCPError::InvalidParams(
163            "XML format is not supported — use TEXT, YAML, or JSON".into()
164        ));
165    }
166
167    let mut opts = Vec::new();
168    opts.push(format!("FORMAT {}", format_upper));
169    if analyze {
170        opts.push("ANALYZE".to_string());
171    }
172    if buffers {
173        opts.push("BUFFERS".to_string());
174    }
175
176    let mut explain_sql = String::with_capacity(sql.len() + 64);
177    explain_sql.push_str("EXPLAIN (");
178    explain_sql.push_str(&opts.join(", "));
179    explain_sql.push_str(") ");
180    explain_sql.push_str(sql);
181
182    let rows = client.query(&explain_sql, &[]).await?;
183
184    if rows.is_empty() {
185        return Ok(json!({ "plan": null }));
186    }
187
188    if format.eq_ignore_ascii_case("json") {
189        let plan: serde_json::Value = rows[0].get(0);
190        Ok(json!({
191            "plan": plan,
192            "options": { "analyze": analyze, "buffers": buffers, "format": format }
193        }))
194    } else {
195        let lines: Vec<String> = rows.iter().map(|r| r.get::<_, String>(0)).collect();
196        Ok(json!({
197            "plan": lines.join("\n"),
198            "options": { "analyze": analyze, "buffers": buffers, "format": format }
199        }))
200    }
201}
202
203/// 26. Async execute insert (with synchronous_commit=off for high-volume operations)
204///
205/// High-performance insert for WHERE predicate affecting more than 100 rows.
206/// Disables synchronous_commit temporarily for maximum throughput.
207/// Significant performance benefit when WHERE condition matches > 100 rows.
208/// Returns rows affected count.
209pub async fn async_execute_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
210    let sql = params
211        .as_ref()
212        .and_then(|p| p.get("sql"))
213        .and_then(|v| v.as_str())
214        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
215
216    validate_sql(sql, "INSERT", "INSERT")?;
217
218    async_sync_commit_execute(client, sql).await
219}
220
221/// 27. Async execute update (with synchronous_commit=off for high-volume operations)
222///
223/// High-performance update for WHERE predicate affecting more than 100 rows.
224/// Disables synchronous_commit temporarily for maximum throughput.
225/// Significant performance benefit when WHERE condition matches > 100 rows.
226/// Always include WHERE clause to prevent accidental updates.
227/// Returns rows affected count.
228pub async fn async_execute_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
229    let sql = params
230        .as_ref()
231        .and_then(|p| p.get("sql"))
232        .and_then(|v| v.as_str())
233        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
234
235    validate_sql(sql, "UPDATE", "UPDATE")?;
236
237    async_sync_commit_execute(client, sql).await
238}
239
240/// 28. Async execute delete (with synchronous_commit=off for high-volume operations)
241///
242/// High-performance delete for WHERE predicate affecting more than 100 rows.
243/// Disables synchronous_commit temporarily for maximum throughput.
244/// Significant performance benefit when WHERE condition matches > 100 rows.
245/// Always include WHERE clause - deleting without one removes all rows.
246/// Returns rows affected count.
247pub async fn async_execute_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
248    let sql = params
249        .as_ref()
250        .and_then(|p| p.get("sql"))
251        .and_then(|v| v.as_str())
252        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
253
254    validate_sql(sql, "DELETE", "DELETE")?;
255
256    async_sync_commit_execute(client, sql).await
257}
258
259/// Execute a DML statement inside a transaction with SET LOCAL synchronous_commit = OFF.
260/// The SET LOCAL is scoped to the transaction, so it auto-resets on COMMIT/ROLLBACK,
261/// preventing session-state leakage when the connection returns to the pool.
262async fn async_sync_commit_execute(client: &Client, sql: &str) -> MCPResult<Value> {
263    client.execute("BEGIN", &[]).await?;
264    client.execute("SET LOCAL synchronous_commit = OFF", &[]).await?;
265    match client.execute(sql, &[]).await {
266        Ok(rows_affected) => {
267            client.execute("COMMIT", &[]).await?;
268            Ok(json!({ "rows_affected": rows_affected }))
269        }
270        Err(e) => {
271            client.execute("ROLLBACK", &[]).await.ok();
272            Err(crate::errors::MCPError::DatabaseError(e))
273        }
274    }
275}