Skip to main content

mcp_postgres/actions/
batch.rs

1use serde_json::{json, Value};
2use tokio_postgres::Client;
3use crate::errors::Result as MCPResult;
4
5const MAX_BATCH_ROWS: usize = 1000;
6const MAX_IDENTIFIER_LEN: usize = 255;
7
8/// Format JSON value as SQL-safe string
9fn format_sql_value(val: &Value) -> String {
10    match val {
11        Value::String(s) => format!("'{}'", s.replace("'", "''")),
12        Value::Number(n) => n.to_string(),
13        Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
14        Value::Null => "NULL".to_string(),
15        Value::Array(_) | Value::Object(_) => format!("'{}'", val.to_string().replace("'", "''")),
16    }
17}
18
19/// Batch insert - high performance multi-row insertion
20/// Applies synchronous_commit = OFF at query level for maximum throughput during bulk loads
21pub async fn async_batch_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
22    let params = params.as_ref().ok_or_else(|| {
23        crate::errors::MCPError::InvalidParams("Missing parameters".into())
24    })?;
25
26    let table = params
27        .get("table")
28        .and_then(|v| v.as_str())
29        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'table'".into()))?;
30
31    if table.is_empty() || table.len() > MAX_IDENTIFIER_LEN {
32        return Err(crate::errors::MCPError::InvalidParams(
33            format!("'table' must be 1-{MAX_IDENTIFIER_LEN} characters")
34        ));
35    }
36
37    let columns = params
38        .get("columns")
39        .and_then(|v| v.as_array())
40        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'columns'".into()))?;
41
42    let rows = params
43        .get("rows")
44        .and_then(|v| v.as_array())
45        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'rows'".into()))?;
46
47    if rows.is_empty() {
48        return Ok(json!({ "rows_affected": 0 }));
49    }
50
51    if rows.len() > MAX_BATCH_ROWS {
52        return Err(crate::errors::MCPError::InvalidParams(
53            format!("Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})", rows.len())
54        ));
55    }
56
57    let returning = params.get("returning").and_then(|v| v.as_str());
58
59    let column_count = columns.len();
60    let column_names: Vec<&str> = columns
61        .iter()
62        .filter_map(|c| c.as_str())
63        .collect();
64
65    if column_names.len() != column_count {
66        return Err(crate::errors::MCPError::InvalidParams(
67            "All column names must be strings".into(),
68        ));
69    }
70
71    // Build VALUES clause
72    let cols = column_names.join(", ");
73    let total_capacity = 64 + cols.len() + rows.len() * (column_count * 16 + 4);
74    let mut sql = String::with_capacity(total_capacity);
75    use std::fmt::Write;
76    write!(sql, "INSERT INTO {table} ({cols}) VALUES ").unwrap();
77
78    for (i, row) in rows.iter().enumerate() {
79        let row_array = row.as_array().ok_or_else(|| {
80            crate::errors::MCPError::InvalidParams("Each row must be an array".into())
81        })?;
82
83        if row_array.len() != column_count {
84            return Err(crate::errors::MCPError::InvalidParams(
85                format!("Row has {} columns, expected {}", row_array.len(), column_count),
86            ));
87        }
88
89        if i > 0 {
90            sql.push(',');
91        }
92        sql.push('(');
93        for (j, val) in row_array.iter().enumerate() {
94            if j > 0 {
95                sql.push_str(", ");
96            }
97            match val {
98                Value::String(s) => {
99                    sql.push('\'');
100                    for ch in s.chars() {
101                        if ch == '\'' {
102                            sql.push_str("''");
103                        } else {
104                            sql.push(ch);
105                        }
106                    }
107                    sql.push('\'');
108                }
109                Value::Number(n) => {
110                    write!(sql, "{n}").unwrap();
111                }
112                Value::Bool(b) => {
113                    sql.push_str(if *b { "true" } else { "false" });
114                }
115                Value::Null => {
116                    sql.push_str("NULL");
117                }
118                Value::Array(_) | Value::Object(_) => {
119                    let s = val.to_string();
120                    sql.push('\'');
121                    for ch in s.chars() {
122                        if ch == '\'' {
123                            sql.push_str("''");
124                        } else {
125                            sql.push(ch);
126                        }
127                    }
128                    sql.push('\'');
129                }
130            }
131        }
132        sql.push(')');
133    }
134
135    // Temporarily disable synchronous commit for bulk insert throughput,
136    // then restore the original setting to avoid session-level side effects.
137    let orig_sync = client
138        .query_one("SHOW synchronous_commit", &[])
139        .await
140        .map(|r| r.get::<_, String>(0))
141        .unwrap_or_else(|_| "on".to_string());
142    client.execute("SET synchronous_commit = OFF", &[]).await?;
143
144    let result = if let Some(col) = returning {
145        let r = format!(" RETURNING {}", col);
146        sql.push_str(&r);
147        let rows = client.query(&sql, &[]).await;
148        client
149            .execute(&format!("SET synchronous_commit = {}", orig_sync), &[])
150            .await
151            .ok();
152        let rows = rows?;
153        let ids: Vec<Value> = rows.iter().map(|r| {
154            if let Ok(id) = r.try_get::<_, i64>(0) {
155                json!(id)
156            } else if let Ok(id) = r.try_get::<_, i32>(0) {
157                json!(id)
158            } else {
159                json!(null)
160            }
161        }).collect();
162        json!({
163            "rows_affected": ids.len(),
164            "inserted_ids": ids
165        })
166    } else {
167        let rows_affected = client.execute(&sql, &[]).await;
168        client
169            .execute(&format!("SET synchronous_commit = {}", orig_sync), &[])
170            .await
171            .ok();
172        json!({
173            "rows_affected": rows_affected?
174        })
175    };
176
177    Ok(result)
178}
179
180/// Batch update - bulk updates with WHERE conditions
181pub async fn async_batch_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
182    let params = params.as_ref().ok_or_else(|| {
183        crate::errors::MCPError::InvalidParams("Missing parameters".into())
184    })?;
185
186    let table = params
187        .get("table")
188        .and_then(|v| v.as_str())
189        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'table'".into()))?;
190
191    if table.is_empty() || table.len() > MAX_IDENTIFIER_LEN {
192        return Err(crate::errors::MCPError::InvalidParams(
193            format!("'table' must be 1-{MAX_IDENTIFIER_LEN} characters")
194        ));
195    }
196
197    let updates = params
198        .get("updates")
199        .and_then(|v| v.as_object())
200        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'updates'".into()))?;
201
202    let where_clauses = params
203        .get("where_clauses")
204        .and_then(|v| v.as_array())
205        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
206
207    if where_clauses.is_empty() {
208        return Ok(json!({ "rows_affected": 0 }));
209    }
210
211    let mut total_affected = 0u64;
212
213    for where_clause in where_clauses {
214        let where_str = where_clause
215            .as_str()
216            .ok_or_else(|| crate::errors::MCPError::InvalidParams("Where clause must be string".into()))?;
217
218        let mut set_clauses = Vec::new();
219        for (key, val) in updates {
220            let val_str = format_sql_value(val);
221            set_clauses.push(format!("{} = {}", key, val_str));
222        }
223
224        let sql = format!(
225            "UPDATE {} SET {} WHERE {}",
226            table,
227            set_clauses.join(", "),
228            where_str
229        );
230
231        let rows_affected = client.execute(&sql, &[]).await?;
232        total_affected += rows_affected;
233    }
234
235    Ok(json!({
236        "rows_affected": total_affected
237    }))
238}
239
240/// Batch delete - bulk deletion with combined WHERE clauses
241pub async fn async_batch_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
242    let params = params.as_ref().ok_or_else(|| {
243        crate::errors::MCPError::InvalidParams("Missing parameters".into())
244    })?;
245
246    let table = params
247        .get("table")
248        .and_then(|v| v.as_str())
249        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'table'".into()))?;
250
251    if table.is_empty() || table.len() > MAX_IDENTIFIER_LEN {
252        return Err(crate::errors::MCPError::InvalidParams(
253            format!("'table' must be 1-{MAX_IDENTIFIER_LEN} characters")
254        ));
255    }
256
257    let where_clauses = params
258        .get("where_clauses")
259        .and_then(|v| v.as_array())
260        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
261
262    if where_clauses.is_empty() {
263        return Ok(json!({ "rows_affected": 0 }));
264    }
265
266    let returning = params.get("returning").and_then(|v| v.as_str());
267
268    let where_conditions: Vec<String> = where_clauses
269        .iter()
270        .filter_map(|c| c.as_str().map(|s| format!("({})", s)))
271        .collect();
272
273    let mut sql = format!(
274        "DELETE FROM {} WHERE {}",
275        table,
276        where_conditions.join(" OR ")
277    );
278
279    if let Some(col) = returning {
280        sql.push_str(&format!(" RETURNING {}", col));
281        let rows = client.query(&sql, &[]).await?;
282        let ids: Vec<Value> = rows.iter().map(|r| {
283            if let Ok(id) = r.try_get::<_, i64>(0) {
284                json!(id)
285            } else if let Ok(id) = r.try_get::<_, i32>(0) {
286                json!(id)
287            } else {
288                json!(null)
289            }
290        }).collect();
291        Ok(json!({
292            "rows_affected": ids.len(),
293            "inserted_ids": ids
294        }))
295    } else {
296        let rows_affected = client.execute(&sql, &[]).await?;
297        Ok(json!({
298            "rows_affected": rows_affected
299        }))
300    }
301}
302
303/// Batch insert with auto-batching for massive loads
304pub async fn async_batch_insert_copy(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
305    let params = params.as_ref().ok_or_else(|| {
306        crate::errors::MCPError::InvalidParams("Missing parameters".into())
307    })?;
308
309    let table = params
310        .get("table")
311        .and_then(|v| v.as_str())
312        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'table'".into()))?;
313
314    if table.is_empty() || table.len() > MAX_IDENTIFIER_LEN {
315        return Err(crate::errors::MCPError::InvalidParams(
316            format!("'table' must be 1-{MAX_IDENTIFIER_LEN} characters")
317        ));
318    }
319
320    let columns = params
321        .get("columns")
322        .and_then(|v| v.as_array())
323        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'columns'".into()))?;
324
325    let rows = params
326        .get("rows")
327        .and_then(|v| v.as_array())
328        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'rows'".into()))?;
329
330    let batch_size = params
331        .get("batch_size")
332        .and_then(|v| v.as_u64())
333        .unwrap_or(1000) as usize;
334
335    if rows.is_empty() {
336        return Ok(json!({"rows_affected": 0}));
337    }
338
339    if rows.len() > MAX_BATCH_ROWS {
340        return Err(crate::errors::MCPError::InvalidParams(
341            format!("Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})", rows.len())
342        ));
343    }
344
345    let column_names: Vec<&str> = columns
346        .iter()
347        .filter_map(|c| c.as_str())
348        .collect();
349
350    let mut total_affected = 0u64;
351
352    // Process in batches
353    for batch in rows.chunks(batch_size) {
354        let mut sql = format!("INSERT INTO {} ({}) VALUES ", table, column_names.join(", "));
355        let mut value_parts = Vec::new();
356
357        for row in batch {
358            let row_array = row.as_array().ok_or_else(|| {
359                crate::errors::MCPError::InvalidParams("Each row must be an array".into())
360            })?;
361
362            let row_values: Vec<String> = row_array
363                .iter()
364                .map(format_sql_value)
365                .collect();
366
367            value_parts.push(format!("({})", row_values.join(", ")));
368        }
369
370        sql.push_str(&value_parts.join(", "));
371
372        let rows_affected = client.execute(&sql, &[]).await?;
373        total_affected += rows_affected;
374    }
375
376    Ok(json!({
377        "rows_affected": total_affected,
378        "batches": (rows.len() as f64 / batch_size as f64).ceil() as u32
379    }))
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn test_format_sql_value() {
388        assert_eq!(format_sql_value(&Value::String("test".into())), "'test'");
389        assert_eq!(format_sql_value(&Value::Number(123.into())), "123");
390        assert_eq!(format_sql_value(&Value::Bool(true)), "true");
391        assert_eq!(format_sql_value(&Value::Null), "NULL");
392    }
393
394    #[test]
395    fn test_sql_injection_prevention() {
396        let malicious = Value::String("'; DROP TABLE users; --".into());
397        let result = format_sql_value(&malicious);
398        assert_eq!(result, "'''; DROP TABLE users; --'");
399    }
400}