Skip to main content

qail_core/
validator.rs

1//! Schema validator and fuzzy matching suggestions.
2//!
3//! Provides compile-time-like validation for Qail against a known schema.
4//! Used by CLI, LSP, and the encoder to catch errors before they hit the wire.
5
6use crate::ast::{Expr, Qail};
7use std::collections::HashMap;
8use strsim::levenshtein;
9
10/// Validation error with structured information.
11#[derive(Debug, Clone, PartialEq)]
12pub enum ValidationError {
13    /// Referenced table does not exist in the schema.
14    TableNotFound {
15        /// Name of the missing table.
16        table: String,
17        /// Closest match from known tables.
18        suggestion: Option<String>,
19    },
20    /// Referenced column does not exist in the table.
21    ColumnNotFound {
22        /// Table the column was looked up in.
23        table: String,
24        /// Name of the missing column.
25        column: String,
26        /// Closest match from known columns.
27        suggestion: Option<String>,
28    },
29    /// Type mismatch (future: when schema includes types)
30    TypeMismatch {
31        /// Table name.
32        table: String,
33        /// Column name.
34        column: String,
35        /// Expected type.
36        expected: String,
37        /// Actual type.
38        got: String,
39    },
40    /// Invalid operator for column type (future)
41    InvalidOperator {
42        /// Column name.
43        column: String,
44        /// Operator string.
45        operator: String,
46        /// Explanation.
47        reason: String,
48    },
49}
50
51impl std::fmt::Display for ValidationError {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            ValidationError::TableNotFound { table, suggestion } => {
55                if let Some(s) = suggestion {
56                    write!(f, "Table '{}' not found. Did you mean '{}'?", table, s)
57                } else {
58                    write!(f, "Table '{}' not found.", table)
59                }
60            }
61            ValidationError::ColumnNotFound {
62                table,
63                column,
64                suggestion,
65            } => {
66                if let Some(s) = suggestion {
67                    write!(
68                        f,
69                        "Column '{}' not found in table '{}'. Did you mean '{}'?",
70                        column, table, s
71                    )
72                } else {
73                    write!(f, "Column '{}' not found in table '{}'.", column, table)
74                }
75            }
76            ValidationError::TypeMismatch {
77                table,
78                column,
79                expected,
80                got,
81            } => {
82                write!(
83                    f,
84                    "Type mismatch for '{}.{}': expected {}, got {}",
85                    table, column, expected, got
86                )
87            }
88            ValidationError::InvalidOperator {
89                column,
90                operator,
91                reason,
92            } => {
93                write!(
94                    f,
95                    "Invalid operator '{}' for column '{}': {}",
96                    operator, column, reason
97                )
98            }
99        }
100    }
101}
102
103impl std::error::Error for ValidationError {}
104
105/// Result of validation
106pub type ValidationResult = Result<(), Vec<ValidationError>>;
107
108/// Validates query elements against known schema and provides suggestions.
109#[derive(Debug, Clone)]
110pub struct Validator {
111    /// Known table names.
112    tables: Vec<String>,
113    /// Columns indexed by table name.
114    columns: HashMap<String, Vec<String>>,
115    /// Column types indexed by table.column.
116    #[allow(dead_code)]
117    column_types: HashMap<String, HashMap<String, String>>,
118}
119
120impl Default for Validator {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl Validator {
127    /// Create a new Validator with known tables and columns.
128    pub fn new() -> Self {
129        Self {
130            tables: Vec::new(),
131            columns: HashMap::new(),
132            column_types: HashMap::new(),
133        }
134    }
135
136    /// Register a table and its columns.
137    pub fn add_table(&mut self, table: &str, cols: &[&str]) {
138        self.tables.push(table.to_string());
139        self.columns.insert(
140            table.to_string(),
141            cols.iter().map(|s| s.to_string()).collect(),
142        );
143    }
144
145    /// Register a table name only (no column metadata).
146    ///
147    /// Useful for views discovered from schema text where build-time parser
148    /// does not infer projected columns. Table existence is validated, while
149    /// column-level checks are skipped.
150    pub fn add_table_name(&mut self, table: &str) {
151        self.tables.push(table.to_string());
152    }
153
154    /// Register a table with column types (for future type validation).
155    ///
156    /// # Arguments
157    ///
158    /// * `table` — Table name to register.
159    /// * `cols` — Slice of `(column_name, column_type)` pairs.
160    pub fn add_table_with_types(&mut self, table: &str, cols: &[(&str, &str)]) {
161        self.tables.push(table.to_string());
162        let col_names: Vec<String> = cols.iter().map(|(name, _)| name.to_string()).collect();
163        self.columns.insert(table.to_string(), col_names);
164
165        let type_map: HashMap<String, String> = cols
166            .iter()
167            .map(|(name, typ)| (name.to_string(), typ.to_string()))
168            .collect();
169        self.column_types.insert(table.to_string(), type_map);
170    }
171
172    /// Get list of all table names (for autocomplete).
173    pub fn table_names(&self) -> &[String] {
174        &self.tables
175    }
176
177    /// Get column names for a table (for autocomplete).
178    pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
179        self.columns.get(table)
180    }
181
182    /// Check if a table exists.
183    pub fn table_exists(&self, table: &str) -> bool {
184        self.tables.contains(&table.to_string())
185    }
186
187    /// Check if a table exists. If not, returns structured error with suggestion.
188    pub fn validate_table(&self, table: &str) -> Result<(), ValidationError> {
189        if self.tables.contains(&table.to_string()) {
190            Ok(())
191        } else {
192            let suggestion = self.did_you_mean(table, &self.tables);
193            Err(ValidationError::TableNotFound {
194                table: table.to_string(),
195                suggestion,
196            })
197        }
198    }
199
200    /// Check if a column exists in a table. If not, returns a structured error.
201    ///
202    /// # Arguments
203    ///
204    /// * `table` — Table to look up.
205    /// * `column` — Column name to validate.
206    pub fn validate_column(&self, table: &str, column: &str) -> Result<(), ValidationError> {
207        // If table doesn't exist, skip column validation (table error takes precedence)
208        if !self.tables.contains(&table.to_string()) {
209            return Ok(());
210        }
211
212        // Always allow *
213        if column == "*" {
214            return Ok(());
215        }
216
217        // Skip SQL expressions — these are not plain column names.
218        // Covers: COUNT(*), id::text, leg_ids[1], brand_name AS alias,
219        //         count(distinct x), EXCLUDED.col, NOW(), etc.
220        if column.contains('(')
221            || column.contains('[')
222            || column.contains("::")
223            || column.contains(" AS ")
224            || column.contains(" as ")
225            || column.starts_with("distinct ")
226            || column.starts_with("DISTINCT ")
227        {
228            return Ok(());
229        }
230
231        // Qualified names like "table.column" — validate against the referenced table
232        if column.contains('.') {
233            let parts: Vec<&str> = column.split('.').collect();
234            if parts.len() == 2 {
235                // Only validate if the referenced table is known to the validator
236                if self.tables.contains(&parts[0].to_string()) {
237                    return self.validate_column(parts[0], parts[1]);
238                }
239            }
240            // Unknown table prefix or complex dotted path — allow (might be JSON)
241            return Ok(());
242        }
243
244        if let Some(cols) = self.columns.get(table) {
245            if cols.contains(&column.to_string()) {
246                Ok(())
247            } else {
248                let suggestion = self.did_you_mean(column, cols);
249                Err(ValidationError::ColumnNotFound {
250                    table: table.to_string(),
251                    column: column.to_string(),
252                    suggestion,
253                })
254            }
255        } else {
256            Ok(())
257        }
258    }
259
260    /// Extract column name from an Expr for validation.
261    fn extract_column_name(expr: &Expr) -> Option<String> {
262        match expr {
263            Expr::Named(name) => Some(name.clone()),
264            Expr::Aliased { name, .. } => Some(name.clone()),
265            Expr::Aggregate { col, .. } => Some(col.clone()),
266            Expr::Cast { expr, .. } => Self::extract_column_name(expr),
267            Expr::JsonAccess { column, .. } => Some(column.clone()),
268            _ => None,
269        }
270    }
271
272    /// Get column type for a table.column
273    pub fn get_column_type(&self, table: &str, column: &str) -> Option<&String> {
274        self.column_types.get(table)?.get(column)
275    }
276
277    /// Validate that a Value's type matches the expected column type.
278    /// Returns Ok(()) if compatible, Err with TypeMismatch if not.
279    pub fn validate_value_type(
280        &self,
281        table: &str,
282        column: &str,
283        value: &crate::ast::Value,
284    ) -> Result<(), ValidationError> {
285        use crate::ast::Value;
286
287        // Get the expected type for this column
288        let expected_type = match self.get_column_type(table, column) {
289            Some(t) => t.to_uppercase(),
290            None => return Ok(()), // Column type unknown, skip validation
291        };
292
293        // NULL is always allowed (for nullable columns)
294        if matches!(value, Value::Null | Value::NullUuid) {
295            return Ok(());
296        }
297
298        // Param and NamedParam are runtime values - can't validate at compile time
299        if matches!(
300            value,
301            Value::Param(_)
302                | Value::NamedParam(_)
303                | Value::Function(_)
304                | Value::Subquery(_)
305                | Value::Expr(_)
306        ) {
307            return Ok(());
308        }
309
310        // Value::Array is used with IN/NOT IN operators — the array is a container
311        // of elements whose type should match the column, not the array itself.
312        // Skip type validation for arrays (the element types are checked at runtime).
313        if matches!(value, Value::Array(_)) {
314            return Ok(());
315        }
316
317        // Map Value variant to its type name
318        let value_type = match value {
319            Value::Bool(_) => "BOOLEAN",
320            Value::Int(_) => "INT",
321            Value::Float(_) => "FLOAT",
322            Value::String(_) => "TEXT",
323            Value::Uuid(_) => "UUID",
324            Value::Column(_) => return Ok(()), // Column reference, type checked elsewhere
325            Value::Interval { .. } => "INTERVAL",
326            Value::Timestamp(_) => "TIMESTAMP",
327            Value::Bytes(_) => "BYTEA",
328            Value::Vector(_) => "VECTOR",
329            Value::Json(_) => "JSONB",
330            _ => return Ok(()), // Unknown value type, skip
331        };
332
333        // Check compatibility
334        if !Self::types_compatible(&expected_type, value_type) {
335            return Err(ValidationError::TypeMismatch {
336                table: table.to_string(),
337                column: column.to_string(),
338                expected: expected_type,
339                got: value_type.to_string(),
340            });
341        }
342
343        Ok(())
344    }
345
346    /// Check if expected column type is compatible with value type.
347    /// Allows flexible matching (e.g., INT matches INT4, BIGINT, INTEGER, etc.)
348    fn types_compatible(expected: &str, value_type: &str) -> bool {
349        let expected = expected.to_uppercase();
350        let value_type = value_type.to_uppercase();
351
352        // Exact match
353        if expected == value_type {
354            return true;
355        }
356
357        // Integer family
358        let int_types = [
359            "INT",
360            "INT4",
361            "INT8",
362            "INTEGER",
363            "BIGINT",
364            "SMALLINT",
365            "SERIAL",
366            "BIGSERIAL",
367        ];
368        if int_types.contains(&expected.as_str()) && value_type == "INT" {
369            return true;
370        }
371
372        // Float family — includes "DOUBLE PRECISION" which is how PG stores f64
373        let float_types = [
374            "FLOAT",
375            "FLOAT4",
376            "FLOAT8",
377            "DOUBLE",
378            "DOUBLE PRECISION",
379            "DECIMAL",
380            "NUMERIC",
381            "REAL",
382        ];
383        if float_types.contains(&expected.as_str())
384            && (value_type == "FLOAT" || value_type == "INT")
385        {
386            return true;
387        }
388
389        // Text family - TEXT is very flexible in PostgreSQL
390        let text_types = ["TEXT", "VARCHAR", "CHAR", "CHARACTER", "CITEXT", "NAME"];
391        if text_types.contains(&expected.as_str()) && value_type == "TEXT" {
392            return true;
393        }
394
395        // Boolean
396        if (expected == "BOOLEAN" || expected == "BOOL") && value_type == "BOOLEAN" {
397            return true;
398        }
399
400        // UUID
401        if expected == "UUID" && (value_type == "UUID" || value_type == "TEXT") {
402            return true;
403        }
404
405        // Timestamp family
406        let ts_types = ["TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME", "TIMETZ"];
407        if ts_types.contains(&expected.as_str())
408            && (value_type == "TIMESTAMP" || value_type == "TEXT")
409        {
410            return true;
411        }
412
413        // JSONB/JSON accepts almost anything
414        if expected == "JSONB" || expected == "JSON" {
415            return true;
416        }
417
418        // Arrays
419        if expected.contains("[]") || expected.starts_with("_") {
420            return value_type == "ARRAY";
421        }
422
423        false
424    }
425
426    /// Validate an entire Qail against the schema.
427    pub fn validate_command(&self, cmd: &Qail) -> ValidationResult {
428        let mut errors = Vec::new();
429
430        if let Err(e) = self.validate_table(&cmd.table) {
431            errors.push(e);
432        }
433
434        // Collect aliases from the SELECT column list so that order_by / having
435        // references to computed aliases (e.g. count().alias("route_count"))
436        // are not flagged as column-not-found errors.
437        let mut aliases: Vec<String> = Vec::new();
438        for col in &cmd.columns {
439            if let Expr::Aliased { alias, .. } = col {
440                aliases.push(alias.clone());
441            }
442            if let Some(name) = Self::extract_column_name(col)
443                && let Err(e) = self.validate_column(&cmd.table, &name)
444            {
445                errors.push(e);
446            }
447        }
448
449        for cage in &cmd.cages {
450            // Skip column validation for Sort cages — ORDER BY can reference
451            // aliases from column_expr (e.g. count().alias("route_count")),
452            // which may not exist as physical columns in the primary table.
453            if matches!(cage.kind, crate::ast::CageKind::Sort(_)) {
454                continue;
455            }
456            for cond in &cage.conditions {
457                if let Some(name) = Self::extract_column_name(&cond.left) {
458                    // Skip validation for columns that match a computed alias
459                    if aliases.iter().any(|a| a == &name) {
460                        continue;
461                    }
462                    // For join conditions, column might be qualified (table.column)
463                    if name.contains('.') {
464                        let parts: Vec<&str> = name.split('.').collect();
465                        if parts.len() == 2 {
466                            if let Err(e) = self.validate_column(parts[0], parts[1]) {
467                                errors.push(e);
468                            }
469                            // Type validation for qualified column
470                            if let Err(e) =
471                                self.validate_value_type(parts[0], parts[1], &cond.value)
472                            {
473                                errors.push(e);
474                            }
475                        }
476                    } else {
477                        if let Err(e) = self.validate_column(&cmd.table, &name) {
478                            errors.push(e);
479                        }
480                        // Type validation for unqualified column
481                        if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
482                            errors.push(e);
483                        }
484                    }
485                }
486            }
487        }
488
489        for cond in &cmd.having {
490            if let Some(name) = Self::extract_column_name(&cond.left) {
491                if name.contains('(') || name == "*" {
492                    continue;
493                }
494                if name.contains('.') {
495                    let parts: Vec<&str> = name.split('.').collect();
496                    if parts.len() == 2 {
497                        if let Err(e) = self.validate_column(parts[0], parts[1]) {
498                            errors.push(e);
499                        }
500                        if let Err(e) = self.validate_value_type(parts[0], parts[1], &cond.value) {
501                            errors.push(e);
502                        }
503                    }
504                } else {
505                    if let Err(e) = self.validate_column(&cmd.table, &name) {
506                        errors.push(e);
507                    }
508                    if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
509                        errors.push(e);
510                    }
511                }
512            }
513        }
514
515        for join in &cmd.joins {
516            if let Err(e) = self.validate_table(&join.table) {
517                errors.push(e);
518            }
519
520            if let Some(conditions) = &join.on {
521                for cond in conditions {
522                    if let Some(name) = Self::extract_column_name(&cond.left)
523                        && name.contains('.')
524                    {
525                        let parts: Vec<&str> = name.split('.').collect();
526                        if parts.len() == 2
527                            && let Err(e) = self.validate_column(parts[0], parts[1])
528                        {
529                            errors.push(e);
530                        }
531                    }
532                    // Also check right side if it's a column reference
533                    if let crate::ast::Value::Column(col_name) = &cond.value
534                        && col_name.contains('.')
535                    {
536                        let parts: Vec<&str> = col_name.split('.').collect();
537                        if parts.len() == 2
538                            && let Err(e) = self.validate_column(parts[0], parts[1])
539                        {
540                            errors.push(e);
541                        }
542                    }
543                }
544            }
545        }
546
547        if let Some(returning) = &cmd.returning {
548            for col in returning {
549                if let Some(name) = Self::extract_column_name(col)
550                    && let Err(e) = self.validate_column(&cmd.table, &name)
551                {
552                    errors.push(e);
553                }
554            }
555        }
556
557        if errors.is_empty() {
558            Ok(())
559        } else {
560            Err(errors)
561        }
562    }
563
564    /// Find the best match with Levenshtein distance within threshold.
565    fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
566        let mut best_match = None;
567        let mut min_dist = usize::MAX;
568
569        for cand in candidates {
570            let cand_str = cand.as_ref();
571            let dist = levenshtein(input, cand_str);
572
573            // Dynamic threshold based on length
574            let threshold = match input.len() {
575                0..=2 => 0, // Precise match only for very short strings
576                3..=5 => 2, // Allow 2 char diff (e.g. usr -> users)
577                _ => 3,     // Allow 3 char diff for longer strings
578            };
579
580            if dist <= threshold && dist < min_dist {
581                min_dist = dist;
582                best_match = Some(cand_str.to_string());
583            }
584        }
585
586        best_match
587    }
588
589    // =========================================================================
590    // Legacy API (for backward compatibility)
591    // =========================================================================
592
593    /// Legacy: validate_table that returns String error
594    #[deprecated(note = "Use validate_table() which returns ValidationError")]
595    pub fn validate_table_legacy(&self, table: &str) -> Result<(), String> {
596        self.validate_table(table).map_err(|e| e.to_string())
597    }
598
599    /// Legacy: validate_column that returns String error
600    #[deprecated(note = "Use validate_column() which returns ValidationError")]
601    pub fn validate_column_legacy(&self, table: &str, column: &str) -> Result<(), String> {
602        self.validate_column(table, column)
603            .map_err(|e| e.to_string())
604    }
605}
606
607#[cfg(test)]
608mod tests {
609    use super::*;
610
611    #[test]
612    fn test_did_you_mean_table() {
613        let mut v = Validator::new();
614        v.add_table("users", &["id", "name"]);
615        v.add_table("orders", &["id", "total"]);
616
617        assert!(v.validate_table("users").is_ok());
618
619        let err = v.validate_table("usr").unwrap_err();
620        assert!(
621            matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
622        );
623
624        let err = v.validate_table("usrs").unwrap_err();
625        assert!(
626            matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
627        );
628    }
629
630    #[test]
631    fn test_did_you_mean_column() {
632        let mut v = Validator::new();
633        v.add_table("users", &["email", "password"]);
634
635        assert!(v.validate_column("users", "email").is_ok());
636        assert!(v.validate_column("users", "*").is_ok());
637
638        let err = v.validate_column("users", "emial").unwrap_err();
639        assert!(
640            matches!(err, ValidationError::ColumnNotFound { suggestion: Some(ref s), .. } if s == "email")
641        );
642    }
643
644    #[test]
645    fn test_qualified_column_name() {
646        let mut v = Validator::new();
647        v.add_table("users", &["id", "name"]);
648        v.add_table("profiles", &["user_id", "avatar"]);
649
650        // Qualified names should pass through
651        assert!(v.validate_column("users", "users.id").is_ok());
652        assert!(v.validate_column("users", "profiles.user_id").is_ok());
653    }
654
655    #[test]
656    fn test_validate_command() {
657        let mut v = Validator::new();
658        v.add_table("users", &["id", "email", "name"]);
659
660        let cmd = Qail::get("users").columns(["id", "email"]);
661        assert!(v.validate_command(&cmd).is_ok());
662
663        let cmd = Qail::get("users").columns(["id", "emial"]); // typo
664        let errors = v.validate_command(&cmd).unwrap_err();
665        assert_eq!(errors.len(), 1);
666        assert!(
667            matches!(&errors[0], ValidationError::ColumnNotFound { column, .. } if column == "emial")
668        );
669    }
670
671    #[test]
672    fn test_validate_having_columns() {
673        let mut v = Validator::new();
674        v.add_table("orders", &["id", "status", "total"]);
675
676        let mut cmd = Qail::get("orders");
677        cmd.having.push(crate::ast::Condition {
678            left: Expr::Named("totl".to_string()),
679            op: crate::ast::Operator::Eq,
680            value: crate::ast::Value::Int(1),
681            is_array_unnest: false,
682        });
683
684        let errors = v.validate_command(&cmd).unwrap_err();
685        assert!(errors.iter().any(
686            |e| matches!(e, ValidationError::ColumnNotFound { column, .. } if column == "totl")
687        ));
688    }
689
690    #[test]
691    fn test_error_display() {
692        let err = ValidationError::TableNotFound {
693            table: "usrs".to_string(),
694            suggestion: Some("users".to_string()),
695        };
696        assert_eq!(
697            err.to_string(),
698            "Table 'usrs' not found. Did you mean 'users'?"
699        );
700
701        let err = ValidationError::ColumnNotFound {
702            table: "users".to_string(),
703            column: "emial".to_string(),
704            suggestion: Some("email".to_string()),
705        };
706        assert_eq!(
707            err.to_string(),
708            "Column 'emial' not found in table 'users'. Did you mean 'email'?"
709        );
710    }
711}