1use rlx_ir::op::BinaryOp;
26use rlx_ir::{Graph, NodeId, Op};
27use std::collections::HashMap;
28
29const BOUNDARY_TAIL_GUARD_BYTES: usize = 128;
33
34fn boundary_min_slot_bytes(op: &rlx_ir::Op, alignment: usize) -> usize {
35 if matches!(
36 op,
37 rlx_ir::Op::Input { .. } | rlx_ir::Op::Param { .. } | rlx_ir::Op::Constant { .. }
38 ) {
39 alignment.max(1)
40 } else {
41 0
42 }
43}
44
45fn boundary_tail_guard(op: &rlx_ir::Op, alignment: usize) -> usize {
46 if matches!(
47 op,
48 rlx_ir::Op::Input { .. } | rlx_ir::Op::Param { .. } | rlx_ir::Op::Constant { .. }
49 ) {
50 alignment.max(BOUNDARY_TAIL_GUARD_BYTES)
51 } else {
52 0
53 }
54}
55fn pure_view_offset(graph: &Graph, node: &rlx_ir::Node) -> Option<(NodeId, usize)> {
69 match &node.op {
70 Op::Reshape { .. } => Some((node.inputs[0], 0)),
71 Op::Cast { to } => {
72 let parent = graph.node(node.inputs[0]);
73 if parent.shape.dtype() == *to {
74 Some((node.inputs[0], 0))
75 } else {
76 None
77 }
78 }
79 Op::Narrow {
80 axis,
81 start,
82 len: _,
83 } if *axis == 0 => {
84 let parent = graph.node(node.inputs[0]);
85 let inner_elems: usize = (1..parent.shape.rank())
87 .map(|i| parent.shape.dim(i).unwrap_static())
88 .product();
89 let dt_bytes = parent.shape.dtype().size_bytes();
90 Some((node.inputs[0], start * inner_elems * dt_bytes))
91 }
92 _ => None,
93 }
94}
95
96pub fn is_pure_view(graph: &Graph, node: &rlx_ir::Node) -> bool {
100 pure_view_offset(graph, node).is_some()
101}
102
103#[derive(Debug, Clone)]
105pub struct BufferSlot {
106 pub offset: usize,
108 pub size: usize,
110}
111
112#[derive(Debug, Clone)]
114pub struct MemoryPlan {
115 pub arena_size: usize,
117 pub assignments: HashMap<NodeId, BufferSlot>,
119 pub schedule: Vec<NodeId>,
121}
122
123impl MemoryPlan {
124 pub fn total_unshared_bytes(&self) -> usize {
128 self.assignments.values().map(|s| s.size).sum()
129 }
130
131 pub fn bytes_saved(&self) -> usize {
134 self.total_unshared_bytes().saturating_sub(self.arena_size)
135 }
136
137 pub fn report(&self) -> String {
145 let mut rows: Vec<(usize, usize, NodeId)> = self
146 .assignments
147 .iter()
148 .map(|(id, slot)| (slot.offset, slot.size, *id))
149 .collect();
150 rows.sort();
151 let mut out = String::new();
152 out.push_str(&format!(
153 "# arena_size={} total_unshared={} saved={}\n",
154 self.arena_size,
155 self.total_unshared_bytes(),
156 self.bytes_saved()
157 ));
158 out.push_str("# offset\tsize\tnode\n");
159 for (off, sz, id) in rows {
160 out.push_str(&format!("{off}\t{sz}\t{id}\n"));
161 }
162 out
163 }
164}
165
166pub fn collect_view_aliases(graph: &Graph) -> HashMap<NodeId, (NodeId, usize)> {
168 let mut out = HashMap::new();
169 for node in graph.nodes() {
170 if pure_view_offset(graph, node).is_some() {
171 let (root, off) = resolve_view_root(graph, node.id);
172 out.insert(node.id, (root, off));
173 }
174 }
175 out
176}
177
178fn resolve_view_root(graph: &Graph, mut id: NodeId) -> (NodeId, usize) {
181 let mut total_offset = 0usize;
182 loop {
183 let node = graph.node(id);
184 match pure_view_offset(graph, node) {
185 Some((parent, off)) => {
186 total_offset += off;
187 id = parent;
188 }
189 None => return (id, total_offset),
190 }
191 }
192}
193
194#[allow(dead_code)]
198fn compute_live_ranges(graph: &Graph) -> HashMap<NodeId, (usize, usize)> {
199 compute_live_ranges_opts(graph, true)
200}
201
202fn compute_live_ranges_opts(
203 graph: &Graph,
204 pin_output_ancestors: bool,
205) -> HashMap<NodeId, (usize, usize)> {
206 let mut ranges: HashMap<NodeId, (usize, usize)> = HashMap::new();
207
208 for (step, node) in graph.nodes().iter().enumerate() {
209 ranges.entry(node.id).or_insert((step, step));
211
212 for &input in &node.inputs {
217 let (root, _off) = resolve_view_root(graph, input);
218 ranges.entry(root).and_modify(|r| r.1 = r.1.max(step));
219 if root != input {
223 ranges.entry(input).and_modify(|r| r.1 = r.1.max(step));
224 }
225 }
226 }
227
228 let last_step = graph.len();
230 for &out in &graph.outputs {
231 let (root, _off) = resolve_view_root(graph, out);
232 ranges.entry(root).and_modify(|r| r.1 = last_step);
233 if root != out {
234 ranges.entry(out).and_modify(|r| r.1 = last_step);
235 }
236 }
237
238 {
243 let mut stack: Vec<NodeId> = graph.outputs.clone();
244 let mut seen = std::collections::HashSet::new();
245 while let Some(id) = stack.pop() {
246 if !seen.insert(id) {
247 continue;
248 }
249 let (root, _) = resolve_view_root(graph, id);
250 ranges.entry(root).and_modify(|r| r.1 = last_step);
251 if root != id {
252 ranges.entry(id).and_modify(|r| r.1 = last_step);
253 }
254 if pin_output_ancestors {
261 for &input in &graph.node(id).inputs {
262 stack.push(input);
263 }
264 }
265 }
266 }
267
268 for node in graph.nodes() {
275 if matches!(
276 node.op,
277 rlx_ir::Op::Param { .. } | rlx_ir::Op::Input { .. } | rlx_ir::Op::Constant { .. }
278 ) {
279 ranges.entry(node.id).and_modify(|r| {
280 r.0 = 0;
281 r.1 = last_step;
282 });
283 }
284 }
285
286 ranges
287}
288
289fn extend_node_chain_liveness_to_end(
294 graph: &Graph,
295 ranges: &mut HashMap<NodeId, (usize, usize)>,
296 start: NodeId,
297 last_step: usize,
298) {
299 let mut stack = vec![start];
300 let mut seen = std::collections::HashSet::new();
301 while let Some(id) = stack.pop() {
302 if !seen.insert(id) {
303 continue;
304 }
305 let (root, _) = resolve_view_root(graph, id);
306 ranges.entry(root).and_modify(|r| r.1 = last_step);
307 if root != id {
308 ranges.entry(id).and_modify(|r| r.1 = last_step);
309 }
310 for &input in &graph.node(id).inputs {
311 stack.push(input);
312 }
313 }
314}
315
316fn extend_custom_op_input_liveness(graph: &Graph, ranges: &mut HashMap<NodeId, (usize, usize)>) {
319 let last_step = graph.len();
320 for node in graph.nodes() {
321 let Op::Custom {
322 name, num_inputs, ..
323 } = &node.op
324 else {
325 continue;
326 };
327 if !name.starts_with("onnx.") {
328 continue;
329 }
330 let n = (*num_inputs as usize).min(node.inputs.len());
331 for &input in &node.inputs[..n] {
332 extend_node_chain_liveness_to_end(graph, ranges, input, last_step);
333 }
334 }
335 for node in graph.nodes() {
349 match &node.op {
350 Op::DequantMatMul { .. } => {
351 if let Some(&x) = node.inputs.first() {
352 extend_node_chain_liveness_to_end(graph, ranges, x, last_step);
353 }
354 }
355 Op::DequantGroupedMatMul { .. } => {
356 if let Some(&x) = node.inputs.first() {
357 extend_node_chain_liveness_to_end(graph, ranges, x, last_step);
358 }
359 }
360 _ => {}
361 }
362 }
363}
364
365fn extend_bert_hidden_liveness(graph: &Graph, ranges: &mut HashMap<NodeId, (usize, usize)>) {
369 let uses_onnx_qmatmul = graph.nodes().iter().any(|node| {
370 matches!(
371 &node.op,
372 Op::Custom { name, .. } if name == "onnx.QMatMul" || name == "onnx.ActCopy"
373 )
374 });
375 if !uses_onnx_qmatmul {
376 return;
377 }
378 let last_step = graph.len();
379 for node in graph.nodes() {
380 match &node.op {
381 Op::LayerNorm { .. } | Op::LayerNorm2d { .. } => {
382 if let Some(&input) = node.inputs.first() {
383 extend_node_chain_liveness_to_end(graph, ranges, input, last_step);
384 }
385 ranges.entry(node.id).and_modify(|r| r.1 = last_step);
386 }
387 Op::Binary(BinaryOp::Add) => {
388 for &input in &node.inputs {
389 extend_node_chain_liveness_to_end(graph, ranges, input, last_step);
390 }
391 ranges.entry(node.id).and_modify(|r| r.1 = last_step);
392 }
393 _ => {}
394 }
395 }
396}
397
398fn extend_onnx_duration_epilogue_liveness(
399 graph: &Graph,
400 ranges: &mut HashMap<NodeId, (usize, usize)>,
401) {
402 if !graph_exports_onnx_duration(graph) {
405 return;
406 }
407 let last_step = graph.len();
408 for &out in &graph.outputs {
409 extend_node_chain_liveness_to_end(graph, ranges, out, last_step);
410 }
411 for node in graph.nodes() {
412 let keep = match &node.op {
413 Op::Custom { name, .. }
414 if name == "onnx.ConcatFromSequence" || name == "onnx.KittenConcatFromSequence" =>
415 {
416 true
417 }
418 Op::Expand { .. } => node.shape.dtype() == rlx_ir::DType::I64,
419 Op::Cast { to, .. } => *to == rlx_ir::DType::I64,
420 Op::Where => node.shape.dtype() == rlx_ir::DType::I64,
421 Op::Binary(_) => node.shape.dtype() == rlx_ir::DType::I64,
422 _ => node.shape.dtype() == rlx_ir::DType::I64 && node.shape.rank() <= 2,
423 };
424 if keep {
425 extend_node_chain_liveness_to_end(graph, ranges, node.id, last_step);
426 ranges.entry(node.id).and_modify(|r| r.1 = last_step);
427 }
428 }
429}
430
431fn graph_exports_onnx_duration(graph: &Graph) -> bool {
432 graph
433 .outputs
434 .iter()
435 .any(|&id| graph.node(id).shape.dtype() == rlx_ir::DType::I64)
436}
437
438#[allow(dead_code)]
439fn graph_uses_onnx_duration_epilogue(graph: &Graph) -> bool {
440 if graph.nodes().iter().any(|node| {
441 matches!(
442 &node.op,
443 Op::Custom { name, .. }
444 if name == "onnx.ConcatFromSequence"
445 || name == "onnx.KittenConcatFromSequence"
446 )
447 }) {
448 return true;
449 }
450 graph_exports_onnx_duration(graph)
451}
452
453fn extend_packed_qkv_parent_liveness(graph: &Graph, ranges: &mut HashMap<NodeId, (usize, usize)>) {
454 for (step, node) in graph.nodes().iter().enumerate() {
455 let rlx_ir::Op::Attention { .. } = &node.op else {
456 continue;
457 };
458 if node.inputs.len() < 3 {
459 continue;
460 }
461 let Some((parent, _, _)) = rlx_ir::detect_packed_bshd_qkv_attention(
462 graph,
463 node.inputs[0],
464 node.inputs[1],
465 node.inputs[2],
466 ) else {
467 continue;
468 };
469 let (root, _) = resolve_view_root(graph, parent);
470 ranges.entry(root).and_modify(|r| r.1 = r.1.max(step));
471 if root != parent {
472 ranges.entry(parent).and_modify(|r| r.1 = r.1.max(step));
473 }
474 }
475}
476
477#[derive(Debug, Clone, Copy, PartialEq, Eq)]
489pub struct MemoryPlanOptions {
490 pub allocate_params: bool,
491 pub allocate_inputs: bool,
492 pub allocate_constants: bool,
493 pub arena_no_reuse: bool,
495 pub pin_output_ancestors: bool,
502}
503
504impl MemoryPlanOptions {
505 pub fn inference() -> Self {
506 Self {
507 allocate_params: true,
508 allocate_inputs: true,
509 allocate_constants: true,
510 arena_no_reuse: std::env::var("RLX_ARENA_NO_REUSE")
511 .ok()
512 .is_some_and(|v| v == "1" || v.eq_ignore_ascii_case("true")),
513 pin_output_ancestors: true,
514 }
515 }
516
517 pub fn backward_activations_only() -> Self {
519 Self {
520 allocate_params: false,
521 allocate_inputs: true,
522 allocate_constants: true,
523 arena_no_reuse: std::env::var("RLX_ARENA_NO_REUSE")
524 .ok()
525 .is_some_and(|v| v == "1" || v.eq_ignore_ascii_case("true")),
526 pin_output_ancestors: true,
527 }
528 }
529}
530
531impl Default for MemoryPlanOptions {
532 fn default() -> Self {
533 Self::inference()
534 }
535}
536
537#[derive(Debug, Clone, PartialEq, Eq)]
539pub struct SharedWeightLayout {
540 pub arena_size: usize,
541 pub slots: Vec<WeightSlot>,
542}
543
544#[derive(Debug, Clone, PartialEq, Eq)]
546pub struct WeightSlot {
547 pub name: String,
548 pub forward_id: NodeId,
549 pub offset: usize,
550 pub size: usize,
551}
552
553impl SharedWeightLayout {
554 pub fn from_forward(graph: &Graph, plan: &MemoryPlan) -> Self {
556 let mut slots = Vec::new();
557 for node in graph.nodes() {
558 if let rlx_ir::Op::Param { name } = &node.op {
559 if let Some(slot) = plan.assignments.get(&node.id) {
560 slots.push(WeightSlot {
561 name: name.clone(),
562 forward_id: node.id,
563 offset: slot.offset,
564 size: slot.size,
565 });
566 }
567 }
568 }
569 slots.sort_by(|a, b| a.name.cmp(&b.name));
570 let arena_size = slots.iter().map(|s| s.offset + s.size).max().unwrap_or(0);
571 Self { arena_size, slots }
572 }
573
574 pub fn apply_to_plan(&self, graph: &Graph, plan: &mut MemoryPlan) {
576 let by_name: std::collections::HashMap<&str, &WeightSlot> =
577 self.slots.iter().map(|s| (s.name.as_str(), s)).collect();
578 for node in graph.nodes() {
579 if let rlx_ir::Op::Param { name } = &node.op {
580 let Some(slot) = by_name.get(name.as_str()) else {
581 continue;
582 };
583 plan.assignments.insert(
584 node.id,
585 BufferSlot {
586 offset: slot.offset,
587 size: slot.size,
588 },
589 );
590 }
591 }
592 plan.arena_size = plan.arena_size.max(self.arena_size);
593 }
594}
595
596#[inline]
597fn plans_boundary_buffer(op: &rlx_ir::Op, opts: MemoryPlanOptions) -> bool {
598 match op {
599 rlx_ir::Op::Param { .. } => opts.allocate_params,
600 rlx_ir::Op::Input { .. } => opts.allocate_inputs,
601 rlx_ir::Op::Constant { .. } => opts.allocate_constants,
602 _ => true,
603 }
604}
605
606pub fn plan_memory(graph: &Graph) -> MemoryPlan {
608 plan_memory_aligned(graph, 64)
609}
610
611pub fn plan_memory_with_options(
613 graph: &Graph,
614 alignment: usize,
615 opts: MemoryPlanOptions,
616) -> MemoryPlan {
617 plan_memory_aligned_inner(graph, alignment, opts, None, false)
618}
619
620pub fn plan_memory_aligned(graph: &Graph, alignment: usize) -> MemoryPlan {
622 plan_memory_aligned_inner(graph, alignment, MemoryPlanOptions::default(), None, false)
623}
624
625pub fn plan_memory_f32_uniform(graph: &Graph, alignment: usize) -> MemoryPlan {
629 let opts = MemoryPlanOptions {
633 pin_output_ancestors: false,
634 ..MemoryPlanOptions::default()
635 };
636 plan_memory_aligned_inner(graph, alignment, opts, None, true)
637}
638
639pub fn plan_memory_backward(
641 graph: &Graph,
642 alignment: usize,
643 weights: &SharedWeightLayout,
644) -> MemoryPlan {
645 plan_memory_aligned_inner(
646 graph,
647 alignment,
648 MemoryPlanOptions::backward_activations_only(),
649 Some(weights),
650 false,
651 )
652}
653
654#[inline]
655fn node_slot_bytes(node: &rlx_ir::Node, f32_uniform: bool) -> usize {
656 if f32_uniform {
657 node.shape.num_elements().unwrap_or(0) * 4
658 } else {
659 node.shape.size_bytes().unwrap_or(0)
660 }
661}
662
663fn plan_memory_aligned_inner(
664 graph: &Graph,
665 alignment: usize,
666 opts: MemoryPlanOptions,
667 weights: Option<&SharedWeightLayout>,
668 f32_uniform: bool,
669) -> MemoryPlan {
670 let mut ranges = compute_live_ranges_opts(graph, opts.pin_output_ancestors);
671 extend_packed_qkv_parent_liveness(graph, &mut ranges);
672 extend_custom_op_input_liveness(graph, &mut ranges);
673 extend_bert_hidden_liveness(graph, &mut ranges);
674 extend_onnx_duration_epilogue_liveness(graph, &mut ranges);
675 let mut opts = opts;
676 if graph_exports_onnx_duration(graph) {
677 opts.arena_no_reuse = true;
678 }
679 struct BufInfo {
681 id: NodeId,
682 size: usize,
683 birth: usize,
684 death: usize,
685 }
686
687 let mut buffers: Vec<BufInfo> = Vec::new();
688 for node in graph.nodes() {
689 if pure_view_offset(graph, node).is_some() {
692 continue;
693 }
694 let raw_size = node_slot_bytes(node, f32_uniform);
695 let size = if raw_size == 0 {
696 boundary_min_slot_bytes(&node.op, alignment)
697 } else {
698 raw_size
699 };
700 if size > 0
701 && let Some(&(birth, death)) = ranges.get(&node.id)
702 && plans_boundary_buffer(&node.op, opts)
703 {
704 buffers.push(BufInfo {
705 id: node.id,
706 size,
707 birth,
708 death,
709 });
710 }
711 }
712
713 buffers.sort_by_key(|b| std::cmp::Reverse(b.size));
715
716 let mut assignments: HashMap<NodeId, BufferSlot> = HashMap::new();
718 let mut arena_size: usize = 0;
719
720 let mut placed: Vec<(usize, usize, usize, usize)> = Vec::new(); for buf in &buffers {
724 let align = alignment;
725 let node = graph.node(buf.id);
726 let tail_guard = boundary_tail_guard(&node.op, align);
727 let placement_size = buf.size + tail_guard;
728 let mut best_offset: Option<usize> = None;
729
730 let mut candidates = vec![0usize];
733 for &(p_off, p_size, _, _) in &placed {
734 candidates.push(p_off + p_size);
735 }
736 candidates.sort_unstable();
737 candidates.dedup();
738
739 for &candidate_offset in &candidates {
740 let aligned = (candidate_offset + align - 1) & !(align - 1);
741 let end = aligned + placement_size;
742
743 let conflict = placed.iter().any(|&(p_off, p_size, p_birth, p_death)| {
744 let p_end = p_off + p_size;
745 let mem_overlap = aligned < p_end && end > p_off;
746 let time_overlap = buf.birth <= p_death && buf.death >= p_birth;
747 mem_overlap && time_overlap
748 });
749
750 if !conflict {
751 match best_offset {
752 None => best_offset = Some(aligned),
753 Some(best) if aligned < best => best_offset = Some(aligned),
754 _ => {}
755 }
756 }
757 }
758
759 let aligned = if opts.arena_no_reuse {
760 (arena_size + align - 1) & !(align - 1)
761 } else {
762 best_offset.unwrap_or_else(|| {
763 (arena_size + align - 1) & !(align - 1)
765 })
766 };
767 assignments.insert(
768 buf.id,
769 BufferSlot {
770 offset: aligned,
771 size: buf.size,
772 },
773 );
774 placed.push((aligned, placement_size, buf.birth, buf.death));
775 arena_size = arena_size.max(aligned + placement_size);
776 }
777
778 for node in graph.nodes() {
784 if pure_view_offset(graph, node).is_some() {
785 let (root, off) = resolve_view_root(graph, node.id);
786 if let Some(root_slot) = assignments.get(&root).cloned() {
787 let view_size = node_slot_bytes(node, f32_uniform);
788 assignments.insert(
789 node.id,
790 BufferSlot {
791 offset: root_slot.offset + off,
792 size: view_size,
793 },
794 );
795 }
796 }
797 }
798
799 let schedule = graph.topo_order().collect();
800
801 let mut plan = MemoryPlan {
802 arena_size,
803 assignments,
804 schedule,
805 };
806 if let Some(w) = weights {
807 w.apply_to_plan(graph, &mut plan);
808 }
809 plan
810}
811
812#[cfg(test)]
813mod tests {
814 use super::*;
815 use rlx_ir::*;
816
817 #[test]
818 fn non_overlapping_buffers_share_memory() {
819 let mut g = Graph::new("test");
820 let f = DType::F32;
821
822 let x = g.input("x", Shape::new(&[100, 384], f)); let w1 = g.param("w1", Shape::new(&[384, 384], f));
824 let w2 = g.param("w2", Shape::new(&[384, 384], f));
825
826 let mm1 = g.matmul(x, w1, Shape::new(&[100, 384], f)); let mm2 = g.matmul(mm1, w2, Shape::new(&[100, 384], f)); g.set_outputs(vec![mm2]);
830
831 let plan = plan_memory(&g);
832 println!("Arena size: {} bytes", plan.arena_size);
833 for (id, slot) in &plan.assignments {
834 if let Some((b, d)) = compute_live_ranges(&g).get(id) {
835 println!(
836 " {id}: offset={}, size={}, live=[{b}, {d}]",
837 slot.offset, slot.size
838 );
839 }
840 }
841
842 let total_logical: usize = plan.assignments.values().map(|s| s.size).sum();
846 let align_slack = plan.assignments.len() * BOUNDARY_TAIL_GUARD_BYTES;
847 assert!(
848 plan.arena_size <= total_logical + align_slack,
849 "arena {} should be <= logical sum {} + slack {}",
850 plan.arena_size,
851 total_logical,
852 align_slack
853 );
854 }
855
856 #[test]
857 fn plan_report_includes_savings() {
858 let mut g = Graph::new("rep");
862 let f = DType::F32;
863 let x = g.input("x", Shape::new(&[16], f));
864 let w = g.param("w", Shape::new(&[16, 16], f));
865 let mm1 = g.matmul(x, w, Shape::new(&[1, 16], f));
866 let mm2 = g.matmul(mm1, w, Shape::new(&[1, 16], f));
867 g.set_outputs(vec![mm2]);
868
869 let plan = plan_memory(&g);
870 let r = plan.report();
871 assert!(r.starts_with("# arena_size="));
873 assert!(r.contains("total_unshared="));
874 assert!(r.contains("saved="));
875 let body: Vec<&str> = r.lines().filter(|l| !l.starts_with('#')).collect();
877 assert!(!body.is_empty());
878 assert!(plan.assignments.contains_key(&mm1));
880 assert!(plan.assignments.contains_key(&mm2));
881 }
882
883 #[test]
884 fn view_ops_alias_parent_slot() {
885 use rlx_ir::GraphExt;
888 let mut g = Graph::new("views");
889 let f = DType::F32;
890 let x = g.input("x", Shape::new(&[8, 4], f)); let w = g.param("w", Shape::new(&[4, 4], f)); let mm = g.matmul(x, w, Shape::new(&[8, 4], f)); let r = g.reshape_(mm, vec![32]); let c = g.cast(r, DType::F32); let n = g.narrow_(c, 0, 8, 16); g.set_outputs(vec![n]);
897
898 let plan = plan_memory(&g);
899
900 let mm_off = plan.assignments[&mm].offset;
903 assert_eq!(
904 plan.assignments[&r].offset, mm_off,
905 "reshape view should alias mm slot exactly"
906 );
907 assert_eq!(
908 plan.assignments[&c].offset, mm_off,
909 "same-dtype cast view should alias mm slot exactly"
910 );
911 assert_eq!(
912 plan.assignments[&n].offset,
913 mm_off + 32,
914 "axis-0 narrow start=8 should alias mm slot + 8*4 bytes"
915 );
916 assert_eq!(
917 plan.assignments[&n].size, 64,
918 "narrow view's size is its own (16 f32 = 64B), not parent's"
919 );
920 }
921
922 #[test]
923 fn backward_plan_aliases_forward_param_slots() {
924 let f = DType::F32;
925 let mut fwd = Graph::new("fwd");
926 let x = fwd.input("x", Shape::new(&[2, 4], f));
927 let w = fwd.param("w", Shape::new(&[4, 4], f));
928 let mm = fwd.matmul(x, w, Shape::new(&[2, 4], f));
929 fwd.set_outputs(vec![mm]);
930 let fwd_plan = plan_memory_aligned(&fwd, 64);
931 let layout = SharedWeightLayout::from_forward(&fwd, &fwd_plan);
932
933 let mut bwd = Graph::new("bwd_grad");
934 let x2 = bwd.input("x", Shape::new(&[2, 4], f));
935 let w2 = bwd.param("w", Shape::new(&[4, 4], f));
936 let mm2 = bwd.matmul(x2, w2, Shape::new(&[2, 4], f));
937 bwd.set_outputs(vec![mm2]);
938
939 let bwd_plan = plan_memory_backward(&bwd, 64, &layout);
940 let fwd_w_off = fwd_plan.assignments[&w].offset;
941 let bwd_w_off = bwd_plan.assignments[&w2].offset;
942 assert_eq!(bwd_w_off, fwd_w_off, "backward w must share forward offset");
943 assert!(
944 !bwd_plan.assignments.contains_key(&w2)
945 || bwd_plan.assignments[&w2].offset == fwd_w_off
946 );
947 }
948
949 #[test]
950 fn overlapping_buffers_get_separate_memory() {
951 let mut g = Graph::new("test");
952 let f = DType::F32;
953
954 let x = g.input("x", Shape::new(&[100, 384], f));
955 let w = g.param("w", Shape::new(&[384, 384], f));
956
957 let mm = g.matmul(x, w, Shape::new(&[100, 384], f));
958 let add = g.binary(BinaryOp::Add, mm, x, Shape::new(&[100, 384], f));
961 g.set_outputs(vec![add]);
962
963 let plan = plan_memory(&g);
964 let mm_slot = &plan.assignments[&mm];
965 let add_slot = &plan.assignments[&add];
966
967 let mm_end = mm_slot.offset + mm_slot.size;
969 let add_end = add_slot.offset + add_slot.size;
970 let no_overlap = mm_end <= add_slot.offset || add_end <= mm_slot.offset;
971 assert!(no_overlap, "overlapping buffers must have separate memory");
972 }
973
974 #[test]
975 fn zero_length_inputs_get_arena_slots() {
976 let mut g = Graph::new("empty_past");
977 let f = DType::F32;
978 let past = g.input("past_k", Shape::new(&[1, 0, 8], f));
979 let x = g.input("x", Shape::new(&[1, 1, 8], f));
980 let cat = g.concat(vec![past, x], 1, Shape::new(&[1, 1, 8], f));
981 g.set_outputs(vec![cat]);
982
983 let plan = plan_memory(&g);
984 assert!(
985 plan.assignments.contains_key(&past),
986 "zero-length decode past input must have an arena slot"
987 );
988 assert!(plan.assignments[&past].size >= 64);
989 }
990
991 #[test]
992 fn duration_export_forces_no_reuse_waveform_only_does_not() {
993 let f = DType::F32;
994 let mut wave_only = Graph::new("wave_only");
995 let w = wave_only.input("wave", Shape::new(&[1024], f));
996 wave_only.set_outputs(vec![w]);
997 assert!(!graph_exports_onnx_duration(&wave_only));
998
999 let mut dual = Graph::new("dual");
1000 let w2 = dual.input("wave", Shape::new(&[1024], f));
1001 let d = dual.input("dur", Shape::new(&[8], DType::I64));
1002 dual.set_outputs(vec![w2, d]);
1003 assert!(graph_exports_onnx_duration(&dual));
1004 }
1005}