1use crate::{
2 error::TypeError,
3 types::{BuiltinTypeId, Scheme, Type, TypeKind, TypeVar, TypeVarId, Types},
4};
5use rex_ast::expr::Symbol;
6use rex_lexer::span::Span;
7use rex_util::gas::GasMeter;
8use rpds::HashTrieMapSync;
9
10pub type Subst = HashTrieMapSync<TypeVarId, Type>;
11
12#[derive(Debug)]
13pub(crate) struct Unifier<'g> {
14 subs: Vec<Option<Type>>,
21 gas: Option<&'g mut GasMeter>,
22 max_infer_depth: Option<usize>,
23 infer_depth: usize,
24}
25
26impl<'g> Unifier<'g> {
27 pub(crate) fn new(max_infer_depth: Option<usize>) -> Self {
28 Self {
29 subs: Vec::new(),
30 gas: None,
31 max_infer_depth,
32 infer_depth: 0,
33 }
34 }
35
36 pub(crate) fn with_gas(gas: &'g mut GasMeter, max_infer_depth: Option<usize>) -> Self {
37 Self {
38 subs: Vec::new(),
39 gas: Some(gas),
40 max_infer_depth,
41 infer_depth: 0,
42 }
43 }
44
45 pub(crate) fn with_infer_depth<T>(
46 &mut self,
47 span: Span,
48 f: impl FnOnce(&mut Self) -> Result<T, TypeError>,
49 ) -> Result<T, TypeError> {
50 if let Some(max) = self.max_infer_depth
51 && self.infer_depth >= max
52 {
53 return Err(TypeError::Spanned {
54 span,
55 error: Box::new(TypeError::Internal(format!(
56 "maximum inference depth exceeded (max {max})"
57 ))),
58 });
59 }
60 self.infer_depth += 1;
61 let res = f(self);
62 self.infer_depth = self.infer_depth.saturating_sub(1);
63 res
64 }
65
66 pub(crate) fn charge_infer_node(&mut self) -> Result<(), TypeError> {
67 let Some(gas) = self.gas.as_mut() else {
68 return Ok(());
69 };
70 let cost = gas.costs.infer_node;
71 gas.charge(cost)?;
72 Ok(())
73 }
74
75 fn charge_unify_step(&mut self) -> Result<(), TypeError> {
76 let Some(gas) = self.gas.as_mut() else {
77 return Ok(());
78 };
79 let cost = gas.costs.unify_step;
80 gas.charge(cost)?;
81 Ok(())
82 }
83
84 fn bind_var(&mut self, id: TypeVarId, ty: Type) {
85 if id >= self.subs.len() {
86 self.subs.resize(id + 1, None);
87 }
88 self.subs[id] = Some(ty);
89 }
90
91 fn prune(&mut self, ty: &Type) -> Type {
92 match ty.as_ref() {
93 TypeKind::Var(tv) => {
94 let bound = self.subs.get(tv.id).and_then(|t| t.clone());
95 match bound {
96 Some(bound) => {
97 let pruned = self.prune(&bound);
98 self.bind_var(tv.id, pruned.clone());
99 pruned
100 }
101 None => ty.clone(),
102 }
103 }
104 TypeKind::Con(_) => ty.clone(),
105 TypeKind::App(l, r) => {
106 let l = self.prune(l);
107 let r = self.prune(r);
108 Type::app(l, r)
109 }
110 TypeKind::Fun(a, b) => {
111 let a = self.prune(a);
112 let b = self.prune(b);
113 Type::fun(a, b)
114 }
115 TypeKind::Tuple(ts) => {
116 Type::new(TypeKind::Tuple(ts.iter().map(|t| self.prune(t)).collect()))
117 }
118 TypeKind::Record(fields) => Type::new(TypeKind::Record(
119 fields
120 .iter()
121 .map(|(name, ty)| (name.clone(), self.prune(ty)))
122 .collect(),
123 )),
124 }
125 }
126
127 pub(crate) fn apply_type(&mut self, ty: &Type) -> Type {
128 self.prune(ty)
129 }
130
131 fn occurs(&mut self, id: TypeVarId, ty: &Type) -> bool {
132 match self.prune(ty).as_ref() {
133 TypeKind::Var(tv) => tv.id == id,
134 TypeKind::Con(_) => false,
135 TypeKind::App(l, r) => self.occurs(id, l) || self.occurs(id, r),
136 TypeKind::Fun(a, b) => self.occurs(id, a) || self.occurs(id, b),
137 TypeKind::Tuple(ts) => ts.iter().any(|t| self.occurs(id, t)),
138 TypeKind::Record(fields) => fields.iter().any(|(_, ty)| self.occurs(id, ty)),
139 }
140 }
141
142 pub(crate) fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeError> {
143 self.charge_unify_step()?;
144 let t1 = self.prune(t1);
145 let t2 = self.prune(t2);
146 match (t1.as_ref(), t2.as_ref()) {
147 (TypeKind::Var(a), TypeKind::Var(b)) if a.id == b.id => Ok(()),
148 (TypeKind::Var(tv), other) | (other, TypeKind::Var(tv)) => {
149 if self.occurs(tv.id, &Type::new(other.clone())) {
150 Err(TypeError::Occurs(
151 tv.id,
152 Type::new(other.clone()).to_string(),
153 ))
154 } else {
155 self.bind_var(tv.id, Type::new(other.clone()));
156 Ok(())
157 }
158 }
159 (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(()),
160 (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
161 self.unify(l1, l2)?;
162 self.unify(r1, r2)
163 }
164 (TypeKind::Fun(a1, b1), TypeKind::Fun(a2, b2)) => {
165 self.unify(a1, a2)?;
166 self.unify(b1, b2)
167 }
168 (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
169 if ts1.len() != ts2.len() {
170 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
171 }
172 for (a, b) in ts1.iter().zip(ts2.iter()) {
173 self.unify(a, b)?;
174 }
175 Ok(())
176 }
177 (TypeKind::Record(f1), TypeKind::Record(f2)) => {
178 if f1.len() != f2.len() {
179 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
180 }
181 for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
182 if n1 != n2 {
183 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
184 }
185 self.unify(t1, t2)?;
186 }
187 Ok(())
188 }
189 (TypeKind::Record(fields), TypeKind::App(head, arg))
190 | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
191 TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
192 let elem_ty = record_elem_type_unifier(fields, self)?;
193 self.unify(arg, &elem_ty)
194 }
195 TypeKind::Var(tv) => {
196 self.unify(
197 &Type::new(TypeKind::Var(tv.clone())),
198 &Type::builtin(BuiltinTypeId::Dict),
199 )?;
200 let elem_ty = record_elem_type_unifier(fields, self)?;
201 self.unify(arg, &elem_ty)
202 }
203 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
204 },
205 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
206 }
207 }
208
209 pub(crate) fn into_subst(mut self) -> Subst {
210 let mut out = Subst::new_sync();
211 for id in 0..self.subs.len() {
212 if let Some(ty) = self.subs[id].clone() {
213 let pruned = self.prune(&ty);
214 out = out.insert(id, pruned);
215 }
216 }
217 out
218 }
219}
220
221pub fn compose_subst(a: Subst, b: Subst) -> Subst {
226 if subst_is_empty(&a) {
227 return b;
228 }
229 if subst_is_empty(&b) {
230 return a;
231 }
232 let mut res = Subst::new_sync();
233 for (k, v) in b.iter() {
234 res = res.insert(*k, v.apply(&a));
235 }
236 for (k, v) in a.iter() {
237 res = res.insert(*k, v.clone());
238 }
239 res
240}
241
242pub(crate) fn subst_is_empty(s: &Subst) -> bool {
243 s.iter().next().is_none()
244}
245
246pub(crate) fn scheme_compatible(existing: &Scheme, declared: &Scheme) -> bool {
247 let s = match unify(&existing.typ, &declared.typ) {
248 Ok(s) => s,
249 Err(_) => return false,
250 };
251
252 let existing_preds = existing.preds.apply(&s);
253 let declared_preds = declared.preds.apply(&s);
254
255 let mut lhs: Vec<(Symbol, String)> = existing_preds
256 .iter()
257 .map(|p| (p.class.clone(), p.typ.to_string()))
258 .collect();
259 let mut rhs: Vec<(Symbol, String)> = declared_preds
260 .iter()
261 .map(|p| (p.class.clone(), p.typ.to_string()))
262 .collect();
263 lhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
264 rhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
265 lhs == rhs
266}
267
268fn record_elem_type_unifier(
269 fields: &[(Symbol, Type)],
270 unifier: &mut Unifier<'_>,
271) -> Result<Type, TypeError> {
272 let mut iter = fields.iter();
273 let first = match iter.next() {
274 Some((_, ty)) => ty.clone(),
275 None => return Err(TypeError::UnsupportedExpr("empty record")),
276 };
277 for (_, ty) in iter {
278 unifier.unify(&first, ty)?;
279 }
280 Ok(unifier.apply_type(&first))
281}
282
283pub(crate) fn bind(tv: &TypeVar, t: &Type) -> Result<Subst, TypeError> {
284 if let TypeKind::Var(var) = t.as_ref()
285 && var.id == tv.id
286 {
287 return Ok(Subst::new_sync());
288 }
289 if t.ftv().contains(&tv.id) {
290 Err(TypeError::Occurs(tv.id, t.to_string()))
291 } else {
292 Ok(Subst::new_sync().insert(tv.id, t.clone()))
293 }
294}
295
296fn record_elem_type(fields: &[(Symbol, Type)]) -> Result<(Subst, Type), TypeError> {
297 let mut iter = fields.iter();
298 let first = match iter.next() {
299 Some((_, ty)) => ty.clone(),
300 None => return Err(TypeError::UnsupportedExpr("empty record")),
301 };
302 let mut subst = Subst::new_sync();
303 let mut current = first;
304 for (_, ty) in iter {
305 let s_next = unify(¤t.apply(&subst), &ty.apply(&subst))?;
306 subst = compose_subst(s_next, subst);
307 current = current.apply(&subst);
308 }
309 Ok((subst.clone(), current.apply(&subst)))
310}
311
312pub fn unify(t1: &Type, t2: &Type) -> Result<Subst, TypeError> {
319 match (t1.as_ref(), t2.as_ref()) {
320 (TypeKind::Fun(l1, r1), TypeKind::Fun(l2, r2)) => {
321 let s1 = unify(l1, l2)?;
322 let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
323 Ok(compose_subst(s2, s1))
324 }
325 (TypeKind::Record(f1), TypeKind::Record(f2)) => {
326 if f1.len() != f2.len() {
327 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
328 }
329 let mut subst = Subst::new_sync();
330 for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
331 if n1 != n2 {
332 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
333 }
334 let s_next = unify(&t1.apply(&subst), &t2.apply(&subst))?;
335 subst = compose_subst(s_next, subst);
336 }
337 Ok(subst)
338 }
339 (TypeKind::Record(fields), TypeKind::App(head, arg))
340 | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
341 TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
342 let (s_fields, elem_ty) = record_elem_type(fields)?;
343 let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
344 Ok(compose_subst(s_arg, s_fields))
345 }
346 TypeKind::Var(tv) => {
347 let s_head = bind(tv, &Type::builtin(BuiltinTypeId::Dict))?;
348 let arg = arg.apply(&s_head);
349 let (s_fields, elem_ty) = record_elem_type(fields)?;
350 let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
351 Ok(compose_subst(s_arg, compose_subst(s_fields, s_head)))
352 }
353 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
354 },
355 (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
356 let s1 = unify(l1, l2)?;
357 let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
358 Ok(compose_subst(s2, s1))
359 }
360 (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
361 if ts1.len() != ts2.len() {
362 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
363 }
364 let mut s = Subst::new_sync();
365 for (a, b) in ts1.iter().zip(ts2.iter()) {
366 let s_next = unify(&a.apply(&s), &b.apply(&s))?;
367 s = compose_subst(s_next, s);
368 }
369 Ok(s)
370 }
371 (TypeKind::Var(tv), t) | (t, TypeKind::Var(tv)) => bind(tv, &Type::new(t.clone())),
372 (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(Subst::new_sync()),
373 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
374 }
375}