mir_analyzer/generic.rs
1/// Generic type inference — infer template bindings from argument types and
2/// substitute them into return types.
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use mir_codebase::storage::{FnParam, TemplateParam};
7use mir_types::{Atomic, Union};
8
9// ---------------------------------------------------------------------------
10// Public API
11// ---------------------------------------------------------------------------
12
13/// Infer template parameter bindings by matching parameter types against
14/// argument types.
15///
16/// For example, given `function identity<T>(T $x): T` called with `"hello"`,
17/// this returns `{ T → string }`.
18pub fn infer_template_bindings(
19 template_params: &[TemplateParam],
20 params: &[FnParam],
21 arg_types: &[Union],
22) -> HashMap<Arc<str>, Union> {
23 let mut bindings: HashMap<Arc<str>, Union> = HashMap::new();
24
25 for (param, arg_ty) in params.iter().zip(arg_types.iter()) {
26 if let Some(param_ty) = ¶m.ty {
27 infer_from_pair(param_ty, arg_ty, &mut bindings);
28 }
29 }
30
31 // For any template not bound through arguments, fall back to its bound
32 // (or mixed if no bound is declared).
33 for tp in template_params {
34 bindings
35 .entry(tp.name.clone())
36 .or_insert_with(|| tp.bound.clone().unwrap_or_else(Union::mixed));
37 }
38
39 bindings
40}
41
42/// Check that each binding satisfies the template's declared bound.
43/// Returns a list of `(template_name, inferred_type, bound)` for violations.
44pub fn check_template_bounds<'a>(
45 bindings: &'a HashMap<Arc<str>, Union>,
46 template_params: &'a [TemplateParam],
47) -> Vec<(&'a Arc<str>, &'a Union, &'a Union)> {
48 let mut violations = Vec::new();
49 for tp in template_params {
50 if let Some(bound) = &tp.bound {
51 if let Some(inferred) = bindings.get(&tp.name) {
52 if !bound.is_mixed()
53 && !inferred.is_mixed()
54 && !inferred.is_subtype_of_simple(bound)
55 {
56 violations.push((&tp.name, inferred, bound));
57 }
58 }
59 }
60 }
61 violations
62}
63
64/// Build template bindings from a receiver's concrete type params.
65///
66/// Zips `class_template_params` (e.g. `[T]` declared on the class) with
67/// `receiver_type_params` (e.g. `[User]` from `Collection<User>`) to produce
68/// `{ T → User }`. If the receiver supplies fewer type params than the class
69/// declares, the trailing template params are left unbound. If the receiver
70/// supplies more, the extras are ignored.
71pub fn build_class_bindings(
72 class_template_params: &[TemplateParam],
73 receiver_type_params: &[Union],
74) -> HashMap<Arc<str>, Union> {
75 class_template_params
76 .iter()
77 .zip(receiver_type_params.iter())
78 .map(|(tp, ty)| (tp.name.clone(), ty.clone()))
79 .collect()
80}
81
82// ---------------------------------------------------------------------------
83// Internal helpers
84// ---------------------------------------------------------------------------
85
86/// Recursively match `param_ty` (which may contain template placeholders)
87/// against `arg_ty` (a concrete type), updating `bindings`.
88fn infer_from_pair(param_ty: &Union, arg_ty: &Union, bindings: &mut HashMap<Arc<str>, Union>) {
89 for p_atomic in ¶m_ty.types {
90 match p_atomic {
91 // Direct template placeholder: T → bind T = arg_ty
92 Atomic::TTemplateParam { name, .. } => {
93 // Merge if already partially bound
94 let entry = bindings.entry(name.clone()).or_insert_with(Union::empty);
95 *entry = Union::merge(entry, arg_ty);
96 }
97
98 // array<K, V> matched against array<k_ty, v_ty>
99 Atomic::TArray { key: pk, value: pv } => {
100 for a_atomic in &arg_ty.types {
101 match a_atomic {
102 Atomic::TArray { key: ak, value: av }
103 | Atomic::TNonEmptyArray { key: ak, value: av } => {
104 infer_from_pair(pk, ak, bindings);
105 infer_from_pair(pv, av, bindings);
106 }
107 _ => {}
108 }
109 }
110 }
111
112 // list<T> matched against list<t_ty>
113 Atomic::TList { value: pv } | Atomic::TNonEmptyList { value: pv } => {
114 for a_atomic in &arg_ty.types {
115 match a_atomic {
116 Atomic::TList { value: av } | Atomic::TNonEmptyList { value: av } => {
117 infer_from_pair(pv, av, bindings);
118 }
119 _ => {}
120 }
121 }
122 }
123
124 // ClassName<T> matched against ClassName<t_ty>
125 Atomic::TNamedObject {
126 fqcn: pfqcn,
127 type_params: pp,
128 } => {
129 for a_atomic in &arg_ty.types {
130 if let Atomic::TNamedObject {
131 fqcn: afqcn,
132 type_params: ap,
133 } = a_atomic
134 {
135 if pfqcn == afqcn {
136 for (p_param, a_param) in pp.iter().zip(ap.iter()) {
137 infer_from_pair(p_param, a_param, bindings);
138 }
139 }
140 }
141 }
142 }
143
144 _ => {}
145 }
146 }
147}