1use serde_json::Value;
5
6#[derive(Debug, Clone)]
7pub struct ValidationError {
8 pub tool: String,
9 pub param: String,
10 pub error: String,
11 pub suggestion: String,
12}
13
14impl std::fmt::Display for ValidationError {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 write!(
17 f,
18 "ā Validation Error in tool '{}' parameter '{}': {}\nš” Suggestion: {}",
19 self.tool, self.param, self.error, self.suggestion
20 )
21 }
22}
23
24pub fn validate_tool_input(tool_name: &str, arguments: &Value) -> Result<(), Vec<ValidationError>> {
25 let mut errors = Vec::new();
26
27 match tool_name {
28 "list_tables" => {
30 }
32 "describe_table" => {
33 if let Some(table) = arguments.get("table") {
34 if !table.is_string() {
35 errors.push(ValidationError {
36 tool: tool_name.to_string(),
37 param: "table".to_string(),
38 error: "Expected string, got ".to_string() + &table.type_str(),
39 suggestion: "Example: {\"table\": \"users\"} or {\"table\": \"public.orders\"}".to_string(),
40 });
41 } else {
42 let table_str = table.as_str().unwrap();
43 if table_str.is_empty() {
44 errors.push(ValidationError {
45 tool: tool_name.to_string(),
46 param: "table".to_string(),
47 error: "Table name cannot be empty".to_string(),
48 suggestion: "Provide a valid table name like 'users' or 'public.products'".to_string(),
49 });
50 }
51 if table_str.len() > 255 {
52 errors.push(ValidationError {
53 tool: tool_name.to_string(),
54 param: "table".to_string(),
55 error: format!("Table name too long: {} characters (max 255)", table_str.len()),
56 suggestion: "Use a shorter table name".to_string(),
57 });
58 }
59 }
60 } else {
61 errors.push(ValidationError {
62 tool: tool_name.to_string(),
63 param: "table".to_string(),
64 error: "Required parameter missing".to_string(),
65 suggestion: "Include 'table' parameter: {\"table\": \"users\"}".to_string(),
66 });
67 }
68 }
69
70 "execute_query" | "execute_insert" | "execute_update" | "execute_delete" => {
72 if let Some(sql) = arguments.get("sql") {
73 if !sql.is_string() {
74 errors.push(ValidationError {
75 tool: tool_name.to_string(),
76 param: "sql".to_string(),
77 error: format!("Expected string SQL, got {}", sql.type_str()),
78 suggestion: "Example: {\"sql\": \"SELECT * FROM users LIMIT 10\"}".to_string(),
79 });
80 } else {
81 let sql_str = sql.as_str().unwrap();
82 if sql_str.is_empty() {
83 errors.push(ValidationError {
84 tool: tool_name.to_string(),
85 param: "sql".to_string(),
86 error: "SQL statement cannot be empty".to_string(),
87 suggestion: "Provide a valid SQL statement".to_string(),
88 });
89 }
90 if sql_str.len() > 10000 {
91 errors.push(ValidationError {
92 tool: tool_name.to_string(),
93 param: "sql".to_string(),
94 error: format!("SQL too long: {} characters (max 10,000)", sql_str.len()),
95 suggestion: "Break the query into smaller parts or use a subquery".to_string(),
96 });
97 }
98
99 let sql_upper = sql_str.trim().to_uppercase();
101 match tool_name {
102 "execute_query" => {
103 if !sql_upper.starts_with("SELECT") && !sql_upper.starts_with("WITH") {
104 errors.push(ValidationError {
105 tool: tool_name.to_string(),
106 param: "sql".to_string(),
107 error: "execute_query requires a SELECT statement".to_string(),
108 suggestion: "Use 'execute_query' only for SELECT queries. Use 'execute_insert', 'execute_update', or 'execute_delete' for modifications.".to_string(),
109 });
110 }
111 }
112 "execute_insert" => {
113 if !sql_upper.starts_with("INSERT") {
114 errors.push(ValidationError {
115 tool: tool_name.to_string(),
116 param: "sql".to_string(),
117 error: "execute_insert requires an INSERT statement".to_string(),
118 suggestion: "Example: {\"sql\": \"INSERT INTO users (email) VALUES ('user@example.com')\"}".to_string(),
119 });
120 }
121 }
122 "execute_update" => {
123 if !sql_upper.starts_with("UPDATE") {
124 errors.push(ValidationError {
125 tool: tool_name.to_string(),
126 param: "sql".to_string(),
127 error: "execute_update requires an UPDATE statement".to_string(),
128 suggestion: "Example: {\"sql\": \"UPDATE users SET status = 'active' WHERE id = 1\"}".to_string(),
129 });
130 }
131 if !sql_str.contains("WHERE") && !sql_str.contains("where") {
132 errors.push(ValidationError {
133 tool: tool_name.to_string(),
134 param: "sql".to_string(),
135 error: "UPDATE without WHERE clause will modify all rows".to_string(),
136 suggestion: "Add a WHERE clause: UPDATE users SET ... WHERE <condition>".to_string(),
137 });
138 }
139 }
140 "execute_delete" => {
141 if !sql_upper.starts_with("DELETE") {
142 errors.push(ValidationError {
143 tool: tool_name.to_string(),
144 param: "sql".to_string(),
145 error: "execute_delete requires a DELETE statement".to_string(),
146 suggestion: "Example: {\"sql\": \"DELETE FROM users WHERE id = 999\"}".to_string(),
147 });
148 }
149 if !sql_str.contains("WHERE") && !sql_str.contains("where") {
150 errors.push(ValidationError {
151 tool: tool_name.to_string(),
152 param: "sql".to_string(),
153 error: "DELETE without WHERE clause will delete all rows".to_string(),
154 suggestion: "Add a WHERE clause: DELETE FROM users WHERE <condition>".to_string(),
155 });
156 }
157 }
158 _ => {}
159 }
160 }
161 } else {
162 errors.push(ValidationError {
163 tool: tool_name.to_string(),
164 param: "sql".to_string(),
165 error: "Required parameter 'sql' missing".to_string(),
166 suggestion: format!("Include SQL: {{\"sql\": \"<{} statement>\"}}", tool_name.split('_').nth(1).unwrap_or("SQL")),
167 });
168 }
169 }
170
171 "batch_insert" | "batch_insert_copy" => {
173 validate_batch_insert(tool_name, arguments, &mut errors);
174 }
175
176 "explain_query" => {
178 if let Some(sql) = arguments.get("sql") {
179 if !sql.is_string() {
180 errors.push(ValidationError {
181 tool: tool_name.to_string(),
182 param: "sql".to_string(),
183 error: "Expected string, got ".to_string() + &sql.type_str(),
184 suggestion: "Example: {\"sql\": \"SELECT * FROM users\"}".to_string(),
185 });
186 } else if sql.as_str().unwrap().is_empty() {
187 errors.push(ValidationError {
188 tool: tool_name.to_string(),
189 param: "sql".to_string(),
190 error: "SQL cannot be empty".to_string(),
191 suggestion: "Provide a SELECT query to explain".to_string(),
192 });
193 }
194 } else {
195 errors.push(ValidationError {
196 tool: tool_name.to_string(),
197 param: "sql".to_string(),
198 error: "Required parameter 'sql' missing".to_string(),
199 suggestion: "Include SQL: {\"sql\": \"SELECT * FROM users\"}".to_string(),
200 });
201 }
202
203 if let Some(format) = arguments.get("format") {
204 if let Some(fmt) = format.as_str() {
205 if !["json", "text", "xml", "yaml"].contains(&fmt) {
206 errors.push(ValidationError {
207 tool: tool_name.to_string(),
208 param: "format".to_string(),
209 error: format!("Invalid format '{}' (must be json, text, xml, or yaml)", fmt),
210 suggestion: "Use one of: json (default), text, xml, yaml".to_string(),
211 });
212 }
213 }
214 }
215 }
216
217 "get_setting" => {
219 if let Some(setting) = arguments.get("setting_name") {
220 if !setting.is_string() {
221 errors.push(ValidationError {
222 tool: tool_name.to_string(),
223 param: "setting_name".to_string(),
224 error: format!("Expected string, got {}", setting.type_str()),
225 suggestion: "Example: {\"setting_name\": \"max_connections\"}".to_string(),
226 });
227 } else if setting.as_str().unwrap().is_empty() {
228 errors.push(ValidationError {
229 tool: tool_name.to_string(),
230 param: "setting_name".to_string(),
231 error: "Setting name cannot be empty".to_string(),
232 suggestion: "Examples: max_connections, shared_buffers, work_mem, effective_cache_size".to_string(),
233 });
234 }
235 } else {
236 errors.push(ValidationError {
237 tool: tool_name.to_string(),
238 param: "setting_name".to_string(),
239 error: "Required parameter 'setting_name' missing".to_string(),
240 suggestion: "Include setting: {\"setting_name\": \"max_connections\"}".to_string(),
241 });
242 }
243 }
244
245 "get_object_details" => {
247 if let Some(table) = arguments.get("table") {
248 if !table.is_string() {
249 errors.push(ValidationError {
250 tool: tool_name.to_string(),
251 param: "table".to_string(),
252 error: format!("Expected string, got {}", table.type_str()),
253 suggestion: "Example: {\"table\": \"users\"}".to_string(),
254 });
255 } else if table.as_str().unwrap().is_empty() {
256 errors.push(ValidationError {
257 tool: tool_name.to_string(),
258 param: "table".to_string(),
259 error: "Table name cannot be empty".to_string(),
260 suggestion: "Provide a valid table name".to_string(),
261 });
262 }
263 } else {
264 errors.push(ValidationError {
265 tool: tool_name.to_string(),
266 param: "table".to_string(),
267 error: "Required parameter 'table' missing".to_string(),
268 suggestion: "Include table name: {\"table\": \"users\"}".to_string(),
269 });
270 }
271
272 if let Some(schema) = arguments.get("schema") {
273 if !schema.is_string() {
274 errors.push(ValidationError {
275 tool: tool_name.to_string(),
276 param: "schema".to_string(),
277 error: format!("Expected string, got {}", schema.type_str()),
278 suggestion: "Example: {\"schema\": \"public\"}".to_string(),
279 });
280 }
281 }
282 }
283
284 _ => {
285 }
287 }
288
289 if errors.is_empty() {
290 Ok(())
291 } else {
292 Err(errors)
293 }
294}
295
296fn validate_batch_insert(tool_name: &str, arguments: &Value, errors: &mut Vec<ValidationError>) {
297 if let Some(table) = arguments.get("table") {
299 if !table.is_string() {
300 errors.push(ValidationError {
301 tool: tool_name.to_string(),
302 param: "table".to_string(),
303 error: format!("Expected string, got {}", table.type_str()),
304 suggestion: "Example: {\"table\": \"users\"}".to_string(),
305 });
306 } else if table.as_str().unwrap().is_empty() {
307 errors.push(ValidationError {
308 tool: tool_name.to_string(),
309 param: "table".to_string(),
310 error: "Table name cannot be empty".to_string(),
311 suggestion: "Provide a valid table name".to_string(),
312 });
313 }
314 } else {
315 errors.push(ValidationError {
316 tool: tool_name.to_string(),
317 param: "table".to_string(),
318 error: "Required parameter 'table' missing".to_string(),
319 suggestion: "Include table name: {\"table\": \"users\"}".to_string(),
320 });
321 }
322
323 if let Some(columns) = arguments.get("columns") {
325 if !columns.is_array() {
326 errors.push(ValidationError {
327 tool: tool_name.to_string(),
328 param: "columns".to_string(),
329 error: format!("Expected array, got {}", columns.type_str()),
330 suggestion: "Example: {\"columns\": [\"email\", \"name\", \"created_at\"]}".to_string(),
331 });
332 } else {
333 let cols = columns.as_array().unwrap();
334 if cols.is_empty() {
335 errors.push(ValidationError {
336 tool: tool_name.to_string(),
337 param: "columns".to_string(),
338 error: "Columns array cannot be empty".to_string(),
339 suggestion: "Provide at least one column name".to_string(),
340 });
341 }
342 for (i, col) in cols.iter().enumerate() {
343 if !col.is_string() {
344 errors.push(ValidationError {
345 tool: tool_name.to_string(),
346 param: format!("columns[{}]", i),
347 error: format!("Expected string column name, got {}", col.type_str()),
348 suggestion: "Column names must be strings".to_string(),
349 });
350 }
351 }
352 }
353 } else {
354 errors.push(ValidationError {
355 tool: tool_name.to_string(),
356 param: "columns".to_string(),
357 error: "Required parameter 'columns' missing".to_string(),
358 suggestion: "Include column names: {\"columns\": [\"email\", \"name\"]}".to_string(),
359 });
360 }
361
362 if let Some(rows) = arguments.get("rows") {
364 if !rows.is_array() {
365 errors.push(ValidationError {
366 tool: tool_name.to_string(),
367 param: "rows".to_string(),
368 error: format!("Expected array of arrays, got {}", rows.type_str()),
369 suggestion: "Example: {\"rows\": [[\"user@test.com\", \"John\"], [\"jane@test.com\", \"Jane\"]]}".to_string(),
370 });
371 } else {
372 let rows_arr = rows.as_array().unwrap();
373 if rows_arr.is_empty() {
374 errors.push(ValidationError {
375 tool: tool_name.to_string(),
376 param: "rows".to_string(),
377 error: "Rows array cannot be empty".to_string(),
378 suggestion: "Provide at least one row of data".to_string(),
379 });
380 } else {
381 let max_rows = if tool_name == "batch_insert" { 1000 } else { 100000 };
382 if rows_arr.len() > max_rows {
383 errors.push(ValidationError {
384 tool: tool_name.to_string(),
385 param: "rows".to_string(),
386 error: format!("Too many rows: {} (max {})", rows_arr.len(), max_rows),
387 suggestion: format!("Split into multiple calls with max {} rows each", max_rows),
388 });
389 }
390
391 for (i, row) in rows_arr.iter().enumerate() {
393 if !row.is_array() {
394 errors.push(ValidationError {
395 tool: tool_name.to_string(),
396 param: format!("rows[{}]", i),
397 error: format!("Row must be array, got {}", row.type_str()),
398 suggestion: "Each row must be an array of values".to_string(),
399 });
400 }
401 }
402 }
403 }
404 } else {
405 errors.push(ValidationError {
406 tool: tool_name.to_string(),
407 param: "rows".to_string(),
408 error: "Required parameter 'rows' missing".to_string(),
409 suggestion: "Include rows: {\"rows\": [[\"value1\", \"value2\"], ...]}".to_string(),
410 });
411 }
412
413 if tool_name == "batch_insert_copy" {
415 if let Some(batch_size) = arguments.get("batch_size") {
416 if !batch_size.is_number() {
417 errors.push(ValidationError {
418 tool: tool_name.to_string(),
419 param: "batch_size".to_string(),
420 error: format!("Expected integer, got {}", batch_size.type_str()),
421 suggestion: "Example: {\"batch_size\": 1000}".to_string(),
422 });
423 } else if let Some(size) = batch_size.as_i64() {
424 if size < 100 || size > 5000 {
425 errors.push(ValidationError {
426 tool: tool_name.to_string(),
427 param: "batch_size".to_string(),
428 error: format!("Batch size {} out of range (must be 100-5000)", size),
429 suggestion: "Use default (1000) or set between 100 and 5000".to_string(),
430 });
431 }
432 }
433 }
434 }
435}
436
437trait ValueType {
439 fn type_str(&self) -> &str;
440}
441
442impl ValueType for Value {
443 fn type_str(&self) -> &str {
444 match self {
445 Value::Null => "null",
446 Value::Bool(_) => "boolean",
447 Value::Number(_) => "number",
448 Value::String(_) => "string",
449 Value::Array(_) => "array",
450 Value::Object(_) => "object",
451 }
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458 use serde_json::json;
459
460 #[test]
461 fn test_missing_required_param() {
462 let args = json!({});
463 let result = validate_tool_input("describe_table", &args);
464 assert!(result.is_err());
465 }
466
467 #[test]
468 fn test_invalid_type() {
469 let args = json!({"table": 123});
470 let result = validate_tool_input("describe_table", &args);
471 assert!(result.is_err());
472 }
473
474 #[test]
475 fn test_valid_input() {
476 let args = json!({"table": "users"});
477 let result = validate_tool_input("describe_table", &args);
478 assert!(result.is_ok());
479 }
480}