1use crate::ast::{Expr, Qail};
7use std::collections::HashMap;
8use strsim::levenshtein;
9
10#[derive(Debug, Clone, PartialEq)]
12pub enum ValidationError {
13 TableNotFound {
15 table: String,
17 suggestion: Option<String>,
19 },
20 ColumnNotFound {
22 table: String,
24 column: String,
26 suggestion: Option<String>,
28 },
29 TypeMismatch {
31 table: String,
33 column: String,
35 expected: String,
37 got: String,
39 },
40 InvalidOperator {
42 column: String,
44 operator: String,
46 reason: String,
48 },
49}
50
51impl std::fmt::Display for ValidationError {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 ValidationError::TableNotFound { table, suggestion } => {
55 if let Some(s) = suggestion {
56 write!(f, "Table '{}' not found. Did you mean '{}'?", table, s)
57 } else {
58 write!(f, "Table '{}' not found.", table)
59 }
60 }
61 ValidationError::ColumnNotFound {
62 table,
63 column,
64 suggestion,
65 } => {
66 if let Some(s) = suggestion {
67 write!(
68 f,
69 "Column '{}' not found in table '{}'. Did you mean '{}'?",
70 column, table, s
71 )
72 } else {
73 write!(f, "Column '{}' not found in table '{}'.", column, table)
74 }
75 }
76 ValidationError::TypeMismatch {
77 table,
78 column,
79 expected,
80 got,
81 } => {
82 write!(
83 f,
84 "Type mismatch for '{}.{}': expected {}, got {}",
85 table, column, expected, got
86 )
87 }
88 ValidationError::InvalidOperator {
89 column,
90 operator,
91 reason,
92 } => {
93 write!(
94 f,
95 "Invalid operator '{}' for column '{}': {}",
96 operator, column, reason
97 )
98 }
99 }
100 }
101}
102
103impl std::error::Error for ValidationError {}
104
105pub type ValidationResult = Result<(), Vec<ValidationError>>;
107
108#[derive(Debug, Clone)]
110pub struct Validator {
111 tables: Vec<String>,
113 columns: HashMap<String, Vec<String>>,
115 #[allow(dead_code)]
117 column_types: HashMap<String, HashMap<String, String>>,
118}
119
120impl Default for Validator {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl Validator {
127 pub fn new() -> Self {
129 Self {
130 tables: Vec::new(),
131 columns: HashMap::new(),
132 column_types: HashMap::new(),
133 }
134 }
135
136 pub fn add_table(&mut self, table: &str, cols: &[&str]) {
138 self.tables.push(table.to_string());
139 self.columns.insert(
140 table.to_string(),
141 cols.iter().map(|s| s.to_string()).collect(),
142 );
143 }
144
145 pub fn add_table_name(&mut self, table: &str) {
151 self.tables.push(table.to_string());
152 }
153
154 pub fn add_table_with_types(&mut self, table: &str, cols: &[(&str, &str)]) {
161 self.tables.push(table.to_string());
162 let col_names: Vec<String> = cols.iter().map(|(name, _)| name.to_string()).collect();
163 self.columns.insert(table.to_string(), col_names);
164
165 let type_map: HashMap<String, String> = cols
166 .iter()
167 .map(|(name, typ)| (name.to_string(), typ.to_string()))
168 .collect();
169 self.column_types.insert(table.to_string(), type_map);
170 }
171
172 pub fn table_names(&self) -> &[String] {
174 &self.tables
175 }
176
177 pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
179 self.columns.get(table)
180 }
181
182 pub fn table_exists(&self, table: &str) -> bool {
184 self.tables.contains(&table.to_string())
185 }
186
187 pub fn validate_table(&self, table: &str) -> Result<(), ValidationError> {
189 if self.tables.contains(&table.to_string()) {
190 Ok(())
191 } else {
192 let suggestion = self.did_you_mean(table, &self.tables);
193 Err(ValidationError::TableNotFound {
194 table: table.to_string(),
195 suggestion,
196 })
197 }
198 }
199
200 pub fn validate_column(&self, table: &str, column: &str) -> Result<(), ValidationError> {
207 if !self.tables.contains(&table.to_string()) {
209 return Ok(());
210 }
211
212 if column == "*" {
214 return Ok(());
215 }
216
217 if column.contains('(')
221 || column.contains('[')
222 || column.contains("::")
223 || column.contains(" AS ")
224 || column.contains(" as ")
225 || column.starts_with("distinct ")
226 || column.starts_with("DISTINCT ")
227 {
228 return Ok(());
229 }
230
231 if column.contains('.') {
233 let parts: Vec<&str> = column.split('.').collect();
234 if parts.len() == 2 {
235 if self.tables.contains(&parts[0].to_string()) {
237 return self.validate_column(parts[0], parts[1]);
238 }
239 }
240 return Ok(());
242 }
243
244 if let Some(cols) = self.columns.get(table) {
245 if cols.contains(&column.to_string()) {
246 Ok(())
247 } else {
248 let suggestion = self.did_you_mean(column, cols);
249 Err(ValidationError::ColumnNotFound {
250 table: table.to_string(),
251 column: column.to_string(),
252 suggestion,
253 })
254 }
255 } else {
256 Ok(())
257 }
258 }
259
260 fn extract_column_name(expr: &Expr) -> Option<String> {
262 match expr {
263 Expr::Named(name) => Some(name.clone()),
264 Expr::Aliased { name, .. } => Some(name.clone()),
265 Expr::Aggregate { col, .. } => Some(col.clone()),
266 Expr::Cast { expr, .. } => Self::extract_column_name(expr),
267 Expr::JsonAccess { column, .. } => Some(column.clone()),
268 _ => None,
269 }
270 }
271
272 pub fn get_column_type(&self, table: &str, column: &str) -> Option<&String> {
274 self.column_types.get(table)?.get(column)
275 }
276
277 pub fn validate_value_type(
280 &self,
281 table: &str,
282 column: &str,
283 value: &crate::ast::Value,
284 ) -> Result<(), ValidationError> {
285 use crate::ast::Value;
286
287 let expected_type = match self.get_column_type(table, column) {
289 Some(t) => t.to_uppercase(),
290 None => return Ok(()), };
292
293 if matches!(value, Value::Null | Value::NullUuid) {
295 return Ok(());
296 }
297
298 if matches!(
300 value,
301 Value::Param(_)
302 | Value::NamedParam(_)
303 | Value::Function(_)
304 | Value::Subquery(_)
305 | Value::Expr(_)
306 ) {
307 return Ok(());
308 }
309
310 if matches!(value, Value::Array(_)) {
314 return Ok(());
315 }
316
317 let value_type = match value {
319 Value::Bool(_) => "BOOLEAN",
320 Value::Int(_) => "INT",
321 Value::Float(_) => "FLOAT",
322 Value::String(_) => "TEXT",
323 Value::Uuid(_) => "UUID",
324 Value::Column(_) => return Ok(()), Value::Interval { .. } => "INTERVAL",
326 Value::Timestamp(_) => "TIMESTAMP",
327 Value::Bytes(_) => "BYTEA",
328 Value::Vector(_) => "VECTOR",
329 Value::Json(_) => "JSONB",
330 _ => return Ok(()), };
332
333 if !Self::types_compatible(&expected_type, value_type) {
335 return Err(ValidationError::TypeMismatch {
336 table: table.to_string(),
337 column: column.to_string(),
338 expected: expected_type,
339 got: value_type.to_string(),
340 });
341 }
342
343 Ok(())
344 }
345
346 fn types_compatible(expected: &str, value_type: &str) -> bool {
349 let expected = expected.to_uppercase();
350 let value_type = value_type.to_uppercase();
351
352 if expected == value_type {
354 return true;
355 }
356
357 let int_types = [
359 "INT",
360 "INT4",
361 "INT8",
362 "INTEGER",
363 "BIGINT",
364 "SMALLINT",
365 "SERIAL",
366 "BIGSERIAL",
367 ];
368 if int_types.contains(&expected.as_str()) && value_type == "INT" {
369 return true;
370 }
371
372 let float_types = [
374 "FLOAT",
375 "FLOAT4",
376 "FLOAT8",
377 "DOUBLE",
378 "DOUBLE PRECISION",
379 "DECIMAL",
380 "NUMERIC",
381 "REAL",
382 ];
383 if float_types.contains(&expected.as_str())
384 && (value_type == "FLOAT" || value_type == "INT")
385 {
386 return true;
387 }
388
389 let text_types = ["TEXT", "VARCHAR", "CHAR", "CHARACTER", "CITEXT", "NAME"];
391 if text_types.contains(&expected.as_str()) && value_type == "TEXT" {
392 return true;
393 }
394
395 if (expected == "BOOLEAN" || expected == "BOOL") && value_type == "BOOLEAN" {
397 return true;
398 }
399
400 if expected == "UUID" && (value_type == "UUID" || value_type == "TEXT") {
402 return true;
403 }
404
405 let ts_types = ["TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME", "TIMETZ"];
407 if ts_types.contains(&expected.as_str())
408 && (value_type == "TIMESTAMP" || value_type == "TEXT")
409 {
410 return true;
411 }
412
413 if expected == "JSONB" || expected == "JSON" {
415 return true;
416 }
417
418 if expected.contains("[]") || expected.starts_with("_") {
420 return value_type == "ARRAY";
421 }
422
423 false
424 }
425
426 pub fn validate_command(&self, cmd: &Qail) -> ValidationResult {
428 let mut errors = Vec::new();
429
430 if let Err(e) = self.validate_table(&cmd.table) {
431 errors.push(e);
432 }
433
434 let mut aliases: Vec<String> = Vec::new();
438 for col in &cmd.columns {
439 if let Expr::Aliased { alias, .. } = col {
440 aliases.push(alias.clone());
441 }
442 if let Some(name) = Self::extract_column_name(col)
443 && let Err(e) = self.validate_column(&cmd.table, &name)
444 {
445 errors.push(e);
446 }
447 }
448
449 for cage in &cmd.cages {
450 if matches!(cage.kind, crate::ast::CageKind::Sort(_)) {
454 continue;
455 }
456 for cond in &cage.conditions {
457 if let Some(name) = Self::extract_column_name(&cond.left) {
458 if aliases.iter().any(|a| a == &name) {
460 continue;
461 }
462 if name.contains('.') {
464 let parts: Vec<&str> = name.split('.').collect();
465 if parts.len() == 2 {
466 if let Err(e) = self.validate_column(parts[0], parts[1]) {
467 errors.push(e);
468 }
469 if let Err(e) =
471 self.validate_value_type(parts[0], parts[1], &cond.value)
472 {
473 errors.push(e);
474 }
475 }
476 } else {
477 if let Err(e) = self.validate_column(&cmd.table, &name) {
478 errors.push(e);
479 }
480 if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
482 errors.push(e);
483 }
484 }
485 }
486 }
487 }
488
489 for cond in &cmd.having {
490 if let Some(name) = Self::extract_column_name(&cond.left) {
491 if name.contains('(') || name == "*" {
492 continue;
493 }
494 if name.contains('.') {
495 let parts: Vec<&str> = name.split('.').collect();
496 if parts.len() == 2 {
497 if let Err(e) = self.validate_column(parts[0], parts[1]) {
498 errors.push(e);
499 }
500 if let Err(e) = self.validate_value_type(parts[0], parts[1], &cond.value) {
501 errors.push(e);
502 }
503 }
504 } else {
505 if let Err(e) = self.validate_column(&cmd.table, &name) {
506 errors.push(e);
507 }
508 if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
509 errors.push(e);
510 }
511 }
512 }
513 }
514
515 for join in &cmd.joins {
516 if let Err(e) = self.validate_table(&join.table) {
517 errors.push(e);
518 }
519
520 if let Some(conditions) = &join.on {
521 for cond in conditions {
522 if let Some(name) = Self::extract_column_name(&cond.left)
523 && name.contains('.')
524 {
525 let parts: Vec<&str> = name.split('.').collect();
526 if parts.len() == 2
527 && let Err(e) = self.validate_column(parts[0], parts[1])
528 {
529 errors.push(e);
530 }
531 }
532 if let crate::ast::Value::Column(col_name) = &cond.value
534 && col_name.contains('.')
535 {
536 let parts: Vec<&str> = col_name.split('.').collect();
537 if parts.len() == 2
538 && let Err(e) = self.validate_column(parts[0], parts[1])
539 {
540 errors.push(e);
541 }
542 }
543 }
544 }
545 }
546
547 if let Some(returning) = &cmd.returning {
548 for col in returning {
549 if let Some(name) = Self::extract_column_name(col)
550 && let Err(e) = self.validate_column(&cmd.table, &name)
551 {
552 errors.push(e);
553 }
554 }
555 }
556
557 if errors.is_empty() {
558 Ok(())
559 } else {
560 Err(errors)
561 }
562 }
563
564 fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
566 let mut best_match = None;
567 let mut min_dist = usize::MAX;
568
569 for cand in candidates {
570 let cand_str = cand.as_ref();
571 let dist = levenshtein(input, cand_str);
572
573 let threshold = match input.len() {
575 0..=2 => 0, 3..=5 => 2, _ => 3, };
579
580 if dist <= threshold && dist < min_dist {
581 min_dist = dist;
582 best_match = Some(cand_str.to_string());
583 }
584 }
585
586 best_match
587 }
588
589 #[deprecated(note = "Use validate_table() which returns ValidationError")]
595 pub fn validate_table_legacy(&self, table: &str) -> Result<(), String> {
596 self.validate_table(table).map_err(|e| e.to_string())
597 }
598
599 #[deprecated(note = "Use validate_column() which returns ValidationError")]
601 pub fn validate_column_legacy(&self, table: &str, column: &str) -> Result<(), String> {
602 self.validate_column(table, column)
603 .map_err(|e| e.to_string())
604 }
605}
606
607#[cfg(test)]
608mod tests {
609 use super::*;
610
611 #[test]
612 fn test_did_you_mean_table() {
613 let mut v = Validator::new();
614 v.add_table("users", &["id", "name"]);
615 v.add_table("orders", &["id", "total"]);
616
617 assert!(v.validate_table("users").is_ok());
618
619 let err = v.validate_table("usr").unwrap_err();
620 assert!(
621 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
622 );
623
624 let err = v.validate_table("usrs").unwrap_err();
625 assert!(
626 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
627 );
628 }
629
630 #[test]
631 fn test_did_you_mean_column() {
632 let mut v = Validator::new();
633 v.add_table("users", &["email", "password"]);
634
635 assert!(v.validate_column("users", "email").is_ok());
636 assert!(v.validate_column("users", "*").is_ok());
637
638 let err = v.validate_column("users", "emial").unwrap_err();
639 assert!(
640 matches!(err, ValidationError::ColumnNotFound { suggestion: Some(ref s), .. } if s == "email")
641 );
642 }
643
644 #[test]
645 fn test_qualified_column_name() {
646 let mut v = Validator::new();
647 v.add_table("users", &["id", "name"]);
648 v.add_table("profiles", &["user_id", "avatar"]);
649
650 assert!(v.validate_column("users", "users.id").is_ok());
652 assert!(v.validate_column("users", "profiles.user_id").is_ok());
653 }
654
655 #[test]
656 fn test_validate_command() {
657 let mut v = Validator::new();
658 v.add_table("users", &["id", "email", "name"]);
659
660 let cmd = Qail::get("users").columns(["id", "email"]);
661 assert!(v.validate_command(&cmd).is_ok());
662
663 let cmd = Qail::get("users").columns(["id", "emial"]); let errors = v.validate_command(&cmd).unwrap_err();
665 assert_eq!(errors.len(), 1);
666 assert!(
667 matches!(&errors[0], ValidationError::ColumnNotFound { column, .. } if column == "emial")
668 );
669 }
670
671 #[test]
672 fn test_validate_having_columns() {
673 let mut v = Validator::new();
674 v.add_table("orders", &["id", "status", "total"]);
675
676 let mut cmd = Qail::get("orders");
677 cmd.having.push(crate::ast::Condition {
678 left: Expr::Named("totl".to_string()),
679 op: crate::ast::Operator::Eq,
680 value: crate::ast::Value::Int(1),
681 is_array_unnest: false,
682 });
683
684 let errors = v.validate_command(&cmd).unwrap_err();
685 assert!(errors.iter().any(
686 |e| matches!(e, ValidationError::ColumnNotFound { column, .. } if column == "totl")
687 ));
688 }
689
690 #[test]
691 fn test_error_display() {
692 let err = ValidationError::TableNotFound {
693 table: "usrs".to_string(),
694 suggestion: Some("users".to_string()),
695 };
696 assert_eq!(
697 err.to_string(),
698 "Table 'usrs' not found. Did you mean 'users'?"
699 );
700
701 let err = ValidationError::ColumnNotFound {
702 table: "users".to_string(),
703 column: "emial".to_string(),
704 suggestion: Some("email".to_string()),
705 };
706 assert_eq!(
707 err.to_string(),
708 "Column 'emial' not found in table 'users'. Did you mean 'email'?"
709 );
710 }
711}