1use crate::types::{StackType, Type};
9use std::collections::HashMap;
10
11pub type TypeSubst = HashMap<String, Type>;
13
14pub type RowSubst = HashMap<String, StackType>;
16
17#[derive(Debug, Clone, PartialEq)]
19pub struct Subst {
20 pub types: TypeSubst,
21 pub rows: RowSubst,
22}
23
24impl Subst {
25 pub fn empty() -> Self {
27 Subst {
28 types: HashMap::new(),
29 rows: HashMap::new(),
30 }
31 }
32
33 pub fn apply_type(&self, ty: &Type) -> Type {
35 match ty {
36 Type::Var(name) => self.types.get(name).cloned().unwrap_or(ty.clone()),
37 _ => ty.clone(),
38 }
39 }
40
41 pub fn apply_stack(&self, stack: &StackType) -> StackType {
43 match stack {
44 StackType::Empty => StackType::Empty,
45 StackType::Cons { rest, top } => {
46 let new_rest = self.apply_stack(rest);
47 let new_top = self.apply_type(top);
48 StackType::Cons {
49 rest: Box::new(new_rest),
50 top: new_top,
51 }
52 }
53 StackType::RowVar(name) => self.rows.get(name).cloned().unwrap_or(stack.clone()),
54 }
55 }
56
57 pub fn compose(&self, other: &Subst) -> Subst {
60 let mut types = HashMap::new();
61 let mut rows = HashMap::new();
62
63 for (k, v) in &self.types {
65 types.insert(k.clone(), other.apply_type(v));
66 }
67
68 for (k, v) in &other.types {
70 let v_subst = self.apply_type(v);
71 types.insert(k.clone(), v_subst);
72 }
73
74 for (k, v) in &self.rows {
76 rows.insert(k.clone(), other.apply_stack(v));
77 }
78
79 for (k, v) in &other.rows {
81 let v_subst = self.apply_stack(v);
82 rows.insert(k.clone(), v_subst);
83 }
84
85 Subst { types, rows }
86 }
87}
88
89fn occurs_in_type(var: &str, ty: &Type) -> bool {
103 match ty {
104 Type::Var(name) => name == var,
105 Type::Int
107 | Type::Float
108 | Type::Bool
109 | Type::String
110 | Type::Symbol
111 | Type::Channel
112 | Type::Union(_) => false,
113 Type::Quotation(effect) => {
114 occurs_in_stack(var, &effect.inputs) || occurs_in_stack(var, &effect.outputs)
116 }
117 Type::Closure { effect, captures } => {
118 occurs_in_stack(var, &effect.inputs)
120 || occurs_in_stack(var, &effect.outputs)
121 || captures.iter().any(|t| occurs_in_type(var, t))
122 }
123 }
124}
125
126fn occurs_in_stack(var: &str, stack: &StackType) -> bool {
128 match stack {
129 StackType::Empty => false,
130 StackType::RowVar(name) => name == var,
131 StackType::Cons { rest, top: _ } => {
132 occurs_in_stack(var, rest)
135 }
136 }
137}
138
139pub fn unify_types(t1: &Type, t2: &Type) -> Result<Subst, String> {
141 match (t1, t2) {
142 (Type::Int, Type::Int)
144 | (Type::Float, Type::Float)
145 | (Type::Bool, Type::Bool)
146 | (Type::String, Type::String)
147 | (Type::Symbol, Type::Symbol)
148 | (Type::Channel, Type::Channel) => Ok(Subst::empty()),
149
150 (Type::Union(name1), Type::Union(name2)) => {
152 if name1 == name2 {
153 Ok(Subst::empty())
154 } else {
155 Err(format!(
156 "Type mismatch: cannot unify Union({}) with Union({})",
157 name1, name2
158 ))
159 }
160 }
161
162 (Type::Var(name), ty) | (ty, Type::Var(name)) => {
164 if matches!(ty, Type::Var(ty_name) if ty_name == name) {
166 return Ok(Subst::empty());
167 }
168
169 if occurs_in_type(name, ty) {
171 return Err(format!(
172 "Occurs check failed: cannot unify {:?} with {:?} (would create infinite type)",
173 Type::Var(name.clone()),
174 ty
175 ));
176 }
177
178 let mut subst = Subst::empty();
179 subst.types.insert(name.clone(), ty.clone());
180 Ok(subst)
181 }
182
183 (Type::Quotation(effect1), Type::Quotation(effect2)) => {
185 let s_in = unify_stacks(&effect1.inputs, &effect2.inputs)?;
187
188 let out1 = s_in.apply_stack(&effect1.outputs);
190 let out2 = s_in.apply_stack(&effect2.outputs);
191 let s_out = unify_stacks(&out1, &out2)?;
192
193 Ok(s_in.compose(&s_out))
195 }
196
197 (
201 Type::Closure {
202 effect: effect1, ..
203 },
204 Type::Closure {
205 effect: effect2, ..
206 },
207 ) => {
208 let s_in = unify_stacks(&effect1.inputs, &effect2.inputs)?;
210
211 let out1 = s_in.apply_stack(&effect1.outputs);
213 let out2 = s_in.apply_stack(&effect2.outputs);
214 let s_out = unify_stacks(&out1, &out2)?;
215
216 Ok(s_in.compose(&s_out))
218 }
219
220 (Type::Quotation(quot_effect), Type::Closure { effect, .. })
224 | (Type::Closure { effect, .. }, Type::Quotation(quot_effect)) => {
225 let s_in = unify_stacks("_effect.inputs, &effect.inputs)?;
227
228 let out1 = s_in.apply_stack("_effect.outputs);
230 let out2 = s_in.apply_stack(&effect.outputs);
231 let s_out = unify_stacks(&out1, &out2)?;
232
233 Ok(s_in.compose(&s_out))
235 }
236
237 _ => Err(format!("Type mismatch: cannot unify {} with {}", t1, t2)),
239 }
240}
241
242pub fn unify_stacks(s1: &StackType, s2: &StackType) -> Result<Subst, String> {
244 match (s1, s2) {
245 (StackType::Empty, StackType::Empty) => Ok(Subst::empty()),
247
248 (StackType::RowVar(name), stack) | (stack, StackType::RowVar(name)) => {
250 if matches!(stack, StackType::RowVar(stack_name) if stack_name == name) {
252 return Ok(Subst::empty());
253 }
254
255 if occurs_in_stack(name, stack) {
257 return Err(format!(
258 "Occurs check failed: cannot unify {} with {} (would create infinite stack type)",
259 StackType::RowVar(name.clone()),
260 stack
261 ));
262 }
263
264 let mut subst = Subst::empty();
265 subst.rows.insert(name.clone(), stack.clone());
266 Ok(subst)
267 }
268
269 (
271 StackType::Cons {
272 rest: rest1,
273 top: top1,
274 },
275 StackType::Cons {
276 rest: rest2,
277 top: top2,
278 },
279 ) => {
280 let s_top = unify_types(top1, top2)?;
282
283 let rest1_subst = s_top.apply_stack(rest1);
285 let rest2_subst = s_top.apply_stack(rest2);
286 let s_rest = unify_stacks(&rest1_subst, &rest2_subst)?;
287
288 Ok(s_top.compose(&s_rest))
290 }
291
292 _ => Err(format!(
294 "Stack shape mismatch: cannot unify {} with {}",
295 s1, s2
296 )),
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_unify_concrete_types() {
306 assert!(unify_types(&Type::Int, &Type::Int).is_ok());
307 assert!(unify_types(&Type::Bool, &Type::Bool).is_ok());
308 assert!(unify_types(&Type::String, &Type::String).is_ok());
309
310 assert!(unify_types(&Type::Int, &Type::Bool).is_err());
311 }
312
313 #[test]
314 fn test_unify_type_variable() {
315 let subst = unify_types(&Type::Var("T".to_string()), &Type::Int).unwrap();
316 assert_eq!(subst.types.get("T"), Some(&Type::Int));
317
318 let subst = unify_types(&Type::Bool, &Type::Var("U".to_string())).unwrap();
319 assert_eq!(subst.types.get("U"), Some(&Type::Bool));
320 }
321
322 #[test]
323 fn test_unify_empty_stacks() {
324 assert!(unify_stacks(&StackType::Empty, &StackType::Empty).is_ok());
325 }
326
327 #[test]
328 fn test_unify_row_variable() {
329 let subst = unify_stacks(
330 &StackType::RowVar("a".to_string()),
331 &StackType::singleton(Type::Int),
332 )
333 .unwrap();
334
335 assert_eq!(subst.rows.get("a"), Some(&StackType::singleton(Type::Int)));
336 }
337
338 #[test]
339 fn test_unify_cons_stacks() {
340 let s1 = StackType::singleton(Type::Int);
342 let s2 = StackType::singleton(Type::Int);
343
344 assert!(unify_stacks(&s1, &s2).is_ok());
345 }
346
347 #[test]
348 fn test_unify_cons_with_type_var() {
349 let s1 = StackType::singleton(Type::Var("T".to_string()));
351 let s2 = StackType::singleton(Type::Int);
352
353 let subst = unify_stacks(&s1, &s2).unwrap();
354 assert_eq!(subst.types.get("T"), Some(&Type::Int));
355 }
356
357 #[test]
358 fn test_unify_row_poly_stack() {
359 let s1 = StackType::RowVar("a".to_string()).push(Type::Int);
361 let s2 = StackType::Empty.push(Type::Bool).push(Type::Int);
362
363 let subst = unify_stacks(&s1, &s2).unwrap();
364
365 assert_eq!(subst.rows.get("a"), Some(&StackType::singleton(Type::Bool)));
366 }
367
368 #[test]
369 fn test_unify_polymorphic_dup() {
370 let input_actual = StackType::singleton(Type::Int);
374 let input_declared = StackType::RowVar("a".to_string()).push(Type::Var("T".to_string()));
375
376 let subst = unify_stacks(&input_declared, &input_actual).unwrap();
377
378 assert_eq!(subst.rows.get("a"), Some(&StackType::Empty));
379 assert_eq!(subst.types.get("T"), Some(&Type::Int));
380
381 let output_declared = StackType::RowVar("a".to_string())
383 .push(Type::Var("T".to_string()))
384 .push(Type::Var("T".to_string()));
385
386 let output_actual = subst.apply_stack(&output_declared);
387
388 assert_eq!(
390 output_actual,
391 StackType::Empty.push(Type::Int).push(Type::Int)
392 );
393 }
394
395 #[test]
396 fn test_subst_compose() {
397 let mut s1 = Subst::empty();
399 s1.types.insert("T".to_string(), Type::Int);
400
401 let mut s2 = Subst::empty();
403 s2.types.insert("U".to_string(), Type::Var("T".to_string()));
404
405 let composed = s1.compose(&s2);
407
408 assert_eq!(composed.types.get("T"), Some(&Type::Int));
409 assert_eq!(composed.types.get("U"), Some(&Type::Int));
410 }
411
412 #[test]
413 fn test_occurs_check_type_var_with_itself() {
414 let result = unify_types(&Type::Var("T".to_string()), &Type::Var("T".to_string()));
416 assert!(result.is_ok());
417 let subst = result.unwrap();
418 assert!(subst.types.is_empty());
420 }
421
422 #[test]
423 fn test_occurs_check_row_var_with_itself() {
424 let result = unify_stacks(
426 &StackType::RowVar("a".to_string()),
427 &StackType::RowVar("a".to_string()),
428 );
429 assert!(result.is_ok());
430 let subst = result.unwrap();
431 assert!(subst.rows.is_empty());
433 }
434
435 #[test]
436 fn test_occurs_check_prevents_infinite_stack() {
437 let row_var = StackType::RowVar("a".to_string());
440 let infinite_stack = StackType::RowVar("a".to_string()).push(Type::Int);
441
442 let result = unify_stacks(&row_var, &infinite_stack);
443 assert!(result.is_err());
444 let err = result.unwrap_err();
445 assert!(err.contains("Occurs check failed"));
446 assert!(err.contains("infinite"));
447 }
448
449 #[test]
450 fn test_occurs_check_allows_different_row_vars() {
451 let result = unify_stacks(
453 &StackType::RowVar("a".to_string()),
454 &StackType::RowVar("b".to_string()),
455 );
456 assert!(result.is_ok());
457 let subst = result.unwrap();
458 assert_eq!(
459 subst.rows.get("a"),
460 Some(&StackType::RowVar("b".to_string()))
461 );
462 }
463
464 #[test]
465 fn test_occurs_check_allows_concrete_stack() {
466 let row_var = StackType::RowVar("a".to_string());
468 let concrete = StackType::Empty.push(Type::Int).push(Type::String);
469
470 let result = unify_stacks(&row_var, &concrete);
471 assert!(result.is_ok());
472 let subst = result.unwrap();
473 assert_eq!(subst.rows.get("a"), Some(&concrete));
474 }
475
476 #[test]
477 fn test_occurs_in_type() {
478 assert!(occurs_in_type("T", &Type::Var("T".to_string())));
480
481 assert!(!occurs_in_type("T", &Type::Var("U".to_string())));
483
484 assert!(!occurs_in_type("T", &Type::Int));
486 assert!(!occurs_in_type("T", &Type::String));
487 assert!(!occurs_in_type("T", &Type::Bool));
488 }
489
490 #[test]
491 fn test_occurs_in_stack() {
492 assert!(occurs_in_stack("a", &StackType::RowVar("a".to_string())));
494
495 assert!(!occurs_in_stack("a", &StackType::RowVar("b".to_string())));
497
498 assert!(!occurs_in_stack("a", &StackType::Empty));
500
501 let stack = StackType::RowVar("a".to_string()).push(Type::Int);
503 assert!(occurs_in_stack("a", &stack));
504
505 let stack = StackType::RowVar("b".to_string()).push(Type::Int);
507 assert!(!occurs_in_stack("a", &stack));
508
509 let stack = StackType::Empty.push(Type::Int).push(Type::String);
511 assert!(!occurs_in_stack("a", &stack));
512 }
513
514 #[test]
515 fn test_quotation_type_unification_stack_neutral() {
516 use crate::types::Effect;
519
520 let stack_neutral = Type::Quotation(Box::new(Effect::new(
521 StackType::RowVar("a".to_string()),
522 StackType::RowVar("a".to_string()),
523 )));
524
525 let pushes_int = Type::Quotation(Box::new(Effect::new(
526 StackType::RowVar("b".to_string()),
527 StackType::RowVar("b".to_string()).push(Type::Int),
528 )));
529
530 let result = unify_types(&stack_neutral, &pushes_int);
531 assert!(
534 result.is_err(),
535 "Unifying stack-neutral with stack-pushing quotation should fail, got {:?}",
536 result
537 );
538 }
539}