Skip to main content

blr_lang/compiler/crust/
ty.rs

1use std::collections::HashSet;
2
3use ena::unify::{EqUnifyValue, UnifyKey};
4
5use super::{Evidence, Label};
6
7#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
8pub struct ClosedRow {
9    pub fields: Vec<Label>,
10    pub values: Vec<Type>,
11}
12impl ClosedRow {
13    /// Merge two disjoint rows into a new row.
14    pub fn merge(left: ClosedRow, right: ClosedRow) -> ClosedRow {
15        let mut left_fields = left.fields.into_iter().peekable();
16        let mut left_values = left.values.into_iter();
17        let mut right_fields = right.fields.into_iter().peekable();
18        let mut right_values = right.values.into_iter();
19
20        let mut fields = vec![];
21        let mut values = vec![];
22
23        // Since our input rows are already sorted we can explit that and not worry about resorting
24        // them here, we just have to merge our two sorted rows.
25        loop {
26            match (left_fields.peek(), right_fields.peek()) {
27                (Some(left), Some(right)) => {
28                    if left <= right {
29                        fields.push(left_fields.next().unwrap());
30                        values.push(left_values.next().unwrap());
31                    } else {
32                        fields.push(right_fields.next().unwrap());
33                        values.push(right_values.next().unwrap());
34                    }
35                }
36                (Some(_), None) => {
37                    fields.extend(left_fields);
38                    values.extend(left_values);
39                    break;
40                }
41                (None, Some(_)) => {
42                    fields.extend(right_fields);
43                    values.extend(right_values);
44                    break;
45                }
46                (None, None) => {
47                    break;
48                }
49            }
50        }
51
52        ClosedRow { fields, values }
53    }
54
55    /// Check if our closed row mentions any of our unbound types or rows.
56    pub fn mentions(
57        &self,
58        unbound_tys: &HashSet<TypeUniVar>,
59        unbound_rows: &HashSet<RowUniVar>,
60    ) -> bool {
61        for ty in self.values.iter() {
62            if ty.mentions(unbound_tys, unbound_rows) {
63                return true;
64            }
65        }
66        false
67    }
68
69    pub fn fields_and_values(&self) -> impl Iterator<Item = (&Label, &Type)> {
70        self.fields.iter().zip(self.values.iter())
71    }
72}
73impl EqUnifyValue for ClosedRow {}
74
75#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
76pub enum Row {
77    /// A row to solve for?
78    Unifier(RowUniVar),
79    /// An unknown row?
80    Open(RowVar),
81    /// A known row?
82    Closed(ClosedRow),
83}
84
85impl EqUnifyValue for Row {}
86impl Row {
87    pub fn single<S: ToString>(lbl: S, ty: Type) -> Self {
88        Row::Closed(ClosedRow {
89            fields: vec![lbl.to_string()],
90            values: vec![ty],
91        })
92    }
93
94    /// This is not strcit equality (like we get with Eq).
95    /// This instead checks a looser sense of equality
96    /// that is helpful during unification.
97    pub fn equatable(&self, other: &Self) -> bool {
98        match (self, other) {
99            // Unifier rows are equatable when their variables are equal
100            (Row::Unifier(a), Row::Unifier(b)) => a == b,
101            // Open rows are equatable when their variables are equal
102            (Row::Open(a), Row::Open(b)) => a == b,
103            // Closed rows are equatable when their fields are equal
104            (Row::Closed(a), Row::Closed(b)) => a.fields == b.fields,
105            // Anything else is not equatable
106            _ => false,
107        }
108    }
109}
110
111/// Our type
112/// Each AST node in our input will be annotated by a value of `Type`
113/// after type inference succeeeds.
114#[derive(Debug, PartialEq, Eq, Clone, PartialOrd, Ord, Hash)]
115pub enum Type {
116    /// Empty Type
117    Unit,
118    /// Type of integers
119    Int,
120    /// Type of floating point numbers
121    Float,
122    /// Type of strings
123    String,
124    /// A type variable, stands for a value of Type
125    Unifier(TypeUniVar),
126    /// A rigid type variable, cannot be unified like a normal type variable.
127    Var(TypeVar),
128    /// A curried abstraction type
129    Abs(Box<Self>, Box<Self>),
130    /// A product type
131    Prod(Row),
132    /// A sum type
133    Sum(Row),
134    /// Type of singleton rows
135    Label(Label, Box<Self>),
136    /// DataFrame
137    DataFrame,
138}
139
140impl EqUnifyValue for Type {}
141impl Type {
142    pub fn abstraction(arg: Self, ret: Self) -> Self {
143        Self::Abs(Box::new(arg), Box::new(ret))
144    }
145    pub fn abstractions<T>(arg: T, ret: Self) -> Self
146    where
147        T: IntoIterator<Item = Type>,
148        <T as IntoIterator>::IntoIter: DoubleEndedIterator<Item = Type>,
149    {
150        arg.into_iter()
151            .rfold(ret, |ret, param| Self::Abs(Box::new(param), Box::new(ret)))
152    }
153
154    pub fn label(label: Label, value: Self) -> Self {
155        Self::Label(label, Box::new(value))
156    }
157
158    pub fn occurs_check(&self, var: TypeUniVar) -> Result<(), Type> {
159        match self {
160            Type::Unit
161            | Type::Int
162            | Type::Float
163            | Type::String
164            | Type::Var(_)
165            | Type::DataFrame => Ok(()),
166            Type::Unifier(v) => {
167                if *v == var {
168                    Err(Type::Unifier(*v))
169                } else {
170                    Ok(())
171                }
172            }
173            Type::Abs(arg, ret) => {
174                arg.occurs_check(var).map_err(|_| self.clone())?;
175                ret.occurs_check(var).map_err(|_| self.clone())
176            }
177            Type::Label(_, ty) => ty.occurs_check(var).map_err(|_| self.clone()),
178            Type::Prod(row) | Type::Sum(row) => match row {
179                Row::Unifier(_) => Ok(()),
180                Row::Open(_) => Ok(()),
181                Row::Closed(closed_row) => {
182                    for ty in closed_row.values.iter() {
183                        ty.occurs_check(var).map_err(|_| self.clone())?
184                    }
185                    Ok(())
186                }
187            },
188        }
189    }
190
191    pub fn mentions(
192        &self,
193        unbound_tys: &HashSet<TypeUniVar>,
194        unbound_rows: &HashSet<RowUniVar>,
195    ) -> bool {
196        match self {
197            Type::Unit
198            | Type::Int
199            | Type::Float
200            | Type::String
201            | Type::Var(_)
202            | Type::DataFrame => false,
203            Type::Unifier(v) => unbound_tys.contains(v),
204            Type::Abs(arg, ret) => {
205                arg.mentions(unbound_tys, unbound_rows) || ret.mentions(unbound_tys, unbound_rows)
206            }
207            Type::Label(_, ty) => ty.mentions(unbound_tys, unbound_rows),
208            Type::Prod(row) | Type::Sum(row) => match row {
209                Row::Unifier(var) => unbound_rows.contains(var),
210                // Rigid variables can only exist as bound, so they cannot appear in unbound_rows.
211                Row::Open(_) => false,
212                Row::Closed(row) => row.mentions(unbound_tys, unbound_rows),
213            },
214        }
215    }
216}
217
218#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
219pub struct RowVar(pub u32);
220
221#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
222pub struct RowUniVar {
223    pub id: u32,
224}
225impl RowUniVar {
226    pub fn new(id: u32) -> Self {
227        Self { id }
228    }
229}
230
231impl UnifyKey for RowUniVar {
232    type Value = Option<Row>;
233
234    fn index(&self) -> u32 {
235        self.id
236    }
237
238    fn from_index(id: u32) -> Self {
239        Self::new(id)
240    }
241
242    fn tag() -> &'static str {
243        "RowUniVar"
244    }
245}
246
247#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
248pub struct TypeVar(pub u32);
249
250#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
251pub struct TypeUniVar {
252    pub id: u32,
253}
254impl TypeUniVar {
255    fn new(id: u32) -> Self {
256        Self { id }
257    }
258}
259impl UnifyKey for TypeUniVar {
260    type Value = Option<Type>;
261
262    fn index(&self) -> u32 {
263        self.id
264    }
265
266    fn from_index(id: u32) -> Self {
267        Self::new(id)
268    }
269
270    fn tag() -> &'static str {
271        "TypeUniVar"
272    }
273}
274
275#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
276pub struct RowCombination {
277    pub left: Row,
278    pub right: Row,
279    pub goal: Row,
280}
281impl RowCombination {
282    /// Two rows are unifiable if two of their components are equatable.
283    /// A row can be uniquely determined by two of it's components (the third is calculated from
284    /// the two). Because of this whenever rows agree on two components we can unify both rows and
285    /// possible learn new information about the third row.
286    ///
287    /// This only works because our row combinations are commutative.
288    pub fn is_unifiable(&self, other: &Self) -> bool {
289        let left_equatable = self.left.equatable(&other.left);
290        let right_equatable = self.right.equatable(&other.right);
291        let goal_equatable = self.goal.equatable(&other.goal);
292        (goal_equatable && (left_equatable || right_equatable))
293            || (left_equatable && right_equatable)
294    }
295
296    /// Check unifiability the same way as `is_unifiable` but commutes the arguments.
297    /// So we check left against right, and right against left. Goal is still checked against goal.
298    pub fn is_comm_unifiable(&self, other: &Self) -> bool {
299        let left_equatable = self.left.equatable(&other.right);
300        let right_equatable = self.right.equatable(&other.left);
301        let goal_equatable = self.goal.equatable(&other.goal);
302        (goal_equatable && (left_equatable || right_equatable))
303            || (left_equatable && right_equatable)
304    }
305
306    pub fn into_evidence(self) -> Evidence {
307        Evidence::RowEquation {
308            left: self.left,
309            right: self.right,
310            goal: self.goal,
311        }
312    }
313}