Skip to main content

aver/ir/
buffer_build.rs

1//! Buffer-build sink detection.
2//!
3//! Identifies user fns that match the canonical functional list-builder
4//! shape consumed by `String.join`:
5//!
6//! ```aver
7//! fn build(..., acc: List<T>) -> List<T>
8//!     match <cond>
9//!         true  -> List.reverse(acc)
10//!         false -> build(..., List.prepend(<elem>, acc))
11//! ```
12//!
13//! When such a fn is called from `String.join(build(..., []), sep)`, the
14//! whole pipeline is semantically equivalent to a single buffer-write
15//! loop — Wadler 1990 shortcut fusion / deforestation. This module is
16//! Phase 1 of the deforestation work for 0.15 "Traversal": it detects
17//! candidate fns. Lowering (rewriting matched fns + their `String.join`
18//! call sites) lives in a separate pass.
19//!
20//! Detection is intentionally local — the analyzer looks only at the fn
21//! body, not its call sites. A matched fn may or may not actually be
22//! consumed by `String.join`; the lowering pass cross-references call
23//! sites separately and only fuses when both ends of the pipeline agree.
24
25use std::collections::HashMap;
26use std::sync::Arc;
27
28use crate::ast::{Expr, FnBody, FnDef, Literal, MatchArm, Pattern, Spanned, Stmt, TailCallData};
29
30/// Where the matched builder puts the `List.reverse` step that gives
31/// the result its forward order.
32///
33/// `prepend(elem, acc)` builds the accumulator in reverse-of-input
34/// order. To get a forward list (the order we'd hand to `String.join`)
35/// it has to be reversed exactly once. Two equally common Aver idioms
36/// place that reverse in different spots:
37///
38/// - `InternalReverse`: the sink itself has `true -> List.reverse(acc)`
39///   in its base case. Caller writes `String.join(<sink>(args, []), sep)`.
40///   This is the classic "loop with reversed accumulator + reverse on
41///   exit" shape.
42/// - `ExternalReverse`: the sink has `[] -> acc` (no reverse) and
43///   matches on the input list directly. Caller writes
44///   `String.join(List.reverse(<sink>(args, [])), sep)`. Common in the
45///   payment_ops / workflow_engine codebases under the `*Into` naming
46///   convention (e.g. `serializeEntriesInto`, `filterSubjectInto`).
47///
48/// Both shapes lower to the same buffered variant — appending in
49/// processing order (which is forward order of the input) yields the
50/// final string in the right order without any explicit reverse.
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum BufferBuildKind {
53    InternalReverse,
54    ExternalReverse,
55}
56
57/// Information about a fn that matches the buffer-build sink shape.
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub struct BufferBuildShape {
60    /// 0-based index of the `acc: List<T>` parameter in the fn signature.
61    /// Identifies which arg in tail-call positions threads the
62    /// accumulator and which `Ident` in the base-case arm is returned.
63    pub acc_param_idx: usize,
64    /// The accumulator parameter's binding name (looked up in tail-call
65    /// args and in the base-case return position).
66    pub acc_param_name: String,
67    /// Which of the two reverse-placement idioms this sink follows;
68    /// determines (a) what shape the original base arm has and (b)
69    /// whether the call site is `String.join(<sink>(...), sep)` or
70    /// `String.join(List.reverse(<sink>(...)), sep)`.
71    pub kind: BufferBuildKind,
72}
73
74/// What the matched builder feeds into. Different consumers compile
75/// to different buffer types and finalizers, but all share the same
76/// underlying deforestation: skip the intermediate List, write
77/// elements straight to the consumer's storage.
78///
79/// Phase 2 implements `StringJoin` only — the canonical case from the
80/// fractal demo. Future variants land as separate phases:
81/// `VectorFromList` (already half-fused via `Vector.set` owned-mutate
82/// in 0.14.0; deforestation closes the cons-cell side), and `ListFold`
83/// for stream-fusion-style consumer rewrites.
84#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum ConsumerKind {
86    /// `String.join(builder(...), sep)` — write each element + sep
87    /// directly into a `Vec<u8>`-shaped buffer in linear memory.
88    StringJoin,
89}
90
91/// One detected fusion site: a builder call whose result is consumed
92/// by a known sink (currently just `String.join`). Lowering rewrites
93/// the producer + consumer pair into a single buffer-write loop.
94#[derive(Debug, Clone, PartialEq, Eq)]
95pub struct FusionSite {
96    /// Name of the enclosing user fn that contains the call.
97    pub enclosing_fn: String,
98    /// Line of the consumer call.
99    pub line: usize,
100    /// The matched buffer-build fn being wrapped.
101    pub sink_fn: String,
102    /// What's consuming the builder's result.
103    pub consumer: ConsumerKind,
104}
105
106/// Walk all fns in `fns`, return a map from fn name to detected shape
107/// for fns that match the buffer-build sink pattern. Fns that don't
108/// match are absent from the result.
109pub fn compute_buffer_build_sinks(fns: &[&FnDef]) -> HashMap<String, BufferBuildShape> {
110    let mut out = HashMap::new();
111    for fd in fns {
112        if let Some(shape) = match_buffer_build_shape(fd) {
113            out.insert(fd.name.clone(), shape);
114        }
115    }
116    out
117}
118
119/// Walk every expression in every fn body looking for fusion sites:
120/// `String.join(matched_fn(...), sep)` calls where `matched_fn` is a
121/// key in `sinks`. Returns one `FusionSite` per call. The lowering
122/// pass rewrites each site to call a buffered variant of `matched_fn`
123/// directly into a pre-allocated buffer.
124pub fn find_fusion_sites(
125    fns: &[&FnDef],
126    sinks: &HashMap<String, BufferBuildShape>,
127) -> Vec<FusionSite> {
128    let mut out = Vec::new();
129    for fd in fns {
130        for stmt in fd.body.stmts() {
131            match stmt {
132                Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
133                    walk_expr_for_fusion_sites(&expr.node, expr.line, &fd.name, sinks, &mut out);
134                }
135            }
136        }
137    }
138    out
139}
140
141/// Recursively walk an expression tree, recording any fusion site we
142/// find. The fallback `expr_line` is used when a sub-expression has no
143/// own line info.
144fn walk_expr_for_fusion_sites(
145    expr: &Expr,
146    expr_line: usize,
147    enclosing_fn: &str,
148    sinks: &HashMap<String, BufferBuildShape>,
149    out: &mut Vec<FusionSite>,
150) {
151    if let Some(inner_name) = match_string_join_fusion_site(expr, sinks) {
152        out.push(FusionSite {
153            enclosing_fn: enclosing_fn.to_string(),
154            line: expr_line,
155            sink_fn: inner_name,
156            consumer: ConsumerKind::StringJoin,
157        });
158    }
159    // Recurse into all sub-expressions regardless of whether this node
160    // matched (a fusion site can sit inside another fusion site's args
161    // — rare but valid; we'd record both and let the lowering decide).
162    visit_subexprs(expr, expr_line, enclosing_fn, sinks, out);
163}
164
165/// Helper: recurse into the sub-expressions of `expr`. Mirrors the
166/// shape coverage of `expr_allocates` in `alloc_info.rs` so we don't
167/// miss any node kind.
168fn visit_subexprs(
169    expr: &Expr,
170    fallback_line: usize,
171    enclosing_fn: &str,
172    sinks: &HashMap<String, BufferBuildShape>,
173    out: &mut Vec<FusionSite>,
174) {
175    let line_of = |s: &crate::ast::Spanned<Expr>| {
176        if s.line > 0 { s.line } else { fallback_line }
177    };
178    match expr {
179        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
180        Expr::Constructor(_, Some(inner)) | Expr::Attr(inner, _) | Expr::ErrorProp(inner) => {
181            walk_expr_for_fusion_sites(&inner.node, line_of(inner), enclosing_fn, sinks, out);
182        }
183        Expr::FnCall(callee, args) => {
184            walk_expr_for_fusion_sites(&callee.node, line_of(callee), enclosing_fn, sinks, out);
185            for a in args {
186                walk_expr_for_fusion_sites(&a.node, line_of(a), enclosing_fn, sinks, out);
187            }
188        }
189        Expr::TailCall(data) => {
190            for a in &data.args {
191                walk_expr_for_fusion_sites(&a.node, line_of(a), enclosing_fn, sinks, out);
192            }
193        }
194        Expr::BinOp(_, l, r) => {
195            walk_expr_for_fusion_sites(&l.node, line_of(l), enclosing_fn, sinks, out);
196            walk_expr_for_fusion_sites(&r.node, line_of(r), enclosing_fn, sinks, out);
197        }
198        Expr::Match { subject, arms } => {
199            walk_expr_for_fusion_sites(&subject.node, line_of(subject), enclosing_fn, sinks, out);
200            for arm in arms {
201                walk_expr_for_fusion_sites(
202                    &arm.body.node,
203                    line_of(&arm.body),
204                    enclosing_fn,
205                    sinks,
206                    out,
207                );
208            }
209        }
210        Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
211            for it in items {
212                walk_expr_for_fusion_sites(&it.node, line_of(it), enclosing_fn, sinks, out);
213            }
214        }
215        Expr::MapLiteral(entries) => {
216            for (k, v) in entries {
217                walk_expr_for_fusion_sites(&k.node, line_of(k), enclosing_fn, sinks, out);
218                walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
219            }
220        }
221        Expr::RecordCreate { fields, .. } => {
222            for (_, v) in fields {
223                walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
224            }
225        }
226        Expr::RecordUpdate { base, updates, .. } => {
227            walk_expr_for_fusion_sites(&base.node, line_of(base), enclosing_fn, sinks, out);
228            for (_, v) in updates {
229                walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
230            }
231        }
232        Expr::InterpolatedStr(parts) => {
233            for part in parts {
234                if let crate::ast::StrPart::Parsed(inner) = part {
235                    walk_expr_for_fusion_sites(
236                        &inner.node,
237                        line_of(inner),
238                        enclosing_fn,
239                        sinks,
240                        out,
241                    );
242                }
243            }
244        }
245    }
246}
247
248/// Pattern-match a single fn against the buffer-build shape.
249fn match_buffer_build_shape(fd: &FnDef) -> Option<BufferBuildShape> {
250    // The accumulator must be a parameter of type `List<...>`. The
251    // params vector stores type strings, not parsed `Type` values, so we
252    // match the textual form. Aver's surface syntax accepts both
253    // `List<T>` and (rarely) `[T]`-like sugar; canonical form is
254    // `List<T>`.
255    // Take the *rightmost* List<...> parameter as the accumulator. The
256    // InternalReverse shape (`fn build(n: Int, acc: List<T>)`) typically
257    // has only one list param; the ExternalReverse shape often has two
258    // (`fn build(input: List<T>, acc: List<U>)`) and the accumulator is
259    // by convention the last argument. Picking the first match would
260    // misidentify `input` as the accumulator and the rest of the
261    // detection would silently fail.
262    let (acc_idx, acc_name) = fd
263        .params
264        .iter()
265        .enumerate()
266        .rfind(|(_, (_, ty))| is_list_type_str(ty))
267        .map(|(i, (name, _))| (i, name.clone()))?;
268
269    // Body must be a single expression statement holding the match.
270    let match_expr = single_match_body(&fd.body)?;
271    let (subject_expr, arms) = match match_expr {
272        Expr::Match { subject, arms } => (subject, arms),
273        _ => return None,
274    };
275
276    // Try the InternalReverse shape first: `match <bool> { true -> List.reverse(acc); false -> recurse(... prepend(_, acc)) }`.
277    if let Some((true_body, false_body)) = pair_bool_arms(arms) {
278        let _ = subject_expr;
279        if is_list_reverse_of(true_body, &acc_name)
280            && is_self_tail_with_prepend_acc(false_body, &fd.name, acc_idx, &acc_name)
281        {
282            return Some(BufferBuildShape {
283                acc_param_idx: acc_idx,
284                acc_param_name: acc_name,
285                kind: BufferBuildKind::InternalReverse,
286            });
287        }
288    }
289
290    // Otherwise try the ExternalReverse shape: `match <list> { [] -> acc;
291    // [_, .._] -> recurse(... prepend(_, acc)) }`. The reverse lives at
292    // the caller, e.g. `List.reverse(<this>(args, []))` (see the
293    // `BufferBuildKind` doc above for context).
294    if let Some((nil_body, cons_body)) = pair_nil_cons_arms(arms)
295        && is_ident_named(nil_body, &acc_name)
296        && is_self_tail_with_prepend_acc(cons_body, &fd.name, acc_idx, &acc_name)
297    {
298        return Some(BufferBuildShape {
299            acc_param_idx: acc_idx,
300            acc_param_name: acc_name,
301            kind: BufferBuildKind::ExternalReverse,
302        });
303    }
304
305    None
306}
307
308/// Match a 2-arm match where one arm is `[]` and the other is
309/// `[head, ..tail]`. Returns `(nil_body, cons_body)` (both as `&Expr`,
310/// matching `pair_bool_arms`). `None` if the arms don't exactly cover
311/// those two patterns.
312fn pair_nil_cons_arms(arms: &[MatchArm]) -> Option<(&Expr, &Expr)> {
313    if arms.len() != 2 {
314        return None;
315    }
316    let mut nil_body: Option<&Expr> = None;
317    let mut cons_body: Option<&Expr> = None;
318    for arm in arms {
319        match &arm.pattern {
320            Pattern::EmptyList => nil_body = Some(&arm.body.node),
321            Pattern::Cons(_, _) => cons_body = Some(&arm.body.node),
322            _ => return None,
323        }
324    }
325    match (nil_body, cons_body) {
326        (Some(n), Some(c)) => Some((n, c)),
327        _ => None,
328    }
329}
330
331/// True if `expr` is just an identifier reference to `name`.
332fn is_ident_named(expr: &Expr, name: &str) -> bool {
333    matches!(expr, Expr::Ident(n) if n == name)
334}
335
336/// Recognise a `String.join` first-arg as either a direct sink call or
337/// a `List.reverse(<sink>(args, []))` wrapper. Returns the matched
338/// sink fn name when the kind matches the shape we expect:
339///   InternalReverse sinks → only the direct call form
340///   ExternalReverse sinks → only the `List.reverse(...)` form
341/// Returning the name only when the kinds line up keeps us from
342/// accidentally fusing a sink against a call site that's missing
343/// its required reverse (or has an extraneous one).
344/// Single source of truth for "is this expression a rewriteable
345/// `String.join(<sink>(args, []), sep)` fusion site?". Returns the
346/// matched sink name when **all** preconditions hold:
347/// 1. The expression is a `String.join(<inner>, _)` call.
348/// 2. `<inner>` is either a direct sink call (for InternalReverse
349///    sinks) or `List.reverse(<sink>(...))` (for ExternalReverse).
350/// 3. The reverse-placement on the call site matches the sink's kind
351///    — mismatch would silently drop or double-reverse.
352/// 4. The acc-position arg in the inner call is a literal empty list.
353///    Anything else means the user is starting the fold with a
354///    non-empty accumulator, and the buffered rewrite would silently
355///    drop those initial elements.
356///
357/// Both `find_fusion_sites` (diagnostics) and `try_rewrite_fusion_site`
358/// (the actual AST rewrite) call this so the two stay in lockstep —
359/// `aver check` can never report a site that the rewrite then refuses
360/// to take.
361fn match_string_join_fusion_site(
362    expr: &Expr,
363    sinks: &HashMap<String, BufferBuildShape>,
364) -> Option<String> {
365    let Expr::FnCall(callee, args) = expr else {
366        return None;
367    };
368    if !is_dotted_ident(&callee.node, "String", "join") || args.len() != 2 {
369        return None;
370    }
371    let consumer_arg = &args[0].node;
372
373    // Peel an optional `List.reverse(...)` wrapper.
374    let (inner_call_expr, saw_external_reverse) = match consumer_arg {
375        Expr::FnCall(rev_callee, rev_args)
376            if is_dotted_ident(&rev_callee.node, "List", "reverse") && rev_args.len() == 1 =>
377        {
378            (&rev_args[0].node, true)
379        }
380        other => (other, false),
381    };
382
383    let Expr::FnCall(inner_callee, inner_args) = inner_call_expr else {
384        return None;
385    };
386    let Expr::Ident(name) = &inner_callee.node else {
387        return None;
388    };
389    let shape = sinks.get(name)?;
390
391    let kinds_align = matches!(
392        (saw_external_reverse, &shape.kind),
393        (false, BufferBuildKind::InternalReverse) | (true, BufferBuildKind::ExternalReverse)
394    );
395    if !kinds_align {
396        return None;
397    }
398
399    let acc_arg = inner_args.get(shape.acc_param_idx)?;
400    if !matches!(&acc_arg.node, Expr::List(items) if items.is_empty()) {
401        return None;
402    }
403
404    Some(name.clone())
405}
406
407/// True if a parameter type-string parses as `List<...>`.
408fn is_list_type_str(ty: &str) -> bool {
409    let t = ty.trim();
410    t.starts_with("List<") && t.ends_with('>')
411}
412
413/// Extract the single match expression that forms a fn's entire body.
414/// Returns `None` if the body is empty, has multiple statements, or its
415/// single statement isn't a match expression.
416fn single_match_body(body: &FnBody) -> Option<&Expr> {
417    let stmts = body.stmts();
418    if stmts.len() != 1 {
419        return None;
420    }
421    match &stmts[0] {
422        Stmt::Expr(spanned) => match &spanned.node {
423            Expr::Match { .. } => Some(&spanned.node),
424            _ => None,
425        },
426        Stmt::Binding(_, _, _) => None,
427    }
428}
429
430/// If `arms` is exactly two arms with `Bool(true)` / `Bool(false)`
431/// patterns, return `(true_body, false_body)` references. Order in
432/// source doesn't matter — we sort by pattern.
433fn pair_bool_arms(arms: &[MatchArm]) -> Option<(&Expr, &Expr)> {
434    if arms.len() != 2 {
435        return None;
436    }
437    let mut t = None;
438    let mut f = None;
439    for arm in arms {
440        match &arm.pattern {
441            Pattern::Literal(Literal::Bool(true)) => {
442                if t.is_some() {
443                    return None;
444                }
445                t = Some(&arm.body.node);
446            }
447            Pattern::Literal(Literal::Bool(false)) => {
448                if f.is_some() {
449                    return None;
450                }
451                f = Some(&arm.body.node);
452            }
453            _ => return None,
454        }
455    }
456    Some((t?, f?))
457}
458
459/// True if `expr` is `List.reverse(<Ident(acc_name)>)`.
460fn is_list_reverse_of(expr: &Expr, acc_name: &str) -> bool {
461    let (callee, args) = match expr {
462        Expr::FnCall(c, a) => (c, a),
463        _ => return false,
464    };
465    if !is_dotted_ident(&callee.node, "List", "reverse") {
466        return false;
467    }
468    if args.len() != 1 {
469        return false;
470    }
471    matches!(&args[0].node, Expr::Ident(name) if name == acc_name)
472}
473
474/// True if `expr` is a tail-call to `self_name` whose argument list
475/// contains `List.prepend(<anything>, <Ident(acc_name)>)` in any
476/// position. The position should match the `acc_param_idx` but the
477/// caller may have other params before it; we only require the
478/// `prepend` to terminate in the expected accumulator binding.
479fn is_self_tail_with_prepend_acc(
480    expr: &Expr,
481    self_name: &str,
482    acc_idx: usize,
483    acc_name: &str,
484) -> bool {
485    let data = match expr {
486        Expr::TailCall(data) => data,
487        _ => return false,
488    };
489    if data.target != self_name {
490        return false;
491    }
492    // The prepend has to land in the *acc* position specifically — the
493    // synthesizer extracts the element expression from `args[acc_idx]`.
494    // A loose `any` here would let through fns where some other arg
495    // happens to be a prepend, the synth would later return None, but
496    // detection had already promised the sink to call-site rewriting:
497    // `String.join(<sink>(...))` would be rewritten to call a
498    // `<sink>__buffered` that never gets generated. Require the exact
499    // shape here so detection and synthesis agree.
500    let acc_arg = match data.args.get(acc_idx) {
501        Some(a) => a,
502        None => return false,
503    };
504    is_list_prepend_to_acc(&acc_arg.node, acc_name)
505}
506
507/// True if `expr` is `List.prepend(<anything>, <Ident(acc_name)>)`.
508fn is_list_prepend_to_acc(expr: &Expr, acc_name: &str) -> bool {
509    let (callee, args) = match expr {
510        Expr::FnCall(c, a) => (c, a),
511        _ => return false,
512    };
513    if !is_dotted_ident(&callee.node, "List", "prepend") {
514        return false;
515    }
516    if args.len() != 2 {
517        return false;
518    }
519    matches!(&args[1].node, Expr::Ident(name) if name == acc_name)
520}
521
522/// True if `expr` is `<Module>.<Member>` access (the un-called callee
523/// shape of `Module.member(...)`).
524fn is_dotted_ident(expr: &Expr, module: &str, member: &str) -> bool {
525    let (base, attr) = match expr {
526        Expr::Attr(b, a) => (b, a),
527        _ => return false,
528    };
529    if attr != member {
530        return false;
531    }
532    matches!(&base.node, Expr::Ident(name) if name == module)
533}
534
535/// Synthesize a `<fn>__buffered` variant for each matched buffer-build
536/// sink. The synthesized FnDef walks the same shape as the original but
537/// threads a runtime `Buffer` through tail-call args instead of building
538/// a `List<T>` of strings:
539///
540/// Original:
541/// ```aver
542/// fn build(.., acc: List<T>) -> List<T>
543///     match <cond>
544///         true  -> List.reverse(acc)
545///         false -> build(.., List.prepend(<elem>, acc))
546/// ```
547///
548/// Synthesized:
549/// ```aver
550/// fn build__buffered(.., __buf: Buffer, __sep: String) -> Buffer
551///     match <cond>
552///         true  -> __buf
553///         false -> build__buffered(..,
554///             __buf_append(
555///                 __buf_append_sep_unless_first(__buf, __sep),
556///                 <elem>
557///             ),
558///             __sep
559///         )
560/// ```
561///
562/// Threading is via expression composition: the inner
563/// `__buf_append_sep_unless_first` returns the (possibly grown) buffer,
564/// the outer `__buf_append` writes the element and again returns
565/// the (possibly grown) buffer, and that final pointer is what the tail
566/// call sees as `__buf`. No `_ =` discards anywhere — the C' review
567/// explicitly required this to avoid use-after-grow corruption.
568///
569/// Returns one `FnDef` per matched fn. Caller appends to the user-fn
570/// list before WASM emission so both original and buffered variants
571/// reach codegen through the same pipeline.
572pub fn synthesize_buffered_variants(
573    fns: &[&FnDef],
574    sinks: &HashMap<String, BufferBuildShape>,
575) -> Vec<FnDef> {
576    let mut out = Vec::new();
577    for fd in fns {
578        if let Some(shape) = sinks.get(&fd.name)
579            && let Some(buffered) = build_buffered_variant(fd, shape)
580        {
581            out.push(buffered);
582        }
583    }
584    out
585}
586
587/// Wrap an `Expr` as `Spanned<Expr>` carrying the same line as the
588/// matched fn (best effort — the synthesized code is internal and
589/// won't be source-located by the user, but having a non-zero line
590/// keeps downstream visitors happy).
591fn sp_at(line: usize, expr: Expr) -> Spanned<Expr> {
592    Spanned { node: expr, line }
593}
594
595/// Build `<intrinsic>(args...)` as a Spanned<Expr>. Intrinsic names
596/// are bare identifiers (no module dot) — `__buf_append`,
597/// `__buf_append_sep_unless_first`. The WASM emitter recognises them
598/// in the builtin dispatch.
599fn intrinsic_call(line: usize, name: &str, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
600    let callee = sp_at(line, Expr::Ident(name.to_string()));
601    sp_at(line, Expr::FnCall(Box::new(callee), args))
602}
603
604/// Run the full buffer-build deforestation pass on a program: detect
605/// sinks, synthesize buffered variants, rewrite fusion sites in place,
606/// and APPEND the synthesized FnDefs to the items list as new
607/// top-level fns. Caller is responsible for invoking this AFTER
608/// `tco::transform_program` (the detector requires `Expr::TailCall`
609/// nodes) and BEFORE `resolver::resolve_program` (the detector +
610/// rewrite both match on `Expr::Ident` shapes that the resolver
611/// rewrites to `Expr::Resolved`).
612///
613/// Returns a [`BufferBuildPassReport`] describing what fired, for
614/// diagnostic / bench reporting and `--explain-passes`.
615pub fn run_buffer_build_pass(items: &mut Vec<crate::ast::TopLevel>) -> BufferBuildPassReport {
616    let fn_refs: Vec<&FnDef> = items
617        .iter()
618        .filter_map(|it| match it {
619            crate::ast::TopLevel::FnDef(fd) => Some(fd),
620            _ => None,
621        })
622        .collect();
623    let all_sinks = compute_buffer_build_sinks(&fn_refs);
624    if all_sinks.is_empty() {
625        return BufferBuildPassReport::default();
626    }
627    let sites = find_fusion_sites(&fn_refs, &all_sinks);
628
629    // Synthesize a buffered variant only for sinks that actually have
630    // at least one rewriteable call site. The earlier shape produced
631    // a `<sink>__buffered` for every detected sink — bloat in the
632    // common case (most detected sinks aren't called via the canonical
633    // String.join shape) and a real risk of name-shadowing a user fn
634    // named `<sink>__buffered`. Restricting to used sinks keeps the
635    // synthetic surface tight.
636    let mut used_sinks: HashMap<String, BufferBuildShape> = HashMap::new();
637    for site in &sites {
638        if let Some(shape) = all_sinks.get(&site.sink_fn) {
639            used_sinks.insert(site.sink_fn.clone(), shape.clone());
640        }
641    }
642    let synthesized = synthesize_buffered_variants(&fn_refs, &used_sinks);
643    let sinks = used_sinks;
644    drop(fn_refs);
645
646    let mut fn_defs_owned: Vec<&mut FnDef> = items
647        .iter_mut()
648        .filter_map(|it| match it {
649            crate::ast::TopLevel::FnDef(fd) => Some(fd),
650            _ => None,
651        })
652        .collect();
653    // rewrite_fusion_sites takes &mut [FnDef], so pull a fresh
654    // mutable view across owned slots. We can't pass &mut [&mut FnDef]
655    // directly — instead, walk and rewrite each fn body individually.
656    for fd in fn_defs_owned.iter_mut() {
657        rewrite_one_fn(fd, &sinks);
658    }
659
660    items.reserve(synthesized.len());
661    for fd in synthesized.iter() {
662        items.push(crate::ast::TopLevel::FnDef(fd.clone()));
663    }
664
665    let mut sink_fns: Vec<String> = sinks.keys().cloned().collect();
666    sink_fns.sort();
667    let synthesized_fns: Vec<String> = synthesized.iter().map(|fd| fd.name.clone()).collect();
668
669    let mut rewrites_by_sink: std::collections::BTreeMap<String, usize> =
670        std::collections::BTreeMap::new();
671    for site in &sites {
672        *rewrites_by_sink.entry(site.sink_fn.clone()).or_default() += 1;
673    }
674
675    BufferBuildPassReport {
676        rewrites: sites.len(),
677        synthesized: synthesized_fns,
678        sink_fns,
679        rewrites_by_sink,
680    }
681}
682
683/// Per-pass report — what buffer_build did during a single pipeline run.
684/// Drives `aver compile --explain-passes`; consumed by the bench
685/// regression checks (e.g. "fail if buffer_build no longer fires on the
686/// canonical shape").
687#[derive(Debug, Clone, Default)]
688pub struct BufferBuildPassReport {
689    /// Number of fusion sites rewritten in place.
690    pub rewrites: usize,
691    /// Names of synthesized `<sink>__buffered` variants appended to
692    /// the items list.
693    pub synthesized: Vec<String>,
694    /// Sink fns whose buffered variant actually fired (matches one of
695    /// `synthesized` minus the `__buffered` suffix). Sorted.
696    pub sink_fns: Vec<String>,
697    /// Per-sink rewrite counts. Sorted alphabetically by sink fn.
698    pub rewrites_by_sink: std::collections::BTreeMap<String, usize>,
699}
700
701/// Apply fusion-site rewrite to a single fn body. Internal helper
702/// for `run_buffer_build_pass` since `rewrite_fusion_sites` takes a
703/// slice and we have an iterator-of-mut-refs here.
704fn rewrite_one_fn(fd: &mut FnDef, sinks: &HashMap<String, BufferBuildShape>) {
705    let body_arc = std::sync::Arc::make_mut(&mut fd.body);
706    let FnBody::Block(stmts) = body_arc;
707    for stmt in stmts.iter_mut() {
708        match stmt {
709            Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
710                rewrite_expr_in_place(expr, sinks);
711            }
712        }
713    }
714}
715
716/// Walk every expression in `fn_defs` and rewrite `String.join`
717/// fusion sites in place: `String.join(matched_fn(args, []), sep)` →
718/// `__buf_finalize(matched_fn__buffered(args_without_acc, __buf_new(8192), sep))`.
719///
720/// Conservative trigger per the C' review: only fires when the
721/// acc-position arg is a literal `Expr::List([])`. A non-empty
722/// initial accumulator would silently lose elements after rewrite,
723/// so we skip in that case.
724///
725/// The rewrite is recursive: nested fusion sites (a fusion site
726/// inside another fusion site's args) all get rewritten in one pass.
727pub fn rewrite_fusion_sites(fn_defs: &mut [FnDef], sinks: &HashMap<String, BufferBuildShape>) {
728    if sinks.is_empty() {
729        return;
730    }
731    for fd in fn_defs.iter_mut() {
732        let body_arc = std::sync::Arc::make_mut(&mut fd.body);
733        let FnBody::Block(stmts) = body_arc;
734        for stmt in stmts.iter_mut() {
735            match stmt {
736                Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
737                    rewrite_expr_in_place(expr, sinks);
738                }
739            }
740        }
741    }
742}
743
744/// Recursive expression-tree walker that rewrites fusion sites in
745/// place. Rewrite is "outermost first" — if the whole expression is
746/// a fusion site, transform it before descending into the new shape's
747/// children, so we don't double-rewrite.
748fn rewrite_expr_in_place(expr: &mut Spanned<Expr>, sinks: &HashMap<String, BufferBuildShape>) {
749    if let Some(replacement) = try_rewrite_fusion_site(expr, sinks) {
750        *expr = replacement;
751        // The replacement contains the original elem expressions
752        // (possibly themselves containing fusion sites in deep
753        // gradient builders). Recurse into the new tree.
754        descend_into_subexprs(expr, sinks);
755        return;
756    }
757    descend_into_subexprs(expr, sinks);
758}
759
760/// Recurse into the children of an Expr, applying `rewrite_expr_in_place`
761/// to each. Mirrors the shape coverage of `walk_expr_for_fusion_sites`
762/// in this module so we don't miss any node kind.
763fn descend_into_subexprs(expr: &mut Spanned<Expr>, sinks: &HashMap<String, BufferBuildShape>) {
764    match &mut expr.node {
765        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
766        Expr::Constructor(_, Some(inner)) | Expr::Attr(inner, _) | Expr::ErrorProp(inner) => {
767            rewrite_expr_in_place(inner, sinks);
768        }
769        Expr::FnCall(callee, args) => {
770            rewrite_expr_in_place(callee, sinks);
771            for a in args.iter_mut() {
772                rewrite_expr_in_place(a, sinks);
773            }
774        }
775        Expr::TailCall(data) => {
776            for a in data.args.iter_mut() {
777                rewrite_expr_in_place(a, sinks);
778            }
779        }
780        Expr::BinOp(_, l, r) => {
781            rewrite_expr_in_place(l, sinks);
782            rewrite_expr_in_place(r, sinks);
783        }
784        Expr::Match { subject, arms } => {
785            rewrite_expr_in_place(subject, sinks);
786            for arm in arms.iter_mut() {
787                rewrite_expr_in_place(&mut arm.body, sinks);
788            }
789        }
790        Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
791            for it in items.iter_mut() {
792                rewrite_expr_in_place(it, sinks);
793            }
794        }
795        Expr::MapLiteral(entries) => {
796            for (k, v) in entries.iter_mut() {
797                rewrite_expr_in_place(k, sinks);
798                rewrite_expr_in_place(v, sinks);
799            }
800        }
801        Expr::RecordCreate { fields, .. } => {
802            for (_, v) in fields.iter_mut() {
803                rewrite_expr_in_place(v, sinks);
804            }
805        }
806        Expr::RecordUpdate { base, updates, .. } => {
807            rewrite_expr_in_place(base, sinks);
808            for (_, v) in updates.iter_mut() {
809                rewrite_expr_in_place(v, sinks);
810            }
811        }
812        Expr::InterpolatedStr(parts) => {
813            for part in parts.iter_mut() {
814                if let crate::ast::StrPart::Parsed(inner) = part {
815                    rewrite_expr_in_place(inner, sinks);
816                }
817            }
818        }
819    }
820}
821
822/// If `expr` is a `String.join(matched_fn(args, []), sep)` with
823/// matched_fn in `sinks` and acc-position arg a literal empty list,
824/// return the rewritten Spanned<Expr>. Else return None.
825fn try_rewrite_fusion_site(
826    expr: &Spanned<Expr>,
827    sinks: &HashMap<String, BufferBuildShape>,
828) -> Option<Spanned<Expr>> {
829    let line = expr.line;
830
831    // Match the same predicate used by `find_fusion_sites` so the
832    // diagnostic count and the rewrite count are guaranteed equal.
833    let sink_name = match_string_join_fusion_site(&expr.node, sinks)?;
834    let shape = sinks.get(&sink_name)?;
835
836    // Re-extract the inner sink call and its args. The match predicate
837    // above already verified the shape; this is just to recover the
838    // pieces we need to assemble the rewrite.
839    let outer_args = match &expr.node {
840        Expr::FnCall(_, a) => a,
841        _ => return None,
842    };
843    let consumer_arg = &outer_args[0].node;
844    let inner_call_expr = if let Expr::FnCall(rev_callee, rev_args) = consumer_arg
845        && is_dotted_ident(&rev_callee.node, "List", "reverse")
846        && rev_args.len() == 1
847    {
848        &rev_args[0].node
849    } else {
850        consumer_arg
851    };
852    let inner_args = match inner_call_expr {
853        Expr::FnCall(_, a) => a,
854        _ => return None,
855    };
856
857    // Build the rewrite:
858    //   __buf_finalize(
859    //     <fn>__buffered(
860    //       <args without acc-pos>,
861    //       __buf_new(8192),
862    //       <sep>
863    //     )
864    //   )
865    let sep_expr = outer_args[1].clone();
866    let buf_new = intrinsic_call(
867        line,
868        "__buf_new",
869        vec![sp_at(line, Expr::Literal(Literal::Int(8192)))],
870    );
871    let mut buffered_args: Vec<Spanned<Expr>> = inner_args
872        .iter()
873        .enumerate()
874        .filter_map(|(i, a)| (i != shape.acc_param_idx).then_some(a).cloned())
875        .collect();
876    buffered_args.push(buf_new);
877    buffered_args.push(sep_expr);
878    let buffered_call = sp_at(
879        line,
880        Expr::FnCall(
881            Box::new(sp_at(line, Expr::Ident(format!("{}__buffered", sink_name)))),
882            buffered_args,
883        ),
884    );
885    Some(intrinsic_call(line, "__buf_finalize", vec![buffered_call]))
886}
887
888/// Construct the buffered FnDef for a single matched fn. Returns
889/// `None` if the original body shape doesn't match what we expect
890/// (defensive: detection should have caught this, but if the body
891/// changed shape between detection and synthesis, skip).
892fn build_buffered_variant(fd: &FnDef, shape: &BufferBuildShape) -> Option<FnDef> {
893    // Original body: `match <subject> { <terminating-arm>; <recursive-arm> }`.
894    // The two BufferBuildKind variants pair different patterns:
895    //   InternalReverse: `true -> List.reverse(acc)`, `false -> recurse(...)`.
896    //   ExternalReverse: `[] -> acc`, `[head, ..rest] -> recurse(...)`.
897    // We extract the recursive arm in both cases (for the prepend tail
898    // call) and rebuild the match with terminating arm `... -> __buf`.
899    let stmts = fd.body.stmts();
900    if stmts.len() != 1 {
901        return None;
902    }
903    let outer_expr = match &stmts[0] {
904        Stmt::Expr(spanned) => spanned,
905        _ => return None,
906    };
907    let (subject_orig, arms_orig) = match &outer_expr.node {
908        Expr::Match { subject, arms } => (subject, arms),
909        _ => return None,
910    };
911    let recursive_body: &Spanned<Expr> = match shape.kind {
912        BufferBuildKind::InternalReverse => arms_orig
913            .iter()
914            .find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(false))))
915            .map(|a| a.body.as_ref())?,
916        BufferBuildKind::ExternalReverse => arms_orig
917            .iter()
918            .find(|a| matches!(a.pattern, Pattern::Cons(_, _)))
919            .map(|a| a.body.as_ref())?,
920    };
921    let tail_data = match &recursive_body.node {
922        Expr::TailCall(data) => data,
923        _ => return None,
924    };
925
926    // The acc-position arg in the original tail call is
927    // `List.prepend(<elem>, acc)`. Extract the element expression.
928    let acc_arg_orig = tail_data.args.get(shape.acc_param_idx)?;
929    let elem_expr = match &acc_arg_orig.node {
930        Expr::FnCall(callee, args) => {
931            if !is_dotted_ident(&callee.node, "List", "prepend") {
932                return None;
933            }
934            if args.len() != 2 {
935                return None;
936            }
937            // args[0] is elem, args[1] is acc ident — verify acc.
938            match &args[1].node {
939                Expr::Ident(name) if name == &shape.acc_param_name => {}
940                _ => return None,
941            }
942            args[0].clone()
943        }
944        _ => return None,
945    };
946
947    let line = fd.line;
948    let buf_name = "__buf";
949    let sep_name = "__sep";
950    let buffered_target = format!("{}__buffered", fd.name);
951
952    // Synthesized false arm body:
953    //   <self>__buffered(<orig args minus acc>, __buf_append(<sep_unless_first>, <elem>), __sep)
954    //
955    // Build the buffer-threading expression first: the inner intrinsic
956    // appends `__sep` if the buffer is non-empty (otherwise no-op),
957    // returning the possibly-grown buffer. The outer intrinsic appends
958    // the user's element. The result is what gets passed as the
959    // buffered variant's `__buf` arg in the recursive call.
960    let buf_ident = || sp_at(line, Expr::Ident(buf_name.to_string()));
961    let sep_ident = || sp_at(line, Expr::Ident(sep_name.to_string()));
962    let sep_then_buf = intrinsic_call(
963        line,
964        "__buf_append_sep_unless_first",
965        vec![buf_ident(), sep_ident()],
966    );
967    let final_buf = intrinsic_call(line, "__buf_append", vec![sep_then_buf, elem_expr]);
968
969    // Build new tail-call args: original args with acc-pos replaced by
970    // the threaded buffer expression, then `__sep` appended at end.
971    let mut new_args: Vec<Spanned<Expr>> = tail_data
972        .args
973        .iter()
974        .enumerate()
975        .map(|(i, a)| {
976            if i == shape.acc_param_idx {
977                final_buf.clone()
978            } else {
979                a.clone()
980            }
981        })
982        .collect();
983    new_args.push(sep_ident());
984
985    let new_recursive_body = sp_at(
986        line,
987        Expr::TailCall(Box::new(TailCallData {
988            target: buffered_target.clone(),
989            args: new_args,
990        })),
991    );
992
993    // Terminating arm body: just return `__buf` — the buffer IS the result.
994    // Pattern depends on which sink shape we matched: `true` for the
995    // InternalReverse idiom (where the original returned `List.reverse(acc)`),
996    // `[]` for ExternalReverse (where the original returned `acc`).
997    let new_arms = match shape.kind {
998        BufferBuildKind::InternalReverse => vec![
999            MatchArm {
1000                pattern: Pattern::Literal(Literal::Bool(true)),
1001                body: Box::new(buf_ident()),
1002            },
1003            MatchArm {
1004                pattern: Pattern::Literal(Literal::Bool(false)),
1005                body: Box::new(new_recursive_body),
1006            },
1007        ],
1008        BufferBuildKind::ExternalReverse => {
1009            // Re-use the cons binding names from the original arm so
1010            // any `head` / `rest` references inside the recursive body
1011            // continue to resolve.
1012            let cons_pat = arms_orig
1013                .iter()
1014                .find_map(|a| match &a.pattern {
1015                    Pattern::Cons(h, t) => Some(Pattern::Cons(h.clone(), t.clone())),
1016                    _ => None,
1017                })
1018                .unwrap_or(Pattern::Cons("__head".to_string(), "__tail".to_string()));
1019            vec![
1020                MatchArm {
1021                    pattern: Pattern::EmptyList,
1022                    body: Box::new(buf_ident()),
1023                },
1024                MatchArm {
1025                    pattern: cons_pat,
1026                    body: Box::new(new_recursive_body),
1027                },
1028            ]
1029        }
1030    };
1031
1032    let new_match = sp_at(
1033        line,
1034        Expr::Match {
1035            subject: subject_orig.clone(),
1036            arms: new_arms,
1037        },
1038    );
1039
1040    let new_body = FnBody::Block(vec![Stmt::Expr(new_match)]);
1041
1042    // Params: original minus acc + (__buf, "Buffer") + (__sep, "String").
1043    let mut new_params: Vec<(String, String)> = fd
1044        .params
1045        .iter()
1046        .enumerate()
1047        .filter_map(|(i, p)| (i != shape.acc_param_idx).then_some(p).cloned())
1048        .collect();
1049    new_params.push((buf_name.to_string(), "Buffer".to_string()));
1050    new_params.push((sep_name.to_string(), "String".to_string()));
1051
1052    Some(FnDef {
1053        name: buffered_target,
1054        line,
1055        params: new_params,
1056        return_type: "Buffer".to_string(),
1057        // Synthesized variants inherit effects from the original — if
1058        // the matched fn calls effectful helpers (like `renderRow`
1059        // calling `Console.print`), the buffered variant calls them
1060        // too at the same positions. Conservative.
1061        effects: fd.effects.clone(),
1062        desc: Some(format!(
1063            "Synthesized buffered variant of `{}` for deforestation \
1064             lowering. Call sites that match `String.join({}(...), sep)` \
1065             are rewritten to alloc a buffer + call this variant + \
1066             finalize, skipping the intermediate List.",
1067            fd.name, fd.name
1068        )),
1069        body: Arc::new(new_body),
1070        resolution: None,
1071    })
1072}
1073
1074#[cfg(test)]
1075mod tests {
1076    use super::*;
1077    use crate::ast::{BinOp, FnBody, FnDef, Literal, Spanned, TailCallData};
1078    use std::sync::Arc;
1079
1080    fn sp<T>(value: T) -> Spanned<T> {
1081        Spanned {
1082            node: value,
1083            line: 1,
1084        }
1085    }
1086
1087    fn ident(name: &str) -> Spanned<Expr> {
1088        sp(Expr::Ident(name.to_string()))
1089    }
1090
1091    fn dotted(module: &str, member: &str) -> Spanned<Expr> {
1092        sp(Expr::Attr(Box::new(ident(module)), member.to_string()))
1093    }
1094
1095    fn call(callee: Spanned<Expr>, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
1096        sp(Expr::FnCall(Box::new(callee), args))
1097    }
1098
1099    /// Build a canonical buffer-build fn: takes (col: Int, acc: List<Int>),
1100    /// matches col >= 10, true → reverse(acc), false → tail-call self
1101    /// with prepend(col, acc).
1102    fn canonical_builder(name: &str) -> FnDef {
1103        let true_body = call(dotted("List", "reverse"), vec![ident("acc")]);
1104        let prepend = call(dotted("List", "prepend"), vec![ident("col"), ident("acc")]);
1105        let false_body = sp(Expr::TailCall(Box::new(TailCallData {
1106            target: name.to_string(),
1107            args: vec![
1108                sp(Expr::BinOp(
1109                    BinOp::Add,
1110                    Box::new(ident("col")),
1111                    Box::new(sp(Expr::Literal(Literal::Int(1)))),
1112                )),
1113                prepend,
1114            ],
1115        })));
1116        let match_expr = sp(Expr::Match {
1117            subject: Box::new(sp(Expr::BinOp(
1118                BinOp::Gte,
1119                Box::new(ident("col")),
1120                Box::new(sp(Expr::Literal(Literal::Int(10)))),
1121            ))),
1122            arms: vec![
1123                MatchArm {
1124                    pattern: Pattern::Literal(Literal::Bool(true)),
1125                    body: Box::new(true_body),
1126                },
1127                MatchArm {
1128                    pattern: Pattern::Literal(Literal::Bool(false)),
1129                    body: Box::new(false_body),
1130                },
1131            ],
1132        });
1133        FnDef {
1134            name: name.to_string(),
1135            line: 1,
1136            params: vec![
1137                ("col".to_string(), "Int".to_string()),
1138                ("acc".to_string(), "List<Int>".to_string()),
1139            ],
1140            return_type: "List<Int>".to_string(),
1141            effects: vec![],
1142            desc: None,
1143            body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
1144            resolution: None,
1145        }
1146    }
1147
1148    #[test]
1149    fn matches_canonical_buffer_build() {
1150        let fd = canonical_builder("build");
1151        let info = compute_buffer_build_sinks(&[&fd]);
1152        let shape = info.get("build").expect("expected match");
1153        assert_eq!(shape.acc_param_idx, 1);
1154        assert_eq!(shape.acc_param_name, "acc");
1155    }
1156
1157    #[test]
1158    fn rejects_fn_without_list_param() {
1159        let mut fd = canonical_builder("build");
1160        // Strip the List<...> param.
1161        fd.params = vec![("col".to_string(), "Int".to_string())];
1162        let info = compute_buffer_build_sinks(&[&fd]);
1163        assert!(info.is_empty(), "fn without List param should not match");
1164    }
1165
1166    #[test]
1167    fn rejects_when_true_arm_isnt_reverse() {
1168        let mut fd = canonical_builder("build");
1169        // Replace true arm body with a different expression.
1170        if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1171            if let Stmt::Expr(spanned) = &mut stmts[0] {
1172                if let Expr::Match { arms, .. } = &mut spanned.node {
1173                    arms[0].body = Box::new(ident("acc"));
1174                }
1175            }
1176        }
1177        let info = compute_buffer_build_sinks(&[&fd]);
1178        assert!(
1179            info.is_empty(),
1180            "fn returning bare acc instead of reverse should not match"
1181        );
1182    }
1183
1184    #[test]
1185    fn rejects_when_false_arm_uses_append_not_prepend() {
1186        let mut fd = canonical_builder("build");
1187        // Swap List.prepend → List.append in the false arm tail call.
1188        if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1189            if let Stmt::Expr(spanned) = &mut stmts[0] {
1190                if let Expr::Match { arms, .. } = &mut spanned.node {
1191                    let false_body = arms[1].body.as_mut();
1192                    if let Expr::TailCall(data) = &mut false_body.node {
1193                        if let Expr::FnCall(callee, _) = &mut data.args[1].node {
1194                            if let Expr::Attr(_, attr) = &mut callee.node {
1195                                *attr = "append".to_string();
1196                            }
1197                        }
1198                    }
1199                }
1200            }
1201        }
1202        let info = compute_buffer_build_sinks(&[&fd]);
1203        assert!(
1204            info.is_empty(),
1205            "fn using List.append instead of prepend should not match"
1206        );
1207    }
1208
1209    #[test]
1210    fn rejects_tail_call_to_different_fn() {
1211        let mut fd = canonical_builder("build");
1212        if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1213            if let Stmt::Expr(spanned) = &mut stmts[0] {
1214                if let Expr::Match { arms, .. } = &mut spanned.node {
1215                    let false_body = arms[1].body.as_mut();
1216                    if let Expr::TailCall(data) = &mut false_body.node {
1217                        data.target = "someone_else".to_string();
1218                    }
1219                }
1220            }
1221        }
1222        let info = compute_buffer_build_sinks(&[&fd]);
1223        assert!(
1224            info.is_empty(),
1225            "fn whose recursive call targets a different name should not match"
1226        );
1227    }
1228
1229    #[test]
1230    fn rejects_match_with_non_bool_arms() {
1231        let mut fd = canonical_builder("build");
1232        if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1233            if let Stmt::Expr(spanned) = &mut stmts[0] {
1234                if let Expr::Match { arms, .. } = &mut spanned.node {
1235                    arms[0].pattern = Pattern::Literal(Literal::Int(0));
1236                }
1237            }
1238        }
1239        let info = compute_buffer_build_sinks(&[&fd]);
1240        assert!(
1241            info.is_empty(),
1242            "match on non-bool patterns should not be detected as buffer-build"
1243        );
1244    }
1245
1246    /// End-to-end: parse a small Aver source, run TCO, then detect.
1247    /// The TCO transform is what produces `Expr::TailCall` nodes from
1248    /// raw `Expr::FnCall` self-recursion; detection runs on the post-TCO
1249    /// AST.
1250    #[test]
1251    fn detects_via_parser_after_tco() {
1252        let src = r#"
1253fn build(n: Int, acc: List<Int>) -> List<Int>
1254    match n <= 0
1255        true  -> List.reverse(acc)
1256        false -> build(n - 1, List.prepend(n, acc))
1257"#;
1258        let mut lexer = crate::lexer::Lexer::new(src);
1259        let tokens = lexer.tokenize().expect("lex");
1260        let mut parser = crate::parser::Parser::new(tokens);
1261        let mut items = parser.parse().expect("parse");
1262        crate::ir::pipeline::tco(&mut items);
1263        let fns: Vec<&FnDef> = items
1264            .iter()
1265            .filter_map(|it| match it {
1266                crate::ast::TopLevel::FnDef(fd) => Some(fd),
1267                _ => None,
1268            })
1269            .collect();
1270        let info = compute_buffer_build_sinks(&fns);
1271        let shape = info
1272            .get("build")
1273            .expect("expected end-to-end shape match for canonical builder");
1274        assert_eq!(shape.acc_param_idx, 1);
1275        assert_eq!(shape.acc_param_name, "acc");
1276    }
1277
1278    /// End-to-end fusion-site detection: builder + caller `String.join`
1279    /// site recognised, line recorded, sink name attached.
1280    #[test]
1281    fn finds_fusion_site_via_parser() {
1282        let src = r#"
1283fn build(n: Int, acc: List<Int>) -> List<Int>
1284    match n <= 0
1285        true  -> List.reverse(acc)
1286        false -> build(n - 1, List.prepend(n, acc))
1287
1288fn main() -> String
1289    String.join(build(5, []), ",")
1290"#;
1291        let mut lexer = crate::lexer::Lexer::new(src);
1292        let tokens = lexer.tokenize().expect("lex");
1293        let mut parser = crate::parser::Parser::new(tokens);
1294        let mut items = parser.parse().expect("parse");
1295        crate::ir::pipeline::tco(&mut items);
1296        let fns: Vec<&FnDef> = items
1297            .iter()
1298            .filter_map(|it| match it {
1299                crate::ast::TopLevel::FnDef(fd) => Some(fd),
1300                _ => None,
1301            })
1302            .collect();
1303        let sinks = compute_buffer_build_sinks(&fns);
1304        let sites = find_fusion_sites(&fns, &sinks);
1305        assert_eq!(sites.len(), 1, "expected one fusion site, got {sites:?}");
1306        let site = &sites[0];
1307        assert_eq!(site.enclosing_fn, "main");
1308        assert_eq!(site.sink_fn, "build");
1309        assert!(site.line > 0, "expected real line info, got 0");
1310    }
1311
1312    /// Caller passes the matched fn's result to a non-`String.join`
1313    /// destination — should NOT register as a fusion site (no buffer
1314    /// to write into).
1315    #[test]
1316    fn ignores_call_when_not_wrapped_in_string_join() {
1317        let src = r#"
1318fn build(n: Int, acc: List<Int>) -> List<Int>
1319    match n <= 0
1320        true  -> List.reverse(acc)
1321        false -> build(n - 1, List.prepend(n, acc))
1322
1323fn main() -> List<Int>
1324    build(5, [])
1325"#;
1326        let mut lexer = crate::lexer::Lexer::new(src);
1327        let tokens = lexer.tokenize().expect("lex");
1328        let mut parser = crate::parser::Parser::new(tokens);
1329        let mut items = parser.parse().expect("parse");
1330        crate::ir::pipeline::tco(&mut items);
1331        let fns: Vec<&FnDef> = items
1332            .iter()
1333            .filter_map(|it| match it {
1334                crate::ast::TopLevel::FnDef(fd) => Some(fd),
1335                _ => None,
1336            })
1337            .collect();
1338        let sinks = compute_buffer_build_sinks(&fns);
1339        let sites = find_fusion_sites(&fns, &sinks);
1340        assert!(
1341            sites.is_empty(),
1342            "build called outside String.join must not be a fusion site, got {sites:?}"
1343        );
1344    }
1345
1346    /// Counter-test: a recursive fn that returns `acc` directly (no
1347    /// reverse) — semantically valid Aver, but its result order is
1348    /// reversed relative to natural read order, so deforestation can't
1349    /// safely rewrite to a forward-emit buffer loop without explicit
1350    /// authorisation. Detector must reject it.
1351    #[test]
1352    fn rejects_via_parser_when_true_arm_returns_bare_acc() {
1353        let src = r#"
1354fn build(n: Int, acc: List<Int>) -> List<Int>
1355    match n <= 0
1356        true  -> acc
1357        false -> build(n - 1, List.prepend(n, acc))
1358"#;
1359        let mut lexer = crate::lexer::Lexer::new(src);
1360        let tokens = lexer.tokenize().expect("lex");
1361        let mut parser = crate::parser::Parser::new(tokens);
1362        let mut items = parser.parse().expect("parse");
1363        crate::ir::pipeline::tco(&mut items);
1364        let fns: Vec<&FnDef> = items
1365            .iter()
1366            .filter_map(|it| match it {
1367                crate::ast::TopLevel::FnDef(fd) => Some(fd),
1368                _ => None,
1369            })
1370            .collect();
1371        let info = compute_buffer_build_sinks(&fns);
1372        assert!(
1373            info.is_empty(),
1374            "fn returning bare acc must not be detected as a deforestation candidate"
1375        );
1376    }
1377
1378    /// End-to-end synthesis: parse a small builder, run TCO, detect
1379    /// it as a sink, then synthesize the buffered variant. Verify the
1380    /// shape: name suffix, dropped acc param, added __buf/__sep
1381    /// params, true arm returns __buf ident, false arm tail-calls
1382    /// __buffered self with threaded buffer expression.
1383    #[test]
1384    fn synthesizes_buffered_variant_from_real_builder() {
1385        let src = r#"
1386fn build(n: Int, acc: List<Int>) -> List<Int>
1387    match n <= 0
1388        true  -> List.reverse(acc)
1389        false -> build(n - 1, List.prepend(n, acc))
1390"#;
1391        let mut lexer = crate::lexer::Lexer::new(src);
1392        let tokens = lexer.tokenize().expect("lex");
1393        let mut parser = crate::parser::Parser::new(tokens);
1394        let mut items = parser.parse().expect("parse");
1395        crate::ir::pipeline::tco(&mut items);
1396        let fns: Vec<&FnDef> = items
1397            .iter()
1398            .filter_map(|it| match it {
1399                crate::ast::TopLevel::FnDef(fd) => Some(fd),
1400                _ => None,
1401            })
1402            .collect();
1403        let sinks = compute_buffer_build_sinks(&fns);
1404        assert!(sinks.contains_key("build"));
1405        let synthesized = synthesize_buffered_variants(&fns, &sinks);
1406        assert_eq!(
1407            synthesized.len(),
1408            1,
1409            "expected exactly one synthesized variant"
1410        );
1411        let bf = &synthesized[0];
1412
1413        // Name + signature shape.
1414        assert_eq!(bf.name, "build__buffered");
1415        assert_eq!(bf.return_type, "Buffer");
1416        let param_names: Vec<&str> = bf.params.iter().map(|(n, _)| n.as_str()).collect();
1417        let param_types: Vec<&str> = bf.params.iter().map(|(_, t)| t.as_str()).collect();
1418        assert_eq!(param_names, vec!["n", "__buf", "__sep"]);
1419        assert_eq!(param_types, vec!["Int", "Buffer", "String"]);
1420
1421        // Body: single Stmt::Expr holding a 2-arm match.
1422        let stmts = bf.body.stmts();
1423        assert_eq!(stmts.len(), 1);
1424        let match_expr = match &stmts[0] {
1425            Stmt::Expr(s) => match &s.node {
1426                Expr::Match { subject: _, arms } => arms,
1427                _ => panic!("body root must be a match"),
1428            },
1429            _ => panic!("body root must be Stmt::Expr"),
1430        };
1431        assert_eq!(match_expr.len(), 2);
1432
1433        // True arm: body is `__buf` ident.
1434        let true_arm = match_expr
1435            .iter()
1436            .find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(true))))
1437            .expect("true arm");
1438        match &true_arm.body.node {
1439            Expr::Ident(name) => assert_eq!(name, "__buf"),
1440            other => panic!("true arm should be Ident(__buf), got {other:?}"),
1441        }
1442
1443        // False arm: tail-call to build__buffered with threaded buf.
1444        let false_arm = match_expr
1445            .iter()
1446            .find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(false))))
1447            .expect("false arm");
1448        let tail_data = match &false_arm.body.node {
1449            Expr::TailCall(d) => d,
1450            other => panic!("false arm should be TailCall, got {other:?}"),
1451        };
1452        assert_eq!(tail_data.target, "build__buffered");
1453        // Args: [n - 1, threaded-buffer-expr, __sep_ident]. acc-pos
1454        // (was index 1 in original) is now the threaded buffer; sep
1455        // appended at end.
1456        assert_eq!(tail_data.args.len(), 3);
1457        // Arg 1 is the buffer-threading composition; verify it's
1458        // `__buf_append(__buf_append_sep_unless_first(__buf, __sep), n)`.
1459        let outer = match &tail_data.args[1].node {
1460            Expr::FnCall(callee, args) => {
1461                match &callee.node {
1462                    Expr::Ident(name) => assert_eq!(name, "__buf_append"),
1463                    _ => panic!("expected Ident callee"),
1464                }
1465                args
1466            }
1467            _ => panic!("expected outer __buf_append FnCall"),
1468        };
1469        assert_eq!(outer.len(), 2);
1470        // First arg of outer = inner sep-then-buf.
1471        match &outer[0].node {
1472            Expr::FnCall(callee, _) => match &callee.node {
1473                Expr::Ident(name) => assert_eq!(name, "__buf_append_sep_unless_first"),
1474                _ => panic!("expected Ident callee for inner intrinsic"),
1475            },
1476            _ => panic!("expected inner __buf_append_sep_unless_first FnCall"),
1477        }
1478        // Second arg of outer = original `n` (the prepend's element).
1479        match &outer[1].node {
1480            Expr::Ident(name) => assert_eq!(name, "n"),
1481            _ => panic!("expected `n` ident as elem"),
1482        }
1483        // Last tail-call arg = __sep ident.
1484        match &tail_data.args[2].node {
1485            Expr::Ident(name) => assert_eq!(name, "__sep"),
1486            _ => panic!("expected __sep ident as last arg"),
1487        }
1488    }
1489
1490    #[test]
1491    fn detects_acc_param_at_arbitrary_index() {
1492        // Builder where the List<T> param is first and the tail-call
1493        // body wires the prepend at the same index. Detection has to
1494        // pin the acc position to where the prepend actually lands —
1495        // an earlier loose `any` check would silently pass even on
1496        // mismatched param/arg orderings, then synthesis would fail
1497        // to extract the element expression. Keep the body and the
1498        // params consistent so we exercise the real path.
1499        let true_body = call(dotted("List", "reverse"), vec![ident("acc")]);
1500        let prepend = call(dotted("List", "prepend"), vec![ident("col"), ident("acc")]);
1501        // Tail call: build(prepend(col, acc), col + 1)
1502        // — acc-position arg is at index 0, col+1 at index 1.
1503        let false_body = sp(Expr::TailCall(Box::new(TailCallData {
1504            target: "build".to_string(),
1505            args: vec![
1506                prepend,
1507                sp(Expr::BinOp(
1508                    BinOp::Add,
1509                    Box::new(ident("col")),
1510                    Box::new(sp(Expr::Literal(Literal::Int(1)))),
1511                )),
1512            ],
1513        })));
1514        let match_expr = sp(Expr::Match {
1515            subject: Box::new(sp(Expr::BinOp(
1516                BinOp::Gte,
1517                Box::new(ident("col")),
1518                Box::new(sp(Expr::Literal(Literal::Int(10)))),
1519            ))),
1520            arms: vec![
1521                MatchArm {
1522                    pattern: Pattern::Literal(Literal::Bool(true)),
1523                    body: Box::new(true_body),
1524                },
1525                MatchArm {
1526                    pattern: Pattern::Literal(Literal::Bool(false)),
1527                    body: Box::new(false_body),
1528                },
1529            ],
1530        });
1531        let fd = FnDef {
1532            name: "build".to_string(),
1533            line: 1,
1534            params: vec![
1535                ("acc".to_string(), "List<Int>".to_string()),
1536                ("col".to_string(), "Int".to_string()),
1537            ],
1538            return_type: "List<Int>".to_string(),
1539            effects: vec![],
1540            desc: None,
1541            body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
1542            resolution: None,
1543        };
1544        let info = compute_buffer_build_sinks(&[&fd]);
1545        let shape = info.get("build").expect("expected match");
1546        assert_eq!(shape.acc_param_idx, 0);
1547        assert_eq!(shape.acc_param_name, "acc");
1548    }
1549
1550    #[test]
1551    fn rejects_loose_prepend_in_non_acc_position() {
1552        // Earlier the detector accepted a fn whose tail call had a
1553        // prepend in *some* arg, regardless of position. That let
1554        // detection promise a sink the synthesizer couldn't actually
1555        // build. Make sure the tightened predicate refuses this.
1556        let mut fd = canonical_builder("build");
1557        // Reorder tail-call args so prepend ends up at index 0 instead
1558        // of index 1 — but keep params [(col, Int), (acc, List<Int>)],
1559        // so acc-position is index 1, where there's now a `col + 1`
1560        // expression (no prepend). Detection should refuse.
1561        {
1562            let body = std::sync::Arc::make_mut(&mut fd.body);
1563            let FnBody::Block(stmts) = body;
1564            if let Stmt::Expr(spanned) = &mut stmts[0]
1565                && let Expr::Match { arms, .. } = &mut spanned.node
1566            {
1567                for arm in arms.iter_mut() {
1568                    if matches!(arm.pattern, Pattern::Literal(Literal::Bool(false)))
1569                        && let Expr::TailCall(data) = &mut arm.body.node
1570                    {
1571                        data.args.reverse();
1572                    }
1573                }
1574            }
1575        }
1576        let info = compute_buffer_build_sinks(&[&fd]);
1577        assert!(
1578            info.get("build").is_none(),
1579            "loose-prepend (prepend not at acc-position) must not be detected"
1580        );
1581    }
1582
1583    #[test]
1584    fn skips_synth_when_no_rewriteable_call_site() {
1585        // A fn that matches the sink shape but whose only call site
1586        // doesn't fit the canonical fusion pattern (e.g. starts with a
1587        // non-empty initial accumulator, or the wrapper is an unrelated
1588        // function call rather than `String.join`) should NOT get a
1589        // synthesized `__buffered` variant. Generating one is bloat
1590        // and risks shadowing user fns.
1591        let sink = canonical_builder("build");
1592        // Dummy caller that uses `build` but not via `String.join(...)`.
1593        let caller = FnDef {
1594            name: "use_build".to_string(),
1595            line: 2,
1596            params: vec![],
1597            return_type: "List<Int>".to_string(),
1598            effects: vec![],
1599            desc: None,
1600            body: Arc::new(FnBody::Block(vec![Stmt::Expr(call(
1601                ident_expr("build"),
1602                vec![sp(Expr::Literal(Literal::Int(0))), sp(Expr::List(vec![]))],
1603            ))])),
1604            resolution: None,
1605        };
1606        let mut items = vec![
1607            crate::ast::TopLevel::FnDef(sink),
1608            crate::ast::TopLevel::FnDef(caller),
1609        ];
1610        let initial_count = items.len();
1611        let report = run_buffer_build_pass(&mut items);
1612        assert_eq!(report.rewrites, 0, "no fusion sites — no rewriteable call");
1613        assert_eq!(
1614            report.synthesized.len(),
1615            0,
1616            "no synth — nothing to fuse against"
1617        );
1618        assert_eq!(items.len(), initial_count, "no buffered variant appended");
1619    }
1620
1621    #[test]
1622    fn external_reverse_pattern_round_trips() {
1623        // `match list { [] -> acc; [h, ..t] -> recurse(t, prepend(_, acc)) }`
1624        // sink + `String.join(List.reverse(<sink>(args, [])), sep)` call
1625        // site should detect, synth, and rewrite as a single fusion.
1626        let nil_body = ident("acc");
1627        let prepend = call(dotted("List", "prepend"), vec![ident("h"), ident("acc")]);
1628        let cons_body = sp(Expr::TailCall(Box::new(TailCallData {
1629            target: "build".to_string(),
1630            args: vec![ident("t"), prepend],
1631        })));
1632        let match_expr = sp(Expr::Match {
1633            subject: Box::new(ident("xs")),
1634            arms: vec![
1635                MatchArm {
1636                    pattern: Pattern::EmptyList,
1637                    body: Box::new(nil_body),
1638                },
1639                MatchArm {
1640                    pattern: Pattern::Cons("h".to_string(), "t".to_string()),
1641                    body: Box::new(cons_body),
1642                },
1643            ],
1644        });
1645        let sink = FnDef {
1646            name: "build".to_string(),
1647            line: 1,
1648            params: vec![
1649                ("xs".to_string(), "List<Int>".to_string()),
1650                ("acc".to_string(), "List<String>".to_string()),
1651            ],
1652            return_type: "List<String>".to_string(),
1653            effects: vec![],
1654            desc: None,
1655            body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
1656            resolution: None,
1657        };
1658        let info = compute_buffer_build_sinks(&[&sink]);
1659        let shape = info
1660            .get("build")
1661            .expect("external-reverse sink should be detected");
1662        assert_eq!(shape.kind, BufferBuildKind::ExternalReverse);
1663        assert_eq!(shape.acc_param_idx, 1);
1664
1665        // Caller: `String.join(List.reverse(build(xs, [])), "\n")`
1666        let join_call = call(
1667            dotted("String", "join"),
1668            vec![
1669                call(
1670                    dotted("List", "reverse"),
1671                    vec![call(
1672                        ident_expr("build"),
1673                        vec![ident("xs"), sp(Expr::List(vec![]))],
1674                    )],
1675                ),
1676                sp(Expr::Literal(Literal::Str("\n".to_string()))),
1677            ],
1678        );
1679        let caller = FnDef {
1680            name: "render".to_string(),
1681            line: 2,
1682            params: vec![("xs".to_string(), "List<Int>".to_string())],
1683            return_type: "String".to_string(),
1684            effects: vec![],
1685            desc: None,
1686            body: Arc::new(FnBody::Block(vec![Stmt::Expr(join_call)])),
1687            resolution: None,
1688        };
1689
1690        let mut items = vec![
1691            crate::ast::TopLevel::FnDef(sink),
1692            crate::ast::TopLevel::FnDef(caller),
1693        ];
1694        let report = run_buffer_build_pass(&mut items);
1695        assert_eq!(
1696            report.rewrites, 1,
1697            "external-reverse pattern should be one fusion site"
1698        );
1699        assert_eq!(
1700            report.synthesized.len(),
1701            1,
1702            "exactly one buffered variant for the used sink"
1703        );
1704
1705        // The synthesized variant should be appended.
1706        let synth_present = items.iter().any(|it| match it {
1707            crate::ast::TopLevel::FnDef(fd) => fd.name == "build__buffered",
1708            _ => false,
1709        });
1710        assert!(synth_present, "build__buffered must be appended");
1711    }
1712
1713    fn ident_expr(name: &str) -> Spanned<Expr> {
1714        sp(Expr::Ident(name.to_string()))
1715    }
1716}