Skip to main content

aver/ir/
buffer_build.rs

1//! Buffer-build sink detection.
2//!
3//! Identifies user fns that match the canonical functional list-builder
4//! shape consumed by `String.join`:
5//!
6//! ```aver
7//! fn build(..., acc: List<T>) -> List<T>
8//!     match <cond>
9//!         true  -> List.reverse(acc)
10//!         false -> build(..., List.prepend(<elem>, acc))
11//! ```
12//!
13//! When such a fn is called from `String.join(build(..., []), sep)`, the
14//! whole pipeline is semantically equivalent to a single buffer-write
15//! loop — Wadler 1990 shortcut fusion / deforestation. This module is
16//! Phase 1 of the deforestation work for 0.15 "Traversal": it detects
17//! candidate fns. Lowering (rewriting matched fns + their `String.join`
18//! call sites) lives in a separate pass.
19//!
20//! Detection is intentionally local — the analyzer looks only at the fn
21//! body, not its call sites. A matched fn may or may not actually be
22//! consumed by `String.join`; the lowering pass cross-references call
23//! sites separately and only fuses when both ends of the pipeline agree.
24
25use std::collections::HashMap;
26use std::sync::Arc;
27
28use crate::ast::{Expr, FnBody, FnDef, Literal, MatchArm, Pattern, Spanned, Stmt, TailCallData};
29
30/// Where the matched builder puts the `List.reverse` step that gives
31/// the result its forward order.
32///
33/// `prepend(elem, acc)` builds the accumulator in reverse-of-input
34/// order. To get a forward list (the order we'd hand to `String.join`)
35/// it has to be reversed exactly once. Two equally common Aver idioms
36/// place that reverse in different spots:
37///
38/// - `InternalReverse`: the sink itself has `true -> List.reverse(acc)`
39///   in its base case. Caller writes `String.join(<sink>(args, []), sep)`.
40///   This is the classic "loop with reversed accumulator + reverse on
41///   exit" shape.
42/// - `ExternalReverse`: the sink has `[] -> acc` (no reverse) and
43///   matches on the input list directly. Caller writes
44///   `String.join(List.reverse(<sink>(args, [])), sep)`. Common in the
45///   payment_ops / workflow_engine codebases under the `*Into` naming
46///   convention (e.g. `serializeEntriesInto`, `filterSubjectInto`).
47///
48/// Both shapes lower to the same buffered variant — appending in
49/// processing order (which is forward order of the input) yields the
50/// final string in the right order without any explicit reverse.
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum BufferBuildKind {
53    InternalReverse,
54    ExternalReverse,
55}
56
57/// Information about a fn that matches the buffer-build sink shape.
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub struct BufferBuildShape {
60    /// 0-based index of the `acc: List<T>` parameter in the fn signature.
61    /// Identifies which arg in tail-call positions threads the
62    /// accumulator and which `Ident` in the base-case arm is returned.
63    pub acc_param_idx: usize,
64    /// The accumulator parameter's binding name (looked up in tail-call
65    /// args and in the base-case return position).
66    pub acc_param_name: String,
67    /// Which of the two reverse-placement idioms this sink follows;
68    /// determines (a) what shape the original base arm has and (b)
69    /// whether the call site is `String.join(<sink>(...), sep)` or
70    /// `String.join(List.reverse(<sink>(...)), sep)`.
71    pub kind: BufferBuildKind,
72}
73
74/// What the matched builder feeds into. Different consumers compile
75/// to different buffer types and finalizers, but all share the same
76/// underlying deforestation: skip the intermediate List, write
77/// elements straight to the consumer's storage.
78///
79/// Phase 2 implements `StringJoin` only — the canonical case from the
80/// fractal demo. Future variants land as separate phases:
81/// `VectorFromList` (already half-fused via `Vector.set` owned-mutate
82/// in 0.14.0; deforestation closes the cons-cell side), and `ListFold`
83/// for stream-fusion-style consumer rewrites.
84#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum ConsumerKind {
86    /// `String.join(builder(...), sep)` — write each element + sep
87    /// directly into a `Vec<u8>`-shaped buffer in linear memory.
88    StringJoin,
89}
90
91/// One detected fusion site: a builder call whose result is consumed
92/// by a known sink (currently just `String.join`). Lowering rewrites
93/// the producer + consumer pair into a single buffer-write loop.
94#[derive(Debug, Clone, PartialEq, Eq)]
95pub struct FusionSite {
96    /// Name of the enclosing user fn that contains the call.
97    pub enclosing_fn: String,
98    /// Line of the consumer call.
99    pub line: usize,
100    /// The matched buffer-build fn being wrapped.
101    pub sink_fn: String,
102    /// What's consuming the builder's result.
103    pub consumer: ConsumerKind,
104}
105
106/// Walk all fns in `fns`, return a map from fn name to detected shape
107/// for fns that match the buffer-build sink pattern. Fns that don't
108/// match are absent from the result.
109pub fn compute_buffer_build_sinks(fns: &[&FnDef]) -> HashMap<String, BufferBuildShape> {
110    let mut out = HashMap::new();
111    for fd in fns {
112        if let Some(shape) = match_buffer_build_shape(fd) {
113            out.insert(fd.name.clone(), shape);
114        }
115    }
116    out
117}
118
119/// Walk every expression in every fn body looking for fusion sites:
120/// `String.join(matched_fn(...), sep)` calls where `matched_fn` is a
121/// key in `sinks`. Returns one `FusionSite` per call. The lowering
122/// pass rewrites each site to call a buffered variant of `matched_fn`
123/// directly into a pre-allocated buffer.
124pub fn find_fusion_sites(
125    fns: &[&FnDef],
126    sinks: &HashMap<String, BufferBuildShape>,
127) -> Vec<FusionSite> {
128    let mut out = Vec::new();
129    for fd in fns {
130        for stmt in fd.body.stmts() {
131            match stmt {
132                Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
133                    walk_expr_for_fusion_sites(&expr.node, expr.line, &fd.name, sinks, &mut out);
134                }
135            }
136        }
137    }
138    out
139}
140
141/// Recursively walk an expression tree, recording any fusion site we
142/// find. The fallback `expr_line` is used when a sub-expression has no
143/// own line info.
144fn walk_expr_for_fusion_sites(
145    expr: &Expr,
146    expr_line: usize,
147    enclosing_fn: &str,
148    sinks: &HashMap<String, BufferBuildShape>,
149    out: &mut Vec<FusionSite>,
150) {
151    if let Some(inner_name) = match_string_join_fusion_site(expr, sinks) {
152        out.push(FusionSite {
153            enclosing_fn: enclosing_fn.to_string(),
154            line: expr_line,
155            sink_fn: inner_name,
156            consumer: ConsumerKind::StringJoin,
157        });
158    }
159    // Recurse into all sub-expressions regardless of whether this node
160    // matched (a fusion site can sit inside another fusion site's args
161    // — rare but valid; we'd record both and let the lowering decide).
162    visit_subexprs(expr, expr_line, enclosing_fn, sinks, out);
163}
164
165/// Helper: recurse into the sub-expressions of `expr`. Mirrors the
166/// shape coverage of `expr_allocates` in `alloc_info.rs` so we don't
167/// miss any node kind.
168fn visit_subexprs(
169    expr: &Expr,
170    fallback_line: usize,
171    enclosing_fn: &str,
172    sinks: &HashMap<String, BufferBuildShape>,
173    out: &mut Vec<FusionSite>,
174) {
175    let line_of = |s: &crate::ast::Spanned<Expr>| {
176        if s.line > 0 { s.line } else { fallback_line }
177    };
178    match expr {
179        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
180        Expr::Constructor(_, Some(inner)) | Expr::Attr(inner, _) | Expr::ErrorProp(inner) => {
181            walk_expr_for_fusion_sites(&inner.node, line_of(inner), enclosing_fn, sinks, out);
182        }
183        Expr::FnCall(callee, args) => {
184            walk_expr_for_fusion_sites(&callee.node, line_of(callee), enclosing_fn, sinks, out);
185            for a in args {
186                walk_expr_for_fusion_sites(&a.node, line_of(a), enclosing_fn, sinks, out);
187            }
188        }
189        Expr::TailCall(data) => {
190            for a in &data.args {
191                walk_expr_for_fusion_sites(&a.node, line_of(a), enclosing_fn, sinks, out);
192            }
193        }
194        Expr::BinOp(_, l, r) => {
195            walk_expr_for_fusion_sites(&l.node, line_of(l), enclosing_fn, sinks, out);
196            walk_expr_for_fusion_sites(&r.node, line_of(r), enclosing_fn, sinks, out);
197        }
198        Expr::Match { subject, arms } => {
199            walk_expr_for_fusion_sites(&subject.node, line_of(subject), enclosing_fn, sinks, out);
200            for arm in arms {
201                walk_expr_for_fusion_sites(
202                    &arm.body.node,
203                    line_of(&arm.body),
204                    enclosing_fn,
205                    sinks,
206                    out,
207                );
208            }
209        }
210        Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
211            for it in items {
212                walk_expr_for_fusion_sites(&it.node, line_of(it), enclosing_fn, sinks, out);
213            }
214        }
215        Expr::MapLiteral(entries) => {
216            for (k, v) in entries {
217                walk_expr_for_fusion_sites(&k.node, line_of(k), enclosing_fn, sinks, out);
218                walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
219            }
220        }
221        Expr::RecordCreate { fields, .. } => {
222            for (_, v) in fields {
223                walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
224            }
225        }
226        Expr::RecordUpdate { base, updates, .. } => {
227            walk_expr_for_fusion_sites(&base.node, line_of(base), enclosing_fn, sinks, out);
228            for (_, v) in updates {
229                walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
230            }
231        }
232        Expr::InterpolatedStr(parts) => {
233            for part in parts {
234                if let crate::ast::StrPart::Parsed(inner) = part {
235                    walk_expr_for_fusion_sites(
236                        &inner.node,
237                        line_of(inner),
238                        enclosing_fn,
239                        sinks,
240                        out,
241                    );
242                }
243            }
244        }
245    }
246}
247
248/// Pattern-match a single fn against the buffer-build shape.
249fn match_buffer_build_shape(fd: &FnDef) -> Option<BufferBuildShape> {
250    // The accumulator must be a parameter of type `List<...>`. The
251    // params vector stores type strings, not parsed `Type` values, so we
252    // match the textual form. Aver's surface syntax accepts both
253    // `List<T>` and (rarely) `[T]`-like sugar; canonical form is
254    // `List<T>`.
255    // Take the *rightmost* List<...> parameter as the accumulator. The
256    // InternalReverse shape (`fn build(n: Int, acc: List<T>)`) typically
257    // has only one list param; the ExternalReverse shape often has two
258    // (`fn build(input: List<T>, acc: List<U>)`) and the accumulator is
259    // by convention the last argument. Picking the first match would
260    // misidentify `input` as the accumulator and the rest of the
261    // detection would silently fail.
262    let (acc_idx, acc_name) = fd
263        .params
264        .iter()
265        .enumerate()
266        .rfind(|(_, (_, ty))| is_list_type_str(ty))
267        .map(|(i, (name, _))| (i, name.clone()))?;
268
269    // Body must be a single expression statement holding the match.
270    let match_expr = single_match_body(&fd.body)?;
271    let (subject_expr, arms) = match match_expr {
272        Expr::Match { subject, arms } => (subject, arms),
273        _ => return None,
274    };
275
276    // Try the InternalReverse shape first: `match <bool> { true -> List.reverse(acc); false -> recurse(... prepend(_, acc)) }`.
277    if let Some((true_body, false_body)) = pair_bool_arms(arms) {
278        let _ = subject_expr;
279        if is_list_reverse_of(true_body, &acc_name)
280            && is_self_tail_with_prepend_acc(false_body, &fd.name, acc_idx, &acc_name)
281        {
282            return Some(BufferBuildShape {
283                acc_param_idx: acc_idx,
284                acc_param_name: acc_name,
285                kind: BufferBuildKind::InternalReverse,
286            });
287        }
288    }
289
290    // Otherwise try the ExternalReverse shape: `match <list> { [] -> acc;
291    // [_, .._] -> recurse(... prepend(_, acc)) }`. The reverse lives at
292    // the caller, e.g. `List.reverse(<this>(args, []))` (see the
293    // `BufferBuildKind` doc above for context).
294    if let Some((nil_body, cons_body)) = pair_nil_cons_arms(arms)
295        && is_ident_named(nil_body, &acc_name)
296        && is_self_tail_with_prepend_acc(cons_body, &fd.name, acc_idx, &acc_name)
297    {
298        return Some(BufferBuildShape {
299            acc_param_idx: acc_idx,
300            acc_param_name: acc_name,
301            kind: BufferBuildKind::ExternalReverse,
302        });
303    }
304
305    None
306}
307
308/// Match a 2-arm match where one arm is `[]` and the other is
309/// `[head, ..tail]`. Returns `(nil_body, cons_body)` (both as `&Expr`,
310/// matching `pair_bool_arms`). `None` if the arms don't exactly cover
311/// those two patterns.
312fn pair_nil_cons_arms(arms: &[MatchArm]) -> Option<(&Expr, &Expr)> {
313    if arms.len() != 2 {
314        return None;
315    }
316    let mut nil_body: Option<&Expr> = None;
317    let mut cons_body: Option<&Expr> = None;
318    for arm in arms {
319        match &arm.pattern {
320            Pattern::EmptyList => nil_body = Some(&arm.body.node),
321            Pattern::Cons(_, _) => cons_body = Some(&arm.body.node),
322            _ => return None,
323        }
324    }
325    match (nil_body, cons_body) {
326        (Some(n), Some(c)) => Some((n, c)),
327        _ => None,
328    }
329}
330
331/// True if `expr` is just an identifier reference to `name`.
332fn is_ident_named(expr: &Expr, name: &str) -> bool {
333    matches!(expr, Expr::Ident(n) if n == name)
334}
335
336/// Recognise a `String.join` first-arg as either a direct sink call or
337/// a `List.reverse(<sink>(args, []))` wrapper. Returns the matched
338/// sink fn name when the kind matches the shape we expect:
339///   InternalReverse sinks → only the direct call form
340///   ExternalReverse sinks → only the `List.reverse(...)` form
341/// Returning the name only when the kinds line up keeps us from
342/// accidentally fusing a sink against a call site that's missing
343/// its required reverse (or has an extraneous one).
344/// Single source of truth for "is this expression a rewriteable
345/// `String.join(<sink>(args, []), sep)` fusion site?". Returns the
346/// matched sink name when **all** preconditions hold:
347/// 1. The expression is a `String.join(<inner>, _)` call.
348/// 2. `<inner>` is either a direct sink call (for InternalReverse
349///    sinks) or `List.reverse(<sink>(...))` (for ExternalReverse).
350/// 3. The reverse-placement on the call site matches the sink's kind
351///    — mismatch would silently drop or double-reverse.
352/// 4. The acc-position arg in the inner call is a literal empty list.
353///    Anything else means the user is starting the fold with a
354///    non-empty accumulator, and the buffered rewrite would silently
355///    drop those initial elements.
356///
357/// Both `find_fusion_sites` (diagnostics) and `try_rewrite_fusion_site`
358/// (the actual AST rewrite) call this so the two stay in lockstep —
359/// `aver check` can never report a site that the rewrite then refuses
360/// to take.
361fn match_string_join_fusion_site(
362    expr: &Expr,
363    sinks: &HashMap<String, BufferBuildShape>,
364) -> Option<String> {
365    let Expr::FnCall(callee, args) = expr else {
366        return None;
367    };
368    if !is_dotted_ident(&callee.node, "String", "join") || args.len() != 2 {
369        return None;
370    }
371    let consumer_arg = &args[0].node;
372
373    // Peel an optional `List.reverse(...)` wrapper.
374    let (inner_call_expr, saw_external_reverse) = match consumer_arg {
375        Expr::FnCall(rev_callee, rev_args)
376            if is_dotted_ident(&rev_callee.node, "List", "reverse") && rev_args.len() == 1 =>
377        {
378            (&rev_args[0].node, true)
379        }
380        other => (other, false),
381    };
382
383    let Expr::FnCall(inner_callee, inner_args) = inner_call_expr else {
384        return None;
385    };
386    let Expr::Ident(name) = &inner_callee.node else {
387        return None;
388    };
389    let shape = sinks.get(name)?;
390
391    let kinds_align = matches!(
392        (saw_external_reverse, &shape.kind),
393        (false, BufferBuildKind::InternalReverse) | (true, BufferBuildKind::ExternalReverse)
394    );
395    if !kinds_align {
396        return None;
397    }
398
399    let acc_arg = inner_args.get(shape.acc_param_idx)?;
400    if !matches!(&acc_arg.node, Expr::List(items) if items.is_empty()) {
401        return None;
402    }
403
404    Some(name.clone())
405}
406
407/// True if a parameter type-string parses as `List<...>`.
408fn is_list_type_str(ty: &str) -> bool {
409    let t = ty.trim();
410    t.starts_with("List<") && t.ends_with('>')
411}
412
413/// Extract the single match expression that forms a fn's entire body.
414/// Returns `None` if the body is empty, has multiple statements, or its
415/// single statement isn't a match expression.
416fn single_match_body(body: &FnBody) -> Option<&Expr> {
417    let stmts = body.stmts();
418    if stmts.len() != 1 {
419        return None;
420    }
421    match &stmts[0] {
422        Stmt::Expr(spanned) => match &spanned.node {
423            Expr::Match { .. } => Some(&spanned.node),
424            _ => None,
425        },
426        Stmt::Binding(_, _, _) => None,
427    }
428}
429
430/// If `arms` is exactly two arms with `Bool(true)` / `Bool(false)`
431/// patterns, return `(true_body, false_body)` references. Order in
432/// source doesn't matter — we sort by pattern.
433fn pair_bool_arms(arms: &[MatchArm]) -> Option<(&Expr, &Expr)> {
434    if arms.len() != 2 {
435        return None;
436    }
437    let mut t = None;
438    let mut f = None;
439    for arm in arms {
440        match &arm.pattern {
441            Pattern::Literal(Literal::Bool(true)) => {
442                if t.is_some() {
443                    return None;
444                }
445                t = Some(&arm.body.node);
446            }
447            Pattern::Literal(Literal::Bool(false)) => {
448                if f.is_some() {
449                    return None;
450                }
451                f = Some(&arm.body.node);
452            }
453            _ => return None,
454        }
455    }
456    Some((t?, f?))
457}
458
459/// True if `expr` is `List.reverse(<Ident(acc_name)>)`.
460fn is_list_reverse_of(expr: &Expr, acc_name: &str) -> bool {
461    let (callee, args) = match expr {
462        Expr::FnCall(c, a) => (c, a),
463        _ => return false,
464    };
465    if !is_dotted_ident(&callee.node, "List", "reverse") {
466        return false;
467    }
468    if args.len() != 1 {
469        return false;
470    }
471    matches!(&args[0].node, Expr::Ident(name) if name == acc_name)
472}
473
474/// True if `expr` is a tail-call to `self_name` whose argument list
475/// contains `List.prepend(<anything>, <Ident(acc_name)>)` in any
476/// position. The position should match the `acc_param_idx` but the
477/// caller may have other params before it; we only require the
478/// `prepend` to terminate in the expected accumulator binding.
479fn is_self_tail_with_prepend_acc(
480    expr: &Expr,
481    self_name: &str,
482    acc_idx: usize,
483    acc_name: &str,
484) -> bool {
485    let data = match expr {
486        Expr::TailCall(data) => data,
487        _ => return false,
488    };
489    if data.target != self_name {
490        return false;
491    }
492    // The prepend has to land in the *acc* position specifically — the
493    // synthesizer extracts the element expression from `args[acc_idx]`.
494    // A loose `any` here would let through fns where some other arg
495    // happens to be a prepend, the synth would later return None, but
496    // detection had already promised the sink to call-site rewriting:
497    // `String.join(<sink>(...))` would be rewritten to call a
498    // `<sink>__buffered` that never gets generated. Require the exact
499    // shape here so detection and synthesis agree.
500    let acc_arg = match data.args.get(acc_idx) {
501        Some(a) => a,
502        None => return false,
503    };
504    is_list_prepend_to_acc(&acc_arg.node, acc_name)
505}
506
507/// True if `expr` is `List.prepend(<anything>, <Ident(acc_name)>)`.
508fn is_list_prepend_to_acc(expr: &Expr, acc_name: &str) -> bool {
509    let (callee, args) = match expr {
510        Expr::FnCall(c, a) => (c, a),
511        _ => return false,
512    };
513    if !is_dotted_ident(&callee.node, "List", "prepend") {
514        return false;
515    }
516    if args.len() != 2 {
517        return false;
518    }
519    matches!(&args[1].node, Expr::Ident(name) if name == acc_name)
520}
521
522/// True if `expr` is `<Module>.<Member>` access (the un-called callee
523/// shape of `Module.member(...)`).
524fn is_dotted_ident(expr: &Expr, module: &str, member: &str) -> bool {
525    let (base, attr) = match expr {
526        Expr::Attr(b, a) => (b, a),
527        _ => return false,
528    };
529    if attr != member {
530        return false;
531    }
532    matches!(&base.node, Expr::Ident(name) if name == module)
533}
534
535/// Synthesize a `<fn>__buffered` variant for each matched buffer-build
536/// sink. The synthesized FnDef walks the same shape as the original but
537/// threads a runtime `Buffer` through tail-call args instead of building
538/// a `List<T>` of strings:
539///
540/// Original:
541/// ```aver
542/// fn build(.., acc: List<T>) -> List<T>
543///     match <cond>
544///         true  -> List.reverse(acc)
545///         false -> build(.., List.prepend(<elem>, acc))
546/// ```
547///
548/// Synthesized:
549/// ```aver
550/// fn build__buffered(.., __buf: Buffer, __sep: String) -> Buffer
551///     match <cond>
552///         true  -> __buf
553///         false -> build__buffered(..,
554///             __buf_append(
555///                 __buf_append_sep_unless_first(__buf, __sep),
556///                 <elem>
557///             ),
558///             __sep
559///         )
560/// ```
561///
562/// Threading is via expression composition: the inner
563/// `__buf_append_sep_unless_first` returns the (possibly grown) buffer,
564/// the outer `__buf_append` writes the element and again returns
565/// the (possibly grown) buffer, and that final pointer is what the tail
566/// call sees as `__buf`. No `_ =` discards anywhere — the C' review
567/// explicitly required this to avoid use-after-grow corruption.
568///
569/// Returns one `FnDef` per matched fn. Caller appends to the user-fn
570/// list before WASM emission so both original and buffered variants
571/// reach codegen through the same pipeline.
572pub fn synthesize_buffered_variants(
573    fns: &[&FnDef],
574    sinks: &HashMap<String, BufferBuildShape>,
575) -> Vec<FnDef> {
576    let mut out = Vec::new();
577    for fd in fns {
578        if let Some(shape) = sinks.get(&fd.name)
579            && let Some(buffered) = build_buffered_variant(fd, shape)
580        {
581            out.push(buffered);
582        }
583    }
584    out
585}
586
587/// Wrap an `Expr` as `Spanned<Expr>` carrying the same line as the
588/// matched fn (best effort — the synthesized code is internal and
589/// won't be source-located by the user, but having a non-zero line
590/// keeps downstream visitors happy).
591fn sp_at(line: usize, expr: Expr) -> Spanned<Expr> {
592    Spanned { node: expr, line }
593}
594
595/// Build `<intrinsic>(args...)` as a Spanned<Expr>. Intrinsic names
596/// are bare identifiers (no module dot) — `__buf_append`,
597/// `__buf_append_sep_unless_first`. The WASM emitter recognises them
598/// in the builtin dispatch.
599fn intrinsic_call(line: usize, name: &str, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
600    let callee = sp_at(line, Expr::Ident(name.to_string()));
601    sp_at(line, Expr::FnCall(Box::new(callee), args))
602}
603
604/// Run the full buffer-build deforestation pass on a program: detect
605/// sinks, synthesize buffered variants, rewrite fusion sites in place,
606/// and APPEND the synthesized FnDefs to the items list as new
607/// top-level fns. Caller is responsible for invoking this AFTER
608/// `tco::transform_program` (the detector requires `Expr::TailCall`
609/// nodes) and BEFORE `resolver::resolve_program` (the detector +
610/// rewrite both match on `Expr::Ident` shapes that the resolver
611/// rewrites to `Expr::Resolved`).
612///
613/// Returns the count of fusion sites rewritten + buffered variants
614/// synthesized for diagnostic / bench reporting.
615pub fn run_buffer_build_pass(items: &mut Vec<crate::ast::TopLevel>) -> (usize, usize) {
616    let fn_refs: Vec<&FnDef> = items
617        .iter()
618        .filter_map(|it| match it {
619            crate::ast::TopLevel::FnDef(fd) => Some(fd),
620            _ => None,
621        })
622        .collect();
623    let all_sinks = compute_buffer_build_sinks(&fn_refs);
624    if all_sinks.is_empty() {
625        return (0, 0);
626    }
627    let sites = find_fusion_sites(&fn_refs, &all_sinks);
628
629    // Synthesize a buffered variant only for sinks that actually have
630    // at least one rewriteable call site. The earlier shape produced
631    // a `<sink>__buffered` for every detected sink — bloat in the
632    // common case (most detected sinks aren't called via the canonical
633    // String.join shape) and a real risk of name-shadowing a user fn
634    // named `<sink>__buffered`. Restricting to used sinks keeps the
635    // synthetic surface tight.
636    let mut used_sinks: HashMap<String, BufferBuildShape> = HashMap::new();
637    for site in &sites {
638        if let Some(shape) = all_sinks.get(&site.sink_fn) {
639            used_sinks.insert(site.sink_fn.clone(), shape.clone());
640        }
641    }
642    let synthesized = synthesize_buffered_variants(&fn_refs, &used_sinks);
643    let sinks = used_sinks;
644    drop(fn_refs);
645
646    let mut fn_defs_owned: Vec<&mut FnDef> = items
647        .iter_mut()
648        .filter_map(|it| match it {
649            crate::ast::TopLevel::FnDef(fd) => Some(fd),
650            _ => None,
651        })
652        .collect();
653    // rewrite_fusion_sites takes &mut [FnDef], so pull a fresh
654    // mutable view across owned slots. We can't pass &mut [&mut FnDef]
655    // directly — instead, walk and rewrite each fn body individually.
656    for fd in fn_defs_owned.iter_mut() {
657        rewrite_one_fn(fd, &sinks);
658    }
659
660    items.reserve(synthesized.len());
661    for fd in synthesized.iter() {
662        items.push(crate::ast::TopLevel::FnDef(fd.clone()));
663    }
664
665    (sites.len(), synthesized.len())
666}
667
668/// Apply fusion-site rewrite to a single fn body. Internal helper
669/// for `run_buffer_build_pass` since `rewrite_fusion_sites` takes a
670/// slice and we have an iterator-of-mut-refs here.
671fn rewrite_one_fn(fd: &mut FnDef, sinks: &HashMap<String, BufferBuildShape>) {
672    let body_arc = std::sync::Arc::make_mut(&mut fd.body);
673    let FnBody::Block(stmts) = body_arc;
674    for stmt in stmts.iter_mut() {
675        match stmt {
676            Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
677                rewrite_expr_in_place(expr, sinks);
678            }
679        }
680    }
681}
682
683/// Walk every expression in `fn_defs` and rewrite `String.join`
684/// fusion sites in place: `String.join(matched_fn(args, []), sep)` →
685/// `__buf_finalize(matched_fn__buffered(args_without_acc, __buf_new(8192), sep))`.
686///
687/// Conservative trigger per the C' review: only fires when the
688/// acc-position arg is a literal `Expr::List([])`. A non-empty
689/// initial accumulator would silently lose elements after rewrite,
690/// so we skip in that case.
691///
692/// The rewrite is recursive: nested fusion sites (a fusion site
693/// inside another fusion site's args) all get rewritten in one pass.
694pub fn rewrite_fusion_sites(fn_defs: &mut [FnDef], sinks: &HashMap<String, BufferBuildShape>) {
695    if sinks.is_empty() {
696        return;
697    }
698    for fd in fn_defs.iter_mut() {
699        let body_arc = std::sync::Arc::make_mut(&mut fd.body);
700        let FnBody::Block(stmts) = body_arc;
701        for stmt in stmts.iter_mut() {
702            match stmt {
703                Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
704                    rewrite_expr_in_place(expr, sinks);
705                }
706            }
707        }
708    }
709}
710
711/// Recursive expression-tree walker that rewrites fusion sites in
712/// place. Rewrite is "outermost first" — if the whole expression is
713/// a fusion site, transform it before descending into the new shape's
714/// children, so we don't double-rewrite.
715fn rewrite_expr_in_place(expr: &mut Spanned<Expr>, sinks: &HashMap<String, BufferBuildShape>) {
716    if let Some(replacement) = try_rewrite_fusion_site(expr, sinks) {
717        *expr = replacement;
718        // The replacement contains the original elem expressions
719        // (possibly themselves containing fusion sites in deep
720        // gradient builders). Recurse into the new tree.
721        descend_into_subexprs(expr, sinks);
722        return;
723    }
724    descend_into_subexprs(expr, sinks);
725}
726
727/// Recurse into the children of an Expr, applying `rewrite_expr_in_place`
728/// to each. Mirrors the shape coverage of `walk_expr_for_fusion_sites`
729/// in this module so we don't miss any node kind.
730fn descend_into_subexprs(expr: &mut Spanned<Expr>, sinks: &HashMap<String, BufferBuildShape>) {
731    match &mut expr.node {
732        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
733        Expr::Constructor(_, Some(inner)) | Expr::Attr(inner, _) | Expr::ErrorProp(inner) => {
734            rewrite_expr_in_place(inner, sinks);
735        }
736        Expr::FnCall(callee, args) => {
737            rewrite_expr_in_place(callee, sinks);
738            for a in args.iter_mut() {
739                rewrite_expr_in_place(a, sinks);
740            }
741        }
742        Expr::TailCall(data) => {
743            for a in data.args.iter_mut() {
744                rewrite_expr_in_place(a, sinks);
745            }
746        }
747        Expr::BinOp(_, l, r) => {
748            rewrite_expr_in_place(l, sinks);
749            rewrite_expr_in_place(r, sinks);
750        }
751        Expr::Match { subject, arms } => {
752            rewrite_expr_in_place(subject, sinks);
753            for arm in arms.iter_mut() {
754                rewrite_expr_in_place(&mut arm.body, sinks);
755            }
756        }
757        Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
758            for it in items.iter_mut() {
759                rewrite_expr_in_place(it, sinks);
760            }
761        }
762        Expr::MapLiteral(entries) => {
763            for (k, v) in entries.iter_mut() {
764                rewrite_expr_in_place(k, sinks);
765                rewrite_expr_in_place(v, sinks);
766            }
767        }
768        Expr::RecordCreate { fields, .. } => {
769            for (_, v) in fields.iter_mut() {
770                rewrite_expr_in_place(v, sinks);
771            }
772        }
773        Expr::RecordUpdate { base, updates, .. } => {
774            rewrite_expr_in_place(base, sinks);
775            for (_, v) in updates.iter_mut() {
776                rewrite_expr_in_place(v, sinks);
777            }
778        }
779        Expr::InterpolatedStr(parts) => {
780            for part in parts.iter_mut() {
781                if let crate::ast::StrPart::Parsed(inner) = part {
782                    rewrite_expr_in_place(inner, sinks);
783                }
784            }
785        }
786    }
787}
788
789/// If `expr` is a `String.join(matched_fn(args, []), sep)` with
790/// matched_fn in `sinks` and acc-position arg a literal empty list,
791/// return the rewritten Spanned<Expr>. Else return None.
792fn try_rewrite_fusion_site(
793    expr: &Spanned<Expr>,
794    sinks: &HashMap<String, BufferBuildShape>,
795) -> Option<Spanned<Expr>> {
796    let line = expr.line;
797
798    // Match the same predicate used by `find_fusion_sites` so the
799    // diagnostic count and the rewrite count are guaranteed equal.
800    let sink_name = match_string_join_fusion_site(&expr.node, sinks)?;
801    let shape = sinks.get(&sink_name)?;
802
803    // Re-extract the inner sink call and its args. The match predicate
804    // above already verified the shape; this is just to recover the
805    // pieces we need to assemble the rewrite.
806    let outer_args = match &expr.node {
807        Expr::FnCall(_, a) => a,
808        _ => return None,
809    };
810    let consumer_arg = &outer_args[0].node;
811    let inner_call_expr = if let Expr::FnCall(rev_callee, rev_args) = consumer_arg
812        && is_dotted_ident(&rev_callee.node, "List", "reverse")
813        && rev_args.len() == 1
814    {
815        &rev_args[0].node
816    } else {
817        consumer_arg
818    };
819    let inner_args = match inner_call_expr {
820        Expr::FnCall(_, a) => a,
821        _ => return None,
822    };
823
824    // Build the rewrite:
825    //   __buf_finalize(
826    //     <fn>__buffered(
827    //       <args without acc-pos>,
828    //       __buf_new(8192),
829    //       <sep>
830    //     )
831    //   )
832    let sep_expr = outer_args[1].clone();
833    let buf_new = intrinsic_call(
834        line,
835        "__buf_new",
836        vec![sp_at(line, Expr::Literal(Literal::Int(8192)))],
837    );
838    let mut buffered_args: Vec<Spanned<Expr>> = inner_args
839        .iter()
840        .enumerate()
841        .filter_map(|(i, a)| (i != shape.acc_param_idx).then_some(a).cloned())
842        .collect();
843    buffered_args.push(buf_new);
844    buffered_args.push(sep_expr);
845    let buffered_call = sp_at(
846        line,
847        Expr::FnCall(
848            Box::new(sp_at(line, Expr::Ident(format!("{}__buffered", sink_name)))),
849            buffered_args,
850        ),
851    );
852    Some(intrinsic_call(line, "__buf_finalize", vec![buffered_call]))
853}
854
855/// Construct the buffered FnDef for a single matched fn. Returns
856/// `None` if the original body shape doesn't match what we expect
857/// (defensive: detection should have caught this, but if the body
858/// changed shape between detection and synthesis, skip).
859fn build_buffered_variant(fd: &FnDef, shape: &BufferBuildShape) -> Option<FnDef> {
860    // Original body: `match <subject> { <terminating-arm>; <recursive-arm> }`.
861    // The two BufferBuildKind variants pair different patterns:
862    //   InternalReverse: `true -> List.reverse(acc)`, `false -> recurse(...)`.
863    //   ExternalReverse: `[] -> acc`, `[head, ..rest] -> recurse(...)`.
864    // We extract the recursive arm in both cases (for the prepend tail
865    // call) and rebuild the match with terminating arm `... -> __buf`.
866    let stmts = fd.body.stmts();
867    if stmts.len() != 1 {
868        return None;
869    }
870    let outer_expr = match &stmts[0] {
871        Stmt::Expr(spanned) => spanned,
872        _ => return None,
873    };
874    let (subject_orig, arms_orig) = match &outer_expr.node {
875        Expr::Match { subject, arms } => (subject, arms),
876        _ => return None,
877    };
878    let recursive_body: &Spanned<Expr> = match shape.kind {
879        BufferBuildKind::InternalReverse => arms_orig
880            .iter()
881            .find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(false))))
882            .map(|a| a.body.as_ref())?,
883        BufferBuildKind::ExternalReverse => arms_orig
884            .iter()
885            .find(|a| matches!(a.pattern, Pattern::Cons(_, _)))
886            .map(|a| a.body.as_ref())?,
887    };
888    let tail_data = match &recursive_body.node {
889        Expr::TailCall(data) => data,
890        _ => return None,
891    };
892
893    // The acc-position arg in the original tail call is
894    // `List.prepend(<elem>, acc)`. Extract the element expression.
895    let acc_arg_orig = tail_data.args.get(shape.acc_param_idx)?;
896    let elem_expr = match &acc_arg_orig.node {
897        Expr::FnCall(callee, args) => {
898            if !is_dotted_ident(&callee.node, "List", "prepend") {
899                return None;
900            }
901            if args.len() != 2 {
902                return None;
903            }
904            // args[0] is elem, args[1] is acc ident — verify acc.
905            match &args[1].node {
906                Expr::Ident(name) if name == &shape.acc_param_name => {}
907                _ => return None,
908            }
909            args[0].clone()
910        }
911        _ => return None,
912    };
913
914    let line = fd.line;
915    let buf_name = "__buf";
916    let sep_name = "__sep";
917    let buffered_target = format!("{}__buffered", fd.name);
918
919    // Synthesized false arm body:
920    //   <self>__buffered(<orig args minus acc>, __buf_append(<sep_unless_first>, <elem>), __sep)
921    //
922    // Build the buffer-threading expression first: the inner intrinsic
923    // appends `__sep` if the buffer is non-empty (otherwise no-op),
924    // returning the possibly-grown buffer. The outer intrinsic appends
925    // the user's element. The result is what gets passed as the
926    // buffered variant's `__buf` arg in the recursive call.
927    let buf_ident = || sp_at(line, Expr::Ident(buf_name.to_string()));
928    let sep_ident = || sp_at(line, Expr::Ident(sep_name.to_string()));
929    let sep_then_buf = intrinsic_call(
930        line,
931        "__buf_append_sep_unless_first",
932        vec![buf_ident(), sep_ident()],
933    );
934    let final_buf = intrinsic_call(line, "__buf_append", vec![sep_then_buf, elem_expr]);
935
936    // Build new tail-call args: original args with acc-pos replaced by
937    // the threaded buffer expression, then `__sep` appended at end.
938    let mut new_args: Vec<Spanned<Expr>> = tail_data
939        .args
940        .iter()
941        .enumerate()
942        .map(|(i, a)| {
943            if i == shape.acc_param_idx {
944                final_buf.clone()
945            } else {
946                a.clone()
947            }
948        })
949        .collect();
950    new_args.push(sep_ident());
951
952    let new_recursive_body = sp_at(
953        line,
954        Expr::TailCall(Box::new(TailCallData {
955            target: buffered_target.clone(),
956            args: new_args,
957        })),
958    );
959
960    // Terminating arm body: just return `__buf` — the buffer IS the result.
961    // Pattern depends on which sink shape we matched: `true` for the
962    // InternalReverse idiom (where the original returned `List.reverse(acc)`),
963    // `[]` for ExternalReverse (where the original returned `acc`).
964    let new_arms = match shape.kind {
965        BufferBuildKind::InternalReverse => vec![
966            MatchArm {
967                pattern: Pattern::Literal(Literal::Bool(true)),
968                body: Box::new(buf_ident()),
969            },
970            MatchArm {
971                pattern: Pattern::Literal(Literal::Bool(false)),
972                body: Box::new(new_recursive_body),
973            },
974        ],
975        BufferBuildKind::ExternalReverse => {
976            // Re-use the cons binding names from the original arm so
977            // any `head` / `rest` references inside the recursive body
978            // continue to resolve.
979            let cons_pat = arms_orig
980                .iter()
981                .find_map(|a| match &a.pattern {
982                    Pattern::Cons(h, t) => Some(Pattern::Cons(h.clone(), t.clone())),
983                    _ => None,
984                })
985                .unwrap_or(Pattern::Cons("__head".to_string(), "__tail".to_string()));
986            vec![
987                MatchArm {
988                    pattern: Pattern::EmptyList,
989                    body: Box::new(buf_ident()),
990                },
991                MatchArm {
992                    pattern: cons_pat,
993                    body: Box::new(new_recursive_body),
994                },
995            ]
996        }
997    };
998
999    let new_match = sp_at(
1000        line,
1001        Expr::Match {
1002            subject: subject_orig.clone(),
1003            arms: new_arms,
1004        },
1005    );
1006
1007    let new_body = FnBody::Block(vec![Stmt::Expr(new_match)]);
1008
1009    // Params: original minus acc + (__buf, "Buffer") + (__sep, "String").
1010    let mut new_params: Vec<(String, String)> = fd
1011        .params
1012        .iter()
1013        .enumerate()
1014        .filter_map(|(i, p)| (i != shape.acc_param_idx).then_some(p).cloned())
1015        .collect();
1016    new_params.push((buf_name.to_string(), "Buffer".to_string()));
1017    new_params.push((sep_name.to_string(), "String".to_string()));
1018
1019    Some(FnDef {
1020        name: buffered_target,
1021        line,
1022        params: new_params,
1023        return_type: "Buffer".to_string(),
1024        // Synthesized variants inherit effects from the original — if
1025        // the matched fn calls effectful helpers (like `renderRow`
1026        // calling `Console.print`), the buffered variant calls them
1027        // too at the same positions. Conservative.
1028        effects: fd.effects.clone(),
1029        desc: Some(format!(
1030            "Synthesized buffered variant of `{}` for deforestation \
1031             lowering. Call sites that match `String.join({}(...), sep)` \
1032             are rewritten to alloc a buffer + call this variant + \
1033             finalize, skipping the intermediate List.",
1034            fd.name, fd.name
1035        )),
1036        body: Arc::new(new_body),
1037        resolution: None,
1038    })
1039}
1040
1041#[cfg(test)]
1042mod tests {
1043    use super::*;
1044    use crate::ast::{BinOp, FnBody, FnDef, Literal, Spanned, TailCallData};
1045    use std::sync::Arc;
1046
1047    fn sp<T>(value: T) -> Spanned<T> {
1048        Spanned {
1049            node: value,
1050            line: 1,
1051        }
1052    }
1053
1054    fn ident(name: &str) -> Spanned<Expr> {
1055        sp(Expr::Ident(name.to_string()))
1056    }
1057
1058    fn dotted(module: &str, member: &str) -> Spanned<Expr> {
1059        sp(Expr::Attr(Box::new(ident(module)), member.to_string()))
1060    }
1061
1062    fn call(callee: Spanned<Expr>, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
1063        sp(Expr::FnCall(Box::new(callee), args))
1064    }
1065
1066    /// Build a canonical buffer-build fn: takes (col: Int, acc: List<Int>),
1067    /// matches col >= 10, true → reverse(acc), false → tail-call self
1068    /// with prepend(col, acc).
1069    fn canonical_builder(name: &str) -> FnDef {
1070        let true_body = call(dotted("List", "reverse"), vec![ident("acc")]);
1071        let prepend = call(dotted("List", "prepend"), vec![ident("col"), ident("acc")]);
1072        let false_body = sp(Expr::TailCall(Box::new(TailCallData {
1073            target: name.to_string(),
1074            args: vec![
1075                sp(Expr::BinOp(
1076                    BinOp::Add,
1077                    Box::new(ident("col")),
1078                    Box::new(sp(Expr::Literal(Literal::Int(1)))),
1079                )),
1080                prepend,
1081            ],
1082        })));
1083        let match_expr = sp(Expr::Match {
1084            subject: Box::new(sp(Expr::BinOp(
1085                BinOp::Gte,
1086                Box::new(ident("col")),
1087                Box::new(sp(Expr::Literal(Literal::Int(10)))),
1088            ))),
1089            arms: vec![
1090                MatchArm {
1091                    pattern: Pattern::Literal(Literal::Bool(true)),
1092                    body: Box::new(true_body),
1093                },
1094                MatchArm {
1095                    pattern: Pattern::Literal(Literal::Bool(false)),
1096                    body: Box::new(false_body),
1097                },
1098            ],
1099        });
1100        FnDef {
1101            name: name.to_string(),
1102            line: 1,
1103            params: vec![
1104                ("col".to_string(), "Int".to_string()),
1105                ("acc".to_string(), "List<Int>".to_string()),
1106            ],
1107            return_type: "List<Int>".to_string(),
1108            effects: vec![],
1109            desc: None,
1110            body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
1111            resolution: None,
1112        }
1113    }
1114
1115    #[test]
1116    fn matches_canonical_buffer_build() {
1117        let fd = canonical_builder("build");
1118        let info = compute_buffer_build_sinks(&[&fd]);
1119        let shape = info.get("build").expect("expected match");
1120        assert_eq!(shape.acc_param_idx, 1);
1121        assert_eq!(shape.acc_param_name, "acc");
1122    }
1123
1124    #[test]
1125    fn rejects_fn_without_list_param() {
1126        let mut fd = canonical_builder("build");
1127        // Strip the List<...> param.
1128        fd.params = vec![("col".to_string(), "Int".to_string())];
1129        let info = compute_buffer_build_sinks(&[&fd]);
1130        assert!(info.is_empty(), "fn without List param should not match");
1131    }
1132
1133    #[test]
1134    fn rejects_when_true_arm_isnt_reverse() {
1135        let mut fd = canonical_builder("build");
1136        // Replace true arm body with a different expression.
1137        if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1138            if let Stmt::Expr(spanned) = &mut stmts[0] {
1139                if let Expr::Match { arms, .. } = &mut spanned.node {
1140                    arms[0].body = Box::new(ident("acc"));
1141                }
1142            }
1143        }
1144        let info = compute_buffer_build_sinks(&[&fd]);
1145        assert!(
1146            info.is_empty(),
1147            "fn returning bare acc instead of reverse should not match"
1148        );
1149    }
1150
1151    #[test]
1152    fn rejects_when_false_arm_uses_append_not_prepend() {
1153        let mut fd = canonical_builder("build");
1154        // Swap List.prepend → List.append in the false arm tail call.
1155        if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1156            if let Stmt::Expr(spanned) = &mut stmts[0] {
1157                if let Expr::Match { arms, .. } = &mut spanned.node {
1158                    let false_body = arms[1].body.as_mut();
1159                    if let Expr::TailCall(data) = &mut false_body.node {
1160                        if let Expr::FnCall(callee, _) = &mut data.args[1].node {
1161                            if let Expr::Attr(_, attr) = &mut callee.node {
1162                                *attr = "append".to_string();
1163                            }
1164                        }
1165                    }
1166                }
1167            }
1168        }
1169        let info = compute_buffer_build_sinks(&[&fd]);
1170        assert!(
1171            info.is_empty(),
1172            "fn using List.append instead of prepend should not match"
1173        );
1174    }
1175
1176    #[test]
1177    fn rejects_tail_call_to_different_fn() {
1178        let mut fd = canonical_builder("build");
1179        if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1180            if let Stmt::Expr(spanned) = &mut stmts[0] {
1181                if let Expr::Match { arms, .. } = &mut spanned.node {
1182                    let false_body = arms[1].body.as_mut();
1183                    if let Expr::TailCall(data) = &mut false_body.node {
1184                        data.target = "someone_else".to_string();
1185                    }
1186                }
1187            }
1188        }
1189        let info = compute_buffer_build_sinks(&[&fd]);
1190        assert!(
1191            info.is_empty(),
1192            "fn whose recursive call targets a different name should not match"
1193        );
1194    }
1195
1196    #[test]
1197    fn rejects_match_with_non_bool_arms() {
1198        let mut fd = canonical_builder("build");
1199        if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1200            if let Stmt::Expr(spanned) = &mut stmts[0] {
1201                if let Expr::Match { arms, .. } = &mut spanned.node {
1202                    arms[0].pattern = Pattern::Literal(Literal::Int(0));
1203                }
1204            }
1205        }
1206        let info = compute_buffer_build_sinks(&[&fd]);
1207        assert!(
1208            info.is_empty(),
1209            "match on non-bool patterns should not be detected as buffer-build"
1210        );
1211    }
1212
1213    /// End-to-end: parse a small Aver source, run TCO, then detect.
1214    /// The TCO transform is what produces `Expr::TailCall` nodes from
1215    /// raw `Expr::FnCall` self-recursion; detection runs on the post-TCO
1216    /// AST.
1217    #[test]
1218    fn detects_via_parser_after_tco() {
1219        let src = r#"
1220fn build(n: Int, acc: List<Int>) -> List<Int>
1221    match n <= 0
1222        true  -> List.reverse(acc)
1223        false -> build(n - 1, List.prepend(n, acc))
1224"#;
1225        let mut lexer = crate::lexer::Lexer::new(src);
1226        let tokens = lexer.tokenize().expect("lex");
1227        let mut parser = crate::parser::Parser::new(tokens);
1228        let mut items = parser.parse().expect("parse");
1229        crate::tco::transform_program(&mut items);
1230        let fns: Vec<&FnDef> = items
1231            .iter()
1232            .filter_map(|it| match it {
1233                crate::ast::TopLevel::FnDef(fd) => Some(fd),
1234                _ => None,
1235            })
1236            .collect();
1237        let info = compute_buffer_build_sinks(&fns);
1238        let shape = info
1239            .get("build")
1240            .expect("expected end-to-end shape match for canonical builder");
1241        assert_eq!(shape.acc_param_idx, 1);
1242        assert_eq!(shape.acc_param_name, "acc");
1243    }
1244
1245    /// End-to-end fusion-site detection: builder + caller `String.join`
1246    /// site recognised, line recorded, sink name attached.
1247    #[test]
1248    fn finds_fusion_site_via_parser() {
1249        let src = r#"
1250fn build(n: Int, acc: List<Int>) -> List<Int>
1251    match n <= 0
1252        true  -> List.reverse(acc)
1253        false -> build(n - 1, List.prepend(n, acc))
1254
1255fn main() -> String
1256    String.join(build(5, []), ",")
1257"#;
1258        let mut lexer = crate::lexer::Lexer::new(src);
1259        let tokens = lexer.tokenize().expect("lex");
1260        let mut parser = crate::parser::Parser::new(tokens);
1261        let mut items = parser.parse().expect("parse");
1262        crate::tco::transform_program(&mut items);
1263        let fns: Vec<&FnDef> = items
1264            .iter()
1265            .filter_map(|it| match it {
1266                crate::ast::TopLevel::FnDef(fd) => Some(fd),
1267                _ => None,
1268            })
1269            .collect();
1270        let sinks = compute_buffer_build_sinks(&fns);
1271        let sites = find_fusion_sites(&fns, &sinks);
1272        assert_eq!(sites.len(), 1, "expected one fusion site, got {sites:?}");
1273        let site = &sites[0];
1274        assert_eq!(site.enclosing_fn, "main");
1275        assert_eq!(site.sink_fn, "build");
1276        assert!(site.line > 0, "expected real line info, got 0");
1277    }
1278
1279    /// Caller passes the matched fn's result to a non-`String.join`
1280    /// destination — should NOT register as a fusion site (no buffer
1281    /// to write into).
1282    #[test]
1283    fn ignores_call_when_not_wrapped_in_string_join() {
1284        let src = r#"
1285fn build(n: Int, acc: List<Int>) -> List<Int>
1286    match n <= 0
1287        true  -> List.reverse(acc)
1288        false -> build(n - 1, List.prepend(n, acc))
1289
1290fn main() -> List<Int>
1291    build(5, [])
1292"#;
1293        let mut lexer = crate::lexer::Lexer::new(src);
1294        let tokens = lexer.tokenize().expect("lex");
1295        let mut parser = crate::parser::Parser::new(tokens);
1296        let mut items = parser.parse().expect("parse");
1297        crate::tco::transform_program(&mut items);
1298        let fns: Vec<&FnDef> = items
1299            .iter()
1300            .filter_map(|it| match it {
1301                crate::ast::TopLevel::FnDef(fd) => Some(fd),
1302                _ => None,
1303            })
1304            .collect();
1305        let sinks = compute_buffer_build_sinks(&fns);
1306        let sites = find_fusion_sites(&fns, &sinks);
1307        assert!(
1308            sites.is_empty(),
1309            "build called outside String.join must not be a fusion site, got {sites:?}"
1310        );
1311    }
1312
1313    /// Counter-test: a recursive fn that returns `acc` directly (no
1314    /// reverse) — semantically valid Aver, but its result order is
1315    /// reversed relative to natural read order, so deforestation can't
1316    /// safely rewrite to a forward-emit buffer loop without explicit
1317    /// authorisation. Detector must reject it.
1318    #[test]
1319    fn rejects_via_parser_when_true_arm_returns_bare_acc() {
1320        let src = r#"
1321fn build(n: Int, acc: List<Int>) -> List<Int>
1322    match n <= 0
1323        true  -> acc
1324        false -> build(n - 1, List.prepend(n, acc))
1325"#;
1326        let mut lexer = crate::lexer::Lexer::new(src);
1327        let tokens = lexer.tokenize().expect("lex");
1328        let mut parser = crate::parser::Parser::new(tokens);
1329        let mut items = parser.parse().expect("parse");
1330        crate::tco::transform_program(&mut items);
1331        let fns: Vec<&FnDef> = items
1332            .iter()
1333            .filter_map(|it| match it {
1334                crate::ast::TopLevel::FnDef(fd) => Some(fd),
1335                _ => None,
1336            })
1337            .collect();
1338        let info = compute_buffer_build_sinks(&fns);
1339        assert!(
1340            info.is_empty(),
1341            "fn returning bare acc must not be detected as a deforestation candidate"
1342        );
1343    }
1344
1345    /// End-to-end synthesis: parse a small builder, run TCO, detect
1346    /// it as a sink, then synthesize the buffered variant. Verify the
1347    /// shape: name suffix, dropped acc param, added __buf/__sep
1348    /// params, true arm returns __buf ident, false arm tail-calls
1349    /// __buffered self with threaded buffer expression.
1350    #[test]
1351    fn synthesizes_buffered_variant_from_real_builder() {
1352        let src = r#"
1353fn build(n: Int, acc: List<Int>) -> List<Int>
1354    match n <= 0
1355        true  -> List.reverse(acc)
1356        false -> build(n - 1, List.prepend(n, acc))
1357"#;
1358        let mut lexer = crate::lexer::Lexer::new(src);
1359        let tokens = lexer.tokenize().expect("lex");
1360        let mut parser = crate::parser::Parser::new(tokens);
1361        let mut items = parser.parse().expect("parse");
1362        crate::tco::transform_program(&mut items);
1363        let fns: Vec<&FnDef> = items
1364            .iter()
1365            .filter_map(|it| match it {
1366                crate::ast::TopLevel::FnDef(fd) => Some(fd),
1367                _ => None,
1368            })
1369            .collect();
1370        let sinks = compute_buffer_build_sinks(&fns);
1371        assert!(sinks.contains_key("build"));
1372        let synthesized = synthesize_buffered_variants(&fns, &sinks);
1373        assert_eq!(
1374            synthesized.len(),
1375            1,
1376            "expected exactly one synthesized variant"
1377        );
1378        let bf = &synthesized[0];
1379
1380        // Name + signature shape.
1381        assert_eq!(bf.name, "build__buffered");
1382        assert_eq!(bf.return_type, "Buffer");
1383        let param_names: Vec<&str> = bf.params.iter().map(|(n, _)| n.as_str()).collect();
1384        let param_types: Vec<&str> = bf.params.iter().map(|(_, t)| t.as_str()).collect();
1385        assert_eq!(param_names, vec!["n", "__buf", "__sep"]);
1386        assert_eq!(param_types, vec!["Int", "Buffer", "String"]);
1387
1388        // Body: single Stmt::Expr holding a 2-arm match.
1389        let stmts = bf.body.stmts();
1390        assert_eq!(stmts.len(), 1);
1391        let match_expr = match &stmts[0] {
1392            Stmt::Expr(s) => match &s.node {
1393                Expr::Match { subject: _, arms } => arms,
1394                _ => panic!("body root must be a match"),
1395            },
1396            _ => panic!("body root must be Stmt::Expr"),
1397        };
1398        assert_eq!(match_expr.len(), 2);
1399
1400        // True arm: body is `__buf` ident.
1401        let true_arm = match_expr
1402            .iter()
1403            .find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(true))))
1404            .expect("true arm");
1405        match &true_arm.body.node {
1406            Expr::Ident(name) => assert_eq!(name, "__buf"),
1407            other => panic!("true arm should be Ident(__buf), got {other:?}"),
1408        }
1409
1410        // False arm: tail-call to build__buffered with threaded buf.
1411        let false_arm = match_expr
1412            .iter()
1413            .find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(false))))
1414            .expect("false arm");
1415        let tail_data = match &false_arm.body.node {
1416            Expr::TailCall(d) => d,
1417            other => panic!("false arm should be TailCall, got {other:?}"),
1418        };
1419        assert_eq!(tail_data.target, "build__buffered");
1420        // Args: [n - 1, threaded-buffer-expr, __sep_ident]. acc-pos
1421        // (was index 1 in original) is now the threaded buffer; sep
1422        // appended at end.
1423        assert_eq!(tail_data.args.len(), 3);
1424        // Arg 1 is the buffer-threading composition; verify it's
1425        // `__buf_append(__buf_append_sep_unless_first(__buf, __sep), n)`.
1426        let outer = match &tail_data.args[1].node {
1427            Expr::FnCall(callee, args) => {
1428                match &callee.node {
1429                    Expr::Ident(name) => assert_eq!(name, "__buf_append"),
1430                    _ => panic!("expected Ident callee"),
1431                }
1432                args
1433            }
1434            _ => panic!("expected outer __buf_append FnCall"),
1435        };
1436        assert_eq!(outer.len(), 2);
1437        // First arg of outer = inner sep-then-buf.
1438        match &outer[0].node {
1439            Expr::FnCall(callee, _) => match &callee.node {
1440                Expr::Ident(name) => assert_eq!(name, "__buf_append_sep_unless_first"),
1441                _ => panic!("expected Ident callee for inner intrinsic"),
1442            },
1443            _ => panic!("expected inner __buf_append_sep_unless_first FnCall"),
1444        }
1445        // Second arg of outer = original `n` (the prepend's element).
1446        match &outer[1].node {
1447            Expr::Ident(name) => assert_eq!(name, "n"),
1448            _ => panic!("expected `n` ident as elem"),
1449        }
1450        // Last tail-call arg = __sep ident.
1451        match &tail_data.args[2].node {
1452            Expr::Ident(name) => assert_eq!(name, "__sep"),
1453            _ => panic!("expected __sep ident as last arg"),
1454        }
1455    }
1456
1457    #[test]
1458    fn detects_acc_param_at_arbitrary_index() {
1459        // Builder where the List<T> param is first and the tail-call
1460        // body wires the prepend at the same index. Detection has to
1461        // pin the acc position to where the prepend actually lands —
1462        // an earlier loose `any` check would silently pass even on
1463        // mismatched param/arg orderings, then synthesis would fail
1464        // to extract the element expression. Keep the body and the
1465        // params consistent so we exercise the real path.
1466        let true_body = call(dotted("List", "reverse"), vec![ident("acc")]);
1467        let prepend = call(dotted("List", "prepend"), vec![ident("col"), ident("acc")]);
1468        // Tail call: build(prepend(col, acc), col + 1)
1469        // — acc-position arg is at index 0, col+1 at index 1.
1470        let false_body = sp(Expr::TailCall(Box::new(TailCallData {
1471            target: "build".to_string(),
1472            args: vec![
1473                prepend,
1474                sp(Expr::BinOp(
1475                    BinOp::Add,
1476                    Box::new(ident("col")),
1477                    Box::new(sp(Expr::Literal(Literal::Int(1)))),
1478                )),
1479            ],
1480        })));
1481        let match_expr = sp(Expr::Match {
1482            subject: Box::new(sp(Expr::BinOp(
1483                BinOp::Gte,
1484                Box::new(ident("col")),
1485                Box::new(sp(Expr::Literal(Literal::Int(10)))),
1486            ))),
1487            arms: vec![
1488                MatchArm {
1489                    pattern: Pattern::Literal(Literal::Bool(true)),
1490                    body: Box::new(true_body),
1491                },
1492                MatchArm {
1493                    pattern: Pattern::Literal(Literal::Bool(false)),
1494                    body: Box::new(false_body),
1495                },
1496            ],
1497        });
1498        let fd = FnDef {
1499            name: "build".to_string(),
1500            line: 1,
1501            params: vec![
1502                ("acc".to_string(), "List<Int>".to_string()),
1503                ("col".to_string(), "Int".to_string()),
1504            ],
1505            return_type: "List<Int>".to_string(),
1506            effects: vec![],
1507            desc: None,
1508            body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
1509            resolution: None,
1510        };
1511        let info = compute_buffer_build_sinks(&[&fd]);
1512        let shape = info.get("build").expect("expected match");
1513        assert_eq!(shape.acc_param_idx, 0);
1514        assert_eq!(shape.acc_param_name, "acc");
1515    }
1516
1517    #[test]
1518    fn rejects_loose_prepend_in_non_acc_position() {
1519        // Earlier the detector accepted a fn whose tail call had a
1520        // prepend in *some* arg, regardless of position. That let
1521        // detection promise a sink the synthesizer couldn't actually
1522        // build. Make sure the tightened predicate refuses this.
1523        let mut fd = canonical_builder("build");
1524        // Reorder tail-call args so prepend ends up at index 0 instead
1525        // of index 1 — but keep params [(col, Int), (acc, List<Int>)],
1526        // so acc-position is index 1, where there's now a `col + 1`
1527        // expression (no prepend). Detection should refuse.
1528        {
1529            let body = std::sync::Arc::make_mut(&mut fd.body);
1530            let FnBody::Block(stmts) = body;
1531            if let Stmt::Expr(spanned) = &mut stmts[0]
1532                && let Expr::Match { arms, .. } = &mut spanned.node
1533            {
1534                for arm in arms.iter_mut() {
1535                    if matches!(arm.pattern, Pattern::Literal(Literal::Bool(false)))
1536                        && let Expr::TailCall(data) = &mut arm.body.node
1537                    {
1538                        data.args.reverse();
1539                    }
1540                }
1541            }
1542        }
1543        let info = compute_buffer_build_sinks(&[&fd]);
1544        assert!(
1545            info.get("build").is_none(),
1546            "loose-prepend (prepend not at acc-position) must not be detected"
1547        );
1548    }
1549
1550    #[test]
1551    fn skips_synth_when_no_rewriteable_call_site() {
1552        // A fn that matches the sink shape but whose only call site
1553        // doesn't fit the canonical fusion pattern (e.g. starts with a
1554        // non-empty initial accumulator, or the wrapper is an unrelated
1555        // function call rather than `String.join`) should NOT get a
1556        // synthesized `__buffered` variant. Generating one is bloat
1557        // and risks shadowing user fns.
1558        let sink = canonical_builder("build");
1559        // Dummy caller that uses `build` but not via `String.join(...)`.
1560        let caller = FnDef {
1561            name: "use_build".to_string(),
1562            line: 2,
1563            params: vec![],
1564            return_type: "List<Int>".to_string(),
1565            effects: vec![],
1566            desc: None,
1567            body: Arc::new(FnBody::Block(vec![Stmt::Expr(call(
1568                ident_expr("build"),
1569                vec![sp(Expr::Literal(Literal::Int(0))), sp(Expr::List(vec![]))],
1570            ))])),
1571            resolution: None,
1572        };
1573        let mut items = vec![
1574            crate::ast::TopLevel::FnDef(sink),
1575            crate::ast::TopLevel::FnDef(caller),
1576        ];
1577        let initial_count = items.len();
1578        let (sites, synth) = run_buffer_build_pass(&mut items);
1579        assert_eq!(sites, 0, "no fusion sites — no rewriteable call");
1580        assert_eq!(synth, 0, "no synth — nothing to fuse against");
1581        assert_eq!(items.len(), initial_count, "no buffered variant appended");
1582    }
1583
1584    #[test]
1585    fn external_reverse_pattern_round_trips() {
1586        // `match list { [] -> acc; [h, ..t] -> recurse(t, prepend(_, acc)) }`
1587        // sink + `String.join(List.reverse(<sink>(args, [])), sep)` call
1588        // site should detect, synth, and rewrite as a single fusion.
1589        let nil_body = ident("acc");
1590        let prepend = call(dotted("List", "prepend"), vec![ident("h"), ident("acc")]);
1591        let cons_body = sp(Expr::TailCall(Box::new(TailCallData {
1592            target: "build".to_string(),
1593            args: vec![ident("t"), prepend],
1594        })));
1595        let match_expr = sp(Expr::Match {
1596            subject: Box::new(ident("xs")),
1597            arms: vec![
1598                MatchArm {
1599                    pattern: Pattern::EmptyList,
1600                    body: Box::new(nil_body),
1601                },
1602                MatchArm {
1603                    pattern: Pattern::Cons("h".to_string(), "t".to_string()),
1604                    body: Box::new(cons_body),
1605                },
1606            ],
1607        });
1608        let sink = FnDef {
1609            name: "build".to_string(),
1610            line: 1,
1611            params: vec![
1612                ("xs".to_string(), "List<Int>".to_string()),
1613                ("acc".to_string(), "List<String>".to_string()),
1614            ],
1615            return_type: "List<String>".to_string(),
1616            effects: vec![],
1617            desc: None,
1618            body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
1619            resolution: None,
1620        };
1621        let info = compute_buffer_build_sinks(&[&sink]);
1622        let shape = info
1623            .get("build")
1624            .expect("external-reverse sink should be detected");
1625        assert_eq!(shape.kind, BufferBuildKind::ExternalReverse);
1626        assert_eq!(shape.acc_param_idx, 1);
1627
1628        // Caller: `String.join(List.reverse(build(xs, [])), "\n")`
1629        let join_call = call(
1630            dotted("String", "join"),
1631            vec![
1632                call(
1633                    dotted("List", "reverse"),
1634                    vec![call(
1635                        ident_expr("build"),
1636                        vec![ident("xs"), sp(Expr::List(vec![]))],
1637                    )],
1638                ),
1639                sp(Expr::Literal(Literal::Str("\n".to_string()))),
1640            ],
1641        );
1642        let caller = FnDef {
1643            name: "render".to_string(),
1644            line: 2,
1645            params: vec![("xs".to_string(), "List<Int>".to_string())],
1646            return_type: "String".to_string(),
1647            effects: vec![],
1648            desc: None,
1649            body: Arc::new(FnBody::Block(vec![Stmt::Expr(join_call)])),
1650            resolution: None,
1651        };
1652
1653        let mut items = vec![
1654            crate::ast::TopLevel::FnDef(sink),
1655            crate::ast::TopLevel::FnDef(caller),
1656        ];
1657        let (sites, synth) = run_buffer_build_pass(&mut items);
1658        assert_eq!(
1659            sites, 1,
1660            "external-reverse pattern should be one fusion site"
1661        );
1662        assert_eq!(synth, 1, "exactly one buffered variant for the used sink");
1663
1664        // The synthesized variant should be appended.
1665        let synth_present = items.iter().any(|it| match it {
1666            crate::ast::TopLevel::FnDef(fd) => fd.name == "build__buffered",
1667            _ => false,
1668        });
1669        assert!(synth_present, "build__buffered must be appended");
1670    }
1671
1672    fn ident_expr(name: &str) -> Spanned<Expr> {
1673        sp(Expr::Ident(name.to_string()))
1674    }
1675}