Skip to main content

aver/ir/
buffer_build.rs

1//! Buffer-build sink detection.
2//!
3//! Identifies user fns that match the canonical functional list-builder
4//! shape consumed by `String.join`:
5//!
6//! ```aver
7//! fn build(..., acc: List<T>) -> List<T>
8//!     match <cond>
9//!         true  -> List.reverse(acc)
10//!         false -> build(..., List.prepend(<elem>, acc))
11//! ```
12//!
13//! When such a fn is called from `String.join(build(..., []), sep)`, the
14//! whole pipeline is semantically equivalent to a single buffer-write
15//! loop — Wadler 1990 shortcut fusion / deforestation. This module is
16//! Phase 1 of the deforestation work for 0.15 "Traversal": it detects
17//! candidate fns. Lowering (rewriting matched fns + their `String.join`
18//! call sites) lives in a separate pass.
19//!
20//! Detection is intentionally local — the analyzer looks only at the fn
21//! body, not its call sites. A matched fn may or may not actually be
22//! consumed by `String.join`; the lowering pass cross-references call
23//! sites separately and only fuses when both ends of the pipeline agree.
24
25use std::collections::HashMap;
26use std::sync::Arc;
27
28use crate::ast::{Expr, FnBody, FnDef, Literal, MatchArm, Pattern, Spanned, Stmt, TailCallData};
29
30/// Where the matched builder puts the `List.reverse` step that gives
31/// the result its forward order.
32///
33/// `prepend(elem, acc)` builds the accumulator in reverse-of-input
34/// order. To get a forward list (the order we'd hand to `String.join`)
35/// it has to be reversed exactly once. Two equally common Aver idioms
36/// place that reverse in different spots:
37///
38/// - `InternalReverse`: the sink itself has `true -> List.reverse(acc)`
39///   in its base case. Caller writes `String.join(<sink>(args, []), sep)`.
40///   This is the classic "loop with reversed accumulator + reverse on
41///   exit" shape.
42/// - `ExternalReverse`: the sink has `[] -> acc` (no reverse) and
43///   matches on the input list directly. Caller writes
44///   `String.join(List.reverse(<sink>(args, [])), sep)`. Common in the
45///   payment_ops / workflow_engine codebases under the `*Into` naming
46///   convention (e.g. `serializeEntriesInto`, `filterSubjectInto`).
47///
48/// Both shapes lower to the same buffered variant — appending in
49/// processing order (which is forward order of the input) yields the
50/// final string in the right order without any explicit reverse.
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum BufferBuildKind {
53    InternalReverse,
54    ExternalReverse,
55}
56
57/// Information about a fn that matches the buffer-build sink shape.
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub struct BufferBuildShape {
60    /// 0-based index of the `acc: List<T>` parameter in the fn signature.
61    /// Identifies which arg in tail-call positions threads the
62    /// accumulator and which `Ident` in the base-case arm is returned.
63    pub acc_param_idx: usize,
64    /// The accumulator parameter's binding name (looked up in tail-call
65    /// args and in the base-case return position).
66    pub acc_param_name: String,
67    /// Which of the two reverse-placement idioms this sink follows;
68    /// determines (a) what shape the original base arm has and (b)
69    /// whether the call site is `String.join(<sink>(...), sep)` or
70    /// `String.join(List.reverse(<sink>(...)), sep)`.
71    pub kind: BufferBuildKind,
72}
73
74/// What the matched builder feeds into. Different consumers compile
75/// to different buffer types and finalizers, but all share the same
76/// underlying deforestation: skip the intermediate List, write
77/// elements straight to the consumer's storage.
78///
79/// Phase 2 implements `StringJoin` only — the canonical case from the
80/// fractal demo. Future variants land as separate phases:
81/// `VectorFromList` (already half-fused via `Vector.set` owned-mutate
82/// in 0.14.0; deforestation closes the cons-cell side), and `ListFold`
83/// for stream-fusion-style consumer rewrites.
84#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum ConsumerKind {
86    /// `String.join(builder(...), sep)` — write each element + sep
87    /// directly into a `Vec<u8>`-shaped buffer in linear memory.
88    StringJoin,
89}
90
91/// One detected fusion site: a builder call whose result is consumed
92/// by a known sink (currently just `String.join`). Lowering rewrites
93/// the producer + consumer pair into a single buffer-write loop.
94#[derive(Debug, Clone, PartialEq, Eq)]
95pub struct FusionSite {
96    /// Name of the enclosing user fn that contains the call.
97    pub enclosing_fn: String,
98    /// Line of the consumer call.
99    pub line: usize,
100    /// The matched buffer-build fn being wrapped.
101    pub sink_fn: String,
102    /// What's consuming the builder's result.
103    pub consumer: ConsumerKind,
104}
105
106/// Walk all fns in `fns`, return a map from fn name to detected shape
107/// for fns that match the buffer-build sink pattern. Fns that don't
108/// match are absent from the result.
109pub fn compute_buffer_build_sinks(fns: &[&FnDef]) -> HashMap<String, BufferBuildShape> {
110    let mut out = HashMap::new();
111    for fd in fns {
112        if let Some(shape) = match_buffer_build_shape(fd) {
113            out.insert(fd.name.clone(), shape);
114        }
115    }
116    out
117}
118
119/// Walk every expression in every fn body looking for fusion sites:
120/// `String.join(matched_fn(...), sep)` calls where `matched_fn` is a
121/// key in `sinks`. Returns one `FusionSite` per call. The lowering
122/// pass rewrites each site to call a buffered variant of `matched_fn`
123/// directly into a pre-allocated buffer.
124pub fn find_fusion_sites(
125    fns: &[&FnDef],
126    sinks: &HashMap<String, BufferBuildShape>,
127) -> Vec<FusionSite> {
128    let mut out = Vec::new();
129    for fd in fns {
130        for stmt in fd.body.stmts() {
131            match stmt {
132                Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
133                    walk_expr_for_fusion_sites(&expr.node, expr.line, &fd.name, sinks, &mut out);
134                }
135            }
136        }
137    }
138    out
139}
140
141/// Recursively walk an expression tree, recording any fusion site we
142/// find. The fallback `expr_line` is used when a sub-expression has no
143/// own line info.
144fn walk_expr_for_fusion_sites(
145    expr: &Expr,
146    expr_line: usize,
147    enclosing_fn: &str,
148    sinks: &HashMap<String, BufferBuildShape>,
149    out: &mut Vec<FusionSite>,
150) {
151    if let Some(inner_name) = match_string_join_fusion_site(expr, sinks) {
152        out.push(FusionSite {
153            enclosing_fn: enclosing_fn.to_string(),
154            line: expr_line,
155            sink_fn: inner_name,
156            consumer: ConsumerKind::StringJoin,
157        });
158    }
159    // Recurse into all sub-expressions regardless of whether this node
160    // matched (a fusion site can sit inside another fusion site's args
161    // — rare but valid; we'd record both and let the lowering decide).
162    visit_subexprs(expr, expr_line, enclosing_fn, sinks, out);
163}
164
165/// Helper: recurse into the sub-expressions of `expr`. Mirrors the
166/// shape coverage of `expr_allocates` in `alloc_info.rs` so we don't
167/// miss any node kind.
168fn visit_subexprs(
169    expr: &Expr,
170    fallback_line: usize,
171    enclosing_fn: &str,
172    sinks: &HashMap<String, BufferBuildShape>,
173    out: &mut Vec<FusionSite>,
174) {
175    let line_of = |s: &crate::ast::Spanned<Expr>| {
176        if s.line > 0 { s.line } else { fallback_line }
177    };
178    match expr {
179        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
180        Expr::Constructor(_, Some(inner)) | Expr::Attr(inner, _) | Expr::ErrorProp(inner) => {
181            walk_expr_for_fusion_sites(&inner.node, line_of(inner), enclosing_fn, sinks, out);
182        }
183        Expr::FnCall(callee, args) => {
184            walk_expr_for_fusion_sites(&callee.node, line_of(callee), enclosing_fn, sinks, out);
185            for a in args {
186                walk_expr_for_fusion_sites(&a.node, line_of(a), enclosing_fn, sinks, out);
187            }
188        }
189        Expr::TailCall(data) => {
190            for a in &data.args {
191                walk_expr_for_fusion_sites(&a.node, line_of(a), enclosing_fn, sinks, out);
192            }
193        }
194        Expr::BinOp(_, l, r) => {
195            walk_expr_for_fusion_sites(&l.node, line_of(l), enclosing_fn, sinks, out);
196            walk_expr_for_fusion_sites(&r.node, line_of(r), enclosing_fn, sinks, out);
197        }
198        Expr::Match { subject, arms } => {
199            walk_expr_for_fusion_sites(&subject.node, line_of(subject), enclosing_fn, sinks, out);
200            for arm in arms {
201                walk_expr_for_fusion_sites(
202                    &arm.body.node,
203                    line_of(&arm.body),
204                    enclosing_fn,
205                    sinks,
206                    out,
207                );
208            }
209        }
210        Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
211            for it in items {
212                walk_expr_for_fusion_sites(&it.node, line_of(it), enclosing_fn, sinks, out);
213            }
214        }
215        Expr::MapLiteral(entries) => {
216            for (k, v) in entries {
217                walk_expr_for_fusion_sites(&k.node, line_of(k), enclosing_fn, sinks, out);
218                walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
219            }
220        }
221        Expr::RecordCreate { fields, .. } => {
222            for (_, v) in fields {
223                walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
224            }
225        }
226        Expr::RecordUpdate { base, updates, .. } => {
227            walk_expr_for_fusion_sites(&base.node, line_of(base), enclosing_fn, sinks, out);
228            for (_, v) in updates {
229                walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
230            }
231        }
232        Expr::InterpolatedStr(parts) => {
233            for part in parts {
234                if let crate::ast::StrPart::Parsed(inner) = part {
235                    walk_expr_for_fusion_sites(
236                        &inner.node,
237                        line_of(inner),
238                        enclosing_fn,
239                        sinks,
240                        out,
241                    );
242                }
243            }
244        }
245    }
246}
247
248/// Pattern-match a single fn against the buffer-build shape.
249fn match_buffer_build_shape(fd: &FnDef) -> Option<BufferBuildShape> {
250    // The accumulator must be a parameter of type `List<...>`. The
251    // params vector stores type strings, not parsed `Type` values, so we
252    // match the textual form. Aver's surface syntax accepts both
253    // `List<T>` and (rarely) `[T]`-like sugar; canonical form is
254    // `List<T>`.
255    // Take the *rightmost* List<...> parameter as the accumulator. The
256    // InternalReverse shape (`fn build(n: Int, acc: List<T>)`) typically
257    // has only one list param; the ExternalReverse shape often has two
258    // (`fn build(input: List<T>, acc: List<U>)`) and the accumulator is
259    // by convention the last argument. Picking the first match would
260    // misidentify `input` as the accumulator and the rest of the
261    // detection would silently fail.
262    let (acc_idx, acc_name) = fd
263        .params
264        .iter()
265        .enumerate()
266        .rfind(|(_, (_, ty))| is_list_type_str(ty))
267        .map(|(i, (name, _))| (i, name.clone()))?;
268
269    // Body must be a single expression statement holding the match.
270    let match_expr = single_match_body(&fd.body)?;
271    let (subject_expr, arms) = match match_expr {
272        Expr::Match { subject, arms } => (subject, arms),
273        _ => return None,
274    };
275
276    // Try the InternalReverse shape first: `match <bool> { true -> List.reverse(acc); false -> recurse(... prepend(_, acc)) }`.
277    if let Some((true_body, false_body)) = pair_bool_arms(arms) {
278        let _ = subject_expr;
279        if is_list_reverse_of(true_body, &acc_name)
280            && is_self_tail_with_prepend_acc(false_body, &fd.name, acc_idx, &acc_name)
281        {
282            return Some(BufferBuildShape {
283                acc_param_idx: acc_idx,
284                acc_param_name: acc_name,
285                kind: BufferBuildKind::InternalReverse,
286            });
287        }
288    }
289
290    // Otherwise try the ExternalReverse shape: `match <list> { [] -> acc;
291    // [_, .._] -> recurse(... prepend(_, acc)) }`. The reverse lives at
292    // the caller, e.g. `List.reverse(<this>(args, []))` (see the
293    // `BufferBuildKind` doc above for context).
294    if let Some((nil_body, cons_body)) = pair_nil_cons_arms(arms)
295        && is_ident_named(nil_body, &acc_name)
296        && is_self_tail_with_prepend_acc(cons_body, &fd.name, acc_idx, &acc_name)
297    {
298        return Some(BufferBuildShape {
299            acc_param_idx: acc_idx,
300            acc_param_name: acc_name,
301            kind: BufferBuildKind::ExternalReverse,
302        });
303    }
304
305    None
306}
307
308/// Match a 2-arm match where one arm is `[]` and the other is
309/// `[head, ..tail]`. Returns `(nil_body, cons_body)` (both as `&Expr`,
310/// matching `pair_bool_arms`). `None` if the arms don't exactly cover
311/// those two patterns.
312fn pair_nil_cons_arms(arms: &[MatchArm]) -> Option<(&Expr, &Expr)> {
313    if arms.len() != 2 {
314        return None;
315    }
316    let mut nil_body: Option<&Expr> = None;
317    let mut cons_body: Option<&Expr> = None;
318    for arm in arms {
319        match &arm.pattern {
320            Pattern::EmptyList => nil_body = Some(&arm.body.node),
321            Pattern::Cons(_, _) => cons_body = Some(&arm.body.node),
322            _ => return None,
323        }
324    }
325    match (nil_body, cons_body) {
326        (Some(n), Some(c)) => Some((n, c)),
327        _ => None,
328    }
329}
330
331/// True if `expr` is just an identifier reference to `name`.
332fn is_ident_named(expr: &Expr, name: &str) -> bool {
333    matches!(expr, Expr::Ident(n) if n == name)
334}
335
336/// Recognise a `String.join` first-arg as either a direct sink call or
337/// a `List.reverse(<sink>(args, []))` wrapper. Returns the matched
338/// sink fn name when the kind matches the shape we expect:
339///   InternalReverse sinks → only the direct call form
340///   ExternalReverse sinks → only the `List.reverse(...)` form
341/// Returning the name only when the kinds line up keeps us from
342/// accidentally fusing a sink against a call site that's missing
343/// its required reverse (or has an extraneous one).
344/// Single source of truth for "is this expression a rewriteable
345/// `String.join(<sink>(args, []), sep)` fusion site?". Returns the
346/// matched sink name when **all** preconditions hold:
347/// 1. The expression is a `String.join(<inner>, _)` call.
348/// 2. `<inner>` is either a direct sink call (for InternalReverse
349///    sinks) or `List.reverse(<sink>(...))` (for ExternalReverse).
350/// 3. The reverse-placement on the call site matches the sink's kind
351///    — mismatch would silently drop or double-reverse.
352/// 4. The acc-position arg in the inner call is a literal empty list.
353///    Anything else means the user is starting the fold with a
354///    non-empty accumulator, and the buffered rewrite would silently
355///    drop those initial elements.
356///
357/// Both `find_fusion_sites` (diagnostics) and `try_rewrite_fusion_site`
358/// (the actual AST rewrite) call this so the two stay in lockstep —
359/// `aver check` can never report a site that the rewrite then refuses
360/// to take.
361fn match_string_join_fusion_site(
362    expr: &Expr,
363    sinks: &HashMap<String, BufferBuildShape>,
364) -> Option<String> {
365    let Expr::FnCall(callee, args) = expr else {
366        return None;
367    };
368    if !is_dotted_ident(&callee.node, "String", "join") || args.len() != 2 {
369        return None;
370    }
371    let consumer_arg = &args[0].node;
372
373    // Peel an optional `List.reverse(...)` wrapper.
374    let (inner_call_expr, saw_external_reverse) = match consumer_arg {
375        Expr::FnCall(rev_callee, rev_args)
376            if is_dotted_ident(&rev_callee.node, "List", "reverse") && rev_args.len() == 1 =>
377        {
378            (&rev_args[0].node, true)
379        }
380        other => (other, false),
381    };
382
383    let Expr::FnCall(inner_callee, inner_args) = inner_call_expr else {
384        return None;
385    };
386    let Expr::Ident(name) = &inner_callee.node else {
387        return None;
388    };
389    let shape = sinks.get(name)?;
390
391    let kinds_align = matches!(
392        (saw_external_reverse, &shape.kind),
393        (false, BufferBuildKind::InternalReverse) | (true, BufferBuildKind::ExternalReverse)
394    );
395    if !kinds_align {
396        return None;
397    }
398
399    let acc_arg = inner_args.get(shape.acc_param_idx)?;
400    if !matches!(&acc_arg.node, Expr::List(items) if items.is_empty()) {
401        return None;
402    }
403
404    Some(name.clone())
405}
406
407/// True if a parameter type-string parses as `List<...>`.
408fn is_list_type_str(ty: &str) -> bool {
409    let t = ty.trim();
410    t.starts_with("List<") && t.ends_with('>')
411}
412
413/// Extract the single match expression that forms a fn's entire body.
414/// Returns `None` if the body is empty, has multiple statements, or its
415/// single statement isn't a match expression.
416fn single_match_body(body: &FnBody) -> Option<&Expr> {
417    let stmts = body.stmts();
418    if stmts.len() != 1 {
419        return None;
420    }
421    match &stmts[0] {
422        Stmt::Expr(spanned) => match &spanned.node {
423            Expr::Match { .. } => Some(&spanned.node),
424            _ => None,
425        },
426        Stmt::Binding(_, _, _) => None,
427    }
428}
429
430/// If `arms` is exactly two arms with `Bool(true)` / `Bool(false)`
431/// patterns, return `(true_body, false_body)` references. Order in
432/// source doesn't matter — we sort by pattern.
433fn pair_bool_arms(arms: &[MatchArm]) -> Option<(&Expr, &Expr)> {
434    if arms.len() != 2 {
435        return None;
436    }
437    let mut t = None;
438    let mut f = None;
439    for arm in arms {
440        match &arm.pattern {
441            Pattern::Literal(Literal::Bool(true)) => {
442                if t.is_some() {
443                    return None;
444                }
445                t = Some(&arm.body.node);
446            }
447            Pattern::Literal(Literal::Bool(false)) => {
448                if f.is_some() {
449                    return None;
450                }
451                f = Some(&arm.body.node);
452            }
453            _ => return None,
454        }
455    }
456    Some((t?, f?))
457}
458
459/// True if `expr` is `List.reverse(<Ident(acc_name)>)`.
460fn is_list_reverse_of(expr: &Expr, acc_name: &str) -> bool {
461    let (callee, args) = match expr {
462        Expr::FnCall(c, a) => (c, a),
463        _ => return false,
464    };
465    if !is_dotted_ident(&callee.node, "List", "reverse") {
466        return false;
467    }
468    if args.len() != 1 {
469        return false;
470    }
471    matches!(&args[0].node, Expr::Ident(name) if name == acc_name)
472}
473
474/// True if `expr` is a tail-call to `self_name` whose argument list
475/// contains `List.prepend(<anything>, <Ident(acc_name)>)` in any
476/// position. The position should match the `acc_param_idx` but the
477/// caller may have other params before it; we only require the
478/// `prepend` to terminate in the expected accumulator binding.
479fn is_self_tail_with_prepend_acc(
480    expr: &Expr,
481    self_name: &str,
482    acc_idx: usize,
483    acc_name: &str,
484) -> bool {
485    let data = match expr {
486        Expr::TailCall(data) => data,
487        _ => return false,
488    };
489    if data.target != self_name {
490        return false;
491    }
492    // The prepend has to land in the *acc* position specifically — the
493    // synthesizer extracts the element expression from `args[acc_idx]`.
494    // A loose `any` here would let through fns where some other arg
495    // happens to be a prepend, the synth would later return None, but
496    // detection had already promised the sink to call-site rewriting:
497    // `String.join(<sink>(...))` would be rewritten to call a
498    // `<sink>__buffered` that never gets generated. Require the exact
499    // shape here so detection and synthesis agree.
500    let acc_arg = match data.args.get(acc_idx) {
501        Some(a) => a,
502        None => return false,
503    };
504    is_list_prepend_to_acc(&acc_arg.node, acc_name)
505}
506
507/// True if `expr` is `List.prepend(<anything>, <Ident(acc_name)>)`.
508fn is_list_prepend_to_acc(expr: &Expr, acc_name: &str) -> bool {
509    let (callee, args) = match expr {
510        Expr::FnCall(c, a) => (c, a),
511        _ => return false,
512    };
513    if !is_dotted_ident(&callee.node, "List", "prepend") {
514        return false;
515    }
516    if args.len() != 2 {
517        return false;
518    }
519    matches!(&args[1].node, Expr::Ident(name) if name == acc_name)
520}
521
522/// True if `expr` is `<Module>.<Member>` access (the un-called callee
523/// shape of `Module.member(...)`).
524fn is_dotted_ident(expr: &Expr, module: &str, member: &str) -> bool {
525    let (base, attr) = match expr {
526        Expr::Attr(b, a) => (b, a),
527        _ => return false,
528    };
529    if attr != member {
530        return false;
531    }
532    matches!(&base.node, Expr::Ident(name) if name == module)
533}
534
535/// Synthesize a `<fn>__buffered` variant for each matched buffer-build
536/// sink. The synthesized FnDef walks the same shape as the original but
537/// threads a runtime `Buffer` through tail-call args instead of building
538/// a `List<T>` of strings:
539///
540/// Original:
541/// ```aver
542/// fn build(.., acc: List<T>) -> List<T>
543///     match <cond>
544///         true  -> List.reverse(acc)
545///         false -> build(.., List.prepend(<elem>, acc))
546/// ```
547///
548/// Synthesized:
549/// ```aver
550/// fn build__buffered(.., __buf: Buffer, __sep: String) -> Buffer
551///     match <cond>
552///         true  -> __buf
553///         false -> build__buffered(..,
554///             __buf_append(
555///                 __buf_append_sep_unless_first(__buf, __sep),
556///                 <elem>
557///             ),
558///             __sep
559///         )
560/// ```
561///
562/// Threading is via expression composition: the inner
563/// `__buf_append_sep_unless_first` returns the (possibly grown) buffer,
564/// the outer `__buf_append` writes the element and again returns
565/// the (possibly grown) buffer, and that final pointer is what the tail
566/// call sees as `__buf`. No `_ =` discards anywhere — the C' review
567/// explicitly required this to avoid use-after-grow corruption.
568///
569/// Returns one `FnDef` per matched fn. Caller appends to the user-fn
570/// list before WASM emission so both original and buffered variants
571/// reach codegen through the same pipeline.
572pub fn synthesize_buffered_variants(
573    fns: &[&FnDef],
574    sinks: &HashMap<String, BufferBuildShape>,
575) -> Vec<FnDef> {
576    let mut out = Vec::new();
577    for fd in fns {
578        if let Some(shape) = sinks.get(&fd.name)
579            && let Some(buffered) = build_buffered_variant(fd, shape)
580        {
581            out.push(buffered);
582        }
583    }
584    out
585}
586
587/// Wrap an `Expr` as `Spanned<Expr>` carrying the same line as the
588/// matched fn (best effort — the synthesized code is internal and
589/// won't be source-located by the user, but having a non-zero line
590/// keeps downstream visitors happy).
591fn sp_at(line: usize, expr: Expr) -> Spanned<Expr> {
592    Spanned::new(expr, line)
593}
594
595fn sp_at_typed(line: usize, expr: Expr, ty: crate::types::Type) -> Spanned<Expr> {
596    let s = Spanned::new(expr, line);
597    s.set_ty(ty);
598    s
599}
600
601/// Build `<intrinsic>(args...)` as a Spanned<Expr>. Intrinsic names
602/// are bare identifiers (no module dot) — `__buf_append`,
603/// `__buf_append_sep_unless_first`. The WASM emitter recognises them
604/// in the builtin dispatch.
605///
606/// This pass runs *after* the type checker, so synthesized nodes never
607/// see `infer_type`. Callers that want the result `Spanned` to carry a
608/// type for downstream readers (Rust codegen reads `expr.ty()` to gate
609/// owned-mutable `Buffer` hoists) should use [`buffer_intrinsic_call`]
610/// or [`finalize_intrinsic_call`].
611fn intrinsic_call(line: usize, name: &str, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
612    let callee = sp_at(line, Expr::Ident(name.to_string()));
613    sp_at(line, Expr::FnCall(Box::new(callee), args))
614}
615
616/// Same as [`intrinsic_call`] but stamps the result with `Buffer` —
617/// every `__buf_new` / `__buf_append` / `__buf_append_sep_unless_first`
618/// returns the (possibly-grown) owned buffer.
619fn buffer_intrinsic_call(line: usize, name: &str, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
620    let call = intrinsic_call(line, name, args);
621    call.set_ty(crate::types::Type::Named("Buffer".to_string()));
622    call
623}
624
625/// `__buf_finalize` consumes the buffer and yields the final `String`.
626fn finalize_intrinsic_call(line: usize, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
627    let call = intrinsic_call(line, "__buf_finalize", args);
628    call.set_ty(crate::types::Type::Str);
629    call
630}
631
632/// Run the full buffer-build deforestation pass on a program: detect
633/// sinks, synthesize buffered variants, rewrite fusion sites in place,
634/// and APPEND the synthesized FnDefs to the items list as new
635/// top-level fns. Caller is responsible for invoking this AFTER
636/// `tco::transform_program` (the detector requires `Expr::TailCall`
637/// nodes) and BEFORE `resolver::resolve_program` (the detector +
638/// rewrite both match on `Expr::Ident` shapes that the resolver
639/// rewrites to `Expr::Resolved`).
640///
641/// Returns a [`BufferBuildPassReport`] describing what fired, for
642/// diagnostic / bench reporting and `--explain-passes`.
643pub fn run_buffer_build_pass(items: &mut Vec<crate::ast::TopLevel>) -> BufferBuildPassReport {
644    let fn_refs: Vec<&FnDef> = items
645        .iter()
646        .filter_map(|it| match it {
647            crate::ast::TopLevel::FnDef(fd) => Some(fd),
648            _ => None,
649        })
650        .collect();
651    let all_sinks = compute_buffer_build_sinks(&fn_refs);
652    if all_sinks.is_empty() {
653        return BufferBuildPassReport::default();
654    }
655    let sites = find_fusion_sites(&fn_refs, &all_sinks);
656
657    // Synthesize a buffered variant only for sinks that actually have
658    // at least one rewriteable call site. The earlier shape produced
659    // a `<sink>__buffered` for every detected sink — bloat in the
660    // common case (most detected sinks aren't called via the canonical
661    // String.join shape) and a real risk of name-shadowing a user fn
662    // named `<sink>__buffered`. Restricting to used sinks keeps the
663    // synthetic surface tight.
664    let mut used_sinks: HashMap<String, BufferBuildShape> = HashMap::new();
665    for site in &sites {
666        if let Some(shape) = all_sinks.get(&site.sink_fn) {
667            used_sinks.insert(site.sink_fn.clone(), shape.clone());
668        }
669    }
670    let synthesized = synthesize_buffered_variants(&fn_refs, &used_sinks);
671    let sinks = used_sinks;
672    drop(fn_refs);
673
674    let mut fn_defs_owned: Vec<&mut FnDef> = items
675        .iter_mut()
676        .filter_map(|it| match it {
677            crate::ast::TopLevel::FnDef(fd) => Some(fd),
678            _ => None,
679        })
680        .collect();
681    // rewrite_fusion_sites takes &mut [FnDef], so pull a fresh
682    // mutable view across owned slots. We can't pass &mut [&mut FnDef]
683    // directly — instead, walk and rewrite each fn body individually.
684    for fd in fn_defs_owned.iter_mut() {
685        rewrite_one_fn(fd, &sinks);
686    }
687
688    items.reserve(synthesized.len());
689    for fd in synthesized.iter() {
690        items.push(crate::ast::TopLevel::FnDef(fd.clone()));
691    }
692
693    let mut sink_fns: Vec<String> = sinks.keys().cloned().collect();
694    sink_fns.sort();
695    let synthesized_fns: Vec<String> = synthesized.iter().map(|fd| fd.name.clone()).collect();
696
697    let mut rewrites_by_sink: std::collections::BTreeMap<String, usize> =
698        std::collections::BTreeMap::new();
699    for site in &sites {
700        *rewrites_by_sink.entry(site.sink_fn.clone()).or_default() += 1;
701    }
702
703    BufferBuildPassReport {
704        rewrites: sites.len(),
705        synthesized: synthesized_fns,
706        sink_fns,
707        rewrites_by_sink,
708    }
709}
710
711/// Per-pass report — what buffer_build did during a single pipeline run.
712/// Drives `aver compile --explain-passes`; consumed by the bench
713/// regression checks (e.g. "fail if buffer_build no longer fires on the
714/// canonical shape").
715#[derive(Debug, Clone, Default)]
716pub struct BufferBuildPassReport {
717    /// Number of fusion sites rewritten in place.
718    pub rewrites: usize,
719    /// Names of synthesized `<sink>__buffered` variants appended to
720    /// the items list.
721    pub synthesized: Vec<String>,
722    /// Sink fns whose buffered variant actually fired (matches one of
723    /// `synthesized` minus the `__buffered` suffix). Sorted.
724    pub sink_fns: Vec<String>,
725    /// Per-sink rewrite counts. Sorted alphabetically by sink fn.
726    pub rewrites_by_sink: std::collections::BTreeMap<String, usize>,
727}
728
729/// Apply fusion-site rewrite to a single fn body. Internal helper
730/// for `run_buffer_build_pass` since `rewrite_fusion_sites` takes a
731/// slice and we have an iterator-of-mut-refs here.
732fn rewrite_one_fn(fd: &mut FnDef, sinks: &HashMap<String, BufferBuildShape>) {
733    let body_arc = std::sync::Arc::make_mut(&mut fd.body);
734    let FnBody::Block(stmts) = body_arc;
735    for stmt in stmts.iter_mut() {
736        match stmt {
737            Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
738                rewrite_expr_in_place(expr, sinks);
739            }
740        }
741    }
742}
743
744/// Walk every expression in `fn_defs` and rewrite `String.join`
745/// fusion sites in place: `String.join(matched_fn(args, []), sep)` →
746/// `__buf_finalize(matched_fn__buffered(args_without_acc, __buf_new(8192), sep))`.
747///
748/// Conservative trigger per the C' review: only fires when the
749/// acc-position arg is a literal `Expr::List([])`. A non-empty
750/// initial accumulator would silently lose elements after rewrite,
751/// so we skip in that case.
752///
753/// The rewrite is recursive: nested fusion sites (a fusion site
754/// inside another fusion site's args) all get rewritten in one pass.
755pub fn rewrite_fusion_sites(fn_defs: &mut [FnDef], sinks: &HashMap<String, BufferBuildShape>) {
756    if sinks.is_empty() {
757        return;
758    }
759    for fd in fn_defs.iter_mut() {
760        let body_arc = std::sync::Arc::make_mut(&mut fd.body);
761        let FnBody::Block(stmts) = body_arc;
762        for stmt in stmts.iter_mut() {
763            match stmt {
764                Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
765                    rewrite_expr_in_place(expr, sinks);
766                }
767            }
768        }
769    }
770}
771
772/// Recursive expression-tree walker that rewrites fusion sites in
773/// place. Rewrite is "outermost first" — if the whole expression is
774/// a fusion site, transform it before descending into the new shape's
775/// children, so we don't double-rewrite.
776fn rewrite_expr_in_place(expr: &mut Spanned<Expr>, sinks: &HashMap<String, BufferBuildShape>) {
777    if let Some(replacement) = try_rewrite_fusion_site(expr, sinks) {
778        *expr = replacement;
779        // The replacement contains the original elem expressions
780        // (possibly themselves containing fusion sites in deep
781        // gradient builders). Recurse into the new tree.
782        descend_into_subexprs(expr, sinks);
783        return;
784    }
785    descend_into_subexprs(expr, sinks);
786}
787
788/// Recurse into the children of an Expr, applying `rewrite_expr_in_place`
789/// to each. Mirrors the shape coverage of `walk_expr_for_fusion_sites`
790/// in this module so we don't miss any node kind.
791fn descend_into_subexprs(expr: &mut Spanned<Expr>, sinks: &HashMap<String, BufferBuildShape>) {
792    match &mut expr.node {
793        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
794        Expr::Constructor(_, Some(inner)) | Expr::Attr(inner, _) | Expr::ErrorProp(inner) => {
795            rewrite_expr_in_place(inner, sinks);
796        }
797        Expr::FnCall(callee, args) => {
798            rewrite_expr_in_place(callee, sinks);
799            for a in args.iter_mut() {
800                rewrite_expr_in_place(a, sinks);
801            }
802        }
803        Expr::TailCall(data) => {
804            for a in data.args.iter_mut() {
805                rewrite_expr_in_place(a, sinks);
806            }
807        }
808        Expr::BinOp(_, l, r) => {
809            rewrite_expr_in_place(l, sinks);
810            rewrite_expr_in_place(r, sinks);
811        }
812        Expr::Match { subject, arms } => {
813            rewrite_expr_in_place(subject, sinks);
814            for arm in arms.iter_mut() {
815                rewrite_expr_in_place(&mut arm.body, sinks);
816            }
817        }
818        Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
819            for it in items.iter_mut() {
820                rewrite_expr_in_place(it, sinks);
821            }
822        }
823        Expr::MapLiteral(entries) => {
824            for (k, v) in entries.iter_mut() {
825                rewrite_expr_in_place(k, sinks);
826                rewrite_expr_in_place(v, sinks);
827            }
828        }
829        Expr::RecordCreate { fields, .. } => {
830            for (_, v) in fields.iter_mut() {
831                rewrite_expr_in_place(v, sinks);
832            }
833        }
834        Expr::RecordUpdate { base, updates, .. } => {
835            rewrite_expr_in_place(base, sinks);
836            for (_, v) in updates.iter_mut() {
837                rewrite_expr_in_place(v, sinks);
838            }
839        }
840        Expr::InterpolatedStr(parts) => {
841            for part in parts.iter_mut() {
842                if let crate::ast::StrPart::Parsed(inner) = part {
843                    rewrite_expr_in_place(inner, sinks);
844                }
845            }
846        }
847    }
848}
849
850/// If `expr` is a `String.join(matched_fn(args, []), sep)` with
851/// matched_fn in `sinks` and acc-position arg a literal empty list,
852/// return the rewritten Spanned<Expr>. Else return None.
853fn try_rewrite_fusion_site(
854    expr: &Spanned<Expr>,
855    sinks: &HashMap<String, BufferBuildShape>,
856) -> Option<Spanned<Expr>> {
857    let line = expr.line;
858
859    // Match the same predicate used by `find_fusion_sites` so the
860    // diagnostic count and the rewrite count are guaranteed equal.
861    let sink_name = match_string_join_fusion_site(&expr.node, sinks)?;
862    let shape = sinks.get(&sink_name)?;
863
864    // Re-extract the inner sink call and its args. The match predicate
865    // above already verified the shape; this is just to recover the
866    // pieces we need to assemble the rewrite.
867    let outer_args = match &expr.node {
868        Expr::FnCall(_, a) => a,
869        _ => return None,
870    };
871    let consumer_arg = &outer_args[0].node;
872    let inner_call_expr = if let Expr::FnCall(rev_callee, rev_args) = consumer_arg
873        && is_dotted_ident(&rev_callee.node, "List", "reverse")
874        && rev_args.len() == 1
875    {
876        &rev_args[0].node
877    } else {
878        consumer_arg
879    };
880    let inner_args = match inner_call_expr {
881        Expr::FnCall(_, a) => a,
882        _ => return None,
883    };
884
885    // Build the rewrite:
886    //   __buf_finalize(
887    //     <fn>__buffered(
888    //       <args without acc-pos>,
889    //       __buf_new(8192),
890    //       <sep>
891    //     )
892    //   )
893    let sep_expr = outer_args[1].clone();
894    let buf_new = buffer_intrinsic_call(
895        line,
896        "__buf_new",
897        vec![sp_at_typed(
898            line,
899            Expr::Literal(Literal::Int(8192)),
900            crate::types::Type::Int,
901        )],
902    );
903    let mut buffered_args: Vec<Spanned<Expr>> = inner_args
904        .iter()
905        .enumerate()
906        .filter_map(|(i, a)| (i != shape.acc_param_idx).then_some(a).cloned())
907        .collect();
908    buffered_args.push(buf_new);
909    buffered_args.push(sep_expr);
910    // `<sink>__buffered` returns `String` (the buffered variant signature
911    // in `synthesize_buffered_fn` types it that way).
912    let buffered_call = sp_at_typed(
913        line,
914        Expr::FnCall(
915            Box::new(sp_at(line, Expr::Ident(format!("{}__buffered", sink_name)))),
916            buffered_args,
917        ),
918        crate::types::Type::Str,
919    );
920    Some(finalize_intrinsic_call(line, vec![buffered_call]))
921}
922
923/// Construct the buffered FnDef for a single matched fn. Returns
924/// `None` if the original body shape doesn't match what we expect
925/// (defensive: detection should have caught this, but if the body
926/// changed shape between detection and synthesis, skip).
927fn build_buffered_variant(fd: &FnDef, shape: &BufferBuildShape) -> Option<FnDef> {
928    // Original body: `match <subject> { <terminating-arm>; <recursive-arm> }`.
929    // The two BufferBuildKind variants pair different patterns:
930    //   InternalReverse: `true -> List.reverse(acc)`, `false -> recurse(...)`.
931    //   ExternalReverse: `[] -> acc`, `[head, ..rest] -> recurse(...)`.
932    // We extract the recursive arm in both cases (for the prepend tail
933    // call) and rebuild the match with terminating arm `... -> __buf`.
934    let stmts = fd.body.stmts();
935    if stmts.len() != 1 {
936        return None;
937    }
938    let outer_expr = match &stmts[0] {
939        Stmt::Expr(spanned) => spanned,
940        _ => return None,
941    };
942    let (subject_orig, arms_orig) = match &outer_expr.node {
943        Expr::Match { subject, arms } => (subject, arms),
944        _ => return None,
945    };
946    let recursive_body: &Spanned<Expr> = match shape.kind {
947        BufferBuildKind::InternalReverse => arms_orig
948            .iter()
949            .find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(false))))
950            .map(|a| a.body.as_ref())?,
951        BufferBuildKind::ExternalReverse => arms_orig
952            .iter()
953            .find(|a| matches!(a.pattern, Pattern::Cons(_, _)))
954            .map(|a| a.body.as_ref())?,
955    };
956    let tail_data = match &recursive_body.node {
957        Expr::TailCall(data) => data,
958        _ => return None,
959    };
960
961    // The acc-position arg in the original tail call is
962    // `List.prepend(<elem>, acc)`. Extract the element expression.
963    let acc_arg_orig = tail_data.args.get(shape.acc_param_idx)?;
964    let elem_expr = match &acc_arg_orig.node {
965        Expr::FnCall(callee, args) => {
966            if !is_dotted_ident(&callee.node, "List", "prepend") {
967                return None;
968            }
969            if args.len() != 2 {
970                return None;
971            }
972            // args[0] is elem, args[1] is acc ident — verify acc.
973            match &args[1].node {
974                Expr::Ident(name) if name == &shape.acc_param_name => {}
975                _ => return None,
976            }
977            args[0].clone()
978        }
979        _ => return None,
980    };
981
982    let line = fd.line;
983    let buf_name = "__buf";
984    let sep_name = "__sep";
985    let buffered_target = format!("{}__buffered", fd.name);
986
987    // Synthesized false arm body:
988    //   <self>__buffered(<orig args minus acc>, __buf_append(<sep_unless_first>, <elem>), __sep)
989    //
990    // Build the buffer-threading expression first: the inner intrinsic
991    // appends `__sep` if the buffer is non-empty (otherwise no-op),
992    // returning the possibly-grown buffer. The outer intrinsic appends
993    // the user's element. The result is what gets passed as the
994    // buffered variant's `__buf` arg in the recursive call.
995    let buffer_ty = crate::types::Type::Named("Buffer".to_string());
996    let buf_ident = || sp_at_typed(line, Expr::Ident(buf_name.to_string()), buffer_ty.clone());
997    let sep_ident = || {
998        sp_at_typed(
999            line,
1000            Expr::Ident(sep_name.to_string()),
1001            crate::types::Type::Str,
1002        )
1003    };
1004    let sep_then_buf = buffer_intrinsic_call(
1005        line,
1006        "__buf_append_sep_unless_first",
1007        vec![buf_ident(), sep_ident()],
1008    );
1009    let final_buf = buffer_intrinsic_call(line, "__buf_append", vec![sep_then_buf, elem_expr]);
1010
1011    // Build new tail-call args: original args with acc-pos replaced by
1012    // the threaded buffer expression, then `__sep` appended at end.
1013    let mut new_args: Vec<Spanned<Expr>> = tail_data
1014        .args
1015        .iter()
1016        .enumerate()
1017        .map(|(i, a)| {
1018            if i == shape.acc_param_idx {
1019                final_buf.clone()
1020            } else {
1021                a.clone()
1022            }
1023        })
1024        .collect();
1025    new_args.push(sep_ident());
1026
1027    // The recursive tail-call lands on `<fn>__buffered`, which returns
1028    // Buffer; stamp the type so the typed WASM emitter doesn't have to
1029    // re-derive it from the call target.
1030    let new_recursive_body = sp_at_typed(
1031        line,
1032        Expr::TailCall(Box::new(TailCallData {
1033            target: buffered_target.clone(),
1034            args: new_args,
1035        })),
1036        buffer_ty.clone(),
1037    );
1038
1039    // Terminating arm body: just return `__buf` — the buffer IS the result.
1040    // Pattern depends on which sink shape we matched: `true` for the
1041    // InternalReverse idiom (where the original returned `List.reverse(acc)`),
1042    // `[]` for ExternalReverse (where the original returned `acc`).
1043    let new_arms = match shape.kind {
1044        BufferBuildKind::InternalReverse => vec![
1045            MatchArm {
1046                pattern: Pattern::Literal(Literal::Bool(true)),
1047                body: Box::new(buf_ident()),
1048                binding_slots: std::sync::OnceLock::new(),
1049            },
1050            MatchArm {
1051                pattern: Pattern::Literal(Literal::Bool(false)),
1052                body: Box::new(new_recursive_body),
1053                binding_slots: std::sync::OnceLock::new(),
1054            },
1055        ],
1056        BufferBuildKind::ExternalReverse => {
1057            // Re-use the cons binding names from the original arm so
1058            // any `head` / `rest` references inside the recursive body
1059            // continue to resolve.
1060            let cons_pat = arms_orig
1061                .iter()
1062                .find_map(|a| match &a.pattern {
1063                    Pattern::Cons(h, t) => Some(Pattern::Cons(h.clone(), t.clone())),
1064                    _ => None,
1065                })
1066                .unwrap_or(Pattern::Cons("__head".to_string(), "__tail".to_string()));
1067            vec![
1068                MatchArm {
1069                    pattern: Pattern::EmptyList,
1070                    body: Box::new(buf_ident()),
1071                    binding_slots: std::sync::OnceLock::new(),
1072                },
1073                MatchArm {
1074                    pattern: cons_pat,
1075                    body: Box::new(new_recursive_body),
1076                    binding_slots: std::sync::OnceLock::new(),
1077                },
1078            ]
1079        }
1080    };
1081
1082    // Synthesised Match returns Buffer — both arms produce one
1083    // (terminating arm: bare `__buf` ident; recursive arm: a
1084    // `<fn>__buffered` tail-call returning Buffer). Stamp it so the
1085    // Step 2 type-driven WASM emitter can read the type without a
1086    // fallback.
1087    let new_match = sp_at_typed(
1088        line,
1089        Expr::Match {
1090            subject: subject_orig.clone(),
1091            arms: new_arms,
1092        },
1093        crate::types::Type::Named("Buffer".to_string()),
1094    );
1095
1096    let new_body = FnBody::Block(vec![Stmt::Expr(new_match)]);
1097
1098    // Params: original minus acc + (__buf, "Buffer") + (__sep, "String").
1099    let mut new_params: Vec<(String, String)> = fd
1100        .params
1101        .iter()
1102        .enumerate()
1103        .filter_map(|(i, p)| (i != shape.acc_param_idx).then_some(p).cloned())
1104        .collect();
1105    new_params.push((buf_name.to_string(), "Buffer".to_string()));
1106    new_params.push((sep_name.to_string(), "String".to_string()));
1107
1108    Some(FnDef {
1109        name: buffered_target,
1110        line,
1111        params: new_params,
1112        return_type: "Buffer".to_string(),
1113        // Synthesized variants inherit effects from the original — if
1114        // the matched fn calls effectful helpers (like `renderRow`
1115        // calling `Console.print`), the buffered variant calls them
1116        // too at the same positions. Conservative.
1117        effects: fd.effects.clone(),
1118        desc: Some(format!(
1119            "Synthesized buffered variant of `{}` for deforestation \
1120             lowering. Call sites that match `String.join({}(...), sep)` \
1121             are rewritten to alloc a buffer + call this variant + \
1122             finalize, skipping the intermediate List.",
1123            fd.name, fd.name
1124        )),
1125        body: Arc::new(new_body),
1126        resolution: None,
1127    })
1128}
1129
1130#[cfg(test)]
1131mod tests {
1132    use super::*;
1133    use crate::ast::{BinOp, FnBody, FnDef, Literal, Spanned, TailCallData};
1134    use std::sync::Arc;
1135
1136    fn sp<T>(value: T) -> Spanned<T> {
1137        Spanned::new(value, 1)
1138    }
1139
1140    fn ident(name: &str) -> Spanned<Expr> {
1141        sp(Expr::Ident(name.to_string()))
1142    }
1143
1144    fn dotted(module: &str, member: &str) -> Spanned<Expr> {
1145        sp(Expr::Attr(Box::new(ident(module)), member.to_string()))
1146    }
1147
1148    fn call(callee: Spanned<Expr>, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
1149        sp(Expr::FnCall(Box::new(callee), args))
1150    }
1151
1152    /// Build a canonical buffer-build fn: takes (col: Int, acc: List<Int>),
1153    /// matches col >= 10, true → reverse(acc), false → tail-call self
1154    /// with prepend(col, acc).
1155    fn canonical_builder(name: &str) -> FnDef {
1156        let true_body = call(dotted("List", "reverse"), vec![ident("acc")]);
1157        let prepend = call(dotted("List", "prepend"), vec![ident("col"), ident("acc")]);
1158        let false_body = sp(Expr::TailCall(Box::new(TailCallData {
1159            target: name.to_string(),
1160            args: vec![
1161                sp(Expr::BinOp(
1162                    BinOp::Add,
1163                    Box::new(ident("col")),
1164                    Box::new(sp(Expr::Literal(Literal::Int(1)))),
1165                )),
1166                prepend,
1167            ],
1168        })));
1169        let match_expr = sp(Expr::Match {
1170            subject: Box::new(sp(Expr::BinOp(
1171                BinOp::Gte,
1172                Box::new(ident("col")),
1173                Box::new(sp(Expr::Literal(Literal::Int(10)))),
1174            ))),
1175            arms: vec![
1176                MatchArm {
1177                    pattern: Pattern::Literal(Literal::Bool(true)),
1178                    body: Box::new(true_body),
1179                    binding_slots: std::sync::OnceLock::new(),
1180                },
1181                MatchArm {
1182                    pattern: Pattern::Literal(Literal::Bool(false)),
1183                    body: Box::new(false_body),
1184                    binding_slots: std::sync::OnceLock::new(),
1185                },
1186            ],
1187        });
1188        FnDef {
1189            name: name.to_string(),
1190            line: 1,
1191            params: vec![
1192                ("col".to_string(), "Int".to_string()),
1193                ("acc".to_string(), "List<Int>".to_string()),
1194            ],
1195            return_type: "List<Int>".to_string(),
1196            effects: vec![],
1197            desc: None,
1198            body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
1199            resolution: None,
1200        }
1201    }
1202
1203    #[test]
1204    fn matches_canonical_buffer_build() {
1205        let fd = canonical_builder("build");
1206        let info = compute_buffer_build_sinks(&[&fd]);
1207        let shape = info.get("build").expect("expected match");
1208        assert_eq!(shape.acc_param_idx, 1);
1209        assert_eq!(shape.acc_param_name, "acc");
1210    }
1211
1212    #[test]
1213    fn rejects_fn_without_list_param() {
1214        let mut fd = canonical_builder("build");
1215        // Strip the List<...> param.
1216        fd.params = vec![("col".to_string(), "Int".to_string())];
1217        let info = compute_buffer_build_sinks(&[&fd]);
1218        assert!(info.is_empty(), "fn without List param should not match");
1219    }
1220
1221    #[test]
1222    fn rejects_when_true_arm_isnt_reverse() {
1223        let mut fd = canonical_builder("build");
1224        // Replace true arm body with a different expression.
1225        if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body)
1226            && let Stmt::Expr(spanned) = &mut stmts[0]
1227            && let Expr::Match { arms, .. } = &mut spanned.node
1228        {
1229            *arms[0].body = ident("acc");
1230        }
1231        let info = compute_buffer_build_sinks(&[&fd]);
1232        assert!(
1233            info.is_empty(),
1234            "fn returning bare acc instead of reverse should not match"
1235        );
1236    }
1237
1238    #[test]
1239    fn rejects_when_false_arm_uses_append_not_prepend() {
1240        let mut fd = canonical_builder("build");
1241        // Swap List.prepend → List.append in the false arm tail call.
1242        if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body)
1243            && let Stmt::Expr(spanned) = &mut stmts[0]
1244            && let Expr::Match { arms, .. } = &mut spanned.node
1245        {
1246            let false_body = arms[1].body.as_mut();
1247            if let Expr::TailCall(data) = &mut false_body.node
1248                && let Expr::FnCall(callee, _) = &mut data.args[1].node
1249                && let Expr::Attr(_, attr) = &mut callee.node
1250            {
1251                *attr = "append".to_string();
1252            }
1253        }
1254        let info = compute_buffer_build_sinks(&[&fd]);
1255        assert!(
1256            info.is_empty(),
1257            "fn using List.append instead of prepend should not match"
1258        );
1259    }
1260
1261    #[test]
1262    fn rejects_tail_call_to_different_fn() {
1263        let mut fd = canonical_builder("build");
1264        if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body)
1265            && let Stmt::Expr(spanned) = &mut stmts[0]
1266            && let Expr::Match { arms, .. } = &mut spanned.node
1267        {
1268            let false_body = arms[1].body.as_mut();
1269            if let Expr::TailCall(data) = &mut false_body.node {
1270                data.target = "someone_else".to_string();
1271            }
1272        }
1273        let info = compute_buffer_build_sinks(&[&fd]);
1274        assert!(
1275            info.is_empty(),
1276            "fn whose recursive call targets a different name should not match"
1277        );
1278    }
1279
1280    #[test]
1281    fn rejects_match_with_non_bool_arms() {
1282        let mut fd = canonical_builder("build");
1283        if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body)
1284            && let Stmt::Expr(spanned) = &mut stmts[0]
1285            && let Expr::Match { arms, .. } = &mut spanned.node
1286        {
1287            arms[0].pattern = Pattern::Literal(Literal::Int(0));
1288        }
1289        let info = compute_buffer_build_sinks(&[&fd]);
1290        assert!(
1291            info.is_empty(),
1292            "match on non-bool patterns should not be detected as buffer-build"
1293        );
1294    }
1295
1296    /// End-to-end: parse a small Aver source, run TCO, then detect.
1297    /// The TCO transform is what produces `Expr::TailCall` nodes from
1298    /// raw `Expr::FnCall` self-recursion; detection runs on the post-TCO
1299    /// AST.
1300    #[test]
1301    fn detects_via_parser_after_tco() {
1302        let src = r#"
1303fn build(n: Int, acc: List<Int>) -> List<Int>
1304    match n <= 0
1305        true  -> List.reverse(acc)
1306        false -> build(n - 1, List.prepend(n, acc))
1307"#;
1308        let mut lexer = crate::lexer::Lexer::new(src);
1309        let tokens = lexer.tokenize().expect("lex");
1310        let mut parser = crate::parser::Parser::new(tokens);
1311        let mut items = parser.parse().expect("parse");
1312        crate::ir::pipeline::tco(&mut items);
1313        let fns: Vec<&FnDef> = items
1314            .iter()
1315            .filter_map(|it| match it {
1316                crate::ast::TopLevel::FnDef(fd) => Some(fd),
1317                _ => None,
1318            })
1319            .collect();
1320        let info = compute_buffer_build_sinks(&fns);
1321        let shape = info
1322            .get("build")
1323            .expect("expected end-to-end shape match for canonical builder");
1324        assert_eq!(shape.acc_param_idx, 1);
1325        assert_eq!(shape.acc_param_name, "acc");
1326    }
1327
1328    /// End-to-end fusion-site detection: builder + caller `String.join`
1329    /// site recognised, line recorded, sink name attached.
1330    #[test]
1331    fn finds_fusion_site_via_parser() {
1332        let src = r#"
1333fn build(n: Int, acc: List<Int>) -> List<Int>
1334    match n <= 0
1335        true  -> List.reverse(acc)
1336        false -> build(n - 1, List.prepend(n, acc))
1337
1338fn main() -> String
1339    String.join(build(5, []), ",")
1340"#;
1341        let mut lexer = crate::lexer::Lexer::new(src);
1342        let tokens = lexer.tokenize().expect("lex");
1343        let mut parser = crate::parser::Parser::new(tokens);
1344        let mut items = parser.parse().expect("parse");
1345        crate::ir::pipeline::tco(&mut items);
1346        let fns: Vec<&FnDef> = items
1347            .iter()
1348            .filter_map(|it| match it {
1349                crate::ast::TopLevel::FnDef(fd) => Some(fd),
1350                _ => None,
1351            })
1352            .collect();
1353        let sinks = compute_buffer_build_sinks(&fns);
1354        let sites = find_fusion_sites(&fns, &sinks);
1355        assert_eq!(sites.len(), 1, "expected one fusion site, got {sites:?}");
1356        let site = &sites[0];
1357        assert_eq!(site.enclosing_fn, "main");
1358        assert_eq!(site.sink_fn, "build");
1359        assert!(site.line > 0, "expected real line info, got 0");
1360    }
1361
1362    /// Caller passes the matched fn's result to a non-`String.join`
1363    /// destination — should NOT register as a fusion site (no buffer
1364    /// to write into).
1365    #[test]
1366    fn ignores_call_when_not_wrapped_in_string_join() {
1367        let src = r#"
1368fn build(n: Int, acc: List<Int>) -> List<Int>
1369    match n <= 0
1370        true  -> List.reverse(acc)
1371        false -> build(n - 1, List.prepend(n, acc))
1372
1373fn main() -> List<Int>
1374    build(5, [])
1375"#;
1376        let mut lexer = crate::lexer::Lexer::new(src);
1377        let tokens = lexer.tokenize().expect("lex");
1378        let mut parser = crate::parser::Parser::new(tokens);
1379        let mut items = parser.parse().expect("parse");
1380        crate::ir::pipeline::tco(&mut items);
1381        let fns: Vec<&FnDef> = items
1382            .iter()
1383            .filter_map(|it| match it {
1384                crate::ast::TopLevel::FnDef(fd) => Some(fd),
1385                _ => None,
1386            })
1387            .collect();
1388        let sinks = compute_buffer_build_sinks(&fns);
1389        let sites = find_fusion_sites(&fns, &sinks);
1390        assert!(
1391            sites.is_empty(),
1392            "build called outside String.join must not be a fusion site, got {sites:?}"
1393        );
1394    }
1395
1396    /// Counter-test: a recursive fn that returns `acc` directly (no
1397    /// reverse) — semantically valid Aver, but its result order is
1398    /// reversed relative to natural read order, so deforestation can't
1399    /// safely rewrite to a forward-emit buffer loop without explicit
1400    /// authorisation. Detector must reject it.
1401    #[test]
1402    fn rejects_via_parser_when_true_arm_returns_bare_acc() {
1403        let src = r#"
1404fn build(n: Int, acc: List<Int>) -> List<Int>
1405    match n <= 0
1406        true  -> acc
1407        false -> build(n - 1, List.prepend(n, acc))
1408"#;
1409        let mut lexer = crate::lexer::Lexer::new(src);
1410        let tokens = lexer.tokenize().expect("lex");
1411        let mut parser = crate::parser::Parser::new(tokens);
1412        let mut items = parser.parse().expect("parse");
1413        crate::ir::pipeline::tco(&mut items);
1414        let fns: Vec<&FnDef> = items
1415            .iter()
1416            .filter_map(|it| match it {
1417                crate::ast::TopLevel::FnDef(fd) => Some(fd),
1418                _ => None,
1419            })
1420            .collect();
1421        let info = compute_buffer_build_sinks(&fns);
1422        assert!(
1423            info.is_empty(),
1424            "fn returning bare acc must not be detected as a deforestation candidate"
1425        );
1426    }
1427
1428    /// End-to-end synthesis: parse a small builder, run TCO, detect
1429    /// it as a sink, then synthesize the buffered variant. Verify the
1430    /// shape: name suffix, dropped acc param, added __buf/__sep
1431    /// params, true arm returns __buf ident, false arm tail-calls
1432    /// __buffered self with threaded buffer expression.
1433    #[test]
1434    fn synthesizes_buffered_variant_from_real_builder() {
1435        let src = r#"
1436fn build(n: Int, acc: List<Int>) -> List<Int>
1437    match n <= 0
1438        true  -> List.reverse(acc)
1439        false -> build(n - 1, List.prepend(n, acc))
1440"#;
1441        let mut lexer = crate::lexer::Lexer::new(src);
1442        let tokens = lexer.tokenize().expect("lex");
1443        let mut parser = crate::parser::Parser::new(tokens);
1444        let mut items = parser.parse().expect("parse");
1445        crate::ir::pipeline::tco(&mut items);
1446        let fns: Vec<&FnDef> = items
1447            .iter()
1448            .filter_map(|it| match it {
1449                crate::ast::TopLevel::FnDef(fd) => Some(fd),
1450                _ => None,
1451            })
1452            .collect();
1453        let sinks = compute_buffer_build_sinks(&fns);
1454        assert!(sinks.contains_key("build"));
1455        let synthesized = synthesize_buffered_variants(&fns, &sinks);
1456        assert_eq!(
1457            synthesized.len(),
1458            1,
1459            "expected exactly one synthesized variant"
1460        );
1461        let bf = &synthesized[0];
1462
1463        // Name + signature shape.
1464        assert_eq!(bf.name, "build__buffered");
1465        assert_eq!(bf.return_type, "Buffer");
1466        let param_names: Vec<&str> = bf.params.iter().map(|(n, _)| n.as_str()).collect();
1467        let param_types: Vec<&str> = bf.params.iter().map(|(_, t)| t.as_str()).collect();
1468        assert_eq!(param_names, vec!["n", "__buf", "__sep"]);
1469        assert_eq!(param_types, vec!["Int", "Buffer", "String"]);
1470
1471        // Body: single Stmt::Expr holding a 2-arm match.
1472        let stmts = bf.body.stmts();
1473        assert_eq!(stmts.len(), 1);
1474        let match_expr = match &stmts[0] {
1475            Stmt::Expr(s) => match &s.node {
1476                Expr::Match { subject: _, arms } => arms,
1477                _ => panic!("body root must be a match"),
1478            },
1479            _ => panic!("body root must be Stmt::Expr"),
1480        };
1481        assert_eq!(match_expr.len(), 2);
1482
1483        // True arm: body is `__buf` ident.
1484        let true_arm = match_expr
1485            .iter()
1486            .find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(true))))
1487            .expect("true arm");
1488        match &true_arm.body.node {
1489            Expr::Ident(name) => assert_eq!(name, "__buf"),
1490            other => panic!("true arm should be Ident(__buf), got {other:?}"),
1491        }
1492
1493        // False arm: tail-call to build__buffered with threaded buf.
1494        let false_arm = match_expr
1495            .iter()
1496            .find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(false))))
1497            .expect("false arm");
1498        let tail_data = match &false_arm.body.node {
1499            Expr::TailCall(d) => d,
1500            other => panic!("false arm should be TailCall, got {other:?}"),
1501        };
1502        assert_eq!(tail_data.target, "build__buffered");
1503        // Args: [n - 1, threaded-buffer-expr, __sep_ident]. acc-pos
1504        // (was index 1 in original) is now the threaded buffer; sep
1505        // appended at end.
1506        assert_eq!(tail_data.args.len(), 3);
1507        // Arg 1 is the buffer-threading composition; verify it's
1508        // `__buf_append(__buf_append_sep_unless_first(__buf, __sep), n)`.
1509        let outer = match &tail_data.args[1].node {
1510            Expr::FnCall(callee, args) => {
1511                match &callee.node {
1512                    Expr::Ident(name) => assert_eq!(name, "__buf_append"),
1513                    _ => panic!("expected Ident callee"),
1514                }
1515                args
1516            }
1517            _ => panic!("expected outer __buf_append FnCall"),
1518        };
1519        assert_eq!(outer.len(), 2);
1520        // First arg of outer = inner sep-then-buf.
1521        match &outer[0].node {
1522            Expr::FnCall(callee, _) => match &callee.node {
1523                Expr::Ident(name) => assert_eq!(name, "__buf_append_sep_unless_first"),
1524                _ => panic!("expected Ident callee for inner intrinsic"),
1525            },
1526            _ => panic!("expected inner __buf_append_sep_unless_first FnCall"),
1527        }
1528        // Second arg of outer = original `n` (the prepend's element).
1529        match &outer[1].node {
1530            Expr::Ident(name) => assert_eq!(name, "n"),
1531            _ => panic!("expected `n` ident as elem"),
1532        }
1533        // Last tail-call arg = __sep ident.
1534        match &tail_data.args[2].node {
1535            Expr::Ident(name) => assert_eq!(name, "__sep"),
1536            _ => panic!("expected __sep ident as last arg"),
1537        }
1538    }
1539
1540    #[test]
1541    fn detects_acc_param_at_arbitrary_index() {
1542        // Builder where the List<T> param is first and the tail-call
1543        // body wires the prepend at the same index. Detection has to
1544        // pin the acc position to where the prepend actually lands —
1545        // an earlier loose `any` check would silently pass even on
1546        // mismatched param/arg orderings, then synthesis would fail
1547        // to extract the element expression. Keep the body and the
1548        // params consistent so we exercise the real path.
1549        let true_body = call(dotted("List", "reverse"), vec![ident("acc")]);
1550        let prepend = call(dotted("List", "prepend"), vec![ident("col"), ident("acc")]);
1551        // Tail call: build(prepend(col, acc), col + 1)
1552        // — acc-position arg is at index 0, col+1 at index 1.
1553        let false_body = sp(Expr::TailCall(Box::new(TailCallData {
1554            target: "build".to_string(),
1555            args: vec![
1556                prepend,
1557                sp(Expr::BinOp(
1558                    BinOp::Add,
1559                    Box::new(ident("col")),
1560                    Box::new(sp(Expr::Literal(Literal::Int(1)))),
1561                )),
1562            ],
1563        })));
1564        let match_expr = sp(Expr::Match {
1565            subject: Box::new(sp(Expr::BinOp(
1566                BinOp::Gte,
1567                Box::new(ident("col")),
1568                Box::new(sp(Expr::Literal(Literal::Int(10)))),
1569            ))),
1570            arms: vec![
1571                MatchArm {
1572                    pattern: Pattern::Literal(Literal::Bool(true)),
1573                    body: Box::new(true_body),
1574                    binding_slots: std::sync::OnceLock::new(),
1575                },
1576                MatchArm {
1577                    pattern: Pattern::Literal(Literal::Bool(false)),
1578                    body: Box::new(false_body),
1579                    binding_slots: std::sync::OnceLock::new(),
1580                },
1581            ],
1582        });
1583        let fd = FnDef {
1584            name: "build".to_string(),
1585            line: 1,
1586            params: vec![
1587                ("acc".to_string(), "List<Int>".to_string()),
1588                ("col".to_string(), "Int".to_string()),
1589            ],
1590            return_type: "List<Int>".to_string(),
1591            effects: vec![],
1592            desc: None,
1593            body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
1594            resolution: None,
1595        };
1596        let info = compute_buffer_build_sinks(&[&fd]);
1597        let shape = info.get("build").expect("expected match");
1598        assert_eq!(shape.acc_param_idx, 0);
1599        assert_eq!(shape.acc_param_name, "acc");
1600    }
1601
1602    #[test]
1603    fn rejects_loose_prepend_in_non_acc_position() {
1604        // Earlier the detector accepted a fn whose tail call had a
1605        // prepend in *some* arg, regardless of position. That let
1606        // detection promise a sink the synthesizer couldn't actually
1607        // build. Make sure the tightened predicate refuses this.
1608        let mut fd = canonical_builder("build");
1609        // Reorder tail-call args so prepend ends up at index 0 instead
1610        // of index 1 — but keep params [(col, Int), (acc, List<Int>)],
1611        // so acc-position is index 1, where there's now a `col + 1`
1612        // expression (no prepend). Detection should refuse.
1613        {
1614            let body = std::sync::Arc::make_mut(&mut fd.body);
1615            let FnBody::Block(stmts) = body;
1616            if let Stmt::Expr(spanned) = &mut stmts[0]
1617                && let Expr::Match { arms, .. } = &mut spanned.node
1618            {
1619                for arm in arms.iter_mut() {
1620                    if matches!(arm.pattern, Pattern::Literal(Literal::Bool(false)))
1621                        && let Expr::TailCall(data) = &mut arm.body.node
1622                    {
1623                        data.args.reverse();
1624                    }
1625                }
1626            }
1627        }
1628        let info = compute_buffer_build_sinks(&[&fd]);
1629        assert!(
1630            !info.contains_key("build"),
1631            "loose-prepend (prepend not at acc-position) must not be detected"
1632        );
1633    }
1634
1635    #[test]
1636    fn skips_synth_when_no_rewriteable_call_site() {
1637        // A fn that matches the sink shape but whose only call site
1638        // doesn't fit the canonical fusion pattern (e.g. starts with a
1639        // non-empty initial accumulator, or the wrapper is an unrelated
1640        // function call rather than `String.join`) should NOT get a
1641        // synthesized `__buffered` variant. Generating one is bloat
1642        // and risks shadowing user fns.
1643        let sink = canonical_builder("build");
1644        // Dummy caller that uses `build` but not via `String.join(...)`.
1645        let caller = FnDef {
1646            name: "use_build".to_string(),
1647            line: 2,
1648            params: vec![],
1649            return_type: "List<Int>".to_string(),
1650            effects: vec![],
1651            desc: None,
1652            body: Arc::new(FnBody::Block(vec![Stmt::Expr(call(
1653                ident_expr("build"),
1654                vec![sp(Expr::Literal(Literal::Int(0))), sp(Expr::List(vec![]))],
1655            ))])),
1656            resolution: None,
1657        };
1658        let mut items = vec![
1659            crate::ast::TopLevel::FnDef(sink),
1660            crate::ast::TopLevel::FnDef(caller),
1661        ];
1662        let initial_count = items.len();
1663        let report = run_buffer_build_pass(&mut items);
1664        assert_eq!(report.rewrites, 0, "no fusion sites — no rewriteable call");
1665        assert_eq!(
1666            report.synthesized.len(),
1667            0,
1668            "no synth — nothing to fuse against"
1669        );
1670        assert_eq!(items.len(), initial_count, "no buffered variant appended");
1671    }
1672
1673    #[test]
1674    fn external_reverse_pattern_round_trips() {
1675        // `match list { [] -> acc; [h, ..t] -> recurse(t, prepend(_, acc)) }`
1676        // sink + `String.join(List.reverse(<sink>(args, [])), sep)` call
1677        // site should detect, synth, and rewrite as a single fusion.
1678        let nil_body = ident("acc");
1679        let prepend = call(dotted("List", "prepend"), vec![ident("h"), ident("acc")]);
1680        let cons_body = sp(Expr::TailCall(Box::new(TailCallData {
1681            target: "build".to_string(),
1682            args: vec![ident("t"), prepend],
1683        })));
1684        let match_expr = sp(Expr::Match {
1685            subject: Box::new(ident("xs")),
1686            arms: vec![
1687                MatchArm {
1688                    pattern: Pattern::EmptyList,
1689                    body: Box::new(nil_body),
1690                    binding_slots: std::sync::OnceLock::new(),
1691                },
1692                MatchArm {
1693                    pattern: Pattern::Cons("h".to_string(), "t".to_string()),
1694                    body: Box::new(cons_body),
1695                    binding_slots: std::sync::OnceLock::new(),
1696                },
1697            ],
1698        });
1699        let sink = FnDef {
1700            name: "build".to_string(),
1701            line: 1,
1702            params: vec![
1703                ("xs".to_string(), "List<Int>".to_string()),
1704                ("acc".to_string(), "List<String>".to_string()),
1705            ],
1706            return_type: "List<String>".to_string(),
1707            effects: vec![],
1708            desc: None,
1709            body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
1710            resolution: None,
1711        };
1712        let info = compute_buffer_build_sinks(&[&sink]);
1713        let shape = info
1714            .get("build")
1715            .expect("external-reverse sink should be detected");
1716        assert_eq!(shape.kind, BufferBuildKind::ExternalReverse);
1717        assert_eq!(shape.acc_param_idx, 1);
1718
1719        // Caller: `String.join(List.reverse(build(xs, [])), "\n")`
1720        let join_call = call(
1721            dotted("String", "join"),
1722            vec![
1723                call(
1724                    dotted("List", "reverse"),
1725                    vec![call(
1726                        ident_expr("build"),
1727                        vec![ident("xs"), sp(Expr::List(vec![]))],
1728                    )],
1729                ),
1730                sp(Expr::Literal(Literal::Str("\n".to_string()))),
1731            ],
1732        );
1733        let caller = FnDef {
1734            name: "render".to_string(),
1735            line: 2,
1736            params: vec![("xs".to_string(), "List<Int>".to_string())],
1737            return_type: "String".to_string(),
1738            effects: vec![],
1739            desc: None,
1740            body: Arc::new(FnBody::Block(vec![Stmt::Expr(join_call)])),
1741            resolution: None,
1742        };
1743
1744        let mut items = vec![
1745            crate::ast::TopLevel::FnDef(sink),
1746            crate::ast::TopLevel::FnDef(caller),
1747        ];
1748        let report = run_buffer_build_pass(&mut items);
1749        assert_eq!(
1750            report.rewrites, 1,
1751            "external-reverse pattern should be one fusion site"
1752        );
1753        assert_eq!(
1754            report.synthesized.len(),
1755            1,
1756            "exactly one buffered variant for the used sink"
1757        );
1758
1759        // The synthesized variant should be appended.
1760        let synth_present = items.iter().any(|it| match it {
1761            crate::ast::TopLevel::FnDef(fd) => fd.name == "build__buffered",
1762            _ => false,
1763        });
1764        assert!(synth_present, "build__buffered must be appended");
1765    }
1766
1767    fn ident_expr(name: &str) -> Spanned<Expr> {
1768        sp(Expr::Ident(name.to_string()))
1769    }
1770}