Skip to main content

mcp_postgres/actions/
batch.rs

1use serde_json::{json, Value};
2use tokio_postgres::Client;
3use crate::errors::{MCPError, Result as MCPResult};
4use crate::validation::{validate_identifier, quote_identifier};
5
6const MAX_BATCH_ROWS: usize = 1000;
7const ALLOWED_OPS: &[&str] = &["=", "<", ">", "<=", ">=", "<>", "IN", "LIKE"];
8
9fn format_sql_value(val: &Value) -> String {
10    match val {
11        Value::String(s) => format!("'{}'", s.replace("'", "''")),
12        Value::Number(n) => n.to_string(),
13        Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
14        Value::Null => "NULL".to_string(),
15        Value::Array(_) | Value::Object(_) => format!("'{}'", val.to_string().replace("'", "''")),
16    }
17}
18
19fn validate_table_columns(table: &str, columns: &[&str]) -> Result<(), MCPError> {
20    validate_identifier(table, "table")?;
21    for col in columns {
22        validate_identifier(col, "column")?;
23    }
24    Ok(())
25}
26
27fn validate_where_clauses(where_clauses: &[Value]) -> Result<Vec<(String, String, &Value)>, MCPError> {
28    if where_clauses.is_empty() {
29        return Err(MCPError::InvalidParams("'where_clauses' must not be empty".into()));
30    }
31    let mut parsed = Vec::new();
32    for clause in where_clauses {
33        let obj = clause.as_object().ok_or_else(|| {
34            MCPError::InvalidParams("Each where_clause must be an object with 'column', 'op', and 'value'".into())
35        })?;
36        let column = obj.get("column").and_then(|v| v.as_str()).ok_or_else(|| {
37            MCPError::InvalidParams("Each where_clause must have a string 'column'".into())
38        })?;
39        let op = obj.get("op").and_then(|v| v.as_str()).ok_or_else(|| {
40            MCPError::InvalidParams("Each where_clause must have a string 'op'".into())
41        })?;
42        let value = obj.get("value").ok_or_else(|| {
43            MCPError::InvalidParams("Each where_clause must have a 'value'".into())
44        })?;
45        validate_identifier(column, "where_clause.column")?;
46        if !ALLOWED_OPS.contains(&op) {
47            return Err(MCPError::InvalidParams(
48                format!("Invalid operator '{op}' — allowed: {}", ALLOWED_OPS.join(", "))
49            ));
50        }
51        parsed.push((column.to_string(), op.to_string(), value));
52    }
53    Ok(parsed)
54}
55
56fn build_where_sql(parsed: &[(String, String, &Value)]) -> String {
57    parsed.iter().map(|(col, op, val)| {
58        if op == "IN" {
59            if let Some(arr) = val.as_array() {
60                let items: Vec<String> = arr.iter().map(format_sql_value).collect();
61                format!("{} IN ({})", quote_identifier(col), items.join(", "))
62            } else {
63                format!("{} {} {}", quote_identifier(col), op, format_sql_value(val))
64            }
65        } else {
66            format!("{} {} {}", quote_identifier(col), op, format_sql_value(val))
67        }
68    }).collect::<Vec<_>>().join(" OR ")
69}
70
71/// Batch insert - high performance multi-row insertion
72/// Uses SET LOCAL inside a transaction to avoid session-level side effects.
73pub async fn async_batch_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
74    let params = params.as_ref().ok_or_else(|| {
75        MCPError::InvalidParams("Missing parameters".into())
76    })?;
77
78    let table = params
79        .get("table")
80        .and_then(|v| v.as_str())
81        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
82
83    let columns = params
84        .get("columns")
85        .and_then(|v| v.as_array())
86        .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
87
88    let rows = params
89        .get("rows")
90        .and_then(|v| v.as_array())
91        .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
92
93    if rows.is_empty() {
94        return Ok(json!({ "rows_affected": 0 }));
95    }
96
97    if rows.len() > MAX_BATCH_ROWS {
98        return Err(MCPError::InvalidParams(
99            format!("Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})", rows.len())
100        ));
101    }
102
103    let returning = params.get("returning").and_then(|v| v.as_str());
104
105    let column_count = columns.len();
106    let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
107
108    if column_names.len() != column_count {
109        return Err(MCPError::InvalidParams("All column names must be strings".into()));
110    }
111
112    validate_table_columns(table, &column_names)?;
113
114    let quoted_table = quote_identifier(table);
115    let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_identifier(c)).collect();
116    let cols = quoted_cols.join(", ");
117
118    let mut sql = String::with_capacity(64 + cols.len() + rows.len() * (column_count * 16 + 4));
119    use std::fmt::Write;
120    write!(sql, "INSERT INTO {quoted_table} ({cols}) VALUES ").unwrap();
121
122    for (i, row) in rows.iter().enumerate() {
123        let row_array = row.as_array().ok_or_else(|| {
124            MCPError::InvalidParams("Each row must be an array".into())
125        })?;
126
127        if row_array.len() != column_count {
128            return Err(MCPError::InvalidParams(
129                format!("Row {} has {} columns, expected {}", i, row_array.len(), column_count),
130            ));
131        }
132
133        if i > 0 {
134            sql.push(',');
135        }
136        sql.push('(');
137        for (j, val) in row_array.iter().enumerate() {
138            if j > 0 {
139                sql.push_str(", ");
140            }
141            match val {
142                Value::String(s) => {
143                    sql.push('\'');
144                    for ch in s.chars() {
145                        if ch == '\'' {
146                            sql.push_str("''");
147                        } else {
148                            sql.push(ch);
149                        }
150                    }
151                    sql.push('\'');
152                }
153                Value::Number(n) => {
154                    write!(sql, "{n}").unwrap();
155                }
156                Value::Bool(b) => {
157                    sql.push_str(if *b { "true" } else { "false" });
158                }
159                Value::Null => {
160                    sql.push_str("NULL");
161                }
162                Value::Array(_) | Value::Object(_) => {
163                    let s = val.to_string();
164                    sql.push('\'');
165                    for ch in s.chars() {
166                        if ch == '\'' {
167                            sql.push_str("''");
168                        } else {
169                            sql.push(ch);
170                        }
171                    }
172                    sql.push('\'');
173                }
174            }
175        }
176        sql.push(')');
177    }
178
179    client.execute("BEGIN", &[]).await?;
180    client.execute("SET LOCAL synchronous_commit = OFF", &[]).await?;
181
182    let result = if let Some(col) = returning {
183        validate_identifier(col, "returning")?;
184        let r = format!(" RETURNING {}", quote_identifier(col));
185        sql.push_str(&r);
186        match client.query(&sql, &[]).await {
187            Ok(rows) => {
188                client.execute("COMMIT", &[]).await?;
189                let ids: Vec<Value> = rows.iter().map(|r| {
190                    r.try_get::<_, i64>(0).map(|id| json!(id))
191                        .or_else(|_| r.try_get::<_, i32>(0).map(|id| json!(id)))
192                        .unwrap_or(json!(null))
193                }).collect();
194                json!({ "rows_affected": ids.len(), "inserted_ids": ids })
195            }
196            Err(e) => {
197                client.execute("ROLLBACK", &[]).await.ok();
198                return Err(MCPError::DatabaseError(e));
199            }
200        }
201    } else {
202        match client.execute(&sql, &[]).await {
203            Ok(rows_affected) => {
204                client.execute("COMMIT", &[]).await?;
205                json!({ "rows_affected": rows_affected })
206            }
207            Err(e) => {
208                client.execute("ROLLBACK", &[]).await.ok();
209                return Err(MCPError::DatabaseError(e));
210            }
211        }
212    };
213
214    Ok(result)
215}
216
217/// Batch update - bulk updates with structured WHERE conditions
218pub async fn async_batch_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
219    let params = params.as_ref().ok_or_else(|| {
220        MCPError::InvalidParams("Missing parameters".into())
221    })?;
222
223    let table = params
224        .get("table")
225        .and_then(|v| v.as_str())
226        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
227
228    let updates = params
229        .get("updates")
230        .and_then(|v| v.as_object())
231        .ok_or_else(|| MCPError::InvalidParams("Missing 'updates'".into()))?;
232
233    let where_clauses = params
234        .get("where_clauses")
235        .and_then(|v| v.as_array())
236        .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
237
238    validate_identifier(table, "table")?;
239    let parsed_where = validate_where_clauses(where_clauses)?;
240
241    let quoted_table = quote_identifier(table);
242    let mut set_clauses = Vec::new();
243    for (key, val) in updates {
244        validate_identifier(key, "updates key")?;
245        set_clauses.push(format!("{} = {}", quote_identifier(key), format_sql_value(val)));
246    }
247
248    let where_sql = build_where_sql(&parsed_where);
249    let sql = format!("UPDATE {quoted_table} SET {} WHERE {where_sql}", set_clauses.join(", "));
250
251    let rows_affected = client.execute(&sql, &[]).await?;
252
253    Ok(json!({ "rows_affected": rows_affected }))
254}
255
256/// Batch delete - bulk deletion with structured WHERE conditions
257pub async fn async_batch_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
258    let params = params.as_ref().ok_or_else(|| {
259        MCPError::InvalidParams("Missing parameters".into())
260    })?;
261
262    let table = params
263        .get("table")
264        .and_then(|v| v.as_str())
265        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
266
267    let where_clauses = params
268        .get("where_clauses")
269        .and_then(|v| v.as_array())
270        .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
271
272    validate_identifier(table, "table")?;
273    let parsed_where = validate_where_clauses(where_clauses)?;
274
275    let returning = params.get("returning").and_then(|v| v.as_str());
276
277    let quoted_table = quote_identifier(table);
278    let where_sql = build_where_sql(&parsed_where);
279    let mut sql = format!("DELETE FROM {quoted_table} WHERE {where_sql}");
280
281    if let Some(col) = returning {
282        validate_identifier(col, "returning")?;
283        sql.push_str(&format!(" RETURNING {}", quote_identifier(col)));
284        let rows = client.query(&sql, &[]).await?;
285        let ids: Vec<Value> = rows.iter().map(|r| {
286            r.try_get::<_, i64>(0).map(|id| json!(id))
287                .or_else(|_| r.try_get::<_, i32>(0).map(|id| json!(id)))
288                .unwrap_or(json!(null))
289        }).collect();
290        Ok(json!({ "rows_affected": ids.len(), "inserted_ids": ids }))
291    } else {
292        let rows_affected = client.execute(&sql, &[]).await?;
293        Ok(json!({ "rows_affected": rows_affected }))
294    }
295}
296
297/// Batch insert with auto-batching for massive loads
298pub async fn async_batch_insert_copy(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
299    let params = params.as_ref().ok_or_else(|| {
300        MCPError::InvalidParams("Missing parameters".into())
301    })?;
302
303    let table = params
304        .get("table")
305        .and_then(|v| v.as_str())
306        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
307
308    let columns = params
309        .get("columns")
310        .and_then(|v| v.as_array())
311        .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
312
313    let rows = params
314        .get("rows")
315        .and_then(|v| v.as_array())
316        .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
317
318    let batch_size = params
319        .get("batch_size")
320        .and_then(|v| v.as_u64())
321        .unwrap_or(1000) as usize;
322
323    if rows.is_empty() {
324        return Ok(json!({"rows_affected": 0}));
325    }
326
327    if rows.len() > MAX_BATCH_ROWS {
328        return Err(MCPError::InvalidParams(
329            format!("Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})", rows.len())
330        ));
331    }
332
333    let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
334    validate_table_columns(table, &column_names)?;
335
336    let quoted_table = quote_identifier(table);
337    let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_identifier(c)).collect();
338
339    let mut total_affected = 0u64;
340
341    for batch in rows.chunks(batch_size) {
342        let mut sql = format!("INSERT INTO {quoted_table} ({}) VALUES ", quoted_cols.join(", "));
343        let mut value_parts = Vec::new();
344
345        for row in batch {
346            let row_array = row.as_array().ok_or_else(|| {
347                MCPError::InvalidParams("Each row must be an array".into())
348            })?;
349
350            let row_values: Vec<String> = row_array.iter().map(format_sql_value).collect();
351            value_parts.push(format!("({})", row_values.join(", ")));
352        }
353
354        sql.push_str(&value_parts.join(", "));
355
356        let rows_affected = client.execute(&sql, &[]).await?;
357        total_affected += rows_affected;
358    }
359
360    Ok(json!({
361        "rows_affected": total_affected,
362        "batches": (rows.len() as f64 / batch_size as f64).ceil() as u32
363    }))
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_format_sql_value() {
372        assert_eq!(format_sql_value(&Value::String("test".into())), "'test'");
373        assert_eq!(format_sql_value(&Value::Number(123.into())), "123");
374        assert_eq!(format_sql_value(&Value::Bool(true)), "true");
375        assert_eq!(format_sql_value(&Value::Null), "NULL");
376    }
377
378    #[test]
379    fn test_sql_injection_prevention() {
380        let malicious = Value::String("'; DROP TABLE users; --".into());
381        let result = format_sql_value(&malicious);
382        assert_eq!(result, "'''; DROP TABLE users; --'");
383    }
384
385    #[test]
386    fn test_validate_table_columns_rejects_injection() {
387        let result = validate_table_columns("users; DROP TABLE", &["id"]);
388        assert!(result.is_err());
389        assert!(result.unwrap_err().to_string().contains("invalid character"));
390    }
391
392    #[test]
393    fn test_validate_table_columns_rejects_sql_in_column() {
394        let result = validate_table_columns("users", &["id; DROP TABLE users"]);
395        assert!(result.is_err());
396    }
397
398    #[test]
399    fn test_validate_table_columns_accepts_valid() {
400        assert!(validate_table_columns("users", &["id", "name"]).is_ok());
401    }
402
403    #[test]
404    fn test_validate_where_clauses_accepts_structured() {
405        let clauses = vec![
406            json!({"column": "id", "op": "=", "value": 1}),
407            json!({"column": "status", "op": "IN", "value": ["active", "pending"]}),
408        ];
409        let result = validate_where_clauses(&clauses);
410        assert!(result.is_ok());
411    }
412
413    #[test]
414    fn test_validate_where_clauses_rejects_invalid_op() {
415        let clauses = vec![
416            json!({"column": "id", "op": "EXECUTE", "value": "malicious"}),
417        ];
418        let result = validate_where_clauses(&clauses);
419        assert!(result.is_err());
420        assert!(result.unwrap_err().to_string().contains("Invalid operator"));
421    }
422
423    #[test]
424    fn test_validate_where_clauses_rejects_sql_in_column() {
425        let clauses = vec![
426            json!({"column": "id; DROP TABLE", "op": "=", "value": 1}),
427        ];
428        let result = validate_where_clauses(&clauses);
429        assert!(result.is_err());
430    }
431
432    #[test]
433    fn test_build_where_sql() {
434        let v1 = Value::Number(1.into());
435        let v2 = Value::String("active".into());
436        let parsed = vec![
437            ("id".to_string(), "=".to_string(), &v1),
438            ("status".to_string(), "=".to_string(), &v2),
439        ];
440        let sql = build_where_sql(&parsed);
441        assert_eq!(sql, r#""id" = 1 OR "status" = 'active'"#);
442    }
443
444    #[test]
445    fn test_build_where_sql_in_op() {
446        let values = json!(["a", "b"]);
447        let parsed = vec![
448            ("status".to_string(), "IN".to_string(), &values),
449        ];
450        let sql = build_where_sql(&parsed);
451        assert_eq!(sql, r#""status" IN ('a', 'b')"#);
452    }
453}