1use crate::ast::{Expr, Qail};
7use std::collections::HashMap;
8use strsim::levenshtein;
9
10#[derive(Debug, Clone, PartialEq)]
12pub enum ValidationError {
13 TableNotFound {
15 table: String,
17 suggestion: Option<String>,
19 },
20 ColumnNotFound {
22 table: String,
24 column: String,
26 suggestion: Option<String>,
28 },
29 TypeMismatch {
31 table: String,
33 column: String,
35 expected: String,
37 got: String,
39 },
40 InvalidOperator {
42 column: String,
44 operator: String,
46 reason: String,
48 },
49}
50
51impl std::fmt::Display for ValidationError {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 ValidationError::TableNotFound { table, suggestion } => {
55 if let Some(s) = suggestion {
56 write!(f, "Table '{}' not found. Did you mean '{}'?", table, s)
57 } else {
58 write!(f, "Table '{}' not found.", table)
59 }
60 }
61 ValidationError::ColumnNotFound {
62 table,
63 column,
64 suggestion,
65 } => {
66 if let Some(s) = suggestion {
67 write!(
68 f,
69 "Column '{}' not found in table '{}'. Did you mean '{}'?",
70 column, table, s
71 )
72 } else {
73 write!(f, "Column '{}' not found in table '{}'.", column, table)
74 }
75 }
76 ValidationError::TypeMismatch {
77 table,
78 column,
79 expected,
80 got,
81 } => {
82 write!(
83 f,
84 "Type mismatch for '{}.{}': expected {}, got {}",
85 table, column, expected, got
86 )
87 }
88 ValidationError::InvalidOperator {
89 column,
90 operator,
91 reason,
92 } => {
93 write!(
94 f,
95 "Invalid operator '{}' for column '{}': {}",
96 operator, column, reason
97 )
98 }
99 }
100 }
101}
102
103impl std::error::Error for ValidationError {}
104
105pub type ValidationResult = Result<(), Vec<ValidationError>>;
107
108#[derive(Debug, Clone)]
110pub struct Validator {
111 tables: Vec<String>,
113 columns: HashMap<String, Vec<String>>,
115 column_types: HashMap<String, HashMap<String, String>>,
117}
118
119impl Default for Validator {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125impl Validator {
126 pub fn new() -> Self {
128 Self {
129 tables: Vec::new(),
130 columns: HashMap::new(),
131 column_types: HashMap::new(),
132 }
133 }
134
135 pub fn add_table(&mut self, table: &str, cols: &[&str]) {
137 self.tables.push(table.to_string());
138 self.columns.insert(
139 table.to_string(),
140 cols.iter().map(|s| s.to_string()).collect(),
141 );
142 }
143
144 pub fn add_table_name(&mut self, table: &str) {
150 self.tables.push(table.to_string());
151 }
152
153 pub fn add_table_with_types(&mut self, table: &str, cols: &[(&str, &str)]) {
160 self.tables.push(table.to_string());
161 let col_names: Vec<String> = cols.iter().map(|(name, _)| name.to_string()).collect();
162 self.columns.insert(table.to_string(), col_names);
163
164 let type_map: HashMap<String, String> = cols
165 .iter()
166 .map(|(name, typ)| (name.to_string(), typ.to_string()))
167 .collect();
168 self.column_types.insert(table.to_string(), type_map);
169 }
170
171 pub fn table_names(&self) -> &[String] {
173 &self.tables
174 }
175
176 pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
178 self.columns.get(table)
179 }
180
181 pub fn table_exists(&self, table: &str) -> bool {
183 self.tables.contains(&table.to_string())
184 }
185
186 pub fn validate_table(&self, table: &str) -> Result<(), ValidationError> {
188 if self.tables.contains(&table.to_string()) {
189 Ok(())
190 } else {
191 let suggestion = self.did_you_mean(table, &self.tables);
192 Err(ValidationError::TableNotFound {
193 table: table.to_string(),
194 suggestion,
195 })
196 }
197 }
198
199 pub fn validate_column(&self, table: &str, column: &str) -> Result<(), ValidationError> {
206 if !self.tables.contains(&table.to_string()) {
208 return Ok(());
209 }
210
211 if column == "*" {
213 return Ok(());
214 }
215
216 if column.contains('(')
220 || column.contains('[')
221 || column.contains("::")
222 || column.contains(" AS ")
223 || column.contains(" as ")
224 || column.starts_with("distinct ")
225 || column.starts_with("DISTINCT ")
226 {
227 return Ok(());
228 }
229
230 if column.contains('.') {
232 let parts: Vec<&str> = column.split('.').collect();
233 if parts.len() == 3 && parts[0] == "public" {
234 let public_table = format!("{}.{}", parts[0], parts[1]);
235 if self.tables.contains(&public_table) {
236 return self.validate_column(&public_table, parts[2]);
237 }
238 if self.tables.contains(&parts[1].to_string()) {
239 return self.validate_column(parts[1], parts[2]);
240 }
241 }
242 if parts.len() == 2 {
243 if self.tables.contains(&parts[0].to_string()) {
245 return self.validate_column(parts[0], parts[1]);
246 }
247 let public_table = format!("public.{}", parts[0]);
248 if self.tables.contains(&public_table) {
249 return self.validate_column(&public_table, parts[1]);
250 }
251 }
252 return Ok(());
254 }
255
256 if let Some(cols) = self.columns.get(table) {
257 if cols.contains(&column.to_string()) {
258 Ok(())
259 } else {
260 let suggestion = self.did_you_mean(column, cols);
261 Err(ValidationError::ColumnNotFound {
262 table: table.to_string(),
263 column: column.to_string(),
264 suggestion,
265 })
266 }
267 } else {
268 Ok(())
269 }
270 }
271
272 fn extract_column_name(expr: &Expr) -> Option<String> {
274 match expr {
275 Expr::Named(name) => Some(name.clone()),
276 Expr::Aliased { name, .. } => Some(name.clone()),
277 Expr::Aggregate { col, .. } => Some(col.clone()),
278 Expr::Cast { expr, .. } => Self::extract_column_name(expr),
279 Expr::JsonAccess { column, .. } => Some(column.clone()),
280 _ => None,
281 }
282 }
283
284 pub fn get_column_type(&self, table: &str, column: &str) -> Option<&String> {
286 self.column_types.get(table)?.get(column)
287 }
288
289 pub fn validate_value_type(
292 &self,
293 table: &str,
294 column: &str,
295 value: &crate::ast::Value,
296 ) -> Result<(), ValidationError> {
297 use crate::ast::Value;
298
299 let expected_type = match self.get_column_type(table, column) {
301 Some(t) => t.to_uppercase(),
302 None => return Ok(()), };
304
305 if matches!(value, Value::Null | Value::NullUuid) {
307 return Ok(());
308 }
309
310 if matches!(
312 value,
313 Value::Param(_)
314 | Value::NamedParam(_)
315 | Value::Function(_)
316 | Value::Subquery(_)
317 | Value::Expr(_)
318 ) {
319 return Ok(());
320 }
321
322 if matches!(value, Value::Array(_)) {
326 return Ok(());
327 }
328
329 let value_type = match value {
331 Value::Bool(_) => "BOOLEAN",
332 Value::Int(_) => "INT",
333 Value::Float(_) => "FLOAT",
334 Value::String(_) => "TEXT",
335 Value::Uuid(_) => "UUID",
336 Value::Column(_) => return Ok(()), Value::Interval { .. } => "INTERVAL",
338 Value::Timestamp(_) => "TIMESTAMP",
339 Value::Bytes(_) => "BYTEA",
340 Value::Vector(_) => "VECTOR",
341 Value::Json(_) => "JSONB",
342 _ => return Ok(()), };
344
345 if !Self::types_compatible(&expected_type, value_type) {
347 return Err(ValidationError::TypeMismatch {
348 table: table.to_string(),
349 column: column.to_string(),
350 expected: expected_type,
351 got: value_type.to_string(),
352 });
353 }
354
355 Ok(())
356 }
357
358 fn types_compatible(expected: &str, value_type: &str) -> bool {
361 let expected = expected.to_uppercase();
362 let value_type = value_type.to_uppercase();
363
364 if expected == value_type {
366 return true;
367 }
368
369 let int_types = [
371 "INT",
372 "INT4",
373 "INT8",
374 "INTEGER",
375 "BIGINT",
376 "SMALLINT",
377 "SERIAL",
378 "BIGSERIAL",
379 ];
380 if int_types.contains(&expected.as_str()) && value_type == "INT" {
381 return true;
382 }
383
384 let float_types = [
386 "FLOAT",
387 "FLOAT4",
388 "FLOAT8",
389 "DOUBLE",
390 "DOUBLE PRECISION",
391 "DECIMAL",
392 "NUMERIC",
393 "REAL",
394 ];
395 if float_types.contains(&expected.as_str())
396 && (value_type == "FLOAT" || value_type == "INT")
397 {
398 return true;
399 }
400
401 let text_types = ["TEXT", "VARCHAR", "CHAR", "CHARACTER", "CITEXT", "NAME"];
403 if text_types.contains(&expected.as_str()) && value_type == "TEXT" {
404 return true;
405 }
406
407 if (expected == "BOOLEAN" || expected == "BOOL") && value_type == "BOOLEAN" {
409 return true;
410 }
411
412 if expected == "UUID" && (value_type == "UUID" || value_type == "TEXT") {
414 return true;
415 }
416
417 let ts_types = ["TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME", "TIMETZ"];
419 if ts_types.contains(&expected.as_str())
420 && (value_type == "TIMESTAMP" || value_type == "TEXT")
421 {
422 return true;
423 }
424
425 if expected == "JSONB" || expected == "JSON" {
427 return true;
428 }
429
430 if expected.contains("[]") || expected.starts_with("_") {
432 return value_type == "ARRAY";
433 }
434
435 false
436 }
437
438 pub fn validate_command(&self, cmd: &Qail) -> ValidationResult {
440 let mut errors = Vec::new();
441
442 if let Err(e) = self.validate_table(&cmd.table) {
443 errors.push(e);
444 }
445
446 let mut aliases: Vec<String> = Vec::new();
450 for col in &cmd.columns {
451 if let Expr::Aliased { alias, .. } = col {
452 aliases.push(alias.clone());
453 }
454 if let Some(name) = Self::extract_column_name(col)
455 && let Err(e) = self.validate_column(&cmd.table, &name)
456 {
457 errors.push(e);
458 }
459 }
460
461 for cage in &cmd.cages {
462 if matches!(cage.kind, crate::ast::CageKind::Sort(_)) {
466 continue;
467 }
468 for cond in &cage.conditions {
469 if let Some(name) = Self::extract_column_name(&cond.left) {
470 if aliases.iter().any(|a| a == &name) {
472 continue;
473 }
474 if name.contains('.') {
476 let parts: Vec<&str> = name.split('.').collect();
477 if parts.len() == 2 {
478 if let Err(e) = self.validate_column(parts[0], parts[1]) {
479 errors.push(e);
480 }
481 if let Err(e) =
483 self.validate_value_type(parts[0], parts[1], &cond.value)
484 {
485 errors.push(e);
486 }
487 }
488 } else {
489 if let Err(e) = self.validate_column(&cmd.table, &name) {
490 errors.push(e);
491 }
492 if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
494 errors.push(e);
495 }
496 }
497 }
498 }
499 }
500
501 for cond in &cmd.having {
502 if let Some(name) = Self::extract_column_name(&cond.left) {
503 if name.contains('(') || name == "*" {
504 continue;
505 }
506 if name.contains('.') {
507 let parts: Vec<&str> = name.split('.').collect();
508 if parts.len() == 2 {
509 if let Err(e) = self.validate_column(parts[0], parts[1]) {
510 errors.push(e);
511 }
512 if let Err(e) = self.validate_value_type(parts[0], parts[1], &cond.value) {
513 errors.push(e);
514 }
515 }
516 } else {
517 if let Err(e) = self.validate_column(&cmd.table, &name) {
518 errors.push(e);
519 }
520 if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
521 errors.push(e);
522 }
523 }
524 }
525 }
526
527 for join in &cmd.joins {
528 if let Err(e) = self.validate_table(&join.table) {
529 errors.push(e);
530 }
531
532 if let Some(conditions) = &join.on {
533 for cond in conditions {
534 if let Some(name) = Self::extract_column_name(&cond.left)
535 && name.contains('.')
536 {
537 let parts: Vec<&str> = name.split('.').collect();
538 if parts.len() == 2
539 && let Err(e) = self.validate_column(parts[0], parts[1])
540 {
541 errors.push(e);
542 }
543 }
544 if let crate::ast::Value::Column(col_name) = &cond.value
546 && col_name.contains('.')
547 {
548 let parts: Vec<&str> = col_name.split('.').collect();
549 if parts.len() == 2
550 && let Err(e) = self.validate_column(parts[0], parts[1])
551 {
552 errors.push(e);
553 }
554 }
555 }
556 }
557 }
558
559 if let Some(returning) = &cmd.returning {
560 for col in returning {
561 if let Some(name) = Self::extract_column_name(col)
562 && let Err(e) = self.validate_column(&cmd.table, &name)
563 {
564 errors.push(e);
565 }
566 }
567 }
568
569 if errors.is_empty() {
570 Ok(())
571 } else {
572 Err(errors)
573 }
574 }
575
576 fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
578 let mut best_match = None;
579 let mut min_dist = usize::MAX;
580
581 for cand in candidates {
582 let cand_str = cand.as_ref();
583 let dist = levenshtein(input, cand_str);
584
585 let threshold = match input.len() {
587 0..=2 => 0, 3..=5 => 2, _ => 3, };
591
592 if dist <= threshold && dist < min_dist {
593 min_dist = dist;
594 best_match = Some(cand_str.to_string());
595 }
596 }
597
598 best_match
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[test]
607 fn test_did_you_mean_table() {
608 let mut v = Validator::new();
609 v.add_table("users", &["id", "name"]);
610 v.add_table("orders", &["id", "total"]);
611
612 assert!(v.validate_table("users").is_ok());
613
614 let err = v.validate_table("usr").unwrap_err();
615 assert!(
616 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
617 );
618
619 let err = v.validate_table("usrs").unwrap_err();
620 assert!(
621 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
622 );
623 }
624
625 #[test]
626 fn test_did_you_mean_column() {
627 let mut v = Validator::new();
628 v.add_table("users", &["email", "password"]);
629
630 assert!(v.validate_column("users", "email").is_ok());
631 assert!(v.validate_column("users", "*").is_ok());
632
633 let err = v.validate_column("users", "emial").unwrap_err();
634 assert!(
635 matches!(err, ValidationError::ColumnNotFound { suggestion: Some(ref s), .. } if s == "email")
636 );
637 }
638
639 #[test]
640 fn test_qualified_column_name() {
641 let mut v = Validator::new();
642 v.add_table("users", &["id", "name"]);
643 v.add_table("profiles", &["user_id", "avatar"]);
644
645 assert!(v.validate_column("users", "users.id").is_ok());
647 assert!(v.validate_column("users", "profiles.user_id").is_ok());
648 }
649
650 #[test]
651 fn test_validate_command() {
652 let mut v = Validator::new();
653 v.add_table("users", &["id", "email", "name"]);
654
655 let cmd = Qail::get("users").columns(["id", "email"]);
656 assert!(v.validate_command(&cmd).is_ok());
657
658 let cmd = Qail::get("users").columns(["id", "emial"]); let errors = v.validate_command(&cmd).unwrap_err();
660 assert_eq!(errors.len(), 1);
661 assert!(
662 matches!(&errors[0], ValidationError::ColumnNotFound { column, .. } if column == "emial")
663 );
664 }
665
666 #[test]
667 fn test_validate_having_columns() {
668 let mut v = Validator::new();
669 v.add_table("orders", &["id", "status", "total"]);
670
671 let mut cmd = Qail::get("orders");
672 cmd.having.push(crate::ast::Condition {
673 left: Expr::Named("totl".to_string()),
674 op: crate::ast::Operator::Eq,
675 value: crate::ast::Value::Int(1),
676 is_array_unnest: false,
677 });
678
679 let errors = v.validate_command(&cmd).unwrap_err();
680 assert!(errors.iter().any(
681 |e| matches!(e, ValidationError::ColumnNotFound { column, .. } if column == "totl")
682 ));
683 }
684
685 #[test]
686 fn test_error_display() {
687 let err = ValidationError::TableNotFound {
688 table: "usrs".to_string(),
689 suggestion: Some("users".to_string()),
690 };
691 assert_eq!(
692 err.to_string(),
693 "Table 'usrs' not found. Did you mean 'users'?"
694 );
695
696 let err = ValidationError::ColumnNotFound {
697 table: "users".to_string(),
698 column: "emial".to_string(),
699 suggestion: Some("email".to_string()),
700 };
701 assert_eq!(
702 err.to_string(),
703 "Column 'emial' not found in table 'users'. Did you mean 'email'?"
704 );
705 }
706}