use crate::types::{StackType, Type};
use std::collections::HashMap;
pub type TypeSubst = HashMap<String, Type>;
pub type RowSubst = HashMap<String, StackType>;
#[derive(Debug, Clone, PartialEq)]
pub struct Subst {
pub types: TypeSubst,
pub rows: RowSubst,
}
impl Subst {
pub fn empty() -> Self {
Subst {
types: HashMap::new(),
rows: HashMap::new(),
}
}
pub fn apply_type(&self, ty: &Type) -> Type {
match ty {
Type::Var(name) => self.types.get(name).cloned().unwrap_or(ty.clone()),
_ => ty.clone(),
}
}
pub fn apply_stack(&self, stack: &StackType) -> StackType {
match stack {
StackType::Empty => StackType::Empty,
StackType::Cons { rest, top } => {
let new_rest = self.apply_stack(rest);
let new_top = self.apply_type(top);
StackType::Cons {
rest: Box::new(new_rest),
top: new_top,
}
}
StackType::RowVar(name) => self.rows.get(name).cloned().unwrap_or(stack.clone()),
}
}
pub fn compose(&self, other: &Subst) -> Subst {
let mut types = HashMap::new();
let mut rows = HashMap::new();
for (k, v) in &self.types {
types.insert(k.clone(), other.apply_type(v));
}
for (k, v) in &other.types {
let v_subst = self.apply_type(v);
types.insert(k.clone(), v_subst);
}
for (k, v) in &self.rows {
rows.insert(k.clone(), other.apply_stack(v));
}
for (k, v) in &other.rows {
let v_subst = self.apply_stack(v);
rows.insert(k.clone(), v_subst);
}
Subst { types, rows }
}
}
fn occurs_in_type(var: &str, ty: &Type) -> bool {
match ty {
Type::Var(name) => name == var,
Type::Int
| Type::Float
| Type::Bool
| Type::String
| Type::Symbol
| Type::Channel
| Type::Union(_) => false,
Type::Quotation(effect) => {
occurs_in_stack(var, &effect.inputs) || occurs_in_stack(var, &effect.outputs)
}
Type::Closure { effect, captures } => {
occurs_in_stack(var, &effect.inputs)
|| occurs_in_stack(var, &effect.outputs)
|| captures.iter().any(|t| occurs_in_type(var, t))
}
}
}
fn occurs_in_stack(var: &str, stack: &StackType) -> bool {
match stack {
StackType::Empty => false,
StackType::RowVar(name) => name == var,
StackType::Cons { rest, top: _ } => {
occurs_in_stack(var, rest)
}
}
}
pub fn unify_types(t1: &Type, t2: &Type) -> Result<Subst, String> {
match (t1, t2) {
(Type::Int, Type::Int)
| (Type::Float, Type::Float)
| (Type::Bool, Type::Bool)
| (Type::String, Type::String)
| (Type::Symbol, Type::Symbol)
| (Type::Channel, Type::Channel) => Ok(Subst::empty()),
(Type::Union(name1), Type::Union(name2)) => {
if name1 == name2 {
Ok(Subst::empty())
} else {
Err(format!(
"Type mismatch: cannot unify Union({}) with Union({})",
name1, name2
))
}
}
(Type::Var(name), ty) | (ty, Type::Var(name)) => {
if matches!(ty, Type::Var(ty_name) if ty_name == name) {
return Ok(Subst::empty());
}
if occurs_in_type(name, ty) {
return Err(format!(
"Occurs check failed: cannot unify {:?} with {:?} (would create infinite type)",
Type::Var(name.clone()),
ty
));
}
let mut subst = Subst::empty();
subst.types.insert(name.clone(), ty.clone());
Ok(subst)
}
(Type::Quotation(effect1), Type::Quotation(effect2)) => {
let s_in = unify_stacks(&effect1.inputs, &effect2.inputs)?;
let out1 = s_in.apply_stack(&effect1.outputs);
let out2 = s_in.apply_stack(&effect2.outputs);
let s_out = unify_stacks(&out1, &out2)?;
Ok(s_in.compose(&s_out))
}
(
Type::Closure {
effect: effect1, ..
},
Type::Closure {
effect: effect2, ..
},
) => {
let s_in = unify_stacks(&effect1.inputs, &effect2.inputs)?;
let out1 = s_in.apply_stack(&effect1.outputs);
let out2 = s_in.apply_stack(&effect2.outputs);
let s_out = unify_stacks(&out1, &out2)?;
Ok(s_in.compose(&s_out))
}
(Type::Quotation(quot_effect), Type::Closure { effect, .. })
| (Type::Closure { effect, .. }, Type::Quotation(quot_effect)) => {
let s_in = unify_stacks("_effect.inputs, &effect.inputs)?;
let out1 = s_in.apply_stack("_effect.outputs);
let out2 = s_in.apply_stack(&effect.outputs);
let s_out = unify_stacks(&out1, &out2)?;
Ok(s_in.compose(&s_out))
}
_ => Err(format!("Type mismatch: cannot unify {} with {}", t1, t2)),
}
}
pub fn unify_stacks(s1: &StackType, s2: &StackType) -> Result<Subst, String> {
match (s1, s2) {
(StackType::Empty, StackType::Empty) => Ok(Subst::empty()),
(StackType::RowVar(name), stack) | (stack, StackType::RowVar(name)) => {
if matches!(stack, StackType::RowVar(stack_name) if stack_name == name) {
return Ok(Subst::empty());
}
if occurs_in_stack(name, stack) {
return Err(format!(
"Occurs check failed: cannot unify {} with {} (would create infinite stack type)",
StackType::RowVar(name.clone()),
stack
));
}
let mut subst = Subst::empty();
subst.rows.insert(name.clone(), stack.clone());
Ok(subst)
}
(
StackType::Cons {
rest: rest1,
top: top1,
},
StackType::Cons {
rest: rest2,
top: top2,
},
) => {
let s_top = unify_types(top1, top2)?;
let rest1_subst = s_top.apply_stack(rest1);
let rest2_subst = s_top.apply_stack(rest2);
let s_rest = unify_stacks(&rest1_subst, &rest2_subst)?;
Ok(s_top.compose(&s_rest))
}
_ => Err(format!(
"Stack shape mismatch: cannot unify {} with {}",
s1, s2
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unify_concrete_types() {
assert!(unify_types(&Type::Int, &Type::Int).is_ok());
assert!(unify_types(&Type::Bool, &Type::Bool).is_ok());
assert!(unify_types(&Type::String, &Type::String).is_ok());
assert!(unify_types(&Type::Int, &Type::Bool).is_err());
}
#[test]
fn test_unify_type_variable() {
let subst = unify_types(&Type::Var("T".to_string()), &Type::Int).unwrap();
assert_eq!(subst.types.get("T"), Some(&Type::Int));
let subst = unify_types(&Type::Bool, &Type::Var("U".to_string())).unwrap();
assert_eq!(subst.types.get("U"), Some(&Type::Bool));
}
#[test]
fn test_unify_empty_stacks() {
assert!(unify_stacks(&StackType::Empty, &StackType::Empty).is_ok());
}
#[test]
fn test_unify_row_variable() {
let subst = unify_stacks(
&StackType::RowVar("a".to_string()),
&StackType::singleton(Type::Int),
)
.unwrap();
assert_eq!(subst.rows.get("a"), Some(&StackType::singleton(Type::Int)));
}
#[test]
fn test_unify_cons_stacks() {
let s1 = StackType::singleton(Type::Int);
let s2 = StackType::singleton(Type::Int);
assert!(unify_stacks(&s1, &s2).is_ok());
}
#[test]
fn test_unify_cons_with_type_var() {
let s1 = StackType::singleton(Type::Var("T".to_string()));
let s2 = StackType::singleton(Type::Int);
let subst = unify_stacks(&s1, &s2).unwrap();
assert_eq!(subst.types.get("T"), Some(&Type::Int));
}
#[test]
fn test_unify_row_poly_stack() {
let s1 = StackType::RowVar("a".to_string()).push(Type::Int);
let s2 = StackType::Empty.push(Type::Bool).push(Type::Int);
let subst = unify_stacks(&s1, &s2).unwrap();
assert_eq!(subst.rows.get("a"), Some(&StackType::singleton(Type::Bool)));
}
#[test]
fn test_unify_polymorphic_dup() {
let input_actual = StackType::singleton(Type::Int);
let input_declared = StackType::RowVar("a".to_string()).push(Type::Var("T".to_string()));
let subst = unify_stacks(&input_declared, &input_actual).unwrap();
assert_eq!(subst.rows.get("a"), Some(&StackType::Empty));
assert_eq!(subst.types.get("T"), Some(&Type::Int));
let output_declared = StackType::RowVar("a".to_string())
.push(Type::Var("T".to_string()))
.push(Type::Var("T".to_string()));
let output_actual = subst.apply_stack(&output_declared);
assert_eq!(
output_actual,
StackType::Empty.push(Type::Int).push(Type::Int)
);
}
#[test]
fn test_subst_compose() {
let mut s1 = Subst::empty();
s1.types.insert("T".to_string(), Type::Int);
let mut s2 = Subst::empty();
s2.types.insert("U".to_string(), Type::Var("T".to_string()));
let composed = s1.compose(&s2);
assert_eq!(composed.types.get("T"), Some(&Type::Int));
assert_eq!(composed.types.get("U"), Some(&Type::Int));
}
#[test]
fn test_occurs_check_type_var_with_itself() {
let result = unify_types(&Type::Var("T".to_string()), &Type::Var("T".to_string()));
assert!(result.is_ok());
let subst = result.unwrap();
assert!(subst.types.is_empty());
}
#[test]
fn test_occurs_check_row_var_with_itself() {
let result = unify_stacks(
&StackType::RowVar("a".to_string()),
&StackType::RowVar("a".to_string()),
);
assert!(result.is_ok());
let subst = result.unwrap();
assert!(subst.rows.is_empty());
}
#[test]
fn test_occurs_check_prevents_infinite_stack() {
let row_var = StackType::RowVar("a".to_string());
let infinite_stack = StackType::RowVar("a".to_string()).push(Type::Int);
let result = unify_stacks(&row_var, &infinite_stack);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.contains("Occurs check failed"));
assert!(err.contains("infinite"));
}
#[test]
fn test_occurs_check_allows_different_row_vars() {
let result = unify_stacks(
&StackType::RowVar("a".to_string()),
&StackType::RowVar("b".to_string()),
);
assert!(result.is_ok());
let subst = result.unwrap();
assert_eq!(
subst.rows.get("a"),
Some(&StackType::RowVar("b".to_string()))
);
}
#[test]
fn test_occurs_check_allows_concrete_stack() {
let row_var = StackType::RowVar("a".to_string());
let concrete = StackType::Empty.push(Type::Int).push(Type::String);
let result = unify_stacks(&row_var, &concrete);
assert!(result.is_ok());
let subst = result.unwrap();
assert_eq!(subst.rows.get("a"), Some(&concrete));
}
#[test]
fn test_occurs_in_type() {
assert!(occurs_in_type("T", &Type::Var("T".to_string())));
assert!(!occurs_in_type("T", &Type::Var("U".to_string())));
assert!(!occurs_in_type("T", &Type::Int));
assert!(!occurs_in_type("T", &Type::String));
assert!(!occurs_in_type("T", &Type::Bool));
}
#[test]
fn test_occurs_in_stack() {
assert!(occurs_in_stack("a", &StackType::RowVar("a".to_string())));
assert!(!occurs_in_stack("a", &StackType::RowVar("b".to_string())));
assert!(!occurs_in_stack("a", &StackType::Empty));
let stack = StackType::RowVar("a".to_string()).push(Type::Int);
assert!(occurs_in_stack("a", &stack));
let stack = StackType::RowVar("b".to_string()).push(Type::Int);
assert!(!occurs_in_stack("a", &stack));
let stack = StackType::Empty.push(Type::Int).push(Type::String);
assert!(!occurs_in_stack("a", &stack));
}
#[test]
fn test_quotation_type_unification_stack_neutral() {
use crate::types::Effect;
let stack_neutral = Type::Quotation(Box::new(Effect::new(
StackType::RowVar("a".to_string()),
StackType::RowVar("a".to_string()),
)));
let pushes_int = Type::Quotation(Box::new(Effect::new(
StackType::RowVar("b".to_string()),
StackType::RowVar("b".to_string()).push(Type::Int),
)));
let result = unify_types(&stack_neutral, &pushes_int);
assert!(
result.is_err(),
"Unifying stack-neutral with stack-pushing quotation should fail, got {:?}",
result
);
}
}