1use crate::ast::{Expr, Qail};
7use std::collections::HashMap;
8use strsim::levenshtein;
9
10#[derive(Debug, Clone, PartialEq)]
12pub enum ValidationError {
13 TableNotFound {
14 table: String,
15 suggestion: Option<String>,
16 },
17 ColumnNotFound {
18 table: String,
19 column: String,
20 suggestion: Option<String>,
21 },
22 TypeMismatch {
24 table: String,
25 column: String,
26 expected: String,
27 got: String,
28 },
29 InvalidOperator {
31 column: String,
32 operator: String,
33 reason: String,
34 },
35}
36
37impl std::fmt::Display for ValidationError {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 ValidationError::TableNotFound { table, suggestion } => {
41 if let Some(s) = suggestion {
42 write!(f, "Table '{}' not found. Did you mean '{}'?", table, s)
43 } else {
44 write!(f, "Table '{}' not found.", table)
45 }
46 }
47 ValidationError::ColumnNotFound {
48 table,
49 column,
50 suggestion,
51 } => {
52 if let Some(s) = suggestion {
53 write!(
54 f,
55 "Column '{}' not found in table '{}'. Did you mean '{}'?",
56 column, table, s
57 )
58 } else {
59 write!(f, "Column '{}' not found in table '{}'.", column, table)
60 }
61 }
62 ValidationError::TypeMismatch {
63 table,
64 column,
65 expected,
66 got,
67 } => {
68 write!(
69 f,
70 "Type mismatch for '{}.{}': expected {}, got {}",
71 table, column, expected, got
72 )
73 }
74 ValidationError::InvalidOperator {
75 column,
76 operator,
77 reason,
78 } => {
79 write!(
80 f,
81 "Invalid operator '{}' for column '{}': {}",
82 operator, column, reason
83 )
84 }
85 }
86 }
87}
88
89impl std::error::Error for ValidationError {}
90
91pub type ValidationResult = Result<(), Vec<ValidationError>>;
93
94#[derive(Debug, Clone)]
96pub struct Validator {
97 tables: Vec<String>,
98 columns: HashMap<String, Vec<String>>,
99 #[allow(dead_code)]
100 column_types: HashMap<String, HashMap<String, String>>,
101}
102
103impl Default for Validator {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109impl Validator {
110 pub fn new() -> Self {
112 Self {
113 tables: Vec::new(),
114 columns: HashMap::new(),
115 column_types: HashMap::new(),
116 }
117 }
118
119 pub fn add_table(&mut self, table: &str, cols: &[&str]) {
121 self.tables.push(table.to_string());
122 self.columns.insert(
123 table.to_string(),
124 cols.iter().map(|s| s.to_string()).collect(),
125 );
126 }
127
128 pub fn add_table_with_types(&mut self, table: &str, cols: &[(&str, &str)]) {
130 self.tables.push(table.to_string());
131 let col_names: Vec<String> = cols.iter().map(|(name, _)| name.to_string()).collect();
132 self.columns.insert(table.to_string(), col_names);
133
134 let type_map: HashMap<String, String> = cols
135 .iter()
136 .map(|(name, typ)| (name.to_string(), typ.to_string()))
137 .collect();
138 self.column_types.insert(table.to_string(), type_map);
139 }
140
141 pub fn table_names(&self) -> &[String] {
143 &self.tables
144 }
145
146 pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
148 self.columns.get(table)
149 }
150
151 pub fn table_exists(&self, table: &str) -> bool {
153 self.tables.contains(&table.to_string())
154 }
155
156 pub fn validate_table(&self, table: &str) -> Result<(), ValidationError> {
158 if self.tables.contains(&table.to_string()) {
159 Ok(())
160 } else {
161 let suggestion = self.did_you_mean(table, &self.tables);
162 Err(ValidationError::TableNotFound {
163 table: table.to_string(),
164 suggestion,
165 })
166 }
167 }
168
169 pub fn validate_column(&self, table: &str, column: &str) -> Result<(), ValidationError> {
171 if !self.tables.contains(&table.to_string()) {
173 return Ok(());
174 }
175
176 if column == "*" || column.contains('.') {
178 return Ok(());
179 }
180
181 if let Some(cols) = self.columns.get(table) {
182 if cols.contains(&column.to_string()) {
183 Ok(())
184 } else {
185 let suggestion = self.did_you_mean(column, cols);
186 Err(ValidationError::ColumnNotFound {
187 table: table.to_string(),
188 column: column.to_string(),
189 suggestion,
190 })
191 }
192 } else {
193 Ok(())
194 }
195 }
196
197 fn extract_column_name(expr: &Expr) -> Option<String> {
199 match expr {
200 Expr::Named(name) => Some(name.clone()),
201 Expr::Aliased { name, .. } => Some(name.clone()),
202 Expr::Aggregate { col, .. } => Some(col.clone()),
203 Expr::Cast { expr, .. } => Self::extract_column_name(expr),
204 Expr::JsonAccess { column, .. } => Some(column.clone()),
205 _ => None,
206 }
207 }
208
209 pub fn get_column_type(&self, table: &str, column: &str) -> Option<&String> {
211 self.column_types.get(table)?.get(column)
212 }
213
214 pub fn validate_value_type(
217 &self,
218 table: &str,
219 column: &str,
220 value: &crate::ast::Value,
221 ) -> Result<(), ValidationError> {
222 use crate::ast::Value;
223
224 let expected_type = match self.get_column_type(table, column) {
226 Some(t) => t.to_uppercase(),
227 None => return Ok(()), };
229
230 if matches!(value, Value::Null | Value::NullUuid) {
232 return Ok(());
233 }
234
235 if matches!(value, Value::Param(_) | Value::NamedParam(_) | Value::Function(_) | Value::Subquery(_) | Value::Expr(_)) {
237 return Ok(());
238 }
239
240 let value_type = match value {
242 Value::Bool(_) => "BOOLEAN",
243 Value::Int(_) => "INT",
244 Value::Float(_) => "FLOAT",
245 Value::String(_) => "TEXT",
246 Value::Uuid(_) => "UUID",
247 Value::Array(_) => "ARRAY",
248 Value::Column(_) => return Ok(()), Value::Interval { .. } => "INTERVAL",
250 Value::Timestamp(_) => "TIMESTAMP",
251 Value::Bytes(_) => "BYTEA",
252 Value::Vector(_) => "VECTOR",
253 Value::Json(_) => "JSONB",
254 _ => return Ok(()), };
256
257 if !Self::types_compatible(&expected_type, value_type) {
259 return Err(ValidationError::TypeMismatch {
260 table: table.to_string(),
261 column: column.to_string(),
262 expected: expected_type,
263 got: value_type.to_string(),
264 });
265 }
266
267 Ok(())
268 }
269
270 fn types_compatible(expected: &str, value_type: &str) -> bool {
273 let expected = expected.to_uppercase();
274 let value_type = value_type.to_uppercase();
275
276 if expected == value_type {
278 return true;
279 }
280
281 let int_types = ["INT", "INT4", "INT8", "INTEGER", "BIGINT", "SMALLINT", "SERIAL", "BIGSERIAL"];
283 if int_types.contains(&expected.as_str()) && value_type == "INT" {
284 return true;
285 }
286
287 let float_types = ["FLOAT", "FLOAT4", "FLOAT8", "DOUBLE", "DECIMAL", "NUMERIC", "REAL"];
289 if float_types.contains(&expected.as_str()) && (value_type == "FLOAT" || value_type == "INT") {
290 return true;
291 }
292
293 let text_types = ["TEXT", "VARCHAR", "CHAR", "CHARACTER", "CITEXT", "NAME"];
295 if text_types.contains(&expected.as_str()) && value_type == "TEXT" {
296 return true;
297 }
298
299 if (expected == "BOOLEAN" || expected == "BOOL") && value_type == "BOOLEAN" {
301 return true;
302 }
303
304 if expected == "UUID" && (value_type == "UUID" || value_type == "TEXT") {
306 return true;
307 }
308
309 let ts_types = ["TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME", "TIMETZ"];
311 if ts_types.contains(&expected.as_str()) && (value_type == "TIMESTAMP" || value_type == "TEXT") {
312 return true;
313 }
314
315 if expected == "JSONB" || expected == "JSON" {
317 return true;
318 }
319
320 if expected.contains("[]") || expected.starts_with("_") {
322 return value_type == "ARRAY";
323 }
324
325 false
326 }
327
328 pub fn validate_command(&self, cmd: &Qail) -> ValidationResult {
330 let mut errors = Vec::new();
331
332 if let Err(e) = self.validate_table(&cmd.table) {
333 errors.push(e);
334 }
335
336 for col in &cmd.columns {
337 if let Some(name) = Self::extract_column_name(col)
338 && let Err(e) = self.validate_column(&cmd.table, &name)
339 {
340 errors.push(e);
341 }
342 }
343
344 for cage in &cmd.cages {
345 for cond in &cage.conditions {
346 if let Some(name) = Self::extract_column_name(&cond.left) {
347 if name.contains('.') {
349 let parts: Vec<&str> = name.split('.').collect();
350 if parts.len() == 2 {
351 if let Err(e) = self.validate_column(parts[0], parts[1]) {
352 errors.push(e);
353 }
354 if let Err(e) = self.validate_value_type(parts[0], parts[1], &cond.value) {
356 errors.push(e);
357 }
358 }
359 } else {
360 if let Err(e) = self.validate_column(&cmd.table, &name) {
361 errors.push(e);
362 }
363 if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
365 errors.push(e);
366 }
367 }
368 }
369 }
370 }
371
372 for join in &cmd.joins {
373
374 if let Err(e) = self.validate_table(&join.table) {
375 errors.push(e);
376 }
377
378
379 if let Some(conditions) = &join.on {
380 for cond in conditions {
381 if let Some(name) = Self::extract_column_name(&cond.left)
382 && name.contains('.')
383 {
384 let parts: Vec<&str> = name.split('.').collect();
385 if parts.len() == 2
386 && let Err(e) = self.validate_column(parts[0], parts[1])
387 {
388 errors.push(e);
389 }
390 }
391 if let crate::ast::Value::Column(col_name) = &cond.value
393 && col_name.contains('.')
394 {
395 let parts: Vec<&str> = col_name.split('.').collect();
396 if parts.len() == 2
397 && let Err(e) = self.validate_column(parts[0], parts[1])
398 {
399 errors.push(e);
400 }
401 }
402 }
403 }
404 }
405
406 if let Some(returning) = &cmd.returning {
407 for col in returning {
408 if let Some(name) = Self::extract_column_name(col)
409 && let Err(e) = self.validate_column(&cmd.table, &name)
410 {
411 errors.push(e);
412 }
413 }
414 }
415
416 if errors.is_empty() {
417 Ok(())
418 } else {
419 Err(errors)
420 }
421 }
422
423 fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
425 let mut best_match = None;
426 let mut min_dist = usize::MAX;
427
428 for cand in candidates {
429 let cand_str = cand.as_ref();
430 let dist = levenshtein(input, cand_str);
431
432 let threshold = match input.len() {
434 0..=2 => 0, 3..=5 => 2, _ => 3, };
438
439 if dist <= threshold && dist < min_dist {
440 min_dist = dist;
441 best_match = Some(cand_str.to_string());
442 }
443 }
444
445 best_match
446 }
447
448 #[deprecated(note = "Use validate_table() which returns ValidationError")]
454 pub fn validate_table_legacy(&self, table: &str) -> Result<(), String> {
455 self.validate_table(table).map_err(|e| e.to_string())
456 }
457
458 #[deprecated(note = "Use validate_column() which returns ValidationError")]
460 pub fn validate_column_legacy(&self, table: &str, column: &str) -> Result<(), String> {
461 self.validate_column(table, column)
462 .map_err(|e| e.to_string())
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_did_you_mean_table() {
472 let mut v = Validator::new();
473 v.add_table("users", &["id", "name"]);
474 v.add_table("orders", &["id", "total"]);
475
476 assert!(v.validate_table("users").is_ok());
477
478 let err = v.validate_table("usr").unwrap_err();
479 assert!(
480 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
481 );
482
483 let err = v.validate_table("usrs").unwrap_err();
484 assert!(
485 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
486 );
487 }
488
489 #[test]
490 fn test_did_you_mean_column() {
491 let mut v = Validator::new();
492 v.add_table("users", &["email", "password"]);
493
494 assert!(v.validate_column("users", "email").is_ok());
495 assert!(v.validate_column("users", "*").is_ok());
496
497 let err = v.validate_column("users", "emial").unwrap_err();
498 assert!(
499 matches!(err, ValidationError::ColumnNotFound { suggestion: Some(ref s), .. } if s == "email")
500 );
501 }
502
503 #[test]
504 fn test_qualified_column_name() {
505 let mut v = Validator::new();
506 v.add_table("users", &["id", "name"]);
507 v.add_table("profiles", &["user_id", "avatar"]);
508
509 assert!(v.validate_column("users", "users.id").is_ok());
511 assert!(v.validate_column("users", "profiles.user_id").is_ok());
512 }
513
514 #[test]
515 fn test_validate_command() {
516 let mut v = Validator::new();
517 v.add_table("users", &["id", "email", "name"]);
518
519 let cmd = Qail::get("users").columns(["id", "email"]);
520 assert!(v.validate_command(&cmd).is_ok());
521
522 let cmd = Qail::get("users").columns(["id", "emial"]); let errors = v.validate_command(&cmd).unwrap_err();
524 assert_eq!(errors.len(), 1);
525 assert!(
526 matches!(&errors[0], ValidationError::ColumnNotFound { column, .. } if column == "emial")
527 );
528 }
529
530 #[test]
531 fn test_error_display() {
532 let err = ValidationError::TableNotFound {
533 table: "usrs".to_string(),
534 suggestion: Some("users".to_string()),
535 };
536 assert_eq!(
537 err.to_string(),
538 "Table 'usrs' not found. Did you mean 'users'?"
539 );
540
541 let err = ValidationError::ColumnNotFound {
542 table: "users".to_string(),
543 column: "emial".to_string(),
544 suggestion: Some("email".to_string()),
545 };
546 assert_eq!(
547 err.to_string(),
548 "Column 'emial' not found in table 'users'. Did you mean 'email'?"
549 );
550 }
551}