Skip to main content

rlx_compile/
memory.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Memory planning — liveness analysis and buffer assignment.
17//!
18//! This is the XLA feature that no other Rust framework has. It computes
19//! which intermediate tensors have non-overlapping lifetimes and assigns
20//! them to the same memory, minimizing total arena size.
21//!
22//! The output is a [`MemoryPlan`] that tells the runtime exactly how
23//! large the arena should be and where each tensor lives within it.
24
25use rlx_ir::{Graph, NodeId, Op};
26use std::collections::HashMap;
27
28/// Identify ops whose output is a *view* of an existing buffer — no
29/// copy needed, no separate arena slot. Returns the parent input index
30/// and the byte offset of the view within the parent.
31///
32/// Borrowed from MAX's "view-vs-copy" pattern (#46 in PLAN.md).
33/// The hard case (strided narrow on a non-outermost axis — e.g. BERT
34/// QKV split) requires kernels that consume strided inputs and is
35/// deferred. This function only catches the safely-elidable cases:
36///
37///   - **`Reshape`**: pure metadata; data layout is identical.
38///   - **`Cast`** with `src dtype == dst dtype`: pure metadata.
39///   - **`Narrow` on axis 0**: contiguous sub-slice of the parent;
40///     offset = `start * size_of_inner_in_bytes`.
41fn pure_view_offset(graph: &Graph, node: &rlx_ir::Node) -> Option<(NodeId, usize)> {
42    match &node.op {
43        Op::Reshape { .. } => Some((node.inputs[0], 0)),
44        Op::Cast { to } => {
45            let parent = graph.node(node.inputs[0]);
46            if parent.shape.dtype() == *to {
47                Some((node.inputs[0], 0))
48            } else {
49                None
50            }
51        }
52        Op::Narrow {
53            axis,
54            start,
55            len: _,
56        } if *axis == 0 => {
57            let parent = graph.node(node.inputs[0]);
58            // inner = product of dims after axis 0
59            let inner_elems: usize = (1..parent.shape.rank())
60                .map(|i| parent.shape.dim(i).unwrap_static())
61                .product();
62            let dt_bytes = parent.shape.dtype().size_bytes();
63            Some((node.inputs[0], start * inner_elems * dt_bytes))
64        }
65        _ => None,
66    }
67}
68
69/// Public predicate for backends — true iff this op should compile to
70/// a Nop because its output aliases a parent buffer (the memory
71/// planner has already aliased its slot).
72pub fn is_pure_view(graph: &Graph, node: &rlx_ir::Node) -> bool {
73    pure_view_offset(graph, node).is_some()
74}
75
76/// A buffer slot in the memory arena.
77#[derive(Debug, Clone)]
78pub struct BufferSlot {
79    /// Offset in bytes from the start of the arena.
80    pub offset: usize,
81    /// Size in bytes.
82    pub size: usize,
83}
84
85/// Complete memory plan for executing a graph.
86#[derive(Debug, Clone)]
87pub struct MemoryPlan {
88    /// Total arena size in bytes.
89    pub arena_size: usize,
90    /// Buffer assignment: NodeId → offset within arena.
91    pub assignments: HashMap<NodeId, BufferSlot>,
92    /// Node execution order (topological).
93    pub schedule: Vec<NodeId>,
94}
95
96impl MemoryPlan {
97    /// Sum of all assigned buffer sizes (i.e. how much memory the
98    /// plan would use if every node had its own slot). Useful for
99    /// reporting how much the liveness-aware sharing saved.
100    pub fn total_unshared_bytes(&self) -> usize {
101        self.assignments.values().map(|s| s.size).sum()
102    }
103
104    /// Bytes saved vs. naive "every node gets its own slot" — how
105    /// much the liveness analysis bought you.
106    pub fn bytes_saved(&self) -> usize {
107        self.total_unshared_bytes().saturating_sub(self.arena_size)
108    }
109
110    /// Render the buffer plan as a one-line-per-node table for
111    /// debugging — sorted by offset so adjacent buffers in memory
112    /// are adjacent in the report (plan #87).
113    ///
114    /// The output is parseable: `<offset>\t<size>\t%<node_id>`. Pipe
115    /// through `column -t` for human display, or grep / awk it for
116    /// scripted analysis.
117    pub fn report(&self) -> String {
118        let mut rows: Vec<(usize, usize, NodeId)> = self
119            .assignments
120            .iter()
121            .map(|(id, slot)| (slot.offset, slot.size, *id))
122            .collect();
123        rows.sort();
124        let mut out = String::new();
125        out.push_str(&format!(
126            "# arena_size={} total_unshared={} saved={}\n",
127            self.arena_size,
128            self.total_unshared_bytes(),
129            self.bytes_saved()
130        ));
131        out.push_str("# offset\tsize\tnode\n");
132        for (off, sz, id) in rows {
133            out.push_str(&format!("{off}\t{sz}\t{id}\n"));
134        }
135        out
136    }
137}
138
139/// Collect view-node aliases for embedding in LIR.
140pub fn collect_view_aliases(graph: &Graph) -> HashMap<NodeId, (NodeId, usize)> {
141    let mut out = HashMap::new();
142    for node in graph.nodes() {
143        if pure_view_offset(graph, node).is_some() {
144            let (root, off) = resolve_view_root(graph, node.id);
145            out.insert(node.id, (root, off));
146        }
147    }
148    out
149}
150
151/// Walk view chains until reaching a non-view ancestor. Returns the
152/// root buffer-owning node and the cumulative byte offset from the root.
153fn resolve_view_root(graph: &Graph, mut id: NodeId) -> (NodeId, usize) {
154    let mut total_offset = 0usize;
155    loop {
156        let node = graph.node(id);
157        match pure_view_offset(graph, node) {
158            Some((parent, off)) => {
159                total_offset += off;
160                id = parent;
161            }
162            None => return (id, total_offset),
163        }
164    }
165}
166
167/// Compute the live range [birth, death] for each node's output buffer.
168/// Birth = when the node produces its output.
169/// Death = the last time any consumer reads it.
170fn compute_live_ranges(graph: &Graph) -> HashMap<NodeId, (usize, usize)> {
171    let mut ranges: HashMap<NodeId, (usize, usize)> = HashMap::new();
172
173    for (step, node) in graph.nodes().iter().enumerate() {
174        // Birth: this node's output is produced at this step
175        ranges.entry(node.id).or_insert((step, step));
176
177        // Extend death of all inputs to at least this step. For view
178        // inputs, attribute the read to the *root* buffer so the
179        // underlying allocation stays alive while any view of it is
180        // still being read (#46 view-aliasing pattern).
181        for &input in &node.inputs {
182            let (root, _off) = resolve_view_root(graph, input);
183            ranges.entry(root).and_modify(|r| r.1 = r.1.max(step));
184            // Also track the view itself so we don't leave a dangling
185            // entry; views inherit the root's range later in
186            // plan_memory_aligned.
187            if root != input {
188                ranges.entry(input).and_modify(|r| r.1 = r.1.max(step));
189            }
190        }
191    }
192
193    // Extend death of output nodes to the end
194    let last_step = graph.len();
195    for &out in &graph.outputs {
196        let (root, _off) = resolve_view_root(graph, out);
197        ranges.entry(root).and_modify(|r| r.1 = last_step);
198        if root != out {
199            ranges.entry(out).and_modify(|r| r.1 = last_step);
200        }
201    }
202
203    // Params, Inputs, and Constants live for the ENTIRE execution.
204    // Params/Inputs are pre-loaded externally; Constants are pre-loaded
205    // by the runtime's compile step (see backend.rs::compile_inner). In
206    // all three cases the slot must not be overwritten by intermediate
207    // buffer sharing, otherwise iteration 2 of a training/inference
208    // loop would read whatever the previous run scribbled into it.
209    for node in graph.nodes() {
210        if matches!(
211            node.op,
212            rlx_ir::Op::Param { .. } | rlx_ir::Op::Input { .. } | rlx_ir::Op::Constant { .. }
213        ) {
214            ranges.entry(node.id).and_modify(|r| {
215                r.0 = 0;
216                r.1 = last_step;
217            });
218        }
219    }
220
221    ranges
222}
223
224/// Assign buffers using a greedy best-fit algorithm.
225///
226/// Sorts buffers by size (largest first), then for each buffer finds
227/// the smallest free gap in the arena during its live interval.
228/// This is a simplified version of XLA's GlobalDecreasingSizeBestFitHeap.
229/// Controls which graph boundaries receive arena slots during planning.
230///
231/// Inference graphs use [`Self::inference`] (all boundaries allocated).
232/// Backward graphs in a training pair use [`Self::backward_activations_only`]:
233/// parameters borrow offsets from the forward plan via [`SharedWeightLayout`]
234/// so weights are not stored twice in the activation arena.
235#[derive(Debug, Clone, Copy, PartialEq, Eq)]
236pub struct MemoryPlanOptions {
237    pub allocate_params: bool,
238    pub allocate_inputs: bool,
239    pub allocate_constants: bool,
240}
241
242impl MemoryPlanOptions {
243    pub fn inference() -> Self {
244        Self {
245            allocate_params: true,
246            allocate_inputs: true,
247            allocate_constants: true,
248        }
249    }
250
251    /// Activations + inputs/constants only; params bound via [`SharedWeightLayout`].
252    pub fn backward_activations_only() -> Self {
253        Self {
254            allocate_params: false,
255            allocate_inputs: true,
256            allocate_constants: true,
257        }
258    }
259}
260
261impl Default for MemoryPlanOptions {
262    fn default() -> Self {
263        Self::inference()
264    }
265}
266
267/// Persistent parameter slots extracted from a forward [`MemoryPlan`].
268#[derive(Debug, Clone, PartialEq, Eq)]
269pub struct SharedWeightLayout {
270    pub arena_size: usize,
271    pub slots: Vec<WeightSlot>,
272}
273
274/// One named parameter and its byte range in the shared weight region.
275#[derive(Debug, Clone, PartialEq, Eq)]
276pub struct WeightSlot {
277    pub name: String,
278    pub forward_id: NodeId,
279    pub offset: usize,
280    pub size: usize,
281}
282
283impl SharedWeightLayout {
284    /// Collect `Op::Param` slots from a forward memory plan (by param name).
285    pub fn from_forward(graph: &Graph, plan: &MemoryPlan) -> Self {
286        let mut slots = Vec::new();
287        for node in graph.nodes() {
288            if let rlx_ir::Op::Param { name } = &node.op {
289                if let Some(slot) = plan.assignments.get(&node.id) {
290                    slots.push(WeightSlot {
291                        name: name.clone(),
292                        forward_id: node.id,
293                        offset: slot.offset,
294                        size: slot.size,
295                    });
296                }
297            }
298        }
299        slots.sort_by(|a, b| a.name.cmp(&b.name));
300        let arena_size = slots.iter().map(|s| s.offset + s.size).max().unwrap_or(0);
301        Self { arena_size, slots }
302    }
303
304    /// Map backward-graph `Op::Param` nodes to the forward weight offsets.
305    pub fn apply_to_plan(&self, graph: &Graph, plan: &mut MemoryPlan) {
306        let by_name: std::collections::HashMap<&str, &WeightSlot> =
307            self.slots.iter().map(|s| (s.name.as_str(), s)).collect();
308        for node in graph.nodes() {
309            if let rlx_ir::Op::Param { name } = &node.op {
310                let Some(slot) = by_name.get(name.as_str()) else {
311                    continue;
312                };
313                plan.assignments.insert(
314                    node.id,
315                    BufferSlot {
316                        offset: slot.offset,
317                        size: slot.size,
318                    },
319                );
320            }
321        }
322        plan.arena_size = plan.arena_size.max(self.arena_size);
323    }
324}
325
326#[inline]
327fn plans_boundary_buffer(op: &rlx_ir::Op, opts: MemoryPlanOptions) -> bool {
328    match op {
329        rlx_ir::Op::Param { .. } => opts.allocate_params,
330        rlx_ir::Op::Input { .. } => opts.allocate_inputs,
331        rlx_ir::Op::Constant { .. } => opts.allocate_constants,
332        _ => true,
333    }
334}
335
336/// Plan memory with default 64-byte alignment.
337pub fn plan_memory(graph: &Graph) -> MemoryPlan {
338    plan_memory_aligned(graph, 64)
339}
340
341/// Plan memory with custom alignment and boundary allocation policy.
342pub fn plan_memory_with_options(
343    graph: &Graph,
344    alignment: usize,
345    opts: MemoryPlanOptions,
346) -> MemoryPlan {
347    plan_memory_aligned_inner(graph, alignment, opts, None, false)
348}
349
350/// Plan memory with custom alignment (inference defaults).
351pub fn plan_memory_aligned(graph: &Graph, alignment: usize) -> MemoryPlan {
352    plan_memory_aligned_inner(graph, alignment, MemoryPlanOptions::default(), None, false)
353}
354
355/// Liveness-aware planning with every slot sized as `num_elements * 4`
356/// bytes (wgpu / uniform-f32 arenas). Reuses dead tensor slots so large
357/// `[n, n]` pairwise graphs stay under WebGPU's 128 MiB binding cap.
358pub fn plan_memory_f32_uniform(graph: &Graph, alignment: usize) -> MemoryPlan {
359    plan_memory_aligned_inner(graph, alignment, MemoryPlanOptions::default(), None, true)
360}
361
362/// Plan backward activations, then alias params onto `weights`.
363pub fn plan_memory_backward(
364    graph: &Graph,
365    alignment: usize,
366    weights: &SharedWeightLayout,
367) -> MemoryPlan {
368    plan_memory_aligned_inner(
369        graph,
370        alignment,
371        MemoryPlanOptions::backward_activations_only(),
372        Some(weights),
373        false,
374    )
375}
376
377#[inline]
378fn node_slot_bytes(node: &rlx_ir::Node, f32_uniform: bool) -> usize {
379    if f32_uniform {
380        node.shape.num_elements().unwrap_or(0) * 4
381    } else {
382        node.shape.size_bytes().unwrap_or(0)
383    }
384}
385
386fn plan_memory_aligned_inner(
387    graph: &Graph,
388    alignment: usize,
389    opts: MemoryPlanOptions,
390    weights: Option<&SharedWeightLayout>,
391    f32_uniform: bool,
392) -> MemoryPlan {
393    let ranges = compute_live_ranges(graph);
394
395    // Collect buffers that need allocation (skip inputs/params — external)
396    struct BufInfo {
397        id: NodeId,
398        size: usize,
399        birth: usize,
400        death: usize,
401    }
402
403    let mut buffers: Vec<BufInfo> = Vec::new();
404    for node in graph.nodes() {
405        // Skip view nodes — they alias their parent's buffer (handled
406        // in the post-pass below). Plan #46.
407        if pure_view_offset(graph, node).is_some() {
408            continue;
409        }
410        let size = node_slot_bytes(node, f32_uniform);
411        if size > 0
412            && let Some(&(birth, death)) = ranges.get(&node.id)
413            && plans_boundary_buffer(&node.op, opts)
414        {
415            buffers.push(BufInfo {
416                id: node.id,
417                size,
418                birth,
419                death,
420            });
421        }
422    }
423
424    // Sort by size descending (largest first gets priority placement)
425    buffers.sort_by_key(|b| std::cmp::Reverse(b.size));
426
427    // Greedy first-fit allocation
428    let mut assignments: HashMap<NodeId, BufferSlot> = HashMap::new();
429    let mut arena_size: usize = 0;
430
431    // Track allocated regions with their live ranges
432    let mut placed: Vec<(usize, usize, usize, usize)> = Vec::new(); // (offset, size, birth, death)
433
434    for buf in &buffers {
435        let align = alignment;
436        let mut best_offset: Option<usize> = None;
437
438        // Collect candidate start offsets: 0 plus the end of every placed
439        // buffer that could border a free gap.
440        let mut candidates = vec![0usize];
441        for &(p_off, p_size, _, _) in &placed {
442            candidates.push(p_off + p_size);
443        }
444        candidates.sort_unstable();
445        candidates.dedup();
446
447        for &candidate_offset in &candidates {
448            let aligned = (candidate_offset + align - 1) & !(align - 1);
449            let end = aligned + buf.size;
450
451            let conflict = placed.iter().any(|&(p_off, p_size, p_birth, p_death)| {
452                let p_end = p_off + p_size;
453                let mem_overlap = aligned < p_end && end > p_off;
454                let time_overlap = buf.birth <= p_death && buf.death >= p_birth;
455                mem_overlap && time_overlap
456            });
457
458            if !conflict {
459                match best_offset {
460                    None => best_offset = Some(aligned),
461                    Some(best) if aligned < best => best_offset = Some(aligned),
462                    _ => {}
463                }
464            }
465        }
466
467        let aligned = best_offset.unwrap_or_else(|| {
468            // No gap fit — append at arena tail.
469            (arena_size + align - 1) & !(align - 1)
470        });
471        assignments.insert(
472            buf.id,
473            BufferSlot {
474                offset: aligned,
475                size: buf.size,
476            },
477        );
478        placed.push((aligned, buf.size, buf.birth, buf.death));
479        arena_size = arena_size.max(aligned + buf.size);
480    }
481
482    // ── View aliasing pass (plan #46) ────────────────────────
483    // Every view node points at its root buffer's slot, offset by the
484    // accumulated view offset. The root has its own allocation above;
485    // views just borrow its bytes. This is the post-pass — done after
486    // root allocations are placed so we have offsets to point at.
487    for node in graph.nodes() {
488        if pure_view_offset(graph, node).is_some() {
489            let (root, off) = resolve_view_root(graph, node.id);
490            if let Some(root_slot) = assignments.get(&root).cloned() {
491                let view_size = node_slot_bytes(node, f32_uniform);
492                assignments.insert(
493                    node.id,
494                    BufferSlot {
495                        offset: root_slot.offset + off,
496                        size: view_size,
497                    },
498                );
499            }
500        }
501    }
502
503    let schedule = graph.topo_order().collect();
504
505    let mut plan = MemoryPlan {
506        arena_size,
507        assignments,
508        schedule,
509    };
510    if let Some(w) = weights {
511        w.apply_to_plan(graph, &mut plan);
512    }
513    plan
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519    use rlx_ir::op::*;
520    use rlx_ir::*;
521
522    #[test]
523    fn non_overlapping_buffers_share_memory() {
524        let mut g = Graph::new("test");
525        let f = DType::F32;
526
527        let x = g.input("x", Shape::new(&[100, 384], f)); // 153.6KB
528        let w1 = g.param("w1", Shape::new(&[384, 384], f));
529        let w2 = g.param("w2", Shape::new(&[384, 384], f));
530
531        // mm1 is only used by mm2's input; after mm2, mm1 is dead
532        let mm1 = g.matmul(x, w1, Shape::new(&[100, 384], f)); // 153.6KB, live [4, 5]
533        let mm2 = g.matmul(mm1, w2, Shape::new(&[100, 384], f)); // 153.6KB, live [5, ∞]
534        g.set_outputs(vec![mm2]);
535
536        let plan = plan_memory(&g);
537        println!("Arena size: {} bytes", plan.arena_size);
538        for (id, slot) in &plan.assignments {
539            if let Some((b, d)) = compute_live_ranges(&g).get(id) {
540                println!(
541                    "  {id}: offset={}, size={}, live=[{b}, {d}]",
542                    slot.offset, slot.size
543                );
544            }
545        }
546
547        // mm1 and mm2 have non-overlapping lifetimes, so they CAN share memory.
548        // The arena should be smaller than the sum of all buffers.
549        let total_if_no_sharing: usize = plan.assignments.values().map(|s| s.size).sum();
550        assert!(
551            plan.arena_size <= total_if_no_sharing,
552            "arena {0} should be <= sum {total_if_no_sharing}",
553            plan.arena_size
554        );
555    }
556
557    #[test]
558    fn plan_report_includes_savings() {
559        // Plan #87: the public report() string surfaces enough info
560        // for debug tooling — arena size, unshared total, saved
561        // bytes, and a per-buffer table sorted by offset.
562        let mut g = Graph::new("rep");
563        let f = DType::F32;
564        let x = g.input("x", Shape::new(&[16], f));
565        let w = g.param("w", Shape::new(&[16, 16], f));
566        let mm1 = g.matmul(x, w, Shape::new(&[1, 16], f));
567        let mm2 = g.matmul(mm1, w, Shape::new(&[1, 16], f));
568        g.set_outputs(vec![mm2]);
569
570        let plan = plan_memory(&g);
571        let r = plan.report();
572        // Header carries the headline numbers.
573        assert!(r.starts_with("# arena_size="));
574        assert!(r.contains("total_unshared="));
575        assert!(r.contains("saved="));
576        // Body is parseable (offset\tsize\tnode), sorted ascending.
577        let body: Vec<&str> = r.lines().filter(|l| !l.starts_with('#')).collect();
578        assert!(!body.is_empty());
579        // assignments map → at least mm1 + mm2 + x + w should appear.
580        assert!(plan.assignments.contains_key(&mm1));
581        assert!(plan.assignments.contains_key(&mm2));
582    }
583
584    #[test]
585    fn view_ops_alias_parent_slot() {
586        // Reshape, same-dtype Cast, and axis-0 Narrow should NOT get
587        // their own arena slot — they alias the parent (#46).
588        use rlx_ir::GraphExt;
589        let mut g = Graph::new("views");
590        let f = DType::F32;
591        let x = g.input("x", Shape::new(&[8, 4], f)); // 128B
592        let w = g.param("w", Shape::new(&[4, 4], f)); // 64B
593        let mm = g.matmul(x, w, Shape::new(&[8, 4], f)); // 128B (root)
594        let r = g.reshape_(mm, vec![32]); // VIEW (Reshape)
595        let c = g.cast(r, DType::F32); // VIEW (same-dtype Cast)
596        let n = g.narrow_(c, 0, 8, 16); // VIEW (axis-0 Narrow)
597        g.set_outputs(vec![n]);
598
599        let plan = plan_memory(&g);
600
601        // All three view nodes should share mm's offset (with adjustment
602        // for the narrow's start=8 → +8*4 = 32 bytes).
603        let mm_off = plan.assignments[&mm].offset;
604        assert_eq!(
605            plan.assignments[&r].offset, mm_off,
606            "reshape view should alias mm slot exactly"
607        );
608        assert_eq!(
609            plan.assignments[&c].offset, mm_off,
610            "same-dtype cast view should alias mm slot exactly"
611        );
612        assert_eq!(
613            plan.assignments[&n].offset,
614            mm_off + 32,
615            "axis-0 narrow start=8 should alias mm slot + 8*4 bytes"
616        );
617        assert_eq!(
618            plan.assignments[&n].size, 64,
619            "narrow view's size is its own (16 f32 = 64B), not parent's"
620        );
621    }
622
623    #[test]
624    fn backward_plan_aliases_forward_param_slots() {
625        let f = DType::F32;
626        let mut fwd = Graph::new("fwd");
627        let x = fwd.input("x", Shape::new(&[2, 4], f));
628        let w = fwd.param("w", Shape::new(&[4, 4], f));
629        let mm = fwd.matmul(x, w, Shape::new(&[2, 4], f));
630        fwd.set_outputs(vec![mm]);
631        let fwd_plan = plan_memory_aligned(&fwd, 64);
632        let layout = SharedWeightLayout::from_forward(&fwd, &fwd_plan);
633
634        let mut bwd = Graph::new("bwd_grad");
635        let x2 = bwd.input("x", Shape::new(&[2, 4], f));
636        let w2 = bwd.param("w", Shape::new(&[4, 4], f));
637        let mm2 = bwd.matmul(x2, w2, Shape::new(&[2, 4], f));
638        bwd.set_outputs(vec![mm2]);
639
640        let bwd_plan = plan_memory_backward(&bwd, 64, &layout);
641        let fwd_w_off = fwd_plan.assignments[&w].offset;
642        let bwd_w_off = bwd_plan.assignments[&w2].offset;
643        assert_eq!(bwd_w_off, fwd_w_off, "backward w must share forward offset");
644        assert!(
645            !bwd_plan.assignments.contains_key(&w2)
646                || bwd_plan.assignments[&w2].offset == fwd_w_off
647        );
648    }
649
650    #[test]
651    fn overlapping_buffers_get_separate_memory() {
652        let mut g = Graph::new("test");
653        let f = DType::F32;
654
655        let x = g.input("x", Shape::new(&[100, 384], f));
656        let w = g.param("w", Shape::new(&[384, 384], f));
657
658        let mm = g.matmul(x, w, Shape::new(&[100, 384], f));
659        // Both mm and x are live at the same time (mm uses x)
660        // x is also an output, so it stays live
661        let add = g.binary(BinaryOp::Add, mm, x, Shape::new(&[100, 384], f));
662        g.set_outputs(vec![add]);
663
664        let plan = plan_memory(&g);
665        let mm_slot = &plan.assignments[&mm];
666        let add_slot = &plan.assignments[&add];
667
668        // mm and add overlap in time, so they must not overlap in memory
669        let mm_end = mm_slot.offset + mm_slot.size;
670        let add_end = add_slot.offset + add_slot.size;
671        let no_overlap = mm_end <= add_slot.offset || add_end <= mm_slot.offset;
672        assert!(no_overlap, "overlapping buffers must have separate memory");
673    }
674}