use super::CodeGenerator;
use crate::ast::{EnumVariant, ExpressionNode, Field, PrimitiveType, TypeKind};
use crate::semantics::{GenericContext, MethodSig, Type};
use inkwell::AddressSpace;
use inkwell::types::BasicType;
use inkwell::values::{BasicMetadataValueEnum, BasicValueEnum, IntValue, PointerValue};
use std::collections::HashMap;
impl<'a> CodeGenerator<'a> {
fn create_empty_collection_value(
&mut self,
new_fn: &str,
value_fn: &str,
) -> BasicValueEnum<'a> {
let raw_ptr = self
.generate_runtime_call(new_fn, &[])
.expect("should always return a value");
self.generate_runtime_call(value_fn, &[raw_ptr.into()])
.expect("should always return a value")
}
fn compute_class_field_default_value(
&mut self,
field: &Field,
) -> Result<BasicValueEnum<'a>, String> {
if let Some(default_expr) = &field.default_value {
let literal_val = self.generate_expression(default_expr)?;
if matches!(field.type_.kind, TypeKind::Primitive(_)) {
return Ok(self.box_value(literal_val).into());
}
return Ok(literal_val);
}
if matches!(field.type_.kind, TypeKind::Primitive(_)) {
let llvm_type = self.llvm_type_from_mux_type(&field.type_)?;
let zero_val = if llvm_type.is_int_type() {
llvm_type.into_int_type().const_zero().into()
} else if llvm_type.is_float_type() {
llvm_type.into_float_type().const_zero().into()
} else {
self.context
.ptr_type(AddressSpace::default())
.const_zero()
.into()
};
return Ok(self.box_value(zero_val).into());
}
Ok(self
.context
.ptr_type(AddressSpace::default())
.const_zero()
.into())
}
pub(super) fn generate_enum_constructors(
&mut self,
name: &str,
variants: &[EnumVariant],
) -> Result<(), String> {
for variant in variants {
let variant_name = &variant.name;
let full_name = format!("{}!{}", name, variant_name);
let field_count = variant.data.as_ref().map(|d| d.len()).unwrap_or(0);
let mut param_types = vec![];
if let Some(ref d) = variant.data {
for (_, t) in d {
let llvm_type = self.llvm_type_from_mux_type(t)?;
param_types.push(llvm_type.into());
}
}
let enum_type_basic = self.type_map.get(name).ok_or("Enum type not found")?;
let struct_type = enum_type_basic.into_struct_type();
let fn_type = enum_type_basic.fn_type(¶m_types, false);
let function = self.module.add_function(&full_name, fn_type, None);
let entry = self.context.append_basic_block(function, "entry");
self.builder.position_at_end(entry);
let tag_index = self.get_variant_index(name, variant_name)?;
let tag_val = self.context.i32_type().const_int(tag_index as u64, false);
let temp_ptr = self
.builder
.build_alloca(struct_type, "temp_struct")
.map_err(|e| e.to_string())?;
self.builder
.build_store(temp_ptr, struct_type.const_zero())
.map_err(|e| e.to_string())?;
let tag_ptr = self
.builder
.build_struct_gep(struct_type, temp_ptr, 0, "tag_ptr")
.map_err(|e| e.to_string())?;
self.builder
.build_store(tag_ptr, tag_val)
.map_err(|e| e.to_string())?;
for i in 0..field_count {
let arg = function
.get_nth_param(i as u32)
.expect("function parameter should exist at expected index");
let data_ptr = self
.builder
.build_struct_gep(struct_type, temp_ptr, (i + 1) as u32, "data_ptr")
.map_err(|e| e.to_string())?;
self.builder
.build_store(data_ptr, arg)
.map_err(|e| e.to_string())?;
}
let struct_val = self
.builder
.build_load(struct_type, temp_ptr, "struct")
.map_err(|e| e.to_string())?;
self.builder
.build_return(Some(&struct_val))
.map_err(|e| e.to_string())?;
self.constructors
.insert(format!("{}.{}", name, variant_name), function);
}
Ok(())
}
fn register_class_type(
&mut self,
name: &str,
type_name_global: PointerValue<'a>,
type_size: IntValue<'a>,
) -> Result<IntValue<'a>, String> {
let register_func = self
.runtime_function("mux_register_object_type")
.ok_or("mux_register_object_type not found")?;
let type_id = self
.builder
.build_call(
register_func,
&[type_name_global.into(), type_size.into()],
"type_id",
)
.map_err(|e| e.to_string())?;
let type_id_val = type_id
.try_as_basic_value()
.left()
.ok_or("type_id call should return a basic value")?
.into_int_value();
if let (Some(copy_fn), Some(destructor_fn)) = (
self.class_copy_fns.get(name).copied(),
self.class_destructor_fns.get(name).copied(),
) {
let register_copy = self
.runtime_function("mux_register_object_copy")
.ok_or("mux_register_object_copy not found")?;
self.builder
.build_call(
register_copy,
&[type_id_val.into(), copy_fn.into()],
"register_copy",
)
.map_err(|e| e.to_string())?;
let register_destructor = self
.runtime_function("mux_register_object_destructor")
.ok_or("mux_register_object_destructor not found")?;
self.builder
.build_call(
register_destructor,
&[type_id_val.into(), destructor_fn.into()],
"register_destructor",
)
.map_err(|e| e.to_string())?;
}
Ok(type_id_val)
}
fn allocate_class_object(
&mut self,
type_id_val: IntValue<'a>,
) -> Result<(PointerValue<'a>, PointerValue<'a>), String> {
let alloc_func = self
.runtime_function("mux_alloc_object")
.ok_or("mux_alloc_object not found")?;
let obj_ptr = self
.builder
.build_call(alloc_func, &[type_id_val.into()], "obj_ptr")
.map_err(|e| e.to_string())?;
let obj_value_ptr = obj_ptr
.try_as_basic_value()
.left()
.ok_or("alloc_object call should return a pointer value")?
.into_pointer_value();
let get_ptr_func = self
.runtime_function("mux_get_object_ptr")
.ok_or("mux_get_object_ptr not found")?;
let data_ptr = self
.builder
.build_call(get_ptr_func, &[obj_value_ptr.into()], "data_ptr")
.map_err(|e| e.to_string())?;
let struct_ptr = data_ptr
.try_as_basic_value()
.left()
.ok_or("mux_get_object_ptr should return a basic value")?
.into_pointer_value();
Ok((obj_value_ptr, struct_ptr))
}
pub(super) fn generate_class_constructors(
&mut self,
name: &str,
fields: &[Field],
interfaces: &HashMap<String, (Vec<Type>, HashMap<String, MethodSig>)>,
) -> Result<(), String> {
let full_name = format!("{}.new", name);
let param_types = vec![];
let ptr_type = self.context.ptr_type(AddressSpace::default());
let fn_type = ptr_type.fn_type(¶m_types, false);
let function = self.module.add_function(&full_name, fn_type, None);
let entry = self.context.append_basic_block(function, "entry");
self.builder.position_at_end(entry);
let type_name = format!("type_name_{}", name);
let type_name_global = self
.builder
.build_global_string_ptr(name, &type_name)
.map_err(|e| e.to_string())?;
if let Some(global) = self.module.get_global(&type_name) {
global.set_linkage(inkwell::module::Linkage::External);
}
let type_size = self
.type_map
.get(name)
.ok_or("Class type not found")?
.size_of()
.ok_or("Cannot get type size")?;
let type_id_val =
self.register_class_type(name, type_name_global.as_pointer_value(), type_size)?;
let (obj_value_ptr, struct_ptr) = self.allocate_class_object(type_id_val)?;
let class_type = self.type_map.get(name).ok_or("Class type not found")?;
let class_type_clone = *class_type;
let struct_ptr_typed = self
.builder
.build_pointer_cast(
struct_ptr,
self.context.ptr_type(AddressSpace::default()),
"struct_ptr",
)
.map_err(|e| e.to_string())?;
for field in fields.iter() {
let field_index = self
.field_map
.get(name)
.expect("class should be in field_map after type generation")
.get(&field.name)
.expect("field should exist in class after semantic analysis");
let field_ptr = self
.builder
.build_struct_gep(
class_type_clone,
struct_ptr_typed,
*field_index as u32,
&field.name,
)
.map_err(|e| e.to_string())?;
let default_value = self.compute_class_field_default_value(field)?;
self.builder
.build_store(field_ptr, default_value)
.map_err(|e| e.to_string())?;
}
for interface_name in interfaces.keys() {
let vtable_key = format!("{}_{}", name, interface_name);
let vtable_ptr = self
.vtable_map
.get(&vtable_key)
.ok_or(format!("Vtable not found for {}", vtable_key))?;
let vtable_field_name = format!("vtable_{}", interface_name);
let field_index = self
.field_map
.get(name)
.ok_or_else(|| format!("Field map not found for class {}", name))?
.get(&vtable_field_name)
.ok_or_else(|| {
format!(
"Vtable field {} not found in class {}",
vtable_field_name, name
)
})?;
let field_ptr = self
.builder
.build_struct_gep(
class_type_clone,
struct_ptr_typed,
*field_index as u32,
&vtable_field_name,
)
.map_err(|e| e.to_string())?;
self.builder
.build_store(field_ptr, *vtable_ptr)
.map_err(|e| e.to_string())?;
}
self.builder
.build_return(Some(&obj_value_ptr))
.map_err(|e| e.to_string())?;
self.constructors.insert(format!("{}.new", name), function);
Ok(())
}
pub(super) fn initialize_field_by_type(
&mut self,
field_ptr: PointerValue<'a>,
field_type: &Type,
is_generic_param: bool,
) -> Result<(), String> {
if is_generic_param {
let null_ptr = self.context.ptr_type(AddressSpace::default()).const_null();
self.builder
.build_store(field_ptr, null_ptr)
.map_err(|e| e.to_string())?;
return Ok(());
}
let resolved_type = self.resolve_type(field_type)?;
match resolved_type {
Type::Primitive(PrimitiveType::Bool) => {
let false_val = self.context.bool_type().const_int(0, false);
self.builder
.build_store(field_ptr, false_val)
.map_err(|e| e.to_string())?;
}
Type::Primitive(PrimitiveType::Int) => {
let zero_val = self.context.i64_type().const_int(0, false);
self.builder
.build_store(field_ptr, zero_val)
.map_err(|e| e.to_string())?;
}
Type::Primitive(PrimitiveType::Float) => {
let zero_val = self.context.f64_type().const_float(0.0);
self.builder
.build_store(field_ptr, zero_val)
.map_err(|e| e.to_string())?;
}
Type::Primitive(PrimitiveType::Str) => {
let null_ptr = self.context.ptr_type(AddressSpace::default()).const_null();
self.builder
.build_store(field_ptr, null_ptr)
.map_err(|e| e.to_string())?;
}
Type::List(_) => {
let val = self.create_empty_collection_value("mux_new_list", "mux_list_value");
self.builder
.build_store(field_ptr, val)
.map_err(|e| e.to_string())?;
}
Type::Map(_, _) => {
let val = self.create_empty_collection_value("mux_new_map", "mux_map_value");
self.builder
.build_store(field_ptr, val)
.map_err(|e| e.to_string())?;
}
Type::Set(_) => {
let val = self.create_empty_collection_value("mux_new_set", "mux_set_value");
self.builder
.build_store(field_ptr, val)
.map_err(|e| e.to_string())?;
}
Type::Optional(_) => {
let optional_ptr = self
.generate_runtime_call("mux_optional_none", &[])
.expect("mux_optional_none should always return a value");
self.builder
.build_store(field_ptr, optional_ptr)
.map_err(|e| e.to_string())?;
}
Type::Result(ok_type, _) => {
let ok_value = self.create_default_value_ptr(&ok_type)?;
let result_ptr = self
.generate_runtime_call("mux_result_ok_value", &[ok_value.into()])
.expect("mux_result_ok_value should always return a value");
self.builder
.build_store(field_ptr, result_ptr)
.map_err(|e| e.to_string())?;
}
Type::Tuple(left_type, right_type) => {
let tuple_value = self.generate_tuple_constructor(&left_type, &right_type)?;
self.builder
.build_store(field_ptr, tuple_value)
.map_err(|e| e.to_string())?;
}
Type::Named(class_name, type_args) => {
if class_name == "string" && type_args.is_empty() {
let null_ptr = self.context.ptr_type(AddressSpace::default()).const_null();
self.builder
.build_store(field_ptr, null_ptr)
.map_err(|e| e.to_string())?;
} else if class_name == "bool" && type_args.is_empty() {
let false_val = self.context.bool_type().const_int(0, false);
self.builder
.build_store(field_ptr, false_val)
.map_err(|e| e.to_string())?;
} else {
let nested_obj =
self.generate_constructor_call_with_types(&class_name, &type_args, &[])?;
self.builder
.build_store(field_ptr, nested_obj)
.map_err(|e| e.to_string())?;
}
}
_ => return Err(format!("Unsupported field type: {:?}", resolved_type)),
}
Ok(())
}
pub(super) fn generate_tuple_constructor(
&mut self,
left_type: &Type,
right_type: &Type,
) -> Result<BasicValueEnum<'a>, String> {
let left_ptr = self.create_default_value_ptr(left_type)?;
let right_ptr = self.create_default_value_ptr(right_type)?;
let tuple_value = self
.generate_runtime_call("mux_new_tuple", &[left_ptr.into(), right_ptr.into()])
.expect("mux_new_tuple should always return a value");
let wrapped_value = self
.generate_runtime_call("mux_tuple_value", &[tuple_value.into()])
.expect("mux_tuple_value should always return a value");
Ok(wrapped_value)
}
pub(super) fn create_default_value_ptr(
&mut self,
mux_type: &Type,
) -> Result<PointerValue<'a>, String> {
let resolved_type = self.resolve_type(mux_type)?;
match resolved_type {
Type::Primitive(PrimitiveType::Int) => {
let zero = self.context.i64_type().const_zero();
Ok(self.box_value(zero.into()))
}
Type::Primitive(PrimitiveType::Float) => {
let zero = self.context.f64_type().const_zero();
Ok(self.box_value(zero.into()))
}
Type::Primitive(PrimitiveType::Bool) => {
let zero = self.context.bool_type().const_zero();
Ok(self.box_value(zero.into()))
}
Type::Primitive(PrimitiveType::Str) => {
let str_ptr = self
.builder
.build_global_string_ptr("", "empty_str")
.map_err(|e| e.to_string())?;
let value_ptr = self
.generate_runtime_call(
"mux_new_string_from_cstr",
&[str_ptr.as_pointer_value().into()],
)
.expect("mux_new_string_from_cstr should always return a value");
Ok(value_ptr.into_pointer_value())
}
Type::List(_) => {
let val = self.create_empty_collection_value("mux_new_list", "mux_list_value");
Ok(val.into_pointer_value())
}
Type::Map(_, _) => {
let val = self.create_empty_collection_value("mux_new_map", "mux_map_value");
Ok(val.into_pointer_value())
}
Type::Set(_) => {
let val = self.create_empty_collection_value("mux_new_set", "mux_set_value");
Ok(val.into_pointer_value())
}
Type::Tuple(left_type, right_type) => {
let tuple_value = self.generate_tuple_constructor(&left_type, &right_type)?;
Ok(tuple_value.into_pointer_value())
}
Type::Optional(_) => {
let optional_ptr = self
.generate_runtime_call("mux_optional_none", &[])
.expect("mux_optional_none should always return a value");
Ok(optional_ptr.into_pointer_value())
}
Type::Result(ok_type, _) => {
let ok_value = self.create_default_value_ptr(&ok_type)?;
let result_ptr = self
.generate_runtime_call("mux_result_ok_value", &[ok_value.into()])
.expect("mux_result_ok_value should always return a value");
Ok(result_ptr.into_pointer_value())
}
Type::Named(name, type_args) => {
if name == "optional" {
let optional_ptr = self
.generate_runtime_call("mux_optional_none", &[])
.expect("mux_optional_none should always return a value");
return Ok(optional_ptr.into_pointer_value());
}
if name == "result" {
if let Some(ok_type) = type_args.first() {
let ok_value = self.create_default_value_ptr(ok_type)?;
let result_ptr = self
.generate_runtime_call("mux_result_ok_value", &[ok_value.into()])
.expect("mux_result_ok_value should always return a value");
return Ok(result_ptr.into_pointer_value());
}
return Ok(self.context.ptr_type(AddressSpace::default()).const_zero());
}
if self.classes.contains_key(&name) {
let obj_value =
self.generate_constructor_call_with_types(&name, &type_args, &[])?;
return Ok(obj_value.into_pointer_value());
}
Ok(self.context.ptr_type(AddressSpace::default()).const_zero())
}
Type::Instantiated(name, type_args) => {
if self.classes.contains_key(&name) {
let obj_value =
self.generate_constructor_call_with_types(&name, &type_args, &[])?;
return Ok(obj_value.into_pointer_value());
}
Ok(self.context.ptr_type(AddressSpace::default()).const_zero())
}
_ => Ok(self.context.ptr_type(AddressSpace::default()).const_zero()),
}
}
pub(super) fn generate_constructor_call_with_types(
&mut self,
class_name: &str,
type_args: &[Type],
args: &[ExpressionNode],
) -> Result<BasicValueEnum<'a>, String> {
if class_name == "tuple"
&& type_args.len() == 2
&& let [left_type, right_type] = type_args
{
return self.generate_tuple_constructor(left_type, right_type);
}
let context = GenericContext {
type_params: self.build_type_param_map(class_name, type_args)?,
};
self.context_stack.push(context.clone());
self.generic_context = Some(context);
if !type_args.is_empty() {
self.generate_specialized_methods(class_name, type_args)?;
}
let result = self.generate_constructor_call(class_name, args);
self.context_stack.pop();
self.generic_context = self.context_stack.last().cloned();
result
}
#[allow(clippy::only_used_in_recursion)]
pub(super) fn sanitize_type_name(&self, type_: &Type) -> String {
match type_ {
Type::Primitive(PrimitiveType::Int) => "int".to_string(),
Type::Primitive(PrimitiveType::Float) => "float".to_string(),
Type::Primitive(PrimitiveType::Bool) => "bool".to_string(),
Type::Primitive(PrimitiveType::Str) => "string".to_string(),
Type::Named(name, type_args) => {
if type_args.is_empty() {
name.clone()
} else {
let args_str = type_args
.iter()
.map(|arg| self.sanitize_type_name(arg))
.collect::<Vec<_>>()
.join("_");
format!("{}_{}", name, args_str)
}
}
Type::Generic(name) | Type::Variable(name) => name.clone(),
Type::List(inner) => format!("list_{}", self.sanitize_type_name(inner)),
Type::Map(k, v) => format!(
"map_{}_{}",
self.sanitize_type_name(k),
self.sanitize_type_name(v)
),
Type::Set(inner) => format!("set_{}", self.sanitize_type_name(inner)),
Type::Optional(inner) => format!("optional_{}", self.sanitize_type_name(inner)),
Type::Result(ok, err) => format!(
"result_{}_{}",
self.sanitize_type_name(ok),
self.sanitize_type_name(err)
),
Type::Instantiated(name, type_args) => {
let args_str = type_args
.iter()
.map(|arg| self.sanitize_type_name(arg))
.collect::<Vec<_>>()
.join("$");
format!("{}${}", name, args_str)
}
_ => "unknown".to_string(),
}
}
pub(super) fn create_specialized_method_name(
&self,
class_name: &str,
type_args: &[Type],
method_name: &str,
) -> String {
if type_args.is_empty() {
format!("{}.{}", class_name, method_name)
} else {
let args_str = type_args
.iter()
.map(|t| self.sanitize_type_name(t))
.collect::<Vec<_>>()
.join("$");
format!("{}${}.{}", class_name, args_str, method_name)
}
}
pub(super) fn generate_constructor_call(
&mut self,
class_name: &str,
_args: &[ExpressionNode],
) -> Result<BasicValueEnum<'a>, String> {
let class_type = *self
.type_map
.get(class_name)
.ok_or(format!("Class '{}' not found in type map", class_name))?;
let type_name = format!("type_name_{}", class_name);
let type_name_global = self
.builder
.build_global_string_ptr(class_name, &type_name)
.map_err(|e| e.to_string())?;
if let Some(global) = self.module.get_global(&type_name) {
global.set_linkage(inkwell::module::Linkage::External);
}
let type_size = class_type.size_of().ok_or("Cannot get type size")?;
let type_id_val =
self.register_class_type(class_name, type_name_global.as_pointer_value(), type_size)?;
let (obj_value_ptr, struct_ptr) = self.allocate_class_object(type_id_val)?;
if let Some(fields) = self.classes.get(class_name) {
let fields_vec: Vec<_> = fields
.iter()
.enumerate()
.map(|(i, f)| (i, f.clone()))
.collect();
for (i, field) in fields_vec {
let field_ptr = self
.builder
.build_struct_gep(
class_type.into_struct_type(),
struct_ptr,
i as u32,
&format!("field_{}", i),
)
.map_err(|e| e.to_string())?;
let field_type = self.type_node_to_type(&field.type_);
self.initialize_field_by_type(field_ptr, &field_type, field.is_generic_param)?;
}
}
Ok(obj_value_ptr.into())
}
fn get_self_class_info(&mut self) -> Result<(PointerValue<'a>, String, Vec<Type>), String> {
let (self_ptr, _, _) = self
.variables
.get("self")
.or_else(|| self.global_variables.get("self"))
.ok_or("Self not found in method call")?;
if let Some((_, _, Type::Named(class_name, type_args))) = self
.variables
.get("self")
.or_else(|| self.global_variables.get("self"))
{
Ok((*self_ptr, class_name.clone(), type_args.clone()))
} else {
Err("Self type not found".to_string())
}
}
fn resolve_method_function_name(
&mut self,
class_name: &str,
type_args: &[Type],
method_name: &str,
) -> Result<String, String> {
let mut method_func_name =
self.create_specialized_method_name(class_name, type_args, method_name);
if !self.functions.contains_key(&method_func_name) {
if type_args.is_empty()
&& let Some(current_fn) = &self.current_function_name
&& let Some((current_class_part, _)) = current_fn.split_once('.')
&& current_class_part.starts_with(&format!("{}$", class_name))
{
let contextual_name = format!("{}.{}", current_class_part, method_name);
if self.functions.contains_key(&contextual_name) {
method_func_name = contextual_name;
} else {
method_func_name = format!("{}.{}", class_name, method_name);
}
} else {
method_func_name = format!("{}.{}", class_name, method_name);
}
}
Ok(method_func_name)
}
fn check_method_is_static(
&mut self,
class_name: &str,
method_name: &str,
) -> Result<(), String> {
if let Some(class) = self.analyzer.symbol_table().lookup(class_name)
&& let Some(method) = class.methods.get(method_name)
&& method.is_static
{
return Err(format!(
"Cannot call static method {} with self",
method_name
));
}
Ok(())
}
fn build_call_arguments_for_method(
&mut self,
self_ptr: PointerValue<'a>,
args: &[ExpressionNode],
is_specialized: bool,
) -> Result<Vec<BasicMetadataValueEnum<'a>>, String> {
let self_loaded = self
.builder
.build_load(
self.context.ptr_type(AddressSpace::default()),
self_ptr,
"load_self_for_method_call",
)
.map_err(|e| e.to_string())?;
let mut call_args: Vec<BasicMetadataValueEnum> = vec![self_loaded.into()];
for arg in args {
let arg_val = self.generate_expression(arg)?;
if is_specialized {
call_args.push(self.box_value(arg_val).into());
} else {
call_args.push(arg_val.into());
}
}
Ok(call_args)
}
pub(super) fn generate_method_call_on_self(
&mut self,
method_name: &str,
args: &[ExpressionNode],
) -> Result<BasicValueEnum<'a>, String> {
let (self_ptr, class_name, type_args) = self.get_self_class_info()?;
let method_func_name =
self.resolve_method_function_name(&class_name, &type_args, method_name)?;
self.check_method_is_static(&class_name, method_name)?;
let func_val = *self
.functions
.get(&method_func_name)
.ok_or(format!("Method {} not found", method_func_name))?;
let is_specialized = method_func_name.contains('$');
let call_args = self.build_call_arguments_for_method(self_ptr, args, is_specialized)?;
let call = self
.builder
.build_call(func_val, &call_args, &format!("call_{}", method_name))
.map_err(|e| e.to_string())?;
if let Some(value) = call.try_as_basic_value().left() {
Ok(value)
} else {
Ok(self.context.i64_type().const_zero().into())
}
}
}