use crate::compiler::parser::{Stmt, Expr, VarDecl, FuncDecl, Type, BinOp, Mutability, Assignment, IfStatement, ForLoop};
use std::collections::HashMap;
use crate::compiler::parser::WhileLoop;
#[derive(Debug, Clone)]
pub struct SemanticAnalyzer {
scopes: Vec<HashMap<String, SymbolInfo>>, functions: HashMap<String, FunctionInfo>,
current_function: Option<String>,
loop_depth: usize,
errors: Vec<SemanticError>,
}
#[derive(Debug, Clone)]
struct SymbolInfo {
ty: Type,
mutability: Mutability,
initialized: bool,
}
#[derive(Debug, Clone)]
struct FunctionInfo {
params: Vec<(Type, String)>,
return_type: Option<Type>,
}
#[derive(Debug, Clone)]
pub enum SemanticErrorKind {
UndeclaredVariable,
UndeclaredFunction,
RedeclaredVariable,
TypeMismatch,
ImmutableAssignment,
InvalidOperation,
MissingReturn,
WrongReturnType,
ArgumentCountMismatch,
ArgumentTypeMismatch,
}
#[derive(Debug, Clone)]
pub struct SemanticError {
kind: SemanticErrorKind,
message: String,
}
impl std::fmt::Display for SemanticError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Semantic Error: {} ({:?})", self.message, self.kind)
}
}
pub struct AnalysisResult {
pub errors: Vec<SemanticError>,
pub success: bool,
}
impl AnalysisResult {
pub fn unwrap_or_exit(self) {
if !self.errors.is_empty() {
for err in &self.errors {
eprintln!("{}", err);
}
std::process::exit(1);
}
}
}
impl SemanticAnalyzer {
pub fn new() -> Self {
Self {
scopes: vec![HashMap::new()], functions: HashMap::new(),
current_function: None,
loop_depth: 0,
errors: Vec::new(),
}
}
pub fn analyze(&mut self, stmts: &[Stmt]) -> AnalysisResult {
for stmt in stmts {
if let Stmt::FuncDecl(func) = stmt {
self.register_function(func);
}
}
for stmt in stmts {
self.check_stmt(stmt);
}
AnalysisResult {
success: self.errors.is_empty(),
errors: self.errors.clone(),
}
}
fn enter_scope(&mut self) {
self.scopes.push(HashMap::new());
}
fn exit_scope(&mut self) {
self.scopes.pop();
}
fn declare_var(&mut self, name: String, ty: Type, mutability: Mutability) {
if let Some(current_scope) = self.scopes.last_mut() {
if current_scope.contains_key(&name) {
self.error(
SemanticErrorKind::RedeclaredVariable,
format!("Variable '{}' is already declared in this scope", name)
);
} else {
current_scope.insert(name, SymbolInfo { ty, mutability, initialized: true });
}
}
}
fn lookup_var(&self, name: &str) -> Option<&SymbolInfo> {
for scope in self.scopes.iter().rev() {
if let Some(info) = scope.get(name) {
return Some(info);
}
}
None
}
fn register_function(&mut self, func: &FuncDecl) {
if self.functions.contains_key(&func.name) {
self.error(
SemanticErrorKind::RedeclaredVariable,
format!("Function '{}' is already declared", func.name)
);
} else {
self.functions.insert(
func.name.clone(),
FunctionInfo {
params: func.params.clone(),
return_type: func.return_type.clone(),
}
);
}
}
fn check_stmt(&mut self, stmt: &Stmt) {
match stmt {
Stmt::VarDecl(var) => self.check_var_decl(var),
Stmt::FuncDecl(func) => self.check_func_decl(func),
Stmt::Assignment(assign) => self.check_assignment(assign),
Stmt::CompoundAssign { target, op, value } => {
self.check_compound_assign(target, op, value);
}
Stmt::IfStatement(if_stmt) => self.check_if(if_stmt),
Stmt::ForLoop(for_loop) => self.check_for(for_loop),
Stmt::Return(expr) => self.check_return(expr),
Stmt::Print(exprs) => {
for expr in exprs {
self.check_expr(expr);
}
}
Stmt::Break => {
if self.loop_depth == 0 {
self.error(
SemanticErrorKind::InvalidOperation,
"break statement outside of loop".to_string()
);
}
}
Stmt::Continue => {
if self.loop_depth == 0 {
self.error(
SemanticErrorKind::InvalidOperation,
"continue statement outside of loop".to_string()
);
}
}
Stmt::WhileLoop(while_loop) => self.check_while(while_loop)
}
}
fn check_while(&mut self, while_loop: &WhileLoop) {
let cond_type = self.check_expr(&while_loop.cond);
if !matches!(cond_type, Type::Bool) {
self.error(
SemanticErrorKind::TypeMismatch,
format!("While condition must be bool, found {:?}", cond_type)
);
}
self.enter_scope();
self.loop_depth += 1;
for stmt in &while_loop.body {
self.check_stmt(stmt);
}
self.exit_scope();
}
fn check_var_decl(&mut self, var: &VarDecl) {
let expr_type = self.check_expr(&var.value);
if !self.types_compatible(&var.ty, &expr_type) {
self.error(
SemanticErrorKind::TypeMismatch,
format!(
"Type mismatch in variable '{}': expected {:?}, found {:?}",
var.name, var.ty, expr_type
)
);
}
self.declare_var(var.name.clone(), var.ty.clone(), var.mutability);
}
fn check_assignment(&mut self, assign: &Assignment) {
let var_info = match self.lookup_var(&assign.target) {
Some(info) => info.clone(),
None => {
self.error(
SemanticErrorKind::UndeclaredVariable,
format!("Variable '{}' is not declared", assign.target)
);
return;
}
};
if var_info.mutability == Mutability::Const {
self.error(
SemanticErrorKind::ImmutableAssignment,
format!("Cannot assign to const variable '{}'", assign.target)
);
}
let expr_type = self.check_expr(&assign.value);
if !self.types_compatible(&var_info.ty, &expr_type) {
self.error(
SemanticErrorKind::TypeMismatch,
format!(
"Type mismatch in assignment to '{}': expected {:?}, found {:?}",
assign.target, var_info.ty, expr_type
)
);
}
}
fn check_compound_assign(&mut self, target: &str, op: &BinOp, value: &Expr) {
let var_info = match self.lookup_var(target) {
Some(info) => info.clone(),
None => {
self.error(
SemanticErrorKind::UndeclaredVariable,
format!("Variable '{}' is not declared", target)
);
return;
}
};
if var_info.mutability == Mutability::Const {
self.error(
SemanticErrorKind::ImmutableAssignment,
format!("Cannot assign to const variable '{}'", target)
);
}
let expr_type = self.check_expr(value);
if !self.is_valid_binop(&var_info.ty, op, &expr_type) {
self.error(
SemanticErrorKind::InvalidOperation,
format!(
"Invalid operation {:?} between {:?} and {:?}",
op, var_info.ty, expr_type
)
);
}
}
fn check_if(&mut self, if_stmt: &IfStatement) {
let cond_type = self.check_expr(&if_stmt.condition);
if !matches!(cond_type, Type::Bool) {
self.error(
SemanticErrorKind::TypeMismatch,
format!("If condition must be bool, found {:?}", cond_type)
);
}
self.enter_scope();
for stmt in &if_stmt.then_branch {
self.check_stmt(stmt);
}
self.exit_scope();
if let Some(else_branch) = &if_stmt.else_branch {
self.enter_scope();
for stmt in else_branch {
self.check_stmt(stmt);
}
self.exit_scope();
}
}
fn check_for(&mut self, for_loop: &ForLoop) {
let start_type = self.check_expr(&for_loop.start);
let end_type = self.check_expr(&for_loop.end);
if !self.is_integer_type(&start_type) {
self.error(
SemanticErrorKind::TypeMismatch,
format!("For loop start must be integer, found {:?}", start_type)
);
}
if !self.is_integer_type(&end_type) {
self.error(
SemanticErrorKind::TypeMismatch,
format!("For loop end must be integer, found {:?}", end_type)
);
}
self.enter_scope();
self.loop_depth += 1;
self.declare_var(for_loop.var.clone(), Type::Int, Mutability::Const);
for stmt in &for_loop.body {
self.check_stmt(stmt);
}
self.exit_scope();
}
fn check_func_decl(&mut self, func: &FuncDecl) {
self.current_function = Some(func.name.clone());
self.enter_scope();
for (ty, name) in &func.params {
self.declare_var(name.clone(), ty.clone(), Mutability::Mutable);
}
let mut has_return = false;
for stmt in &func.body {
if matches!(stmt, Stmt::Return(_)) {
has_return = true;
}
self.check_stmt(stmt);
}
if func.return_type.is_some() && !matches!(func.return_type, Some(Type::Void)) && !has_return {
self.error(
SemanticErrorKind::MissingReturn,
format!("Function '{}' must return a value", func.name)
);
}
self.exit_scope();
self.current_function = None;
}
fn check_return(&mut self, expr: &Expr) {
let expr_type = self.check_expr(expr);
if let Some(func_name) = &self.current_function {
if let Some(func_info) = self.functions.get(func_name) {
if let Some(expected_type) = &func_info.return_type {
if !self.types_compatible(expected_type, &expr_type) {
self.error(
SemanticErrorKind::WrongReturnType,
format!(
"Function '{}' expects return type {:?}, found {:?}",
func_name, expected_type, expr_type
)
);
}
}
}
}
}
fn check_expr(&mut self, expr: &Expr) -> Type {
match expr {
Expr::UIntLiteral(_) => Type::UInt,
Expr::IntLiteral(_) => Type::Int,
Expr::BoolLiteral(_) => Type::Bool,
Expr::StringLiteral(_) => Type::Str,
Expr::UInt8(_) => Type::U8,
Expr::UInt16(_) => Type::U16,
Expr::UInt32(_) => Type::U32,
Expr::UInt64(_) => Type::U64,
Expr::UInt128(_) => Type::U128,
Expr::Int8(_) => Type::I8,
Expr::Int16(_) => Type::I16,
Expr::Int32(_) => Type::I32,
Expr::Int64(_) => Type::I64,
Expr::Int128(_) => Type::I128,
Expr::Int(_) => Type::Int,
Expr::UInt(_) => Type::UInt,
Expr::Ident(name) => {
match self.lookup_var(name) {
Some(info) => info.ty.clone(),
None => {
self.error(
SemanticErrorKind::UndeclaredVariable,
format!("Variable '{}' is not declared", name)
);
Type::Void }
}
}
Expr::Call { name, args } => {
let func_info = match self.functions.get(name).cloned() {
Some(info) => info,
None => {
self.error(
SemanticErrorKind::UndeclaredFunction,
format!("Function '{}' is not declared", name)
);
return Type::Void;
}
};
if args.len() != func_info.params.len() {
self.error(
SemanticErrorKind::ArgumentCountMismatch,
format!(
"Function '{}' expects {} arguments, found {}",
name, func_info.params.len(), args.len()
)
);
}
for (i, arg) in args.iter().enumerate() {
if let Some((expected_type, _)) = func_info.params.get(i) {
let arg_type = self.check_expr(arg);
if !self.types_compatible(expected_type, &arg_type) {
self.error(
SemanticErrorKind::ArgumentTypeMismatch,
format!(
"Argument {} of function '{}': expected {:?}, found {:?}",
i + 1, name, expected_type, arg_type
)
);
}
}
}
func_info.return_type.unwrap_or(Type::Void)
}
Expr::BinaryOp { left, op, right } => {
let left_type = self.check_expr(left);
let right_type = self.check_expr(right);
if !self.is_valid_binop(&left_type, op, &right_type) {
self.error(
SemanticErrorKind::InvalidOperation,
format!(
"Invalid binary operation {:?} between {:?} and {:?}",
op, left_type, right_type
)
);
}
self.result_type_of_binop(&left_type, op, &right_type)
}
Expr::Identity { expr, negated: _ } => {
let ty = self.check_expr(expr);
if !matches!(ty, Type::Bool) {
self.error(
SemanticErrorKind::TypeMismatch,
format!("Identity operator expects bool, found {:?}", ty)
);
}
Type::Bool
}
Expr::Vec { values, size: _ } => {
if values.is_empty() {
self.error(
SemanticErrorKind::InvalidOperation,
"Vector cannot be empty".to_string()
);
return Type::Void;
}
let first_type = self.check_expr(&values[0]);
for val in values.iter().skip(1) {
let val_type = self.check_expr(val);
if !self.types_compatible(&first_type, &val_type) {
self.error(
SemanticErrorKind::TypeMismatch,
format!("Vector elements must have same type: expected {:?}, found {:?}", first_type, val_type)
);
}
}
Type::Vec {
inner: Box::new(first_type),
size: values.len(),
}
}
Expr::Unknown => Type::Void,
}
}
fn types_compatible(&self, expected: &Type, found: &Type) -> bool {
match (expected, found) {
(Type::Int, Type::I8 | Type::I16 | Type::I32 | Type::I64 | Type::I128 | Type::Int) => true,
(Type::UInt, Type::U8 | Type::U16 | Type::U32 | Type::U64 | Type::U128 | Type::UInt) => true,
_ => expected == found,
}
}
fn is_integer_type(&self, ty: &Type) -> bool {
matches!(
ty,
Type::Int | Type::UInt |
Type::I8 | Type::I16 | Type::I32 | Type::I64 | Type::I128 |
Type::U8 | Type::U16 | Type::U32 | Type::U64 | Type::U128
)
}
fn is_valid_binop(&self, left: &Type, op: &BinOp, right: &Type) -> bool {
match op {
BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Percent => {
self.is_integer_type(left) && self.is_integer_type(right)
}
BinOp::DoubleEqual | BinOp::Less | BinOp::Greater | BinOp::LessEqual | BinOp::GreaterEqual => {
self.types_compatible(left, right)
}
BinOp::IndentityOp => {
matches!(left, Type::Bool) && matches!(right, Type::Bool)
}
BinOp::CompoundAdd | BinOp::CompoundSub | BinOp::CompoundMul | BinOp::CompoundDiv => {
self.is_integer_type(left) && self.is_integer_type(right)
}
}
}
fn result_type_of_binop(&self, left: &Type, op: &BinOp, _right: &Type) -> Type {
match op {
BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Percent => left.clone(),
BinOp::DoubleEqual | BinOp::Less | BinOp::Greater | BinOp::LessEqual | BinOp::GreaterEqual | BinOp::IndentityOp => Type::Bool,
BinOp::CompoundAdd | BinOp::CompoundSub | BinOp::CompoundMul | BinOp::CompoundDiv => left.clone(),
}
}
fn error(&mut self, kind: SemanticErrorKind, message: String) {
self.errors.push(SemanticError { kind, message });
}
}