1use rlx_ir::shape::Dim;
38use rlx_ir::*;
39use std::collections::{HashMap, HashSet};
40
41pub fn vmap(forward: &Graph, batched_input_names: &[&str], batch_size: usize) -> Graph {
58 let batched_set: HashSet<&str> = batched_input_names.iter().copied().collect();
59 let mut out = Graph::new(format!("{}_vmap", forward.name));
60 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
61 let mut batched: HashSet<NodeId> = HashSet::new();
65
66 for node in forward.nodes() {
67 let new_id = match &node.op {
68 Op::Input { name } => {
69 if batched_set.contains(name.as_str()) {
70 let mut dims: Vec<Dim> = vec![Dim::Static(batch_size)];
71 dims.extend(node.shape.dims().iter().copied());
72 let s = Shape::from_dims(&dims, node.shape.dtype());
73 let id = out.input(name.clone(), s);
74 batched.insert(id);
75 id
76 } else {
77 out.input(name.clone(), node.shape.clone())
78 }
79 }
80 Op::Param { name } => {
81 out.param(name.clone(), node.shape.clone())
84 }
85 Op::Constant { data } => out.add_node(
86 Op::Constant { data: data.clone() },
87 vec![],
88 node.shape.clone(),
89 ),
90 _ => {
91 let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
92 let any_batched = new_inputs.iter().any(|i| batched.contains(i));
93 if !any_batched {
94 out.add_node(node.op.clone(), new_inputs, node.shape.clone())
97 } else {
98 let id = vmap_op(node, &new_inputs, &mut out, &mut batched, batch_size);
99 batched.insert(id);
100 id
101 }
102 }
103 };
104 id_map.insert(node.id, new_id);
105 }
106
107 let new_outputs: Vec<NodeId> = forward.outputs.iter().map(|o| id_map[o]).collect();
108 out.set_outputs(new_outputs);
109 out
110}
111
112fn vmap_op(
114 node: &Node,
115 new_inputs: &[NodeId],
116 out: &mut Graph,
117 batched: &mut HashSet<NodeId>,
118 batch_size: usize,
119) -> NodeId {
120 let orig_shape = &node.shape;
121 let dtype = orig_shape.dtype();
122
123 let batched_shape = || -> Shape {
125 let mut dims: Vec<Dim> = vec![Dim::Static(batch_size)];
126 dims.extend(orig_shape.dims().iter().copied());
127 Shape::from_dims(&dims, dtype)
128 };
129
130 match &node.op {
131 Op::Binary(_) | Op::Activation(_) | Op::Where | Op::Compare(_) | Op::Cast { .. } => {
133 let lifted: Vec<NodeId> = new_inputs
134 .iter()
135 .map(|&id| lift_to_batched(out, id, batched, batch_size))
136 .collect();
137 for &id in &lifted {
138 batched.insert(id);
139 }
140 out.add_node(node.op.clone(), lifted, batched_shape())
141 }
142
143 Op::Reshape { new_shape } => {
145 let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
146 batched.insert(lifted);
147 let mut bsh: Vec<i64> = vec![batch_size as i64];
148 bsh.extend(new_shape.iter().copied());
149 out.add_node(
150 Op::Reshape { new_shape: bsh },
151 vec![lifted],
152 batched_shape(),
153 )
154 }
155
156 Op::Transpose { perm } => {
158 let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
159 batched.insert(lifted);
160 let mut new_perm: Vec<usize> = vec![0];
161 new_perm.extend(perm.iter().map(|p| p + 1));
162 out.add_node(
163 Op::Transpose { perm: new_perm },
164 vec![lifted],
165 batched_shape(),
166 )
167 }
168
169 Op::Expand { target_shape } => {
171 let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
172 batched.insert(lifted);
173 let mut bsh: Vec<i64> = vec![batch_size as i64];
174 bsh.extend(target_shape.iter().copied());
175 out.add_node(
176 Op::Expand { target_shape: bsh },
177 vec![lifted],
178 batched_shape(),
179 )
180 }
181
182 Op::Reduce { op, axes, keep_dim } => {
184 let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
185 batched.insert(lifted);
186 let new_axes: Vec<usize> = axes.iter().map(|a| a + 1).collect();
187 out.add_node(
188 Op::Reduce {
189 op: *op,
190 axes: new_axes,
191 keep_dim: *keep_dim,
192 },
193 vec![lifted],
194 batched_shape(),
195 )
196 }
197
198 Op::MatMul => {
203 let a = lift_to_batched(out, new_inputs[0], batched, batch_size);
204 let b = lift_to_batched(out, new_inputs[1], batched, batch_size);
205 batched.insert(a);
206 batched.insert(b);
207 out.matmul(a, b, batched_shape())
208 }
209
210 Op::DenseSolve => {
213 let a = lift_to_batched(out, new_inputs[0], batched, batch_size);
214 let b = lift_to_batched(out, new_inputs[1], batched, batch_size);
215 batched.insert(a);
216 batched.insert(b);
217 out.batched_dense_solve(a, b, batched_shape())
218 }
219
220 Op::Scan {
235 body,
236 length,
237 save_trajectory,
238 num_xs,
239 num_checkpoints: _,
240 num_bcast,
241 } => {
242 let init_b = lift_to_batched(out, new_inputs[0], batched, batch_size);
244 batched.insert(init_b);
245
246 let mut bcasts_b: Vec<NodeId> = Vec::with_capacity(*num_bcast as usize);
250 for i in 0..*num_bcast as usize {
251 let bcast_in = new_inputs[1 + i];
252 let lifted = lift_to_batched(out, bcast_in, batched, batch_size);
253 batched.insert(lifted);
254 bcasts_b.push(lifted);
255 }
256
257 let xs_base = 1 + *num_bcast as usize;
260 let mut xs_t: Vec<NodeId> = Vec::with_capacity(*num_xs as usize);
261 for i in 0..*num_xs as usize {
262 let xs_in = new_inputs[xs_base + i];
263 let lifted = lift_to_batched(out, xs_in, batched, batch_size);
264 batched.insert(lifted);
265 let xs_shape = out.node(lifted).shape.clone();
266 let r = xs_shape.rank();
267 let mut perm: Vec<usize> = vec![1, 0];
268 for k in 2..r {
269 perm.push(k);
270 }
271 let mut new_dims: Vec<Dim> = xs_shape.dims().to_vec();
272 new_dims.swap(0, 1);
273 let new_shape = Shape::from_dims(&new_dims, xs_shape.dtype());
274 let transposed = out.add_node(Op::Transpose { perm }, vec![lifted], new_shape);
275 batched.insert(transposed);
276 xs_t.push(transposed);
277 }
278
279 let body_input_names_owned: Vec<String> = body
283 .nodes()
284 .iter()
285 .filter_map(|n| match &n.op {
286 Op::Input { name } => Some(name.clone()),
287 _ => None,
288 })
289 .collect();
290 let body_input_names: Vec<&str> =
291 body_input_names_owned.iter().map(|s| s.as_str()).collect();
292 let body_b = vmap(body, &body_input_names, batch_size);
293
294 let dtype = orig_shape.dtype();
298 let inner_out_shape: Shape = if *save_trajectory {
301 let mut dims: Vec<Dim> = vec![orig_shape.dim(0)];
303 dims.push(Dim::Static(batch_size));
304 for i in 1..orig_shape.rank() {
305 dims.push(orig_shape.dim(i));
306 }
307 Shape::from_dims(&dims, dtype)
308 } else {
309 let mut dims: Vec<Dim> = vec![Dim::Static(batch_size)];
311 for i in 0..orig_shape.rank() {
312 dims.push(orig_shape.dim(i));
313 }
314 Shape::from_dims(&dims, dtype)
315 };
316
317 let mut inner_inputs = vec![init_b];
319 inner_inputs.extend_from_slice(&bcasts_b);
320 inner_inputs.extend_from_slice(&xs_t);
321
322 let inner_id = out.add_node(
323 Op::Scan {
324 body: Box::new(body_b),
325 length: *length,
326 save_trajectory: *save_trajectory,
327 num_xs: *num_xs,
328 num_checkpoints: 0,
329 num_bcast: *num_bcast,
330 },
331 inner_inputs,
332 inner_out_shape,
333 );
334
335 if *save_trajectory {
336 let r = orig_shape.rank() + 1; let mut perm: Vec<usize> = vec![1, 0];
339 for k in 2..r {
340 perm.push(k);
341 }
342 out.add_node(Op::Transpose { perm }, vec![inner_id], batched_shape())
343 } else {
344 inner_id
345 }
346 }
347
348 Op::Narrow { axis, start, len } => {
350 let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
351 batched.insert(lifted);
352 out.add_node(
353 Op::Narrow {
354 axis: axis + 1,
355 start: *start,
356 len: *len,
357 },
358 vec![lifted],
359 batched_shape(),
360 )
361 }
362
363 Op::Concat { axis } => {
365 let lifted: Vec<NodeId> = new_inputs
366 .iter()
367 .map(|&id| lift_to_batched(out, id, batched, batch_size))
368 .collect();
369 for &id in &lifted {
370 batched.insert(id);
371 }
372 out.add_node(Op::Concat { axis: axis + 1 }, lifted, batched_shape())
373 }
374
375 Op::Softmax { axis } => {
377 let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
378 batched.insert(lifted);
379 out.add_node(
380 Op::Softmax { axis: *axis + 1 },
381 vec![lifted],
382 batched_shape(),
383 )
384 }
385
386 Op::Cumsum { axis, exclusive } => {
388 let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
389 batched.insert(lifted);
390 out.add_node(
391 Op::Cumsum {
392 axis: *axis + 1,
393 exclusive: *exclusive,
394 },
395 vec![lifted],
396 batched_shape(),
397 )
398 }
399
400 Op::LayerNorm { axis, eps } => {
406 let x = lift_to_batched(out, new_inputs[0], batched, batch_size);
407 batched.insert(x);
408 out.add_node(
409 Op::LayerNorm {
410 axis: *axis + 1,
411 eps: *eps,
412 },
413 vec![x, new_inputs[1], new_inputs[2]],
414 batched_shape(),
415 )
416 }
417
418 Op::RmsNorm { axis, eps } => {
420 let x = lift_to_batched(out, new_inputs[0], batched, batch_size);
421 batched.insert(x);
422 out.add_node(
423 Op::RmsNorm {
424 axis: *axis + 1,
425 eps: *eps,
426 },
427 vec![x, new_inputs[1], new_inputs[2]],
428 batched_shape(),
429 )
430 }
431
432 Op::Gather { axis } => {
436 let table = lift_to_batched(out, new_inputs[0], batched, batch_size);
437 let indices = lift_to_batched(out, new_inputs[1], batched, batch_size);
438 batched.insert(table);
439 batched.insert(indices);
440 out.add_node(
441 Op::Gather { axis: axis + 1 },
442 vec![table, indices],
443 batched_shape(),
444 )
445 }
446
447 Op::ScatterAdd => {
454 let updates = lift_to_batched(out, new_inputs[0], batched, batch_size);
455 let indices = lift_to_batched(out, new_inputs[1], batched, batch_size);
456 batched.insert(updates);
457 batched.insert(indices);
458 out.add_node(Op::ScatterAdd, vec![updates, indices], batched_shape())
459 }
460
461 Op::ElementwiseRegion { .. } => {
470 let lifted: Vec<NodeId> = new_inputs
471 .iter()
472 .map(|&id| lift_to_batched(out, id, batched, batch_size))
473 .collect();
474 for &id in &lifted {
475 batched.insert(id);
476 }
477 out.add_node(node.op.clone(), lifted, batched_shape())
478 }
479
480 Op::DotGeneral {
482 lhs_contracting,
483 rhs_contracting,
484 lhs_batch,
485 rhs_batch,
486 } => {
487 let lhs = lift_to_batched(out, new_inputs[0], batched, batch_size);
488 let rhs = lift_to_batched(out, new_inputs[1], batched, batch_size);
489 batched.insert(lhs);
490 batched.insert(rhs);
491 let mut new_lhs_b: Vec<usize> = vec![0];
495 new_lhs_b.extend(lhs_batch.iter().map(|i| i + 1));
496 let mut new_rhs_b: Vec<usize> = vec![0];
497 new_rhs_b.extend(rhs_batch.iter().map(|i| i + 1));
498 out.add_node(
499 Op::DotGeneral {
500 lhs_contracting: lhs_contracting.iter().map(|i| i + 1).collect(),
501 rhs_contracting: rhs_contracting.iter().map(|i| i + 1).collect(),
502 lhs_batch: new_lhs_b,
503 rhs_batch: new_rhs_b,
504 },
505 vec![lhs, rhs],
506 batched_shape(),
507 )
508 }
509
510 Op::ReluBackward | Op::ActivationBackward { .. } => {
515 let lifted: Vec<NodeId> = new_inputs
516 .iter()
517 .map(|&id| lift_to_batched(out, id, batched, batch_size))
518 .collect();
519 for &id in &lifted {
520 batched.insert(id);
521 }
522 out.add_node(node.op.clone(), lifted, batched_shape())
523 }
524
525 Op::ScanBackward {
529 body_vjp,
530 length,
531 save_trajectory,
532 num_xs,
533 num_checkpoints: _,
534 forward_body: _,
535 } => {
536 let init_b = lift_to_batched(out, new_inputs[0], batched, batch_size);
538 batched.insert(init_b);
539
540 let traj_lifted = lift_to_batched(out, new_inputs[1], batched, batch_size);
544 batched.insert(traj_lifted);
545 let traj_t = transpose_swap_01(out, traj_lifted);
546 batched.insert(traj_t);
547
548 let up_lifted = lift_to_batched(out, new_inputs[2], batched, batch_size);
554 batched.insert(up_lifted);
555 let up_t = if *save_trajectory {
556 let id = transpose_swap_01(out, up_lifted);
557 batched.insert(id);
558 id
559 } else {
560 up_lifted
561 };
562
563 let mut xs_t: Vec<NodeId> = Vec::with_capacity(*num_xs as usize);
565 for i in 0..*num_xs as usize {
566 let xs_in = new_inputs[3 + i];
567 let lifted = lift_to_batched(out, xs_in, batched, batch_size);
568 batched.insert(lifted);
569 let t = transpose_swap_01(out, lifted);
570 batched.insert(t);
571 xs_t.push(t);
572 }
573
574 let body_input_names_owned: Vec<String> = body_vjp
577 .nodes()
578 .iter()
579 .filter_map(|n| match &n.op {
580 Op::Input { name } => Some(name.clone()),
581 _ => None,
582 })
583 .collect();
584 let body_input_names: Vec<&str> =
585 body_input_names_owned.iter().map(|s| s.as_str()).collect();
586 let body_vjp_b = vmap(body_vjp, &body_input_names, batch_size);
587
588 let mut dinit_dims: Vec<Dim> = vec![Dim::Static(batch_size)];
590 for i in 0..orig_shape.rank() {
591 dinit_dims.push(orig_shape.dim(i));
592 }
593 let dinit_shape = Shape::from_dims(&dinit_dims, dtype);
594
595 let mut inner_inputs = vec![init_b, traj_t, up_t];
596 inner_inputs.extend_from_slice(&xs_t);
597
598 out.scan_backward(
599 init_b,
600 traj_t,
601 up_t,
602 &xs_t,
603 body_vjp_b,
604 *length,
605 *save_trajectory,
606 dinit_shape,
607 )
608 }
609
610 Op::ScanBackwardXs {
614 body_vjp,
615 length,
616 save_trajectory,
617 num_xs,
618 xs_idx,
619 num_checkpoints: _,
620 forward_body: _,
621 } => {
622 let init_b = lift_to_batched(out, new_inputs[0], batched, batch_size);
623 batched.insert(init_b);
624 let traj_lifted = lift_to_batched(out, new_inputs[1], batched, batch_size);
625 batched.insert(traj_lifted);
626 let traj_t = transpose_swap_01(out, traj_lifted);
627 batched.insert(traj_t);
628 let up_lifted = lift_to_batched(out, new_inputs[2], batched, batch_size);
629 batched.insert(up_lifted);
630 let up_t = if *save_trajectory {
631 let id = transpose_swap_01(out, up_lifted);
632 batched.insert(id);
633 id
634 } else {
635 up_lifted
636 };
637
638 let mut xs_t: Vec<NodeId> = Vec::with_capacity(*num_xs as usize);
639 for i in 0..*num_xs as usize {
640 let xs_in = new_inputs[3 + i];
641 let lifted = lift_to_batched(out, xs_in, batched, batch_size);
642 batched.insert(lifted);
643 let t = transpose_swap_01(out, lifted);
644 batched.insert(t);
645 xs_t.push(t);
646 }
647
648 let body_input_names_owned: Vec<String> = body_vjp
649 .nodes()
650 .iter()
651 .filter_map(|n| match &n.op {
652 Op::Input { name } => Some(name.clone()),
653 _ => None,
654 })
655 .collect();
656 let body_input_names: Vec<&str> =
657 body_input_names_owned.iter().map(|s| s.as_str()).collect();
658 let body_vjp_b = vmap(body_vjp, &body_input_names, batch_size);
659
660 let mut inner_dims: Vec<Dim> = vec![orig_shape.dim(0)];
663 inner_dims.push(Dim::Static(batch_size));
664 for i in 1..orig_shape.rank() {
665 inner_dims.push(orig_shape.dim(i));
666 }
667 let inner_shape = Shape::from_dims(&inner_dims, dtype);
668
669 let inner_id = out.scan_backward_xs(
670 init_b,
671 traj_t,
672 up_t,
673 &xs_t,
674 body_vjp_b,
675 *length,
676 *save_trajectory,
677 *xs_idx,
678 inner_shape,
679 );
680
681 transpose_swap_01(out, inner_id)
683 }
684
685 Op::Quantize {
687 axis,
688 scales,
689 zero_points,
690 } => {
691 let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
692 batched.insert(lifted);
693 let new_axis = axis.map(|a| a + 1);
694 out.add_node(
695 Op::Quantize {
696 axis: new_axis,
697 scales: scales.clone(),
698 zero_points: zero_points.clone(),
699 },
700 vec![lifted],
701 batched_shape(),
702 )
703 }
704 Op::Dequantize {
705 axis,
706 scales,
707 zero_points,
708 } => {
709 let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
710 batched.insert(lifted);
711 let new_axis = axis.map(|a| a + 1);
712 out.add_node(
713 Op::Dequantize {
714 axis: new_axis,
715 scales: scales.clone(),
716 zero_points: zero_points.clone(),
717 },
718 vec![lifted],
719 batched_shape(),
720 )
721 }
722 Op::FakeQuantize {
723 bits,
724 axis,
725 ste,
726 scale_mode,
727 } => {
728 let lifted: Vec<NodeId> = new_inputs
729 .iter()
730 .map(|&id| lift_to_batched(out, id, batched, batch_size))
731 .collect();
732 for &id in &lifted {
733 batched.insert(id);
734 }
735 let new_axis = axis.map(|a| a + 1);
736 out.add_node(
737 Op::FakeQuantize {
738 bits: *bits,
739 axis: new_axis,
740 ste: *ste,
741 scale_mode: *scale_mode,
742 },
743 lifted,
744 batched_shape(),
745 )
746 }
747 Op::FakeQuantizeBackward { bits, axis, ste } => {
748 let lifted: Vec<NodeId> = new_inputs
749 .iter()
750 .map(|&id| lift_to_batched(out, id, batched, batch_size))
751 .collect();
752 for &id in &lifted {
753 batched.insert(id);
754 }
755 let new_axis = axis.map(|a| a + 1);
756 out.add_node(
757 Op::FakeQuantizeBackward {
758 bits: *bits,
759 axis: new_axis,
760 ste: *ste,
761 },
762 lifted,
763 batched_shape(),
764 )
765 }
766 Op::FakeQuantizeLSQ { bits, axis } => {
767 let lifted: Vec<NodeId> = new_inputs
768 .iter()
769 .map(|&id| lift_to_batched(out, id, batched, batch_size))
770 .collect();
771 for &id in &lifted {
772 batched.insert(id);
773 }
774 out.add_node(
775 Op::FakeQuantizeLSQ {
776 bits: *bits,
777 axis: axis.map(|a| a + 1),
778 },
779 lifted,
780 batched_shape(),
781 )
782 }
783 Op::FakeQuantizeLSQBackwardX { bits, axis } => {
784 let lifted: Vec<NodeId> = new_inputs
785 .iter()
786 .map(|&id| lift_to_batched(out, id, batched, batch_size))
787 .collect();
788 for &id in &lifted {
789 batched.insert(id);
790 }
791 out.add_node(
792 Op::FakeQuantizeLSQBackwardX {
793 bits: *bits,
794 axis: axis.map(|a| a + 1),
795 },
796 lifted,
797 batched_shape(),
798 )
799 }
800 Op::FakeQuantizeLSQBackwardScale { bits, axis } => {
801 let lifted: Vec<NodeId> = new_inputs
802 .iter()
803 .map(|&id| lift_to_batched(out, id, batched, batch_size))
804 .collect();
805 for &id in &lifted {
806 batched.insert(id);
807 }
808 out.add_node(
809 Op::FakeQuantizeLSQBackwardScale {
810 bits: *bits,
811 axis: axis.map(|a| a + 1),
812 },
813 lifted,
814 batched_shape(),
815 )
816 }
817
818 Op::LayerNormBackwardInput { axis, eps } => {
820 let lifted: Vec<NodeId> = new_inputs
821 .iter()
822 .map(|&id| lift_to_batched(out, id, batched, batch_size))
823 .collect();
824 for &id in &lifted {
825 batched.insert(id);
826 }
827 out.add_node(
828 Op::LayerNormBackwardInput {
829 axis: axis + 1,
830 eps: *eps,
831 },
832 lifted,
833 batched_shape(),
834 )
835 }
836 Op::LayerNormBackwardGamma { axis, eps } => {
837 let lifted: Vec<NodeId> = new_inputs
838 .iter()
839 .map(|&id| lift_to_batched(out, id, batched, batch_size))
840 .collect();
841 for &id in &lifted {
842 batched.insert(id);
843 }
844 out.add_node(
845 Op::LayerNormBackwardGamma {
846 axis: axis + 1,
847 eps: *eps,
848 },
849 lifted,
850 batched_shape(),
851 )
852 }
853
854 Op::TopK { k } => {
857 let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
858 batched.insert(lifted);
859 out.add_node(Op::TopK { k: *k }, vec![lifted], batched_shape())
860 }
861 Op::Sample {
862 top_k,
863 top_p,
864 temperature,
865 seed,
866 } => {
867 let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
868 batched.insert(lifted);
869 out.add_node(
870 Op::Sample {
871 top_k: *top_k,
872 top_p: *top_p,
873 temperature: *temperature,
874 seed: *seed,
875 },
876 vec![lifted],
877 batched_shape(),
878 )
879 }
880
881 Op::LoraMatMul { scale } => {
883 let lifted: Vec<NodeId> = new_inputs
884 .iter()
885 .map(|&id| lift_to_batched(out, id, batched, batch_size))
886 .collect();
887 for &id in &lifted {
888 batched.insert(id);
889 }
890 out.add_node(Op::LoraMatMul { scale: *scale }, lifted, batched_shape())
891 }
892
893 Op::Conv {
899 kernel_size,
900 stride,
901 padding,
902 dilation,
903 groups,
904 } => {
905 let x = lift_to_batched(out, new_inputs[0], batched, batch_size);
906 let w = new_inputs[1]; batched.insert(x);
908 let x_shape = out.node(x).shape.clone();
911 let r = x_shape.rank();
912 assert!(r == 5, "vmap Conv: expected 5-D after lift, got {r}");
913 let n_orig = match x_shape.dim(1) {
914 Dim::Static(n) => n,
915 _ => panic!("dynamic N"),
916 };
917 let bn = batch_size * n_orig;
918 let inner_dims_static: Vec<i64> = (2..r)
919 .map(|i| match x_shape.dim(i) {
920 Dim::Static(d) => d as i64,
921 _ => -1,
922 })
923 .collect();
924 let mut flat_dims = vec![bn as i64];
925 flat_dims.extend(inner_dims_static.iter().copied());
926 let mut flat_dim_objs = vec![Dim::Static(bn)];
927 for i in 2..r {
928 flat_dim_objs.push(x_shape.dim(i));
929 }
930 let flat_shape = Shape::from_dims(&flat_dim_objs, x_shape.dtype());
931 let x_flat = out.add_node(
932 Op::Reshape {
933 new_shape: flat_dims,
934 },
935 vec![x],
936 flat_shape,
937 );
938 let mut conv_out_dims = vec![Dim::Static(bn)];
940 for i in 1..orig_shape.rank() {
941 conv_out_dims.push(orig_shape.dim(i));
942 }
943 let conv_out_shape = Shape::from_dims(&conv_out_dims, dtype);
944 let conv_out = out.add_node(
945 Op::Conv {
946 kernel_size: kernel_size.clone(),
947 stride: stride.clone(),
948 padding: padding.clone(),
949 dilation: dilation.clone(),
950 groups: *groups,
951 },
952 vec![x_flat, w],
953 conv_out_shape,
954 );
955 let mut final_dims_static: Vec<i64> = vec![batch_size as i64];
957 for i in 0..orig_shape.rank() {
958 final_dims_static.push(match orig_shape.dim(i) {
959 Dim::Static(d) => d as i64,
960 _ => -1,
961 });
962 }
963 out.add_node(
964 Op::Reshape {
965 new_shape: final_dims_static,
966 },
967 vec![conv_out],
968 batched_shape(),
969 )
970 }
971 Op::Pool {
972 kind,
973 kernel_size,
974 stride,
975 padding,
976 } => {
977 let x = lift_to_batched(out, new_inputs[0], batched, batch_size);
978 batched.insert(x);
979 let x_shape = out.node(x).shape.clone();
980 let r = x_shape.rank();
981 assert!(r == 5, "vmap Pool: expected 5-D after lift, got {r}");
982 let n_orig = match x_shape.dim(1) {
983 Dim::Static(n) => n,
984 _ => panic!("dynamic N"),
985 };
986 let bn = batch_size * n_orig;
987 let mut flat_dims = vec![bn as i64];
988 for i in 2..r {
989 flat_dims.push(match x_shape.dim(i) {
990 Dim::Static(d) => d as i64,
991 _ => -1,
992 });
993 }
994 let mut flat_dim_objs = vec![Dim::Static(bn)];
995 for i in 2..r {
996 flat_dim_objs.push(x_shape.dim(i));
997 }
998 let flat_shape = Shape::from_dims(&flat_dim_objs, x_shape.dtype());
999 let x_flat = out.add_node(
1000 Op::Reshape {
1001 new_shape: flat_dims,
1002 },
1003 vec![x],
1004 flat_shape,
1005 );
1006 let mut pool_dims = vec![Dim::Static(bn)];
1007 for i in 1..orig_shape.rank() {
1008 pool_dims.push(orig_shape.dim(i));
1009 }
1010 let pool_out_shape = Shape::from_dims(&pool_dims, dtype);
1011 let pool_out = out.add_node(
1012 Op::Pool {
1013 kind: *kind,
1014 kernel_size: kernel_size.clone(),
1015 stride: stride.clone(),
1016 padding: padding.clone(),
1017 },
1018 vec![x_flat],
1019 pool_out_shape,
1020 );
1021 let mut final_dims_static: Vec<i64> = vec![batch_size as i64];
1022 for i in 0..orig_shape.rank() {
1023 final_dims_static.push(match orig_shape.dim(i) {
1024 Dim::Static(d) => d as i64,
1025 _ => -1,
1026 });
1027 }
1028 out.add_node(
1029 Op::Reshape {
1030 new_shape: final_dims_static,
1031 },
1032 vec![pool_out],
1033 batched_shape(),
1034 )
1035 }
1036
1037 Op::Attention { .. }
1042 | Op::FusedAttentionBlock { .. }
1043 | Op::FusedTransformerLayer { .. }
1044 | Op::Rope { .. } => panic!(
1045 "vmap: {:?} kernels expect a fixed input rank — extra batch \
1046 axis would need either decomposition (use rlx-opt unfuse \
1047 passes first) or a kernel rewrite. Skipped in MVP.",
1048 node.op,
1049 ),
1050
1051 Op::Conv2dBackwardInput {
1058 kernel_size,
1059 stride,
1060 padding,
1061 dilation,
1062 groups,
1063 } => {
1064 let dy = lift_to_batched(out, new_inputs[0], batched, batch_size);
1065 let w = new_inputs[1];
1066 batched.insert(dy);
1067 let dy_shape = out.node(dy).shape.clone();
1068 assert_eq!(
1069 dy_shape.rank(),
1070 5,
1071 "vmap Conv2dBackwardInput: expected 5-D dy"
1072 );
1073 let n_orig = match dy_shape.dim(1) {
1074 Dim::Static(n) => n,
1075 _ => panic!("dynamic N"),
1076 };
1077 let bn = batch_size * n_orig;
1078 let mut flat_dims_static: Vec<i64> = vec![bn as i64];
1079 for i in 2..dy_shape.rank() {
1080 flat_dims_static.push(match dy_shape.dim(i) {
1081 Dim::Static(d) => d as i64,
1082 _ => -1,
1083 });
1084 }
1085 let mut flat_dim_objs = vec![Dim::Static(bn)];
1086 for i in 2..dy_shape.rank() {
1087 flat_dim_objs.push(dy_shape.dim(i));
1088 }
1089 let dy_flat = out.add_node(
1090 Op::Reshape {
1091 new_shape: flat_dims_static,
1092 },
1093 vec![dy],
1094 Shape::from_dims(&flat_dim_objs, dy_shape.dtype()),
1095 );
1096 let mut out_flat_dim_objs = vec![Dim::Static(bn)];
1098 for i in 1..orig_shape.rank() {
1099 out_flat_dim_objs.push(orig_shape.dim(i));
1100 }
1101 let out_flat_shape = Shape::from_dims(&out_flat_dim_objs, dtype);
1102 let out_flat = out.add_node(
1103 Op::Conv2dBackwardInput {
1104 kernel_size: kernel_size.clone(),
1105 stride: stride.clone(),
1106 padding: padding.clone(),
1107 dilation: dilation.clone(),
1108 groups: *groups,
1109 },
1110 vec![dy_flat, w],
1111 out_flat_shape,
1112 );
1113 let mut final_dims: Vec<i64> = vec![batch_size as i64];
1115 for i in 0..orig_shape.rank() {
1116 final_dims.push(match orig_shape.dim(i) {
1117 Dim::Static(d) => d as i64,
1118 _ => -1,
1119 });
1120 }
1121 out.add_node(
1122 Op::Reshape {
1123 new_shape: final_dims,
1124 },
1125 vec![out_flat],
1126 batched_shape(),
1127 )
1128 }
1129
1130 Op::MaxPool2dBackward {
1134 kernel_size,
1135 stride,
1136 padding,
1137 } => {
1138 let x = lift_to_batched(out, new_inputs[0], batched, batch_size);
1139 let dy = lift_to_batched(out, new_inputs[1], batched, batch_size);
1140 batched.insert(x);
1141 batched.insert(dy);
1142 let x_shape = out.node(x).shape.clone();
1143 assert_eq!(x_shape.rank(), 5, "vmap MaxPool2dBackward: expected 5-D x");
1144 let n_orig = match x_shape.dim(1) {
1145 Dim::Static(n) => n,
1146 _ => panic!("dynamic N"),
1147 };
1148 let bn = batch_size * n_orig;
1149 let flatten = |out: &mut Graph, id: NodeId| -> NodeId {
1150 let s = out.node(id).shape.clone();
1151 let mut flat_objs = vec![Dim::Static(bn)];
1152 for i in 2..s.rank() {
1153 flat_objs.push(s.dim(i));
1154 }
1155 let flat_shape = Shape::from_dims(&flat_objs, s.dtype());
1156 let mut flat_static: Vec<i64> = vec![bn as i64];
1157 for i in 2..s.rank() {
1158 flat_static.push(match s.dim(i) {
1159 Dim::Static(d) => d as i64,
1160 _ => -1,
1161 });
1162 }
1163 out.add_node(
1164 Op::Reshape {
1165 new_shape: flat_static,
1166 },
1167 vec![id],
1168 flat_shape,
1169 )
1170 };
1171 let x_flat = flatten(out, x);
1172 let dy_flat = flatten(out, dy);
1173 let mut out_flat_objs = vec![Dim::Static(bn)];
1174 for i in 1..orig_shape.rank() {
1175 out_flat_objs.push(orig_shape.dim(i));
1176 }
1177 let out_flat_shape = Shape::from_dims(&out_flat_objs, dtype);
1178 let pool_out = out.add_node(
1179 Op::MaxPool2dBackward {
1180 kernel_size: kernel_size.clone(),
1181 stride: stride.clone(),
1182 padding: padding.clone(),
1183 },
1184 vec![x_flat, dy_flat],
1185 out_flat_shape,
1186 );
1187 let mut final_dims: Vec<i64> = vec![batch_size as i64];
1188 for i in 0..orig_shape.rank() {
1189 final_dims.push(match orig_shape.dim(i) {
1190 Dim::Static(d) => d as i64,
1191 _ => -1,
1192 });
1193 }
1194 out.add_node(
1195 Op::Reshape {
1196 new_shape: final_dims,
1197 },
1198 vec![pool_out],
1199 batched_shape(),
1200 )
1201 }
1202
1203 Op::Conv2dBackwardWeight { .. } => panic!(
1204 "vmap: Conv2dBackwardWeight: weight gradient is summed across \
1205 samples — vmap-batching gives a B-stack of independent dWs. \
1206 Reshape-trick doesn't apply since the output isn't naturally \
1207 N-leading. Add a per-batch dW pattern when needed.",
1208 ),
1209
1210 Op::SelectiveScan { .. }
1211 | Op::GroupedMatMul
1212 | Op::QMatMul { .. }
1213 | Op::QConv2d { .. }
1214 | Op::DequantMatMul { .. } => panic!(
1215 "vmap: {:?} has its own internal batch handling; \
1216 the right rule depends on whether the user wants \
1217 nested batching or to fold into the existing batch \
1218 dim. Add a rule when a real workload demands it.",
1219 node.op,
1220 ),
1221
1222 Op::DequantGroupedMatMul { scheme } => {
1224 let x = lift_to_batched(out, new_inputs[0], batched, batch_size);
1225 let idx = lift_to_batched(out, new_inputs[2], batched, batch_size);
1226 let w = new_inputs[1];
1227 batched.insert(x);
1228 batched.insert(idx);
1229 let x_shape = out.node(x).shape.clone();
1230 assert_eq!(
1231 x_shape.rank(),
1232 3,
1233 "vmap DequantGroupedMatMul: expected 3-D x"
1234 );
1235 let m_orig = match x_shape.dim(1) {
1236 Dim::Static(v) => v,
1237 _ => panic!("dynamic M"),
1238 };
1239 let k = match x_shape.dim(2) {
1240 Dim::Static(v) => v as i64,
1241 _ => -1,
1242 };
1243 let bm = batch_size * m_orig;
1244 let n = match orig_shape.dim(orig_shape.rank() - 1) {
1245 Dim::Static(v) => v as i64,
1246 _ => -1,
1247 };
1248 let x_flat = out.add_node(
1249 Op::Reshape {
1250 new_shape: vec![bm as i64, k],
1251 },
1252 vec![x],
1253 Shape::from_dims(&[Dim::Static(bm), x_shape.dim(2)], orig_shape.dtype()),
1254 );
1255 let idx_flat = out.add_node(
1256 Op::Reshape {
1257 new_shape: vec![bm as i64],
1258 },
1259 vec![idx],
1260 Shape::from_dims(&[Dim::Static(bm)], orig_shape.dtype()),
1261 );
1262 let y_flat = out.add_node(
1263 Op::DequantGroupedMatMul { scheme: *scheme },
1264 vec![x_flat, w, idx_flat],
1265 Shape::from_dims(
1266 &[Dim::Static(bm), orig_shape.dim(orig_shape.rank() - 1)],
1267 orig_shape.dtype(),
1268 ),
1269 );
1270 let mut final_dims: Vec<i64> = vec![batch_size as i64, m_orig as i64];
1271 final_dims.push(n);
1272 out.add_node(
1273 Op::Reshape {
1274 new_shape: final_dims,
1275 },
1276 vec![y_flat],
1277 batched_shape(),
1278 )
1279 }
1280
1281 Op::DequantMoEWeights { .. } => panic!(
1282 "vmap: DequantMoEWeights is a weight materialization helper; \
1283 vmap the downstream GroupedMatMul / DequantGroupedMatMul instead.",
1284 ),
1285
1286 Op::FusedSwiGLU { .. }
1287 | Op::FusedMatMulBiasAct { .. }
1288 | Op::FusedResidualLN { .. }
1289 | Op::FusedResidualRmsNorm { .. } => {
1290 panic!(
1291 "vmap: {:?} is fused — decompose first via \
1292 `rlx_fusion::UnfuseElementwiseRegions` (or \
1293 `rlx_fusion::unfuse_fused_for_autodiff`) so the simpler \
1294 ops get vmap'd individually.",
1295 node.op,
1296 )
1297 }
1298
1299 Op::SoftmaxCrossEntropyWithLogits | Op::SoftmaxCrossEntropyBackward => panic!(
1300 "vmap: SoftmaxCrossEntropy* expect 2-D logits; lifting to \
1301 3-D would need a kernel change. Workaround: reshape \
1302 logits to 2-D before the op and back after.",
1303 ),
1304
1305 Op::Custom { name, .. } => {
1306 let ext = rlx_ir::lookup_op(name)
1311 .unwrap_or_else(|| panic!("vmap: Op::Custom('{name}') not registered"));
1312 let is_batched: Vec<bool> = new_inputs.iter().map(|i| batched.contains(i)).collect();
1313 let mut ctx = rlx_ir::VmapContext {
1314 lifted_inputs: new_inputs,
1315 is_batched: &is_batched,
1316 batch_size,
1317 out,
1318 };
1319 match ext.vmap(node, &mut ctx) {
1320 Some(id) => id,
1321 None => panic!(
1322 "vmap: Op::Custom('{name}') has no vmap rule registered. \
1323 Override `OpExtension::vmap` on the impl to add one."
1324 ),
1325 }
1326 }
1327
1328 Op::CustomFn {
1335 fwd_body,
1336 vjp_body,
1337 jvp_body,
1338 num_inputs,
1339 } => {
1340 let mut lifted_inputs: Vec<NodeId> = Vec::with_capacity(*num_inputs as usize);
1342 for &raw in new_inputs.iter() {
1343 let lifted = lift_to_batched(out, raw, batched, batch_size);
1344 batched.insert(lifted);
1345 lifted_inputs.push(lifted);
1346 }
1347
1348 let vmap_body = |body: &Graph| -> Graph {
1349 let names_owned: Vec<String> = body
1350 .nodes()
1351 .iter()
1352 .filter_map(|n| match &n.op {
1353 Op::Input { name } => Some(name.clone()),
1354 _ => None,
1355 })
1356 .collect();
1357 let names: Vec<&str> = names_owned.iter().map(|s| s.as_str()).collect();
1358 vmap(body, &names, batch_size)
1359 };
1360
1361 let fwd_b = vmap_body(fwd_body);
1362 let vjp_b = vjp_body.as_ref().map(|g| vmap_body(g));
1363 let jvp_b = jvp_body.as_ref().map(|g| vmap_body(g));
1364
1365 let mut out_dims: Vec<Dim> = vec![Dim::Static(batch_size)];
1367 for i in 0..orig_shape.rank() {
1368 out_dims.push(orig_shape.dim(i));
1369 }
1370 let out_shape = Shape::from_dims(&out_dims, orig_shape.dtype());
1371
1372 let id = out.add_node(
1373 Op::CustomFn {
1374 fwd_body: Box::new(fwd_b),
1375 vjp_body: vjp_b.map(Box::new),
1376 jvp_body: jvp_b.map(Box::new),
1377 num_inputs: *num_inputs,
1378 },
1379 lifted_inputs,
1380 out_shape,
1381 );
1382 batched.insert(id);
1383 id
1384 }
1385
1386 other => panic!(
1387 "vmap: no rule for op {:?}. Add a per-op rule in vmap.rs.",
1388 other,
1389 ),
1390 }
1391}
1392
1393fn transpose_swap_01(out: &mut Graph, id: NodeId) -> NodeId {
1399 let s = out.node(id).shape.clone();
1400 let r = s.rank();
1401 debug_assert!(r >= 2, "transpose_swap_01 needs rank ≥ 2");
1402 let mut perm: Vec<usize> = vec![1, 0];
1403 for i in 2..r {
1404 perm.push(i);
1405 }
1406 let mut new_dims: Vec<Dim> = s.dims().to_vec();
1407 new_dims.swap(0, 1);
1408 let new_shape = Shape::from_dims(&new_dims, s.dtype());
1409 out.add_node(Op::Transpose { perm }, vec![id], new_shape)
1410}
1411
1412fn lift_to_batched(
1416 out: &mut Graph,
1417 id: NodeId,
1418 batched: &HashSet<NodeId>,
1419 batch_size: usize,
1420) -> NodeId {
1421 if batched.contains(&id) {
1422 return id;
1423 }
1424 let orig_shape = out.node(id).shape.clone();
1425 let dtype = orig_shape.dtype();
1426
1427 let mut dims_with_1: Vec<Dim> = vec![Dim::Static(1)];
1429 dims_with_1.extend(orig_shape.dims().iter().copied());
1430 let with1_shape = Shape::from_dims(&dims_with_1, dtype);
1431 let reshape_dims: Vec<i64> = dims_with_1
1432 .iter()
1433 .map(|d| match d {
1434 Dim::Static(n) => *n as i64,
1435 Dim::Dynamic(_) => -1,
1436 })
1437 .collect();
1438 let with1 = out.add_node(
1439 Op::Reshape {
1440 new_shape: reshape_dims,
1441 },
1442 vec![id],
1443 with1_shape,
1444 );
1445
1446 let mut target_dims: Vec<i64> = vec![batch_size as i64];
1448 for d in orig_shape.dims().iter() {
1449 target_dims.push(match d {
1450 Dim::Static(n) => *n as i64,
1451 Dim::Dynamic(_) => -1,
1452 });
1453 }
1454 let mut target_shape_dims: Vec<Dim> = vec![Dim::Static(batch_size)];
1455 target_shape_dims.extend(orig_shape.dims().iter().copied());
1456 let target_shape = Shape::from_dims(&target_shape_dims, dtype);
1457 out.add_node(
1458 Op::Expand {
1459 target_shape: target_dims,
1460 },
1461 vec![with1],
1462 target_shape,
1463 )
1464}
1465
1466#[cfg(test)]
1467mod tests {
1468 use super::*;
1469 use rlx_ir::op::{BinaryOp, ReduceOp};
1470
1471 #[test]
1479 fn vmap_elementwise_scaling_lifts_to_batched_shape() {
1480 let n = 3usize;
1481 let batch = 4usize;
1482 let mut g = Graph::new("scale");
1483 let x = g.input("x", Shape::new(&[n], DType::F64));
1484 let two_bytes: Vec<u8> = (0..n).flat_map(|_| 2.0_f64.to_le_bytes()).collect();
1485 let two = g.add_node(
1486 Op::Constant { data: two_bytes },
1487 vec![],
1488 Shape::new(&[n], DType::F64),
1489 );
1490 let y = g.binary(BinaryOp::Mul, x, two, Shape::new(&[n], DType::F64));
1491 g.set_outputs(vec![y]);
1492
1493 let bg = vmap(&g, &["x"], batch);
1494 let out_id = bg.outputs[0];
1496 let out_shape = &bg.node(out_id).shape;
1497 assert_eq!(out_shape.dims().len(), 2);
1498 assert_eq!(out_shape.dim(0), Dim::Static(batch));
1499 assert_eq!(out_shape.dim(1), Dim::Static(n));
1500 }
1501
1502 #[test]
1506 fn vmap_matmul_with_shared_weight() {
1507 let m = 2usize;
1508 let k = 3usize;
1509 let n = 4usize;
1510 let batch = 5usize;
1511 let mut g = Graph::new("mm");
1512 let x = g.input("x", Shape::new(&[m, k], DType::F64));
1513 let w = g.input("w", Shape::new(&[k, n], DType::F64));
1514 let y = g.matmul(x, w, Shape::new(&[m, n], DType::F64));
1515 g.set_outputs(vec![y]);
1516
1517 let bg = vmap(&g, &["x"], batch);
1518 let out_id = bg.outputs[0];
1519 let out_shape = &bg.node(out_id).shape;
1520 assert_eq!(out_shape.dims().len(), 3);
1521 assert_eq!(out_shape.dim(0), Dim::Static(batch));
1522 assert_eq!(out_shape.dim(1), Dim::Static(m));
1523 assert_eq!(out_shape.dim(2), Dim::Static(n));
1524 }
1525
1526 #[test]
1531 fn vmap_extended_op_set_lifts_without_panic() {
1532 let mut g = Graph::new("gather_check");
1534 let table = g.input("table", Shape::new(&[5, 4], DType::F32));
1535 let idx = g.input("idx", Shape::new(&[3], DType::F32));
1536 let out_node = g.add_node(
1537 Op::Gather { axis: 0 },
1538 vec![table, idx],
1539 Shape::new(&[3, 4], DType::F32),
1540 );
1541 g.set_outputs(vec![out_node]);
1542 let bg = vmap(&g, &["table"], 2);
1543 let s = &bg.node(bg.outputs[0]).shape;
1545 assert_eq!(s.rank(), 3);
1546 assert_eq!(s.dim(0), Dim::Static(2));
1547
1548 let mut g = Graph::new("relu_bwd_check");
1550 let x = g.input("x", Shape::new(&[4], DType::F32));
1551 let dy = g.input("dy", Shape::new(&[4], DType::F32));
1552 let dx = g.add_node(Op::ReluBackward, vec![x, dy], Shape::new(&[4], DType::F32));
1553 g.set_outputs(vec![dx]);
1554 let bg = vmap(&g, &["x"], 3);
1555 let s = &bg.node(bg.outputs[0]).shape;
1556 assert_eq!(s.rank(), 2);
1557 assert_eq!(s.dim(0), Dim::Static(3));
1558 }
1559
1560 #[test]
1563 fn vmap_combined_matmul_add_reduce() {
1564 let n = 3usize;
1565 let batch = 4usize;
1566 let mut g = Graph::new("combined");
1567 let x = g.input("x", Shape::new(&[n], DType::F64));
1568 let w = g.input("w", Shape::new(&[n, n], DType::F64));
1569 let b = g.input("b", Shape::new(&[n], DType::F64));
1570 let x_row = g.add_node(
1572 Op::Reshape {
1573 new_shape: vec![1, n as i64],
1574 },
1575 vec![x],
1576 Shape::new(&[1, n], DType::F64),
1577 );
1578 let mm = g.matmul(x_row, w, Shape::new(&[1, n], DType::F64));
1579 let mm_flat = g.add_node(
1580 Op::Reshape {
1581 new_shape: vec![n as i64],
1582 },
1583 vec![mm],
1584 Shape::new(&[n], DType::F64),
1585 );
1586 let yv = g.binary(BinaryOp::Add, mm_flat, b, Shape::new(&[n], DType::F64));
1587 let loss = g.reduce(
1588 yv,
1589 ReduceOp::Sum,
1590 vec![0],
1591 false,
1592 Shape::new(&[1], DType::F64),
1593 );
1594 g.set_outputs(vec![loss]);
1595
1596 let bg = vmap(&g, &["x"], batch);
1597 let out = bg.node(bg.outputs[0]);
1598 assert_eq!(out.shape.dim(0), Dim::Static(batch));
1601 assert_eq!(out.shape.rank(), 2);
1602 }
1603}