1use rlx_ir::op::BinaryOp;
26use rlx_ir::{Graph, NodeId, Op};
27use std::collections::HashMap;
28
29const BOUNDARY_TAIL_GUARD_BYTES: usize = 128;
33
34fn boundary_tail_guard(op: &rlx_ir::Op, alignment: usize) -> usize {
35 if matches!(
36 op,
37 rlx_ir::Op::Input { .. } | rlx_ir::Op::Param { .. } | rlx_ir::Op::Constant { .. }
38 ) {
39 alignment.max(BOUNDARY_TAIL_GUARD_BYTES)
40 } else {
41 0
42 }
43}
44fn pure_view_offset(graph: &Graph, node: &rlx_ir::Node) -> Option<(NodeId, usize)> {
58 match &node.op {
59 Op::Reshape { .. } => Some((node.inputs[0], 0)),
60 Op::Cast { to } => {
61 let parent = graph.node(node.inputs[0]);
62 if parent.shape.dtype() == *to {
63 Some((node.inputs[0], 0))
64 } else {
65 None
66 }
67 }
68 Op::Narrow {
69 axis,
70 start,
71 len: _,
72 } if *axis == 0 => {
73 let parent = graph.node(node.inputs[0]);
74 let inner_elems: usize = (1..parent.shape.rank())
76 .map(|i| parent.shape.dim(i).unwrap_static())
77 .product();
78 let dt_bytes = parent.shape.dtype().size_bytes();
79 Some((node.inputs[0], start * inner_elems * dt_bytes))
80 }
81 _ => None,
82 }
83}
84
85pub fn is_pure_view(graph: &Graph, node: &rlx_ir::Node) -> bool {
89 pure_view_offset(graph, node).is_some()
90}
91
92#[derive(Debug, Clone)]
94pub struct BufferSlot {
95 pub offset: usize,
97 pub size: usize,
99}
100
101#[derive(Debug, Clone)]
103pub struct MemoryPlan {
104 pub arena_size: usize,
106 pub assignments: HashMap<NodeId, BufferSlot>,
108 pub schedule: Vec<NodeId>,
110}
111
112impl MemoryPlan {
113 pub fn total_unshared_bytes(&self) -> usize {
117 self.assignments.values().map(|s| s.size).sum()
118 }
119
120 pub fn bytes_saved(&self) -> usize {
123 self.total_unshared_bytes().saturating_sub(self.arena_size)
124 }
125
126 pub fn report(&self) -> String {
134 let mut rows: Vec<(usize, usize, NodeId)> = self
135 .assignments
136 .iter()
137 .map(|(id, slot)| (slot.offset, slot.size, *id))
138 .collect();
139 rows.sort();
140 let mut out = String::new();
141 out.push_str(&format!(
142 "# arena_size={} total_unshared={} saved={}\n",
143 self.arena_size,
144 self.total_unshared_bytes(),
145 self.bytes_saved()
146 ));
147 out.push_str("# offset\tsize\tnode\n");
148 for (off, sz, id) in rows {
149 out.push_str(&format!("{off}\t{sz}\t{id}\n"));
150 }
151 out
152 }
153}
154
155pub fn collect_view_aliases(graph: &Graph) -> HashMap<NodeId, (NodeId, usize)> {
157 let mut out = HashMap::new();
158 for node in graph.nodes() {
159 if pure_view_offset(graph, node).is_some() {
160 let (root, off) = resolve_view_root(graph, node.id);
161 out.insert(node.id, (root, off));
162 }
163 }
164 out
165}
166
167fn resolve_view_root(graph: &Graph, mut id: NodeId) -> (NodeId, usize) {
170 let mut total_offset = 0usize;
171 loop {
172 let node = graph.node(id);
173 match pure_view_offset(graph, node) {
174 Some((parent, off)) => {
175 total_offset += off;
176 id = parent;
177 }
178 None => return (id, total_offset),
179 }
180 }
181}
182
183fn compute_live_ranges(graph: &Graph) -> HashMap<NodeId, (usize, usize)> {
187 let mut ranges: HashMap<NodeId, (usize, usize)> = HashMap::new();
188
189 for (step, node) in graph.nodes().iter().enumerate() {
190 ranges.entry(node.id).or_insert((step, step));
192
193 for &input in &node.inputs {
198 let (root, _off) = resolve_view_root(graph, input);
199 ranges.entry(root).and_modify(|r| r.1 = r.1.max(step));
200 if root != input {
204 ranges.entry(input).and_modify(|r| r.1 = r.1.max(step));
205 }
206 }
207 }
208
209 let last_step = graph.len();
211 for &out in &graph.outputs {
212 let (root, _off) = resolve_view_root(graph, out);
213 ranges.entry(root).and_modify(|r| r.1 = last_step);
214 if root != out {
215 ranges.entry(out).and_modify(|r| r.1 = last_step);
216 }
217 }
218
219 {
224 let mut stack: Vec<NodeId> = graph.outputs.clone();
225 let mut seen = std::collections::HashSet::new();
226 while let Some(id) = stack.pop() {
227 if !seen.insert(id) {
228 continue;
229 }
230 let (root, _) = resolve_view_root(graph, id);
231 ranges.entry(root).and_modify(|r| r.1 = last_step);
232 if root != id {
233 ranges.entry(id).and_modify(|r| r.1 = last_step);
234 }
235 for &input in &graph.node(id).inputs {
236 stack.push(input);
237 }
238 }
239 }
240
241 for node in graph.nodes() {
248 if matches!(
249 node.op,
250 rlx_ir::Op::Param { .. } | rlx_ir::Op::Input { .. } | rlx_ir::Op::Constant { .. }
251 ) {
252 ranges.entry(node.id).and_modify(|r| {
253 r.0 = 0;
254 r.1 = last_step;
255 });
256 }
257 }
258
259 ranges
260}
261
262fn extend_node_chain_liveness_to_end(
267 graph: &Graph,
268 ranges: &mut HashMap<NodeId, (usize, usize)>,
269 start: NodeId,
270 last_step: usize,
271) {
272 let mut stack = vec![start];
273 let mut seen = std::collections::HashSet::new();
274 while let Some(id) = stack.pop() {
275 if !seen.insert(id) {
276 continue;
277 }
278 let (root, _) = resolve_view_root(graph, id);
279 ranges.entry(root).and_modify(|r| r.1 = last_step);
280 if root != id {
281 ranges.entry(id).and_modify(|r| r.1 = last_step);
282 }
283 for &input in &graph.node(id).inputs {
284 stack.push(input);
285 }
286 }
287}
288
289fn extend_custom_op_input_liveness(graph: &Graph, ranges: &mut HashMap<NodeId, (usize, usize)>) {
292 let last_step = graph.len();
293 for node in graph.nodes() {
294 let Op::Custom {
295 name, num_inputs, ..
296 } = &node.op
297 else {
298 continue;
299 };
300 if !name.starts_with("onnx.") {
301 continue;
302 }
303 let n = (*num_inputs as usize).min(node.inputs.len());
304 for &input in &node.inputs[..n] {
305 extend_node_chain_liveness_to_end(graph, ranges, input, last_step);
306 }
307 }
308}
309
310fn extend_bert_hidden_liveness(graph: &Graph, ranges: &mut HashMap<NodeId, (usize, usize)>) {
314 let uses_onnx_qmatmul = graph.nodes().iter().any(|node| {
315 matches!(
316 &node.op,
317 Op::Custom { name, .. } if name == "onnx.QMatMul" || name == "onnx.ActCopy"
318 )
319 });
320 if !uses_onnx_qmatmul {
321 return;
322 }
323 let last_step = graph.len();
324 for node in graph.nodes() {
325 match &node.op {
326 Op::LayerNorm { .. } | Op::LayerNorm2d { .. } => {
327 if let Some(&input) = node.inputs.first() {
328 extend_node_chain_liveness_to_end(graph, ranges, input, last_step);
329 }
330 ranges.entry(node.id).and_modify(|r| r.1 = last_step);
331 }
332 Op::Binary(BinaryOp::Add) => {
333 for &input in &node.inputs {
334 extend_node_chain_liveness_to_end(graph, ranges, input, last_step);
335 }
336 ranges.entry(node.id).and_modify(|r| r.1 = last_step);
337 }
338 _ => {}
339 }
340 }
341}
342
343fn extend_packed_qkv_parent_liveness(graph: &Graph, ranges: &mut HashMap<NodeId, (usize, usize)>) {
344 for (step, node) in graph.nodes().iter().enumerate() {
345 let rlx_ir::Op::Attention { .. } = &node.op else {
346 continue;
347 };
348 if node.inputs.len() < 3 {
349 continue;
350 }
351 let Some((parent, _, _)) = rlx_ir::detect_packed_bshd_qkv_attention(
352 graph,
353 node.inputs[0],
354 node.inputs[1],
355 node.inputs[2],
356 ) else {
357 continue;
358 };
359 let (root, _) = resolve_view_root(graph, parent);
360 ranges.entry(root).and_modify(|r| r.1 = r.1.max(step));
361 if root != parent {
362 ranges.entry(parent).and_modify(|r| r.1 = r.1.max(step));
363 }
364 }
365}
366
367#[derive(Debug, Clone, Copy, PartialEq, Eq)]
379pub struct MemoryPlanOptions {
380 pub allocate_params: bool,
381 pub allocate_inputs: bool,
382 pub allocate_constants: bool,
383 pub arena_no_reuse: bool,
385}
386
387impl MemoryPlanOptions {
388 pub fn inference() -> Self {
389 Self {
390 allocate_params: true,
391 allocate_inputs: true,
392 allocate_constants: true,
393 arena_no_reuse: std::env::var("RLX_ARENA_NO_REUSE")
394 .ok()
395 .is_some_and(|v| v == "1" || v.eq_ignore_ascii_case("true")),
396 }
397 }
398
399 pub fn backward_activations_only() -> Self {
401 Self {
402 allocate_params: false,
403 allocate_inputs: true,
404 allocate_constants: true,
405 arena_no_reuse: std::env::var("RLX_ARENA_NO_REUSE")
406 .ok()
407 .is_some_and(|v| v == "1" || v.eq_ignore_ascii_case("true")),
408 }
409 }
410}
411
412impl Default for MemoryPlanOptions {
413 fn default() -> Self {
414 Self::inference()
415 }
416}
417
418#[derive(Debug, Clone, PartialEq, Eq)]
420pub struct SharedWeightLayout {
421 pub arena_size: usize,
422 pub slots: Vec<WeightSlot>,
423}
424
425#[derive(Debug, Clone, PartialEq, Eq)]
427pub struct WeightSlot {
428 pub name: String,
429 pub forward_id: NodeId,
430 pub offset: usize,
431 pub size: usize,
432}
433
434impl SharedWeightLayout {
435 pub fn from_forward(graph: &Graph, plan: &MemoryPlan) -> Self {
437 let mut slots = Vec::new();
438 for node in graph.nodes() {
439 if let rlx_ir::Op::Param { name } = &node.op {
440 if let Some(slot) = plan.assignments.get(&node.id) {
441 slots.push(WeightSlot {
442 name: name.clone(),
443 forward_id: node.id,
444 offset: slot.offset,
445 size: slot.size,
446 });
447 }
448 }
449 }
450 slots.sort_by(|a, b| a.name.cmp(&b.name));
451 let arena_size = slots.iter().map(|s| s.offset + s.size).max().unwrap_or(0);
452 Self { arena_size, slots }
453 }
454
455 pub fn apply_to_plan(&self, graph: &Graph, plan: &mut MemoryPlan) {
457 let by_name: std::collections::HashMap<&str, &WeightSlot> =
458 self.slots.iter().map(|s| (s.name.as_str(), s)).collect();
459 for node in graph.nodes() {
460 if let rlx_ir::Op::Param { name } = &node.op {
461 let Some(slot) = by_name.get(name.as_str()) else {
462 continue;
463 };
464 plan.assignments.insert(
465 node.id,
466 BufferSlot {
467 offset: slot.offset,
468 size: slot.size,
469 },
470 );
471 }
472 }
473 plan.arena_size = plan.arena_size.max(self.arena_size);
474 }
475}
476
477#[inline]
478fn plans_boundary_buffer(op: &rlx_ir::Op, opts: MemoryPlanOptions) -> bool {
479 match op {
480 rlx_ir::Op::Param { .. } => opts.allocate_params,
481 rlx_ir::Op::Input { .. } => opts.allocate_inputs,
482 rlx_ir::Op::Constant { .. } => opts.allocate_constants,
483 _ => true,
484 }
485}
486
487pub fn plan_memory(graph: &Graph) -> MemoryPlan {
489 plan_memory_aligned(graph, 64)
490}
491
492pub fn plan_memory_with_options(
494 graph: &Graph,
495 alignment: usize,
496 opts: MemoryPlanOptions,
497) -> MemoryPlan {
498 plan_memory_aligned_inner(graph, alignment, opts, None, false)
499}
500
501pub fn plan_memory_aligned(graph: &Graph, alignment: usize) -> MemoryPlan {
503 plan_memory_aligned_inner(graph, alignment, MemoryPlanOptions::default(), None, false)
504}
505
506pub fn plan_memory_f32_uniform(graph: &Graph, alignment: usize) -> MemoryPlan {
510 plan_memory_aligned_inner(graph, alignment, MemoryPlanOptions::default(), None, true)
511}
512
513pub fn plan_memory_backward(
515 graph: &Graph,
516 alignment: usize,
517 weights: &SharedWeightLayout,
518) -> MemoryPlan {
519 plan_memory_aligned_inner(
520 graph,
521 alignment,
522 MemoryPlanOptions::backward_activations_only(),
523 Some(weights),
524 false,
525 )
526}
527
528#[inline]
529fn node_slot_bytes(node: &rlx_ir::Node, f32_uniform: bool) -> usize {
530 if f32_uniform {
531 node.shape.num_elements().unwrap_or(0) * 4
532 } else {
533 node.shape.size_bytes().unwrap_or(0)
534 }
535}
536
537fn plan_memory_aligned_inner(
538 graph: &Graph,
539 alignment: usize,
540 opts: MemoryPlanOptions,
541 weights: Option<&SharedWeightLayout>,
542 f32_uniform: bool,
543) -> MemoryPlan {
544 let mut ranges = compute_live_ranges(graph);
545 extend_packed_qkv_parent_liveness(graph, &mut ranges);
546 extend_custom_op_input_liveness(graph, &mut ranges);
547 extend_bert_hidden_liveness(graph, &mut ranges);
548 struct BufInfo {
550 id: NodeId,
551 size: usize,
552 birth: usize,
553 death: usize,
554 }
555
556 let mut buffers: Vec<BufInfo> = Vec::new();
557 for node in graph.nodes() {
558 if pure_view_offset(graph, node).is_some() {
561 continue;
562 }
563 let size = node_slot_bytes(node, f32_uniform);
564 if size > 0
565 && let Some(&(birth, death)) = ranges.get(&node.id)
566 && plans_boundary_buffer(&node.op, opts)
567 {
568 buffers.push(BufInfo {
569 id: node.id,
570 size,
571 birth,
572 death,
573 });
574 }
575 }
576
577 buffers.sort_by_key(|b| std::cmp::Reverse(b.size));
579
580 let mut assignments: HashMap<NodeId, BufferSlot> = HashMap::new();
582 let mut arena_size: usize = 0;
583
584 let mut placed: Vec<(usize, usize, usize, usize)> = Vec::new(); for buf in &buffers {
588 let align = alignment;
589 let node = graph.node(buf.id);
590 let tail_guard = boundary_tail_guard(&node.op, align);
591 let placement_size = buf.size + tail_guard;
592 let mut best_offset: Option<usize> = None;
593
594 let mut candidates = vec![0usize];
597 for &(p_off, p_size, _, _) in &placed {
598 candidates.push(p_off + p_size);
599 }
600 candidates.sort_unstable();
601 candidates.dedup();
602
603 for &candidate_offset in &candidates {
604 let aligned = (candidate_offset + align - 1) & !(align - 1);
605 let end = aligned + placement_size;
606
607 let conflict = placed.iter().any(|&(p_off, p_size, p_birth, p_death)| {
608 let p_end = p_off + p_size;
609 let mem_overlap = aligned < p_end && end > p_off;
610 let time_overlap = buf.birth <= p_death && buf.death >= p_birth;
611 mem_overlap && time_overlap
612 });
613
614 if !conflict {
615 match best_offset {
616 None => best_offset = Some(aligned),
617 Some(best) if aligned < best => best_offset = Some(aligned),
618 _ => {}
619 }
620 }
621 }
622
623 let aligned = if opts.arena_no_reuse {
624 (arena_size + align - 1) & !(align - 1)
625 } else {
626 best_offset.unwrap_or_else(|| {
627 (arena_size + align - 1) & !(align - 1)
629 })
630 };
631 assignments.insert(
632 buf.id,
633 BufferSlot {
634 offset: aligned,
635 size: buf.size,
636 },
637 );
638 placed.push((aligned, placement_size, buf.birth, buf.death));
639 arena_size = arena_size.max(aligned + placement_size);
640 }
641
642 for node in graph.nodes() {
648 if pure_view_offset(graph, node).is_some() {
649 let (root, off) = resolve_view_root(graph, node.id);
650 if let Some(root_slot) = assignments.get(&root).cloned() {
651 let view_size = node_slot_bytes(node, f32_uniform);
652 assignments.insert(
653 node.id,
654 BufferSlot {
655 offset: root_slot.offset + off,
656 size: view_size,
657 },
658 );
659 }
660 }
661 }
662
663 let schedule = graph.topo_order().collect();
664
665 let mut plan = MemoryPlan {
666 arena_size,
667 assignments,
668 schedule,
669 };
670 if let Some(w) = weights {
671 w.apply_to_plan(graph, &mut plan);
672 }
673 plan
674}
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679 use rlx_ir::*;
680
681 #[test]
682 fn non_overlapping_buffers_share_memory() {
683 let mut g = Graph::new("test");
684 let f = DType::F32;
685
686 let x = g.input("x", Shape::new(&[100, 384], f)); let w1 = g.param("w1", Shape::new(&[384, 384], f));
688 let w2 = g.param("w2", Shape::new(&[384, 384], f));
689
690 let mm1 = g.matmul(x, w1, Shape::new(&[100, 384], f)); let mm2 = g.matmul(mm1, w2, Shape::new(&[100, 384], f)); g.set_outputs(vec![mm2]);
694
695 let plan = plan_memory(&g);
696 println!("Arena size: {} bytes", plan.arena_size);
697 for (id, slot) in &plan.assignments {
698 if let Some((b, d)) = compute_live_ranges(&g).get(id) {
699 println!(
700 " {id}: offset={}, size={}, live=[{b}, {d}]",
701 slot.offset, slot.size
702 );
703 }
704 }
705
706 let total_logical: usize = plan.assignments.values().map(|s| s.size).sum();
710 let align_slack = plan.assignments.len() * BOUNDARY_TAIL_GUARD_BYTES;
711 assert!(
712 plan.arena_size <= total_logical + align_slack,
713 "arena {} should be <= logical sum {} + slack {}",
714 plan.arena_size,
715 total_logical,
716 align_slack
717 );
718 }
719
720 #[test]
721 fn plan_report_includes_savings() {
722 let mut g = Graph::new("rep");
726 let f = DType::F32;
727 let x = g.input("x", Shape::new(&[16], f));
728 let w = g.param("w", Shape::new(&[16, 16], f));
729 let mm1 = g.matmul(x, w, Shape::new(&[1, 16], f));
730 let mm2 = g.matmul(mm1, w, Shape::new(&[1, 16], f));
731 g.set_outputs(vec![mm2]);
732
733 let plan = plan_memory(&g);
734 let r = plan.report();
735 assert!(r.starts_with("# arena_size="));
737 assert!(r.contains("total_unshared="));
738 assert!(r.contains("saved="));
739 let body: Vec<&str> = r.lines().filter(|l| !l.starts_with('#')).collect();
741 assert!(!body.is_empty());
742 assert!(plan.assignments.contains_key(&mm1));
744 assert!(plan.assignments.contains_key(&mm2));
745 }
746
747 #[test]
748 fn view_ops_alias_parent_slot() {
749 use rlx_ir::GraphExt;
752 let mut g = Graph::new("views");
753 let f = DType::F32;
754 let x = g.input("x", Shape::new(&[8, 4], f)); let w = g.param("w", Shape::new(&[4, 4], f)); let mm = g.matmul(x, w, Shape::new(&[8, 4], f)); let r = g.reshape_(mm, vec![32]); let c = g.cast(r, DType::F32); let n = g.narrow_(c, 0, 8, 16); g.set_outputs(vec![n]);
761
762 let plan = plan_memory(&g);
763
764 let mm_off = plan.assignments[&mm].offset;
767 assert_eq!(
768 plan.assignments[&r].offset, mm_off,
769 "reshape view should alias mm slot exactly"
770 );
771 assert_eq!(
772 plan.assignments[&c].offset, mm_off,
773 "same-dtype cast view should alias mm slot exactly"
774 );
775 assert_eq!(
776 plan.assignments[&n].offset,
777 mm_off + 32,
778 "axis-0 narrow start=8 should alias mm slot + 8*4 bytes"
779 );
780 assert_eq!(
781 plan.assignments[&n].size, 64,
782 "narrow view's size is its own (16 f32 = 64B), not parent's"
783 );
784 }
785
786 #[test]
787 fn backward_plan_aliases_forward_param_slots() {
788 let f = DType::F32;
789 let mut fwd = Graph::new("fwd");
790 let x = fwd.input("x", Shape::new(&[2, 4], f));
791 let w = fwd.param("w", Shape::new(&[4, 4], f));
792 let mm = fwd.matmul(x, w, Shape::new(&[2, 4], f));
793 fwd.set_outputs(vec![mm]);
794 let fwd_plan = plan_memory_aligned(&fwd, 64);
795 let layout = SharedWeightLayout::from_forward(&fwd, &fwd_plan);
796
797 let mut bwd = Graph::new("bwd_grad");
798 let x2 = bwd.input("x", Shape::new(&[2, 4], f));
799 let w2 = bwd.param("w", Shape::new(&[4, 4], f));
800 let mm2 = bwd.matmul(x2, w2, Shape::new(&[2, 4], f));
801 bwd.set_outputs(vec![mm2]);
802
803 let bwd_plan = plan_memory_backward(&bwd, 64, &layout);
804 let fwd_w_off = fwd_plan.assignments[&w].offset;
805 let bwd_w_off = bwd_plan.assignments[&w2].offset;
806 assert_eq!(bwd_w_off, fwd_w_off, "backward w must share forward offset");
807 assert!(
808 !bwd_plan.assignments.contains_key(&w2)
809 || bwd_plan.assignments[&w2].offset == fwd_w_off
810 );
811 }
812
813 #[test]
814 fn overlapping_buffers_get_separate_memory() {
815 let mut g = Graph::new("test");
816 let f = DType::F32;
817
818 let x = g.input("x", Shape::new(&[100, 384], f));
819 let w = g.param("w", Shape::new(&[384, 384], f));
820
821 let mm = g.matmul(x, w, Shape::new(&[100, 384], f));
822 let add = g.binary(BinaryOp::Add, mm, x, Shape::new(&[100, 384], f));
825 g.set_outputs(vec![add]);
826
827 let plan = plan_memory(&g);
828 let mm_slot = &plan.assignments[&mm];
829 let add_slot = &plan.assignments[&add];
830
831 let mm_end = mm_slot.offset + mm_slot.size;
833 let add_end = add_slot.offset + add_slot.size;
834 let no_overlap = mm_end <= add_slot.offset || add_end <= mm_slot.offset;
835 assert!(no_overlap, "overlapping buffers must have separate memory");
836 }
837}