1use std::{cell::RefCell, fmt::Display};
2
3use ena::unify::{InPlaceUnificationTable, UnifyKey, UnifyValue};
4use snafu::Snafu;
5use tracing::trace;
6
7use super::{
8 Constraint, Evidence, NodeId, TypeInference,
9 ty::{ClosedRow, Row, RowCombination, RowUniVar, RowVar, Type, TypeUniVar, TypeVar},
10};
11
12#[derive(Debug, PartialEq, Eq)]
13pub enum TypeErrorKind {
14 TypeNotEqual((Type, Type)),
15 InfiniteType(TypeUniVar, Type),
16 RowsNotEqual((Row, Row)),
17 CheckIntroducedExtraVariablesOrConstraints {
18 extra_types: Vec<TypeVar>,
19 extra_row: Vec<RowVar>,
20 extra_evidence: Vec<Evidence>,
21 },
22}
23
24#[derive(Debug, Snafu, PartialEq, Eq)]
25#[snafu(display("{node_id}@{kind}"))]
26pub struct TypeError {
27 pub kind: TypeErrorKind,
28 pub node_id: NodeId,
29}
30
31impl Display for TypeErrorKind {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match &self {
34 TypeErrorKind::TypeNotEqual((l, r)) => write!(f, "types not equal: {l:?} != {r:?}"),
35 TypeErrorKind::InfiniteType(_type_uni_var, _) => todo!(),
36 TypeErrorKind::RowsNotEqual((l, r)) => write!(f, "rows not equal: {l:?} != {r:?}"),
37 TypeErrorKind::CheckIntroducedExtraVariablesOrConstraints { .. } => todo!(),
38 }
39 }
40}
41
42impl From<(ClosedRow, ClosedRow)> for TypeErrorKind {
43 fn from((left, right): (ClosedRow, ClosedRow)) -> Self {
44 TypeErrorKind::RowsNotEqual((Row::Closed(left), Row::Closed(right)))
45 }
46}
47
48impl TypeInference {
50 pub(crate) fn unification(&mut self, constraints: Vec<Constraint>) -> Result<(), TypeError> {
51 for constr in constraints {
52 match constr {
53 Constraint::TypeEqual(node_id, left, right) => self
54 .unify_ty_ty(left, right)
55 .map_err(|kind| TypeError { kind, node_id })?,
56 Constraint::RowCombine(node_id, row_comb) => self
57 .unify_row_comb(row_comb)
58 .map_err(|kind| TypeError { kind, node_id })?,
59 }
60 }
61 Ok(())
62 }
63
64 fn normalize_closed_row(&mut self, closed: ClosedRow) -> ClosedRow {
65 ClosedRow {
66 fields: closed.fields,
67 values: closed
68 .values
69 .into_iter()
70 .map(|ty| self.normalize_ty(ty))
71 .collect(),
72 }
73 }
74
75 fn normalize_row(&mut self, row: Row) -> Row {
76 match row {
77 Row::Unifier(var) => match self.row_unification_table.probe_value(var) {
78 Some(Row::Closed(closed)) => Row::Closed(self.normalize_closed_row(closed)),
79 Some(row) => row,
80 None => row,
81 },
82 Row::Open(var) => Row::Open(var),
83 Row::Closed(closed) => Row::Closed(self.normalize_closed_row(closed)),
84 }
85 }
86
87 fn dispatch_any_solved(&mut self, var: RowUniVar, row: ClosedRow) -> Result<(), TypeErrorKind> {
88 let var = self.row_unification_table.find(var);
89 let mut changed_combs = vec![];
90 trace!(?var,?self.partial_row_combs,"dispatch_any_solved");
91 self.partial_row_combs = std::mem::take(&mut self.partial_row_combs)
92 .into_iter()
93 .filter_map(|comb| match comb {
94 RowCombination {
95 left: Row::Unifier(left),
96 right,
97 goal,
98 } if self.row_unification_table.find(left) == var => {
99 changed_combs.push(RowCombination {
100 left: Row::Closed(row.clone()),
101 right,
102 goal,
103 });
104 None
105 }
106 RowCombination {
107 left,
108 right: Row::Unifier(right),
109 goal,
110 } if self.row_unification_table.find(right) == var => {
111 changed_combs.push(RowCombination {
112 left,
113 right: Row::Closed(row.clone()),
114 goal,
115 });
116 None
117 }
118 RowCombination {
119 left,
120 right,
121 goal: Row::Unifier(goal),
122 } if self.row_unification_table.find(goal) == var => {
123 changed_combs.push(RowCombination {
124 left,
125 right,
126 goal: Row::Closed(row.clone()),
127 });
128 None
129 }
130 comb => Some(comb),
131 })
132 .collect();
133
134 for row_comb in changed_combs {
135 self.unify_row_comb(row_comb)?;
136 }
137 Ok(())
138 }
139
140 fn normalize_ty(&mut self, ty: Type) -> Type {
141 match ty {
142 Type::Unit => Type::Unit,
143 Type::Int => Type::Int,
144 Type::Float => Type::Float,
145 Type::String => Type::String,
146 Type::Var(var) => Type::Var(var),
147 Type::Abs(arg, ret) => {
148 let arg = self.normalize_ty(*arg);
149 let ret = self.normalize_ty(*ret);
150 Type::abstraction(arg, ret)
151 }
152 Type::Unifier(v) => match self.unification_table.probe_value(v) {
153 Some(ty) => self.normalize_ty(ty),
154 None => Type::Unifier(v),
155 },
156 Type::Label(label, ty) => {
157 let ty = self.normalize_ty(*ty);
158 Type::label(label, ty)
159 }
160 Type::Prod(row) => Type::Prod(self.normalize_row(row)),
161 Type::Sum(row) => Type::Sum(self.normalize_row(row)),
162 Type::DataFrame => Type::DataFrame,
163 }
164 }
165
166 fn unify_ty_ty(&mut self, unnorm_left: Type, unnorm_right: Type) -> Result<(), TypeErrorKind> {
167 trace!(?unnorm_left, ?unnorm_right, "unify_ty_ty");
168 let left = self.normalize_ty(unnorm_left);
169 let right = self.normalize_ty(unnorm_right);
170 match (left, right) {
171 (Type::Unit, Type::Unit) => Ok(()),
172 (Type::Int, Type::Int) => Ok(()),
173 (Type::Float, Type::Float) => Ok(()),
174 (Type::String, Type::String) => Ok(()),
175 (Type::DataFrame, Type::DataFrame) => Ok(()),
176 (Type::Var(a), Type::Var(b)) => (a == b)
177 .then_some(())
178 .ok_or(TypeErrorKind::TypeNotEqual((Type::Var(a), Type::Var(b)))),
179 (Type::Abs(a_arg, a_ret), Type::Abs(b_arg, b_ret)) => {
180 self.unify_ty_ty(*a_arg, *b_arg)?;
181 self.unify_ty_ty(*a_ret, *b_ret)
182 }
183 (Type::Unifier(a), Type::Unifier(b)) => self
184 .unification_table
185 .unify_var_var(a, b)
186 .map_err(TypeErrorKind::TypeNotEqual),
187 (Type::Unifier(v), ty) | (ty, Type::Unifier(v)) => {
188 ty.occurs_check(v)
189 .map_err(|ty| TypeErrorKind::InfiniteType(v, ty))?;
190 self.unification_table
191 .unify_var_value(v, Some(ty))
192 .map_err(TypeErrorKind::TypeNotEqual)
193 }
194 (Type::Prod(left), Type::Prod(right)) | (Type::Sum(left), Type::Sum(right)) => {
195 self.unify_row_row(left, right)
196 }
197 (Type::Label(field, ty), Type::Prod(row))
198 | (Type::Prod(row), Type::Label(field, ty))
199 | (Type::Label(field, ty), Type::Sum(row))
200 | (Type::Sum(row), Type::Label(field, ty)) => self.unify_row_row(
201 Row::Closed(ClosedRow {
202 fields: vec![field],
203 values: vec![*ty],
204 }),
205 row,
206 ),
207 (left, right) => Err(TypeErrorKind::TypeNotEqual((left, right))),
208 }
209 }
210
211 fn diff_and_unify(
215 &mut self,
216 goal: ClosedRow,
217 sub: ClosedRow,
218 ) -> Result<ClosedRow, TypeErrorKind> {
219 let mut diff_fields = vec![];
220 let mut diff_values = vec![];
221 for (field, value) in goal.fields_and_values() {
222 match sub.fields.binary_search(field) {
223 Ok(indx) => {
224 self.unify_ty_ty(value.clone(), sub.values[indx].clone())?;
225 }
226 Err(_) => {
227 diff_fields.push(field.clone());
228 diff_values.push(value.clone());
229 }
230 }
231 }
232 let mut extra_fields = vec![];
233 let mut extra_values = vec![];
234 for (field, value) in sub.fields_and_values() {
236 if goal.fields.binary_search(field).is_err() {
237 extra_fields.push(field.clone());
238 extra_values.push(value.clone());
239 }
240 }
241 if !extra_fields.is_empty() {
242 let expected = Row::Closed(ClosedRow::merge(
243 ClosedRow {
244 fields: extra_fields,
245 values: extra_values,
246 },
247 goal.clone(),
248 ));
249 let goal = Row::Closed(goal);
250 return Err(TypeErrorKind::RowsNotEqual((goal, expected)));
251 }
252
253 Ok(ClosedRow {
254 fields: diff_fields,
255 values: diff_values,
256 })
257 }
258
259 fn unify_row_row(&mut self, left: Row, right: Row) -> Result<(), TypeErrorKind> {
260 trace!(?left, ?right, "unify_row_row");
261 let left = self.normalize_row(left);
262 let right = self.normalize_row(right);
263 match (left, right) {
264 (Row::Open(left), Row::Open(right)) => {
265 (left == right)
266 .then_some(())
267 .ok_or(TypeErrorKind::RowsNotEqual((
268 Row::Open(left),
269 Row::Open(right),
270 )))
271 }
272 (Row::Unifier(left), Row::Unifier(right)) => self
273 .row_unification_table
274 .unify_var_var(left, right)
275 .map_err(TypeErrorKind::RowsNotEqual),
276 (Row::Unifier(var), Row::Open(row)) | (Row::Open(row), Row::Unifier(var)) => self
277 .row_unification_table
278 .unify_var_value(var, Some(Row::Open(row)))
279 .map_err(TypeErrorKind::RowsNotEqual),
280 (Row::Unifier(var), Row::Closed(row)) | (Row::Closed(row), Row::Unifier(var)) => {
281 self.row_unification_table
282 .unify_var_value(var, Some(Row::Closed(row.clone())))
283 .map_err(TypeErrorKind::RowsNotEqual)?;
284 self.dispatch_any_solved(var, row)
285 }
286 (Row::Closed(left), Row::Closed(right)) => {
287 if left.fields != right.fields {
289 return Err(TypeErrorKind::from((left, right)));
290 }
291
292 for (left_ty, right_ty) in left.values.into_iter().zip(right.values) {
295 self.unify_ty_ty(left_ty, right_ty)?;
296 }
297 Ok(())
298 }
299 (Row::Open(var), Row::Closed(row)) | (Row::Closed(row), Row::Open(var)) => Err(
300 TypeErrorKind::RowsNotEqual((Row::Open(var), Row::Closed(row))),
301 ),
302 }
303 }
304
305 fn unify_row_comb(&mut self, row_comb: RowCombination) -> Result<(), TypeErrorKind> {
306 let left = self.normalize_row(row_comb.left);
307 let right = self.normalize_row(row_comb.right);
308 let goal = self.normalize_row(row_comb.goal);
309 trace!(?left, ?right, ?goal, "unify_row_comb");
310 match (left, right, goal) {
311 (Row::Closed(left), Row::Closed(right), goal) => {
313 let calc_goal = ClosedRow::merge(left, right);
314 self.unify_row_row(Row::Closed(calc_goal), goal)
315 }
316 (Row::Unifier(var), Row::Closed(sub), Row::Closed(goal))
318 | (Row::Closed(sub), Row::Unifier(var), Row::Closed(goal)) => {
319 let diff_row = self.diff_and_unify(goal, sub)?;
320 self.unify_row_row(Row::Unifier(var), Row::Closed(diff_row))
321 }
322 (left, right, goal) => {
324 let new_comb = RowCombination { left, right, goal };
325 let mut poss_uni = None;
327 self.partial_row_combs = std::mem::take(&mut self.partial_row_combs)
328 .into_iter()
329 .map(|comb| {
330 let comb = RowCombination {
331 left: self.normalize_row(comb.left),
332 right: self.normalize_row(comb.right),
333 goal: self.normalize_row(comb.goal),
334 };
335 if comb.is_unifiable(&new_comb) {
336 poss_uni = Some(comb.clone());
337 } else if comb.is_comm_unifiable(&new_comb) {
339 poss_uni = Some(RowCombination {
341 left: comb.right.clone(),
342 right: comb.left.clone(),
343 goal: comb.goal.clone(),
344 });
345 }
346 comb
347 })
348 .collect();
349
350 match poss_uni {
351 Some(match_comb) => {
353 self.unify_row_row(new_comb.left, match_comb.left)?;
354 self.unify_row_row(new_comb.right, match_comb.right)?;
355 self.unify_row_row(new_comb.goal, match_comb.goal)?;
356 }
357 None => {
359 self.partial_row_combs.insert(new_comb);
360 }
361 }
362 Ok(())
363 }
364 }
365 }
366}
367
368pub struct UnificationTable<K: ena::unify::UnifyKey> {
369 table: RefCell<InPlaceUnificationTable<K>>,
370 keys: Vec<K>,
371}
372
373impl<K: ena::unify::UnifyKey> std::fmt::Debug for UnificationTable<K> {
374 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375 let mut table = self.table.borrow_mut();
376 let vars: Vec<_> = self
377 .keys
378 .iter()
379 .map(|key| {
380 let root = table.find(*key);
381 (key, root, table.probe_value(root))
382 })
383 .collect();
384
385 f.debug_struct("UnificationTable")
386 .field("vars", &vars)
387 .finish()
388 }
389}
390
391impl<K: UnifyKey> Default for UnificationTable<K> {
392 fn default() -> Self {
393 Self {
394 table: Default::default(),
395 keys: Default::default(),
396 }
397 }
398}
399
400impl<K: UnifyKey> UnificationTable<K> {
401 pub fn new_key(&mut self, value: <K as UnifyKey>::Value) -> K {
402 let k = self.table.borrow_mut().new_key(value);
403 self.keys.push(k);
404 k
405 }
406
407 pub fn unify_var_var<K1, K2>(
408 &mut self,
409 a: K1,
410 b: K2,
411 ) -> Result<(), <<K as UnifyKey>::Value as UnifyValue>::Error>
412 where
413 K1: Into<K>,
414 K2: Into<K>,
415 {
416 self.table.borrow_mut().unify_var_var(a, b)
417 }
418 pub fn unify_var_value<K1>(
419 &mut self,
420 a_id: K1,
421 b: <K as UnifyKey>::Value,
422 ) -> Result<(), <<K as UnifyKey>::Value as UnifyValue>::Error>
423 where
424 K1: Into<K>,
425 {
426 self.table.borrow_mut().unify_var_value(a_id, b)
427 }
428
429 pub fn find<K1>(&mut self, id: K1) -> K
430 where
431 K1: Into<K>,
432 {
433 self.table.borrow_mut().find(id)
434 }
435
436 pub fn probe_value<K1>(&mut self, id: K1) -> <K as UnifyKey>::Value
437 where
438 K1: Into<K>,
439 {
440 self.table.borrow_mut().probe_value(id)
441 }
442}