1use crate::ast::{Expr, Qail};
7use std::collections::HashMap;
8use strsim::levenshtein;
9
10#[derive(Debug, Clone, PartialEq)]
12pub enum ValidationError {
13 TableNotFound {
15 table: String,
17 suggestion: Option<String>,
19 },
20 ColumnNotFound {
22 table: String,
24 column: String,
26 suggestion: Option<String>,
28 },
29 TypeMismatch {
31 table: String,
33 column: String,
35 expected: String,
37 got: String,
39 },
40 InvalidOperator {
42 column: String,
44 operator: String,
46 reason: String,
48 },
49}
50
51impl std::fmt::Display for ValidationError {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 ValidationError::TableNotFound { table, suggestion } => {
55 if let Some(s) = suggestion {
56 write!(f, "Table '{}' not found. Did you mean '{}'?", table, s)
57 } else {
58 write!(f, "Table '{}' not found.", table)
59 }
60 }
61 ValidationError::ColumnNotFound {
62 table,
63 column,
64 suggestion,
65 } => {
66 if let Some(s) = suggestion {
67 write!(
68 f,
69 "Column '{}' not found in table '{}'. Did you mean '{}'?",
70 column, table, s
71 )
72 } else {
73 write!(f, "Column '{}' not found in table '{}'.", column, table)
74 }
75 }
76 ValidationError::TypeMismatch {
77 table,
78 column,
79 expected,
80 got,
81 } => {
82 write!(
83 f,
84 "Type mismatch for '{}.{}': expected {}, got {}",
85 table, column, expected, got
86 )
87 }
88 ValidationError::InvalidOperator {
89 column,
90 operator,
91 reason,
92 } => {
93 write!(
94 f,
95 "Invalid operator '{}' for column '{}': {}",
96 operator, column, reason
97 )
98 }
99 }
100 }
101}
102
103impl std::error::Error for ValidationError {}
104
105pub type ValidationResult = Result<(), Vec<ValidationError>>;
107
108#[derive(Debug, Clone)]
110pub struct Validator {
111 tables: Vec<String>,
113 columns: HashMap<String, Vec<String>>,
115 #[allow(dead_code)]
117 column_types: HashMap<String, HashMap<String, String>>,
118}
119
120impl Default for Validator {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl Validator {
127 pub fn new() -> Self {
129 Self {
130 tables: Vec::new(),
131 columns: HashMap::new(),
132 column_types: HashMap::new(),
133 }
134 }
135
136 pub fn add_table(&mut self, table: &str, cols: &[&str]) {
138 self.tables.push(table.to_string());
139 self.columns.insert(
140 table.to_string(),
141 cols.iter().map(|s| s.to_string()).collect(),
142 );
143 }
144
145 pub fn add_table_with_types(&mut self, table: &str, cols: &[(&str, &str)]) {
152 self.tables.push(table.to_string());
153 let col_names: Vec<String> = cols.iter().map(|(name, _)| name.to_string()).collect();
154 self.columns.insert(table.to_string(), col_names);
155
156 let type_map: HashMap<String, String> = cols
157 .iter()
158 .map(|(name, typ)| (name.to_string(), typ.to_string()))
159 .collect();
160 self.column_types.insert(table.to_string(), type_map);
161 }
162
163 pub fn table_names(&self) -> &[String] {
165 &self.tables
166 }
167
168 pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
170 self.columns.get(table)
171 }
172
173 pub fn table_exists(&self, table: &str) -> bool {
175 self.tables.contains(&table.to_string())
176 }
177
178 pub fn validate_table(&self, table: &str) -> Result<(), ValidationError> {
180 if self.tables.contains(&table.to_string()) {
181 Ok(())
182 } else {
183 let suggestion = self.did_you_mean(table, &self.tables);
184 Err(ValidationError::TableNotFound {
185 table: table.to_string(),
186 suggestion,
187 })
188 }
189 }
190
191 pub fn validate_column(&self, table: &str, column: &str) -> Result<(), ValidationError> {
198 if !self.tables.contains(&table.to_string()) {
200 return Ok(());
201 }
202
203 if column == "*" {
205 return Ok(());
206 }
207
208 if column.contains('(')
212 || column.contains('[')
213 || column.contains("::")
214 || column.contains(" AS ")
215 || column.contains(" as ")
216 || column.starts_with("distinct ")
217 || column.starts_with("DISTINCT ")
218 {
219 return Ok(());
220 }
221
222 if column.contains('.') {
224 let parts: Vec<&str> = column.split('.').collect();
225 if parts.len() == 2 {
226 if self.tables.contains(&parts[0].to_string()) {
228 return self.validate_column(parts[0], parts[1]);
229 }
230 }
231 return Ok(());
233 }
234
235 if let Some(cols) = self.columns.get(table) {
236 if cols.contains(&column.to_string()) {
237 Ok(())
238 } else {
239 let suggestion = self.did_you_mean(column, cols);
240 Err(ValidationError::ColumnNotFound {
241 table: table.to_string(),
242 column: column.to_string(),
243 suggestion,
244 })
245 }
246 } else {
247 Ok(())
248 }
249 }
250
251 fn extract_column_name(expr: &Expr) -> Option<String> {
253 match expr {
254 Expr::Named(name) => Some(name.clone()),
255 Expr::Aliased { name, .. } => Some(name.clone()),
256 Expr::Aggregate { col, .. } => Some(col.clone()),
257 Expr::Cast { expr, .. } => Self::extract_column_name(expr),
258 Expr::JsonAccess { column, .. } => Some(column.clone()),
259 _ => None,
260 }
261 }
262
263 pub fn get_column_type(&self, table: &str, column: &str) -> Option<&String> {
265 self.column_types.get(table)?.get(column)
266 }
267
268 pub fn validate_value_type(
271 &self,
272 table: &str,
273 column: &str,
274 value: &crate::ast::Value,
275 ) -> Result<(), ValidationError> {
276 use crate::ast::Value;
277
278 let expected_type = match self.get_column_type(table, column) {
280 Some(t) => t.to_uppercase(),
281 None => return Ok(()), };
283
284 if matches!(value, Value::Null | Value::NullUuid) {
286 return Ok(());
287 }
288
289 if matches!(
291 value,
292 Value::Param(_)
293 | Value::NamedParam(_)
294 | Value::Function(_)
295 | Value::Subquery(_)
296 | Value::Expr(_)
297 ) {
298 return Ok(());
299 }
300
301 if matches!(value, Value::Array(_)) {
305 return Ok(());
306 }
307
308 let value_type = match value {
310 Value::Bool(_) => "BOOLEAN",
311 Value::Int(_) => "INT",
312 Value::Float(_) => "FLOAT",
313 Value::String(_) => "TEXT",
314 Value::Uuid(_) => "UUID",
315 Value::Column(_) => return Ok(()), Value::Interval { .. } => "INTERVAL",
317 Value::Timestamp(_) => "TIMESTAMP",
318 Value::Bytes(_) => "BYTEA",
319 Value::Vector(_) => "VECTOR",
320 Value::Json(_) => "JSONB",
321 _ => return Ok(()), };
323
324 if !Self::types_compatible(&expected_type, value_type) {
326 return Err(ValidationError::TypeMismatch {
327 table: table.to_string(),
328 column: column.to_string(),
329 expected: expected_type,
330 got: value_type.to_string(),
331 });
332 }
333
334 Ok(())
335 }
336
337 fn types_compatible(expected: &str, value_type: &str) -> bool {
340 let expected = expected.to_uppercase();
341 let value_type = value_type.to_uppercase();
342
343 if expected == value_type {
345 return true;
346 }
347
348 let int_types = [
350 "INT",
351 "INT4",
352 "INT8",
353 "INTEGER",
354 "BIGINT",
355 "SMALLINT",
356 "SERIAL",
357 "BIGSERIAL",
358 ];
359 if int_types.contains(&expected.as_str()) && value_type == "INT" {
360 return true;
361 }
362
363 let float_types = [
365 "FLOAT",
366 "FLOAT4",
367 "FLOAT8",
368 "DOUBLE",
369 "DOUBLE PRECISION",
370 "DECIMAL",
371 "NUMERIC",
372 "REAL",
373 ];
374 if float_types.contains(&expected.as_str())
375 && (value_type == "FLOAT" || value_type == "INT")
376 {
377 return true;
378 }
379
380 let text_types = ["TEXT", "VARCHAR", "CHAR", "CHARACTER", "CITEXT", "NAME"];
382 if text_types.contains(&expected.as_str()) && value_type == "TEXT" {
383 return true;
384 }
385
386 if (expected == "BOOLEAN" || expected == "BOOL") && value_type == "BOOLEAN" {
388 return true;
389 }
390
391 if expected == "UUID" && (value_type == "UUID" || value_type == "TEXT") {
393 return true;
394 }
395
396 let ts_types = ["TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME", "TIMETZ"];
398 if ts_types.contains(&expected.as_str())
399 && (value_type == "TIMESTAMP" || value_type == "TEXT")
400 {
401 return true;
402 }
403
404 if expected == "JSONB" || expected == "JSON" {
406 return true;
407 }
408
409 if expected.contains("[]") || expected.starts_with("_") {
411 return value_type == "ARRAY";
412 }
413
414 false
415 }
416
417 pub fn validate_command(&self, cmd: &Qail) -> ValidationResult {
419 let mut errors = Vec::new();
420
421 if let Err(e) = self.validate_table(&cmd.table) {
422 errors.push(e);
423 }
424
425 let mut aliases: Vec<String> = Vec::new();
429 for col in &cmd.columns {
430 if let Expr::Aliased { alias, .. } = col {
431 aliases.push(alias.clone());
432 }
433 if let Some(name) = Self::extract_column_name(col)
434 && let Err(e) = self.validate_column(&cmd.table, &name)
435 {
436 errors.push(e);
437 }
438 }
439
440 for cage in &cmd.cages {
441 if matches!(cage.kind, crate::ast::CageKind::Sort(_)) {
445 continue;
446 }
447 for cond in &cage.conditions {
448 if let Some(name) = Self::extract_column_name(&cond.left) {
449 if aliases.iter().any(|a| a == &name) {
451 continue;
452 }
453 if name.contains('.') {
455 let parts: Vec<&str> = name.split('.').collect();
456 if parts.len() == 2 {
457 if let Err(e) = self.validate_column(parts[0], parts[1]) {
458 errors.push(e);
459 }
460 if let Err(e) =
462 self.validate_value_type(parts[0], parts[1], &cond.value)
463 {
464 errors.push(e);
465 }
466 }
467 } else {
468 if let Err(e) = self.validate_column(&cmd.table, &name) {
469 errors.push(e);
470 }
471 if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
473 errors.push(e);
474 }
475 }
476 }
477 }
478 }
479
480 for cond in &cmd.having {
481 if let Some(name) = Self::extract_column_name(&cond.left) {
482 if name.contains('(') || name == "*" {
483 continue;
484 }
485 if name.contains('.') {
486 let parts: Vec<&str> = name.split('.').collect();
487 if parts.len() == 2 {
488 if let Err(e) = self.validate_column(parts[0], parts[1]) {
489 errors.push(e);
490 }
491 if let Err(e) = self.validate_value_type(parts[0], parts[1], &cond.value) {
492 errors.push(e);
493 }
494 }
495 } else {
496 if let Err(e) = self.validate_column(&cmd.table, &name) {
497 errors.push(e);
498 }
499 if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
500 errors.push(e);
501 }
502 }
503 }
504 }
505
506 for join in &cmd.joins {
507 if let Err(e) = self.validate_table(&join.table) {
508 errors.push(e);
509 }
510
511 if let Some(conditions) = &join.on {
512 for cond in conditions {
513 if let Some(name) = Self::extract_column_name(&cond.left)
514 && name.contains('.')
515 {
516 let parts: Vec<&str> = name.split('.').collect();
517 if parts.len() == 2
518 && let Err(e) = self.validate_column(parts[0], parts[1])
519 {
520 errors.push(e);
521 }
522 }
523 if let crate::ast::Value::Column(col_name) = &cond.value
525 && col_name.contains('.')
526 {
527 let parts: Vec<&str> = col_name.split('.').collect();
528 if parts.len() == 2
529 && let Err(e) = self.validate_column(parts[0], parts[1])
530 {
531 errors.push(e);
532 }
533 }
534 }
535 }
536 }
537
538 if let Some(returning) = &cmd.returning {
539 for col in returning {
540 if let Some(name) = Self::extract_column_name(col)
541 && let Err(e) = self.validate_column(&cmd.table, &name)
542 {
543 errors.push(e);
544 }
545 }
546 }
547
548 if errors.is_empty() {
549 Ok(())
550 } else {
551 Err(errors)
552 }
553 }
554
555 fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
557 let mut best_match = None;
558 let mut min_dist = usize::MAX;
559
560 for cand in candidates {
561 let cand_str = cand.as_ref();
562 let dist = levenshtein(input, cand_str);
563
564 let threshold = match input.len() {
566 0..=2 => 0, 3..=5 => 2, _ => 3, };
570
571 if dist <= threshold && dist < min_dist {
572 min_dist = dist;
573 best_match = Some(cand_str.to_string());
574 }
575 }
576
577 best_match
578 }
579
580 #[deprecated(note = "Use validate_table() which returns ValidationError")]
586 pub fn validate_table_legacy(&self, table: &str) -> Result<(), String> {
587 self.validate_table(table).map_err(|e| e.to_string())
588 }
589
590 #[deprecated(note = "Use validate_column() which returns ValidationError")]
592 pub fn validate_column_legacy(&self, table: &str, column: &str) -> Result<(), String> {
593 self.validate_column(table, column)
594 .map_err(|e| e.to_string())
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601
602 #[test]
603 fn test_did_you_mean_table() {
604 let mut v = Validator::new();
605 v.add_table("users", &["id", "name"]);
606 v.add_table("orders", &["id", "total"]);
607
608 assert!(v.validate_table("users").is_ok());
609
610 let err = v.validate_table("usr").unwrap_err();
611 assert!(
612 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
613 );
614
615 let err = v.validate_table("usrs").unwrap_err();
616 assert!(
617 matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
618 );
619 }
620
621 #[test]
622 fn test_did_you_mean_column() {
623 let mut v = Validator::new();
624 v.add_table("users", &["email", "password"]);
625
626 assert!(v.validate_column("users", "email").is_ok());
627 assert!(v.validate_column("users", "*").is_ok());
628
629 let err = v.validate_column("users", "emial").unwrap_err();
630 assert!(
631 matches!(err, ValidationError::ColumnNotFound { suggestion: Some(ref s), .. } if s == "email")
632 );
633 }
634
635 #[test]
636 fn test_qualified_column_name() {
637 let mut v = Validator::new();
638 v.add_table("users", &["id", "name"]);
639 v.add_table("profiles", &["user_id", "avatar"]);
640
641 assert!(v.validate_column("users", "users.id").is_ok());
643 assert!(v.validate_column("users", "profiles.user_id").is_ok());
644 }
645
646 #[test]
647 fn test_validate_command() {
648 let mut v = Validator::new();
649 v.add_table("users", &["id", "email", "name"]);
650
651 let cmd = Qail::get("users").columns(["id", "email"]);
652 assert!(v.validate_command(&cmd).is_ok());
653
654 let cmd = Qail::get("users").columns(["id", "emial"]); let errors = v.validate_command(&cmd).unwrap_err();
656 assert_eq!(errors.len(), 1);
657 assert!(
658 matches!(&errors[0], ValidationError::ColumnNotFound { column, .. } if column == "emial")
659 );
660 }
661
662 #[test]
663 fn test_validate_having_columns() {
664 let mut v = Validator::new();
665 v.add_table("orders", &["id", "status", "total"]);
666
667 let mut cmd = Qail::get("orders");
668 cmd.having.push(crate::ast::Condition {
669 left: Expr::Named("totl".to_string()),
670 op: crate::ast::Operator::Eq,
671 value: crate::ast::Value::Int(1),
672 is_array_unnest: false,
673 });
674
675 let errors = v.validate_command(&cmd).unwrap_err();
676 assert!(errors.iter().any(
677 |e| matches!(e, ValidationError::ColumnNotFound { column, .. } if column == "totl")
678 ));
679 }
680
681 #[test]
682 fn test_error_display() {
683 let err = ValidationError::TableNotFound {
684 table: "usrs".to_string(),
685 suggestion: Some("users".to_string()),
686 };
687 assert_eq!(
688 err.to_string(),
689 "Table 'usrs' not found. Did you mean 'users'?"
690 );
691
692 let err = ValidationError::ColumnNotFound {
693 table: "users".to_string(),
694 column: "emial".to_string(),
695 suggestion: Some("email".to_string()),
696 };
697 assert_eq!(
698 err.to_string(),
699 "Column 'emial' not found in table 'users'. Did you mean 'email'?"
700 );
701 }
702}