1use rlx_ir::{Graph, NodeId, Op};
26use std::collections::HashMap;
27
28fn pure_view_offset(graph: &Graph, node: &rlx_ir::Node) -> Option<(NodeId, usize)> {
42 match &node.op {
43 Op::Reshape { .. } => Some((node.inputs[0], 0)),
44 Op::Cast { to } => {
45 let parent = graph.node(node.inputs[0]);
46 if parent.shape.dtype() == *to {
47 Some((node.inputs[0], 0))
48 } else {
49 None
50 }
51 }
52 Op::Narrow {
53 axis,
54 start,
55 len: _,
56 } if *axis == 0 => {
57 let parent = graph.node(node.inputs[0]);
58 let inner_elems: usize = (1..parent.shape.rank())
60 .map(|i| parent.shape.dim(i).unwrap_static())
61 .product();
62 let dt_bytes = parent.shape.dtype().size_bytes();
63 Some((node.inputs[0], start * inner_elems * dt_bytes))
64 }
65 _ => None,
66 }
67}
68
69pub fn is_pure_view(graph: &Graph, node: &rlx_ir::Node) -> bool {
73 pure_view_offset(graph, node).is_some()
74}
75
76#[derive(Debug, Clone)]
78pub struct BufferSlot {
79 pub offset: usize,
81 pub size: usize,
83}
84
85#[derive(Debug, Clone)]
87pub struct MemoryPlan {
88 pub arena_size: usize,
90 pub assignments: HashMap<NodeId, BufferSlot>,
92 pub schedule: Vec<NodeId>,
94}
95
96impl MemoryPlan {
97 pub fn total_unshared_bytes(&self) -> usize {
101 self.assignments.values().map(|s| s.size).sum()
102 }
103
104 pub fn bytes_saved(&self) -> usize {
107 self.total_unshared_bytes().saturating_sub(self.arena_size)
108 }
109
110 pub fn report(&self) -> String {
118 let mut rows: Vec<(usize, usize, NodeId)> = self
119 .assignments
120 .iter()
121 .map(|(id, slot)| (slot.offset, slot.size, *id))
122 .collect();
123 rows.sort();
124 let mut out = String::new();
125 out.push_str(&format!(
126 "# arena_size={} total_unshared={} saved={}\n",
127 self.arena_size,
128 self.total_unshared_bytes(),
129 self.bytes_saved()
130 ));
131 out.push_str("# offset\tsize\tnode\n");
132 for (off, sz, id) in rows {
133 out.push_str(&format!("{off}\t{sz}\t{id}\n"));
134 }
135 out
136 }
137}
138
139pub fn collect_view_aliases(graph: &Graph) -> HashMap<NodeId, (NodeId, usize)> {
141 let mut out = HashMap::new();
142 for node in graph.nodes() {
143 if pure_view_offset(graph, node).is_some() {
144 let (root, off) = resolve_view_root(graph, node.id);
145 out.insert(node.id, (root, off));
146 }
147 }
148 out
149}
150
151fn resolve_view_root(graph: &Graph, mut id: NodeId) -> (NodeId, usize) {
154 let mut total_offset = 0usize;
155 loop {
156 let node = graph.node(id);
157 match pure_view_offset(graph, node) {
158 Some((parent, off)) => {
159 total_offset += off;
160 id = parent;
161 }
162 None => return (id, total_offset),
163 }
164 }
165}
166
167fn compute_live_ranges(graph: &Graph) -> HashMap<NodeId, (usize, usize)> {
171 let mut ranges: HashMap<NodeId, (usize, usize)> = HashMap::new();
172
173 for (step, node) in graph.nodes().iter().enumerate() {
174 ranges.entry(node.id).or_insert((step, step));
176
177 for &input in &node.inputs {
182 let (root, _off) = resolve_view_root(graph, input);
183 ranges.entry(root).and_modify(|r| r.1 = r.1.max(step));
184 if root != input {
188 ranges.entry(input).and_modify(|r| r.1 = r.1.max(step));
189 }
190 }
191 }
192
193 let last_step = graph.len();
195 for &out in &graph.outputs {
196 let (root, _off) = resolve_view_root(graph, out);
197 ranges.entry(root).and_modify(|r| r.1 = last_step);
198 if root != out {
199 ranges.entry(out).and_modify(|r| r.1 = last_step);
200 }
201 }
202
203 for node in graph.nodes() {
210 if matches!(
211 node.op,
212 rlx_ir::Op::Param { .. } | rlx_ir::Op::Input { .. } | rlx_ir::Op::Constant { .. }
213 ) {
214 ranges.entry(node.id).and_modify(|r| {
215 r.0 = 0;
216 r.1 = last_step;
217 });
218 }
219 }
220
221 ranges
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq)]
236pub struct MemoryPlanOptions {
237 pub allocate_params: bool,
238 pub allocate_inputs: bool,
239 pub allocate_constants: bool,
240}
241
242impl MemoryPlanOptions {
243 pub fn inference() -> Self {
244 Self {
245 allocate_params: true,
246 allocate_inputs: true,
247 allocate_constants: true,
248 }
249 }
250
251 pub fn backward_activations_only() -> Self {
253 Self {
254 allocate_params: false,
255 allocate_inputs: true,
256 allocate_constants: true,
257 }
258 }
259}
260
261impl Default for MemoryPlanOptions {
262 fn default() -> Self {
263 Self::inference()
264 }
265}
266
267#[derive(Debug, Clone, PartialEq, Eq)]
269pub struct SharedWeightLayout {
270 pub arena_size: usize,
271 pub slots: Vec<WeightSlot>,
272}
273
274#[derive(Debug, Clone, PartialEq, Eq)]
276pub struct WeightSlot {
277 pub name: String,
278 pub forward_id: NodeId,
279 pub offset: usize,
280 pub size: usize,
281}
282
283impl SharedWeightLayout {
284 pub fn from_forward(graph: &Graph, plan: &MemoryPlan) -> Self {
286 let mut slots = Vec::new();
287 for node in graph.nodes() {
288 if let rlx_ir::Op::Param { name } = &node.op {
289 if let Some(slot) = plan.assignments.get(&node.id) {
290 slots.push(WeightSlot {
291 name: name.clone(),
292 forward_id: node.id,
293 offset: slot.offset,
294 size: slot.size,
295 });
296 }
297 }
298 }
299 slots.sort_by(|a, b| a.name.cmp(&b.name));
300 let arena_size = slots.iter().map(|s| s.offset + s.size).max().unwrap_or(0);
301 Self { arena_size, slots }
302 }
303
304 pub fn apply_to_plan(&self, graph: &Graph, plan: &mut MemoryPlan) {
306 let by_name: std::collections::HashMap<&str, &WeightSlot> =
307 self.slots.iter().map(|s| (s.name.as_str(), s)).collect();
308 for node in graph.nodes() {
309 if let rlx_ir::Op::Param { name } = &node.op {
310 let Some(slot) = by_name.get(name.as_str()) else {
311 continue;
312 };
313 plan.assignments.insert(
314 node.id,
315 BufferSlot {
316 offset: slot.offset,
317 size: slot.size,
318 },
319 );
320 }
321 }
322 plan.arena_size = plan.arena_size.max(self.arena_size);
323 }
324}
325
326#[inline]
327fn plans_boundary_buffer(op: &rlx_ir::Op, opts: MemoryPlanOptions) -> bool {
328 match op {
329 rlx_ir::Op::Param { .. } => opts.allocate_params,
330 rlx_ir::Op::Input { .. } => opts.allocate_inputs,
331 rlx_ir::Op::Constant { .. } => opts.allocate_constants,
332 _ => true,
333 }
334}
335
336pub fn plan_memory(graph: &Graph) -> MemoryPlan {
338 plan_memory_aligned(graph, 64)
339}
340
341pub fn plan_memory_with_options(
343 graph: &Graph,
344 alignment: usize,
345 opts: MemoryPlanOptions,
346) -> MemoryPlan {
347 plan_memory_aligned_inner(graph, alignment, opts, None, false)
348}
349
350pub fn plan_memory_aligned(graph: &Graph, alignment: usize) -> MemoryPlan {
352 plan_memory_aligned_inner(graph, alignment, MemoryPlanOptions::default(), None, false)
353}
354
355pub fn plan_memory_f32_uniform(graph: &Graph, alignment: usize) -> MemoryPlan {
359 plan_memory_aligned_inner(graph, alignment, MemoryPlanOptions::default(), None, true)
360}
361
362pub fn plan_memory_backward(
364 graph: &Graph,
365 alignment: usize,
366 weights: &SharedWeightLayout,
367) -> MemoryPlan {
368 plan_memory_aligned_inner(
369 graph,
370 alignment,
371 MemoryPlanOptions::backward_activations_only(),
372 Some(weights),
373 false,
374 )
375}
376
377#[inline]
378fn node_slot_bytes(node: &rlx_ir::Node, f32_uniform: bool) -> usize {
379 if f32_uniform {
380 node.shape.num_elements().unwrap_or(0) * 4
381 } else {
382 node.shape.size_bytes().unwrap_or(0)
383 }
384}
385
386fn plan_memory_aligned_inner(
387 graph: &Graph,
388 alignment: usize,
389 opts: MemoryPlanOptions,
390 weights: Option<&SharedWeightLayout>,
391 f32_uniform: bool,
392) -> MemoryPlan {
393 let ranges = compute_live_ranges(graph);
394
395 struct BufInfo {
397 id: NodeId,
398 size: usize,
399 birth: usize,
400 death: usize,
401 }
402
403 let mut buffers: Vec<BufInfo> = Vec::new();
404 for node in graph.nodes() {
405 if pure_view_offset(graph, node).is_some() {
408 continue;
409 }
410 let size = node_slot_bytes(node, f32_uniform);
411 if size > 0
412 && let Some(&(birth, death)) = ranges.get(&node.id)
413 && plans_boundary_buffer(&node.op, opts)
414 {
415 buffers.push(BufInfo {
416 id: node.id,
417 size,
418 birth,
419 death,
420 });
421 }
422 }
423
424 buffers.sort_by_key(|b| std::cmp::Reverse(b.size));
426
427 let mut assignments: HashMap<NodeId, BufferSlot> = HashMap::new();
429 let mut arena_size: usize = 0;
430
431 let mut placed: Vec<(usize, usize, usize, usize)> = Vec::new(); for buf in &buffers {
435 let align = alignment;
436 let mut best_offset: Option<usize> = None;
437
438 let mut candidates = vec![0usize];
441 for &(p_off, p_size, _, _) in &placed {
442 candidates.push(p_off + p_size);
443 }
444 candidates.sort_unstable();
445 candidates.dedup();
446
447 for &candidate_offset in &candidates {
448 let aligned = (candidate_offset + align - 1) & !(align - 1);
449 let end = aligned + buf.size;
450
451 let conflict = placed.iter().any(|&(p_off, p_size, p_birth, p_death)| {
452 let p_end = p_off + p_size;
453 let mem_overlap = aligned < p_end && end > p_off;
454 let time_overlap = buf.birth <= p_death && buf.death >= p_birth;
455 mem_overlap && time_overlap
456 });
457
458 if !conflict {
459 match best_offset {
460 None => best_offset = Some(aligned),
461 Some(best) if aligned < best => best_offset = Some(aligned),
462 _ => {}
463 }
464 }
465 }
466
467 let aligned = best_offset.unwrap_or_else(|| {
468 (arena_size + align - 1) & !(align - 1)
470 });
471 assignments.insert(
472 buf.id,
473 BufferSlot {
474 offset: aligned,
475 size: buf.size,
476 },
477 );
478 placed.push((aligned, buf.size, buf.birth, buf.death));
479 arena_size = arena_size.max(aligned + buf.size);
480 }
481
482 for node in graph.nodes() {
488 if pure_view_offset(graph, node).is_some() {
489 let (root, off) = resolve_view_root(graph, node.id);
490 if let Some(root_slot) = assignments.get(&root).cloned() {
491 let view_size = node_slot_bytes(node, f32_uniform);
492 assignments.insert(
493 node.id,
494 BufferSlot {
495 offset: root_slot.offset + off,
496 size: view_size,
497 },
498 );
499 }
500 }
501 }
502
503 let schedule = graph.topo_order().collect();
504
505 let mut plan = MemoryPlan {
506 arena_size,
507 assignments,
508 schedule,
509 };
510 if let Some(w) = weights {
511 w.apply_to_plan(graph, &mut plan);
512 }
513 plan
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use rlx_ir::op::*;
520 use rlx_ir::*;
521
522 #[test]
523 fn non_overlapping_buffers_share_memory() {
524 let mut g = Graph::new("test");
525 let f = DType::F32;
526
527 let x = g.input("x", Shape::new(&[100, 384], f)); let w1 = g.param("w1", Shape::new(&[384, 384], f));
529 let w2 = g.param("w2", Shape::new(&[384, 384], f));
530
531 let mm1 = g.matmul(x, w1, Shape::new(&[100, 384], f)); let mm2 = g.matmul(mm1, w2, Shape::new(&[100, 384], f)); g.set_outputs(vec![mm2]);
535
536 let plan = plan_memory(&g);
537 println!("Arena size: {} bytes", plan.arena_size);
538 for (id, slot) in &plan.assignments {
539 if let Some((b, d)) = compute_live_ranges(&g).get(id) {
540 println!(
541 " {id}: offset={}, size={}, live=[{b}, {d}]",
542 slot.offset, slot.size
543 );
544 }
545 }
546
547 let total_if_no_sharing: usize = plan.assignments.values().map(|s| s.size).sum();
550 assert!(
551 plan.arena_size <= total_if_no_sharing,
552 "arena {0} should be <= sum {total_if_no_sharing}",
553 plan.arena_size
554 );
555 }
556
557 #[test]
558 fn plan_report_includes_savings() {
559 let mut g = Graph::new("rep");
563 let f = DType::F32;
564 let x = g.input("x", Shape::new(&[16], f));
565 let w = g.param("w", Shape::new(&[16, 16], f));
566 let mm1 = g.matmul(x, w, Shape::new(&[1, 16], f));
567 let mm2 = g.matmul(mm1, w, Shape::new(&[1, 16], f));
568 g.set_outputs(vec![mm2]);
569
570 let plan = plan_memory(&g);
571 let r = plan.report();
572 assert!(r.starts_with("# arena_size="));
574 assert!(r.contains("total_unshared="));
575 assert!(r.contains("saved="));
576 let body: Vec<&str> = r.lines().filter(|l| !l.starts_with('#')).collect();
578 assert!(!body.is_empty());
579 assert!(plan.assignments.contains_key(&mm1));
581 assert!(plan.assignments.contains_key(&mm2));
582 }
583
584 #[test]
585 fn view_ops_alias_parent_slot() {
586 use rlx_ir::GraphExt;
589 let mut g = Graph::new("views");
590 let f = DType::F32;
591 let x = g.input("x", Shape::new(&[8, 4], f)); let w = g.param("w", Shape::new(&[4, 4], f)); let mm = g.matmul(x, w, Shape::new(&[8, 4], f)); let r = g.reshape_(mm, vec![32]); let c = g.cast(r, DType::F32); let n = g.narrow_(c, 0, 8, 16); g.set_outputs(vec![n]);
598
599 let plan = plan_memory(&g);
600
601 let mm_off = plan.assignments[&mm].offset;
604 assert_eq!(
605 plan.assignments[&r].offset, mm_off,
606 "reshape view should alias mm slot exactly"
607 );
608 assert_eq!(
609 plan.assignments[&c].offset, mm_off,
610 "same-dtype cast view should alias mm slot exactly"
611 );
612 assert_eq!(
613 plan.assignments[&n].offset,
614 mm_off + 32,
615 "axis-0 narrow start=8 should alias mm slot + 8*4 bytes"
616 );
617 assert_eq!(
618 plan.assignments[&n].size, 64,
619 "narrow view's size is its own (16 f32 = 64B), not parent's"
620 );
621 }
622
623 #[test]
624 fn backward_plan_aliases_forward_param_slots() {
625 let f = DType::F32;
626 let mut fwd = Graph::new("fwd");
627 let x = fwd.input("x", Shape::new(&[2, 4], f));
628 let w = fwd.param("w", Shape::new(&[4, 4], f));
629 let mm = fwd.matmul(x, w, Shape::new(&[2, 4], f));
630 fwd.set_outputs(vec![mm]);
631 let fwd_plan = plan_memory_aligned(&fwd, 64);
632 let layout = SharedWeightLayout::from_forward(&fwd, &fwd_plan);
633
634 let mut bwd = Graph::new("bwd_grad");
635 let x2 = bwd.input("x", Shape::new(&[2, 4], f));
636 let w2 = bwd.param("w", Shape::new(&[4, 4], f));
637 let mm2 = bwd.matmul(x2, w2, Shape::new(&[2, 4], f));
638 bwd.set_outputs(vec![mm2]);
639
640 let bwd_plan = plan_memory_backward(&bwd, 64, &layout);
641 let fwd_w_off = fwd_plan.assignments[&w].offset;
642 let bwd_w_off = bwd_plan.assignments[&w2].offset;
643 assert_eq!(bwd_w_off, fwd_w_off, "backward w must share forward offset");
644 assert!(
645 !bwd_plan.assignments.contains_key(&w2)
646 || bwd_plan.assignments[&w2].offset == fwd_w_off
647 );
648 }
649
650 #[test]
651 fn overlapping_buffers_get_separate_memory() {
652 let mut g = Graph::new("test");
653 let f = DType::F32;
654
655 let x = g.input("x", Shape::new(&[100, 384], f));
656 let w = g.param("w", Shape::new(&[384, 384], f));
657
658 let mm = g.matmul(x, w, Shape::new(&[100, 384], f));
659 let add = g.binary(BinaryOp::Add, mm, x, Shape::new(&[100, 384], f));
662 g.set_outputs(vec![add]);
663
664 let plan = plan_memory(&g);
665 let mm_slot = &plan.assignments[&mm];
666 let add_slot = &plan.assignments[&add];
667
668 let mm_end = mm_slot.offset + mm_slot.size;
670 let add_end = add_slot.offset + add_slot.size;
671 let no_overlap = mm_end <= add_slot.offset || add_end <= mm_slot.offset;
672 assert!(no_overlap, "overlapping buffers must have separate memory");
673 }
674}