Skip to main content

mir_analyzer/
generic.rs

1/// Generic type inference — infer template bindings from argument types and
2/// substitute them into return types.
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use mir_codebase::storage::{FnParam, TemplateParam};
7use mir_types::{Atomic, Union};
8
9// ---------------------------------------------------------------------------
10// Public API
11// ---------------------------------------------------------------------------
12
13/// Infer template parameter bindings by matching parameter types against
14/// argument types.
15///
16/// For example, given `function identity<T>(T $x): T` called with `"hello"`,
17/// this returns `{ T → string }`.
18pub fn infer_template_bindings(
19    template_params: &[TemplateParam],
20    params: &[FnParam],
21    arg_types: &[Union],
22) -> HashMap<Arc<str>, Union> {
23    let mut bindings: HashMap<Arc<str>, Union> = HashMap::new();
24
25    for (param, arg_ty) in params.iter().zip(arg_types.iter()) {
26        if let Some(param_ty) = &param.ty {
27            infer_from_pair(param_ty, arg_ty, &mut bindings);
28        }
29    }
30
31    // For any template not bound through arguments, fall back to its bound
32    // (or mixed if no bound is declared).
33    for tp in template_params {
34        bindings
35            .entry(tp.name.clone())
36            .or_insert_with(|| tp.bound.clone().unwrap_or_else(Union::mixed));
37    }
38
39    bindings
40}
41
42/// Check that each binding satisfies the template's declared bound.
43/// Returns a list of `(template_name, inferred_type, bound)` for violations.
44pub fn check_template_bounds<'a>(
45    bindings: &'a HashMap<Arc<str>, Union>,
46    template_params: &'a [TemplateParam],
47) -> Vec<(&'a Arc<str>, &'a Union, &'a Union)> {
48    let mut violations = Vec::new();
49    for tp in template_params {
50        if let Some(bound) = &tp.bound {
51            if let Some(inferred) = bindings.get(&tp.name) {
52                if !bound.is_mixed()
53                    && !inferred.is_mixed()
54                    && !inferred.is_subtype_of_simple(bound)
55                {
56                    violations.push((&tp.name, inferred, bound));
57                }
58            }
59        }
60    }
61    violations
62}
63
64/// Build template bindings from a receiver's concrete type params.
65///
66/// Zips `class_template_params` (e.g. `[T]` declared on the class) with
67/// `receiver_type_params` (e.g. `[User]` from `Collection<User>`) to produce
68/// `{ T → User }`. If the receiver supplies fewer type params than the class
69/// declares, the trailing template params are left unbound. If the receiver
70/// supplies more, the extras are ignored.
71pub fn build_class_bindings(
72    class_template_params: &[TemplateParam],
73    receiver_type_params: &[Union],
74) -> HashMap<Arc<str>, Union> {
75    class_template_params
76        .iter()
77        .zip(receiver_type_params.iter())
78        .map(|(tp, ty)| (tp.name.clone(), ty.clone()))
79        .collect()
80}
81
82// ---------------------------------------------------------------------------
83// Internal helpers
84// ---------------------------------------------------------------------------
85
86/// Recursively match `param_ty` (which may contain template placeholders)
87/// against `arg_ty` (a concrete type), updating `bindings`.
88fn infer_from_pair(param_ty: &Union, arg_ty: &Union, bindings: &mut HashMap<Arc<str>, Union>) {
89    for p_atomic in &param_ty.types {
90        match p_atomic {
91            // Direct template placeholder: T → bind T = arg_ty
92            Atomic::TTemplateParam { name, .. } => {
93                // Merge if already partially bound
94                let entry = bindings.entry(name.clone()).or_insert_with(Union::empty);
95                *entry = Union::merge(entry, arg_ty);
96            }
97
98            // array<K, V> matched against array<k_ty, v_ty>
99            Atomic::TArray { key: pk, value: pv } => {
100                for a_atomic in &arg_ty.types {
101                    match a_atomic {
102                        Atomic::TArray { key: ak, value: av }
103                        | Atomic::TNonEmptyArray { key: ak, value: av } => {
104                            infer_from_pair(pk, ak, bindings);
105                            infer_from_pair(pv, av, bindings);
106                        }
107                        _ => {}
108                    }
109                }
110            }
111
112            // list<T> matched against list<t_ty>
113            Atomic::TList { value: pv } | Atomic::TNonEmptyList { value: pv } => {
114                for a_atomic in &arg_ty.types {
115                    match a_atomic {
116                        Atomic::TList { value: av } | Atomic::TNonEmptyList { value: av } => {
117                            infer_from_pair(pv, av, bindings);
118                        }
119                        _ => {}
120                    }
121                }
122            }
123
124            // ClassName<T> matched against ClassName<t_ty>
125            Atomic::TNamedObject {
126                fqcn: pfqcn,
127                type_params: pp,
128            } => {
129                for a_atomic in &arg_ty.types {
130                    if let Atomic::TNamedObject {
131                        fqcn: afqcn,
132                        type_params: ap,
133                    } = a_atomic
134                    {
135                        if pfqcn == afqcn {
136                            for (p_param, a_param) in pp.iter().zip(ap.iter()) {
137                                infer_from_pair(p_param, a_param, bindings);
138                            }
139                        }
140                    }
141                }
142            }
143
144            _ => {}
145        }
146    }
147}