1use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone)]
15pub struct SchemaContext {
16 pub types: HashMap<String, TypeDef>,
18 pub fields: HashMap<(String, String), FieldType>,
20}
21
22#[derive(Debug, Clone)]
24pub struct TypeDef {
25 pub name: String,
27 pub fields: Vec<String>,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33#[non_exhaustive]
34pub enum FieldType {
35 String,
37 Integer,
39 Float,
41 Boolean,
43 Date,
45 DateTime,
47 Custom(String),
49}
50
51impl FieldType {
52 pub fn is_comparable_with(&self, other: &FieldType) -> bool {
54 match (self, other) {
55 (a, b) if a == b => true,
57 (FieldType::Integer, FieldType::Float)
59 | (FieldType::Float, FieldType::Integer)
60 | (FieldType::Date, FieldType::DateTime)
61 | (FieldType::DateTime, FieldType::Date) => true,
62 _ => false,
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct CompileTimeValidationResult {
71 pub valid: bool,
73 pub errors: Vec<CompileTimeError>,
75 pub warnings: Vec<String>,
77 pub sql_constraint: Option<String>,
79}
80
81#[derive(Debug, Clone)]
83pub struct CompileTimeError {
84 pub field: String,
86 pub message: String,
88 pub suggestion: Option<String>,
90}
91
92#[derive(Debug)]
94pub struct CompileTimeValidator {
95 context: SchemaContext,
96}
97
98impl CompileTimeValidator {
99 pub const fn new(context: SchemaContext) -> Self {
101 Self { context }
102 }
103
104 pub fn validate_cross_field_rule(
106 &self,
107 type_name: &str,
108 left_field: &str,
109 operator: &str,
110 right_field: &str,
111 ) -> CompileTimeValidationResult {
112 let mut errors = Vec::new();
113 let warnings = Vec::new();
114
115 if !self.context.types.contains_key(type_name) {
117 return CompileTimeValidationResult {
118 valid: false,
119 errors: vec![CompileTimeError {
120 field: type_name.to_string(),
121 message: format!("Type '{}' not found in schema", type_name),
122 suggestion: Some("Check that the type is defined".to_string()),
123 }],
124 warnings,
125 sql_constraint: None,
126 };
127 }
128
129 let left_key = (type_name.to_string(), left_field.to_string());
131 let Some(left_type) = self.context.fields.get(&left_key) else {
132 errors.push(CompileTimeError {
133 field: left_field.to_string(),
134 message: format!("Field '{}' not found in type '{}'", left_field, type_name),
135 suggestion: Some(self.suggest_field(type_name, left_field)),
136 });
137 return CompileTimeValidationResult {
138 valid: false,
139 errors,
140 warnings,
141 sql_constraint: None,
142 };
143 };
144
145 let right_key = (type_name.to_string(), right_field.to_string());
147 let Some(right_type) = self.context.fields.get(&right_key) else {
148 errors.push(CompileTimeError {
149 field: right_field.to_string(),
150 message: format!("Field '{}' not found in type '{}'", right_field, type_name),
151 suggestion: Some(self.suggest_field(type_name, right_field)),
152 });
153 return CompileTimeValidationResult {
154 valid: false,
155 errors,
156 warnings,
157 sql_constraint: None,
158 };
159 };
160
161 if !left_type.is_comparable_with(right_type) {
163 errors.push(CompileTimeError {
164 field: format!("{} {} {}", left_field, operator, right_field),
165 message: format!("Cannot compare {:?} with {:?}", left_type, right_type),
166 suggestion: Some("Ensure both fields have comparable types".to_string()),
167 });
168 return CompileTimeValidationResult {
169 valid: false,
170 errors,
171 warnings,
172 sql_constraint: None,
173 };
174 }
175
176 let sql_constraint = self.generate_sql_constraint(
178 type_name,
179 left_field,
180 operator,
181 right_field,
182 left_type,
183 right_type,
184 );
185
186 CompileTimeValidationResult {
187 valid: true,
188 errors,
189 warnings,
190 sql_constraint,
191 }
192 }
193
194 pub fn validate_elo_expression(
196 &self,
197 type_name: &str,
198 expression: &str,
199 ) -> CompileTimeValidationResult {
200 let mut errors = Vec::new();
201 let warnings = Vec::new();
202
203 if !self.context.types.contains_key(type_name) {
205 return CompileTimeValidationResult {
206 valid: false,
207 errors: vec![CompileTimeError {
208 field: type_name.to_string(),
209 message: format!("Type '{}' not found in schema", type_name),
210 suggestion: None,
211 }],
212 warnings,
213 sql_constraint: None,
214 };
215 }
216
217 let field_refs = self.extract_field_references(expression);
219
220 for field_name in field_refs {
222 let field_key = (type_name.to_string(), field_name.clone());
223 if !self.context.fields.contains_key(&field_key) {
224 errors.push(CompileTimeError {
225 field: field_name.clone(),
226 message: format!("Field '{}' not found in type '{}'", field_name, type_name),
227 suggestion: Some(self.suggest_field(type_name, &field_name)),
228 });
229 }
230 }
231
232 let valid_operators = vec!["<", ">", "<=", ">=", "==", "!=", "&&", "||", "!"];
234 for op in valid_operators {
235 if expression.contains(op) {
236 }
238 }
239
240 CompileTimeValidationResult {
241 valid: errors.is_empty(),
242 errors,
243 warnings,
244 sql_constraint: None,
245 }
246 }
247
248 fn extract_field_references(&self, expression: &str) -> Vec<String> {
250 let mut fields = HashSet::new();
251 let mut tokens = Vec::new();
252 let mut current_token = String::new();
253 let mut in_string = false;
254 let mut string_char = ' ';
255 let mut escape = false;
256
257 for ch in expression.chars() {
259 if escape {
261 escape = false;
262 current_token.push(ch);
263 continue;
264 }
265
266 if ch == '\\' && in_string {
267 escape = true;
268 current_token.push(ch);
269 continue;
270 }
271
272 if !in_string && (ch == '"' || ch == '\'') {
274 in_string = true;
275 string_char = ch;
276 current_token.push(ch);
277 } else if in_string && ch == string_char {
278 in_string = false;
279 current_token.push(ch);
280 } else if !in_string && (ch.is_whitespace() || ch == '(' || ch == ')') {
281 if !current_token.is_empty() {
282 tokens.push(current_token.clone());
283 current_token.clear();
284 }
285 } else {
286 current_token.push(ch);
287 }
288 }
289
290 if !current_token.is_empty() {
291 tokens.push(current_token);
292 }
293
294 let infix_operators = ["matches", "in", "contains"];
296
297 for (i, token) in tokens.iter().enumerate() {
298 if token.starts_with('"') || token.starts_with('\'') {
300 continue;
301 }
302
303 if infix_operators.contains(&token.as_str()) {
305 continue;
306 }
307
308 if i > 0 && infix_operators.contains(&tokens[i - 1].as_str()) {
310 continue;
311 }
312
313 if token == "true"
315 || token == "false"
316 || token == "null"
317 || token == "and"
318 || token == "or"
319 || token == "not"
320 {
321 continue;
322 }
323
324 if token.chars().next().is_some_and(|ch| ch.is_uppercase()) {
326 continue;
327 }
328
329 if token.chars().next().is_some_and(|ch| ch.is_lowercase()) {
331 fields.insert(token.clone());
332 }
333 }
334
335 fields.into_iter().collect()
336 }
337
338 fn generate_sql_constraint(
340 &self,
341 _type_name: &str,
342 left_field: &str,
343 operator: &str,
344 right_field: &str,
345 left_type: &FieldType,
346 _right_type: &FieldType,
347 ) -> Option<String> {
348 let sql_op = match operator {
350 "<" | "lt" => "<",
351 "<=" | "lte" => "<=",
352 ">" | "gt" => ">",
353 ">=" | "gte" => ">=",
354 "==" | "eq" => "=",
355 "!=" | "neq" => "!=",
356 _ => return None,
357 };
358
359 let left_quoted = format!("\"{}\"", left_field.replace('"', "\"\""));
361 let right_quoted = format!("\"{}\"", right_field.replace('"', "\"\""));
362 let constraint = match left_type {
363 FieldType::Date
364 | FieldType::DateTime
365 | FieldType::Integer
366 | FieldType::Float
367 | FieldType::String => {
368 format!("CHECK ({} {} {})", left_quoted, sql_op, right_quoted)
369 },
370 _ => return None,
371 };
372
373 Some(constraint)
374 }
375
376 fn suggest_field(&self, type_name: &str, _attempted_field: &str) -> String {
378 let Some(type_def) = self.context.types.get(type_name) else {
379 return "Check schema definition".to_string();
380 };
381
382 let available = type_def.fields.join(", ");
384 format!("Available fields: {}", available)
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 #![allow(clippy::unwrap_used)] use super::*;
393
394 fn create_test_context() -> SchemaContext {
395 let mut types = HashMap::new();
396 let mut fields = HashMap::new();
397
398 types.insert(
400 "User".to_string(),
401 TypeDef {
402 name: "User".to_string(),
403 fields: vec![
404 "email".to_string(),
405 "age".to_string(),
406 "birthDate".to_string(),
407 "verified".to_string(),
408 ],
409 },
410 );
411
412 fields.insert(("User".to_string(), "email".to_string()), FieldType::String);
413 fields.insert(("User".to_string(), "age".to_string()), FieldType::Integer);
414 fields.insert(("User".to_string(), "birthDate".to_string()), FieldType::Date);
415 fields.insert(("User".to_string(), "verified".to_string()), FieldType::Boolean);
416
417 types.insert(
419 "DateRange".to_string(),
420 TypeDef {
421 name: "DateRange".to_string(),
422 fields: vec!["startDate".to_string(), "endDate".to_string()],
423 },
424 );
425
426 fields.insert(("DateRange".to_string(), "startDate".to_string()), FieldType::Date);
427 fields.insert(("DateRange".to_string(), "endDate".to_string()), FieldType::Date);
428
429 SchemaContext { types, fields }
430 }
431
432 #[test]
435 fn test_valid_cross_field_comparison() {
436 let context = create_test_context();
437 let validator = CompileTimeValidator::new(context);
438
439 let result = validator.validate_cross_field_rule("DateRange", "startDate", "<", "endDate");
440
441 assert!(result.valid);
442 assert!(result.errors.is_empty());
443 assert!(result.sql_constraint.is_some());
444 }
445
446 #[test]
447 fn test_cross_field_type_mismatch() {
448 let context = create_test_context();
449 let validator = CompileTimeValidator::new(context);
450
451 let result = validator.validate_cross_field_rule("User", "age", "<", "verified");
452
453 assert!(!result.valid);
454 assert!(!result.errors.is_empty());
455 assert_eq!(result.errors[0].field, "age < verified");
456 }
457
458 #[test]
459 fn test_cross_field_left_field_not_found() {
460 let context = create_test_context();
461 let validator = CompileTimeValidator::new(context);
462
463 let result = validator.validate_cross_field_rule("User", "nonexistent", "<", "age");
464
465 assert!(!result.valid);
466 assert_eq!(result.errors[0].field, "nonexistent");
467 assert!(result.errors[0].message.contains("not found"));
468 }
469
470 #[test]
471 fn test_cross_field_right_field_not_found() {
472 let context = create_test_context();
473 let validator = CompileTimeValidator::new(context);
474
475 let result = validator.validate_cross_field_rule("User", "age", "<", "nonexistent");
476
477 assert!(!result.valid);
478 assert_eq!(result.errors[0].field, "nonexistent");
479 }
480
481 #[test]
482 fn test_cross_field_type_not_found() {
483 let context = create_test_context();
484 let validator = CompileTimeValidator::new(context);
485
486 let result = validator.validate_cross_field_rule("NonexistentType", "field", "<", "field2");
487
488 assert!(!result.valid);
489 assert!(result.errors[0].message.contains("not found"));
490 }
491
492 #[test]
495 fn test_same_types_compatible() {
496 let left = FieldType::Integer;
497 let right = FieldType::Integer;
498 assert!(left.is_comparable_with(&right));
499 }
500
501 #[test]
502 fn test_numeric_types_compatible() {
503 let left = FieldType::Integer;
504 let right = FieldType::Float;
505 assert!(left.is_comparable_with(&right));
506 }
507
508 #[test]
509 fn test_date_datetime_compatible() {
510 let left = FieldType::Date;
511 let right = FieldType::DateTime;
512 assert!(left.is_comparable_with(&right));
513 }
514
515 #[test]
516 fn test_string_number_incompatible() {
517 let left = FieldType::String;
518 let right = FieldType::Integer;
519 assert!(!left.is_comparable_with(&right));
520 }
521
522 #[test]
523 fn test_boolean_incompatible_with_numbers() {
524 let left = FieldType::Boolean;
525 let right = FieldType::Integer;
526 assert!(!left.is_comparable_with(&right));
527 }
528
529 #[test]
532 fn test_sql_constraint_generated() {
533 let context = create_test_context();
534 let validator = CompileTimeValidator::new(context);
535
536 let result = validator.validate_cross_field_rule("DateRange", "startDate", "<=", "endDate");
537
538 assert!(result.valid);
539 assert!(result.sql_constraint.is_some());
540 let sql = result
541 .sql_constraint
542 .expect("sql_constraint must be Some when result.valid is true");
543 assert!(sql.contains("CHECK"));
544 assert!(sql.contains("startDate"));
545 assert!(sql.contains("<="));
546 assert!(sql.contains("endDate"));
547 }
548
549 #[test]
550 fn test_sql_constraint_with_different_operators() {
551 let context = create_test_context();
552 let validator = CompileTimeValidator::new(context);
553
554 let operators = vec!["<", ">", "<=", ">=", "==", "!="];
555 for op in operators {
556 let result =
557 validator.validate_cross_field_rule("DateRange", "startDate", op, "endDate");
558
559 assert!(result.valid);
560 let sql =
561 result.sql_constraint.expect("sql_constraint must be Some for valid operator");
562 assert!(sql.contains(op) || op == "==" && sql.contains('='));
563 }
564 }
565
566 #[test]
569 fn test_valid_elo_expression() {
570 let context = create_test_context();
571 let validator = CompileTimeValidator::new(context);
572
573 let result = validator.validate_elo_expression("User", "age >= 18 && verified == true");
574
575 assert!(result.valid);
576 assert!(result.errors.is_empty());
577 }
578
579 #[test]
580 fn test_elo_expression_unknown_field() {
581 let context = create_test_context();
582 let validator = CompileTimeValidator::new(context);
583
584 let result = validator.validate_elo_expression("User", "nonexistent >= 18");
585
586 assert!(!result.valid);
587 assert!(!result.errors.is_empty());
588 }
589
590 #[test]
591 fn test_elo_expression_type_not_found() {
592 let context = create_test_context();
593 let validator = CompileTimeValidator::new(context);
594
595 let result = validator.validate_elo_expression("NonexistentType", "age >= 18");
596
597 assert!(!result.valid);
598 }
599
600 #[test]
601 fn test_elo_field_reference_extraction() {
602 let context = create_test_context();
603 let validator = CompileTimeValidator::new(context);
604
605 let fields = validator.extract_field_references("age >= 18 && verified == true");
606
607 assert!(fields.contains(&"age".to_string()));
608 assert!(fields.contains(&"verified".to_string()));
609 }
610
611 #[test]
612 fn test_elo_field_extraction_with_strings() {
613 let context = create_test_context();
614 let validator = CompileTimeValidator::new(context);
615
616 let fields = validator.extract_field_references("email matches \"pattern\" && age > 10");
617
618 assert!(fields.contains(&"email".to_string()));
619 assert!(fields.contains(&"age".to_string()));
620 assert!(!fields.contains(&"pattern".to_string())); }
622
623 #[test]
626 fn test_date_range_validation() {
627 let context = create_test_context();
628 let validator = CompileTimeValidator::new(context);
629
630 let result = validator.validate_cross_field_rule("DateRange", "startDate", "<=", "endDate");
631
632 assert!(result.valid);
633 let sql = result
634 .sql_constraint
635 .expect("sql_constraint must be Some when result.valid is true");
636 assert!(sql.contains("CHECK"));
637 }
638
639 #[test]
640 fn test_age_constraint() {
641 let context = create_test_context();
642 let validator = CompileTimeValidator::new(context);
643
644 let result = validator.validate_elo_expression("User", "age >= 18 && age <= 120");
645
646 assert!(result.valid);
647 }
648
649 #[test]
650 fn test_email_field_validation() {
651 let context = create_test_context();
652 let validator = CompileTimeValidator::new(context);
653
654 let result = validator.validate_elo_expression(
655 "User",
656 "email matches \"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\\\.[a-zA-Z]{2,}$\"",
657 );
658
659 assert!(result.valid);
660 }
661
662 #[test]
663 fn test_complex_user_validation() {
664 let context = create_test_context();
665 let validator = CompileTimeValidator::new(context);
666
667 let result = validator.validate_elo_expression(
668 "User",
669 "email matches pattern && age >= 18 && verified == true",
670 );
671
672 assert!(result.valid);
673 }
674
675 #[test]
676 fn test_suggestion_on_typo() {
677 let context = create_test_context();
678 let validator = CompileTimeValidator::new(context);
679
680 let result = validator.validate_cross_field_rule("User", "typ0", "<", "age");
681
682 assert!(!result.valid);
683 assert!(result.errors[0].suggestion.is_some());
684 assert!(
685 result.errors[0]
686 .suggestion
687 .as_ref()
688 .expect("suggestion must be Some for typo error")
689 .contains("Available fields")
690 );
691 }
692}