1use crate::types::*;
4use indexmap::IndexMap;
5
6#[derive(Default)]
7pub struct Unifier {
8 next_var: TyVarId,
9 subst: IndexMap<TyVarId, Ty>,
11 eff_subst: IndexMap<u32, EffectSet>,
15 next_eff_var: u32,
18}
19
20impl Unifier {
21 pub fn new() -> Self { Self::default() }
22
23 pub fn fresh(&mut self) -> Ty {
24 let v = self.next_var;
25 self.next_var += 1;
26 Ty::Var(v)
27 }
28
29 pub fn fresh_id(&mut self) -> TyVarId {
30 let v = self.next_var;
31 self.next_var += 1;
32 v
33 }
34
35 pub fn fresh_eff_id(&mut self) -> u32 {
38 let v = self.next_eff_var;
39 self.next_eff_var += 1;
40 v
41 }
42
43 pub fn resolve(&self, t: &Ty) -> Ty {
45 match t {
46 Ty::Var(v) => match self.subst.get(v) {
47 Some(t2) => self.resolve(t2),
48 None => Ty::Var(*v),
49 },
50 Ty::Prim(_) | Ty::Unit | Ty::Never => t.clone(),
51 Ty::List(inner) => Ty::List(Box::new(self.resolve(inner))),
52 Ty::Tuple(items) => Ty::Tuple(items.iter().map(|t| self.resolve(t)).collect()),
53 Ty::Record(fs) => {
54 let mut out = IndexMap::new();
55 for (k, v) in fs { out.insert(k.clone(), self.resolve(v)); }
56 Ty::Record(out)
57 }
58 Ty::Con(n, args) => Ty::Con(n.clone(), args.iter().map(|t| self.resolve(t)).collect()),
59 Ty::Function { params, effects, ret } => Ty::Function {
60 params: params.iter().map(|t| self.resolve(t)).collect(),
61 effects: self.resolve_effects(effects),
62 ret: Box::new(self.resolve(ret)),
63 },
64 }
65 }
66
67 pub fn resolve_effects(&self, eff: &EffectSet) -> EffectSet {
71 let mut out = EffectSet { concrete: eff.concrete.clone(), var: None };
72 let mut cur_var = eff.var;
73 while let Some(v) = cur_var {
74 match self.eff_subst.get(&v) {
75 Some(bound) => {
76 out.concrete.extend(bound.concrete.iter().cloned());
77 cur_var = bound.var;
78 }
79 None => { out.var = Some(v); break; }
80 }
81 }
82 out
83 }
84
85 pub fn unify_effects(&mut self, a: &EffectSet, b: &EffectSet) -> Result<(), UnifyError> {
96 let a = self.resolve_effects(a);
97 let b = self.resolve_effects(b);
98 match (a.var, b.var) {
99 (None, None) => {
100 if a.concrete == b.concrete { Ok(()) }
101 else { Err(UnifyError::EffectMismatch { a, b }) }
102 }
103 (Some(va), Some(vb)) if va == vb => {
104 if a.concrete == b.concrete { Ok(()) }
105 else { Err(UnifyError::EffectMismatch { a, b }) }
106 }
107 (Some(va), _) => {
108 if !a.concrete.is_subset(&b.concrete) {
111 if b.var.is_none() {
114 return Err(UnifyError::EffectMismatch { a, b });
115 }
116 }
122 let extra: std::collections::BTreeSet<String> =
123 b.concrete.difference(&a.concrete).cloned().collect();
124 let bound = EffectSet { concrete: extra, var: b.var };
125 self.eff_subst.insert(va, bound);
126 Ok(())
127 }
128 (None, Some(vb)) => {
129 if !b.concrete.is_subset(&a.concrete) {
130 return Err(UnifyError::EffectMismatch { a, b });
131 }
132 let extra: std::collections::BTreeSet<String> =
133 a.concrete.difference(&b.concrete).cloned().collect();
134 let bound = EffectSet { concrete: extra, var: None };
135 self.eff_subst.insert(vb, bound);
136 Ok(())
137 }
138 }
139 }
140
141 pub fn unify(&mut self, a: &Ty, b: &Ty) -> Result<(), UnifyError> {
142 let a = self.resolve(a);
143 let b = self.resolve(b);
144 match (&a, &b) {
145 (Ty::Var(v), other) | (other, Ty::Var(v)) => {
146 if let Ty::Var(w) = other {
147 if v == w { return Ok(()); }
148 }
149 if occurs(*v, other, self) {
150 return Err(UnifyError::Infinite { var: *v, ty: other.clone() });
151 }
152 self.subst.insert(*v, other.clone());
153 Ok(())
154 }
155 (Ty::Prim(p1), Ty::Prim(p2)) if p1 == p2 => Ok(()),
156 (Ty::Unit, Ty::Unit) | (Ty::Never, Ty::Never) => Ok(()),
157 (Ty::Never, _) | (_, Ty::Never) => Ok(()),
160 (Ty::List(t1), Ty::List(t2)) => self.unify(t1, t2),
161 (Ty::Tuple(xs), Ty::Tuple(ys)) if xs.len() == ys.len() => {
162 for (x, y) in xs.iter().zip(ys.iter()) { self.unify(x, y)?; }
163 Ok(())
164 }
165 (Ty::Record(a), Ty::Record(b)) => {
166 if a.len() != b.len() {
167 return Err(UnifyError::Mismatch { a: Ty::Record(a.clone()), b: Ty::Record(b.clone()) });
168 }
169 for (k, va) in a {
170 match b.get(k) {
171 Some(vb) => self.unify(va, vb)?,
172 None => return Err(UnifyError::Mismatch {
173 a: Ty::Record(a.clone()), b: Ty::Record(b.clone())
174 }),
175 }
176 }
177 Ok(())
178 }
179 (Ty::Con(n1, a1), Ty::Con(n2, a2)) if n1 == n2 && a1.len() == a2.len() => {
180 for (x, y) in a1.iter().zip(a2.iter()) { self.unify(x, y)?; }
181 Ok(())
182 }
183 (Ty::Function { params: p1, effects: e1, ret: r1 },
184 Ty::Function { params: p2, effects: e2, ret: r2 })
185 if p1.len() == p2.len() =>
186 {
187 for (x, y) in p1.iter().zip(p2.iter()) { self.unify(x, y)?; }
188 self.unify_effects(e1, e2)?;
189 self.unify(r1, r2)
190 }
191 _ => Err(UnifyError::Mismatch { a, b }),
192 }
193 }
194}
195
196fn occurs(v: TyVarId, t: &Ty, u: &Unifier) -> bool {
197 let t = u.resolve(t);
198 match t {
199 Ty::Var(w) => v == w,
200 Ty::Prim(_) | Ty::Unit | Ty::Never => false,
201 Ty::List(inner) => occurs(v, &inner, u),
202 Ty::Tuple(items) => items.iter().any(|t| occurs(v, t, u)),
203 Ty::Record(fs) => fs.values().any(|t| occurs(v, t, u)),
204 Ty::Con(_, args) => args.iter().any(|t| occurs(v, t, u)),
205 Ty::Function { params, ret, .. } => {
206 params.iter().any(|t| occurs(v, t, u)) || occurs(v, &ret, u)
207 }
208 }
209}
210
211#[derive(Debug, Clone)]
212pub enum UnifyError {
213 Mismatch { a: Ty, b: Ty },
214 Infinite { var: TyVarId, ty: Ty },
215 EffectMismatch { a: EffectSet, b: EffectSet },
216}