use inkwell::context::Context;
use inkwell::builder::Builder;
use inkwell::module::Module;
use inkwell::values::{FunctionValue, BasicValueEnum, PointerValue};
use inkwell::types::BasicMetadataTypeEnum;
use inkwell::IntPredicate;
use inkwell::targets::{InitializationConfig, Target, TargetMachine, FileType, RelocMode, CodeModel};
use inkwell::{AddressSpace, OptimizationLevel};
use std::collections::HashMap;
use crate::ast::*;
pub struct Codegen<'ctx> {
pub context: &'ctx Context,
pub module: Module<'ctx>,
pub builder: Builder<'ctx>,
variables: HashMap<String, PointerValue<'ctx>>,
current_function: Option<FunctionValue<'ctx>>,
}
impl<'ctx> Codegen<'ctx> {
pub fn new(context: &'ctx Context, module_name: &str) -> Self {
let module = context.create_module(module_name);
let builder = context.create_builder();
Self {
context,
module,
builder,
variables: HashMap::new(),
current_function: None,
}
}
pub fn compile_program(&mut self, program: &Program) -> Result<(), String> {
for item in &program.items {
if let TopLevel::Func(func) = item {
self.declare_function(func)?;
}
}
for item in &program.items {
if let TopLevel::Func(func) = item {
self.define_function(func)?;
}
}
Ok(())
}
fn declare_function(&mut self, func: &FuncDecl) -> Result<(), String> {
let i64_type = self.context.i64_type();
let param_types: Vec<BasicMetadataTypeEnum> = func
.params
.iter()
.map(|_| i64_type.into())
.collect();
let fn_type = i64_type.fn_type(¶m_types, false);
self.module.add_function(&func.name, fn_type, None);
Ok(())
}
fn define_function(&mut self, func: &FuncDecl) -> Result<(), String> {
self.variables.clear();
let function = self.module.get_function(&func.name)
.ok_or_else(|| format!("Function {} not declared", func.name))?;
self.current_function = Some(function);
let entry = self.context.append_basic_block(function, "entry");
self.builder.position_at_end(entry);
let i64_type = self.context.i64_type();
for (i, param) in func.params.iter().enumerate() {
let alloca = self.builder.build_alloca(i64_type, ¶m.name)
.map_err(|e| format!("Failed to build alloca: {}", e))?;
let param_value = function.get_nth_param(i as u32)
.ok_or_else(|| format!("Failed to get parameter {}", i))?;
self.builder.build_store(alloca, param_value)
.map_err(|e| format!("Failed to store param: {}", e))?;
self.variables.insert(param.name.clone(), alloca);
}
self.compile_block(&func.body)?;
if self.builder.get_insert_block().unwrap().get_terminator().is_none() {
let zero = i64_type.const_int(0, false);
self.builder.build_return(Some(&zero))
.map_err(|e| format!("Failed to build return: {}", e))?;
}
Ok(())
}
fn compile_block(&mut self, block: &Block) -> Result<(), String> {
for stmt in &block.statements {
self.compile_stmt(stmt)?;
}
Ok(())
}
fn compile_stmt(&mut self, stmt: &Stmt) -> Result<(), String> {
match stmt {
Stmt::Return(expr) => {
if let Some(e) = expr {
let val = self.compile_expr(e)?;
self.builder.build_return(Some(&val))
.map_err(|e| format!("Failed to build return: {}", e))?;
} else {
let zero = self.context.i64_type().const_int(0, false);
self.builder.build_return(Some(&zero))
.map_err(|e| format!("Failed to build return: {}", e))?;
}
}
Stmt::VarDecl(decl) => {
let i64_type = self.context.i64_type();
let init_val = self.compile_expr(&decl.init)?;
if let Some(&ptr) = self.variables.get(&decl.name) {
self.builder.build_store(ptr, init_val)
.map_err(|e| format!("Failed to store: {}", e))?;
} else {
let alloca = self.builder.build_alloca(i64_type, &decl.name)
.map_err(|e| format!("Failed to build alloca: {}", e))?;
self.builder.build_store(alloca, init_val)
.map_err(|e| format!("Failed to store: {}", e))?;
self.variables.insert(decl.name.clone(), alloca);
}
}
Stmt::Expr(e) => {
self.compile_expr(e)?;
}
Stmt::If(if_stmt) => {
self.compile_if(if_stmt)?;
}
Stmt::While(while_stmt) => {
self.compile_while(while_stmt)?;
}
Stmt::For(for_stmt) => {
self.compile_for(for_stmt)?;
}
_ => {
return Err("Unsupported statement type".to_string());
}
}
Ok(())
}
fn compile_if(&mut self, if_stmt: &IfStmt) -> Result<(), String> {
let function = self.current_function
.ok_or("No current function")?;
let cond_val = self.compile_expr(&if_stmt.condition)?;
let cond_int = cond_val.into_int_value();
let cond_bool = if cond_int.get_type().get_bit_width() == 1 {
cond_int
} else {
let zero = self.context.i64_type().const_int(0, false);
self.builder.build_int_compare(
IntPredicate::NE,
cond_int,
zero,
"ifcond"
).map_err(|e| format!("Failed to compare: {}", e))?
};
let then_bb = self.context.append_basic_block(function, "then");
let else_bb = self.context.append_basic_block(function, "else");
let merge_bb = self.context.append_basic_block(function, "ifcont");
self.builder.build_conditional_branch(cond_bool, then_bb, else_bb)
.map_err(|e| format!("Failed to build branch: {}", e))?;
self.builder.position_at_end(then_bb);
self.compile_block(&if_stmt.then_branch)?;
if self.builder.get_insert_block().unwrap().get_terminator().is_none() {
self.builder.build_unconditional_branch(merge_bb)
.map_err(|e| format!("Failed to build branch: {}", e))?;
}
self.builder.position_at_end(else_bb);
if let Some(else_block) = &if_stmt.else_branch {
self.compile_block(else_block)?;
}
if self.builder.get_insert_block().unwrap().get_terminator().is_none() {
self.builder.build_unconditional_branch(merge_bb)
.map_err(|e| format!("Failed to build branch: {}", e))?;
}
self.builder.position_at_end(merge_bb);
Ok(())
}
fn compile_while(&mut self, while_stmt: &WhileStmt) -> Result<(), String> {
let function = self.current_function
.ok_or("No current function")?;
let cond_bb = self.context.append_basic_block(function, "whilecond");
let body_bb = self.context.append_basic_block(function, "whilebody");
let after_bb = self.context.append_basic_block(function, "afterwhile");
self.builder.build_unconditional_branch(cond_bb)
.map_err(|e| format!("Failed to build branch: {}", e))?;
self.builder.position_at_end(cond_bb);
let cond_val = self.compile_expr(&while_stmt.condition)?;
let cond_int = cond_val.into_int_value();
let cond_bool = if cond_int.get_type().get_bit_width() == 1 {
cond_int
} else {
let zero = self.context.i64_type().const_int(0, false);
self.builder.build_int_compare(
IntPredicate::NE,
cond_int,
zero,
"whilecond"
).map_err(|e| format!("Failed to compare: {}", e))?
};
self.builder.build_conditional_branch(cond_bool, body_bb, after_bb)
.map_err(|e| format!("Failed to build branch: {}", e))?;
self.builder.position_at_end(body_bb);
self.compile_block(&while_stmt.body)?;
if self.builder.get_insert_block().unwrap().get_terminator().is_none() {
self.builder.build_unconditional_branch(cond_bb)
.map_err(|e| format!("Failed to build branch: {}", e))?;
}
self.builder.position_at_end(after_bb);
Ok(())
}
fn compile_for(&mut self, for_stmt: &ForStmt) -> Result<(), String> {
let function = self.current_function.ok_or("No current function")?;
let cond_bb = self.context.append_basic_block(function, "forcond");
let body_bb = self.context.append_basic_block(function, "forbody");
let after_bb = self.context.append_basic_block(function, "afterfor");
if let Some(init_stmt) = &for_stmt.init {
self.compile_stmt(init_stmt)?;
}
self.builder.build_unconditional_branch(cond_bb)
.map_err(|e| format!("Failed to build branch: {}", e))?;
self.builder.position_at_end(cond_bb);
let cond_val = self.compile_expr(&for_stmt.condition)?;
let cond_int = cond_val.into_int_value();
let cond_bool = if cond_int.get_type().get_bit_width() == 1 {
cond_int
} else {
let zero = self.context.i64_type().const_int(0, false);
self.builder.build_int_compare(IntPredicate::NE, cond_int, zero, "forcond")
.map_err(|e| format!("Failed to compare: {}", e))?
};
self.builder.build_conditional_branch(cond_bool, body_bb, after_bb)
.map_err(|e| format!("Failed to build branch: {}", e))?;
self.builder.position_at_end(body_bb);
self.compile_block(&for_stmt.body)?;
if self.builder.get_insert_block().unwrap().get_terminator().is_none() {
if let Some(update_stmt) = &for_stmt.update {
self.compile_stmt(update_stmt)?;
}
self.builder.build_unconditional_branch(cond_bb)
.map_err(|e| format!("Failed to build branch: {}", e))?;
}
self.builder.position_at_end(after_bb);
Ok(())
}
fn compile_expr(&mut self, expr: &Expr) -> Result<BasicValueEnum<'ctx>, String> {
match expr {
Expr::Literal(lit) => {
match lit {
Literal::Int(n) => {
let val = self.context.i64_type().const_int(*n as u64, false);
Ok(val.into())
}
Literal::Bool(b) => {
let i1 = self.context.bool_type().const_int(if *b { 1 } else { 0 }, false);
Ok(i1.into())
}
_ => Err("Unsupported literal type".to_string()),
}
}
Expr::Ident(name) => {
let ptr = self.variables.get(name)
.ok_or_else(|| format!("Undefined variable: {}", name))?;
let val = self.builder.build_load(self.context.i64_type(), *ptr, name)
.map_err(|e| format!("Failed to load: {}", e))?;
Ok(val)
}
Expr::Binary(lhs, op, rhs) => {
let left = self.compile_expr(lhs)?;
let right = self.compile_expr(rhs)?;
let left_int = left.into_int_value();
let right_int = right.into_int_value();
let result = match op {
BinOp::Add => self.builder.build_int_add(left_int, right_int, "addtmp")
.map_err(|e| format!("Failed to add: {}", e))?,
BinOp::Sub => self.builder.build_int_sub(left_int, right_int, "subtmp")
.map_err(|e| format!("Failed to sub: {}", e))?,
BinOp::Mul => self.builder.build_int_mul(left_int, right_int, "multmp")
.map_err(|e| format!("Failed to mul: {}", e))?,
BinOp::Div => self.builder.build_int_signed_div(left_int, right_int, "divtmp")
.map_err(|e| format!("Failed to div: {}", e))?,
BinOp::Mod => self.builder.build_int_signed_rem(left_int, right_int, "modtmp")
.map_err(|e| format!("Failed to mod: {}", e))?,
BinOp::Eq => self.builder.build_int_compare(IntPredicate::EQ, left_int, right_int, "eqtmp")
.map_err(|e| format!("Failed to compare: {}", e))?,
BinOp::NotEq => self.builder.build_int_compare(IntPredicate::NE, left_int, right_int, "netmp")
.map_err(|e| format!("Failed to compare: {}", e))?,
BinOp::Lt => self.builder.build_int_compare(IntPredicate::SLT, left_int, right_int, "lttmp")
.map_err(|e| format!("Failed to compare: {}", e))?,
BinOp::Gt => self.builder.build_int_compare(IntPredicate::SGT, left_int, right_int, "gttmp")
.map_err(|e| format!("Failed to compare: {}", e))?,
BinOp::LtEq => self.builder.build_int_compare(IntPredicate::SLE, left_int, right_int, "letmp")
.map_err(|e| format!("Failed to compare: {}", e))?,
BinOp::GtEq => self.builder.build_int_compare(IntPredicate::SGE, left_int, right_int, "getmp")
.map_err(|e| format!("Failed to compare: {}", e))?,
_ => return Err(format!("Unsupported binary operator: {:?}", op)),
};
Ok(result.into())
}
Expr::Call(func_expr, args) => {
if let Expr::Ident(func_name) = func_expr.as_ref() {
if func_name == "print" || func_name == "println" {
let (fmt_str, mut call_args): (String, Vec<inkwell::values::BasicMetadataValueEnum>) = {
match args.get(0).ok_or("print/println expect exactly one argument")? {
Expr::Literal(Literal::String(s)) => {
let unquoted = s.trim_matches('"');
let gsp = self.builder.build_global_string_ptr(unquoted, "strlit")
.map_err(|e| format!("Failed to build string: {}", e))?;
let fmt = if func_name == "println" { "%s\n".to_string() } else { "%s".to_string() };
let mut v = Vec::new();
v.push(gsp.as_pointer_value().into());
(fmt, v)
}
_ => {
let val = self.compile_expr(&args[0])?;
let is_bool = val.into_int_value().get_type().get_bit_width() == 1;
if is_bool {
let true_ptr = self.builder.build_global_string_ptr("true", "true")
.map_err(|e| format!("Failed to build string: {}", e))?;
let false_ptr = self.builder.build_global_string_ptr("false", "false")
.map_err(|e| format!("Failed to build string: {}", e))?;
let cond = val.into_int_value();
let selected = self.builder.build_select(cond, true_ptr.as_pointer_value(), false_ptr.as_pointer_value(), "boolstr")
.map_err(|e| format!("Failed to select: {}", e))?;
let fmt = if func_name == "println" { "%s\n".to_string() } else { "%s".to_string() };
let mut v = Vec::new();
v.push(selected.into());
(fmt, v)
} else {
let fmt = if func_name == "println" { "%ld\n".to_string() } else { "%ld".to_string() };
let mut v = Vec::new();
v.push(val.into());
(fmt, v)
}
}
}
};
let fmt_ptr = self.builder.build_global_string_ptr(&fmt_str, "fmt")
.map_err(|e| format!("Failed to build string: {}", e))?;
let printf = match self.module.get_function("printf") {
Some(f) => f,
None => {
let i32_type = self.context.i32_type();
let ptr_type = self.context.ptr_type(AddressSpace::default());
let fn_type = i32_type.fn_type(&[ptr_type.into()], true);
self.module.add_function("printf", fn_type, None)
}
};
call_args.insert(0, fmt_ptr.as_pointer_value().into());
let _ = self.builder.build_call(printf, &call_args, "printcall")
.map_err(|e| format!("Failed to call printf: {}", e))?;
let zero = self.context.i64_type().const_int(0, false);
return Ok(zero.into());
}
let function = self.module.get_function(func_name)
.ok_or_else(|| format!("Undefined function: {}", func_name))?;
let mut arg_values = Vec::new();
for arg in args {
let val = self.compile_expr(arg)?;
arg_values.push(val.into());
}
let call_result = self.builder.build_call(function, &arg_values, "calltmp")
.map_err(|e| format!("Failed to call: {}", e))?;
Ok(call_result.try_as_basic_value().left()
.ok_or("Function call returned void")?)
} else {
Err("Only direct function calls supported".to_string())
}
}
Expr::Assign(lhs, rhs) => {
if let Expr::Ident(var_name) = lhs.as_ref() {
let val = self.compile_expr(rhs)?;
let ptr = *self.variables.get(var_name)
.ok_or_else(|| format!("Undefined variable: {}", var_name))?;
self.builder.build_store(ptr, val)
.map_err(|e| format!("Failed to store: {}", e))?;
Ok(val)
} else {
Err("Only simple variable assignment supported".to_string())
}
}
Expr::CompoundAssign(lhs, op, rhs) => {
if let Expr::Ident(var_name) = lhs.as_ref() {
let ptr = *self.variables.get(var_name)
.ok_or_else(|| format!("Undefined variable: {}", var_name))?;
let current_val = self.builder.build_load(self.context.i64_type(), ptr, var_name)
.map_err(|e| format!("Failed to load: {}", e))?;
let rhs_val = self.compile_expr(rhs)?;
let current_int = current_val.into_int_value();
let rhs_int = rhs_val.into_int_value();
let result = match op {
BinOp::Add => self.builder.build_int_add(current_int, rhs_int, "addtmp")
.map_err(|e| format!("Failed to add: {}", e))?,
BinOp::Sub => self.builder.build_int_sub(current_int, rhs_int, "subtmp")
.map_err(|e| format!("Failed to sub: {}", e))?,
BinOp::Mul => self.builder.build_int_mul(current_int, rhs_int, "multmp")
.map_err(|e| format!("Failed to mul: {}", e))?,
BinOp::Div => self.builder.build_int_signed_div(current_int, rhs_int, "divtmp")
.map_err(|e| format!("Failed to div: {}", e))?,
_ => return Err(format!("Unsupported compound assignment operator: {:?}", op)),
};
self.builder.build_store(ptr, result)
.map_err(|e| format!("Failed to store: {}", e))?;
Ok(result.into())
} else {
Err("Only simple variable compound assignment supported".to_string())
}
}
Expr::MethodCall(obj, _method, args) => {
self.compile_expr(obj)?;
for arg in args {
self.compile_expr(arg)?;
}
Ok(self.context.i64_type().const_int(0, false).into())
}
Expr::StructInit(_struct_name, fields) => {
for (_, value_expr) in fields {
self.compile_expr(value_expr)?;
}
Ok(self.context.i64_type().const_int(0, false).into())
}
_ => Err(format!("Unsupported expression type: {:?}", expr)),
}
}
pub fn print_ir(&self) {
self.module.print_to_stderr();
}
pub fn write_ir(&self, path: &str) -> Result<(), String> {
self.module.print_to_file(path).map_err(|e| e.to_string())
}
pub fn write_object(&self, path: &str) -> Result<(), String> {
Target::initialize_all(&InitializationConfig::default());
let triple = TargetMachine::get_default_triple();
let target = Target::from_triple(&triple).map_err(|e| e.to_string())?;
let cpu = "generic";
let features = "";
let tm = target.create_target_machine(
&triple,
cpu,
features,
OptimizationLevel::Default,
RelocMode::Default,
CodeModel::Default,
).ok_or_else(|| "Failed to create target machine".to_string())?;
tm.write_to_file(&self.module, FileType::Object, std::path::Path::new(path))
.map_err(|e| e.to_string())
}
}