Skip to main content

mcp_postgres/actions/
batch.rs

1use crate::errors::{MCPError, Result as MCPResult};
2use crate::validation::{quote_ident, validate_identifier};
3use serde_json::{Value, json};
4use tokio_postgres::Client;
5
6const MAX_BATCH_ROWS: usize = 1000;
7const ALLOWED_OPS: &[&str] = &["=", "<", ">", "<=", ">=", "<>", "IN", "LIKE"];
8
9fn format_sql_value(val: &Value) -> String {
10    match val {
11        Value::String(s) => format!("'{}'", s.replace("'", "''")),
12        Value::Number(n) => n.to_string(),
13        Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
14        Value::Null => "NULL".to_string(),
15        Value::Array(_) | Value::Object(_) => format!("'{}'", val.to_string().replace("'", "''")),
16    }
17}
18
19fn validate_table_columns(table: &str, columns: &[&str]) -> Result<(), MCPError> {
20    validate_identifier(table, "table")?;
21    for col in columns {
22        validate_identifier(col, "column")?;
23    }
24    Ok(())
25}
26
27fn validate_where_clauses(
28    where_clauses: &[Value],
29) -> Result<Vec<(String, String, &Value)>, MCPError> {
30    if where_clauses.is_empty() {
31        return Err(MCPError::InvalidParams(
32            "'where_clauses' must not be empty".into(),
33        ));
34    }
35    let mut parsed = Vec::new();
36    for clause in where_clauses {
37        let obj = clause.as_object().ok_or_else(|| {
38            MCPError::InvalidParams(
39                "Each where_clause must be an object with 'column', 'op', and 'value'".into(),
40            )
41        })?;
42        let column = obj.get("column").and_then(|v| v.as_str()).ok_or_else(|| {
43            MCPError::InvalidParams("Each where_clause must have a string 'column'".into())
44        })?;
45        let op = obj.get("op").and_then(|v| v.as_str()).ok_or_else(|| {
46            MCPError::InvalidParams("Each where_clause must have a string 'op'".into())
47        })?;
48        let value = obj.get("value").ok_or_else(|| {
49            MCPError::InvalidParams("Each where_clause must have a 'value'".into())
50        })?;
51        validate_identifier(column, "where_clause.column")?;
52        if !ALLOWED_OPS.contains(&op) {
53            return Err(MCPError::InvalidParams(format!(
54                "Invalid operator '{op}' — allowed: {}",
55                ALLOWED_OPS.join(", ")
56            )));
57        }
58        parsed.push((column.to_string(), op.to_string(), value));
59    }
60    Ok(parsed)
61}
62
63fn build_where_sql(parsed: &[(String, String, &Value)]) -> String {
64    parsed
65        .iter()
66        .map(|(col, op, val)| {
67            if op == "IN" {
68                if let Some(arr) = val.as_array() {
69                    let items: Vec<String> = arr.iter().map(format_sql_value).collect();
70                    format!("{} IN ({})", quote_ident(col), items.join(", "))
71                } else {
72                    format!("{} {} {}", quote_ident(col), op, format_sql_value(val))
73                }
74            } else {
75                format!("{} {} {}", quote_ident(col), op, format_sql_value(val))
76            }
77        })
78        .collect::<Vec<_>>()
79        .join(" AND ")
80}
81
82/// Batch insert - high performance multi-row insertion
83/// Uses SET LOCAL inside a transaction to avoid session-level side effects.
84pub async fn async_batch_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
85    let params = params
86        .as_ref()
87        .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
88
89    let table = params
90        .get("table")
91        .and_then(|v| v.as_str())
92        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
93
94    let columns = params
95        .get("columns")
96        .and_then(|v| v.as_array())
97        .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
98
99    let rows = params
100        .get("rows")
101        .and_then(|v| v.as_array())
102        .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
103
104    if rows.is_empty() {
105        return Ok(json!({ "rows_affected": 0 }));
106    }
107
108    if rows.len() > MAX_BATCH_ROWS {
109        return Err(MCPError::InvalidParams(format!(
110            "Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})",
111            rows.len()
112        )));
113    }
114
115    let returning = params.get("returning").and_then(|v| v.as_str());
116
117    let column_count = columns.len();
118    let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
119
120    if column_names.len() != column_count {
121        return Err(MCPError::InvalidParams(
122            "All column names must be strings".into(),
123        ));
124    }
125
126    validate_table_columns(table, &column_names)?;
127
128    let quoted_table = quote_ident(table);
129    let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_ident(c)).collect();
130    let cols = quoted_cols.join(", ");
131
132    let mut sql = String::with_capacity(64 + cols.len() + rows.len() * (column_count * 16 + 4));
133    use std::fmt::Write;
134    write!(sql, "INSERT INTO {quoted_table} ({cols}) VALUES ").unwrap();
135
136    for (i, row) in rows.iter().enumerate() {
137        let row_array = row
138            .as_array()
139            .ok_or_else(|| MCPError::InvalidParams("Each row must be an array".into()))?;
140
141        if row_array.len() != column_count {
142            return Err(MCPError::InvalidParams(format!(
143                "Row {} has {} columns, expected {}",
144                i,
145                row_array.len(),
146                column_count
147            )));
148        }
149
150        if i > 0 {
151            sql.push(',');
152        }
153        sql.push('(');
154        for (j, val) in row_array.iter().enumerate() {
155            if j > 0 {
156                sql.push_str(", ");
157            }
158            match val {
159                Value::String(s) => {
160                    sql.push('\'');
161                    for ch in s.chars() {
162                        if ch == '\'' {
163                            sql.push_str("''");
164                        } else {
165                            sql.push(ch);
166                        }
167                    }
168                    sql.push('\'');
169                }
170                Value::Number(n) => {
171                    write!(sql, "{n}").unwrap();
172                }
173                Value::Bool(b) => {
174                    sql.push_str(if *b { "true" } else { "false" });
175                }
176                Value::Null => {
177                    sql.push_str("NULL");
178                }
179                Value::Array(_) | Value::Object(_) => {
180                    let s = val.to_string();
181                    sql.push('\'');
182                    for ch in s.chars() {
183                        if ch == '\'' {
184                            sql.push_str("''");
185                        } else {
186                            sql.push(ch);
187                        }
188                    }
189                    sql.push('\'');
190                }
191            }
192        }
193        sql.push(')');
194    }
195
196    client.execute("BEGIN", &[]).await?;
197    client
198        .execute("SET LOCAL synchronous_commit = OFF", &[])
199        .await?;
200
201    let result = if let Some(col) = returning {
202        validate_identifier(col, "returning")?;
203        let r = format!(" RETURNING {}", quote_ident(col));
204        sql.push_str(&r);
205        match client.query(&sql, &[]).await {
206            Ok(rows) => {
207                client.execute("COMMIT", &[]).await?;
208                let ids: Vec<Value> = rows
209                    .iter()
210                    .map(|r| {
211                        r.try_get::<_, i64>(0)
212                            .map(|id| json!(id))
213                            .or_else(|_| r.try_get::<_, i32>(0).map(|id| json!(id)))
214                            .unwrap_or(json!(null))
215                    })
216                    .collect();
217                json!({ "rows_affected": ids.len(), "inserted_ids": ids })
218            }
219            Err(e) => {
220                client.execute("ROLLBACK", &[]).await.ok();
221                return Err(MCPError::DatabaseError(e));
222            }
223        }
224    } else {
225        match client.execute(&sql, &[]).await {
226            Ok(rows_affected) => {
227                client.execute("COMMIT", &[]).await?;
228                json!({ "rows_affected": rows_affected })
229            }
230            Err(e) => {
231                client.execute("ROLLBACK", &[]).await.ok();
232                return Err(MCPError::DatabaseError(e));
233            }
234        }
235    };
236
237    Ok(result)
238}
239
240/// Batch update - bulk updates with structured WHERE conditions
241pub async fn async_batch_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
242    let params = params
243        .as_ref()
244        .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
245
246    let table = params
247        .get("table")
248        .and_then(|v| v.as_str())
249        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
250
251    let updates = params
252        .get("updates")
253        .and_then(|v| v.as_object())
254        .ok_or_else(|| MCPError::InvalidParams("Missing 'updates'".into()))?;
255
256    let where_clauses = params
257        .get("where_clauses")
258        .and_then(|v| v.as_array())
259        .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
260
261    validate_identifier(table, "table")?;
262    let parsed_where = validate_where_clauses(where_clauses)?;
263
264    let quoted_table = quote_ident(table);
265    let mut set_clauses = Vec::new();
266    for (key, val) in updates {
267        validate_identifier(key, "updates key")?;
268        set_clauses.push(format!("{} = {}", quote_ident(key), format_sql_value(val)));
269    }
270
271    let where_sql = build_where_sql(&parsed_where);
272    let sql = format!(
273        "UPDATE {quoted_table} SET {} WHERE {where_sql}",
274        set_clauses.join(", ")
275    );
276
277    let rows_affected = client.execute(&sql, &[]).await?;
278
279    Ok(json!({ "rows_affected": rows_affected }))
280}
281
282/// Batch delete - bulk deletion with structured WHERE conditions
283pub async fn async_batch_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
284    let params = params
285        .as_ref()
286        .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
287
288    let table = params
289        .get("table")
290        .and_then(|v| v.as_str())
291        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
292
293    let where_clauses = params
294        .get("where_clauses")
295        .and_then(|v| v.as_array())
296        .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
297
298    validate_identifier(table, "table")?;
299    let parsed_where = validate_where_clauses(where_clauses)?;
300
301    let returning = params.get("returning").and_then(|v| v.as_str());
302
303    let quoted_table = quote_ident(table);
304    let where_sql = build_where_sql(&parsed_where);
305    let mut sql = format!("DELETE FROM {quoted_table} WHERE {where_sql}");
306
307    if let Some(col) = returning {
308        validate_identifier(col, "returning")?;
309        sql.push_str(&format!(" RETURNING {}", quote_ident(col)));
310        let rows = client.query(&sql, &[]).await?;
311        let ids: Vec<Value> = rows
312            .iter()
313            .map(|r| {
314                r.try_get::<_, i64>(0)
315                    .map(|id| json!(id))
316                    .or_else(|_| r.try_get::<_, i32>(0).map(|id| json!(id)))
317                    .unwrap_or(json!(null))
318            })
319            .collect();
320        Ok(json!({ "rows_affected": ids.len(), "inserted_ids": ids }))
321    } else {
322        let rows_affected = client.execute(&sql, &[]).await?;
323        Ok(json!({ "rows_affected": rows_affected }))
324    }
325}
326
327/// Maximum rows in a single async_batch_insert_copy request. Memory is bounded
328/// by per-chunk `batch_size` (default 1,000, max 5,000) and the JSON-RPC request
329/// size cap (16 MiB from `MAX_REQUEST_BYTES`). This allows truly large imports
330/// where the caller provides 10K+ rows, unlike `async_batch_insert` which has
331/// a hard cap of 1,000 and builds one SQL statement.
332const MAX_BATCH_COPY_ROWS: usize = 100_000;
333
334/// Batch insert with auto-batching for massive loads
335pub async fn async_batch_insert_copy(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
336    let params = params
337        .as_ref()
338        .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
339
340    let table = params
341        .get("table")
342        .and_then(|v| v.as_str())
343        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
344
345    let columns = params
346        .get("columns")
347        .and_then(|v| v.as_array())
348        .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
349
350    let rows = params
351        .get("rows")
352        .and_then(|v| v.as_array())
353        .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
354
355    const MAX_BATCH_SIZE: usize = 5_000;
356    let batch_size = (params
357        .get("batch_size")
358        .and_then(|v| v.as_u64())
359        .unwrap_or(1000) as usize)
360        .min(MAX_BATCH_SIZE);
361
362    if rows.is_empty() {
363        return Ok(json!({"rows_affected": 0}));
364    }
365
366    if rows.len() > MAX_BATCH_COPY_ROWS {
367        return Err(MCPError::InvalidParams(format!(
368            "Batch copy size exceeds maximum of {MAX_BATCH_COPY_ROWS} rows (got {})",
369            rows.len()
370        )));
371    }
372
373    let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
374    validate_table_columns(table, &column_names)?;
375
376    let quoted_table = quote_ident(table);
377    let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_ident(c)).collect();
378
379    // Wrap the entire import in a transaction with synchronous_commit=OFF
380    // for throughput, matching async_batch_insert behavior.
381    client.execute("BEGIN", &[]).await?;
382    client
383        .execute("SET LOCAL synchronous_commit = OFF", &[])
384        .await?;
385
386    let mut total_affected = 0u64;
387
388    for batch in rows.chunks(batch_size) {
389        let mut sql = format!(
390            "INSERT INTO {quoted_table} ({}) VALUES ",
391            quoted_cols.join(", ")
392        );
393        let mut value_parts = Vec::new();
394
395        for row in batch {
396            let row_array = row
397                .as_array()
398                .ok_or_else(|| MCPError::InvalidParams("Each row must be an array".into()))?;
399
400            let row_values: Vec<String> = row_array.iter().map(format_sql_value).collect();
401            value_parts.push(format!("({})", row_values.join(", ")));
402        }
403
404        sql.push_str(&value_parts.join(", "));
405
406        match client.execute(&sql, &[]).await {
407            Ok(n) => total_affected += n,
408            Err(e) => {
409                client.execute("ROLLBACK", &[]).await.ok();
410                return Err(MCPError::DatabaseError(e));
411            }
412        }
413    }
414
415    client.execute("COMMIT", &[]).await?;
416
417    #[allow(clippy::cast_precision_loss)]
418    let batches = (rows.len() as f64 / batch_size as f64).ceil() as u32;
419    Ok(json!({
420        "rows_affected": total_affected,
421        "batches": batches,
422    }))
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    #[test]
430    fn test_format_sql_value() {
431        assert_eq!(format_sql_value(&Value::String("test".into())), "'test'");
432        assert_eq!(format_sql_value(&Value::Number(123.into())), "123");
433        assert_eq!(format_sql_value(&Value::Bool(true)), "true");
434        assert_eq!(format_sql_value(&Value::Null), "NULL");
435    }
436
437    #[test]
438    fn test_sql_injection_prevention() {
439        let malicious = Value::String("'; DROP TABLE users; --".into());
440        let result = format_sql_value(&malicious);
441        assert_eq!(result, "'''; DROP TABLE users; --'");
442    }
443
444    #[test]
445    fn test_validate_table_columns_rejects_injection() {
446        let result = validate_table_columns("users; DROP TABLE", &["id"]);
447        assert!(result.is_err());
448        assert!(
449            result
450                .unwrap_err()
451                .to_string()
452                .contains("invalid character")
453        );
454    }
455
456    #[test]
457    fn test_validate_table_columns_rejects_sql_in_column() {
458        let result = validate_table_columns("users", &["id; DROP TABLE users"]);
459        assert!(result.is_err());
460    }
461
462    #[test]
463    fn test_validate_table_columns_accepts_valid() {
464        assert!(validate_table_columns("users", &["id", "name"]).is_ok());
465    }
466
467    #[test]
468    fn test_validate_where_clauses_accepts_structured() {
469        let clauses = vec![
470            json!({"column": "id", "op": "=", "value": 1}),
471            json!({"column": "status", "op": "IN", "value": ["active", "pending"]}),
472        ];
473        let result = validate_where_clauses(&clauses);
474        assert!(result.is_ok());
475    }
476
477    #[test]
478    fn test_validate_where_clauses_rejects_invalid_op() {
479        let clauses = vec![json!({"column": "id", "op": "EXECUTE", "value": "malicious"})];
480        let result = validate_where_clauses(&clauses);
481        assert!(result.is_err());
482        assert!(result.unwrap_err().to_string().contains("Invalid operator"));
483    }
484
485    #[test]
486    fn test_validate_where_clauses_rejects_sql_in_column() {
487        let clauses = vec![json!({"column": "id; DROP TABLE", "op": "=", "value": 1})];
488        let result = validate_where_clauses(&clauses);
489        assert!(result.is_err());
490    }
491
492    #[test]
493    fn test_build_where_sql() {
494        let v1 = Value::Number(1.into());
495        let v2 = Value::String("active".into());
496        let parsed = vec![
497            ("id".to_string(), "=".to_string(), &v1),
498            ("status".to_string(), "=".to_string(), &v2),
499        ];
500        let sql = build_where_sql(&parsed);
501        assert_eq!(sql, r#""id" = 1 AND "status" = 'active'"#);
502    }
503
504    #[test]
505    fn test_build_where_sql_in_op() {
506        let values = json!(["a", "b"]);
507        let parsed = vec![("status".to_string(), "IN".to_string(), &values)];
508        let sql = build_where_sql(&parsed);
509        assert_eq!(sql, r#""status" IN ('a', 'b')"#);
510    }
511}