Skip to main content

mcp_postgres/actions/
batch.rs

1use crate::errors::{MCPError, Result as MCPResult};
2use crate::validation::{quote_ident, validate_identifier};
3use serde_json::{Value, json};
4use tokio_postgres::Client;
5
6const MAX_BATCH_ROWS: usize = 1000;
7const ALLOWED_OPS: &[&str] = &["=", "<", ">", "<=", ">=", "<>", "IN", "LIKE"];
8
9fn format_sql_value(val: &Value) -> String {
10    match val {
11        Value::String(s) => format!("'{}'", s.replace("'", "''")),
12        Value::Number(n) => n.to_string(),
13        Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
14        Value::Null => "NULL".to_string(),
15        Value::Array(_) | Value::Object(_) => format!("'{}'", val.to_string().replace("'", "''")),
16    }
17}
18
19fn validate_table_columns(table: &str, columns: &[&str]) -> Result<(), MCPError> {
20    validate_identifier(table, "table")?;
21    for col in columns {
22        validate_identifier(col, "column")?;
23    }
24    Ok(())
25}
26
27fn validate_where_clauses(
28    where_clauses: &[Value],
29) -> Result<Vec<(String, String, &Value)>, MCPError> {
30    if where_clauses.is_empty() {
31        return Err(MCPError::InvalidParams(
32            "'where_clauses' must not be empty".into(),
33        ));
34    }
35    let mut parsed = Vec::new();
36    for clause in where_clauses {
37        let obj = clause.as_object().ok_or_else(|| {
38            MCPError::InvalidParams(
39                "Each where_clause must be an object with 'column', 'op', and 'value'".into(),
40            )
41        })?;
42        let column = obj.get("column").and_then(|v| v.as_str()).ok_or_else(|| {
43            MCPError::InvalidParams("Each where_clause must have a string 'column'".into())
44        })?;
45        let op = obj.get("op").and_then(|v| v.as_str()).ok_or_else(|| {
46            MCPError::InvalidParams("Each where_clause must have a string 'op'".into())
47        })?;
48        let value = obj.get("value").ok_or_else(|| {
49            MCPError::InvalidParams("Each where_clause must have a 'value'".into())
50        })?;
51        validate_identifier(column, "where_clause.column")?;
52        if !ALLOWED_OPS.contains(&op) {
53            return Err(MCPError::InvalidParams(format!(
54                "Invalid operator '{op}' — allowed: {}",
55                ALLOWED_OPS.join(", ")
56            )));
57        }
58        parsed.push((column.to_string(), op.to_string(), value));
59    }
60    Ok(parsed)
61}
62
63fn build_where_sql(parsed: &[(String, String, &Value)]) -> String {
64    parsed
65        .iter()
66        .map(|(col, op, val)| {
67            if op == "IN" {
68                if let Some(arr) = val.as_array() {
69                    let items: Vec<String> = arr.iter().map(format_sql_value).collect();
70                    format!("{} IN ({})", quote_ident(col), items.join(", "))
71                } else {
72                    format!("{} {} {}", quote_ident(col), op, format_sql_value(val))
73                }
74            } else {
75                format!("{} {} {}", quote_ident(col), op, format_sql_value(val))
76            }
77        })
78        .collect::<Vec<_>>()
79        .join(" OR ")
80}
81
82/// Batch insert - high performance multi-row insertion
83/// Uses SET LOCAL inside a transaction to avoid session-level side effects.
84pub async fn async_batch_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
85    let params = params
86        .as_ref()
87        .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
88
89    let table = params
90        .get("table")
91        .and_then(|v| v.as_str())
92        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
93
94    let columns = params
95        .get("columns")
96        .and_then(|v| v.as_array())
97        .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
98
99    let rows = params
100        .get("rows")
101        .and_then(|v| v.as_array())
102        .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
103
104    if rows.is_empty() {
105        return Ok(json!({ "rows_affected": 0 }));
106    }
107
108    if rows.len() > MAX_BATCH_ROWS {
109        return Err(MCPError::InvalidParams(format!(
110            "Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})",
111            rows.len()
112        )));
113    }
114
115    let returning = params.get("returning").and_then(|v| v.as_str());
116
117    let column_count = columns.len();
118    let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
119
120    if column_names.len() != column_count {
121        return Err(MCPError::InvalidParams(
122            "All column names must be strings".into(),
123        ));
124    }
125
126    validate_table_columns(table, &column_names)?;
127
128    let quoted_table = quote_ident(table);
129    let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_ident(c)).collect();
130    let cols = quoted_cols.join(", ");
131
132    let mut sql = String::with_capacity(64 + cols.len() + rows.len() * (column_count * 16 + 4));
133    use std::fmt::Write;
134    write!(sql, "INSERT INTO {quoted_table} ({cols}) VALUES ").unwrap();
135
136    for (i, row) in rows.iter().enumerate() {
137        let row_array = row
138            .as_array()
139            .ok_or_else(|| MCPError::InvalidParams("Each row must be an array".into()))?;
140
141        if row_array.len() != column_count {
142            return Err(MCPError::InvalidParams(format!(
143                "Row {} has {} columns, expected {}",
144                i,
145                row_array.len(),
146                column_count
147            )));
148        }
149
150        if i > 0 {
151            sql.push(',');
152        }
153        sql.push('(');
154        for (j, val) in row_array.iter().enumerate() {
155            if j > 0 {
156                sql.push_str(", ");
157            }
158            match val {
159                Value::String(s) => {
160                    sql.push('\'');
161                    for ch in s.chars() {
162                        if ch == '\'' {
163                            sql.push_str("''");
164                        } else {
165                            sql.push(ch);
166                        }
167                    }
168                    sql.push('\'');
169                }
170                Value::Number(n) => {
171                    write!(sql, "{n}").unwrap();
172                }
173                Value::Bool(b) => {
174                    sql.push_str(if *b { "true" } else { "false" });
175                }
176                Value::Null => {
177                    sql.push_str("NULL");
178                }
179                Value::Array(_) | Value::Object(_) => {
180                    let s = val.to_string();
181                    sql.push('\'');
182                    for ch in s.chars() {
183                        if ch == '\'' {
184                            sql.push_str("''");
185                        } else {
186                            sql.push(ch);
187                        }
188                    }
189                    sql.push('\'');
190                }
191            }
192        }
193        sql.push(')');
194    }
195
196    client.execute("BEGIN", &[]).await?;
197    client
198        .execute("SET LOCAL synchronous_commit = OFF", &[])
199        .await?;
200
201    let result = if let Some(col) = returning {
202        validate_identifier(col, "returning")?;
203        let r = format!(" RETURNING {}", quote_ident(col));
204        sql.push_str(&r);
205        match client.query(&sql, &[]).await {
206            Ok(rows) => {
207                client.execute("COMMIT", &[]).await?;
208                let ids: Vec<Value> = rows
209                    .iter()
210                    .map(|r| {
211                        r.try_get::<_, i64>(0)
212                            .map(|id| json!(id))
213                            .or_else(|_| r.try_get::<_, i32>(0).map(|id| json!(id)))
214                            .unwrap_or(json!(null))
215                    })
216                    .collect();
217                json!({ "rows_affected": ids.len(), "inserted_ids": ids })
218            }
219            Err(e) => {
220                client.execute("ROLLBACK", &[]).await.ok();
221                return Err(MCPError::DatabaseError(e));
222            }
223        }
224    } else {
225        match client.execute(&sql, &[]).await {
226            Ok(rows_affected) => {
227                client.execute("COMMIT", &[]).await?;
228                json!({ "rows_affected": rows_affected })
229            }
230            Err(e) => {
231                client.execute("ROLLBACK", &[]).await.ok();
232                return Err(MCPError::DatabaseError(e));
233            }
234        }
235    };
236
237    Ok(result)
238}
239
240/// Batch update - bulk updates with structured WHERE conditions
241pub async fn async_batch_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
242    let params = params
243        .as_ref()
244        .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
245
246    let table = params
247        .get("table")
248        .and_then(|v| v.as_str())
249        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
250
251    let updates = params
252        .get("updates")
253        .and_then(|v| v.as_object())
254        .ok_or_else(|| MCPError::InvalidParams("Missing 'updates'".into()))?;
255
256    let where_clauses = params
257        .get("where_clauses")
258        .and_then(|v| v.as_array())
259        .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
260
261    validate_identifier(table, "table")?;
262    let parsed_where = validate_where_clauses(where_clauses)?;
263
264    let quoted_table = quote_ident(table);
265    let mut set_clauses = Vec::new();
266    for (key, val) in updates {
267        validate_identifier(key, "updates key")?;
268        set_clauses.push(format!("{} = {}", quote_ident(key), format_sql_value(val)));
269    }
270
271    let where_sql = build_where_sql(&parsed_where);
272    let sql = format!(
273        "UPDATE {quoted_table} SET {} WHERE {where_sql}",
274        set_clauses.join(", ")
275    );
276
277    let rows_affected = client.execute(&sql, &[]).await?;
278
279    Ok(json!({ "rows_affected": rows_affected }))
280}
281
282/// Batch delete - bulk deletion with structured WHERE conditions
283pub async fn async_batch_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
284    let params = params
285        .as_ref()
286        .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
287
288    let table = params
289        .get("table")
290        .and_then(|v| v.as_str())
291        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
292
293    let where_clauses = params
294        .get("where_clauses")
295        .and_then(|v| v.as_array())
296        .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
297
298    validate_identifier(table, "table")?;
299    let parsed_where = validate_where_clauses(where_clauses)?;
300
301    let returning = params.get("returning").and_then(|v| v.as_str());
302
303    let quoted_table = quote_ident(table);
304    let where_sql = build_where_sql(&parsed_where);
305    let mut sql = format!("DELETE FROM {quoted_table} WHERE {where_sql}");
306
307    if let Some(col) = returning {
308        validate_identifier(col, "returning")?;
309        sql.push_str(&format!(" RETURNING {}", quote_ident(col)));
310        let rows = client.query(&sql, &[]).await?;
311        let ids: Vec<Value> = rows
312            .iter()
313            .map(|r| {
314                r.try_get::<_, i64>(0)
315                    .map(|id| json!(id))
316                    .or_else(|_| r.try_get::<_, i32>(0).map(|id| json!(id)))
317                    .unwrap_or(json!(null))
318            })
319            .collect();
320        Ok(json!({ "rows_affected": ids.len(), "inserted_ids": ids }))
321    } else {
322        let rows_affected = client.execute(&sql, &[]).await?;
323        Ok(json!({ "rows_affected": rows_affected }))
324    }
325}
326
327/// Batch insert with auto-batching for massive loads
328pub async fn async_batch_insert_copy(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
329    let params = params
330        .as_ref()
331        .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
332
333    let table = params
334        .get("table")
335        .and_then(|v| v.as_str())
336        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
337
338    let columns = params
339        .get("columns")
340        .and_then(|v| v.as_array())
341        .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
342
343    let rows = params
344        .get("rows")
345        .and_then(|v| v.as_array())
346        .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
347
348    let batch_size = params
349        .get("batch_size")
350        .and_then(|v| v.as_u64())
351        .unwrap_or(1000) as usize;
352
353    if rows.is_empty() {
354        return Ok(json!({"rows_affected": 0}));
355    }
356
357    if rows.len() > MAX_BATCH_ROWS {
358        return Err(MCPError::InvalidParams(format!(
359            "Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})",
360            rows.len()
361        )));
362    }
363
364    let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
365    validate_table_columns(table, &column_names)?;
366
367    let quoted_table = quote_ident(table);
368    let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_ident(c)).collect();
369
370    let mut total_affected = 0u64;
371
372    for batch in rows.chunks(batch_size) {
373        let mut sql = format!(
374            "INSERT INTO {quoted_table} ({}) VALUES ",
375            quoted_cols.join(", ")
376        );
377        let mut value_parts = Vec::new();
378
379        for row in batch {
380            let row_array = row
381                .as_array()
382                .ok_or_else(|| MCPError::InvalidParams("Each row must be an array".into()))?;
383
384            let row_values: Vec<String> = row_array.iter().map(format_sql_value).collect();
385            value_parts.push(format!("({})", row_values.join(", ")));
386        }
387
388        sql.push_str(&value_parts.join(", "));
389
390        let rows_affected = client.execute(&sql, &[]).await?;
391        total_affected += rows_affected;
392    }
393
394    #[allow(clippy::cast_precision_loss)]
395    let batches = (rows.len() as f64 / batch_size as f64).ceil() as u32;
396    Ok(json!({
397        "rows_affected": total_affected,
398        "batches": batches,
399    }))
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    #[test]
407    fn test_format_sql_value() {
408        assert_eq!(format_sql_value(&Value::String("test".into())), "'test'");
409        assert_eq!(format_sql_value(&Value::Number(123.into())), "123");
410        assert_eq!(format_sql_value(&Value::Bool(true)), "true");
411        assert_eq!(format_sql_value(&Value::Null), "NULL");
412    }
413
414    #[test]
415    fn test_sql_injection_prevention() {
416        let malicious = Value::String("'; DROP TABLE users; --".into());
417        let result = format_sql_value(&malicious);
418        assert_eq!(result, "'''; DROP TABLE users; --'");
419    }
420
421    #[test]
422    fn test_validate_table_columns_rejects_injection() {
423        let result = validate_table_columns("users; DROP TABLE", &["id"]);
424        assert!(result.is_err());
425        assert!(
426            result
427                .unwrap_err()
428                .to_string()
429                .contains("invalid character")
430        );
431    }
432
433    #[test]
434    fn test_validate_table_columns_rejects_sql_in_column() {
435        let result = validate_table_columns("users", &["id; DROP TABLE users"]);
436        assert!(result.is_err());
437    }
438
439    #[test]
440    fn test_validate_table_columns_accepts_valid() {
441        assert!(validate_table_columns("users", &["id", "name"]).is_ok());
442    }
443
444    #[test]
445    fn test_validate_where_clauses_accepts_structured() {
446        let clauses = vec![
447            json!({"column": "id", "op": "=", "value": 1}),
448            json!({"column": "status", "op": "IN", "value": ["active", "pending"]}),
449        ];
450        let result = validate_where_clauses(&clauses);
451        assert!(result.is_ok());
452    }
453
454    #[test]
455    fn test_validate_where_clauses_rejects_invalid_op() {
456        let clauses = vec![json!({"column": "id", "op": "EXECUTE", "value": "malicious"})];
457        let result = validate_where_clauses(&clauses);
458        assert!(result.is_err());
459        assert!(result.unwrap_err().to_string().contains("Invalid operator"));
460    }
461
462    #[test]
463    fn test_validate_where_clauses_rejects_sql_in_column() {
464        let clauses = vec![json!({"column": "id; DROP TABLE", "op": "=", "value": 1})];
465        let result = validate_where_clauses(&clauses);
466        assert!(result.is_err());
467    }
468
469    #[test]
470    fn test_build_where_sql() {
471        let v1 = Value::Number(1.into());
472        let v2 = Value::String("active".into());
473        let parsed = vec![
474            ("id".to_string(), "=".to_string(), &v1),
475            ("status".to_string(), "=".to_string(), &v2),
476        ];
477        let sql = build_where_sql(&parsed);
478        assert_eq!(sql, r#""id" = 1 OR "status" = 'active'"#);
479    }
480
481    #[test]
482    fn test_build_where_sql_in_op() {
483        let values = json!(["a", "b"]);
484        let parsed = vec![("status".to_string(), "IN".to_string(), &values)];
485        let sql = build_where_sql(&parsed);
486        assert_eq!(sql, r#""status" IN ('a', 'b')"#);
487    }
488}