1use std::collections::HashMap;
11use std::sync::Arc;
12
13use super::errors::CheckError;
14use super::overload::{
15 finalize_type, format_expected_args, resolve_overload, substitute_type, OverloadMismatch,
16};
17use super::scope::ScopeStack;
18use crate::eval::proto_registry::ProtoTypeResolver;
19use crate::types::{
20 BinaryOp, CelType, CelValue, ComprehensionData, Expr, FunctionDecl, ListElement, MapEntry,
21 ResolvedProtoType, SpannedExpr, StructField, UnaryOp, VariableDecl,
22};
23
24#[derive(Debug, Clone)]
26pub struct ReferenceInfo {
27 pub name: String,
29 pub overload_ids: Vec<String>,
31 pub value: Option<CelValue>,
33 pub enum_type: Option<String>,
35}
36
37impl ReferenceInfo {
38 pub fn ident(name: impl Into<String>) -> Self {
40 Self {
41 name: name.into(),
42 overload_ids: Vec::new(),
43 value: None,
44 enum_type: None,
45 }
46 }
47
48 pub fn function(name: impl Into<String>, overload_ids: Vec<String>) -> Self {
50 Self {
51 name: name.into(),
52 overload_ids,
53 value: None,
54 enum_type: None,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct CheckResult {
62 pub type_map: HashMap<i64, CelType>,
64 pub reference_map: HashMap<i64, ReferenceInfo>,
66 pub errors: Vec<CheckError>,
68}
69
70impl CheckResult {
71 pub fn is_ok(&self) -> bool {
73 self.errors.is_empty()
74 }
75
76 pub fn get_type(&self, expr_id: i64) -> Option<&CelType> {
78 self.type_map.get(&expr_id)
79 }
80
81 pub fn get_reference(&self, expr_id: i64) -> Option<&ReferenceInfo> {
83 self.reference_map.get(&expr_id)
84 }
85}
86
87pub struct Checker<'a> {
92 scopes: ScopeStack,
94 functions: &'a HashMap<String, FunctionDecl>,
96 container: &'a str,
98 type_map: HashMap<i64, CelType>,
100 reference_map: HashMap<i64, ReferenceInfo>,
102 errors: Vec<CheckError>,
104 substitutions: HashMap<Arc<str>, CelType>,
106 type_resolver: Option<&'a dyn ProtoTypeResolver>,
108 abbreviations: Option<&'a HashMap<String, String>>,
110}
111
112impl<'a> Checker<'a> {
113 pub fn new(
120 variables: &HashMap<String, CelType>,
121 functions: &'a HashMap<String, FunctionDecl>,
122 container: &'a str,
123 ) -> Self {
124 let mut scopes = ScopeStack::new();
125
126 for (name, cel_type) in variables {
128 scopes.add_variable(name, cel_type.clone());
129 }
130
131 Self {
132 scopes,
133 functions,
134 container,
135 type_map: HashMap::new(),
136 reference_map: HashMap::new(),
137 errors: Vec::new(),
138 substitutions: HashMap::new(),
139 type_resolver: None,
140 abbreviations: None,
141 }
142 }
143
144 pub fn with_type_resolver(mut self, resolver: &'a dyn ProtoTypeResolver) -> Self {
146 self.type_resolver = Some(resolver);
147 self
148 }
149
150 pub fn with_abbreviations(mut self, abbreviations: &'a HashMap<String, String>) -> Self {
152 self.abbreviations = Some(abbreviations);
153 self
154 }
155
156 pub fn check(mut self, expr: &SpannedExpr) -> CheckResult {
158 self.check_expr(expr);
159 self.finalize_types();
160
161 CheckResult {
162 type_map: self.type_map,
163 reference_map: self.reference_map,
164 errors: self.errors,
165 }
166 }
167
168 fn set_type(&mut self, expr_id: i64, cel_type: CelType) {
170 self.type_map.insert(expr_id, cel_type);
171 }
172
173 fn set_reference(&mut self, expr_id: i64, reference: ReferenceInfo) {
175 self.reference_map.insert(expr_id, reference);
176 }
177
178 fn report_error(&mut self, error: CheckError) {
180 self.errors.push(error);
181 }
182
183 fn finalize_types(&mut self) {
185 for ty in self.type_map.values_mut() {
186 *ty = finalize_type(ty);
187 *ty = substitute_type(ty, &self.substitutions);
188 *ty = finalize_type(ty);
189 }
190 }
191
192 fn check_expr(&mut self, expr: &SpannedExpr) -> CelType {
194 let result = match &expr.node {
195 Expr::Null => CelType::Null,
196 Expr::Bool(_) => CelType::Bool,
197 Expr::Int(_) => CelType::Int,
198 Expr::UInt(_) => CelType::UInt,
199 Expr::Float(_) => CelType::Double,
200 Expr::String(_) => CelType::String,
201 Expr::Bytes(_) => CelType::Bytes,
202
203 Expr::Ident(name) => self.check_ident(name, expr),
204 Expr::RootIdent(name) => self.check_ident(name, expr),
205
206 Expr::List(elements) => self.check_list(elements, expr),
207 Expr::Map(entries) => self.check_map(entries, expr),
208
209 Expr::Unary { op, expr: inner } => self.check_unary(*op, inner, expr),
210 Expr::Binary { op, left, right } => self.check_binary(*op, left, right, expr),
211 Expr::Ternary {
212 cond,
213 then_expr,
214 else_expr,
215 } => self.check_ternary(cond, then_expr, else_expr, expr),
216
217 Expr::Member {
218 expr: obj,
219 field,
220 optional,
221 } => self.check_member(obj, field, *optional, expr),
222 Expr::Index {
223 expr: obj,
224 index,
225 optional,
226 } => self.check_index(obj, index, *optional, expr),
227 Expr::Call { expr: callee, args } => self.check_call(callee, args, expr),
228 Expr::Struct { type_name, fields } => self.check_struct(type_name, fields, expr),
229
230 Expr::Comprehension(comp) => self.check_comprehension(comp, expr),
231
232 Expr::MemberTestOnly { expr: obj, field } => self.check_member_test(obj, field, expr),
233
234 Expr::Bind {
235 var_name,
236 init,
237 body,
238 } => self.check_bind(var_name, init, body, expr),
239
240 Expr::Error => CelType::Error,
241 };
242
243 self.set_type(expr.id, result.clone());
244 result
245 }
246
247 fn check_ident(&mut self, name: &str, expr: &SpannedExpr) -> CelType {
254 if self.scopes.is_local(name) {
256 if let Some(decl) = self.scopes.resolve(name) {
257 let cel_type = decl.cel_type.clone();
258 self.set_reference(expr.id, ReferenceInfo::ident(name));
259 return cel_type;
260 }
261 }
262
263 if !self.container.is_empty() {
265 let mut container = self.container;
266 loop {
267 let qualified = format!("{}.{}", container, name);
268 if let Some(decl) = self.scopes.resolve(&qualified) {
269 let cel_type = decl.cel_type.clone();
270 self.set_reference(expr.id, ReferenceInfo::ident(&qualified));
271 return cel_type;
272 }
273 match container.rfind('.') {
274 Some(pos) => container = &container[..pos],
275 None => break,
276 }
277 }
278 }
279
280 if let Some(decl) = self.scopes.resolve(name) {
282 let cel_type = decl.cel_type.clone();
283 self.set_reference(expr.id, ReferenceInfo::ident(name));
284 return cel_type;
285 }
286
287 self.report_error(CheckError::undeclared_reference(
288 name,
289 expr.span.clone(),
290 expr.id,
291 ));
292 CelType::Error
293 }
294
295 fn check_list(&mut self, elements: &[ListElement], _expr: &SpannedExpr) -> CelType {
297 if elements.is_empty() {
298 return CelType::list(CelType::fresh_type_var());
299 }
300
301 let mut elem_types = Vec::new();
302 for elem in elements {
303 let elem_type = self.check_expr(&elem.expr);
304 elem_types.push(elem_type);
305 }
306
307 let joined = self.join_types(&elem_types);
308 CelType::list(joined)
309 }
310
311 fn check_map(&mut self, entries: &[MapEntry], _expr: &SpannedExpr) -> CelType {
313 if entries.is_empty() {
314 return CelType::map(CelType::fresh_type_var(), CelType::fresh_type_var());
315 }
316
317 let mut key_types = Vec::new();
318 let mut value_types = Vec::new();
319
320 for entry in entries {
321 let key_type = self.check_expr(&entry.key);
322 let value_type = self.check_expr(&entry.value);
323 let value_type = if entry.optional {
325 match value_type {
326 CelType::Optional(inner) => (*inner).clone(),
327 other => other,
328 }
329 } else {
330 value_type
331 };
332 key_types.push(key_type);
333 value_types.push(value_type);
334 }
335
336 let key_joined = self.join_types(&key_types);
337 let value_joined = self.join_types(&value_types);
338
339 CelType::map(key_joined, value_joined)
340 }
341
342 fn join_types(&self, types: &[CelType]) -> CelType {
347 if types.is_empty() {
348 return CelType::fresh_type_var();
349 }
350
351 let mut best = &types[0];
353 for candidate in types {
354 if type_specificity(candidate) > type_specificity(best) {
355 if types.iter().all(|t| candidate.is_assignable_from(t)) {
357 best = candidate;
358 }
359 }
360 }
361
362 let all_compatible = types
364 .iter()
365 .all(|t| best.is_assignable_from(t) || t.is_assignable_from(best));
366
367 if all_compatible {
368 best.clone()
369 } else {
370 CelType::Dyn
371 }
372 }
373
374 fn check_unary(&mut self, op: UnaryOp, inner: &SpannedExpr, expr: &SpannedExpr) -> CelType {
376 let inner_type = self.check_expr(inner);
377
378 let func_name = match op {
379 UnaryOp::Neg => "-_",
380 UnaryOp::Not => "!_",
381 };
382
383 self.resolve_function_call(func_name, None, &[inner_type], expr)
384 }
385
386 fn check_binary(
388 &mut self,
389 op: BinaryOp,
390 left: &SpannedExpr,
391 right: &SpannedExpr,
392 expr: &SpannedExpr,
393 ) -> CelType {
394 let left_type = self.check_expr(left);
395 let right_type = self.check_expr(right);
396
397 let func_name = binary_op_to_function(op);
398
399 self.resolve_function_call(func_name, None, &[left_type, right_type], expr)
400 }
401
402 fn check_ternary(
404 &mut self,
405 cond: &SpannedExpr,
406 then_expr: &SpannedExpr,
407 else_expr: &SpannedExpr,
408 expr: &SpannedExpr,
409 ) -> CelType {
410 let cond_type = self.check_expr(cond);
411 let then_type = self.check_expr(then_expr);
412 let else_type = self.check_expr(else_expr);
413
414 if !matches!(cond_type, CelType::Bool | CelType::Dyn | CelType::Error) {
416 self.report_error(CheckError::type_mismatch(
417 CelType::Bool,
418 cond_type,
419 cond.span.clone(),
420 cond.id,
421 ));
422 }
423
424 self.resolve_function_call("_?_:_", None, &[CelType::Bool, then_type, else_type], expr)
426 }
427
428 fn check_member(
430 &mut self,
431 obj: &SpannedExpr,
432 field: &str,
433 optional: bool,
434 expr: &SpannedExpr,
435 ) -> CelType {
436 if !self.leftmost_ident_resolves(obj) {
440 if let Some(qualified_name) = self.try_qualified_name(obj, field) {
441 if let Some(decl) = self.resolve_qualified(&qualified_name) {
443 let cel_type = decl.cel_type.clone();
444 self.set_reference(expr.id, ReferenceInfo::ident(&qualified_name));
445 return cel_type;
446 }
447
448 if let Some(resolved) = self.resolve_proto_qualified(&qualified_name, expr) {
450 return resolved;
451 }
452 }
453 }
454
455 let obj_type = self.check_expr(obj);
457
458 let (inner_type, was_optional) = match &obj_type {
460 CelType::Optional(inner) => ((**inner).clone(), true),
461 other => (other.clone(), false),
462 };
463
464 let result = match &inner_type {
466 CelType::Message(name) => {
467 if let Some(registry) = self.type_resolver {
469 if let Some(field_type) = registry.get_field_type(name, field) {
470 return self.wrap_optional_if_needed(field_type, optional, was_optional);
471 }
472 if registry.has_message(name) && !registry.is_extension(name, field) {
475 self.report_error(CheckError::undefined_field(
476 name,
477 field,
478 expr.span.clone(),
479 expr.id,
480 ));
481 return CelType::Error;
482 }
483 }
484 CelType::Dyn
486 }
487 CelType::Dyn | CelType::TypeVar(_) => {
488 CelType::Dyn
490 }
491 CelType::Map(_, value_type) => {
492 (**value_type).clone()
494 }
495 _ => {
496 self.report_error(CheckError::undefined_field(
498 &inner_type.display_name(),
499 field,
500 expr.span.clone(),
501 expr.id,
502 ));
503 return CelType::Error;
504 }
505 };
506
507 self.wrap_optional_if_needed(result, optional, was_optional)
508 }
509
510 fn wrap_optional_if_needed(
512 &self,
513 result: CelType,
514 optional: bool,
515 was_optional: bool,
516 ) -> CelType {
517 if optional || was_optional {
520 match &result {
521 CelType::Optional(_) => result, _ => CelType::optional(result),
523 }
524 } else {
525 result
526 }
527 }
528
529 fn resolve_proto_qualified(
531 &mut self,
532 qualified_name: &str,
533 expr: &SpannedExpr,
534 ) -> Option<CelType> {
535 let registry = self.type_resolver?;
536
537 let expanded_name = self.expand_abbreviation(qualified_name);
539 let name_to_resolve = expanded_name.as_deref().unwrap_or(qualified_name);
540 let parts: Vec<&str> = name_to_resolve.split('.').collect();
541
542 match registry.resolve_qualified(&parts, self.container)? {
543 ResolvedProtoType::EnumValue { enum_name, value } => {
544 self.set_reference(
545 expr.id,
546 ReferenceInfo {
547 name: name_to_resolve.to_string(),
548 overload_ids: vec![],
549 value: Some(CelValue::Int(value as i64)),
550 enum_type: Some(enum_name),
551 },
552 );
553 Some(CelType::Int)
554 }
555 ResolvedProtoType::Enum { name, cel_type } => {
556 self.set_reference(expr.id, ReferenceInfo::ident(&name));
557 Some(cel_type)
558 }
559 ResolvedProtoType::Message { name, cel_type } => {
560 self.set_reference(expr.id, ReferenceInfo::ident(&name));
561 Some(cel_type)
562 }
563 }
564 }
565
566 fn try_enum_constructor(
571 &mut self,
572 name: &str,
573 args: &[SpannedExpr],
574 expr: &SpannedExpr,
575 ) -> Option<CelType> {
576 let registry = self.type_resolver?;
577
578 let expanded_name = self.expand_abbreviation(name);
580 let name_to_resolve = expanded_name.as_deref().unwrap_or(name);
581 let parts: Vec<&str> = name_to_resolve.split('.').collect();
582
583 match registry.resolve_qualified(&parts, self.container)? {
585 ResolvedProtoType::Enum {
586 name: enum_name, ..
587 } => {
588 if args.len() != 1 {
590 return None;
591 }
592 let arg_type = self.check_expr(&args[0]);
593
594 match &arg_type {
595 CelType::Int | CelType::String | CelType::Dyn => {
596 self.set_reference(
598 expr.id,
599 ReferenceInfo {
600 name: enum_name.clone(),
601 overload_ids: vec!["enum_constructor".to_string()],
602 value: None,
603 enum_type: Some(enum_name),
604 },
605 );
606 Some(CelType::Int)
607 }
608 _ => None,
609 }
610 }
611 _ => None,
612 }
613 }
614
615 fn expand_abbreviation(&self, name: &str) -> Option<String> {
621 let abbrevs = self.abbreviations?;
622
623 let first_segment = name.split('.').next()?;
625
626 if let Some(qualified) = abbrevs.get(first_segment) {
628 if let Some(rest) = name
630 .strip_prefix(first_segment)
631 .and_then(|s| s.strip_prefix('.'))
632 {
633 Some(format!("{}.{}", qualified, rest))
634 } else {
635 Some(qualified.clone())
636 }
637 } else {
638 None
639 }
640 }
641
642 fn leftmost_ident_resolves(&self, expr: &SpannedExpr) -> bool {
647 match &expr.node {
648 Expr::Ident(name) => self.scopes.resolve(name).is_some(),
649 Expr::RootIdent(_) => false, Expr::Member { expr: inner, .. } => self.leftmost_ident_resolves(inner),
651 _ => true, }
653 }
654
655 fn try_qualified_name(&self, obj: &SpannedExpr, field: &str) -> Option<String> {
657 match &obj.node {
658 Expr::Ident(name) => Some(format!("{}.{}", name, field)),
659 Expr::RootIdent(name) => Some(format!(".{}.{}", name, field)),
660 Expr::Member {
661 expr: inner,
662 field: inner_field,
663 ..
664 } => {
665 let prefix = self.try_qualified_name(inner, inner_field)?;
666 Some(format!("{}.{}", prefix, field))
667 }
668 _ => None,
669 }
670 }
671
672 fn resolve_qualified(&self, name: &str) -> Option<&VariableDecl> {
679 if let Some(decl) = self.scopes.resolve(name) {
681 return Some(decl);
682 }
683
684 if !self.container.is_empty() {
686 let qualified = format!("{}.{}", self.container, name);
687 if let Some(decl) = self.scopes.resolve(&qualified) {
688 return Some(decl);
689 }
690 }
691
692 if let Some(abbrevs) = self.abbreviations {
694 if let Some(qualified) = abbrevs.get(name) {
695 if let Some(decl) = self.scopes.resolve(qualified) {
696 return Some(decl);
697 }
698 }
699 }
700
701 None
702 }
703
704 fn check_index(
706 &mut self,
707 obj: &SpannedExpr,
708 index: &SpannedExpr,
709 optional: bool,
710 expr: &SpannedExpr,
711 ) -> CelType {
712 let obj_type = self.check_expr(obj);
713 let index_type = self.check_expr(index);
714
715 let (inner_type, was_optional) = match &obj_type {
717 CelType::Optional(inner) => ((**inner).clone(), true),
718 other => (other.clone(), false),
719 };
720
721 let result = self.resolve_function_call("_[_]", None, &[inner_type, index_type], expr);
723
724 if optional || was_optional {
727 match &result {
728 CelType::Optional(_) => result, _ => CelType::optional(result),
730 }
731 } else {
732 result
733 }
734 }
735
736 fn check_call(
738 &mut self,
739 callee: &SpannedExpr,
740 args: &[SpannedExpr],
741 expr: &SpannedExpr,
742 ) -> CelType {
743 match &callee.node {
745 Expr::Member {
746 expr: receiver,
747 field: func_name,
748 ..
749 } => {
750 if let Some(qualified_name) = self.try_qualified_function_name(receiver, func_name)
752 {
753 if self.functions.contains_key(&qualified_name) {
754 let arg_types: Vec<_> = args.iter().map(|a| self.check_expr(a)).collect();
755 return self.resolve_function_call(&qualified_name, None, &arg_types, expr);
756 }
757
758 if let Some(result) = self.try_enum_constructor(&qualified_name, args, expr) {
760 return result;
761 }
762 }
763
764 let receiver_type = self.check_expr(receiver);
766 let arg_types: Vec<_> = args.iter().map(|a| self.check_expr(a)).collect();
767 self.resolve_function_call(func_name, Some(receiver_type), &arg_types, expr)
768 }
769 Expr::Ident(func_name) => {
770 if let Some(result) = self.try_enum_constructor(func_name, args, expr) {
772 return result;
773 }
774
775 let arg_types: Vec<_> = args.iter().map(|a| self.check_expr(a)).collect();
777 self.resolve_function_call(func_name, None, &arg_types, expr)
778 }
779 _ => {
780 let _ = self.check_expr(callee);
782 for arg in args {
783 self.check_expr(arg);
784 }
785 CelType::Dyn
786 }
787 }
788 }
789
790 fn try_qualified_function_name(&self, obj: &SpannedExpr, field: &str) -> Option<String> {
794 match &obj.node {
795 Expr::Ident(name) => Some(format!("{}.{}", name, field)),
796 Expr::Member {
797 expr: inner,
798 field: inner_field,
799 ..
800 } => {
801 let prefix = self.try_qualified_function_name(inner, inner_field)?;
802 Some(format!("{}.{}", prefix, field))
803 }
804 _ => None,
805 }
806 }
807
808 fn resolve_function_call(
810 &mut self,
811 name: &str,
812 receiver: Option<CelType>,
813 args: &[CelType],
814 expr: &SpannedExpr,
815 ) -> CelType {
816 if let Some(func) = self.functions.get(name) {
817 let func = func.clone(); self.resolve_with_function(&func, receiver, args, expr)
819 } else {
820 self.report_error(CheckError::undeclared_reference(
821 name,
822 expr.span.clone(),
823 expr.id,
824 ));
825 CelType::Error
826 }
827 }
828
829 fn resolve_with_function(
831 &mut self,
832 func: &FunctionDecl,
833 receiver: Option<CelType>,
834 args: &[CelType],
835 expr: &SpannedExpr,
836 ) -> CelType {
837 match resolve_overload(func, receiver.as_ref(), args, &mut self.substitutions) {
838 Ok(result) => {
839 self.set_reference(
840 expr.id,
841 ReferenceInfo::function(&func.name, result.overload_ids),
842 );
843 result.result_type
844 }
845 Err(OverloadMismatch::ArityMismatch { expected }) => {
846 let got = args.len();
847 let msg = format!(
848 "'{}' expects {}, got {}",
849 func.name,
850 format_expected_args(&expected),
851 got
852 );
853 self.report_error(CheckError::other(msg, expr.span.clone(), expr.id));
854 CelType::Error
855 }
856 Err(_) => {
857 let all_args: Vec<_> = receiver
858 .iter()
859 .cloned()
860 .chain(args.iter().cloned())
861 .collect();
862 self.report_error(CheckError::no_matching_overload(
863 &func.name,
864 all_args,
865 expr.span.clone(),
866 expr.id,
867 ));
868 CelType::Error
869 }
870 }
871 }
872
873 fn check_struct(
875 &mut self,
876 type_name: &SpannedExpr,
877 fields: &[StructField],
878 expr: &SpannedExpr,
879 ) -> CelType {
880 let name = self.get_type_name(type_name);
882
883 for field in fields {
885 self.check_expr(&field.value);
886 }
887
888 if let Some(ref name) = name {
890 let expanded_name = self.expand_abbreviation(name);
892 let name_to_resolve = expanded_name.as_ref().unwrap_or(name);
893
894 let fq_name = if let Some(registry) = self.type_resolver {
896 registry
897 .resolve_message_name(name_to_resolve, self.container)
898 .unwrap_or_else(|| name_to_resolve.clone())
899 } else {
900 name_to_resolve.clone()
901 };
902
903 self.set_reference(expr.id, ReferenceInfo::ident(&fq_name));
906 self.set_reference(type_name.id, ReferenceInfo::ident(&fq_name));
907 CelType::message(&fq_name)
908 } else {
909 CelType::Dyn
910 }
911 }
912
913 fn get_type_name(&self, expr: &SpannedExpr) -> Option<String> {
915 match &expr.node {
916 Expr::Ident(name) => Some(name.clone()),
917 Expr::RootIdent(name) => Some(format!(".{}", name)),
918 Expr::Member {
919 expr: inner, field, ..
920 } => {
921 let prefix = self.get_type_name(inner)?;
922 Some(format!("{}.{}", prefix, field))
923 }
924 _ => None,
925 }
926 }
927
928 fn check_comprehension(&mut self, comp: &ComprehensionData, _expr: &SpannedExpr) -> CelType {
930 let range_type = self.check_expr(&comp.iter_range);
932
933 let (iter_type, iter_type2) = if !comp.iter_var2.is_empty() {
935 match &range_type {
937 CelType::List(elem) => (CelType::Int, (**elem).clone()),
938 CelType::Map(key, value) => ((**key).clone(), (**value).clone()),
939 CelType::Dyn => (CelType::Dyn, CelType::Dyn),
940 _ => (CelType::Dyn, CelType::Dyn),
941 }
942 } else {
943 let t = match &range_type {
945 CelType::List(elem) => (**elem).clone(),
946 CelType::Map(key, _) => (**key).clone(),
947 CelType::Optional(inner) => (**inner).clone(), CelType::Dyn => CelType::Dyn,
949 _ => CelType::Dyn,
950 };
951 (t, CelType::Dyn)
952 };
953
954 let accu_type = self.check_expr(&comp.accu_init);
956
957 self.scopes.enter_scope();
959
960 self.scopes.add_variable(&comp.iter_var, iter_type.clone());
962 if !comp.iter_var2.is_empty() {
963 self.scopes.add_variable(&comp.iter_var2, iter_type2);
964 }
965
966 self.scopes.add_variable(&comp.accu_var, accu_type.clone());
968
969 let cond_type = self.check_expr(&comp.loop_condition);
971 if !matches!(cond_type, CelType::Bool | CelType::Dyn | CelType::Error) {
972 self.report_error(CheckError::type_mismatch(
973 CelType::Bool,
974 cond_type,
975 comp.loop_condition.span.clone(),
976 comp.loop_condition.id,
977 ));
978 }
979
980 let step_type = self.check_expr(&comp.loop_step);
982
983 if contains_type_var_checker(&accu_type) && !contains_type_var_checker(&step_type) {
986 self.scopes.add_variable(&comp.accu_var, step_type);
987 }
988
989 let result_type = self.check_expr(&comp.result);
991
992 self.scopes.exit_scope();
994
995 result_type
996 }
997
998 fn check_member_test(&mut self, obj: &SpannedExpr, field: &str, expr: &SpannedExpr) -> CelType {
1000 let obj_type = self.check_expr(obj);
1002
1003 let inner_type = match &obj_type {
1005 CelType::Optional(inner) => (**inner).clone(),
1006 other => other.clone(),
1007 };
1008
1009 match &inner_type {
1011 CelType::Message(name) => {
1012 if let Some(registry) = self.type_resolver {
1013 if registry.get_field_type(name, field).is_some() {
1014 } else if registry.has_message(name) && !registry.is_extension(name, field) {
1016 self.report_error(CheckError::undefined_field(
1017 name,
1018 field,
1019 expr.span.clone(),
1020 expr.id,
1021 ));
1022 }
1023 }
1024 }
1025 CelType::Map(_, _) | CelType::Dyn | CelType::TypeVar(_) => {
1026 }
1028 CelType::Error => {
1029 }
1031 _ => {
1032 self.report_error(CheckError::undefined_field(
1033 &inner_type.display_name(),
1034 field,
1035 expr.span.clone(),
1036 expr.id,
1037 ));
1038 }
1039 }
1040
1041 CelType::Bool
1043 }
1044
1045 fn check_bind(
1049 &mut self,
1050 var_name: &str,
1051 init: &SpannedExpr,
1052 body: &SpannedExpr,
1053 _expr: &SpannedExpr,
1054 ) -> CelType {
1055 let init_type = self.check_expr(init);
1057
1058 self.scopes.enter_scope();
1060 self.scopes.add_variable(var_name, init_type);
1061
1062 let body_type = self.check_expr(body);
1064
1065 self.scopes.exit_scope();
1067
1068 body_type
1070 }
1071}
1072
1073pub fn check(
1084 expr: &SpannedExpr,
1085 variables: &HashMap<String, CelType>,
1086 functions: &HashMap<String, FunctionDecl>,
1087 container: &str,
1088) -> CheckResult {
1089 let checker = Checker::new(variables, functions, container);
1090 checker.check(expr)
1091}
1092
1093pub fn check_with_type_resolver(
1098 expr: &SpannedExpr,
1099 variables: &HashMap<String, CelType>,
1100 functions: &HashMap<String, FunctionDecl>,
1101 container: &str,
1102 type_resolver: &dyn ProtoTypeResolver,
1103) -> CheckResult {
1104 let checker = Checker::new(variables, functions, container).with_type_resolver(type_resolver);
1105 checker.check(expr)
1106}
1107
1108pub fn check_with_abbreviations(
1112 expr: &SpannedExpr,
1113 variables: &HashMap<String, CelType>,
1114 functions: &HashMap<String, FunctionDecl>,
1115 container: &str,
1116 abbreviations: &HashMap<String, String>,
1117) -> CheckResult {
1118 let checker = Checker::new(variables, functions, container).with_abbreviations(abbreviations);
1119 checker.check(expr)
1120}
1121
1122pub fn check_with_type_resolver_and_abbreviations(
1126 expr: &SpannedExpr,
1127 variables: &HashMap<String, CelType>,
1128 functions: &HashMap<String, FunctionDecl>,
1129 container: &str,
1130 type_resolver: &dyn ProtoTypeResolver,
1131 abbreviations: &HashMap<String, String>,
1132) -> CheckResult {
1133 let checker = Checker::new(variables, functions, container)
1134 .with_type_resolver(type_resolver)
1135 .with_abbreviations(abbreviations);
1136 checker.check(expr)
1137}
1138
1139fn contains_type_var_checker(ty: &CelType) -> bool {
1141 match ty {
1142 CelType::TypeVar(_) => true,
1143 CelType::List(elem) => contains_type_var_checker(elem),
1144 CelType::Map(key, val) => contains_type_var_checker(key) || contains_type_var_checker(val),
1145 CelType::Optional(inner) => contains_type_var_checker(inner),
1146 CelType::Wrapper(inner) => contains_type_var_checker(inner),
1147 CelType::Type(inner) => contains_type_var_checker(inner),
1148 _ => false,
1149 }
1150}
1151
1152fn type_specificity(ty: &CelType) -> u32 {
1158 match ty {
1159 CelType::Dyn | CelType::TypeVar(_) => 0,
1160 CelType::Null => 1,
1161 CelType::Bool
1162 | CelType::Int
1163 | CelType::UInt
1164 | CelType::Double
1165 | CelType::String
1166 | CelType::Bytes
1167 | CelType::Timestamp
1168 | CelType::Duration => 2,
1169 CelType::Message(_) | CelType::Enum(_) => 2,
1170 CelType::List(elem) => 2 + type_specificity(elem),
1171 CelType::Map(key, val) => 2 + type_specificity(key) + type_specificity(val),
1172 CelType::Optional(inner) => 2 + type_specificity(inner),
1173 CelType::Wrapper(inner) => 3 + type_specificity(inner), CelType::Type(inner) => 2 + type_specificity(inner),
1175 CelType::Abstract { params, .. } => 2 + params.iter().map(type_specificity).sum::<u32>(),
1176 _ => 1,
1177 }
1178}
1179
1180fn binary_op_to_function(op: BinaryOp) -> &'static str {
1182 match op {
1183 BinaryOp::Add => "_+_",
1184 BinaryOp::Sub => "_-_",
1185 BinaryOp::Mul => "_*_",
1186 BinaryOp::Div => "_/_",
1187 BinaryOp::Mod => "_%_",
1188 BinaryOp::Eq => "_==_",
1189 BinaryOp::Ne => "_!=_",
1190 BinaryOp::Lt => "_<_",
1191 BinaryOp::Le => "_<=_",
1192 BinaryOp::Gt => "_>_",
1193 BinaryOp::Ge => "_>=_",
1194 BinaryOp::And => "_&&_",
1195 BinaryOp::Or => "_||_",
1196 BinaryOp::In => "@in",
1197 }
1198}
1199
1200#[cfg(test)]
1201mod tests {
1202 use std::any::Any;
1203
1204 use super::super::errors::CheckErrorKind;
1205 use super::super::standard_library::STANDARD_LIBRARY;
1206 use super::*;
1207 use crate::eval::proto_registry::ProtoTypeResolver;
1208 use crate::parser::parse;
1209 use crate::types::ResolvedProtoType;
1210
1211 fn standard_functions() -> HashMap<String, FunctionDecl> {
1213 STANDARD_LIBRARY
1214 .iter()
1215 .map(|f| (f.name.clone(), f.clone()))
1216 .collect()
1217 }
1218
1219 fn standard_variables() -> HashMap<String, CelType> {
1221 let mut vars = HashMap::new();
1222 vars.insert("bool".to_string(), CelType::type_of(CelType::Bool));
1223 vars.insert("int".to_string(), CelType::type_of(CelType::Int));
1224 vars.insert("uint".to_string(), CelType::type_of(CelType::UInt));
1225 vars.insert("double".to_string(), CelType::type_of(CelType::Double));
1226 vars.insert("string".to_string(), CelType::type_of(CelType::String));
1227 vars.insert("bytes".to_string(), CelType::type_of(CelType::Bytes));
1228 vars.insert(
1229 "list".to_string(),
1230 CelType::type_of(CelType::list(CelType::Dyn)),
1231 );
1232 vars.insert(
1233 "map".to_string(),
1234 CelType::type_of(CelType::map(CelType::Dyn, CelType::Dyn)),
1235 );
1236 vars.insert("null_type".to_string(), CelType::type_of(CelType::Null));
1237 vars.insert(
1238 "type".to_string(),
1239 CelType::type_of(CelType::type_of(CelType::Dyn)),
1240 );
1241 vars.insert("dyn".to_string(), CelType::type_of(CelType::Dyn));
1242 vars
1243 }
1244
1245 fn check_expr(source: &str) -> CheckResult {
1246 let result = parse(source);
1247 let ast = result.ast.expect("parse should succeed");
1248 let variables = standard_variables();
1249 let functions = standard_functions();
1250 check(&ast, &variables, &functions, "")
1251 }
1252
1253 fn check_expr_with_var(source: &str, var: &str, cel_type: CelType) -> CheckResult {
1254 let result = parse(source);
1255 let ast = result.ast.expect("parse should succeed");
1256 let mut variables = standard_variables();
1257 variables.insert(var.to_string(), cel_type);
1258 let functions = standard_functions();
1259 check(&ast, &variables, &functions, "")
1260 }
1261
1262 #[derive(Debug)]
1264 struct MockProtoResolver {
1265 fields: HashMap<(String, String), CelType>,
1267 messages: Vec<String>,
1269 }
1270
1271 impl MockProtoResolver {
1272 fn new() -> Self {
1273 Self {
1274 fields: HashMap::new(),
1275 messages: Vec::new(),
1276 }
1277 }
1278
1279 fn with_message(mut self, name: &str) -> Self {
1280 self.messages.push(name.to_string());
1281 self
1282 }
1283
1284 fn with_field(mut self, message: &str, field: &str, cel_type: CelType) -> Self {
1285 self.fields
1286 .insert((message.to_string(), field.to_string()), cel_type);
1287 self
1288 }
1289 }
1290
1291 impl ProtoTypeResolver for MockProtoResolver {
1292 fn get_field_type(&self, message: &str, field: &str) -> Option<CelType> {
1293 self.fields
1294 .get(&(message.to_string(), field.to_string()))
1295 .cloned()
1296 }
1297
1298 fn has_message(&self, message: &str) -> bool {
1299 self.messages.contains(&message.to_string())
1300 }
1301
1302 fn is_extension(&self, _message: &str, _ext_name: &str) -> bool {
1303 false
1304 }
1305
1306 fn get_enum_value(&self, _enum_name: &str, _value_name: &str) -> Option<i32> {
1307 None
1308 }
1309
1310 fn resolve_qualified(
1311 &self,
1312 _parts: &[&str],
1313 _container: &str,
1314 ) -> Option<ResolvedProtoType> {
1315 None
1316 }
1317
1318 fn resolve_message_name(&self, _name: &str, _container: &str) -> Option<String> {
1319 None
1320 }
1321
1322 fn as_any(&self) -> &dyn Any {
1323 self
1324 }
1325 }
1326
1327 fn check_expr_with_var_and_resolver(
1328 source: &str,
1329 var: &str,
1330 cel_type: CelType,
1331 resolver: &dyn ProtoTypeResolver,
1332 ) -> CheckResult {
1333 let result = parse(source);
1334 let ast = result.ast.expect("parse should succeed");
1335 let mut variables = standard_variables();
1336 variables.insert(var.to_string(), cel_type);
1337 let functions = standard_functions();
1338 check_with_type_resolver(&ast, &variables, &functions, "", resolver)
1339 }
1340
1341 #[test]
1342 fn test_literal_types() {
1343 assert_eq!(check_expr("null").get_type(1), Some(&CelType::Null));
1344 assert_eq!(check_expr("true").get_type(1), Some(&CelType::Bool));
1345 assert_eq!(check_expr("42").get_type(1), Some(&CelType::Int));
1346 assert_eq!(check_expr("42u").get_type(1), Some(&CelType::UInt));
1347 assert_eq!(check_expr("3.14").get_type(1), Some(&CelType::Double));
1348 assert_eq!(check_expr("\"hello\"").get_type(1), Some(&CelType::String));
1349 assert_eq!(check_expr("b\"hello\"").get_type(1), Some(&CelType::Bytes));
1350 }
1351
1352 #[test]
1353 fn test_undefined_variable() {
1354 let result = check_expr("x");
1355 assert!(!result.is_ok());
1356 assert!(result.errors.iter().any(|e| matches!(
1357 &e.kind,
1358 CheckErrorKind::UndeclaredReference { name, .. } if name == "x"
1359 )));
1360 }
1361
1362 #[test]
1363 fn test_defined_variable() {
1364 let result = check_expr_with_var("x", "x", CelType::Int);
1365 assert!(result.is_ok());
1366 assert_eq!(result.get_type(1), Some(&CelType::Int));
1367 }
1368
1369 #[test]
1370 fn test_binary_add_int() {
1371 let result = check_expr_with_var("x + 1", "x", CelType::Int);
1372 assert!(result.is_ok());
1373 let types: Vec<_> = result.type_map.values().collect();
1375 assert!(types.contains(&&CelType::Int));
1376 }
1377
1378 #[test]
1379 fn test_list_literal() {
1380 let result = check_expr("[1, 2, 3]");
1381 assert!(result.is_ok());
1382 let list_types: Vec<_> = result
1384 .type_map
1385 .values()
1386 .filter(|t| matches!(t, CelType::List(_)))
1387 .collect();
1388 assert_eq!(list_types.len(), 1);
1389 assert_eq!(list_types[0], &CelType::list(CelType::Int));
1390 }
1391
1392 #[test]
1393 fn test_map_literal() {
1394 let result = check_expr("{\"a\": 1, \"b\": 2}");
1395 assert!(result.is_ok());
1396 let map_types: Vec<_> = result
1398 .type_map
1399 .values()
1400 .filter(|t| matches!(t, CelType::Map(_, _)))
1401 .collect();
1402 assert_eq!(map_types.len(), 1);
1403 assert_eq!(map_types[0], &CelType::map(CelType::String, CelType::Int));
1404 }
1405
1406 #[test]
1407 fn test_comparison() {
1408 let result = check_expr_with_var("x > 0", "x", CelType::Int);
1409 assert!(result.is_ok());
1410 let bool_types: Vec<_> = result
1412 .type_map
1413 .values()
1414 .filter(|t| matches!(t, CelType::Bool))
1415 .collect();
1416 assert!(!bool_types.is_empty());
1417 }
1418
1419 #[test]
1420 fn test_method_call() {
1421 let result = check_expr("\"hello\".contains(\"lo\")");
1422 assert!(result.is_ok());
1423 let bool_types: Vec<_> = result
1425 .type_map
1426 .values()
1427 .filter(|t| matches!(t, CelType::Bool))
1428 .collect();
1429 assert!(!bool_types.is_empty());
1430 }
1431
1432 #[test]
1433 fn test_size_method() {
1434 let result = check_expr("\"hello\".size()");
1435 assert!(result.is_ok());
1436 let int_types: Vec<_> = result
1437 .type_map
1438 .values()
1439 .filter(|t| matches!(t, CelType::Int))
1440 .collect();
1441 assert!(!int_types.is_empty());
1442 }
1443
1444 #[test]
1445 fn test_ternary() {
1446 let result = check_expr_with_var("x ? 1 : 2", "x", CelType::Bool);
1447 assert!(result.is_ok());
1448 }
1449
1450 #[test]
1451 fn test_type_mismatch_addition() {
1452 let result = check_expr_with_var("x + \"str\"", "x", CelType::Int);
1453 assert!(!result.is_ok());
1454 assert!(result.errors.iter().any(|e| matches!(
1455 &e.kind,
1456 CheckErrorKind::NoMatchingOverload { function, .. } if function == "_+_"
1457 )));
1458 }
1459
1460 #[test]
1461 fn test_empty_list() {
1462 let result = check_expr("[]");
1463 assert!(result.is_ok());
1464 let list_types: Vec<_> = result
1466 .type_map
1467 .values()
1468 .filter(|t| matches!(t, CelType::List(_)))
1469 .collect();
1470 assert_eq!(list_types.len(), 1);
1471 }
1472
1473 #[test]
1474 fn test_reference_recording() {
1475 let result = check_expr_with_var("x + 1", "x", CelType::Int);
1476 assert!(result.is_ok());
1477
1478 let refs: Vec<_> = result.reference_map.values().collect();
1480 assert!(refs.iter().any(|r| r.name == "x"));
1481
1482 assert!(refs.iter().any(|r| r.name == "_+_"));
1484 }
1485
1486 #[test]
1487 fn test_has_undefined_field() {
1488 let resolver =
1489 MockProtoResolver::new()
1490 .with_message("Msg")
1491 .with_field("Msg", "name", CelType::String);
1492 let result = check_expr_with_var_and_resolver(
1493 "has(x.nonexistent)",
1494 "x",
1495 CelType::Message("Msg".into()),
1496 &resolver,
1497 );
1498 assert!(!result.is_ok());
1499 assert!(result.errors.iter().any(|e| matches!(
1500 &e.kind,
1501 CheckErrorKind::UndefinedField { field, .. } if field == "nonexistent"
1502 )));
1503 }
1504
1505 #[test]
1506 fn test_has_valid_field() {
1507 let resolver =
1508 MockProtoResolver::new()
1509 .with_message("Msg")
1510 .with_field("Msg", "name", CelType::String);
1511 let result = check_expr_with_var_and_resolver(
1512 "has(x.name)",
1513 "x",
1514 CelType::Message("Msg".into()),
1515 &resolver,
1516 );
1517 assert!(result.is_ok());
1518 let bool_types: Vec<_> = result
1520 .type_map
1521 .values()
1522 .filter(|t| matches!(t, CelType::Bool))
1523 .collect();
1524 assert!(!bool_types.is_empty());
1525 }
1526
1527 #[test]
1528 fn test_has_map_field() {
1529 let result = check_expr_with_var(
1530 "has(m.anything)",
1531 "m",
1532 CelType::map(CelType::String, CelType::Int),
1533 );
1534 assert!(result.is_ok());
1535 }
1536
1537 #[test]
1538 fn test_has_dyn_field() {
1539 let result = check_expr_with_var("has(d.anything)", "d", CelType::Dyn);
1540 assert!(result.is_ok());
1541 }
1542
1543 #[test]
1544 fn test_arity_error_size_standalone_no_args() {
1545 let result = check_expr("size()");
1547 assert!(!result.is_ok());
1548 assert!(result.errors.iter().any(|e| matches!(
1549 &e.kind,
1550 CheckErrorKind::Other(msg) if msg.contains("'size' expects exactly 1 argument, got 0")
1551 )));
1552 }
1553
1554 #[test]
1555 fn test_arity_error_method_extra_args() {
1556 let result = check_expr("\"hello\".size(1)");
1558 assert!(!result.is_ok());
1559 assert!(result.errors.iter().any(|e| matches!(
1560 &e.kind,
1561 CheckErrorKind::Other(msg) if msg.contains("'size' expects exactly 0 arguments, got 1")
1562 )));
1563 }
1564
1565 #[test]
1566 fn test_arity_error_contains_no_args() {
1567 let result = check_expr("\"hello\".contains()");
1569 assert!(!result.is_ok());
1570 assert!(result.errors.iter().any(|e| matches!(
1571 &e.kind,
1572 CheckErrorKind::Other(msg) if msg.contains("'contains' expects exactly 1 argument, got 0")
1573 )));
1574 }
1575
1576 #[test]
1577 fn test_type_mismatch_still_reported_for_correct_arity() {
1578 let result = check_expr("size(123)");
1580 assert!(!result.is_ok());
1581 assert!(result.errors.iter().any(|e| matches!(
1582 &e.kind,
1583 CheckErrorKind::NoMatchingOverload { function, .. } if function == "size"
1584 )));
1585 }
1586
1587 #[test]
1588 fn test_operator_type_mismatch_unchanged() {
1589 let result = check_expr_with_var("x + \"a\"", "x", CelType::Int);
1591 assert!(!result.is_ok());
1592 assert!(result.errors.iter().any(|e| matches!(
1593 &e.kind,
1594 CheckErrorKind::NoMatchingOverload { function, .. } if function == "_+_"
1595 )));
1596 }
1597}