use std::collections::{HashMap, HashSet};
use crate::error::CompilerError;
use crate::ir::{IrExpr, IrFunction, IrModule, ResolvedType};
use crate::location::Span;
use super::expr_walk::{iter_expr_children_mut, walk_expr};
use super::specialise::{substitute_expr_types, substitute_type, type_suffix};
type FunctionSpec = (String, Vec<ResolvedType>);
pub(super) fn specialise_generic_functions(
module: &mut IrModule,
) -> Result<(), Vec<CompilerError>> {
let mut fn_mapping: HashMap<FunctionSpec, String> = HashMap::new();
let mut errors: Vec<CompilerError> = Vec::new();
let generic_fn_names: HashSet<String> = module
.functions
.iter()
.filter(|f| !f.generic_params.is_empty())
.map(|f| f.name.clone())
.collect();
if generic_fn_names.is_empty() {
return Ok(());
}
let mut worklist: Vec<FunctionSpec> = Vec::new();
collect_generic_fn_call_specs(module, &generic_fn_names, &mut worklist);
while let Some(spec) = worklist.pop() {
if fn_mapping.contains_key(&spec) {
continue;
}
match specialise_function(module, &spec.0, &spec.1) {
Ok((mangled_name, discovered)) => {
fn_mapping.insert(spec, mangled_name);
for d in discovered {
if !fn_mapping.contains_key(&d) {
worklist.push(d);
}
}
}
Err(e) => errors.push(e),
}
}
rewrite_function_call_paths(module, &fn_mapping, &generic_fn_names);
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
fn collect_generic_fn_call_specs(
module: &IrModule,
generic_fn_names: &HashSet<String>,
out: &mut Vec<FunctionSpec>,
) {
let mut visit = |expr: &IrExpr| {
if let IrExpr::FunctionCall { path, args, .. } = expr {
if let Some(name) = path.last() {
if generic_fn_names.contains(name) {
if let Some(func) = module.functions.iter().find(|f| f.name == *name) {
if let Some(type_args) = infer_call_type_args(func, args) {
out.push((name.clone(), type_args));
}
}
}
}
}
};
for f in &module.functions {
if let Some(body) = &f.body {
walk_expr(body, &mut visit);
}
}
for imp in &module.impls {
for f in &imp.functions {
if let Some(body) = &f.body {
walk_expr(body, &mut visit);
}
}
}
for l in &module.lets {
walk_expr(&l.value, &mut visit);
}
}
fn infer_call_type_args(
func: &IrFunction,
call_args: &[(Option<String>, IrExpr)],
) -> Option<Vec<ResolvedType>> {
let mut subs: HashMap<String, ResolvedType> = HashMap::new();
for (i, param) in func.params.iter().enumerate() {
let Some(declared) = ¶m.ty else { continue };
let arg_expr = call_args
.iter()
.find_map(|(n, e)| n.as_ref().filter(|name| **name == param.name).map(|_| e))
.or_else(|| call_args.get(i).map(|(_, e)| e))?;
unify_types(declared, arg_expr.ty(), &mut subs);
}
let mut out = Vec::with_capacity(func.generic_params.len());
for gp in &func.generic_params {
let concrete = subs.get(&gp.name)?;
if contains_type_param(concrete) {
return None;
}
out.push(concrete.clone());
}
Some(out)
}
fn unify_types(param: &ResolvedType, arg: &ResolvedType, subs: &mut HashMap<String, ResolvedType>) {
match (param, arg) {
(ResolvedType::TypeParam(name), concrete) => {
subs.entry(name.clone()).or_insert_with(|| concrete.clone());
}
(ResolvedType::Array(p), ResolvedType::Array(a))
| (ResolvedType::Range(p), ResolvedType::Range(a))
| (ResolvedType::Optional(p), ResolvedType::Optional(a)) => {
unify_types(p, a, subs);
}
(ResolvedType::Tuple(ps), ResolvedType::Tuple(as_)) => {
for ((_, p), (_, a)) in ps.iter().zip(as_.iter()) {
unify_types(p, a, subs);
}
}
(
ResolvedType::Dictionary {
key_ty: pk,
value_ty: pv,
},
ResolvedType::Dictionary {
key_ty: ak,
value_ty: av,
},
) => {
unify_types(pk, ak, subs);
unify_types(pv, av, subs);
}
(
ResolvedType::Closure {
param_tys: pp,
return_ty: pr,
},
ResolvedType::Closure {
param_tys: ap,
return_ty: ar,
},
) => {
for ((_, p), (_, a)) in pp.iter().zip(ap.iter()) {
unify_types(p, a, subs);
}
unify_types(pr, ar, subs);
}
(
ResolvedType::Generic { base: pb, args: pa },
ResolvedType::Generic { base: ab, args: aa },
) if pb == ab => {
for (p, a) in pa.iter().zip(aa.iter()) {
unify_types(p, a, subs);
}
}
_ => {}
}
}
fn contains_type_param(ty: &ResolvedType) -> bool {
match ty {
ResolvedType::TypeParam(_) => true,
ResolvedType::Array(inner) | ResolvedType::Range(inner) | ResolvedType::Optional(inner) => {
contains_type_param(inner)
}
ResolvedType::Tuple(fields) => fields.iter().any(|(_, t)| contains_type_param(t)),
ResolvedType::Dictionary { key_ty, value_ty } => {
contains_type_param(key_ty) || contains_type_param(value_ty)
}
ResolvedType::Closure {
param_tys,
return_ty,
} => {
param_tys.iter().any(|(_, t)| contains_type_param(t)) || contains_type_param(return_ty)
}
ResolvedType::Generic { args, .. } => args.iter().any(contains_type_param),
ResolvedType::External { type_args, .. } => type_args.iter().any(contains_type_param),
ResolvedType::Primitive(_)
| ResolvedType::Struct(_)
| ResolvedType::Trait(_)
| ResolvedType::Enum(_)
| ResolvedType::Error => false,
}
}
#[expect(
clippy::result_large_err,
reason = "CompilerError is large by design; errors are bounded to a Vec<CompilerError> at the pass boundary"
)]
fn specialise_function(
module: &mut IrModule,
name: &str,
args: &[ResolvedType],
) -> Result<(String, Vec<FunctionSpec>), CompilerError> {
let Some(source) = module.functions.iter().find(|f| f.name == name).cloned() else {
return Err(CompilerError::InternalError {
detail: format!("monomorphise: missing generic function `{name}`"),
span: Span::default(),
});
};
if source.generic_params.len() != args.len() {
return Err(CompilerError::GenericArityMismatch {
name: source.name.clone(),
expected: source.generic_params.len(),
actual: args.len(),
span: Span::default(),
});
}
let subs: HashMap<String, ResolvedType> = source
.generic_params
.iter()
.zip(args.iter())
.map(|(p, a)| (p.name.clone(), a.clone()))
.collect();
let mangled = mangle_function_name(&source.name, args, module);
let mut spec = source;
spec.name.clone_from(&mangled);
spec.generic_params.clear();
for param in &mut spec.params {
if let Some(t) = &mut param.ty {
substitute_type(t, &subs);
}
if let Some(default) = &mut param.default {
substitute_expr_types(default, &subs);
}
}
if let Some(rt) = &mut spec.return_type {
substitute_type(rt, &subs);
}
if let Some(body) = &mut spec.body {
substitute_expr_types(body, &subs);
}
let generic_fn_names: HashSet<String> = module
.functions
.iter()
.filter(|f| !f.generic_params.is_empty())
.map(|f| f.name.clone())
.collect();
let mut discovered: Vec<FunctionSpec> = Vec::new();
if let Some(body) = &spec.body {
let mut visit = |expr: &IrExpr| {
if let IrExpr::FunctionCall { path, args: a, .. } = expr {
if let Some(callee) = path.last() {
if generic_fn_names.contains(callee) {
if let Some(callee_fn) = module.functions.iter().find(|f| f.name == *callee)
{
if let Some(type_args) = infer_call_type_args(callee_fn, a) {
discovered.push((callee.clone(), type_args));
}
}
}
}
}
};
walk_expr(body, &mut visit);
}
module.add_function(mangled.clone(), spec)?;
Ok((mangled, discovered))
}
fn mangle_function_name(base: &str, args: &[ResolvedType], module: &IrModule) -> String {
let mut out = base.to_string();
for a in args {
out.push_str("__");
type_suffix(a, &mut out);
}
if module.function_id(&out).is_none() {
return out;
}
let mut n: u32 = 2;
loop {
let candidate = format!("{out}#{n}");
if module.function_id(&candidate).is_none() {
return candidate;
}
n = n.saturating_add(1);
if n == u32::MAX {
return candidate;
}
}
}
fn rewrite_function_call_paths(
module: &mut IrModule,
fn_mapping: &HashMap<FunctionSpec, String>,
generic_fn_names: &HashSet<String>,
) {
let snapshot: Vec<IrFunction> = module.functions.clone();
for f in &mut module.functions {
if let Some(body) = &mut f.body {
rewrite_call_paths_expr(body, fn_mapping, generic_fn_names, &snapshot);
}
for param in &mut f.params {
if let Some(default) = &mut param.default {
rewrite_call_paths_expr(default, fn_mapping, generic_fn_names, &snapshot);
}
}
}
for imp in &mut module.impls {
for f in &mut imp.functions {
if let Some(body) = &mut f.body {
rewrite_call_paths_expr(body, fn_mapping, generic_fn_names, &snapshot);
}
for param in &mut f.params {
if let Some(default) = &mut param.default {
rewrite_call_paths_expr(default, fn_mapping, generic_fn_names, &snapshot);
}
}
}
}
for s in &mut module.structs {
for field in &mut s.fields {
if let Some(default) = &mut field.default {
rewrite_call_paths_expr(default, fn_mapping, generic_fn_names, &snapshot);
}
}
}
for e in &mut module.enums {
for variant in &mut e.variants {
for field in &mut variant.fields {
if let Some(default) = &mut field.default {
rewrite_call_paths_expr(default, fn_mapping, generic_fn_names, &snapshot);
}
}
}
}
for l in &mut module.lets {
rewrite_call_paths_expr(&mut l.value, fn_mapping, generic_fn_names, &snapshot);
}
}
fn rewrite_call_paths_expr(
expr: &mut IrExpr,
fn_mapping: &HashMap<FunctionSpec, String>,
generic_fn_names: &HashSet<String>,
snapshot: &[IrFunction],
) {
for child in iter_expr_children_mut(expr) {
rewrite_call_paths_expr(child, fn_mapping, generic_fn_names, snapshot);
}
if let IrExpr::FunctionCall {
path,
function_id: _,
args,
ty,
..
} = expr
{
let Some(last) = path.last() else { return };
if !generic_fn_names.contains(last) {
return;
}
let Some(callee) = snapshot.iter().find(|f| f.name == *last) else {
return;
};
let Some(type_args) = infer_call_type_args(callee, args) else {
return;
};
if let Some(specialised) = fn_mapping.get(&(last.clone(), type_args.clone())) {
if let Some(seg) = path.last_mut() {
seg.clone_from(specialised);
}
let subs: HashMap<String, ResolvedType> = callee
.generic_params
.iter()
.zip(type_args.iter())
.map(|(p, a)| (p.name.clone(), a.clone()))
.collect();
substitute_type(ty, &subs);
}
}
}