Skip to main content

fraiseql_core/validation/
compile_time.rs

1//! Compile-time validation for cross-field rules and schema consistency.
2//!
3//! This module validates Elo expressions at schema compilation time, ensuring:
4//! - Field references exist and are properly typed
5//! - Cross-field rules reference compatible types
6//! - SQL constraints can be generated
7//! - No circular dependencies or invalid rules
8//!
9//! Elo is an expression language by Bernard Lambeau: <https://elo-lang.org/>
10
11use std::collections::{HashMap, HashSet};
12
13/// Schema context for compile-time validation
14#[derive(Debug, Clone)]
15pub struct SchemaContext {
16    /// Type definitions: `type_name` -> fields
17    pub types:  HashMap<String, TypeDef>,
18    /// Field types: (`type_name`, `field_name`) -> `field_type`
19    pub fields: HashMap<(String, String), FieldType>,
20}
21
22/// Type definition
23#[derive(Debug, Clone)]
24pub struct TypeDef {
25    /// Name of the GraphQL type.
26    pub name:   String,
27    /// Names of the fields declared on this type.
28    pub fields: Vec<String>,
29}
30
31/// Field type information
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33#[non_exhaustive]
34pub enum FieldType {
35    /// UTF-8 text field.
36    String,
37    /// 64-bit signed integer field.
38    Integer,
39    /// 64-bit floating-point field.
40    Float,
41    /// Boolean field.
42    Boolean,
43    /// Calendar date without time.
44    Date,
45    /// Timestamp with timezone.
46    DateTime,
47    /// User-defined scalar type; the inner string holds the type name.
48    Custom(String),
49}
50
51impl FieldType {
52    /// Check if two types are comparable
53    pub fn is_comparable_with(&self, other: &FieldType) -> bool {
54        match (self, other) {
55            // Same types are always comparable
56            (a, b) if a == b => true,
57            // Numeric types are comparable with each other; date/datetime are interchangeable
58            (FieldType::Integer, FieldType::Float)
59            | (FieldType::Float, FieldType::Integer)
60            | (FieldType::Date, FieldType::DateTime)
61            | (FieldType::DateTime, FieldType::Date) => true,
62            // Everything else is not comparable
63            _ => false,
64        }
65    }
66}
67
68/// Compile-time validation result
69#[derive(Debug, Clone)]
70pub struct CompileTimeValidationResult {
71    /// Whether the rule is valid.
72    pub valid:          bool,
73    /// All errors found during validation.
74    pub errors:         Vec<CompileTimeError>,
75    /// Non-fatal warnings from validation.
76    pub warnings:       Vec<String>,
77    /// SQL constraint expression derived from the rule, if generation succeeded.
78    pub sql_constraint: Option<String>,
79}
80
81/// Compile-time validation error
82#[derive(Debug, Clone)]
83pub struct CompileTimeError {
84    /// Field path where the error occurred.
85    pub field:      String,
86    /// Human-readable error description.
87    pub message:    String,
88    /// Optional suggestion for how to fix the error.
89    pub suggestion: Option<String>,
90}
91
92/// Compile-time validator for cross-field rules
93#[derive(Debug)]
94pub struct CompileTimeValidator {
95    context: SchemaContext,
96}
97
98impl CompileTimeValidator {
99    /// Create a new compile-time validator
100    pub const fn new(context: SchemaContext) -> Self {
101        Self { context }
102    }
103
104    /// Validate a cross-field rule
105    pub fn validate_cross_field_rule(
106        &self,
107        type_name: &str,
108        left_field: &str,
109        operator: &str,
110        right_field: &str,
111    ) -> CompileTimeValidationResult {
112        let mut errors = Vec::new();
113        let warnings = Vec::new();
114
115        // Check if type exists
116        if !self.context.types.contains_key(type_name) {
117            return CompileTimeValidationResult {
118                valid: false,
119                errors: vec![CompileTimeError {
120                    field:      type_name.to_string(),
121                    message:    format!("Type '{}' not found in schema", type_name),
122                    suggestion: Some("Check that the type is defined".to_string()),
123                }],
124                warnings,
125                sql_constraint: None,
126            };
127        }
128
129        // Check if left field exists
130        let left_key = (type_name.to_string(), left_field.to_string());
131        let Some(left_type) = self.context.fields.get(&left_key) else {
132            errors.push(CompileTimeError {
133                field:      left_field.to_string(),
134                message:    format!("Field '{}' not found in type '{}'", left_field, type_name),
135                suggestion: Some(self.suggest_field(type_name, left_field)),
136            });
137            return CompileTimeValidationResult {
138                valid: false,
139                errors,
140                warnings,
141                sql_constraint: None,
142            };
143        };
144
145        // Check if right field exists
146        let right_key = (type_name.to_string(), right_field.to_string());
147        let Some(right_type) = self.context.fields.get(&right_key) else {
148            errors.push(CompileTimeError {
149                field:      right_field.to_string(),
150                message:    format!("Field '{}' not found in type '{}'", right_field, type_name),
151                suggestion: Some(self.suggest_field(type_name, right_field)),
152            });
153            return CompileTimeValidationResult {
154                valid: false,
155                errors,
156                warnings,
157                sql_constraint: None,
158            };
159        };
160
161        // Check if types are comparable
162        if !left_type.is_comparable_with(right_type) {
163            errors.push(CompileTimeError {
164                field:      format!("{} {} {}", left_field, operator, right_field),
165                message:    format!("Cannot compare {:?} with {:?}", left_type, right_type),
166                suggestion: Some("Ensure both fields have comparable types".to_string()),
167            });
168            return CompileTimeValidationResult {
169                valid: false,
170                errors,
171                warnings,
172                sql_constraint: None,
173            };
174        }
175
176        // Generate SQL constraint
177        let sql_constraint = self.generate_sql_constraint(
178            type_name,
179            left_field,
180            operator,
181            right_field,
182            left_type,
183            right_type,
184        );
185
186        CompileTimeValidationResult {
187            valid: true,
188            errors,
189            warnings,
190            sql_constraint,
191        }
192    }
193
194    /// Validate an ELO expression at compile time
195    pub fn validate_elo_expression(
196        &self,
197        type_name: &str,
198        expression: &str,
199    ) -> CompileTimeValidationResult {
200        let mut errors = Vec::new();
201        let warnings = Vec::new();
202
203        // Check if type exists
204        if !self.context.types.contains_key(type_name) {
205            return CompileTimeValidationResult {
206                valid: false,
207                errors: vec![CompileTimeError {
208                    field:      type_name.to_string(),
209                    message:    format!("Type '{}' not found in schema", type_name),
210                    suggestion: None,
211                }],
212                warnings,
213                sql_constraint: None,
214            };
215        }
216
217        // Extract field references from expression
218        let field_refs = self.extract_field_references(expression);
219
220        // Validate each field reference
221        for field_name in field_refs {
222            let field_key = (type_name.to_string(), field_name.clone());
223            if !self.context.fields.contains_key(&field_key) {
224                errors.push(CompileTimeError {
225                    field:      field_name.clone(),
226                    message:    format!("Field '{}' not found in type '{}'", field_name, type_name),
227                    suggestion: Some(self.suggest_field(type_name, &field_name)),
228                });
229            }
230        }
231
232        // Check for valid operators
233        let valid_operators = vec!["<", ">", "<=", ">=", "==", "!=", "&&", "||", "!"];
234        for op in valid_operators {
235            if expression.contains(op) {
236                // Operator found and is valid
237            }
238        }
239
240        CompileTimeValidationResult {
241            valid: errors.is_empty(),
242            errors,
243            warnings,
244            sql_constraint: None,
245        }
246    }
247
248    /// Extract field references from an expression
249    fn extract_field_references(&self, expression: &str) -> Vec<String> {
250        let mut fields = HashSet::new();
251        let mut tokens = Vec::new();
252        let mut current_token = String::new();
253        let mut in_string = false;
254        let mut string_char = ' ';
255        let mut escape = false;
256
257        // First pass: tokenize the expression, respecting quotes
258        for ch in expression.chars() {
259            // Handle escape sequences
260            if escape {
261                escape = false;
262                current_token.push(ch);
263                continue;
264            }
265
266            if ch == '\\' && in_string {
267                escape = true;
268                current_token.push(ch);
269                continue;
270            }
271
272            // Track if we're inside a quoted string
273            if !in_string && (ch == '"' || ch == '\'') {
274                in_string = true;
275                string_char = ch;
276                current_token.push(ch);
277            } else if in_string && ch == string_char {
278                in_string = false;
279                current_token.push(ch);
280            } else if !in_string && (ch.is_whitespace() || ch == '(' || ch == ')') {
281                if !current_token.is_empty() {
282                    tokens.push(current_token.clone());
283                    current_token.clear();
284                }
285            } else {
286                current_token.push(ch);
287            }
288        }
289
290        if !current_token.is_empty() {
291            tokens.push(current_token);
292        }
293
294        // Second pass: extract field references from tokens
295        let infix_operators = ["matches", "in", "contains"];
296
297        for (i, token) in tokens.iter().enumerate() {
298            // Skip quoted strings
299            if token.starts_with('"') || token.starts_with('\'') {
300                continue;
301            }
302
303            // Skip if this token is an infix operator
304            if infix_operators.contains(&token.as_str()) {
305                continue;
306            }
307
308            // Skip if the previous token was an infix operator (it's the RHS of the operator)
309            if i > 0 && infix_operators.contains(&tokens[i - 1].as_str()) {
310                continue;
311            }
312
313            // Skip reserved keywords
314            if token == "true"
315                || token == "false"
316                || token == "null"
317                || token == "and"
318                || token == "or"
319                || token == "not"
320            {
321                continue;
322            }
323
324            // Skip if starts with uppercase (likely type names, not field references)
325            if token.chars().next().is_some_and(|ch| ch.is_uppercase()) {
326                continue;
327            }
328
329            // Extract field references (lowercase identifiers)
330            if token.chars().next().is_some_and(|ch| ch.is_lowercase()) {
331                fields.insert(token.clone());
332            }
333        }
334
335        fields.into_iter().collect()
336    }
337
338    /// Generate SQL constraint from cross-field rule
339    fn generate_sql_constraint(
340        &self,
341        _type_name: &str,
342        left_field: &str,
343        operator: &str,
344        right_field: &str,
345        left_type: &FieldType,
346        _right_type: &FieldType,
347    ) -> Option<String> {
348        // Map ELO operators to SQL operators
349        let sql_op = match operator {
350            "<" | "lt" => "<",
351            "<=" | "lte" => "<=",
352            ">" | "gt" => ">",
353            ">=" | "gte" => ">=",
354            "==" | "eq" => "=",
355            "!=" | "neq" => "!=",
356            _ => return None,
357        };
358
359        // Build constraint based on field type, quoting column names to avoid SQL injection.
360        let left_quoted = format!("\"{}\"", left_field.replace('"', "\"\""));
361        let right_quoted = format!("\"{}\"", right_field.replace('"', "\"\""));
362        let constraint = match left_type {
363            FieldType::Date
364            | FieldType::DateTime
365            | FieldType::Integer
366            | FieldType::Float
367            | FieldType::String => {
368                format!("CHECK ({} {} {})", left_quoted, sql_op, right_quoted)
369            },
370            _ => return None,
371        };
372
373        Some(constraint)
374    }
375
376    /// Suggest a field name if typo is likely
377    fn suggest_field(&self, type_name: &str, _attempted_field: &str) -> String {
378        let Some(type_def) = self.context.types.get(type_name) else {
379            return "Check schema definition".to_string();
380        };
381
382        // Simple suggestion: show available fields
383        let available = type_def.fields.join(", ");
384        format!("Available fields: {}", available)
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
391
392    use super::*;
393
394    fn create_test_context() -> SchemaContext {
395        let mut types = HashMap::new();
396        let mut fields = HashMap::new();
397
398        // Create User type
399        types.insert(
400            "User".to_string(),
401            TypeDef {
402                name:   "User".to_string(),
403                fields: vec![
404                    "email".to_string(),
405                    "age".to_string(),
406                    "birthDate".to_string(),
407                    "verified".to_string(),
408                ],
409            },
410        );
411
412        fields.insert(("User".to_string(), "email".to_string()), FieldType::String);
413        fields.insert(("User".to_string(), "age".to_string()), FieldType::Integer);
414        fields.insert(("User".to_string(), "birthDate".to_string()), FieldType::Date);
415        fields.insert(("User".to_string(), "verified".to_string()), FieldType::Boolean);
416
417        // Create DateRange type
418        types.insert(
419            "DateRange".to_string(),
420            TypeDef {
421                name:   "DateRange".to_string(),
422                fields: vec!["startDate".to_string(), "endDate".to_string()],
423            },
424        );
425
426        fields.insert(("DateRange".to_string(), "startDate".to_string()), FieldType::Date);
427        fields.insert(("DateRange".to_string(), "endDate".to_string()), FieldType::Date);
428
429        SchemaContext { types, fields }
430    }
431
432    // ========== CROSS-FIELD RULE VALIDATION ==========
433
434    #[test]
435    fn test_valid_cross_field_comparison() {
436        let context = create_test_context();
437        let validator = CompileTimeValidator::new(context);
438
439        let result = validator.validate_cross_field_rule("DateRange", "startDate", "<", "endDate");
440
441        assert!(result.valid);
442        assert!(result.errors.is_empty());
443        assert!(result.sql_constraint.is_some());
444    }
445
446    #[test]
447    fn test_cross_field_type_mismatch() {
448        let context = create_test_context();
449        let validator = CompileTimeValidator::new(context);
450
451        let result = validator.validate_cross_field_rule("User", "age", "<", "verified");
452
453        assert!(!result.valid);
454        assert!(!result.errors.is_empty());
455        assert_eq!(result.errors[0].field, "age < verified");
456    }
457
458    #[test]
459    fn test_cross_field_left_field_not_found() {
460        let context = create_test_context();
461        let validator = CompileTimeValidator::new(context);
462
463        let result = validator.validate_cross_field_rule("User", "nonexistent", "<", "age");
464
465        assert!(!result.valid);
466        assert_eq!(result.errors[0].field, "nonexistent");
467        assert!(result.errors[0].message.contains("not found"));
468    }
469
470    #[test]
471    fn test_cross_field_right_field_not_found() {
472        let context = create_test_context();
473        let validator = CompileTimeValidator::new(context);
474
475        let result = validator.validate_cross_field_rule("User", "age", "<", "nonexistent");
476
477        assert!(!result.valid);
478        assert_eq!(result.errors[0].field, "nonexistent");
479    }
480
481    #[test]
482    fn test_cross_field_type_not_found() {
483        let context = create_test_context();
484        let validator = CompileTimeValidator::new(context);
485
486        let result = validator.validate_cross_field_rule("NonexistentType", "field", "<", "field2");
487
488        assert!(!result.valid);
489        assert!(result.errors[0].message.contains("not found"));
490    }
491
492    // ========== TYPE COMPATIBILITY ==========
493
494    #[test]
495    fn test_same_types_compatible() {
496        let left = FieldType::Integer;
497        let right = FieldType::Integer;
498        assert!(left.is_comparable_with(&right));
499    }
500
501    #[test]
502    fn test_numeric_types_compatible() {
503        let left = FieldType::Integer;
504        let right = FieldType::Float;
505        assert!(left.is_comparable_with(&right));
506    }
507
508    #[test]
509    fn test_date_datetime_compatible() {
510        let left = FieldType::Date;
511        let right = FieldType::DateTime;
512        assert!(left.is_comparable_with(&right));
513    }
514
515    #[test]
516    fn test_string_number_incompatible() {
517        let left = FieldType::String;
518        let right = FieldType::Integer;
519        assert!(!left.is_comparable_with(&right));
520    }
521
522    #[test]
523    fn test_boolean_incompatible_with_numbers() {
524        let left = FieldType::Boolean;
525        let right = FieldType::Integer;
526        assert!(!left.is_comparable_with(&right));
527    }
528
529    // ========== SQL CONSTRAINT GENERATION ==========
530
531    #[test]
532    fn test_sql_constraint_generated() {
533        let context = create_test_context();
534        let validator = CompileTimeValidator::new(context);
535
536        let result = validator.validate_cross_field_rule("DateRange", "startDate", "<=", "endDate");
537
538        assert!(result.valid);
539        assert!(result.sql_constraint.is_some());
540        let sql = result
541            .sql_constraint
542            .expect("sql_constraint must be Some when result.valid is true");
543        assert!(sql.contains("CHECK"));
544        assert!(sql.contains("startDate"));
545        assert!(sql.contains("<="));
546        assert!(sql.contains("endDate"));
547    }
548
549    #[test]
550    fn test_sql_constraint_with_different_operators() {
551        let context = create_test_context();
552        let validator = CompileTimeValidator::new(context);
553
554        let operators = vec!["<", ">", "<=", ">=", "==", "!="];
555        for op in operators {
556            let result =
557                validator.validate_cross_field_rule("DateRange", "startDate", op, "endDate");
558
559            assert!(result.valid);
560            let sql =
561                result.sql_constraint.expect("sql_constraint must be Some for valid operator");
562            assert!(sql.contains(op) || op == "==" && sql.contains('='));
563        }
564    }
565
566    // ========== ELO EXPRESSION VALIDATION ==========
567
568    #[test]
569    fn test_valid_elo_expression() {
570        let context = create_test_context();
571        let validator = CompileTimeValidator::new(context);
572
573        let result = validator.validate_elo_expression("User", "age >= 18 && verified == true");
574
575        assert!(result.valid);
576        assert!(result.errors.is_empty());
577    }
578
579    #[test]
580    fn test_elo_expression_unknown_field() {
581        let context = create_test_context();
582        let validator = CompileTimeValidator::new(context);
583
584        let result = validator.validate_elo_expression("User", "nonexistent >= 18");
585
586        assert!(!result.valid);
587        assert!(!result.errors.is_empty());
588    }
589
590    #[test]
591    fn test_elo_expression_type_not_found() {
592        let context = create_test_context();
593        let validator = CompileTimeValidator::new(context);
594
595        let result = validator.validate_elo_expression("NonexistentType", "age >= 18");
596
597        assert!(!result.valid);
598    }
599
600    #[test]
601    fn test_elo_field_reference_extraction() {
602        let context = create_test_context();
603        let validator = CompileTimeValidator::new(context);
604
605        let fields = validator.extract_field_references("age >= 18 && verified == true");
606
607        assert!(fields.contains(&"age".to_string()));
608        assert!(fields.contains(&"verified".to_string()));
609    }
610
611    #[test]
612    fn test_elo_field_extraction_with_strings() {
613        let context = create_test_context();
614        let validator = CompileTimeValidator::new(context);
615
616        let fields = validator.extract_field_references("email matches \"pattern\" && age > 10");
617
618        assert!(fields.contains(&"email".to_string()));
619        assert!(fields.contains(&"age".to_string()));
620        assert!(!fields.contains(&"pattern".to_string())); // Inside quotes
621    }
622
623    // ========== REAL-WORLD PATTERNS ==========
624
625    #[test]
626    fn test_date_range_validation() {
627        let context = create_test_context();
628        let validator = CompileTimeValidator::new(context);
629
630        let result = validator.validate_cross_field_rule("DateRange", "startDate", "<=", "endDate");
631
632        assert!(result.valid);
633        let sql = result
634            .sql_constraint
635            .expect("sql_constraint must be Some when result.valid is true");
636        assert!(sql.contains("CHECK"));
637    }
638
639    #[test]
640    fn test_age_constraint() {
641        let context = create_test_context();
642        let validator = CompileTimeValidator::new(context);
643
644        let result = validator.validate_elo_expression("User", "age >= 18 && age <= 120");
645
646        assert!(result.valid);
647    }
648
649    #[test]
650    fn test_email_field_validation() {
651        let context = create_test_context();
652        let validator = CompileTimeValidator::new(context);
653
654        let result = validator.validate_elo_expression(
655            "User",
656            "email matches \"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\\\.[a-zA-Z]{2,}$\"",
657        );
658
659        assert!(result.valid);
660    }
661
662    #[test]
663    fn test_complex_user_validation() {
664        let context = create_test_context();
665        let validator = CompileTimeValidator::new(context);
666
667        let result = validator.validate_elo_expression(
668            "User",
669            "email matches pattern && age >= 18 && verified == true",
670        );
671
672        assert!(result.valid);
673    }
674
675    #[test]
676    fn test_suggestion_on_typo() {
677        let context = create_test_context();
678        let validator = CompileTimeValidator::new(context);
679
680        let result = validator.validate_cross_field_rule("User", "typ0", "<", "age");
681
682        assert!(!result.valid);
683        assert!(result.errors[0].suggestion.is_some());
684        assert!(
685            result.errors[0]
686                .suggestion
687                .as_ref()
688                .expect("suggestion must be Some for typo error")
689                .contains("Available fields")
690        );
691    }
692}