1use crate::errors::{MCPError, Result as MCPResult};
2use crate::validation::{quote_ident, validate_identifier};
3use serde_json::{Value, json};
4use tokio_postgres::Client;
5
6const MAX_BATCH_ROWS: usize = 1000;
7const ALLOWED_OPS: &[&str] = &["=", "<", ">", "<=", ">=", "<>", "IN", "LIKE"];
8
9fn format_sql_value(val: &Value) -> String {
10 match val {
11 Value::String(s) => format!("'{}'", s.replace("'", "''")),
12 Value::Number(n) => n.to_string(),
13 Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
14 Value::Null => "NULL".to_string(),
15 Value::Array(_) | Value::Object(_) => format!("'{}'", val.to_string().replace("'", "''")),
16 }
17}
18
19fn validate_table_columns(table: &str, columns: &[&str]) -> Result<(), MCPError> {
20 validate_identifier(table, "table")?;
21 for col in columns {
22 validate_identifier(col, "column")?;
23 }
24 Ok(())
25}
26
27fn validate_where_clauses(
28 where_clauses: &[Value],
29) -> Result<Vec<(String, String, &Value)>, MCPError> {
30 if where_clauses.is_empty() {
31 return Err(MCPError::InvalidParams(
32 "'where_clauses' must not be empty".into(),
33 ));
34 }
35 let mut parsed = Vec::new();
36 for clause in where_clauses {
37 let obj = clause.as_object().ok_or_else(|| {
38 MCPError::InvalidParams(
39 "Each where_clause must be an object with 'column', 'op', and 'value'".into(),
40 )
41 })?;
42 let column = obj.get("column").and_then(|v| v.as_str()).ok_or_else(|| {
43 MCPError::InvalidParams("Each where_clause must have a string 'column'".into())
44 })?;
45 let op = obj.get("op").and_then(|v| v.as_str()).ok_or_else(|| {
46 MCPError::InvalidParams("Each where_clause must have a string 'op'".into())
47 })?;
48 let value = obj.get("value").ok_or_else(|| {
49 MCPError::InvalidParams("Each where_clause must have a 'value'".into())
50 })?;
51 validate_identifier(column, "where_clause.column")?;
52 if !ALLOWED_OPS.contains(&op) {
53 return Err(MCPError::InvalidParams(format!(
54 "Invalid operator '{op}' — allowed: {}",
55 ALLOWED_OPS.join(", ")
56 )));
57 }
58 parsed.push((column.to_string(), op.to_string(), value));
59 }
60 Ok(parsed)
61}
62
63fn build_where_sql(parsed: &[(String, String, &Value)]) -> String {
64 parsed
65 .iter()
66 .map(|(col, op, val)| {
67 if op == "IN" {
68 if let Some(arr) = val.as_array() {
69 let items: Vec<String> = arr.iter().map(format_sql_value).collect();
70 format!("{} IN ({})", quote_ident(col), items.join(", "))
71 } else {
72 format!("{} {} {}", quote_ident(col), op, format_sql_value(val))
73 }
74 } else {
75 format!("{} {} {}", quote_ident(col), op, format_sql_value(val))
76 }
77 })
78 .collect::<Vec<_>>()
79 .join(" OR ")
80}
81
82pub async fn async_batch_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
85 let params = params
86 .as_ref()
87 .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
88
89 let table = params
90 .get("table")
91 .and_then(|v| v.as_str())
92 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
93
94 let columns = params
95 .get("columns")
96 .and_then(|v| v.as_array())
97 .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
98
99 let rows = params
100 .get("rows")
101 .and_then(|v| v.as_array())
102 .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
103
104 if rows.is_empty() {
105 return Ok(json!({ "rows_affected": 0 }));
106 }
107
108 if rows.len() > MAX_BATCH_ROWS {
109 return Err(MCPError::InvalidParams(format!(
110 "Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})",
111 rows.len()
112 )));
113 }
114
115 let returning = params.get("returning").and_then(|v| v.as_str());
116
117 let column_count = columns.len();
118 let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
119
120 if column_names.len() != column_count {
121 return Err(MCPError::InvalidParams(
122 "All column names must be strings".into(),
123 ));
124 }
125
126 validate_table_columns(table, &column_names)?;
127
128 let quoted_table = quote_ident(table);
129 let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_ident(c)).collect();
130 let cols = quoted_cols.join(", ");
131
132 let mut sql = String::with_capacity(64 + cols.len() + rows.len() * (column_count * 16 + 4));
133 use std::fmt::Write;
134 write!(sql, "INSERT INTO {quoted_table} ({cols}) VALUES ").unwrap();
135
136 for (i, row) in rows.iter().enumerate() {
137 let row_array = row
138 .as_array()
139 .ok_or_else(|| MCPError::InvalidParams("Each row must be an array".into()))?;
140
141 if row_array.len() != column_count {
142 return Err(MCPError::InvalidParams(format!(
143 "Row {} has {} columns, expected {}",
144 i,
145 row_array.len(),
146 column_count
147 )));
148 }
149
150 if i > 0 {
151 sql.push(',');
152 }
153 sql.push('(');
154 for (j, val) in row_array.iter().enumerate() {
155 if j > 0 {
156 sql.push_str(", ");
157 }
158 match val {
159 Value::String(s) => {
160 sql.push('\'');
161 for ch in s.chars() {
162 if ch == '\'' {
163 sql.push_str("''");
164 } else {
165 sql.push(ch);
166 }
167 }
168 sql.push('\'');
169 }
170 Value::Number(n) => {
171 write!(sql, "{n}").unwrap();
172 }
173 Value::Bool(b) => {
174 sql.push_str(if *b { "true" } else { "false" });
175 }
176 Value::Null => {
177 sql.push_str("NULL");
178 }
179 Value::Array(_) | Value::Object(_) => {
180 let s = val.to_string();
181 sql.push('\'');
182 for ch in s.chars() {
183 if ch == '\'' {
184 sql.push_str("''");
185 } else {
186 sql.push(ch);
187 }
188 }
189 sql.push('\'');
190 }
191 }
192 }
193 sql.push(')');
194 }
195
196 client.execute("BEGIN", &[]).await?;
197 client
198 .execute("SET LOCAL synchronous_commit = OFF", &[])
199 .await?;
200
201 let result = if let Some(col) = returning {
202 validate_identifier(col, "returning")?;
203 let r = format!(" RETURNING {}", quote_ident(col));
204 sql.push_str(&r);
205 match client.query(&sql, &[]).await {
206 Ok(rows) => {
207 client.execute("COMMIT", &[]).await?;
208 let ids: Vec<Value> = rows
209 .iter()
210 .map(|r| {
211 r.try_get::<_, i64>(0)
212 .map(|id| json!(id))
213 .or_else(|_| r.try_get::<_, i32>(0).map(|id| json!(id)))
214 .unwrap_or(json!(null))
215 })
216 .collect();
217 json!({ "rows_affected": ids.len(), "inserted_ids": ids })
218 }
219 Err(e) => {
220 client.execute("ROLLBACK", &[]).await.ok();
221 return Err(MCPError::DatabaseError(e));
222 }
223 }
224 } else {
225 match client.execute(&sql, &[]).await {
226 Ok(rows_affected) => {
227 client.execute("COMMIT", &[]).await?;
228 json!({ "rows_affected": rows_affected })
229 }
230 Err(e) => {
231 client.execute("ROLLBACK", &[]).await.ok();
232 return Err(MCPError::DatabaseError(e));
233 }
234 }
235 };
236
237 Ok(result)
238}
239
240pub async fn async_batch_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
242 let params = params
243 .as_ref()
244 .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
245
246 let table = params
247 .get("table")
248 .and_then(|v| v.as_str())
249 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
250
251 let updates = params
252 .get("updates")
253 .and_then(|v| v.as_object())
254 .ok_or_else(|| MCPError::InvalidParams("Missing 'updates'".into()))?;
255
256 let where_clauses = params
257 .get("where_clauses")
258 .and_then(|v| v.as_array())
259 .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
260
261 validate_identifier(table, "table")?;
262 let parsed_where = validate_where_clauses(where_clauses)?;
263
264 let quoted_table = quote_ident(table);
265 let mut set_clauses = Vec::new();
266 for (key, val) in updates {
267 validate_identifier(key, "updates key")?;
268 set_clauses.push(format!("{} = {}", quote_ident(key), format_sql_value(val)));
269 }
270
271 let where_sql = build_where_sql(&parsed_where);
272 let sql = format!(
273 "UPDATE {quoted_table} SET {} WHERE {where_sql}",
274 set_clauses.join(", ")
275 );
276
277 let rows_affected = client.execute(&sql, &[]).await?;
278
279 Ok(json!({ "rows_affected": rows_affected }))
280}
281
282pub async fn async_batch_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
284 let params = params
285 .as_ref()
286 .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
287
288 let table = params
289 .get("table")
290 .and_then(|v| v.as_str())
291 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
292
293 let where_clauses = params
294 .get("where_clauses")
295 .and_then(|v| v.as_array())
296 .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
297
298 validate_identifier(table, "table")?;
299 let parsed_where = validate_where_clauses(where_clauses)?;
300
301 let returning = params.get("returning").and_then(|v| v.as_str());
302
303 let quoted_table = quote_ident(table);
304 let where_sql = build_where_sql(&parsed_where);
305 let mut sql = format!("DELETE FROM {quoted_table} WHERE {where_sql}");
306
307 if let Some(col) = returning {
308 validate_identifier(col, "returning")?;
309 sql.push_str(&format!(" RETURNING {}", quote_ident(col)));
310 let rows = client.query(&sql, &[]).await?;
311 let ids: Vec<Value> = rows
312 .iter()
313 .map(|r| {
314 r.try_get::<_, i64>(0)
315 .map(|id| json!(id))
316 .or_else(|_| r.try_get::<_, i32>(0).map(|id| json!(id)))
317 .unwrap_or(json!(null))
318 })
319 .collect();
320 Ok(json!({ "rows_affected": ids.len(), "inserted_ids": ids }))
321 } else {
322 let rows_affected = client.execute(&sql, &[]).await?;
323 Ok(json!({ "rows_affected": rows_affected }))
324 }
325}
326
327pub async fn async_batch_insert_copy(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
329 let params = params
330 .as_ref()
331 .ok_or_else(|| MCPError::InvalidParams("Missing parameters".into()))?;
332
333 let table = params
334 .get("table")
335 .and_then(|v| v.as_str())
336 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
337
338 let columns = params
339 .get("columns")
340 .and_then(|v| v.as_array())
341 .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
342
343 let rows = params
344 .get("rows")
345 .and_then(|v| v.as_array())
346 .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
347
348 let batch_size = params
349 .get("batch_size")
350 .and_then(|v| v.as_u64())
351 .unwrap_or(1000) as usize;
352
353 if rows.is_empty() {
354 return Ok(json!({"rows_affected": 0}));
355 }
356
357 if rows.len() > MAX_BATCH_ROWS {
358 return Err(MCPError::InvalidParams(format!(
359 "Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})",
360 rows.len()
361 )));
362 }
363
364 let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
365 validate_table_columns(table, &column_names)?;
366
367 let quoted_table = quote_ident(table);
368 let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_ident(c)).collect();
369
370 let mut total_affected = 0u64;
371
372 for batch in rows.chunks(batch_size) {
373 let mut sql = format!(
374 "INSERT INTO {quoted_table} ({}) VALUES ",
375 quoted_cols.join(", ")
376 );
377 let mut value_parts = Vec::new();
378
379 for row in batch {
380 let row_array = row
381 .as_array()
382 .ok_or_else(|| MCPError::InvalidParams("Each row must be an array".into()))?;
383
384 let row_values: Vec<String> = row_array.iter().map(format_sql_value).collect();
385 value_parts.push(format!("({})", row_values.join(", ")));
386 }
387
388 sql.push_str(&value_parts.join(", "));
389
390 let rows_affected = client.execute(&sql, &[]).await?;
391 total_affected += rows_affected;
392 }
393
394 #[allow(clippy::cast_precision_loss)]
395 let batches = (rows.len() as f64 / batch_size as f64).ceil() as u32;
396 Ok(json!({
397 "rows_affected": total_affected,
398 "batches": batches,
399 }))
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405
406 #[test]
407 fn test_format_sql_value() {
408 assert_eq!(format_sql_value(&Value::String("test".into())), "'test'");
409 assert_eq!(format_sql_value(&Value::Number(123.into())), "123");
410 assert_eq!(format_sql_value(&Value::Bool(true)), "true");
411 assert_eq!(format_sql_value(&Value::Null), "NULL");
412 }
413
414 #[test]
415 fn test_sql_injection_prevention() {
416 let malicious = Value::String("'; DROP TABLE users; --".into());
417 let result = format_sql_value(&malicious);
418 assert_eq!(result, "'''; DROP TABLE users; --'");
419 }
420
421 #[test]
422 fn test_validate_table_columns_rejects_injection() {
423 let result = validate_table_columns("users; DROP TABLE", &["id"]);
424 assert!(result.is_err());
425 assert!(
426 result
427 .unwrap_err()
428 .to_string()
429 .contains("invalid character")
430 );
431 }
432
433 #[test]
434 fn test_validate_table_columns_rejects_sql_in_column() {
435 let result = validate_table_columns("users", &["id; DROP TABLE users"]);
436 assert!(result.is_err());
437 }
438
439 #[test]
440 fn test_validate_table_columns_accepts_valid() {
441 assert!(validate_table_columns("users", &["id", "name"]).is_ok());
442 }
443
444 #[test]
445 fn test_validate_where_clauses_accepts_structured() {
446 let clauses = vec![
447 json!({"column": "id", "op": "=", "value": 1}),
448 json!({"column": "status", "op": "IN", "value": ["active", "pending"]}),
449 ];
450 let result = validate_where_clauses(&clauses);
451 assert!(result.is_ok());
452 }
453
454 #[test]
455 fn test_validate_where_clauses_rejects_invalid_op() {
456 let clauses = vec![json!({"column": "id", "op": "EXECUTE", "value": "malicious"})];
457 let result = validate_where_clauses(&clauses);
458 assert!(result.is_err());
459 assert!(result.unwrap_err().to_string().contains("Invalid operator"));
460 }
461
462 #[test]
463 fn test_validate_where_clauses_rejects_sql_in_column() {
464 let clauses = vec![json!({"column": "id; DROP TABLE", "op": "=", "value": 1})];
465 let result = validate_where_clauses(&clauses);
466 assert!(result.is_err());
467 }
468
469 #[test]
470 fn test_build_where_sql() {
471 let v1 = Value::Number(1.into());
472 let v2 = Value::String("active".into());
473 let parsed = vec![
474 ("id".to_string(), "=".to_string(), &v1),
475 ("status".to_string(), "=".to_string(), &v2),
476 ];
477 let sql = build_where_sql(&parsed);
478 assert_eq!(sql, r#""id" = 1 OR "status" = 'active'"#);
479 }
480
481 #[test]
482 fn test_build_where_sql_in_op() {
483 let values = json!(["a", "b"]);
484 let parsed = vec![("status".to_string(), "IN".to_string(), &values)];
485 let sql = build_where_sql(&parsed);
486 assert_eq!(sql, r#""status" IN ('a', 'b')"#);
487 }
488}