Skip to main content

qail_core/
validator.rs

1//! Schema validator and fuzzy matching suggestions.
2//!
3//! Provides compile-time-like validation for Qail against a known schema.
4//! Used by CLI, LSP, and the encoder to catch errors before they hit the wire.
5
6use crate::ast::{Expr, Qail};
7use std::collections::HashMap;
8use strsim::levenshtein;
9
10/// Validation error with structured information.
11#[derive(Debug, Clone, PartialEq)]
12pub enum ValidationError {
13    TableNotFound {
14        table: String,
15        suggestion: Option<String>,
16    },
17    ColumnNotFound {
18        table: String,
19        column: String,
20        suggestion: Option<String>,
21    },
22    /// Type mismatch (future: when schema includes types)
23    TypeMismatch {
24        table: String,
25        column: String,
26        expected: String,
27        got: String,
28    },
29    /// Invalid operator for column type (future)
30    InvalidOperator {
31        column: String,
32        operator: String,
33        reason: String,
34    },
35}
36
37impl std::fmt::Display for ValidationError {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self {
40            ValidationError::TableNotFound { table, suggestion } => {
41                if let Some(s) = suggestion {
42                    write!(f, "Table '{}' not found. Did you mean '{}'?", table, s)
43                } else {
44                    write!(f, "Table '{}' not found.", table)
45                }
46            }
47            ValidationError::ColumnNotFound {
48                table,
49                column,
50                suggestion,
51            } => {
52                if let Some(s) = suggestion {
53                    write!(
54                        f,
55                        "Column '{}' not found in table '{}'. Did you mean '{}'?",
56                        column, table, s
57                    )
58                } else {
59                    write!(f, "Column '{}' not found in table '{}'.", column, table)
60                }
61            }
62            ValidationError::TypeMismatch {
63                table,
64                column,
65                expected,
66                got,
67            } => {
68                write!(
69                    f,
70                    "Type mismatch for '{}.{}': expected {}, got {}",
71                    table, column, expected, got
72                )
73            }
74            ValidationError::InvalidOperator {
75                column,
76                operator,
77                reason,
78            } => {
79                write!(
80                    f,
81                    "Invalid operator '{}' for column '{}': {}",
82                    operator, column, reason
83                )
84            }
85        }
86    }
87}
88
89impl std::error::Error for ValidationError {}
90
91/// Result of validation
92pub type ValidationResult = Result<(), Vec<ValidationError>>;
93
94/// Validates query elements against known schema and provides suggestions.
95#[derive(Debug, Clone)]
96pub struct Validator {
97    tables: Vec<String>,
98    columns: HashMap<String, Vec<String>>,
99    #[allow(dead_code)]
100    column_types: HashMap<String, HashMap<String, String>>,
101}
102
103impl Default for Validator {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109impl Validator {
110    /// Create a new Validator with known tables and columns.
111    pub fn new() -> Self {
112        Self {
113            tables: Vec::new(),
114            columns: HashMap::new(),
115            column_types: HashMap::new(),
116        }
117    }
118
119    /// Register a table and its columns.
120    pub fn add_table(&mut self, table: &str, cols: &[&str]) {
121        self.tables.push(table.to_string());
122        self.columns.insert(
123            table.to_string(),
124            cols.iter().map(|s| s.to_string()).collect(),
125        );
126    }
127
128    /// Register a table with column types (for future type validation).
129    pub fn add_table_with_types(&mut self, table: &str, cols: &[(&str, &str)]) {
130        self.tables.push(table.to_string());
131        let col_names: Vec<String> = cols.iter().map(|(name, _)| name.to_string()).collect();
132        self.columns.insert(table.to_string(), col_names);
133
134        let type_map: HashMap<String, String> = cols
135            .iter()
136            .map(|(name, typ)| (name.to_string(), typ.to_string()))
137            .collect();
138        self.column_types.insert(table.to_string(), type_map);
139    }
140
141    /// Get list of all table names (for autocomplete).
142    pub fn table_names(&self) -> &[String] {
143        &self.tables
144    }
145
146    /// Get column names for a table (for autocomplete).
147    pub fn column_names(&self, table: &str) -> Option<&Vec<String>> {
148        self.columns.get(table)
149    }
150
151    /// Check if a table exists.
152    pub fn table_exists(&self, table: &str) -> bool {
153        self.tables.contains(&table.to_string())
154    }
155
156    /// Check if a table exists. If not, returns structured error with suggestion.
157    pub fn validate_table(&self, table: &str) -> Result<(), ValidationError> {
158        if self.tables.contains(&table.to_string()) {
159            Ok(())
160        } else {
161            let suggestion = self.did_you_mean(table, &self.tables);
162            Err(ValidationError::TableNotFound {
163                table: table.to_string(),
164                suggestion,
165            })
166        }
167    }
168
169    /// Check if a column exists in a table. If not, returns structured error.
170    pub fn validate_column(&self, table: &str, column: &str) -> Result<(), ValidationError> {
171        // If table doesn't exist, skip column validation (table error takes precedence)
172        if !self.tables.contains(&table.to_string()) {
173            return Ok(());
174        }
175
176        // Always allow * and qualified names like "table.column"
177        if column == "*" || column.contains('.') {
178            return Ok(());
179        }
180
181        if let Some(cols) = self.columns.get(table) {
182            if cols.contains(&column.to_string()) {
183                Ok(())
184            } else {
185                let suggestion = self.did_you_mean(column, cols);
186                Err(ValidationError::ColumnNotFound {
187                    table: table.to_string(),
188                    column: column.to_string(),
189                    suggestion,
190                })
191            }
192        } else {
193            Ok(())
194        }
195    }
196
197    /// Extract column name from an Expr for validation.
198    fn extract_column_name(expr: &Expr) -> Option<String> {
199        match expr {
200            Expr::Named(name) => Some(name.clone()),
201            Expr::Aliased { name, .. } => Some(name.clone()),
202            Expr::Aggregate { col, .. } => Some(col.clone()),
203            Expr::Cast { expr, .. } => Self::extract_column_name(expr),
204            Expr::JsonAccess { column, .. } => Some(column.clone()),
205            _ => None,
206        }
207    }
208    
209    /// Get column type for a table.column
210    pub fn get_column_type(&self, table: &str, column: &str) -> Option<&String> {
211        self.column_types.get(table)?.get(column)
212    }
213    
214    /// Validate that a Value's type matches the expected column type.
215    /// Returns Ok(()) if compatible, Err with TypeMismatch if not.
216    pub fn validate_value_type(
217        &self,
218        table: &str,
219        column: &str,
220        value: &crate::ast::Value,
221    ) -> Result<(), ValidationError> {
222        use crate::ast::Value;
223        
224        // Get the expected type for this column
225        let expected_type = match self.get_column_type(table, column) {
226            Some(t) => t.to_uppercase(),
227            None => return Ok(()), // Column type unknown, skip validation
228        };
229        
230        // NULL is always allowed (for nullable columns)
231        if matches!(value, Value::Null | Value::NullUuid) {
232            return Ok(());
233        }
234        
235        // Param and NamedParam are runtime values - can't validate at compile time
236        if matches!(value, Value::Param(_) | Value::NamedParam(_) | Value::Function(_) | Value::Subquery(_) | Value::Expr(_)) {
237            return Ok(());
238        }
239        
240        // Map Value variant to its type name
241        let value_type = match value {
242            Value::Bool(_) => "BOOLEAN",
243            Value::Int(_) => "INT",
244            Value::Float(_) => "FLOAT",
245            Value::String(_) => "TEXT",
246            Value::Uuid(_) => "UUID",
247            Value::Array(_) => "ARRAY",
248            Value::Column(_) => return Ok(()), // Column reference, type checked elsewhere
249            Value::Interval { .. } => "INTERVAL",
250            Value::Timestamp(_) => "TIMESTAMP",
251            Value::Bytes(_) => "BYTEA",
252            Value::Vector(_) => "VECTOR",
253            Value::Json(_) => "JSONB",
254            _ => return Ok(()), // Unknown value type, skip
255        };
256        
257        // Check compatibility
258        if !Self::types_compatible(&expected_type, value_type) {
259            return Err(ValidationError::TypeMismatch {
260                table: table.to_string(),
261                column: column.to_string(),
262                expected: expected_type,
263                got: value_type.to_string(),
264            });
265        }
266        
267        Ok(())
268    }
269    
270    /// Check if expected column type is compatible with value type.
271    /// Allows flexible matching (e.g., INT matches INT4, BIGINT, INTEGER, etc.)
272    fn types_compatible(expected: &str, value_type: &str) -> bool {
273        let expected = expected.to_uppercase();
274        let value_type = value_type.to_uppercase();
275        
276        // Exact match
277        if expected == value_type {
278            return true;
279        }
280        
281        // Integer family
282        let int_types = ["INT", "INT4", "INT8", "INTEGER", "BIGINT", "SMALLINT", "SERIAL", "BIGSERIAL"];
283        if int_types.contains(&expected.as_str()) && value_type == "INT" {
284            return true;
285        }
286        
287        // Float family
288        let float_types = ["FLOAT", "FLOAT4", "FLOAT8", "DOUBLE", "DECIMAL", "NUMERIC", "REAL"];
289        if float_types.contains(&expected.as_str()) && (value_type == "FLOAT" || value_type == "INT") {
290            return true;
291        }
292        
293        // Text family - TEXT is very flexible in PostgreSQL
294        let text_types = ["TEXT", "VARCHAR", "CHAR", "CHARACTER", "CITEXT", "NAME"];
295        if text_types.contains(&expected.as_str()) && value_type == "TEXT" {
296            return true;
297        }
298        
299        // Boolean
300        if (expected == "BOOLEAN" || expected == "BOOL") && value_type == "BOOLEAN" {
301            return true;
302        }
303        
304        // UUID
305        if expected == "UUID" && (value_type == "UUID" || value_type == "TEXT") {
306            return true;
307        }
308        
309        // Timestamp family
310        let ts_types = ["TIMESTAMP", "TIMESTAMPTZ", "DATE", "TIME", "TIMETZ"];
311        if ts_types.contains(&expected.as_str()) && (value_type == "TIMESTAMP" || value_type == "TEXT") {
312            return true;
313        }
314        
315        // JSONB/JSON accepts almost anything
316        if expected == "JSONB" || expected == "JSON" {
317            return true;
318        }
319        
320        // Arrays
321        if expected.contains("[]") || expected.starts_with("_") {
322            return value_type == "ARRAY";
323        }
324        
325        false
326    }
327
328    /// Validate an entire Qail against the schema.
329    pub fn validate_command(&self, cmd: &Qail) -> ValidationResult {
330        let mut errors = Vec::new();
331
332        if let Err(e) = self.validate_table(&cmd.table) {
333            errors.push(e);
334        }
335
336        for col in &cmd.columns {
337            if let Some(name) = Self::extract_column_name(col)
338                && let Err(e) = self.validate_column(&cmd.table, &name)
339            {
340                errors.push(e);
341            }
342        }
343
344        for cage in &cmd.cages {
345            for cond in &cage.conditions {
346                if let Some(name) = Self::extract_column_name(&cond.left) {
347                    // For join conditions, column might be qualified (table.column)
348                    if name.contains('.') {
349                        let parts: Vec<&str> = name.split('.').collect();
350                        if parts.len() == 2 {
351                            if let Err(e) = self.validate_column(parts[0], parts[1]) {
352                                errors.push(e);
353                            }
354                            // Type validation for qualified column
355                            if let Err(e) = self.validate_value_type(parts[0], parts[1], &cond.value) {
356                                errors.push(e);
357                            }
358                        }
359                    } else {
360                        if let Err(e) = self.validate_column(&cmd.table, &name) {
361                            errors.push(e);
362                        }
363                        // Type validation for unqualified column
364                        if let Err(e) = self.validate_value_type(&cmd.table, &name, &cond.value) {
365                            errors.push(e);
366                        }
367                    }
368                }
369            }
370        }
371
372        for join in &cmd.joins {
373
374            if let Err(e) = self.validate_table(&join.table) {
375                errors.push(e);
376            }
377
378
379            if let Some(conditions) = &join.on {
380                for cond in conditions {
381                    if let Some(name) = Self::extract_column_name(&cond.left)
382                        && name.contains('.')
383                    {
384                        let parts: Vec<&str> = name.split('.').collect();
385                        if parts.len() == 2
386                            && let Err(e) = self.validate_column(parts[0], parts[1])
387                        {
388                            errors.push(e);
389                        }
390                    }
391                    // Also check right side if it's a column reference
392                    if let crate::ast::Value::Column(col_name) = &cond.value
393                        && col_name.contains('.')
394                    {
395                        let parts: Vec<&str> = col_name.split('.').collect();
396                        if parts.len() == 2
397                            && let Err(e) = self.validate_column(parts[0], parts[1])
398                        {
399                            errors.push(e);
400                        }
401                    }
402                }
403            }
404        }
405
406        if let Some(returning) = &cmd.returning {
407            for col in returning {
408                if let Some(name) = Self::extract_column_name(col)
409                    && let Err(e) = self.validate_column(&cmd.table, &name)
410                {
411                    errors.push(e);
412                }
413            }
414        }
415
416        if errors.is_empty() {
417            Ok(())
418        } else {
419            Err(errors)
420        }
421    }
422
423    /// Find the best match with Levenshtein distance within threshold.
424    fn did_you_mean(&self, input: &str, candidates: &[impl AsRef<str>]) -> Option<String> {
425        let mut best_match = None;
426        let mut min_dist = usize::MAX;
427
428        for cand in candidates {
429            let cand_str = cand.as_ref();
430            let dist = levenshtein(input, cand_str);
431
432            // Dynamic threshold based on length
433            let threshold = match input.len() {
434                0..=2 => 0, // Precise match only for very short strings
435                3..=5 => 2, // Allow 2 char diff (e.g. usr -> users)
436                _ => 3,     // Allow 3 char diff for longer strings
437            };
438
439            if dist <= threshold && dist < min_dist {
440                min_dist = dist;
441                best_match = Some(cand_str.to_string());
442            }
443        }
444
445        best_match
446    }
447
448    // =========================================================================
449    // Legacy API (for backward compatibility)
450    // =========================================================================
451
452    /// Legacy: validate_table that returns String error
453    #[deprecated(note = "Use validate_table() which returns ValidationError")]
454    pub fn validate_table_legacy(&self, table: &str) -> Result<(), String> {
455        self.validate_table(table).map_err(|e| e.to_string())
456    }
457
458    /// Legacy: validate_column that returns String error
459    #[deprecated(note = "Use validate_column() which returns ValidationError")]
460    pub fn validate_column_legacy(&self, table: &str, column: &str) -> Result<(), String> {
461        self.validate_column(table, column)
462            .map_err(|e| e.to_string())
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn test_did_you_mean_table() {
472        let mut v = Validator::new();
473        v.add_table("users", &["id", "name"]);
474        v.add_table("orders", &["id", "total"]);
475
476        assert!(v.validate_table("users").is_ok());
477
478        let err = v.validate_table("usr").unwrap_err();
479        assert!(
480            matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
481        );
482
483        let err = v.validate_table("usrs").unwrap_err();
484        assert!(
485            matches!(err, ValidationError::TableNotFound { suggestion: Some(ref s), .. } if s == "users")
486        );
487    }
488
489    #[test]
490    fn test_did_you_mean_column() {
491        let mut v = Validator::new();
492        v.add_table("users", &["email", "password"]);
493
494        assert!(v.validate_column("users", "email").is_ok());
495        assert!(v.validate_column("users", "*").is_ok());
496
497        let err = v.validate_column("users", "emial").unwrap_err();
498        assert!(
499            matches!(err, ValidationError::ColumnNotFound { suggestion: Some(ref s), .. } if s == "email")
500        );
501    }
502
503    #[test]
504    fn test_qualified_column_name() {
505        let mut v = Validator::new();
506        v.add_table("users", &["id", "name"]);
507        v.add_table("profiles", &["user_id", "avatar"]);
508
509        // Qualified names should pass through
510        assert!(v.validate_column("users", "users.id").is_ok());
511        assert!(v.validate_column("users", "profiles.user_id").is_ok());
512    }
513
514    #[test]
515    fn test_validate_command() {
516        let mut v = Validator::new();
517        v.add_table("users", &["id", "email", "name"]);
518
519        let cmd = Qail::get("users").columns(["id", "email"]);
520        assert!(v.validate_command(&cmd).is_ok());
521
522        let cmd = Qail::get("users").columns(["id", "emial"]); // typo
523        let errors = v.validate_command(&cmd).unwrap_err();
524        assert_eq!(errors.len(), 1);
525        assert!(
526            matches!(&errors[0], ValidationError::ColumnNotFound { column, .. } if column == "emial")
527        );
528    }
529
530    #[test]
531    fn test_error_display() {
532        let err = ValidationError::TableNotFound {
533            table: "usrs".to_string(),
534            suggestion: Some("users".to_string()),
535        };
536        assert_eq!(
537            err.to_string(),
538            "Table 'usrs' not found. Did you mean 'users'?"
539        );
540
541        let err = ValidationError::ColumnNotFound {
542            table: "users".to_string(),
543            column: "emial".to_string(),
544            suggestion: Some("email".to_string()),
545        };
546        assert_eq!(
547            err.to_string(),
548            "Column 'emial' not found in table 'users'. Did you mean 'email'?"
549        );
550    }
551}