kodept-inference 0.2.3

Simple compiler with dependent types support in mind
Documentation
use std::fmt::{Display, Formatter};

use itertools::Itertools;
use thiserror::Error;

use MonomorphicType::*;

use crate::algorithm_u::AlgorithmUError::{InfiniteType, UnificationFail};
use crate::r#type::{MonomorphicType, TVar};
use crate::substitution::Substitutions;
use crate::traits::{FreeTypeVars, Substitutable};

#[derive(Debug, Error)]
pub struct UnificationMismatch(pub Vec<MonomorphicType>, pub Vec<MonomorphicType>);

#[derive(Debug, Error)]
pub enum AlgorithmUError {
    #[error("Cannot unify types: {0} with {1}")]
    UnificationFail(MonomorphicType, MonomorphicType),
    #[error("Cannot construct an infinite type: {0} ~ {1}")]
    InfiniteType(TVar, MonomorphicType),
    #[error(transparent)]
    UnificationMismatch(#[from] UnificationMismatch),
}

struct AlgorithmU;

impl AlgorithmU {
    fn occurs_check(var: &TVar, with: impl FreeTypeVars) -> bool {
        with.free_types().contains(var)
    }

    fn unify_vec(
        vec1: &[MonomorphicType],
        vec2: &[MonomorphicType],
    ) -> Result<Substitutions, AlgorithmUError> {
        match (vec1, vec2) {
            ([], []) => Ok(Substitutions::empty()),
            ([t1, ts1 @ ..], [t2, ts2 @ ..]) => {
                let s1 = t1.unify(t2)?;
                let s2 = Self::unify_vec(&ts1.substitute(&s1), &ts2.substitute(&s1))?;
                Ok(s1 + s2)
            }
            (t1, t2) => Err(UnificationMismatch(t1.to_vec(), t2.to_vec()).into()),
        }
    }

    fn bind(var: &TVar, ty: &MonomorphicType) -> Result<Substitutions, AlgorithmUError> {
        match ty {
            Var(v) if var == v => Ok(Substitutions::empty()),
            _ if Self::occurs_check(var, ty) => Err(InfiniteType(*var, ty.clone())),
            _ => Ok(Substitutions::single(*var, ty.clone())),
        }
    }

    fn apply(
        lhs: &MonomorphicType,
        rhs: &MonomorphicType,
    ) -> Result<Substitutions, AlgorithmUError> {
        match (lhs, rhs) {
            (a, b) if a == b => Ok(Substitutions::empty()),
            (Var(var), b) => Self::bind(var, b),
            (a, Var(var)) => Self::bind(var, a),
            (Fn(i1, o1), Fn(i2, o2)) => Self::unify_vec(
                &[i1.as_ref().clone(), o1.as_ref().clone()],
                &[i2.as_ref().clone(), o2.as_ref().clone()],
            ),
            (Tuple(t1), Tuple(t2)) => Self::unify_vec(&t1.0, &t2.0),
            (Pointer(t1), Pointer(t2)) => t1.unify(t2),
            _ => Err(UnificationFail(lhs.clone(), rhs.clone())),
        }
    }
}

impl MonomorphicType {
    pub fn unify(&self, other: &MonomorphicType) -> Result<Substitutions, AlgorithmUError> {
        AlgorithmU::apply(self, other)
    }
}

impl Display for UnificationMismatch {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "Cannot unify types: [{}] with [{}]; different structure",
            self.0.iter().join(", "),
            self.1.iter().join(", ")
        )
    }
}

#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
    use std::collections::HashMap;

    use nonempty_collections::nev;

    use crate::algorithm_u::AlgorithmUError;
    use crate::r#type::{fun, fun1, MonomorphicType, PrimitiveType, Tuple, TVar, var};
    use crate::r#type::MonomorphicType::Constant;
    use crate::substitution::Substitutions;
    use crate::traits::Substitutable;

    #[test]
    fn test_tautology_example_on_constants() {
        let a = Constant("A".to_string());
        let b = Constant("A".to_string());

        let s = a.unify(&b).unwrap();
        assert_eq!(s.into_inner(), HashMap::new());
    }

    #[test]
    fn test_different_constants_should_not_unify() {
        let a = Constant("A".to_string());
        let b = Constant("B".to_string());

        let e = a.unify(&b).unwrap_err();
        assert!(matches!(e, AlgorithmUError::UnificationFail(..)))
    }

    #[test]
    fn test_tautology_example_on_vars() {
        let a = var(0);
        let b = var(0);

        let s = a.unify(&b).unwrap();
        assert_eq!(s.into_inner(), HashMap::new());
    }

    #[test]
    fn test_variables_should_be_always_unified() {
        let a = TVar(1);
        let b = Constant("A".to_string());

        let s1 = MonomorphicType::Var(a).unify(&b).unwrap();
        let s2 = b.unify(&MonomorphicType::Var(a)).unwrap();

        assert_eq!(s1, s2);
        assert_eq!(s1, Substitutions::single(a, b))
    }

    #[test]
    fn test_aliasing() {
        let a = TVar(1);
        let b = TVar(2);

        let a_ = MonomorphicType::Var(a);
        let b_ = MonomorphicType::Var(b);

        let s1 = a_.unify(&b_).unwrap();
        let s2 = b_.unify(&a_).unwrap();

        assert_eq!(s1, Substitutions::single(a, b_.clone()));
        assert_eq!(s2, Substitutions::single(b, a_))
    }

    #[test]
    fn test_simple_function_unifying() {
        let a = fun(nev![var(1), Constant("A".to_string())], Tuple::unit());
        let b = fun(nev![var(1), var(2)], Tuple::unit());

        let s = a.unify(&b).unwrap();
        assert_eq!(s, Substitutions::single(TVar(2), Constant("A".to_string())));
    }

    #[test]
    fn test_aliasing_in_functions() {
        let a = fun(nev![var(1)], Tuple::unit());
        let b = fun(nev![var(2)], Tuple::unit());

        let s = a.unify(&b).unwrap();
        assert_eq!(s, Substitutions::single(TVar(1), var(2)));
    }

    #[test]
    fn test_functions_with_different_arity_should_not_unify() {
        let a = fun(nev![Constant("A".to_string())], Tuple::unit());
        let b = fun(
            nev![Constant("A".to_string()), Constant("B".to_string())],
            Tuple::unit(),
        );

        let s = a.unify(&b).unwrap_err();
        assert!(matches!(s, AlgorithmUError::UnificationFail(..)))
    }

    #[test]
    fn test_multiple_substitutions() {
        let a = fun(
            nev![fun1(var(1), PrimitiveType::Integral), var(1)],
            Tuple::unit(),
        );
        let b = fun(nev![var(2), Constant("A".to_string())], Tuple::unit());

        let s = a.unify(&b).unwrap();
        assert_eq!(
            s.into_inner(),
            HashMap::from([
                (TVar(1), Constant("A".to_string())),
                (
                    TVar(2),
                    fun1(Constant("A".to_string()), PrimitiveType::Integral)
                )
            ])
        )
    }

    #[test]
    fn test_infinite_substitution() {
        let a = var(1);
        let b = fun1(var(1), Tuple::unit());

        let e = a.unify(&b).unwrap_err();
        assert!(matches!(e, AlgorithmUError::InfiniteType { .. }))
    }

    #[test]
    fn test_transitive_substitutions() {
        let a = var(1);
        let b = var(2);
        let c = Constant("A".to_string());

        let s1 = a.unify(&b).unwrap();
        let s2 = b.unify(&a).unwrap();
        let s3 = c.unify(&b.substitute(&s2)).unwrap();
        let s4 = a.substitute(&s1).unify(&c).unwrap();

        assert_eq!(s1, Substitutions::single(TVar(1), b.clone()));
        assert_eq!(s2, Substitutions::single(TVar(2), a.clone()));
        assert_eq!(s3, Substitutions::single(TVar(1), c.clone()));
        assert_eq!(s4, Substitutions::single(TVar(2), c));
    }

    #[test]
    fn test_different_substitutions_of_same_variable() {
        let a = var(1);
        let b = Constant("A".to_string());
        let c = Constant("B".to_string());

        let s = a.unify(&b).unwrap();
        let e = a.substitute(&s).unify(&c).unwrap_err();

        assert_eq!(s, Substitutions::single(TVar(1), b));
        assert!(matches!(e, AlgorithmUError::UnificationFail(..)))
    }

    #[test]
    fn test_complex_unification() {
        let a = fun1(
            fun1(fun1(Constant("A".to_string()), var(1)), var(2)),
            var(3),
        );
        let b = fun(nev![var(3), var(2), var(1)], Constant("A".to_string()));

        let s1 = a.unify(&b).unwrap();
        let s2 = b.unify(&a).unwrap();

        assert_eq!(s1, s2);
        assert_eq!(a.substitute(&s1), b.substitute(&s1));
        let h = fun1(Constant("A".to_string()), Constant("A".to_string()));
        assert_eq!(
            a.substitute(&s1),
            fun1(fun1(h.clone(), h.clone()), fun1(h.clone(), h))
        )
    }
}