1use std::collections::HashMap;
11use std::sync::Arc;
12
13use super::errors::CheckError;
14use super::overload::{finalize_type, resolve_overload, substitute_type};
15use super::scope::ScopeStack;
16use crate::eval::proto_registry::ProtoTypeResolver;
17use crate::types::{
18 BinaryOp, CelType, CelValue, ComprehensionData, Expr, FunctionDecl, ListElement, MapEntry,
19 ResolvedProtoType, SpannedExpr, StructField, UnaryOp, VariableDecl,
20};
21
22#[derive(Debug, Clone)]
24pub struct ReferenceInfo {
25 pub name: String,
27 pub overload_ids: Vec<String>,
29 pub value: Option<CelValue>,
31 pub enum_type: Option<String>,
33}
34
35impl ReferenceInfo {
36 pub fn ident(name: impl Into<String>) -> Self {
38 Self {
39 name: name.into(),
40 overload_ids: Vec::new(),
41 value: None,
42 enum_type: None,
43 }
44 }
45
46 pub fn function(name: impl Into<String>, overload_ids: Vec<String>) -> Self {
48 Self {
49 name: name.into(),
50 overload_ids,
51 value: None,
52 enum_type: None,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct CheckResult {
60 pub type_map: HashMap<i64, CelType>,
62 pub reference_map: HashMap<i64, ReferenceInfo>,
64 pub errors: Vec<CheckError>,
66}
67
68impl CheckResult {
69 pub fn is_ok(&self) -> bool {
71 self.errors.is_empty()
72 }
73
74 pub fn get_type(&self, expr_id: i64) -> Option<&CelType> {
76 self.type_map.get(&expr_id)
77 }
78
79 pub fn get_reference(&self, expr_id: i64) -> Option<&ReferenceInfo> {
81 self.reference_map.get(&expr_id)
82 }
83}
84
85pub struct Checker<'a> {
90 scopes: ScopeStack,
92 functions: &'a HashMap<String, FunctionDecl>,
94 container: &'a str,
96 type_map: HashMap<i64, CelType>,
98 reference_map: HashMap<i64, ReferenceInfo>,
100 errors: Vec<CheckError>,
102 substitutions: HashMap<Arc<str>, CelType>,
104 type_resolver: Option<&'a dyn ProtoTypeResolver>,
106 abbreviations: Option<&'a HashMap<String, String>>,
108}
109
110impl<'a> Checker<'a> {
111 pub fn new(
118 variables: &HashMap<String, CelType>,
119 functions: &'a HashMap<String, FunctionDecl>,
120 container: &'a str,
121 ) -> Self {
122 let mut scopes = ScopeStack::new();
123
124 for (name, cel_type) in variables {
126 scopes.add_variable(name, cel_type.clone());
127 }
128
129 Self {
130 scopes,
131 functions,
132 container,
133 type_map: HashMap::new(),
134 reference_map: HashMap::new(),
135 errors: Vec::new(),
136 substitutions: HashMap::new(),
137 type_resolver: None,
138 abbreviations: None,
139 }
140 }
141
142 pub fn with_type_resolver(mut self, resolver: &'a dyn ProtoTypeResolver) -> Self {
144 self.type_resolver = Some(resolver);
145 self
146 }
147
148 pub fn with_abbreviations(mut self, abbreviations: &'a HashMap<String, String>) -> Self {
150 self.abbreviations = Some(abbreviations);
151 self
152 }
153
154 pub fn check(mut self, expr: &SpannedExpr) -> CheckResult {
156 self.check_expr(expr);
157 self.finalize_types();
158
159 CheckResult {
160 type_map: self.type_map,
161 reference_map: self.reference_map,
162 errors: self.errors,
163 }
164 }
165
166 fn set_type(&mut self, expr_id: i64, cel_type: CelType) {
168 self.type_map.insert(expr_id, cel_type);
169 }
170
171 fn set_reference(&mut self, expr_id: i64, reference: ReferenceInfo) {
173 self.reference_map.insert(expr_id, reference);
174 }
175
176 fn report_error(&mut self, error: CheckError) {
178 self.errors.push(error);
179 }
180
181 fn finalize_types(&mut self) {
183 for ty in self.type_map.values_mut() {
184 *ty = finalize_type(ty);
185 *ty = substitute_type(ty, &self.substitutions);
186 *ty = finalize_type(ty);
187 }
188 }
189
190 fn check_expr(&mut self, expr: &SpannedExpr) -> CelType {
192 let result = match &expr.node {
193 Expr::Null => CelType::Null,
194 Expr::Bool(_) => CelType::Bool,
195 Expr::Int(_) => CelType::Int,
196 Expr::UInt(_) => CelType::UInt,
197 Expr::Float(_) => CelType::Double,
198 Expr::String(_) => CelType::String,
199 Expr::Bytes(_) => CelType::Bytes,
200
201 Expr::Ident(name) => self.check_ident(name, expr),
202 Expr::RootIdent(name) => self.check_ident(name, expr),
203
204 Expr::List(elements) => self.check_list(elements, expr),
205 Expr::Map(entries) => self.check_map(entries, expr),
206
207 Expr::Unary { op, expr: inner } => self.check_unary(*op, inner, expr),
208 Expr::Binary { op, left, right } => self.check_binary(*op, left, right, expr),
209 Expr::Ternary {
210 cond,
211 then_expr,
212 else_expr,
213 } => self.check_ternary(cond, then_expr, else_expr, expr),
214
215 Expr::Member {
216 expr: obj,
217 field,
218 optional,
219 } => self.check_member(obj, field, *optional, expr),
220 Expr::Index {
221 expr: obj,
222 index,
223 optional,
224 } => self.check_index(obj, index, *optional, expr),
225 Expr::Call { expr: callee, args } => self.check_call(callee, args, expr),
226 Expr::Struct { type_name, fields } => self.check_struct(type_name, fields, expr),
227
228 Expr::Comprehension(comp) => self.check_comprehension(comp, expr),
229
230 Expr::MemberTestOnly { expr: obj, field } => self.check_member_test(obj, field, expr),
231
232 Expr::Bind {
233 var_name,
234 init,
235 body,
236 } => self.check_bind(var_name, init, body, expr),
237
238 Expr::Error => CelType::Error,
239 };
240
241 self.set_type(expr.id, result.clone());
242 result
243 }
244
245 fn check_ident(&mut self, name: &str, expr: &SpannedExpr) -> CelType {
252 if self.scopes.is_local(name) {
254 if let Some(decl) = self.scopes.resolve(name) {
255 let cel_type = decl.cel_type.clone();
256 self.set_reference(expr.id, ReferenceInfo::ident(name));
257 return cel_type;
258 }
259 }
260
261 if !self.container.is_empty() {
263 let mut container = self.container;
264 loop {
265 let qualified = format!("{}.{}", container, name);
266 if let Some(decl) = self.scopes.resolve(&qualified) {
267 let cel_type = decl.cel_type.clone();
268 self.set_reference(expr.id, ReferenceInfo::ident(&qualified));
269 return cel_type;
270 }
271 match container.rfind('.') {
272 Some(pos) => container = &container[..pos],
273 None => break,
274 }
275 }
276 }
277
278 if let Some(decl) = self.scopes.resolve(name) {
280 let cel_type = decl.cel_type.clone();
281 self.set_reference(expr.id, ReferenceInfo::ident(name));
282 return cel_type;
283 }
284
285 self.report_error(CheckError::undeclared_reference(
286 name,
287 expr.span.clone(),
288 expr.id,
289 ));
290 CelType::Error
291 }
292
293 fn check_list(&mut self, elements: &[ListElement], _expr: &SpannedExpr) -> CelType {
295 if elements.is_empty() {
296 return CelType::list(CelType::fresh_type_var());
297 }
298
299 let mut elem_types = Vec::new();
300 for elem in elements {
301 let elem_type = self.check_expr(&elem.expr);
302 elem_types.push(elem_type);
303 }
304
305 let joined = self.join_types(&elem_types);
306 CelType::list(joined)
307 }
308
309 fn check_map(&mut self, entries: &[MapEntry], _expr: &SpannedExpr) -> CelType {
311 if entries.is_empty() {
312 return CelType::map(CelType::fresh_type_var(), CelType::fresh_type_var());
313 }
314
315 let mut key_types = Vec::new();
316 let mut value_types = Vec::new();
317
318 for entry in entries {
319 let key_type = self.check_expr(&entry.key);
320 let value_type = self.check_expr(&entry.value);
321 let value_type = if entry.optional {
323 match value_type {
324 CelType::Optional(inner) => (*inner).clone(),
325 other => other,
326 }
327 } else {
328 value_type
329 };
330 key_types.push(key_type);
331 value_types.push(value_type);
332 }
333
334 let key_joined = self.join_types(&key_types);
335 let value_joined = self.join_types(&value_types);
336
337 CelType::map(key_joined, value_joined)
338 }
339
340 fn join_types(&self, types: &[CelType]) -> CelType {
345 if types.is_empty() {
346 return CelType::fresh_type_var();
347 }
348
349 let mut best = &types[0];
351 for candidate in types {
352 if type_specificity(candidate) > type_specificity(best) {
353 if types.iter().all(|t| candidate.is_assignable_from(t)) {
355 best = candidate;
356 }
357 }
358 }
359
360 let all_compatible = types
362 .iter()
363 .all(|t| best.is_assignable_from(t) || t.is_assignable_from(best));
364
365 if all_compatible {
366 best.clone()
367 } else {
368 CelType::Dyn
369 }
370 }
371
372 fn check_unary(&mut self, op: UnaryOp, inner: &SpannedExpr, expr: &SpannedExpr) -> CelType {
374 let inner_type = self.check_expr(inner);
375
376 let func_name = match op {
377 UnaryOp::Neg => "-_",
378 UnaryOp::Not => "!_",
379 };
380
381 self.resolve_function_call(func_name, None, &[inner_type], expr)
382 }
383
384 fn check_binary(
386 &mut self,
387 op: BinaryOp,
388 left: &SpannedExpr,
389 right: &SpannedExpr,
390 expr: &SpannedExpr,
391 ) -> CelType {
392 let left_type = self.check_expr(left);
393 let right_type = self.check_expr(right);
394
395 let func_name = binary_op_to_function(op);
396
397 self.resolve_function_call(func_name, None, &[left_type, right_type], expr)
398 }
399
400 fn check_ternary(
402 &mut self,
403 cond: &SpannedExpr,
404 then_expr: &SpannedExpr,
405 else_expr: &SpannedExpr,
406 expr: &SpannedExpr,
407 ) -> CelType {
408 let cond_type = self.check_expr(cond);
409 let then_type = self.check_expr(then_expr);
410 let else_type = self.check_expr(else_expr);
411
412 if !matches!(cond_type, CelType::Bool | CelType::Dyn | CelType::Error) {
414 self.report_error(CheckError::type_mismatch(
415 CelType::Bool,
416 cond_type,
417 cond.span.clone(),
418 cond.id,
419 ));
420 }
421
422 self.resolve_function_call("_?_:_", None, &[CelType::Bool, then_type, else_type], expr)
424 }
425
426 fn check_member(
428 &mut self,
429 obj: &SpannedExpr,
430 field: &str,
431 optional: bool,
432 expr: &SpannedExpr,
433 ) -> CelType {
434 if !self.leftmost_ident_resolves(obj) {
438 if let Some(qualified_name) = self.try_qualified_name(obj, field) {
439 if let Some(decl) = self.resolve_qualified(&qualified_name) {
441 let cel_type = decl.cel_type.clone();
442 self.set_reference(expr.id, ReferenceInfo::ident(&qualified_name));
443 return cel_type;
444 }
445
446 if let Some(resolved) = self.resolve_proto_qualified(&qualified_name, expr) {
448 return resolved;
449 }
450 }
451 }
452
453 let obj_type = self.check_expr(obj);
455
456 let (inner_type, was_optional) = match &obj_type {
458 CelType::Optional(inner) => ((**inner).clone(), true),
459 other => (other.clone(), false),
460 };
461
462 let result = match &inner_type {
464 CelType::Message(name) => {
465 if let Some(registry) = self.type_resolver {
467 if let Some(field_type) = registry.get_field_type(name, field) {
468 return self.wrap_optional_if_needed(field_type, optional, was_optional);
469 }
470 if registry.has_message(name) && !registry.is_extension(name, field) {
473 self.report_error(CheckError::undefined_field(
474 name,
475 field,
476 expr.span.clone(),
477 expr.id,
478 ));
479 return CelType::Error;
480 }
481 }
482 CelType::Dyn
484 }
485 CelType::Dyn | CelType::TypeVar(_) => {
486 CelType::Dyn
488 }
489 CelType::Map(_, value_type) => {
490 (**value_type).clone()
492 }
493 _ => {
494 self.report_error(CheckError::undefined_field(
496 &inner_type.display_name(),
497 field,
498 expr.span.clone(),
499 expr.id,
500 ));
501 return CelType::Error;
502 }
503 };
504
505 self.wrap_optional_if_needed(result, optional, was_optional)
506 }
507
508 fn wrap_optional_if_needed(
510 &self,
511 result: CelType,
512 optional: bool,
513 was_optional: bool,
514 ) -> CelType {
515 if optional || was_optional {
518 match &result {
519 CelType::Optional(_) => result, _ => CelType::optional(result),
521 }
522 } else {
523 result
524 }
525 }
526
527 fn resolve_proto_qualified(
529 &mut self,
530 qualified_name: &str,
531 expr: &SpannedExpr,
532 ) -> Option<CelType> {
533 let registry = self.type_resolver?;
534
535 let expanded_name = self.expand_abbreviation(qualified_name);
537 let name_to_resolve = expanded_name.as_deref().unwrap_or(qualified_name);
538 let parts: Vec<&str> = name_to_resolve.split('.').collect();
539
540 match registry.resolve_qualified(&parts, self.container)? {
541 ResolvedProtoType::EnumValue { enum_name, value } => {
542 self.set_reference(
543 expr.id,
544 ReferenceInfo {
545 name: name_to_resolve.to_string(),
546 overload_ids: vec![],
547 value: Some(CelValue::Int(value as i64)),
548 enum_type: Some(enum_name),
549 },
550 );
551 Some(CelType::Int)
552 }
553 ResolvedProtoType::Enum { name, cel_type } => {
554 self.set_reference(expr.id, ReferenceInfo::ident(&name));
555 Some(cel_type)
556 }
557 ResolvedProtoType::Message { name, cel_type } => {
558 self.set_reference(expr.id, ReferenceInfo::ident(&name));
559 Some(cel_type)
560 }
561 }
562 }
563
564 fn try_enum_constructor(
569 &mut self,
570 name: &str,
571 args: &[SpannedExpr],
572 expr: &SpannedExpr,
573 ) -> Option<CelType> {
574 let registry = self.type_resolver?;
575
576 let expanded_name = self.expand_abbreviation(name);
578 let name_to_resolve = expanded_name.as_deref().unwrap_or(name);
579 let parts: Vec<&str> = name_to_resolve.split('.').collect();
580
581 match registry.resolve_qualified(&parts, self.container)? {
583 ResolvedProtoType::Enum {
584 name: enum_name, ..
585 } => {
586 if args.len() != 1 {
588 return None;
589 }
590 let arg_type = self.check_expr(&args[0]);
591
592 match &arg_type {
593 CelType::Int | CelType::String | CelType::Dyn => {
594 self.set_reference(
596 expr.id,
597 ReferenceInfo {
598 name: enum_name.clone(),
599 overload_ids: vec!["enum_constructor".to_string()],
600 value: None,
601 enum_type: Some(enum_name),
602 },
603 );
604 Some(CelType::Int)
605 }
606 _ => None,
607 }
608 }
609 _ => None,
610 }
611 }
612
613 fn expand_abbreviation(&self, name: &str) -> Option<String> {
619 let abbrevs = self.abbreviations?;
620
621 let first_segment = name.split('.').next()?;
623
624 if let Some(qualified) = abbrevs.get(first_segment) {
626 if let Some(rest) = name
628 .strip_prefix(first_segment)
629 .and_then(|s| s.strip_prefix('.'))
630 {
631 Some(format!("{}.{}", qualified, rest))
632 } else {
633 Some(qualified.clone())
634 }
635 } else {
636 None
637 }
638 }
639
640 fn leftmost_ident_resolves(&self, expr: &SpannedExpr) -> bool {
645 match &expr.node {
646 Expr::Ident(name) => self.scopes.resolve(name).is_some(),
647 Expr::RootIdent(_) => false, Expr::Member { expr: inner, .. } => self.leftmost_ident_resolves(inner),
649 _ => true, }
651 }
652
653 fn try_qualified_name(&self, obj: &SpannedExpr, field: &str) -> Option<String> {
655 match &obj.node {
656 Expr::Ident(name) => Some(format!("{}.{}", name, field)),
657 Expr::RootIdent(name) => Some(format!(".{}.{}", name, field)),
658 Expr::Member {
659 expr: inner,
660 field: inner_field,
661 ..
662 } => {
663 let prefix = self.try_qualified_name(inner, inner_field)?;
664 Some(format!("{}.{}", prefix, field))
665 }
666 _ => None,
667 }
668 }
669
670 fn resolve_qualified(&self, name: &str) -> Option<&VariableDecl> {
677 if let Some(decl) = self.scopes.resolve(name) {
679 return Some(decl);
680 }
681
682 if !self.container.is_empty() {
684 let qualified = format!("{}.{}", self.container, name);
685 if let Some(decl) = self.scopes.resolve(&qualified) {
686 return Some(decl);
687 }
688 }
689
690 if let Some(abbrevs) = self.abbreviations {
692 if let Some(qualified) = abbrevs.get(name) {
693 if let Some(decl) = self.scopes.resolve(qualified) {
694 return Some(decl);
695 }
696 }
697 }
698
699 None
700 }
701
702 fn check_index(
704 &mut self,
705 obj: &SpannedExpr,
706 index: &SpannedExpr,
707 optional: bool,
708 expr: &SpannedExpr,
709 ) -> CelType {
710 let obj_type = self.check_expr(obj);
711 let index_type = self.check_expr(index);
712
713 let (inner_type, was_optional) = match &obj_type {
715 CelType::Optional(inner) => ((**inner).clone(), true),
716 other => (other.clone(), false),
717 };
718
719 let result = self.resolve_function_call("_[_]", None, &[inner_type, index_type], expr);
721
722 if optional || was_optional {
725 match &result {
726 CelType::Optional(_) => result, _ => CelType::optional(result),
728 }
729 } else {
730 result
731 }
732 }
733
734 fn check_call(
736 &mut self,
737 callee: &SpannedExpr,
738 args: &[SpannedExpr],
739 expr: &SpannedExpr,
740 ) -> CelType {
741 match &callee.node {
743 Expr::Member {
744 expr: receiver,
745 field: func_name,
746 ..
747 } => {
748 if let Some(qualified_name) = self.try_qualified_function_name(receiver, func_name)
750 {
751 if self.functions.contains_key(&qualified_name) {
752 let arg_types: Vec<_> = args.iter().map(|a| self.check_expr(a)).collect();
753 return self.resolve_function_call(&qualified_name, None, &arg_types, expr);
754 }
755
756 if let Some(result) = self.try_enum_constructor(&qualified_name, args, expr) {
758 return result;
759 }
760 }
761
762 let receiver_type = self.check_expr(receiver);
764 let arg_types: Vec<_> = args.iter().map(|a| self.check_expr(a)).collect();
765 self.resolve_function_call(func_name, Some(receiver_type), &arg_types, expr)
766 }
767 Expr::Ident(func_name) => {
768 if let Some(result) = self.try_enum_constructor(func_name, args, expr) {
770 return result;
771 }
772
773 let arg_types: Vec<_> = args.iter().map(|a| self.check_expr(a)).collect();
775 self.resolve_function_call(func_name, None, &arg_types, expr)
776 }
777 _ => {
778 let _ = self.check_expr(callee);
780 for arg in args {
781 self.check_expr(arg);
782 }
783 CelType::Dyn
784 }
785 }
786 }
787
788 fn try_qualified_function_name(&self, obj: &SpannedExpr, field: &str) -> Option<String> {
792 match &obj.node {
793 Expr::Ident(name) => Some(format!("{}.{}", name, field)),
794 Expr::Member {
795 expr: inner,
796 field: inner_field,
797 ..
798 } => {
799 let prefix = self.try_qualified_function_name(inner, inner_field)?;
800 Some(format!("{}.{}", prefix, field))
801 }
802 _ => None,
803 }
804 }
805
806 fn resolve_function_call(
808 &mut self,
809 name: &str,
810 receiver: Option<CelType>,
811 args: &[CelType],
812 expr: &SpannedExpr,
813 ) -> CelType {
814 if let Some(func) = self.functions.get(name) {
815 let func = func.clone(); self.resolve_with_function(&func, receiver, args, expr)
817 } else {
818 self.report_error(CheckError::undeclared_reference(
819 name,
820 expr.span.clone(),
821 expr.id,
822 ));
823 CelType::Error
824 }
825 }
826
827 fn resolve_with_function(
829 &mut self,
830 func: &FunctionDecl,
831 receiver: Option<CelType>,
832 args: &[CelType],
833 expr: &SpannedExpr,
834 ) -> CelType {
835 if let Some(result) =
836 resolve_overload(func, receiver.as_ref(), args, &mut self.substitutions)
837 {
838 self.set_reference(
839 expr.id,
840 ReferenceInfo::function(&func.name, result.overload_ids),
841 );
842 result.result_type
843 } else {
844 let all_args: Vec<_> = receiver
845 .iter()
846 .cloned()
847 .chain(args.iter().cloned())
848 .collect();
849 self.report_error(CheckError::no_matching_overload(
850 &func.name,
851 all_args,
852 expr.span.clone(),
853 expr.id,
854 ));
855 CelType::Error
856 }
857 }
858
859 fn check_struct(
861 &mut self,
862 type_name: &SpannedExpr,
863 fields: &[StructField],
864 expr: &SpannedExpr,
865 ) -> CelType {
866 let name = self.get_type_name(type_name);
868
869 for field in fields {
871 self.check_expr(&field.value);
872 }
873
874 if let Some(ref name) = name {
876 let expanded_name = self.expand_abbreviation(name);
878 let name_to_resolve = expanded_name.as_ref().unwrap_or(name);
879
880 let fq_name = if let Some(registry) = self.type_resolver {
882 registry
883 .resolve_message_name(name_to_resolve, self.container)
884 .unwrap_or_else(|| name_to_resolve.clone())
885 } else {
886 name_to_resolve.clone()
887 };
888
889 self.set_reference(expr.id, ReferenceInfo::ident(&fq_name));
892 self.set_reference(type_name.id, ReferenceInfo::ident(&fq_name));
893 CelType::message(&fq_name)
894 } else {
895 CelType::Dyn
896 }
897 }
898
899 fn get_type_name(&self, expr: &SpannedExpr) -> Option<String> {
901 match &expr.node {
902 Expr::Ident(name) => Some(name.clone()),
903 Expr::RootIdent(name) => Some(format!(".{}", name)),
904 Expr::Member {
905 expr: inner, field, ..
906 } => {
907 let prefix = self.get_type_name(inner)?;
908 Some(format!("{}.{}", prefix, field))
909 }
910 _ => None,
911 }
912 }
913
914 fn check_comprehension(&mut self, comp: &ComprehensionData, _expr: &SpannedExpr) -> CelType {
916 let range_type = self.check_expr(&comp.iter_range);
918
919 let (iter_type, iter_type2) = if !comp.iter_var2.is_empty() {
921 match &range_type {
923 CelType::List(elem) => (CelType::Int, (**elem).clone()),
924 CelType::Map(key, value) => ((**key).clone(), (**value).clone()),
925 CelType::Dyn => (CelType::Dyn, CelType::Dyn),
926 _ => (CelType::Dyn, CelType::Dyn),
927 }
928 } else {
929 let t = match &range_type {
931 CelType::List(elem) => (**elem).clone(),
932 CelType::Map(key, _) => (**key).clone(),
933 CelType::Optional(inner) => (**inner).clone(), CelType::Dyn => CelType::Dyn,
935 _ => CelType::Dyn,
936 };
937 (t, CelType::Dyn)
938 };
939
940 let accu_type = self.check_expr(&comp.accu_init);
942
943 self.scopes.enter_scope();
945
946 self.scopes.add_variable(&comp.iter_var, iter_type.clone());
948 if !comp.iter_var2.is_empty() {
949 self.scopes.add_variable(&comp.iter_var2, iter_type2);
950 }
951
952 self.scopes.add_variable(&comp.accu_var, accu_type.clone());
954
955 let cond_type = self.check_expr(&comp.loop_condition);
957 if !matches!(cond_type, CelType::Bool | CelType::Dyn | CelType::Error) {
958 self.report_error(CheckError::type_mismatch(
959 CelType::Bool,
960 cond_type,
961 comp.loop_condition.span.clone(),
962 comp.loop_condition.id,
963 ));
964 }
965
966 let step_type = self.check_expr(&comp.loop_step);
968
969 if contains_type_var_checker(&accu_type) && !contains_type_var_checker(&step_type) {
972 self.scopes.add_variable(&comp.accu_var, step_type);
973 }
974
975 let result_type = self.check_expr(&comp.result);
977
978 self.scopes.exit_scope();
980
981 result_type
982 }
983
984 fn check_member_test(&mut self, obj: &SpannedExpr, field: &str, expr: &SpannedExpr) -> CelType {
986 let obj_type = self.check_expr(obj);
988
989 let inner_type = match &obj_type {
991 CelType::Optional(inner) => (**inner).clone(),
992 other => other.clone(),
993 };
994
995 match &inner_type {
997 CelType::Message(name) => {
998 if let Some(registry) = self.type_resolver {
999 if registry.get_field_type(name, field).is_some() {
1000 } else if registry.has_message(name) && !registry.is_extension(name, field) {
1002 self.report_error(CheckError::undefined_field(
1003 name,
1004 field,
1005 expr.span.clone(),
1006 expr.id,
1007 ));
1008 }
1009 }
1010 }
1011 CelType::Map(_, _) | CelType::Dyn | CelType::TypeVar(_) => {
1012 }
1014 CelType::Error => {
1015 }
1017 _ => {
1018 self.report_error(CheckError::undefined_field(
1019 &inner_type.display_name(),
1020 field,
1021 expr.span.clone(),
1022 expr.id,
1023 ));
1024 }
1025 }
1026
1027 CelType::Bool
1029 }
1030
1031 fn check_bind(
1035 &mut self,
1036 var_name: &str,
1037 init: &SpannedExpr,
1038 body: &SpannedExpr,
1039 _expr: &SpannedExpr,
1040 ) -> CelType {
1041 let init_type = self.check_expr(init);
1043
1044 self.scopes.enter_scope();
1046 self.scopes.add_variable(var_name, init_type);
1047
1048 let body_type = self.check_expr(body);
1050
1051 self.scopes.exit_scope();
1053
1054 body_type
1056 }
1057}
1058
1059pub fn check(
1070 expr: &SpannedExpr,
1071 variables: &HashMap<String, CelType>,
1072 functions: &HashMap<String, FunctionDecl>,
1073 container: &str,
1074) -> CheckResult {
1075 let checker = Checker::new(variables, functions, container);
1076 checker.check(expr)
1077}
1078
1079pub fn check_with_type_resolver(
1084 expr: &SpannedExpr,
1085 variables: &HashMap<String, CelType>,
1086 functions: &HashMap<String, FunctionDecl>,
1087 container: &str,
1088 type_resolver: &dyn ProtoTypeResolver,
1089) -> CheckResult {
1090 let checker = Checker::new(variables, functions, container).with_type_resolver(type_resolver);
1091 checker.check(expr)
1092}
1093
1094pub fn check_with_abbreviations(
1098 expr: &SpannedExpr,
1099 variables: &HashMap<String, CelType>,
1100 functions: &HashMap<String, FunctionDecl>,
1101 container: &str,
1102 abbreviations: &HashMap<String, String>,
1103) -> CheckResult {
1104 let checker = Checker::new(variables, functions, container).with_abbreviations(abbreviations);
1105 checker.check(expr)
1106}
1107
1108pub fn check_with_type_resolver_and_abbreviations(
1112 expr: &SpannedExpr,
1113 variables: &HashMap<String, CelType>,
1114 functions: &HashMap<String, FunctionDecl>,
1115 container: &str,
1116 type_resolver: &dyn ProtoTypeResolver,
1117 abbreviations: &HashMap<String, String>,
1118) -> CheckResult {
1119 let checker = Checker::new(variables, functions, container)
1120 .with_type_resolver(type_resolver)
1121 .with_abbreviations(abbreviations);
1122 checker.check(expr)
1123}
1124
1125fn contains_type_var_checker(ty: &CelType) -> bool {
1127 match ty {
1128 CelType::TypeVar(_) => true,
1129 CelType::List(elem) => contains_type_var_checker(elem),
1130 CelType::Map(key, val) => contains_type_var_checker(key) || contains_type_var_checker(val),
1131 CelType::Optional(inner) => contains_type_var_checker(inner),
1132 CelType::Wrapper(inner) => contains_type_var_checker(inner),
1133 CelType::Type(inner) => contains_type_var_checker(inner),
1134 _ => false,
1135 }
1136}
1137
1138fn type_specificity(ty: &CelType) -> u32 {
1144 match ty {
1145 CelType::Dyn | CelType::TypeVar(_) => 0,
1146 CelType::Null => 1,
1147 CelType::Bool
1148 | CelType::Int
1149 | CelType::UInt
1150 | CelType::Double
1151 | CelType::String
1152 | CelType::Bytes
1153 | CelType::Timestamp
1154 | CelType::Duration => 2,
1155 CelType::Message(_) | CelType::Enum(_) => 2,
1156 CelType::List(elem) => 2 + type_specificity(elem),
1157 CelType::Map(key, val) => 2 + type_specificity(key) + type_specificity(val),
1158 CelType::Optional(inner) => 2 + type_specificity(inner),
1159 CelType::Wrapper(inner) => 3 + type_specificity(inner), CelType::Type(inner) => 2 + type_specificity(inner),
1161 CelType::Abstract { params, .. } => 2 + params.iter().map(type_specificity).sum::<u32>(),
1162 _ => 1,
1163 }
1164}
1165
1166fn binary_op_to_function(op: BinaryOp) -> &'static str {
1168 match op {
1169 BinaryOp::Add => "_+_",
1170 BinaryOp::Sub => "_-_",
1171 BinaryOp::Mul => "_*_",
1172 BinaryOp::Div => "_/_",
1173 BinaryOp::Mod => "_%_",
1174 BinaryOp::Eq => "_==_",
1175 BinaryOp::Ne => "_!=_",
1176 BinaryOp::Lt => "_<_",
1177 BinaryOp::Le => "_<=_",
1178 BinaryOp::Gt => "_>_",
1179 BinaryOp::Ge => "_>=_",
1180 BinaryOp::And => "_&&_",
1181 BinaryOp::Or => "_||_",
1182 BinaryOp::In => "@in",
1183 }
1184}
1185
1186#[cfg(test)]
1187mod tests {
1188 use std::any::Any;
1189
1190 use super::super::errors::CheckErrorKind;
1191 use super::super::standard_library::STANDARD_LIBRARY;
1192 use super::*;
1193 use crate::eval::proto_registry::ProtoTypeResolver;
1194 use crate::parser::parse;
1195 use crate::types::ResolvedProtoType;
1196
1197 fn standard_functions() -> HashMap<String, FunctionDecl> {
1199 STANDARD_LIBRARY
1200 .iter()
1201 .map(|f| (f.name.clone(), f.clone()))
1202 .collect()
1203 }
1204
1205 fn standard_variables() -> HashMap<String, CelType> {
1207 let mut vars = HashMap::new();
1208 vars.insert("bool".to_string(), CelType::type_of(CelType::Bool));
1209 vars.insert("int".to_string(), CelType::type_of(CelType::Int));
1210 vars.insert("uint".to_string(), CelType::type_of(CelType::UInt));
1211 vars.insert("double".to_string(), CelType::type_of(CelType::Double));
1212 vars.insert("string".to_string(), CelType::type_of(CelType::String));
1213 vars.insert("bytes".to_string(), CelType::type_of(CelType::Bytes));
1214 vars.insert(
1215 "list".to_string(),
1216 CelType::type_of(CelType::list(CelType::Dyn)),
1217 );
1218 vars.insert(
1219 "map".to_string(),
1220 CelType::type_of(CelType::map(CelType::Dyn, CelType::Dyn)),
1221 );
1222 vars.insert("null_type".to_string(), CelType::type_of(CelType::Null));
1223 vars.insert(
1224 "type".to_string(),
1225 CelType::type_of(CelType::type_of(CelType::Dyn)),
1226 );
1227 vars.insert("dyn".to_string(), CelType::type_of(CelType::Dyn));
1228 vars
1229 }
1230
1231 fn check_expr(source: &str) -> CheckResult {
1232 let result = parse(source);
1233 let ast = result.ast.expect("parse should succeed");
1234 let variables = standard_variables();
1235 let functions = standard_functions();
1236 check(&ast, &variables, &functions, "")
1237 }
1238
1239 fn check_expr_with_var(source: &str, var: &str, cel_type: CelType) -> CheckResult {
1240 let result = parse(source);
1241 let ast = result.ast.expect("parse should succeed");
1242 let mut variables = standard_variables();
1243 variables.insert(var.to_string(), cel_type);
1244 let functions = standard_functions();
1245 check(&ast, &variables, &functions, "")
1246 }
1247
1248 #[derive(Debug)]
1250 struct MockProtoResolver {
1251 fields: HashMap<(String, String), CelType>,
1253 messages: Vec<String>,
1255 }
1256
1257 impl MockProtoResolver {
1258 fn new() -> Self {
1259 Self {
1260 fields: HashMap::new(),
1261 messages: Vec::new(),
1262 }
1263 }
1264
1265 fn with_message(mut self, name: &str) -> Self {
1266 self.messages.push(name.to_string());
1267 self
1268 }
1269
1270 fn with_field(mut self, message: &str, field: &str, cel_type: CelType) -> Self {
1271 self.fields
1272 .insert((message.to_string(), field.to_string()), cel_type);
1273 self
1274 }
1275 }
1276
1277 impl ProtoTypeResolver for MockProtoResolver {
1278 fn get_field_type(&self, message: &str, field: &str) -> Option<CelType> {
1279 self.fields
1280 .get(&(message.to_string(), field.to_string()))
1281 .cloned()
1282 }
1283
1284 fn has_message(&self, message: &str) -> bool {
1285 self.messages.contains(&message.to_string())
1286 }
1287
1288 fn is_extension(&self, _message: &str, _ext_name: &str) -> bool {
1289 false
1290 }
1291
1292 fn get_enum_value(&self, _enum_name: &str, _value_name: &str) -> Option<i32> {
1293 None
1294 }
1295
1296 fn resolve_qualified(
1297 &self,
1298 _parts: &[&str],
1299 _container: &str,
1300 ) -> Option<ResolvedProtoType> {
1301 None
1302 }
1303
1304 fn resolve_message_name(&self, _name: &str, _container: &str) -> Option<String> {
1305 None
1306 }
1307
1308 fn as_any(&self) -> &dyn Any {
1309 self
1310 }
1311 }
1312
1313 fn check_expr_with_var_and_resolver(
1314 source: &str,
1315 var: &str,
1316 cel_type: CelType,
1317 resolver: &dyn ProtoTypeResolver,
1318 ) -> CheckResult {
1319 let result = parse(source);
1320 let ast = result.ast.expect("parse should succeed");
1321 let mut variables = standard_variables();
1322 variables.insert(var.to_string(), cel_type);
1323 let functions = standard_functions();
1324 check_with_type_resolver(&ast, &variables, &functions, "", resolver)
1325 }
1326
1327 #[test]
1328 fn test_literal_types() {
1329 assert_eq!(check_expr("null").get_type(1), Some(&CelType::Null));
1330 assert_eq!(check_expr("true").get_type(1), Some(&CelType::Bool));
1331 assert_eq!(check_expr("42").get_type(1), Some(&CelType::Int));
1332 assert_eq!(check_expr("42u").get_type(1), Some(&CelType::UInt));
1333 assert_eq!(check_expr("3.14").get_type(1), Some(&CelType::Double));
1334 assert_eq!(check_expr("\"hello\"").get_type(1), Some(&CelType::String));
1335 assert_eq!(check_expr("b\"hello\"").get_type(1), Some(&CelType::Bytes));
1336 }
1337
1338 #[test]
1339 fn test_undefined_variable() {
1340 let result = check_expr("x");
1341 assert!(!result.is_ok());
1342 assert!(result.errors.iter().any(|e| matches!(
1343 &e.kind,
1344 CheckErrorKind::UndeclaredReference { name, .. } if name == "x"
1345 )));
1346 }
1347
1348 #[test]
1349 fn test_defined_variable() {
1350 let result = check_expr_with_var("x", "x", CelType::Int);
1351 assert!(result.is_ok());
1352 assert_eq!(result.get_type(1), Some(&CelType::Int));
1353 }
1354
1355 #[test]
1356 fn test_binary_add_int() {
1357 let result = check_expr_with_var("x + 1", "x", CelType::Int);
1358 assert!(result.is_ok());
1359 let types: Vec<_> = result.type_map.values().collect();
1361 assert!(types.contains(&&CelType::Int));
1362 }
1363
1364 #[test]
1365 fn test_list_literal() {
1366 let result = check_expr("[1, 2, 3]");
1367 assert!(result.is_ok());
1368 let list_types: Vec<_> = result
1370 .type_map
1371 .values()
1372 .filter(|t| matches!(t, CelType::List(_)))
1373 .collect();
1374 assert_eq!(list_types.len(), 1);
1375 assert_eq!(list_types[0], &CelType::list(CelType::Int));
1376 }
1377
1378 #[test]
1379 fn test_map_literal() {
1380 let result = check_expr("{\"a\": 1, \"b\": 2}");
1381 assert!(result.is_ok());
1382 let map_types: Vec<_> = result
1384 .type_map
1385 .values()
1386 .filter(|t| matches!(t, CelType::Map(_, _)))
1387 .collect();
1388 assert_eq!(map_types.len(), 1);
1389 assert_eq!(map_types[0], &CelType::map(CelType::String, CelType::Int));
1390 }
1391
1392 #[test]
1393 fn test_comparison() {
1394 let result = check_expr_with_var("x > 0", "x", CelType::Int);
1395 assert!(result.is_ok());
1396 let bool_types: Vec<_> = result
1398 .type_map
1399 .values()
1400 .filter(|t| matches!(t, CelType::Bool))
1401 .collect();
1402 assert!(!bool_types.is_empty());
1403 }
1404
1405 #[test]
1406 fn test_method_call() {
1407 let result = check_expr("\"hello\".contains(\"lo\")");
1408 assert!(result.is_ok());
1409 let bool_types: Vec<_> = result
1411 .type_map
1412 .values()
1413 .filter(|t| matches!(t, CelType::Bool))
1414 .collect();
1415 assert!(!bool_types.is_empty());
1416 }
1417
1418 #[test]
1419 fn test_size_method() {
1420 let result = check_expr("\"hello\".size()");
1421 assert!(result.is_ok());
1422 let int_types: Vec<_> = result
1423 .type_map
1424 .values()
1425 .filter(|t| matches!(t, CelType::Int))
1426 .collect();
1427 assert!(!int_types.is_empty());
1428 }
1429
1430 #[test]
1431 fn test_ternary() {
1432 let result = check_expr_with_var("x ? 1 : 2", "x", CelType::Bool);
1433 assert!(result.is_ok());
1434 }
1435
1436 #[test]
1437 fn test_type_mismatch_addition() {
1438 let result = check_expr_with_var("x + \"str\"", "x", CelType::Int);
1439 assert!(!result.is_ok());
1440 assert!(result.errors.iter().any(|e| matches!(
1441 &e.kind,
1442 CheckErrorKind::NoMatchingOverload { function, .. } if function == "_+_"
1443 )));
1444 }
1445
1446 #[test]
1447 fn test_empty_list() {
1448 let result = check_expr("[]");
1449 assert!(result.is_ok());
1450 let list_types: Vec<_> = result
1452 .type_map
1453 .values()
1454 .filter(|t| matches!(t, CelType::List(_)))
1455 .collect();
1456 assert_eq!(list_types.len(), 1);
1457 }
1458
1459 #[test]
1460 fn test_reference_recording() {
1461 let result = check_expr_with_var("x + 1", "x", CelType::Int);
1462 assert!(result.is_ok());
1463
1464 let refs: Vec<_> = result.reference_map.values().collect();
1466 assert!(refs.iter().any(|r| r.name == "x"));
1467
1468 assert!(refs.iter().any(|r| r.name == "_+_"));
1470 }
1471
1472 #[test]
1473 fn test_has_undefined_field() {
1474 let resolver =
1475 MockProtoResolver::new()
1476 .with_message("Msg")
1477 .with_field("Msg", "name", CelType::String);
1478 let result = check_expr_with_var_and_resolver(
1479 "has(x.nonexistent)",
1480 "x",
1481 CelType::Message("Msg".into()),
1482 &resolver,
1483 );
1484 assert!(!result.is_ok());
1485 assert!(result.errors.iter().any(|e| matches!(
1486 &e.kind,
1487 CheckErrorKind::UndefinedField { field, .. } if field == "nonexistent"
1488 )));
1489 }
1490
1491 #[test]
1492 fn test_has_valid_field() {
1493 let resolver =
1494 MockProtoResolver::new()
1495 .with_message("Msg")
1496 .with_field("Msg", "name", CelType::String);
1497 let result = check_expr_with_var_and_resolver(
1498 "has(x.name)",
1499 "x",
1500 CelType::Message("Msg".into()),
1501 &resolver,
1502 );
1503 assert!(result.is_ok());
1504 let bool_types: Vec<_> = result
1506 .type_map
1507 .values()
1508 .filter(|t| matches!(t, CelType::Bool))
1509 .collect();
1510 assert!(!bool_types.is_empty());
1511 }
1512
1513 #[test]
1514 fn test_has_map_field() {
1515 let result = check_expr_with_var(
1516 "has(m.anything)",
1517 "m",
1518 CelType::map(CelType::String, CelType::Int),
1519 );
1520 assert!(result.is_ok());
1521 }
1522
1523 #[test]
1524 fn test_has_dyn_field() {
1525 let result = check_expr_with_var("has(d.anything)", "d", CelType::Dyn);
1526 assert!(result.is_ok());
1527 }
1528}