Skip to main content

rlx_compile/
memory.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Memory planning — liveness analysis and buffer assignment.
17//!
18//! This is the XLA feature that no other Rust framework has. It computes
19//! which intermediate tensors have non-overlapping lifetimes and assigns
20//! them to the same memory, minimizing total arena size.
21//!
22//! The output is a [`MemoryPlan`] that tells the runtime exactly how
23//! large the arena should be and where each tensor lives within it.
24
25use rlx_ir::op::BinaryOp;
26use rlx_ir::{Graph, NodeId, Op};
27use std::collections::HashMap;
28
29/// Extra bytes reserved after Input/Param/Constant slots so a kernel
30/// that writes slightly past its logical tensor size cannot stomp the
31/// next arena slot (e.g. small bias tensor adjacent to input_ids).
32const BOUNDARY_TAIL_GUARD_BYTES: usize = 128;
33
34fn boundary_tail_guard(op: &rlx_ir::Op, alignment: usize) -> usize {
35    if matches!(
36        op,
37        rlx_ir::Op::Input { .. } | rlx_ir::Op::Param { .. } | rlx_ir::Op::Constant { .. }
38    ) {
39        alignment.max(BOUNDARY_TAIL_GUARD_BYTES)
40    } else {
41        0
42    }
43}
44/// Identify ops whose output is a *view* of an existing buffer — no
45/// copy needed, no separate arena slot. Returns the parent input index
46/// and the byte offset of the view within the parent.
47///
48/// Borrowed from MAX's "view-vs-copy" pattern (#46 in PLAN.md).
49/// The hard case (strided narrow on a non-outermost axis — e.g. BERT
50/// QKV split) requires kernels that consume strided inputs and is
51/// deferred. This function only catches the safely-elidable cases:
52///
53///   - **`Reshape`**: pure metadata; data layout is identical.
54///   - **`Cast`** with `src dtype == dst dtype`: pure metadata.
55///   - **`Narrow` on axis 0**: contiguous sub-slice of the parent;
56///     offset = `start * size_of_inner_in_bytes`.
57fn pure_view_offset(graph: &Graph, node: &rlx_ir::Node) -> Option<(NodeId, usize)> {
58    match &node.op {
59        Op::Reshape { .. } => Some((node.inputs[0], 0)),
60        Op::Cast { to } => {
61            let parent = graph.node(node.inputs[0]);
62            if parent.shape.dtype() == *to {
63                Some((node.inputs[0], 0))
64            } else {
65                None
66            }
67        }
68        Op::Narrow {
69            axis,
70            start,
71            len: _,
72        } if *axis == 0 => {
73            let parent = graph.node(node.inputs[0]);
74            // inner = product of dims after axis 0
75            let inner_elems: usize = (1..parent.shape.rank())
76                .map(|i| parent.shape.dim(i).unwrap_static())
77                .product();
78            let dt_bytes = parent.shape.dtype().size_bytes();
79            Some((node.inputs[0], start * inner_elems * dt_bytes))
80        }
81        _ => None,
82    }
83}
84
85/// Public predicate for backends — true iff this op should compile to
86/// a Nop because its output aliases a parent buffer (the memory
87/// planner has already aliased its slot).
88pub fn is_pure_view(graph: &Graph, node: &rlx_ir::Node) -> bool {
89    pure_view_offset(graph, node).is_some()
90}
91
92/// A buffer slot in the memory arena.
93#[derive(Debug, Clone)]
94pub struct BufferSlot {
95    /// Offset in bytes from the start of the arena.
96    pub offset: usize,
97    /// Size in bytes.
98    pub size: usize,
99}
100
101/// Complete memory plan for executing a graph.
102#[derive(Debug, Clone)]
103pub struct MemoryPlan {
104    /// Total arena size in bytes.
105    pub arena_size: usize,
106    /// Buffer assignment: NodeId → offset within arena.
107    pub assignments: HashMap<NodeId, BufferSlot>,
108    /// Node execution order (topological).
109    pub schedule: Vec<NodeId>,
110}
111
112impl MemoryPlan {
113    /// Sum of all assigned buffer sizes (i.e. how much memory the
114    /// plan would use if every node had its own slot). Useful for
115    /// reporting how much the liveness-aware sharing saved.
116    pub fn total_unshared_bytes(&self) -> usize {
117        self.assignments.values().map(|s| s.size).sum()
118    }
119
120    /// Bytes saved vs. naive "every node gets its own slot" — how
121    /// much the liveness analysis bought you.
122    pub fn bytes_saved(&self) -> usize {
123        self.total_unshared_bytes().saturating_sub(self.arena_size)
124    }
125
126    /// Render the buffer plan as a one-line-per-node table for
127    /// debugging — sorted by offset so adjacent buffers in memory
128    /// are adjacent in the report (plan #87).
129    ///
130    /// The output is parseable: `<offset>\t<size>\t%<node_id>`. Pipe
131    /// through `column -t` for human display, or grep / awk it for
132    /// scripted analysis.
133    pub fn report(&self) -> String {
134        let mut rows: Vec<(usize, usize, NodeId)> = self
135            .assignments
136            .iter()
137            .map(|(id, slot)| (slot.offset, slot.size, *id))
138            .collect();
139        rows.sort();
140        let mut out = String::new();
141        out.push_str(&format!(
142            "# arena_size={} total_unshared={} saved={}\n",
143            self.arena_size,
144            self.total_unshared_bytes(),
145            self.bytes_saved()
146        ));
147        out.push_str("# offset\tsize\tnode\n");
148        for (off, sz, id) in rows {
149            out.push_str(&format!("{off}\t{sz}\t{id}\n"));
150        }
151        out
152    }
153}
154
155/// Collect view-node aliases for embedding in LIR.
156pub fn collect_view_aliases(graph: &Graph) -> HashMap<NodeId, (NodeId, usize)> {
157    let mut out = HashMap::new();
158    for node in graph.nodes() {
159        if pure_view_offset(graph, node).is_some() {
160            let (root, off) = resolve_view_root(graph, node.id);
161            out.insert(node.id, (root, off));
162        }
163    }
164    out
165}
166
167/// Walk view chains until reaching a non-view ancestor. Returns the
168/// root buffer-owning node and the cumulative byte offset from the root.
169fn resolve_view_root(graph: &Graph, mut id: NodeId) -> (NodeId, usize) {
170    let mut total_offset = 0usize;
171    loop {
172        let node = graph.node(id);
173        match pure_view_offset(graph, node) {
174            Some((parent, off)) => {
175                total_offset += off;
176                id = parent;
177            }
178            None => return (id, total_offset),
179        }
180    }
181}
182
183/// Compute the live range [birth, death] for each node's output buffer.
184/// Birth = when the node produces its output.
185/// Death = the last time any consumer reads it.
186fn compute_live_ranges(graph: &Graph) -> HashMap<NodeId, (usize, usize)> {
187    let mut ranges: HashMap<NodeId, (usize, usize)> = HashMap::new();
188
189    for (step, node) in graph.nodes().iter().enumerate() {
190        // Birth: this node's output is produced at this step
191        ranges.entry(node.id).or_insert((step, step));
192
193        // Extend death of all inputs to at least this step. For view
194        // inputs, attribute the read to the *root* buffer so the
195        // underlying allocation stays alive while any view of it is
196        // still being read (#46 view-aliasing pattern).
197        for &input in &node.inputs {
198            let (root, _off) = resolve_view_root(graph, input);
199            ranges.entry(root).and_modify(|r| r.1 = r.1.max(step));
200            // Also track the view itself so we don't leave a dangling
201            // entry; views inherit the root's range later in
202            // plan_memory_aligned.
203            if root != input {
204                ranges.entry(input).and_modify(|r| r.1 = r.1.max(step));
205            }
206        }
207    }
208
209    // Extend death of output nodes to the end
210    let last_step = graph.len();
211    for &out in &graph.outputs {
212        let (root, _off) = resolve_view_root(graph, out);
213        ranges.entry(root).and_modify(|r| r.1 = last_step);
214        if root != out {
215            ranges.entry(out).and_modify(|r| r.1 = last_step);
216        }
217    }
218
219    // All producers feeding graph outputs must stay live through the final
220    // read-back (e.g. Cast f32→i64 feeding a boundary output). Without
221    // this, a later epilogue tensor can reuse an ancestor slot while thunks
222    // still run out of schedule order on overlapping paths.
223    {
224        let mut stack: Vec<NodeId> = graph.outputs.clone();
225        let mut seen = std::collections::HashSet::new();
226        while let Some(id) = stack.pop() {
227            if !seen.insert(id) {
228                continue;
229            }
230            let (root, _) = resolve_view_root(graph, id);
231            ranges.entry(root).and_modify(|r| r.1 = last_step);
232            if root != id {
233                ranges.entry(id).and_modify(|r| r.1 = last_step);
234            }
235            for &input in &graph.node(id).inputs {
236                stack.push(input);
237            }
238        }
239    }
240
241    // Params, Inputs, and Constants live for the ENTIRE execution.
242    // Params/Inputs are pre-loaded externally; Constants are pre-loaded
243    // by the runtime's compile step (see backend.rs::compile_inner). In
244    // all three cases the slot must not be overwritten by intermediate
245    // buffer sharing, otherwise iteration 2 of a training/inference
246    // loop would read whatever the previous run scribbled into it.
247    for node in graph.nodes() {
248        if matches!(
249            node.op,
250            rlx_ir::Op::Param { .. } | rlx_ir::Op::Input { .. } | rlx_ir::Op::Constant { .. }
251        ) {
252            ranges.entry(node.id).and_modify(|r| {
253                r.0 = 0;
254                r.1 = last_step;
255            });
256        }
257    }
258
259    ranges
260}
261
262/// Keep packed `[B,S,3,H,D]` QKV parents alive through Attention. Without
263/// this, liveness ends after the Narrow ops and the planner may reuse the
264/// parent slot for the attention output while the CPU fused path (and
265/// wgpu packed stride path) still read Q/K/V from that buffer.
266fn extend_node_chain_liveness_to_end(
267    graph: &Graph,
268    ranges: &mut HashMap<NodeId, (usize, usize)>,
269    start: NodeId,
270    last_step: usize,
271) {
272    let mut stack = vec![start];
273    let mut seen = std::collections::HashSet::new();
274    while let Some(id) = stack.pop() {
275        if !seen.insert(id) {
276            continue;
277        }
278        let (root, _) = resolve_view_root(graph, id);
279        ranges.entry(root).and_modify(|r| r.1 = last_step);
280        if root != id {
281            ranges.entry(id).and_modify(|r| r.1 = last_step);
282        }
283        for &input in &graph.node(id).inputs {
284            stack.push(input);
285        }
286    }
287}
288
289/// Keep primary data inputs alive through graph end for `Op::Custom("onnx.*")`
290/// thunks that read activations after parallel branches would otherwise reuse slots.
291fn extend_custom_op_input_liveness(graph: &Graph, ranges: &mut HashMap<NodeId, (usize, usize)>) {
292    let last_step = graph.len();
293    for node in graph.nodes() {
294        let Op::Custom {
295            name, num_inputs, ..
296        } = &node.op
297        else {
298            continue;
299        };
300        if !name.starts_with("onnx.") {
301            continue;
302        }
303        let n = (*num_inputs as usize).min(node.inputs.len());
304        for &input in &node.inputs[..n] {
305            extend_node_chain_liveness_to_end(graph, ranges, input, last_step);
306        }
307    }
308}
309
310/// Albert-style blocks reuse hidden buffers across many sequential Add/LN
311/// stages; keep residual inputs alive through graph end when this graph uses
312/// ONNX `QMatMul` thunks (marker for the bundled ONNX import path).
313fn extend_bert_hidden_liveness(graph: &Graph, ranges: &mut HashMap<NodeId, (usize, usize)>) {
314    let uses_onnx_qmatmul = graph.nodes().iter().any(|node| {
315        matches!(
316            &node.op,
317            Op::Custom { name, .. } if name == "onnx.QMatMul" || name == "onnx.ActCopy"
318        )
319    });
320    if !uses_onnx_qmatmul {
321        return;
322    }
323    let last_step = graph.len();
324    for node in graph.nodes() {
325        match &node.op {
326            Op::LayerNorm { .. } | Op::LayerNorm2d { .. } => {
327                if let Some(&input) = node.inputs.first() {
328                    extend_node_chain_liveness_to_end(graph, ranges, input, last_step);
329                }
330                ranges.entry(node.id).and_modify(|r| r.1 = last_step);
331            }
332            Op::Binary(BinaryOp::Add) => {
333                for &input in &node.inputs {
334                    extend_node_chain_liveness_to_end(graph, ranges, input, last_step);
335                }
336                ranges.entry(node.id).and_modify(|r| r.1 = last_step);
337            }
338            _ => {}
339        }
340    }
341}
342
343fn extend_packed_qkv_parent_liveness(graph: &Graph, ranges: &mut HashMap<NodeId, (usize, usize)>) {
344    for (step, node) in graph.nodes().iter().enumerate() {
345        let rlx_ir::Op::Attention { .. } = &node.op else {
346            continue;
347        };
348        if node.inputs.len() < 3 {
349            continue;
350        }
351        let Some((parent, _, _)) = rlx_ir::detect_packed_bshd_qkv_attention(
352            graph,
353            node.inputs[0],
354            node.inputs[1],
355            node.inputs[2],
356        ) else {
357            continue;
358        };
359        let (root, _) = resolve_view_root(graph, parent);
360        ranges.entry(root).and_modify(|r| r.1 = r.1.max(step));
361        if root != parent {
362            ranges.entry(parent).and_modify(|r| r.1 = r.1.max(step));
363        }
364    }
365}
366
367/// Assign buffers using a greedy best-fit algorithm.
368///
369/// Sorts buffers by size (largest first), then for each buffer finds
370/// the smallest free gap in the arena during its live interval.
371/// This is a simplified version of XLA's GlobalDecreasingSizeBestFitHeap.
372/// Controls which graph boundaries receive arena slots during planning.
373///
374/// Inference graphs use [`Self::inference`] (all boundaries allocated).
375/// Backward graphs in a training pair use [`Self::backward_activations_only`]:
376/// parameters borrow offsets from the forward plan via [`SharedWeightLayout`]
377/// so weights are not stored twice in the activation arena.
378#[derive(Debug, Clone, Copy, PartialEq, Eq)]
379pub struct MemoryPlanOptions {
380    pub allocate_params: bool,
381    pub allocate_inputs: bool,
382    pub allocate_constants: bool,
383    /// When true (or env `RLX_ARENA_NO_REUSE=1`), every tensor gets a unique arena slot.
384    pub arena_no_reuse: bool,
385}
386
387impl MemoryPlanOptions {
388    pub fn inference() -> Self {
389        Self {
390            allocate_params: true,
391            allocate_inputs: true,
392            allocate_constants: true,
393            arena_no_reuse: std::env::var("RLX_ARENA_NO_REUSE")
394                .ok()
395                .is_some_and(|v| v == "1" || v.eq_ignore_ascii_case("true")),
396        }
397    }
398
399    /// Activations + inputs/constants only; params bound via [`SharedWeightLayout`].
400    pub fn backward_activations_only() -> Self {
401        Self {
402            allocate_params: false,
403            allocate_inputs: true,
404            allocate_constants: true,
405            arena_no_reuse: std::env::var("RLX_ARENA_NO_REUSE")
406                .ok()
407                .is_some_and(|v| v == "1" || v.eq_ignore_ascii_case("true")),
408        }
409    }
410}
411
412impl Default for MemoryPlanOptions {
413    fn default() -> Self {
414        Self::inference()
415    }
416}
417
418/// Persistent parameter slots extracted from a forward [`MemoryPlan`].
419#[derive(Debug, Clone, PartialEq, Eq)]
420pub struct SharedWeightLayout {
421    pub arena_size: usize,
422    pub slots: Vec<WeightSlot>,
423}
424
425/// One named parameter and its byte range in the shared weight region.
426#[derive(Debug, Clone, PartialEq, Eq)]
427pub struct WeightSlot {
428    pub name: String,
429    pub forward_id: NodeId,
430    pub offset: usize,
431    pub size: usize,
432}
433
434impl SharedWeightLayout {
435    /// Collect `Op::Param` slots from a forward memory plan (by param name).
436    pub fn from_forward(graph: &Graph, plan: &MemoryPlan) -> Self {
437        let mut slots = Vec::new();
438        for node in graph.nodes() {
439            if let rlx_ir::Op::Param { name } = &node.op {
440                if let Some(slot) = plan.assignments.get(&node.id) {
441                    slots.push(WeightSlot {
442                        name: name.clone(),
443                        forward_id: node.id,
444                        offset: slot.offset,
445                        size: slot.size,
446                    });
447                }
448            }
449        }
450        slots.sort_by(|a, b| a.name.cmp(&b.name));
451        let arena_size = slots.iter().map(|s| s.offset + s.size).max().unwrap_or(0);
452        Self { arena_size, slots }
453    }
454
455    /// Map backward-graph `Op::Param` nodes to the forward weight offsets.
456    pub fn apply_to_plan(&self, graph: &Graph, plan: &mut MemoryPlan) {
457        let by_name: std::collections::HashMap<&str, &WeightSlot> =
458            self.slots.iter().map(|s| (s.name.as_str(), s)).collect();
459        for node in graph.nodes() {
460            if let rlx_ir::Op::Param { name } = &node.op {
461                let Some(slot) = by_name.get(name.as_str()) else {
462                    continue;
463                };
464                plan.assignments.insert(
465                    node.id,
466                    BufferSlot {
467                        offset: slot.offset,
468                        size: slot.size,
469                    },
470                );
471            }
472        }
473        plan.arena_size = plan.arena_size.max(self.arena_size);
474    }
475}
476
477#[inline]
478fn plans_boundary_buffer(op: &rlx_ir::Op, opts: MemoryPlanOptions) -> bool {
479    match op {
480        rlx_ir::Op::Param { .. } => opts.allocate_params,
481        rlx_ir::Op::Input { .. } => opts.allocate_inputs,
482        rlx_ir::Op::Constant { .. } => opts.allocate_constants,
483        _ => true,
484    }
485}
486
487/// Plan memory with default 64-byte alignment.
488pub fn plan_memory(graph: &Graph) -> MemoryPlan {
489    plan_memory_aligned(graph, 64)
490}
491
492/// Plan memory with custom alignment and boundary allocation policy.
493pub fn plan_memory_with_options(
494    graph: &Graph,
495    alignment: usize,
496    opts: MemoryPlanOptions,
497) -> MemoryPlan {
498    plan_memory_aligned_inner(graph, alignment, opts, None, false)
499}
500
501/// Plan memory with custom alignment (inference defaults).
502pub fn plan_memory_aligned(graph: &Graph, alignment: usize) -> MemoryPlan {
503    plan_memory_aligned_inner(graph, alignment, MemoryPlanOptions::default(), None, false)
504}
505
506/// Liveness-aware planning with every slot sized as `num_elements * 4`
507/// bytes (wgpu / uniform-f32 arenas). Reuses dead tensor slots so large
508/// `[n, n]` pairwise graphs stay under WebGPU's 128 MiB binding cap.
509pub fn plan_memory_f32_uniform(graph: &Graph, alignment: usize) -> MemoryPlan {
510    plan_memory_aligned_inner(graph, alignment, MemoryPlanOptions::default(), None, true)
511}
512
513/// Plan backward activations, then alias params onto `weights`.
514pub fn plan_memory_backward(
515    graph: &Graph,
516    alignment: usize,
517    weights: &SharedWeightLayout,
518) -> MemoryPlan {
519    plan_memory_aligned_inner(
520        graph,
521        alignment,
522        MemoryPlanOptions::backward_activations_only(),
523        Some(weights),
524        false,
525    )
526}
527
528#[inline]
529fn node_slot_bytes(node: &rlx_ir::Node, f32_uniform: bool) -> usize {
530    if f32_uniform {
531        node.shape.num_elements().unwrap_or(0) * 4
532    } else {
533        node.shape.size_bytes().unwrap_or(0)
534    }
535}
536
537fn plan_memory_aligned_inner(
538    graph: &Graph,
539    alignment: usize,
540    opts: MemoryPlanOptions,
541    weights: Option<&SharedWeightLayout>,
542    f32_uniform: bool,
543) -> MemoryPlan {
544    let mut ranges = compute_live_ranges(graph);
545    extend_packed_qkv_parent_liveness(graph, &mut ranges);
546    extend_custom_op_input_liveness(graph, &mut ranges);
547    extend_bert_hidden_liveness(graph, &mut ranges);
548    // Collect buffers that need allocation (skip inputs/params — external)
549    struct BufInfo {
550        id: NodeId,
551        size: usize,
552        birth: usize,
553        death: usize,
554    }
555
556    let mut buffers: Vec<BufInfo> = Vec::new();
557    for node in graph.nodes() {
558        // Skip view nodes — they alias their parent's buffer (handled
559        // in the post-pass below). Plan #46.
560        if pure_view_offset(graph, node).is_some() {
561            continue;
562        }
563        let size = node_slot_bytes(node, f32_uniform);
564        if size > 0
565            && let Some(&(birth, death)) = ranges.get(&node.id)
566            && plans_boundary_buffer(&node.op, opts)
567        {
568            buffers.push(BufInfo {
569                id: node.id,
570                size,
571                birth,
572                death,
573            });
574        }
575    }
576
577    // Sort by size descending (largest first gets priority placement)
578    buffers.sort_by_key(|b| std::cmp::Reverse(b.size));
579
580    // Greedy first-fit allocation
581    let mut assignments: HashMap<NodeId, BufferSlot> = HashMap::new();
582    let mut arena_size: usize = 0;
583
584    // Track allocated regions with their live ranges
585    let mut placed: Vec<(usize, usize, usize, usize)> = Vec::new(); // (offset, size, birth, death)
586
587    for buf in &buffers {
588        let align = alignment;
589        let node = graph.node(buf.id);
590        let tail_guard = boundary_tail_guard(&node.op, align);
591        let placement_size = buf.size + tail_guard;
592        let mut best_offset: Option<usize> = None;
593
594        // Collect candidate start offsets: 0 plus the end of every placed
595        // buffer that could border a free gap.
596        let mut candidates = vec![0usize];
597        for &(p_off, p_size, _, _) in &placed {
598            candidates.push(p_off + p_size);
599        }
600        candidates.sort_unstable();
601        candidates.dedup();
602
603        for &candidate_offset in &candidates {
604            let aligned = (candidate_offset + align - 1) & !(align - 1);
605            let end = aligned + placement_size;
606
607            let conflict = placed.iter().any(|&(p_off, p_size, p_birth, p_death)| {
608                let p_end = p_off + p_size;
609                let mem_overlap = aligned < p_end && end > p_off;
610                let time_overlap = buf.birth <= p_death && buf.death >= p_birth;
611                mem_overlap && time_overlap
612            });
613
614            if !conflict {
615                match best_offset {
616                    None => best_offset = Some(aligned),
617                    Some(best) if aligned < best => best_offset = Some(aligned),
618                    _ => {}
619                }
620            }
621        }
622
623        let aligned = if opts.arena_no_reuse {
624            (arena_size + align - 1) & !(align - 1)
625        } else {
626            best_offset.unwrap_or_else(|| {
627                // No gap fit — append at arena tail.
628                (arena_size + align - 1) & !(align - 1)
629            })
630        };
631        assignments.insert(
632            buf.id,
633            BufferSlot {
634                offset: aligned,
635                size: buf.size,
636            },
637        );
638        placed.push((aligned, placement_size, buf.birth, buf.death));
639        arena_size = arena_size.max(aligned + placement_size);
640    }
641
642    // ── View aliasing pass (plan #46) ────────────────────────
643    // Every view node points at its root buffer's slot, offset by the
644    // accumulated view offset. The root has its own allocation above;
645    // views just borrow its bytes. This is the post-pass — done after
646    // root allocations are placed so we have offsets to point at.
647    for node in graph.nodes() {
648        if pure_view_offset(graph, node).is_some() {
649            let (root, off) = resolve_view_root(graph, node.id);
650            if let Some(root_slot) = assignments.get(&root).cloned() {
651                let view_size = node_slot_bytes(node, f32_uniform);
652                assignments.insert(
653                    node.id,
654                    BufferSlot {
655                        offset: root_slot.offset + off,
656                        size: view_size,
657                    },
658                );
659            }
660        }
661    }
662
663    let schedule = graph.topo_order().collect();
664
665    let mut plan = MemoryPlan {
666        arena_size,
667        assignments,
668        schedule,
669    };
670    if let Some(w) = weights {
671        w.apply_to_plan(graph, &mut plan);
672    }
673    plan
674}
675
676#[cfg(test)]
677mod tests {
678    use super::*;
679    use rlx_ir::*;
680
681    #[test]
682    fn non_overlapping_buffers_share_memory() {
683        let mut g = Graph::new("test");
684        let f = DType::F32;
685
686        let x = g.input("x", Shape::new(&[100, 384], f)); // 153.6KB
687        let w1 = g.param("w1", Shape::new(&[384, 384], f));
688        let w2 = g.param("w2", Shape::new(&[384, 384], f));
689
690        // mm1 is only used by mm2's input; after mm2, mm1 is dead
691        let mm1 = g.matmul(x, w1, Shape::new(&[100, 384], f)); // 153.6KB, live [4, 5]
692        let mm2 = g.matmul(mm1, w2, Shape::new(&[100, 384], f)); // 153.6KB, live [5, ∞]
693        g.set_outputs(vec![mm2]);
694
695        let plan = plan_memory(&g);
696        println!("Arena size: {} bytes", plan.arena_size);
697        for (id, slot) in &plan.assignments {
698            if let Some((b, d)) = compute_live_ranges(&g).get(id) {
699                println!(
700                    "  {id}: offset={}, size={}, live=[{b}, {d}]",
701                    slot.offset, slot.size
702                );
703            }
704        }
705
706        // Logical slot sizes omit 64-byte alignment gaps and param tail guards
707        // (see `boundary_tail_guard`). Arena may be slightly larger than that sum
708        // even when temporaries reuse gaps; cap slack at one guard per slot.
709        let total_logical: usize = plan.assignments.values().map(|s| s.size).sum();
710        let align_slack = plan.assignments.len() * BOUNDARY_TAIL_GUARD_BYTES;
711        assert!(
712            plan.arena_size <= total_logical + align_slack,
713            "arena {} should be <= logical sum {} + slack {}",
714            plan.arena_size,
715            total_logical,
716            align_slack
717        );
718    }
719
720    #[test]
721    fn plan_report_includes_savings() {
722        // Plan #87: the public report() string surfaces enough info
723        // for debug tooling — arena size, unshared total, saved
724        // bytes, and a per-buffer table sorted by offset.
725        let mut g = Graph::new("rep");
726        let f = DType::F32;
727        let x = g.input("x", Shape::new(&[16], f));
728        let w = g.param("w", Shape::new(&[16, 16], f));
729        let mm1 = g.matmul(x, w, Shape::new(&[1, 16], f));
730        let mm2 = g.matmul(mm1, w, Shape::new(&[1, 16], f));
731        g.set_outputs(vec![mm2]);
732
733        let plan = plan_memory(&g);
734        let r = plan.report();
735        // Header carries the headline numbers.
736        assert!(r.starts_with("# arena_size="));
737        assert!(r.contains("total_unshared="));
738        assert!(r.contains("saved="));
739        // Body is parseable (offset\tsize\tnode), sorted ascending.
740        let body: Vec<&str> = r.lines().filter(|l| !l.starts_with('#')).collect();
741        assert!(!body.is_empty());
742        // assignments map → at least mm1 + mm2 + x + w should appear.
743        assert!(plan.assignments.contains_key(&mm1));
744        assert!(plan.assignments.contains_key(&mm2));
745    }
746
747    #[test]
748    fn view_ops_alias_parent_slot() {
749        // Reshape, same-dtype Cast, and axis-0 Narrow should NOT get
750        // their own arena slot — they alias the parent (#46).
751        use rlx_ir::GraphExt;
752        let mut g = Graph::new("views");
753        let f = DType::F32;
754        let x = g.input("x", Shape::new(&[8, 4], f)); // 128B
755        let w = g.param("w", Shape::new(&[4, 4], f)); // 64B
756        let mm = g.matmul(x, w, Shape::new(&[8, 4], f)); // 128B (root)
757        let r = g.reshape_(mm, vec![32]); // VIEW (Reshape)
758        let c = g.cast(r, DType::F32); // VIEW (same-dtype Cast)
759        let n = g.narrow_(c, 0, 8, 16); // VIEW (axis-0 Narrow)
760        g.set_outputs(vec![n]);
761
762        let plan = plan_memory(&g);
763
764        // All three view nodes should share mm's offset (with adjustment
765        // for the narrow's start=8 → +8*4 = 32 bytes).
766        let mm_off = plan.assignments[&mm].offset;
767        assert_eq!(
768            plan.assignments[&r].offset, mm_off,
769            "reshape view should alias mm slot exactly"
770        );
771        assert_eq!(
772            plan.assignments[&c].offset, mm_off,
773            "same-dtype cast view should alias mm slot exactly"
774        );
775        assert_eq!(
776            plan.assignments[&n].offset,
777            mm_off + 32,
778            "axis-0 narrow start=8 should alias mm slot + 8*4 bytes"
779        );
780        assert_eq!(
781            plan.assignments[&n].size, 64,
782            "narrow view's size is its own (16 f32 = 64B), not parent's"
783        );
784    }
785
786    #[test]
787    fn backward_plan_aliases_forward_param_slots() {
788        let f = DType::F32;
789        let mut fwd = Graph::new("fwd");
790        let x = fwd.input("x", Shape::new(&[2, 4], f));
791        let w = fwd.param("w", Shape::new(&[4, 4], f));
792        let mm = fwd.matmul(x, w, Shape::new(&[2, 4], f));
793        fwd.set_outputs(vec![mm]);
794        let fwd_plan = plan_memory_aligned(&fwd, 64);
795        let layout = SharedWeightLayout::from_forward(&fwd, &fwd_plan);
796
797        let mut bwd = Graph::new("bwd_grad");
798        let x2 = bwd.input("x", Shape::new(&[2, 4], f));
799        let w2 = bwd.param("w", Shape::new(&[4, 4], f));
800        let mm2 = bwd.matmul(x2, w2, Shape::new(&[2, 4], f));
801        bwd.set_outputs(vec![mm2]);
802
803        let bwd_plan = plan_memory_backward(&bwd, 64, &layout);
804        let fwd_w_off = fwd_plan.assignments[&w].offset;
805        let bwd_w_off = bwd_plan.assignments[&w2].offset;
806        assert_eq!(bwd_w_off, fwd_w_off, "backward w must share forward offset");
807        assert!(
808            !bwd_plan.assignments.contains_key(&w2)
809                || bwd_plan.assignments[&w2].offset == fwd_w_off
810        );
811    }
812
813    #[test]
814    fn overlapping_buffers_get_separate_memory() {
815        let mut g = Graph::new("test");
816        let f = DType::F32;
817
818        let x = g.input("x", Shape::new(&[100, 384], f));
819        let w = g.param("w", Shape::new(&[384, 384], f));
820
821        let mm = g.matmul(x, w, Shape::new(&[100, 384], f));
822        // Both mm and x are live at the same time (mm uses x)
823        // x is also an output, so it stays live
824        let add = g.binary(BinaryOp::Add, mm, x, Shape::new(&[100, 384], f));
825        g.set_outputs(vec![add]);
826
827        let plan = plan_memory(&g);
828        let mm_slot = &plan.assignments[&mm];
829        let add_slot = &plan.assignments[&add];
830
831        // mm and add overlap in time, so they must not overlap in memory
832        let mm_end = mm_slot.offset + mm_slot.size;
833        let add_end = add_slot.offset + add_slot.size;
834        let no_overlap = mm_end <= add_slot.offset || add_end <= mm_slot.offset;
835        assert!(no_overlap, "overlapping buffers must have separate memory");
836    }
837}