Skip to main content

rlx_compile/
memory.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Memory planning — liveness analysis and buffer assignment.
17//!
18//! This is the XLA feature that no other Rust framework has. It computes
19//! which intermediate tensors have non-overlapping lifetimes and assigns
20//! them to the same memory, minimizing total arena size.
21//!
22//! The output is a [`MemoryPlan`] that tells the runtime exactly how
23//! large the arena should be and where each tensor lives within it.
24
25use rlx_ir::op::BinaryOp;
26use rlx_ir::{Graph, NodeId, Op};
27use std::collections::HashMap;
28
29/// Extra bytes reserved after Input/Param/Constant slots so a kernel
30/// that writes slightly past its logical tensor size cannot stomp the
31/// next arena slot (e.g. small bias tensor adjacent to input_ids).
32const BOUNDARY_TAIL_GUARD_BYTES: usize = 128;
33
34fn boundary_min_slot_bytes(op: &rlx_ir::Op, alignment: usize) -> usize {
35    if matches!(
36        op,
37        rlx_ir::Op::Input { .. } | rlx_ir::Op::Param { .. } | rlx_ir::Op::Constant { .. }
38    ) {
39        alignment.max(1)
40    } else {
41        0
42    }
43}
44
45fn boundary_tail_guard(op: &rlx_ir::Op, alignment: usize) -> usize {
46    if matches!(
47        op,
48        rlx_ir::Op::Input { .. } | rlx_ir::Op::Param { .. } | rlx_ir::Op::Constant { .. }
49    ) {
50        alignment.max(BOUNDARY_TAIL_GUARD_BYTES)
51    } else {
52        0
53    }
54}
55/// Identify ops whose output is a *view* of an existing buffer — no
56/// copy needed, no separate arena slot. Returns the parent input index
57/// and the byte offset of the view within the parent.
58///
59/// Borrowed from MAX's "view-vs-copy" pattern (#46 in PLAN.md).
60/// The hard case (strided narrow on a non-outermost axis — e.g. BERT
61/// QKV split) requires kernels that consume strided inputs and is
62/// deferred. This function only catches the safely-elidable cases:
63///
64///   - **`Reshape`**: pure metadata; data layout is identical.
65///   - **`Cast`** with `src dtype == dst dtype`: pure metadata.
66///   - **`Narrow` on axis 0**: contiguous sub-slice of the parent;
67///     offset = `start * size_of_inner_in_bytes`.
68fn pure_view_offset(graph: &Graph, node: &rlx_ir::Node) -> Option<(NodeId, usize)> {
69    match &node.op {
70        Op::Reshape { .. } => Some((node.inputs[0], 0)),
71        Op::Cast { to } => {
72            let parent = graph.node(node.inputs[0]);
73            if parent.shape.dtype() == *to {
74                Some((node.inputs[0], 0))
75            } else {
76                None
77            }
78        }
79        Op::Narrow {
80            axis,
81            start,
82            len: _,
83        } if *axis == 0 => {
84            let parent = graph.node(node.inputs[0]);
85            // inner = product of dims after axis 0
86            let inner_elems: usize = (1..parent.shape.rank())
87                .map(|i| parent.shape.dim(i).unwrap_static())
88                .product();
89            let dt_bytes = parent.shape.dtype().size_bytes();
90            Some((node.inputs[0], start * inner_elems * dt_bytes))
91        }
92        _ => None,
93    }
94}
95
96/// Public predicate for backends — true iff this op should compile to
97/// a Nop because its output aliases a parent buffer (the memory
98/// planner has already aliased its slot).
99pub fn is_pure_view(graph: &Graph, node: &rlx_ir::Node) -> bool {
100    pure_view_offset(graph, node).is_some()
101}
102
103/// A buffer slot in the memory arena.
104#[derive(Debug, Clone)]
105pub struct BufferSlot {
106    /// Offset in bytes from the start of the arena.
107    pub offset: usize,
108    /// Size in bytes.
109    pub size: usize,
110}
111
112/// Complete memory plan for executing a graph.
113#[derive(Debug, Clone)]
114pub struct MemoryPlan {
115    /// Total arena size in bytes.
116    pub arena_size: usize,
117    /// Buffer assignment: NodeId → offset within arena.
118    pub assignments: HashMap<NodeId, BufferSlot>,
119    /// Node execution order (topological).
120    pub schedule: Vec<NodeId>,
121}
122
123impl MemoryPlan {
124    /// Sum of all assigned buffer sizes (i.e. how much memory the
125    /// plan would use if every node had its own slot). Useful for
126    /// reporting how much the liveness-aware sharing saved.
127    pub fn total_unshared_bytes(&self) -> usize {
128        self.assignments.values().map(|s| s.size).sum()
129    }
130
131    /// Bytes saved vs. naive "every node gets its own slot" — how
132    /// much the liveness analysis bought you.
133    pub fn bytes_saved(&self) -> usize {
134        self.total_unshared_bytes().saturating_sub(self.arena_size)
135    }
136
137    /// Render the buffer plan as a one-line-per-node table for
138    /// debugging — sorted by offset so adjacent buffers in memory
139    /// are adjacent in the report (plan #87).
140    ///
141    /// The output is parseable: `<offset>\t<size>\t%<node_id>`. Pipe
142    /// through `column -t` for human display, or grep / awk it for
143    /// scripted analysis.
144    pub fn report(&self) -> String {
145        let mut rows: Vec<(usize, usize, NodeId)> = self
146            .assignments
147            .iter()
148            .map(|(id, slot)| (slot.offset, slot.size, *id))
149            .collect();
150        rows.sort();
151        let mut out = String::new();
152        out.push_str(&format!(
153            "# arena_size={} total_unshared={} saved={}\n",
154            self.arena_size,
155            self.total_unshared_bytes(),
156            self.bytes_saved()
157        ));
158        out.push_str("# offset\tsize\tnode\n");
159        for (off, sz, id) in rows {
160            out.push_str(&format!("{off}\t{sz}\t{id}\n"));
161        }
162        out
163    }
164}
165
166/// Collect view-node aliases for embedding in LIR.
167pub fn collect_view_aliases(graph: &Graph) -> HashMap<NodeId, (NodeId, usize)> {
168    let mut out = HashMap::new();
169    for node in graph.nodes() {
170        if pure_view_offset(graph, node).is_some() {
171            let (root, off) = resolve_view_root(graph, node.id);
172            out.insert(node.id, (root, off));
173        }
174    }
175    out
176}
177
178/// Walk view chains until reaching a non-view ancestor. Returns the
179/// root buffer-owning node and the cumulative byte offset from the root.
180fn resolve_view_root(graph: &Graph, mut id: NodeId) -> (NodeId, usize) {
181    let mut total_offset = 0usize;
182    loop {
183        let node = graph.node(id);
184        match pure_view_offset(graph, node) {
185            Some((parent, off)) => {
186                total_offset += off;
187                id = parent;
188            }
189            None => return (id, total_offset),
190        }
191    }
192}
193
194/// Compute the live range [birth, death] for each node's output buffer.
195/// Birth = when the node produces its output.
196/// Death = the last time any consumer reads it.
197#[allow(dead_code)]
198fn compute_live_ranges(graph: &Graph) -> HashMap<NodeId, (usize, usize)> {
199    compute_live_ranges_opts(graph, true)
200}
201
202fn compute_live_ranges_opts(
203    graph: &Graph,
204    pin_output_ancestors: bool,
205) -> HashMap<NodeId, (usize, usize)> {
206    let mut ranges: HashMap<NodeId, (usize, usize)> = HashMap::new();
207
208    for (step, node) in graph.nodes().iter().enumerate() {
209        // Birth: this node's output is produced at this step
210        ranges.entry(node.id).or_insert((step, step));
211
212        // Extend death of all inputs to at least this step. For view
213        // inputs, attribute the read to the *root* buffer so the
214        // underlying allocation stays alive while any view of it is
215        // still being read (#46 view-aliasing pattern).
216        for &input in &node.inputs {
217            let (root, _off) = resolve_view_root(graph, input);
218            ranges.entry(root).and_modify(|r| r.1 = r.1.max(step));
219            // Also track the view itself so we don't leave a dangling
220            // entry; views inherit the root's range later in
221            // plan_memory_aligned.
222            if root != input {
223                ranges.entry(input).and_modify(|r| r.1 = r.1.max(step));
224            }
225        }
226    }
227
228    // Extend death of output nodes to the end
229    let last_step = graph.len();
230    for &out in &graph.outputs {
231        let (root, _off) = resolve_view_root(graph, out);
232        ranges.entry(root).and_modify(|r| r.1 = last_step);
233        if root != out {
234            ranges.entry(out).and_modify(|r| r.1 = last_step);
235        }
236    }
237
238    // All producers feeding graph outputs must stay live through the final
239    // read-back (e.g. Cast f32→i64 feeding a boundary output). Without
240    // this, a later epilogue tensor can reuse an ancestor slot while thunks
241    // still run out of schedule order on overlapping paths.
242    {
243        let mut stack: Vec<NodeId> = graph.outputs.clone();
244        let mut seen = std::collections::HashSet::new();
245        while let Some(id) = stack.pop() {
246            if !seen.insert(id) {
247                continue;
248            }
249            let (root, _) = resolve_view_root(graph, id);
250            ranges.entry(root).and_modify(|r| r.1 = last_step);
251            if root != id {
252                ranges.entry(id).and_modify(|r| r.1 = last_step);
253            }
254            // Walking the full transitive ancestor DAG pins (almost) every node of
255            // a deep feed-forward graph to the final step, which destroys slot reuse
256            // — the HiFi-GAN decoder ballooned to a 5 GB arena (over wgpu's 4 GB bind
257            // limit) purely from this. `pin_output_ancestors=false` keeps only the
258            // read-back protection on the output nodes (and their view roots), which
259            // is sufficient for in-order executors and drops that arena to ~0.12 GB.
260            if pin_output_ancestors {
261                for &input in &graph.node(id).inputs {
262                    stack.push(input);
263                }
264            }
265        }
266    }
267
268    // Params, Inputs, and Constants live for the ENTIRE execution.
269    // Params/Inputs are pre-loaded externally; Constants are pre-loaded
270    // by the runtime's compile step (see backend.rs::compile_inner). In
271    // all three cases the slot must not be overwritten by intermediate
272    // buffer sharing, otherwise iteration 2 of a training/inference
273    // loop would read whatever the previous run scribbled into it.
274    for node in graph.nodes() {
275        if matches!(
276            node.op,
277            rlx_ir::Op::Param { .. } | rlx_ir::Op::Input { .. } | rlx_ir::Op::Constant { .. }
278        ) {
279            ranges.entry(node.id).and_modify(|r| {
280                r.0 = 0;
281                r.1 = last_step;
282            });
283        }
284    }
285
286    ranges
287}
288
289/// Keep packed `[B,S,3,H,D]` QKV parents alive through Attention. Without
290/// this, liveness ends after the Narrow ops and the planner may reuse the
291/// parent slot for the attention output while the CPU fused path (and
292/// wgpu packed stride path) still read Q/K/V from that buffer.
293fn extend_node_chain_liveness_to_end(
294    graph: &Graph,
295    ranges: &mut HashMap<NodeId, (usize, usize)>,
296    start: NodeId,
297    last_step: usize,
298) {
299    let mut stack = vec![start];
300    let mut seen = std::collections::HashSet::new();
301    while let Some(id) = stack.pop() {
302        if !seen.insert(id) {
303            continue;
304        }
305        let (root, _) = resolve_view_root(graph, id);
306        ranges.entry(root).and_modify(|r| r.1 = last_step);
307        if root != id {
308            ranges.entry(id).and_modify(|r| r.1 = last_step);
309        }
310        for &input in &graph.node(id).inputs {
311            stack.push(input);
312        }
313    }
314}
315
316/// Keep primary data inputs alive through graph end for `Op::Custom("onnx.*")`
317/// thunks that read activations after parallel branches would otherwise reuse slots.
318fn extend_custom_op_input_liveness(graph: &Graph, ranges: &mut HashMap<NodeId, (usize, usize)>) {
319    let last_step = graph.len();
320    for node in graph.nodes() {
321        let Op::Custom {
322            name, num_inputs, ..
323        } = &node.op
324        else {
325            continue;
326        };
327        if !name.starts_with("onnx.") {
328            continue;
329        }
330        let n = (*num_inputs as usize).min(node.inputs.len());
331        for &input in &node.inputs[..n] {
332            extend_node_chain_liveness_to_end(graph, ranges, input, last_step);
333        }
334    }
335    // Op::DequantMatMul / Op::DequantGroupedMatMul on Metal may fall back to a
336    // deferred-host execution path (`RLX_METAL_DEQUANT_GPU_DISABLE=1`, or when
337    // `dequant_scratch_off == 0`, or for schemes the GPU kernel doesn't support).
338    // The deferred path runs at a `flush_deferred_host` sync point INSIDE a
339    // later `e!()` macro invocation — by then the activation buffer may have
340    // been reused by a subsequent GPU op because the planner sees the host op
341    // as a normal-step consumer and considers the input free for reuse after
342    // that step. Without this extension, attention output (last read by the
343    // o_proj DequantMatMul) gets clobbered between attention's GPU dispatch
344    // and the host o_proj flush, producing exact-zero downstream values
345    // (task #50). The fix is conservative — extends only the direct
346    // activation input (operand 0), not the whole ancestor chain — because
347    // weights (operand 1+) are always Params and already pinned.
348    for node in graph.nodes() {
349        match &node.op {
350            Op::DequantMatMul { .. } => {
351                if let Some(&x) = node.inputs.first() {
352                    extend_node_chain_liveness_to_end(graph, ranges, x, last_step);
353                }
354            }
355            Op::DequantGroupedMatMul { .. } => {
356                if let Some(&x) = node.inputs.first() {
357                    extend_node_chain_liveness_to_end(graph, ranges, x, last_step);
358                }
359            }
360            _ => {}
361        }
362    }
363}
364
365/// Albert-style blocks reuse hidden buffers across many sequential Add/LN
366/// stages; keep residual inputs alive through graph end when this graph uses
367/// ONNX `QMatMul` thunks (marker for the bundled ONNX import path).
368fn extend_bert_hidden_liveness(graph: &Graph, ranges: &mut HashMap<NodeId, (usize, usize)>) {
369    let uses_onnx_qmatmul = graph.nodes().iter().any(|node| {
370        matches!(
371            &node.op,
372            Op::Custom { name, .. } if name == "onnx.QMatMul" || name == "onnx.ActCopy"
373        )
374    });
375    if !uses_onnx_qmatmul {
376        return;
377    }
378    let last_step = graph.len();
379    for node in graph.nodes() {
380        match &node.op {
381            Op::LayerNorm { .. } | Op::LayerNorm2d { .. } => {
382                if let Some(&input) = node.inputs.first() {
383                    extend_node_chain_liveness_to_end(graph, ranges, input, last_step);
384                }
385                ranges.entry(node.id).and_modify(|r| r.1 = last_step);
386            }
387            Op::Binary(BinaryOp::Add) => {
388                for &input in &node.inputs {
389                    extend_node_chain_liveness_to_end(graph, ranges, input, last_step);
390                }
391                ranges.entry(node.id).and_modify(|r| r.1 = last_step);
392            }
393            _ => {}
394        }
395    }
396}
397
398fn extend_onnx_duration_epilogue_liveness(
399    graph: &Graph,
400    ranges: &mut HashMap<NodeId, (usize, usize)>,
401) {
402    // Waveform-only graphs still contain duration-loop nodes in IR, but when
403    // duration is not exported we can use normal slot reuse.
404    if !graph_exports_onnx_duration(graph) {
405        return;
406    }
407    let last_step = graph.len();
408    for &out in &graph.outputs {
409        extend_node_chain_liveness_to_end(graph, ranges, out, last_step);
410    }
411    for node in graph.nodes() {
412        let keep = match &node.op {
413            Op::Custom { name, .. }
414                if name == "onnx.ConcatFromSequence" || name == "onnx.KittenConcatFromSequence" =>
415            {
416                true
417            }
418            Op::Expand { .. } => node.shape.dtype() == rlx_ir::DType::I64,
419            Op::Cast { to, .. } => *to == rlx_ir::DType::I64,
420            Op::Where => node.shape.dtype() == rlx_ir::DType::I64,
421            Op::Binary(_) => node.shape.dtype() == rlx_ir::DType::I64,
422            _ => node.shape.dtype() == rlx_ir::DType::I64 && node.shape.rank() <= 2,
423        };
424        if keep {
425            extend_node_chain_liveness_to_end(graph, ranges, node.id, last_step);
426            ranges.entry(node.id).and_modify(|r| r.1 = last_step);
427        }
428    }
429}
430
431fn graph_exports_onnx_duration(graph: &Graph) -> bool {
432    graph
433        .outputs
434        .iter()
435        .any(|&id| graph.node(id).shape.dtype() == rlx_ir::DType::I64)
436}
437
438#[allow(dead_code)]
439fn graph_uses_onnx_duration_epilogue(graph: &Graph) -> bool {
440    if graph.nodes().iter().any(|node| {
441        matches!(
442            &node.op,
443            Op::Custom { name, .. }
444                if name == "onnx.ConcatFromSequence"
445                    || name == "onnx.KittenConcatFromSequence"
446        )
447    }) {
448        return true;
449    }
450    graph_exports_onnx_duration(graph)
451}
452
453fn extend_packed_qkv_parent_liveness(graph: &Graph, ranges: &mut HashMap<NodeId, (usize, usize)>) {
454    for (step, node) in graph.nodes().iter().enumerate() {
455        let rlx_ir::Op::Attention { .. } = &node.op else {
456            continue;
457        };
458        if node.inputs.len() < 3 {
459            continue;
460        }
461        let Some((parent, _, _)) = rlx_ir::detect_packed_bshd_qkv_attention(
462            graph,
463            node.inputs[0],
464            node.inputs[1],
465            node.inputs[2],
466        ) else {
467            continue;
468        };
469        let (root, _) = resolve_view_root(graph, parent);
470        ranges.entry(root).and_modify(|r| r.1 = r.1.max(step));
471        if root != parent {
472            ranges.entry(parent).and_modify(|r| r.1 = r.1.max(step));
473        }
474    }
475}
476
477/// Assign buffers using a greedy best-fit algorithm.
478///
479/// Sorts buffers by size (largest first), then for each buffer finds
480/// the smallest free gap in the arena during its live interval.
481/// This is a simplified version of XLA's GlobalDecreasingSizeBestFitHeap.
482/// Controls which graph boundaries receive arena slots during planning.
483///
484/// Inference graphs use [`Self::inference`] (all boundaries allocated).
485/// Backward graphs in a training pair use [`Self::backward_activations_only`]:
486/// parameters borrow offsets from the forward plan via [`SharedWeightLayout`]
487/// so weights are not stored twice in the activation arena.
488#[derive(Debug, Clone, Copy, PartialEq, Eq)]
489pub struct MemoryPlanOptions {
490    pub allocate_params: bool,
491    pub allocate_inputs: bool,
492    pub allocate_constants: bool,
493    /// When true (or env `RLX_ARENA_NO_REUSE=1`), every tensor gets a unique arena slot.
494    pub arena_no_reuse: bool,
495    /// When true (default), pin the *entire* transitive ancestor DAG of the graph
496    /// outputs to the final step. That's a conservative guard for out-of-order
497    /// execution, but it destroys slot reuse on deep feed-forward graphs (the
498    /// HiFi-GAN decoder hit a 5 GB arena). In-order executors (CPU, wgpu) can set
499    /// this false: only the output nodes are pinned (read-back protection), which
500    /// is sufficient and keeps the arena small.
501    pub pin_output_ancestors: bool,
502}
503
504impl MemoryPlanOptions {
505    pub fn inference() -> Self {
506        Self {
507            allocate_params: true,
508            allocate_inputs: true,
509            allocate_constants: true,
510            arena_no_reuse: std::env::var("RLX_ARENA_NO_REUSE")
511                .ok()
512                .is_some_and(|v| v == "1" || v.eq_ignore_ascii_case("true")),
513            pin_output_ancestors: true,
514        }
515    }
516
517    /// Activations + inputs/constants only; params bound via [`SharedWeightLayout`].
518    pub fn backward_activations_only() -> Self {
519        Self {
520            allocate_params: false,
521            allocate_inputs: true,
522            allocate_constants: true,
523            arena_no_reuse: std::env::var("RLX_ARENA_NO_REUSE")
524                .ok()
525                .is_some_and(|v| v == "1" || v.eq_ignore_ascii_case("true")),
526            pin_output_ancestors: true,
527        }
528    }
529}
530
531impl Default for MemoryPlanOptions {
532    fn default() -> Self {
533        Self::inference()
534    }
535}
536
537/// Persistent parameter slots extracted from a forward [`MemoryPlan`].
538#[derive(Debug, Clone, PartialEq, Eq)]
539pub struct SharedWeightLayout {
540    pub arena_size: usize,
541    pub slots: Vec<WeightSlot>,
542}
543
544/// One named parameter and its byte range in the shared weight region.
545#[derive(Debug, Clone, PartialEq, Eq)]
546pub struct WeightSlot {
547    pub name: String,
548    pub forward_id: NodeId,
549    pub offset: usize,
550    pub size: usize,
551}
552
553impl SharedWeightLayout {
554    /// Collect `Op::Param` slots from a forward memory plan (by param name).
555    pub fn from_forward(graph: &Graph, plan: &MemoryPlan) -> Self {
556        let mut slots = Vec::new();
557        for node in graph.nodes() {
558            if let rlx_ir::Op::Param { name } = &node.op {
559                if let Some(slot) = plan.assignments.get(&node.id) {
560                    slots.push(WeightSlot {
561                        name: name.clone(),
562                        forward_id: node.id,
563                        offset: slot.offset,
564                        size: slot.size,
565                    });
566                }
567            }
568        }
569        slots.sort_by(|a, b| a.name.cmp(&b.name));
570        let arena_size = slots.iter().map(|s| s.offset + s.size).max().unwrap_or(0);
571        Self { arena_size, slots }
572    }
573
574    /// Map backward-graph `Op::Param` nodes to the forward weight offsets.
575    pub fn apply_to_plan(&self, graph: &Graph, plan: &mut MemoryPlan) {
576        let by_name: std::collections::HashMap<&str, &WeightSlot> =
577            self.slots.iter().map(|s| (s.name.as_str(), s)).collect();
578        for node in graph.nodes() {
579            if let rlx_ir::Op::Param { name } = &node.op {
580                let Some(slot) = by_name.get(name.as_str()) else {
581                    continue;
582                };
583                plan.assignments.insert(
584                    node.id,
585                    BufferSlot {
586                        offset: slot.offset,
587                        size: slot.size,
588                    },
589                );
590            }
591        }
592        plan.arena_size = plan.arena_size.max(self.arena_size);
593    }
594}
595
596#[inline]
597fn plans_boundary_buffer(op: &rlx_ir::Op, opts: MemoryPlanOptions) -> bool {
598    match op {
599        rlx_ir::Op::Param { .. } => opts.allocate_params,
600        rlx_ir::Op::Input { .. } => opts.allocate_inputs,
601        rlx_ir::Op::Constant { .. } => opts.allocate_constants,
602        _ => true,
603    }
604}
605
606/// Plan memory with default 64-byte alignment.
607pub fn plan_memory(graph: &Graph) -> MemoryPlan {
608    plan_memory_aligned(graph, 64)
609}
610
611/// Plan memory with custom alignment and boundary allocation policy.
612pub fn plan_memory_with_options(
613    graph: &Graph,
614    alignment: usize,
615    opts: MemoryPlanOptions,
616) -> MemoryPlan {
617    plan_memory_aligned_inner(graph, alignment, opts, None, false)
618}
619
620/// Plan memory with custom alignment (inference defaults).
621pub fn plan_memory_aligned(graph: &Graph, alignment: usize) -> MemoryPlan {
622    plan_memory_aligned_inner(graph, alignment, MemoryPlanOptions::default(), None, false)
623}
624
625/// Liveness-aware planning with every slot sized as `num_elements * 4`
626/// bytes (wgpu / uniform-f32 arenas). Reuses dead tensor slots so large
627/// `[n, n]` pairwise graphs stay under WebGPU's 128 MiB binding cap.
628pub fn plan_memory_f32_uniform(graph: &Graph, alignment: usize) -> MemoryPlan {
629    // wgpu executes dispatches in schedule order, so the conservative
630    // output-ancestor liveness pin isn't needed — and dropping it is what keeps
631    // deep decoders (HiFi-GAN) under wgpu's 4 GB storage-buffer binding limit.
632    let opts = MemoryPlanOptions {
633        pin_output_ancestors: false,
634        ..MemoryPlanOptions::default()
635    };
636    plan_memory_aligned_inner(graph, alignment, opts, None, true)
637}
638
639/// Plan backward activations, then alias params onto `weights`.
640pub fn plan_memory_backward(
641    graph: &Graph,
642    alignment: usize,
643    weights: &SharedWeightLayout,
644) -> MemoryPlan {
645    plan_memory_aligned_inner(
646        graph,
647        alignment,
648        MemoryPlanOptions::backward_activations_only(),
649        Some(weights),
650        false,
651    )
652}
653
654#[inline]
655fn node_slot_bytes(node: &rlx_ir::Node, f32_uniform: bool) -> usize {
656    if f32_uniform {
657        node.shape.num_elements().unwrap_or(0) * 4
658    } else {
659        node.shape.size_bytes().unwrap_or(0)
660    }
661}
662
663fn plan_memory_aligned_inner(
664    graph: &Graph,
665    alignment: usize,
666    opts: MemoryPlanOptions,
667    weights: Option<&SharedWeightLayout>,
668    f32_uniform: bool,
669) -> MemoryPlan {
670    let mut ranges = compute_live_ranges_opts(graph, opts.pin_output_ancestors);
671    extend_packed_qkv_parent_liveness(graph, &mut ranges);
672    extend_custom_op_input_liveness(graph, &mut ranges);
673    extend_bert_hidden_liveness(graph, &mut ranges);
674    extend_onnx_duration_epilogue_liveness(graph, &mut ranges);
675    let mut opts = opts;
676    if graph_exports_onnx_duration(graph) {
677        opts.arena_no_reuse = true;
678    }
679    // Collect buffers that need allocation (skip inputs/params — external)
680    struct BufInfo {
681        id: NodeId,
682        size: usize,
683        birth: usize,
684        death: usize,
685    }
686
687    let mut buffers: Vec<BufInfo> = Vec::new();
688    for node in graph.nodes() {
689        // Skip view nodes — they alias their parent's buffer (handled
690        // in the post-pass below). Plan #46.
691        if pure_view_offset(graph, node).is_some() {
692            continue;
693        }
694        let raw_size = node_slot_bytes(node, f32_uniform);
695        let size = if raw_size == 0 {
696            boundary_min_slot_bytes(&node.op, alignment)
697        } else {
698            raw_size
699        };
700        if size > 0
701            && let Some(&(birth, death)) = ranges.get(&node.id)
702            && plans_boundary_buffer(&node.op, opts)
703        {
704            buffers.push(BufInfo {
705                id: node.id,
706                size,
707                birth,
708                death,
709            });
710        }
711    }
712
713    // Sort by size descending (largest first gets priority placement)
714    buffers.sort_by_key(|b| std::cmp::Reverse(b.size));
715
716    // Greedy first-fit allocation
717    let mut assignments: HashMap<NodeId, BufferSlot> = HashMap::new();
718    let mut arena_size: usize = 0;
719
720    // Track allocated regions with their live ranges
721    let mut placed: Vec<(usize, usize, usize, usize)> = Vec::new(); // (offset, size, birth, death)
722
723    for buf in &buffers {
724        let align = alignment;
725        let node = graph.node(buf.id);
726        let tail_guard = boundary_tail_guard(&node.op, align);
727        let placement_size = buf.size + tail_guard;
728        let mut best_offset: Option<usize> = None;
729
730        // Collect candidate start offsets: 0 plus the end of every placed
731        // buffer that could border a free gap.
732        let mut candidates = vec![0usize];
733        for &(p_off, p_size, _, _) in &placed {
734            candidates.push(p_off + p_size);
735        }
736        candidates.sort_unstable();
737        candidates.dedup();
738
739        for &candidate_offset in &candidates {
740            let aligned = (candidate_offset + align - 1) & !(align - 1);
741            let end = aligned + placement_size;
742
743            let conflict = placed.iter().any(|&(p_off, p_size, p_birth, p_death)| {
744                let p_end = p_off + p_size;
745                let mem_overlap = aligned < p_end && end > p_off;
746                let time_overlap = buf.birth <= p_death && buf.death >= p_birth;
747                mem_overlap && time_overlap
748            });
749
750            if !conflict {
751                match best_offset {
752                    None => best_offset = Some(aligned),
753                    Some(best) if aligned < best => best_offset = Some(aligned),
754                    _ => {}
755                }
756            }
757        }
758
759        let aligned = if opts.arena_no_reuse {
760            (arena_size + align - 1) & !(align - 1)
761        } else {
762            best_offset.unwrap_or_else(|| {
763                // No gap fit — append at arena tail.
764                (arena_size + align - 1) & !(align - 1)
765            })
766        };
767        assignments.insert(
768            buf.id,
769            BufferSlot {
770                offset: aligned,
771                size: buf.size,
772            },
773        );
774        placed.push((aligned, placement_size, buf.birth, buf.death));
775        arena_size = arena_size.max(aligned + placement_size);
776    }
777
778    // ── View aliasing pass (plan #46) ────────────────────────
779    // Every view node points at its root buffer's slot, offset by the
780    // accumulated view offset. The root has its own allocation above;
781    // views just borrow its bytes. This is the post-pass — done after
782    // root allocations are placed so we have offsets to point at.
783    for node in graph.nodes() {
784        if pure_view_offset(graph, node).is_some() {
785            let (root, off) = resolve_view_root(graph, node.id);
786            if let Some(root_slot) = assignments.get(&root).cloned() {
787                let view_size = node_slot_bytes(node, f32_uniform);
788                assignments.insert(
789                    node.id,
790                    BufferSlot {
791                        offset: root_slot.offset + off,
792                        size: view_size,
793                    },
794                );
795            }
796        }
797    }
798
799    let schedule = graph.topo_order().collect();
800
801    let mut plan = MemoryPlan {
802        arena_size,
803        assignments,
804        schedule,
805    };
806    if let Some(w) = weights {
807        w.apply_to_plan(graph, &mut plan);
808    }
809    plan
810}
811
812#[cfg(test)]
813mod tests {
814    use super::*;
815    use rlx_ir::*;
816
817    #[test]
818    fn non_overlapping_buffers_share_memory() {
819        let mut g = Graph::new("test");
820        let f = DType::F32;
821
822        let x = g.input("x", Shape::new(&[100, 384], f)); // 153.6KB
823        let w1 = g.param("w1", Shape::new(&[384, 384], f));
824        let w2 = g.param("w2", Shape::new(&[384, 384], f));
825
826        // mm1 is only used by mm2's input; after mm2, mm1 is dead
827        let mm1 = g.matmul(x, w1, Shape::new(&[100, 384], f)); // 153.6KB, live [4, 5]
828        let mm2 = g.matmul(mm1, w2, Shape::new(&[100, 384], f)); // 153.6KB, live [5, ∞]
829        g.set_outputs(vec![mm2]);
830
831        let plan = plan_memory(&g);
832        println!("Arena size: {} bytes", plan.arena_size);
833        for (id, slot) in &plan.assignments {
834            if let Some((b, d)) = compute_live_ranges(&g).get(id) {
835                println!(
836                    "  {id}: offset={}, size={}, live=[{b}, {d}]",
837                    slot.offset, slot.size
838                );
839            }
840        }
841
842        // Logical slot sizes omit 64-byte alignment gaps and param tail guards
843        // (see `boundary_tail_guard`). Arena may be slightly larger than that sum
844        // even when temporaries reuse gaps; cap slack at one guard per slot.
845        let total_logical: usize = plan.assignments.values().map(|s| s.size).sum();
846        let align_slack = plan.assignments.len() * BOUNDARY_TAIL_GUARD_BYTES;
847        assert!(
848            plan.arena_size <= total_logical + align_slack,
849            "arena {} should be <= logical sum {} + slack {}",
850            plan.arena_size,
851            total_logical,
852            align_slack
853        );
854    }
855
856    #[test]
857    fn plan_report_includes_savings() {
858        // Plan #87: the public report() string surfaces enough info
859        // for debug tooling — arena size, unshared total, saved
860        // bytes, and a per-buffer table sorted by offset.
861        let mut g = Graph::new("rep");
862        let f = DType::F32;
863        let x = g.input("x", Shape::new(&[16], f));
864        let w = g.param("w", Shape::new(&[16, 16], f));
865        let mm1 = g.matmul(x, w, Shape::new(&[1, 16], f));
866        let mm2 = g.matmul(mm1, w, Shape::new(&[1, 16], f));
867        g.set_outputs(vec![mm2]);
868
869        let plan = plan_memory(&g);
870        let r = plan.report();
871        // Header carries the headline numbers.
872        assert!(r.starts_with("# arena_size="));
873        assert!(r.contains("total_unshared="));
874        assert!(r.contains("saved="));
875        // Body is parseable (offset\tsize\tnode), sorted ascending.
876        let body: Vec<&str> = r.lines().filter(|l| !l.starts_with('#')).collect();
877        assert!(!body.is_empty());
878        // assignments map → at least mm1 + mm2 + x + w should appear.
879        assert!(plan.assignments.contains_key(&mm1));
880        assert!(plan.assignments.contains_key(&mm2));
881    }
882
883    #[test]
884    fn view_ops_alias_parent_slot() {
885        // Reshape, same-dtype Cast, and axis-0 Narrow should NOT get
886        // their own arena slot — they alias the parent (#46).
887        use rlx_ir::GraphExt;
888        let mut g = Graph::new("views");
889        let f = DType::F32;
890        let x = g.input("x", Shape::new(&[8, 4], f)); // 128B
891        let w = g.param("w", Shape::new(&[4, 4], f)); // 64B
892        let mm = g.matmul(x, w, Shape::new(&[8, 4], f)); // 128B (root)
893        let r = g.reshape_(mm, vec![32]); // VIEW (Reshape)
894        let c = g.cast(r, DType::F32); // VIEW (same-dtype Cast)
895        let n = g.narrow_(c, 0, 8, 16); // VIEW (axis-0 Narrow)
896        g.set_outputs(vec![n]);
897
898        let plan = plan_memory(&g);
899
900        // All three view nodes should share mm's offset (with adjustment
901        // for the narrow's start=8 → +8*4 = 32 bytes).
902        let mm_off = plan.assignments[&mm].offset;
903        assert_eq!(
904            plan.assignments[&r].offset, mm_off,
905            "reshape view should alias mm slot exactly"
906        );
907        assert_eq!(
908            plan.assignments[&c].offset, mm_off,
909            "same-dtype cast view should alias mm slot exactly"
910        );
911        assert_eq!(
912            plan.assignments[&n].offset,
913            mm_off + 32,
914            "axis-0 narrow start=8 should alias mm slot + 8*4 bytes"
915        );
916        assert_eq!(
917            plan.assignments[&n].size, 64,
918            "narrow view's size is its own (16 f32 = 64B), not parent's"
919        );
920    }
921
922    #[test]
923    fn backward_plan_aliases_forward_param_slots() {
924        let f = DType::F32;
925        let mut fwd = Graph::new("fwd");
926        let x = fwd.input("x", Shape::new(&[2, 4], f));
927        let w = fwd.param("w", Shape::new(&[4, 4], f));
928        let mm = fwd.matmul(x, w, Shape::new(&[2, 4], f));
929        fwd.set_outputs(vec![mm]);
930        let fwd_plan = plan_memory_aligned(&fwd, 64);
931        let layout = SharedWeightLayout::from_forward(&fwd, &fwd_plan);
932
933        let mut bwd = Graph::new("bwd_grad");
934        let x2 = bwd.input("x", Shape::new(&[2, 4], f));
935        let w2 = bwd.param("w", Shape::new(&[4, 4], f));
936        let mm2 = bwd.matmul(x2, w2, Shape::new(&[2, 4], f));
937        bwd.set_outputs(vec![mm2]);
938
939        let bwd_plan = plan_memory_backward(&bwd, 64, &layout);
940        let fwd_w_off = fwd_plan.assignments[&w].offset;
941        let bwd_w_off = bwd_plan.assignments[&w2].offset;
942        assert_eq!(bwd_w_off, fwd_w_off, "backward w must share forward offset");
943        assert!(
944            !bwd_plan.assignments.contains_key(&w2)
945                || bwd_plan.assignments[&w2].offset == fwd_w_off
946        );
947    }
948
949    #[test]
950    fn overlapping_buffers_get_separate_memory() {
951        let mut g = Graph::new("test");
952        let f = DType::F32;
953
954        let x = g.input("x", Shape::new(&[100, 384], f));
955        let w = g.param("w", Shape::new(&[384, 384], f));
956
957        let mm = g.matmul(x, w, Shape::new(&[100, 384], f));
958        // Both mm and x are live at the same time (mm uses x)
959        // x is also an output, so it stays live
960        let add = g.binary(BinaryOp::Add, mm, x, Shape::new(&[100, 384], f));
961        g.set_outputs(vec![add]);
962
963        let plan = plan_memory(&g);
964        let mm_slot = &plan.assignments[&mm];
965        let add_slot = &plan.assignments[&add];
966
967        // mm and add overlap in time, so they must not overlap in memory
968        let mm_end = mm_slot.offset + mm_slot.size;
969        let add_end = add_slot.offset + add_slot.size;
970        let no_overlap = mm_end <= add_slot.offset || add_end <= mm_slot.offset;
971        assert!(no_overlap, "overlapping buffers must have separate memory");
972    }
973
974    #[test]
975    fn zero_length_inputs_get_arena_slots() {
976        let mut g = Graph::new("empty_past");
977        let f = DType::F32;
978        let past = g.input("past_k", Shape::new(&[1, 0, 8], f));
979        let x = g.input("x", Shape::new(&[1, 1, 8], f));
980        let cat = g.concat(vec![past, x], 1, Shape::new(&[1, 1, 8], f));
981        g.set_outputs(vec![cat]);
982
983        let plan = plan_memory(&g);
984        assert!(
985            plan.assignments.contains_key(&past),
986            "zero-length decode past input must have an arena slot"
987        );
988        assert!(plan.assignments[&past].size >= 64);
989    }
990
991    #[test]
992    fn duration_export_forces_no_reuse_waveform_only_does_not() {
993        let f = DType::F32;
994        let mut wave_only = Graph::new("wave_only");
995        let w = wave_only.input("wave", Shape::new(&[1024], f));
996        wave_only.set_outputs(vec![w]);
997        assert!(!graph_exports_onnx_duration(&wave_only));
998
999        let mut dual = Graph::new("dual");
1000        let w2 = dual.input("wave", Shape::new(&[1024], f));
1001        let d = dual.input("dur", Shape::new(&[8], DType::I64));
1002        dual.set_outputs(vec![w2, d]);
1003        assert!(graph_exports_onnx_duration(&dual));
1004    }
1005}