1use serde_json::{json, Value};
2use tokio_postgres::Client;
3use crate::errors::{MCPError, Result as MCPResult};
4use crate::validation::{validate_identifier, quote_ident};
5
6const MAX_BATCH_ROWS: usize = 1000;
7const ALLOWED_OPS: &[&str] = &["=", "<", ">", "<=", ">=", "<>", "IN", "LIKE"];
8
9fn format_sql_value(val: &Value) -> String {
10 match val {
11 Value::String(s) => format!("'{}'", s.replace("'", "''")),
12 Value::Number(n) => n.to_string(),
13 Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
14 Value::Null => "NULL".to_string(),
15 Value::Array(_) | Value::Object(_) => format!("'{}'", val.to_string().replace("'", "''")),
16 }
17}
18
19fn validate_table_columns(table: &str, columns: &[&str]) -> Result<(), MCPError> {
20 validate_identifier(table, "table")?;
21 for col in columns {
22 validate_identifier(col, "column")?;
23 }
24 Ok(())
25}
26
27fn validate_where_clauses(where_clauses: &[Value]) -> Result<Vec<(String, String, &Value)>, MCPError> {
28 if where_clauses.is_empty() {
29 return Err(MCPError::InvalidParams("'where_clauses' must not be empty".into()));
30 }
31 let mut parsed = Vec::new();
32 for clause in where_clauses {
33 let obj = clause.as_object().ok_or_else(|| {
34 MCPError::InvalidParams("Each where_clause must be an object with 'column', 'op', and 'value'".into())
35 })?;
36 let column = obj.get("column").and_then(|v| v.as_str()).ok_or_else(|| {
37 MCPError::InvalidParams("Each where_clause must have a string 'column'".into())
38 })?;
39 let op = obj.get("op").and_then(|v| v.as_str()).ok_or_else(|| {
40 MCPError::InvalidParams("Each where_clause must have a string 'op'".into())
41 })?;
42 let value = obj.get("value").ok_or_else(|| {
43 MCPError::InvalidParams("Each where_clause must have a 'value'".into())
44 })?;
45 validate_identifier(column, "where_clause.column")?;
46 if !ALLOWED_OPS.contains(&op) {
47 return Err(MCPError::InvalidParams(
48 format!("Invalid operator '{op}' — allowed: {}", ALLOWED_OPS.join(", "))
49 ));
50 }
51 parsed.push((column.to_string(), op.to_string(), value));
52 }
53 Ok(parsed)
54}
55
56fn build_where_sql(parsed: &[(String, String, &Value)]) -> String {
57 parsed.iter().map(|(col, op, val)| {
58 if op == "IN" {
59 if let Some(arr) = val.as_array() {
60 let items: Vec<String> = arr.iter().map(format_sql_value).collect();
61 format!("{} IN ({})", quote_ident(col), items.join(", "))
62 } else {
63 format!("{} {} {}", quote_ident(col), op, format_sql_value(val))
64 }
65 } else {
66 format!("{} {} {}", quote_ident(col), op, format_sql_value(val))
67 }
68 }).collect::<Vec<_>>().join(" OR ")
69}
70
71pub async fn async_batch_insert(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
74 let params = params.as_ref().ok_or_else(|| {
75 MCPError::InvalidParams("Missing parameters".into())
76 })?;
77
78 let table = params
79 .get("table")
80 .and_then(|v| v.as_str())
81 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
82
83 let columns = params
84 .get("columns")
85 .and_then(|v| v.as_array())
86 .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
87
88 let rows = params
89 .get("rows")
90 .and_then(|v| v.as_array())
91 .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
92
93 if rows.is_empty() {
94 return Ok(json!({ "rows_affected": 0 }));
95 }
96
97 if rows.len() > MAX_BATCH_ROWS {
98 return Err(MCPError::InvalidParams(
99 format!("Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})", rows.len())
100 ));
101 }
102
103 let returning = params.get("returning").and_then(|v| v.as_str());
104
105 let column_count = columns.len();
106 let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
107
108 if column_names.len() != column_count {
109 return Err(MCPError::InvalidParams("All column names must be strings".into()));
110 }
111
112 validate_table_columns(table, &column_names)?;
113
114 let quoted_table = quote_ident(table);
115 let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_ident(c)).collect();
116 let cols = quoted_cols.join(", ");
117
118 let mut sql = String::with_capacity(64 + cols.len() + rows.len() * (column_count * 16 + 4));
119 use std::fmt::Write;
120 write!(sql, "INSERT INTO {quoted_table} ({cols}) VALUES ").unwrap();
121
122 for (i, row) in rows.iter().enumerate() {
123 let row_array = row.as_array().ok_or_else(|| {
124 MCPError::InvalidParams("Each row must be an array".into())
125 })?;
126
127 if row_array.len() != column_count {
128 return Err(MCPError::InvalidParams(
129 format!("Row {} has {} columns, expected {}", i, row_array.len(), column_count),
130 ));
131 }
132
133 if i > 0 {
134 sql.push(',');
135 }
136 sql.push('(');
137 for (j, val) in row_array.iter().enumerate() {
138 if j > 0 {
139 sql.push_str(", ");
140 }
141 match val {
142 Value::String(s) => {
143 sql.push('\'');
144 for ch in s.chars() {
145 if ch == '\'' {
146 sql.push_str("''");
147 } else {
148 sql.push(ch);
149 }
150 }
151 sql.push('\'');
152 }
153 Value::Number(n) => {
154 write!(sql, "{n}").unwrap();
155 }
156 Value::Bool(b) => {
157 sql.push_str(if *b { "true" } else { "false" });
158 }
159 Value::Null => {
160 sql.push_str("NULL");
161 }
162 Value::Array(_) | Value::Object(_) => {
163 let s = val.to_string();
164 sql.push('\'');
165 for ch in s.chars() {
166 if ch == '\'' {
167 sql.push_str("''");
168 } else {
169 sql.push(ch);
170 }
171 }
172 sql.push('\'');
173 }
174 }
175 }
176 sql.push(')');
177 }
178
179 client.execute("BEGIN", &[]).await?;
180 client.execute("SET LOCAL synchronous_commit = OFF", &[]).await?;
181
182 let result = if let Some(col) = returning {
183 validate_identifier(col, "returning")?;
184 let r = format!(" RETURNING {}", quote_ident(col));
185 sql.push_str(&r);
186 match client.query(&sql, &[]).await {
187 Ok(rows) => {
188 client.execute("COMMIT", &[]).await?;
189 let ids: Vec<Value> = rows.iter().map(|r| {
190 r.try_get::<_, i64>(0).map(|id| json!(id))
191 .or_else(|_| r.try_get::<_, i32>(0).map(|id| json!(id)))
192 .unwrap_or(json!(null))
193 }).collect();
194 json!({ "rows_affected": ids.len(), "inserted_ids": ids })
195 }
196 Err(e) => {
197 client.execute("ROLLBACK", &[]).await.ok();
198 return Err(MCPError::DatabaseError(e));
199 }
200 }
201 } else {
202 match client.execute(&sql, &[]).await {
203 Ok(rows_affected) => {
204 client.execute("COMMIT", &[]).await?;
205 json!({ "rows_affected": rows_affected })
206 }
207 Err(e) => {
208 client.execute("ROLLBACK", &[]).await.ok();
209 return Err(MCPError::DatabaseError(e));
210 }
211 }
212 };
213
214 Ok(result)
215}
216
217pub async fn async_batch_update(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
219 let params = params.as_ref().ok_or_else(|| {
220 MCPError::InvalidParams("Missing parameters".into())
221 })?;
222
223 let table = params
224 .get("table")
225 .and_then(|v| v.as_str())
226 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
227
228 let updates = params
229 .get("updates")
230 .and_then(|v| v.as_object())
231 .ok_or_else(|| MCPError::InvalidParams("Missing 'updates'".into()))?;
232
233 let where_clauses = params
234 .get("where_clauses")
235 .and_then(|v| v.as_array())
236 .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
237
238 validate_identifier(table, "table")?;
239 let parsed_where = validate_where_clauses(where_clauses)?;
240
241 let quoted_table = quote_ident(table);
242 let mut set_clauses = Vec::new();
243 for (key, val) in updates {
244 validate_identifier(key, "updates key")?;
245 set_clauses.push(format!("{} = {}", quote_ident(key), format_sql_value(val)));
246 }
247
248 let where_sql = build_where_sql(&parsed_where);
249 let sql = format!("UPDATE {quoted_table} SET {} WHERE {where_sql}", set_clauses.join(", "));
250
251 let rows_affected = client.execute(&sql, &[]).await?;
252
253 Ok(json!({ "rows_affected": rows_affected }))
254}
255
256pub async fn async_batch_delete(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
258 let params = params.as_ref().ok_or_else(|| {
259 MCPError::InvalidParams("Missing parameters".into())
260 })?;
261
262 let table = params
263 .get("table")
264 .and_then(|v| v.as_str())
265 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
266
267 let where_clauses = params
268 .get("where_clauses")
269 .and_then(|v| v.as_array())
270 .ok_or_else(|| MCPError::InvalidParams("Missing 'where_clauses'".into()))?;
271
272 validate_identifier(table, "table")?;
273 let parsed_where = validate_where_clauses(where_clauses)?;
274
275 let returning = params.get("returning").and_then(|v| v.as_str());
276
277 let quoted_table = quote_ident(table);
278 let where_sql = build_where_sql(&parsed_where);
279 let mut sql = format!("DELETE FROM {quoted_table} WHERE {where_sql}");
280
281 if let Some(col) = returning {
282 validate_identifier(col, "returning")?;
283 sql.push_str(&format!(" RETURNING {}", quote_ident(col)));
284 let rows = client.query(&sql, &[]).await?;
285 let ids: Vec<Value> = rows.iter().map(|r| {
286 r.try_get::<_, i64>(0).map(|id| json!(id))
287 .or_else(|_| r.try_get::<_, i32>(0).map(|id| json!(id)))
288 .unwrap_or(json!(null))
289 }).collect();
290 Ok(json!({ "rows_affected": ids.len(), "inserted_ids": ids }))
291 } else {
292 let rows_affected = client.execute(&sql, &[]).await?;
293 Ok(json!({ "rows_affected": rows_affected }))
294 }
295}
296
297pub async fn async_batch_insert_copy(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
299 let params = params.as_ref().ok_or_else(|| {
300 MCPError::InvalidParams("Missing parameters".into())
301 })?;
302
303 let table = params
304 .get("table")
305 .and_then(|v| v.as_str())
306 .ok_or_else(|| MCPError::InvalidParams("Missing 'table'".into()))?;
307
308 let columns = params
309 .get("columns")
310 .and_then(|v| v.as_array())
311 .ok_or_else(|| MCPError::InvalidParams("Missing 'columns'".into()))?;
312
313 let rows = params
314 .get("rows")
315 .and_then(|v| v.as_array())
316 .ok_or_else(|| MCPError::InvalidParams("Missing 'rows'".into()))?;
317
318 let batch_size = params
319 .get("batch_size")
320 .and_then(|v| v.as_u64())
321 .unwrap_or(1000) as usize;
322
323 if rows.is_empty() {
324 return Ok(json!({"rows_affected": 0}));
325 }
326
327 if rows.len() > MAX_BATCH_ROWS {
328 return Err(MCPError::InvalidParams(
329 format!("Batch size exceeds maximum of {MAX_BATCH_ROWS} rows (got {})", rows.len())
330 ));
331 }
332
333 let column_names: Vec<&str> = columns.iter().filter_map(|c| c.as_str()).collect();
334 validate_table_columns(table, &column_names)?;
335
336 let quoted_table = quote_ident(table);
337 let quoted_cols: Vec<String> = column_names.iter().map(|c| quote_ident(c)).collect();
338
339 let mut total_affected = 0u64;
340
341 for batch in rows.chunks(batch_size) {
342 let mut sql = format!("INSERT INTO {quoted_table} ({}) VALUES ", quoted_cols.join(", "));
343 let mut value_parts = Vec::new();
344
345 for row in batch {
346 let row_array = row.as_array().ok_or_else(|| {
347 MCPError::InvalidParams("Each row must be an array".into())
348 })?;
349
350 let row_values: Vec<String> = row_array.iter().map(format_sql_value).collect();
351 value_parts.push(format!("({})", row_values.join(", ")));
352 }
353
354 sql.push_str(&value_parts.join(", "));
355
356 let rows_affected = client.execute(&sql, &[]).await?;
357 total_affected += rows_affected;
358 }
359
360 Ok(json!({
361 "rows_affected": total_affected,
362 "batches": (rows.len() as f64 / batch_size as f64).ceil() as u32
363 }))
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn test_format_sql_value() {
372 assert_eq!(format_sql_value(&Value::String("test".into())), "'test'");
373 assert_eq!(format_sql_value(&Value::Number(123.into())), "123");
374 assert_eq!(format_sql_value(&Value::Bool(true)), "true");
375 assert_eq!(format_sql_value(&Value::Null), "NULL");
376 }
377
378 #[test]
379 fn test_sql_injection_prevention() {
380 let malicious = Value::String("'; DROP TABLE users; --".into());
381 let result = format_sql_value(&malicious);
382 assert_eq!(result, "'''; DROP TABLE users; --'");
383 }
384
385 #[test]
386 fn test_validate_table_columns_rejects_injection() {
387 let result = validate_table_columns("users; DROP TABLE", &["id"]);
388 assert!(result.is_err());
389 assert!(result.unwrap_err().to_string().contains("invalid character"));
390 }
391
392 #[test]
393 fn test_validate_table_columns_rejects_sql_in_column() {
394 let result = validate_table_columns("users", &["id; DROP TABLE users"]);
395 assert!(result.is_err());
396 }
397
398 #[test]
399 fn test_validate_table_columns_accepts_valid() {
400 assert!(validate_table_columns("users", &["id", "name"]).is_ok());
401 }
402
403 #[test]
404 fn test_validate_where_clauses_accepts_structured() {
405 let clauses = vec![
406 json!({"column": "id", "op": "=", "value": 1}),
407 json!({"column": "status", "op": "IN", "value": ["active", "pending"]}),
408 ];
409 let result = validate_where_clauses(&clauses);
410 assert!(result.is_ok());
411 }
412
413 #[test]
414 fn test_validate_where_clauses_rejects_invalid_op() {
415 let clauses = vec![
416 json!({"column": "id", "op": "EXECUTE", "value": "malicious"}),
417 ];
418 let result = validate_where_clauses(&clauses);
419 assert!(result.is_err());
420 assert!(result.unwrap_err().to_string().contains("Invalid operator"));
421 }
422
423 #[test]
424 fn test_validate_where_clauses_rejects_sql_in_column() {
425 let clauses = vec![
426 json!({"column": "id; DROP TABLE", "op": "=", "value": 1}),
427 ];
428 let result = validate_where_clauses(&clauses);
429 assert!(result.is_err());
430 }
431
432 #[test]
433 fn test_build_where_sql() {
434 let v1 = Value::Number(1.into());
435 let v2 = Value::String("active".into());
436 let parsed = vec![
437 ("id".to_string(), "=".to_string(), &v1),
438 ("status".to_string(), "=".to_string(), &v2),
439 ];
440 let sql = build_where_sql(&parsed);
441 assert_eq!(sql, r#""id" = 1 OR "status" = 'active'"#);
442 }
443
444 #[test]
445 fn test_build_where_sql_in_op() {
446 let values = json!(["a", "b"]);
447 let parsed = vec![
448 ("status".to_string(), "IN".to_string(), &values),
449 ];
450 let sql = build_where_sql(&parsed);
451 assert_eq!(sql, r#""status" IN ('a', 'b')"#);
452 }
453}