Skip to main content

mcp_postgres/actions/
query.rs

1use serde_json::{json, Value};
2use tokio_postgres::Client;
3use crate::errors::Result as MCPResult;
4
5const MAX_SQL_LEN: usize = 10_000;
6
7fn validate_sql(sql: &str, allowed_prefix: &str, label: &str) -> std::result::Result<(), crate::errors::MCPError> {
8    if sql.is_empty() {
9        return Err(crate::errors::MCPError::InvalidParams("'sql' parameter must not be empty".into()));
10    }
11    if sql.len() > MAX_SQL_LEN {
12        return Err(crate::errors::MCPError::InvalidParams(
13            format!("SQL exceeds maximum length of {MAX_SQL_LEN} characters (got {})", sql.len())
14        ));
15    }
16    let trimmed = sql.trim();
17    let first_word = trimmed.split_whitespace().next().unwrap_or("");
18    if !first_word.eq_ignore_ascii_case(allowed_prefix) {
19        return Err(crate::errors::MCPError::InvalidParams(
20            format!("Invalid {label} query: expected '{allowed_prefix}'")
21        ));
22    }
23    // Reject multi-statement: find the first unquoted ';' that is not trailing
24    let body = trimmed.strip_suffix(';').unwrap_or(trimmed);
25    let mut in_string = false;
26    for (i, ch) in body.char_indices() {
27        if ch == '\'' {
28            in_string = !in_string;
29        }
30        if !in_string && ch == ';' {
31            let ctx_end = (i + 20).min(sql.len());
32            return Err(crate::errors::MCPError::InvalidParams(
33                format!("Multi-statement queries are not allowed: {label} contained ';' at position {i} (context: ...{}...)", &sql[i..ctx_end])
34            ));
35        }
36    }
37    Ok(())
38}
39
40/// 6. Execute query
41pub async fn execute_query(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
42    let sql = params
43        .as_ref()
44        .and_then(|p| p.get("sql"))
45        .and_then(|v| v.as_str())
46        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
47
48    validate_sql(sql, "SELECT", "SELECT")?;
49
50    let rows = client.query(sql, &[]).await?;
51
52    let results: Vec<Value> = rows
53        .iter()
54        .map(|row| {
55            let values: Vec<Value> = (0..row.len())
56                .map(|i| {
57                    // Try type inference: prefer native JSON types over raw strings
58                    row.try_get::<_, bool>(i).map(|v| json!(v))
59                        .or_else(|_| row.try_get::<_, i32>(i).map(|v| json!(v)))
60                        .or_else(|_| row.try_get::<_, i64>(i).map(|v| json!(v)))
61                        .or_else(|_| row.try_get::<_, f32>(i).map(|v| json!(v)))
62                        .or_else(|_| row.try_get::<_, f64>(i).map(|v| json!(v)))
63                        .or_else(|_| row.try_get::<_, String>(i).map(Value::String))
64                        .or_else(|_| row.try_get::<_, Option<String>>(i).map(|v| v.map(Value::String).unwrap_or(Value::Null)))
65                        .unwrap_or(Value::Null)
66                })
67                .collect();
68            Value::Array(values)
69        })
70        .collect();
71
72    Ok(json!({ "rows": results }))
73}
74
75/// 7. Execute insert
76pub async fn execute_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
77    let sql = params
78        .as_ref()
79        .and_then(|p| p.get("sql"))
80        .and_then(|v| v.as_str())
81        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
82
83    validate_sql(sql, "INSERT", "INSERT")?;
84
85    let rows_affected = client.execute(sql, &[]).await?;
86
87    Ok(json!({ "rows_affected": rows_affected }))
88}
89
90/// 8. Execute update
91pub async fn execute_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
92    let sql = params
93        .as_ref()
94        .and_then(|p| p.get("sql"))
95        .and_then(|v| v.as_str())
96        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
97
98    validate_sql(sql, "UPDATE", "UPDATE")?;
99
100    let rows_affected = client.execute(sql, &[]).await?;
101
102    Ok(json!({ "rows_affected": rows_affected }))
103}
104
105/// 9. Execute delete
106pub async fn execute_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
107    let sql = params
108        .as_ref()
109        .and_then(|p| p.get("sql"))
110        .and_then(|v| v.as_str())
111        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
112
113    validate_sql(sql, "DELETE", "DELETE")?;
114
115    let rows_affected = client.execute(sql, &[]).await?;
116
117    Ok(json!({ "rows_affected": rows_affected }))
118}
119
120/// 10. Explain query
121///
122/// Supports EXPLAIN with optional ANALYZE, BUFFERS, and FORMAT options.
123/// Only SELECT queries can be explained.
124pub async fn explain_query(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
125    let sql = params
126        .as_ref()
127        .and_then(|p| p.get("sql"))
128        .and_then(|v| v.as_str())
129        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
130
131    validate_sql(sql, "SELECT", "SELECT")?;
132
133    let analyze = params
134        .as_ref()
135        .and_then(|p| p.get("analyze"))
136        .and_then(|v| v.as_bool())
137        .unwrap_or(false);
138
139    let buffers = params
140        .as_ref()
141        .and_then(|p| p.get("buffers"))
142        .and_then(|v| v.as_bool())
143        .unwrap_or(false);
144
145    let format = params
146        .as_ref()
147        .and_then(|p| p.get("format"))
148        .and_then(|v| v.as_str())
149        .unwrap_or("json");
150
151    if format.eq_ignore_ascii_case("xml") {
152        return Err(crate::errors::MCPError::InvalidParams(
153            "XML format is not supported — use TEXT, YAML, or JSON".into()
154        ));
155    }
156
157    let mut explain_sql = String::with_capacity(sql.len() + 80);
158    explain_sql.push_str("EXPLAIN (FORMAT ");
159    explain_sql.push_str(&format.to_uppercase());
160    if analyze {
161        explain_sql.push_str(", ANALYZE");
162    }
163    if buffers {
164        explain_sql.push_str(", BUFFERS");
165    }
166    explain_sql.push_str(") ");
167    explain_sql.push_str(sql);
168
169    let rows = client.query(&explain_sql, &[]).await?;
170
171    if rows.is_empty() {
172        return Ok(json!({ "plan": null }));
173    }
174
175    if format.eq_ignore_ascii_case("json") {
176        let plan: serde_json::Value = rows[0].get(0);
177        Ok(json!({
178            "plan": plan,
179            "options": { "analyze": analyze, "buffers": buffers, "format": format }
180        }))
181    } else {
182        let mut plan = String::new();
183        for (i, row) in rows.iter().enumerate() {
184            if i > 0 { plan.push('\n'); }
185            plan.push_str(&row.get::<_, String>(0));
186        }
187        Ok(json!({
188            "plan": plan,
189            "options": { "analyze": analyze, "buffers": buffers, "format": format }
190        }))
191    }
192}
193
194/// 26. Async execute insert (with synchronous_commit=off for high-volume operations)
195///
196/// High-performance insert for WHERE predicate affecting more than 100 rows.
197/// Disables synchronous_commit temporarily for maximum throughput.
198/// Significant performance benefit when WHERE condition matches > 100 rows.
199/// Returns rows affected count.
200pub async fn async_execute_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
201    let sql = params
202        .as_ref()
203        .and_then(|p| p.get("sql"))
204        .and_then(|v| v.as_str())
205        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
206
207    validate_sql(sql, "INSERT", "INSERT")?;
208
209    async_sync_commit_execute(client, sql).await
210}
211
212/// 27. Async execute update (with synchronous_commit=off for high-volume operations)
213///
214/// High-performance update for WHERE predicate affecting more than 100 rows.
215/// Disables synchronous_commit temporarily for maximum throughput.
216/// Significant performance benefit when WHERE condition matches > 100 rows.
217/// Always include WHERE clause to prevent accidental updates.
218/// Returns rows affected count.
219pub async fn async_execute_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
220    let sql = params
221        .as_ref()
222        .and_then(|p| p.get("sql"))
223        .and_then(|v| v.as_str())
224        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
225
226    validate_sql(sql, "UPDATE", "UPDATE")?;
227
228    async_sync_commit_execute(client, sql).await
229}
230
231/// 28. Async execute delete (with synchronous_commit=off for high-volume operations)
232///
233/// High-performance delete for WHERE predicate affecting more than 100 rows.
234/// Disables synchronous_commit temporarily for maximum throughput.
235/// Significant performance benefit when WHERE condition matches > 100 rows.
236/// Always include WHERE clause - deleting without one removes all rows.
237/// Returns rows affected count.
238pub async fn async_execute_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
239    let sql = params
240        .as_ref()
241        .and_then(|p| p.get("sql"))
242        .and_then(|v| v.as_str())
243        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'sql' parameter".into()))?;
244
245    validate_sql(sql, "DELETE", "DELETE")?;
246
247    async_sync_commit_execute(client, sql).await
248}
249
250/// Execute a DML statement inside a transaction with SET LOCAL synchronous_commit = OFF.
251/// The SET LOCAL is scoped to the transaction, so it auto-resets on COMMIT/ROLLBACK,
252/// preventing session-state leakage when the connection returns to the pool.
253async fn async_sync_commit_execute(client: &Client, sql: &str) -> MCPResult<Value> {
254    client.execute("BEGIN", &[]).await?;
255    client.execute("SET LOCAL synchronous_commit = OFF", &[]).await?;
256    match client.execute(sql, &[]).await {
257        Ok(rows_affected) => {
258            client.execute("COMMIT", &[]).await?;
259            Ok(json!({ "rows_affected": rows_affected }))
260        }
261        Err(e) => {
262            client.execute("ROLLBACK", &[]).await.ok();
263            Err(crate::errors::MCPError::DatabaseError(e))
264        }
265    }
266}