use std::collections::HashMap;
use std::sync::Arc;
use crate::ast::{Expr, FnBody, FnDef, Literal, MatchArm, Pattern, Spanned, Stmt, TailCallData};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BufferBuildKind {
InternalReverse,
ExternalReverse,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BufferBuildShape {
pub acc_param_idx: usize,
pub acc_param_name: String,
pub kind: BufferBuildKind,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConsumerKind {
StringJoin,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FusionSite {
pub enclosing_fn: String,
pub line: usize,
pub sink_fn: String,
pub consumer: ConsumerKind,
}
pub fn compute_buffer_build_sinks(fns: &[&FnDef]) -> HashMap<String, BufferBuildShape> {
let mut out = HashMap::new();
for fd in fns {
if let Some(shape) = match_buffer_build_shape(fd) {
out.insert(fd.name.clone(), shape);
}
}
out
}
pub fn find_fusion_sites(
fns: &[&FnDef],
sinks: &HashMap<String, BufferBuildShape>,
) -> Vec<FusionSite> {
let mut out = Vec::new();
for fd in fns {
for stmt in fd.body.stmts() {
match stmt {
Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
walk_expr_for_fusion_sites(&expr.node, expr.line, &fd.name, sinks, &mut out);
}
}
}
}
out
}
fn walk_expr_for_fusion_sites(
expr: &Expr,
expr_line: usize,
enclosing_fn: &str,
sinks: &HashMap<String, BufferBuildShape>,
out: &mut Vec<FusionSite>,
) {
if let Some(inner_name) = match_string_join_fusion_site(expr, sinks) {
out.push(FusionSite {
enclosing_fn: enclosing_fn.to_string(),
line: expr_line,
sink_fn: inner_name,
consumer: ConsumerKind::StringJoin,
});
}
visit_subexprs(expr, expr_line, enclosing_fn, sinks, out);
}
fn visit_subexprs(
expr: &Expr,
fallback_line: usize,
enclosing_fn: &str,
sinks: &HashMap<String, BufferBuildShape>,
out: &mut Vec<FusionSite>,
) {
let line_of = |s: &crate::ast::Spanned<Expr>| {
if s.line > 0 { s.line } else { fallback_line }
};
match expr {
Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
Expr::Constructor(_, Some(inner)) | Expr::Attr(inner, _) | Expr::ErrorProp(inner) => {
walk_expr_for_fusion_sites(&inner.node, line_of(inner), enclosing_fn, sinks, out);
}
Expr::FnCall(callee, args) => {
walk_expr_for_fusion_sites(&callee.node, line_of(callee), enclosing_fn, sinks, out);
for a in args {
walk_expr_for_fusion_sites(&a.node, line_of(a), enclosing_fn, sinks, out);
}
}
Expr::TailCall(data) => {
for a in &data.args {
walk_expr_for_fusion_sites(&a.node, line_of(a), enclosing_fn, sinks, out);
}
}
Expr::BinOp(_, l, r) => {
walk_expr_for_fusion_sites(&l.node, line_of(l), enclosing_fn, sinks, out);
walk_expr_for_fusion_sites(&r.node, line_of(r), enclosing_fn, sinks, out);
}
Expr::Match { subject, arms } => {
walk_expr_for_fusion_sites(&subject.node, line_of(subject), enclosing_fn, sinks, out);
for arm in arms {
walk_expr_for_fusion_sites(
&arm.body.node,
line_of(&arm.body),
enclosing_fn,
sinks,
out,
);
}
}
Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
for it in items {
walk_expr_for_fusion_sites(&it.node, line_of(it), enclosing_fn, sinks, out);
}
}
Expr::MapLiteral(entries) => {
for (k, v) in entries {
walk_expr_for_fusion_sites(&k.node, line_of(k), enclosing_fn, sinks, out);
walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
}
}
Expr::RecordCreate { fields, .. } => {
for (_, v) in fields {
walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
}
}
Expr::RecordUpdate { base, updates, .. } => {
walk_expr_for_fusion_sites(&base.node, line_of(base), enclosing_fn, sinks, out);
for (_, v) in updates {
walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
}
}
Expr::InterpolatedStr(parts) => {
for part in parts {
if let crate::ast::StrPart::Parsed(inner) = part {
walk_expr_for_fusion_sites(
&inner.node,
line_of(inner),
enclosing_fn,
sinks,
out,
);
}
}
}
}
}
fn match_buffer_build_shape(fd: &FnDef) -> Option<BufferBuildShape> {
let (acc_idx, acc_name) = fd
.params
.iter()
.enumerate()
.rfind(|(_, (_, ty))| is_list_type_str(ty))
.map(|(i, (name, _))| (i, name.clone()))?;
let match_expr = single_match_body(&fd.body)?;
let (subject_expr, arms) = match match_expr {
Expr::Match { subject, arms } => (subject, arms),
_ => return None,
};
if let Some((true_body, false_body)) = pair_bool_arms(arms) {
let _ = subject_expr;
if is_list_reverse_of(true_body, &acc_name)
&& is_self_tail_with_prepend_acc(false_body, &fd.name, acc_idx, &acc_name)
{
return Some(BufferBuildShape {
acc_param_idx: acc_idx,
acc_param_name: acc_name,
kind: BufferBuildKind::InternalReverse,
});
}
}
if let Some((nil_body, cons_body)) = pair_nil_cons_arms(arms)
&& is_ident_named(nil_body, &acc_name)
&& is_self_tail_with_prepend_acc(cons_body, &fd.name, acc_idx, &acc_name)
{
return Some(BufferBuildShape {
acc_param_idx: acc_idx,
acc_param_name: acc_name,
kind: BufferBuildKind::ExternalReverse,
});
}
None
}
fn pair_nil_cons_arms(arms: &[MatchArm]) -> Option<(&Expr, &Expr)> {
if arms.len() != 2 {
return None;
}
let mut nil_body: Option<&Expr> = None;
let mut cons_body: Option<&Expr> = None;
for arm in arms {
match &arm.pattern {
Pattern::EmptyList => nil_body = Some(&arm.body.node),
Pattern::Cons(_, _) => cons_body = Some(&arm.body.node),
_ => return None,
}
}
match (nil_body, cons_body) {
(Some(n), Some(c)) => Some((n, c)),
_ => None,
}
}
fn is_ident_named(expr: &Expr, name: &str) -> bool {
matches!(expr, Expr::Ident(n) if n == name)
}
fn match_string_join_fusion_site(
expr: &Expr,
sinks: &HashMap<String, BufferBuildShape>,
) -> Option<String> {
let Expr::FnCall(callee, args) = expr else {
return None;
};
if !is_dotted_ident(&callee.node, "String", "join") || args.len() != 2 {
return None;
}
let consumer_arg = &args[0].node;
let (inner_call_expr, saw_external_reverse) = match consumer_arg {
Expr::FnCall(rev_callee, rev_args)
if is_dotted_ident(&rev_callee.node, "List", "reverse") && rev_args.len() == 1 =>
{
(&rev_args[0].node, true)
}
other => (other, false),
};
let Expr::FnCall(inner_callee, inner_args) = inner_call_expr else {
return None;
};
let Expr::Ident(name) = &inner_callee.node else {
return None;
};
let shape = sinks.get(name)?;
let kinds_align = matches!(
(saw_external_reverse, &shape.kind),
(false, BufferBuildKind::InternalReverse) | (true, BufferBuildKind::ExternalReverse)
);
if !kinds_align {
return None;
}
let acc_arg = inner_args.get(shape.acc_param_idx)?;
if !matches!(&acc_arg.node, Expr::List(items) if items.is_empty()) {
return None;
}
Some(name.clone())
}
fn is_list_type_str(ty: &str) -> bool {
let t = ty.trim();
t.starts_with("List<") && t.ends_with('>')
}
fn single_match_body(body: &FnBody) -> Option<&Expr> {
let stmts = body.stmts();
if stmts.len() != 1 {
return None;
}
match &stmts[0] {
Stmt::Expr(spanned) => match &spanned.node {
Expr::Match { .. } => Some(&spanned.node),
_ => None,
},
Stmt::Binding(_, _, _) => None,
}
}
fn pair_bool_arms(arms: &[MatchArm]) -> Option<(&Expr, &Expr)> {
if arms.len() != 2 {
return None;
}
let mut t = None;
let mut f = None;
for arm in arms {
match &arm.pattern {
Pattern::Literal(Literal::Bool(true)) => {
if t.is_some() {
return None;
}
t = Some(&arm.body.node);
}
Pattern::Literal(Literal::Bool(false)) => {
if f.is_some() {
return None;
}
f = Some(&arm.body.node);
}
_ => return None,
}
}
Some((t?, f?))
}
fn is_list_reverse_of(expr: &Expr, acc_name: &str) -> bool {
let (callee, args) = match expr {
Expr::FnCall(c, a) => (c, a),
_ => return false,
};
if !is_dotted_ident(&callee.node, "List", "reverse") {
return false;
}
if args.len() != 1 {
return false;
}
matches!(&args[0].node, Expr::Ident(name) if name == acc_name)
}
fn is_self_tail_with_prepend_acc(
expr: &Expr,
self_name: &str,
acc_idx: usize,
acc_name: &str,
) -> bool {
let data = match expr {
Expr::TailCall(data) => data,
_ => return false,
};
if data.target != self_name {
return false;
}
let acc_arg = match data.args.get(acc_idx) {
Some(a) => a,
None => return false,
};
is_list_prepend_to_acc(&acc_arg.node, acc_name)
}
fn is_list_prepend_to_acc(expr: &Expr, acc_name: &str) -> bool {
let (callee, args) = match expr {
Expr::FnCall(c, a) => (c, a),
_ => return false,
};
if !is_dotted_ident(&callee.node, "List", "prepend") {
return false;
}
if args.len() != 2 {
return false;
}
matches!(&args[1].node, Expr::Ident(name) if name == acc_name)
}
fn is_dotted_ident(expr: &Expr, module: &str, member: &str) -> bool {
let (base, attr) = match expr {
Expr::Attr(b, a) => (b, a),
_ => return false,
};
if attr != member {
return false;
}
matches!(&base.node, Expr::Ident(name) if name == module)
}
pub fn synthesize_buffered_variants(
fns: &[&FnDef],
sinks: &HashMap<String, BufferBuildShape>,
) -> Vec<FnDef> {
let mut out = Vec::new();
for fd in fns {
if let Some(shape) = sinks.get(&fd.name)
&& let Some(buffered) = build_buffered_variant(fd, shape)
{
out.push(buffered);
}
}
out
}
fn sp_at(line: usize, expr: Expr) -> Spanned<Expr> {
Spanned::new(expr, line)
}
fn sp_at_typed(line: usize, expr: Expr, ty: crate::types::Type) -> Spanned<Expr> {
let s = Spanned::new(expr, line);
s.set_ty(ty);
s
}
fn intrinsic_call(line: usize, name: &str, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
let callee = sp_at(line, Expr::Ident(name.to_string()));
sp_at(line, Expr::FnCall(Box::new(callee), args))
}
fn buffer_intrinsic_call(line: usize, name: &str, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
let call = intrinsic_call(line, name, args);
call.set_ty(crate::types::Type::Named("Buffer".to_string()));
call
}
fn finalize_intrinsic_call(line: usize, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
let call = intrinsic_call(line, "__buf_finalize", args);
call.set_ty(crate::types::Type::Str);
call
}
pub fn run_buffer_build_pass(items: &mut Vec<crate::ast::TopLevel>) -> BufferBuildPassReport {
let fn_refs: Vec<&FnDef> = items
.iter()
.filter_map(|it| match it {
crate::ast::TopLevel::FnDef(fd) => Some(fd),
_ => None,
})
.collect();
let all_sinks = compute_buffer_build_sinks(&fn_refs);
if all_sinks.is_empty() {
return BufferBuildPassReport::default();
}
let sites = find_fusion_sites(&fn_refs, &all_sinks);
let mut used_sinks: HashMap<String, BufferBuildShape> = HashMap::new();
for site in &sites {
if let Some(shape) = all_sinks.get(&site.sink_fn) {
used_sinks.insert(site.sink_fn.clone(), shape.clone());
}
}
let synthesized = synthesize_buffered_variants(&fn_refs, &used_sinks);
let sinks = used_sinks;
drop(fn_refs);
let mut fn_defs_owned: Vec<&mut FnDef> = items
.iter_mut()
.filter_map(|it| match it {
crate::ast::TopLevel::FnDef(fd) => Some(fd),
_ => None,
})
.collect();
for fd in fn_defs_owned.iter_mut() {
rewrite_one_fn(fd, &sinks);
}
items.reserve(synthesized.len());
for fd in synthesized.iter() {
items.push(crate::ast::TopLevel::FnDef(fd.clone()));
}
let mut sink_fns: Vec<String> = sinks.keys().cloned().collect();
sink_fns.sort();
let synthesized_fns: Vec<String> = synthesized.iter().map(|fd| fd.name.clone()).collect();
let mut rewrites_by_sink: std::collections::BTreeMap<String, usize> =
std::collections::BTreeMap::new();
for site in &sites {
*rewrites_by_sink.entry(site.sink_fn.clone()).or_default() += 1;
}
BufferBuildPassReport {
rewrites: sites.len(),
synthesized: synthesized_fns,
sink_fns,
rewrites_by_sink,
}
}
#[derive(Debug, Clone, Default)]
pub struct BufferBuildPassReport {
pub rewrites: usize,
pub synthesized: Vec<String>,
pub sink_fns: Vec<String>,
pub rewrites_by_sink: std::collections::BTreeMap<String, usize>,
}
fn rewrite_one_fn(fd: &mut FnDef, sinks: &HashMap<String, BufferBuildShape>) {
let body_arc = std::sync::Arc::make_mut(&mut fd.body);
let FnBody::Block(stmts) = body_arc;
for stmt in stmts.iter_mut() {
match stmt {
Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
rewrite_expr_in_place(expr, sinks);
}
}
}
}
pub fn rewrite_fusion_sites(fn_defs: &mut [FnDef], sinks: &HashMap<String, BufferBuildShape>) {
if sinks.is_empty() {
return;
}
for fd in fn_defs.iter_mut() {
let body_arc = std::sync::Arc::make_mut(&mut fd.body);
let FnBody::Block(stmts) = body_arc;
for stmt in stmts.iter_mut() {
match stmt {
Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
rewrite_expr_in_place(expr, sinks);
}
}
}
}
}
fn rewrite_expr_in_place(expr: &mut Spanned<Expr>, sinks: &HashMap<String, BufferBuildShape>) {
if let Some(replacement) = try_rewrite_fusion_site(expr, sinks) {
*expr = replacement;
descend_into_subexprs(expr, sinks);
return;
}
descend_into_subexprs(expr, sinks);
}
fn descend_into_subexprs(expr: &mut Spanned<Expr>, sinks: &HashMap<String, BufferBuildShape>) {
match &mut expr.node {
Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
Expr::Constructor(_, Some(inner)) | Expr::Attr(inner, _) | Expr::ErrorProp(inner) => {
rewrite_expr_in_place(inner, sinks);
}
Expr::FnCall(callee, args) => {
rewrite_expr_in_place(callee, sinks);
for a in args.iter_mut() {
rewrite_expr_in_place(a, sinks);
}
}
Expr::TailCall(data) => {
for a in data.args.iter_mut() {
rewrite_expr_in_place(a, sinks);
}
}
Expr::BinOp(_, l, r) => {
rewrite_expr_in_place(l, sinks);
rewrite_expr_in_place(r, sinks);
}
Expr::Match { subject, arms } => {
rewrite_expr_in_place(subject, sinks);
for arm in arms.iter_mut() {
rewrite_expr_in_place(&mut arm.body, sinks);
}
}
Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
for it in items.iter_mut() {
rewrite_expr_in_place(it, sinks);
}
}
Expr::MapLiteral(entries) => {
for (k, v) in entries.iter_mut() {
rewrite_expr_in_place(k, sinks);
rewrite_expr_in_place(v, sinks);
}
}
Expr::RecordCreate { fields, .. } => {
for (_, v) in fields.iter_mut() {
rewrite_expr_in_place(v, sinks);
}
}
Expr::RecordUpdate { base, updates, .. } => {
rewrite_expr_in_place(base, sinks);
for (_, v) in updates.iter_mut() {
rewrite_expr_in_place(v, sinks);
}
}
Expr::InterpolatedStr(parts) => {
for part in parts.iter_mut() {
if let crate::ast::StrPart::Parsed(inner) = part {
rewrite_expr_in_place(inner, sinks);
}
}
}
}
}
fn try_rewrite_fusion_site(
expr: &Spanned<Expr>,
sinks: &HashMap<String, BufferBuildShape>,
) -> Option<Spanned<Expr>> {
let line = expr.line;
let sink_name = match_string_join_fusion_site(&expr.node, sinks)?;
let shape = sinks.get(&sink_name)?;
let outer_args = match &expr.node {
Expr::FnCall(_, a) => a,
_ => return None,
};
let consumer_arg = &outer_args[0].node;
let inner_call_expr = if let Expr::FnCall(rev_callee, rev_args) = consumer_arg
&& is_dotted_ident(&rev_callee.node, "List", "reverse")
&& rev_args.len() == 1
{
&rev_args[0].node
} else {
consumer_arg
};
let inner_args = match inner_call_expr {
Expr::FnCall(_, a) => a,
_ => return None,
};
let sep_expr = outer_args[1].clone();
let buf_new = buffer_intrinsic_call(
line,
"__buf_new",
vec![sp_at_typed(
line,
Expr::Literal(Literal::Int(8192)),
crate::types::Type::Int,
)],
);
let mut buffered_args: Vec<Spanned<Expr>> = inner_args
.iter()
.enumerate()
.filter_map(|(i, a)| (i != shape.acc_param_idx).then_some(a).cloned())
.collect();
buffered_args.push(buf_new);
buffered_args.push(sep_expr);
let buffered_call = sp_at_typed(
line,
Expr::FnCall(
Box::new(sp_at(line, Expr::Ident(format!("{}__buffered", sink_name)))),
buffered_args,
),
crate::types::Type::Str,
);
Some(finalize_intrinsic_call(line, vec![buffered_call]))
}
fn build_buffered_variant(fd: &FnDef, shape: &BufferBuildShape) -> Option<FnDef> {
let stmts = fd.body.stmts();
if stmts.len() != 1 {
return None;
}
let outer_expr = match &stmts[0] {
Stmt::Expr(spanned) => spanned,
_ => return None,
};
let (subject_orig, arms_orig) = match &outer_expr.node {
Expr::Match { subject, arms } => (subject, arms),
_ => return None,
};
let recursive_body: &Spanned<Expr> = match shape.kind {
BufferBuildKind::InternalReverse => arms_orig
.iter()
.find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(false))))
.map(|a| a.body.as_ref())?,
BufferBuildKind::ExternalReverse => arms_orig
.iter()
.find(|a| matches!(a.pattern, Pattern::Cons(_, _)))
.map(|a| a.body.as_ref())?,
};
let tail_data = match &recursive_body.node {
Expr::TailCall(data) => data,
_ => return None,
};
let acc_arg_orig = tail_data.args.get(shape.acc_param_idx)?;
let elem_expr = match &acc_arg_orig.node {
Expr::FnCall(callee, args) => {
if !is_dotted_ident(&callee.node, "List", "prepend") {
return None;
}
if args.len() != 2 {
return None;
}
match &args[1].node {
Expr::Ident(name) if name == &shape.acc_param_name => {}
_ => return None,
}
args[0].clone()
}
_ => return None,
};
let line = fd.line;
let buf_name = "__buf";
let sep_name = "__sep";
let buffered_target = format!("{}__buffered", fd.name);
let buffer_ty = crate::types::Type::Named("Buffer".to_string());
let buf_ident = || sp_at_typed(line, Expr::Ident(buf_name.to_string()), buffer_ty.clone());
let sep_ident = || {
sp_at_typed(
line,
Expr::Ident(sep_name.to_string()),
crate::types::Type::Str,
)
};
let sep_then_buf = buffer_intrinsic_call(
line,
"__buf_append_sep_unless_first",
vec![buf_ident(), sep_ident()],
);
let final_buf = buffer_intrinsic_call(line, "__buf_append", vec![sep_then_buf, elem_expr]);
let mut new_args: Vec<Spanned<Expr>> = tail_data
.args
.iter()
.enumerate()
.map(|(i, a)| {
if i == shape.acc_param_idx {
final_buf.clone()
} else {
a.clone()
}
})
.collect();
new_args.push(sep_ident());
let new_recursive_body = sp_at_typed(
line,
Expr::TailCall(Box::new(TailCallData {
target: buffered_target.clone(),
args: new_args,
})),
buffer_ty.clone(),
);
let new_arms = match shape.kind {
BufferBuildKind::InternalReverse => vec![
MatchArm {
pattern: Pattern::Literal(Literal::Bool(true)),
body: Box::new(buf_ident()),
binding_slots: std::sync::OnceLock::new(),
},
MatchArm {
pattern: Pattern::Literal(Literal::Bool(false)),
body: Box::new(new_recursive_body),
binding_slots: std::sync::OnceLock::new(),
},
],
BufferBuildKind::ExternalReverse => {
let cons_pat = arms_orig
.iter()
.find_map(|a| match &a.pattern {
Pattern::Cons(h, t) => Some(Pattern::Cons(h.clone(), t.clone())),
_ => None,
})
.unwrap_or(Pattern::Cons("__head".to_string(), "__tail".to_string()));
vec![
MatchArm {
pattern: Pattern::EmptyList,
body: Box::new(buf_ident()),
binding_slots: std::sync::OnceLock::new(),
},
MatchArm {
pattern: cons_pat,
body: Box::new(new_recursive_body),
binding_slots: std::sync::OnceLock::new(),
},
]
}
};
let new_match = sp_at_typed(
line,
Expr::Match {
subject: subject_orig.clone(),
arms: new_arms,
},
crate::types::Type::Named("Buffer".to_string()),
);
let new_body = FnBody::Block(vec![Stmt::Expr(new_match)]);
let mut new_params: Vec<(String, String)> = fd
.params
.iter()
.enumerate()
.filter_map(|(i, p)| (i != shape.acc_param_idx).then_some(p).cloned())
.collect();
new_params.push((buf_name.to_string(), "Buffer".to_string()));
new_params.push((sep_name.to_string(), "String".to_string()));
Some(FnDef {
name: buffered_target,
line,
params: new_params,
return_type: "Buffer".to_string(),
effects: fd.effects.clone(),
desc: Some(format!(
"Synthesized buffered variant of `{}` for deforestation \
lowering. Call sites that match `String.join({}(...), sep)` \
are rewritten to alloc a buffer + call this variant + \
finalize, skipping the intermediate List.",
fd.name, fd.name
)),
body: Arc::new(new_body),
resolution: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::{BinOp, FnBody, FnDef, Literal, Spanned, TailCallData};
use std::sync::Arc;
fn sp<T>(value: T) -> Spanned<T> {
Spanned::new(value, 1)
}
fn ident(name: &str) -> Spanned<Expr> {
sp(Expr::Ident(name.to_string()))
}
fn dotted(module: &str, member: &str) -> Spanned<Expr> {
sp(Expr::Attr(Box::new(ident(module)), member.to_string()))
}
fn call(callee: Spanned<Expr>, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
sp(Expr::FnCall(Box::new(callee), args))
}
fn canonical_builder(name: &str) -> FnDef {
let true_body = call(dotted("List", "reverse"), vec![ident("acc")]);
let prepend = call(dotted("List", "prepend"), vec![ident("col"), ident("acc")]);
let false_body = sp(Expr::TailCall(Box::new(TailCallData {
target: name.to_string(),
args: vec![
sp(Expr::BinOp(
BinOp::Add,
Box::new(ident("col")),
Box::new(sp(Expr::Literal(Literal::Int(1)))),
)),
prepend,
],
})));
let match_expr = sp(Expr::Match {
subject: Box::new(sp(Expr::BinOp(
BinOp::Gte,
Box::new(ident("col")),
Box::new(sp(Expr::Literal(Literal::Int(10)))),
))),
arms: vec![
MatchArm {
pattern: Pattern::Literal(Literal::Bool(true)),
body: Box::new(true_body),
binding_slots: std::sync::OnceLock::new(),
},
MatchArm {
pattern: Pattern::Literal(Literal::Bool(false)),
body: Box::new(false_body),
binding_slots: std::sync::OnceLock::new(),
},
],
});
FnDef {
name: name.to_string(),
line: 1,
params: vec![
("col".to_string(), "Int".to_string()),
("acc".to_string(), "List<Int>".to_string()),
],
return_type: "List<Int>".to_string(),
effects: vec![],
desc: None,
body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
resolution: None,
}
}
#[test]
fn matches_canonical_buffer_build() {
let fd = canonical_builder("build");
let info = compute_buffer_build_sinks(&[&fd]);
let shape = info.get("build").expect("expected match");
assert_eq!(shape.acc_param_idx, 1);
assert_eq!(shape.acc_param_name, "acc");
}
#[test]
fn rejects_fn_without_list_param() {
let mut fd = canonical_builder("build");
fd.params = vec![("col".to_string(), "Int".to_string())];
let info = compute_buffer_build_sinks(&[&fd]);
assert!(info.is_empty(), "fn without List param should not match");
}
#[test]
fn rejects_when_true_arm_isnt_reverse() {
let mut fd = canonical_builder("build");
if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body)
&& let Stmt::Expr(spanned) = &mut stmts[0]
&& let Expr::Match { arms, .. } = &mut spanned.node
{
*arms[0].body = ident("acc");
}
let info = compute_buffer_build_sinks(&[&fd]);
assert!(
info.is_empty(),
"fn returning bare acc instead of reverse should not match"
);
}
#[test]
fn rejects_when_false_arm_uses_append_not_prepend() {
let mut fd = canonical_builder("build");
if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body)
&& let Stmt::Expr(spanned) = &mut stmts[0]
&& let Expr::Match { arms, .. } = &mut spanned.node
{
let false_body = arms[1].body.as_mut();
if let Expr::TailCall(data) = &mut false_body.node
&& let Expr::FnCall(callee, _) = &mut data.args[1].node
&& let Expr::Attr(_, attr) = &mut callee.node
{
*attr = "append".to_string();
}
}
let info = compute_buffer_build_sinks(&[&fd]);
assert!(
info.is_empty(),
"fn using List.append instead of prepend should not match"
);
}
#[test]
fn rejects_tail_call_to_different_fn() {
let mut fd = canonical_builder("build");
if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body)
&& let Stmt::Expr(spanned) = &mut stmts[0]
&& let Expr::Match { arms, .. } = &mut spanned.node
{
let false_body = arms[1].body.as_mut();
if let Expr::TailCall(data) = &mut false_body.node {
data.target = "someone_else".to_string();
}
}
let info = compute_buffer_build_sinks(&[&fd]);
assert!(
info.is_empty(),
"fn whose recursive call targets a different name should not match"
);
}
#[test]
fn rejects_match_with_non_bool_arms() {
let mut fd = canonical_builder("build");
if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body)
&& let Stmt::Expr(spanned) = &mut stmts[0]
&& let Expr::Match { arms, .. } = &mut spanned.node
{
arms[0].pattern = Pattern::Literal(Literal::Int(0));
}
let info = compute_buffer_build_sinks(&[&fd]);
assert!(
info.is_empty(),
"match on non-bool patterns should not be detected as buffer-build"
);
}
#[test]
fn detects_via_parser_after_tco() {
let src = r#"
fn build(n: Int, acc: List<Int>) -> List<Int>
match n <= 0
true -> List.reverse(acc)
false -> build(n - 1, List.prepend(n, acc))
"#;
let mut lexer = crate::lexer::Lexer::new(src);
let tokens = lexer.tokenize().expect("lex");
let mut parser = crate::parser::Parser::new(tokens);
let mut items = parser.parse().expect("parse");
crate::ir::pipeline::tco(&mut items);
let fns: Vec<&FnDef> = items
.iter()
.filter_map(|it| match it {
crate::ast::TopLevel::FnDef(fd) => Some(fd),
_ => None,
})
.collect();
let info = compute_buffer_build_sinks(&fns);
let shape = info
.get("build")
.expect("expected end-to-end shape match for canonical builder");
assert_eq!(shape.acc_param_idx, 1);
assert_eq!(shape.acc_param_name, "acc");
}
#[test]
fn finds_fusion_site_via_parser() {
let src = r#"
fn build(n: Int, acc: List<Int>) -> List<Int>
match n <= 0
true -> List.reverse(acc)
false -> build(n - 1, List.prepend(n, acc))
fn main() -> String
String.join(build(5, []), ",")
"#;
let mut lexer = crate::lexer::Lexer::new(src);
let tokens = lexer.tokenize().expect("lex");
let mut parser = crate::parser::Parser::new(tokens);
let mut items = parser.parse().expect("parse");
crate::ir::pipeline::tco(&mut items);
let fns: Vec<&FnDef> = items
.iter()
.filter_map(|it| match it {
crate::ast::TopLevel::FnDef(fd) => Some(fd),
_ => None,
})
.collect();
let sinks = compute_buffer_build_sinks(&fns);
let sites = find_fusion_sites(&fns, &sinks);
assert_eq!(sites.len(), 1, "expected one fusion site, got {sites:?}");
let site = &sites[0];
assert_eq!(site.enclosing_fn, "main");
assert_eq!(site.sink_fn, "build");
assert!(site.line > 0, "expected real line info, got 0");
}
#[test]
fn ignores_call_when_not_wrapped_in_string_join() {
let src = r#"
fn build(n: Int, acc: List<Int>) -> List<Int>
match n <= 0
true -> List.reverse(acc)
false -> build(n - 1, List.prepend(n, acc))
fn main() -> List<Int>
build(5, [])
"#;
let mut lexer = crate::lexer::Lexer::new(src);
let tokens = lexer.tokenize().expect("lex");
let mut parser = crate::parser::Parser::new(tokens);
let mut items = parser.parse().expect("parse");
crate::ir::pipeline::tco(&mut items);
let fns: Vec<&FnDef> = items
.iter()
.filter_map(|it| match it {
crate::ast::TopLevel::FnDef(fd) => Some(fd),
_ => None,
})
.collect();
let sinks = compute_buffer_build_sinks(&fns);
let sites = find_fusion_sites(&fns, &sinks);
assert!(
sites.is_empty(),
"build called outside String.join must not be a fusion site, got {sites:?}"
);
}
#[test]
fn rejects_via_parser_when_true_arm_returns_bare_acc() {
let src = r#"
fn build(n: Int, acc: List<Int>) -> List<Int>
match n <= 0
true -> acc
false -> build(n - 1, List.prepend(n, acc))
"#;
let mut lexer = crate::lexer::Lexer::new(src);
let tokens = lexer.tokenize().expect("lex");
let mut parser = crate::parser::Parser::new(tokens);
let mut items = parser.parse().expect("parse");
crate::ir::pipeline::tco(&mut items);
let fns: Vec<&FnDef> = items
.iter()
.filter_map(|it| match it {
crate::ast::TopLevel::FnDef(fd) => Some(fd),
_ => None,
})
.collect();
let info = compute_buffer_build_sinks(&fns);
assert!(
info.is_empty(),
"fn returning bare acc must not be detected as a deforestation candidate"
);
}
#[test]
fn synthesizes_buffered_variant_from_real_builder() {
let src = r#"
fn build(n: Int, acc: List<Int>) -> List<Int>
match n <= 0
true -> List.reverse(acc)
false -> build(n - 1, List.prepend(n, acc))
"#;
let mut lexer = crate::lexer::Lexer::new(src);
let tokens = lexer.tokenize().expect("lex");
let mut parser = crate::parser::Parser::new(tokens);
let mut items = parser.parse().expect("parse");
crate::ir::pipeline::tco(&mut items);
let fns: Vec<&FnDef> = items
.iter()
.filter_map(|it| match it {
crate::ast::TopLevel::FnDef(fd) => Some(fd),
_ => None,
})
.collect();
let sinks = compute_buffer_build_sinks(&fns);
assert!(sinks.contains_key("build"));
let synthesized = synthesize_buffered_variants(&fns, &sinks);
assert_eq!(
synthesized.len(),
1,
"expected exactly one synthesized variant"
);
let bf = &synthesized[0];
assert_eq!(bf.name, "build__buffered");
assert_eq!(bf.return_type, "Buffer");
let param_names: Vec<&str> = bf.params.iter().map(|(n, _)| n.as_str()).collect();
let param_types: Vec<&str> = bf.params.iter().map(|(_, t)| t.as_str()).collect();
assert_eq!(param_names, vec!["n", "__buf", "__sep"]);
assert_eq!(param_types, vec!["Int", "Buffer", "String"]);
let stmts = bf.body.stmts();
assert_eq!(stmts.len(), 1);
let match_expr = match &stmts[0] {
Stmt::Expr(s) => match &s.node {
Expr::Match { subject: _, arms } => arms,
_ => panic!("body root must be a match"),
},
_ => panic!("body root must be Stmt::Expr"),
};
assert_eq!(match_expr.len(), 2);
let true_arm = match_expr
.iter()
.find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(true))))
.expect("true arm");
match &true_arm.body.node {
Expr::Ident(name) => assert_eq!(name, "__buf"),
other => panic!("true arm should be Ident(__buf), got {other:?}"),
}
let false_arm = match_expr
.iter()
.find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(false))))
.expect("false arm");
let tail_data = match &false_arm.body.node {
Expr::TailCall(d) => d,
other => panic!("false arm should be TailCall, got {other:?}"),
};
assert_eq!(tail_data.target, "build__buffered");
assert_eq!(tail_data.args.len(), 3);
let outer = match &tail_data.args[1].node {
Expr::FnCall(callee, args) => {
match &callee.node {
Expr::Ident(name) => assert_eq!(name, "__buf_append"),
_ => panic!("expected Ident callee"),
}
args
}
_ => panic!("expected outer __buf_append FnCall"),
};
assert_eq!(outer.len(), 2);
match &outer[0].node {
Expr::FnCall(callee, _) => match &callee.node {
Expr::Ident(name) => assert_eq!(name, "__buf_append_sep_unless_first"),
_ => panic!("expected Ident callee for inner intrinsic"),
},
_ => panic!("expected inner __buf_append_sep_unless_first FnCall"),
}
match &outer[1].node {
Expr::Ident(name) => assert_eq!(name, "n"),
_ => panic!("expected `n` ident as elem"),
}
match &tail_data.args[2].node {
Expr::Ident(name) => assert_eq!(name, "__sep"),
_ => panic!("expected __sep ident as last arg"),
}
}
#[test]
fn detects_acc_param_at_arbitrary_index() {
let true_body = call(dotted("List", "reverse"), vec![ident("acc")]);
let prepend = call(dotted("List", "prepend"), vec![ident("col"), ident("acc")]);
let false_body = sp(Expr::TailCall(Box::new(TailCallData {
target: "build".to_string(),
args: vec![
prepend,
sp(Expr::BinOp(
BinOp::Add,
Box::new(ident("col")),
Box::new(sp(Expr::Literal(Literal::Int(1)))),
)),
],
})));
let match_expr = sp(Expr::Match {
subject: Box::new(sp(Expr::BinOp(
BinOp::Gte,
Box::new(ident("col")),
Box::new(sp(Expr::Literal(Literal::Int(10)))),
))),
arms: vec![
MatchArm {
pattern: Pattern::Literal(Literal::Bool(true)),
body: Box::new(true_body),
binding_slots: std::sync::OnceLock::new(),
},
MatchArm {
pattern: Pattern::Literal(Literal::Bool(false)),
body: Box::new(false_body),
binding_slots: std::sync::OnceLock::new(),
},
],
});
let fd = FnDef {
name: "build".to_string(),
line: 1,
params: vec![
("acc".to_string(), "List<Int>".to_string()),
("col".to_string(), "Int".to_string()),
],
return_type: "List<Int>".to_string(),
effects: vec![],
desc: None,
body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
resolution: None,
};
let info = compute_buffer_build_sinks(&[&fd]);
let shape = info.get("build").expect("expected match");
assert_eq!(shape.acc_param_idx, 0);
assert_eq!(shape.acc_param_name, "acc");
}
#[test]
fn rejects_loose_prepend_in_non_acc_position() {
let mut fd = canonical_builder("build");
{
let body = std::sync::Arc::make_mut(&mut fd.body);
let FnBody::Block(stmts) = body;
if let Stmt::Expr(spanned) = &mut stmts[0]
&& let Expr::Match { arms, .. } = &mut spanned.node
{
for arm in arms.iter_mut() {
if matches!(arm.pattern, Pattern::Literal(Literal::Bool(false)))
&& let Expr::TailCall(data) = &mut arm.body.node
{
data.args.reverse();
}
}
}
}
let info = compute_buffer_build_sinks(&[&fd]);
assert!(
!info.contains_key("build"),
"loose-prepend (prepend not at acc-position) must not be detected"
);
}
#[test]
fn skips_synth_when_no_rewriteable_call_site() {
let sink = canonical_builder("build");
let caller = FnDef {
name: "use_build".to_string(),
line: 2,
params: vec![],
return_type: "List<Int>".to_string(),
effects: vec![],
desc: None,
body: Arc::new(FnBody::Block(vec![Stmt::Expr(call(
ident_expr("build"),
vec![sp(Expr::Literal(Literal::Int(0))), sp(Expr::List(vec![]))],
))])),
resolution: None,
};
let mut items = vec![
crate::ast::TopLevel::FnDef(sink),
crate::ast::TopLevel::FnDef(caller),
];
let initial_count = items.len();
let report = run_buffer_build_pass(&mut items);
assert_eq!(report.rewrites, 0, "no fusion sites — no rewriteable call");
assert_eq!(
report.synthesized.len(),
0,
"no synth — nothing to fuse against"
);
assert_eq!(items.len(), initial_count, "no buffered variant appended");
}
#[test]
fn external_reverse_pattern_round_trips() {
let nil_body = ident("acc");
let prepend = call(dotted("List", "prepend"), vec![ident("h"), ident("acc")]);
let cons_body = sp(Expr::TailCall(Box::new(TailCallData {
target: "build".to_string(),
args: vec![ident("t"), prepend],
})));
let match_expr = sp(Expr::Match {
subject: Box::new(ident("xs")),
arms: vec![
MatchArm {
pattern: Pattern::EmptyList,
body: Box::new(nil_body),
binding_slots: std::sync::OnceLock::new(),
},
MatchArm {
pattern: Pattern::Cons("h".to_string(), "t".to_string()),
body: Box::new(cons_body),
binding_slots: std::sync::OnceLock::new(),
},
],
});
let sink = FnDef {
name: "build".to_string(),
line: 1,
params: vec![
("xs".to_string(), "List<Int>".to_string()),
("acc".to_string(), "List<String>".to_string()),
],
return_type: "List<String>".to_string(),
effects: vec![],
desc: None,
body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
resolution: None,
};
let info = compute_buffer_build_sinks(&[&sink]);
let shape = info
.get("build")
.expect("external-reverse sink should be detected");
assert_eq!(shape.kind, BufferBuildKind::ExternalReverse);
assert_eq!(shape.acc_param_idx, 1);
let join_call = call(
dotted("String", "join"),
vec![
call(
dotted("List", "reverse"),
vec![call(
ident_expr("build"),
vec![ident("xs"), sp(Expr::List(vec![]))],
)],
),
sp(Expr::Literal(Literal::Str("\n".to_string()))),
],
);
let caller = FnDef {
name: "render".to_string(),
line: 2,
params: vec![("xs".to_string(), "List<Int>".to_string())],
return_type: "String".to_string(),
effects: vec![],
desc: None,
body: Arc::new(FnBody::Block(vec![Stmt::Expr(join_call)])),
resolution: None,
};
let mut items = vec![
crate::ast::TopLevel::FnDef(sink),
crate::ast::TopLevel::FnDef(caller),
];
let report = run_buffer_build_pass(&mut items);
assert_eq!(
report.rewrites, 1,
"external-reverse pattern should be one fusion site"
);
assert_eq!(
report.synthesized.len(),
1,
"exactly one buffered variant for the used sink"
);
let synth_present = items.iter().any(|it| match it {
crate::ast::TopLevel::FnDef(fd) => fd.name == "build__buffered",
_ => false,
});
assert!(synth_present, "build__buffered must be appended");
}
fn ident_expr(name: &str) -> Spanned<Expr> {
sp(Expr::Ident(name.to_string()))
}
}