use crate::logic::typing::Context;
use crate::logic::typing::Type;
use std::collections::HashMap;
pub fn is_unresolved(ty: &Type) -> bool {
matches!(
ty,
Type::Path(_) | Type::PathOf(_, _) | Type::ContextCall(_, _)
)
}
#[derive(Debug, Clone, PartialEq)]
pub enum UnifyResult {
Ok,
Indeterminate,
Fail(String),
}
impl UnifyResult {
pub fn is_ok(&self) -> bool {
matches!(self, UnifyResult::Ok)
}
pub fn is_fail(&self) -> bool {
matches!(self, UnifyResult::Fail(_))
}
pub fn is_indeterminate(&self) -> bool {
matches!(self, UnifyResult::Indeterminate)
}
}
#[derive(Debug, Clone, Default)]
pub struct Unifier {
pub substitution: HashMap<String, Type>,
pub context: Option<Context>,
pub binding_values: HashMap<String, String>,
}
impl Unifier {
pub fn new() -> Self {
Self::default()
}
pub fn from_map(map: HashMap<String, Type>) -> Self {
Self {
substitution: map,
context: None,
binding_values: HashMap::new(),
}
}
pub fn set_context(&mut self, ctx: &Context) {
self.context = Some(ctx.clone());
}
pub fn clear_context(&mut self) {
self.context = None;
}
pub fn set_binding_values(&mut self, values: HashMap<String, String>) {
self.binding_values = values;
}
pub fn as_map(&self) -> &HashMap<String, Type> {
&self.substitution
}
pub fn as_map_mut(&mut self) -> &mut HashMap<String, Type> {
&mut self.substitution
}
pub fn resolve_meta(&self, name: &str) -> Option<&Type> {
self.substitution.get(name)
}
pub fn seed<I, F>(&mut self, names: I, mut resolve: F)
where
I: IntoIterator<Item = String>,
F: FnMut(&str) -> Option<Type>,
{
for name in names {
if self.resolve_meta(&name).is_none()
&& let Some(resolved) = resolve(&name)
{
let _ = self.bind(&name, &resolved);
}
}
}
pub fn bind(&mut self, name: &str, ty: &Type) -> UnifyResult {
if let Some(existing) = self.substitution.get(name).cloned() {
return self.unify(&existing, ty);
}
if occurs_meta(name, ty) {
return UnifyResult::Fail(format!("Occurs check failed: ?{} occurs in {}", name, ty));
}
self.substitution.insert(name.to_string(), ty.clone());
UnifyResult::Ok
}
pub fn apply(&self, ty: &Type) -> Result<Type, String> {
let resolved = self.resolve_ctx_call(ty, true);
match resolved {
Type::Meta(name) => {
if let Some(bound) = self.substitution.get(&name) {
self.apply(bound)
} else {
Err(format!("Unbound meta variable: ?{}", name))
}
}
Type::Arrow(a, b) => {
let a = self.apply(a.as_ref())?;
let b = self.apply(b.as_ref())?;
Ok(Type::Arrow(Box::new(a), Box::new(b)))
}
Type::Array(inner) => {
let inner = self.apply(inner.as_ref())?;
Ok(Type::Array(Box::new(inner)))
}
Type::Union(parts) => {
let mut resolved = Vec::with_capacity(parts.len());
for p in parts {
resolved.push(self.apply(&p)?);
}
Ok(Type::Union(resolved))
}
Type::Not(a) => {
let a = self.apply(a.as_ref())?;
Ok(Type::Not(Box::new(a)))
}
Type::Partial(t, s) => Ok(Type::Partial(
Box::new(self.resolve_ctx_call(t.as_ref(), false)),
s,
)),
Type::PathOf(t, p) => Ok(Type::PathOf(
Box::new(self.resolve_ctx_call(t.as_ref(), false)),
p,
)),
_ => Ok(resolved),
}
}
pub fn has_unresolved_meta(&self, ty: &Type) -> bool {
match ty {
Type::Meta(name) => !self.substitution.contains_key(name),
Type::Arrow(a, b) => self.has_unresolved_meta(a) || self.has_unresolved_meta(b),
Type::Array(inner) => self.has_unresolved_meta(inner),
Type::Union(parts) => parts.iter().any(|p| self.has_unresolved_meta(p)),
Type::Not(a) => self.has_unresolved_meta(a),
_ => false,
}
}
pub fn resolve_for_subtyping(&self, ty: &Type) -> Type {
self.resolve_ctx_call(ty, true)
}
pub fn unify(&mut self, t1: &Type, t2: &Type) -> UnifyResult {
let t1 = self.walk(&self.resolve_ctx_call(t1, true));
let t2 = self.walk(&self.resolve_ctx_call(t2, true));
match (&t1, &t2) {
(Type::Raw(a), Type::Raw(b)) => {
if a == b {
UnifyResult::Ok
} else {
UnifyResult::Fail(format!("{} ≠ {}", a, b))
}
}
(Type::Meta(name), _) => self.bind(name, &t2),
(_, Type::Meta(name)) => self.bind(name, &t1),
(Type::Arrow(l1, r1), Type::Arrow(l2, r2)) => {
let l1 = l1.clone();
let r1 = r1.clone();
let l2 = l2.clone();
let r2 = r2.clone();
match self.unify(&l1, &l2) {
UnifyResult::Ok => self.unify(&r1, &r2),
UnifyResult::Indeterminate => {
match self.unify(&r1, &r2) {
UnifyResult::Fail(e) => UnifyResult::Fail(e),
_ => UnifyResult::Indeterminate,
}
}
fail => fail,
}
}
(Type::Array(a), Type::Array(b)) => {
let a = a.clone();
let b = b.clone();
self.unify(&a, &b)
}
(Type::Not(a), Type::Not(b)) => {
let a = a.clone();
let b = b.clone();
self.unify(&a, &b)
}
(Type::Any, Type::Any) => UnifyResult::Ok,
(Type::Any, _) | (_, Type::Any) => UnifyResult::Ok,
(Type::Union(a), Type::Union(b)) => {
if a.len() != b.len() {
return UnifyResult::Fail(format!(
"Union arity mismatch: {} vs {}",
a.len(),
b.len()
));
}
let mut saw_indeterminate = false;
for (l, r) in a.iter().zip(b.iter()) {
match self.unify(l, r) {
UnifyResult::Ok => {}
UnifyResult::Indeterminate => saw_indeterminate = true,
fail => return fail,
}
}
if saw_indeterminate {
UnifyResult::Indeterminate
} else {
UnifyResult::Ok
}
}
(Type::Union(_), _) | (_, Type::Union(_)) => {
UnifyResult::Fail(format!("Cannot unify {} with {}", t1, t2))
}
(Type::None, Type::None) => UnifyResult::Ok,
(Type::None, _) | (_, Type::None) => {
UnifyResult::Fail("None is not unifiable with non-None".to_string())
}
(Type::Path(_), _) | (_, Type::Path(_)) => UnifyResult::Indeterminate,
(Type::PathOf(_, _), _) | (_, Type::PathOf(_, _)) => UnifyResult::Indeterminate,
(Type::ContextCall(_, _), _) | (_, Type::ContextCall(_, _)) => {
UnifyResult::Indeterminate
}
(Type::Partial(t, _), other) | (other, Type::Partial(t, _)) => {
let t = t.clone();
let other = other.clone();
self.unify(&t, &other)
}
_ => UnifyResult::Fail(format!("Cannot unify {} with {}", t1, t2)),
}
}
fn walk(&self, ty: &Type) -> Type {
match ty {
Type::Meta(name) => {
if let Some(bound) = self.substitution.get(name) {
self.walk(bound)
} else {
ty.clone()
}
}
_ => ty.clone(),
}
}
fn resolve_ctx_call(&self, ty: &Type, allow_context: bool) -> Type {
match ty {
Type::ContextCall(ctx_name, var) => {
let resolved_var = self
.binding_values
.get(var)
.map(|v| v.as_str())
.unwrap_or(var.as_str());
if allow_context && let Some(ctx) = self.context.as_ref() {
if let Some(found) = ctx.lookup(resolved_var) {
return found.clone();
}
if ctx.lookup_starts_with(resolved_var).is_some() {
return ty.clone();
}
}
if resolved_var != var.as_str() {
return Type::ContextCall(ctx_name.clone(), resolved_var.to_string());
}
ty.clone()
}
Type::Arrow(a, b) => Type::Arrow(
Box::new(self.resolve_ctx_call(a, allow_context)),
Box::new(self.resolve_ctx_call(b, allow_context)),
),
Type::Array(inner) => {
Type::Array(Box::new(self.resolve_ctx_call(inner, allow_context)))
}
Type::Union(parts) => Type::Union(
parts
.iter()
.map(|p| self.resolve_ctx_call(p, allow_context))
.collect(),
),
Type::Not(a) => Type::Not(Box::new(self.resolve_ctx_call(a, allow_context))),
Type::Partial(t, s) => {
Type::Partial(Box::new(self.resolve_ctx_call(t, false)), s.clone())
}
Type::PathOf(t, p) => {
Type::PathOf(Box::new(self.resolve_ctx_call(t, false)), p.clone())
}
_ => ty.clone(),
}
}
}
pub fn equal(t1: &Type, t2: &Type) -> Option<bool> {
match (t1, t2) {
(Type::Raw(a), Type::Raw(b)) => Some(a == b),
(Type::Arrow(l1, r1), Type::Arrow(l2, r2)) => Some(equal(l1, l2)? && equal(r1, r2)?),
(Type::Array(a), Type::Array(b)) => equal(a, b),
(Type::Union(a), Type::Union(b)) => {
if a.len() != b.len() {
Some(false)
} else {
let mut all = true;
for (x, y) in a.iter().zip(b.iter()) {
all = all && equal(x, y)?;
}
Some(all)
}
}
(Type::Not(a), Type::Not(b)) => equal(a, b),
(Type::ContextCall(_, _), _) | (_, Type::ContextCall(_, _)) => None,
(Type::Path(_), _) | (_, Type::Path(_)) => None,
(Type::PathOf(_, _), _) | (_, Type::PathOf(_, _)) => None,
(Type::Any, Type::Any) => Some(true),
(Type::Any, _) | (_, Type::Any) => None,
(Type::None, Type::None) => Some(true),
(Type::None, _) | (_, Type::None) => Some(false),
_ => Some(false),
}
}
pub fn subtype(t1: &Type, t2: &Type) -> bool {
if matches!(t1, Type::None) {
return true;
}
if matches!(t2, Type::Any) {
return true;
}
if let Some(true) = equal(t1, t2) {
return true;
}
match (&t1, &t2) {
(Type::Arrow(d1, c1), Type::Arrow(d2, c2)) => subtype(d2, d1) && subtype(c1, c2),
(Type::Array(a), Type::Array(b)) => subtype(a, b),
(Type::Union(parts), other) => parts.iter().all(|p| subtype(p, other)),
(other, Type::Union(parts)) => parts.iter().any(|p| subtype(other, p)),
_ => false,
}
}
#[allow(dead_code)]
fn occurs_meta(name: &str, ty: &Type) -> bool {
match ty {
Type::Meta(n) => n == name,
Type::Arrow(l, r) => occurs_meta(name, l) || occurs_meta(name, r),
Type::Array(inner) => occurs_meta(name, inner),
Type::Not(t) => occurs_meta(name, t),
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::logic::typing::Type;
use proptest::prelude::*;
fn parse(t: &str) -> Type {
Type::parse(t).expect("type should parse")
}
#[test]
fn subtype_member_into_union() {
let int_t = parse("'Int'");
let union_t = parse("'Int' | 'Bool'");
assert!(subtype(&int_t, &union_t));
}
#[test]
fn subtype_union_not_into_single_member() {
let union_t = parse("'Int' | 'Bool'");
let int_t = parse("'Int'");
assert!(!subtype(&union_t, &int_t));
}
#[test]
fn unify_union_with_meta_member() {
let mut unifier = Unifier::new();
let lhs = parse("?A | 'Bool'");
let rhs = parse("'Int' | 'Bool'");
assert!(unifier.unify(&lhs, &rhs).is_ok());
assert!(matches!(unifier.resolve_meta("A"), Some(Type::Raw(name)) if name == "Int"));
}
#[test]
fn seed_preserves_existing_binding() {
let mut unifier = Unifier::new();
assert!(unifier.bind("A", &parse("'Int' ")).is_ok());
unifier.seed(vec!["A".to_string()], |_| Some(parse("'Bool'")));
assert_eq!(unifier.resolve_meta("A"), Some(&parse("'Int'")));
}
#[test]
fn seed_only_binds_requested_names() {
let mut unifier = Unifier::new();
unifier.seed(vec!["A".to_string()], |name| match name {
"A" => Some(parse("'Int'")),
_ => Some(parse("'Bool'")),
});
assert_eq!(unifier.resolve_meta("A"), Some(&parse("'Int'")));
assert_eq!(unifier.resolve_meta("B"), None);
}
proptest! {
#[test]
fn prop_seed_binds_each_name_at_most_once(name in "[A-Z]{1,3}") {
let mut unifier = Unifier::new();
let mut calls = 0usize;
unifier.seed(vec![name.clone(), name.clone(), name.clone()], |_| {
calls += 1;
Some(parse("'Int'"))
});
prop_assert_eq!(calls, 1);
prop_assert_eq!(unifier.resolve_meta(&name), Some(&parse("'Int'")));
}
}
}