use std::collections::HashMap;
use crate::ast::*;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SemanticError {
#[error("Undefined identifier: {0}")]
UndefinedIdentifier(String),
#[error("Type mismatch: expected {expected}, found {found}")]
TypeMismatch { expected: String, found: String },
#[error("Duplicate definition: {0}")]
DuplicateDefinition(String),
}
#[derive(Debug, Clone)]
pub struct Symbol {
pub name: String,
pub ty: Type,
}
pub struct Analyzer {
scopes: Vec<HashMap<String, Symbol>>,
}
impl Analyzer {
pub fn new() -> Self {
let mut global = HashMap::new();
global.insert("print".to_string(), Symbol { name: "print".to_string(), ty: Type::Void });
global.insert("println".to_string(), Symbol { name: "println".to_string(), ty: Type::Void });
Analyzer { scopes: vec![global] }
}
fn push_scope(&mut self) {
self.scopes.push(HashMap::new());
}
fn pop_scope(&mut self) {
self.scopes.pop();
}
fn insert(&mut self, name: String, ty: Type) -> Result<(), SemanticError> {
let current = self.scopes.last_mut().unwrap();
if current.contains_key(&name) {
return Err(SemanticError::DuplicateDefinition(name));
}
current.insert(name.clone(), Symbol { name, ty });
Ok(())
}
fn resolve(&self, name: &str) -> Option<Symbol> {
for scope in self.scopes.iter().rev() {
if let Some(sym) = scope.get(name) {
return Some(sym.clone());
}
}
None
}
pub fn analyze(&mut self, program: &Program) -> Result<(), SemanticError> {
for item in &program.items {
match item {
TopLevel::Func(func) => {
let ret_type = func.ret_type.clone().unwrap_or(Type::Void);
self.insert(func.name.clone(), ret_type)?;
}
_ => {}
}
}
for item in &program.items {
match item {
TopLevel::Func(func) => self.analyze_func(func)?,
_ => {}
}
}
Ok(())
}
fn analyze_func(&mut self, func: &FuncDecl) -> Result<(), SemanticError> {
self.push_scope();
for param in &func.params {
let ty = param.ty.clone().unwrap_or(Type::Auto);
self.insert(param.name.clone(), ty)?;
}
self.analyze_block(&func.body)?;
self.pop_scope();
Ok(())
}
fn analyze_block(&mut self, block: &Block) -> Result<(), SemanticError> {
self.push_scope();
for stmt in &block.statements {
self.analyze_stmt(stmt)?;
}
self.pop_scope();
Ok(())
}
fn analyze_stmt(&mut self, stmt: &Stmt) -> Result<(), SemanticError> {
match stmt {
Stmt::VarDecl(v) => {
let ty = self.infer_expr(&v.init)?;
self.insert(v.name.clone(), ty)?;
}
Stmt::Expr(e) => {
self.infer_expr(e)?;
}
Stmt::If(if_stmt) => {
self.infer_expr(&if_stmt.condition)?;
self.analyze_block(&if_stmt.then_branch)?;
for (cond, block) in &if_stmt.elif_branches {
self.infer_expr(cond)?;
self.analyze_block(block)?;
}
if let Some(else_block) = &if_stmt.else_branch {
self.analyze_block(else_block)?;
}
}
Stmt::While(w) => {
self.infer_expr(&w.condition)?;
self.analyze_block(&w.body)?;
}
Stmt::For(f) => {
if let Some(init) = &f.init {
self.analyze_stmt(init)?;
}
self.infer_expr(&f.condition)?;
if let Some(update) = &f.update {
self.analyze_stmt(update)?;
}
self.analyze_block(&f.body)?;
}
Stmt::Return(expr) => {
if let Some(e) = expr {
self.infer_expr(e)?;
}
}
Stmt::Match(_) => {
}
}
Ok(())
}
fn infer_expr(&mut self, expr: &Expr) -> Result<Type, SemanticError> {
match expr {
Expr::Literal(lit) => Ok(match lit {
Literal::Int(_) => Type::Int,
Literal::Float(_) => Type::Float,
Literal::Bool(_) => Type::Bool,
Literal::String(_) => Type::String,
Literal::Char(_) => Type::Int,
}),
Expr::Ident(name) => {
self.resolve(name)
.map(|s| s.ty)
.ok_or_else(|| SemanticError::UndefinedIdentifier(name.clone()))
}
Expr::Binary(_, _, _) => Ok(Type::Int),
Expr::Unary(_, _) => Ok(Type::Int),
Expr::Call(func, args) => {
self.infer_expr(func)?;
for arg in args {
self.infer_expr(arg)?;
}
Ok(Type::Auto)
}
Expr::MethodCall(obj, _method, args) => {
self.infer_expr(obj)?;
for arg in args {
self.infer_expr(arg)?;
}
Ok(Type::Auto)
}
Expr::Lambda(_, _) => Ok(Type::Auto),
Expr::List(_) => Ok(Type::Auto),
Expr::Dict(_) => Ok(Type::Auto),
Expr::Index(_, _) => Ok(Type::Auto),
Expr::FieldAccess(_, _) => Ok(Type::Auto),
Expr::StructInit(struct_name, fields) => {
for (_, value) in fields {
self.infer_expr(value)?;
}
Ok(Type::Named(struct_name.clone(), vec![]))
}
Expr::Assign(_, _) => Ok(Type::Void),
Expr::CompoundAssign(_, _, _) => Ok(Type::Void),
Expr::Block(b) => {
self.analyze_block(b)?;
Ok(Type::Void)
}
}
}
}