Skip to main content

lex_types/
unifier.rs

1//! Union-find based unification for type variables.
2
3use crate::types::*;
4use indexmap::IndexMap;
5
6#[derive(Default)]
7pub struct Unifier {
8    next_var: TyVarId,
9    /// Substitutions: `subst[v] = t` means var `v` was bound to `t`.
10    subst: IndexMap<TyVarId, Ty>,
11    /// Effect-row substitutions: `eff_subst[v] = set` means effect-var
12    /// `v` was bound to `set` (which may itself carry another var,
13    /// so resolve_effects walks the chain).
14    eff_subst: IndexMap<u32, EffectSet>,
15    /// Counter for fresh effect-row variables, separate from type
16    /// variables to keep the namespaces clean.
17    next_eff_var: u32,
18}
19
20impl Unifier {
21    pub fn new() -> Self { Self::default() }
22
23    pub fn fresh(&mut self) -> Ty {
24        let v = self.next_var;
25        self.next_var += 1;
26        Ty::Var(v)
27    }
28
29    pub fn fresh_id(&mut self) -> TyVarId {
30        let v = self.next_var;
31        self.next_var += 1;
32        v
33    }
34
35    /// Allocate a fresh effect-row variable for use in polymorphic
36    /// signatures (e.g. `list.map[T, U, E]`'s `E`).
37    pub fn fresh_eff_id(&mut self) -> u32 {
38        let v = self.next_eff_var;
39        self.next_eff_var += 1;
40        v
41    }
42
43    /// Resolve a type by following substitutions. Recursive; structural.
44    pub fn resolve(&self, t: &Ty) -> Ty {
45        match t {
46            Ty::Var(v) => match self.subst.get(v) {
47                Some(t2) => self.resolve(t2),
48                None => Ty::Var(*v),
49            },
50            Ty::Prim(_) | Ty::Unit | Ty::Never => t.clone(),
51            Ty::List(inner) => Ty::List(Box::new(self.resolve(inner))),
52            Ty::Tuple(items) => Ty::Tuple(items.iter().map(|t| self.resolve(t)).collect()),
53            Ty::Record(fs) => {
54                let mut out = IndexMap::new();
55                for (k, v) in fs { out.insert(k.clone(), self.resolve(v)); }
56                Ty::Record(out)
57            }
58            Ty::Con(n, args) => Ty::Con(n.clone(), args.iter().map(|t| self.resolve(t)).collect()),
59            Ty::Function { params, effects, ret } => Ty::Function {
60                params: params.iter().map(|t| self.resolve(t)).collect(),
61                effects: self.resolve_effects(effects),
62                ret: Box::new(self.resolve(ret)),
63            },
64        }
65    }
66
67    /// Resolve an effect set by chasing the `var` substitution chain.
68    /// Concrete effects accumulate along the chain; the returned set's
69    /// `var` is the terminal unbound var, or `None` if fully concrete.
70    pub fn resolve_effects(&self, eff: &EffectSet) -> EffectSet {
71        let mut out = EffectSet { concrete: eff.concrete.clone(), var: None };
72        let mut cur_var = eff.var;
73        while let Some(v) = cur_var {
74            match self.eff_subst.get(&v) {
75                Some(bound) => {
76                    out.concrete.extend(bound.concrete.iter().cloned());
77                    cur_var = bound.var;
78                }
79                None => { out.var = Some(v); break; }
80            }
81        }
82        out
83    }
84
85    /// Unify two effect sets. Variables are existentially bound at the
86    /// signature site; at call sites they bind to the actual closure's
87    /// effects.
88    ///
89    /// Cases (after resolving):
90    ///   - both fully concrete: must be equal
91    ///   - exactly one carries a var: var := the *missing* effects
92    ///     (i.e. the other side's concrete minus this side's), with
93    ///     the other side's residual var if any
94    ///   - both carry a var: bind one to the other (alias)
95    pub fn unify_effects(&mut self, a: &EffectSet, b: &EffectSet) -> Result<(), UnifyError> {
96        let a = self.resolve_effects(a);
97        let b = self.resolve_effects(b);
98        match (a.var, b.var) {
99            (None, None) => {
100                if a.concrete == b.concrete { Ok(()) }
101                else { Err(UnifyError::EffectMismatch { a, b }) }
102            }
103            (Some(va), Some(vb)) if va == vb => {
104                if a.concrete == b.concrete { Ok(()) }
105                else { Err(UnifyError::EffectMismatch { a, b }) }
106            }
107            (Some(va), _) => {
108                // Bind va so a + bound = b. Means bound has b.concrete
109                // minus a.concrete, plus b.var.
110                if !a.concrete.is_subset(&b.concrete) {
111                    // a says "at least these" but b doesn't have all
112                    // of them and isn't open enough to absorb. Reject.
113                    if b.var.is_none() {
114                        return Err(UnifyError::EffectMismatch { a, b });
115                    }
116                    // b has a var; we can absorb a's extras into b's var
117                    // by binding b's var symmetrically. Easier: just
118                    // bind va to (b's concrete + b's var) and rely on
119                    // the chain to track. Tighter handling possible
120                    // later.
121                }
122                let extra: std::collections::BTreeSet<crate::types::EffectKind> =
123                    b.concrete.difference(&a.concrete).cloned().collect();
124                let bound = EffectSet { concrete: extra, var: b.var };
125                self.eff_subst.insert(va, bound);
126                Ok(())
127            }
128            (None, Some(vb)) => {
129                if !b.concrete.is_subset(&a.concrete) {
130                    return Err(UnifyError::EffectMismatch { a, b });
131                }
132                let extra: std::collections::BTreeSet<crate::types::EffectKind> =
133                    a.concrete.difference(&b.concrete).cloned().collect();
134                let bound = EffectSet { concrete: extra, var: None };
135                self.eff_subst.insert(vb, bound);
136                Ok(())
137            }
138        }
139    }
140
141    pub fn unify(&mut self, a: &Ty, b: &Ty) -> Result<(), UnifyError> {
142        let a = self.resolve(a);
143        let b = self.resolve(b);
144        match (&a, &b) {
145            (Ty::Var(v), other) | (other, Ty::Var(v)) => {
146                if let Ty::Var(w) = other {
147                    if v == w { return Ok(()); }
148                }
149                if occurs(*v, other, self) {
150                    return Err(UnifyError::Infinite { var: *v, ty: other.clone() });
151                }
152                self.subst.insert(*v, other.clone());
153                Ok(())
154            }
155            (Ty::Prim(p1), Ty::Prim(p2)) if p1 == p2 => Ok(()),
156            (Ty::Unit, Ty::Unit) | (Ty::Never, Ty::Never) => Ok(()),
157            // Never is a subtype of everything (bottom). For unification we
158            // treat it as compatible with any type.
159            (Ty::Never, _) | (_, Ty::Never) => Ok(()),
160            (Ty::List(t1), Ty::List(t2)) => self.unify(t1, t2),
161            (Ty::Tuple(xs), Ty::Tuple(ys)) if xs.len() == ys.len() => {
162                for (x, y) in xs.iter().zip(ys.iter()) { self.unify(x, y)?; }
163                Ok(())
164            }
165            (Ty::Record(a), Ty::Record(b)) => {
166                if a.len() != b.len() {
167                    return Err(UnifyError::Mismatch { a: Ty::Record(a.clone()), b: Ty::Record(b.clone()) });
168                }
169                for (k, va) in a {
170                    match b.get(k) {
171                        Some(vb) => self.unify(va, vb)?,
172                        None => return Err(UnifyError::Mismatch {
173                            a: Ty::Record(a.clone()), b: Ty::Record(b.clone())
174                        }),
175                    }
176                }
177                Ok(())
178            }
179            (Ty::Con(n1, a1), Ty::Con(n2, a2)) if n1 == n2 && a1.len() == a2.len() => {
180                for (x, y) in a1.iter().zip(a2.iter()) { self.unify(x, y)?; }
181                Ok(())
182            }
183            (Ty::Function { params: p1, effects: e1, ret: r1 },
184             Ty::Function { params: p2, effects: e2, ret: r2 })
185                if p1.len() == p2.len() =>
186            {
187                for (x, y) in p1.iter().zip(p2.iter()) { self.unify(x, y)?; }
188                self.unify_effects(e1, e2)?;
189                self.unify(r1, r2)
190            }
191            _ => Err(UnifyError::Mismatch { a, b }),
192        }
193    }
194}
195
196fn occurs(v: TyVarId, t: &Ty, u: &Unifier) -> bool {
197    let t = u.resolve(t);
198    match t {
199        Ty::Var(w) => v == w,
200        Ty::Prim(_) | Ty::Unit | Ty::Never => false,
201        Ty::List(inner) => occurs(v, &inner, u),
202        Ty::Tuple(items) => items.iter().any(|t| occurs(v, t, u)),
203        Ty::Record(fs) => fs.values().any(|t| occurs(v, t, u)),
204        Ty::Con(_, args) => args.iter().any(|t| occurs(v, t, u)),
205        Ty::Function { params, ret, .. } => {
206            params.iter().any(|t| occurs(v, t, u)) || occurs(v, &ret, u)
207        }
208    }
209}
210
211#[derive(Debug, Clone)]
212pub enum UnifyError {
213    Mismatch { a: Ty, b: Ty },
214    Infinite { var: TyVarId, ty: Ty },
215    EffectMismatch { a: EffectSet, b: EffectSet },
216}