use crate::compiler::ast::{Type, mangle_wrap_args};
use std::collections::HashSet;
pub struct SpecializationRegistry {
specializations: Vec<(String, Vec<Type>)>,
mangled_names: HashSet<String>,
}
impl SpecializationRegistry {
pub fn new() -> Self {
Self {
specializations: Vec::new(),
mangled_names: HashSet::new(),
}
}
pub fn register(&mut self, base_name: &str, type_args: &[Type]) -> bool {
let mangled = mangle_type_name(base_name, type_args);
if self.mangled_names.contains(&mangled) {
return false; }
self.mangled_names.insert(mangled);
self.specializations.push((base_name.to_string(), type_args.to_vec()));
true }
pub fn contains(&self, base_name: &str, type_args: &[Type]) -> bool {
let mangled = mangle_type_name(base_name, type_args);
self.mangled_names.contains(&mangled)
}
pub fn drain_pending(&mut self) -> Vec<(String, Vec<Type>)> {
std::mem::take(&mut self.specializations)
}
pub fn len(&self) -> usize {
self.mangled_names.len()
}
pub fn is_empty(&self) -> bool {
self.specializations.is_empty()
}
}
impl Default for SpecializationRegistry {
fn default() -> Self {
Self::new()
}
}
pub fn type_to_suffix(ty: &Type) -> String {
match ty {
Type::I64 => "i64".to_string(),
Type::I32 => "i32".to_string(),
Type::U8 => "u8".to_string(),
Type::F32 => "f32".to_string(),
Type::F64 => "f64".to_string(),
Type::Bool => "bool".to_string(),
Type::Usize => "usize".to_string(),
Type::Void => "void".to_string(),
Type::String(_) => "String".to_string(),
Type::Char(_) => "Char".to_string(),
Type::Struct(name, args) => {
if args.is_empty() {
name.clone()
} else {
mangle_type_name(name, args)
}
}
Type::Enum(name, args) => {
if args.is_empty() {
name.clone()
} else {
mangle_type_name(name, args)
}
}
Type::Tensor(inner, rank) => {
let args = vec![type_to_suffix(inner), rank.to_string()];
mangle_wrap_args("Tensor", &args)
}
Type::Tuple(types) => {
let args: Vec<String> = types.iter().map(type_to_suffix).collect();
mangle_wrap_args("Tuple", &args)
}
Type::Path(segments, args) => {
let base = segments.join("_");
if args.is_empty() {
base
} else {
let args_str: Vec<String> = args.iter().map(type_to_suffix).collect();
mangle_wrap_args(&base, &args_str)
}
}
Type::Array(inner, size) => {
let args = vec![type_to_suffix(inner), size.to_string()];
mangle_wrap_args("Array", &args)
}
_ => "unknown".to_string(),
}
}
pub fn mangle_type_name(base_name: &str, type_args: &[Type]) -> String {
if type_args.is_empty() {
base_name.to_string()
} else {
let args_str: Vec<String> = type_args.iter().map(type_to_suffix).collect();
mangle_wrap_args(base_name, &args_str)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_register_new() {
let mut registry = SpecializationRegistry::new();
assert!(registry.register("Vec", &[Type::I64]));
assert_eq!(registry.len(), 1);
}
#[test]
fn test_register_duplicate() {
let mut registry = SpecializationRegistry::new();
assert!(registry.register("Vec", &[Type::I64]));
assert!(!registry.register("Vec", &[Type::I64])); assert_eq!(registry.len(), 1);
}
#[test]
fn test_mangle_type_name() {
assert_eq!(mangle_type_name("Vec", &[Type::I64]), "Vec[i64]");
assert_eq!(mangle_type_name("HashMap", &[Type::I64, Type::I64]), "HashMap[i64][i64]");
assert_eq!(mangle_type_name("Option", &[]), "Option");
}
}