t-ree 0.1.0

AST definitions for the T programming language
Documentation
use std::collections::{HashMap, HashSet};

use crate::declaration::{Declaration, Module};
use crate::expression::{Block, Expression, ExpressionKind, Statement};
use crate::operator::BinaryOperator;
use crate::types::Type;

/// Validates type correctness across a module.
///
/// Checks binary operation type compatibility, assignment type safety,
/// and type construction field correctness.
pub fn validate_module(module: &Module) -> Result<(), String> {
    let mut newtypes: HashMap<String, Type> = HashMap::new();
    for declaration in module {
        if let Declaration::Type(newtype) = declaration {
            newtypes.insert(newtype.name.clone(), newtype.inner_type.clone());
        }
    }
    let context = ValidationContext { newtypes };
    for declaration in module {
        match declaration {
            Declaration::Function(function) => {
                context.validate_block(&function.body)?;
            }
            Declaration::Constant(constant) => {
                context.validate_expression(&constant.value)?;
            }
            _ => {}
        }
    }
    Ok(())
}

struct ValidationContext {
    newtypes: HashMap<String, Type>,
}

impl ValidationContext {
    fn validate_block(&self, block: &Block) -> Result<(), String> {
        for statement in &block.statements {
            self.validate_statement(statement)?;
        }
        if let Some(result) = &block.result {
            self.validate_expression(result)?;
        }
        Ok(())
    }

    fn validate_statement(&self, statement: &Statement) -> Result<(), String> {
        match statement {
            Statement::Expression(expression) | Statement::Return(Some(expression)) => {
                self.validate_expression(expression)?;
            }
            Statement::Let { value, .. } => {
                self.validate_expression(value)?;
            }
            Statement::Assign(target, value) => {
                self.validate_expression(target)?;
                self.validate_expression(value)?;
                Self::check_replace_types(target, value)?;
            }
            Statement::Label {
                initial_arguments, ..
            } => {
                for argument in initial_arguments {
                    self.validate_expression(argument)?;
                }
            }
            Statement::Jump { arguments, .. } => {
                for argument in arguments {
                    self.validate_expression(argument)?;
                }
            }
            Statement::MultiReplace {
                targets, values, ..
            } => {
                for target in targets {
                    self.validate_expression(target)?;
                }
                for value in values {
                    self.validate_expression(value)?;
                }
            }
            Statement::Defer(inner) => {
                self.validate_statement(inner)?;
            }
            Statement::Return(None) => {}
        }
        Ok(())
    }

    fn validate_expression(&self, expression: &Expression) -> Result<(), String> {
        match &expression.kind {
            ExpressionKind::BinaryOperation(operator, left, right) => {
                self.validate_expression(left)?;
                self.validate_expression(right)?;
                self.check_binary_operands(operator, left, right)?;
            }
            ExpressionKind::TypeConstruction(name, fields) => {
                for (_, value) in fields {
                    self.validate_expression(value)?;
                }
                self.check_construction_fields(name, fields)?;
            }
            ExpressionKind::Replace(target, value) | ExpressionKind::OpAssign(_, target, value) => {
                self.validate_expression(target)?;
                self.validate_expression(value)?;
                Self::check_replace_types(target, value)?;
            }
            ExpressionKind::Call(callee, arguments) => {
                self.validate_expression(callee)?;
                for argument in arguments {
                    self.validate_expression(argument)?;
                }
            }
            ExpressionKind::UnaryOperation(_, operand)
            | ExpressionKind::Dereference(operand)
            | ExpressionKind::Convert(operand, _)
            | ExpressionKind::Transmute(operand, _) => {
                self.validate_expression(operand)?;
            }
            ExpressionKind::Field(object, _) => {
                self.validate_expression(object)?;
            }
            ExpressionKind::Index(object, index) => {
                self.validate_expression(object)?;
                self.validate_expression(index)?;
            }
            ExpressionKind::ArrayLiteral(elements)
            | ExpressionKind::TupleLiteral(elements)
            | ExpressionKind::Print(elements) => {
                for element in elements {
                    self.validate_expression(element)?;
                }
            }
            ExpressionKind::Block(block) => {
                self.validate_block(block)?;
            }
            ExpressionKind::If {
                condition,
                then_branch,
                else_branch,
            } => {
                self.validate_expression(condition)?;
                self.validate_block(then_branch)?;
                if let Some(else_branch) = else_branch {
                    self.validate_block(else_branch)?;
                }
            }
            ExpressionKind::Match { value, arms } => {
                self.validate_expression(value)?;
                for arm in arms {
                    self.validate_block(&arm.body)?;
                }
            }
            ExpressionKind::Slice(array, start, end) => {
                self.validate_expression(array)?;
                if let Some(start) = start {
                    self.validate_expression(start)?;
                }
                if let Some(end) = end {
                    self.validate_expression(end)?;
                }
            }
            ExpressionKind::Literal(_)
            | ExpressionKind::Variable(_)
            | ExpressionKind::SizeOf(_) => {}
        }
        Ok(())
    }

    fn resolve_underlying(&self, resolved_type: &Type) -> Type {
        match resolved_type {
            Type::Named(name) => self.newtypes.get(name).map_or_else(
                || resolved_type.clone(),
                |inner| self.resolve_underlying(inner),
            ),
            Type::Pointer(mutability, inner) => {
                Type::Pointer(*mutability, Box::new(self.resolve_underlying(inner)))
            }
            other => other.clone(),
        }
    }

    fn check_binary_operands(
        &self,
        operator: &BinaryOperator,
        left: &Expression,
        right: &Expression,
    ) -> Result<(), String> {
        if matches!(operator, BinaryOperator::Logical(_)) {
            return Ok(());
        }
        let (Some(left_type), Some(right_type)) = (&left.resolved_type, &right.resolved_type)
        else {
            return Ok(());
        };
        if left_type == right_type {
            return Ok(());
        }
        let left_resolved = self.resolve_underlying(left_type);
        let right_resolved = self.resolve_underlying(right_type);
        if left_resolved != right_resolved {
            return Err(format!(
                "type mismatch in '{operator}': left is {left_type}, right is {right_type}",
            ));
        }
        if matches!(left_type, Type::Named(_)) && matches!(right_type, Type::Named(_)) {
            return Err(format!(
                "cannot mix distinct types in '{operator}': left is {left_type}, right is {right_type}",
            ));
        }
        Ok(())
    }

    fn check_replace_types(target: &Expression, value: &Expression) -> Result<(), String> {
        let Some(target_resolved) = &target.resolved_type else {
            return Ok(());
        };
        let Some(value_type) = &value.resolved_type else {
            return Ok(());
        };
        let target_type = match target_resolved {
            Type::Pointer(_, inner) => inner.as_ref(),
            other => other,
        };
        if target_type == value_type {
            return Ok(());
        }
        if matches!(target_type, Type::Named(_)) && matches!(value_type, Type::Named(_)) {
            return Err(format!(
                "type mismatch in assignment: target is {target_type}, value is {value_type}",
            ));
        }
        Ok(())
    }

    fn check_construction_fields(
        &self,
        type_name: &str,
        fields: &[(String, Expression)],
    ) -> Result<(), String> {
        let Some(inner) = self.newtypes.get(type_name) else {
            return Ok(());
        };
        let expected_fields: Vec<&str> = match inner {
            Type::Tuple(field_types) => field_types
                .iter()
                .filter_map(|field_type| match field_type {
                    Type::Named(name) => Some(name.as_str()),
                    _ => None,
                })
                .collect(),
            Type::Named(name) => vec![name.as_str()],
            _ => return Ok(()),
        };

        let mut seen = HashSet::new();
        for (field_name, _) in fields {
            if !expected_fields.contains(&field_name.as_str()) {
                return Err(format!("'{type_name}' has no field '{field_name}'"));
            }
            if !seen.insert(field_name.as_str()) {
                return Err(format!(
                    "duplicate field '{field_name}' in '{type_name}' construction"
                ));
            }
        }
        for expected in &expected_fields {
            if !seen.contains(expected) {
                return Err(format!(
                    "missing field '{expected}' in '{type_name}' construction"
                ));
            }
        }
        Ok(())
    }
}