1use crate::ast::{Expr, Qail};
7use std::collections::HashMap;
8use strsim::levenshtein;
9
10#[derive(Debug, Clone, PartialEq)]
12pub enum ValidationError {
13 TableNotFound {
15 table: String,
17 suggestion: Option<String>,
19 },
20 ColumnNotFound {
22 table: String,
24 column: String,
26 suggestion: Option<String>,
28 },
29 TypeMismatch {
31 table: String,
33 column: String,
35 expected: String,
37 got: String,
39 },
40 InvalidOperator {
42 column: String,
44 operator: String,
46 reason: String,
48 },
49}
50
51impl std::fmt::Display for ValidationError {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 ValidationError::TableNotFound { table, suggestion } => {
55 if let Some(s) = suggestion {
56 write!(f, "Table '{}' not found. Did you mean '{}'?", table, s)
57 } else {
58 write!(f, "Table '{}' not found.", table)
59 }
60 }
61 ValidationError::ColumnNotFound {
62 table,
63 column,
64 suggestion,
65 } => {
66 if let Some(s) = suggestion {
67 write!(
68 f,
69 "Column '{}' not found in table '{}'. Did you mean '{}'?",
70 column, table, s
71 )
72 } else {
73 write!(f, "Column '{}' not found in table '{}'.", column, table)
74 }
75 }
76 ValidationError::TypeMismatch {
77 table,
78 column,
79 expected,
80 got,
81 } => {
82 write!(
83 f,
84 "Type mismatch for '{}.{}': expected {}, got {}",
85 table, column, expected, got
86 )
87 }
88 ValidationError::InvalidOperator {
89 column,
90 operator,
91 reason,
92 } => {
93 write!(
94 f,
95 "Invalid operator '{}' for column '{}': {}",
96 operator, column, reason
97 )
98 }
99 }
100 }
101}
102
103impl std::error::Error for ValidationError {}
104
105pub type ValidationResult = Result<(), Vec<ValidationError>>;
107
108#[derive(Debug, Clone)]
110pub struct Validator {
111 tables: Vec<String>,
113 columns: HashMap<String, Vec<String>>,
115 column_types: HashMap<String, HashMap<String, String>>,
117}
118
119impl Default for Validator {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125impl Validator {
126 pub fn new() -> Self {
128 Self {
129 tables: Vec::new(),
130 columns: HashMap::new(),
131 column_types: HashMap::new(),
132 }
133 }
134
135 pub fn add_table(&mut self, table: &str, cols: &[&str]) {
137 self.tables.push(table.to_string());
138 self.columns.insert(
139 table.to_string(),
140 cols.iter().map(|s| s.to_string()).collect(),
141 );
142 }
143
144 pub fn add_table_name(&mut self, table: &str) {
150 self.tables.push(table.to_string());
151 }
152
153 pub fn add_table_with_types(&mut self, table: &str, cols: &[(&str, &str)]) {
160 self.tables.push(table.to_string());
161 let col_names: Vec<String> = cols.iter().map(|(name, _)| name.to_string()).collect();
162 self.columns.insert(table.to_string(), col_names);
163
164 let type_map: HashMap<String, String> = cols
165 .iter()
166 .map(|(name, typ)| (name.to_string(), typ.to_string()))
167 .collect();
168 self.column_types.insert(table.to_string(), type_map);
169 }
170
171 pub fn table_names(&self) -> &[String] {
173 &self.tables
174 }
175
176 pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
178 self.columns.get(table)
179 }
180
181 pub fn table_exists(&self, table: &str) -> bool {
183 self.tables.contains(&table.to_string())
184 }
185
186 pub fn validate_table(&self, table: &str) -> Result<(), ValidationError> {
188 if self.tables.contains(&table.to_string()) {
189 Ok(())
190 } else {
191 let suggestion = self.did_you_mean(table, &self.tables);
192 Err(ValidationError::TableNotFound {
193 table: table.to_string(),
194 suggestion,
195 })
196 }
197 }
198
199 pub fn validate_column(&self, table: &str, column: &str) -> Result<(), ValidationError> {
206 if !self.tables.contains(&table.to_string()) {
208 return Ok(());
209 }
210
211 if column == "*" {
213 return Ok(());
214 }
215
216 if column.contains('(')
220 || column.contains('[')
221 || column.contains("::")
222 || column.contains(" AS ")
223 || column.contains(" as ")
224 || column.starts_with("distinct ")
225 || column.starts_with("DISTINCT ")
226 {
227 return Ok(());
228 }
229
230 if column.contains('.') {
232 let parts: Vec<&str> = column.split('.').collect();
233 if parts.len() == 2 {
234 if self.tables.contains(&parts[0].to_string()) {
236 return self.validate_column(parts[0], parts[1]);
237 }
238 }
239 return Ok(());
241 }
242
243 if let Some(cols) = self.columns.get(table) {
244 if cols.contains(&column.to_string()) {
245 Ok(())
246 } else {
247 let suggestion = self.did_you_mean(column, cols);
248 Err(ValidationError::ColumnNotFound {
249 table: table.to_string(),
250 column: column.to_string(),
251 suggestion,
252 })
253 }
254 } else {
255 Ok(())
256 }
257 }
258
259 fn extract_column_name(expr: &Expr) -> Option<String> {
261 match expr {
262 Expr::Named(name) => Some(name.clone()),
263 Expr::Aliased { name, .. } => Some(name.clone()),
264 Expr::Aggregate { col, .. } => Some(col.clone()),
265 Expr::Cast { expr, .. } => Self::extract_column_name(expr),
266 Expr::JsonAccess { column, .. } => Some(column.clone()),
267 _ => None,
268 }
269 }
270
271 pub fn get_column_type(&self, table: &str, column: &str) -> Option<&String> {
273 self.column_types.get(table)?.get(column)
274 }
275
276 pub fn validate_value_type(
279 &self,
280 table: &str,
281 column: &str,
282 value: &crate::ast::Value,
283 ) -> Result<(), ValidationError> {
284 use crate::ast::Value;
285
286 let expected_type = match self.get_column_type(table, column) {
288 Some(t) => t.to_uppercase(),
289 None => return Ok(()), };
291
292 if matches!(value, Value::Null | Value::NullUuid) {
294 return Ok(());
295 }
296
297 if matches!(
299 value,
300 Value::Param(_)
301 | Value::NamedParam(_)
302 | Value::Function(_)
303 | Value::Subquery(_)
304 | Value::Expr(_)
305 ) {
306 return Ok(());
307 }
308
309 if matches!(value, Value::Array(_)) {
313 return Ok(());
314 }
315
316 let value_type = match value {
318 Value::Bool(_) => "BOOLEAN",
319 Value::Int(_) => "INT",
320 Value::Float(_) => "FLOAT",
321 Value::String(_) => "TEXT",
322 Value::Uuid(_) => "UUID",
323 Value::Column(_) => return Ok(()), Value::Interval { .. } => "INTERVAL",
325 Value::Timestamp(_) => "TIMESTAMP",
326 Value::Bytes(_) => "BYTEA",
327 Value::Vector(_) => "VECTOR",
328 Value::Json(_) => "JSONB",
329 _ => return Ok(()), };
331
332 if !Self::types_compatible(&expected_type, value_type) {
334 return Err(ValidationError::TypeMismatch {
335 table: table.to_string(),
336 column: column.to_string(),
337 expected: expected_type,
338 got: value_type.to_string(),
339 });
340 }
341
342 Ok(())
343 }
344
345 fn types_compatible(expected: &str, value_type: &str) -> bool {
348 let expected = expected.to_uppercase();
349 let value_type = value_type.to_uppercase();
350
351 if expected == value_type {
353 return true;
354 }
355
356 let int_types = [
358 "INT",
359 "INT4",
360 "INT8",
361 "INTEGER",
362 "BIGINT",
363 "SMALLINT",
364 "SERIAL",
365 "BIGSERIAL",
366 ];
367 if int_types.contains(&expected.as_str()) && value_type == "INT" {
368 return true;
369 }
370
371 let float_types = [
373 "FLOAT",
374 "FLOAT4",
375 "FLOAT8",
376 "DOUBLE",
377 "DOUBLE PRECISION",
378 "DECIMAL",
379 "NUMERIC",
380 "REAL",
381 ];
382 if float_types.contains(&expected.as_str())
383 && (value_type == "FLOAT" || value_type == "INT")
384 {
385 return true;
386 }
387
388 let text_types = ["TEXT", "VARCHAR", "CHAR", "CHARACTER", "CITEXT", "NAME"];
390 if text_types.contains(&expected.as_str()) && value_type == "TEXT" {
391 return true;
392 }
393
394 if (expected == "BOOLEAN" || expected == "BOOL") && value_type == "BOOLEAN" {
396 return true;
397 }
398
399 if expected == "UUID" && (value_type == "UUID" || value_type == "TEXT") {
401 return true;
402 }
403
404 let ts_types = ["TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME", "TIMETZ"];
406 if ts_types.contains(&expected.as_str())
407 && (value_type == "TIMESTAMP" || value_type == "TEXT")
408 {
409 return true;
410 }
411
412 if expected == "JSONB" || expected == "JSON" {
414 return true;
415 }
416
417 if expected.contains("[]") || expected.starts_with("_") {
419 return value_type == "ARRAY";
420 }
421
422 false
423 }
424
425 pub fn validate_command(&self, cmd: &Qail) -> ValidationResult {
427 let mut errors = Vec::new();
428
429 if let Err(e) = self.validate_table(&cmd.table) {
430 errors.push(e);
431 }
432
433 let mut aliases: Vec<String> = Vec::new();
437 for col in &cmd.columns {
438 if let Expr::Aliased { alias, .. } = col {
439 aliases.push(alias.clone());
440 }
441 if let Some(name) = Self::extract_column_name(col)
442 && let Err(e) = self.validate_column(&cmd.table, &name)
443 {
444 errors.push(e);
445 }
446 }
447
448 for cage in &cmd.cages {
449 if matches!(cage.kind, crate::ast::CageKind::Sort(_)) {
453 continue;
454 }
455 for cond in &cage.conditions {
456 if let Some(name) = Self::extract_column_name(&cond.left) {
457 if aliases.iter().any(|a| a == &name) {
459 continue;
460 }
461 if name.contains('.') {
463 let parts: Vec<&str> = name.split('.').collect();
464 if parts.len() == 2 {
465 if let Err(e) = self.validate_column(parts[0], parts[1]) {
466 errors.push(e);
467 }
468 if let Err(e) =
470 self.validate_value_type(parts[0], parts[1], &cond.value)
471 {
472 errors.push(e);
473 }
474 }
475 } else {
476 if let Err(e) = self.validate_column(&cmd.table, &name) {
477 errors.push(e);
478 }
479 if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
481 errors.push(e);
482 }
483 }
484 }
485 }
486 }
487
488 for cond in &cmd.having {
489 if let Some(name) = Self::extract_column_name(&cond.left) {
490 if name.contains('(') || name == "*" {
491 continue;
492 }
493 if name.contains('.') {
494 let parts: Vec<&str> = name.split('.').collect();
495 if parts.len() == 2 {
496 if let Err(e) = self.validate_column(parts[0], parts[1]) {
497 errors.push(e);
498 }
499 if let Err(e) = self.validate_value_type(parts[0], parts[1], &cond.value) {
500 errors.push(e);
501 }
502 }
503 } else {
504 if let Err(e) = self.validate_column(&cmd.table, &name) {
505 errors.push(e);
506 }
507 if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
508 errors.push(e);
509 }
510 }
511 }
512 }
513
514 for join in &cmd.joins {
515 if let Err(e) = self.validate_table(&join.table) {
516 errors.push(e);
517 }
518
519 if let Some(conditions) = &join.on {
520 for cond in conditions {
521 if let Some(name) = Self::extract_column_name(&cond.left)
522 && name.contains('.')
523 {
524 let parts: Vec<&str> = name.split('.').collect();
525 if parts.len() == 2
526 && let Err(e) = self.validate_column(parts[0], parts[1])
527 {
528 errors.push(e);
529 }
530 }
531 if let crate::ast::Value::Column(col_name) = &cond.value
533 && col_name.contains('.')
534 {
535 let parts: Vec<&str> = col_name.split('.').collect();
536 if parts.len() == 2
537 && let Err(e) = self.validate_column(parts[0], parts[1])
538 {
539 errors.push(e);
540 }
541 }
542 }
543 }
544 }
545
546 if let Some(returning) = &cmd.returning {
547 for col in returning {
548 if let Some(name) = Self::extract_column_name(col)
549 && let Err(e) = self.validate_column(&cmd.table, &name)
550 {
551 errors.push(e);
552 }
553 }
554 }
555
556 if errors.is_empty() {
557 Ok(())
558 } else {
559 Err(errors)
560 }
561 }
562
563 fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
565 let mut best_match = None;
566 let mut min_dist = usize::MAX;
567
568 for cand in candidates {
569 let cand_str = cand.as_ref();
570 let dist = levenshtein(input, cand_str);
571
572 let threshold = match input.len() {
574 0..=2 => 0, 3..=5 => 2, _ => 3, };
578
579 if dist <= threshold && dist < min_dist {
580 min_dist = dist;
581 best_match = Some(cand_str.to_string());
582 }
583 }
584
585 best_match
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592
593 #[test]
594 fn test_did_you_mean_table() {
595 let mut v = Validator::new();
596 v.add_table("users", &["id", "name"]);
597 v.add_table("orders", &["id", "total"]);
598
599 assert!(v.validate_table("users").is_ok());
600
601 let err = v.validate_table("usr").unwrap_err();
602 assert!(
603 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
604 );
605
606 let err = v.validate_table("usrs").unwrap_err();
607 assert!(
608 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
609 );
610 }
611
612 #[test]
613 fn test_did_you_mean_column() {
614 let mut v = Validator::new();
615 v.add_table("users", &["email", "password"]);
616
617 assert!(v.validate_column("users", "email").is_ok());
618 assert!(v.validate_column("users", "*").is_ok());
619
620 let err = v.validate_column("users", "emial").unwrap_err();
621 assert!(
622 matches!(err, ValidationError::ColumnNotFound { suggestion: Some(ref s), .. } if s == "email")
623 );
624 }
625
626 #[test]
627 fn test_qualified_column_name() {
628 let mut v = Validator::new();
629 v.add_table("users", &["id", "name"]);
630 v.add_table("profiles", &["user_id", "avatar"]);
631
632 assert!(v.validate_column("users", "users.id").is_ok());
634 assert!(v.validate_column("users", "profiles.user_id").is_ok());
635 }
636
637 #[test]
638 fn test_validate_command() {
639 let mut v = Validator::new();
640 v.add_table("users", &["id", "email", "name"]);
641
642 let cmd = Qail::get("users").columns(["id", "email"]);
643 assert!(v.validate_command(&cmd).is_ok());
644
645 let cmd = Qail::get("users").columns(["id", "emial"]); let errors = v.validate_command(&cmd).unwrap_err();
647 assert_eq!(errors.len(), 1);
648 assert!(
649 matches!(&errors[0], ValidationError::ColumnNotFound { column, .. } if column == "emial")
650 );
651 }
652
653 #[test]
654 fn test_validate_having_columns() {
655 let mut v = Validator::new();
656 v.add_table("orders", &["id", "status", "total"]);
657
658 let mut cmd = Qail::get("orders");
659 cmd.having.push(crate::ast::Condition {
660 left: Expr::Named("totl".to_string()),
661 op: crate::ast::Operator::Eq,
662 value: crate::ast::Value::Int(1),
663 is_array_unnest: false,
664 });
665
666 let errors = v.validate_command(&cmd).unwrap_err();
667 assert!(errors.iter().any(
668 |e| matches!(e, ValidationError::ColumnNotFound { column, .. } if column == "totl")
669 ));
670 }
671
672 #[test]
673 fn test_error_display() {
674 let err = ValidationError::TableNotFound {
675 table: "usrs".to_string(),
676 suggestion: Some("users".to_string()),
677 };
678 assert_eq!(
679 err.to_string(),
680 "Table 'usrs' not found. Did you mean 'users'?"
681 );
682
683 let err = ValidationError::ColumnNotFound {
684 table: "users".to_string(),
685 column: "emial".to_string(),
686 suggestion: Some("email".to_string()),
687 };
688 assert_eq!(
689 err.to_string(),
690 "Column 'emial' not found in table 'users'. Did you mean 'email'?"
691 );
692 }
693}