Skip to main content

mcp_postgres/
validation.rs

1/// Input validation for tool parameters
2/// Validates tool arguments against defined schemas and provides helpful error messages
3use serde_json::Value;
4
5#[derive(Debug, Clone)]
6pub struct ValidationError {
7    pub tool: String,
8    pub param: String,
9    pub error: String,
10    pub suggestion: String,
11}
12
13impl std::fmt::Display for ValidationError {
14    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15        write!(
16            f,
17            "āŒ Validation Error in tool '{}' parameter '{}': {}\nšŸ’” Suggestion: {}",
18            self.tool, self.param, self.error, self.suggestion
19        )
20    }
21}
22
23pub fn validate_tool_input(tool_name: &str, arguments: &Value) -> Result<(), Vec<ValidationError>> {
24    let mut errors = Vec::new();
25
26    match tool_name {
27        // Schema Inspection Tools
28        "list_tables" => {
29            // No required parameters
30        }
31        "describe_table" => {
32            if let Some(table) = arguments.get("table") {
33                if !table.is_string() {
34                    errors.push(ValidationError {
35                        tool: tool_name.to_string(),
36                        param: "table".to_string(),
37                        error: format!("Expected string, got {}", table.type_str()),
38                        suggestion: "Example: {\"table\": \"users\"} or {\"table\": \"public.orders\"}".to_string(),
39                    });
40                } else {
41                    let table_str = table.as_str().unwrap();
42                    if table_str.is_empty() {
43                        errors.push(ValidationError {
44                            tool: tool_name.to_string(),
45                            param: "table".to_string(),
46                            error: "Table name cannot be empty".to_string(),
47                            suggestion: "Provide a valid table name like 'users' or 'public.products'".to_string(),
48                        });
49                    }
50                    if table_str.len() > 255 {
51                        errors.push(ValidationError {
52                            tool: tool_name.to_string(),
53                            param: "table".to_string(),
54                            error: format!("Table name too long: {} characters (max 255)", table_str.len()),
55                            suggestion: "Use a shorter table name".to_string(),
56                        });
57                    }
58                }
59            } else {
60                errors.push(ValidationError {
61                    tool: tool_name.to_string(),
62                    param: "table".to_string(),
63                    error: "Required parameter missing".to_string(),
64                    suggestion: "Include 'table' parameter: {\"table\": \"users\"}".to_string(),
65                });
66            }
67        }
68
69        // Query Execution Tools
70        "execute_query" | "execute_insert" | "execute_update" | "execute_delete" => {
71            if let Some(sql) = arguments.get("sql") {
72                if !sql.is_string() {
73                    errors.push(ValidationError {
74                        tool: tool_name.to_string(),
75                        param: "sql".to_string(),
76                        error: format!("Expected string SQL, got {}", sql.type_str()),
77                        suggestion: "Example: {\"sql\": \"SELECT * FROM users LIMIT 10\"}".to_string(),
78                    });
79                } else {
80                    let sql_str = sql.as_str().unwrap();
81                    if sql_str.is_empty() {
82                        errors.push(ValidationError {
83                            tool: tool_name.to_string(),
84                            param: "sql".to_string(),
85                            error: "SQL statement cannot be empty".to_string(),
86                            suggestion: "Provide a valid SQL statement".to_string(),
87                        });
88                    }
89                    if sql_str.len() > 10000 {
90                        errors.push(ValidationError {
91                            tool: tool_name.to_string(),
92                            param: "sql".to_string(),
93                            error: format!("SQL too long: {} characters (max 10,000)", sql_str.len()),
94                            suggestion: "Break the query into smaller parts or use a subquery".to_string(),
95                        });
96                    }
97
98                    // Validate SQL type for specific tools
99                    let sql_upper = sql_str.trim().to_uppercase();
100                    match tool_name {
101                        "execute_query" => {
102                            if !sql_upper.starts_with("SELECT") && !sql_upper.starts_with("WITH") {
103                                errors.push(ValidationError {
104                                    tool: tool_name.to_string(),
105                                    param: "sql".to_string(),
106                                    error: "execute_query requires a SELECT statement".to_string(),
107                                    suggestion: "Use 'execute_query' only for SELECT queries. Use 'execute_insert', 'execute_update', or 'execute_delete' for modifications.".to_string(),
108                                });
109                            }
110                        }
111                        "execute_insert" => {
112                            if !sql_upper.starts_with("INSERT") {
113                                errors.push(ValidationError {
114                                    tool: tool_name.to_string(),
115                                    param: "sql".to_string(),
116                                    error: "execute_insert requires an INSERT statement".to_string(),
117                                    suggestion: "Example: {\"sql\": \"INSERT INTO users (email) VALUES ('user@example.com')\"}".to_string(),
118                                });
119                            }
120                        }
121                        "execute_update" => {
122                            if !sql_upper.starts_with("UPDATE") {
123                                errors.push(ValidationError {
124                                    tool: tool_name.to_string(),
125                                    param: "sql".to_string(),
126                                    error: "execute_update requires an UPDATE statement".to_string(),
127                                    suggestion: "Example: {\"sql\": \"UPDATE users SET status = 'active' WHERE id = 1\"}".to_string(),
128                                });
129                            }
130                            if !sql_str.contains("WHERE") && !sql_str.contains("where") {
131                                errors.push(ValidationError {
132                                    tool: tool_name.to_string(),
133                                    param: "sql".to_string(),
134                                    error: "UPDATE without WHERE clause will modify all rows".to_string(),
135                                    suggestion: "Add a WHERE clause: UPDATE users SET ... WHERE <condition>".to_string(),
136                                });
137                            }
138                        }
139                        "execute_delete" => {
140                            if !sql_upper.starts_with("DELETE") {
141                                errors.push(ValidationError {
142                                    tool: tool_name.to_string(),
143                                    param: "sql".to_string(),
144                                    error: "execute_delete requires a DELETE statement".to_string(),
145                                    suggestion: "Example: {\"sql\": \"DELETE FROM users WHERE id = 999\"}".to_string(),
146                                });
147                            }
148                            if !sql_str.contains("WHERE") && !sql_str.contains("where") {
149                                errors.push(ValidationError {
150                                    tool: tool_name.to_string(),
151                                    param: "sql".to_string(),
152                                    error: "DELETE without WHERE clause will delete all rows".to_string(),
153                                    suggestion: "Add a WHERE clause: DELETE FROM users WHERE <condition>".to_string(),
154                                });
155                            }
156                        }
157                        _ => {}
158                    }
159                }
160            } else {
161                errors.push(ValidationError {
162                    tool: tool_name.to_string(),
163                    param: "sql".to_string(),
164                    error: "Required parameter 'sql' missing".to_string(),
165                    suggestion: format!("Include SQL: {{\"sql\": \"<{} statement>\"}}", tool_name.split('_').nth(1).unwrap_or("SQL")),
166                });
167            }
168        }
169
170        // Batch Insert Tools
171        "batch_insert" | "batch_insert_copy" => {
172            validate_batch_insert(tool_name, arguments, &mut errors);
173        }
174
175        // Explain Query Tool
176        "explain_query" => {
177            if let Some(sql) = arguments.get("sql") {
178                if !sql.is_string() {
179                    errors.push(ValidationError {
180                        tool: tool_name.to_string(),
181                        param: "sql".to_string(),
182                        error: format!("Expected string, got {}", sql.type_str()),
183                        suggestion: "Example: {\"sql\": \"SELECT * FROM users\"}".to_string(),
184                    });
185                } else if sql.as_str().unwrap().is_empty() {
186                    errors.push(ValidationError {
187                        tool: tool_name.to_string(),
188                        param: "sql".to_string(),
189                        error: "SQL cannot be empty".to_string(),
190                        suggestion: "Provide a SELECT query to explain".to_string(),
191                    });
192                }
193            } else {
194                errors.push(ValidationError {
195                    tool: tool_name.to_string(),
196                    param: "sql".to_string(),
197                    error: "Required parameter 'sql' missing".to_string(),
198                    suggestion: "Include SQL: {\"sql\": \"SELECT * FROM users\"}".to_string(),
199                });
200            }
201
202            if let Some(format) = arguments.get("format") {
203                if let Some(fmt) = format.as_str() {
204                    if !["json", "text", "xml", "yaml"].contains(&fmt) {
205                        errors.push(ValidationError {
206                            tool: tool_name.to_string(),
207                            param: "format".to_string(),
208                            error: format!("Invalid format '{}' (must be json, text, xml, or yaml)", fmt),
209                            suggestion: "Use one of: json (default), text, xml, yaml".to_string(),
210                        });
211                    }
212                }
213            }
214        }
215
216        // Configuration Tool
217        "get_setting" => {
218            if let Some(setting) = arguments.get("setting_name") {
219                if !setting.is_string() {
220                    errors.push(ValidationError {
221                        tool: tool_name.to_string(),
222                        param: "setting_name".to_string(),
223                        error: format!("Expected string, got {}", setting.type_str()),
224                        suggestion: "Example: {\"setting_name\": \"max_connections\"}".to_string(),
225                    });
226                } else if setting.as_str().unwrap().is_empty() {
227                    errors.push(ValidationError {
228                        tool: tool_name.to_string(),
229                        param: "setting_name".to_string(),
230                        error: "Setting name cannot be empty".to_string(),
231                        suggestion: "Examples: max_connections, shared_buffers, work_mem, effective_cache_size".to_string(),
232                    });
233                }
234            } else {
235                errors.push(ValidationError {
236                    tool: tool_name.to_string(),
237                    param: "setting_name".to_string(),
238                    error: "Required parameter 'setting_name' missing".to_string(),
239                    suggestion: "Include setting: {\"setting_name\": \"max_connections\"}".to_string(),
240                });
241            }
242        }
243
244        // Object Details Tool
245        "get_object_details" => {
246            if let Some(table) = arguments.get("table") {
247                if !table.is_string() {
248                    errors.push(ValidationError {
249                        tool: tool_name.to_string(),
250                        param: "table".to_string(),
251                        error: format!("Expected string, got {}", table.type_str()),
252                        suggestion: "Example: {\"table\": \"users\"}".to_string(),
253                    });
254                } else if table.as_str().unwrap().is_empty() {
255                    errors.push(ValidationError {
256                        tool: tool_name.to_string(),
257                        param: "table".to_string(),
258                        error: "Table name cannot be empty".to_string(),
259                        suggestion: "Provide a valid table name".to_string(),
260                    });
261                }
262            } else {
263                errors.push(ValidationError {
264                    tool: tool_name.to_string(),
265                    param: "table".to_string(),
266                    error: "Required parameter 'table' missing".to_string(),
267                    suggestion: "Include table name: {\"table\": \"users\"}".to_string(),
268                });
269            }
270
271            if let Some(schema) = arguments.get("schema") {
272                if !schema.is_string() {
273                    errors.push(ValidationError {
274                        tool: tool_name.to_string(),
275                        param: "schema".to_string(),
276                        error: format!("Expected string, got {}", schema.type_str()),
277                        suggestion: "Example: {\"schema\": \"public\"}".to_string(),
278                    });
279                }
280            }
281        }
282
283        _ => {
284            // Unknown tool - no specific validation
285        }
286    }
287
288    if errors.is_empty() {
289        Ok(())
290    } else {
291        Err(errors)
292    }
293}
294
295fn validate_batch_insert(tool_name: &str, arguments: &Value, errors: &mut Vec<ValidationError>) {
296    // Validate table
297    if let Some(table) = arguments.get("table") {
298        if !table.is_string() {
299            errors.push(ValidationError {
300                tool: tool_name.to_string(),
301                param: "table".to_string(),
302                error: format!("Expected string, got {}", table.type_str()),
303                suggestion: "Example: {\"table\": \"users\"}".to_string(),
304            });
305        } else if table.as_str().unwrap().is_empty() {
306            errors.push(ValidationError {
307                tool: tool_name.to_string(),
308                param: "table".to_string(),
309                error: "Table name cannot be empty".to_string(),
310                suggestion: "Provide a valid table name".to_string(),
311            });
312        }
313    } else {
314        errors.push(ValidationError {
315            tool: tool_name.to_string(),
316            param: "table".to_string(),
317            error: "Required parameter 'table' missing".to_string(),
318            suggestion: "Include table name: {\"table\": \"users\"}".to_string(),
319        });
320    }
321
322    // Validate columns
323    if let Some(columns) = arguments.get("columns") {
324        if !columns.is_array() {
325            errors.push(ValidationError {
326                tool: tool_name.to_string(),
327                param: "columns".to_string(),
328                error: format!("Expected array, got {}", columns.type_str()),
329                suggestion: "Example: {\"columns\": [\"email\", \"name\", \"created_at\"]}".to_string(),
330            });
331        } else {
332            let cols = columns.as_array().unwrap();
333            if cols.is_empty() {
334                errors.push(ValidationError {
335                    tool: tool_name.to_string(),
336                    param: "columns".to_string(),
337                    error: "Columns array cannot be empty".to_string(),
338                    suggestion: "Provide at least one column name".to_string(),
339                });
340            }
341            for (i, col) in cols.iter().enumerate() {
342                if !col.is_string() {
343                    errors.push(ValidationError {
344                        tool: tool_name.to_string(),
345                        param: format!("columns[{}]", i),
346                        error: format!("Expected string column name, got {}", col.type_str()),
347                        suggestion: "Column names must be strings".to_string(),
348                    });
349                }
350            }
351        }
352    } else {
353        errors.push(ValidationError {
354            tool: tool_name.to_string(),
355            param: "columns".to_string(),
356            error: "Required parameter 'columns' missing".to_string(),
357            suggestion: "Include column names: {\"columns\": [\"email\", \"name\"]}".to_string(),
358        });
359    }
360
361    // Validate rows
362    if let Some(rows) = arguments.get("rows") {
363        if !rows.is_array() {
364            errors.push(ValidationError {
365                tool: tool_name.to_string(),
366                param: "rows".to_string(),
367                error: format!("Expected array of arrays, got {}", rows.type_str()),
368                suggestion: "Example: {\"rows\": [[\"user@test.com\", \"John\"], [\"jane@test.com\", \"Jane\"]]}".to_string(),
369            });
370        } else {
371            let rows_arr = rows.as_array().unwrap();
372            if rows_arr.is_empty() {
373                errors.push(ValidationError {
374                    tool: tool_name.to_string(),
375                    param: "rows".to_string(),
376                    error: "Rows array cannot be empty".to_string(),
377                    suggestion: "Provide at least one row of data".to_string(),
378                });
379            } else {
380                let max_rows = if tool_name == "batch_insert" { 1000 } else { 100000 };
381                if rows_arr.len() > max_rows {
382                    errors.push(ValidationError {
383                        tool: tool_name.to_string(),
384                        param: "rows".to_string(),
385                        error: format!("Too many rows: {} (max {})", rows_arr.len(), max_rows),
386                        suggestion: format!("Split into multiple calls with max {} rows each", max_rows),
387                    });
388                }
389
390                // Validate each row is an array
391                for (i, row) in rows_arr.iter().enumerate() {
392                    if !row.is_array() {
393                        errors.push(ValidationError {
394                            tool: tool_name.to_string(),
395                            param: format!("rows[{}]", i),
396                            error: format!("Row must be array, got {}", row.type_str()),
397                            suggestion: "Each row must be an array of values".to_string(),
398                        });
399                    }
400                }
401            }
402        }
403    } else {
404        errors.push(ValidationError {
405            tool: tool_name.to_string(),
406            param: "rows".to_string(),
407            error: "Required parameter 'rows' missing".to_string(),
408            suggestion: "Include rows: {\"rows\": [[\"value1\", \"value2\"], ...]}".to_string(),
409        });
410    }
411
412    // Validate batch_size for batch_insert_copy
413    if tool_name == "batch_insert_copy" {
414        if let Some(batch_size) = arguments.get("batch_size") {
415            if !batch_size.is_number() {
416                errors.push(ValidationError {
417                    tool: tool_name.to_string(),
418                    param: "batch_size".to_string(),
419                    error: format!("Expected integer, got {}", batch_size.type_str()),
420                    suggestion: "Example: {\"batch_size\": 1000}".to_string(),
421                });
422            } else if let Some(size) = batch_size.as_i64() {
423                if !((100..=5000).contains(&size)) {
424                    errors.push(ValidationError {
425                        tool: tool_name.to_string(),
426                        param: "batch_size".to_string(),
427                        error: format!("Batch size {} out of range (must be 100-5000)", size),
428                        suggestion: "Use default (1000) or set between 100 and 5000".to_string(),
429                    });
430                }
431            }
432        }
433    }
434}
435
436/// Helper trait to get type name from Value
437trait ValueType {
438    fn type_str(&self) -> &str;
439}
440
441impl ValueType for Value {
442    fn type_str(&self) -> &str {
443        match self {
444            Value::Null => "null",
445            Value::Bool(_) => "boolean",
446            Value::Number(_) => "number",
447            Value::String(_) => "string",
448            Value::Array(_) => "array",
449            Value::Object(_) => "object",
450        }
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use serde_json::json;
458
459    #[test]
460    fn test_missing_required_param() {
461        let args = json!({});
462        let result = validate_tool_input("describe_table", &args);
463        assert!(result.is_err());
464    }
465
466    #[test]
467    fn test_invalid_type() {
468        let args = json!({"table": 123});
469        let result = validate_tool_input("describe_table", &args);
470        assert!(result.is_err());
471    }
472
473    #[test]
474    fn test_valid_input() {
475        let args = json!({"table": "users"});
476        let result = validate_tool_input("describe_table", &args);
477        assert!(result.is_ok());
478    }
479}