use inkwell::AddressSpace;
use inkwell::types::BasicMetadataTypeEnum;
use inkwell::types::{BasicType, BasicTypeEnum};
use inkwell::values::FunctionValue;
use crate::ast::{AstNode, FunctionNode, PrimitiveType, StatementNode, TypeKind};
use crate::semantics::Type;
use super::CodeGenerator;
impl<'a> CodeGenerator<'a> {
fn resolve_base_class_name_for_method(method_name: &str) -> &str {
let class_name = method_name
.split('.')
.next()
.or_else(|| method_name.split('$').next())
.expect("class method name should contain '.' or '$'");
class_name.split('$').next().unwrap_or(class_name)
}
fn resolve_self_type_for_method(&self, base_class_name: &str) -> Type {
if let Some(ref context) = self.generic_context
&& let Some(class_symbol) = self.analyzer.symbol_table().lookup(base_class_name)
{
let type_args: Vec<Type> = class_symbol
.type_params
.iter()
.filter_map(|(param_name, _bounds)| context.type_params.get(param_name).cloned())
.collect();
return Type::Named(base_class_name.to_string(), type_args);
}
Type::Named(base_class_name.to_string(), vec![])
}
fn setup_method_self_parameter(
&mut self,
func: &FunctionNode,
function: FunctionValue<'a>,
param_index: &mut u32,
) -> Result<(), String> {
let base_class_name = Self::resolve_base_class_name_for_method(&func.name);
let class_type = self
.type_map
.get(base_class_name)
.expect("class type should be in type_map after type generation");
let arg = function
.get_nth_param(*param_index)
.expect("self parameter should exist for class methods");
*param_index += 1;
let ptr_type = self.context.ptr_type(AddressSpace::default());
let alloca = self
.builder
.build_alloca(ptr_type, "self")
.map_err(|e| e.to_string())?;
self.builder
.build_store(alloca, arg)
.map_err(|e| e.to_string())?;
let self_type = self.resolve_self_type_for_method(base_class_name);
self.variables
.insert("self".to_string(), (alloca, *class_type, self_type.clone()));
self.analyzer.current_self_type = Some(self_type);
Ok(())
}
fn is_enum_type(&self, resolved_type: &Type) -> bool {
matches!(resolved_type, Type::Named(type_name, _) if self
.analyzer
.symbol_table()
.lookup(type_name)
.map(|s| s.kind == crate::semantics::SymbolKind::Enum)
.unwrap_or(false))
}
fn store_enum_parameter(
&mut self,
param_name: &str,
arg: inkwell::values::BasicValueEnum<'a>,
resolved_type: Type,
) -> Result<(), String> {
let struct_type = arg.get_type();
let alloca = self
.builder
.build_alloca(struct_type, param_name)
.map_err(|e| e.to_string())?;
self.builder
.build_store(alloca, arg)
.map_err(|e| e.to_string())?;
self.variables
.insert(param_name.to_string(), (alloca, struct_type, resolved_type));
Ok(())
}
fn store_function_parameter(
&mut self,
param_name: &str,
arg: inkwell::values::BasicValueEnum<'a>,
resolved_type: Type,
) -> Result<(), String> {
let func_ptr_type = self.context.ptr_type(AddressSpace::default());
let alloca = self
.builder
.build_alloca(func_ptr_type, param_name)
.map_err(|e| e.to_string())?;
self.builder
.build_store(alloca, arg)
.map_err(|e| e.to_string())?;
self.variables.insert(
param_name.to_string(),
(alloca, func_ptr_type.into(), resolved_type),
);
Ok(())
}
fn store_boxed_parameter(
&mut self,
param_name: &str,
value_to_store: inkwell::values::PointerValue<'a>,
resolved_type: Type,
) -> Result<(), String> {
let ptr_type = self.context.ptr_type(AddressSpace::default());
let alloca = self
.builder
.build_alloca(ptr_type, param_name)
.map_err(|e| e.to_string())?;
self.builder
.build_store(alloca, value_to_store)
.map_err(|e| e.to_string())?;
self.variables.insert(
param_name.to_string(),
(alloca, BasicTypeEnum::PointerType(ptr_type), resolved_type),
);
Ok(())
}
fn store_function_parameter_value(
&mut self,
param_name: &str,
arg: inkwell::values::BasicValueEnum<'a>,
resolved_type: Type,
) -> Result<(), String> {
if matches!(resolved_type, Type::Reference(_)) {
return self.store_boxed_parameter(param_name, arg.into_pointer_value(), resolved_type);
}
if self.is_enum_type(&resolved_type) {
return self.store_enum_parameter(param_name, arg, resolved_type);
}
if matches!(resolved_type, Type::Function { .. }) {
return self.store_function_parameter(param_name, arg, resolved_type);
}
let boxed_value = self.box_value(arg);
self.store_boxed_parameter(param_name, boxed_value, resolved_type)
}
fn setup_function_parameters(
&mut self,
func: &FunctionNode,
function: FunctionValue<'a>,
start_param_index: u32,
) -> Result<(), String> {
for (i, param) in func.params.iter().enumerate() {
let arg = function
.get_nth_param((i as u32) + start_param_index)
.expect("function parameter should exist at expected index");
let resolved_type = self
.analyzer
.resolve_type(¶m.type_)
.map_err(|e| e.to_string())?;
self.store_function_parameter_value(¶m.name, arg, resolved_type)?;
}
Ok(())
}
fn resolve_decl_llvm_name(&self, func: &FunctionNode) -> String {
if func.name.contains('$') {
return func.name.clone();
}
if let Some(symbol) = self.analyzer.symbol_table().lookup(&func.name)
&& let Some(mangled_name) = &symbol.llvm_name
{
return mangled_name.clone();
}
for module_syms in self.analyzer.imported_symbols().values() {
if let Some(func_symbol) = module_syms.get(&func.name)
&& let Some(mangled) = &func_symbol.llvm_name
{
return mangled.clone();
}
}
func.name.clone()
}
pub(super) fn declare_function(&mut self, func: &FunctionNode) -> Result<(), String> {
let mut param_types: Vec<BasicMetadataTypeEnum> = func
.params
.iter()
.map(|p| self.llvm_type_from_mux_type(&p.type_).map(|t| t.into()))
.collect::<Result<_, _>>()?;
let is_class_method = func.name.contains('.');
if is_class_method && !func.is_common {
param_types.insert(0, self.context.ptr_type(AddressSpace::default()).into());
}
let is_specialized = func.name.contains('$');
let is_static = func.is_common;
if is_specialized && !is_static {
let ptr_type = self.context.ptr_type(AddressSpace::default());
param_types = param_types
.into_iter()
.enumerate()
.map(|(i, param_type)| {
if i == 0 && is_class_method && !func.is_common {
param_type
} else {
ptr_type.into()
}
})
.collect();
}
let fn_type = if matches!(
func.return_type.kind,
TypeKind::Primitive(PrimitiveType::Void)
) {
self.context.void_type().fn_type(¶m_types, false)
} else {
let return_type = self.llvm_type_from_mux_type(&func.return_type)?;
return_type.fn_type(¶m_types, false)
};
let llvm_name = self.resolve_decl_llvm_name(func);
let function = self.module.add_function(&llvm_name, fn_type, None);
self.functions.insert(func.name.clone(), function);
self.function_nodes.insert(func.name.clone(), func.clone());
Ok(())
}
pub(super) fn declare_function_with_name(
&mut self,
func: &FunctionNode,
llvm_name: &str,
) -> Result<(), String> {
let mut param_types: Vec<BasicMetadataTypeEnum> = func
.params
.iter()
.map(|p| self.llvm_type_from_mux_type(&p.type_).map(|t| t.into()))
.collect::<Result<_, _>>()?;
let is_class_method = func.name.contains('.');
if is_class_method && !func.is_common {
param_types.insert(0, self.context.ptr_type(AddressSpace::default()).into());
}
let is_specialized = func.name.contains('$');
let is_static = func.is_common;
if is_specialized && !is_static {
let ptr_type = self.context.ptr_type(AddressSpace::default());
param_types = param_types
.into_iter()
.enumerate()
.map(|(i, param_type)| {
if i == 0 && is_class_method && !func.is_common {
param_type
} else {
ptr_type.into()
}
})
.collect();
}
let fn_type = if matches!(
func.return_type.kind,
TypeKind::Primitive(PrimitiveType::Void)
) {
self.context.void_type().fn_type(¶m_types, false)
} else {
let return_type = self.llvm_type_from_mux_type(&func.return_type)?;
return_type.fn_type(¶m_types, false)
};
let function = self.module.add_function(llvm_name, fn_type, None);
self.functions.insert(llvm_name.to_string(), function);
Ok(())
}
pub(super) fn generate_module_init(
&mut self,
top_level_statements: &[StatementNode],
module_name: &str,
) -> Result<(), String> {
let init_name = format!("!{}!init", module_name.replace(['.', '/'], "_"));
let init_type = self.context.void_type().fn_type(&[], false);
let init_func = self.module.add_function(&init_name, init_type, None);
let entry = self.context.append_basic_block(init_func, "entry");
self.builder.position_at_end(entry);
self.variables = self.global_variables.clone();
for stmt in top_level_statements {
self.generate_statement(stmt, Some(&init_func))?;
}
self.builder.build_return(None).map_err(|e| e.to_string())?;
Ok(())
}
pub(super) fn generate_main_function(&mut self, module_name: &str) -> Result<(), String> {
let main_type = self.context.i32_type().fn_type(&[], false);
let main_func = self.module.add_function("main", main_type, None);
let entry = self.context.append_basic_block(main_func, "entry");
self.builder.position_at_end(entry);
for module_path in &self.analyzer.module_dependencies {
let init_name = format!("!{}!init", Self::sanitize_module_path(module_path));
if let Some(init_func) = self.module.get_function(&init_name) {
self.builder
.build_call(
init_func,
&[],
&format!("{}_init_call", module_path.replace('.', "_")),
)
.map_err(|e| e.to_string())?;
}
}
let init_name = format!("!{}!init", Self::sanitize_module_path(module_name));
if let Some(init_func) = self.module.get_function(&init_name) {
self.builder
.build_call(init_func, &[], "init_call")
.map_err(|e| e.to_string())?;
}
if let Some(user_main) = self.module.get_function("!user!main") {
self.builder
.build_call(user_main, &[], "user_main_call")
.map_err(|e| e.to_string())?;
}
self.builder
.build_return(Some(&self.context.i32_type().const_int(0, false)))
.map_err(|e| e.to_string())?;
Ok(())
}
pub(super) fn get_module_name(&self, nodes: &[AstNode]) -> String {
for node in nodes {
match node {
AstNode::Class { name, .. } => {
return name.split('.').next().unwrap_or("main").to_string();
}
AstNode::Function(func) => {
return func.name.split('.').next().unwrap_or("main").to_string();
}
_ => {}
}
}
"main".to_string()
}
pub(super) fn generate_function(&mut self, func: &FunctionNode) -> Result<(), String> {
let saved_function_name = self.current_function_name.take();
let saved_return_type = self.current_function_return_type.take();
let saved_self_type = self.analyzer.current_self_type.take();
let saved_rc_scope_stack = std::mem::take(&mut self.rc_scope_stack);
self.current_function_name = Some(func.name.clone());
self.current_function_return_type = Some(
self.analyzer
.resolve_type(&func.return_type)
.map_err(|e| e.to_string())?,
);
let function = *self
.functions
.get(&func.name)
.ok_or("Function not declared")?;
let entry = self.context.append_basic_block(function, "entry");
self.builder.position_at_end(entry);
self.variables.clear();
self.push_rc_scope();
let is_class_method = func.name.contains('.');
let mut param_index = 0u32;
if is_class_method && !func.is_common {
self.setup_method_self_parameter(func, function, &mut param_index)?;
}
self.setup_function_parameters(func, function, param_index)?;
for stmt in &func.body {
self.generate_statement(stmt, Some(&function))?;
}
if matches!(
func.return_type.kind,
TypeKind::Primitive(PrimitiveType::Void)
) && let Some(block) = self.builder.get_insert_block()
&& block.get_terminator().is_none()
{
self.generate_all_scopes_cleanup()?;
self.builder.build_return(None).map_err(|e| e.to_string())?;
}
self.rc_scope_stack.pop();
self.current_function_name = saved_function_name;
self.current_function_return_type = saved_return_type;
self.analyzer.current_self_type = saved_self_type;
self.rc_scope_stack = saved_rc_scope_stack;
Ok(())
}
pub(super) fn generate_function_with_llvm_name(
&mut self,
func: &FunctionNode,
llvm_name: &str,
) -> Result<(), String> {
let function = *self.functions.get(llvm_name).ok_or_else(|| {
format!(
"Function {} not declared (LLVM name: {})",
func.name, llvm_name
)
})?;
self.functions.insert(func.name.clone(), function);
let result = self.generate_function(func);
self.functions.remove(&func.name);
result
}
}