Skip to main content

mcp_postgres/
validation.rs

1/// Input validation for tool parameters
2/// Validates tool arguments against defined schemas and provides helpful error messages
3
4use serde_json::Value;
5
6#[derive(Debug, Clone)]
7pub struct ValidationError {
8    pub tool: String,
9    pub param: String,
10    pub error: String,
11    pub suggestion: String,
12}
13
14impl std::fmt::Display for ValidationError {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        write!(
17            f,
18            "āŒ Validation Error in tool '{}' parameter '{}': {}\nšŸ’” Suggestion: {}",
19            self.tool, self.param, self.error, self.suggestion
20        )
21    }
22}
23
24pub fn validate_tool_input(tool_name: &str, arguments: &Value) -> Result<(), Vec<ValidationError>> {
25    let mut errors = Vec::new();
26
27    match tool_name {
28        // Schema Inspection Tools
29        "list_tables" => {
30            // No required parameters
31        }
32        "describe_table" => {
33            if let Some(table) = arguments.get("table") {
34                if !table.is_string() {
35                    errors.push(ValidationError {
36                        tool: tool_name.to_string(),
37                        param: "table".to_string(),
38                        error: "Expected string, got ".to_string() + &table.type_str(),
39                        suggestion: "Example: {\"table\": \"users\"} or {\"table\": \"public.orders\"}".to_string(),
40                    });
41                } else {
42                    let table_str = table.as_str().unwrap();
43                    if table_str.is_empty() {
44                        errors.push(ValidationError {
45                            tool: tool_name.to_string(),
46                            param: "table".to_string(),
47                            error: "Table name cannot be empty".to_string(),
48                            suggestion: "Provide a valid table name like 'users' or 'public.products'".to_string(),
49                        });
50                    }
51                    if table_str.len() > 255 {
52                        errors.push(ValidationError {
53                            tool: tool_name.to_string(),
54                            param: "table".to_string(),
55                            error: format!("Table name too long: {} characters (max 255)", table_str.len()),
56                            suggestion: "Use a shorter table name".to_string(),
57                        });
58                    }
59                }
60            } else {
61                errors.push(ValidationError {
62                    tool: tool_name.to_string(),
63                    param: "table".to_string(),
64                    error: "Required parameter missing".to_string(),
65                    suggestion: "Include 'table' parameter: {\"table\": \"users\"}".to_string(),
66                });
67            }
68        }
69
70        // Query Execution Tools
71        "execute_query" | "execute_insert" | "execute_update" | "execute_delete" => {
72            if let Some(sql) = arguments.get("sql") {
73                if !sql.is_string() {
74                    errors.push(ValidationError {
75                        tool: tool_name.to_string(),
76                        param: "sql".to_string(),
77                        error: format!("Expected string SQL, got {}", sql.type_str()),
78                        suggestion: "Example: {\"sql\": \"SELECT * FROM users LIMIT 10\"}".to_string(),
79                    });
80                } else {
81                    let sql_str = sql.as_str().unwrap();
82                    if sql_str.is_empty() {
83                        errors.push(ValidationError {
84                            tool: tool_name.to_string(),
85                            param: "sql".to_string(),
86                            error: "SQL statement cannot be empty".to_string(),
87                            suggestion: "Provide a valid SQL statement".to_string(),
88                        });
89                    }
90                    if sql_str.len() > 10000 {
91                        errors.push(ValidationError {
92                            tool: tool_name.to_string(),
93                            param: "sql".to_string(),
94                            error: format!("SQL too long: {} characters (max 10,000)", sql_str.len()),
95                            suggestion: "Break the query into smaller parts or use a subquery".to_string(),
96                        });
97                    }
98
99                    // Validate SQL type for specific tools
100                    let sql_upper = sql_str.trim().to_uppercase();
101                    match tool_name {
102                        "execute_query" => {
103                            if !sql_upper.starts_with("SELECT") && !sql_upper.starts_with("WITH") {
104                                errors.push(ValidationError {
105                                    tool: tool_name.to_string(),
106                                    param: "sql".to_string(),
107                                    error: "execute_query requires a SELECT statement".to_string(),
108                                    suggestion: "Use 'execute_query' only for SELECT queries. Use 'execute_insert', 'execute_update', or 'execute_delete' for modifications.".to_string(),
109                                });
110                            }
111                        }
112                        "execute_insert" => {
113                            if !sql_upper.starts_with("INSERT") {
114                                errors.push(ValidationError {
115                                    tool: tool_name.to_string(),
116                                    param: "sql".to_string(),
117                                    error: "execute_insert requires an INSERT statement".to_string(),
118                                    suggestion: "Example: {\"sql\": \"INSERT INTO users (email) VALUES ('user@example.com')\"}".to_string(),
119                                });
120                            }
121                        }
122                        "execute_update" => {
123                            if !sql_upper.starts_with("UPDATE") {
124                                errors.push(ValidationError {
125                                    tool: tool_name.to_string(),
126                                    param: "sql".to_string(),
127                                    error: "execute_update requires an UPDATE statement".to_string(),
128                                    suggestion: "Example: {\"sql\": \"UPDATE users SET status = 'active' WHERE id = 1\"}".to_string(),
129                                });
130                            }
131                            if !sql_str.contains("WHERE") && !sql_str.contains("where") {
132                                errors.push(ValidationError {
133                                    tool: tool_name.to_string(),
134                                    param: "sql".to_string(),
135                                    error: "UPDATE without WHERE clause will modify all rows".to_string(),
136                                    suggestion: "Add a WHERE clause: UPDATE users SET ... WHERE <condition>".to_string(),
137                                });
138                            }
139                        }
140                        "execute_delete" => {
141                            if !sql_upper.starts_with("DELETE") {
142                                errors.push(ValidationError {
143                                    tool: tool_name.to_string(),
144                                    param: "sql".to_string(),
145                                    error: "execute_delete requires a DELETE statement".to_string(),
146                                    suggestion: "Example: {\"sql\": \"DELETE FROM users WHERE id = 999\"}".to_string(),
147                                });
148                            }
149                            if !sql_str.contains("WHERE") && !sql_str.contains("where") {
150                                errors.push(ValidationError {
151                                    tool: tool_name.to_string(),
152                                    param: "sql".to_string(),
153                                    error: "DELETE without WHERE clause will delete all rows".to_string(),
154                                    suggestion: "Add a WHERE clause: DELETE FROM users WHERE <condition>".to_string(),
155                                });
156                            }
157                        }
158                        _ => {}
159                    }
160                }
161            } else {
162                errors.push(ValidationError {
163                    tool: tool_name.to_string(),
164                    param: "sql".to_string(),
165                    error: "Required parameter 'sql' missing".to_string(),
166                    suggestion: format!("Include SQL: {{\"sql\": \"<{} statement>\"}}", tool_name.split('_').nth(1).unwrap_or("SQL")),
167                });
168            }
169        }
170
171        // Batch Insert Tools
172        "batch_insert" | "batch_insert_copy" => {
173            validate_batch_insert(tool_name, arguments, &mut errors);
174        }
175
176        // Explain Query Tool
177        "explain_query" => {
178            if let Some(sql) = arguments.get("sql") {
179                if !sql.is_string() {
180                    errors.push(ValidationError {
181                        tool: tool_name.to_string(),
182                        param: "sql".to_string(),
183                        error: "Expected string, got ".to_string() + &sql.type_str(),
184                        suggestion: "Example: {\"sql\": \"SELECT * FROM users\"}".to_string(),
185                    });
186                } else if sql.as_str().unwrap().is_empty() {
187                    errors.push(ValidationError {
188                        tool: tool_name.to_string(),
189                        param: "sql".to_string(),
190                        error: "SQL cannot be empty".to_string(),
191                        suggestion: "Provide a SELECT query to explain".to_string(),
192                    });
193                }
194            } else {
195                errors.push(ValidationError {
196                    tool: tool_name.to_string(),
197                    param: "sql".to_string(),
198                    error: "Required parameter 'sql' missing".to_string(),
199                    suggestion: "Include SQL: {\"sql\": \"SELECT * FROM users\"}".to_string(),
200                });
201            }
202
203            if let Some(format) = arguments.get("format") {
204                if let Some(fmt) = format.as_str() {
205                    if !["json", "text", "xml", "yaml"].contains(&fmt) {
206                        errors.push(ValidationError {
207                            tool: tool_name.to_string(),
208                            param: "format".to_string(),
209                            error: format!("Invalid format '{}' (must be json, text, xml, or yaml)", fmt),
210                            suggestion: "Use one of: json (default), text, xml, yaml".to_string(),
211                        });
212                    }
213                }
214            }
215        }
216
217        // Configuration Tool
218        "get_setting" => {
219            if let Some(setting) = arguments.get("setting_name") {
220                if !setting.is_string() {
221                    errors.push(ValidationError {
222                        tool: tool_name.to_string(),
223                        param: "setting_name".to_string(),
224                        error: format!("Expected string, got {}", setting.type_str()),
225                        suggestion: "Example: {\"setting_name\": \"max_connections\"}".to_string(),
226                    });
227                } else if setting.as_str().unwrap().is_empty() {
228                    errors.push(ValidationError {
229                        tool: tool_name.to_string(),
230                        param: "setting_name".to_string(),
231                        error: "Setting name cannot be empty".to_string(),
232                        suggestion: "Examples: max_connections, shared_buffers, work_mem, effective_cache_size".to_string(),
233                    });
234                }
235            } else {
236                errors.push(ValidationError {
237                    tool: tool_name.to_string(),
238                    param: "setting_name".to_string(),
239                    error: "Required parameter 'setting_name' missing".to_string(),
240                    suggestion: "Include setting: {\"setting_name\": \"max_connections\"}".to_string(),
241                });
242            }
243        }
244
245        // Object Details Tool
246        "get_object_details" => {
247            if let Some(table) = arguments.get("table") {
248                if !table.is_string() {
249                    errors.push(ValidationError {
250                        tool: tool_name.to_string(),
251                        param: "table".to_string(),
252                        error: format!("Expected string, got {}", table.type_str()),
253                        suggestion: "Example: {\"table\": \"users\"}".to_string(),
254                    });
255                } else if table.as_str().unwrap().is_empty() {
256                    errors.push(ValidationError {
257                        tool: tool_name.to_string(),
258                        param: "table".to_string(),
259                        error: "Table name cannot be empty".to_string(),
260                        suggestion: "Provide a valid table name".to_string(),
261                    });
262                }
263            } else {
264                errors.push(ValidationError {
265                    tool: tool_name.to_string(),
266                    param: "table".to_string(),
267                    error: "Required parameter 'table' missing".to_string(),
268                    suggestion: "Include table name: {\"table\": \"users\"}".to_string(),
269                });
270            }
271
272            if let Some(schema) = arguments.get("schema") {
273                if !schema.is_string() {
274                    errors.push(ValidationError {
275                        tool: tool_name.to_string(),
276                        param: "schema".to_string(),
277                        error: format!("Expected string, got {}", schema.type_str()),
278                        suggestion: "Example: {\"schema\": \"public\"}".to_string(),
279                    });
280                }
281            }
282        }
283
284        _ => {
285            // Unknown tool - no specific validation
286        }
287    }
288
289    if errors.is_empty() {
290        Ok(())
291    } else {
292        Err(errors)
293    }
294}
295
296fn validate_batch_insert(tool_name: &str, arguments: &Value, errors: &mut Vec<ValidationError>) {
297    // Validate table
298    if let Some(table) = arguments.get("table") {
299        if !table.is_string() {
300            errors.push(ValidationError {
301                tool: tool_name.to_string(),
302                param: "table".to_string(),
303                error: format!("Expected string, got {}", table.type_str()),
304                suggestion: "Example: {\"table\": \"users\"}".to_string(),
305            });
306        } else if table.as_str().unwrap().is_empty() {
307            errors.push(ValidationError {
308                tool: tool_name.to_string(),
309                param: "table".to_string(),
310                error: "Table name cannot be empty".to_string(),
311                suggestion: "Provide a valid table name".to_string(),
312            });
313        }
314    } else {
315        errors.push(ValidationError {
316            tool: tool_name.to_string(),
317            param: "table".to_string(),
318            error: "Required parameter 'table' missing".to_string(),
319            suggestion: "Include table name: {\"table\": \"users\"}".to_string(),
320        });
321    }
322
323    // Validate columns
324    if let Some(columns) = arguments.get("columns") {
325        if !columns.is_array() {
326            errors.push(ValidationError {
327                tool: tool_name.to_string(),
328                param: "columns".to_string(),
329                error: format!("Expected array, got {}", columns.type_str()),
330                suggestion: "Example: {\"columns\": [\"email\", \"name\", \"created_at\"]}".to_string(),
331            });
332        } else {
333            let cols = columns.as_array().unwrap();
334            if cols.is_empty() {
335                errors.push(ValidationError {
336                    tool: tool_name.to_string(),
337                    param: "columns".to_string(),
338                    error: "Columns array cannot be empty".to_string(),
339                    suggestion: "Provide at least one column name".to_string(),
340                });
341            }
342            for (i, col) in cols.iter().enumerate() {
343                if !col.is_string() {
344                    errors.push(ValidationError {
345                        tool: tool_name.to_string(),
346                        param: format!("columns[{}]", i),
347                        error: format!("Expected string column name, got {}", col.type_str()),
348                        suggestion: "Column names must be strings".to_string(),
349                    });
350                }
351            }
352        }
353    } else {
354        errors.push(ValidationError {
355            tool: tool_name.to_string(),
356            param: "columns".to_string(),
357            error: "Required parameter 'columns' missing".to_string(),
358            suggestion: "Include column names: {\"columns\": [\"email\", \"name\"]}".to_string(),
359        });
360    }
361
362    // Validate rows
363    if let Some(rows) = arguments.get("rows") {
364        if !rows.is_array() {
365            errors.push(ValidationError {
366                tool: tool_name.to_string(),
367                param: "rows".to_string(),
368                error: format!("Expected array of arrays, got {}", rows.type_str()),
369                suggestion: "Example: {\"rows\": [[\"user@test.com\", \"John\"], [\"jane@test.com\", \"Jane\"]]}".to_string(),
370            });
371        } else {
372            let rows_arr = rows.as_array().unwrap();
373            if rows_arr.is_empty() {
374                errors.push(ValidationError {
375                    tool: tool_name.to_string(),
376                    param: "rows".to_string(),
377                    error: "Rows array cannot be empty".to_string(),
378                    suggestion: "Provide at least one row of data".to_string(),
379                });
380            } else {
381                let max_rows = if tool_name == "batch_insert" { 1000 } else { 100000 };
382                if rows_arr.len() > max_rows {
383                    errors.push(ValidationError {
384                        tool: tool_name.to_string(),
385                        param: "rows".to_string(),
386                        error: format!("Too many rows: {} (max {})", rows_arr.len(), max_rows),
387                        suggestion: format!("Split into multiple calls with max {} rows each", max_rows),
388                    });
389                }
390
391                // Validate each row is an array
392                for (i, row) in rows_arr.iter().enumerate() {
393                    if !row.is_array() {
394                        errors.push(ValidationError {
395                            tool: tool_name.to_string(),
396                            param: format!("rows[{}]", i),
397                            error: format!("Row must be array, got {}", row.type_str()),
398                            suggestion: "Each row must be an array of values".to_string(),
399                        });
400                    }
401                }
402            }
403        }
404    } else {
405        errors.push(ValidationError {
406            tool: tool_name.to_string(),
407            param: "rows".to_string(),
408            error: "Required parameter 'rows' missing".to_string(),
409            suggestion: "Include rows: {\"rows\": [[\"value1\", \"value2\"], ...]}".to_string(),
410        });
411    }
412
413    // Validate batch_size for batch_insert_copy
414    if tool_name == "batch_insert_copy" {
415        if let Some(batch_size) = arguments.get("batch_size") {
416            if !batch_size.is_number() {
417                errors.push(ValidationError {
418                    tool: tool_name.to_string(),
419                    param: "batch_size".to_string(),
420                    error: format!("Expected integer, got {}", batch_size.type_str()),
421                    suggestion: "Example: {\"batch_size\": 1000}".to_string(),
422                });
423            } else if let Some(size) = batch_size.as_i64() {
424                if size < 100 || size > 5000 {
425                    errors.push(ValidationError {
426                        tool: tool_name.to_string(),
427                        param: "batch_size".to_string(),
428                        error: format!("Batch size {} out of range (must be 100-5000)", size),
429                        suggestion: "Use default (1000) or set between 100 and 5000".to_string(),
430                    });
431                }
432            }
433        }
434    }
435}
436
437/// Helper trait to get type name from Value
438trait ValueType {
439    fn type_str(&self) -> &str;
440}
441
442impl ValueType for Value {
443    fn type_str(&self) -> &str {
444        match self {
445            Value::Null => "null",
446            Value::Bool(_) => "boolean",
447            Value::Number(_) => "number",
448            Value::String(_) => "string",
449            Value::Array(_) => "array",
450            Value::Object(_) => "object",
451        }
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458    use serde_json::json;
459
460    #[test]
461    fn test_missing_required_param() {
462        let args = json!({});
463        let result = validate_tool_input("describe_table", &args);
464        assert!(result.is_err());
465    }
466
467    #[test]
468    fn test_invalid_type() {
469        let args = json!({"table": 123});
470        let result = validate_tool_input("describe_table", &args);
471        assert!(result.is_err());
472    }
473
474    #[test]
475    fn test_valid_input() {
476        let args = json!({"table": "users"});
477        let result = validate_tool_input("describe_table", &args);
478        assert!(result.is_ok());
479    }
480}