qail_core/
validator.rs

1//! Schema validator and fuzzy matching suggestions.
2//!
3//! Provides compile-time-like validation for Qail against a known schema.
4//! Used by CLI, LSP, and the encoder to catch errors before they hit the wire.
5
6use crate::ast::{Expr, Qail};
7use std::collections::HashMap;
8use strsim::levenshtein;
9
10/// Validation error with structured information.
11#[derive(Debug, Clone, PartialEq)]
12pub enum ValidationError {
13    TableNotFound {
14        table: String,
15        suggestion: Option<String>,
16    },
17    ColumnNotFound {
18        table: String,
19        column: String,
20        suggestion: Option<String>,
21    },
22    /// Type mismatch (future: when schema includes types)
23    TypeMismatch {
24        table: String,
25        column: String,
26        expected: String,
27        got: String,
28    },
29    /// Invalid operator for column type (future)
30    InvalidOperator {
31        column: String,
32        operator: String,
33        reason: String,
34    },
35}
36
37impl std::fmt::Display for ValidationError {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self {
40            ValidationError::TableNotFound { table, suggestion } => {
41                if let Some(s) = suggestion {
42                    write!(f, "Table '{}' not found. Did you mean '{}'?", table, s)
43                } else {
44                    write!(f, "Table '{}' not found.", table)
45                }
46            }
47            ValidationError::ColumnNotFound {
48                table,
49                column,
50                suggestion,
51            } => {
52                if let Some(s) = suggestion {
53                    write!(
54                        f,
55                        "Column '{}' not found in table '{}'. Did you mean '{}'?",
56                        column, table, s
57                    )
58                } else {
59                    write!(f, "Column '{}' not found in table '{}'.", column, table)
60                }
61            }
62            ValidationError::TypeMismatch {
63                table,
64                column,
65                expected,
66                got,
67            } => {
68                write!(
69                    f,
70                    "Type mismatch for '{}.{}': expected {}, got {}",
71                    table, column, expected, got
72                )
73            }
74            ValidationError::InvalidOperator {
75                column,
76                operator,
77                reason,
78            } => {
79                write!(
80                    f,
81                    "Invalid operator '{}' for column '{}': {}",
82                    operator, column, reason
83                )
84            }
85        }
86    }
87}
88
89impl std::error::Error for ValidationError {}
90
91/// Result of validation
92pub type ValidationResult = Result<(), Vec<ValidationError>>;
93
94/// Validates query elements against known schema and provides suggestions.
95#[derive(Debug, Clone)]
96pub struct Validator {
97    tables: Vec<String>,
98    columns: HashMap<String, Vec<String>>,
99    #[allow(dead_code)]
100    column_types: HashMap<String, HashMap<String, String>>,
101}
102
103impl Default for Validator {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109impl Validator {
110    /// Create a new Validator with known tables and columns.
111    pub fn new() -> Self {
112        Self {
113            tables: Vec::new(),
114            columns: HashMap::new(),
115            column_types: HashMap::new(),
116        }
117    }
118
119    /// Register a table and its columns.
120    pub fn add_table(&mut self, table: &str, cols: &[&str]) {
121        self.tables.push(table.to_string());
122        self.columns.insert(
123            table.to_string(),
124            cols.iter().map(|s| s.to_string()).collect(),
125        );
126    }
127
128    /// Register a table with column types (for future type validation).
129    pub fn add_table_with_types(&mut self, table: &str, cols: &[(&str, &str)]) {
130        self.tables.push(table.to_string());
131        let col_names: Vec<String> = cols.iter().map(|(name, _)| name.to_string()).collect();
132        self.columns.insert(table.to_string(), col_names);
133
134        let type_map: HashMap<String, String> = cols
135            .iter()
136            .map(|(name, typ)| (name.to_string(), typ.to_string()))
137            .collect();
138        self.column_types.insert(table.to_string(), type_map);
139    }
140
141    /// Get list of all table names (for autocomplete).
142    pub fn table_names(&self) -> &[String] {
143        &self.tables
144    }
145
146    /// Get column names for a table (for autocomplete).
147    pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
148        self.columns.get(table)
149    }
150
151    /// Check if a table exists.
152    pub fn table_exists(&self, table: &str) -> bool {
153        self.tables.contains(&table.to_string())
154    }
155
156    /// Check if a table exists. If not, returns structured error with suggestion.
157    pub fn validate_table(&self, table: &str) -> Result<(), ValidationError> {
158        if self.tables.contains(&table.to_string()) {
159            Ok(())
160        } else {
161            let suggestion = self.did_you_mean(table, &self.tables);
162            Err(ValidationError::TableNotFound {
163                table: table.to_string(),
164                suggestion,
165            })
166        }
167    }
168
169    /// Check if a column exists in a table. If not, returns structured error.
170    pub fn validate_column(&self, table: &str, column: &str) -> Result<(), ValidationError> {
171        // If table doesn't exist, skip column validation (table error takes precedence)
172        if !self.tables.contains(&table.to_string()) {
173            return Ok(());
174        }
175
176        // Always allow * and qualified names like "table.column"
177        if column == "*" || column.contains('.') {
178            return Ok(());
179        }
180
181        if let Some(cols) = self.columns.get(table) {
182            if cols.contains(&column.to_string()) {
183                Ok(())
184            } else {
185                let suggestion = self.did_you_mean(column, cols);
186                Err(ValidationError::ColumnNotFound {
187                    table: table.to_string(),
188                    column: column.to_string(),
189                    suggestion,
190                })
191            }
192        } else {
193            Ok(())
194        }
195    }
196
197    /// Extract column name from an Expr for validation.
198    fn extract_column_name(expr: &Expr) -> Option<String> {
199        match expr {
200            Expr::Named(name) => Some(name.clone()),
201            Expr::Aliased { name, .. } => Some(name.clone()),
202            Expr::Aggregate { col, .. } => Some(col.clone()),
203            Expr::Cast { expr, .. } => Self::extract_column_name(expr),
204            Expr::JsonAccess { column, .. } => Some(column.clone()),
205            _ => None,
206        }
207    }
208    
209    /// Get column type for a table.column
210    pub fn get_column_type(&self, table: &str, column: &str) -> Option<&String> {
211        self.column_types.get(table)?.get(column)
212    }
213    
214    /// Validate that a Value's type matches the expected column type.
215    /// Returns Ok(()) if compatible, Err with TypeMismatch if not.
216    pub fn validate_value_type(
217        &self,
218        table: &str,
219        column: &str,
220        value: &crate::ast::Value,
221    ) -> Result<(), ValidationError> {
222        use crate::ast::Value;
223        
224        // Get the expected type for this column
225        let expected_type = match self.get_column_type(table, column) {
226            Some(t) => t.to_uppercase(),
227            None => return Ok(()), // Column type unknown, skip validation
228        };
229        
230        // NULL is always allowed (for nullable columns)
231        if matches!(value, Value::Null | Value::NullUuid) {
232            return Ok(());
233        }
234        
235        // Param and NamedParam are runtime values - can't validate at compile time
236        if matches!(value, Value::Param(_) | Value::NamedParam(_) | Value::Function(_) | Value::Subquery(_) | Value::Expr(_)) {
237            return Ok(());
238        }
239        
240        // Map Value variant to its type name
241        let value_type = match value {
242            Value::Bool(_) => "BOOLEAN",
243            Value::Int(_) => "INT",
244            Value::Float(_) => "FLOAT",
245            Value::String(_) => "TEXT",
246            Value::Uuid(_) => "UUID",
247            Value::Array(_) => "ARRAY",
248            Value::Column(_) => return Ok(()), // Column reference, type checked elsewhere
249            Value::Interval { .. } => "INTERVAL",
250            Value::Timestamp(_) => "TIMESTAMP",
251            Value::Bytes(_) => "BYTEA",
252            Value::Vector(_) => "VECTOR",
253            _ => return Ok(()), // Unknown value type, skip
254        };
255        
256        // Check compatibility
257        if !Self::types_compatible(&expected_type, value_type) {
258            return Err(ValidationError::TypeMismatch {
259                table: table.to_string(),
260                column: column.to_string(),
261                expected: expected_type,
262                got: value_type.to_string(),
263            });
264        }
265        
266        Ok(())
267    }
268    
269    /// Check if expected column type is compatible with value type.
270    /// Allows flexible matching (e.g., INT matches INT4, BIGINT, INTEGER, etc.)
271    fn types_compatible(expected: &str, value_type: &str) -> bool {
272        let expected = expected.to_uppercase();
273        let value_type = value_type.to_uppercase();
274        
275        // Exact match
276        if expected == value_type {
277            return true;
278        }
279        
280        // Integer family
281        let int_types = ["INT", "INT4", "INT8", "INTEGER", "BIGINT", "SMALLINT", "SERIAL", "BIGSERIAL"];
282        if int_types.contains(&expected.as_str()) && value_type == "INT" {
283            return true;
284        }
285        
286        // Float family
287        let float_types = ["FLOAT", "FLOAT4", "FLOAT8", "DOUBLE", "DECIMAL", "NUMERIC", "REAL"];
288        if float_types.contains(&expected.as_str()) && (value_type == "FLOAT" || value_type == "INT") {
289            return true;
290        }
291        
292        // Text family - TEXT is very flexible in PostgreSQL
293        let text_types = ["TEXT", "VARCHAR", "CHAR", "CHARACTER", "CITEXT", "NAME"];
294        if text_types.contains(&expected.as_str()) && value_type == "TEXT" {
295            return true;
296        }
297        
298        // Boolean
299        if (expected == "BOOLEAN" || expected == "BOOL") && value_type == "BOOLEAN" {
300            return true;
301        }
302        
303        // UUID
304        if expected == "UUID" && (value_type == "UUID" || value_type == "TEXT") {
305            return true;
306        }
307        
308        // Timestamp family
309        let ts_types = ["TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME", "TIMETZ"];
310        if ts_types.contains(&expected.as_str()) && (value_type == "TIMESTAMP" || value_type == "TEXT") {
311            return true;
312        }
313        
314        // JSONB/JSON accepts almost anything
315        if expected == "JSONB" || expected == "JSON" {
316            return true;
317        }
318        
319        // Arrays
320        if expected.contains("[]") || expected.starts_with("_") {
321            return value_type == "ARRAY";
322        }
323        
324        false
325    }
326
327    /// Validate an entire Qail against the schema.
328    pub fn validate_command(&self, cmd: &Qail) -> ValidationResult {
329        let mut errors = Vec::new();
330
331        if let Err(e) = self.validate_table(&cmd.table) {
332            errors.push(e);
333        }
334
335        for col in &cmd.columns {
336            if let Some(name) = Self::extract_column_name(col)
337                && let Err(e) = self.validate_column(&cmd.table, &name)
338            {
339                errors.push(e);
340            }
341        }
342
343        for cage in &cmd.cages {
344            for cond in &cage.conditions {
345                if let Some(name) = Self::extract_column_name(&cond.left) {
346                    // For join conditions, column might be qualified (table.column)
347                    if name.contains('.') {
348                        let parts: Vec<&str> = name.split('.').collect();
349                        if parts.len() == 2 {
350                            if let Err(e) = self.validate_column(parts[0], parts[1]) {
351                                errors.push(e);
352                            }
353                            // Type validation for qualified column
354                            if let Err(e) = self.validate_value_type(parts[0], parts[1], &cond.value) {
355                                errors.push(e);
356                            }
357                        }
358                    } else {
359                        if let Err(e) = self.validate_column(&cmd.table, &name) {
360                            errors.push(e);
361                        }
362                        // Type validation for unqualified column
363                        if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
364                            errors.push(e);
365                        }
366                    }
367                }
368            }
369        }
370
371        for join in &cmd.joins {
372
373            if let Err(e) = self.validate_table(&join.table) {
374                errors.push(e);
375            }
376
377
378            if let Some(conditions) = &join.on {
379                for cond in conditions {
380                    if let Some(name) = Self::extract_column_name(&cond.left)
381                        && name.contains('.')
382                    {
383                        let parts: Vec<&str> = name.split('.').collect();
384                        if parts.len() == 2
385                            && let Err(e) = self.validate_column(parts[0], parts[1])
386                        {
387                            errors.push(e);
388                        }
389                    }
390                    // Also check right side if it's a column reference
391                    if let crate::ast::Value::Column(col_name) = &cond.value
392                        && col_name.contains('.')
393                    {
394                        let parts: Vec<&str> = col_name.split('.').collect();
395                        if parts.len() == 2
396                            && let Err(e) = self.validate_column(parts[0], parts[1])
397                        {
398                            errors.push(e);
399                        }
400                    }
401                }
402            }
403        }
404
405        if let Some(returning) = &cmd.returning {
406            for col in returning {
407                if let Some(name) = Self::extract_column_name(col)
408                    && let Err(e) = self.validate_column(&cmd.table, &name)
409                {
410                    errors.push(e);
411                }
412            }
413        }
414
415        if errors.is_empty() {
416            Ok(())
417        } else {
418            Err(errors)
419        }
420    }
421
422    /// Find the best match with Levenshtein distance within threshold.
423    fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
424        let mut best_match = None;
425        let mut min_dist = usize::MAX;
426
427        for cand in candidates {
428            let cand_str = cand.as_ref();
429            let dist = levenshtein(input, cand_str);
430
431            // Dynamic threshold based on length
432            let threshold = match input.len() {
433                0..=2 => 0, // Precise match only for very short strings
434                3..=5 => 2, // Allow 2 char diff (e.g. usr -> users)
435                _ => 3,     // Allow 3 char diff for longer strings
436            };
437
438            if dist <= threshold && dist < min_dist {
439                min_dist = dist;
440                best_match = Some(cand_str.to_string());
441            }
442        }
443
444        best_match
445    }
446
447    // =========================================================================
448    // Legacy API (for backward compatibility)
449    // =========================================================================
450
451    /// Legacy: validate_table that returns String error
452    #[deprecated(note = "Use validate_table() which returns ValidationError")]
453    pub fn validate_table_legacy(&self, table: &str) -> Result<(), String> {
454        self.validate_table(table).map_err(|e| e.to_string())
455    }
456
457    /// Legacy: validate_column that returns String error
458    #[deprecated(note = "Use validate_column() which returns ValidationError")]
459    pub fn validate_column_legacy(&self, table: &str, column: &str) -> Result<(), String> {
460        self.validate_column(table, column)
461            .map_err(|e| e.to_string())
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    #[test]
470    fn test_did_you_mean_table() {
471        let mut v = Validator::new();
472        v.add_table("users", &["id", "name"]);
473        v.add_table("orders", &["id", "total"]);
474
475        assert!(v.validate_table("users").is_ok());
476
477        let err = v.validate_table("usr").unwrap_err();
478        assert!(
479            matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
480        );
481
482        let err = v.validate_table("usrs").unwrap_err();
483        assert!(
484            matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
485        );
486    }
487
488    #[test]
489    fn test_did_you_mean_column() {
490        let mut v = Validator::new();
491        v.add_table("users", &["email", "password"]);
492
493        assert!(v.validate_column("users", "email").is_ok());
494        assert!(v.validate_column("users", "*").is_ok());
495
496        let err = v.validate_column("users", "emial").unwrap_err();
497        assert!(
498            matches!(err, ValidationError::ColumnNotFound { suggestion: Some(ref s), .. } if s == "email")
499        );
500    }
501
502    #[test]
503    fn test_qualified_column_name() {
504        let mut v = Validator::new();
505        v.add_table("users", &["id", "name"]);
506        v.add_table("profiles", &["user_id", "avatar"]);
507
508        // Qualified names should pass through
509        assert!(v.validate_column("users", "users.id").is_ok());
510        assert!(v.validate_column("users", "profiles.user_id").is_ok());
511    }
512
513    #[test]
514    fn test_validate_command() {
515        let mut v = Validator::new();
516        v.add_table("users", &["id", "email", "name"]);
517
518        let cmd = Qail::get("users").columns(["id", "email"]);
519        assert!(v.validate_command(&cmd).is_ok());
520
521        let cmd = Qail::get("users").columns(["id", "emial"]); // typo
522        let errors = v.validate_command(&cmd).unwrap_err();
523        assert_eq!(errors.len(), 1);
524        assert!(
525            matches!(&errors[0], ValidationError::ColumnNotFound { column, .. } if column == "emial")
526        );
527    }
528
529    #[test]
530    fn test_error_display() {
531        let err = ValidationError::TableNotFound {
532            table: "usrs".to_string(),
533            suggestion: Some("users".to_string()),
534        };
535        assert_eq!(
536            err.to_string(),
537            "Table 'usrs' not found. Did you mean 'users'?"
538        );
539
540        let err = ValidationError::ColumnNotFound {
541            table: "users".to_string(),
542            column: "emial".to_string(),
543            suggestion: Some("email".to_string()),
544        };
545        assert_eq!(
546            err.to_string(),
547            "Column 'emial' not found in table 'users'. Did you mean 'email'?"
548        );
549    }
550}