Skip to main content

qail_core/
validator.rs

1//! Schema validator and fuzzy matching suggestions.
2//!
3//! Provides compile-time-like validation for Qail against a known schema.
4//! Used by CLI, LSP, and the encoder to catch errors before they hit the wire.
5
6use crate::ast::{Expr, Qail};
7use std::collections::HashMap;
8use strsim::levenshtein;
9
10/// Validation error with structured information.
11#[derive(Debug, Clone, PartialEq)]
12pub enum ValidationError {
13    /// Referenced table does not exist in the schema.
14    TableNotFound {
15        /// Name of the missing table.
16        table: String,
17        /// Closest match from known tables.
18        suggestion: Option<String>,
19    },
20    /// Referenced column does not exist in the table.
21    ColumnNotFound {
22        /// Table the column was looked up in.
23        table: String,
24        /// Name of the missing column.
25        column: String,
26        /// Closest match from known columns.
27        suggestion: Option<String>,
28    },
29    /// Type mismatch (future: when schema includes types)
30    TypeMismatch {
31        /// Table name.
32        table: String,
33        /// Column name.
34        column: String,
35        /// Expected type.
36        expected: String,
37        /// Actual type.
38        got: String,
39    },
40    /// Invalid operator for column type (future)
41    InvalidOperator {
42        /// Column name.
43        column: String,
44        /// Operator string.
45        operator: String,
46        /// Explanation.
47        reason: String,
48    },
49}
50
51impl std::fmt::Display for ValidationError {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            ValidationError::TableNotFound { table, suggestion } => {
55                if let Some(s) = suggestion {
56                    write!(f, "Table '{}' not found. Did you mean '{}'?", table, s)
57                } else {
58                    write!(f, "Table '{}' not found.", table)
59                }
60            }
61            ValidationError::ColumnNotFound {
62                table,
63                column,
64                suggestion,
65            } => {
66                if let Some(s) = suggestion {
67                    write!(
68                        f,
69                        "Column '{}' not found in table '{}'. Did you mean '{}'?",
70                        column, table, s
71                    )
72                } else {
73                    write!(f, "Column '{}' not found in table '{}'.", column, table)
74                }
75            }
76            ValidationError::TypeMismatch {
77                table,
78                column,
79                expected,
80                got,
81            } => {
82                write!(
83                    f,
84                    "Type mismatch for '{}.{}': expected {}, got {}",
85                    table, column, expected, got
86                )
87            }
88            ValidationError::InvalidOperator {
89                column,
90                operator,
91                reason,
92            } => {
93                write!(
94                    f,
95                    "Invalid operator '{}' for column '{}': {}",
96                    operator, column, reason
97                )
98            }
99        }
100    }
101}
102
103impl std::error::Error for ValidationError {}
104
105/// Result of validation
106pub type ValidationResult = Result<(), Vec<ValidationError>>;
107
108/// Validates query elements against known schema and provides suggestions.
109#[derive(Debug, Clone)]
110pub struct Validator {
111    /// Known table names.
112    tables: Vec<String>,
113    /// Columns indexed by table name.
114    columns: HashMap<String, Vec<String>>,
115    /// Column types indexed by table.column.
116    column_types: HashMap<String, HashMap<String, String>>,
117}
118
119impl Default for Validator {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125impl Validator {
126    /// Create a new Validator with known tables and columns.
127    pub fn new() -> Self {
128        Self {
129            tables: Vec::new(),
130            columns: HashMap::new(),
131            column_types: HashMap::new(),
132        }
133    }
134
135    /// Register a table and its columns.
136    pub fn add_table(&mut self, table: &str, cols: &[&str]) {
137        self.tables.push(table.to_string());
138        self.columns.insert(
139            table.to_string(),
140            cols.iter().map(|s| s.to_string()).collect(),
141        );
142    }
143
144    /// Register a table name only (no column metadata).
145    ///
146    /// Useful for views discovered from schema text where build-time parser
147    /// does not infer projected columns. Table existence is validated, while
148    /// column-level checks are skipped.
149    pub fn add_table_name(&mut self, table: &str) {
150        self.tables.push(table.to_string());
151    }
152
153    /// Register a table with column types (for future type validation).
154    ///
155    /// # Arguments
156    ///
157    /// * `table` — Table name to register.
158    /// * `cols` — Slice of `(column_name, column_type)` pairs.
159    pub fn add_table_with_types(&mut self, table: &str, cols: &[(&str, &str)]) {
160        self.tables.push(table.to_string());
161        let col_names: Vec<String> = cols.iter().map(|(name, _)| name.to_string()).collect();
162        self.columns.insert(table.to_string(), col_names);
163
164        let type_map: HashMap<String, String> = cols
165            .iter()
166            .map(|(name, typ)| (name.to_string(), typ.to_string()))
167            .collect();
168        self.column_types.insert(table.to_string(), type_map);
169    }
170
171    /// Get list of all table names (for autocomplete).
172    pub fn table_names(&self) -> &[String] {
173        &self.tables
174    }
175
176    /// Get column names for a table (for autocomplete).
177    pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
178        self.columns.get(table)
179    }
180
181    /// Check if a table exists.
182    pub fn table_exists(&self, table: &str) -> bool {
183        self.tables.contains(&table.to_string())
184    }
185
186    /// Check if a table exists. If not, returns structured error with suggestion.
187    pub fn validate_table(&self, table: &str) -> Result<(), ValidationError> {
188        if self.tables.contains(&table.to_string()) {
189            Ok(())
190        } else {
191            let suggestion = self.did_you_mean(table, &self.tables);
192            Err(ValidationError::TableNotFound {
193                table: table.to_string(),
194                suggestion,
195            })
196        }
197    }
198
199    /// Check if a column exists in a table. If not, returns a structured error.
200    ///
201    /// # Arguments
202    ///
203    /// * `table` — Table to look up.
204    /// * `column` — Column name to validate.
205    pub fn validate_column(&self, table: &str, column: &str) -> Result<(), ValidationError> {
206        // If table doesn't exist, skip column validation (table error takes precedence)
207        if !self.tables.contains(&table.to_string()) {
208            return Ok(());
209        }
210
211        // Always allow *
212        if column == "*" {
213            return Ok(());
214        }
215
216        // Skip SQL expressions — these are not plain column names.
217        // Covers: COUNT(*), id::text, leg_ids[1], brand_name AS alias,
218        //         count(distinct x), EXCLUDED.col, NOW(), etc.
219        if column.contains('(')
220            || column.contains('[')
221            || column.contains("::")
222            || column.contains(" AS ")
223            || column.contains(" as ")
224            || column.starts_with("distinct ")
225            || column.starts_with("DISTINCT ")
226        {
227            return Ok(());
228        }
229
230        // Qualified names like "table.column" — validate against the referenced table
231        if column.contains('.') {
232            let parts: Vec<&str> = column.split('.').collect();
233            if parts.len() == 3 && parts[0] == "public" {
234                let public_table = format!("{}.{}", parts[0], parts[1]);
235                if self.tables.contains(&public_table) {
236                    return self.validate_column(&public_table, parts[2]);
237                }
238                if self.tables.contains(&parts[1].to_string()) {
239                    return self.validate_column(parts[1], parts[2]);
240                }
241            }
242            if parts.len() == 2 {
243                // Only validate if the referenced table is known to the validator
244                if self.tables.contains(&parts[0].to_string()) {
245                    return self.validate_column(parts[0], parts[1]);
246                }
247                let public_table = format!("public.{}", parts[0]);
248                if self.tables.contains(&public_table) {
249                    return self.validate_column(&public_table, parts[1]);
250                }
251            }
252            // Unknown table prefix or complex dotted path — allow (might be JSON)
253            return Ok(());
254        }
255
256        if let Some(cols) = self.columns.get(table) {
257            if cols.contains(&column.to_string()) {
258                Ok(())
259            } else {
260                let suggestion = self.did_you_mean(column, cols);
261                Err(ValidationError::ColumnNotFound {
262                    table: table.to_string(),
263                    column: column.to_string(),
264                    suggestion,
265                })
266            }
267        } else {
268            Ok(())
269        }
270    }
271
272    /// Extract column name from an Expr for validation.
273    fn extract_column_name(expr: &Expr) -> Option<String> {
274        match expr {
275            Expr::Named(name) => Some(name.clone()),
276            Expr::Aliased { name, .. } => Some(name.clone()),
277            Expr::Aggregate { col, .. } => Some(col.clone()),
278            Expr::Cast { expr, .. } => Self::extract_column_name(expr),
279            Expr::JsonAccess { column, .. } => Some(column.clone()),
280            _ => None,
281        }
282    }
283
284    /// Get column type for a table.column
285    pub fn get_column_type(&self, table: &str, column: &str) -> Option<&String> {
286        self.column_types.get(table)?.get(column)
287    }
288
289    /// Validate that a Value's type matches the expected column type.
290    /// Returns Ok(()) if compatible, Err with TypeMismatch if not.
291    pub fn validate_value_type(
292        &self,
293        table: &str,
294        column: &str,
295        value: &crate::ast::Value,
296    ) -> Result<(), ValidationError> {
297        use crate::ast::Value;
298
299        // Get the expected type for this column
300        let expected_type = match self.get_column_type(table, column) {
301            Some(t) => t.to_uppercase(),
302            None => return Ok(()), // Column type unknown, skip validation
303        };
304
305        // NULL is always allowed (for nullable columns)
306        if matches!(value, Value::Null | Value::NullUuid) {
307            return Ok(());
308        }
309
310        // Param and NamedParam are runtime values - can't validate at compile time
311        if matches!(
312            value,
313            Value::Param(_)
314                | Value::NamedParam(_)
315                | Value::Function(_)
316                | Value::Subquery(_)
317                | Value::Expr(_)
318        ) {
319            return Ok(());
320        }
321
322        // Value::Array is used with IN/NOT IN operators — the array is a container
323        // of elements whose type should match the column, not the array itself.
324        // Skip type validation for arrays (the element types are checked at runtime).
325        if matches!(value, Value::Array(_)) {
326            return Ok(());
327        }
328
329        // Map Value variant to its type name
330        let value_type = match value {
331            Value::Bool(_) => "BOOLEAN",
332            Value::Int(_) => "INT",
333            Value::Float(_) => "FLOAT",
334            Value::String(_) => "TEXT",
335            Value::Uuid(_) => "UUID",
336            Value::Column(_) => return Ok(()), // Column reference, type checked elsewhere
337            Value::Interval { .. } => "INTERVAL",
338            Value::Timestamp(_) => "TIMESTAMP",
339            Value::Bytes(_) => "BYTEA",
340            Value::Vector(_) => "VECTOR",
341            Value::Json(_) => "JSONB",
342            _ => return Ok(()), // Unknown value type, skip
343        };
344
345        // Check compatibility
346        if !Self::types_compatible(&expected_type, value_type) {
347            return Err(ValidationError::TypeMismatch {
348                table: table.to_string(),
349                column: column.to_string(),
350                expected: expected_type,
351                got: value_type.to_string(),
352            });
353        }
354
355        Ok(())
356    }
357
358    /// Check if expected column type is compatible with value type.
359    /// Allows flexible matching (e.g., INT matches INT4, BIGINT, INTEGER, etc.)
360    fn types_compatible(expected: &str, value_type: &str) -> bool {
361        let expected = expected.to_uppercase();
362        let value_type = value_type.to_uppercase();
363
364        // Exact match
365        if expected == value_type {
366            return true;
367        }
368
369        // Integer family
370        let int_types = [
371            "INT",
372            "INT4",
373            "INT8",
374            "INTEGER",
375            "BIGINT",
376            "SMALLINT",
377            "SERIAL",
378            "BIGSERIAL",
379        ];
380        if int_types.contains(&expected.as_str()) && value_type == "INT" {
381            return true;
382        }
383
384        // Float family — includes "DOUBLE PRECISION" which is how PG stores f64
385        let float_types = [
386            "FLOAT",
387            "FLOAT4",
388            "FLOAT8",
389            "DOUBLE",
390            "DOUBLE PRECISION",
391            "DECIMAL",
392            "NUMERIC",
393            "REAL",
394        ];
395        if float_types.contains(&expected.as_str())
396            && (value_type == "FLOAT" || value_type == "INT")
397        {
398            return true;
399        }
400
401        // Text family - TEXT is very flexible in PostgreSQL
402        let text_types = ["TEXT", "VARCHAR", "CHAR", "CHARACTER", "CITEXT", "NAME"];
403        if text_types.contains(&expected.as_str()) && value_type == "TEXT" {
404            return true;
405        }
406
407        // Boolean
408        if (expected == "BOOLEAN" || expected == "BOOL") && value_type == "BOOLEAN" {
409            return true;
410        }
411
412        // UUID
413        if expected == "UUID" && (value_type == "UUID" || value_type == "TEXT") {
414            return true;
415        }
416
417        // Timestamp family
418        let ts_types = ["TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME", "TIMETZ"];
419        if ts_types.contains(&expected.as_str())
420            && (value_type == "TIMESTAMP" || value_type == "TEXT")
421        {
422            return true;
423        }
424
425        // JSONB/JSON accepts almost anything
426        if expected == "JSONB" || expected == "JSON" {
427            return true;
428        }
429
430        // Arrays
431        if expected.contains("[]") || expected.starts_with("_") {
432            return value_type == "ARRAY";
433        }
434
435        false
436    }
437
438    /// Validate an entire Qail against the schema.
439    pub fn validate_command(&self, cmd: &Qail) -> ValidationResult {
440        let mut errors = Vec::new();
441
442        if let Err(e) = self.validate_table(&cmd.table) {
443            errors.push(e);
444        }
445
446        // Collect aliases from the SELECT column list so that order_by / having
447        // references to computed aliases (e.g. count().alias("route_count"))
448        // are not flagged as column-not-found errors.
449        let mut aliases: Vec<String> = Vec::new();
450        for col in &cmd.columns {
451            if let Expr::Aliased { alias, .. } = col {
452                aliases.push(alias.clone());
453            }
454            if let Some(name) = Self::extract_column_name(col)
455                && let Err(e) = self.validate_column(&cmd.table, &name)
456            {
457                errors.push(e);
458            }
459        }
460
461        for cage in &cmd.cages {
462            // Skip column validation for Sort cages — ORDER BY can reference
463            // aliases from column_expr (e.g. count().alias("route_count")),
464            // which may not exist as physical columns in the primary table.
465            if matches!(cage.kind, crate::ast::CageKind::Sort(_)) {
466                continue;
467            }
468            for cond in &cage.conditions {
469                if let Some(name) = Self::extract_column_name(&cond.left) {
470                    // Skip validation for columns that match a computed alias
471                    if aliases.iter().any(|a| a == &name) {
472                        continue;
473                    }
474                    // For join conditions, column might be qualified (table.column)
475                    if name.contains('.') {
476                        let parts: Vec<&str> = name.split('.').collect();
477                        if parts.len() == 2 {
478                            if let Err(e) = self.validate_column(parts[0], parts[1]) {
479                                errors.push(e);
480                            }
481                            // Type validation for qualified column
482                            if let Err(e) =
483                                self.validate_value_type(parts[0], parts[1], &cond.value)
484                            {
485                                errors.push(e);
486                            }
487                        }
488                    } else {
489                        if let Err(e) = self.validate_column(&cmd.table, &name) {
490                            errors.push(e);
491                        }
492                        // Type validation for unqualified column
493                        if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
494                            errors.push(e);
495                        }
496                    }
497                }
498            }
499        }
500
501        for cond in &cmd.having {
502            if let Some(name) = Self::extract_column_name(&cond.left) {
503                if name.contains('(') || name == "*" {
504                    continue;
505                }
506                if name.contains('.') {
507                    let parts: Vec<&str> = name.split('.').collect();
508                    if parts.len() == 2 {
509                        if let Err(e) = self.validate_column(parts[0], parts[1]) {
510                            errors.push(e);
511                        }
512                        if let Err(e) = self.validate_value_type(parts[0], parts[1], &cond.value) {
513                            errors.push(e);
514                        }
515                    }
516                } else {
517                    if let Err(e) = self.validate_column(&cmd.table, &name) {
518                        errors.push(e);
519                    }
520                    if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
521                        errors.push(e);
522                    }
523                }
524            }
525        }
526
527        for join in &cmd.joins {
528            if let Err(e) = self.validate_table(&join.table) {
529                errors.push(e);
530            }
531
532            if let Some(conditions) = &join.on {
533                for cond in conditions {
534                    if let Some(name) = Self::extract_column_name(&cond.left)
535                        && name.contains('.')
536                    {
537                        let parts: Vec<&str> = name.split('.').collect();
538                        if parts.len() == 2
539                            && let Err(e) = self.validate_column(parts[0], parts[1])
540                        {
541                            errors.push(e);
542                        }
543                    }
544                    // Also check right side if it's a column reference
545                    if let crate::ast::Value::Column(col_name) = &cond.value
546                        && col_name.contains('.')
547                    {
548                        let parts: Vec<&str> = col_name.split('.').collect();
549                        if parts.len() == 2
550                            && let Err(e) = self.validate_column(parts[0], parts[1])
551                        {
552                            errors.push(e);
553                        }
554                    }
555                }
556            }
557        }
558
559        if let Some(returning) = &cmd.returning {
560            for col in returning {
561                if let Some(name) = Self::extract_column_name(col)
562                    && let Err(e) = self.validate_column(&cmd.table, &name)
563                {
564                    errors.push(e);
565                }
566            }
567        }
568
569        if errors.is_empty() {
570            Ok(())
571        } else {
572            Err(errors)
573        }
574    }
575
576    /// Find the best match with Levenshtein distance within threshold.
577    fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
578        let mut best_match = None;
579        let mut min_dist = usize::MAX;
580
581        for cand in candidates {
582            let cand_str = cand.as_ref();
583            let dist = levenshtein(input, cand_str);
584
585            // Dynamic threshold based on length
586            let threshold = match input.len() {
587                0..=2 => 0, // Precise match only for very short strings
588                3..=5 => 2, // Allow 2 char diff (e.g. usr -> users)
589                _ => 3,     // Allow 3 char diff for longer strings
590            };
591
592            if dist <= threshold && dist < min_dist {
593                min_dist = dist;
594                best_match = Some(cand_str.to_string());
595            }
596        }
597
598        best_match
599    }
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605
606    #[test]
607    fn test_did_you_mean_table() {
608        let mut v = Validator::new();
609        v.add_table("users", &["id", "name"]);
610        v.add_table("orders", &["id", "total"]);
611
612        assert!(v.validate_table("users").is_ok());
613
614        let err = v.validate_table("usr").unwrap_err();
615        assert!(
616            matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
617        );
618
619        let err = v.validate_table("usrs").unwrap_err();
620        assert!(
621            matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
622        );
623    }
624
625    #[test]
626    fn test_did_you_mean_column() {
627        let mut v = Validator::new();
628        v.add_table("users", &["email", "password"]);
629
630        assert!(v.validate_column("users", "email").is_ok());
631        assert!(v.validate_column("users", "*").is_ok());
632
633        let err = v.validate_column("users", "emial").unwrap_err();
634        assert!(
635            matches!(err, ValidationError::ColumnNotFound { suggestion: Some(ref s), .. } if s == "email")
636        );
637    }
638
639    #[test]
640    fn test_qualified_column_name() {
641        let mut v = Validator::new();
642        v.add_table("users", &["id", "name"]);
643        v.add_table("profiles", &["user_id", "avatar"]);
644
645        // Qualified names should pass through
646        assert!(v.validate_column("users", "users.id").is_ok());
647        assert!(v.validate_column("users", "profiles.user_id").is_ok());
648    }
649
650    #[test]
651    fn test_validate_command() {
652        let mut v = Validator::new();
653        v.add_table("users", &["id", "email", "name"]);
654
655        let cmd = Qail::get("users").columns(["id", "email"]);
656        assert!(v.validate_command(&cmd).is_ok());
657
658        let cmd = Qail::get("users").columns(["id", "emial"]); // typo
659        let errors = v.validate_command(&cmd).unwrap_err();
660        assert_eq!(errors.len(), 1);
661        assert!(
662            matches!(&errors[0], ValidationError::ColumnNotFound { column, .. } if column == "emial")
663        );
664    }
665
666    #[test]
667    fn test_validate_having_columns() {
668        let mut v = Validator::new();
669        v.add_table("orders", &["id", "status", "total"]);
670
671        let mut cmd = Qail::get("orders");
672        cmd.having.push(crate::ast::Condition {
673            left: Expr::Named("totl".to_string()),
674            op: crate::ast::Operator::Eq,
675            value: crate::ast::Value::Int(1),
676            is_array_unnest: false,
677        });
678
679        let errors = v.validate_command(&cmd).unwrap_err();
680        assert!(errors.iter().any(
681            |e| matches!(e, ValidationError::ColumnNotFound { column, .. } if column == "totl")
682        ));
683    }
684
685    #[test]
686    fn test_error_display() {
687        let err = ValidationError::TableNotFound {
688            table: "usrs".to_string(),
689            suggestion: Some("users".to_string()),
690        };
691        assert_eq!(
692            err.to_string(),
693            "Table 'usrs' not found. Did you mean 'users'?"
694        );
695
696        let err = ValidationError::ColumnNotFound {
697            table: "users".to_string(),
698            column: "emial".to_string(),
699            suggestion: Some("email".to_string()),
700        };
701        assert_eq!(
702            err.to_string(),
703            "Column 'emial' not found in table 'users'. Did you mean 'email'?"
704        );
705    }
706}