Skip to main content

rex_typesystem/
unification.rs

1use crate::{
2    error::TypeError,
3    types::{BuiltinTypeId, Scheme, Type, TypeKind, TypeVar, TypeVarId, Types},
4};
5use rex_ast::{Span, Symbol};
6use rpds::HashTrieMapSync;
7
8pub type Subst = HashTrieMapSync<TypeVarId, Type>;
9
10#[derive(Debug)]
11pub(crate) struct Unifier {
12    // `subs[id] = Some(t)` means type variable `id` has been bound to `t`.
13    //
14    // This is intentionally a dense `Vec` rather than a `BTreeMap`: inference
15    // generates `TypeVarId`s from a monotonic counter, so the common case is
16    // “small id space, lots of lookups”. This makes the cost model obvious:
17    // you pay O(max_id) space, and you get O(1) binds/queries.
18    subs: Vec<Option<Type>>,
19    max_infer_depth: Option<usize>,
20    infer_depth: usize,
21}
22
23impl Unifier {
24    pub(crate) fn new(max_infer_depth: Option<usize>) -> Self {
25        Self {
26            subs: Vec::new(),
27            max_infer_depth,
28            infer_depth: 0,
29        }
30    }
31
32    pub(crate) fn with_infer_depth<T>(
33        &mut self,
34        span: Span,
35        f: impl FnOnce(&mut Self) -> Result<T, TypeError>,
36    ) -> Result<T, TypeError> {
37        if let Some(max) = self.max_infer_depth
38            && self.infer_depth >= max
39        {
40            return Err(TypeError::Spanned {
41                span,
42                error: Box::new(TypeError::Internal(format!(
43                    "maximum inference depth exceeded (max {max})"
44                ))),
45            });
46        }
47        self.infer_depth += 1;
48        let res = f(self);
49        self.infer_depth = self.infer_depth.saturating_sub(1);
50        res
51    }
52
53    fn bind_var(&mut self, id: TypeVarId, ty: Type) {
54        if id >= self.subs.len() {
55            self.subs.resize(id + 1, None);
56        }
57        self.subs[id] = Some(ty);
58    }
59
60    fn prune(&mut self, ty: &Type) -> Type {
61        match ty.as_ref() {
62            TypeKind::Var(tv) => {
63                let bound = self.subs.get(tv.id).and_then(|t| t.clone());
64                match bound {
65                    Some(bound) => {
66                        let pruned = self.prune(&bound);
67                        self.bind_var(tv.id, pruned.clone());
68                        pruned
69                    }
70                    None => ty.clone(),
71                }
72            }
73            TypeKind::Con(_) => ty.clone(),
74            TypeKind::App(l, r) => {
75                let l = self.prune(l);
76                let r = self.prune(r);
77                Type::app(l, r)
78            }
79            TypeKind::Fun(a, b) => {
80                let a = self.prune(a);
81                let b = self.prune(b);
82                Type::fun(a, b)
83            }
84            TypeKind::Tuple(ts) => {
85                Type::new(TypeKind::Tuple(ts.iter().map(|t| self.prune(t)).collect()))
86            }
87            TypeKind::Record(fields) => Type::new(TypeKind::Record(
88                fields
89                    .iter()
90                    .map(|(name, ty)| (name.clone(), self.prune(ty)))
91                    .collect(),
92            )),
93        }
94    }
95
96    pub(crate) fn apply_type(&mut self, ty: &Type) -> Type {
97        self.prune(ty)
98    }
99
100    fn occurs(&mut self, id: TypeVarId, ty: &Type) -> bool {
101        match self.prune(ty).as_ref() {
102            TypeKind::Var(tv) => tv.id == id,
103            TypeKind::Con(_) => false,
104            TypeKind::App(l, r) => self.occurs(id, l) || self.occurs(id, r),
105            TypeKind::Fun(a, b) => self.occurs(id, a) || self.occurs(id, b),
106            TypeKind::Tuple(ts) => ts.iter().any(|t| self.occurs(id, t)),
107            TypeKind::Record(fields) => fields.iter().any(|(_, ty)| self.occurs(id, ty)),
108        }
109    }
110
111    pub(crate) fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeError> {
112        let t1 = self.prune(t1);
113        let t2 = self.prune(t2);
114        match (t1.as_ref(), t2.as_ref()) {
115            (TypeKind::Var(a), TypeKind::Var(b)) if a.id == b.id => Ok(()),
116            (TypeKind::Var(tv), other) | (other, TypeKind::Var(tv)) => {
117                if self.occurs(tv.id, &Type::new(other.clone())) {
118                    Err(TypeError::Occurs(
119                        tv.id,
120                        Type::new(other.clone()).to_string(),
121                    ))
122                } else {
123                    self.bind_var(tv.id, Type::new(other.clone()));
124                    Ok(())
125                }
126            }
127            (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(()),
128            (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
129                self.unify(l1, l2)?;
130                self.unify(r1, r2)
131            }
132            (TypeKind::Fun(a1, b1), TypeKind::Fun(a2, b2)) => {
133                self.unify(a1, a2)?;
134                self.unify(b1, b2)
135            }
136            (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
137                if ts1.len() != ts2.len() {
138                    return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
139                }
140                for (a, b) in ts1.iter().zip(ts2.iter()) {
141                    self.unify(a, b)?;
142                }
143                Ok(())
144            }
145            (TypeKind::Record(f1), TypeKind::Record(f2)) => {
146                if f1.len() != f2.len() {
147                    return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
148                }
149                for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
150                    if n1 != n2 {
151                        return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
152                    }
153                    self.unify(t1, t2)?;
154                }
155                Ok(())
156            }
157            (TypeKind::Record(fields), TypeKind::App(head, arg))
158            | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
159                TypeKind::Con(c) if c.is_builtin(BuiltinTypeId::Dict) => {
160                    let elem_ty = record_elem_type_unifier(fields, self)?;
161                    self.unify(arg, &elem_ty)
162                }
163                TypeKind::Var(tv) => {
164                    self.unify(
165                        &Type::new(TypeKind::Var(tv.clone())),
166                        &Type::builtin(BuiltinTypeId::Dict),
167                    )?;
168                    let elem_ty = record_elem_type_unifier(fields, self)?;
169                    self.unify(arg, &elem_ty)
170                }
171                _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
172            },
173            _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
174        }
175    }
176
177    pub(crate) fn into_subst(mut self) -> Subst {
178        let mut out = Subst::new_sync();
179        for id in 0..self.subs.len() {
180            if let Some(ty) = self.subs[id].clone() {
181                let pruned = self.prune(&ty);
182                out = out.insert(id, pruned);
183            }
184        }
185        out
186    }
187}
188
189/// Compose substitutions `a` after `b`.
190///
191/// If `t.apply(&b)` is “apply `b` first”, then:
192/// `t.apply(&compose_subst(a, b)) == t.apply(&b).apply(&a)`.
193pub fn compose_subst(a: Subst, b: Subst) -> Subst {
194    if subst_is_empty(&a) {
195        return b;
196    }
197    if subst_is_empty(&b) {
198        return a;
199    }
200    let mut res = Subst::new_sync();
201    for (k, v) in b.iter() {
202        res = res.insert(*k, v.apply(&a));
203    }
204    for (k, v) in a.iter() {
205        res = res.insert(*k, v.clone());
206    }
207    res
208}
209
210pub(crate) fn subst_is_empty(s: &Subst) -> bool {
211    s.iter().next().is_none()
212}
213
214pub(crate) fn scheme_compatible(existing: &Scheme, declared: &Scheme) -> bool {
215    let s = match unify(&existing.typ, &declared.typ) {
216        Ok(s) => s,
217        Err(_) => return false,
218    };
219
220    let existing_preds = existing.preds.apply(&s);
221    let declared_preds = declared.preds.apply(&s);
222
223    let mut lhs: Vec<(Symbol, String)> = existing_preds
224        .iter()
225        .map(|p| (p.class.clone(), p.typ.to_string()))
226        .collect();
227    let mut rhs: Vec<(Symbol, String)> = declared_preds
228        .iter()
229        .map(|p| (p.class.clone(), p.typ.to_string()))
230        .collect();
231    lhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
232    rhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
233    lhs == rhs
234}
235
236fn record_elem_type_unifier(
237    fields: &[(Symbol, Type)],
238    unifier: &mut Unifier,
239) -> Result<Type, TypeError> {
240    let mut iter = fields.iter();
241    let first = match iter.next() {
242        Some((_, ty)) => ty.clone(),
243        None => return Err(TypeError::UnsupportedExpr("empty record")),
244    };
245    for (_, ty) in iter {
246        unifier.unify(&first, ty)?;
247    }
248    Ok(unifier.apply_type(&first))
249}
250
251pub(crate) fn bind(tv: &TypeVar, t: &Type) -> Result<Subst, TypeError> {
252    if let TypeKind::Var(var) = t.as_ref()
253        && var.id == tv.id
254    {
255        return Ok(Subst::new_sync());
256    }
257    if t.ftv().contains(&tv.id) {
258        Err(TypeError::Occurs(tv.id, t.to_string()))
259    } else {
260        Ok(Subst::new_sync().insert(tv.id, t.clone()))
261    }
262}
263
264fn record_elem_type(fields: &[(Symbol, Type)]) -> Result<(Subst, Type), TypeError> {
265    let mut iter = fields.iter();
266    let first = match iter.next() {
267        Some((_, ty)) => ty.clone(),
268        None => return Err(TypeError::UnsupportedExpr("empty record")),
269    };
270    let mut subst = Subst::new_sync();
271    let mut current = first;
272    for (_, ty) in iter {
273        let s_next = unify(&current.apply(&subst), &ty.apply(&subst))?;
274        subst = compose_subst(s_next, subst);
275        current = current.apply(&subst);
276    }
277    Ok((subst.clone(), current.apply(&subst)))
278}
279
280/// Compute a most-general unifier for two types.
281///
282/// This is the “pure” unifier: it returns an explicit substitution map and is
283/// easy to read/compose in isolation. The type inference engine uses `Unifier`
284/// directly to avoid allocating and composing persistent maps at every
285/// unification step.
286pub fn unify(t1: &Type, t2: &Type) -> Result<Subst, TypeError> {
287    match (t1.as_ref(), t2.as_ref()) {
288        (TypeKind::Fun(l1, r1), TypeKind::Fun(l2, r2)) => {
289            let s1 = unify(l1, l2)?;
290            let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
291            Ok(compose_subst(s2, s1))
292        }
293        (TypeKind::Record(f1), TypeKind::Record(f2)) => {
294            if f1.len() != f2.len() {
295                return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
296            }
297            let mut subst = Subst::new_sync();
298            for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
299                if n1 != n2 {
300                    return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
301                }
302                let s_next = unify(&t1.apply(&subst), &t2.apply(&subst))?;
303                subst = compose_subst(s_next, subst);
304            }
305            Ok(subst)
306        }
307        (TypeKind::Record(fields), TypeKind::App(head, arg))
308        | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
309            TypeKind::Con(c) if c.is_builtin(BuiltinTypeId::Dict) => {
310                let (s_fields, elem_ty) = record_elem_type(fields)?;
311                let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
312                Ok(compose_subst(s_arg, s_fields))
313            }
314            TypeKind::Var(tv) => {
315                let s_head = bind(tv, &Type::builtin(BuiltinTypeId::Dict))?;
316                let arg = arg.apply(&s_head);
317                let (s_fields, elem_ty) = record_elem_type(fields)?;
318                let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
319                Ok(compose_subst(s_arg, compose_subst(s_fields, s_head)))
320            }
321            _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
322        },
323        (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
324            let s1 = unify(l1, l2)?;
325            let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
326            Ok(compose_subst(s2, s1))
327        }
328        (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
329            if ts1.len() != ts2.len() {
330                return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
331            }
332            let mut s = Subst::new_sync();
333            for (a, b) in ts1.iter().zip(ts2.iter()) {
334                let s_next = unify(&a.apply(&s), &b.apply(&s))?;
335                s = compose_subst(s_next, s);
336            }
337            Ok(s)
338        }
339        (TypeKind::Var(tv), t) | (t, TypeKind::Var(tv)) => bind(tv, &Type::new(t.clone())),
340        (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(Subst::new_sync()),
341        _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
342    }
343}