Skip to main content

rex_typesystem/
unification.rs

1use crate::{
2    error::TypeError,
3    types::{BuiltinTypeId, Scheme, Type, TypeKind, TypeVar, TypeVarId, Types},
4};
5use rex_ast::expr::Symbol;
6use rex_lexer::span::Span;
7use rex_util::gas::GasMeter;
8use rpds::HashTrieMapSync;
9
10pub type Subst = HashTrieMapSync<TypeVarId, Type>;
11
12#[derive(Debug)]
13pub(crate) struct Unifier<'g> {
14    // `subs[id] = Some(t)` means type variable `id` has been bound to `t`.
15    //
16    // This is intentionally a dense `Vec` rather than a `BTreeMap`: inference
17    // generates `TypeVarId`s from a monotonic counter, so the common case is
18    // “small id space, lots of lookups”. This makes the cost model obvious:
19    // you pay O(max_id) space, and you get O(1) binds/queries.
20    subs: Vec<Option<Type>>,
21    gas: Option<&'g mut GasMeter>,
22    max_infer_depth: Option<usize>,
23    infer_depth: usize,
24}
25
26impl<'g> Unifier<'g> {
27    pub(crate) fn new(max_infer_depth: Option<usize>) -> Self {
28        Self {
29            subs: Vec::new(),
30            gas: None,
31            max_infer_depth,
32            infer_depth: 0,
33        }
34    }
35
36    pub(crate) fn with_gas(gas: &'g mut GasMeter, max_infer_depth: Option<usize>) -> Self {
37        Self {
38            subs: Vec::new(),
39            gas: Some(gas),
40            max_infer_depth,
41            infer_depth: 0,
42        }
43    }
44
45    pub(crate) fn with_infer_depth<T>(
46        &mut self,
47        span: Span,
48        f: impl FnOnce(&mut Self) -> Result<T, TypeError>,
49    ) -> Result<T, TypeError> {
50        if let Some(max) = self.max_infer_depth
51            && self.infer_depth >= max
52        {
53            return Err(TypeError::Spanned {
54                span,
55                error: Box::new(TypeError::Internal(format!(
56                    "maximum inference depth exceeded (max {max})"
57                ))),
58            });
59        }
60        self.infer_depth += 1;
61        let res = f(self);
62        self.infer_depth = self.infer_depth.saturating_sub(1);
63        res
64    }
65
66    pub(crate) fn charge_infer_node(&mut self) -> Result<(), TypeError> {
67        let Some(gas) = self.gas.as_mut() else {
68            return Ok(());
69        };
70        let cost = gas.costs.infer_node;
71        gas.charge(cost)?;
72        Ok(())
73    }
74
75    fn charge_unify_step(&mut self) -> Result<(), TypeError> {
76        let Some(gas) = self.gas.as_mut() else {
77            return Ok(());
78        };
79        let cost = gas.costs.unify_step;
80        gas.charge(cost)?;
81        Ok(())
82    }
83
84    fn bind_var(&mut self, id: TypeVarId, ty: Type) {
85        if id >= self.subs.len() {
86            self.subs.resize(id + 1, None);
87        }
88        self.subs[id] = Some(ty);
89    }
90
91    fn prune(&mut self, ty: &Type) -> Type {
92        match ty.as_ref() {
93            TypeKind::Var(tv) => {
94                let bound = self.subs.get(tv.id).and_then(|t| t.clone());
95                match bound {
96                    Some(bound) => {
97                        let pruned = self.prune(&bound);
98                        self.bind_var(tv.id, pruned.clone());
99                        pruned
100                    }
101                    None => ty.clone(),
102                }
103            }
104            TypeKind::Con(_) => ty.clone(),
105            TypeKind::App(l, r) => {
106                let l = self.prune(l);
107                let r = self.prune(r);
108                Type::app(l, r)
109            }
110            TypeKind::Fun(a, b) => {
111                let a = self.prune(a);
112                let b = self.prune(b);
113                Type::fun(a, b)
114            }
115            TypeKind::Tuple(ts) => {
116                Type::new(TypeKind::Tuple(ts.iter().map(|t| self.prune(t)).collect()))
117            }
118            TypeKind::Record(fields) => Type::new(TypeKind::Record(
119                fields
120                    .iter()
121                    .map(|(name, ty)| (name.clone(), self.prune(ty)))
122                    .collect(),
123            )),
124        }
125    }
126
127    pub(crate) fn apply_type(&mut self, ty: &Type) -> Type {
128        self.prune(ty)
129    }
130
131    fn occurs(&mut self, id: TypeVarId, ty: &Type) -> bool {
132        match self.prune(ty).as_ref() {
133            TypeKind::Var(tv) => tv.id == id,
134            TypeKind::Con(_) => false,
135            TypeKind::App(l, r) => self.occurs(id, l) || self.occurs(id, r),
136            TypeKind::Fun(a, b) => self.occurs(id, a) || self.occurs(id, b),
137            TypeKind::Tuple(ts) => ts.iter().any(|t| self.occurs(id, t)),
138            TypeKind::Record(fields) => fields.iter().any(|(_, ty)| self.occurs(id, ty)),
139        }
140    }
141
142    pub(crate) fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeError> {
143        self.charge_unify_step()?;
144        let t1 = self.prune(t1);
145        let t2 = self.prune(t2);
146        match (t1.as_ref(), t2.as_ref()) {
147            (TypeKind::Var(a), TypeKind::Var(b)) if a.id == b.id => Ok(()),
148            (TypeKind::Var(tv), other) | (other, TypeKind::Var(tv)) => {
149                if self.occurs(tv.id, &Type::new(other.clone())) {
150                    Err(TypeError::Occurs(
151                        tv.id,
152                        Type::new(other.clone()).to_string(),
153                    ))
154                } else {
155                    self.bind_var(tv.id, Type::new(other.clone()));
156                    Ok(())
157                }
158            }
159            (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(()),
160            (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
161                self.unify(l1, l2)?;
162                self.unify(r1, r2)
163            }
164            (TypeKind::Fun(a1, b1), TypeKind::Fun(a2, b2)) => {
165                self.unify(a1, a2)?;
166                self.unify(b1, b2)
167            }
168            (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
169                if ts1.len() != ts2.len() {
170                    return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
171                }
172                for (a, b) in ts1.iter().zip(ts2.iter()) {
173                    self.unify(a, b)?;
174                }
175                Ok(())
176            }
177            (TypeKind::Record(f1), TypeKind::Record(f2)) => {
178                if f1.len() != f2.len() {
179                    return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
180                }
181                for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
182                    if n1 != n2 {
183                        return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
184                    }
185                    self.unify(t1, t2)?;
186                }
187                Ok(())
188            }
189            (TypeKind::Record(fields), TypeKind::App(head, arg))
190            | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
191                TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
192                    let elem_ty = record_elem_type_unifier(fields, self)?;
193                    self.unify(arg, &elem_ty)
194                }
195                TypeKind::Var(tv) => {
196                    self.unify(
197                        &Type::new(TypeKind::Var(tv.clone())),
198                        &Type::builtin(BuiltinTypeId::Dict),
199                    )?;
200                    let elem_ty = record_elem_type_unifier(fields, self)?;
201                    self.unify(arg, &elem_ty)
202                }
203                _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
204            },
205            _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
206        }
207    }
208
209    pub(crate) fn into_subst(mut self) -> Subst {
210        let mut out = Subst::new_sync();
211        for id in 0..self.subs.len() {
212            if let Some(ty) = self.subs[id].clone() {
213                let pruned = self.prune(&ty);
214                out = out.insert(id, pruned);
215            }
216        }
217        out
218    }
219}
220
221/// Compose substitutions `a` after `b`.
222///
223/// If `t.apply(&b)` is “apply `b` first”, then:
224/// `t.apply(&compose_subst(a, b)) == t.apply(&b).apply(&a)`.
225pub fn compose_subst(a: Subst, b: Subst) -> Subst {
226    if subst_is_empty(&a) {
227        return b;
228    }
229    if subst_is_empty(&b) {
230        return a;
231    }
232    let mut res = Subst::new_sync();
233    for (k, v) in b.iter() {
234        res = res.insert(*k, v.apply(&a));
235    }
236    for (k, v) in a.iter() {
237        res = res.insert(*k, v.clone());
238    }
239    res
240}
241
242pub(crate) fn subst_is_empty(s: &Subst) -> bool {
243    s.iter().next().is_none()
244}
245
246pub(crate) fn scheme_compatible(existing: &Scheme, declared: &Scheme) -> bool {
247    let s = match unify(&existing.typ, &declared.typ) {
248        Ok(s) => s,
249        Err(_) => return false,
250    };
251
252    let existing_preds = existing.preds.apply(&s);
253    let declared_preds = declared.preds.apply(&s);
254
255    let mut lhs: Vec<(Symbol, String)> = existing_preds
256        .iter()
257        .map(|p| (p.class.clone(), p.typ.to_string()))
258        .collect();
259    let mut rhs: Vec<(Symbol, String)> = declared_preds
260        .iter()
261        .map(|p| (p.class.clone(), p.typ.to_string()))
262        .collect();
263    lhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
264    rhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
265    lhs == rhs
266}
267
268fn record_elem_type_unifier(
269    fields: &[(Symbol, Type)],
270    unifier: &mut Unifier<'_>,
271) -> Result<Type, TypeError> {
272    let mut iter = fields.iter();
273    let first = match iter.next() {
274        Some((_, ty)) => ty.clone(),
275        None => return Err(TypeError::UnsupportedExpr("empty record")),
276    };
277    for (_, ty) in iter {
278        unifier.unify(&first, ty)?;
279    }
280    Ok(unifier.apply_type(&first))
281}
282
283pub(crate) fn bind(tv: &TypeVar, t: &Type) -> Result<Subst, TypeError> {
284    if let TypeKind::Var(var) = t.as_ref()
285        && var.id == tv.id
286    {
287        return Ok(Subst::new_sync());
288    }
289    if t.ftv().contains(&tv.id) {
290        Err(TypeError::Occurs(tv.id, t.to_string()))
291    } else {
292        Ok(Subst::new_sync().insert(tv.id, t.clone()))
293    }
294}
295
296fn record_elem_type(fields: &[(Symbol, Type)]) -> Result<(Subst, Type), TypeError> {
297    let mut iter = fields.iter();
298    let first = match iter.next() {
299        Some((_, ty)) => ty.clone(),
300        None => return Err(TypeError::UnsupportedExpr("empty record")),
301    };
302    let mut subst = Subst::new_sync();
303    let mut current = first;
304    for (_, ty) in iter {
305        let s_next = unify(&current.apply(&subst), &ty.apply(&subst))?;
306        subst = compose_subst(s_next, subst);
307        current = current.apply(&subst);
308    }
309    Ok((subst.clone(), current.apply(&subst)))
310}
311
312/// Compute a most-general unifier for two types.
313///
314/// This is the “pure” unifier: it returns an explicit substitution map and is
315/// easy to read/compose in isolation. The type inference engine uses `Unifier`
316/// directly to avoid allocating and composing persistent maps at every
317/// unification step.
318pub fn unify(t1: &Type, t2: &Type) -> Result<Subst, TypeError> {
319    match (t1.as_ref(), t2.as_ref()) {
320        (TypeKind::Fun(l1, r1), TypeKind::Fun(l2, r2)) => {
321            let s1 = unify(l1, l2)?;
322            let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
323            Ok(compose_subst(s2, s1))
324        }
325        (TypeKind::Record(f1), TypeKind::Record(f2)) => {
326            if f1.len() != f2.len() {
327                return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
328            }
329            let mut subst = Subst::new_sync();
330            for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
331                if n1 != n2 {
332                    return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
333                }
334                let s_next = unify(&t1.apply(&subst), &t2.apply(&subst))?;
335                subst = compose_subst(s_next, subst);
336            }
337            Ok(subst)
338        }
339        (TypeKind::Record(fields), TypeKind::App(head, arg))
340        | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
341            TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
342                let (s_fields, elem_ty) = record_elem_type(fields)?;
343                let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
344                Ok(compose_subst(s_arg, s_fields))
345            }
346            TypeKind::Var(tv) => {
347                let s_head = bind(tv, &Type::builtin(BuiltinTypeId::Dict))?;
348                let arg = arg.apply(&s_head);
349                let (s_fields, elem_ty) = record_elem_type(fields)?;
350                let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
351                Ok(compose_subst(s_arg, compose_subst(s_fields, s_head)))
352            }
353            _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
354        },
355        (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
356            let s1 = unify(l1, l2)?;
357            let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
358            Ok(compose_subst(s2, s1))
359        }
360        (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
361            if ts1.len() != ts2.len() {
362                return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
363            }
364            let mut s = Subst::new_sync();
365            for (a, b) in ts1.iter().zip(ts2.iter()) {
366                let s_next = unify(&a.apply(&s), &b.apply(&s))?;
367                s = compose_subst(s_next, s);
368            }
369            Ok(s)
370        }
371        (TypeKind::Var(tv), t) | (t, TypeKind::Var(tv)) => bind(tv, &Type::new(t.clone())),
372        (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(Subst::new_sync()),
373        _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
374    }
375}