use std::collections::HashMap;
use crate::error::{CompileError, Result};
use crate::parse::ast::{
BinaryOp, BlockItem, Declaration, Expr, ForInit, FunctionDecl, Program, Statement,
TopLevelDecl, Type, UnaryOp, has_fam, struct_member_offset_ex,
};
#[derive(Debug, Clone)]
enum SymbolType {
Variable(Type),
Function {
return_type: Type,
param_types: Option<Vec<Type>>,
is_variadic: bool,
},
}
pub fn typecheck(program: &mut Program) -> Result<()> {
let mut symbols: HashMap<String, SymbolType> = HashMap::new();
for decl in &mut program.declarations {
match decl {
TopLevelDecl::Function(func_decl) => {
typecheck_function_decl(func_decl, &mut symbols)?;
}
TopLevelDecl::Variable(var_decl) => {
typecheck_file_scope_var(var_decl, &mut symbols)?;
}
TopLevelDecl::Typedef { .. } => {
}
}
}
Ok(())
}
fn convert_string_to_char_array(decl: &mut Declaration) -> Result<()> {
if let Some(Expr::StringLiteral(s)) = &decl.init
&& let Type::Array(ref elem_type, count) = decl.var_type
&& matches!(**elem_type, Type::Char | Type::UChar)
{
let bytes: Vec<u8> = s.bytes().collect();
let str_len = bytes.len() + 1; let array_size = if count == 0 { str_len } else { count };
if count > 0 && bytes.len() > count {
return Err(CompileError::TypeError(format!(
"initializer-string for '{}' is too long ({} chars for array of {})",
decl.name,
bytes.len(),
count
)));
}
decl.var_type = Type::Array(elem_type.clone(), array_size);
let mut chars: Vec<Expr> = bytes
.iter()
.take(array_size)
.map(|&b| Expr::Constant(b as i64))
.collect();
while chars.len() < array_size {
chars.push(Expr::Constant(0));
}
decl.init = Some(Expr::CompoundInit(chars));
}
Ok(())
}
fn typecheck_file_scope_var(
decl: &mut Declaration,
symbols: &mut HashMap<String, SymbolType>,
) -> Result<()> {
if decl.name.is_empty() {
return Ok(());
}
if !decl.var_type.is_object_type() {
return Err(CompileError::TypeError(format!(
"variable '{}' has non-object type '{:?}'",
decl.name, decl.var_type
)));
}
convert_string_to_char_array(decl)?;
symbols.insert(
decl.name.clone(),
SymbolType::Variable(decl.var_type.clone()),
);
if let Some(init) = &mut decl.init {
resolve_constant(init);
}
Ok(())
}
fn typecheck_function_decl(
func: &mut FunctionDecl,
symbols: &mut HashMap<String, SymbolType>,
) -> Result<()> {
let param_types_vec: Vec<Type> = func.params.iter().map(|(t, _)| t.clone()).collect();
let param_types: Option<Vec<Type>> = if func.has_prototype {
Some(param_types_vec)
} else {
None
};
if let Some(existing) = symbols.get(&func.name)
&& let SymbolType::Function {
return_type,
param_types: existing_params,
is_variadic,
} = existing
{
let conflict = *return_type != func.return_type
|| *is_variadic != func.is_variadic
|| match (existing_params, ¶m_types) {
(None, _) | (_, &None) => false, (Some(a), Some(b)) => a != b,
};
if conflict {
return Err(CompileError::TypeError(format!(
"conflicting types for function '{}'",
func.name
)));
}
}
let should_update = match symbols.get(&func.name) {
Some(SymbolType::Function {
param_types: existing_params,
..
}) => existing_params.is_none() && param_types.is_some(),
_ => true,
};
if should_update {
symbols.insert(
func.name.clone(),
SymbolType::Function {
return_type: func.return_type.clone(),
param_types: param_types.clone(),
is_variadic: func.is_variadic,
},
);
}
if let Some(body) = &mut func.body {
let mut local_symbols = symbols.clone();
for (param_type, param_name) in &func.params {
local_symbols.insert(param_name.clone(), SymbolType::Variable(param_type.clone()));
}
for item in body.iter_mut() {
typecheck_block_item(item, &mut local_symbols, &func.return_type)?;
}
}
Ok(())
}
fn typecheck_block_item(
item: &mut BlockItem,
symbols: &mut HashMap<String, SymbolType>,
return_type: &Type,
) -> Result<()> {
match item {
BlockItem::Statement(stmt) => typecheck_statement(stmt, symbols, return_type),
BlockItem::Declaration(decl) => typecheck_local_declaration(decl, symbols),
BlockItem::Typedef { .. } => Ok(()), }
}
fn typecheck_local_declaration(
decl: &mut Declaration,
symbols: &mut HashMap<String, SymbolType>,
) -> Result<()> {
if decl.name.is_empty() {
return Ok(());
}
if !decl.var_type.is_object_type() {
return Err(CompileError::TypeError(format!(
"variable '{}' has non-object type '{:?}'",
decl.name, decl.var_type
)));
}
convert_string_to_char_array(decl)?;
symbols.insert(
decl.name.clone(),
SymbolType::Variable(decl.var_type.clone()),
);
if let Some(init) = &mut decl.init {
if let Expr::CompoundInit(inits) = init {
if let Type::Struct { ref members, .. } = decl.var_type {
let expected = if has_fam(members) {
members.len() - 1
} else {
members.len()
};
if inits.len() != expected {
return Err(CompileError::TypeError(format!(
"wrong number of initializers for struct (expected {}, got {})",
expected,
inits.len()
)));
}
let init_members = if has_fam(members) {
&members[..members.len() - 1]
} else {
members
};
for (init_expr, member) in inits.iter_mut().zip(init_members.iter()) {
let init_type = typecheck_expr(init_expr, symbols)?;
if init_type != member.member_type {
let old = std::mem::replace(init_expr, Expr::Constant(0));
*init_expr = Expr::Cast {
target_type: member.member_type.clone(),
source_type: init_type,
expr: Box::new(old),
};
}
}
return Ok(());
} else if let Type::Array(_, count) = &decl.var_type {
let count = *count;
if count == 0
&& !inits.is_empty()
&& let Type::Array(ref elem_type, _) = decl.var_type
{
let et = (**elem_type).clone();
decl.var_type = Type::Array(Box::new(et), inits.len());
}
let (elem_type, count) = match &decl.var_type {
Type::Array(e, c) => (e.clone(), *c),
_ => unreachable!(),
};
if inits.len() > count {
return Err(CompileError::TypeError(format!(
"too many initializers for array (expected at most {}, got {})",
count,
inits.len()
)));
}
for init_expr in inits.iter_mut() {
if let Expr::CompoundInit(sub_inits) = init_expr
&& let Type::Struct { ref members, .. } = *elem_type
{
for (sub_init, member) in sub_inits.iter_mut().zip(members.iter()) {
let sub_type = typecheck_expr(sub_init, symbols)?;
if sub_type != member.member_type {
let old = std::mem::replace(sub_init, Expr::Constant(0));
*sub_init = Expr::Cast {
target_type: member.member_type.clone(),
source_type: sub_type,
expr: Box::new(old),
};
}
}
continue;
}
let init_type = typecheck_expr(init_expr, symbols)?;
if init_type != *elem_type {
let old = std::mem::replace(init_expr, Expr::Constant(0));
*init_expr = Expr::Cast {
target_type: (*elem_type).clone(),
source_type: init_type,
expr: Box::new(old),
};
}
}
return Ok(());
} else {
return Err(CompileError::TypeError(
"compound initializer used with non-struct/non-array type".to_string(),
));
}
}
let init_type = typecheck_expr(init, symbols)?;
if decl.var_type.is_struct() && init_type.is_struct() {
if decl.var_type != init_type {
return Err(CompileError::TypeError(
"incompatible struct types in initialization".to_string(),
));
}
return Ok(());
}
if init_type != decl.var_type {
if decl.var_type.is_struct() || init_type.is_struct() {
return Err(CompileError::TypeError(
"incompatible types in initialization".to_string(),
));
}
if decl.var_type.is_pointer()
&& !init_type.is_pointer()
&& !is_null_pointer_constant(init)
{
return Err(CompileError::TypeError(
"cannot initialize pointer with non-pointer non-null value".to_string(),
));
}
if decl.var_type.is_pointer()
&& init_type.is_pointer()
&& decl.var_type != init_type
&& !is_void_pointer(&decl.var_type)
&& !is_void_pointer(&init_type)
{
return Err(CompileError::TypeError(
"incompatible pointer types in initialization".to_string(),
));
}
let old_init = std::mem::replace(init, Expr::Constant(0)); *init = Expr::Cast {
target_type: decl.var_type.clone(),
source_type: init_type,
expr: Box::new(old_init),
};
}
}
Ok(())
}
fn typecheck_statement(
stmt: &mut Statement,
symbols: &mut HashMap<String, SymbolType>,
return_type: &Type,
) -> Result<()> {
match stmt {
Statement::Return(opt_expr) => {
match opt_expr {
None => {
if !return_type.is_void() {
return Err(CompileError::TypeError(
"return with no value in non-void function".to_string(),
));
}
Ok(())
}
Some(expr) => {
if return_type.is_void() {
return Err(CompileError::TypeError(
"return with a value in void function".to_string(),
));
}
let expr_type = typecheck_expr(expr, symbols)?;
if expr_type != *return_type {
if return_type.is_pointer()
&& expr_type.is_pointer()
&& *return_type != expr_type
&& !is_void_pointer(return_type)
&& !is_void_pointer(&expr_type)
{
return Err(CompileError::TypeError(
"incompatible pointer types in return".to_string(),
));
}
let old_expr = std::mem::replace(expr, Expr::Constant(0));
*expr = Expr::Cast {
target_type: return_type.clone(),
source_type: expr_type,
expr: Box::new(old_expr),
};
}
Ok(())
}
}
}
Statement::Expression(expr) => {
typecheck_expr(expr, symbols)?;
Ok(())
}
Statement::Null => Ok(()),
Statement::If {
condition,
then_branch,
else_branch,
} => {
let cond_type = typecheck_expr(condition, symbols)?;
if cond_type.is_void() {
return Err(CompileError::TypeError(
"void expression used as condition".to_string(),
));
}
typecheck_statement(then_branch, symbols, return_type)?;
if let Some(else_stmt) = else_branch {
typecheck_statement(else_stmt, symbols, return_type)?;
}
Ok(())
}
Statement::Compound(items) => {
let mut inner_symbols = symbols.clone();
for item in items.iter_mut() {
typecheck_block_item(item, &mut inner_symbols, return_type)?;
}
Ok(())
}
Statement::While { condition, body } => {
let cond_type = typecheck_expr(condition, symbols)?;
if cond_type.is_void() {
return Err(CompileError::TypeError(
"void expression used as condition".to_string(),
));
}
typecheck_statement(body, symbols, return_type)
}
Statement::DoWhile { body, condition } => {
typecheck_statement(body, symbols, return_type)?;
let cond_type = typecheck_expr(condition, symbols)?;
if cond_type.is_void() {
return Err(CompileError::TypeError(
"void expression used as condition".to_string(),
));
}
Ok(())
}
Statement::For {
init,
condition,
post,
body,
} => {
let mut inner_symbols = symbols.clone();
match init.as_mut() {
ForInit::Declaration(decls) => {
for decl in decls.iter_mut() {
typecheck_local_declaration(decl, &mut inner_symbols)?;
}
}
ForInit::Expression(Some(expr)) => {
typecheck_expr(expr, &inner_symbols)?;
}
ForInit::Expression(None) => {}
}
if let Some(cond) = condition {
let cond_type = typecheck_expr(cond, &inner_symbols)?;
if cond_type.is_void() {
return Err(CompileError::TypeError(
"void expression used as condition".to_string(),
));
}
}
if let Some(post_expr) = post {
typecheck_expr(post_expr, &inner_symbols)?;
}
typecheck_statement(body, &mut inner_symbols, return_type)
}
Statement::Break | Statement::Continue => Ok(()),
Statement::Switch { expr, body } => {
let expr_type = typecheck_expr(expr, symbols)?;
if expr_type.is_void()
|| expr_type.is_struct()
|| expr_type.is_floating()
|| expr_type.is_pointer()
{
return Err(CompileError::TypeError(
"switch expression must have integer type".to_string(),
));
}
typecheck_statement(body, symbols, return_type)
}
Statement::Case { body, .. } => typecheck_statement(body, symbols, return_type),
Statement::Default(body) => typecheck_statement(body, symbols, return_type),
Statement::Goto(_) => Ok(()),
Statement::Label { body, .. } => typecheck_statement(body, symbols, return_type),
}
}
fn typecheck_expr(expr: &mut Expr, symbols: &HashMap<String, SymbolType>) -> Result<Type> {
match expr {
Expr::Constant(v) => {
if *v > i32::MAX as i64 || *v < i32::MIN as i64 {
let val = *v;
*expr = Expr::ConstantLong(val);
Ok(Type::Long)
} else {
Ok(Type::Int)
}
}
Expr::ConstantLong(_) => Ok(Type::Long),
Expr::ConstantUInt(v) => {
if *v > u32::MAX as u64 {
let val = *v;
*expr = Expr::ConstantULong(val);
Ok(Type::ULong)
} else {
Ok(Type::UInt)
}
}
Expr::ConstantULong(_) => Ok(Type::ULong),
Expr::ConstantDouble(_) => Ok(Type::Double),
Expr::ConstantFloat(_) => Ok(Type::Float),
Expr::Cast {
target_type,
source_type,
expr: inner,
} => {
let actual_source = typecheck_expr(inner, symbols)?;
*source_type = actual_source.clone();
if target_type.is_void() {
return Ok(Type::Void);
}
if actual_source.is_void() {
return Err(CompileError::TypeError(
"cannot cast void expression to non-void type".to_string(),
));
}
if target_type.is_pointer() && actual_source.is_floating() {
return Err(CompileError::TypeError(
"cannot cast floating-point to pointer type".to_string(),
));
}
if target_type.is_floating() && actual_source.is_pointer() {
return Err(CompileError::TypeError(
"cannot cast pointer type to floating-point".to_string(),
));
}
Ok(target_type.clone())
}
Expr::Var(name) => {
match symbols.get(name) {
Some(SymbolType::Variable(t)) => {
Ok(array_decay(t.clone()))
}
Some(SymbolType::Function {
return_type,
param_types,
is_variadic,
}) => {
Ok(Type::Pointer(Box::new(Type::Function {
return_type: Box::new(return_type.clone()),
param_types: param_types.clone(),
is_variadic: *is_variadic,
})))
}
None if name == "__func__" || name == "__FUNCTION__" => {
Ok(Type::Pointer(Box::new(Type::Char)))
}
None => Err(CompileError::TypeError(format!(
"undeclared variable '{}'",
name
))),
}
}
Expr::Assign(lhs, rhs) => {
let lhs_type = typecheck_expr(lhs, symbols)?;
if !is_lvalue(lhs) {
return Err(CompileError::TypeError(
"left side of assignment must be an lvalue".to_string(),
));
}
if is_array_var(lhs, symbols) {
return Err(CompileError::TypeError(
"cannot assign to array variable".to_string(),
));
}
let rhs_type = typecheck_expr(rhs, symbols)?;
if rhs_type.is_void() {
return Err(CompileError::TypeError(
"void expression used as right-hand side of assignment".to_string(),
));
}
if lhs_type.is_struct() || rhs_type.is_struct() {
if lhs_type != rhs_type {
return Err(CompileError::TypeError(
"incompatible struct types in assignment".to_string(),
));
}
return Ok(lhs_type);
}
if rhs_type != lhs_type {
if lhs_type.is_pointer()
&& rhs_type.is_pointer()
&& lhs_type != rhs_type
&& !is_void_pointer(&lhs_type)
&& !is_void_pointer(&rhs_type)
{
return Err(CompileError::TypeError(
"incompatible pointer types in assignment".to_string(),
));
}
let old_rhs = std::mem::replace(rhs.as_mut(), Expr::Constant(0));
**rhs = Expr::Cast {
target_type: lhs_type.clone(),
source_type: rhs_type,
expr: Box::new(old_rhs),
};
}
Ok(lhs_type)
}
Expr::CompoundAssign(op, lhs, rhs) => {
let lhs_type = typecheck_expr(lhs, symbols)?;
if !is_lvalue(lhs) {
return Err(CompileError::TypeError(
"left side of compound assignment must be an lvalue".to_string(),
));
}
if is_array_var(lhs, symbols) {
return Err(CompileError::TypeError(
"cannot assign to array variable".to_string(),
));
}
let rhs_type = typecheck_expr(rhs, symbols)?;
if lhs_type.is_pointer()
&& is_void_pointer(&lhs_type)
&& matches!(op, BinaryOp::Add | BinaryOp::Subtract)
{
return Err(CompileError::TypeError(
"pointer arithmetic on void pointer".to_string(),
));
}
if lhs_type.is_pointer() && matches!(op, BinaryOp::Add | BinaryOp::Subtract) {
if rhs_type != Type::Long {
let old_rhs = std::mem::replace(rhs.as_mut(), Expr::Constant(0));
**rhs = Expr::Cast {
target_type: Type::Long,
source_type: rhs_type,
expr: Box::new(old_rhs),
};
}
Ok(lhs_type)
} else if lhs_type.is_pointer() {
Err(CompileError::TypeError(format!(
"compound assignment '{:?}' cannot be applied to pointer types",
op
)))
} else {
let _common = common_type(&lhs_type, &rhs_type);
Ok(lhs_type)
}
}
Expr::PostfixIncrement(inner) | Expr::PostfixDecrement(inner) => {
let inner_type = typecheck_expr(inner, symbols)?;
if !is_lvalue(inner) {
return Err(CompileError::TypeError(
"operand of postfix increment/decrement must be an lvalue".to_string(),
));
}
if is_array_var(inner, symbols) {
return Err(CompileError::TypeError(
"cannot increment/decrement array variable".to_string(),
));
}
Ok(inner_type)
}
Expr::Unary(op, inner) => {
let inner_type = typecheck_expr(inner, symbols)?;
match op {
UnaryOp::Not => {
if inner_type.is_void() {
return Err(CompileError::TypeError(
"logical not '!' cannot be applied to void expression".to_string(),
));
}
Ok(Type::Int)
}
UnaryOp::Complement => {
if inner_type.is_void() {
return Err(CompileError::TypeError(
"bitwise complement '~' cannot be applied to void expression"
.to_string(),
));
}
if inner_type.is_floating() {
return Err(CompileError::TypeError(
"bitwise complement '~' cannot be applied to floating-point"
.to_string(),
));
}
if inner_type.is_pointer() {
return Err(CompileError::TypeError(
"bitwise complement '~' cannot be applied to pointer type".to_string(),
));
}
if inner_type.is_character() {
let old = std::mem::replace(inner.as_mut(), Expr::Constant(0));
**inner = Expr::Cast {
target_type: Type::Int,
source_type: inner_type,
expr: Box::new(old),
};
Ok(Type::Int)
} else {
Ok(inner_type)
}
}
UnaryOp::Negate => {
if inner_type.is_void() {
return Err(CompileError::TypeError(
"unary negation '-' cannot be applied to void expression".to_string(),
));
}
if inner_type.is_pointer() {
return Err(CompileError::TypeError(
"unary negation '-' cannot be applied to pointer type".to_string(),
));
}
if inner_type.is_struct() {
return Err(CompileError::TypeError(
"unary negation '-' cannot be applied to struct type".to_string(),
));
}
if inner_type.is_character() {
let old = std::mem::replace(inner.as_mut(), Expr::Constant(0));
**inner = Expr::Cast {
target_type: Type::Int,
source_type: inner_type,
expr: Box::new(old),
};
Ok(Type::Int)
} else {
Ok(inner_type)
}
}
UnaryOp::PreIncrement | UnaryOp::PreDecrement => {
if !is_lvalue(inner) {
return Err(CompileError::TypeError(
"operand of prefix increment/decrement must be an lvalue".to_string(),
));
}
if is_array_var(inner, symbols) {
return Err(CompileError::TypeError(
"cannot increment/decrement array variable".to_string(),
));
}
Ok(inner_type)
}
}
}
Expr::Binary(op, left, right) => {
let left_type = typecheck_expr(left, symbols)?;
let right_type = typecheck_expr(right, symbols)?;
match op {
BinaryOp::LogicalAnd | BinaryOp::LogicalOr => {
if left_type.is_void() || right_type.is_void() {
return Err(CompileError::TypeError(
"void expression used as operand of logical operator".to_string(),
));
}
Ok(Type::Int)
}
BinaryOp::LessThan
| BinaryOp::LessEqual
| BinaryOp::GreaterThan
| BinaryOp::GreaterEqual
| BinaryOp::Equal
| BinaryOp::NotEqual => {
if left_type.is_void() || right_type.is_void() {
return Err(CompileError::TypeError(
"void expression used as operand of comparison".to_string(),
));
}
if left_type.is_pointer() || right_type.is_pointer() {
if left_type.is_pointer() && right_type.is_pointer() {
if left_type != right_type {
if is_void_pointer(&left_type) || is_void_pointer(&right_type) {
if !matches!(op, BinaryOp::Equal | BinaryOp::NotEqual) {
return Err(CompileError::TypeError(
"ordered comparison of void pointer with other pointer type".to_string()
));
}
if !is_void_pointer(&left_type) {
convert_operand(left, &left_type, &right_type);
} else if !is_void_pointer(&right_type) {
convert_operand(right, &right_type, &left_type);
}
} else {
return Err(CompileError::TypeError(
"comparison between incompatible pointer types".to_string(),
));
}
}
} else if left_type.is_pointer() && !right_type.is_pointer() {
if !is_null_pointer_constant(right) {
return Err(CompileError::TypeError(
"comparison between pointer and non-zero integer".to_string(),
));
}
convert_operand(right, &right_type, &left_type);
} else {
if !is_null_pointer_constant(left) {
return Err(CompileError::TypeError(
"comparison between pointer and non-zero integer".to_string(),
));
}
convert_operand(left, &left_type, &right_type);
}
Ok(Type::Int)
} else {
let common = common_type(&left_type, &right_type);
convert_operand(left, &left_type, &common);
convert_operand(right, &right_type, &common);
Ok(Type::Int)
}
}
BinaryOp::Add => {
if left_type.is_void() || right_type.is_void() {
return Err(CompileError::TypeError(
"void expression used as operand of addition".to_string(),
));
}
if left_type.is_pointer() && right_type.is_pointer() {
return Err(CompileError::TypeError(
"cannot add two pointers".to_string(),
));
}
if left_type.is_pointer() && !right_type.is_pointer() {
if is_void_pointer(&left_type) {
return Err(CompileError::TypeError(
"pointer arithmetic on void pointer".to_string(),
));
}
if right_type != Type::Long {
convert_operand(right, &right_type, &Type::Long);
}
Ok(left_type)
} else if !left_type.is_pointer() && right_type.is_pointer() {
if is_void_pointer(&right_type) {
return Err(CompileError::TypeError(
"pointer arithmetic on void pointer".to_string(),
));
}
std::mem::swap(left, right);
let int_type = left_type.clone(); if int_type != Type::Long {
convert_operand(right, &int_type, &Type::Long);
}
Ok(right_type) } else {
let common = common_type(&left_type, &right_type);
convert_operand(left, &left_type, &common);
convert_operand(right, &right_type, &common);
Ok(common)
}
}
BinaryOp::Subtract => {
if left_type.is_void() || right_type.is_void() {
return Err(CompileError::TypeError(
"void expression used as operand of subtraction".to_string(),
));
}
if left_type.is_pointer() && right_type.is_pointer() {
if is_void_pointer(&left_type) || is_void_pointer(&right_type) {
return Err(CompileError::TypeError(
"pointer arithmetic on void pointer".to_string(),
));
}
if left_type != right_type {
return Err(CompileError::TypeError(
"subtraction between incompatible pointer types".to_string(),
));
}
Ok(Type::Long) } else if left_type.is_pointer() && !right_type.is_pointer() {
if is_void_pointer(&left_type) {
return Err(CompileError::TypeError(
"pointer arithmetic on void pointer".to_string(),
));
}
if right_type != Type::Long {
convert_operand(right, &right_type, &Type::Long);
}
Ok(left_type)
} else if !left_type.is_pointer() && right_type.is_pointer() {
Err(CompileError::TypeError(
"cannot subtract pointer from integer".to_string(),
))
} else {
let common = common_type(&left_type, &right_type);
convert_operand(left, &left_type, &common);
convert_operand(right, &right_type, &common);
Ok(common)
}
}
BinaryOp::Multiply | BinaryOp::Divide | BinaryOp::Remainder => {
if left_type.is_void() || right_type.is_void() {
return Err(CompileError::TypeError(
"void expression used as operand of arithmetic".to_string(),
));
}
if left_type.is_pointer() || right_type.is_pointer() {
return Err(CompileError::TypeError(format!(
"arithmetic operator '{:?}' cannot be applied to pointer types",
op
)));
}
let common = common_type(&left_type, &right_type);
if matches!(op, BinaryOp::Remainder) && common.is_floating() {
return Err(CompileError::TypeError(
"remainder '%' cannot be applied to floating-point".to_string(),
));
}
convert_operand(left, &left_type, &common);
convert_operand(right, &right_type, &common);
Ok(common)
}
BinaryOp::BitwiseAnd | BinaryOp::BitwiseOr | BinaryOp::BitwiseXor => {
if left_type.is_void() || right_type.is_void() {
return Err(CompileError::TypeError(
"void expression used as operand of bitwise operation".to_string(),
));
}
if left_type.is_pointer() || right_type.is_pointer() {
return Err(CompileError::TypeError(format!(
"bitwise operator '{:?}' cannot be applied to pointer types",
op
)));
}
if left_type.is_floating() || right_type.is_floating() {
return Err(CompileError::TypeError(
"bitwise operator cannot be applied to floating-point".to_string(),
));
}
let common = common_type(&left_type, &right_type);
convert_operand(left, &left_type, &common);
convert_operand(right, &right_type, &common);
Ok(common)
}
BinaryOp::ShiftLeft | BinaryOp::ShiftRight => {
if left_type.is_void() || right_type.is_void() {
return Err(CompileError::TypeError(
"void expression used as operand of shift operation".to_string(),
));
}
if left_type.is_pointer() || right_type.is_pointer() {
return Err(CompileError::TypeError(
"shift operator cannot be applied to pointer types".to_string(),
));
}
if left_type.is_floating() || right_type.is_floating() {
return Err(CompileError::TypeError(
"shift operator cannot be applied to floating-point".to_string(),
));
}
let promoted_left = integer_promote(&left_type);
let promoted_right = integer_promote(&right_type);
convert_operand(left, &left_type, &promoted_left);
convert_operand(right, &right_type, &promoted_right);
Ok(promoted_left)
}
BinaryOp::Comma => Ok(right_type),
}
}
Expr::Conditional {
condition,
then_expr,
else_expr,
} => {
let cond_type = typecheck_expr(condition, symbols)?;
if cond_type.is_void() {
return Err(CompileError::TypeError(
"void expression used as condition".to_string(),
));
}
let then_type = typecheck_expr(then_expr, symbols)?;
let else_type = typecheck_expr(else_expr, symbols)?;
if then_type.is_void() && else_type.is_void() {
return Ok(Type::Void);
}
if then_type.is_void() || else_type.is_void() {
return Err(CompileError::TypeError(
"incompatible types in conditional expression (void and non-void)".to_string(),
));
}
if then_type.is_pointer() || else_type.is_pointer() {
if then_type == else_type {
Ok(then_type)
} else if then_type.is_pointer() && is_null_pointer_constant(else_expr) {
convert_operand(else_expr, &else_type, &then_type);
Ok(then_type)
} else if else_type.is_pointer() && is_null_pointer_constant(then_expr) {
convert_operand(then_expr, &then_type, &else_type);
Ok(else_type)
} else if then_type.is_pointer() && else_type.is_pointer() {
if is_void_pointer(&then_type) && !is_void_pointer(&else_type) {
convert_operand(else_expr, &else_type, &then_type);
Ok(then_type)
} else if is_void_pointer(&else_type) && !is_void_pointer(&then_type) {
convert_operand(then_expr, &then_type, &else_type);
Ok(else_type)
} else {
Err(CompileError::TypeError(
"incompatible types in conditional expression".to_string(),
))
}
} else {
Err(CompileError::TypeError(
"incompatible types in conditional expression".to_string(),
))
}
} else {
let common = common_type(&then_type, &else_type);
convert_operand(then_expr, &then_type, &common);
convert_operand(else_expr, &else_type, &common);
Ok(common)
}
}
Expr::FunctionCall(name, args) => {
if name == "__builtin_expect" {
if args.len() != 2 {
return Err(CompileError::TypeError(
"__builtin_expect requires 2 arguments".to_string(),
));
}
let first_type = typecheck_expr(&mut args[0], symbols)?;
typecheck_expr(&mut args[1], symbols)?;
return Ok(first_type);
}
if matches!(
name.as_str(),
"__builtin_bswap16"
| "__builtin_bswap32"
| "__builtin_bswap64"
| "__builtin_object_size"
| "__builtin_ctz"
| "__builtin_ctzl"
| "__builtin_clz"
| "__builtin_clzl"
| "__builtin_popcount"
| "__builtin_popcountl"
| "__builtin_abs"
| "__builtin_labs"
) {
for arg in args.iter_mut() {
typecheck_expr(arg, symbols)?;
}
if !args.is_empty() {
return typecheck_expr(&mut args[0], symbols);
}
return Ok(Type::Int);
}
let call_sig = match symbols.get(name) {
Some(SymbolType::Function {
return_type,
param_types,
is_variadic,
}) => CallSig {
return_type: return_type.clone(),
param_types: param_types.clone(),
is_variadic: *is_variadic,
},
Some(SymbolType::Variable(Type::Pointer(inner))) => match inner.as_ref() {
Type::Function {
return_type,
param_types,
is_variadic,
} => CallSig {
return_type: *return_type.clone(),
param_types: param_types.clone(),
is_variadic: *is_variadic,
},
Type::Void => {
for arg in args.iter_mut() {
typecheck_expr(arg, symbols)?;
}
return Ok(Type::Int);
}
_ => {
return Err(CompileError::TypeError(format!(
"called object '{}' is not a function or function pointer",
name
)));
}
},
_ => {
return Err(CompileError::TypeError(format!(
"undeclared function '{}'",
name
)));
}
};
typecheck_call_args(name, args, &call_sig, symbols)?;
Ok(call_sig.return_type)
}
Expr::CallExpr(callee, args) => {
let callee_type = typecheck_expr(callee, symbols)?;
let call_sig = match &callee_type {
Type::Pointer(inner) => match inner.as_ref() {
Type::Function {
return_type,
param_types,
is_variadic,
} => CallSig {
return_type: *return_type.clone(),
param_types: param_types.clone(),
is_variadic: *is_variadic,
},
Type::Void => {
for arg in args.iter_mut() {
typecheck_expr(arg, symbols)?;
}
return Ok(Type::Int);
}
_ => {
return Err(CompileError::TypeError(
"called expression is not a function pointer".to_string(),
));
}
},
_ => {
return Err(CompileError::TypeError(
"called expression is not a function pointer".to_string(),
));
}
};
typecheck_call_args("<expr>", args, &call_sig, symbols)?;
Ok(call_sig.return_type)
}
Expr::Dereref(inner) => {
let inner_type = typecheck_expr(inner, symbols)?;
match inner_type {
Type::Pointer(ref target) if target.is_void() => Err(CompileError::TypeError(
"dereference of void pointer".to_string(),
)),
Type::Pointer(target) => {
Ok(array_decay(*target))
}
_ => Err(CompileError::TypeError(
"dereference of non-pointer type".to_string(),
)),
}
}
Expr::AddrOf(inner) => {
if let Expr::Var(name) = inner.as_ref()
&& let Some(SymbolType::Variable(Type::Array(elem, size))) = symbols.get(name)
{
return Ok(Type::Pointer(Box::new(Type::Array(elem.clone(), *size))));
}
let inner_type = typecheck_expr(inner, symbols)?;
if matches!(inner_type, Type::Pointer(ref t) if matches!(t.as_ref(), Type::Function { .. }))
{
return Ok(inner_type);
}
if !is_lvalue(inner) {
return Err(CompileError::TypeError(
"cannot take address of non-lvalue expression".to_string(),
));
}
Ok(Type::Pointer(Box::new(inner_type)))
}
Expr::StringLiteral(_) => Ok(Type::Pointer(Box::new(Type::Char))),
Expr::SizeOfType(ty) => {
if ty.is_incomplete() {
return Err(CompileError::TypeError(
"sizeof applied to incomplete type".to_string(),
));
}
let size = ty.size() as u64;
*expr = Expr::ConstantULong(size);
Ok(Type::ULong)
}
Expr::SizeOfExpr(inner) => {
let actual_type = if let Expr::Var(name) = inner.as_ref() {
match symbols.get(name) {
Some(SymbolType::Variable(t)) => t.clone(),
_ => typecheck_expr(inner, symbols)?,
}
} else if matches!(inner.as_ref(), Expr::Dot(..)) {
let Expr::Dot(struct_expr, member_name) = inner.as_mut() else {
unreachable!()
};
let inner_type = typecheck_expr(struct_expr, symbols)?;
match &inner_type {
Type::Struct {
members, is_union, ..
} => match struct_member_offset_ex(members, member_name, *is_union) {
Some((_, member_type)) => member_type,
None => {
return Err(CompileError::TypeError(format!(
"member access on non-existent member '{}'",
member_name
)));
}
},
_ => {
return Err(CompileError::TypeError(
"member access on non-struct type".to_string(),
));
}
}
} else {
typecheck_expr(inner, symbols)?
};
if actual_type.is_incomplete() {
return Err(CompileError::TypeError(
"sizeof applied to expression of incomplete type".to_string(),
));
}
let size = actual_type.size() as u64;
*expr = Expr::ConstantULong(size);
Ok(Type::ULong)
}
Expr::Dot(inner, member_name) => {
let inner_type = typecheck_expr(inner, symbols)?;
match &inner_type {
Type::Struct {
members,
tag,
is_union,
} => match struct_member_offset_ex(members, member_name, *is_union) {
Some((_, member_type)) => Ok(array_decay(member_type)),
None => Err(CompileError::TypeError(format!(
"struct '{}' has no member '{}'",
tag, member_name
))),
},
_ => Err(CompileError::TypeError(
"member access on non-struct type".to_string(),
)),
}
}
Expr::CompoundInit(_) => Err(CompileError::TypeError(
"compound initializer not allowed in expression context".to_string(),
)),
Expr::CompoundLiteral { target_type, init } => {
if let Expr::CompoundInit(inits) = init.as_mut() {
if let Type::Struct { members, .. } = target_type {
for (i, init_expr) in inits.iter_mut().enumerate() {
if i < members.len() {
let member_type = &members[i].member_type;
let init_type = typecheck_expr(init_expr, symbols)?;
if init_type != *member_type {
let old = std::mem::replace(init_expr, Expr::Constant(0));
*init_expr = Expr::Cast {
target_type: member_type.clone(),
source_type: init_type,
expr: Box::new(old),
};
}
}
}
} else if let Type::Array(elem_type, _) = target_type {
for init_expr in inits.iter_mut() {
let init_type = typecheck_expr(init_expr, symbols)?;
if init_type != **elem_type {
let old = std::mem::replace(init_expr, Expr::Constant(0));
*init_expr = Expr::Cast {
target_type: *elem_type.clone(),
source_type: init_type,
expr: Box::new(old),
};
}
}
} else {
if let Some(init_expr) = inits.first_mut() {
let init_type = typecheck_expr(init_expr, symbols)?;
if init_type != *target_type {
let old = std::mem::replace(init_expr, Expr::Constant(0));
*init_expr = Expr::Cast {
target_type: target_type.clone(),
source_type: init_type,
expr: Box::new(old),
};
}
}
}
}
if let Type::Array(elem, _) = target_type {
Ok(Type::Pointer(elem.clone()))
} else {
Ok(target_type.clone())
}
}
Expr::VaStart(ap) => {
let ap_type = typecheck_expr(ap, symbols)?;
if ap_type != Type::VaList {
return Err(CompileError::TypeError(
"va_start requires va_list argument".into(),
));
}
Ok(Type::Void)
}
Expr::VaArg { ap, arg_type } => {
let ap_type = typecheck_expr(ap, symbols)?;
if ap_type != Type::VaList {
return Err(CompileError::TypeError(
"va_arg requires va_list argument".into(),
));
}
Ok(arg_type.clone())
}
Expr::VaEnd(ap) => {
let ap_type = typecheck_expr(ap, symbols)?;
if ap_type != Type::VaList {
return Err(CompileError::TypeError(
"va_end requires va_list argument".into(),
));
}
Ok(Type::Void)
}
Expr::VaCopy(dst, src) => {
typecheck_expr(dst, symbols)?;
typecheck_expr(src, symbols)?;
Ok(Type::Void)
}
}
}
fn resolve_constant(expr: &mut Expr) {
match expr {
Expr::Constant(v) => {
if *v > i32::MAX as i64 || *v < i32::MIN as i64 {
let val = *v;
*expr = Expr::ConstantLong(val);
}
}
Expr::ConstantUInt(v) => {
if *v > u32::MAX as u64 {
let val = *v;
*expr = Expr::ConstantULong(val);
}
}
Expr::ConstantDouble(_) | Expr::ConstantFloat(_) => {
}
_ => {}
}
}
fn apply_default_argument_promotion(arg: &mut Expr, arg_type: &Type) {
let promoted = if arg_type.is_character() || arg_type.is_short() {
Some(Type::Int)
} else if arg_type.is_float() {
Some(Type::Double)
} else {
None
};
if let Some(target) = promoted {
let old_arg = std::mem::replace(arg, Expr::Constant(0));
*arg = Expr::Cast {
target_type: target,
source_type: arg_type.clone(),
expr: Box::new(old_arg),
};
}
}
struct CallSig {
return_type: Type,
param_types: Option<Vec<Type>>,
is_variadic: bool,
}
fn typecheck_call_args(
name: &str,
args: &mut [Expr],
sig: &CallSig,
symbols: &HashMap<String, SymbolType>,
) -> Result<()> {
let param_types = match &sig.param_types {
Some(pt) => pt,
None => {
for arg in args.iter_mut() {
let arg_type = typecheck_expr(arg, symbols)?;
apply_default_argument_promotion(arg, &arg_type);
}
return Ok(());
}
};
if sig.is_variadic {
if args.len() < param_types.len() {
return Err(CompileError::TypeError(format!(
"function '{}' requires at least {} arguments, got {}",
name,
param_types.len(),
args.len()
)));
}
} else if args.len() != param_types.len() {
return Err(CompileError::TypeError(format!(
"function '{}' expects {} arguments, got {}",
name,
param_types.len(),
args.len()
)));
}
for (arg, expected_type) in args.iter_mut().zip(param_types.iter()) {
let arg_type = typecheck_expr(arg, symbols)?;
if arg_type != *expected_type {
let is_fn_ptr = |t: &Type| matches!(t, Type::Pointer(inner) if matches!(inner.as_ref(), Type::Function { .. }));
if expected_type.is_pointer()
&& arg_type.is_pointer()
&& *expected_type != arg_type
&& !is_void_pointer(expected_type)
&& !is_void_pointer(&arg_type)
&& !(is_fn_ptr(expected_type) && is_fn_ptr(&arg_type))
{
return Err(CompileError::TypeError(format!(
"incompatible pointer types in argument to function '{}'",
name
)));
}
let old_arg = std::mem::replace(arg, Expr::Constant(0));
*arg = Expr::Cast {
target_type: expected_type.clone(),
source_type: arg_type,
expr: Box::new(old_arg),
};
}
}
if sig.is_variadic {
for arg in args.iter_mut().skip(param_types.len()) {
let arg_type = typecheck_expr(arg, symbols)?;
apply_default_argument_promotion(arg, &arg_type);
}
}
Ok(())
}
fn array_decay(ty: Type) -> Type {
match ty {
Type::Array(elem, _) => Type::Pointer(elem),
Type::Function { .. } => Type::Pointer(Box::new(ty)),
other => other,
}
}
fn is_array_var(expr: &Expr, symbols: &HashMap<String, SymbolType>) -> bool {
if let Expr::Var(name) = expr
&& let Some(SymbolType::Variable(t)) = symbols.get(name)
{
return t.is_array();
}
false
}
fn is_lvalue(expr: &Expr) -> bool {
match expr {
Expr::Var(_) | Expr::Dereref(_) => true,
Expr::Dot(inner, _) => is_lvalue(inner),
_ => false,
}
}
fn is_void_pointer(ty: &Type) -> bool {
matches!(ty, Type::Pointer(inner) if inner.is_void())
}
fn is_null_pointer_constant(expr: &Expr) -> bool {
matches!(
expr,
Expr::Constant(0) | Expr::ConstantLong(0) | Expr::ConstantUInt(0) | Expr::ConstantULong(0)
)
}
fn integer_promote(t: &Type) -> Type {
if t.is_character() {
Type::Int
} else {
t.clone()
}
}
fn common_type(a: &Type, b: &Type) -> Type {
let a = if a.is_character() || a.is_short() {
&Type::Int
} else {
a
};
let b = if b.is_character() || b.is_short() {
&Type::Int
} else {
b
};
if a == b {
return a.clone();
}
if *a == Type::Double || *b == Type::Double {
return Type::Double;
}
if *a == Type::Float || *b == Type::Float {
return Type::Float;
}
if *a == Type::ULong || *b == Type::ULong {
return Type::ULong;
}
if *a == Type::Long || *b == Type::Long {
return Type::Long;
}
if *a == Type::UInt || *b == Type::UInt {
return Type::UInt;
}
Type::Int
}
fn convert_operand(expr: &mut Box<Expr>, from: &Type, to: &Type) {
if from != to {
let old = std::mem::replace(expr.as_mut(), Expr::Constant(0));
**expr = Expr::Cast {
target_type: to.clone(),
source_type: from.clone(),
expr: Box::new(old),
};
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parse::ast::*;
#[test]
fn typecheck_constant_int() {
let mut program = Program {
declarations: vec![TopLevelDecl::Function(FunctionDecl {
name: "main".to_string(),
return_type: Type::Int,
params: vec![],
body: Some(vec![BlockItem::Statement(Statement::Return(Some(
Expr::Constant(42),
)))]),
storage_class: None,
is_variadic: false,
has_prototype: true,
})],
};
typecheck(&mut program).unwrap();
let func = match &program.declarations[0] {
TopLevelDecl::Function(f) => f,
_ => panic!(),
};
if let BlockItem::Statement(Statement::Return(Some(expr))) = &func.body.as_ref().unwrap()[0]
{
assert!(matches!(expr, Expr::Constant(42)));
} else {
panic!("expected return");
}
}
#[test]
fn typecheck_constant_too_large_for_int() {
let mut program = Program {
declarations: vec![TopLevelDecl::Function(FunctionDecl {
name: "main".to_string(),
return_type: Type::Long,
params: vec![],
body: Some(vec![
BlockItem::Statement(Statement::Return(Some(Expr::Constant(8589934592)))), ]),
storage_class: None,
is_variadic: false,
has_prototype: true,
})],
};
typecheck(&mut program).unwrap();
let func = match &program.declarations[0] {
TopLevelDecl::Function(f) => f,
_ => panic!(),
};
if let BlockItem::Statement(Statement::Return(Some(expr))) = &func.body.as_ref().unwrap()[0]
{
assert!(matches!(expr, Expr::ConstantLong(8589934592)));
} else {
panic!("expected return");
}
}
#[test]
fn typecheck_cast_on_return_type_mismatch() {
let mut program = Program {
declarations: vec![TopLevelDecl::Function(FunctionDecl {
name: "main".to_string(),
return_type: Type::Int,
params: vec![],
body: Some(vec![BlockItem::Statement(Statement::Return(Some(
Expr::ConstantLong(42),
)))]),
storage_class: None,
is_variadic: false,
has_prototype: true,
})],
};
typecheck(&mut program).unwrap();
let func = match &program.declarations[0] {
TopLevelDecl::Function(f) => f,
_ => panic!(),
};
if let BlockItem::Statement(Statement::Return(Some(expr))) = &func.body.as_ref().unwrap()[0]
{
assert!(matches!(
expr,
Expr::Cast {
target_type: Type::Int,
..
}
));
} else {
panic!("expected return with cast");
}
}
#[test]
fn typecheck_binary_promotion() {
let mut program = Program {
declarations: vec![TopLevelDecl::Function(FunctionDecl {
name: "main".to_string(),
return_type: Type::Long,
params: vec![],
body: Some(vec![
BlockItem::Declaration(Declaration {
name: "a".to_string(),
var_type: Type::Int,
init: Some(Expr::Constant(1)),
storage_class: None,
}),
BlockItem::Declaration(Declaration {
name: "b".to_string(),
var_type: Type::Long,
init: Some(Expr::ConstantLong(2)),
storage_class: None,
}),
BlockItem::Statement(Statement::Return(Some(Expr::Binary(
BinaryOp::Add,
Box::new(Expr::Var("a".to_string())),
Box::new(Expr::Var("b".to_string())),
)))),
]),
storage_class: None,
is_variadic: false,
has_prototype: true,
})],
};
typecheck(&mut program).unwrap();
let func = match &program.declarations[0] {
TopLevelDecl::Function(f) => f,
_ => panic!(),
};
if let BlockItem::Statement(Statement::Return(Some(Expr::Binary(_, left, _)))) =
&func.body.as_ref().unwrap()[2]
{
assert!(matches!(
left.as_ref(),
Expr::Cast {
target_type: Type::Long,
..
}
));
} else {
panic!("expected return with binary");
}
}
}