Skip to main content

qail_core/
validator.rs

1//! Schema validator and fuzzy matching suggestions.
2//!
3//! Provides compile-time-like validation for Qail against a known schema.
4//! Used by CLI, LSP, and the encoder to catch errors before they hit the wire.
5
6use crate::ast::{Expr, Qail};
7use std::collections::HashMap;
8use strsim::levenshtein;
9
10/// Validation error with structured information.
11#[derive(Debug, Clone, PartialEq)]
12pub enum ValidationError {
13    /// Referenced table does not exist in the schema.
14    TableNotFound {
15        /// Name of the missing table.
16        table: String,
17        /// Closest match from known tables.
18        suggestion: Option<String>,
19    },
20    /// Referenced column does not exist in the table.
21    ColumnNotFound {
22        /// Table the column was looked up in.
23        table: String,
24        /// Name of the missing column.
25        column: String,
26        /// Closest match from known columns.
27        suggestion: Option<String>,
28    },
29    /// Type mismatch (future: when schema includes types)
30    TypeMismatch {
31        /// Table name.
32        table: String,
33        /// Column name.
34        column: String,
35        /// Expected type.
36        expected: String,
37        /// Actual type.
38        got: String,
39    },
40    /// Invalid operator for column type (future)
41    InvalidOperator {
42        /// Column name.
43        column: String,
44        /// Operator string.
45        operator: String,
46        /// Explanation.
47        reason: String,
48    },
49}
50
51impl std::fmt::Display for ValidationError {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            ValidationError::TableNotFound { table, suggestion } => {
55                if let Some(s) = suggestion {
56                    write!(f, "Table '{}' not found. Did you mean '{}'?", table, s)
57                } else {
58                    write!(f, "Table '{}' not found.", table)
59                }
60            }
61            ValidationError::ColumnNotFound {
62                table,
63                column,
64                suggestion,
65            } => {
66                if let Some(s) = suggestion {
67                    write!(
68                        f,
69                        "Column '{}' not found in table '{}'. Did you mean '{}'?",
70                        column, table, s
71                    )
72                } else {
73                    write!(f, "Column '{}' not found in table '{}'.", column, table)
74                }
75            }
76            ValidationError::TypeMismatch {
77                table,
78                column,
79                expected,
80                got,
81            } => {
82                write!(
83                    f,
84                    "Type mismatch for '{}.{}': expected {}, got {}",
85                    table, column, expected, got
86                )
87            }
88            ValidationError::InvalidOperator {
89                column,
90                operator,
91                reason,
92            } => {
93                write!(
94                    f,
95                    "Invalid operator '{}' for column '{}': {}",
96                    operator, column, reason
97                )
98            }
99        }
100    }
101}
102
103impl std::error::Error for ValidationError {}
104
105/// Result of validation
106pub type ValidationResult = Result<(), Vec<ValidationError>>;
107
108/// Validates query elements against known schema and provides suggestions.
109#[derive(Debug, Clone)]
110pub struct Validator {
111    /// Known table names.
112    tables: Vec<String>,
113    /// Columns indexed by table name.
114    columns: HashMap<String, Vec<String>>,
115    /// Column types indexed by table.column.
116    column_types: HashMap<String, HashMap<String, String>>,
117}
118
119impl Default for Validator {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125impl Validator {
126    /// Create a new Validator with known tables and columns.
127    pub fn new() -> Self {
128        Self {
129            tables: Vec::new(),
130            columns: HashMap::new(),
131            column_types: HashMap::new(),
132        }
133    }
134
135    /// Register a table and its columns.
136    pub fn add_table(&mut self, table: &str, cols: &[&str]) {
137        self.tables.push(table.to_string());
138        self.columns.insert(
139            table.to_string(),
140            cols.iter().map(|s| s.to_string()).collect(),
141        );
142    }
143
144    /// Register a table name only (no column metadata).
145    ///
146    /// Useful for views discovered from schema text where build-time parser
147    /// does not infer projected columns. Table existence is validated, while
148    /// column-level checks are skipped.
149    pub fn add_table_name(&mut self, table: &str) {
150        self.tables.push(table.to_string());
151    }
152
153    /// Register a table with column types (for future type validation).
154    ///
155    /// # Arguments
156    ///
157    /// * `table` — Table name to register.
158    /// * `cols` — Slice of `(column_name, column_type)` pairs.
159    pub fn add_table_with_types(&mut self, table: &str, cols: &[(&str, &str)]) {
160        self.tables.push(table.to_string());
161        let col_names: Vec<String> = cols.iter().map(|(name, _)| name.to_string()).collect();
162        self.columns.insert(table.to_string(), col_names);
163
164        let type_map: HashMap<String, String> = cols
165            .iter()
166            .map(|(name, typ)| (name.to_string(), typ.to_string()))
167            .collect();
168        self.column_types.insert(table.to_string(), type_map);
169    }
170
171    /// Get list of all table names (for autocomplete).
172    pub fn table_names(&self) -> &[String] {
173        &self.tables
174    }
175
176    /// Get column names for a table (for autocomplete).
177    pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
178        self.columns.get(table)
179    }
180
181    /// Check if a table exists.
182    pub fn table_exists(&self, table: &str) -> bool {
183        self.tables.contains(&table.to_string())
184    }
185
186    /// Check if a table exists. If not, returns structured error with suggestion.
187    pub fn validate_table(&self, table: &str) -> Result<(), ValidationError> {
188        if self.tables.contains(&table.to_string()) {
189            Ok(())
190        } else {
191            let suggestion = self.did_you_mean(table, &self.tables);
192            Err(ValidationError::TableNotFound {
193                table: table.to_string(),
194                suggestion,
195            })
196        }
197    }
198
199    /// Check if a column exists in a table. If not, returns a structured error.
200    ///
201    /// # Arguments
202    ///
203    /// * `table` — Table to look up.
204    /// * `column` — Column name to validate.
205    pub fn validate_column(&self, table: &str, column: &str) -> Result<(), ValidationError> {
206        // If table doesn't exist, skip column validation (table error takes precedence)
207        if !self.tables.contains(&table.to_string()) {
208            return Ok(());
209        }
210
211        // Always allow *
212        if column == "*" {
213            return Ok(());
214        }
215
216        // Skip SQL expressions — these are not plain column names.
217        // Covers: COUNT(*), id::text, leg_ids[1], brand_name AS alias,
218        //         count(distinct x), EXCLUDED.col, NOW(), etc.
219        if column.contains('(')
220            || column.contains('[')
221            || column.contains("::")
222            || column.contains(" AS ")
223            || column.contains(" as ")
224            || column.starts_with("distinct ")
225            || column.starts_with("DISTINCT ")
226        {
227            return Ok(());
228        }
229
230        // Qualified names like "table.column" — validate against the referenced table
231        if column.contains('.') {
232            let parts: Vec<&str> = column.split('.').collect();
233            if parts.len() == 2 {
234                // Only validate if the referenced table is known to the validator
235                if self.tables.contains(&parts[0].to_string()) {
236                    return self.validate_column(parts[0], parts[1]);
237                }
238            }
239            // Unknown table prefix or complex dotted path — allow (might be JSON)
240            return Ok(());
241        }
242
243        if let Some(cols) = self.columns.get(table) {
244            if cols.contains(&column.to_string()) {
245                Ok(())
246            } else {
247                let suggestion = self.did_you_mean(column, cols);
248                Err(ValidationError::ColumnNotFound {
249                    table: table.to_string(),
250                    column: column.to_string(),
251                    suggestion,
252                })
253            }
254        } else {
255            Ok(())
256        }
257    }
258
259    /// Extract column name from an Expr for validation.
260    fn extract_column_name(expr: &Expr) -> Option<String> {
261        match expr {
262            Expr::Named(name) => Some(name.clone()),
263            Expr::Aliased { name, .. } => Some(name.clone()),
264            Expr::Aggregate { col, .. } => Some(col.clone()),
265            Expr::Cast { expr, .. } => Self::extract_column_name(expr),
266            Expr::JsonAccess { column, .. } => Some(column.clone()),
267            _ => None,
268        }
269    }
270
271    /// Get column type for a table.column
272    pub fn get_column_type(&self, table: &str, column: &str) -> Option<&String> {
273        self.column_types.get(table)?.get(column)
274    }
275
276    /// Validate that a Value's type matches the expected column type.
277    /// Returns Ok(()) if compatible, Err with TypeMismatch if not.
278    pub fn validate_value_type(
279        &self,
280        table: &str,
281        column: &str,
282        value: &crate::ast::Value,
283    ) -> Result<(), ValidationError> {
284        use crate::ast::Value;
285
286        // Get the expected type for this column
287        let expected_type = match self.get_column_type(table, column) {
288            Some(t) => t.to_uppercase(),
289            None => return Ok(()), // Column type unknown, skip validation
290        };
291
292        // NULL is always allowed (for nullable columns)
293        if matches!(value, Value::Null | Value::NullUuid) {
294            return Ok(());
295        }
296
297        // Param and NamedParam are runtime values - can't validate at compile time
298        if matches!(
299            value,
300            Value::Param(_)
301                | Value::NamedParam(_)
302                | Value::Function(_)
303                | Value::Subquery(_)
304                | Value::Expr(_)
305        ) {
306            return Ok(());
307        }
308
309        // Value::Array is used with IN/NOT IN operators — the array is a container
310        // of elements whose type should match the column, not the array itself.
311        // Skip type validation for arrays (the element types are checked at runtime).
312        if matches!(value, Value::Array(_)) {
313            return Ok(());
314        }
315
316        // Map Value variant to its type name
317        let value_type = match value {
318            Value::Bool(_) => "BOOLEAN",
319            Value::Int(_) => "INT",
320            Value::Float(_) => "FLOAT",
321            Value::String(_) => "TEXT",
322            Value::Uuid(_) => "UUID",
323            Value::Column(_) => return Ok(()), // Column reference, type checked elsewhere
324            Value::Interval { .. } => "INTERVAL",
325            Value::Timestamp(_) => "TIMESTAMP",
326            Value::Bytes(_) => "BYTEA",
327            Value::Vector(_) => "VECTOR",
328            Value::Json(_) => "JSONB",
329            _ => return Ok(()), // Unknown value type, skip
330        };
331
332        // Check compatibility
333        if !Self::types_compatible(&expected_type, value_type) {
334            return Err(ValidationError::TypeMismatch {
335                table: table.to_string(),
336                column: column.to_string(),
337                expected: expected_type,
338                got: value_type.to_string(),
339            });
340        }
341
342        Ok(())
343    }
344
345    /// Check if expected column type is compatible with value type.
346    /// Allows flexible matching (e.g., INT matches INT4, BIGINT, INTEGER, etc.)
347    fn types_compatible(expected: &str, value_type: &str) -> bool {
348        let expected = expected.to_uppercase();
349        let value_type = value_type.to_uppercase();
350
351        // Exact match
352        if expected == value_type {
353            return true;
354        }
355
356        // Integer family
357        let int_types = [
358            "INT",
359            "INT4",
360            "INT8",
361            "INTEGER",
362            "BIGINT",
363            "SMALLINT",
364            "SERIAL",
365            "BIGSERIAL",
366        ];
367        if int_types.contains(&expected.as_str()) && value_type == "INT" {
368            return true;
369        }
370
371        // Float family — includes "DOUBLE PRECISION" which is how PG stores f64
372        let float_types = [
373            "FLOAT",
374            "FLOAT4",
375            "FLOAT8",
376            "DOUBLE",
377            "DOUBLE PRECISION",
378            "DECIMAL",
379            "NUMERIC",
380            "REAL",
381        ];
382        if float_types.contains(&expected.as_str())
383            && (value_type == "FLOAT" || value_type == "INT")
384        {
385            return true;
386        }
387
388        // Text family - TEXT is very flexible in PostgreSQL
389        let text_types = ["TEXT", "VARCHAR", "CHAR", "CHARACTER", "CITEXT", "NAME"];
390        if text_types.contains(&expected.as_str()) && value_type == "TEXT" {
391            return true;
392        }
393
394        // Boolean
395        if (expected == "BOOLEAN" || expected == "BOOL") && value_type == "BOOLEAN" {
396            return true;
397        }
398
399        // UUID
400        if expected == "UUID" && (value_type == "UUID" || value_type == "TEXT") {
401            return true;
402        }
403
404        // Timestamp family
405        let ts_types = ["TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME", "TIMETZ"];
406        if ts_types.contains(&expected.as_str())
407            && (value_type == "TIMESTAMP" || value_type == "TEXT")
408        {
409            return true;
410        }
411
412        // JSONB/JSON accepts almost anything
413        if expected == "JSONB" || expected == "JSON" {
414            return true;
415        }
416
417        // Arrays
418        if expected.contains("[]") || expected.starts_with("_") {
419            return value_type == "ARRAY";
420        }
421
422        false
423    }
424
425    /// Validate an entire Qail against the schema.
426    pub fn validate_command(&self, cmd: &Qail) -> ValidationResult {
427        let mut errors = Vec::new();
428
429        if let Err(e) = self.validate_table(&cmd.table) {
430            errors.push(e);
431        }
432
433        // Collect aliases from the SELECT column list so that order_by / having
434        // references to computed aliases (e.g. count().alias("route_count"))
435        // are not flagged as column-not-found errors.
436        let mut aliases: Vec<String> = Vec::new();
437        for col in &cmd.columns {
438            if let Expr::Aliased { alias, .. } = col {
439                aliases.push(alias.clone());
440            }
441            if let Some(name) = Self::extract_column_name(col)
442                && let Err(e) = self.validate_column(&cmd.table, &name)
443            {
444                errors.push(e);
445            }
446        }
447
448        for cage in &cmd.cages {
449            // Skip column validation for Sort cages — ORDER BY can reference
450            // aliases from column_expr (e.g. count().alias("route_count")),
451            // which may not exist as physical columns in the primary table.
452            if matches!(cage.kind, crate::ast::CageKind::Sort(_)) {
453                continue;
454            }
455            for cond in &cage.conditions {
456                if let Some(name) = Self::extract_column_name(&cond.left) {
457                    // Skip validation for columns that match a computed alias
458                    if aliases.iter().any(|a| a == &name) {
459                        continue;
460                    }
461                    // For join conditions, column might be qualified (table.column)
462                    if name.contains('.') {
463                        let parts: Vec<&str> = name.split('.').collect();
464                        if parts.len() == 2 {
465                            if let Err(e) = self.validate_column(parts[0], parts[1]) {
466                                errors.push(e);
467                            }
468                            // Type validation for qualified column
469                            if let Err(e) =
470                                self.validate_value_type(parts[0], parts[1], &cond.value)
471                            {
472                                errors.push(e);
473                            }
474                        }
475                    } else {
476                        if let Err(e) = self.validate_column(&cmd.table, &name) {
477                            errors.push(e);
478                        }
479                        // Type validation for unqualified column
480                        if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
481                            errors.push(e);
482                        }
483                    }
484                }
485            }
486        }
487
488        for cond in &cmd.having {
489            if let Some(name) = Self::extract_column_name(&cond.left) {
490                if name.contains('(') || name == "*" {
491                    continue;
492                }
493                if name.contains('.') {
494                    let parts: Vec<&str> = name.split('.').collect();
495                    if parts.len() == 2 {
496                        if let Err(e) = self.validate_column(parts[0], parts[1]) {
497                            errors.push(e);
498                        }
499                        if let Err(e) = self.validate_value_type(parts[0], parts[1], &cond.value) {
500                            errors.push(e);
501                        }
502                    }
503                } else {
504                    if let Err(e) = self.validate_column(&cmd.table, &name) {
505                        errors.push(e);
506                    }
507                    if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
508                        errors.push(e);
509                    }
510                }
511            }
512        }
513
514        for join in &cmd.joins {
515            if let Err(e) = self.validate_table(&join.table) {
516                errors.push(e);
517            }
518
519            if let Some(conditions) = &join.on {
520                for cond in conditions {
521                    if let Some(name) = Self::extract_column_name(&cond.left)
522                        && name.contains('.')
523                    {
524                        let parts: Vec<&str> = name.split('.').collect();
525                        if parts.len() == 2
526                            && let Err(e) = self.validate_column(parts[0], parts[1])
527                        {
528                            errors.push(e);
529                        }
530                    }
531                    // Also check right side if it's a column reference
532                    if let crate::ast::Value::Column(col_name) = &cond.value
533                        && col_name.contains('.')
534                    {
535                        let parts: Vec<&str> = col_name.split('.').collect();
536                        if parts.len() == 2
537                            && let Err(e) = self.validate_column(parts[0], parts[1])
538                        {
539                            errors.push(e);
540                        }
541                    }
542                }
543            }
544        }
545
546        if let Some(returning) = &cmd.returning {
547            for col in returning {
548                if let Some(name) = Self::extract_column_name(col)
549                    && let Err(e) = self.validate_column(&cmd.table, &name)
550                {
551                    errors.push(e);
552                }
553            }
554        }
555
556        if errors.is_empty() {
557            Ok(())
558        } else {
559            Err(errors)
560        }
561    }
562
563    /// Find the best match with Levenshtein distance within threshold.
564    fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
565        let mut best_match = None;
566        let mut min_dist = usize::MAX;
567
568        for cand in candidates {
569            let cand_str = cand.as_ref();
570            let dist = levenshtein(input, cand_str);
571
572            // Dynamic threshold based on length
573            let threshold = match input.len() {
574                0..=2 => 0, // Precise match only for very short strings
575                3..=5 => 2, // Allow 2 char diff (e.g. usr -> users)
576                _ => 3,     // Allow 3 char diff for longer strings
577            };
578
579            if dist <= threshold && dist < min_dist {
580                min_dist = dist;
581                best_match = Some(cand_str.to_string());
582            }
583        }
584
585        best_match
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592
593    #[test]
594    fn test_did_you_mean_table() {
595        let mut v = Validator::new();
596        v.add_table("users", &["id", "name"]);
597        v.add_table("orders", &["id", "total"]);
598
599        assert!(v.validate_table("users").is_ok());
600
601        let err = v.validate_table("usr").unwrap_err();
602        assert!(
603            matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
604        );
605
606        let err = v.validate_table("usrs").unwrap_err();
607        assert!(
608            matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
609        );
610    }
611
612    #[test]
613    fn test_did_you_mean_column() {
614        let mut v = Validator::new();
615        v.add_table("users", &["email", "password"]);
616
617        assert!(v.validate_column("users", "email").is_ok());
618        assert!(v.validate_column("users", "*").is_ok());
619
620        let err = v.validate_column("users", "emial").unwrap_err();
621        assert!(
622            matches!(err, ValidationError::ColumnNotFound { suggestion: Some(ref s), .. } if s == "email")
623        );
624    }
625
626    #[test]
627    fn test_qualified_column_name() {
628        let mut v = Validator::new();
629        v.add_table("users", &["id", "name"]);
630        v.add_table("profiles", &["user_id", "avatar"]);
631
632        // Qualified names should pass through
633        assert!(v.validate_column("users", "users.id").is_ok());
634        assert!(v.validate_column("users", "profiles.user_id").is_ok());
635    }
636
637    #[test]
638    fn test_validate_command() {
639        let mut v = Validator::new();
640        v.add_table("users", &["id", "email", "name"]);
641
642        let cmd = Qail::get("users").columns(["id", "email"]);
643        assert!(v.validate_command(&cmd).is_ok());
644
645        let cmd = Qail::get("users").columns(["id", "emial"]); // typo
646        let errors = v.validate_command(&cmd).unwrap_err();
647        assert_eq!(errors.len(), 1);
648        assert!(
649            matches!(&errors[0], ValidationError::ColumnNotFound { column, .. } if column == "emial")
650        );
651    }
652
653    #[test]
654    fn test_validate_having_columns() {
655        let mut v = Validator::new();
656        v.add_table("orders", &["id", "status", "total"]);
657
658        let mut cmd = Qail::get("orders");
659        cmd.having.push(crate::ast::Condition {
660            left: Expr::Named("totl".to_string()),
661            op: crate::ast::Operator::Eq,
662            value: crate::ast::Value::Int(1),
663            is_array_unnest: false,
664        });
665
666        let errors = v.validate_command(&cmd).unwrap_err();
667        assert!(errors.iter().any(
668            |e| matches!(e, ValidationError::ColumnNotFound { column, .. } if column == "totl")
669        ));
670    }
671
672    #[test]
673    fn test_error_display() {
674        let err = ValidationError::TableNotFound {
675            table: "usrs".to_string(),
676            suggestion: Some("users".to_string()),
677        };
678        assert_eq!(
679            err.to_string(),
680            "Table 'usrs' not found. Did you mean 'users'?"
681        );
682
683        let err = ValidationError::ColumnNotFound {
684            table: "users".to_string(),
685            column: "emial".to_string(),
686            suggestion: Some("email".to_string()),
687        };
688        assert_eq!(
689            err.to_string(),
690            "Column 'emial' not found in table 'users'. Did you mean 'email'?"
691        );
692    }
693}