use super::CodeGenerator;
use crate::compiler::ast::{Type, Expr, mangle_wrap_args, mangle_base_name, mangle_has_args};
use crate::compiler::mangler::MANGLER;
use crate::compiler::ast_subst::TypeSubstitutor;
use inkwell::types::{BasicTypeEnum, StructType};
use inkwell::AddressSpace;
use std::collections::HashMap;
impl<'ctx> CodeGenerator<'ctx> {
pub fn monomorphize_method(
&mut self,
struct_name: &str,
method_name: &str,
generic_args: &[Type],
) -> Result<String, String> {
let args_str = format!("{:?}", generic_args);
if struct_name == "Vec" && method_name == "pop" && args_str.contains("K") {
panic!("Vec_pop_K_monomorphized!!");
}
if generic_args.is_empty() {
let mangled = format!("tl_{}_{}", struct_name, method_name);
if self.module.get_function(&mangled).is_some() {
return Ok(mangled);
}
}
let impls = self.generic_impls.get(struct_name)
.ok_or_else(|| format!("No generic impls found for struct {}", struct_name))?;
let mut target_method = None;
let mut target_impl = None;
for imp in impls {
for method in &imp.methods {
if method.name == method_name {
target_method = Some(method);
target_impl = Some(imp);
break;
}
}
if target_method.is_some() { break; }
}
let method = target_method.ok_or_else(|| format!("Method {} not found in generic impls of {}", method_name, struct_name))?;
let imp = target_impl.unwrap();
if !self.check_method_trait_bounds(method, generic_args, &imp.generics) {
return Err(format!("Method {}.{} skipped: trait bounds not satisfied for {:?}", struct_name, method_name, generic_args));
}
let mut final_generic_args = generic_args.to_vec();
if imp.generics.len() > generic_args.len() {
for _ in 0..(imp.generics.len() - generic_args.len()) {
final_generic_args.push(Type::I64);
}
} else if imp.generics.len() < generic_args.len() {
final_generic_args.truncate(imp.generics.len());
}
let generic_args = &final_generic_args;
let mut subst_map = HashMap::new();
for (param, arg) in imp.generics.iter().zip(generic_args) {
subst_map.insert(param.clone(), arg.clone());
}
let mangled_name = crate::compiler::codegen::builtin_types::resolver::resolve_static_method_name(struct_name, method_name, generic_args);
if self.module.get_function(&mangled_name).is_some() {
return Ok(mangled_name);
}
let substitutor = TypeSubstitutor::new(subst_map.clone());
let mut new_method = method.clone();
new_method.name = mangled_name.clone();
new_method.generics = vec![];
let concrete_self = if self.enum_defs.contains_key(struct_name) {
Type::Enum(struct_name.to_string(), generic_args.to_vec())
} else {
Type::Struct(struct_name.to_string(), generic_args.to_vec())
};
let mut full_map = substitutor.subst.clone();
full_map.insert("Self".to_string(), concrete_self);
let full_substitutor = TypeSubstitutor::new(full_map);
for (_, ty) in &mut new_method.args {
*ty = full_substitutor.substitute_type(ty);
*ty = self.normalize_type(ty);
}
new_method.return_type = full_substitutor.substitute_type(&new_method.return_type);
new_method.return_type = full_substitutor.substitute_type(&new_method.return_type);
new_method.return_type = self.normalize_type(&new_method.return_type);
new_method.body = new_method.body.iter().map(|s| {
let ns = full_substitutor.substitute_stmt(s);
if let crate::compiler::ast::StmtKind::Let { name, type_annotation, .. } = &s.inner {
if name == "result" {
eprintln!("[MONO] var={} orig_ta={:?}", name, type_annotation);
if let crate::compiler::ast::StmtKind::Let { type_annotation: n_ta, .. } = &ns.inner {
eprintln!("[MONO] var={} new_ta={:?} subst={:?}", name, n_ta, subst_map);
}
}
}
ns
}).collect();
self.transform_method_body_enum_inits(&mut new_method.body);
self.transform_method_body_struct_inits(&mut new_method.body, &full_substitutor);
let previous_block = self.builder.get_insert_block();
self.compile_fn_proto(&new_method)?;
if !new_method.is_extern {
self.pending_functions.push((new_method, Some(subst_map)));
}
if let Some(block) = previous_block {
self.builder.position_at_end(block);
}
Ok(mangled_name)
}
pub fn monomorphize_generic_function(
&mut self,
func_name: &str,
arg_types: &[Type],
) -> Result<String, String> {
if !self.generic_fn_defs.contains_key(func_name) {
return Ok(func_name.to_string());
}
let func_def = self.generic_fn_defs.get(func_name).cloned().unwrap();
if func_def.args.len() != arg_types.len() {
return Err(format!("Argument count mismatch for generic function {}: expected {}, got {}",
func_name, func_def.args.len(), arg_types.len()));
}
let mut subst_map: HashMap<String, Type> = HashMap::new();
for ((_, expected_ty), actual_ty) in func_def.args.iter().zip(arg_types) {
self.unify_types(expected_ty, actual_ty, &mut subst_map)?;
}
for param in &func_def.generics {
if !subst_map.contains_key(param) {
return Err(format!("Could not infer type parameter {} for function {}", param, func_name));
}
}
let type_args: Vec<Type> = func_def.generics.iter().map(|g| subst_map[g].clone()).collect();
let args_str: Vec<String> = type_args.iter().map(|t| self.type_to_suffix(t)).collect();
let mangled_name = mangle_wrap_args(func_name, &args_str);
if self.module.get_function(&mangled_name).is_some() {
return Ok(mangled_name);
}
let substitutor = TypeSubstitutor::new(subst_map.clone());
let mut new_func = func_def.clone();
new_func.name = mangled_name.clone();
new_func.generics = vec![];
for (_, ty) in &mut new_func.args {
*ty = substitutor.substitute_type(ty);
}
new_func.return_type = substitutor.substitute_type(&new_func.return_type);
new_func.body = new_func.body.iter().map(|s| substitutor.substitute_stmt(s)).collect();
self.transform_method_body_struct_inits(&mut new_func.body, &substitutor);
self.compile_fn_proto(&new_func)?;
self.pending_functions.push((new_func, Some(subst_map)));
Ok(mangled_name)
}
pub fn monomorphize_enum(
&mut self,
enum_name: &str,
generic_args: &[Type],
) -> Result<String, String> {
if self.enum_types.contains_key(enum_name) {
return Ok(enum_name.to_string());
}
if !generic_args.is_empty() {
let base = mangle_base_name(enum_name);
let candidate = self.mangle_type_name(base, generic_args);
if self.enum_types.contains_key(&candidate) {
return Ok(candidate);
}
}
let enum_def = if let Some(def) = self.enum_defs.get(enum_name) {
def.clone()
} else if enum_name.contains('[') && !enum_name.contains('<') {
let base_name = mangle_base_name(enum_name);
if let Some(def) = self.enum_defs.get(base_name) {
def.clone()
} else {
return Err(format!("Enum {} not found (tried base {})", enum_name, base_name));
}
} else {
return Err(format!("Enum {} not found", enum_name));
};
if enum_def.generics.len() != generic_args.len() {
return Err(format!("Generic count mismatch for enum {}: expected {}, got {}",
enum_name, enum_def.generics.len(), generic_args.len()));
}
let base = mangle_base_name(enum_name);
let mangled_name = self.mangle_type_name(base, generic_args);
if self.enum_types.contains_key(&mangled_name) {
return Ok(mangled_name);
}
self.specialization_registry.register(enum_name, generic_args);
let mut subst_map = HashMap::new();
for (param, arg) in enum_def.generics.iter().zip(generic_args) {
subst_map.insert(param.clone(), arg.clone());
}
let substitutor = TypeSubstitutor::new(subst_map);
let mut new_def = enum_def.clone();
new_def.name = mangled_name.clone();
new_def.generics = vec![];
for variant in &mut new_def.variants {
match &mut variant.kind {
crate::compiler::ast::VariantKind::Unit => {},
crate::compiler::ast::VariantKind::Tuple(types) => {
for t in types.iter_mut() {
*t = substitutor.substitute_type(t);
*t = self.to_unified_type_if_generic(t.clone());
}
}
crate::compiler::ast::VariantKind::Struct(fields) => {
for (_, t) in fields.iter_mut() {
*t = substitutor.substitute_type(t);
*t = self.to_unified_type_if_generic(t.clone());
}
}
crate::compiler::ast::VariantKind::Array(ty, _size) => {
*ty = substitutor.substitute_type(ty);
*ty = self.to_unified_type_if_generic(ty.clone());
}
}
}
for variant in &new_def.variants {
let types_to_check: Vec<&Type> = match &variant.kind {
crate::compiler::ast::VariantKind::Unit => vec![],
crate::compiler::ast::VariantKind::Tuple(types) => types.iter().collect(),
crate::compiler::ast::VariantKind::Struct(fields) => fields.iter().map(|(_, t)| t).collect(),
crate::compiler::ast::VariantKind::Array(ty, size) => vec![ty; *size],
};
for ty in types_to_check {
let _ = self.get_or_monomorphize_type(ty)?;
}
}
self.compile_enum_defs(&[new_def.clone()])
.map_err(|e| e.to_string())?;
Ok(mangled_name)
}
fn unify_types(
&self,
expected: &Type,
actual: &Type,
map: &mut HashMap<String, Type>,
) -> Result<(), String> {
match (expected, actual) {
(Type::Struct(name, args), _) => {
if args.is_empty() {
if let Some(existing) = map.get(name) {
if existing != actual {
return Err(format!("Type mismatch for generic {}: expected {:?}, got {:?}", name, existing, actual));
}
} else {
map.insert(name.clone(), actual.clone());
}
return Ok(());
}
if let Type::Struct(act_name, act_args) = actual {
if name != act_name || args.len() != act_args.len() {
return Err("Type mismatch or arity mismatch".into());
}
for (e, a) in args.iter().zip(act_args) {
self.unify_types(e, a, map)?;
}
return Ok(());
}
return Err(format!("Type mismatch: Expected Struct {}, found {:?}", name, actual));
}
(Type::Tensor(e, r), Type::Tensor(a, ar)) => {
if r != ar { return Err("Rank mismatch".into()); }
self.unify_types(e, a, map)?;
}
(Type::Array(e_inner, e_size), Type::Array(a_inner, a_size)) => {
if e_size != a_size { return Err(format!("Array size mismatch: expected {}, got {}", e_size, a_size)); }
self.unify_types(e_inner, a_inner, map)?;
}
_ => {
if expected != actual {
return Err(format!("Type mismatch: expected {:?}, got {:?}", expected, actual));
}
}
}
Ok(())
}
pub fn mangle_type_name(&self, base_name: &str, type_args: &[Type]) -> String {
if type_args.is_empty() {
base_name.to_string()
} else if base_name.contains('[') {
base_name.to_string()
} else {
let args_str: Vec<String> = type_args.iter().map(|t| self.type_to_suffix(t)).collect();
mangle_wrap_args(base_name, &args_str)
}
}
pub fn type_to_suffix(&self, ty: &Type) -> String {
match ty {
Type::I64 => "i64".to_string(),
Type::I32 => "i32".to_string(),
Type::U8 => "u8".to_string(),
Type::F32 => "f32".to_string(),
Type::F64 => "f64".to_string(),
Type::Bool => "bool".to_string(),
Type::Usize => "usize".to_string(),
Type::Void => "void".to_string(),
Type::String(_) => "String".to_string(),
Type::Char(_) => "Char".to_string(),
Type::Struct(name, args) => {
if args.is_empty() || mangle_has_args(name) {
name.clone()
} else {
self.mangle_type_name(name, args)
}
}
Type::Enum(name, args) => {
if args.is_empty() || mangle_has_args(name) {
name.clone()
} else {
self.mangle_type_name(name, args)
}
}
Type::SpecializedType { gen_type, .. } => {
gen_type.mangled_name_or_name().unwrap_or("specialized").to_string()
}
Type::Tensor(inner, rank) => {
let args = vec![self.type_to_suffix(inner), rank.to_string()];
mangle_wrap_args("Tensor", &args)
}
Type::Tuple(types) => {
let parts: Vec<String> = types.iter().map(|t| self.type_to_suffix(t)).collect();
mangle_wrap_args("Tuple", &parts)
}
Type::Path(path, args) => {
if let Some(name) = path.last() {
if args.is_empty() {
name.clone()
} else {
self.mangle_type_name(name, args)
}
} else {
"unknown_path".to_string()
}
}
Type::Undefined(id) => format!("undefined{}", MANGLER.wrap_single(&id.to_string())),
Type::Ptr(inner) => format!("ptr{}", MANGLER.wrap_single(&self.type_to_suffix(inner))),
Type::Array(inner, size) => {
let args = vec![self.type_to_suffix(inner), size.to_string()];
mangle_wrap_args("Array", &args)
}
Type::Entity => "entity".to_string(),
_ => "unknown".to_string(),
}
}
pub fn mangle_generic_method(
&self,
base_type: &str,
type_args: &[Type],
method: &str,
) -> String {
let suffix = if type_args.is_empty() {
String::new()
} else {
type_args.iter()
.map(|t| MANGLER.wrap_single(&self.type_to_suffix(t).to_lowercase()))
.collect::<String>()
};
format!("tl_{}{}_{}", base_type.to_lowercase(), suffix, method)
}
pub fn get_llvm_type(&self, ty: &Type) -> Result<BasicTypeEnum<'ctx>, String> {
match ty {
Type::Ptr(_) => Ok(self.context.ptr_type(AddressSpace::default()).into()),
Type::I64 | Type::Entity => Ok(self.context.i64_type().into()),
Type::I32 => Ok(self.context.i32_type().into()),
Type::F32 => Ok(self.context.f32_type().into()),
Type::F64 => Ok(self.context.f64_type().into()),
Type::Bool => Ok(self.context.bool_type().into()),
Type::U8 => Ok(self.context.i8_type().into()), Type::Usize => Ok(self.context.i64_type().into()), Type::Void => Ok(self.context.i8_type().into()),
Type::Tensor(_, _) | Type::TensorShaped(_, _) | Type::GradTensor(_, _) => {
Ok(self.context.ptr_type(AddressSpace::default()).into())
}
Type::String(_) => {
Ok(self.context.ptr_type(AddressSpace::default()).into())
}
Type::Char(_) => {
Ok(self.context.i32_type().into())
}
Type::Struct(name, _args) => {
match name.as_str() {
"bool" => return Ok(self.context.bool_type().into()),
"i64" => return Ok(self.context.i64_type().into()),
"i32" => return Ok(self.context.i32_type().into()),
"f32" => return Ok(self.context.f32_type().into()),
"f64" => return Ok(self.context.f64_type().into()),
"usize" => return Ok(self.context.i64_type().into()),
"u8" => return Ok(self.context.i8_type().into()), "String" => return Ok(self.context.ptr_type(inkwell::AddressSpace::default()).into()),
_ => {}
}
if name == "File" || name == "Path" || name == "Env" || name == "Http" {
return Ok(self.context.ptr_type(AddressSpace::default()).into());
}
let simple_name = name.as_str();
if let Some(def) = self.struct_defs.get(simple_name) {
if def.fields.is_empty() {
return Ok(self.context.ptr_type(AddressSpace::default()).into());
}
}
if name.contains("PhantomData") {
}
Ok(self.context.ptr_type(AddressSpace::default()).into())
}
Type::Enum(_name, _args) => {
Ok(self.context.ptr_type(AddressSpace::default()).into())
}
Type::Array(inner, size) => {
let elem_ty = self.get_llvm_type(inner)?;
match elem_ty {
BasicTypeEnum::IntType(t) => Ok(t.array_type(*size as u32).into()),
BasicTypeEnum::FloatType(t) => Ok(t.array_type(*size as u32).into()),
BasicTypeEnum::PointerType(t) => Ok(t.array_type(*size as u32).into()),
BasicTypeEnum::StructType(t) => Ok(t.array_type(*size as u32).into()),
BasicTypeEnum::ArrayType(t) => Ok(t.array_type(*size as u32).into()),
BasicTypeEnum::VectorType(t) => Ok(t.array_type(*size as u32).into()),
_ => Err(format!("Unsupported array element type: {:?}", elem_ty)),
}
}
Type::Tuple(_) => {
Ok(self.context.ptr_type(AddressSpace::default()).into())
}
Type::Path(_segments, _) => {
Ok(self.context.ptr_type(AddressSpace::default()).into())
}
Type::Fn(_, _) => {
let ptr_ty = self.context.ptr_type(AddressSpace::default());
Ok(self.context.struct_type(&[ptr_ty.into(), ptr_ty.into()], false).into())
}
Type::TraitObject(_) => {
let ptr_ty = self.context.ptr_type(AddressSpace::default());
Ok(self.context.struct_type(&[ptr_ty.into(), ptr_ty.into()], false).into())
}
Type::SpecializedType { gen_type, .. } => {
if gen_type.is_enum_type() || gen_type.is_struct_type() {
Ok(self.context.ptr_type(AddressSpace::default()).into())
} else {
Ok(self.context.ptr_type(AddressSpace::default()).into())
}
}
_ => {
Err(format!("get_llvm_type: compilation error, unhandled or unresolved type {:?}", ty))
}
}
}
pub fn get_or_monomorphize_type(&mut self, ty: &Type) -> Result<BasicTypeEnum<'ctx>, String> {
match ty {
Type::Struct(name, args) if !args.is_empty() => {
let _ = self.monomorphize_struct(name, args)?;
Ok(self.context.ptr_type(AddressSpace::default()).into())
}
Type::Enum(name, args) if !args.is_empty() => {
let _ = self.monomorphize_enum(name, args)?;
Ok(self.context.ptr_type(AddressSpace::default()).into())
}
_ => self.get_llvm_type(ty),
}
}
pub fn monomorphize_struct(
&mut self,
base_name: &str,
type_args: &[Type],
) -> Result<StructType<'ctx>, String> {
let mangled_name = self.mangle_type_name(base_name, type_args);
if let Some(existing) = self.struct_types.get(&mangled_name) {
return Ok(*existing);
}
let struct_def = self.struct_defs.get(base_name).cloned()
.ok_or_else(|| format!("Generic struct definition not found: {}", base_name))?;
self.specialization_registry.register(base_name, type_args);
let mut subst: HashMap<String, Type> = HashMap::new();
for (i, param_name) in struct_def.generics.iter().enumerate() {
if let Some(arg) = type_args.get(i) {
subst.insert(param_name.clone(), arg.clone());
}
}
let opaque_struct = self.context.opaque_struct_type(&mangled_name);
self.struct_types.insert(mangled_name.clone(), opaque_struct);
let mut field_llvm_types = Vec::new();
for (field_name, field_ty) in &struct_def.fields {
let substituted_ty = self.substitute_type(field_ty, &subst);
let llvm_ty = self.get_llvm_type(&substituted_ty).map_err(|e| format!("Error compiling field {} of {}: {}", field_name, mangled_name, e))?;
field_llvm_types.push(llvm_ty);
}
opaque_struct.set_body(&field_llvm_types, false);
let mut specialized_def = struct_def.clone();
specialized_def.name = mangled_name.clone();
specialized_def.generics = vec![]; specialized_def.fields = struct_def.fields.iter().map(|(name, ty)| {
let substituted = self.substitute_type(ty, &subst);
let unified = self.to_unified_type_if_generic(substituted);
(name.clone(), unified)
}).collect();
self.struct_defs.insert(mangled_name.clone(), specialized_def);
Ok(opaque_struct)
}
#[allow(dead_code)]
fn contains_unresolved_generics(args: &[Type]) -> bool {
args.iter().any(|arg| Self::is_unresolved_generic(arg))
}
#[allow(dead_code)]
fn is_unresolved_generic(ty: &Type) -> bool {
match ty {
Type::Struct(name, inner_args) => {
if inner_args.is_empty() && name.len() <= 2 && name.chars().all(|c| c.is_ascii_uppercase()) {
return true;
}
inner_args.iter().any(|a| Self::is_unresolved_generic(a))
}
Type::Enum(_, inner_args) => inner_args.iter().any(|a| Self::is_unresolved_generic(a)),
Type::SpecializedType { type_args, .. } => type_args.iter().any(|a| Self::is_unresolved_generic(a)),
Type::Tuple(types) => types.iter().any(|a| Self::is_unresolved_generic(a)),
Type::Undefined(_) => true,
_ => false,
}
}
pub fn to_unified_type_if_generic(&self, ty: Type) -> Type {
match &ty {
Type::Struct(name, args) if !args.is_empty() => {
let is_enum = self.enum_defs.contains_key(name)
|| self.enum_defs.contains_key(mangle_base_name(name));
let (base, mangled) = if mangle_has_args(name) {
(mangle_base_name(name).to_string(), name.clone())
} else {
(name.clone(), self.mangle_type_name(name, args))
};
let unified_args: Vec<Type> = args.iter()
.map(|a| self.to_unified_type_if_generic(a.clone()))
.collect();
Type::SpecializedType {
gen_type: Box::new(if is_enum { Type::Enum(base, vec![]) } else { Type::Struct(base, vec![]) }),
type_args: unified_args,
type_map: vec![], mangled_name: mangled,
}
}
Type::Struct(name, args) if args.is_empty() => {
if self.enum_defs.contains_key(name) {
Type::Enum(name.clone(), vec![])
} else {
ty
}
}
Type::Enum(name, args) if !args.is_empty() => {
let (base, mangled) = if mangle_has_args(name) {
(mangle_base_name(name).to_string(), name.clone())
} else {
(name.clone(), self.mangle_type_name(name, args))
};
let unified_args: Vec<Type> = args.iter()
.map(|a| self.to_unified_type_if_generic(a.clone()))
.collect();
Type::SpecializedType {
gen_type: Box::new(Type::Enum(base, vec![])),
type_args: unified_args,
type_map: vec![],
mangled_name: mangled,
}
}
Type::Tuple(types) => {
let unified_types: Vec<Type> = types.iter()
.map(|t| self.to_unified_type_if_generic(t.clone()))
.collect();
Type::Tuple(unified_types)
}
_ => ty,
}
}
pub fn substitute_type(&self, ty: &Type, subst: &HashMap<String, Type>) -> Type {
match ty {
Type::Struct(name, args) => {
if args.is_empty() {
if let Some(replacement) = subst.get(name) {
return replacement.clone();
}
}
let new_args: Vec<Type> = args.iter().map(|a| self.substitute_type(a, subst)).collect();
if self.enum_defs.contains_key(name) {
Type::Enum(name.clone(), new_args)
} else {
Type::Struct(name.clone(), new_args)
}
}
Type::Enum(name, args) => {
if args.is_empty() {
if let Some(replacement) = subst.get(name) {
return replacement.clone();
}
}
let new_args: Vec<Type> = args.iter().map(|a| self.substitute_type(a, subst)).collect();
Type::Enum(name.clone(), new_args)
}
Type::Tensor(inner, rank) => Type::Tensor(Box::new(self.substitute_type(inner, subst)), *rank),
Type::Tuple(types) => Type::Tuple(types.iter().map(|t| self.substitute_type(t, subst)).collect()),
Type::Path(segments, args) => {
if segments.len() == 1 {
let name = &segments[0];
if let Some(replacement) = subst.get(name) {
return replacement.clone();
}
let new_args: Vec<Type> = args.iter().map(|a| self.substitute_type(a, subst)).collect();
if self.enum_defs.contains_key(name) {
return Type::Enum(name.clone(), new_args);
} else {
return Type::Struct(name.clone(), new_args);
}
}
let new_args: Vec<Type> = args.iter().map(|a| self.substitute_type(a, subst)).collect();
Type::Path(segments.clone(), new_args)
},
Type::Ptr(inner) => Type::Ptr(Box::new(self.substitute_type(inner, subst))),
Type::Array(inner, size) => Type::Array(Box::new(self.substitute_type(inner, subst)), *size),
Type::SpecializedType { gen_type, type_args, type_map, mangled_name } => {
let new_args: Vec<Type> = type_args.iter().map(|a| self.substitute_type(a, subst)).collect();
Type::SpecializedType {
gen_type: gen_type.clone(),
type_args: new_args,
type_map: type_map.clone(),
mangled_name: mangled_name.clone(),
}
}
_ => ty.clone(),
}
}
pub fn normalize_type(&self, ty: &Type) -> Type {
match ty {
Type::Path(segments, args) => {
if segments.len() == 1 {
let name = &segments[0];
let normalized_args: Vec<Type> = args.iter().map(|a| self.normalize_type(a)).collect();
if let Some(enum_def) = self.enum_defs.get(name) {
if enum_def.generics.len() == normalized_args.len() || enum_def.generics.is_empty() {
return Type::Enum(name.clone(), normalized_args);
}
return ty.clone();
}
Type::Struct(name.clone(), normalized_args)
} else {
let normalized_args: Vec<Type> = args.iter().map(|a| self.normalize_type(a)).collect();
Type::Path(segments.clone(), normalized_args)
}
}
Type::Struct(name, args) => {
let normalized_args: Vec<Type> = args.iter().map(|a| self.normalize_type(a)).collect();
if let Some(enum_def) = self.enum_defs.get(name) {
if enum_def.generics.len() == normalized_args.len() || enum_def.generics.is_empty() {
return Type::Enum(name.clone(), normalized_args);
}
}
Type::Struct(name.clone(), normalized_args)
}
Type::Enum(name, args) => {
let normalized_args: Vec<Type> = args.iter().map(|a| self.normalize_type(a)).collect();
Type::Enum(name.clone(), normalized_args)
}
Type::SpecializedType { gen_type, type_args, type_map: _, mangled_name } => {
let normalized_args: Vec<Type> = type_args.iter().map(|a| self.normalize_type(a)).collect();
if gen_type.is_enum_type() {
Type::Enum(mangled_name.clone(), normalized_args)
} else {
Type::Struct(mangled_name.clone(), normalized_args)
}
}
Type::Tensor(inner, rank) => Type::Tensor(Box::new(self.normalize_type(inner)), *rank),
Type::Tuple(types) => Type::Tuple(types.iter().map(|t| self.normalize_type(t)).collect()),
Type::Ptr(inner) => Type::Ptr(Box::new(self.normalize_type(inner))),
Type::Array(inner, size) => Type::Array(Box::new(self.normalize_type(inner)), *size),
_ => ty.clone(),
}
}
fn transform_method_body_enum_inits(&self, stmts: &mut Vec<crate::compiler::ast::Stmt>) {
for stmt in stmts.iter_mut() {
self.transform_stmt_enum_inits(stmt);
}
}
fn transform_stmt_enum_inits(&self, stmt: &mut crate::compiler::ast::Stmt) {
use crate::compiler::ast::StmtKind;
match &mut stmt.inner {
StmtKind::Let { value, .. } => self.transform_expr_enum_inits(value),
StmtKind::Expr(e) => self.transform_expr_enum_inits(e),
StmtKind::Return(Some(e)) => self.transform_expr_enum_inits(e),
StmtKind::While { cond, body } => {
self.transform_expr_enum_inits(cond);
self.transform_method_body_enum_inits(body);
}
StmtKind::For { iterator, body, .. } => {
self.transform_expr_enum_inits(iterator);
self.transform_method_body_enum_inits(body);
}
StmtKind::Loop { body } => {
self.transform_method_body_enum_inits(body);
}
StmtKind::Assign { value, .. } => {
self.transform_expr_enum_inits(value);
}
_ => {}
}
}
fn transform_expr_enum_inits(&self, expr: &mut crate::compiler::ast::Expr) {
use crate::compiler::ast::{ExprKind, EnumVariantInit};
match &mut expr.inner {
ExprKind::BinOp(l, _, r) => {
self.transform_expr_enum_inits(l);
self.transform_expr_enum_inits(r);
}
ExprKind::UnOp(_, e) => {
self.transform_expr_enum_inits(e);
}
ExprKind::MethodCall(obj, _, args) => {
self.transform_expr_enum_inits(obj);
for arg in args.iter_mut() {
self.transform_expr_enum_inits(arg);
}
}
ExprKind::FnCall(_, args) => {
for arg in args.iter_mut() {
self.transform_expr_enum_inits(arg);
}
}
ExprKind::IndexAccess(e, indices) => {
self.transform_expr_enum_inits(e);
for idx in indices.iter_mut() {
self.transform_expr_enum_inits(idx);
}
}
ExprKind::FieldAccess(obj, _) => {
self.transform_expr_enum_inits(obj);
}
ExprKind::Match { expr: subject, arms } => {
self.transform_expr_enum_inits(subject);
for (_, arm_expr) in arms.iter_mut() {
self.transform_expr_enum_inits(arm_expr);
}
}
ExprKind::Block(stmts) => {
self.transform_method_body_enum_inits(stmts);
}
ExprKind::IfExpr(cond, then_block, else_block) => {
self.transform_expr_enum_inits(cond);
self.transform_method_body_enum_inits(then_block);
if let Some(else_stmts) = else_block {
self.transform_method_body_enum_inits(else_stmts);
}
}
ExprKind::Tuple(exprs) => {
for e in exprs.iter_mut() {
self.transform_expr_enum_inits(e);
}
}
ExprKind::StructInit(_, fields) => {
for (_, e) in fields.iter_mut() {
self.transform_expr_enum_inits(e);
}
}
ExprKind::EnumInit { payload, .. } => {
match payload {
EnumVariantInit::Tuple(exprs) => {
for e in exprs.iter_mut() {
self.transform_expr_enum_inits(e);
}
}
EnumVariantInit::Struct(fields) => {
for (_, e) in fields.iter_mut() {
self.transform_expr_enum_inits(e);
}
}
EnumVariantInit::Unit => {}
}
}
ExprKind::StaticMethodCall(ty, method, args) => {
for arg in args.iter_mut() {
self.transform_expr_enum_inits(arg);
}
let enum_name = match ty {
Type::Struct(name, _) | Type::Enum(name, _) => name.clone(),
Type::Path(segments, _) => segments.last().cloned().unwrap_or_default(),
_ => String::new(),
};
if let Some(enum_def) = self.enum_defs.get(&enum_name) {
if let Some(variant) = enum_def.variants.iter().find(|v| &v.name == method) {
use crate::compiler::ast::VariantKind;
let payload = match &variant.kind {
VariantKind::Unit => EnumVariantInit::Unit,
VariantKind::Tuple(_) => EnumVariantInit::Tuple(std::mem::take(args)),
VariantKind::Struct(fields) => {
let field_pairs: Vec<(String, Expr)> = fields.iter()
.zip(std::mem::take(args).into_iter())
.map(|((name, _), expr)| (name.clone(), expr))
.collect();
EnumVariantInit::Struct(field_pairs)
}
VariantKind::Array(_, _) => EnumVariantInit::Tuple(std::mem::take(args)),
};
let generics = match ty {
Type::Struct(_, g) | Type::Enum(_, g) | Type::Path(_, g) => g.clone(),
_ => vec![],
};
expr.inner = ExprKind::EnumInit {
enum_name,
variant_name: method.clone(),
generics,
payload,
};
return;
}
}
}
_ => {}
}
}
fn transform_method_body_struct_inits(&self, stmts: &mut Vec<crate::compiler::ast::Stmt>, substitutor: &TypeSubstitutor) {
for stmt in stmts.iter_mut() {
self.transform_stmt_struct_inits(stmt, substitutor);
}
}
fn transform_stmt_struct_inits(&self, stmt: &mut crate::compiler::ast::Stmt, substitutor: &TypeSubstitutor) {
use crate::compiler::ast::StmtKind;
match &mut stmt.inner {
StmtKind::Let { value, .. } => self.transform_expr_struct_inits(value, substitutor),
StmtKind::Expr(e) => self.transform_expr_struct_inits(e, substitutor),
StmtKind::Return(Some(e)) => self.transform_expr_struct_inits(e, substitutor),
StmtKind::While { cond, body } => {
self.transform_expr_struct_inits(cond, substitutor);
self.transform_method_body_struct_inits(body, substitutor);
}
StmtKind::For { iterator, body, .. } => {
self.transform_expr_struct_inits(iterator, substitutor);
self.transform_method_body_struct_inits(body, substitutor);
}
StmtKind::Loop { body } => {
self.transform_method_body_struct_inits(body, substitutor);
}
StmtKind::Assign { value, .. } => {
self.transform_expr_struct_inits(value, substitutor);
}
_ => {}
}
}
fn transform_expr_struct_inits(&self, expr: &mut crate::compiler::ast::Expr, substitutor: &TypeSubstitutor) {
use crate::compiler::ast::ExprKind;
match &mut expr.inner {
ExprKind::BinOp(l, _, r) => {
self.transform_expr_struct_inits(l, substitutor);
self.transform_expr_struct_inits(r, substitutor);
}
ExprKind::UnOp(_, e) => self.transform_expr_struct_inits(e, substitutor),
ExprKind::MethodCall(obj, _, args) => {
self.transform_expr_struct_inits(obj, substitutor);
for arg in args.iter_mut() {
self.transform_expr_struct_inits(arg, substitutor);
}
}
ExprKind::FnCall(_, args) => {
for arg in args.iter_mut() {
self.transform_expr_struct_inits(arg, substitutor);
}
}
ExprKind::StaticMethodCall(ty, _, args) => {
for arg in args.iter_mut() {
self.transform_expr_struct_inits(arg, substitutor);
}
self.transform_generic_type(ty, substitutor);
}
ExprKind::EnumInit { payload, .. } => {
match payload {
crate::compiler::ast::EnumVariantInit::Tuple(args) => {
for arg in args.iter_mut() {
self.transform_expr_struct_inits(arg, substitutor);
}
}
crate::compiler::ast::EnumVariantInit::Struct(fields) => {
for (_, e) in fields.iter_mut() {
self.transform_expr_struct_inits(e, substitutor);
}
}
crate::compiler::ast::EnumVariantInit::Unit => {}
}
}
ExprKind::Match { expr: match_expr, arms } => {
self.transform_expr_struct_inits(match_expr, substitutor);
for (_, arm_body) in arms.iter_mut() {
self.transform_expr_struct_inits(arm_body, substitutor);
}
}
ExprKind::Try(inner) => {
self.transform_expr_struct_inits(inner, substitutor);
}
ExprKind::IndexAccess(obj, indices) => {
self.transform_expr_struct_inits(obj, substitutor);
for idx in indices.iter_mut() {
self.transform_expr_struct_inits(idx, substitutor);
}
}
ExprKind::Tuple(elems) => {
for e in elems.iter_mut() {
self.transform_expr_struct_inits(e, substitutor);
}
}
ExprKind::TupleAccess(obj, _) => {
self.transform_expr_struct_inits(obj, substitutor);
}
ExprKind::Closure { body, .. } => {
self.transform_method_body_struct_inits(body, substitutor);
}
ExprKind::FieldAccess(obj, _) => self.transform_expr_struct_inits(obj, substitutor),
ExprKind::Block(stmts) => self.transform_method_body_struct_inits(stmts, substitutor),
ExprKind::IfExpr(cond, then_block, else_block) => {
self.transform_expr_struct_inits(cond, substitutor);
self.transform_method_body_struct_inits(then_block, substitutor);
if let Some(else_stmts) = else_block {
self.transform_method_body_struct_inits(else_stmts, substitutor);
}
}
ExprKind::StructInit(ty, fields) => {
for (_, e) in fields.iter_mut() {
self.transform_expr_struct_inits(e, substitutor);
}
self.transform_generic_type(ty, substitutor);
}
_ => {}
}
}
fn transform_generic_type(&self, ty: &mut Type, substitutor: &TypeSubstitutor) {
if let Type::Path(segments, generics) = ty {
if segments.len() == 1 {
let name = &segments[0];
if self.struct_defs.contains_key(name) || self.generic_impls.contains_key(name) {
*ty = Type::Struct(name.clone(), generics.clone());
} else if self.enum_defs.contains_key(name) {
*ty = Type::Enum(name.clone(), generics.clone());
}
}
}
let name = match ty {
Type::Struct(n, _) | Type::Enum(n, _) => n.clone(),
_ => return,
};
let generic_params = self.get_type_generic_params(&name);
if generic_params.is_empty() {
return;
}
let concrete_args: Vec<Type> = generic_params.iter()
.map(|g| {
let param_ty = Type::Path(vec![g.clone()], vec![]);
let substituted = substitutor.substitute_type(¶m_ty);
self.normalize_type(&substituted)
})
.collect();
let all_resolved = concrete_args.iter().zip(generic_params.iter()).all(|(resolved, param)| {
!matches!(resolved, Type::Path(segs, _) if segs.len() == 1 && segs[0] == *param)
});
if !all_resolved {
return; }
let mangled = self.mangle_type_name(&name, &concrete_args);
match ty {
Type::Struct(n, g) => { *n = mangled; g.clear(); }
Type::Enum(n, g) => { *n = mangled; g.clear(); }
_ => {}
}
}
fn get_type_generic_params(&self, name: &str) -> Vec<String> {
if let Some(def) = self.struct_defs.get(name) {
if !def.generics.is_empty() {
return def.generics.clone();
}
}
if let Some(def) = self.enum_defs.get(name) {
if !def.generics.is_empty() {
return def.generics.clone();
}
}
if let Some(impls) = self.generic_impls.get(name) {
for imp in impls {
if !imp.generics.is_empty() {
return imp.generics.clone();
}
}
}
vec![]
}
pub fn generate_methods_for_specialized_type(
&mut self,
base_name: &str,
type_args: &[Type],
) -> Result<(), String> {
let impls = match self.generic_impls.get(base_name) {
Some(impls) => impls.clone(),
None => return Ok(()), };
for imp in &impls {
if base_name == "Vec" {
let _method_names: Vec<String> = imp.methods.iter().map(|m| m.name.clone()).collect();
}
for method in &imp.methods {
if !self.check_method_trait_bounds(method, type_args, &imp.generics) {
log::debug!("Skipping method {}.{}: trait bounds not satisfied for {:?}",
base_name, method.name, type_args);
continue;
}
let mangled_name = crate::compiler::codegen::builtin_types::resolver::resolve_static_method_name(
base_name, &method.name, type_args
);
if self.module.get_function(&mangled_name).is_some() {
continue;
}
match self.monomorphize_method(base_name, &method.name, type_args) {
Ok(_) => {}
Err(e) => {
log::debug!("Could not generate method {}.{}: {}",
base_name, method.name, e);
}
}
}
}
Ok(())
}
fn check_method_trait_bounds(
&self,
method: &crate::compiler::ast::FunctionDef,
type_args: &[Type],
impl_generics: &[String],
) -> bool {
if method.name == "index_of" || method.name == "contains" {
}
if method.generic_bounds.is_empty() {
if let Some(ref wc) = method.where_clause {
for pred in &wc.predicates {
let concrete_type = self.resolve_generic_param(&pred.type_param, type_args, impl_generics);
let type_name = self.concrete_type_to_trait_key(&concrete_type);
for bound in &pred.bounds {
let trait_impls = self.trait_registry.get(&type_name);
let mut satisfied = trait_impls.map_or(false, |traits| traits.contains(&bound.trait_name));
if !satisfied {
satisfied = match type_name.as_str() {
"i64" | "f64" | "i32" | "f32" | "u8" | "char" | "bool" | "String" => {
match bound.trait_name.as_str() {
"PartialEq" | "Clone" => true,
"PartialOrd" => type_name != "String" && type_name != "bool",
_ => false,
}
}
_ => false,
};
}
if !satisfied {
return false;
}
}
}
}
return true;
}
for (param_name, bounds) in &method.generic_bounds {
let concrete_type = self.resolve_generic_param(param_name, type_args, impl_generics);
let type_name = self.concrete_type_to_trait_key(&concrete_type);
for bound in bounds {
let trait_impls = self.trait_registry.get(&type_name);
let mut satisfied = trait_impls.map_or(false, |traits| traits.contains(&bound.trait_name));
if !satisfied {
satisfied = match type_name.as_str() {
"i64" | "f64" | "i32" | "f32" | "u8" | "char" | "bool" | "String" => {
match bound.trait_name.as_str() {
"PartialEq" | "Clone" => true,
"PartialOrd" => type_name != "String" && type_name != "bool",
_ => false,
}
}
_ => false,
};
}
if !satisfied {
return false;
}
}
}
true
}
pub(crate) fn resolve_generic_param(&self, param: &str, type_args: &[Type], impl_generics: &[String]) -> Type {
for (i, generic_name) in impl_generics.iter().enumerate() {
if generic_name == param {
if let Some(ty) = type_args.get(i) {
return ty.clone();
}
}
}
Type::Struct(param.to_string(), vec![])
}
pub(crate) fn concrete_type_to_trait_key(&self, ty: &Type) -> String {
match ty {
Type::I64 => "i64".to_string(),
Type::I32 => "i32".to_string(),
Type::F32 => "f32".to_string(),
Type::F64 => "f64".to_string(),
Type::Bool => "bool".to_string(),
Type::U8 => "u8".to_string(),
Type::Usize => "usize".to_string(),
Type::String(_) => "String".to_string(),
Type::Struct(name, _) => name.clone(),
Type::Enum(name, _) => name.clone(),
Type::SpecializedType { gen_type, .. } => gen_type.get_base_name(),
_ => format!("{:?}", ty),
}
}
}