1use crate::{
2 error::TypeError,
3 types::{BuiltinTypeId, Scheme, Type, TypeKind, TypeVar, TypeVarId, Types},
4};
5use rex_ast::{Span, Symbol};
6use rpds::HashTrieMapSync;
7
8pub type Subst = HashTrieMapSync<TypeVarId, Type>;
9
10#[derive(Debug)]
11pub(crate) struct Unifier {
12 subs: Vec<Option<Type>>,
19 max_infer_depth: Option<usize>,
20 infer_depth: usize,
21}
22
23impl Unifier {
24 pub(crate) fn new(max_infer_depth: Option<usize>) -> Self {
25 Self {
26 subs: Vec::new(),
27 max_infer_depth,
28 infer_depth: 0,
29 }
30 }
31
32 pub(crate) fn with_infer_depth<T>(
33 &mut self,
34 span: Span,
35 f: impl FnOnce(&mut Self) -> Result<T, TypeError>,
36 ) -> Result<T, TypeError> {
37 if let Some(max) = self.max_infer_depth
38 && self.infer_depth >= max
39 {
40 return Err(TypeError::Spanned {
41 span,
42 error: Box::new(TypeError::Internal(format!(
43 "maximum inference depth exceeded (max {max})"
44 ))),
45 });
46 }
47 self.infer_depth += 1;
48 let res = f(self);
49 self.infer_depth = self.infer_depth.saturating_sub(1);
50 res
51 }
52
53 fn bind_var(&mut self, id: TypeVarId, ty: Type) {
54 if id >= self.subs.len() {
55 self.subs.resize(id + 1, None);
56 }
57 self.subs[id] = Some(ty);
58 }
59
60 fn prune(&mut self, ty: &Type) -> Type {
61 match ty.as_ref() {
62 TypeKind::Var(tv) => {
63 let bound = self.subs.get(tv.id).and_then(|t| t.clone());
64 match bound {
65 Some(bound) => {
66 let pruned = self.prune(&bound);
67 self.bind_var(tv.id, pruned.clone());
68 pruned
69 }
70 None => ty.clone(),
71 }
72 }
73 TypeKind::Con(_) => ty.clone(),
74 TypeKind::App(l, r) => {
75 let l = self.prune(l);
76 let r = self.prune(r);
77 Type::app(l, r)
78 }
79 TypeKind::Fun(a, b) => {
80 let a = self.prune(a);
81 let b = self.prune(b);
82 Type::fun(a, b)
83 }
84 TypeKind::Tuple(ts) => {
85 Type::new(TypeKind::Tuple(ts.iter().map(|t| self.prune(t)).collect()))
86 }
87 TypeKind::Record(fields) => Type::new(TypeKind::Record(
88 fields
89 .iter()
90 .map(|(name, ty)| (name.clone(), self.prune(ty)))
91 .collect(),
92 )),
93 }
94 }
95
96 pub(crate) fn apply_type(&mut self, ty: &Type) -> Type {
97 self.prune(ty)
98 }
99
100 fn occurs(&mut self, id: TypeVarId, ty: &Type) -> bool {
101 match self.prune(ty).as_ref() {
102 TypeKind::Var(tv) => tv.id == id,
103 TypeKind::Con(_) => false,
104 TypeKind::App(l, r) => self.occurs(id, l) || self.occurs(id, r),
105 TypeKind::Fun(a, b) => self.occurs(id, a) || self.occurs(id, b),
106 TypeKind::Tuple(ts) => ts.iter().any(|t| self.occurs(id, t)),
107 TypeKind::Record(fields) => fields.iter().any(|(_, ty)| self.occurs(id, ty)),
108 }
109 }
110
111 pub(crate) fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeError> {
112 let t1 = self.prune(t1);
113 let t2 = self.prune(t2);
114 match (t1.as_ref(), t2.as_ref()) {
115 (TypeKind::Var(a), TypeKind::Var(b)) if a.id == b.id => Ok(()),
116 (TypeKind::Var(tv), other) | (other, TypeKind::Var(tv)) => {
117 if self.occurs(tv.id, &Type::new(other.clone())) {
118 Err(TypeError::Occurs(
119 tv.id,
120 Type::new(other.clone()).to_string(),
121 ))
122 } else {
123 self.bind_var(tv.id, Type::new(other.clone()));
124 Ok(())
125 }
126 }
127 (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(()),
128 (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
129 self.unify(l1, l2)?;
130 self.unify(r1, r2)
131 }
132 (TypeKind::Fun(a1, b1), TypeKind::Fun(a2, b2)) => {
133 self.unify(a1, a2)?;
134 self.unify(b1, b2)
135 }
136 (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
137 if ts1.len() != ts2.len() {
138 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
139 }
140 for (a, b) in ts1.iter().zip(ts2.iter()) {
141 self.unify(a, b)?;
142 }
143 Ok(())
144 }
145 (TypeKind::Record(f1), TypeKind::Record(f2)) => {
146 if f1.len() != f2.len() {
147 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
148 }
149 for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
150 if n1 != n2 {
151 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
152 }
153 self.unify(t1, t2)?;
154 }
155 Ok(())
156 }
157 (TypeKind::Record(fields), TypeKind::App(head, arg))
158 | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
159 TypeKind::Con(c) if c.is_builtin(BuiltinTypeId::Dict) => {
160 let elem_ty = record_elem_type_unifier(fields, self)?;
161 self.unify(arg, &elem_ty)
162 }
163 TypeKind::Var(tv) => {
164 self.unify(
165 &Type::new(TypeKind::Var(tv.clone())),
166 &Type::builtin(BuiltinTypeId::Dict),
167 )?;
168 let elem_ty = record_elem_type_unifier(fields, self)?;
169 self.unify(arg, &elem_ty)
170 }
171 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
172 },
173 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
174 }
175 }
176
177 pub(crate) fn into_subst(mut self) -> Subst {
178 let mut out = Subst::new_sync();
179 for id in 0..self.subs.len() {
180 if let Some(ty) = self.subs[id].clone() {
181 let pruned = self.prune(&ty);
182 out = out.insert(id, pruned);
183 }
184 }
185 out
186 }
187}
188
189pub fn compose_subst(a: Subst, b: Subst) -> Subst {
194 if subst_is_empty(&a) {
195 return b;
196 }
197 if subst_is_empty(&b) {
198 return a;
199 }
200 let mut res = Subst::new_sync();
201 for (k, v) in b.iter() {
202 res = res.insert(*k, v.apply(&a));
203 }
204 for (k, v) in a.iter() {
205 res = res.insert(*k, v.clone());
206 }
207 res
208}
209
210pub(crate) fn subst_is_empty(s: &Subst) -> bool {
211 s.iter().next().is_none()
212}
213
214pub(crate) fn scheme_compatible(existing: &Scheme, declared: &Scheme) -> bool {
215 let s = match unify(&existing.typ, &declared.typ) {
216 Ok(s) => s,
217 Err(_) => return false,
218 };
219
220 let existing_preds = existing.preds.apply(&s);
221 let declared_preds = declared.preds.apply(&s);
222
223 let mut lhs: Vec<(Symbol, String)> = existing_preds
224 .iter()
225 .map(|p| (p.class.clone(), p.typ.to_string()))
226 .collect();
227 let mut rhs: Vec<(Symbol, String)> = declared_preds
228 .iter()
229 .map(|p| (p.class.clone(), p.typ.to_string()))
230 .collect();
231 lhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
232 rhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
233 lhs == rhs
234}
235
236fn record_elem_type_unifier(
237 fields: &[(Symbol, Type)],
238 unifier: &mut Unifier,
239) -> Result<Type, TypeError> {
240 let mut iter = fields.iter();
241 let first = match iter.next() {
242 Some((_, ty)) => ty.clone(),
243 None => return Err(TypeError::UnsupportedExpr("empty record")),
244 };
245 for (_, ty) in iter {
246 unifier.unify(&first, ty)?;
247 }
248 Ok(unifier.apply_type(&first))
249}
250
251pub(crate) fn bind(tv: &TypeVar, t: &Type) -> Result<Subst, TypeError> {
252 if let TypeKind::Var(var) = t.as_ref()
253 && var.id == tv.id
254 {
255 return Ok(Subst::new_sync());
256 }
257 if t.ftv().contains(&tv.id) {
258 Err(TypeError::Occurs(tv.id, t.to_string()))
259 } else {
260 Ok(Subst::new_sync().insert(tv.id, t.clone()))
261 }
262}
263
264fn record_elem_type(fields: &[(Symbol, Type)]) -> Result<(Subst, Type), TypeError> {
265 let mut iter = fields.iter();
266 let first = match iter.next() {
267 Some((_, ty)) => ty.clone(),
268 None => return Err(TypeError::UnsupportedExpr("empty record")),
269 };
270 let mut subst = Subst::new_sync();
271 let mut current = first;
272 for (_, ty) in iter {
273 let s_next = unify(¤t.apply(&subst), &ty.apply(&subst))?;
274 subst = compose_subst(s_next, subst);
275 current = current.apply(&subst);
276 }
277 Ok((subst.clone(), current.apply(&subst)))
278}
279
280pub fn unify(t1: &Type, t2: &Type) -> Result<Subst, TypeError> {
287 match (t1.as_ref(), t2.as_ref()) {
288 (TypeKind::Fun(l1, r1), TypeKind::Fun(l2, r2)) => {
289 let s1 = unify(l1, l2)?;
290 let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
291 Ok(compose_subst(s2, s1))
292 }
293 (TypeKind::Record(f1), TypeKind::Record(f2)) => {
294 if f1.len() != f2.len() {
295 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
296 }
297 let mut subst = Subst::new_sync();
298 for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
299 if n1 != n2 {
300 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
301 }
302 let s_next = unify(&t1.apply(&subst), &t2.apply(&subst))?;
303 subst = compose_subst(s_next, subst);
304 }
305 Ok(subst)
306 }
307 (TypeKind::Record(fields), TypeKind::App(head, arg))
308 | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
309 TypeKind::Con(c) if c.is_builtin(BuiltinTypeId::Dict) => {
310 let (s_fields, elem_ty) = record_elem_type(fields)?;
311 let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
312 Ok(compose_subst(s_arg, s_fields))
313 }
314 TypeKind::Var(tv) => {
315 let s_head = bind(tv, &Type::builtin(BuiltinTypeId::Dict))?;
316 let arg = arg.apply(&s_head);
317 let (s_fields, elem_ty) = record_elem_type(fields)?;
318 let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
319 Ok(compose_subst(s_arg, compose_subst(s_fields, s_head)))
320 }
321 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
322 },
323 (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
324 let s1 = unify(l1, l2)?;
325 let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
326 Ok(compose_subst(s2, s1))
327 }
328 (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
329 if ts1.len() != ts2.len() {
330 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
331 }
332 let mut s = Subst::new_sync();
333 for (a, b) in ts1.iter().zip(ts2.iter()) {
334 let s_next = unify(&a.apply(&s), &b.apply(&s))?;
335 s = compose_subst(s_next, s);
336 }
337 Ok(s)
338 }
339 (TypeKind::Var(tv), t) | (t, TypeKind::Var(tv)) => bind(tv, &Type::new(t.clone())),
340 (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(Subst::new_sync()),
341 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
342 }
343}