delo 0.1.0

A statically typed, compiled-to-C programming language that can time-travel.
use std::collections::HashMap;

use crate::ast::token::Token;
use crate::error::{TypeError, UnifyError};
use crate::type_system::types::TypeExpr;

pub struct Unifier {
    next_type_var_id: usize,
    substitutions: HashMap<usize, TypeExpr>,
}

impl Unifier {
    pub fn new() -> Self {
        Self {
            next_type_var_id: 0,
            substitutions: HashMap::new(),
        }
    }

    pub fn new_type_var(&mut self) -> TypeExpr {
        let id = self.next_type_var_id;
        self.next_type_var_id += 1;
        TypeExpr::TypeVar { id }
    }

    pub fn instantiate_type(&mut self, type_expr: &TypeExpr) -> TypeExpr {
        let mut mapping = HashMap::new();
        self.instantiate_type_with_mapping(type_expr, &mut mapping)
    }

    pub fn apply_substitutions(&self, type_expr: &TypeExpr) -> TypeExpr {
        match type_expr {
            TypeExpr::Named { identifier, type_parameters, type_arguments, enum_variants, struct_fields } => {
                TypeExpr::Named {
                    identifier: identifier.clone(),
                    type_parameters: type_parameters.clone(),
                    type_arguments: type_arguments.iter().map(|t| self.apply_substitutions(t)).collect(),
                    enum_variants: enum_variants.clone(),
                    struct_fields: struct_fields.clone()
                }
            }
            TypeExpr::Function { parameter_types, return_type } => {
                TypeExpr::Function {
                    parameter_types: parameter_types.iter().map(|t| self.apply_substitutions(t)).collect(),
                    return_type: Box::new(self.apply_substitutions(return_type))
                }
            }
            TypeExpr::Tuple { element_types } => {
                TypeExpr::Tuple {
                    element_types: element_types.iter().map(|t| self.apply_substitutions(t)).collect(),
                }
            }
            TypeExpr::TypeVar { id } => {
                if let Some(substitution) = self.substitutions.get(id) {
                    self.apply_substitutions(substitution)
                } else {
                    type_expr.clone()
                }
            }
        }
    }

    pub fn unify_at(&mut self, type_a: &TypeExpr, type_b: &TypeExpr, token: &Token) -> Result<TypeExpr, TypeError> {
        self.unify(type_a, type_b).map_err(|e| match e {
            UnifyError::MismatchedTypes { expected, found } => TypeError::MismatchedTypes { 
                line: token.line,
                column: token.column,
                expected, 
                found 
            },
            UnifyError::InfiniteType { type_var_id, type_expr } => TypeError::InfiniteType { 
                line: token.line,
                column: token.column,
                type_var_id, 
                type_expr 
            },
        })
    }

    fn unify(&mut self, type_a: &TypeExpr, type_b: &TypeExpr) -> Result<TypeExpr, UnifyError> {
        let type_a = self.resolve_type_var_chains(type_a);
        let type_b = self.resolve_type_var_chains(type_b);

        match (&type_a, &type_b) {
            (TypeExpr::TypeVar { id: type_var_id_a }, other_type) => {
                self.bind_type_var_to_type(*type_var_id_a, other_type)
            }
            (other_type, TypeExpr::TypeVar { id: type_var_id_b }) => {
                self.bind_type_var_to_type(*type_var_id_b, other_type)
            }
            (TypeExpr::Named {
                identifier: identifier_a,
                type_parameters: type_parameters_a,
                type_arguments: type_arguments_a,
                enum_variants: enum_variants_a,
                struct_fields: struct_fields_a,
            },
            TypeExpr::Named {
                identifier: identifier_b,
                type_arguments: type_arguments_b,
                ..
            }) => {
                if identifier_a.lexeme != identifier_b.lexeme {
                    return Err(UnifyError::MismatchedTypes { expected: type_a, found: type_b });
                }

                if type_arguments_a.len() != type_arguments_b.len() {
                    return Err(UnifyError::MismatchedTypes { expected: type_a, found: type_b });
                }

                let mut unified_type_arguments = Vec::new();
                for (argument_a, argument_b) in type_arguments_a.iter().zip(type_arguments_b.iter()) {
                    unified_type_arguments.push(self.unify(argument_a, argument_b)?);
                }

                Ok(TypeExpr::Named {
                    identifier: identifier_a.clone(),
                    type_parameters: type_parameters_a.clone(),
                    type_arguments: unified_type_arguments,
                    enum_variants: enum_variants_a.clone(),
                    struct_fields: struct_fields_a.clone(),
                })
            }
            (TypeExpr::Function {
                parameter_types: parameter_types_a,
                return_type: return_type_a,
            },
            TypeExpr::Function {
                parameter_types: parameter_types_b,
                return_type: return_type_b,
            }) => {
                if parameter_types_a.len() != parameter_types_b.len() {
                    return Err(UnifyError::MismatchedTypes { expected: type_a, found: type_b });
                }

                let mut unified_parameter_types = Vec::new();
                for (parameter_a, parameter_b) in parameter_types_a.iter().zip(parameter_types_b.iter()) {
                    unified_parameter_types.push(self.unify(parameter_a, parameter_b)?);
                }

                let unified_return_type = self.unify(&return_type_a, &return_type_b)?;

                Ok(TypeExpr::Function {
                    parameter_types: unified_parameter_types,
                    return_type: Box::new(unified_return_type),
                })
            }
            (TypeExpr::Tuple { element_types: element_types_a }, TypeExpr::Tuple { element_types: element_types_b }) => {
                if element_types_a.len() != element_types_b.len() {
                    return Err(UnifyError::MismatchedTypes { expected: type_a, found: type_b });
                }

                let mut unified_element_types = Vec::new();
                for (element_a, element_b) in element_types_a.iter().zip(element_types_b.iter()) {
                    unified_element_types.push(self.unify(element_a, element_b)?);
                }

                Ok(TypeExpr::Tuple { element_types: unified_element_types })
            }
            _ => Err(UnifyError::MismatchedTypes { expected: type_a, found: type_b })
        }
    }

    fn bind_type_var_to_type(&mut self, type_var_id: usize, type_expr: &TypeExpr) -> Result<TypeExpr, UnifyError> {
        let resolved_type = self.resolve_type_var_chains(type_expr);

        if let TypeExpr::TypeVar { id } = resolved_type {
            if id == type_var_id {
                return Ok(TypeExpr::TypeVar { id });
            }
        }

        if self.type_var_occurs_in_type(type_var_id, &resolved_type) {
            return Err(UnifyError::InfiniteType { type_var_id, type_expr: resolved_type });
        }

        self.substitutions.insert(type_var_id, resolved_type.clone());

        Ok(resolved_type)
    }

    fn type_var_occurs_in_type(&self, type_var_id: usize, type_expr: &TypeExpr) -> bool {
        let resolved_type = self.resolve_type_var_chains(type_expr);

        match resolved_type {
            TypeExpr::Named { type_arguments, .. } => {
                type_arguments.iter().any(|t| self.type_var_occurs_in_type(type_var_id, t))
            }
            TypeExpr::Function { parameter_types, return_type } => {
                parameter_types.iter().any(|t| self.type_var_occurs_in_type(type_var_id, t))
                || self.type_var_occurs_in_type(type_var_id, &return_type)
            }
            TypeExpr::Tuple { element_types } => {
                element_types.iter().any(|t| self.type_var_occurs_in_type(type_var_id, t))
            }
            TypeExpr::TypeVar { id } => id == type_var_id,
        }
    }

    fn resolve_type_var_chains(&self, type_expr: &TypeExpr) -> TypeExpr {
        match type_expr {
            TypeExpr::TypeVar { id } => {
                if let Some(substitution) = self.substitutions.get(id) {
                    self.resolve_type_var_chains(substitution)
                } else {
                    type_expr.clone()
                }
            }
            _ => type_expr.clone()
        }
    }

    fn instantiate_type_with_mapping(&mut self, type_expr: &TypeExpr, mapping: &mut HashMap<usize, TypeExpr>) -> TypeExpr {
        match type_expr {
            TypeExpr::Named { identifier, type_parameters, type_arguments, enum_variants, struct_fields } => {
                TypeExpr::Named {
                    identifier: identifier.clone(),
                    type_parameters: type_parameters.clone(),
                    type_arguments: type_arguments.iter().map(|t| self.instantiate_type_with_mapping(t, mapping)).collect(),
                    enum_variants: enum_variants.clone(),
                    struct_fields: struct_fields.as_ref().map(|fields| {
                        fields.iter().map(|(identifier, field_type)| {
                            (identifier.clone(), self.instantiate_type_with_mapping(field_type, mapping))
                        }).collect()
                    }),
                }
            }
            TypeExpr::Function { parameter_types, return_type } => {
                TypeExpr::Function {
                    parameter_types: parameter_types.iter().map(|t| self.instantiate_type_with_mapping(t, mapping)).collect(),
                    return_type: Box::new(self.instantiate_type_with_mapping(return_type, mapping)),
                }
            }
            TypeExpr::Tuple { element_types } => {
                TypeExpr::Tuple {
                    element_types: element_types.iter().map(|t| self.instantiate_type_with_mapping(t, mapping)).collect(),
                }
            }
            TypeExpr::TypeVar { id } => {
                if let Some(fresh) = mapping.get(id) {
                    fresh.clone()
                } else {
                    let fresh = self.new_type_var();
                    mapping.insert(*id, fresh.clone());
                    fresh
                }
            }
        }
    }
}