Skip to main content

fraiseql_core/validation/
compile_time.rs

1//! Compile-time validation for cross-field rules and schema consistency.
2//!
3//! This module validates Elo expressions at schema compilation time, ensuring:
4//! - Field references exist and are properly typed
5//! - Cross-field rules reference compatible types
6//! - SQL constraints can be generated
7//! - No circular dependencies or invalid rules
8//!
9//! Elo is an expression language by Bernard Lambeau: <https://elo-lang.org/>
10
11use std::collections::{HashMap, HashSet};
12
13/// Schema context for compile-time validation
14#[derive(Debug, Clone)]
15pub struct SchemaContext {
16    /// Type definitions: type_name -> fields
17    pub types:  HashMap<String, TypeDef>,
18    /// Field types: (type_name, field_name) -> field_type
19    pub fields: HashMap<(String, String), FieldType>,
20}
21
22/// Type definition
23#[derive(Debug, Clone)]
24pub struct TypeDef {
25    pub name:   String,
26    pub fields: Vec<String>,
27}
28
29/// Field type information
30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
31pub enum FieldType {
32    String,
33    Integer,
34    Float,
35    Boolean,
36    Date,
37    DateTime,
38    Custom(String),
39}
40
41impl FieldType {
42    /// Check if two types are comparable
43    pub fn is_comparable_with(&self, other: &FieldType) -> bool {
44        match (self, other) {
45            // Same types are always comparable
46            (a, b) if a == b => true,
47            // Numeric types are comparable with each other
48            (FieldType::Integer, FieldType::Float) => true,
49            (FieldType::Float, FieldType::Integer) => true,
50            // Date and DateTime are comparable
51            (FieldType::Date, FieldType::DateTime) => true,
52            (FieldType::DateTime, FieldType::Date) => true,
53            // Everything else is not comparable
54            _ => false,
55        }
56    }
57}
58
59/// Compile-time validation result
60#[derive(Debug, Clone)]
61pub struct CompileTimeValidationResult {
62    pub valid:          bool,
63    pub errors:         Vec<CompileTimeError>,
64    pub warnings:       Vec<String>,
65    pub sql_constraint: Option<String>,
66}
67
68/// Compile-time validation error
69#[derive(Debug, Clone)]
70pub struct CompileTimeError {
71    pub field:      String,
72    pub message:    String,
73    pub suggestion: Option<String>,
74}
75
76/// Compile-time validator for cross-field rules
77#[derive(Debug)]
78pub struct CompileTimeValidator {
79    context: SchemaContext,
80}
81
82impl CompileTimeValidator {
83    /// Create a new compile-time validator
84    pub fn new(context: SchemaContext) -> Self {
85        Self { context }
86    }
87
88    /// Validate a cross-field rule
89    pub fn validate_cross_field_rule(
90        &self,
91        type_name: &str,
92        left_field: &str,
93        operator: &str,
94        right_field: &str,
95    ) -> CompileTimeValidationResult {
96        let mut errors = Vec::new();
97        let warnings = Vec::new();
98
99        // Check if type exists
100        if !self.context.types.contains_key(type_name) {
101            return CompileTimeValidationResult {
102                valid: false,
103                errors: vec![CompileTimeError {
104                    field:      type_name.to_string(),
105                    message:    format!("Type '{}' not found in schema", type_name),
106                    suggestion: Some("Check that the type is defined".to_string()),
107                }],
108                warnings,
109                sql_constraint: None,
110            };
111        }
112
113        // Check if left field exists
114        let left_key = (type_name.to_string(), left_field.to_string());
115        let Some(left_type) = self.context.fields.get(&left_key) else {
116            errors.push(CompileTimeError {
117                field:      left_field.to_string(),
118                message:    format!("Field '{}' not found in type '{}'", left_field, type_name),
119                suggestion: Some(self.suggest_field(type_name, left_field)),
120            });
121            return CompileTimeValidationResult {
122                valid: false,
123                errors,
124                warnings,
125                sql_constraint: None,
126            };
127        };
128
129        // Check if right field exists
130        let right_key = (type_name.to_string(), right_field.to_string());
131        let Some(right_type) = self.context.fields.get(&right_key) else {
132            errors.push(CompileTimeError {
133                field:      right_field.to_string(),
134                message:    format!("Field '{}' not found in type '{}'", right_field, type_name),
135                suggestion: Some(self.suggest_field(type_name, right_field)),
136            });
137            return CompileTimeValidationResult {
138                valid: false,
139                errors,
140                warnings,
141                sql_constraint: None,
142            };
143        };
144
145        // Check if types are comparable
146        if !left_type.is_comparable_with(right_type) {
147            errors.push(CompileTimeError {
148                field:      format!("{} {} {}", left_field, operator, right_field),
149                message:    format!("Cannot compare {:?} with {:?}", left_type, right_type),
150                suggestion: Some(format!("Ensure both fields have comparable types")),
151            });
152            return CompileTimeValidationResult {
153                valid: false,
154                errors,
155                warnings,
156                sql_constraint: None,
157            };
158        }
159
160        // Generate SQL constraint
161        let sql_constraint = self.generate_sql_constraint(
162            type_name,
163            left_field,
164            operator,
165            right_field,
166            left_type,
167            right_type,
168        );
169
170        CompileTimeValidationResult {
171            valid: true,
172            errors,
173            warnings,
174            sql_constraint,
175        }
176    }
177
178    /// Validate an ELO expression at compile time
179    pub fn validate_elo_expression(
180        &self,
181        type_name: &str,
182        expression: &str,
183    ) -> CompileTimeValidationResult {
184        let mut errors = Vec::new();
185        let warnings = Vec::new();
186
187        // Check if type exists
188        if !self.context.types.contains_key(type_name) {
189            return CompileTimeValidationResult {
190                valid: false,
191                errors: vec![CompileTimeError {
192                    field:      type_name.to_string(),
193                    message:    format!("Type '{}' not found in schema", type_name),
194                    suggestion: None,
195                }],
196                warnings,
197                sql_constraint: None,
198            };
199        }
200
201        // Extract field references from expression
202        let field_refs = self.extract_field_references(expression);
203
204        // Validate each field reference
205        for field_name in field_refs {
206            let field_key = (type_name.to_string(), field_name.clone());
207            if !self.context.fields.contains_key(&field_key) {
208                errors.push(CompileTimeError {
209                    field:      field_name.clone(),
210                    message:    format!("Field '{}' not found in type '{}'", field_name, type_name),
211                    suggestion: Some(self.suggest_field(type_name, &field_name)),
212                });
213            }
214        }
215
216        // Check for valid operators
217        let valid_operators = vec!["<", ">", "<=", ">=", "==", "!=", "&&", "||", "!"];
218        for op in valid_operators {
219            if expression.contains(op) {
220                // Operator found and is valid
221            }
222        }
223
224        CompileTimeValidationResult {
225            valid: errors.is_empty(),
226            errors,
227            warnings,
228            sql_constraint: None,
229        }
230    }
231
232    /// Extract field references from an expression
233    fn extract_field_references(&self, expression: &str) -> Vec<String> {
234        let mut fields = HashSet::new();
235        let mut tokens = Vec::new();
236        let mut current_token = String::new();
237        let mut in_string = false;
238        let mut string_char = ' ';
239        let mut escape = false;
240
241        // First pass: tokenize the expression, respecting quotes
242        for ch in expression.chars() {
243            // Handle escape sequences
244            if escape {
245                escape = false;
246                current_token.push(ch);
247                continue;
248            }
249
250            if ch == '\\' && in_string {
251                escape = true;
252                current_token.push(ch);
253                continue;
254            }
255
256            // Track if we're inside a quoted string
257            if !in_string && (ch == '"' || ch == '\'') {
258                in_string = true;
259                string_char = ch;
260                current_token.push(ch);
261            } else if in_string && ch == string_char {
262                in_string = false;
263                current_token.push(ch);
264            } else if !in_string && (ch.is_whitespace() || ch == '(' || ch == ')') {
265                if !current_token.is_empty() {
266                    tokens.push(current_token.clone());
267                    current_token.clear();
268                }
269            } else {
270                current_token.push(ch);
271            }
272        }
273
274        if !current_token.is_empty() {
275            tokens.push(current_token);
276        }
277
278        // Second pass: extract field references from tokens
279        let infix_operators = ["matches", "in", "contains"];
280
281        for (i, token) in tokens.iter().enumerate() {
282            // Skip quoted strings
283            if token.starts_with('"') || token.starts_with('\'') {
284                continue;
285            }
286
287            // Skip if this token is an infix operator
288            if infix_operators.contains(&token.as_str()) {
289                continue;
290            }
291
292            // Skip if the previous token was an infix operator (it's the RHS of the operator)
293            if i > 0 && infix_operators.contains(&tokens[i - 1].as_str()) {
294                continue;
295            }
296
297            // Skip reserved keywords
298            if token == "true"
299                || token == "false"
300                || token == "null"
301                || token == "and"
302                || token == "or"
303                || token == "not"
304            {
305                continue;
306            }
307
308            // Skip if starts with uppercase (likely type names, not field references)
309            if token.chars().next().is_some_and(|ch| ch.is_uppercase()) {
310                continue;
311            }
312
313            // Extract field references (lowercase identifiers)
314            if token.chars().next().is_some_and(|ch| ch.is_lowercase()) {
315                fields.insert(token.clone());
316            }
317        }
318
319        fields.into_iter().collect()
320    }
321
322    /// Generate SQL constraint from cross-field rule
323    fn generate_sql_constraint(
324        &self,
325        _type_name: &str,
326        left_field: &str,
327        operator: &str,
328        right_field: &str,
329        left_type: &FieldType,
330        _right_type: &FieldType,
331    ) -> Option<String> {
332        // Map ELO operators to SQL operators
333        let sql_op = match operator {
334            "<" | "lt" => "<",
335            "<=" | "lte" => "<=",
336            ">" | "gt" => ">",
337            ">=" | "gte" => ">=",
338            "==" | "eq" => "=",
339            "!=" | "neq" => "!=",
340            _ => return None,
341        };
342
343        // Build constraint based on field type
344        let constraint = match left_type {
345            FieldType::Date | FieldType::DateTime => {
346                format!("CHECK ({} {} {})", left_field, sql_op, right_field)
347            },
348            FieldType::Integer | FieldType::Float => {
349                format!("CHECK ({} {} {})", left_field, sql_op, right_field)
350            },
351            FieldType::String => {
352                format!("CHECK ({} {} {})", left_field, sql_op, right_field)
353            },
354            _ => return None,
355        };
356
357        Some(constraint)
358    }
359
360    /// Suggest a field name if typo is likely
361    fn suggest_field(&self, type_name: &str, _attempted_field: &str) -> String {
362        let Some(type_def) = self.context.types.get(type_name) else {
363            return "Check schema definition".to_string();
364        };
365
366        // Simple suggestion: show available fields
367        let available = type_def.fields.join(", ");
368        format!("Available fields: {}", available)
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    fn create_test_context() -> SchemaContext {
377        let mut types = HashMap::new();
378        let mut fields = HashMap::new();
379
380        // Create User type
381        types.insert(
382            "User".to_string(),
383            TypeDef {
384                name:   "User".to_string(),
385                fields: vec![
386                    "email".to_string(),
387                    "age".to_string(),
388                    "birthDate".to_string(),
389                    "verified".to_string(),
390                ],
391            },
392        );
393
394        fields.insert(("User".to_string(), "email".to_string()), FieldType::String);
395        fields.insert(("User".to_string(), "age".to_string()), FieldType::Integer);
396        fields.insert(("User".to_string(), "birthDate".to_string()), FieldType::Date);
397        fields.insert(("User".to_string(), "verified".to_string()), FieldType::Boolean);
398
399        // Create DateRange type
400        types.insert(
401            "DateRange".to_string(),
402            TypeDef {
403                name:   "DateRange".to_string(),
404                fields: vec!["startDate".to_string(), "endDate".to_string()],
405            },
406        );
407
408        fields.insert(("DateRange".to_string(), "startDate".to_string()), FieldType::Date);
409        fields.insert(("DateRange".to_string(), "endDate".to_string()), FieldType::Date);
410
411        SchemaContext { types, fields }
412    }
413
414    // ========== CROSS-FIELD RULE VALIDATION ==========
415
416    #[test]
417    fn test_valid_cross_field_comparison() {
418        let context = create_test_context();
419        let validator = CompileTimeValidator::new(context);
420
421        let result = validator.validate_cross_field_rule("DateRange", "startDate", "<", "endDate");
422
423        assert!(result.valid);
424        assert!(result.errors.is_empty());
425        assert!(result.sql_constraint.is_some());
426    }
427
428    #[test]
429    fn test_cross_field_type_mismatch() {
430        let context = create_test_context();
431        let validator = CompileTimeValidator::new(context);
432
433        let result = validator.validate_cross_field_rule("User", "age", "<", "verified");
434
435        assert!(!result.valid);
436        assert!(!result.errors.is_empty());
437        assert_eq!(result.errors[0].field, "age < verified");
438    }
439
440    #[test]
441    fn test_cross_field_left_field_not_found() {
442        let context = create_test_context();
443        let validator = CompileTimeValidator::new(context);
444
445        let result = validator.validate_cross_field_rule("User", "nonexistent", "<", "age");
446
447        assert!(!result.valid);
448        assert_eq!(result.errors[0].field, "nonexistent");
449        assert!(result.errors[0].message.contains("not found"));
450    }
451
452    #[test]
453    fn test_cross_field_right_field_not_found() {
454        let context = create_test_context();
455        let validator = CompileTimeValidator::new(context);
456
457        let result = validator.validate_cross_field_rule("User", "age", "<", "nonexistent");
458
459        assert!(!result.valid);
460        assert_eq!(result.errors[0].field, "nonexistent");
461    }
462
463    #[test]
464    fn test_cross_field_type_not_found() {
465        let context = create_test_context();
466        let validator = CompileTimeValidator::new(context);
467
468        let result = validator.validate_cross_field_rule("NonexistentType", "field", "<", "field2");
469
470        assert!(!result.valid);
471        assert!(result.errors[0].message.contains("not found"));
472    }
473
474    // ========== TYPE COMPATIBILITY ==========
475
476    #[test]
477    fn test_same_types_compatible() {
478        let left = FieldType::Integer;
479        let right = FieldType::Integer;
480        assert!(left.is_comparable_with(&right));
481    }
482
483    #[test]
484    fn test_numeric_types_compatible() {
485        let left = FieldType::Integer;
486        let right = FieldType::Float;
487        assert!(left.is_comparable_with(&right));
488    }
489
490    #[test]
491    fn test_date_datetime_compatible() {
492        let left = FieldType::Date;
493        let right = FieldType::DateTime;
494        assert!(left.is_comparable_with(&right));
495    }
496
497    #[test]
498    fn test_string_number_incompatible() {
499        let left = FieldType::String;
500        let right = FieldType::Integer;
501        assert!(!left.is_comparable_with(&right));
502    }
503
504    #[test]
505    fn test_boolean_incompatible_with_numbers() {
506        let left = FieldType::Boolean;
507        let right = FieldType::Integer;
508        assert!(!left.is_comparable_with(&right));
509    }
510
511    // ========== SQL CONSTRAINT GENERATION ==========
512
513    #[test]
514    fn test_sql_constraint_generated() {
515        let context = create_test_context();
516        let validator = CompileTimeValidator::new(context);
517
518        let result = validator.validate_cross_field_rule("DateRange", "startDate", "<=", "endDate");
519
520        assert!(result.valid);
521        assert!(result.sql_constraint.is_some());
522        let sql = result.sql_constraint.unwrap();
523        assert!(sql.contains("CHECK"));
524        assert!(sql.contains("startDate"));
525        assert!(sql.contains("<="));
526        assert!(sql.contains("endDate"));
527    }
528
529    #[test]
530    fn test_sql_constraint_with_different_operators() {
531        let context = create_test_context();
532        let validator = CompileTimeValidator::new(context);
533
534        let operators = vec!["<", ">", "<=", ">=", "==", "!="];
535        for op in operators {
536            let result =
537                validator.validate_cross_field_rule("DateRange", "startDate", op, "endDate");
538
539            assert!(result.valid);
540            let sql = result.sql_constraint.unwrap();
541            assert!(sql.contains(op) || op == "==" && sql.contains("="));
542        }
543    }
544
545    // ========== ELO EXPRESSION VALIDATION ==========
546
547    #[test]
548    fn test_valid_elo_expression() {
549        let context = create_test_context();
550        let validator = CompileTimeValidator::new(context);
551
552        let result = validator.validate_elo_expression("User", "age >= 18 && verified == true");
553
554        assert!(result.valid);
555        assert!(result.errors.is_empty());
556    }
557
558    #[test]
559    fn test_elo_expression_unknown_field() {
560        let context = create_test_context();
561        let validator = CompileTimeValidator::new(context);
562
563        let result = validator.validate_elo_expression("User", "nonexistent >= 18");
564
565        assert!(!result.valid);
566        assert!(!result.errors.is_empty());
567    }
568
569    #[test]
570    fn test_elo_expression_type_not_found() {
571        let context = create_test_context();
572        let validator = CompileTimeValidator::new(context);
573
574        let result = validator.validate_elo_expression("NonexistentType", "age >= 18");
575
576        assert!(!result.valid);
577    }
578
579    #[test]
580    fn test_elo_field_reference_extraction() {
581        let context = create_test_context();
582        let validator = CompileTimeValidator::new(context);
583
584        let fields = validator.extract_field_references("age >= 18 && verified == true");
585
586        assert!(fields.contains(&"age".to_string()));
587        assert!(fields.contains(&"verified".to_string()));
588    }
589
590    #[test]
591    fn test_elo_field_extraction_with_strings() {
592        let context = create_test_context();
593        let validator = CompileTimeValidator::new(context);
594
595        let fields = validator.extract_field_references("email matches \"pattern\" && age > 10");
596
597        assert!(fields.contains(&"email".to_string()));
598        assert!(fields.contains(&"age".to_string()));
599        assert!(!fields.contains(&"pattern".to_string())); // Inside quotes
600    }
601
602    // ========== REAL-WORLD PATTERNS ==========
603
604    #[test]
605    fn test_date_range_validation() {
606        let context = create_test_context();
607        let validator = CompileTimeValidator::new(context);
608
609        let result = validator.validate_cross_field_rule("DateRange", "startDate", "<=", "endDate");
610
611        assert!(result.valid);
612        let sql = result.sql_constraint.unwrap();
613        assert!(sql.contains("CHECK"));
614    }
615
616    #[test]
617    fn test_age_constraint() {
618        let context = create_test_context();
619        let validator = CompileTimeValidator::new(context);
620
621        let result = validator.validate_elo_expression("User", "age >= 18 && age <= 120");
622
623        assert!(result.valid);
624    }
625
626    #[test]
627    fn test_email_field_validation() {
628        let context = create_test_context();
629        let validator = CompileTimeValidator::new(context);
630
631        let result = validator.validate_elo_expression(
632            "User",
633            "email matches \"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\\\.[a-zA-Z]{2,}$\"",
634        );
635
636        assert!(result.valid);
637    }
638
639    #[test]
640    fn test_complex_user_validation() {
641        let context = create_test_context();
642        let validator = CompileTimeValidator::new(context);
643
644        let result = validator.validate_elo_expression(
645            "User",
646            "email matches pattern && age >= 18 && verified == true",
647        );
648
649        assert!(result.valid);
650    }
651
652    #[test]
653    fn test_suggestion_on_typo() {
654        let context = create_test_context();
655        let validator = CompileTimeValidator::new(context);
656
657        let result = validator.validate_cross_field_rule("User", "typ0", "<", "age");
658
659        assert!(!result.valid);
660        assert!(result.errors[0].suggestion.is_some());
661        assert!(result.errors[0].suggestion.as_ref().unwrap().contains("Available fields"));
662    }
663}