use crate::lexer::Span;
use crate::semantics::error::SemanticError;
use crate::semantics::format::format_type;
use crate::semantics::types::Type;
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct Unifier {
pub substitutions: HashMap<String, Type>,
}
impl Unifier {
pub fn new() -> Self {
Self::default()
}
pub fn unify(&mut self, a: &Type, b: &Type, span: Span) -> Result<(), SemanticError> {
match (a, b) {
(Type::Variable(var), t) | (t, Type::Variable(var)) => {
if let Type::Variable(tvar) = t {
if tvar == var {
return Ok(());
}
}
if let Some(existing) = self.substitutions.get(var).cloned() {
self.unify(&existing, t, span)?;
} else {
if self.occurs(var, t) {
return Err(SemanticError {
message: format!(
"Recursive type: {} occurs in {}",
var,
format_type(t)
),
span,
});
}
self.substitutions.insert(var.clone(), t.clone());
}
}
(Type::Primitive(p1), Type::Primitive(p2)) if p1 == p2 => {}
(Type::Named(n1, args1), Type::Named(n2, args2))
if n1 == n2 && args1.len() == args2.len() =>
{
for (a1, a2) in args1.iter().zip(args2) {
self.unify(a1, a2, span)?;
}
}
(
Type::Function {
params: p1,
returns: r1,
..
},
Type::Function {
params: p2,
returns: r2,
..
},
) if p1.len() == p2.len() => {
for (a1, a2) in p1.iter().zip(p2) {
self.unify(a1, a2, span)?;
}
self.unify(r1, r2, span)?;
}
(Type::Reference(t1), Type::Reference(t2)) => self.unify(t1, t2, span)?,
(Type::List(t1), Type::List(t2)) => self.unify(t1, t2, span)?,
(Type::Map(k1, v1), Type::Map(k2, v2)) => {
self.unify(k1, k2, span)?;
self.unify(v1, v2, span)?;
}
(Type::Set(t1), Type::Set(t2)) => self.unify(t1, t2, span)?,
(Type::Tuple(l1, r1), Type::Tuple(l2, r2)) => {
self.unify(l1, l2, span)?;
self.unify(r1, r2, span)?;
}
(Type::Optional(t1), Type::Optional(t2)) => self.unify(t1, t2, span)?,
(Type::Void, Type::Void) => {}
(Type::EmptyList, Type::EmptyList) => {}
(Type::EmptyMap, Type::EmptyMap) => {}
(Type::EmptySet, Type::EmptySet) => {}
(Type::List(_), Type::EmptyList) | (Type::EmptyList, Type::List(_)) => {}
(Type::Map(_, _), Type::EmptyMap) | (Type::EmptyMap, Type::Map(_, _)) => {}
(Type::Set(_), Type::EmptySet) | (Type::EmptySet, Type::Set(_)) => {}
(Type::Map(_, _), Type::EmptySet) | (Type::EmptySet, Type::Map(_, _)) => {}
(Type::Never, _) => {}
(_, Type::Never) => {}
_ => {
return Err(SemanticError::new(
format!(
"Type mismatch: expected {}, got {}",
format_type(a),
format_type(b)
),
span,
));
}
}
Ok(())
}
#[allow(clippy::only_used_in_recursion)]
fn occurs(&self, var: &str, t: &Type) -> bool {
match t {
Type::Variable(v) if v == var => true,
Type::Named(_, args) => args.iter().any(|arg| self.occurs(var, arg)),
Type::Function {
params, returns, ..
} => params.iter().any(|p| self.occurs(var, p)) || self.occurs(var, returns),
Type::Reference(inner)
| Type::List(inner)
| Type::Set(inner)
| Type::Optional(inner) => self.occurs(var, inner),
Type::Map(k, v) => self.occurs(var, k) || self.occurs(var, v),
Type::Tuple(l, r) => self.occurs(var, l) || self.occurs(var, r),
_ => false,
}
}
pub fn apply(&self, t: &Type) -> Type {
match t {
Type::Variable(var) => self
.substitutions
.get(var)
.cloned()
.unwrap_or_else(|| t.clone()),
Type::Named(name, args) => {
let applied_args: Vec<Type> = args.iter().map(|arg| self.apply(arg)).collect();
Type::Named(name.clone(), applied_args)
}
Type::Function {
params,
returns,
default_count,
} => Type::Function {
params: params.iter().map(|p| self.apply(p)).collect(),
returns: Box::new(self.apply(returns)),
default_count: *default_count,
},
Type::Reference(inner) => Type::Reference(Box::new(self.apply(inner))),
Type::List(inner) => Type::List(Box::new(self.apply(inner))),
Type::Map(k, v) => Type::Map(Box::new(self.apply(k)), Box::new(self.apply(v))),
Type::Set(inner) => Type::Set(Box::new(self.apply(inner))),
Type::Tuple(l, r) => Type::Tuple(Box::new(self.apply(l)), Box::new(self.apply(r))),
Type::Optional(inner) => Type::Optional(Box::new(self.apply(inner))),
_ => t.clone(),
}
}
}