Skip to main content

qail_core/
validator.rs

1//! Schema validator and fuzzy matching suggestions.
2//!
3//! Provides compile-time-like validation for Qail against a known schema.
4//! Used by CLI, LSP, and the encoder to catch errors before they hit the wire.
5
6use crate::ast::{Expr, Qail};
7use std::collections::HashMap;
8use strsim::levenshtein;
9
10/// Validation error with structured information.
11#[derive(Debug, Clone, PartialEq)]
12pub enum ValidationError {
13    /// Referenced table does not exist in the schema.
14    TableNotFound {
15        /// Name of the missing table.
16        table: String,
17        /// Closest match from known tables.
18        suggestion: Option<String>,
19    },
20    /// Referenced column does not exist in the table.
21    ColumnNotFound {
22        /// Table the column was looked up in.
23        table: String,
24        /// Name of the missing column.
25        column: String,
26        /// Closest match from known columns.
27        suggestion: Option<String>,
28    },
29    /// Type mismatch (future: when schema includes types)
30    TypeMismatch {
31        /// Table name.
32        table: String,
33        /// Column name.
34        column: String,
35        /// Expected type.
36        expected: String,
37        /// Actual type.
38        got: String,
39    },
40    /// Invalid operator for column type (future)
41    InvalidOperator {
42        /// Column name.
43        column: String,
44        /// Operator string.
45        operator: String,
46        /// Explanation.
47        reason: String,
48    },
49}
50
51impl std::fmt::Display for ValidationError {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            ValidationError::TableNotFound { table, suggestion } => {
55                if let Some(s) = suggestion {
56                    write!(f, "Table '{}' not found. Did you mean '{}'?", table, s)
57                } else {
58                    write!(f, "Table '{}' not found.", table)
59                }
60            }
61            ValidationError::ColumnNotFound {
62                table,
63                column,
64                suggestion,
65            } => {
66                if let Some(s) = suggestion {
67                    write!(
68                        f,
69                        "Column '{}' not found in table '{}'. Did you mean '{}'?",
70                        column, table, s
71                    )
72                } else {
73                    write!(f, "Column '{}' not found in table '{}'.", column, table)
74                }
75            }
76            ValidationError::TypeMismatch {
77                table,
78                column,
79                expected,
80                got,
81            } => {
82                write!(
83                    f,
84                    "Type mismatch for '{}.{}': expected {}, got {}",
85                    table, column, expected, got
86                )
87            }
88            ValidationError::InvalidOperator {
89                column,
90                operator,
91                reason,
92            } => {
93                write!(
94                    f,
95                    "Invalid operator '{}' for column '{}': {}",
96                    operator, column, reason
97                )
98            }
99        }
100    }
101}
102
103impl std::error::Error for ValidationError {}
104
105/// Result of validation
106pub type ValidationResult = Result<(), Vec<ValidationError>>;
107
108/// Validates query elements against known schema and provides suggestions.
109#[derive(Debug, Clone)]
110pub struct Validator {
111    /// Known table names.
112    tables: Vec<String>,
113    /// Columns indexed by table name.
114    columns: HashMap<String, Vec<String>>,
115    /// Column types indexed by table.column.
116    #[allow(dead_code)]
117    column_types: HashMap<String, HashMap<String, String>>,
118}
119
120impl Default for Validator {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl Validator {
127    /// Create a new Validator with known tables and columns.
128    pub fn new() -> Self {
129        Self {
130            tables: Vec::new(),
131            columns: HashMap::new(),
132            column_types: HashMap::new(),
133        }
134    }
135
136    /// Register a table and its columns.
137    pub fn add_table(&mut self, table: &str, cols: &[&str]) {
138        self.tables.push(table.to_string());
139        self.columns.insert(
140            table.to_string(),
141            cols.iter().map(|s| s.to_string()).collect(),
142        );
143    }
144
145    /// Register a table with column types (for future type validation).
146    ///
147    /// # Arguments
148    ///
149    /// * `table` — Table name to register.
150    /// * `cols` — Slice of `(column_name, column_type)` pairs.
151    pub fn add_table_with_types(&mut self, table: &str, cols: &[(&str, &str)]) {
152        self.tables.push(table.to_string());
153        let col_names: Vec<String> = cols.iter().map(|(name, _)| name.to_string()).collect();
154        self.columns.insert(table.to_string(), col_names);
155
156        let type_map: HashMap<String, String> = cols
157            .iter()
158            .map(|(name, typ)| (name.to_string(), typ.to_string()))
159            .collect();
160        self.column_types.insert(table.to_string(), type_map);
161    }
162
163    /// Get list of all table names (for autocomplete).
164    pub fn table_names(&self) -> &[String] {
165        &self.tables
166    }
167
168    /// Get column names for a table (for autocomplete).
169    pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
170        self.columns.get(table)
171    }
172
173    /// Check if a table exists.
174    pub fn table_exists(&self, table: &str) -> bool {
175        self.tables.contains(&table.to_string())
176    }
177
178    /// Check if a table exists. If not, returns structured error with suggestion.
179    pub fn validate_table(&self, table: &str) -> Result<(), ValidationError> {
180        if self.tables.contains(&table.to_string()) {
181            Ok(())
182        } else {
183            let suggestion = self.did_you_mean(table, &self.tables);
184            Err(ValidationError::TableNotFound {
185                table: table.to_string(),
186                suggestion,
187            })
188        }
189    }
190
191    /// Check if a column exists in a table. If not, returns a structured error.
192    ///
193    /// # Arguments
194    ///
195    /// * `table` — Table to look up.
196    /// * `column` — Column name to validate.
197    pub fn validate_column(&self, table: &str, column: &str) -> Result<(), ValidationError> {
198        // If table doesn't exist, skip column validation (table error takes precedence)
199        if !self.tables.contains(&table.to_string()) {
200            return Ok(());
201        }
202
203        // Always allow *
204        if column == "*" {
205            return Ok(());
206        }
207
208        // Skip SQL expressions — these are not plain column names.
209        // Covers: COUNT(*), id::text, leg_ids[1], brand_name AS alias,
210        //         count(distinct x), EXCLUDED.col, NOW(), etc.
211        if column.contains('(')
212            || column.contains('[')
213            || column.contains("::")
214            || column.contains(" AS ")
215            || column.contains(" as ")
216            || column.starts_with("distinct ")
217            || column.starts_with("DISTINCT ")
218        {
219            return Ok(());
220        }
221
222        // Qualified names like "table.column" — validate against the referenced table
223        if column.contains('.') {
224            let parts: Vec<&str> = column.split('.').collect();
225            if parts.len() == 2 {
226                // Only validate if the referenced table is known to the validator
227                if self.tables.contains(&parts[0].to_string()) {
228                    return self.validate_column(parts[0], parts[1]);
229                }
230            }
231            // Unknown table prefix or complex dotted path — allow (might be JSON)
232            return Ok(());
233        }
234
235        if let Some(cols) = self.columns.get(table) {
236            if cols.contains(&column.to_string()) {
237                Ok(())
238            } else {
239                let suggestion = self.did_you_mean(column, cols);
240                Err(ValidationError::ColumnNotFound {
241                    table: table.to_string(),
242                    column: column.to_string(),
243                    suggestion,
244                })
245            }
246        } else {
247            Ok(())
248        }
249    }
250
251    /// Extract column name from an Expr for validation.
252    fn extract_column_name(expr: &Expr) -> Option<String> {
253        match expr {
254            Expr::Named(name) => Some(name.clone()),
255            Expr::Aliased { name, .. } => Some(name.clone()),
256            Expr::Aggregate { col, .. } => Some(col.clone()),
257            Expr::Cast { expr, .. } => Self::extract_column_name(expr),
258            Expr::JsonAccess { column, .. } => Some(column.clone()),
259            _ => None,
260        }
261    }
262
263    /// Get column type for a table.column
264    pub fn get_column_type(&self, table: &str, column: &str) -> Option<&String> {
265        self.column_types.get(table)?.get(column)
266    }
267
268    /// Validate that a Value's type matches the expected column type.
269    /// Returns Ok(()) if compatible, Err with TypeMismatch if not.
270    pub fn validate_value_type(
271        &self,
272        table: &str,
273        column: &str,
274        value: &crate::ast::Value,
275    ) -> Result<(), ValidationError> {
276        use crate::ast::Value;
277
278        // Get the expected type for this column
279        let expected_type = match self.get_column_type(table, column) {
280            Some(t) => t.to_uppercase(),
281            None => return Ok(()), // Column type unknown, skip validation
282        };
283
284        // NULL is always allowed (for nullable columns)
285        if matches!(value, Value::Null | Value::NullUuid) {
286            return Ok(());
287        }
288
289        // Param and NamedParam are runtime values - can't validate at compile time
290        if matches!(
291            value,
292            Value::Param(_)
293                | Value::NamedParam(_)
294                | Value::Function(_)
295                | Value::Subquery(_)
296                | Value::Expr(_)
297        ) {
298            return Ok(());
299        }
300
301        // Value::Array is used with IN/NOT IN operators — the array is a container
302        // of elements whose type should match the column, not the array itself.
303        // Skip type validation for arrays (the element types are checked at runtime).
304        if matches!(value, Value::Array(_)) {
305            return Ok(());
306        }
307
308        // Map Value variant to its type name
309        let value_type = match value {
310            Value::Bool(_) => "BOOLEAN",
311            Value::Int(_) => "INT",
312            Value::Float(_) => "FLOAT",
313            Value::String(_) => "TEXT",
314            Value::Uuid(_) => "UUID",
315            Value::Column(_) => return Ok(()), // Column reference, type checked elsewhere
316            Value::Interval { .. } => "INTERVAL",
317            Value::Timestamp(_) => "TIMESTAMP",
318            Value::Bytes(_) => "BYTEA",
319            Value::Vector(_) => "VECTOR",
320            Value::Json(_) => "JSONB",
321            _ => return Ok(()), // Unknown value type, skip
322        };
323
324        // Check compatibility
325        if !Self::types_compatible(&expected_type, value_type) {
326            return Err(ValidationError::TypeMismatch {
327                table: table.to_string(),
328                column: column.to_string(),
329                expected: expected_type,
330                got: value_type.to_string(),
331            });
332        }
333
334        Ok(())
335    }
336
337    /// Check if expected column type is compatible with value type.
338    /// Allows flexible matching (e.g., INT matches INT4, BIGINT, INTEGER, etc.)
339    fn types_compatible(expected: &str, value_type: &str) -> bool {
340        let expected = expected.to_uppercase();
341        let value_type = value_type.to_uppercase();
342
343        // Exact match
344        if expected == value_type {
345            return true;
346        }
347
348        // Integer family
349        let int_types = [
350            "INT",
351            "INT4",
352            "INT8",
353            "INTEGER",
354            "BIGINT",
355            "SMALLINT",
356            "SERIAL",
357            "BIGSERIAL",
358        ];
359        if int_types.contains(&expected.as_str()) && value_type == "INT" {
360            return true;
361        }
362
363        // Float family — includes "DOUBLE PRECISION" which is how PG stores f64
364        let float_types = [
365            "FLOAT",
366            "FLOAT4",
367            "FLOAT8",
368            "DOUBLE",
369            "DOUBLE PRECISION",
370            "DECIMAL",
371            "NUMERIC",
372            "REAL",
373        ];
374        if float_types.contains(&expected.as_str())
375            && (value_type == "FLOAT" || value_type == "INT")
376        {
377            return true;
378        }
379
380        // Text family - TEXT is very flexible in PostgreSQL
381        let text_types = ["TEXT", "VARCHAR", "CHAR", "CHARACTER", "CITEXT", "NAME"];
382        if text_types.contains(&expected.as_str()) && value_type == "TEXT" {
383            return true;
384        }
385
386        // Boolean
387        if (expected == "BOOLEAN" || expected == "BOOL") && value_type == "BOOLEAN" {
388            return true;
389        }
390
391        // UUID
392        if expected == "UUID" && (value_type == "UUID" || value_type == "TEXT") {
393            return true;
394        }
395
396        // Timestamp family
397        let ts_types = ["TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME", "TIMETZ"];
398        if ts_types.contains(&expected.as_str())
399            && (value_type == "TIMESTAMP" || value_type == "TEXT")
400        {
401            return true;
402        }
403
404        // JSONB/JSON accepts almost anything
405        if expected == "JSONB" || expected == "JSON" {
406            return true;
407        }
408
409        // Arrays
410        if expected.contains("[]") || expected.starts_with("_") {
411            return value_type == "ARRAY";
412        }
413
414        false
415    }
416
417    /// Validate an entire Qail against the schema.
418    pub fn validate_command(&self, cmd: &Qail) -> ValidationResult {
419        let mut errors = Vec::new();
420
421        if let Err(e) = self.validate_table(&cmd.table) {
422            errors.push(e);
423        }
424
425        // Collect aliases from the SELECT column list so that order_by / having
426        // references to computed aliases (e.g. count().alias("route_count"))
427        // are not flagged as column-not-found errors.
428        let mut aliases: Vec<String> = Vec::new();
429        for col in &cmd.columns {
430            if let Expr::Aliased { alias, .. } = col {
431                aliases.push(alias.clone());
432            }
433            if let Some(name) = Self::extract_column_name(col)
434                && let Err(e) = self.validate_column(&cmd.table, &name)
435            {
436                errors.push(e);
437            }
438        }
439
440        for cage in &cmd.cages {
441            // Skip column validation for Sort cages — ORDER BY can reference
442            // aliases from column_expr (e.g. count().alias("route_count")),
443            // which may not exist as physical columns in the primary table.
444            if matches!(cage.kind, crate::ast::CageKind::Sort(_)) {
445                continue;
446            }
447            for cond in &cage.conditions {
448                if let Some(name) = Self::extract_column_name(&cond.left) {
449                    // Skip validation for columns that match a computed alias
450                    if aliases.iter().any(|a| a == &name) {
451                        continue;
452                    }
453                    // For join conditions, column might be qualified (table.column)
454                    if name.contains('.') {
455                        let parts: Vec<&str> = name.split('.').collect();
456                        if parts.len() == 2 {
457                            if let Err(e) = self.validate_column(parts[0], parts[1]) {
458                                errors.push(e);
459                            }
460                            // Type validation for qualified column
461                            if let Err(e) =
462                                self.validate_value_type(parts[0], parts[1], &cond.value)
463                            {
464                                errors.push(e);
465                            }
466                        }
467                    } else {
468                        if let Err(e) = self.validate_column(&cmd.table, &name) {
469                            errors.push(e);
470                        }
471                        // Type validation for unqualified column
472                        if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
473                            errors.push(e);
474                        }
475                    }
476                }
477            }
478        }
479
480        for cond in &cmd.having {
481            if let Some(name) = Self::extract_column_name(&cond.left) {
482                if name.contains('(') || name == "*" {
483                    continue;
484                }
485                if name.contains('.') {
486                    let parts: Vec<&str> = name.split('.').collect();
487                    if parts.len() == 2 {
488                        if let Err(e) = self.validate_column(parts[0], parts[1]) {
489                            errors.push(e);
490                        }
491                        if let Err(e) = self.validate_value_type(parts[0], parts[1], &cond.value) {
492                            errors.push(e);
493                        }
494                    }
495                } else {
496                    if let Err(e) = self.validate_column(&cmd.table, &name) {
497                        errors.push(e);
498                    }
499                    if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
500                        errors.push(e);
501                    }
502                }
503            }
504        }
505
506        for join in &cmd.joins {
507            if let Err(e) = self.validate_table(&join.table) {
508                errors.push(e);
509            }
510
511            if let Some(conditions) = &join.on {
512                for cond in conditions {
513                    if let Some(name) = Self::extract_column_name(&cond.left)
514                        && name.contains('.')
515                    {
516                        let parts: Vec<&str> = name.split('.').collect();
517                        if parts.len() == 2
518                            && let Err(e) = self.validate_column(parts[0], parts[1])
519                        {
520                            errors.push(e);
521                        }
522                    }
523                    // Also check right side if it's a column reference
524                    if let crate::ast::Value::Column(col_name) = &cond.value
525                        && col_name.contains('.')
526                    {
527                        let parts: Vec<&str> = col_name.split('.').collect();
528                        if parts.len() == 2
529                            && let Err(e) = self.validate_column(parts[0], parts[1])
530                        {
531                            errors.push(e);
532                        }
533                    }
534                }
535            }
536        }
537
538        if let Some(returning) = &cmd.returning {
539            for col in returning {
540                if let Some(name) = Self::extract_column_name(col)
541                    && let Err(e) = self.validate_column(&cmd.table, &name)
542                {
543                    errors.push(e);
544                }
545            }
546        }
547
548        if errors.is_empty() {
549            Ok(())
550        } else {
551            Err(errors)
552        }
553    }
554
555    /// Find the best match with Levenshtein distance within threshold.
556    fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
557        let mut best_match = None;
558        let mut min_dist = usize::MAX;
559
560        for cand in candidates {
561            let cand_str = cand.as_ref();
562            let dist = levenshtein(input, cand_str);
563
564            // Dynamic threshold based on length
565            let threshold = match input.len() {
566                0..=2 => 0, // Precise match only for very short strings
567                3..=5 => 2, // Allow 2 char diff (e.g. usr -> users)
568                _ => 3,     // Allow 3 char diff for longer strings
569            };
570
571            if dist <= threshold && dist < min_dist {
572                min_dist = dist;
573                best_match = Some(cand_str.to_string());
574            }
575        }
576
577        best_match
578    }
579
580    // =========================================================================
581    // Legacy API (for backward compatibility)
582    // =========================================================================
583
584    /// Legacy: validate_table that returns String error
585    #[deprecated(note = "Use validate_table() which returns ValidationError")]
586    pub fn validate_table_legacy(&self, table: &str) -> Result<(), String> {
587        self.validate_table(table).map_err(|e| e.to_string())
588    }
589
590    /// Legacy: validate_column that returns String error
591    #[deprecated(note = "Use validate_column() which returns ValidationError")]
592    pub fn validate_column_legacy(&self, table: &str, column: &str) -> Result<(), String> {
593        self.validate_column(table, column)
594            .map_err(|e| e.to_string())
595    }
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601
602    #[test]
603    fn test_did_you_mean_table() {
604        let mut v = Validator::new();
605        v.add_table("users", &["id", "name"]);
606        v.add_table("orders", &["id", "total"]);
607
608        assert!(v.validate_table("users").is_ok());
609
610        let err = v.validate_table("usr").unwrap_err();
611        assert!(
612            matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
613        );
614
615        let err = v.validate_table("usrs").unwrap_err();
616        assert!(
617            matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
618        );
619    }
620
621    #[test]
622    fn test_did_you_mean_column() {
623        let mut v = Validator::new();
624        v.add_table("users", &["email", "password"]);
625
626        assert!(v.validate_column("users", "email").is_ok());
627        assert!(v.validate_column("users", "*").is_ok());
628
629        let err = v.validate_column("users", "emial").unwrap_err();
630        assert!(
631            matches!(err, ValidationError::ColumnNotFound { suggestion: Some(ref s), .. } if s == "email")
632        );
633    }
634
635    #[test]
636    fn test_qualified_column_name() {
637        let mut v = Validator::new();
638        v.add_table("users", &["id", "name"]);
639        v.add_table("profiles", &["user_id", "avatar"]);
640
641        // Qualified names should pass through
642        assert!(v.validate_column("users", "users.id").is_ok());
643        assert!(v.validate_column("users", "profiles.user_id").is_ok());
644    }
645
646    #[test]
647    fn test_validate_command() {
648        let mut v = Validator::new();
649        v.add_table("users", &["id", "email", "name"]);
650
651        let cmd = Qail::get("users").columns(["id", "email"]);
652        assert!(v.validate_command(&cmd).is_ok());
653
654        let cmd = Qail::get("users").columns(["id", "emial"]); // typo
655        let errors = v.validate_command(&cmd).unwrap_err();
656        assert_eq!(errors.len(), 1);
657        assert!(
658            matches!(&errors[0], ValidationError::ColumnNotFound { column, .. } if column == "emial")
659        );
660    }
661
662    #[test]
663    fn test_validate_having_columns() {
664        let mut v = Validator::new();
665        v.add_table("orders", &["id", "status", "total"]);
666
667        let mut cmd = Qail::get("orders");
668        cmd.having.push(crate::ast::Condition {
669            left: Expr::Named("totl".to_string()),
670            op: crate::ast::Operator::Eq,
671            value: crate::ast::Value::Int(1),
672            is_array_unnest: false,
673        });
674
675        let errors = v.validate_command(&cmd).unwrap_err();
676        assert!(errors.iter().any(
677            |e| matches!(e, ValidationError::ColumnNotFound { column, .. } if column == "totl")
678        ));
679    }
680
681    #[test]
682    fn test_error_display() {
683        let err = ValidationError::TableNotFound {
684            table: "usrs".to_string(),
685            suggestion: Some("users".to_string()),
686        };
687        assert_eq!(
688            err.to_string(),
689            "Table 'usrs' not found. Did you mean 'users'?"
690        );
691
692        let err = ValidationError::ColumnNotFound {
693            table: "users".to_string(),
694            column: "emial".to_string(),
695            suggestion: Some("email".to_string()),
696        };
697        assert_eq!(
698            err.to_string(),
699            "Column 'emial' not found in table 'users'. Did you mean 'email'?"
700        );
701    }
702}