Skip to main content

cel_core/checker/
checker.rs

1//! Core type checker implementation.
2//!
3//! This module provides the main `Checker` struct and the `check` function
4//! for type-checking CEL expressions.
5//!
6//! The checker is independent and takes raw data (variables, functions, container)
7//! rather than a type environment struct. This allows it to be used as a building
8//! block for higher-level APIs.
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use crate::types::{BinaryOp, Expr, FunctionDecl, ListElement, MapEntry, SpannedExpr, StructField, UnaryOp, VariableDecl};
14use crate::types::{CelType, CelValue};
15use crate::types::{ProtoTypeRegistry, ResolvedProtoType};
16use super::errors::CheckError;
17use super::overload::{finalize_type, resolve_overload, substitute_type};
18use super::scope::ScopeStack;
19
20/// Reference information for a resolved identifier or function.
21#[derive(Debug, Clone)]
22pub struct ReferenceInfo {
23    /// The fully qualified name.
24    pub name: String,
25    /// Matching overload IDs for function calls.
26    pub overload_ids: Vec<String>,
27    /// Constant value for enum constants.
28    pub value: Option<CelValue>,
29}
30
31impl ReferenceInfo {
32    /// Create a new identifier reference.
33    pub fn ident(name: impl Into<String>) -> Self {
34        Self {
35            name: name.into(),
36            overload_ids: Vec::new(),
37            value: None,
38        }
39    }
40
41    /// Create a new function reference with overload IDs.
42    pub fn function(name: impl Into<String>, overload_ids: Vec<String>) -> Self {
43        Self {
44            name: name.into(),
45            overload_ids,
46            value: None,
47        }
48    }
49}
50
51/// Result of type checking an expression.
52#[derive(Debug, Clone)]
53pub struct CheckResult {
54    /// Map from expression ID to inferred type.
55    pub type_map: HashMap<i64, CelType>,
56    /// Map from expression ID to resolved reference.
57    pub reference_map: HashMap<i64, ReferenceInfo>,
58    /// Errors encountered during type checking.
59    pub errors: Vec<CheckError>,
60}
61
62impl CheckResult {
63    /// Check if type checking was successful (no errors).
64    pub fn is_ok(&self) -> bool {
65        self.errors.is_empty()
66    }
67
68    /// Get the type for an expression ID.
69    pub fn get_type(&self, expr_id: i64) -> Option<&CelType> {
70        self.type_map.get(&expr_id)
71    }
72
73    /// Get the reference for an expression ID.
74    pub fn get_reference(&self, expr_id: i64) -> Option<&ReferenceInfo> {
75        self.reference_map.get(&expr_id)
76    }
77}
78
79/// Type checker for CEL expressions.
80///
81/// The checker takes raw data (variables, functions, container) rather than
82/// a type environment struct, making it independent and reusable.
83pub struct Checker<'a> {
84    /// Scope stack for variable resolution (managed internally).
85    scopes: ScopeStack,
86    /// Function declarations indexed by name.
87    functions: &'a HashMap<String, FunctionDecl>,
88    /// Container namespace for qualified name resolution.
89    container: &'a str,
90    /// Map from expression ID to inferred type.
91    type_map: HashMap<i64, CelType>,
92    /// Map from expression ID to resolved reference.
93    reference_map: HashMap<i64, ReferenceInfo>,
94    /// Type checking errors.
95    errors: Vec<CheckError>,
96    /// Type parameter substitutions.
97    substitutions: HashMap<Arc<str>, CelType>,
98    /// Proto type registry for resolving protobuf types.
99    proto_types: Option<&'a ProtoTypeRegistry>,
100}
101
102impl<'a> Checker<'a> {
103    /// Create a new type checker with the given data.
104    ///
105    /// # Arguments
106    /// * `variables` - Variable declarations (name -> type)
107    /// * `functions` - Function declarations indexed by name
108    /// * `container` - Container namespace for qualified name resolution
109    pub fn new(
110        variables: &HashMap<String, CelType>,
111        functions: &'a HashMap<String, FunctionDecl>,
112        container: &'a str,
113    ) -> Self {
114        let mut scopes = ScopeStack::new();
115
116        // Add variables to the initial scope
117        for (name, cel_type) in variables {
118            scopes.add_variable(name, cel_type.clone());
119        }
120
121        Self {
122            scopes,
123            functions,
124            container,
125            type_map: HashMap::new(),
126            reference_map: HashMap::new(),
127            errors: Vec::new(),
128            substitutions: HashMap::new(),
129            proto_types: None,
130        }
131    }
132
133    /// Set the proto type registry for resolving protobuf types.
134    pub fn with_proto_types(mut self, registry: &'a ProtoTypeRegistry) -> Self {
135        self.proto_types = Some(registry);
136        self
137    }
138
139    /// Type check an expression and return the result.
140    pub fn check(mut self, expr: &SpannedExpr) -> CheckResult {
141        self.check_expr(expr);
142        self.finalize_types();
143
144        CheckResult {
145            type_map: self.type_map,
146            reference_map: self.reference_map,
147            errors: self.errors,
148        }
149    }
150
151    /// Store the type for an expression.
152    fn set_type(&mut self, expr_id: i64, cel_type: CelType) {
153        self.type_map.insert(expr_id, cel_type);
154    }
155
156    /// Store a reference for an expression.
157    fn set_reference(&mut self, expr_id: i64, reference: ReferenceInfo) {
158        self.reference_map.insert(expr_id, reference);
159    }
160
161    /// Report an error.
162    fn report_error(&mut self, error: CheckError) {
163        self.errors.push(error);
164    }
165
166    /// Finalize all types by replacing unbound type variables with Dyn.
167    fn finalize_types(&mut self) {
168        for ty in self.type_map.values_mut() {
169            *ty = finalize_type(ty);
170            *ty = substitute_type(ty, &self.substitutions);
171            *ty = finalize_type(ty);
172        }
173    }
174
175    /// Type check an expression and return its type.
176    fn check_expr(&mut self, expr: &SpannedExpr) -> CelType {
177        let result = match &expr.node {
178            Expr::Null => CelType::Null,
179            Expr::Bool(_) => CelType::Bool,
180            Expr::Int(_) => CelType::Int,
181            Expr::UInt(_) => CelType::UInt,
182            Expr::Float(_) => CelType::Double,
183            Expr::String(_) => CelType::String,
184            Expr::Bytes(_) => CelType::Bytes,
185
186            Expr::Ident(name) => self.check_ident(name, expr),
187            Expr::RootIdent(name) => self.check_ident(name, expr),
188
189            Expr::List(elements) => self.check_list(elements, expr),
190            Expr::Map(entries) => self.check_map(entries, expr),
191
192            Expr::Unary { op, expr: inner } => self.check_unary(*op, inner, expr),
193            Expr::Binary { op, left, right } => self.check_binary(*op, left, right, expr),
194            Expr::Ternary { cond, then_expr, else_expr } => {
195                self.check_ternary(cond, then_expr, else_expr, expr)
196            }
197
198            Expr::Member { expr: obj, field, optional } => {
199                self.check_member(obj, field, *optional, expr)
200            }
201            Expr::Index { expr: obj, index, optional } => {
202                self.check_index(obj, index, *optional, expr)
203            }
204            Expr::Call { expr: callee, args } => self.check_call(callee, args, expr),
205            Expr::Struct { type_name, fields } => self.check_struct(type_name, fields, expr),
206
207            Expr::Comprehension {
208                iter_var,
209                iter_var2,
210                iter_range,
211                accu_var,
212                accu_init,
213                loop_condition,
214                loop_step,
215                result,
216            } => self.check_comprehension(
217                iter_var,
218                iter_var2,
219                iter_range,
220                accu_var,
221                accu_init,
222                loop_condition,
223                loop_step,
224                result,
225                expr,
226            ),
227
228            Expr::MemberTestOnly { expr: obj, field } => {
229                self.check_member_test(obj, field, expr)
230            }
231
232            Expr::Bind { var_name, init, body } => {
233                self.check_bind(var_name, init, body, expr)
234            }
235
236            Expr::Error => CelType::Error,
237        };
238
239        self.set_type(expr.id, result.clone());
240        result
241    }
242
243    /// Check an identifier expression.
244    fn check_ident(&mut self, name: &str, expr: &SpannedExpr) -> CelType {
245        if let Some(decl) = self.scopes.resolve(name) {
246            let cel_type = decl.cel_type.clone();
247            self.set_reference(expr.id, ReferenceInfo::ident(name));
248            cel_type
249        } else {
250            self.report_error(CheckError::undeclared_reference(name, expr.span.clone(), expr.id));
251            CelType::Error
252        }
253    }
254
255    /// Check a list literal.
256    fn check_list(&mut self, elements: &[ListElement], _expr: &SpannedExpr) -> CelType {
257        if elements.is_empty() {
258            return CelType::list(CelType::fresh_type_var());
259        }
260
261        let mut elem_types = Vec::new();
262        for elem in elements {
263            let elem_type = self.check_expr(&elem.expr);
264            elem_types.push(elem_type);
265        }
266
267        let joined = self.join_types(&elem_types);
268        CelType::list(joined)
269    }
270
271    /// Check a map literal.
272    fn check_map(&mut self, entries: &[MapEntry], _expr: &SpannedExpr) -> CelType {
273        if entries.is_empty() {
274            return CelType::map(CelType::fresh_type_var(), CelType::fresh_type_var());
275        }
276
277        let mut key_types = Vec::new();
278        let mut value_types = Vec::new();
279
280        for entry in entries {
281            let key_type = self.check_expr(&entry.key);
282            let value_type = self.check_expr(&entry.value);
283            key_types.push(key_type);
284            value_types.push(value_type);
285        }
286
287        let key_joined = self.join_types(&key_types);
288        let value_joined = self.join_types(&value_types);
289
290        CelType::map(key_joined, value_joined)
291    }
292
293    /// Join multiple types into a common type.
294    fn join_types(&self, types: &[CelType]) -> CelType {
295        if types.is_empty() {
296            return CelType::fresh_type_var();
297        }
298
299        let first = &types[0];
300
301        // Check if all types are the same
302        if types.iter().all(|t| t == first || matches!(t, CelType::Dyn) || matches!(first, CelType::Dyn)) {
303            if matches!(first, CelType::Dyn) {
304                // If first is Dyn, try to find a concrete type
305                for t in types {
306                    if !matches!(t, CelType::Dyn) {
307                        return t.clone();
308                    }
309                }
310            }
311            return first.clone();
312        }
313
314        // Check if all types are assignable to each other
315        let all_compatible = types.iter().all(|t| {
316            first.is_assignable_from(t) || t.is_assignable_from(first)
317        });
318
319        if all_compatible {
320            first.clone()
321        } else {
322            CelType::Dyn
323        }
324    }
325
326    /// Check a unary operation.
327    fn check_unary(&mut self, op: UnaryOp, inner: &SpannedExpr, expr: &SpannedExpr) -> CelType {
328        let inner_type = self.check_expr(inner);
329
330        let func_name = match op {
331            UnaryOp::Neg => "-_",
332            UnaryOp::Not => "!_",
333        };
334
335        self.resolve_function_call(func_name, None, &[inner_type], expr)
336    }
337
338    /// Check a binary operation.
339    fn check_binary(
340        &mut self,
341        op: BinaryOp,
342        left: &SpannedExpr,
343        right: &SpannedExpr,
344        expr: &SpannedExpr,
345    ) -> CelType {
346        let left_type = self.check_expr(left);
347        let right_type = self.check_expr(right);
348
349        let func_name = binary_op_to_function(op);
350
351        self.resolve_function_call(func_name, None, &[left_type, right_type], expr)
352    }
353
354    /// Check a ternary expression.
355    fn check_ternary(
356        &mut self,
357        cond: &SpannedExpr,
358        then_expr: &SpannedExpr,
359        else_expr: &SpannedExpr,
360        expr: &SpannedExpr,
361    ) -> CelType {
362        let cond_type = self.check_expr(cond);
363        let then_type = self.check_expr(then_expr);
364        let else_type = self.check_expr(else_expr);
365
366        // Condition must be bool
367        if !matches!(cond_type, CelType::Bool | CelType::Dyn | CelType::Error) {
368            self.report_error(CheckError::type_mismatch(
369                CelType::Bool,
370                cond_type,
371                cond.span.clone(),
372                cond.id,
373            ));
374        }
375
376        // Use ternary operator for type resolution
377        self.resolve_function_call("_?_:_", None, &[CelType::Bool, then_type, else_type], expr)
378    }
379
380    /// Check a member access expression.
381    fn check_member(&mut self, obj: &SpannedExpr, field: &str, optional: bool, expr: &SpannedExpr) -> CelType {
382        // First, try to resolve as qualified identifier (e.g., pkg.Type)
383        if let Some(qualified_name) = self.try_qualified_name(obj, field) {
384            // Try variable/type resolution first
385            if let Some(decl) = self.resolve_qualified(&qualified_name) {
386                let cel_type = decl.cel_type.clone();
387                self.set_reference(expr.id, ReferenceInfo::ident(&qualified_name));
388                return cel_type;
389            }
390
391            // Try proto type resolution (enum values, message types)
392            if let Some(resolved) = self.resolve_proto_qualified(&qualified_name, expr) {
393                return resolved;
394            }
395        }
396
397        // Otherwise, it's a field access
398        let obj_type = self.check_expr(obj);
399
400        // Unwrap optional types for field access
401        let (inner_type, was_optional) = match &obj_type {
402            CelType::Optional(inner) => ((**inner).clone(), true),
403            other => (other.clone(), false),
404        };
405
406        // Check for well-known types with fields
407        let result = match &inner_type {
408            CelType::Message(name) => {
409                // Try to get field type from proto registry
410                if let Some(registry) = self.proto_types {
411                    if let Some(field_type) = registry.get_field_type(name, field) {
412                        return self.wrap_optional_if_needed(field_type, optional, was_optional);
413                    }
414                }
415                // Fall back to Dyn if no registry or field not found
416                CelType::Dyn
417            }
418            CelType::Dyn | CelType::TypeVar(_) => {
419                // For Dyn and type variables, we can't verify field existence statically
420                CelType::Dyn
421            }
422            CelType::Map(_, value_type) => {
423                // Map field access returns the value type
424                (**value_type).clone()
425            }
426            _ => {
427                // Other types don't support field access
428                self.report_error(CheckError::undefined_field(
429                    &inner_type.display_name(),
430                    field,
431                    expr.span.clone(),
432                    expr.id,
433                ));
434                return CelType::Error;
435            }
436        };
437
438        self.wrap_optional_if_needed(result, optional, was_optional)
439    }
440
441    /// Wrap a type in optional if needed.
442    fn wrap_optional_if_needed(&self, result: CelType, optional: bool, was_optional: bool) -> CelType {
443        // Wrap in optional if using optional select (.?) or receiver was optional
444        // But flatten nested optionals - CEL semantics say chaining doesn't create optional<optional<T>>
445        if optional || was_optional {
446            match &result {
447                CelType::Optional(_) => result, // Already optional, don't double-wrap
448                _ => CelType::optional(result),
449            }
450        } else {
451            result
452        }
453    }
454
455    /// Try to resolve a qualified name as a proto type.
456    fn resolve_proto_qualified(&mut self, qualified_name: &str, expr: &SpannedExpr) -> Option<CelType> {
457        let registry = self.proto_types?;
458        let parts: Vec<&str> = qualified_name.split('.').collect();
459
460        match registry.resolve_qualified(&parts, self.container)? {
461            ResolvedProtoType::EnumValue { enum_name: _, value } => {
462                self.set_reference(expr.id, ReferenceInfo {
463                    name: qualified_name.to_string(),
464                    overload_ids: vec![],
465                    value: Some(CelValue::Int(value as i64)),
466                });
467                Some(CelType::Int)
468            }
469            ResolvedProtoType::Enum { name, cel_type } => {
470                self.set_reference(expr.id, ReferenceInfo::ident(&name));
471                Some(cel_type)
472            }
473            ResolvedProtoType::Message { name, cel_type } => {
474                self.set_reference(expr.id, ReferenceInfo::ident(&name));
475                Some(cel_type)
476            }
477        }
478    }
479
480    /// Try to build a qualified name from a member chain.
481    fn try_qualified_name(&self, obj: &SpannedExpr, field: &str) -> Option<String> {
482        match &obj.node {
483            Expr::Ident(name) => Some(format!("{}.{}", name, field)),
484            Expr::RootIdent(name) => Some(format!(".{}.{}", name, field)),
485            Expr::Member { expr: inner, field: inner_field, .. } => {
486                let prefix = self.try_qualified_name(inner, inner_field)?;
487                Some(format!("{}.{}", prefix, field))
488            }
489            _ => None,
490        }
491    }
492
493    /// Try to resolve a qualified name (e.g., `pkg.Type`).
494    ///
495    /// This checks for the name in the following order:
496    /// 1. As-is
497    /// 2. Prepended with container
498    fn resolve_qualified(&self, name: &str) -> Option<&VariableDecl> {
499        // Try as-is first
500        if let Some(decl) = self.scopes.resolve(name) {
501            return Some(decl);
502        }
503
504        // Try with container prefix
505        if !self.container.is_empty() {
506            let qualified = format!("{}.{}", self.container, name);
507            if let Some(decl) = self.scopes.resolve(&qualified) {
508                return Some(decl);
509            }
510        }
511
512        None
513    }
514
515    /// Check an index access expression.
516    fn check_index(&mut self, obj: &SpannedExpr, index: &SpannedExpr, optional: bool, expr: &SpannedExpr) -> CelType {
517        let obj_type = self.check_expr(obj);
518        let index_type = self.check_expr(index);
519
520        // Unwrap optional types for index access
521        let (inner_type, was_optional) = match &obj_type {
522            CelType::Optional(inner) => ((**inner).clone(), true),
523            other => (other.clone(), false),
524        };
525
526        // Resolve index operation on inner type
527        let result = self.resolve_function_call("_[_]", None, &[inner_type, index_type], expr);
528
529        // Wrap in optional if using optional index ([?]) or receiver was optional
530        // But flatten nested optionals - CEL semantics say chaining doesn't create optional<optional<T>>
531        if optional || was_optional {
532            match &result {
533                CelType::Optional(_) => result, // Already optional, don't double-wrap
534                _ => CelType::optional(result),
535            }
536        } else {
537            result
538        }
539    }
540
541    /// Check a function call expression.
542    fn check_call(&mut self, callee: &SpannedExpr, args: &[SpannedExpr], expr: &SpannedExpr) -> CelType {
543        // Determine if this is a method call or standalone call
544        match &callee.node {
545            Expr::Member { expr: receiver, field: func_name, .. } => {
546                // First, try to resolve as a namespaced function (e.g., math.greatest)
547                if let Some(qualified_name) = self.try_qualified_function_name(receiver, func_name) {
548                    if self.functions.contains_key(&qualified_name) {
549                        let arg_types: Vec<_> = args.iter().map(|a| self.check_expr(a)).collect();
550                        return self.resolve_function_call(&qualified_name, None, &arg_types, expr);
551                    }
552                }
553
554                // Fall back to method call: receiver.method(args)
555                let receiver_type = self.check_expr(receiver);
556                let arg_types: Vec<_> = args.iter().map(|a| self.check_expr(a)).collect();
557                self.resolve_function_call(func_name, Some(receiver_type), &arg_types, expr)
558            }
559            Expr::Ident(func_name) => {
560                // Standalone call: func(args)
561                let arg_types: Vec<_> = args.iter().map(|a| self.check_expr(a)).collect();
562                self.resolve_function_call(func_name, None, &arg_types, expr)
563            }
564            _ => {
565                // Expression call (unusual)
566                let _ = self.check_expr(callee);
567                for arg in args {
568                    self.check_expr(arg);
569                }
570                CelType::Dyn
571            }
572        }
573    }
574
575    /// Try to build a qualified function name from a member chain.
576    ///
577    /// This supports namespaced functions like `math.greatest` or `strings.quote`.
578    fn try_qualified_function_name(&self, obj: &SpannedExpr, field: &str) -> Option<String> {
579        match &obj.node {
580            Expr::Ident(name) => Some(format!("{}.{}", name, field)),
581            Expr::Member { expr: inner, field: inner_field, .. } => {
582                let prefix = self.try_qualified_function_name(inner, inner_field)?;
583                Some(format!("{}.{}", prefix, field))
584            }
585            _ => None,
586        }
587    }
588
589    /// Resolve a function call and return the result type.
590    fn resolve_function_call(
591        &mut self,
592        name: &str,
593        receiver: Option<CelType>,
594        args: &[CelType],
595        expr: &SpannedExpr,
596    ) -> CelType {
597        if let Some(func) = self.functions.get(name) {
598            let func = func.clone(); // Clone to avoid borrow conflict
599            self.resolve_with_function(&func, receiver, args, expr)
600        } else {
601            self.report_error(CheckError::undeclared_reference(
602                name,
603                expr.span.clone(),
604                expr.id,
605            ));
606            CelType::Error
607        }
608    }
609
610    /// Resolve a function call with a known function declaration.
611    fn resolve_with_function(
612        &mut self,
613        func: &FunctionDecl,
614        receiver: Option<CelType>,
615        args: &[CelType],
616        expr: &SpannedExpr,
617    ) -> CelType {
618        if let Some(result) = resolve_overload(
619            func,
620            receiver.as_ref(),
621            args,
622            &mut self.substitutions,
623        ) {
624            self.set_reference(expr.id, ReferenceInfo::function(&func.name, result.overload_ids));
625            result.result_type
626        } else {
627            let all_args: Vec<_> = receiver.iter().cloned().chain(args.iter().cloned()).collect();
628            self.report_error(CheckError::no_matching_overload(
629                &func.name,
630                all_args,
631                expr.span.clone(),
632                expr.id,
633            ));
634            CelType::Error
635        }
636    }
637
638    /// Check a struct literal expression.
639    fn check_struct(
640        &mut self,
641        type_name: &SpannedExpr,
642        fields: &[StructField],
643        expr: &SpannedExpr,
644    ) -> CelType {
645        // Get the type name
646        let name = self.get_type_name(type_name);
647
648        // Check field values
649        for field in fields {
650            self.check_expr(&field.value);
651        }
652
653        // Return a message type
654        if let Some(ref name) = name {
655            // Try to resolve to fully qualified name using proto registry
656            let fq_name = if let Some(registry) = self.proto_types {
657                registry.resolve_message_name(name, self.container)
658                    .unwrap_or_else(|| name.clone())
659            } else {
660                name.clone()
661            };
662
663            self.set_reference(expr.id, ReferenceInfo::ident(&fq_name));
664            CelType::message(&fq_name)
665        } else {
666            CelType::Dyn
667        }
668    }
669
670    /// Get the type name from a type expression.
671    fn get_type_name(&self, expr: &SpannedExpr) -> Option<String> {
672        match &expr.node {
673            Expr::Ident(name) => Some(name.clone()),
674            Expr::RootIdent(name) => Some(format!(".{}", name)),
675            Expr::Member { expr: inner, field, .. } => {
676                let prefix = self.get_type_name(inner)?;
677                Some(format!("{}.{}", prefix, field))
678            }
679            _ => None,
680        }
681    }
682
683    /// Check a comprehension expression.
684    fn check_comprehension(
685        &mut self,
686        iter_var: &str,
687        iter_var2: &str,
688        iter_range: &SpannedExpr,
689        accu_var: &str,
690        accu_init: &SpannedExpr,
691        loop_condition: &SpannedExpr,
692        loop_step: &SpannedExpr,
693        result: &SpannedExpr,
694        _expr: &SpannedExpr,
695    ) -> CelType {
696        // Check iter_range in outer scope
697        let range_type = self.check_expr(iter_range);
698
699        // Determine iteration variable type from range
700        let iter_type = match &range_type {
701            CelType::List(elem) => (**elem).clone(),
702            CelType::Map(key, _) => (**key).clone(),
703            CelType::Optional(inner) => (**inner).clone(), // For optMap/optFlatMap macros
704            CelType::Dyn => CelType::Dyn,
705            _ => CelType::Dyn,
706        };
707
708        // Check accu_init in outer scope
709        let accu_type = self.check_expr(accu_init);
710
711        // Enter new scope for comprehension body
712        self.scopes.enter_scope();
713
714        // Bind iteration variable(s)
715        self.scopes.add_variable(iter_var, iter_type.clone());
716        if !iter_var2.is_empty() {
717            // For two-variable iteration (maps), bind second var
718            let iter_type2 = match &range_type {
719                CelType::Map(_, value) => (**value).clone(),
720                _ => CelType::Dyn,
721            };
722            self.scopes.add_variable(iter_var2, iter_type2);
723        }
724
725        // Bind accumulator variable
726        self.scopes.add_variable(accu_var, accu_type.clone());
727
728        // Check loop_condition (must be bool)
729        let cond_type = self.check_expr(loop_condition);
730        if !matches!(cond_type, CelType::Bool | CelType::Dyn | CelType::Error) {
731            self.report_error(CheckError::type_mismatch(
732                CelType::Bool,
733                cond_type,
734                loop_condition.span.clone(),
735                loop_condition.id,
736            ));
737        }
738
739        // Check loop_step (should match accu type)
740        let _ = self.check_expr(loop_step);
741
742        // Check result
743        let result_type = self.check_expr(result);
744
745        // Exit comprehension scope
746        self.scopes.exit_scope();
747
748        result_type
749    }
750
751    /// Check a member test (has() macro result).
752    fn check_member_test(&mut self, obj: &SpannedExpr, _field: &str, _expr: &SpannedExpr) -> CelType {
753        // Check the object
754        let _ = self.check_expr(obj);
755
756        // has() always returns bool
757        CelType::Bool
758    }
759
760    /// Check a bind expression (cel.bind macro result).
761    ///
762    /// `cel.bind(var, init, body)` binds a variable to a value for use in the body.
763    fn check_bind(
764        &mut self,
765        var_name: &str,
766        init: &SpannedExpr,
767        body: &SpannedExpr,
768        _expr: &SpannedExpr,
769    ) -> CelType {
770        // Check the initializer expression
771        let init_type = self.check_expr(init);
772
773        // Enter a new scope with the bound variable
774        self.scopes.enter_scope();
775        self.scopes.add_variable(var_name, init_type);
776
777        // Check the body in the new scope
778        let body_type = self.check_expr(body);
779
780        // Exit the scope
781        self.scopes.exit_scope();
782
783        // The type of the bind expression is the type of the body
784        body_type
785    }
786}
787
788/// Check an expression and return the result.
789///
790/// This is the main entry point for type checking. It takes raw data rather
791/// than a type environment struct, making it independent and reusable.
792///
793/// # Arguments
794/// * `expr` - The expression to type check
795/// * `variables` - Variable declarations (name -> type)
796/// * `functions` - Function declarations indexed by name
797/// * `container` - Container namespace for qualified name resolution
798pub fn check(
799    expr: &SpannedExpr,
800    variables: &HashMap<String, CelType>,
801    functions: &HashMap<String, FunctionDecl>,
802    container: &str,
803) -> CheckResult {
804    let checker = Checker::new(variables, functions, container);
805    checker.check(expr)
806}
807
808/// Check an expression with proto type registry.
809///
810/// This is like `check`, but also takes a proto type registry for resolving
811/// protobuf types during type checking.
812pub fn check_with_proto_types(
813    expr: &SpannedExpr,
814    variables: &HashMap<String, CelType>,
815    functions: &HashMap<String, FunctionDecl>,
816    container: &str,
817    proto_types: &ProtoTypeRegistry,
818) -> CheckResult {
819    let checker = Checker::new(variables, functions, container)
820        .with_proto_types(proto_types);
821    checker.check(expr)
822}
823
824/// Convert a binary operator to its function name.
825fn binary_op_to_function(op: BinaryOp) -> &'static str {
826    match op {
827        BinaryOp::Add => "_+_",
828        BinaryOp::Sub => "_-_",
829        BinaryOp::Mul => "_*_",
830        BinaryOp::Div => "_/_",
831        BinaryOp::Mod => "_%_",
832        BinaryOp::Eq => "_==_",
833        BinaryOp::Ne => "_!=_",
834        BinaryOp::Lt => "_<_",
835        BinaryOp::Le => "_<=_",
836        BinaryOp::Gt => "_>_",
837        BinaryOp::Ge => "_>=_",
838        BinaryOp::And => "_&&_",
839        BinaryOp::Or => "_||_",
840        BinaryOp::In => "@in",
841    }
842}
843
844#[cfg(test)]
845mod tests {
846    use super::*;
847    use super::super::errors::CheckErrorKind;
848    use super::super::standard_library::STANDARD_LIBRARY;
849    use crate::parser::parse;
850
851    /// Build the standard library functions map.
852    fn standard_functions() -> HashMap<String, FunctionDecl> {
853        STANDARD_LIBRARY
854            .iter()
855            .map(|f| (f.name.clone(), f.clone()))
856            .collect()
857    }
858
859    /// Build the standard type constants.
860    fn standard_variables() -> HashMap<String, CelType> {
861        let mut vars = HashMap::new();
862        vars.insert("bool".to_string(), CelType::type_of(CelType::Bool));
863        vars.insert("int".to_string(), CelType::type_of(CelType::Int));
864        vars.insert("uint".to_string(), CelType::type_of(CelType::UInt));
865        vars.insert("double".to_string(), CelType::type_of(CelType::Double));
866        vars.insert("string".to_string(), CelType::type_of(CelType::String));
867        vars.insert("bytes".to_string(), CelType::type_of(CelType::Bytes));
868        vars.insert("list".to_string(), CelType::type_of(CelType::list(CelType::Dyn)));
869        vars.insert("map".to_string(), CelType::type_of(CelType::map(CelType::Dyn, CelType::Dyn)));
870        vars.insert("null_type".to_string(), CelType::type_of(CelType::Null));
871        vars.insert("type".to_string(), CelType::type_of(CelType::type_of(CelType::Dyn)));
872        vars.insert("dyn".to_string(), CelType::type_of(CelType::Dyn));
873        vars
874    }
875
876    fn check_expr(source: &str) -> CheckResult {
877        let result = parse(source);
878        let ast = result.ast.expect("parse should succeed");
879        let variables = standard_variables();
880        let functions = standard_functions();
881        check(&ast, &variables, &functions, "")
882    }
883
884    fn check_expr_with_var(source: &str, var: &str, cel_type: CelType) -> CheckResult {
885        let result = parse(source);
886        let ast = result.ast.expect("parse should succeed");
887        let mut variables = standard_variables();
888        variables.insert(var.to_string(), cel_type);
889        let functions = standard_functions();
890        check(&ast, &variables, &functions, "")
891    }
892
893    #[test]
894    fn test_literal_types() {
895        assert_eq!(check_expr("null").get_type(1), Some(&CelType::Null));
896        assert_eq!(check_expr("true").get_type(1), Some(&CelType::Bool));
897        assert_eq!(check_expr("42").get_type(1), Some(&CelType::Int));
898        assert_eq!(check_expr("42u").get_type(1), Some(&CelType::UInt));
899        assert_eq!(check_expr("3.14").get_type(1), Some(&CelType::Double));
900        assert_eq!(check_expr("\"hello\"").get_type(1), Some(&CelType::String));
901        assert_eq!(check_expr("b\"hello\"").get_type(1), Some(&CelType::Bytes));
902    }
903
904    #[test]
905    fn test_undefined_variable() {
906        let result = check_expr("x");
907        assert!(!result.is_ok());
908        assert!(result.errors.iter().any(|e| matches!(
909            &e.kind,
910            CheckErrorKind::UndeclaredReference { name, .. } if name == "x"
911        )));
912    }
913
914    #[test]
915    fn test_defined_variable() {
916        let result = check_expr_with_var("x", "x", CelType::Int);
917        assert!(result.is_ok());
918        assert_eq!(result.get_type(1), Some(&CelType::Int));
919    }
920
921    #[test]
922    fn test_binary_add_int() {
923        let result = check_expr_with_var("x + 1", "x", CelType::Int);
924        assert!(result.is_ok());
925        // The binary expression has ID 3 (after x=1 and 1=2)
926        let types: Vec<_> = result.type_map.values().collect();
927        assert!(types.contains(&&CelType::Int));
928    }
929
930    #[test]
931    fn test_list_literal() {
932        let result = check_expr("[1, 2, 3]");
933        assert!(result.is_ok());
934        // Find the list type
935        let list_types: Vec<_> = result.type_map.values()
936            .filter(|t| matches!(t, CelType::List(_)))
937            .collect();
938        assert_eq!(list_types.len(), 1);
939        assert_eq!(list_types[0], &CelType::list(CelType::Int));
940    }
941
942    #[test]
943    fn test_map_literal() {
944        let result = check_expr("{\"a\": 1, \"b\": 2}");
945        assert!(result.is_ok());
946        // Find the map type
947        let map_types: Vec<_> = result.type_map.values()
948            .filter(|t| matches!(t, CelType::Map(_, _)))
949            .collect();
950        assert_eq!(map_types.len(), 1);
951        assert_eq!(map_types[0], &CelType::map(CelType::String, CelType::Int));
952    }
953
954    #[test]
955    fn test_comparison() {
956        let result = check_expr_with_var("x > 0", "x", CelType::Int);
957        assert!(result.is_ok());
958        // The result type should be Bool
959        let bool_types: Vec<_> = result.type_map.values()
960            .filter(|t| matches!(t, CelType::Bool))
961            .collect();
962        assert!(!bool_types.is_empty());
963    }
964
965    #[test]
966    fn test_method_call() {
967        let result = check_expr("\"hello\".contains(\"lo\")");
968        assert!(result.is_ok());
969        // The result should be Bool
970        let bool_types: Vec<_> = result.type_map.values()
971            .filter(|t| matches!(t, CelType::Bool))
972            .collect();
973        assert!(!bool_types.is_empty());
974    }
975
976    #[test]
977    fn test_size_method() {
978        let result = check_expr("\"hello\".size()");
979        assert!(result.is_ok());
980        let int_types: Vec<_> = result.type_map.values()
981            .filter(|t| matches!(t, CelType::Int))
982            .collect();
983        assert!(!int_types.is_empty());
984    }
985
986    #[test]
987    fn test_ternary() {
988        let result = check_expr_with_var("x ? 1 : 2", "x", CelType::Bool);
989        assert!(result.is_ok());
990    }
991
992    #[test]
993    fn test_type_mismatch_addition() {
994        let result = check_expr_with_var("x + \"str\"", "x", CelType::Int);
995        assert!(!result.is_ok());
996        assert!(result.errors.iter().any(|e| matches!(
997            &e.kind,
998            CheckErrorKind::NoMatchingOverload { function, .. } if function == "_+_"
999        )));
1000    }
1001
1002    #[test]
1003    fn test_empty_list() {
1004        let result = check_expr("[]");
1005        assert!(result.is_ok());
1006        // Empty list should have a type variable element type that gets finalized to Dyn
1007        let list_types: Vec<_> = result.type_map.values()
1008            .filter(|t| matches!(t, CelType::List(_)))
1009            .collect();
1010        assert_eq!(list_types.len(), 1);
1011    }
1012
1013    #[test]
1014    fn test_reference_recording() {
1015        let result = check_expr_with_var("x + 1", "x", CelType::Int);
1016        assert!(result.is_ok());
1017
1018        // Should have a reference for the identifier
1019        let refs: Vec<_> = result.reference_map.values().collect();
1020        assert!(refs.iter().any(|r| r.name == "x"));
1021
1022        // Should have a reference for the operator
1023        assert!(refs.iter().any(|r| r.name == "_+_"));
1024    }
1025}