extern crate alloc;
use alloc::boxed::Box;
use alloc::collections::BTreeMap;
use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use crate::ast::*;
pub fn monomorphize(program: Program) -> Program {
let mut program = program;
let generics: BTreeMap<String, FunctionDef> = program
.functions
.iter()
.filter(|f| !f.type_params.is_empty())
.map(|f| (f.name.clone(), f.clone()))
.collect();
let mut fn_returns: BTreeMap<String, TypeExpr> = BTreeMap::new();
for f in &program.functions {
fn_returns.insert(f.name.clone(), f.return_type.clone());
}
for impl_block in &program.impls {
let head = type_head_for_impl(&impl_block.for_type);
for method in &impl_block.methods {
fn_returns.insert(
alloc::format!("{}::{}", head, method.name),
method.return_type.clone(),
);
}
}
let struct_table: BTreeMap<String, StructDef> = program
.types
.iter()
.filter_map(|td| match td {
TypeDef::Struct(s) => Some((s.name.clone(), s.clone())),
_ => None,
})
.collect();
let mut local_types: BTreeMap<String, TypeExpr> = BTreeMap::new();
let mut specs: BTreeMap<(String, String), String> = BTreeMap::new();
let mut new_functions: Vec<FunctionDef> = Vec::new();
{
use crate::visitor::MutVisitor;
for func in &mut program.functions {
if func.type_params.is_empty() {
local_types.clear();
for param in &func.params {
if let Some(t) = ¶m.type_expr
&& let Pattern::Variable(name, _) = ¶m.pattern
{
local_types.insert(name.clone(), t.clone());
}
}
let mut visitor = CallSpecializer {
generics: &generics,
locals: &mut local_types,
specs: &mut specs,
new_functions: &mut new_functions,
fn_returns: &fn_returns,
struct_table: &struct_table,
};
visitor.visit_block(&mut func.body);
}
}
}
const SPECIALIZATION_LIMIT: usize = 1024;
const PER_FUNCTION_LIMIT: usize = 64;
let mut per_fn_counts: BTreeMap<String, usize> = BTreeMap::new();
for (origin, _) in specs.keys() {
*per_fn_counts.entry(origin.clone()).or_insert(0) += 1;
}
let mut idx = 0;
while idx < new_functions.len() {
if new_functions.len() > SPECIALIZATION_LIMIT {
break;
}
let mut max_count = 0;
for &count in per_fn_counts.values() {
if count > max_count {
max_count = count;
}
}
if max_count > PER_FUNCTION_LIMIT {
break;
}
local_types.clear();
for param in &new_functions[idx].params {
if let Some(t) = ¶m.type_expr
&& let Pattern::Variable(name, _) = ¶m.pattern
{
local_types.insert(name.clone(), t.clone());
}
}
let len_before = new_functions.len();
let mut body_clone = new_functions[idx].body.clone();
{
use crate::visitor::MutVisitor;
let mut visitor = CallSpecializer {
generics: &generics,
locals: &mut local_types,
specs: &mut specs,
new_functions: &mut new_functions,
fn_returns: &fn_returns,
struct_table: &struct_table,
};
visitor.visit_block(&mut body_clone);
}
new_functions[idx].body = body_clone;
if new_functions.len() > len_before {
for new_fn in &new_functions[len_before..] {
let origin = new_fn
.name
.split("__")
.next()
.unwrap_or(&new_fn.name)
.to_string();
*per_fn_counts.entry(origin).or_insert(0) += 1;
}
}
idx += 1;
}
program.functions.extend(new_functions);
let specialized_origins: alloc::collections::BTreeSet<String> =
specs.keys().map(|(name, _)| name.clone()).collect();
program
.functions
.retain(|f| !specialized_origins.contains(&f.name));
program = specialize_structs(program, &fn_returns);
program = specialize_enums(program, &fn_returns);
program
}
fn specialize_enums(mut program: Program, fn_returns: &BTreeMap<String, TypeExpr>) -> Program {
use crate::visitor::MutVisitor;
let generic_enums: BTreeMap<String, EnumDef> = program
.types
.iter()
.filter_map(|td| match td {
TypeDef::Enum(e) if !e.type_params.is_empty() => Some((e.name.clone(), e.clone())),
_ => None,
})
.collect();
if generic_enums.is_empty() {
return program;
}
let mut enum_specs: BTreeMap<(String, String), String> = BTreeMap::new();
let mut new_enums: Vec<EnumDef> = Vec::new();
let mut local_types: BTreeMap<String, TypeExpr> = BTreeMap::new();
for func in &mut program.functions {
local_types.clear();
for param in &func.params {
if let Some(t) = ¶m.type_expr
&& let Pattern::Variable(name, _) = ¶m.pattern
{
local_types.insert(name.clone(), t.clone());
}
}
let mut visitor = EnumSpecializer {
generic_enums: &generic_enums,
locals: &mut local_types,
specs: &mut enum_specs,
new_enums: &mut new_enums,
fn_returns,
};
visitor.visit_block(&mut func.body);
}
program
.types
.extend(new_enums.into_iter().map(TypeDef::Enum));
program
}
struct EnumSpecializer<'a> {
generic_enums: &'a BTreeMap<String, EnumDef>,
locals: &'a mut BTreeMap<String, TypeExpr>,
specs: &'a mut BTreeMap<(String, String), String>,
new_enums: &'a mut Vec<EnumDef>,
fn_returns: &'a BTreeMap<String, TypeExpr>,
}
impl crate::visitor::MutVisitor for EnumSpecializer<'_> {
fn visit_stmt(&mut self, stmt: &mut Stmt) {
if let Stmt::Let(l) = stmt {
self.visit_expr(&mut l.value);
if let Pattern::Variable(name, _) = &l.pattern
&& let Some(t) = l
.type_expr
.clone()
.or_else(|| infer_arg_type(&l.value, self.locals, self.fn_returns, None))
{
self.locals.insert(name.clone(), t);
}
return;
}
self.walk_stmt(stmt);
}
fn visit_expr(&mut self, expr: &mut Expr) {
self.walk_expr(expr);
let Expr::EnumVariant {
enum_name,
variant,
args,
..
} = expr
else {
return;
};
let Some(enum_def) = self.generic_enums.get(enum_name) else {
return;
};
let Some(decl_variant) = enum_def.variants.iter().find(|v| v.name == *variant) else {
return;
};
let mut type_args: Vec<TypeExpr> = Vec::new();
for tp in &enum_def.type_params {
let mut inferred: Option<TypeExpr> = None;
for (i, decl_ty) in decl_variant.fields.iter().enumerate() {
if let TypeExpr::Named(n, _, _) = decl_ty
&& *n == tp.name
&& let Some(arg) = args.get(i)
&& let Some(t) = infer_arg_type(arg, self.locals, self.fn_returns, None)
{
inferred = Some(t);
break;
}
}
match inferred {
Some(t) => type_args.push(t),
None => return,
}
}
if type_args.len() != enum_def.type_params.len() {
return;
}
let key_args: Vec<String> = type_args.iter().map(type_arg_canonical).collect();
let canonical = key_args.join(",");
let cache_key = (enum_name.clone(), canonical);
let spec_name = if let Some(existing) = self.specs.get(&cache_key) {
existing.clone()
} else {
let spec_name = mangle_struct(enum_name, &type_args);
let specialized = specialize_enum(enum_def, &type_args, spec_name.clone());
self.specs.insert(cache_key, spec_name.clone());
self.new_enums.push(specialized);
spec_name
};
if let Expr::EnumVariant { enum_name, .. } = expr {
*enum_name = spec_name;
}
}
}
fn specialize_enum(enum_def: &EnumDef, type_args: &[TypeExpr], spec_name: String) -> EnumDef {
let mut subst: BTreeMap<String, TypeExpr> = BTreeMap::new();
for (tp, arg) in enum_def.type_params.iter().zip(type_args.iter()) {
subst.insert(tp.name.clone(), arg.clone());
}
let variants: Vec<VariantDecl> = enum_def
.variants
.iter()
.map(|v| VariantDecl {
name: v.name.clone(),
fields: v
.fields
.iter()
.map(|t| subst_type_expr(t, &subst))
.collect(),
span: v.span,
})
.collect();
EnumDef {
name: spec_name,
type_params: Vec::new(),
variants,
span: enum_def.span,
}
}
fn specialize_structs(mut program: Program, fn_returns: &BTreeMap<String, TypeExpr>) -> Program {
use crate::visitor::MutVisitor;
let generic_structs: BTreeMap<String, StructDef> = program
.types
.iter()
.filter_map(|td| match td {
TypeDef::Struct(s) if !s.type_params.is_empty() => Some((s.name.clone(), s.clone())),
_ => None,
})
.collect();
if generic_structs.is_empty() {
return program;
}
let mut struct_specs: BTreeMap<(String, String), String> = BTreeMap::new();
let mut new_structs: Vec<StructDef> = Vec::new();
let mut local_types: BTreeMap<String, TypeExpr> = BTreeMap::new();
for func in &mut program.functions {
local_types.clear();
for param in &func.params {
if let Some(t) = ¶m.type_expr
&& let Pattern::Variable(name, _) = ¶m.pattern
{
local_types.insert(name.clone(), t.clone());
}
}
let mut visitor = StructSpecializer {
generic_structs: &generic_structs,
locals: &mut local_types,
specs: &mut struct_specs,
new_structs: &mut new_structs,
fn_returns,
};
visitor.visit_block(&mut func.body);
}
program
.types
.extend(new_structs.into_iter().map(TypeDef::Struct));
program
}
struct StructSpecializer<'a> {
generic_structs: &'a BTreeMap<String, StructDef>,
locals: &'a mut BTreeMap<String, TypeExpr>,
specs: &'a mut BTreeMap<(String, String), String>,
new_structs: &'a mut Vec<StructDef>,
fn_returns: &'a BTreeMap<String, TypeExpr>,
}
impl crate::visitor::MutVisitor for StructSpecializer<'_> {
fn visit_stmt(&mut self, stmt: &mut Stmt) {
if let Stmt::Let(l) = stmt {
self.visit_expr(&mut l.value);
if let Pattern::Variable(name, _) = &l.pattern
&& let Some(t) = l
.type_expr
.clone()
.or_else(|| infer_arg_type(&l.value, self.locals, self.fn_returns, None))
{
self.locals.insert(name.clone(), t);
}
return;
}
self.walk_stmt(stmt);
}
fn visit_expr(&mut self, expr: &mut Expr) {
self.walk_expr(expr);
let Expr::StructInit { name, fields, .. } = expr else {
return;
};
let Some(struct_def) = self.generic_structs.get(name) else {
return;
};
let mut type_args: Vec<TypeExpr> = Vec::new();
for tp in &struct_def.type_params {
let mut inferred: Option<TypeExpr> = None;
for decl_field in &struct_def.fields {
if let TypeExpr::Named(n, _, _) = &decl_field.type_expr
&& *n == tp.name
&& let Some(init) = fields.iter().find(|f| f.name == decl_field.name)
&& let Some(t) = infer_arg_type(&init.value, self.locals, self.fn_returns, None)
{
inferred = Some(t);
break;
}
}
match inferred {
Some(t) => type_args.push(t),
None => return,
}
}
if type_args.len() != struct_def.type_params.len() {
return;
}
let key_args: Vec<String> = type_args.iter().map(type_arg_canonical).collect();
let canonical = key_args.join(",");
let cache_key = (name.clone(), canonical);
let spec_name = if let Some(existing) = self.specs.get(&cache_key) {
existing.clone()
} else {
let spec_name = mangle_struct(name, &type_args);
let specialized = specialize_struct(struct_def, &type_args, spec_name.clone());
self.specs.insert(cache_key, spec_name.clone());
self.new_structs.push(specialized);
spec_name
};
if let Expr::StructInit { name, .. } = expr {
*name = spec_name;
}
}
}
fn mangle_struct(name: &str, type_args: &[TypeExpr]) -> String {
let mut s = name.to_string();
for arg in type_args {
s.push_str("__");
s.push_str(&type_arg_canonical(arg));
}
s
}
fn specialize_struct(
struct_def: &StructDef,
type_args: &[TypeExpr],
spec_name: String,
) -> StructDef {
let mut subst: BTreeMap<String, TypeExpr> = BTreeMap::new();
for (tp, arg) in struct_def.type_params.iter().zip(type_args.iter()) {
subst.insert(tp.name.clone(), arg.clone());
}
let fields: Vec<FieldDecl> = struct_def
.fields
.iter()
.map(|f| FieldDecl {
name: f.name.clone(),
type_expr: subst_type_expr(&f.type_expr, &subst),
span: f.span,
})
.collect();
StructDef {
name: spec_name,
type_params: Vec::new(),
fields,
span: struct_def.span,
}
}
fn mangle(name: &str, type_args: &[TypeExpr]) -> String {
let mut s = name.to_string();
for arg in type_args {
s.push_str("__");
s.push_str(&type_arg_canonical(arg));
}
s
}
fn type_arg_canonical(t: &TypeExpr) -> String {
match t {
TypeExpr::Prim(p, _) => match p {
PrimType::I64 => "i64".to_string(),
PrimType::F64 => "f64".to_string(),
PrimType::Bool => "bool".to_string(),
PrimType::KString => "String".to_string(),
},
TypeExpr::Unit(_) => "unit".to_string(),
TypeExpr::Named(n, args, _) => {
if args.is_empty() {
n.clone()
} else {
let inner: Vec<String> = args.iter().map(type_arg_canonical).collect();
format!("{}_{}", n, inner.join("_"))
}
}
TypeExpr::Tuple(items, _) => {
let inner: Vec<String> = items.iter().map(type_arg_canonical).collect();
format!("tuple_{}", inner.join("_"))
}
TypeExpr::Array(elem, n, _) => {
format!("arr_{}_{}", type_arg_canonical(elem), n)
}
TypeExpr::Option(inner, _) => format!("opt_{}", type_arg_canonical(inner)),
}
}
fn infer_arg_type(
expr: &Expr,
locals: &BTreeMap<String, TypeExpr>,
fn_returns: &BTreeMap<String, TypeExpr>,
structs: Option<&BTreeMap<String, StructDef>>,
) -> Option<TypeExpr> {
match expr {
Expr::Literal { value, span } => Some(match value {
Literal::Int(_) => TypeExpr::Prim(PrimType::I64, *span),
Literal::Float(_) => TypeExpr::Prim(PrimType::F64, *span),
Literal::Bool(_) => TypeExpr::Prim(PrimType::Bool, *span),
Literal::String(_) => TypeExpr::Prim(PrimType::KString, *span),
Literal::Unit => TypeExpr::Unit(*span),
}),
Expr::Ident { name, .. } => locals.get(name).cloned(),
Expr::StructInit { name, span, .. } => {
Some(TypeExpr::Named(name.clone(), Vec::new(), *span))
}
Expr::EnumVariant {
enum_name, span, ..
} => Some(TypeExpr::Named(enum_name.clone(), Vec::new(), *span)),
Expr::Call { name, .. } => fn_returns.get(name).cloned(),
Expr::Cast { target, .. } => Some(target.clone()),
Expr::TupleLiteral { elements, span } => {
let parts: Option<Vec<TypeExpr>> = elements
.iter()
.map(|e| infer_arg_type(e, locals, fn_returns, structs))
.collect();
parts.map(|p| TypeExpr::Tuple(p, *span))
}
Expr::ArrayLiteral { elements, span } => {
let elem = elements.first()?;
let elem_ty = infer_arg_type(elem, locals, fn_returns, structs)?;
Some(TypeExpr::Array(
Box::new(elem_ty),
elements.len() as i64,
*span,
))
}
Expr::If {
then_block,
else_block,
..
} => {
let tail = then_block.tail_expr.as_ref()?;
infer_arg_type(tail, locals, fn_returns, structs).or_else(|| {
else_block
.as_ref()
.and_then(|b| b.tail_expr.as_ref())
.and_then(|e| infer_arg_type(e, locals, fn_returns, structs))
})
}
Expr::Match { arms, .. } => {
let first = arms.first()?;
infer_arg_type(&first.expr, locals, fn_returns, structs)
}
Expr::TupleIndex { object, index, .. } => {
let obj_ty = infer_arg_type(object, locals, fn_returns, structs)?;
if let TypeExpr::Tuple(elements, _) = obj_ty {
let idx = *index as usize;
elements.get(idx).cloned()
} else {
None
}
}
Expr::ArrayIndex { object, .. } => {
let obj_ty = infer_arg_type(object, locals, fn_returns, structs)?;
if let TypeExpr::Array(elem, _, _) = obj_ty {
Some(*elem)
} else {
None
}
}
Expr::FieldAccess { object, field, .. } => {
let structs = structs?;
let obj_ty = infer_arg_type(object, locals, fn_returns, Some(structs))?;
let (struct_name, type_args) = match obj_ty {
TypeExpr::Named(name, args, _) => (name, args),
_ => return None,
};
let struct_def = structs.get(&struct_name)?;
let field_decl = struct_def.fields.iter().find(|f| f.name == *field)?;
if struct_def.type_params.len() == type_args.len() && !type_args.is_empty() {
let subst: BTreeMap<String, TypeExpr> = struct_def
.type_params
.iter()
.zip(type_args.iter())
.map(|(tp, arg)| (tp.name.clone(), arg.clone()))
.collect();
Some(subst_type_expr(&field_decl.type_expr, &subst))
} else {
if let TypeExpr::Named(field_name, field_args, _) = &field_decl.type_expr
&& field_args.is_empty()
&& struct_def
.type_params
.iter()
.any(|tp| tp.name == *field_name)
{
return None;
}
Some(field_decl.type_expr.clone())
}
}
Expr::UnaryOp { op, operand, .. } => {
match op {
UnaryOp::Neg => infer_arg_type(operand, locals, fn_returns, structs),
UnaryOp::Not => Some(TypeExpr::Prim(PrimType::Bool, operand.span())),
}
}
Expr::BinOp { op, left, span, .. } => {
match op {
BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Mod => {
infer_arg_type(left, locals, fn_returns, structs)
}
BinOp::Eq
| BinOp::NotEq
| BinOp::Lt
| BinOp::Gt
| BinOp::LtEq
| BinOp::GtEq
| BinOp::And
| BinOp::Or => Some(TypeExpr::Prim(PrimType::Bool, *span)),
}
}
Expr::MethodCall {
receiver, method, ..
} => {
let recv_ty = infer_arg_type(receiver, locals, fn_returns, structs)?;
let head = type_head_for_impl(&recv_ty);
let key = alloc::format!("{}::{}", head, method);
fn_returns.get(&key).cloned()
}
_ => None,
}
}
fn type_head_for_impl(ty: &TypeExpr) -> String {
use alloc::string::ToString;
match ty {
TypeExpr::Prim(p, _) => match p {
PrimType::I64 => "i64".to_string(),
PrimType::F64 => "f64".to_string(),
PrimType::Bool => "bool".to_string(),
PrimType::KString => "String".to_string(),
},
TypeExpr::Unit(_) => "()".to_string(),
TypeExpr::Named(name, _, _) => name.clone(),
TypeExpr::Tuple(_, _) => "tuple".to_string(),
TypeExpr::Array(_, _, _) => "array".to_string(),
TypeExpr::Option(_, _) => "Option".to_string(),
}
}
fn subst_type_expr(t: &TypeExpr, subst: &BTreeMap<String, TypeExpr>) -> TypeExpr {
match t {
TypeExpr::Prim(_, _) | TypeExpr::Unit(_) => t.clone(),
TypeExpr::Named(name, args, span) => {
if args.is_empty()
&& let Some(replacement) = subst.get(name)
{
return replacement.clone();
}
TypeExpr::Named(
name.clone(),
args.iter().map(|a| subst_type_expr(a, subst)).collect(),
*span,
)
}
TypeExpr::Tuple(items, span) => TypeExpr::Tuple(
items.iter().map(|t| subst_type_expr(t, subst)).collect(),
*span,
),
TypeExpr::Array(elem, n, span) => {
TypeExpr::Array(Box::new(subst_type_expr(elem, subst)), *n, *span)
}
TypeExpr::Option(inner, span) => {
TypeExpr::Option(Box::new(subst_type_expr(inner, subst)), *span)
}
}
}
fn specialize_function(
func: &FunctionDef,
type_args: &[TypeExpr],
spec_name: String,
) -> FunctionDef {
let mut subst: BTreeMap<String, TypeExpr> = BTreeMap::new();
for (tp, arg) in func.type_params.iter().zip(type_args.iter()) {
subst.insert(tp.name.clone(), arg.clone());
}
let params: Vec<Param> = func
.params
.iter()
.map(|p| Param {
pattern: p.pattern.clone(),
type_expr: p.type_expr.as_ref().map(|t| subst_type_expr(t, &subst)),
span: p.span,
})
.collect();
let return_type = subst_type_expr(&func.return_type, &subst);
let body = subst_in_block(&func.body, &subst);
FunctionDef {
category: func.category,
name: spec_name,
type_params: Vec::new(),
params,
return_type,
guard: func.guard.clone(),
body,
span: func.span,
}
}
fn subst_in_block(block: &Block, subst: &BTreeMap<String, TypeExpr>) -> Block {
Block {
stmts: block
.stmts
.iter()
.map(|s| subst_in_stmt(s, subst))
.collect(),
tail_expr: block
.tail_expr
.as_ref()
.map(|e| Box::new(subst_in_expr(e, subst))),
span: block.span,
}
}
fn subst_in_stmt(stmt: &Stmt, subst: &BTreeMap<String, TypeExpr>) -> Stmt {
match stmt {
Stmt::Let(l) => Stmt::Let(LetStmt {
pattern: l.pattern.clone(),
type_expr: l.type_expr.as_ref().map(|t| subst_type_expr(t, subst)),
value: subst_in_expr(&l.value, subst),
span: l.span,
}),
Stmt::For(f) => Stmt::For(ForStmt {
var: f.var.clone(),
iterable: subst_in_iterable(&f.iterable, subst),
body: subst_in_block(&f.body, subst),
span: f.span,
}),
Stmt::Break(span) => Stmt::Break(*span),
Stmt::DataFieldAssign {
data_name,
field,
value,
span,
} => Stmt::DataFieldAssign {
data_name: data_name.clone(),
field: field.clone(),
value: subst_in_expr(value, subst),
span: *span,
},
Stmt::Expr(e) => Stmt::Expr(subst_in_expr(e, subst)),
}
}
fn subst_in_iterable(it: &Iterable, subst: &BTreeMap<String, TypeExpr>) -> Iterable {
match it {
Iterable::Range(start, end) => Iterable::Range(
Box::new(subst_in_expr(start, subst)),
Box::new(subst_in_expr(end, subst)),
),
Iterable::Expr(e) => Iterable::Expr(subst_in_expr(e, subst)),
}
}
fn subst_in_expr(expr: &Expr, subst: &BTreeMap<String, TypeExpr>) -> Expr {
match expr {
Expr::Literal { value, span } => Expr::Literal {
value: value.clone(),
span: *span,
},
Expr::Ident { name, span } => Expr::Ident {
name: name.clone(),
span: *span,
},
Expr::BinOp {
op,
left,
right,
span,
} => Expr::BinOp {
op: *op,
left: Box::new(subst_in_expr(left, subst)),
right: Box::new(subst_in_expr(right, subst)),
span: *span,
},
Expr::UnaryOp { op, operand, span } => Expr::UnaryOp {
op: *op,
operand: Box::new(subst_in_expr(operand, subst)),
span: *span,
},
Expr::Call { name, args, span } => Expr::Call {
name: name.clone(),
args: args.iter().map(|a| subst_in_expr(a, subst)).collect(),
span: *span,
},
Expr::Pipeline {
left,
func,
args,
span,
} => Expr::Pipeline {
left: Box::new(subst_in_expr(left, subst)),
func: func.clone(),
args: args.iter().map(|a| subst_in_expr(a, subst)).collect(),
span: *span,
},
Expr::Yield { value, span } => Expr::Yield {
value: Box::new(subst_in_expr(value, subst)),
span: *span,
},
Expr::If {
condition,
then_block,
else_block,
span,
} => Expr::If {
condition: Box::new(subst_in_expr(condition, subst)),
then_block: subst_in_block(then_block, subst),
else_block: else_block.as_ref().map(|b| subst_in_block(b, subst)),
span: *span,
},
Expr::Match {
scrutinee,
arms,
span,
} => Expr::Match {
scrutinee: Box::new(subst_in_expr(scrutinee, subst)),
arms: arms
.iter()
.map(|a| MatchArm {
pattern: a.pattern.clone(),
expr: subst_in_expr(&a.expr, subst),
span: a.span,
})
.collect(),
span: *span,
},
Expr::Loop { body, span } => Expr::Loop {
body: subst_in_block(body, subst),
span: *span,
},
Expr::FieldAccess {
object,
field,
span,
} => Expr::FieldAccess {
object: Box::new(subst_in_expr(object, subst)),
field: field.clone(),
span: *span,
},
Expr::MethodCall {
receiver,
method,
args,
span,
} => Expr::MethodCall {
receiver: Box::new(subst_in_expr(receiver, subst)),
method: method.clone(),
args: args.iter().map(|a| subst_in_expr(a, subst)).collect(),
span: *span,
},
Expr::TupleIndex {
object,
index,
span,
} => Expr::TupleIndex {
object: Box::new(subst_in_expr(object, subst)),
index: *index,
span: *span,
},
Expr::ArrayIndex {
object,
index,
span,
} => Expr::ArrayIndex {
object: Box::new(subst_in_expr(object, subst)),
index: Box::new(subst_in_expr(index, subst)),
span: *span,
},
Expr::StructInit { name, fields, span } => Expr::StructInit {
name: name.clone(),
fields: fields
.iter()
.map(|f| FieldInit {
name: f.name.clone(),
value: subst_in_expr(&f.value, subst),
span: f.span,
})
.collect(),
span: *span,
},
Expr::EnumVariant {
enum_name,
variant,
args,
span,
} => Expr::EnumVariant {
enum_name: enum_name.clone(),
variant: variant.clone(),
args: args.iter().map(|a| subst_in_expr(a, subst)).collect(),
span: *span,
},
Expr::ArrayLiteral { elements, span } => Expr::ArrayLiteral {
elements: elements.iter().map(|e| subst_in_expr(e, subst)).collect(),
span: *span,
},
Expr::TupleLiteral { elements, span } => Expr::TupleLiteral {
elements: elements.iter().map(|e| subst_in_expr(e, subst)).collect(),
span: *span,
},
Expr::Cast { expr, target, span } => Expr::Cast {
expr: Box::new(subst_in_expr(expr, subst)),
target: subst_type_expr(target, subst),
span: *span,
},
Expr::Placeholder { span } => Expr::Placeholder { span: *span },
Expr::Closure {
params,
return_type,
body,
span,
} => Expr::Closure {
params: params
.iter()
.map(|p| Param {
pattern: p.pattern.clone(),
type_expr: p.type_expr.as_ref().map(|t| subst_type_expr(t, subst)),
span: p.span,
})
.collect(),
return_type: return_type.as_ref().map(|t| subst_type_expr(t, subst)),
body: subst_in_block(body, subst),
span: *span,
},
Expr::ClosureRef {
name,
captures,
recursive,
span,
} => Expr::ClosureRef {
name: name.clone(),
captures: captures.clone(),
recursive: *recursive,
span: *span,
},
}
}
struct CallSpecializer<'a> {
generics: &'a BTreeMap<String, FunctionDef>,
locals: &'a mut BTreeMap<String, TypeExpr>,
specs: &'a mut BTreeMap<(String, String), String>,
new_functions: &'a mut Vec<FunctionDef>,
fn_returns: &'a BTreeMap<String, TypeExpr>,
struct_table: &'a BTreeMap<String, StructDef>,
}
impl crate::visitor::MutVisitor for CallSpecializer<'_> {
fn visit_stmt(&mut self, stmt: &mut Stmt) {
if let Stmt::Let(l) = stmt {
self.visit_expr(&mut l.value);
if let Pattern::Variable(name, _) = &l.pattern
&& let Some(t) = l.type_expr.clone().or_else(|| {
infer_arg_type(
&l.value,
self.locals,
self.fn_returns,
Some(self.struct_table),
)
})
{
self.locals.insert(name.clone(), t);
}
return;
}
self.walk_stmt(stmt);
}
fn visit_expr(&mut self, expr: &mut Expr) {
self.walk_expr(expr);
let Expr::Call { name, args, .. } = expr else {
return;
};
let Some(generic_func) = self.generics.get(name) else {
return;
};
let mut type_args: Vec<TypeExpr> = Vec::new();
for tp in &generic_func.type_params {
let mut inferred: Option<TypeExpr> = None;
for (param_idx, param) in generic_func.params.iter().enumerate() {
if let Some(TypeExpr::Named(n, _, _)) = ¶m.type_expr
&& *n == tp.name
&& let Some(arg) = args.get(param_idx)
&& let Some(t) =
infer_arg_type(arg, self.locals, self.fn_returns, Some(self.struct_table))
{
inferred = Some(t);
break;
}
}
match inferred {
Some(t) => type_args.push(t),
None => return,
}
}
if type_args.len() != generic_func.type_params.len() {
return;
}
let key_args: Vec<String> = type_args.iter().map(type_arg_canonical).collect();
let canonical = key_args.join(",");
let cache_key = (name.clone(), canonical);
let spec_name = if let Some(existing) = self.specs.get(&cache_key) {
existing.clone()
} else {
let spec_name = mangle(name, &type_args);
let specialized = specialize_function(generic_func, &type_args, spec_name.clone());
self.specs.insert(cache_key, spec_name.clone());
self.new_functions.push(specialized);
spec_name
};
if let Expr::Call { name, .. } = expr {
*name = spec_name;
}
}
}