1use std::collections::HashSet;
2
3use ena::unify::{EqUnifyValue, UnifyKey};
4
5use super::{Evidence, Label};
6
7#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
8pub struct ClosedRow {
9 pub fields: Vec<Label>,
10 pub values: Vec<Type>,
11}
12impl ClosedRow {
13 pub fn merge(left: ClosedRow, right: ClosedRow) -> ClosedRow {
15 let mut left_fields = left.fields.into_iter().peekable();
16 let mut left_values = left.values.into_iter();
17 let mut right_fields = right.fields.into_iter().peekable();
18 let mut right_values = right.values.into_iter();
19
20 let mut fields = vec![];
21 let mut values = vec![];
22
23 loop {
26 match (left_fields.peek(), right_fields.peek()) {
27 (Some(left), Some(right)) => {
28 if left <= right {
29 fields.push(left_fields.next().unwrap());
30 values.push(left_values.next().unwrap());
31 } else {
32 fields.push(right_fields.next().unwrap());
33 values.push(right_values.next().unwrap());
34 }
35 }
36 (Some(_), None) => {
37 fields.extend(left_fields);
38 values.extend(left_values);
39 break;
40 }
41 (None, Some(_)) => {
42 fields.extend(right_fields);
43 values.extend(right_values);
44 break;
45 }
46 (None, None) => {
47 break;
48 }
49 }
50 }
51
52 ClosedRow { fields, values }
53 }
54
55 pub fn mentions(
57 &self,
58 unbound_tys: &HashSet<TypeUniVar>,
59 unbound_rows: &HashSet<RowUniVar>,
60 ) -> bool {
61 for ty in self.values.iter() {
62 if ty.mentions(unbound_tys, unbound_rows) {
63 return true;
64 }
65 }
66 false
67 }
68
69 pub fn fields_and_values(&self) -> impl Iterator<Item = (&Label, &Type)> {
70 self.fields.iter().zip(self.values.iter())
71 }
72}
73impl EqUnifyValue for ClosedRow {}
74
75#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
76pub enum Row {
77 Unifier(RowUniVar),
79 Open(RowVar),
81 Closed(ClosedRow),
83}
84
85impl EqUnifyValue for Row {}
86impl Row {
87 pub fn single<S: ToString>(lbl: S, ty: Type) -> Self {
88 Row::Closed(ClosedRow {
89 fields: vec![lbl.to_string()],
90 values: vec![ty],
91 })
92 }
93
94 pub fn equatable(&self, other: &Self) -> bool {
98 match (self, other) {
99 (Row::Unifier(a), Row::Unifier(b)) => a == b,
101 (Row::Open(a), Row::Open(b)) => a == b,
103 (Row::Closed(a), Row::Closed(b)) => a.fields == b.fields,
105 _ => false,
107 }
108 }
109}
110
111#[derive(Debug, PartialEq, Eq, Clone, PartialOrd, Ord, Hash)]
115pub enum Type {
116 Unit,
118 Int,
120 Float,
122 String,
124 Unifier(TypeUniVar),
126 Var(TypeVar),
128 Abs(Box<Self>, Box<Self>),
130 Prod(Row),
132 Sum(Row),
134 Label(Label, Box<Self>),
136 DataFrame,
138}
139
140impl EqUnifyValue for Type {}
141impl Type {
142 pub fn abstraction(arg: Self, ret: Self) -> Self {
143 Self::Abs(Box::new(arg), Box::new(ret))
144 }
145 pub fn abstractions<T>(arg: T, ret: Self) -> Self
146 where
147 T: IntoIterator<Item = Type>,
148 <T as IntoIterator>::IntoIter: DoubleEndedIterator<Item = Type>,
149 {
150 arg.into_iter()
151 .rfold(ret, |ret, param| Self::Abs(Box::new(param), Box::new(ret)))
152 }
153
154 pub fn label(label: Label, value: Self) -> Self {
155 Self::Label(label, Box::new(value))
156 }
157
158 pub fn occurs_check(&self, var: TypeUniVar) -> Result<(), Type> {
159 match self {
160 Type::Unit
161 | Type::Int
162 | Type::Float
163 | Type::String
164 | Type::Var(_)
165 | Type::DataFrame => Ok(()),
166 Type::Unifier(v) => {
167 if *v == var {
168 Err(Type::Unifier(*v))
169 } else {
170 Ok(())
171 }
172 }
173 Type::Abs(arg, ret) => {
174 arg.occurs_check(var).map_err(|_| self.clone())?;
175 ret.occurs_check(var).map_err(|_| self.clone())
176 }
177 Type::Label(_, ty) => ty.occurs_check(var).map_err(|_| self.clone()),
178 Type::Prod(row) | Type::Sum(row) => match row {
179 Row::Unifier(_) => Ok(()),
180 Row::Open(_) => Ok(()),
181 Row::Closed(closed_row) => {
182 for ty in closed_row.values.iter() {
183 ty.occurs_check(var).map_err(|_| self.clone())?
184 }
185 Ok(())
186 }
187 },
188 }
189 }
190
191 pub fn mentions(
192 &self,
193 unbound_tys: &HashSet<TypeUniVar>,
194 unbound_rows: &HashSet<RowUniVar>,
195 ) -> bool {
196 match self {
197 Type::Unit
198 | Type::Int
199 | Type::Float
200 | Type::String
201 | Type::Var(_)
202 | Type::DataFrame => false,
203 Type::Unifier(v) => unbound_tys.contains(v),
204 Type::Abs(arg, ret) => {
205 arg.mentions(unbound_tys, unbound_rows) || ret.mentions(unbound_tys, unbound_rows)
206 }
207 Type::Label(_, ty) => ty.mentions(unbound_tys, unbound_rows),
208 Type::Prod(row) | Type::Sum(row) => match row {
209 Row::Unifier(var) => unbound_rows.contains(var),
210 Row::Open(_) => false,
212 Row::Closed(row) => row.mentions(unbound_tys, unbound_rows),
213 },
214 }
215 }
216}
217
218#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
219pub struct RowVar(pub u32);
220
221#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
222pub struct RowUniVar {
223 pub id: u32,
224}
225impl RowUniVar {
226 pub fn new(id: u32) -> Self {
227 Self { id }
228 }
229}
230
231impl UnifyKey for RowUniVar {
232 type Value = Option<Row>;
233
234 fn index(&self) -> u32 {
235 self.id
236 }
237
238 fn from_index(id: u32) -> Self {
239 Self::new(id)
240 }
241
242 fn tag() -> &'static str {
243 "RowUniVar"
244 }
245}
246
247#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
248pub struct TypeVar(pub u32);
249
250#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
251pub struct TypeUniVar {
252 pub id: u32,
253}
254impl TypeUniVar {
255 fn new(id: u32) -> Self {
256 Self { id }
257 }
258}
259impl UnifyKey for TypeUniVar {
260 type Value = Option<Type>;
261
262 fn index(&self) -> u32 {
263 self.id
264 }
265
266 fn from_index(id: u32) -> Self {
267 Self::new(id)
268 }
269
270 fn tag() -> &'static str {
271 "TypeUniVar"
272 }
273}
274
275#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
276pub struct RowCombination {
277 pub left: Row,
278 pub right: Row,
279 pub goal: Row,
280}
281impl RowCombination {
282 pub fn is_unifiable(&self, other: &Self) -> bool {
289 let left_equatable = self.left.equatable(&other.left);
290 let right_equatable = self.right.equatable(&other.right);
291 let goal_equatable = self.goal.equatable(&other.goal);
292 (goal_equatable && (left_equatable || right_equatable))
293 || (left_equatable && right_equatable)
294 }
295
296 pub fn is_comm_unifiable(&self, other: &Self) -> bool {
299 let left_equatable = self.left.equatable(&other.right);
300 let right_equatable = self.right.equatable(&other.left);
301 let goal_equatable = self.goal.equatable(&other.goal);
302 (goal_equatable && (left_equatable || right_equatable))
303 || (left_equatable && right_equatable)
304 }
305
306 pub fn into_evidence(self) -> Evidence {
307 Evidence::RowEquation {
308 left: self.left,
309 right: self.right,
310 goal: self.goal,
311 }
312 }
313}