Skip to main content

openjd_expr/
function_library.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5//! Function library: signature-based multiple dispatch for expression evaluation.
6
7use crate::error::{ExpressionError, ExpressionErrorKind};
8
9fn friendly_op_name(name: &str) -> Option<&'static str> {
10    match name {
11        "__add__" => Some("+"),
12        "__sub__" => Some("-"),
13        "__mul__" => Some("*"),
14        "__truediv__" => Some("/"),
15        "__floordiv__" => Some("//"),
16        "__mod__" => Some("%"),
17        "__pow__" => Some("**"),
18        "__neg__" => Some("-"),
19        "__pos__" => Some("+"),
20        "__eq__" => Some("=="),
21        "__ne__" => Some("!="),
22        "__lt__" => Some("<"),
23        "__le__" => Some("<="),
24        "__gt__" => Some(">"),
25        "__ge__" => Some(">="),
26        "__contains__" => Some("in"),
27        "__not_contains__" => Some("not in"),
28        _ => None,
29    }
30}
31
32use crate::types::ExprType;
33use crate::value::ExprValue;
34use std::collections::HashMap;
35use std::sync::Arc;
36
37/// Type alias for a boxed function implementation.
38///
39/// Uses `Arc<dyn Fn>` rather than a bare `fn` pointer so that closures
40/// (capturing environment) can be registered alongside plain functions.
41/// `Arc` (not `Box`) keeps `FunctionLibrary` `Clone`.
42pub type FunctionImpl = Arc<
43    dyn Fn(&mut dyn EvalContext, &[ExprValue]) -> Result<ExprValue, ExpressionError> + Send + Sync,
44>;
45
46/// A registered function overload.
47#[derive(Clone)]
48pub struct FunctionEntry {
49    pub signature: ExprType, // TypeCode::Signature
50    pub implementation: FunctionImpl,
51}
52
53/// Trait for the evaluator context that function implementations can access.
54pub trait EvalContext {
55    fn path_format(&self) -> crate::path_mapping::PathFormat;
56    fn count_op(&mut self) -> Result<(), ExpressionError>;
57    fn count_ops(&mut self, n: usize) -> Result<(), ExpressionError>;
58    fn count_string_ops(&mut self, len: usize) -> Result<(), ExpressionError>;
59    /// Pre-check that an allocation of `bytes` would not exceed the memory limit.
60    /// Call before large allocations to avoid temporarily exceeding the limit.
61    fn check_memory(&self, bytes: usize) -> Result<(), ExpressionError>;
62    fn get_or_compile_regex(&mut self, pattern: &str) -> Result<regex::Regex, ExpressionError> {
63        regex::RegexBuilder::new(pattern)
64            .size_limit(1 << 20)
65            .build()
66            .map_err(|e| ExpressionError::new(format!("Invalid regex: {e}")))
67    }
68}
69
70/// Registry of functions available in expressions.
71#[derive(Clone, Default)]
72pub struct FunctionLibrary {
73    functions: HashMap<String, Vec<FunctionEntry>>,
74    pub host_context_enabled: bool,
75}
76
77impl FunctionLibrary {
78    pub fn new() -> Self {
79        Self {
80            functions: HashMap::new(),
81            host_context_enabled: false,
82        }
83    }
84
85    /// Register a function overload with an `ExprType::Signature`.
86    ///
87    /// Accepts either a bare `fn` pointer or a closure — anything implementing
88    /// `Fn(&mut dyn EvalContext, &[ExprValue]) -> Result<ExprValue, ExpressionError>`
89    /// with `Send + Sync + 'static`. Closures enable wrapping host state
90    /// (e.g., AWS clients, config) without plumbing it through `EvalContext`.
91    pub fn register<F>(&mut self, name: &str, signature: ExprType, implementation: F)
92    where
93        F: Fn(&mut dyn EvalContext, &[ExprValue]) -> Result<ExprValue, ExpressionError>
94            + Send
95            + Sync
96            + 'static,
97    {
98        self.functions
99            .entry(name.to_string())
100            .or_default()
101            .push(FunctionEntry {
102                signature,
103                implementation: Arc::new(implementation),
104            });
105    }
106
107    /// Register from spec notation: `lib.register_sig("len", "(list[T1]) -> int", len_list)`
108    ///
109    /// Returns `Err` if `sig_str` cannot be parsed as a valid type
110    /// signature. The function is **not** registered on failure.
111    pub fn register_sig<F>(
112        &mut self,
113        name: &str,
114        sig_str: &str,
115        implementation: F,
116    ) -> Result<(), String>
117    where
118        F: Fn(&mut dyn EvalContext, &[ExprValue]) -> Result<ExprValue, ExpressionError>
119            + Send
120            + Sync
121            + 'static,
122    {
123        let signature = ExprType::parse(sig_str)?;
124        self.register(name, signature, implementation);
125        Ok(())
126    }
127
128    /// Get all entries for a function name.
129    pub fn get_signatures(&self, name: &str) -> &[FunctionEntry] {
130        self.functions
131            .get(name)
132            .map(|v| v.as_slice())
133            .unwrap_or(&[])
134    }
135
136    /// Merge another library's registrations into this one.
137    #[must_use]
138    pub fn merge(mut self, other: FunctionLibrary) -> Self {
139        for (name, entries) in other.functions {
140            self.functions.entry(name).or_default().extend(entries);
141        }
142        self
143    }
144
145    /// Get all registered function names.
146    pub fn function_names(&self) -> impl Iterator<Item = &str> {
147        self.functions.keys().map(|s| s.as_str())
148    }
149
150    /// Dispatch a function call: exact → coerced → generic.
151    pub fn call(
152        &self,
153        name: &str,
154        args: &[ExprValue],
155        ctx: &mut dyn EvalContext,
156    ) -> Result<ExprValue, ExpressionError> {
157        self.call_inner(name, args, ctx, false)
158    }
159
160    /// Dispatch a method call (skip receiver coercion on first arg).
161    pub fn call_method(
162        &self,
163        name: &str,
164        args: &[ExprValue],
165        ctx: &mut dyn EvalContext,
166    ) -> Result<ExprValue, ExpressionError> {
167        self.call_inner(name, args, ctx, true)
168    }
169
170    fn call_inner(
171        &self,
172        name: &str,
173        args: &[ExprValue],
174        ctx: &mut dyn EvalContext,
175        skip_receiver_coercion: bool,
176    ) -> Result<ExprValue, ExpressionError> {
177        let entries = self.get_signatures(name);
178        if entries.is_empty() {
179            return Err(ExpressionError::from_kind(
180                ExpressionErrorKind::UnknownFunction {
181                    name: name.to_string(),
182                },
183            ));
184        }
185        let arg_types: Vec<ExprType> = args
186            .iter()
187            .map(|a| {
188                let t = a.expr_type();
189                if t.code() == crate::types::TypeCode::Unresolved && !t.params().is_empty() {
190                    t.params()[0].clone() // unwrap unresolved for matching
191                } else {
192                    t
193                }
194            })
195            .collect();
196        let any_unresolved = args.iter().any(|a| a.is_unresolved());
197
198        // Phase 1: exact match (non-generic)
199        for entry in entries {
200            if entry.signature.is_symbolic() {
201                continue;
202            }
203            if entry.signature.match_call(&arg_types).is_some() {
204                if any_unresolved {
205                    let ret = entry.signature.sig_return().clone();
206                    return Ok(ExprValue::unresolved(ret));
207                }
208                return (entry.implementation)(ctx, args);
209            }
210        }
211
212        // Phase 2: coerced match (non-generic)
213        for entry in entries {
214            if entry.signature.is_symbolic() {
215                continue;
216            }
217            if any_unresolved {
218                if try_coerce_types(
219                    &arg_types,
220                    entry.signature.sig_params(),
221                    skip_receiver_coercion,
222                )
223                .is_some()
224                {
225                    let ret = entry.signature.sig_return().clone();
226                    return Ok(ExprValue::unresolved(ret));
227                }
228            } else if let Some(coerced) = try_coerce_args(
229                args,
230                &arg_types,
231                entry.signature.sig_params(),
232                skip_receiver_coercion,
233            ) {
234                return (entry.implementation)(ctx, &coerced);
235            }
236        }
237
238        // Phase 3: generic match
239        for entry in entries {
240            if !entry.signature.is_symbolic() {
241                continue;
242            }
243            if let Some(bindings) = entry.signature.match_call(&arg_types) {
244                if any_unresolved {
245                    let ret = entry.signature.sig_return().substitute(&bindings);
246                    return Ok(ExprValue::unresolved(ret));
247                }
248                return (entry.implementation)(ctx, args);
249            }
250            // Try with coercion for generics too
251            if let Some(coerced_types) = try_coerce_types(
252                &arg_types,
253                entry.signature.sig_params(),
254                skip_receiver_coercion,
255            ) {
256                if let Some(bindings) = entry.signature.match_call(&coerced_types) {
257                    let coerced = coerce_values(args, &coerced_types);
258                    if any_unresolved {
259                        let ret = entry.signature.sig_return().substitute(&bindings);
260                        return Ok(ExprValue::unresolved(ret));
261                    }
262                    return (entry.implementation)(ctx, &coerced);
263                }
264            }
265        }
266
267        // Build helpful error message
268        let n_args = arg_types.len();
269        let valid_arities: Vec<usize> = entries
270            .iter()
271            .map(|e| e.signature.sig_params().len())
272            .collect::<std::collections::BTreeSet<_>>()
273            .into_iter()
274            .collect();
275
276        // Check if it's a wrong arg count issue
277        if !valid_arities.contains(&n_args) {
278            let arity_str = if valid_arities.len() == 1 {
279                format!("{}", valid_arities[0])
280            } else {
281                valid_arities
282                    .iter()
283                    .map(|a| a.to_string())
284                    .collect::<Vec<_>>()
285                    .join(", ")
286            };
287            let plural = if valid_arities.len() > 1 {
288                "arguments"
289            } else {
290                "argument(s)"
291            };
292            return Err(ExpressionError::new(format!(
293                "{name}() takes {arity_str} {plural}, but {n_args} were given"
294            )));
295        }
296
297        // Check if it's a receiver type mismatch (method on wrong type)
298        if skip_receiver_coercion && !arg_types.is_empty() {
299            let receiver_type = &arg_types[0];
300            let valid_types: Vec<String> = entries
301                .iter()
302                .filter(|e| e.signature.sig_params().len() == n_args)
303                .filter_map(|e| {
304                    let first = e.signature.sig_params().first()?;
305                    if first.is_symbolic() {
306                        None
307                    } else {
308                        Some(first.to_string())
309                    }
310                })
311                .collect::<std::collections::BTreeSet<_>>()
312                .into_iter()
313                .collect();
314            if !valid_types.is_empty() {
315                return Err(ExpressionError::new(format!(
316                    "{name}() is not available for {receiver_type}. Available for: {}",
317                    valid_types.join(", ")
318                )));
319            }
320        }
321
322        let type_strs: Vec<String> = arg_types.iter().map(|t| t.to_string()).collect();
323        let display_name = friendly_op_name(name);
324        if let Some(op) = display_name {
325            Err(ExpressionError::new(format!(
326                "Cannot use '{}' operator with {}",
327                op,
328                type_strs.join(" and ")
329            )))
330        } else {
331            Err(ExpressionError::new(format!(
332                "No matching signature for {}({})",
333                name,
334                type_strs.join(", ")
335            )))
336        }
337    }
338
339    /// Derive the return type without evaluation (for static type checking).
340    /// Accepts union types in arg_types and returns a union of all possible return types.
341    pub fn derive_return_type(&self, name: &str, arg_types: &[ExprType]) -> Option<ExprType> {
342        // Fast path: if no union args, try exact match first (matches Python's singleton optimization)
343        let has_unions = arg_types
344            .iter()
345            .any(|t| t.code() == crate::types::TypeCode::Union);
346        if !has_unions {
347            for entry in self.get_signatures(name) {
348                if let Some(bindings) = entry.signature.match_call(arg_types) {
349                    return Some(entry.signature.sig_return().substitute(&bindings));
350                }
351            }
352            // Try coercion
353            for entry in self.get_signatures(name) {
354                if let Some(coerced) =
355                    try_coerce_types(arg_types, entry.signature.sig_params(), false)
356                {
357                    if let Some(bindings) = entry.signature.match_call(&coerced) {
358                        return Some(entry.signature.sig_return().substitute(&bindings));
359                    }
360                }
361            }
362            return None;
363        }
364
365        // Union path: flatten each arg's types into sets, then recurse per-signature.
366        // This prunes early: if arg 0 doesn't match a signature's param 0, we skip
367        // all combinations involving that arg type without expanding further args.
368        let arg_type_sets: Vec<Vec<ExprType>> = arg_types
369            .iter()
370            .map(|t| {
371                if t.code() == crate::types::TypeCode::Union {
372                    let mut members = t.params().to_vec();
373                    // Apply implicit coercions: int→float, path→string
374                    let mut coerced = Vec::new();
375                    for m in &members {
376                        if m.code() == crate::types::TypeCode::Int {
377                            coerced.push(ExprType::FLOAT);
378                        } else if m.code() == crate::types::TypeCode::Path {
379                            coerced.push(ExprType::STRING);
380                        }
381                    }
382                    for c in coerced {
383                        if !members.contains(&c) {
384                            members.push(c);
385                        }
386                    }
387                    members
388                } else {
389                    let mut members = vec![t.clone()];
390                    if t.code() == crate::types::TypeCode::Int {
391                        members.push(ExprType::FLOAT);
392                    } else if t.code() == crate::types::TypeCode::Path {
393                        members.push(ExprType::STRING);
394                    }
395                    members
396                }
397            })
398            .collect();
399
400        let mut result_types = Vec::new();
401        for entry in self.get_signatures(name) {
402            let sig_params = entry.signature.sig_params();
403            if sig_params.len() != arg_type_sets.len() {
404                continue;
405            }
406            Self::match_signature_recursive(
407                &entry.signature,
408                sig_params,
409                &arg_type_sets,
410                0,
411                HashMap::new(),
412                &mut result_types,
413            );
414        }
415
416        if result_types.is_empty() {
417            return None;
418        }
419        result_types.sort_by_key(|a| a.to_string());
420        result_types.dedup();
421        if result_types.len() == 1 {
422            Some(result_types.into_iter().next().unwrap())
423        } else {
424            Some(ExprType::union(result_types))
425        }
426    }
427
428    /// Recursively match a signature against argument type sets, pruning
429    /// non-matching branches early instead of expanding the full Cartesian product.
430    fn match_signature_recursive(
431        sig: &ExprType,
432        sig_params: &[ExprType],
433        arg_type_sets: &[Vec<ExprType>],
434        idx: usize,
435        bindings: HashMap<crate::types::TypeCode, ExprType>,
436        result_types: &mut Vec<ExprType>,
437    ) {
438        if idx == arg_type_sets.len() {
439            result_types.push(sig.sig_return().substitute(&bindings));
440            return;
441        }
442        let param = &sig_params[idx];
443        for arg_type in &arg_type_sets[idx] {
444            if let Some(new_binds) = param.match_type(arg_type) {
445                let mut merged = bindings.clone();
446                let mut conflict = false;
447                for (k, v) in new_binds {
448                    if let Some(existing) = merged.get(&k) {
449                        if *existing != v {
450                            conflict = true;
451                            break;
452                        }
453                    }
454                    merged.insert(k, v);
455                }
456                if !conflict {
457                    Self::match_signature_recursive(
458                        sig,
459                        sig_params,
460                        arg_type_sets,
461                        idx + 1,
462                        merged,
463                        result_types,
464                    );
465                }
466            }
467        }
468    }
469
470    /// Get the type of a property access.
471    pub fn get_property_type(&self, base_type: &ExprType, property_name: &str) -> Option<ExprType> {
472        self.derive_return_type(
473            &format!("__property_{property_name}__"),
474            std::slice::from_ref(base_type),
475        )
476    }
477}
478
479/// Implicit coercion rules: int→float, path→string.
480fn can_coerce(from: &ExprType, to: &ExprType) -> bool {
481    (from.code() == crate::types::TypeCode::Int && to.code() == crate::types::TypeCode::Float)
482        || (from.code() == crate::types::TypeCode::Path
483            && to.code() == crate::types::TypeCode::String)
484}
485
486/// Try to coerce argument types to match parameter types.
487fn try_coerce_types(
488    arg_types: &[ExprType],
489    param_types: &[ExprType],
490    skip_first: bool,
491) -> Option<Vec<ExprType>> {
492    if arg_types.len() != param_types.len() {
493        return None;
494    }
495    let mut coerced = Vec::with_capacity(arg_types.len());
496    for (i, (at, pt)) in arg_types.iter().zip(param_types.iter()).enumerate() {
497        if at == pt || pt.code() == crate::types::TypeCode::Any {
498            coerced.push(at.clone());
499        } else if can_coerce(at, pt) && !(skip_first && i == 0) {
500            coerced.push(pt.clone());
501        } else if pt.is_symbolic() {
502            coerced.push(at.clone());
503        } else {
504            return None;
505        }
506    }
507    Some(coerced)
508}
509
510/// Try to coerce argument values to match parameter types.
511fn try_coerce_args(
512    args: &[ExprValue],
513    arg_types: &[ExprType],
514    param_types: &[ExprType],
515    skip_first: bool,
516) -> Option<Vec<ExprValue>> {
517    if args.len() != param_types.len() {
518        return None;
519    }
520    let mut coerced = Vec::with_capacity(args.len());
521    let mut any_changed = false;
522    for (i, (at, pt)) in arg_types.iter().zip(param_types.iter()).enumerate() {
523        if at == pt {
524            coerced.push(args[i].clone());
525        } else if can_coerce(at, pt) && !(skip_first && i == 0) {
526            any_changed = true;
527            match (&args[i], pt.code()) {
528                (ExprValue::Int(v), crate::types::TypeCode::Float) => {
529                    coerced.push(ExprValue::Float(
530                        crate::value::Float64::new(*v as f64).unwrap(),
531                    ));
532                }
533                (ExprValue::Path { value, .. }, crate::types::TypeCode::String) => {
534                    coerced.push(ExprValue::String(value.clone()));
535                }
536                _ => return None,
537            }
538        } else {
539            return None;
540        }
541    }
542    if any_changed {
543        Some(coerced)
544    } else {
545        None
546    }
547}
548
549/// Coerce values to match target types (for generic dispatch after type matching).
550fn coerce_values(args: &[ExprValue], target_types: &[ExprType]) -> Vec<ExprValue> {
551    args.iter()
552        .zip(target_types.iter())
553        .map(|(a, t)| {
554            if can_coerce(&a.expr_type(), t) {
555                match (a, t.code()) {
556                    (ExprValue::Int(v), crate::types::TypeCode::Float) => {
557                        ExprValue::Float(crate::value::Float64::new(*v as f64).unwrap())
558                    }
559                    (ExprValue::Path { value, .. }, crate::types::TypeCode::String) => {
560                        ExprValue::String(value.clone())
561                    }
562                    _ => a.clone(),
563                }
564            } else {
565                a.clone()
566            }
567        })
568        .collect()
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574    use crate::types::ExprType;
575
576    fn dummy_impl(
577        _ctx: &mut dyn EvalContext,
578        args: &[ExprValue],
579    ) -> Result<ExprValue, ExpressionError> {
580        Ok(args.first().cloned().unwrap_or(ExprValue::Null))
581    }
582
583    #[test]
584    fn register_and_get() {
585        let mut lib = FunctionLibrary::new();
586        lib.register_sig("len", "(string) -> int", dummy_impl)
587            .unwrap();
588        lib.register_sig("len", "(list[T1]) -> int", dummy_impl)
589            .unwrap();
590        assert_eq!(lib.get_signatures("len").len(), 2);
591        assert_eq!(lib.get_signatures("missing").len(), 0);
592    }
593
594    #[test]
595    fn merge_libraries() {
596        let mut a = FunctionLibrary::new();
597        a.register_sig("foo", "(int) -> int", dummy_impl).unwrap();
598        let mut b = FunctionLibrary::new();
599        b.register_sig("bar", "(string) -> string", dummy_impl)
600            .unwrap();
601        b.register_sig("foo", "(float) -> float", dummy_impl)
602            .unwrap();
603        let merged = a.merge(b);
604        assert_eq!(merged.get_signatures("foo").len(), 2);
605        assert_eq!(merged.get_signatures("bar").len(), 1);
606    }
607
608    #[test]
609    fn register_sig_parses() {
610        let mut lib = FunctionLibrary::new();
611        lib.register_sig("__getitem__", "(list[T1], int) -> T1", dummy_impl)
612            .unwrap();
613        let sigs = lib.get_signatures("__getitem__");
614        assert_eq!(sigs.len(), 1);
615        assert!(sigs[0].signature.is_symbolic());
616    }
617
618    #[test]
619    fn function_names() {
620        let mut lib = FunctionLibrary::new();
621        lib.register_sig("alpha", "(int) -> int", dummy_impl)
622            .unwrap();
623        lib.register_sig("beta", "(int) -> int", dummy_impl)
624            .unwrap();
625        let names: Vec<&str> = lib.function_names().collect();
626        assert!(names.contains(&"alpha"));
627        assert!(names.contains(&"beta"));
628    }
629
630    // -- Dispatch tests --
631
632    struct MockCtx;
633    impl EvalContext for MockCtx {
634        fn path_format(&self) -> crate::path_mapping::PathFormat {
635            crate::path_mapping::PathFormat::Posix
636        }
637        fn count_op(&mut self) -> Result<(), ExpressionError> {
638            Ok(())
639        }
640        fn count_ops(&mut self, _n: usize) -> Result<(), ExpressionError> {
641            Ok(())
642        }
643        fn count_string_ops(&mut self, _len: usize) -> Result<(), ExpressionError> {
644            Ok(())
645        }
646        fn check_memory(&self, _bytes: usize) -> Result<(), ExpressionError> {
647            Ok(())
648        }
649    }
650
651    fn add_int(
652        _ctx: &mut dyn EvalContext,
653        args: &[ExprValue],
654    ) -> Result<ExprValue, ExpressionError> {
655        match (&args[0], &args[1]) {
656            (ExprValue::Int(a), ExprValue::Int(b)) => Ok(ExprValue::Int(a + b)),
657            _ => Err(ExpressionError::type_error("type error")),
658        }
659    }
660
661    fn add_float(
662        _ctx: &mut dyn EvalContext,
663        args: &[ExprValue],
664    ) -> Result<ExprValue, ExpressionError> {
665        match (&args[0], &args[1]) {
666            (ExprValue::Float(a), ExprValue::Float(b)) => Ok(ExprValue::Float(
667                crate::value::Float64::new(a.value() + b.value())?,
668            )),
669            _ => Err(ExpressionError::type_error("type error")),
670        }
671    }
672
673    fn list_len(
674        _ctx: &mut dyn EvalContext,
675        args: &[ExprValue],
676    ) -> Result<ExprValue, ExpressionError> {
677        Ok(ExprValue::Int(args[0].list_len().unwrap_or(0) as i64))
678    }
679
680    #[test]
681    fn dispatch_exact_match() {
682        let mut lib = FunctionLibrary::new();
683        lib.register_sig("__add__", "(int, int) -> int", add_int)
684            .unwrap();
685        lib.register_sig("__add__", "(float, float) -> float", add_float)
686            .unwrap();
687        let mut ctx = MockCtx;
688        let r = lib
689            .call("__add__", &[ExprValue::Int(2), ExprValue::Int(3)], &mut ctx)
690            .unwrap();
691        assert_eq!(r, ExprValue::Int(5));
692    }
693
694    #[test]
695    fn dispatch_coerced_match() {
696        let mut lib = FunctionLibrary::new();
697        lib.register_sig("__add__", "(float, float) -> float", add_float)
698            .unwrap();
699        let mut ctx = MockCtx;
700        // int + float → coerce int to float
701        let r = lib
702            .call(
703                "__add__",
704                &[
705                    ExprValue::Int(2),
706                    ExprValue::Float(crate::value::Float64::new(3.0).unwrap()),
707                ],
708                &mut ctx,
709            )
710            .unwrap();
711        assert!(matches!(r, ExprValue::Float(_)));
712    }
713
714    #[test]
715    fn dispatch_generic_match() {
716        let mut lib = FunctionLibrary::new();
717        lib.register_sig("len", "(list[T1]) -> int", list_len)
718            .unwrap();
719        let mut ctx = MockCtx;
720        let list = ExprValue::make_list(vec![ExprValue::Int(1), ExprValue::Int(2)], ExprType::INT)
721            .unwrap();
722        let r = lib.call("len", &[list], &mut ctx).unwrap();
723        assert_eq!(r, ExprValue::Int(2));
724    }
725
726    #[test]
727    fn dispatch_unresolved_returns_unresolved() {
728        let mut lib = FunctionLibrary::new();
729        lib.register_sig("__add__", "(int, int) -> int", add_int)
730            .unwrap();
731        let mut ctx = MockCtx;
732        let r = lib
733            .call(
734                "__add__",
735                &[ExprValue::unresolved(ExprType::INT), ExprValue::Int(1)],
736                &mut ctx,
737            )
738            .unwrap();
739        assert!(r.is_unresolved());
740        assert_eq!(r.expr_type(), ExprType::unresolved(ExprType::INT));
741    }
742
743    #[test]
744    fn dispatch_no_match_errors() {
745        let mut lib = FunctionLibrary::new();
746        lib.register_sig("__add__", "(int, int) -> int", add_int)
747            .unwrap();
748        let mut ctx = MockCtx;
749        let r = lib.call(
750            "__add__",
751            &[ExprValue::String("a".into()), ExprValue::Int(1)],
752            &mut ctx,
753        );
754        assert!(r.is_err());
755    }
756
757    #[test]
758    fn dispatch_unknown_function_errors() {
759        let lib = FunctionLibrary::new();
760        let mut ctx = MockCtx;
761        let r = lib.call("nonexistent", &[ExprValue::Int(1)], &mut ctx);
762        assert!(r.is_err());
763    }
764
765    #[test]
766    fn derive_return_type_exact() {
767        let mut lib = FunctionLibrary::new();
768        lib.register_sig("__add__", "(int, int) -> int", add_int)
769            .unwrap();
770        assert_eq!(
771            lib.derive_return_type("__add__", &[ExprType::INT, ExprType::INT]),
772            Some(ExprType::INT)
773        );
774    }
775
776    #[test]
777    fn derive_return_type_generic() {
778        let mut lib = FunctionLibrary::new();
779        lib.register_sig("__getitem__", "(list[T1], int) -> T1", dummy_impl)
780            .unwrap();
781        assert_eq!(
782            lib.derive_return_type(
783                "__getitem__",
784                &[ExprType::list(ExprType::STRING), ExprType::INT]
785            ),
786            Some(ExprType::STRING)
787        );
788    }
789
790    #[test]
791    fn derive_return_type_coerced() {
792        let mut lib = FunctionLibrary::new();
793        lib.register_sig("__add__", "(float, float) -> float", add_float)
794            .unwrap();
795        // int + float → coerce → float
796        assert_eq!(
797            lib.derive_return_type("__add__", &[ExprType::INT, ExprType::FLOAT]),
798            Some(ExprType::FLOAT)
799        );
800    }
801
802    #[test]
803    fn derive_return_type_union_args() {
804        let mut lib = FunctionLibrary::new();
805        lib.register_sig("__add__", "(int, int) -> int", add_int)
806            .unwrap();
807        lib.register_sig("__add__", "(float, float) -> float", add_float)
808            .unwrap();
809        // (int | float) + int → int (from int+int) | float (from float coerced)
810        let union_arg = ExprType::union(vec![ExprType::INT, ExprType::FLOAT]);
811        let result = lib
812            .derive_return_type("__add__", &[union_arg, ExprType::INT])
813            .unwrap();
814        assert_eq!(
815            result,
816            ExprType::union(vec![ExprType::INT, ExprType::FLOAT])
817        );
818    }
819
820    #[test]
821    fn derive_return_type_union_collapses_to_single() {
822        let mut lib = FunctionLibrary::new();
823        lib.register_sig("len", "(string) -> int", dummy_impl)
824            .unwrap();
825        lib.register_sig("len", "(list[T1]) -> int", dummy_impl)
826            .unwrap();
827        // len(string | list[int]) → int (both return int)
828        let union_arg = ExprType::union(vec![ExprType::STRING, ExprType::list(ExprType::INT)]);
829        assert_eq!(
830            lib.derive_return_type("len", &[union_arg]),
831            Some(ExprType::INT)
832        );
833    }
834
835    #[test]
836    fn get_property_type_path() {
837        let mut lib = FunctionLibrary::new();
838        lib.register_sig("__property_name__", "(path) -> string", dummy_impl)
839            .unwrap();
840        lib.register_sig("__property_parent__", "(path) -> path", dummy_impl)
841            .unwrap();
842        assert_eq!(
843            lib.get_property_type(&ExprType::PATH, "name"),
844            Some(ExprType::STRING)
845        );
846        assert_eq!(
847            lib.get_property_type(&ExprType::PATH, "parent"),
848            Some(ExprType::PATH)
849        );
850        assert_eq!(lib.get_property_type(&ExprType::INT, "name"), None);
851    }
852}