1use serde_json::Value;
4
5#[derive(Debug, Clone)]
6pub struct ValidationError {
7 pub tool: String,
8 pub param: String,
9 pub error: String,
10 pub suggestion: String,
11}
12
13impl std::fmt::Display for ValidationError {
14 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15 write!(
16 f,
17 "ā Validation Error in tool '{}' parameter '{}': {}\nš” Suggestion: {}",
18 self.tool, self.param, self.error, self.suggestion
19 )
20 }
21}
22
23pub fn validate_tool_input(tool_name: &str, arguments: &Value) -> Result<(), Vec<ValidationError>> {
24 let mut errors = Vec::new();
25
26 match tool_name {
27 "list_tables" => {
29 }
31 "describe_table" => {
32 if let Some(table) = arguments.get("table") {
33 if !table.is_string() {
34 errors.push(ValidationError {
35 tool: tool_name.to_string(),
36 param: "table".to_string(),
37 error: format!("Expected string, got {}", table.type_str()),
38 suggestion: "Example: {\"table\": \"users\"} or {\"table\": \"public.orders\"}".to_string(),
39 });
40 } else {
41 let table_str = table.as_str().unwrap();
42 if table_str.is_empty() {
43 errors.push(ValidationError {
44 tool: tool_name.to_string(),
45 param: "table".to_string(),
46 error: "Table name cannot be empty".to_string(),
47 suggestion: "Provide a valid table name like 'users' or 'public.products'".to_string(),
48 });
49 }
50 if table_str.len() > 255 {
51 errors.push(ValidationError {
52 tool: tool_name.to_string(),
53 param: "table".to_string(),
54 error: format!("Table name too long: {} characters (max 255)", table_str.len()),
55 suggestion: "Use a shorter table name".to_string(),
56 });
57 }
58 }
59 } else {
60 errors.push(ValidationError {
61 tool: tool_name.to_string(),
62 param: "table".to_string(),
63 error: "Required parameter missing".to_string(),
64 suggestion: "Include 'table' parameter: {\"table\": \"users\"}".to_string(),
65 });
66 }
67 }
68
69 "execute_query" | "execute_insert" | "execute_update" | "execute_delete" => {
71 if let Some(sql) = arguments.get("sql") {
72 if !sql.is_string() {
73 errors.push(ValidationError {
74 tool: tool_name.to_string(),
75 param: "sql".to_string(),
76 error: format!("Expected string SQL, got {}", sql.type_str()),
77 suggestion: "Example: {\"sql\": \"SELECT * FROM users LIMIT 10\"}".to_string(),
78 });
79 } else {
80 let sql_str = sql.as_str().unwrap();
81 if sql_str.is_empty() {
82 errors.push(ValidationError {
83 tool: tool_name.to_string(),
84 param: "sql".to_string(),
85 error: "SQL statement cannot be empty".to_string(),
86 suggestion: "Provide a valid SQL statement".to_string(),
87 });
88 }
89 if sql_str.len() > 10000 {
90 errors.push(ValidationError {
91 tool: tool_name.to_string(),
92 param: "sql".to_string(),
93 error: format!("SQL too long: {} characters (max 10,000)", sql_str.len()),
94 suggestion: "Break the query into smaller parts or use a subquery".to_string(),
95 });
96 }
97
98 let sql_upper = sql_str.trim().to_uppercase();
100 match tool_name {
101 "execute_query" => {
102 if !sql_upper.starts_with("SELECT") && !sql_upper.starts_with("WITH") {
103 errors.push(ValidationError {
104 tool: tool_name.to_string(),
105 param: "sql".to_string(),
106 error: "execute_query requires a SELECT statement".to_string(),
107 suggestion: "Use 'execute_query' only for SELECT queries. Use 'execute_insert', 'execute_update', or 'execute_delete' for modifications.".to_string(),
108 });
109 }
110 }
111 "execute_insert" => {
112 if !sql_upper.starts_with("INSERT") {
113 errors.push(ValidationError {
114 tool: tool_name.to_string(),
115 param: "sql".to_string(),
116 error: "execute_insert requires an INSERT statement".to_string(),
117 suggestion: "Example: {\"sql\": \"INSERT INTO users (email) VALUES ('user@example.com')\"}".to_string(),
118 });
119 }
120 }
121 "execute_update" => {
122 if !sql_upper.starts_with("UPDATE") {
123 errors.push(ValidationError {
124 tool: tool_name.to_string(),
125 param: "sql".to_string(),
126 error: "execute_update requires an UPDATE statement".to_string(),
127 suggestion: "Example: {\"sql\": \"UPDATE users SET status = 'active' WHERE id = 1\"}".to_string(),
128 });
129 }
130 if !sql_str.contains("WHERE") && !sql_str.contains("where") {
131 errors.push(ValidationError {
132 tool: tool_name.to_string(),
133 param: "sql".to_string(),
134 error: "UPDATE without WHERE clause will modify all rows".to_string(),
135 suggestion: "Add a WHERE clause: UPDATE users SET ... WHERE <condition>".to_string(),
136 });
137 }
138 }
139 "execute_delete" => {
140 if !sql_upper.starts_with("DELETE") {
141 errors.push(ValidationError {
142 tool: tool_name.to_string(),
143 param: "sql".to_string(),
144 error: "execute_delete requires a DELETE statement".to_string(),
145 suggestion: "Example: {\"sql\": \"DELETE FROM users WHERE id = 999\"}".to_string(),
146 });
147 }
148 if !sql_str.contains("WHERE") && !sql_str.contains("where") {
149 errors.push(ValidationError {
150 tool: tool_name.to_string(),
151 param: "sql".to_string(),
152 error: "DELETE without WHERE clause will delete all rows".to_string(),
153 suggestion: "Add a WHERE clause: DELETE FROM users WHERE <condition>".to_string(),
154 });
155 }
156 }
157 _ => {}
158 }
159 }
160 } else {
161 errors.push(ValidationError {
162 tool: tool_name.to_string(),
163 param: "sql".to_string(),
164 error: "Required parameter 'sql' missing".to_string(),
165 suggestion: format!("Include SQL: {{\"sql\": \"<{} statement>\"}}", tool_name.split('_').nth(1).unwrap_or("SQL")),
166 });
167 }
168 }
169
170 "batch_insert" | "batch_insert_copy" => {
172 validate_batch_insert(tool_name, arguments, &mut errors);
173 }
174
175 "explain_query" => {
177 if let Some(sql) = arguments.get("sql") {
178 if !sql.is_string() {
179 errors.push(ValidationError {
180 tool: tool_name.to_string(),
181 param: "sql".to_string(),
182 error: format!("Expected string, got {}", sql.type_str()),
183 suggestion: "Example: {\"sql\": \"SELECT * FROM users\"}".to_string(),
184 });
185 } else if sql.as_str().unwrap().is_empty() {
186 errors.push(ValidationError {
187 tool: tool_name.to_string(),
188 param: "sql".to_string(),
189 error: "SQL cannot be empty".to_string(),
190 suggestion: "Provide a SELECT query to explain".to_string(),
191 });
192 }
193 } else {
194 errors.push(ValidationError {
195 tool: tool_name.to_string(),
196 param: "sql".to_string(),
197 error: "Required parameter 'sql' missing".to_string(),
198 suggestion: "Include SQL: {\"sql\": \"SELECT * FROM users\"}".to_string(),
199 });
200 }
201
202 if let Some(format) = arguments.get("format") {
203 if let Some(fmt) = format.as_str() {
204 if !["json", "text", "xml", "yaml"].contains(&fmt) {
205 errors.push(ValidationError {
206 tool: tool_name.to_string(),
207 param: "format".to_string(),
208 error: format!("Invalid format '{}' (must be json, text, xml, or yaml)", fmt),
209 suggestion: "Use one of: json (default), text, xml, yaml".to_string(),
210 });
211 }
212 }
213 }
214 }
215
216 "get_setting" => {
218 if let Some(setting) = arguments.get("setting_name") {
219 if !setting.is_string() {
220 errors.push(ValidationError {
221 tool: tool_name.to_string(),
222 param: "setting_name".to_string(),
223 error: format!("Expected string, got {}", setting.type_str()),
224 suggestion: "Example: {\"setting_name\": \"max_connections\"}".to_string(),
225 });
226 } else if setting.as_str().unwrap().is_empty() {
227 errors.push(ValidationError {
228 tool: tool_name.to_string(),
229 param: "setting_name".to_string(),
230 error: "Setting name cannot be empty".to_string(),
231 suggestion: "Examples: max_connections, shared_buffers, work_mem, effective_cache_size".to_string(),
232 });
233 }
234 } else {
235 errors.push(ValidationError {
236 tool: tool_name.to_string(),
237 param: "setting_name".to_string(),
238 error: "Required parameter 'setting_name' missing".to_string(),
239 suggestion: "Include setting: {\"setting_name\": \"max_connections\"}".to_string(),
240 });
241 }
242 }
243
244 "get_object_details" => {
246 if let Some(table) = arguments.get("table") {
247 if !table.is_string() {
248 errors.push(ValidationError {
249 tool: tool_name.to_string(),
250 param: "table".to_string(),
251 error: format!("Expected string, got {}", table.type_str()),
252 suggestion: "Example: {\"table\": \"users\"}".to_string(),
253 });
254 } else if table.as_str().unwrap().is_empty() {
255 errors.push(ValidationError {
256 tool: tool_name.to_string(),
257 param: "table".to_string(),
258 error: "Table name cannot be empty".to_string(),
259 suggestion: "Provide a valid table name".to_string(),
260 });
261 }
262 } else {
263 errors.push(ValidationError {
264 tool: tool_name.to_string(),
265 param: "table".to_string(),
266 error: "Required parameter 'table' missing".to_string(),
267 suggestion: "Include table name: {\"table\": \"users\"}".to_string(),
268 });
269 }
270
271 if let Some(schema) = arguments.get("schema") {
272 if !schema.is_string() {
273 errors.push(ValidationError {
274 tool: tool_name.to_string(),
275 param: "schema".to_string(),
276 error: format!("Expected string, got {}", schema.type_str()),
277 suggestion: "Example: {\"schema\": \"public\"}".to_string(),
278 });
279 }
280 }
281 }
282
283 _ => {
284 }
286 }
287
288 if errors.is_empty() {
289 Ok(())
290 } else {
291 Err(errors)
292 }
293}
294
295fn validate_batch_insert(tool_name: &str, arguments: &Value, errors: &mut Vec<ValidationError>) {
296 if let Some(table) = arguments.get("table") {
298 if !table.is_string() {
299 errors.push(ValidationError {
300 tool: tool_name.to_string(),
301 param: "table".to_string(),
302 error: format!("Expected string, got {}", table.type_str()),
303 suggestion: "Example: {\"table\": \"users\"}".to_string(),
304 });
305 } else if table.as_str().unwrap().is_empty() {
306 errors.push(ValidationError {
307 tool: tool_name.to_string(),
308 param: "table".to_string(),
309 error: "Table name cannot be empty".to_string(),
310 suggestion: "Provide a valid table name".to_string(),
311 });
312 }
313 } else {
314 errors.push(ValidationError {
315 tool: tool_name.to_string(),
316 param: "table".to_string(),
317 error: "Required parameter 'table' missing".to_string(),
318 suggestion: "Include table name: {\"table\": \"users\"}".to_string(),
319 });
320 }
321
322 if let Some(columns) = arguments.get("columns") {
324 if !columns.is_array() {
325 errors.push(ValidationError {
326 tool: tool_name.to_string(),
327 param: "columns".to_string(),
328 error: format!("Expected array, got {}", columns.type_str()),
329 suggestion: "Example: {\"columns\": [\"email\", \"name\", \"created_at\"]}".to_string(),
330 });
331 } else {
332 let cols = columns.as_array().unwrap();
333 if cols.is_empty() {
334 errors.push(ValidationError {
335 tool: tool_name.to_string(),
336 param: "columns".to_string(),
337 error: "Columns array cannot be empty".to_string(),
338 suggestion: "Provide at least one column name".to_string(),
339 });
340 }
341 for (i, col) in cols.iter().enumerate() {
342 if !col.is_string() {
343 errors.push(ValidationError {
344 tool: tool_name.to_string(),
345 param: format!("columns[{}]", i),
346 error: format!("Expected string column name, got {}", col.type_str()),
347 suggestion: "Column names must be strings".to_string(),
348 });
349 }
350 }
351 }
352 } else {
353 errors.push(ValidationError {
354 tool: tool_name.to_string(),
355 param: "columns".to_string(),
356 error: "Required parameter 'columns' missing".to_string(),
357 suggestion: "Include column names: {\"columns\": [\"email\", \"name\"]}".to_string(),
358 });
359 }
360
361 if let Some(rows) = arguments.get("rows") {
363 if !rows.is_array() {
364 errors.push(ValidationError {
365 tool: tool_name.to_string(),
366 param: "rows".to_string(),
367 error: format!("Expected array of arrays, got {}", rows.type_str()),
368 suggestion: "Example: {\"rows\": [[\"user@test.com\", \"John\"], [\"jane@test.com\", \"Jane\"]]}".to_string(),
369 });
370 } else {
371 let rows_arr = rows.as_array().unwrap();
372 if rows_arr.is_empty() {
373 errors.push(ValidationError {
374 tool: tool_name.to_string(),
375 param: "rows".to_string(),
376 error: "Rows array cannot be empty".to_string(),
377 suggestion: "Provide at least one row of data".to_string(),
378 });
379 } else {
380 let max_rows = if tool_name == "batch_insert" { 1000 } else { 100000 };
381 if rows_arr.len() > max_rows {
382 errors.push(ValidationError {
383 tool: tool_name.to_string(),
384 param: "rows".to_string(),
385 error: format!("Too many rows: {} (max {})", rows_arr.len(), max_rows),
386 suggestion: format!("Split into multiple calls with max {} rows each", max_rows),
387 });
388 }
389
390 for (i, row) in rows_arr.iter().enumerate() {
392 if !row.is_array() {
393 errors.push(ValidationError {
394 tool: tool_name.to_string(),
395 param: format!("rows[{}]", i),
396 error: format!("Row must be array, got {}", row.type_str()),
397 suggestion: "Each row must be an array of values".to_string(),
398 });
399 }
400 }
401 }
402 }
403 } else {
404 errors.push(ValidationError {
405 tool: tool_name.to_string(),
406 param: "rows".to_string(),
407 error: "Required parameter 'rows' missing".to_string(),
408 suggestion: "Include rows: {\"rows\": [[\"value1\", \"value2\"], ...]}".to_string(),
409 });
410 }
411
412 if tool_name == "batch_insert_copy" {
414 if let Some(batch_size) = arguments.get("batch_size") {
415 if !batch_size.is_number() {
416 errors.push(ValidationError {
417 tool: tool_name.to_string(),
418 param: "batch_size".to_string(),
419 error: format!("Expected integer, got {}", batch_size.type_str()),
420 suggestion: "Example: {\"batch_size\": 1000}".to_string(),
421 });
422 } else if let Some(size) = batch_size.as_i64() {
423 if !((100..=5000).contains(&size)) {
424 errors.push(ValidationError {
425 tool: tool_name.to_string(),
426 param: "batch_size".to_string(),
427 error: format!("Batch size {} out of range (must be 100-5000)", size),
428 suggestion: "Use default (1000) or set between 100 and 5000".to_string(),
429 });
430 }
431 }
432 }
433 }
434}
435
436trait ValueType {
438 fn type_str(&self) -> &str;
439}
440
441impl ValueType for Value {
442 fn type_str(&self) -> &str {
443 match self {
444 Value::Null => "null",
445 Value::Bool(_) => "boolean",
446 Value::Number(_) => "number",
447 Value::String(_) => "string",
448 Value::Array(_) => "array",
449 Value::Object(_) => "object",
450 }
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use serde_json::json;
458
459 #[test]
460 fn test_missing_required_param() {
461 let args = json!({});
462 let result = validate_tool_input("describe_table", &args);
463 assert!(result.is_err());
464 }
465
466 #[test]
467 fn test_invalid_type() {
468 let args = json!({"table": 123});
469 let result = validate_tool_input("describe_table", &args);
470 assert!(result.is_err());
471 }
472
473 #[test]
474 fn test_valid_input() {
475 let args = json!({"table": "users"});
476 let result = validate_tool_input("describe_table", &args);
477 assert!(result.is_ok());
478 }
479}