Skip to main content

mcp_postgres/actions/
batch.rs

1use serde_json::{json, Value};
2use tokio_postgres::Client;
3use crate::errors::{MCPError, Result as MCPResult};
4use crate::validation::{validate_identifier, quote_identifier};
5
6const MAX_BATCH_ROWS: usize = 1000;
7const ALLOWED_OPS: &[&str] = &["=", "<", ">", "<=", ">=", "<>", "IN", "LIKE"];
8
9fn format_sql_value(val: &Value) -> String {
10    match val {
11        Value::String(s) => format!("'{}'", s.replace("'", "''")),
12        Value::Number(n) => n.to_string(),
13        Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
14        Value::Null => "NULL".to_string(),
15        Value::Array(_) | Value::Object(_) => format!("'{}'", val.to_string().replace("'", "''")),
16    }
17}
18
19fn validate_table_columns(table: &str, columns: &[&str]) -> Result<(), MCPError> {
20    validate_identifier(table, "table")?;
21    for col in columns {
22        validate_identifier(col, "column")?;
23    }
24    Ok(())
25}
26
27fn validate_where_clauses(where_clauses: &[Value]) -> Result<Vec<(String, String, &Value)>, MCPError> {
28    if where_clauses.is_empty() {
29        return Err(MCPError::InvalidParams("'where_clauses' must not be empty".into()));
30    }
31    let mut parsed = Vec::new();
32    for clause in where_clauses {
33        let obj = clause.as_object().ok_or_else(|| {
34            MCPError::InvalidParams("Each where_clause must be an object with 'column', 'op', and 'value'".into())
35        })?;
36        let column = obj.get("column").and_then(|v| v.as_str()).ok_or_else(|| {
37            MCPError::InvalidParams("Each where_clause must have a string 'column'".into())
38        })?;
39        let op = obj.get("op").and_then(|v| v.as_str()).ok_or_else(|| {
40            MCPError::InvalidParams("Each where_clause must have a string 'op'".into())
41        })?;
42        let value = obj.get("value").ok_or_else(|| {
43            MCPError::InvalidParams("Each where_clause must have a 'value'".into())
44        })?;
45        validate_identifier(column, "where_clause.column")?;
46        if !ALLOWED_OPS.contains(&op) {
47            return Err(MCPError::InvalidParams(
48                format!("Invalid operator '{op}' — allowed: {}", ALLOWED_OPS.join(", "))
49            ));
50        }
51        parsed.push((column.to_string(), op.to_string(), value));
52    }
53    Ok(parsed)
54}
55
56fn build_where_sql(parsed: &[(String, String, &Value)]) -> String {
57    parsed.iter().map(|(col, op, val)| {
58        if op == "IN" {
59            if let Some(arr) = val.as_array() {
60                let items: Vec<String> = arr.iter().map(format_sql_value).collect();
61                format!("{} IN ({})", quote_identifier(col), items.join(", "))
62            } else {
63                format!("{} {} {}", quote_identifier(col), op, format_sql_value(val))
64            }
65        } else {
66            format!("{} {} {}", quote_identifier(col), op, format_sql_value(val))
67        }
68    }).collect::<Vec<_>>().join(" OR ")
69}
70
71/// Batch insert - high performance multi-row insertion
72/// Uses SET LOCAL inside a transaction to avoid session-level side effects.
73pub async fn async_batch_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
74    let params = params.as_ref().ok_or_else(|| {
75        MCPError::InvalidParams("Missing parameters".into())
76    })?;
77
78    let table = params
79        .get("table")
80        .and_then(|v| v.as_str())
81        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
82
83    let columns = params
84        .get("columns")
85        .and_then(|v| v.as_array())
86        .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
87
88    let rows = params
89        .get("rows")
90        .and_then(|v| v.as_array())
91        .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
92
93    if rows.is_empty() {
94        return Ok(json!({ "rows_affected": 0 }));
95    }
96
97    if rows.len() > MAX_BATCH_ROWS {
98        return Err(MCPError::InvalidParams(
99            format!("Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})", rows.len())
100        ));
101    }
102
103    let returning = params.get("returning").and_then(|v| v.as_str());
104
105    let column_count = columns.len();
106    let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
107
108    if column_names.len() != column_count {
109        return Err(MCPError::InvalidParams("All column names must be strings".into()));
110    }
111
112    validate_table_columns(table, &column_names)?;
113
114    let quoted_table = quote_identifier(table);
115    let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_identifier(c)).collect();
116    let cols = quoted_cols.join(", ");
117
118    let mut sql = String::with_capacity(64 + cols.len() + rows.len() * (column_count * 16 + 4));
119    use std::fmt::Write;
120    write!(sql, "INSERT INTO {quoted_table} ({cols}) VALUES ").unwrap();
121
122    for (i, row) in rows.iter().enumerate() {
123        let row_array = row.as_array().ok_or_else(|| {
124            MCPError::InvalidParams("Each row must be an array".into())
125        })?;
126
127        if row_array.len() != column_count {
128            return Err(MCPError::InvalidParams(
129                format!("Row {} has {} columns, expected {}", i, row_array.len(), column_count),
130            ));
131        }
132
133        if i > 0 {
134            sql.push(',');
135        }
136        sql.push('(');
137        for (j, val) in row_array.iter().enumerate() {
138            if j > 0 {
139                sql.push_str(", ");
140            }
141            match val {
142                Value::String(s) => {
143                    sql.push('\'');
144                    for ch in s.chars() {
145                        if ch == '\'' {
146                            sql.push_str("''");
147                        } else {
148                            sql.push(ch);
149                        }
150                    }
151                    sql.push('\'');
152                }
153                Value::Number(n) => {
154                    write!(sql, "{n}").unwrap();
155                }
156                Value::Bool(b) => {
157                    sql.push_str(if *b { "true" } else { "false" });
158                }
159                Value::Null => {
160                    sql.push_str("NULL");
161                }
162                Value::Array(_) | Value::Object(_) => {
163                    let s = val.to_string();
164                    sql.push('\'');
165                    for ch in s.chars() {
166                        if ch == '\'' {
167                            sql.push_str("''");
168                        } else {
169                            sql.push(ch);
170                        }
171                    }
172                    sql.push('\'');
173                }
174            }
175        }
176        sql.push(')');
177    }
178
179    client.execute("BEGIN", &[]).await?;
180    client.execute("SET LOCAL synchronous_commit = OFF", &[]).await?;
181
182    let result = if let Some(col) = returning {
183        validate_identifier(col, "returning")?;
184        let r = format!(" RETURNING {}", quote_identifier(col));
185        sql.push_str(&r);
186        match client.query(&sql, &[]).await {
187            Ok(rows) => {
188                client.execute("COMMIT", &[]).await?;
189                let ids: Vec<Value> = rows.iter().map(|r| {
190                    if let Ok(id) = r.try_get::<_, i64>(0) {
191                        json!(id)
192                    } else if let Ok(id) = r.try_get::<_, i32>(0) {
193                        json!(id)
194                    } else {
195                        json!(null)
196                    }
197                }).collect();
198                json!({ "rows_affected": ids.len(), "inserted_ids": ids })
199            }
200            Err(e) => {
201                client.execute("ROLLBACK", &[]).await.ok();
202                return Err(MCPError::DatabaseError(e));
203            }
204        }
205    } else {
206        match client.execute(&sql, &[]).await {
207            Ok(rows_affected) => {
208                client.execute("COMMIT", &[]).await?;
209                json!({ "rows_affected": rows_affected })
210            }
211            Err(e) => {
212                client.execute("ROLLBACK", &[]).await.ok();
213                return Err(MCPError::DatabaseError(e));
214            }
215        }
216    };
217
218    Ok(result)
219}
220
221/// Batch update - bulk updates with structured WHERE conditions
222pub async fn async_batch_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
223    let params = params.as_ref().ok_or_else(|| {
224        MCPError::InvalidParams("Missing parameters".into())
225    })?;
226
227    let table = params
228        .get("table")
229        .and_then(|v| v.as_str())
230        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
231
232    let updates = params
233        .get("updates")
234        .and_then(|v| v.as_object())
235        .ok_or_else(|| MCPError::InvalidParams("Missing 'updates'".into()))?;
236
237    let where_clauses = params
238        .get("where_clauses")
239        .and_then(|v| v.as_array())
240        .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
241
242    validate_identifier(table, "table")?;
243    let parsed_where = validate_where_clauses(where_clauses)?;
244
245    let quoted_table = quote_identifier(table);
246    let mut set_clauses = Vec::new();
247    for (key, val) in updates {
248        validate_identifier(key, "updates key")?;
249        set_clauses.push(format!("{} = {}", quote_identifier(key), format_sql_value(val)));
250    }
251
252    let where_sql = build_where_sql(&parsed_where);
253    let sql = format!("UPDATE {quoted_table} SET {} WHERE {where_sql}", set_clauses.join(", "));
254
255    let rows_affected = client.execute(&sql, &[]).await?;
256
257    Ok(json!({ "rows_affected": rows_affected }))
258}
259
260/// Batch delete - bulk deletion with structured WHERE conditions
261pub async fn async_batch_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
262    let params = params.as_ref().ok_or_else(|| {
263        MCPError::InvalidParams("Missing parameters".into())
264    })?;
265
266    let table = params
267        .get("table")
268        .and_then(|v| v.as_str())
269        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
270
271    let where_clauses = params
272        .get("where_clauses")
273        .and_then(|v| v.as_array())
274        .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
275
276    validate_identifier(table, "table")?;
277    let parsed_where = validate_where_clauses(where_clauses)?;
278
279    let returning = params.get("returning").and_then(|v| v.as_str());
280
281    let quoted_table = quote_identifier(table);
282    let where_sql = build_where_sql(&parsed_where);
283    let mut sql = format!("DELETE FROM {quoted_table} WHERE {where_sql}");
284
285    if let Some(col) = returning {
286        validate_identifier(col, "returning")?;
287        sql.push_str(&format!(" RETURNING {}", quote_identifier(col)));
288        let rows = client.query(&sql, &[]).await?;
289        let ids: Vec<Value> = rows.iter().map(|r| {
290            if let Ok(id) = r.try_get::<_, i64>(0) {
291                json!(id)
292            } else if let Ok(id) = r.try_get::<_, i32>(0) {
293                json!(id)
294            } else {
295                json!(null)
296            }
297        }).collect();
298        Ok(json!({ "rows_affected": ids.len(), "inserted_ids": ids }))
299    } else {
300        let rows_affected = client.execute(&sql, &[]).await?;
301        Ok(json!({ "rows_affected": rows_affected }))
302    }
303}
304
305/// Batch insert with auto-batching for massive loads
306pub async fn async_batch_insert_copy(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
307    let params = params.as_ref().ok_or_else(|| {
308        MCPError::InvalidParams("Missing parameters".into())
309    })?;
310
311    let table = params
312        .get("table")
313        .and_then(|v| v.as_str())
314        .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
315
316    let columns = params
317        .get("columns")
318        .and_then(|v| v.as_array())
319        .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
320
321    let rows = params
322        .get("rows")
323        .and_then(|v| v.as_array())
324        .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
325
326    let batch_size = params
327        .get("batch_size")
328        .and_then(|v| v.as_u64())
329        .unwrap_or(1000) as usize;
330
331    if rows.is_empty() {
332        return Ok(json!({"rows_affected": 0}));
333    }
334
335    if rows.len() > MAX_BATCH_ROWS {
336        return Err(MCPError::InvalidParams(
337            format!("Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})", rows.len())
338        ));
339    }
340
341    let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
342    validate_table_columns(table, &column_names)?;
343
344    let quoted_table = quote_identifier(table);
345    let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_identifier(c)).collect();
346
347    let mut total_affected = 0u64;
348
349    for batch in rows.chunks(batch_size) {
350        let mut sql = format!("INSERT INTO {quoted_table} ({}) VALUES ", quoted_cols.join(", "));
351        let mut value_parts = Vec::new();
352
353        for row in batch {
354            let row_array = row.as_array().ok_or_else(|| {
355                MCPError::InvalidParams("Each row must be an array".into())
356            })?;
357
358            let row_values: Vec<String> = row_array.iter().map(format_sql_value).collect();
359            value_parts.push(format!("({})", row_values.join(", ")));
360        }
361
362        sql.push_str(&value_parts.join(", "));
363
364        let rows_affected = client.execute(&sql, &[]).await?;
365        total_affected += rows_affected;
366    }
367
368    Ok(json!({
369        "rows_affected": total_affected,
370        "batches": (rows.len() as f64 / batch_size as f64).ceil() as u32
371    }))
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_format_sql_value() {
380        assert_eq!(format_sql_value(&Value::String("test".into())), "'test'");
381        assert_eq!(format_sql_value(&Value::Number(123.into())), "123");
382        assert_eq!(format_sql_value(&Value::Bool(true)), "true");
383        assert_eq!(format_sql_value(&Value::Null), "NULL");
384    }
385
386    #[test]
387    fn test_sql_injection_prevention() {
388        let malicious = Value::String("'; DROP TABLE users; --".into());
389        let result = format_sql_value(&malicious);
390        assert_eq!(result, "'''; DROP TABLE users; --'");
391    }
392
393    #[test]
394    fn test_validate_table_columns_rejects_injection() {
395        let result = validate_table_columns("users; DROP TABLE", &["id"]);
396        assert!(result.is_err());
397        assert!(result.unwrap_err().to_string().contains("invalid character"));
398    }
399
400    #[test]
401    fn test_validate_table_columns_rejects_sql_in_column() {
402        let result = validate_table_columns("users", &["id; DROP TABLE users"]);
403        assert!(result.is_err());
404    }
405
406    #[test]
407    fn test_validate_table_columns_accepts_valid() {
408        assert!(validate_table_columns("users", &["id", "name"]).is_ok());
409    }
410
411    #[test]
412    fn test_validate_where_clauses_accepts_structured() {
413        let clauses = vec![
414            json!({"column": "id", "op": "=", "value": 1}),
415            json!({"column": "status", "op": "IN", "value": ["active", "pending"]}),
416        ];
417        let result = validate_where_clauses(&clauses);
418        assert!(result.is_ok());
419    }
420
421    #[test]
422    fn test_validate_where_clauses_rejects_invalid_op() {
423        let clauses = vec![
424            json!({"column": "id", "op": "EXECUTE", "value": "malicious"}),
425        ];
426        let result = validate_where_clauses(&clauses);
427        assert!(result.is_err());
428        assert!(result.unwrap_err().to_string().contains("Invalid operator"));
429    }
430
431    #[test]
432    fn test_validate_where_clauses_rejects_sql_in_column() {
433        let clauses = vec![
434            json!({"column": "id; DROP TABLE", "op": "=", "value": 1}),
435        ];
436        let result = validate_where_clauses(&clauses);
437        assert!(result.is_err());
438    }
439
440    #[test]
441    fn test_build_where_sql() {
442        let v1 = Value::Number(1.into());
443        let v2 = Value::String("active".into());
444        let parsed = vec![
445            ("id".to_string(), "=".to_string(), &v1),
446            ("status".to_string(), "=".to_string(), &v2),
447        ];
448        let sql = build_where_sql(&parsed);
449        assert_eq!(sql, r#""id" = 1 OR "status" = 'active'"#);
450    }
451
452    #[test]
453    fn test_build_where_sql_in_op() {
454        let values = json!(["a", "b"]);
455        let parsed = vec![
456            ("status".to_string(), "IN".to_string(), &values),
457        ];
458        let sql = build_where_sql(&parsed);
459        assert_eq!(sql, r#""status" IN ('a', 'b')"#);
460    }
461}