1use crate::types::{StackType, Type};
9use std::collections::HashMap;
10
11pub type TypeSubst = HashMap<String, Type>;
13
14pub type RowSubst = HashMap<String, StackType>;
16
17#[derive(Debug, Clone, PartialEq)]
19pub struct Subst {
20 pub types: TypeSubst,
21 pub rows: RowSubst,
22}
23
24impl Subst {
25 pub fn empty() -> Self {
27 Subst {
28 types: HashMap::new(),
29 rows: HashMap::new(),
30 }
31 }
32
33 pub fn apply_type(&self, ty: &Type) -> Type {
35 match ty {
36 Type::Var(name) => self.types.get(name).cloned().unwrap_or(ty.clone()),
37 _ => ty.clone(),
38 }
39 }
40
41 pub fn apply_stack(&self, stack: &StackType) -> StackType {
43 match stack {
44 StackType::Empty => StackType::Empty,
45 StackType::Cons { rest, top } => {
46 let new_rest = self.apply_stack(rest);
47 let new_top = self.apply_type(top);
48 StackType::Cons {
49 rest: Box::new(new_rest),
50 top: new_top,
51 }
52 }
53 StackType::RowVar(name) => self.rows.get(name).cloned().unwrap_or(stack.clone()),
54 }
55 }
56
57 pub fn compose(&self, other: &Subst) -> Subst {
60 let mut types = HashMap::new();
61 let mut rows = HashMap::new();
62
63 for (k, v) in &self.types {
65 types.insert(k.clone(), other.apply_type(v));
66 }
67
68 for (k, v) in &other.types {
70 let v_subst = self.apply_type(v);
71 types.insert(k.clone(), v_subst);
72 }
73
74 for (k, v) in &self.rows {
76 rows.insert(k.clone(), other.apply_stack(v));
77 }
78
79 for (k, v) in &other.rows {
81 let v_subst = self.apply_stack(v);
82 rows.insert(k.clone(), v_subst);
83 }
84
85 Subst { types, rows }
86 }
87}
88
89fn occurs_in_type(var: &str, ty: &Type) -> bool {
103 match ty {
104 Type::Var(name) => name == var,
105 Type::Int
107 | Type::Float
108 | Type::Bool
109 | Type::String
110 | Type::Symbol
111 | Type::Channel
112 | Type::Union(_) => false,
113 Type::Quotation(effect) => {
114 occurs_in_stack(var, &effect.inputs) || occurs_in_stack(var, &effect.outputs)
116 }
117 Type::Closure { effect, captures } => {
118 occurs_in_stack(var, &effect.inputs)
120 || occurs_in_stack(var, &effect.outputs)
121 || captures.iter().any(|t| occurs_in_type(var, t))
122 }
123 }
124}
125
126fn occurs_in_stack(var: &str, stack: &StackType) -> bool {
128 match stack {
129 StackType::Empty => false,
130 StackType::RowVar(name) => name == var,
131 StackType::Cons { rest, top: _ } => {
132 occurs_in_stack(var, rest)
135 }
136 }
137}
138
139pub fn unify_types(t1: &Type, t2: &Type) -> Result<Subst, String> {
141 match (t1, t2) {
142 (Type::Int, Type::Int)
144 | (Type::Float, Type::Float)
145 | (Type::Bool, Type::Bool)
146 | (Type::String, Type::String)
147 | (Type::Symbol, Type::Symbol)
148 | (Type::Channel, Type::Channel) => Ok(Subst::empty()),
149
150 (Type::Union(name1), Type::Union(name2)) => {
152 if name1 == name2 {
153 Ok(Subst::empty())
154 } else {
155 Err(format!(
156 "Type mismatch: cannot unify Union({}) with Union({})",
157 name1, name2
158 ))
159 }
160 }
161
162 (Type::Var(name), ty) | (ty, Type::Var(name)) => {
164 if matches!(ty, Type::Var(ty_name) if ty_name == name) {
166 return Ok(Subst::empty());
167 }
168
169 if occurs_in_type(name, ty) {
171 return Err(format!(
172 "Occurs check failed: cannot unify {:?} with {:?} (would create infinite type)",
173 Type::Var(name.clone()),
174 ty
175 ));
176 }
177
178 let mut subst = Subst::empty();
179 subst.types.insert(name.clone(), ty.clone());
180 Ok(subst)
181 }
182
183 (Type::Quotation(effect1), Type::Quotation(effect2)) => {
185 let s_in = unify_stacks(&effect1.inputs, &effect2.inputs)?;
187
188 let out1 = s_in.apply_stack(&effect1.outputs);
190 let out2 = s_in.apply_stack(&effect2.outputs);
191 let s_out = unify_stacks(&out1, &out2)?;
192
193 Ok(s_in.compose(&s_out))
195 }
196
197 (
201 Type::Closure {
202 effect: effect1, ..
203 },
204 Type::Closure {
205 effect: effect2, ..
206 },
207 ) => {
208 let s_in = unify_stacks(&effect1.inputs, &effect2.inputs)?;
210
211 let out1 = s_in.apply_stack(&effect1.outputs);
213 let out2 = s_in.apply_stack(&effect2.outputs);
214 let s_out = unify_stacks(&out1, &out2)?;
215
216 Ok(s_in.compose(&s_out))
218 }
219
220 (Type::Quotation(quot_effect), Type::Closure { effect, .. })
224 | (Type::Closure { effect, .. }, Type::Quotation(quot_effect)) => {
225 let s_in = unify_stacks("_effect.inputs, &effect.inputs)?;
227
228 let out1 = s_in.apply_stack("_effect.outputs);
230 let out2 = s_in.apply_stack(&effect.outputs);
231 let s_out = unify_stacks(&out1, &out2)?;
232
233 Ok(s_in.compose(&s_out))
235 }
236
237 _ => Err(format!("Type mismatch: cannot unify {} with {}", t1, t2)),
239 }
240}
241
242pub fn unify_stacks(s1: &StackType, s2: &StackType) -> Result<Subst, String> {
244 match (s1, s2) {
245 (StackType::Empty, StackType::Empty) => Ok(Subst::empty()),
247
248 (StackType::RowVar(name), stack) | (stack, StackType::RowVar(name)) => {
250 if matches!(stack, StackType::RowVar(stack_name) if stack_name == name) {
252 return Ok(Subst::empty());
253 }
254
255 if occurs_in_stack(name, stack) {
257 return Err(format!(
258 "Occurs check failed: cannot unify {} with {} (would create infinite stack type)",
259 StackType::RowVar(name.clone()),
260 stack
261 ));
262 }
263
264 let mut subst = Subst::empty();
265 subst.rows.insert(name.clone(), stack.clone());
266 Ok(subst)
267 }
268
269 (
271 StackType::Cons {
272 rest: rest1,
273 top: top1,
274 },
275 StackType::Cons {
276 rest: rest2,
277 top: top2,
278 },
279 ) => {
280 let s_top = unify_types(top1, top2)?;
282
283 let rest1_subst = s_top.apply_stack(rest1);
285 let rest2_subst = s_top.apply_stack(rest2);
286 let s_rest = unify_stacks(&rest1_subst, &rest2_subst)?;
287
288 Ok(s_top.compose(&s_rest))
290 }
291
292 _ => Err(format!(
294 "Stack shape mismatch: cannot unify {} with {}",
295 s1, s2
296 )),
297 }
298}
299
300#[cfg(test)]
301mod tests;