mcp_postgres/actions/
batch.rs1use serde_json::{json, Value};
2use tokio_postgres::Client;
3use crate::errors::Result as MCPResult;
4
5const MAX_BATCH_ROWS: usize = 1000;
6const MAX_IDENTIFIER_LEN: usize = 255;
7
8fn format_sql_value(val: &Value) -> String {
10 match val {
11 Value::String(s) => format!("'{}'", s.replace("'", "''")),
12 Value::Number(n) => n.to_string(),
13 Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
14 Value::Null => "NULL".to_string(),
15 Value::Array(_) | Value::Object(_) => format!("'{}'", val.to_string().replace("'", "''")),
16 }
17}
18
19pub async fn async_batch_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
22 let params = params.as_ref().ok_or_else(|| {
23 crate::errors::MCPError::InvalidParams("Missing parameters".into())
24 })?;
25
26 let table = params
27 .get("table")
28 .and_then(|v| v.as_str())
29 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'table'".into()))?;
30
31 if table.is_empty() || table.len() > MAX_IDENTIFIER_LEN {
32 return Err(crate::errors::MCPError::InvalidParams(
33 format!("'table' must be 1-{MAX_IDENTIFIER_LEN} characters")
34 ));
35 }
36
37 let columns = params
38 .get("columns")
39 .and_then(|v| v.as_array())
40 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'columns'".into()))?;
41
42 let rows = params
43 .get("rows")
44 .and_then(|v| v.as_array())
45 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'rows'".into()))?;
46
47 if rows.is_empty() {
48 return Ok(json!({ "rows_affected": 0 }));
49 }
50
51 if rows.len() > MAX_BATCH_ROWS {
52 return Err(crate::errors::MCPError::InvalidParams(
53 format!("Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})", rows.len())
54 ));
55 }
56
57 let returning = params.get("returning").and_then(|v| v.as_str());
58
59 let column_count = columns.len();
60 let column_names: Vec<&str> = columns
61 .iter()
62 .filter_map(|c| c.as_str())
63 .collect();
64
65 if column_names.len() != column_count {
66 return Err(crate::errors::MCPError::InvalidParams(
67 "All column names must be strings".into(),
68 ));
69 }
70
71 let cols = column_names.join(", ");
73 let total_capacity = 64 + cols.len() + rows.len() * (column_count * 16 + 4);
74 let mut sql = String::with_capacity(total_capacity);
75 use std::fmt::Write;
76 write!(sql, "INSERT INTO {table} ({cols}) VALUES ").unwrap();
77
78 for (i, row) in rows.iter().enumerate() {
79 let row_array = row.as_array().ok_or_else(|| {
80 crate::errors::MCPError::InvalidParams("Each row must be an array".into())
81 })?;
82
83 if row_array.len() != column_count {
84 return Err(crate::errors::MCPError::InvalidParams(
85 format!("Row has {} columns, expected {}", row_array.len(), column_count),
86 ));
87 }
88
89 if i > 0 {
90 sql.push(',');
91 }
92 sql.push('(');
93 for (j, val) in row_array.iter().enumerate() {
94 if j > 0 {
95 sql.push_str(", ");
96 }
97 match val {
98 Value::String(s) => {
99 sql.push('\'');
100 for ch in s.chars() {
101 if ch == '\'' {
102 sql.push_str("''");
103 } else {
104 sql.push(ch);
105 }
106 }
107 sql.push('\'');
108 }
109 Value::Number(n) => {
110 write!(sql, "{n}").unwrap();
111 }
112 Value::Bool(b) => {
113 sql.push_str(if *b { "true" } else { "false" });
114 }
115 Value::Null => {
116 sql.push_str("NULL");
117 }
118 Value::Array(_) | Value::Object(_) => {
119 let s = val.to_string();
120 sql.push('\'');
121 for ch in s.chars() {
122 if ch == '\'' {
123 sql.push_str("''");
124 } else {
125 sql.push(ch);
126 }
127 }
128 sql.push('\'');
129 }
130 }
131 }
132 sql.push(')');
133 }
134
135 let orig_sync = client
138 .query_one("SHOW synchronous_commit", &[])
139 .await
140 .map(|r| r.get::<_, String>(0))
141 .unwrap_or_else(|_| "on".to_string());
142 client.execute("SET synchronous_commit = OFF", &[]).await?;
143
144 let result = if let Some(col) = returning {
145 let r = format!(" RETURNING {}", col);
146 sql.push_str(&r);
147 let rows = client.query(&sql, &[]).await;
148 client
149 .execute(&format!("SET synchronous_commit = {}", orig_sync), &[])
150 .await
151 .ok();
152 let rows = rows?;
153 let ids: Vec<Value> = rows.iter().map(|r| {
154 if let Ok(id) = r.try_get::<_, i64>(0) {
155 json!(id)
156 } else if let Ok(id) = r.try_get::<_, i32>(0) {
157 json!(id)
158 } else {
159 json!(null)
160 }
161 }).collect();
162 json!({
163 "rows_affected": ids.len(),
164 "inserted_ids": ids
165 })
166 } else {
167 let rows_affected = client.execute(&sql, &[]).await;
168 client
169 .execute(&format!("SET synchronous_commit = {}", orig_sync), &[])
170 .await
171 .ok();
172 json!({
173 "rows_affected": rows_affected?
174 })
175 };
176
177 Ok(result)
178}
179
180pub async fn async_batch_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
182 let params = params.as_ref().ok_or_else(|| {
183 crate::errors::MCPError::InvalidParams("Missing parameters".into())
184 })?;
185
186 let table = params
187 .get("table")
188 .and_then(|v| v.as_str())
189 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'table'".into()))?;
190
191 if table.is_empty() || table.len() > MAX_IDENTIFIER_LEN {
192 return Err(crate::errors::MCPError::InvalidParams(
193 format!("'table' must be 1-{MAX_IDENTIFIER_LEN} characters")
194 ));
195 }
196
197 let updates = params
198 .get("updates")
199 .and_then(|v| v.as_object())
200 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'updates'".into()))?;
201
202 let where_clauses = params
203 .get("where_clauses")
204 .and_then(|v| v.as_array())
205 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
206
207 if where_clauses.is_empty() {
208 return Ok(json!({ "rows_affected": 0 }));
209 }
210
211 let mut total_affected = 0u64;
212
213 for where_clause in where_clauses {
214 let where_str = where_clause
215 .as_str()
216 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Where clause must be string".into()))?;
217
218 let mut set_clauses = Vec::new();
219 for (key, val) in updates {
220 let val_str = format_sql_value(val);
221 set_clauses.push(format!("{} = {}", key, val_str));
222 }
223
224 let sql = format!(
225 "UPDATE {} SET {} WHERE {}",
226 table,
227 set_clauses.join(", "),
228 where_str
229 );
230
231 let rows_affected = client.execute(&sql, &[]).await?;
232 total_affected += rows_affected;
233 }
234
235 Ok(json!({
236 "rows_affected": total_affected
237 }))
238}
239
240pub async fn async_batch_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
242 let params = params.as_ref().ok_or_else(|| {
243 crate::errors::MCPError::InvalidParams("Missing parameters".into())
244 })?;
245
246 let table = params
247 .get("table")
248 .and_then(|v| v.as_str())
249 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'table'".into()))?;
250
251 if table.is_empty() || table.len() > MAX_IDENTIFIER_LEN {
252 return Err(crate::errors::MCPError::InvalidParams(
253 format!("'table' must be 1-{MAX_IDENTIFIER_LEN} characters")
254 ));
255 }
256
257 let where_clauses = params
258 .get("where_clauses")
259 .and_then(|v| v.as_array())
260 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
261
262 if where_clauses.is_empty() {
263 return Ok(json!({ "rows_affected": 0 }));
264 }
265
266 let returning = params.get("returning").and_then(|v| v.as_str());
267
268 let where_conditions: Vec<String> = where_clauses
269 .iter()
270 .filter_map(|c| c.as_str().map(|s| format!("({})", s)))
271 .collect();
272
273 let mut sql = format!(
274 "DELETE FROM {} WHERE {}",
275 table,
276 where_conditions.join(" OR ")
277 );
278
279 if let Some(col) = returning {
280 sql.push_str(&format!(" RETURNING {}", col));
281 let rows = client.query(&sql, &[]).await?;
282 let ids: Vec<Value> = rows.iter().map(|r| {
283 if let Ok(id) = r.try_get::<_, i64>(0) {
284 json!(id)
285 } else if let Ok(id) = r.try_get::<_, i32>(0) {
286 json!(id)
287 } else {
288 json!(null)
289 }
290 }).collect();
291 Ok(json!({
292 "rows_affected": ids.len(),
293 "inserted_ids": ids
294 }))
295 } else {
296 let rows_affected = client.execute(&sql, &[]).await?;
297 Ok(json!({
298 "rows_affected": rows_affected
299 }))
300 }
301}
302
303pub async fn async_batch_insert_copy(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
305 let params = params.as_ref().ok_or_else(|| {
306 crate::errors::MCPError::InvalidParams("Missing parameters".into())
307 })?;
308
309 let table = params
310 .get("table")
311 .and_then(|v| v.as_str())
312 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'table'".into()))?;
313
314 if table.is_empty() || table.len() > MAX_IDENTIFIER_LEN {
315 return Err(crate::errors::MCPError::InvalidParams(
316 format!("'table' must be 1-{MAX_IDENTIFIER_LEN} characters")
317 ));
318 }
319
320 let columns = params
321 .get("columns")
322 .and_then(|v| v.as_array())
323 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'columns'".into()))?;
324
325 let rows = params
326 .get("rows")
327 .and_then(|v| v.as_array())
328 .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'rows'".into()))?;
329
330 let batch_size = params
331 .get("batch_size")
332 .and_then(|v| v.as_u64())
333 .unwrap_or(1000) as usize;
334
335 if rows.is_empty() {
336 return Ok(json!({"rows_affected": 0}));
337 }
338
339 if rows.len() > MAX_BATCH_ROWS {
340 return Err(crate::errors::MCPError::InvalidParams(
341 format!("Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})", rows.len())
342 ));
343 }
344
345 let column_names: Vec<&str> = columns
346 .iter()
347 .filter_map(|c| c.as_str())
348 .collect();
349
350 let mut total_affected = 0u64;
351
352 for batch in rows.chunks(batch_size) {
354 let mut sql = format!("INSERT INTO {} ({}) VALUES ", table, column_names.join(", "));
355 let mut value_parts = Vec::new();
356
357 for row in batch {
358 let row_array = row.as_array().ok_or_else(|| {
359 crate::errors::MCPError::InvalidParams("Each row must be an array".into())
360 })?;
361
362 let row_values: Vec<String> = row_array
363 .iter()
364 .map(format_sql_value)
365 .collect();
366
367 value_parts.push(format!("({})", row_values.join(", ")));
368 }
369
370 sql.push_str(&value_parts.join(", "));
371
372 let rows_affected = client.execute(&sql, &[]).await?;
373 total_affected += rows_affected;
374 }
375
376 Ok(json!({
377 "rows_affected": total_affected,
378 "batches": (rows.len() as f64 / batch_size as f64).ceil() as u32
379 }))
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
387 fn test_format_sql_value() {
388 assert_eq!(format_sql_value(&Value::String("test".into())), "'test'");
389 assert_eq!(format_sql_value(&Value::Number(123.into())), "123");
390 assert_eq!(format_sql_value(&Value::Bool(true)), "true");
391 assert_eq!(format_sql_value(&Value::Null), "NULL");
392 }
393
394 #[test]
395 fn test_sql_injection_prevention() {
396 let malicious = Value::String("'; DROP TABLE users; --".into());
397 let result = format_sql_value(&malicious);
398 assert_eq!(result, "'''; DROP TABLE users; --'");
399 }
400}