mux-lang 0.4.1

The Mux Programming Language Compiler
use super::{SemanticAnalyzer, SemanticError, Symbol, SymbolKind, Type, format_type};
use crate::ast::{LiteralNode, PatternNode, PrimitiveType};
use crate::lexer::Span;
use std::collections::HashMap;

impl SemanticAnalyzer {
    pub(super) fn set_pattern_types(
        &mut self,
        pattern: &PatternNode,
        expected_type: &Type,
        span: Span,
    ) -> Result<(), SemanticError> {
        match pattern {
            PatternNode::Identifier(name) => {
                self.handle_identifier_pattern(name, expected_type, span)?;
            }
            PatternNode::EnumVariant { name, args } => {
                self.handle_enum_variant_pattern(name, args, expected_type, span)?;
            }
            PatternNode::Literal(lit) => {
                self.handle_literal_pattern(lit, expected_type, span)?;
            }
            PatternNode::List { elements, rest } => {
                self.handle_list_pattern(elements, rest, expected_type, span)?;
            }
            PatternNode::Wildcard => {}
        }
        Ok(())
    }

    fn handle_identifier_pattern(
        &mut self,
        name: &str,
        expected_type: &Type,
        span: Span,
    ) -> Result<(), SemanticError> {
        let is_constant = self
            .symbol_table
            .lookup(name)
            .map(|s| s.kind == SymbolKind::Constant)
            .unwrap_or(false);

        if is_constant {
            let const_type = self
                .symbol_table
                .lookup(name)
                .and_then(|s| s.type_.clone())
                .ok_or_else(|| {
                    SemanticError::new(format!("Constant '{}' has no type information", name), span)
                })?;
            self.check_type_compatibility(&const_type, expected_type, span)?;
        } else {
            self.symbol_table.add_symbol(
                name,
                Symbol {
                    kind: SymbolKind::Variable,
                    span,
                    type_: Some(expected_type.clone()),
                    interfaces: HashMap::new(),
                    methods: HashMap::new(),
                    fields: HashMap::new(),
                    type_params: Vec::new(),
                    original_name: None,
                    llvm_name: None,
                    default_param_count: 0,
                    variants: None,
                },
            )?;
        }
        Ok(())
    }

    fn handle_enum_variant_pattern(
        &mut self,
        name: &str,
        args: &[PatternNode],
        expected_type: &Type,
        span: Span,
    ) -> Result<(), SemanticError> {
        match expected_type {
            Type::Optional(inner) => self.match_optional_variant(name, args, inner, span),
            Type::Result(ok_type, err_type) => {
                self.match_result_variant(name, args, ok_type, err_type, span)
            }
            Type::Named(enum_name, _) => {
                let symbol = self
                    .symbol_table
                    .lookup(enum_name)
                    .ok_or_else(|| self.undefined_symbol_error("type", enum_name, span))?;
                self.match_enum_variant(name, args, enum_name, &symbol, span)
            }
            _ => Err(SemanticError::with_help(
                format!(
                    "Enum variant patterns are not supported for type {}",
                    format_type(expected_type)
                ),
                span,
                "Variant patterns can only be used with enum, Optional, or Result types",
            )),
        }
    }

    fn match_optional_variant(
        &mut self,
        name: &str,
        args: &[PatternNode],
        inner: &Type,
        span: Span,
    ) -> Result<(), SemanticError> {
        if name == "some" && args.len() == 1 {
            self.set_pattern_types(&args[0], inner, span)?;
            Ok(())
        } else if name == "none" && args.is_empty() {
            Ok(())
        } else {
            Err(SemanticError::with_help(
                format!(
                    "Pattern '{}' does not match type {}",
                    name,
                    format_type(&Type::Optional(Box::new(inner.clone())))
                ),
                span,
                "Optional values can only be matched with Some(value) or None",
            ))
        }
    }

    fn match_result_variant(
        &mut self,
        name: &str,
        args: &[PatternNode],
        ok_type: &Type,
        err_type: &Type,
        span: Span,
    ) -> Result<(), SemanticError> {
        if name == "ok" && args.len() == 1 {
            self.set_pattern_types(&args[0], ok_type, span)?;
            Ok(())
        } else if name == "err" && args.len() == 1 {
            self.set_pattern_types(&args[0], err_type, span)?;
            Ok(())
        } else {
            Err(SemanticError::with_help(
                format!(
                    "Pattern '{}' does not match type {}",
                    name,
                    format_type(&Type::Result(
                        Box::new(ok_type.clone()),
                        Box::new(err_type.clone())
                    ))
                ),
                span,
                "Result values can only be matched with Ok(value) or Err(value)",
            ))
        }
    }

    fn handle_literal_pattern(
        &mut self,
        lit: &LiteralNode,
        expected_type: &Type,
        span: Span,
    ) -> Result<(), SemanticError> {
        let literal_type = match lit {
            LiteralNode::Integer(_) => Type::Primitive(PrimitiveType::Int),
            LiteralNode::Float(_) => Type::Primitive(PrimitiveType::Float),
            LiteralNode::String(_) => Type::Primitive(PrimitiveType::Str),
            LiteralNode::Boolean(_) => Type::Primitive(PrimitiveType::Bool),
            LiteralNode::Char(_) => Type::Primitive(PrimitiveType::Char),
        };
        self.check_type_compatibility(&literal_type, expected_type, span)
    }

    fn handle_list_pattern(
        &mut self,
        elements: &[PatternNode],
        rest: &Option<Box<PatternNode>>,
        expected_type: &Type,
        span: Span,
    ) -> Result<(), SemanticError> {
        let inner_type = match expected_type {
            Type::List(inner) => (**inner).clone(),
            Type::EmptyList => Type::Void,
            _ => {
                return Err(SemanticError::with_help(
                    format!(
                        "List pattern cannot match type {}",
                        format_type(expected_type)
                    ),
                    span,
                    "List patterns (e.g. [head, ...rest]) can only match list types",
                ));
            }
        };
        for elem in elements {
            self.set_pattern_types(elem, &inner_type, span)?;
        }
        if let Some(rest_pat) = rest {
            let rest_type = Type::List(Box::new(inner_type));
            self.set_pattern_types(rest_pat, &rest_type, span)?;
        }
        Ok(())
    }

    fn match_enum_variant(
        &mut self,
        name: &str,
        args: &[PatternNode],
        enum_name: &str,
        symbol: &Symbol,
        span: Span,
    ) -> Result<(), SemanticError> {
        let sig = symbol.methods.get(name).ok_or_else(|| {
            let available_variants: Vec<&String> = symbol.methods.keys().collect();
            if available_variants.is_empty() {
                SemanticError::new(
                    format!("Unknown variant '{}' for enum '{}'", name, enum_name),
                    span,
                )
            } else {
                SemanticError::with_help(
                    format!("Unknown variant '{}' for enum '{}'", name, enum_name),
                    span,
                    format!(
                        "Available variants: {}",
                        available_variants
                            .iter()
                            .map(|s| s.as_str())
                            .collect::<Vec<_>>()
                            .join(", ")
                    ),
                )
            }
        })?;
        if args.len() != sig.params.len() {
            return Err(SemanticError::with_help(
                format!(
                    "Variant '{}' expects {} argument{}, but pattern provides {}",
                    name,
                    sig.params.len(),
                    if sig.params.len() == 1 { "" } else { "s" },
                    args.len()
                ),
                span,
                format!(
                    "Match the variant with exactly {} argument{}",
                    sig.params.len(),
                    if sig.params.len() == 1 { "" } else { "s" }
                ),
            ));
        }
        for (arg, param_type) in args.iter().zip(&sig.params) {
            self.set_pattern_types(arg, param_type, span)?;
        }
        Ok(())
    }

    #[allow(clippy::only_used_in_recursion)]
    pub(super) fn analyze_pattern(&mut self, pattern: &PatternNode) -> Result<(), SemanticError> {
        match pattern {
            PatternNode::Identifier(_) => {}
            PatternNode::EnumVariant { args, .. } => {
                for arg in args {
                    self.analyze_pattern(arg)?;
                }
            }
            PatternNode::List { elements, rest } => {
                for elem in elements {
                    self.analyze_pattern(elem)?;
                }
                if let Some(rest_pat) = rest {
                    self.analyze_pattern(rest_pat)?;
                }
            }
            PatternNode::Literal(_) => {}
            PatternNode::Wildcard => {}
        }
        Ok(())
    }
}