Skip to main content

blr_lang/compiler/crust/
unification.rs

1use std::{cell::RefCell, fmt::Display};
2
3use ena::unify::{InPlaceUnificationTable, UnifyKey, UnifyValue};
4use snafu::Snafu;
5use tracing::trace;
6
7use super::{
8    Constraint, Evidence, NodeId, TypeInference,
9    ty::{ClosedRow, Row, RowCombination, RowUniVar, RowVar, Type, TypeUniVar, TypeVar},
10};
11
12#[derive(Debug, PartialEq, Eq)]
13pub enum TypeErrorKind {
14    TypeNotEqual((Type, Type)),
15    InfiniteType(TypeUniVar, Type),
16    RowsNotEqual((Row, Row)),
17    CheckIntroducedExtraVariablesOrConstraints {
18        extra_types: Vec<TypeVar>,
19        extra_row: Vec<RowVar>,
20        extra_evidence: Vec<Evidence>,
21    },
22}
23
24#[derive(Debug, Snafu, PartialEq, Eq)]
25#[snafu(display("{node_id}@{kind}"))]
26pub struct TypeError {
27    pub kind: TypeErrorKind,
28    pub node_id: NodeId,
29}
30
31impl Display for TypeErrorKind {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match &self {
34            TypeErrorKind::TypeNotEqual((l, r)) => write!(f, "types not equal: {l:?} != {r:?}"),
35            TypeErrorKind::InfiniteType(_type_uni_var, _) => todo!(),
36            TypeErrorKind::RowsNotEqual((l, r)) => write!(f, "rows not equal: {l:?} != {r:?}"),
37            TypeErrorKind::CheckIntroducedExtraVariablesOrConstraints { .. } => todo!(),
38        }
39    }
40}
41
42impl From<(ClosedRow, ClosedRow)> for TypeErrorKind {
43    fn from((left, right): (ClosedRow, ClosedRow)) -> Self {
44        TypeErrorKind::RowsNotEqual((Row::Closed(left), Row::Closed(right)))
45    }
46}
47
48/// Constraint solving
49impl TypeInference {
50    pub(crate) fn unification(&mut self, constraints: Vec<Constraint>) -> Result<(), TypeError> {
51        for constr in constraints {
52            match constr {
53                Constraint::TypeEqual(node_id, left, right) => self
54                    .unify_ty_ty(left, right)
55                    .map_err(|kind| TypeError { kind, node_id })?,
56                Constraint::RowCombine(node_id, row_comb) => self
57                    .unify_row_comb(row_comb)
58                    .map_err(|kind| TypeError { kind, node_id })?,
59            }
60        }
61        Ok(())
62    }
63
64    fn normalize_closed_row(&mut self, closed: ClosedRow) -> ClosedRow {
65        ClosedRow {
66            fields: closed.fields,
67            values: closed
68                .values
69                .into_iter()
70                .map(|ty| self.normalize_ty(ty))
71                .collect(),
72        }
73    }
74
75    fn normalize_row(&mut self, row: Row) -> Row {
76        match row {
77            Row::Unifier(var) => match self.row_unification_table.probe_value(var) {
78                Some(Row::Closed(closed)) => Row::Closed(self.normalize_closed_row(closed)),
79                Some(row) => row,
80                None => row,
81            },
82            Row::Open(var) => Row::Open(var),
83            Row::Closed(closed) => Row::Closed(self.normalize_closed_row(closed)),
84        }
85    }
86
87    fn dispatch_any_solved(&mut self, var: RowUniVar, row: ClosedRow) -> Result<(), TypeErrorKind> {
88        let var = self.row_unification_table.find(var);
89        let mut changed_combs = vec![];
90        trace!(?var,?self.partial_row_combs,"dispatch_any_solved");
91        self.partial_row_combs = std::mem::take(&mut self.partial_row_combs)
92            .into_iter()
93            .filter_map(|comb| match comb {
94                RowCombination {
95                    left: Row::Unifier(left),
96                    right,
97                    goal,
98                } if self.row_unification_table.find(left) == var => {
99                    changed_combs.push(RowCombination {
100                        left: Row::Closed(row.clone()),
101                        right,
102                        goal,
103                    });
104                    None
105                }
106                RowCombination {
107                    left,
108                    right: Row::Unifier(right),
109                    goal,
110                } if self.row_unification_table.find(right) == var => {
111                    changed_combs.push(RowCombination {
112                        left,
113                        right: Row::Closed(row.clone()),
114                        goal,
115                    });
116                    None
117                }
118                RowCombination {
119                    left,
120                    right,
121                    goal: Row::Unifier(goal),
122                } if self.row_unification_table.find(goal) == var => {
123                    changed_combs.push(RowCombination {
124                        left,
125                        right,
126                        goal: Row::Closed(row.clone()),
127                    });
128                    None
129                }
130                comb => Some(comb),
131            })
132            .collect();
133
134        for row_comb in changed_combs {
135            self.unify_row_comb(row_comb)?;
136        }
137        Ok(())
138    }
139
140    fn normalize_ty(&mut self, ty: Type) -> Type {
141        match ty {
142            Type::Unit => Type::Unit,
143            Type::Int => Type::Int,
144            Type::Float => Type::Float,
145            Type::String => Type::String,
146            Type::Var(var) => Type::Var(var),
147            Type::Abs(arg, ret) => {
148                let arg = self.normalize_ty(*arg);
149                let ret = self.normalize_ty(*ret);
150                Type::abstraction(arg, ret)
151            }
152            Type::Unifier(v) => match self.unification_table.probe_value(v) {
153                Some(ty) => self.normalize_ty(ty),
154                None => Type::Unifier(v),
155            },
156            Type::Label(label, ty) => {
157                let ty = self.normalize_ty(*ty);
158                Type::label(label, ty)
159            }
160            Type::Prod(row) => Type::Prod(self.normalize_row(row)),
161            Type::Sum(row) => Type::Sum(self.normalize_row(row)),
162            Type::DataFrame => Type::DataFrame,
163        }
164    }
165
166    fn unify_ty_ty(&mut self, unnorm_left: Type, unnorm_right: Type) -> Result<(), TypeErrorKind> {
167        trace!(?unnorm_left, ?unnorm_right, "unify_ty_ty");
168        let left = self.normalize_ty(unnorm_left);
169        let right = self.normalize_ty(unnorm_right);
170        match (left, right) {
171            (Type::Unit, Type::Unit) => Ok(()),
172            (Type::Int, Type::Int) => Ok(()),
173            (Type::Float, Type::Float) => Ok(()),
174            (Type::String, Type::String) => Ok(()),
175            (Type::DataFrame, Type::DataFrame) => Ok(()),
176            (Type::Var(a), Type::Var(b)) => (a == b)
177                .then_some(())
178                .ok_or(TypeErrorKind::TypeNotEqual((Type::Var(a), Type::Var(b)))),
179            (Type::Abs(a_arg, a_ret), Type::Abs(b_arg, b_ret)) => {
180                self.unify_ty_ty(*a_arg, *b_arg)?;
181                self.unify_ty_ty(*a_ret, *b_ret)
182            }
183            (Type::Unifier(a), Type::Unifier(b)) => self
184                .unification_table
185                .unify_var_var(a, b)
186                .map_err(TypeErrorKind::TypeNotEqual),
187            (Type::Unifier(v), ty) | (ty, Type::Unifier(v)) => {
188                ty.occurs_check(v)
189                    .map_err(|ty| TypeErrorKind::InfiniteType(v, ty))?;
190                self.unification_table
191                    .unify_var_value(v, Some(ty))
192                    .map_err(TypeErrorKind::TypeNotEqual)
193            }
194            (Type::Prod(left), Type::Prod(right)) | (Type::Sum(left), Type::Sum(right)) => {
195                self.unify_row_row(left, right)
196            }
197            (Type::Label(field, ty), Type::Prod(row))
198            | (Type::Prod(row), Type::Label(field, ty))
199            | (Type::Label(field, ty), Type::Sum(row))
200            | (Type::Sum(row), Type::Label(field, ty)) => self.unify_row_row(
201                Row::Closed(ClosedRow {
202                    fields: vec![field],
203                    values: vec![*ty],
204                }),
205                row,
206            ),
207            (left, right) => Err(TypeErrorKind::TypeNotEqual((left, right))),
208        }
209    }
210
211    /// Calculate the set difference of the goal row and the sub row, returning it as a new row.
212    /// Unify the subset of the goal row that matches the sub row.
213    /// When goal row is not a superset of sub it is a type error.
214    fn diff_and_unify(
215        &mut self,
216        goal: ClosedRow,
217        sub: ClosedRow,
218    ) -> Result<ClosedRow, TypeErrorKind> {
219        let mut diff_fields = vec![];
220        let mut diff_values = vec![];
221        for (field, value) in goal.fields_and_values() {
222            match sub.fields.binary_search(field) {
223                Ok(indx) => {
224                    self.unify_ty_ty(value.clone(), sub.values[indx].clone())?;
225                }
226                Err(_) => {
227                    diff_fields.push(field.clone());
228                    diff_values.push(value.clone());
229                }
230            }
231        }
232        let mut extra_fields = vec![];
233        let mut extra_values = vec![];
234        // Find any fields in sub that do not exist in goal, this is a type error
235        for (field, value) in sub.fields_and_values() {
236            if goal.fields.binary_search(field).is_err() {
237                extra_fields.push(field.clone());
238                extra_values.push(value.clone());
239            }
240        }
241        if !extra_fields.is_empty() {
242            let expected = Row::Closed(ClosedRow::merge(
243                ClosedRow {
244                    fields: extra_fields,
245                    values: extra_values,
246                },
247                goal.clone(),
248            ));
249            let goal = Row::Closed(goal);
250            return Err(TypeErrorKind::RowsNotEqual((goal, expected)));
251        }
252
253        Ok(ClosedRow {
254            fields: diff_fields,
255            values: diff_values,
256        })
257    }
258
259    fn unify_row_row(&mut self, left: Row, right: Row) -> Result<(), TypeErrorKind> {
260        trace!(?left, ?right, "unify_row_row");
261        let left = self.normalize_row(left);
262        let right = self.normalize_row(right);
263        match (left, right) {
264            (Row::Open(left), Row::Open(right)) => {
265                (left == right)
266                    .then_some(())
267                    .ok_or(TypeErrorKind::RowsNotEqual((
268                        Row::Open(left),
269                        Row::Open(right),
270                    )))
271            }
272            (Row::Unifier(left), Row::Unifier(right)) => self
273                .row_unification_table
274                .unify_var_var(left, right)
275                .map_err(TypeErrorKind::RowsNotEqual),
276            (Row::Unifier(var), Row::Open(row)) | (Row::Open(row), Row::Unifier(var)) => self
277                .row_unification_table
278                .unify_var_value(var, Some(Row::Open(row)))
279                .map_err(TypeErrorKind::RowsNotEqual),
280            (Row::Unifier(var), Row::Closed(row)) | (Row::Closed(row), Row::Unifier(var)) => {
281                self.row_unification_table
282                    .unify_var_value(var, Some(Row::Closed(row.clone())))
283                    .map_err(TypeErrorKind::RowsNotEqual)?;
284                self.dispatch_any_solved(var, row)
285            }
286            (Row::Closed(left), Row::Closed(right)) => {
287                // Check that our rows are unifiable
288                if left.fields != right.fields {
289                    return Err(TypeErrorKind::from((left, right)));
290                }
291
292                // If they are, our values are already in order so we can walk them and unify the
293                // types
294                for (left_ty, right_ty) in left.values.into_iter().zip(right.values) {
295                    self.unify_ty_ty(left_ty, right_ty)?;
296                }
297                Ok(())
298            }
299            (Row::Open(var), Row::Closed(row)) | (Row::Closed(row), Row::Open(var)) => Err(
300                TypeErrorKind::RowsNotEqual((Row::Open(var), Row::Closed(row))),
301            ),
302        }
303    }
304
305    fn unify_row_comb(&mut self, row_comb: RowCombination) -> Result<(), TypeErrorKind> {
306        let left = self.normalize_row(row_comb.left);
307        let right = self.normalize_row(row_comb.right);
308        let goal = self.normalize_row(row_comb.goal);
309        trace!(?left, ?right, ?goal, "unify_row_comb");
310        match (left, right, goal) {
311            // 0 (and 1) variable(s) case
312            (Row::Closed(left), Row::Closed(right), goal) => {
313                let calc_goal = ClosedRow::merge(left, right);
314                self.unify_row_row(Row::Closed(calc_goal), goal)
315            }
316            // 1 variable cases
317            (Row::Unifier(var), Row::Closed(sub), Row::Closed(goal))
318            | (Row::Closed(sub), Row::Unifier(var), Row::Closed(goal)) => {
319                let diff_row = self.diff_and_unify(goal, sub)?;
320                self.unify_row_row(Row::Unifier(var), Row::Closed(diff_row))
321            }
322            // 2+ variable cases
323            (left, right, goal) => {
324                let new_comb = RowCombination { left, right, goal };
325                // Check if we've already seen an combination that we can unify against
326                let mut poss_uni = None;
327                self.partial_row_combs = std::mem::take(&mut self.partial_row_combs)
328                    .into_iter()
329                    .map(|comb| {
330                        let comb = RowCombination {
331                            left: self.normalize_row(comb.left),
332                            right: self.normalize_row(comb.right),
333                            goal: self.normalize_row(comb.goal),
334                        };
335                        if comb.is_unifiable(&new_comb) {
336                            poss_uni = Some(comb.clone());
337                        //Row combinations commute so we have to check for that possible unification
338                        } else if comb.is_comm_unifiable(&new_comb) {
339                            // We commute our combination so we unify the correct rows later
340                            poss_uni = Some(RowCombination {
341                                left: comb.right.clone(),
342                                right: comb.left.clone(),
343                                goal: comb.goal.clone(),
344                            });
345                        }
346                        comb
347                    })
348                    .collect();
349
350                match poss_uni {
351                    // Unify if we have a match
352                    Some(match_comb) => {
353                        self.unify_row_row(new_comb.left, match_comb.left)?;
354                        self.unify_row_row(new_comb.right, match_comb.right)?;
355                        self.unify_row_row(new_comb.goal, match_comb.goal)?;
356                    }
357                    // Otherwise add our combination to our list of partial combinations
358                    None => {
359                        self.partial_row_combs.insert(new_comb);
360                    }
361                }
362                Ok(())
363            }
364        }
365    }
366}
367
368pub struct UnificationTable<K: ena::unify::UnifyKey> {
369    table: RefCell<InPlaceUnificationTable<K>>,
370    keys: Vec<K>,
371}
372
373impl<K: ena::unify::UnifyKey> std::fmt::Debug for UnificationTable<K> {
374    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375        let mut table = self.table.borrow_mut();
376        let vars: Vec<_> = self
377            .keys
378            .iter()
379            .map(|key| {
380                let root = table.find(*key);
381                (key, root, table.probe_value(root))
382            })
383            .collect();
384
385        f.debug_struct("UnificationTable")
386            .field("vars", &vars)
387            .finish()
388    }
389}
390
391impl<K: UnifyKey> Default for UnificationTable<K> {
392    fn default() -> Self {
393        Self {
394            table: Default::default(),
395            keys: Default::default(),
396        }
397    }
398}
399
400impl<K: UnifyKey> UnificationTable<K> {
401    pub fn new_key(&mut self, value: <K as UnifyKey>::Value) -> K {
402        let k = self.table.borrow_mut().new_key(value);
403        self.keys.push(k);
404        k
405    }
406
407    pub fn unify_var_var<K1, K2>(
408        &mut self,
409        a: K1,
410        b: K2,
411    ) -> Result<(), <<K as UnifyKey>::Value as UnifyValue>::Error>
412    where
413        K1: Into<K>,
414        K2: Into<K>,
415    {
416        self.table.borrow_mut().unify_var_var(a, b)
417    }
418    pub fn unify_var_value<K1>(
419        &mut self,
420        a_id: K1,
421        b: <K as UnifyKey>::Value,
422    ) -> Result<(), <<K as UnifyKey>::Value as UnifyValue>::Error>
423    where
424        K1: Into<K>,
425    {
426        self.table.borrow_mut().unify_var_value(a_id, b)
427    }
428
429    pub fn find<K1>(&mut self, id: K1) -> K
430    where
431        K1: Into<K>,
432    {
433        self.table.borrow_mut().find(id)
434    }
435
436    pub fn probe_value<K1>(&mut self, id: K1) -> <K as UnifyKey>::Value
437    where
438        K1: Into<K>,
439    {
440        self.table.borrow_mut().probe_value(id)
441    }
442}