1use serde_json::{json, Value};
2use tokio_postgres::Client;
3use crate::errors::{MCPError, Result as MCPResult};
4use crate::validation::{validate_identifier, quote_identifier};
5
6const MAX_BATCH_ROWS: usize = 1000;
7const ALLOWED_OPS: &[&str] = &["=", "<", ">", "<=", ">=", "<>", "IN", "LIKE"];
8
9fn format_sql_value(val: &Value) -> String {
10 match val {
11 Value::String(s) => format!("'{}'", s.replace("'", "''")),
12 Value::Number(n) => n.to_string(),
13 Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
14 Value::Null => "NULL".to_string(),
15 Value::Array(_) | Value::Object(_) => format!("'{}'", val.to_string().replace("'", "''")),
16 }
17}
18
19fn validate_table_columns(table: &str, columns: &[&str]) -> Result<(), MCPError> {
20 validate_identifier(table, "table")?;
21 for col in columns {
22 validate_identifier(col, "column")?;
23 }
24 Ok(())
25}
26
27fn validate_where_clauses(where_clauses: &[Value]) -> Result<Vec<(String, String, &Value)>, MCPError> {
28 if where_clauses.is_empty() {
29 return Err(MCPError::InvalidParams("'where_clauses' must not be empty".into()));
30 }
31 let mut parsed = Vec::new();
32 for clause in where_clauses {
33 let obj = clause.as_object().ok_or_else(|| {
34 MCPError::InvalidParams("Each where_clause must be an object with 'column', 'op', and 'value'".into())
35 })?;
36 let column = obj.get("column").and_then(|v| v.as_str()).ok_or_else(|| {
37 MCPError::InvalidParams("Each where_clause must have a string 'column'".into())
38 })?;
39 let op = obj.get("op").and_then(|v| v.as_str()).ok_or_else(|| {
40 MCPError::InvalidParams("Each where_clause must have a string 'op'".into())
41 })?;
42 let value = obj.get("value").ok_or_else(|| {
43 MCPError::InvalidParams("Each where_clause must have a 'value'".into())
44 })?;
45 validate_identifier(column, "where_clause.column")?;
46 if !ALLOWED_OPS.contains(&op) {
47 return Err(MCPError::InvalidParams(
48 format!("Invalid operator '{op}' — allowed: {}", ALLOWED_OPS.join(", "))
49 ));
50 }
51 parsed.push((column.to_string(), op.to_string(), value));
52 }
53 Ok(parsed)
54}
55
56fn build_where_sql(parsed: &[(String, String, &Value)]) -> String {
57 parsed.iter().map(|(col, op, val)| {
58 if op == "IN" {
59 if let Some(arr) = val.as_array() {
60 let items: Vec<String> = arr.iter().map(format_sql_value).collect();
61 format!("{} IN ({})", quote_identifier(col), items.join(", "))
62 } else {
63 format!("{} {} {}", quote_identifier(col), op, format_sql_value(val))
64 }
65 } else {
66 format!("{} {} {}", quote_identifier(col), op, format_sql_value(val))
67 }
68 }).collect::<Vec<_>>().join(" OR ")
69}
70
71pub async fn async_batch_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
74 let params = params.as_ref().ok_or_else(|| {
75 MCPError::InvalidParams("Missing parameters".into())
76 })?;
77
78 let table = params
79 .get("table")
80 .and_then(|v| v.as_str())
81 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
82
83 let columns = params
84 .get("columns")
85 .and_then(|v| v.as_array())
86 .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
87
88 let rows = params
89 .get("rows")
90 .and_then(|v| v.as_array())
91 .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
92
93 if rows.is_empty() {
94 return Ok(json!({ "rows_affected": 0 }));
95 }
96
97 if rows.len() > MAX_BATCH_ROWS {
98 return Err(MCPError::InvalidParams(
99 format!("Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})", rows.len())
100 ));
101 }
102
103 let returning = params.get("returning").and_then(|v| v.as_str());
104
105 let column_count = columns.len();
106 let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
107
108 if column_names.len() != column_count {
109 return Err(MCPError::InvalidParams("All column names must be strings".into()));
110 }
111
112 validate_table_columns(table, &column_names)?;
113
114 let quoted_table = quote_identifier(table);
115 let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_identifier(c)).collect();
116 let cols = quoted_cols.join(", ");
117
118 let mut sql = String::with_capacity(64 + cols.len() + rows.len() * (column_count * 16 + 4));
119 use std::fmt::Write;
120 write!(sql, "INSERT INTO {quoted_table} ({cols}) VALUES ").unwrap();
121
122 for (i, row) in rows.iter().enumerate() {
123 let row_array = row.as_array().ok_or_else(|| {
124 MCPError::InvalidParams("Each row must be an array".into())
125 })?;
126
127 if row_array.len() != column_count {
128 return Err(MCPError::InvalidParams(
129 format!("Row {} has {} columns, expected {}", i, row_array.len(), column_count),
130 ));
131 }
132
133 if i > 0 {
134 sql.push(',');
135 }
136 sql.push('(');
137 for (j, val) in row_array.iter().enumerate() {
138 if j > 0 {
139 sql.push_str(", ");
140 }
141 match val {
142 Value::String(s) => {
143 sql.push('\'');
144 for ch in s.chars() {
145 if ch == '\'' {
146 sql.push_str("''");
147 } else {
148 sql.push(ch);
149 }
150 }
151 sql.push('\'');
152 }
153 Value::Number(n) => {
154 write!(sql, "{n}").unwrap();
155 }
156 Value::Bool(b) => {
157 sql.push_str(if *b { "true" } else { "false" });
158 }
159 Value::Null => {
160 sql.push_str("NULL");
161 }
162 Value::Array(_) | Value::Object(_) => {
163 let s = val.to_string();
164 sql.push('\'');
165 for ch in s.chars() {
166 if ch == '\'' {
167 sql.push_str("''");
168 } else {
169 sql.push(ch);
170 }
171 }
172 sql.push('\'');
173 }
174 }
175 }
176 sql.push(')');
177 }
178
179 client.execute("BEGIN", &[]).await?;
180 client.execute("SET LOCAL synchronous_commit = OFF", &[]).await?;
181
182 let result = if let Some(col) = returning {
183 validate_identifier(col, "returning")?;
184 let r = format!(" RETURNING {}", quote_identifier(col));
185 sql.push_str(&r);
186 match client.query(&sql, &[]).await {
187 Ok(rows) => {
188 client.execute("COMMIT", &[]).await?;
189 let ids: Vec<Value> = rows.iter().map(|r| {
190 if let Ok(id) = r.try_get::<_, i64>(0) {
191 json!(id)
192 } else if let Ok(id) = r.try_get::<_, i32>(0) {
193 json!(id)
194 } else {
195 json!(null)
196 }
197 }).collect();
198 json!({ "rows_affected": ids.len(), "inserted_ids": ids })
199 }
200 Err(e) => {
201 client.execute("ROLLBACK", &[]).await.ok();
202 return Err(MCPError::DatabaseError(e));
203 }
204 }
205 } else {
206 match client.execute(&sql, &[]).await {
207 Ok(rows_affected) => {
208 client.execute("COMMIT", &[]).await?;
209 json!({ "rows_affected": rows_affected })
210 }
211 Err(e) => {
212 client.execute("ROLLBACK", &[]).await.ok();
213 return Err(MCPError::DatabaseError(e));
214 }
215 }
216 };
217
218 Ok(result)
219}
220
221pub async fn async_batch_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
223 let params = params.as_ref().ok_or_else(|| {
224 MCPError::InvalidParams("Missing parameters".into())
225 })?;
226
227 let table = params
228 .get("table")
229 .and_then(|v| v.as_str())
230 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
231
232 let updates = params
233 .get("updates")
234 .and_then(|v| v.as_object())
235 .ok_or_else(|| MCPError::InvalidParams("Missing 'updates'".into()))?;
236
237 let where_clauses = params
238 .get("where_clauses")
239 .and_then(|v| v.as_array())
240 .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
241
242 validate_identifier(table, "table")?;
243 let parsed_where = validate_where_clauses(where_clauses)?;
244
245 let quoted_table = quote_identifier(table);
246 let mut set_clauses = Vec::new();
247 for (key, val) in updates {
248 validate_identifier(key, "updates key")?;
249 set_clauses.push(format!("{} = {}", quote_identifier(key), format_sql_value(val)));
250 }
251
252 let where_sql = build_where_sql(&parsed_where);
253 let sql = format!("UPDATE {quoted_table} SET {} WHERE {where_sql}", set_clauses.join(", "));
254
255 let rows_affected = client.execute(&sql, &[]).await?;
256
257 Ok(json!({ "rows_affected": rows_affected }))
258}
259
260pub async fn async_batch_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
262 let params = params.as_ref().ok_or_else(|| {
263 MCPError::InvalidParams("Missing parameters".into())
264 })?;
265
266 let table = params
267 .get("table")
268 .and_then(|v| v.as_str())
269 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
270
271 let where_clauses = params
272 .get("where_clauses")
273 .and_then(|v| v.as_array())
274 .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
275
276 validate_identifier(table, "table")?;
277 let parsed_where = validate_where_clauses(where_clauses)?;
278
279 let returning = params.get("returning").and_then(|v| v.as_str());
280
281 let quoted_table = quote_identifier(table);
282 let where_sql = build_where_sql(&parsed_where);
283 let mut sql = format!("DELETE FROM {quoted_table} WHERE {where_sql}");
284
285 if let Some(col) = returning {
286 validate_identifier(col, "returning")?;
287 sql.push_str(&format!(" RETURNING {}", quote_identifier(col)));
288 let rows = client.query(&sql, &[]).await?;
289 let ids: Vec<Value> = rows.iter().map(|r| {
290 if let Ok(id) = r.try_get::<_, i64>(0) {
291 json!(id)
292 } else if let Ok(id) = r.try_get::<_, i32>(0) {
293 json!(id)
294 } else {
295 json!(null)
296 }
297 }).collect();
298 Ok(json!({ "rows_affected": ids.len(), "inserted_ids": ids }))
299 } else {
300 let rows_affected = client.execute(&sql, &[]).await?;
301 Ok(json!({ "rows_affected": rows_affected }))
302 }
303}
304
305pub async fn async_batch_insert_copy(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
307 let params = params.as_ref().ok_or_else(|| {
308 MCPError::InvalidParams("Missing parameters".into())
309 })?;
310
311 let table = params
312 .get("table")
313 .and_then(|v| v.as_str())
314 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
315
316 let columns = params
317 .get("columns")
318 .and_then(|v| v.as_array())
319 .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
320
321 let rows = params
322 .get("rows")
323 .and_then(|v| v.as_array())
324 .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
325
326 let batch_size = params
327 .get("batch_size")
328 .and_then(|v| v.as_u64())
329 .unwrap_or(1000) as usize;
330
331 if rows.is_empty() {
332 return Ok(json!({"rows_affected": 0}));
333 }
334
335 if rows.len() > MAX_BATCH_ROWS {
336 return Err(MCPError::InvalidParams(
337 format!("Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})", rows.len())
338 ));
339 }
340
341 let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
342 validate_table_columns(table, &column_names)?;
343
344 let quoted_table = quote_identifier(table);
345 let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_identifier(c)).collect();
346
347 let mut total_affected = 0u64;
348
349 for batch in rows.chunks(batch_size) {
350 let mut sql = format!("INSERT INTO {quoted_table} ({}) VALUES ", quoted_cols.join(", "));
351 let mut value_parts = Vec::new();
352
353 for row in batch {
354 let row_array = row.as_array().ok_or_else(|| {
355 MCPError::InvalidParams("Each row must be an array".into())
356 })?;
357
358 let row_values: Vec<String> = row_array.iter().map(format_sql_value).collect();
359 value_parts.push(format!("({})", row_values.join(", ")));
360 }
361
362 sql.push_str(&value_parts.join(", "));
363
364 let rows_affected = client.execute(&sql, &[]).await?;
365 total_affected += rows_affected;
366 }
367
368 Ok(json!({
369 "rows_affected": total_affected,
370 "batches": (rows.len() as f64 / batch_size as f64).ceil() as u32
371 }))
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_format_sql_value() {
380 assert_eq!(format_sql_value(&Value::String("test".into())), "'test'");
381 assert_eq!(format_sql_value(&Value::Number(123.into())), "123");
382 assert_eq!(format_sql_value(&Value::Bool(true)), "true");
383 assert_eq!(format_sql_value(&Value::Null), "NULL");
384 }
385
386 #[test]
387 fn test_sql_injection_prevention() {
388 let malicious = Value::String("'; DROP TABLE users; --".into());
389 let result = format_sql_value(&malicious);
390 assert_eq!(result, "'''; DROP TABLE users; --'");
391 }
392
393 #[test]
394 fn test_validate_table_columns_rejects_injection() {
395 let result = validate_table_columns("users; DROP TABLE", &["id"]);
396 assert!(result.is_err());
397 assert!(result.unwrap_err().to_string().contains("invalid character"));
398 }
399
400 #[test]
401 fn test_validate_table_columns_rejects_sql_in_column() {
402 let result = validate_table_columns("users", &["id; DROP TABLE users"]);
403 assert!(result.is_err());
404 }
405
406 #[test]
407 fn test_validate_table_columns_accepts_valid() {
408 assert!(validate_table_columns("users", &["id", "name"]).is_ok());
409 }
410
411 #[test]
412 fn test_validate_where_clauses_accepts_structured() {
413 let clauses = vec![
414 json!({"column": "id", "op": "=", "value": 1}),
415 json!({"column": "status", "op": "IN", "value": ["active", "pending"]}),
416 ];
417 let result = validate_where_clauses(&clauses);
418 assert!(result.is_ok());
419 }
420
421 #[test]
422 fn test_validate_where_clauses_rejects_invalid_op() {
423 let clauses = vec![
424 json!({"column": "id", "op": "EXECUTE", "value": "malicious"}),
425 ];
426 let result = validate_where_clauses(&clauses);
427 assert!(result.is_err());
428 assert!(result.unwrap_err().to_string().contains("Invalid operator"));
429 }
430
431 #[test]
432 fn test_validate_where_clauses_rejects_sql_in_column() {
433 let clauses = vec![
434 json!({"column": "id; DROP TABLE", "op": "=", "value": 1}),
435 ];
436 let result = validate_where_clauses(&clauses);
437 assert!(result.is_err());
438 }
439
440 #[test]
441 fn test_build_where_sql() {
442 let v1 = Value::Number(1.into());
443 let v2 = Value::String("active".into());
444 let parsed = vec![
445 ("id".to_string(), "=".to_string(), &v1),
446 ("status".to_string(), "=".to_string(), &v2),
447 ];
448 let sql = build_where_sql(&parsed);
449 assert_eq!(sql, r#""id" = 1 OR "status" = 'active'"#);
450 }
451
452 #[test]
453 fn test_build_where_sql_in_op() {
454 let values = json!(["a", "b"]);
455 let parsed = vec![
456 ("status".to_string(), "IN".to_string(), &values),
457 ];
458 let sql = build_where_sql(&parsed);
459 assert_eq!(sql, r#""status" IN ('a', 'b')"#);
460 }
461}