1use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone)]
15pub struct SchemaContext {
16 pub types: HashMap<String, TypeDef>,
18 pub fields: HashMap<(String, String), FieldType>,
20}
21
22#[derive(Debug, Clone)]
24pub struct TypeDef {
25 pub name: String,
26 pub fields: Vec<String>,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Hash)]
31pub enum FieldType {
32 String,
33 Integer,
34 Float,
35 Boolean,
36 Date,
37 DateTime,
38 Custom(String),
39}
40
41impl FieldType {
42 pub fn is_comparable_with(&self, other: &FieldType) -> bool {
44 match (self, other) {
45 (a, b) if a == b => true,
47 (FieldType::Integer, FieldType::Float) => true,
49 (FieldType::Float, FieldType::Integer) => true,
50 (FieldType::Date, FieldType::DateTime) => true,
52 (FieldType::DateTime, FieldType::Date) => true,
53 _ => false,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct CompileTimeValidationResult {
62 pub valid: bool,
63 pub errors: Vec<CompileTimeError>,
64 pub warnings: Vec<String>,
65 pub sql_constraint: Option<String>,
66}
67
68#[derive(Debug, Clone)]
70pub struct CompileTimeError {
71 pub field: String,
72 pub message: String,
73 pub suggestion: Option<String>,
74}
75
76#[derive(Debug)]
78pub struct CompileTimeValidator {
79 context: SchemaContext,
80}
81
82impl CompileTimeValidator {
83 pub fn new(context: SchemaContext) -> Self {
85 Self { context }
86 }
87
88 pub fn validate_cross_field_rule(
90 &self,
91 type_name: &str,
92 left_field: &str,
93 operator: &str,
94 right_field: &str,
95 ) -> CompileTimeValidationResult {
96 let mut errors = Vec::new();
97 let warnings = Vec::new();
98
99 if !self.context.types.contains_key(type_name) {
101 return CompileTimeValidationResult {
102 valid: false,
103 errors: vec![CompileTimeError {
104 field: type_name.to_string(),
105 message: format!("Type '{}' not found in schema", type_name),
106 suggestion: Some("Check that the type is defined".to_string()),
107 }],
108 warnings,
109 sql_constraint: None,
110 };
111 }
112
113 let left_key = (type_name.to_string(), left_field.to_string());
115 let Some(left_type) = self.context.fields.get(&left_key) else {
116 errors.push(CompileTimeError {
117 field: left_field.to_string(),
118 message: format!("Field '{}' not found in type '{}'", left_field, type_name),
119 suggestion: Some(self.suggest_field(type_name, left_field)),
120 });
121 return CompileTimeValidationResult {
122 valid: false,
123 errors,
124 warnings,
125 sql_constraint: None,
126 };
127 };
128
129 let right_key = (type_name.to_string(), right_field.to_string());
131 let Some(right_type) = self.context.fields.get(&right_key) else {
132 errors.push(CompileTimeError {
133 field: right_field.to_string(),
134 message: format!("Field '{}' not found in type '{}'", right_field, type_name),
135 suggestion: Some(self.suggest_field(type_name, right_field)),
136 });
137 return CompileTimeValidationResult {
138 valid: false,
139 errors,
140 warnings,
141 sql_constraint: None,
142 };
143 };
144
145 if !left_type.is_comparable_with(right_type) {
147 errors.push(CompileTimeError {
148 field: format!("{} {} {}", left_field, operator, right_field),
149 message: format!("Cannot compare {:?} with {:?}", left_type, right_type),
150 suggestion: Some(format!("Ensure both fields have comparable types")),
151 });
152 return CompileTimeValidationResult {
153 valid: false,
154 errors,
155 warnings,
156 sql_constraint: None,
157 };
158 }
159
160 let sql_constraint = self.generate_sql_constraint(
162 type_name,
163 left_field,
164 operator,
165 right_field,
166 left_type,
167 right_type,
168 );
169
170 CompileTimeValidationResult {
171 valid: true,
172 errors,
173 warnings,
174 sql_constraint,
175 }
176 }
177
178 pub fn validate_elo_expression(
180 &self,
181 type_name: &str,
182 expression: &str,
183 ) -> CompileTimeValidationResult {
184 let mut errors = Vec::new();
185 let warnings = Vec::new();
186
187 if !self.context.types.contains_key(type_name) {
189 return CompileTimeValidationResult {
190 valid: false,
191 errors: vec![CompileTimeError {
192 field: type_name.to_string(),
193 message: format!("Type '{}' not found in schema", type_name),
194 suggestion: None,
195 }],
196 warnings,
197 sql_constraint: None,
198 };
199 }
200
201 let field_refs = self.extract_field_references(expression);
203
204 for field_name in field_refs {
206 let field_key = (type_name.to_string(), field_name.clone());
207 if !self.context.fields.contains_key(&field_key) {
208 errors.push(CompileTimeError {
209 field: field_name.clone(),
210 message: format!("Field '{}' not found in type '{}'", field_name, type_name),
211 suggestion: Some(self.suggest_field(type_name, &field_name)),
212 });
213 }
214 }
215
216 let valid_operators = vec!["<", ">", "<=", ">=", "==", "!=", "&&", "||", "!"];
218 for op in valid_operators {
219 if expression.contains(op) {
220 }
222 }
223
224 CompileTimeValidationResult {
225 valid: errors.is_empty(),
226 errors,
227 warnings,
228 sql_constraint: None,
229 }
230 }
231
232 fn extract_field_references(&self, expression: &str) -> Vec<String> {
234 let mut fields = HashSet::new();
235 let mut tokens = Vec::new();
236 let mut current_token = String::new();
237 let mut in_string = false;
238 let mut string_char = ' ';
239 let mut escape = false;
240
241 for ch in expression.chars() {
243 if escape {
245 escape = false;
246 current_token.push(ch);
247 continue;
248 }
249
250 if ch == '\\' && in_string {
251 escape = true;
252 current_token.push(ch);
253 continue;
254 }
255
256 if !in_string && (ch == '"' || ch == '\'') {
258 in_string = true;
259 string_char = ch;
260 current_token.push(ch);
261 } else if in_string && ch == string_char {
262 in_string = false;
263 current_token.push(ch);
264 } else if !in_string && (ch.is_whitespace() || ch == '(' || ch == ')') {
265 if !current_token.is_empty() {
266 tokens.push(current_token.clone());
267 current_token.clear();
268 }
269 } else {
270 current_token.push(ch);
271 }
272 }
273
274 if !current_token.is_empty() {
275 tokens.push(current_token);
276 }
277
278 let infix_operators = ["matches", "in", "contains"];
280
281 for (i, token) in tokens.iter().enumerate() {
282 if token.starts_with('"') || token.starts_with('\'') {
284 continue;
285 }
286
287 if infix_operators.contains(&token.as_str()) {
289 continue;
290 }
291
292 if i > 0 && infix_operators.contains(&tokens[i - 1].as_str()) {
294 continue;
295 }
296
297 if token == "true"
299 || token == "false"
300 || token == "null"
301 || token == "and"
302 || token == "or"
303 || token == "not"
304 {
305 continue;
306 }
307
308 if token.chars().next().is_some_and(|ch| ch.is_uppercase()) {
310 continue;
311 }
312
313 if token.chars().next().is_some_and(|ch| ch.is_lowercase()) {
315 fields.insert(token.clone());
316 }
317 }
318
319 fields.into_iter().collect()
320 }
321
322 fn generate_sql_constraint(
324 &self,
325 _type_name: &str,
326 left_field: &str,
327 operator: &str,
328 right_field: &str,
329 left_type: &FieldType,
330 _right_type: &FieldType,
331 ) -> Option<String> {
332 let sql_op = match operator {
334 "<" | "lt" => "<",
335 "<=" | "lte" => "<=",
336 ">" | "gt" => ">",
337 ">=" | "gte" => ">=",
338 "==" | "eq" => "=",
339 "!=" | "neq" => "!=",
340 _ => return None,
341 };
342
343 let constraint = match left_type {
345 FieldType::Date | FieldType::DateTime => {
346 format!("CHECK ({} {} {})", left_field, sql_op, right_field)
347 },
348 FieldType::Integer | FieldType::Float => {
349 format!("CHECK ({} {} {})", left_field, sql_op, right_field)
350 },
351 FieldType::String => {
352 format!("CHECK ({} {} {})", left_field, sql_op, right_field)
353 },
354 _ => return None,
355 };
356
357 Some(constraint)
358 }
359
360 fn suggest_field(&self, type_name: &str, _attempted_field: &str) -> String {
362 let Some(type_def) = self.context.types.get(type_name) else {
363 return "Check schema definition".to_string();
364 };
365
366 let available = type_def.fields.join(", ");
368 format!("Available fields: {}", available)
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 fn create_test_context() -> SchemaContext {
377 let mut types = HashMap::new();
378 let mut fields = HashMap::new();
379
380 types.insert(
382 "User".to_string(),
383 TypeDef {
384 name: "User".to_string(),
385 fields: vec![
386 "email".to_string(),
387 "age".to_string(),
388 "birthDate".to_string(),
389 "verified".to_string(),
390 ],
391 },
392 );
393
394 fields.insert(("User".to_string(), "email".to_string()), FieldType::String);
395 fields.insert(("User".to_string(), "age".to_string()), FieldType::Integer);
396 fields.insert(("User".to_string(), "birthDate".to_string()), FieldType::Date);
397 fields.insert(("User".to_string(), "verified".to_string()), FieldType::Boolean);
398
399 types.insert(
401 "DateRange".to_string(),
402 TypeDef {
403 name: "DateRange".to_string(),
404 fields: vec!["startDate".to_string(), "endDate".to_string()],
405 },
406 );
407
408 fields.insert(("DateRange".to_string(), "startDate".to_string()), FieldType::Date);
409 fields.insert(("DateRange".to_string(), "endDate".to_string()), FieldType::Date);
410
411 SchemaContext { types, fields }
412 }
413
414 #[test]
417 fn test_valid_cross_field_comparison() {
418 let context = create_test_context();
419 let validator = CompileTimeValidator::new(context);
420
421 let result = validator.validate_cross_field_rule("DateRange", "startDate", "<", "endDate");
422
423 assert!(result.valid);
424 assert!(result.errors.is_empty());
425 assert!(result.sql_constraint.is_some());
426 }
427
428 #[test]
429 fn test_cross_field_type_mismatch() {
430 let context = create_test_context();
431 let validator = CompileTimeValidator::new(context);
432
433 let result = validator.validate_cross_field_rule("User", "age", "<", "verified");
434
435 assert!(!result.valid);
436 assert!(!result.errors.is_empty());
437 assert_eq!(result.errors[0].field, "age < verified");
438 }
439
440 #[test]
441 fn test_cross_field_left_field_not_found() {
442 let context = create_test_context();
443 let validator = CompileTimeValidator::new(context);
444
445 let result = validator.validate_cross_field_rule("User", "nonexistent", "<", "age");
446
447 assert!(!result.valid);
448 assert_eq!(result.errors[0].field, "nonexistent");
449 assert!(result.errors[0].message.contains("not found"));
450 }
451
452 #[test]
453 fn test_cross_field_right_field_not_found() {
454 let context = create_test_context();
455 let validator = CompileTimeValidator::new(context);
456
457 let result = validator.validate_cross_field_rule("User", "age", "<", "nonexistent");
458
459 assert!(!result.valid);
460 assert_eq!(result.errors[0].field, "nonexistent");
461 }
462
463 #[test]
464 fn test_cross_field_type_not_found() {
465 let context = create_test_context();
466 let validator = CompileTimeValidator::new(context);
467
468 let result = validator.validate_cross_field_rule("NonexistentType", "field", "<", "field2");
469
470 assert!(!result.valid);
471 assert!(result.errors[0].message.contains("not found"));
472 }
473
474 #[test]
477 fn test_same_types_compatible() {
478 let left = FieldType::Integer;
479 let right = FieldType::Integer;
480 assert!(left.is_comparable_with(&right));
481 }
482
483 #[test]
484 fn test_numeric_types_compatible() {
485 let left = FieldType::Integer;
486 let right = FieldType::Float;
487 assert!(left.is_comparable_with(&right));
488 }
489
490 #[test]
491 fn test_date_datetime_compatible() {
492 let left = FieldType::Date;
493 let right = FieldType::DateTime;
494 assert!(left.is_comparable_with(&right));
495 }
496
497 #[test]
498 fn test_string_number_incompatible() {
499 let left = FieldType::String;
500 let right = FieldType::Integer;
501 assert!(!left.is_comparable_with(&right));
502 }
503
504 #[test]
505 fn test_boolean_incompatible_with_numbers() {
506 let left = FieldType::Boolean;
507 let right = FieldType::Integer;
508 assert!(!left.is_comparable_with(&right));
509 }
510
511 #[test]
514 fn test_sql_constraint_generated() {
515 let context = create_test_context();
516 let validator = CompileTimeValidator::new(context);
517
518 let result = validator.validate_cross_field_rule("DateRange", "startDate", "<=", "endDate");
519
520 assert!(result.valid);
521 assert!(result.sql_constraint.is_some());
522 let sql = result.sql_constraint.unwrap();
523 assert!(sql.contains("CHECK"));
524 assert!(sql.contains("startDate"));
525 assert!(sql.contains("<="));
526 assert!(sql.contains("endDate"));
527 }
528
529 #[test]
530 fn test_sql_constraint_with_different_operators() {
531 let context = create_test_context();
532 let validator = CompileTimeValidator::new(context);
533
534 let operators = vec!["<", ">", "<=", ">=", "==", "!="];
535 for op in operators {
536 let result =
537 validator.validate_cross_field_rule("DateRange", "startDate", op, "endDate");
538
539 assert!(result.valid);
540 let sql = result.sql_constraint.unwrap();
541 assert!(sql.contains(op) || op == "==" && sql.contains("="));
542 }
543 }
544
545 #[test]
548 fn test_valid_elo_expression() {
549 let context = create_test_context();
550 let validator = CompileTimeValidator::new(context);
551
552 let result = validator.validate_elo_expression("User", "age >= 18 && verified == true");
553
554 assert!(result.valid);
555 assert!(result.errors.is_empty());
556 }
557
558 #[test]
559 fn test_elo_expression_unknown_field() {
560 let context = create_test_context();
561 let validator = CompileTimeValidator::new(context);
562
563 let result = validator.validate_elo_expression("User", "nonexistent >= 18");
564
565 assert!(!result.valid);
566 assert!(!result.errors.is_empty());
567 }
568
569 #[test]
570 fn test_elo_expression_type_not_found() {
571 let context = create_test_context();
572 let validator = CompileTimeValidator::new(context);
573
574 let result = validator.validate_elo_expression("NonexistentType", "age >= 18");
575
576 assert!(!result.valid);
577 }
578
579 #[test]
580 fn test_elo_field_reference_extraction() {
581 let context = create_test_context();
582 let validator = CompileTimeValidator::new(context);
583
584 let fields = validator.extract_field_references("age >= 18 && verified == true");
585
586 assert!(fields.contains(&"age".to_string()));
587 assert!(fields.contains(&"verified".to_string()));
588 }
589
590 #[test]
591 fn test_elo_field_extraction_with_strings() {
592 let context = create_test_context();
593 let validator = CompileTimeValidator::new(context);
594
595 let fields = validator.extract_field_references("email matches \"pattern\" && age > 10");
596
597 assert!(fields.contains(&"email".to_string()));
598 assert!(fields.contains(&"age".to_string()));
599 assert!(!fields.contains(&"pattern".to_string())); }
601
602 #[test]
605 fn test_date_range_validation() {
606 let context = create_test_context();
607 let validator = CompileTimeValidator::new(context);
608
609 let result = validator.validate_cross_field_rule("DateRange", "startDate", "<=", "endDate");
610
611 assert!(result.valid);
612 let sql = result.sql_constraint.unwrap();
613 assert!(sql.contains("CHECK"));
614 }
615
616 #[test]
617 fn test_age_constraint() {
618 let context = create_test_context();
619 let validator = CompileTimeValidator::new(context);
620
621 let result = validator.validate_elo_expression("User", "age >= 18 && age <= 120");
622
623 assert!(result.valid);
624 }
625
626 #[test]
627 fn test_email_field_validation() {
628 let context = create_test_context();
629 let validator = CompileTimeValidator::new(context);
630
631 let result = validator.validate_elo_expression(
632 "User",
633 "email matches \"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\\\.[a-zA-Z]{2,}$\"",
634 );
635
636 assert!(result.valid);
637 }
638
639 #[test]
640 fn test_complex_user_validation() {
641 let context = create_test_context();
642 let validator = CompileTimeValidator::new(context);
643
644 let result = validator.validate_elo_expression(
645 "User",
646 "email matches pattern && age >= 18 && verified == true",
647 );
648
649 assert!(result.valid);
650 }
651
652 #[test]
653 fn test_suggestion_on_typo() {
654 let context = create_test_context();
655 let validator = CompileTimeValidator::new(context);
656
657 let result = validator.validate_cross_field_rule("User", "typ0", "<", "age");
658
659 assert!(!result.valid);
660 assert!(result.errors[0].suggestion.is_some());
661 assert!(result.errors[0].suggestion.as_ref().unwrap().contains("Available fields"));
662 }
663}