1use crate::errors::{MCPError, Result as MCPResult};
2use crate::validation::{quote_ident, validate_identifier};
3use serde_json::{Value, json};
4use tokio_postgres::Client;
5
6const MAX_BATCH_ROWS: usize = 1000;
7const ALLOWED_OPS: &[&str] = &["=", "<", ">", "<=", ">=", "<>", "IN", "LIKE"];
8
9fn format_sql_value(val: &Value) -> String {
10 match val {
11 Value::String(s) => format!("'{}'", s.replace("'", "''")),
12 Value::Number(n) => n.to_string(),
13 Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
14 Value::Null => "NULL".to_string(),
15 Value::Array(_) | Value::Object(_) => format!("'{}'", val.to_string().replace("'", "''")),
16 }
17}
18
19fn validate_table_columns(table: &str, columns: &[&str]) -> Result<(), MCPError> {
20 validate_identifier(table, "table")?;
21 for col in columns {
22 validate_identifier(col, "column")?;
23 }
24 Ok(())
25}
26
27fn validate_where_clauses(
28 where_clauses: &[Value],
29) -> Result<Vec<(String, String, &Value)>, MCPError> {
30 if where_clauses.is_empty() {
31 return Err(MCPError::InvalidParams(
32 "'where_clauses' must not be empty".into(),
33 ));
34 }
35 let mut parsed = Vec::new();
36 for clause in where_clauses {
37 let obj = clause.as_object().ok_or_else(|| {
38 MCPError::InvalidParams(
39 "Each where_clause must be an object with 'column', 'op', and 'value'".into(),
40 )
41 })?;
42 let column = obj.get("column").and_then(|v| v.as_str()).ok_or_else(|| {
43 MCPError::InvalidParams("Each where_clause must have a string 'column'".into())
44 })?;
45 let op = obj.get("op").and_then(|v| v.as_str()).ok_or_else(|| {
46 MCPError::InvalidParams("Each where_clause must have a string 'op'".into())
47 })?;
48 let value = obj.get("value").ok_or_else(|| {
49 MCPError::InvalidParams("Each where_clause must have a 'value'".into())
50 })?;
51 validate_identifier(column, "where_clause.column")?;
52 if !ALLOWED_OPS.contains(&op) {
53 return Err(MCPError::InvalidParams(format!(
54 "Invalid operator '{op}' — allowed: {}",
55 ALLOWED_OPS.join(", ")
56 )));
57 }
58 parsed.push((column.to_string(), op.to_string(), value));
59 }
60 Ok(parsed)
61}
62
63fn build_where_sql(parsed: &[(String, String, &Value)]) -> String {
64 parsed
65 .iter()
66 .map(|(col, op, val)| {
67 if op == "IN" {
68 if let Some(arr) = val.as_array() {
69 let items: Vec<String> = arr.iter().map(format_sql_value).collect();
70 format!("{} IN ({})", quote_ident(col), items.join(", "))
71 } else {
72 format!("{} {} {}", quote_ident(col), op, format_sql_value(val))
73 }
74 } else {
75 format!("{} {} {}", quote_ident(col), op, format_sql_value(val))
76 }
77 })
78 .collect::<Vec<_>>()
79 .join(" AND ")
80}
81
82pub async fn async_batch_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
85 let params = params
86 .as_ref()
87 .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
88
89 let table = params
90 .get("table")
91 .and_then(|v| v.as_str())
92 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
93
94 let columns = params
95 .get("columns")
96 .and_then(|v| v.as_array())
97 .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
98
99 let rows = params
100 .get("rows")
101 .and_then(|v| v.as_array())
102 .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
103
104 if rows.is_empty() {
105 return Ok(json!({ "rows_affected": 0 }));
106 }
107
108 if rows.len() > MAX_BATCH_ROWS {
109 return Err(MCPError::InvalidParams(format!(
110 "Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})",
111 rows.len()
112 )));
113 }
114
115 let returning = params.get("returning").and_then(|v| v.as_str());
116
117 let column_count = columns.len();
118 let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
119
120 if column_names.len() != column_count {
121 return Err(MCPError::InvalidParams(
122 "All column names must be strings".into(),
123 ));
124 }
125
126 validate_table_columns(table, &column_names)?;
127
128 let quoted_table = quote_ident(table);
129 let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_ident(c)).collect();
130 let cols = quoted_cols.join(", ");
131
132 let mut sql = String::with_capacity(64 + cols.len() + rows.len() * (column_count * 16 + 4));
133 use std::fmt::Write;
134 write!(sql, "INSERT INTO {quoted_table} ({cols}) VALUES ").unwrap();
135
136 for (i, row) in rows.iter().enumerate() {
137 let row_array = row
138 .as_array()
139 .ok_or_else(|| MCPError::InvalidParams("Each row must be an array".into()))?;
140
141 if row_array.len() != column_count {
142 return Err(MCPError::InvalidParams(format!(
143 "Row {} has {} columns, expected {}",
144 i,
145 row_array.len(),
146 column_count
147 )));
148 }
149
150 if i > 0 {
151 sql.push(',');
152 }
153 sql.push('(');
154 for (j, val) in row_array.iter().enumerate() {
155 if j > 0 {
156 sql.push_str(", ");
157 }
158 match val {
159 Value::String(s) => {
160 sql.push('\'');
161 for ch in s.chars() {
162 if ch == '\'' {
163 sql.push_str("''");
164 } else {
165 sql.push(ch);
166 }
167 }
168 sql.push('\'');
169 }
170 Value::Number(n) => {
171 write!(sql, "{n}").unwrap();
172 }
173 Value::Bool(b) => {
174 sql.push_str(if *b { "true" } else { "false" });
175 }
176 Value::Null => {
177 sql.push_str("NULL");
178 }
179 Value::Array(_) | Value::Object(_) => {
180 let s = val.to_string();
181 sql.push('\'');
182 for ch in s.chars() {
183 if ch == '\'' {
184 sql.push_str("''");
185 } else {
186 sql.push(ch);
187 }
188 }
189 sql.push('\'');
190 }
191 }
192 }
193 sql.push(')');
194 }
195
196 client.execute("BEGIN", &[]).await?;
197 client
198 .execute("SET LOCAL synchronous_commit = OFF", &[])
199 .await?;
200
201 let result = if let Some(col) = returning {
202 validate_identifier(col, "returning")?;
203 let r = format!(" RETURNING {}", quote_ident(col));
204 sql.push_str(&r);
205 match client.query(&sql, &[]).await {
206 Ok(rows) => {
207 client.execute("COMMIT", &[]).await?;
208 let ids: Vec<Value> = rows
209 .iter()
210 .map(|r| {
211 r.try_get::<_, i64>(0)
212 .map(|id| json!(id))
213 .or_else(|_| r.try_get::<_, i32>(0).map(|id| json!(id)))
214 .unwrap_or(json!(null))
215 })
216 .collect();
217 json!({ "rows_affected": ids.len(), "inserted_ids": ids })
218 }
219 Err(e) => {
220 client.execute("ROLLBACK", &[]).await.ok();
221 return Err(MCPError::DatabaseError(e));
222 }
223 }
224 } else {
225 match client.execute(&sql, &[]).await {
226 Ok(rows_affected) => {
227 client.execute("COMMIT", &[]).await?;
228 json!({ "rows_affected": rows_affected })
229 }
230 Err(e) => {
231 client.execute("ROLLBACK", &[]).await.ok();
232 return Err(MCPError::DatabaseError(e));
233 }
234 }
235 };
236
237 Ok(result)
238}
239
240pub async fn async_batch_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
242 let params = params
243 .as_ref()
244 .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
245
246 let table = params
247 .get("table")
248 .and_then(|v| v.as_str())
249 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
250
251 let updates = params
252 .get("updates")
253 .and_then(|v| v.as_object())
254 .ok_or_else(|| MCPError::InvalidParams("Missing 'updates'".into()))?;
255
256 let where_clauses = params
257 .get("where_clauses")
258 .and_then(|v| v.as_array())
259 .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
260
261 validate_identifier(table, "table")?;
262 let parsed_where = validate_where_clauses(where_clauses)?;
263
264 let quoted_table = quote_ident(table);
265 let mut set_clauses = Vec::new();
266 for (key, val) in updates {
267 validate_identifier(key, "updates key")?;
268 set_clauses.push(format!("{} = {}", quote_ident(key), format_sql_value(val)));
269 }
270
271 let where_sql = build_where_sql(&parsed_where);
272 let sql = format!(
273 "UPDATE {quoted_table} SET {} WHERE {where_sql}",
274 set_clauses.join(", ")
275 );
276
277 let rows_affected = client.execute(&sql, &[]).await?;
278
279 Ok(json!({ "rows_affected": rows_affected }))
280}
281
282pub async fn async_batch_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
284 let params = params
285 .as_ref()
286 .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
287
288 let table = params
289 .get("table")
290 .and_then(|v| v.as_str())
291 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
292
293 let where_clauses = params
294 .get("where_clauses")
295 .and_then(|v| v.as_array())
296 .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
297
298 validate_identifier(table, "table")?;
299 let parsed_where = validate_where_clauses(where_clauses)?;
300
301 let returning = params.get("returning").and_then(|v| v.as_str());
302
303 let quoted_table = quote_ident(table);
304 let where_sql = build_where_sql(&parsed_where);
305 let mut sql = format!("DELETE FROM {quoted_table} WHERE {where_sql}");
306
307 if let Some(col) = returning {
308 validate_identifier(col, "returning")?;
309 sql.push_str(&format!(" RETURNING {}", quote_ident(col)));
310 let rows = client.query(&sql, &[]).await?;
311 let ids: Vec<Value> = rows
312 .iter()
313 .map(|r| {
314 r.try_get::<_, i64>(0)
315 .map(|id| json!(id))
316 .or_else(|_| r.try_get::<_, i32>(0).map(|id| json!(id)))
317 .unwrap_or(json!(null))
318 })
319 .collect();
320 Ok(json!({ "rows_affected": ids.len(), "inserted_ids": ids }))
321 } else {
322 let rows_affected = client.execute(&sql, &[]).await?;
323 Ok(json!({ "rows_affected": rows_affected }))
324 }
325}
326
327const MAX_BATCH_COPY_ROWS: usize = 100_000;
333
334pub async fn async_batch_insert_copy(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
336 let params = params
337 .as_ref()
338 .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
339
340 let table = params
341 .get("table")
342 .and_then(|v| v.as_str())
343 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
344
345 let columns = params
346 .get("columns")
347 .and_then(|v| v.as_array())
348 .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
349
350 let rows = params
351 .get("rows")
352 .and_then(|v| v.as_array())
353 .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
354
355 const MAX_BATCH_SIZE: usize = 5_000;
356 let batch_size = (params
357 .get("batch_size")
358 .and_then(|v| v.as_u64())
359 .unwrap_or(1000) as usize)
360 .min(MAX_BATCH_SIZE);
361
362 if rows.is_empty() {
363 return Ok(json!({"rows_affected": 0}));
364 }
365
366 if rows.len() > MAX_BATCH_COPY_ROWS {
367 return Err(MCPError::InvalidParams(format!(
368 "Batch copy size exceeds maximum of {MAX_BATCH_COPY_ROWS} rows (got {})",
369 rows.len()
370 )));
371 }
372
373 let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
374 validate_table_columns(table, &column_names)?;
375
376 let quoted_table = quote_ident(table);
377 let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_ident(c)).collect();
378
379 client.execute("BEGIN", &[]).await?;
382 client
383 .execute("SET LOCAL synchronous_commit = OFF", &[])
384 .await?;
385
386 let mut total_affected = 0u64;
387
388 for batch in rows.chunks(batch_size) {
389 let mut sql = format!(
390 "INSERT INTO {quoted_table} ({}) VALUES ",
391 quoted_cols.join(", ")
392 );
393 let mut value_parts = Vec::new();
394
395 for row in batch {
396 let row_array = row
397 .as_array()
398 .ok_or_else(|| MCPError::InvalidParams("Each row must be an array".into()))?;
399
400 let row_values: Vec<String> = row_array.iter().map(format_sql_value).collect();
401 value_parts.push(format!("({})", row_values.join(", ")));
402 }
403
404 sql.push_str(&value_parts.join(", "));
405
406 match client.execute(&sql, &[]).await {
407 Ok(n) => total_affected += n,
408 Err(e) => {
409 client.execute("ROLLBACK", &[]).await.ok();
410 return Err(MCPError::DatabaseError(e));
411 }
412 }
413 }
414
415 client.execute("COMMIT", &[]).await?;
416
417 #[allow(clippy::cast_precision_loss)]
418 let batches = (rows.len() as f64 / batch_size as f64).ceil() as u32;
419 Ok(json!({
420 "rows_affected": total_affected,
421 "batches": batches,
422 }))
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 #[test]
430 fn test_format_sql_value() {
431 assert_eq!(format_sql_value(&Value::String("test".into())), "'test'");
432 assert_eq!(format_sql_value(&Value::Number(123.into())), "123");
433 assert_eq!(format_sql_value(&Value::Bool(true)), "true");
434 assert_eq!(format_sql_value(&Value::Null), "NULL");
435 }
436
437 #[test]
438 fn test_sql_injection_prevention() {
439 let malicious = Value::String("'; DROP TABLE users; --".into());
440 let result = format_sql_value(&malicious);
441 assert_eq!(result, "'''; DROP TABLE users; --'");
442 }
443
444 #[test]
445 fn test_validate_table_columns_rejects_injection() {
446 let result = validate_table_columns("users; DROP TABLE", &["id"]);
447 assert!(result.is_err());
448 assert!(
449 result
450 .unwrap_err()
451 .to_string()
452 .contains("invalid character")
453 );
454 }
455
456 #[test]
457 fn test_validate_table_columns_rejects_sql_in_column() {
458 let result = validate_table_columns("users", &["id; DROP TABLE users"]);
459 assert!(result.is_err());
460 }
461
462 #[test]
463 fn test_validate_table_columns_accepts_valid() {
464 assert!(validate_table_columns("users", &["id", "name"]).is_ok());
465 }
466
467 #[test]
468 fn test_validate_where_clauses_accepts_structured() {
469 let clauses = vec![
470 json!({"column": "id", "op": "=", "value": 1}),
471 json!({"column": "status", "op": "IN", "value": ["active", "pending"]}),
472 ];
473 let result = validate_where_clauses(&clauses);
474 assert!(result.is_ok());
475 }
476
477 #[test]
478 fn test_validate_where_clauses_rejects_invalid_op() {
479 let clauses = vec![json!({"column": "id", "op": "EXECUTE", "value": "malicious"})];
480 let result = validate_where_clauses(&clauses);
481 assert!(result.is_err());
482 assert!(result.unwrap_err().to_string().contains("Invalid operator"));
483 }
484
485 #[test]
486 fn test_validate_where_clauses_rejects_sql_in_column() {
487 let clauses = vec![json!({"column": "id; DROP TABLE", "op": "=", "value": 1})];
488 let result = validate_where_clauses(&clauses);
489 assert!(result.is_err());
490 }
491
492 #[test]
493 fn test_build_where_sql() {
494 let v1 = Value::Number(1.into());
495 let v2 = Value::String("active".into());
496 let parsed = vec![
497 ("id".to_string(), "=".to_string(), &v1),
498 ("status".to_string(), "=".to_string(), &v2),
499 ];
500 let sql = build_where_sql(&parsed);
501 assert_eq!(sql, r#""id" = 1 AND "status" = 'active'"#);
502 }
503
504 #[test]
505 fn test_build_where_sql_in_op() {
506 let values = json!(["a", "b"]);
507 let parsed = vec![("status".to_string(), "IN".to_string(), &values)];
508 let sql = build_where_sql(&parsed);
509 assert_eq!(sql, r#""status" IN ('a', 'b')"#);
510 }
511}