1use rlx_ir::{Graph, NodeId, Op};
26use std::collections::HashMap;
27
28fn pure_view_offset(graph: &Graph, node: &rlx_ir::Node) -> Option<(NodeId, usize)> {
42 match &node.op {
43 Op::Reshape { .. } => Some((node.inputs[0], 0)),
44 Op::Cast { to } => {
45 let parent = graph.node(node.inputs[0]);
46 if parent.shape.dtype() == *to {
47 Some((node.inputs[0], 0))
48 } else {
49 None
50 }
51 }
52 Op::Narrow {
53 axis,
54 start,
55 len: _,
56 } if *axis == 0 => {
57 let parent = graph.node(node.inputs[0]);
58 let inner_elems: usize = (1..parent.shape.rank())
60 .map(|i| parent.shape.dim(i).unwrap_static())
61 .product();
62 let dt_bytes = parent.shape.dtype().size_bytes();
63 Some((node.inputs[0], start * inner_elems * dt_bytes))
64 }
65 _ => None,
66 }
67}
68
69pub fn is_pure_view(graph: &Graph, node: &rlx_ir::Node) -> bool {
73 pure_view_offset(graph, node).is_some()
74}
75
76#[derive(Debug, Clone)]
78pub struct BufferSlot {
79 pub offset: usize,
81 pub size: usize,
83}
84
85#[derive(Debug, Clone)]
87pub struct MemoryPlan {
88 pub arena_size: usize,
90 pub assignments: HashMap<NodeId, BufferSlot>,
92 pub schedule: Vec<NodeId>,
94}
95
96impl MemoryPlan {
97 pub fn total_unshared_bytes(&self) -> usize {
101 self.assignments.values().map(|s| s.size).sum()
102 }
103
104 pub fn bytes_saved(&self) -> usize {
107 self.total_unshared_bytes().saturating_sub(self.arena_size)
108 }
109
110 pub fn report(&self) -> String {
118 let mut rows: Vec<(usize, usize, NodeId)> = self
119 .assignments
120 .iter()
121 .map(|(id, slot)| (slot.offset, slot.size, *id))
122 .collect();
123 rows.sort();
124 let mut out = String::new();
125 out.push_str(&format!(
126 "# arena_size={} total_unshared={} saved={}\n",
127 self.arena_size,
128 self.total_unshared_bytes(),
129 self.bytes_saved()
130 ));
131 out.push_str("# offset\tsize\tnode\n");
132 for (off, sz, id) in rows {
133 out.push_str(&format!("{off}\t{sz}\t{id}\n"));
134 }
135 out
136 }
137}
138
139pub fn collect_view_aliases(graph: &Graph) -> HashMap<NodeId, (NodeId, usize)> {
141 let mut out = HashMap::new();
142 for node in graph.nodes() {
143 if pure_view_offset(graph, node).is_some() {
144 let (root, off) = resolve_view_root(graph, node.id);
145 out.insert(node.id, (root, off));
146 }
147 }
148 out
149}
150
151fn resolve_view_root(graph: &Graph, mut id: NodeId) -> (NodeId, usize) {
154 let mut total_offset = 0usize;
155 loop {
156 let node = graph.node(id);
157 match pure_view_offset(graph, node) {
158 Some((parent, off)) => {
159 total_offset += off;
160 id = parent;
161 }
162 None => return (id, total_offset),
163 }
164 }
165}
166
167fn compute_live_ranges(graph: &Graph) -> HashMap<NodeId, (usize, usize)> {
171 let mut ranges: HashMap<NodeId, (usize, usize)> = HashMap::new();
172
173 for (step, node) in graph.nodes().iter().enumerate() {
174 ranges.entry(node.id).or_insert((step, step));
176
177 for &input in &node.inputs {
182 let (root, _off) = resolve_view_root(graph, input);
183 ranges.entry(root).and_modify(|r| r.1 = r.1.max(step));
184 if root != input {
188 ranges.entry(input).and_modify(|r| r.1 = r.1.max(step));
189 }
190 }
191 }
192
193 let last_step = graph.len();
195 for &out in &graph.outputs {
196 let (root, _off) = resolve_view_root(graph, out);
197 ranges.entry(root).and_modify(|r| r.1 = last_step);
198 if root != out {
199 ranges.entry(out).and_modify(|r| r.1 = last_step);
200 }
201 }
202
203 for node in graph.nodes() {
210 if matches!(
211 node.op,
212 rlx_ir::Op::Param { .. } | rlx_ir::Op::Input { .. } | rlx_ir::Op::Constant { .. }
213 ) {
214 ranges.entry(node.id).and_modify(|r| {
215 r.0 = 0;
216 r.1 = last_step;
217 });
218 }
219 }
220
221 ranges
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq)]
236pub struct MemoryPlanOptions {
237 pub allocate_params: bool,
238 pub allocate_inputs: bool,
239 pub allocate_constants: bool,
240}
241
242impl MemoryPlanOptions {
243 pub fn inference() -> Self {
244 Self {
245 allocate_params: true,
246 allocate_inputs: true,
247 allocate_constants: true,
248 }
249 }
250
251 pub fn backward_activations_only() -> Self {
253 Self {
254 allocate_params: false,
255 allocate_inputs: true,
256 allocate_constants: true,
257 }
258 }
259}
260
261impl Default for MemoryPlanOptions {
262 fn default() -> Self {
263 Self::inference()
264 }
265}
266
267#[derive(Debug, Clone, PartialEq, Eq)]
269pub struct SharedWeightLayout {
270 pub arena_size: usize,
271 pub slots: Vec<WeightSlot>,
272}
273
274#[derive(Debug, Clone, PartialEq, Eq)]
276pub struct WeightSlot {
277 pub name: String,
278 pub forward_id: NodeId,
279 pub offset: usize,
280 pub size: usize,
281}
282
283impl SharedWeightLayout {
284 pub fn from_forward(graph: &Graph, plan: &MemoryPlan) -> Self {
286 let mut slots = Vec::new();
287 for node in graph.nodes() {
288 if let rlx_ir::Op::Param { name } = &node.op {
289 if let Some(slot) = plan.assignments.get(&node.id) {
290 slots.push(WeightSlot {
291 name: name.clone(),
292 forward_id: node.id,
293 offset: slot.offset,
294 size: slot.size,
295 });
296 }
297 }
298 }
299 slots.sort_by(|a, b| a.name.cmp(&b.name));
300 let arena_size = slots.iter().map(|s| s.offset + s.size).max().unwrap_or(0);
301 Self { arena_size, slots }
302 }
303
304 pub fn apply_to_plan(&self, graph: &Graph, plan: &mut MemoryPlan) {
306 let by_name: std::collections::HashMap<&str, &WeightSlot> =
307 self.slots.iter().map(|s| (s.name.as_str(), s)).collect();
308 for node in graph.nodes() {
309 if let rlx_ir::Op::Param { name } = &node.op {
310 let Some(slot) = by_name.get(name.as_str()) else {
311 continue;
312 };
313 plan.assignments.insert(
314 node.id,
315 BufferSlot {
316 offset: slot.offset,
317 size: slot.size,
318 },
319 );
320 }
321 }
322 plan.arena_size = plan.arena_size.max(self.arena_size);
323 }
324}
325
326#[inline]
327fn plans_boundary_buffer(op: &rlx_ir::Op, opts: MemoryPlanOptions) -> bool {
328 match op {
329 rlx_ir::Op::Param { .. } => opts.allocate_params,
330 rlx_ir::Op::Input { .. } => opts.allocate_inputs,
331 rlx_ir::Op::Constant { .. } => opts.allocate_constants,
332 _ => true,
333 }
334}
335
336pub fn plan_memory(graph: &Graph) -> MemoryPlan {
338 plan_memory_aligned(graph, 64)
339}
340
341pub fn plan_memory_with_options(
343 graph: &Graph,
344 alignment: usize,
345 opts: MemoryPlanOptions,
346) -> MemoryPlan {
347 plan_memory_aligned_inner(graph, alignment, opts, None)
348}
349
350pub fn plan_memory_aligned(graph: &Graph, alignment: usize) -> MemoryPlan {
352 plan_memory_aligned_inner(graph, alignment, MemoryPlanOptions::default(), None)
353}
354
355pub fn plan_memory_backward(
357 graph: &Graph,
358 alignment: usize,
359 weights: &SharedWeightLayout,
360) -> MemoryPlan {
361 plan_memory_aligned_inner(
362 graph,
363 alignment,
364 MemoryPlanOptions::backward_activations_only(),
365 Some(weights),
366 )
367}
368
369fn plan_memory_aligned_inner(
370 graph: &Graph,
371 alignment: usize,
372 opts: MemoryPlanOptions,
373 weights: Option<&SharedWeightLayout>,
374) -> MemoryPlan {
375 let ranges = compute_live_ranges(graph);
376
377 struct BufInfo {
379 id: NodeId,
380 size: usize,
381 birth: usize,
382 death: usize,
383 }
384
385 let mut buffers: Vec<BufInfo> = Vec::new();
386 for node in graph.nodes() {
387 if pure_view_offset(graph, node).is_some() {
390 continue;
391 }
392 if let Some(size) = node.shape.size_bytes()
393 && size > 0
394 && let Some(&(birth, death)) = ranges.get(&node.id)
395 && plans_boundary_buffer(&node.op, opts)
396 {
397 buffers.push(BufInfo {
398 id: node.id,
399 size,
400 birth,
401 death,
402 });
403 }
404 }
405
406 buffers.sort_by_key(|b| std::cmp::Reverse(b.size));
408
409 let mut assignments: HashMap<NodeId, BufferSlot> = HashMap::new();
411 let mut arena_size: usize = 0;
412
413 let mut placed: Vec<(usize, usize, usize, usize)> = Vec::new(); for buf in &buffers {
417 let align = alignment;
418 let mut best_offset: Option<usize> = None;
419
420 let mut candidates = vec![0usize];
423 for &(p_off, p_size, _, _) in &placed {
424 candidates.push(p_off + p_size);
425 }
426 candidates.sort_unstable();
427 candidates.dedup();
428
429 for &candidate_offset in &candidates {
430 let aligned = (candidate_offset + align - 1) & !(align - 1);
431 let end = aligned + buf.size;
432
433 let conflict = placed.iter().any(|&(p_off, p_size, p_birth, p_death)| {
434 let p_end = p_off + p_size;
435 let mem_overlap = aligned < p_end && end > p_off;
436 let time_overlap = buf.birth <= p_death && buf.death >= p_birth;
437 mem_overlap && time_overlap
438 });
439
440 if !conflict {
441 match best_offset {
442 None => best_offset = Some(aligned),
443 Some(best) if aligned < best => best_offset = Some(aligned),
444 _ => {}
445 }
446 }
447 }
448
449 let aligned = best_offset.unwrap_or_else(|| {
450 (arena_size + align - 1) & !(align - 1)
452 });
453 assignments.insert(
454 buf.id,
455 BufferSlot {
456 offset: aligned,
457 size: buf.size,
458 },
459 );
460 placed.push((aligned, buf.size, buf.birth, buf.death));
461 arena_size = arena_size.max(aligned + buf.size);
462 }
463
464 for node in graph.nodes() {
470 if pure_view_offset(graph, node).is_some() {
471 let (root, off) = resolve_view_root(graph, node.id);
472 if let Some(root_slot) = assignments.get(&root).cloned() {
473 let view_size = node.shape.size_bytes().unwrap_or(0);
474 assignments.insert(
475 node.id,
476 BufferSlot {
477 offset: root_slot.offset + off,
478 size: view_size,
479 },
480 );
481 }
482 }
483 }
484
485 let schedule = graph.topo_order().collect();
486
487 let mut plan = MemoryPlan {
488 arena_size,
489 assignments,
490 schedule,
491 };
492 if let Some(w) = weights {
493 w.apply_to_plan(graph, &mut plan);
494 }
495 plan
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use rlx_ir::op::*;
502 use rlx_ir::*;
503
504 #[test]
505 fn non_overlapping_buffers_share_memory() {
506 let mut g = Graph::new("test");
507 let f = DType::F32;
508
509 let x = g.input("x", Shape::new(&[100, 384], f)); let w1 = g.param("w1", Shape::new(&[384, 384], f));
511 let w2 = g.param("w2", Shape::new(&[384, 384], f));
512
513 let mm1 = g.matmul(x, w1, Shape::new(&[100, 384], f)); let mm2 = g.matmul(mm1, w2, Shape::new(&[100, 384], f)); g.set_outputs(vec![mm2]);
517
518 let plan = plan_memory(&g);
519 println!("Arena size: {} bytes", plan.arena_size);
520 for (id, slot) in &plan.assignments {
521 if let Some((b, d)) = compute_live_ranges(&g).get(id) {
522 println!(
523 " {id}: offset={}, size={}, live=[{b}, {d}]",
524 slot.offset, slot.size
525 );
526 }
527 }
528
529 let total_if_no_sharing: usize = plan.assignments.values().map(|s| s.size).sum();
532 assert!(
533 plan.arena_size <= total_if_no_sharing,
534 "arena {0} should be <= sum {total_if_no_sharing}",
535 plan.arena_size
536 );
537 }
538
539 #[test]
540 fn plan_report_includes_savings() {
541 let mut g = Graph::new("rep");
545 let f = DType::F32;
546 let x = g.input("x", Shape::new(&[16], f));
547 let w = g.param("w", Shape::new(&[16, 16], f));
548 let mm1 = g.matmul(x, w, Shape::new(&[1, 16], f));
549 let mm2 = g.matmul(mm1, w, Shape::new(&[1, 16], f));
550 g.set_outputs(vec![mm2]);
551
552 let plan = plan_memory(&g);
553 let r = plan.report();
554 assert!(r.starts_with("# arena_size="));
556 assert!(r.contains("total_unshared="));
557 assert!(r.contains("saved="));
558 let body: Vec<&str> = r.lines().filter(|l| !l.starts_with('#')).collect();
560 assert!(!body.is_empty());
561 assert!(plan.assignments.contains_key(&mm1));
563 assert!(plan.assignments.contains_key(&mm2));
564 }
565
566 #[test]
567 fn view_ops_alias_parent_slot() {
568 use rlx_ir::GraphExt;
571 let mut g = Graph::new("views");
572 let f = DType::F32;
573 let x = g.input("x", Shape::new(&[8, 4], f)); let w = g.param("w", Shape::new(&[4, 4], f)); let mm = g.matmul(x, w, Shape::new(&[8, 4], f)); let r = g.reshape_(mm, vec![32]); let c = g.cast(r, DType::F32); let n = g.narrow_(c, 0, 8, 16); g.set_outputs(vec![n]);
580
581 let plan = plan_memory(&g);
582
583 let mm_off = plan.assignments[&mm].offset;
586 assert_eq!(
587 plan.assignments[&r].offset, mm_off,
588 "reshape view should alias mm slot exactly"
589 );
590 assert_eq!(
591 plan.assignments[&c].offset, mm_off,
592 "same-dtype cast view should alias mm slot exactly"
593 );
594 assert_eq!(
595 plan.assignments[&n].offset,
596 mm_off + 32,
597 "axis-0 narrow start=8 should alias mm slot + 8*4 bytes"
598 );
599 assert_eq!(
600 plan.assignments[&n].size, 64,
601 "narrow view's size is its own (16 f32 = 64B), not parent's"
602 );
603 }
604
605 #[test]
606 fn backward_plan_aliases_forward_param_slots() {
607 let f = DType::F32;
608 let mut fwd = Graph::new("fwd");
609 let x = fwd.input("x", Shape::new(&[2, 4], f));
610 let w = fwd.param("w", Shape::new(&[4, 4], f));
611 let mm = fwd.matmul(x, w, Shape::new(&[2, 4], f));
612 fwd.set_outputs(vec![mm]);
613 let fwd_plan = plan_memory_aligned(&fwd, 64);
614 let layout = SharedWeightLayout::from_forward(&fwd, &fwd_plan);
615
616 let mut bwd = Graph::new("bwd_grad");
617 let x2 = bwd.input("x", Shape::new(&[2, 4], f));
618 let w2 = bwd.param("w", Shape::new(&[4, 4], f));
619 let mm2 = bwd.matmul(x2, w2, Shape::new(&[2, 4], f));
620 bwd.set_outputs(vec![mm2]);
621
622 let bwd_plan = plan_memory_backward(&bwd, 64, &layout);
623 let fwd_w_off = fwd_plan.assignments[&w].offset;
624 let bwd_w_off = bwd_plan.assignments[&w2].offset;
625 assert_eq!(bwd_w_off, fwd_w_off, "backward w must share forward offset");
626 assert!(
627 !bwd_plan.assignments.contains_key(&w2)
628 || bwd_plan.assignments[&w2].offset == fwd_w_off
629 );
630 }
631
632 #[test]
633 fn overlapping_buffers_get_separate_memory() {
634 let mut g = Graph::new("test");
635 let f = DType::F32;
636
637 let x = g.input("x", Shape::new(&[100, 384], f));
638 let w = g.param("w", Shape::new(&[384, 384], f));
639
640 let mm = g.matmul(x, w, Shape::new(&[100, 384], f));
641 let add = g.binary(BinaryOp::Add, mm, x, Shape::new(&[100, 384], f));
644 g.set_outputs(vec![add]);
645
646 let plan = plan_memory(&g);
647 let mm_slot = &plan.assignments[&mm];
648 let add_slot = &plan.assignments[&add];
649
650 let mm_end = mm_slot.offset + mm_slot.size;
652 let add_end = add_slot.offset + add_slot.size;
653 let no_overlap = mm_end <= add_slot.offset || add_end <= mm_slot.offset;
654 assert!(no_overlap, "overlapping buffers must have separate memory");
655 }
656}