1use rlx_ir::op::*;
114use rlx_ir::shape::Dim;
115use rlx_ir::*;
116use std::collections::HashMap;
117
118pub use crate::prepare_ad::{
119 AutodiffError, PrepareForAutodiff, grad_with_loss_module, jvp_module, prepare_graph_for_ad,
120 prepare_mir_for_ad, prepare_module_for_ad,
121};
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub struct GradWithLossOptions {
150 pub zero_missing_wrt: bool,
153}
154
155impl GradWithLossOptions {
156 pub const STRICT: Self = Self {
157 zero_missing_wrt: false,
158 };
159 pub const TRAINING: Self = Self {
160 zero_missing_wrt: true,
161 };
162}
163
164pub fn grad_with_loss(forward: &Graph, wrt: &[NodeId]) -> Graph {
170 grad_with_loss_opts(forward, wrt, GradWithLossOptions::STRICT)
171}
172
173pub fn grad_with_loss_opts(forward: &Graph, wrt: &[NodeId], opts: GradWithLossOptions) -> Graph {
175 assert!(
176 !forward.outputs.is_empty(),
177 "grad_with_loss: forward must have at least one output (the loss)"
178 );
179
180 let forward_owned = crate::prepare_ad::prepare_graph_for_ad(forward.clone());
193 let forward = &forward_owned;
194
195 let mut bwd = Graph::new(format!("{}_grad", forward.name));
196
197 let mut fwd_to_bwd: HashMap<NodeId, NodeId> = HashMap::new();
202 for node in forward.nodes() {
203 let inputs: Vec<NodeId> = node.inputs.iter().map(|i| fwd_to_bwd[i]).collect();
204 let new_id = bwd.add_node(node.op.clone(), inputs, node.shape.clone());
205 fwd_to_bwd.insert(node.id, new_id);
206 }
207
208 let loss_fwd = forward.outputs[0];
211 let loss_bwd = fwd_to_bwd[&loss_fwd];
212 let loss_shape = forward.node(loss_fwd).shape.clone();
213 let d_output = bwd.input("d_output", loss_shape);
214
215 let mut grads: HashMap<NodeId, NodeId> = HashMap::new();
216 grads.insert(loss_bwd, d_output);
217
218 for fwd_node in forward.nodes().iter().rev() {
219 let bwd_id = fwd_to_bwd[&fwd_node.id];
220 let upstream = match grads.get(&bwd_id) {
221 Some(g) => *g,
222 None => continue,
223 };
224 let input_grads = vjp(fwd_node, upstream, &fwd_to_bwd, &mut bwd);
225 for (idx, grad_id) in input_grads {
226 let target = fwd_node.inputs[idx];
227 let bwd_target = fwd_to_bwd[&target];
228 let new_grad = if let Some(&prev) = grads.get(&bwd_target) {
230 let shape = bwd.node(prev).shape.clone();
231 bwd.binary(BinaryOp::Add, prev, grad_id, shape)
232 } else {
233 grad_id
234 };
235 grads.insert(bwd_target, new_grad);
236 }
237 }
238
239 let n_aux = forward.outputs.len().saturating_sub(1);
240 let mut outputs = Vec::with_capacity(1 + n_aux + wrt.len());
241 outputs.push(loss_bwd);
242 for &aux in &forward.outputs[1..] {
245 outputs.push(fwd_to_bwd[&aux]);
246 }
247 for &id in wrt {
248 let g = match grads.get(&fwd_to_bwd[&id]).copied() {
249 Some(g) => g,
250 None if opts.zero_missing_wrt => {
251 let shape = forward.node(id).shape.clone();
252 let n = shape.num_elements().unwrap_or(0);
253 let data: Vec<u8> = (0..n).flat_map(|_| 0.0f32.to_le_bytes()).collect();
254 bwd.add_node(Op::Constant { data }, vec![], shape)
255 }
256 None => {
257 panic!(
258 "no gradient flowed to {id} — \
259 either the forward graph doesn't depend on it, or one \
260 of its consumer ops has no VJP rule"
261 )
262 }
263 };
264 outputs.push(g);
265 }
266 bwd.set_outputs(outputs);
267 bwd
268}
269
270pub fn grad(forward: &Graph, wrt: &[NodeId]) -> Graph {
274 let g = grad_with_loss(forward, wrt);
275 let mut g = g;
276 let outs = g.outputs.iter().skip(1).copied().collect();
278 g.set_outputs(outs);
279 g
280}
281
282pub fn quantized_weight_bits(forward: &Graph, node_id: NodeId) -> Option<u8> {
295 match &forward.node(node_id).op {
296 Op::FakeQuantize { bits, .. } => Some(*bits),
297 Op::FakeQuantizeLSQ { bits, .. } => Some(*bits),
298 _ => None,
299 }
300}
301
302fn unbroadcast(grad: NodeId, target_shape: &Shape, bwd: &mut Graph) -> NodeId {
303 let grad_shape = bwd.node(grad).shape.clone();
304 if grad_shape == *target_shape {
305 return grad;
306 }
307 let g_rank = grad_shape.rank();
308 let t_rank = target_shape.rank();
309 let extra = g_rank.saturating_sub(t_rank);
310
311 let mut axes: Vec<usize> = (0..extra).collect();
313 for i in 0..t_rank {
314 let g_dim = grad_shape.dim(extra + i);
315 let t_dim = target_shape.dim(i);
316 if matches!(t_dim, Dim::Static(1)) && !matches!(g_dim, Dim::Static(1)) {
317 axes.push(extra + i);
318 }
319 }
320
321 let mut current = grad;
322 if !axes.is_empty() {
323 let mut running_dims: Vec<Dim> = (0..g_rank).map(|i| grad_shape.dim(i)).collect();
330 for &ax in &axes {
331 running_dims[ax] = Dim::Static(1);
332 let step_shape = Shape::from_dims(&running_dims, target_shape.dtype());
333 current = bwd.add_node(
334 Op::Reduce {
335 op: ReduceOp::Sum,
336 axes: vec![ax],
337 keep_dim: true,
338 },
339 vec![current],
340 step_shape,
341 );
342 }
343 }
344
345 if bwd.node(current).shape.rank() != t_rank {
347 let new_shape: Vec<i64> = target_shape
348 .dims()
349 .iter()
350 .map(|d| match d {
351 Dim::Static(n) => *n as i64,
352 Dim::Dynamic(_) => -1,
353 })
354 .collect();
355 current = bwd.add_node(
356 Op::Reshape { new_shape },
357 vec![current],
358 target_shape.clone(),
359 );
360 }
361 current
362}
363
364fn reshape_to(grad: NodeId, target_shape: &Shape, bwd: &mut Graph) -> NodeId {
366 if bwd.node(grad).shape == *target_shape {
367 return grad;
368 }
369 let new_shape: Vec<i64> = target_shape
370 .dims()
371 .iter()
372 .map(|d| match d {
373 Dim::Static(n) => *n as i64,
374 Dim::Dynamic(_) => -1,
375 })
376 .collect();
377 bwd.add_node(Op::Reshape { new_shape }, vec![grad], target_shape.clone())
378}
379
380fn grouped_matmul_vjp(
382 bwd: &mut Graph,
383 upstream: NodeId,
384 x_bwd: NodeId,
385 w_bwd: NodeId,
386 expert_bwd: NodeId,
387 x_shape: &Shape,
388 w_shape: &Shape,
389) -> (NodeId, NodeId) {
390 let dtype = x_shape.dtype();
391 let m = x_shape.dim(0);
392 let k = x_shape.dim(1);
393 let e = w_shape.dim(0);
394 let n_out = w_shape.dim(2);
395 let m_static = match m {
396 Dim::Static(v) => v,
397 _ => panic!("GroupedMatMul VJP: M must be static"),
398 };
399 let k_static = match k {
400 Dim::Static(v) => v,
401 _ => panic!("GroupedMatMul VJP: K must be static"),
402 };
403 let n_static = match n_out {
404 Dim::Static(v) => v,
405 _ => panic!("GroupedMatMul VJP: N must be static"),
406 };
407
408 let w_per = bwd.add_node(
409 Op::Gather { axis: 0 },
410 vec![w_bwd, expert_bwd],
411 Shape::from_dims(&[m, k, n_out], dtype),
412 );
413
414 let up_3d_shape = Shape::from_dims(&[m, Dim::Static(1), n_out], dtype);
415 let up_3d = bwd.reshape(
416 upstream,
417 vec![m_static as i64, 1, n_static as i64],
418 up_3d_shape,
419 );
420 let w_per_t = bwd.add_node(
421 Op::Transpose {
422 perm: vec![0, 2, 1],
423 },
424 vec![w_per],
425 Shape::from_dims(&[m, n_out, k], dtype),
426 );
427 let dx_3d_shape = Shape::from_dims(&[m, Dim::Static(1), k], dtype);
428 let dx_3d = bwd.matmul(up_3d, w_per_t, dx_3d_shape);
429 let dx = bwd.reshape(
430 dx_3d,
431 vec![m_static as i64, k_static as i64],
432 x_shape.clone(),
433 );
434
435 let x_3d = bwd.reshape(
436 x_bwd,
437 vec![m_static as i64, k_static as i64, 1],
438 Shape::from_dims(&[m, k, Dim::Static(1)], dtype),
439 );
440 let up_for_outer = bwd.reshape(
441 upstream,
442 vec![m_static as i64, 1, n_static as i64],
443 Shape::from_dims(&[m, Dim::Static(1), n_out], dtype),
444 );
445 let dw_per = bwd.matmul(x_3d, up_for_outer, Shape::from_dims(&[m, k, n_out], dtype));
446 let dw = bwd.add_node(
447 Op::ScatterAdd,
448 vec![dw_per, expert_bwd],
449 Shape::from_dims(&[e, k, n_out], dtype),
450 );
451 (dx, dw)
452}
453
454fn scalar_const(value: f32, bwd: &mut Graph) -> NodeId {
456 let bytes = value.to_le_bytes().to_vec();
457 let shape = Shape::from_dims(&[Dim::Static(1)], DType::F32);
458 bwd.add_node(Op::Constant { data: bytes }, vec![], shape)
459}
460
461fn vjp(
465 node: &Node,
466 upstream: NodeId,
467 fwd_map: &HashMap<NodeId, NodeId>,
468 bwd: &mut Graph,
469) -> Vec<(usize, NodeId)> {
470 let upstream_shape = bwd.node(upstream).shape.clone();
471 match &node.op {
472 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => vec![],
474
475 Op::Binary(BinaryOp::Add) => {
476 let a_bwd = fwd_map[&node.inputs[0]];
477 let b_bwd = fwd_map[&node.inputs[1]];
478 let a_shape = bwd.node(a_bwd).shape.clone();
479 let b_shape = bwd.node(b_bwd).shape.clone();
480 let g_a = unbroadcast(upstream, &a_shape, bwd);
481 let g_b = unbroadcast(upstream, &b_shape, bwd);
482 vec![(0, g_a), (1, g_b)]
483 }
484
485 Op::Binary(BinaryOp::Sub) => {
486 let a_bwd = fwd_map[&node.inputs[0]];
487 let b_bwd = fwd_map[&node.inputs[1]];
488 let a_shape = bwd.node(a_bwd).shape.clone();
489 let b_shape = bwd.node(b_bwd).shape.clone();
490 let neg = bwd.activation(Activation::Neg, upstream, upstream_shape.clone());
491 let g_a = unbroadcast(upstream, &a_shape, bwd);
492 let g_b = unbroadcast(neg, &b_shape, bwd);
493 vec![(0, g_a), (1, g_b)]
494 }
495
496 Op::Binary(BinaryOp::Mul) => {
497 let a_bwd = fwd_map[&node.inputs[0]];
498 let b_bwd = fwd_map[&node.inputs[1]];
499 let a_shape = bwd.node(a_bwd).shape.clone();
500 let b_shape = bwd.node(b_bwd).shape.clone();
501 let is_c64 = upstream_shape.dtype() == DType::C64;
507 let b_for_a = if is_c64 { bwd.conjugate(b_bwd) } else { b_bwd };
508 let a_for_b = if is_c64 { bwd.conjugate(a_bwd) } else { a_bwd };
509 let g_a_full = bwd.binary(BinaryOp::Mul, upstream, b_for_a, upstream_shape.clone());
510 let g_b_full = bwd.binary(BinaryOp::Mul, upstream, a_for_b, upstream_shape);
511 let g_a = unbroadcast(g_a_full, &a_shape, bwd);
512 let g_b = unbroadcast(g_b_full, &b_shape, bwd);
513 vec![(0, g_a), (1, g_b)]
514 }
515
516 Op::Activation(kind) => {
517 let x_bwd = fwd_map[&node.inputs[0]];
518 let dx = match kind {
523 Activation::Relu => bwd.relu_backward(x_bwd, upstream),
524 _ => bwd.activation_backward(*kind, x_bwd, upstream),
525 };
526 vec![(0, dx)]
527 }
528
529 Op::MatMul => {
530 let a_bwd = fwd_map[&node.inputs[0]];
541 let b_bwd = fwd_map[&node.inputs[1]];
542 let a_shape = bwd.node(a_bwd).shape.clone();
543 let b_shape = bwd.node(b_bwd).shape.clone();
544 assert!(
545 a_shape.rank() >= 2 && b_shape.rank() >= 2,
546 "MatMul VJP: rank must be ≥ 2, got {} and {}",
547 a_shape.rank(),
548 b_shape.rank()
549 );
550 let dtype = upstream_shape.dtype();
551
552 let trans_last_two = |bwd: &mut Graph, x: NodeId| -> NodeId {
554 let s = bwd.node(x).shape.clone();
555 let r = s.rank();
556 let mut perm: Vec<usize> = (0..r).collect();
557 perm.swap(r - 2, r - 1);
558 let mut dims: Vec<Dim> = s.dims().to_vec();
559 dims.swap(r - 2, r - 1);
560 let new_shape = Shape::from_dims(&dims, s.dtype());
561 bwd.add_node(Op::Transpose { perm }, vec![x], new_shape)
562 };
563
564 let upstream_dims: Vec<Dim> = upstream_shape.dims().to_vec();
567 let r_up = upstream_dims.len();
568
569 let b_t = trans_last_two(bwd, b_bwd);
571 let mut ga_dims = upstream_dims.clone();
572 ga_dims[r_up - 1] = a_shape.dim(a_shape.rank() - 1); let ga_shape = Shape::from_dims(&ga_dims, dtype);
574 let g_a_full = bwd.matmul(upstream, b_t, ga_shape);
575 let g_a = unbroadcast(g_a_full, &a_shape, bwd);
576
577 let a_t = trans_last_two(bwd, a_bwd);
579 let mut gb_dims = upstream_dims.clone();
580 gb_dims[r_up - 2] = a_shape.dim(a_shape.rank() - 1); let gb_shape = Shape::from_dims(&gb_dims, dtype);
582 let g_b_full = bwd.matmul(a_t, upstream, gb_shape);
583 let g_b = unbroadcast(g_b_full, &b_shape, bwd);
584
585 vec![(0, g_a), (1, g_b)]
586 }
587
588 Op::DenseSolve => {
589 let a_bwd = fwd_map[&node.inputs[0]];
598 let x_bwd = fwd_map[&node.id];
599 let a_shape = bwd.node(a_bwd).shape.clone();
600 let x_shape = bwd.node(x_bwd).shape.clone();
601 assert_eq!(a_shape.rank(), 2, "DenseSolve VJP: A must be 2-D");
602 let n = match a_shape.dim(0) {
603 Dim::Static(n) => n,
604 Dim::Dynamic(_) => panic!("DenseSolve VJP: dynamic N not supported"),
605 };
606 let dtype = a_shape.dtype();
607
608 let mut a_t_dims: Vec<Dim> = a_shape.dims().to_vec();
610 a_t_dims.swap(0, 1);
611 let a_t_shape = Shape::from_dims(&a_t_dims, dtype);
612 let a_t = bwd.add_node(Op::Transpose { perm: vec![1, 0] }, vec![a_bwd], a_t_shape);
613
614 let d_b = bwd.dense_solve(a_t, upstream, x_shape.clone());
616
617 let neg_outer = match x_shape.rank() {
619 1 => {
620 let col_shape = Shape::from_dims(&[Dim::Static(n), Dim::Static(1)], dtype);
622 let row_shape = Shape::from_dims(&[Dim::Static(1), Dim::Static(n)], dtype);
623 let db_col = bwd.add_node(
624 Op::Reshape {
625 new_shape: vec![n as i64, 1],
626 },
627 vec![d_b],
628 col_shape,
629 );
630 let x_row = bwd.add_node(
631 Op::Reshape {
632 new_shape: vec![1, n as i64],
633 },
634 vec![x_bwd],
635 row_shape,
636 );
637 let outer = bwd.matmul(db_col, x_row, a_shape.clone());
638 bwd.activation(Activation::Neg, outer, a_shape)
639 }
640 2 => {
641 let k = match x_shape.dim(1) {
643 Dim::Static(k) => k,
644 Dim::Dynamic(_) => panic!("DenseSolve VJP: dynamic K not supported"),
645 };
646 let xt_dims = vec![Dim::Static(k), Dim::Static(n)];
647 let xt_shape = Shape::from_dims(&xt_dims, dtype);
648 let x_t =
649 bwd.add_node(Op::Transpose { perm: vec![1, 0] }, vec![x_bwd], xt_shape);
650 let outer = bwd.matmul(d_b, x_t, a_shape.clone());
651 bwd.activation(Activation::Neg, outer, a_shape)
652 }
653 r => panic!("DenseSolve VJP: B must be rank 1 or 2, got rank {r}"),
654 };
655
656 vec![(0, neg_outer), (1, d_b)]
657 }
658
659 Op::BatchedDenseSolve => {
660 let a_bwd = fwd_map[&node.inputs[0]];
666 let x_bwd = fwd_map[&node.id];
667 let a_shape = bwd.node(a_bwd).shape.clone();
668 let x_shape = bwd.node(x_bwd).shape.clone();
669 assert_eq!(
670 a_shape.rank(),
671 3,
672 "BatchedDenseSolve VJP: A must be rank-3 [B, N, N]"
673 );
674 let b_dim = match a_shape.dim(0) {
675 Dim::Static(b) => b,
676 Dim::Dynamic(_) => panic!("BatchedDenseSolve VJP: dynamic B not supported"),
677 };
678 let n = match a_shape.dim(1) {
679 Dim::Static(n) => n,
680 Dim::Dynamic(_) => panic!("BatchedDenseSolve VJP: dynamic N not supported"),
681 };
682 let dtype = a_shape.dtype();
683
684 let a_t = bwd.add_node(
687 Op::Transpose {
688 perm: vec![0, 2, 1],
689 },
690 vec![a_bwd],
691 a_shape.clone(),
692 );
693
694 let d_b = bwd.batched_dense_solve(a_t, upstream, x_shape.clone());
696
697 let neg_outer = match x_shape.rank() {
699 2 => {
700 let col_shape = Shape::from_dims(
703 &[Dim::Static(b_dim), Dim::Static(n), Dim::Static(1)],
704 dtype,
705 );
706 let row_shape = Shape::from_dims(
707 &[Dim::Static(b_dim), Dim::Static(1), Dim::Static(n)],
708 dtype,
709 );
710 let db_col = bwd.add_node(
711 Op::Reshape {
712 new_shape: vec![b_dim as i64, n as i64, 1],
713 },
714 vec![d_b],
715 col_shape,
716 );
717 let x_row = bwd.add_node(
718 Op::Reshape {
719 new_shape: vec![b_dim as i64, 1, n as i64],
720 },
721 vec![x_bwd],
722 row_shape,
723 );
724 let outer = bwd.matmul(db_col, x_row, a_shape.clone());
725 bwd.activation(Activation::Neg, outer, a_shape)
726 }
727 3 => {
728 let k = match x_shape.dim(2) {
731 Dim::Static(k) => k,
732 Dim::Dynamic(_) => panic!("BatchedDenseSolve VJP: dynamic K not supported"),
733 };
734 let xt_shape = Shape::from_dims(
735 &[Dim::Static(b_dim), Dim::Static(k), Dim::Static(n)],
736 dtype,
737 );
738 let x_t = bwd.add_node(
739 Op::Transpose {
740 perm: vec![0, 2, 1],
741 },
742 vec![x_bwd],
743 xt_shape,
744 );
745 let outer = bwd.matmul(d_b, x_t, a_shape.clone());
746 bwd.activation(Activation::Neg, outer, a_shape)
747 }
748 r => panic!("BatchedDenseSolve VJP: b must be rank 2 or 3, got rank {r}"),
749 };
750
751 vec![(0, neg_outer), (1, d_b)]
752 }
753
754 Op::Scan {
755 body,
756 length,
757 save_trajectory,
758 num_bcast: _,
759 num_xs,
760 num_checkpoints,
761 } => {
762 let init_bwd = fwd_map[&node.inputs[0]];
770 let traj_bwd = fwd_map[&node.id];
771 let init_shape = bwd.node(init_bwd).shape.clone();
772
773 let mut body_input_ids: Vec<NodeId> = body
775 .nodes()
776 .iter()
777 .filter(|n| matches!(n.op, Op::Input { .. }))
778 .map(|n| n.id)
779 .collect();
780 body_input_ids.sort();
781
782 let body_vjp = grad(body, &body_input_ids);
783
784 let xs_bwd: Vec<NodeId> = (0..*num_xs as usize)
785 .map(|i| fwd_map[&node.inputs[1 + i]])
786 .collect();
787
788 let is_checkpointed = *num_checkpoints != 0 && *num_checkpoints != *length;
794 let forward_body_for_bwd = if is_checkpointed {
795 Some((**body).clone())
796 } else {
797 None
798 };
799
800 let dinit = bwd.scan_backward_with_checkpoints(
801 init_bwd,
802 traj_bwd,
803 upstream,
804 &xs_bwd,
805 body_vjp.clone(),
806 *length,
807 *save_trajectory,
808 *num_checkpoints,
809 forward_body_for_bwd.clone(),
810 init_shape,
811 );
812
813 let mut grads: Vec<(usize, NodeId)> = vec![(0, dinit)];
814 for i in 0..*num_xs as usize {
815 let outer_xs_id = node.inputs[1 + i];
816 let xs_shape = bwd.node(fwd_map[&outer_xs_id]).shape.clone();
817 let dxs_i = bwd.scan_backward_xs_with_checkpoints(
818 init_bwd,
819 traj_bwd,
820 upstream,
821 &xs_bwd,
822 body_vjp.clone(),
823 *length,
824 *save_trajectory,
825 i as u32,
826 *num_checkpoints,
827 forward_body_for_bwd.clone(),
828 xs_shape,
829 );
830 grads.push((1 + i, dxs_i));
831 }
832 grads
833 }
834
835 Op::Conv {
836 kernel_size,
837 stride,
838 padding,
839 dilation,
840 groups,
841 } => {
842 let x_bwd = fwd_map[&node.inputs[0]];
843 let w_bwd = fwd_map[&node.inputs[1]];
844 let x_shape = bwd.node(x_bwd).shape.clone();
845 let w_shape = bwd.node(w_bwd).shape.clone();
846 let dx = bwd.conv2d_backward_input(
847 upstream,
848 w_bwd,
849 x_shape,
850 kernel_size.clone(),
851 stride.clone(),
852 padding.clone(),
853 dilation.clone(),
854 *groups,
855 );
856 let _qat_bits: Option<u8> = None;
866 let dw = bwd.conv2d_backward_weight(
867 x_bwd,
868 upstream,
869 w_shape,
870 kernel_size.clone(),
871 stride.clone(),
872 padding.clone(),
873 dilation.clone(),
874 *groups,
875 );
876 vec![(0, dx), (1, dw)]
877 }
878
879 Op::Pool {
880 kind: ReduceOp::Max,
881 kernel_size,
882 stride,
883 padding,
884 } => {
885 let x_bwd = fwd_map[&node.inputs[0]];
886 let dx = bwd.maxpool2d_backward(
887 x_bwd,
888 upstream,
889 kernel_size.clone(),
890 stride.clone(),
891 padding.clone(),
892 );
893 vec![(0, dx)]
894 }
895
896 Op::SoftmaxCrossEntropyWithLogits => {
897 let logits_bwd = fwd_map[&node.inputs[0]];
898 let labels_bwd = fwd_map[&node.inputs[1]];
899 let dlogits = bwd.softmax_cross_entropy_backward(logits_bwd, labels_bwd, upstream);
900 vec![(0, dlogits)]
902 }
903
904 Op::Reduce {
905 op: ReduceOp::Sum,
906 axes,
907 keep_dim,
908 } => {
909 let x_bwd = fwd_map[&node.inputs[0]];
910 let x_shape = bwd.node(x_bwd).shape.clone();
911 let g = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
912 vec![(0, g)]
913 }
914
915 Op::Reduce {
916 op: ReduceOp::Mean,
917 axes,
918 keep_dim,
919 } => {
920 let x_bwd = fwd_map[&node.inputs[0]];
926 let x_shape = bwd.node(x_bwd).shape.clone();
927 let count: usize = axes
928 .iter()
929 .map(|&a| match x_shape.dim(a) {
930 Dim::Static(n) => n,
931 _ => panic!("Reduce::Mean VJP requires static reduced dims"),
932 })
933 .product();
934 let expanded = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
935 let inv_count = scalar_const(1.0 / count as f32, bwd);
936 let g = bwd.binary(BinaryOp::Mul, expanded, inv_count, x_shape);
937 vec![(0, g)]
938 }
939
940 Op::Reshape { .. } => {
941 let x_bwd = fwd_map[&node.inputs[0]];
942 let x_shape = bwd.node(x_bwd).shape.clone();
943 let dx = reshape_to(upstream, &x_shape, bwd);
944 vec![(0, dx)]
945 }
946
947 Op::ComplexNormSq => {
948 let z_bwd = fwd_map[&node.inputs[0]];
951 let dz = bwd.complex_norm_sq_backward(z_bwd, upstream);
952 vec![(0, dz)]
953 }
954
955 Op::Conjugate => {
956 let dz = bwd.conjugate(upstream);
962 vec![(0, dz)]
963 }
964
965 Op::Cast { .. } => {
966 let x_bwd = fwd_map[&node.inputs[0]];
967 let x_shape = bwd.node(x_bwd).shape.clone();
968 let dx = bwd.add_node(
969 Op::Cast {
970 to: x_shape.dtype(),
971 },
972 vec![upstream],
973 x_shape,
974 );
975 vec![(0, dx)]
976 }
977
978 Op::StopGradient => vec![],
983
984 Op::Quantize { .. } | Op::Dequantize { .. } => {
996 vec![(0, upstream)]
997 }
998
999 Op::FakeQuantizeLSQ { bits, axis } => {
1000 let x_bwd = fwd_map[&node.inputs[0]];
1004 let scale_bwd = fwd_map[&node.inputs[1]];
1005 let x_shape = bwd.node(x_bwd).shape.clone();
1006 let scale_shape = bwd.node(scale_bwd).shape.clone();
1007 let dx = bwd.add_node(
1008 Op::FakeQuantizeLSQBackwardX {
1009 bits: *bits,
1010 axis: *axis,
1011 },
1012 vec![x_bwd, scale_bwd, upstream],
1013 x_shape,
1014 );
1015 let dscale = bwd.add_node(
1016 Op::FakeQuantizeLSQBackwardScale {
1017 bits: *bits,
1018 axis: *axis,
1019 },
1020 vec![x_bwd, scale_bwd, upstream],
1021 scale_shape,
1022 );
1023 vec![(0, dx), (1, dscale)]
1024 }
1025
1026 Op::FakeQuantize {
1031 bits, axis, ste, ..
1032 } => {
1033 use rlx_ir::op::SteKind;
1034 match ste {
1035 SteKind::Identity => vec![(0, upstream)],
1036 _ => {
1037 let x_bwd = fwd_map[&node.inputs[0]];
1038 let x_shape = bwd.node(x_bwd).shape.clone();
1039 let dx = bwd.add_node(
1040 Op::FakeQuantizeBackward {
1041 bits: *bits,
1042 axis: *axis,
1043 ste: *ste,
1044 },
1045 vec![x_bwd, upstream],
1046 x_shape,
1047 );
1048 vec![(0, dx)]
1049 }
1050 }
1051 }
1052
1053 Op::Expand { .. } => {
1054 let x_bwd = fwd_map[&node.inputs[0]];
1055 let x_shape = bwd.node(x_bwd).shape.clone();
1056 let dx = unbroadcast(upstream, &x_shape, bwd);
1057 vec![(0, dx)]
1058 }
1059
1060 Op::BatchNormInference { eps } => {
1061 let x_bwd = fwd_map[&node.inputs[0]];
1062 let gamma_bwd = fwd_map[&node.inputs[1]];
1063 let _beta_bwd = fwd_map[&node.inputs[2]];
1064 let mean_bwd = fwd_map[&node.inputs[3]];
1065 let var_bwd = fwd_map[&node.inputs[4]];
1066 let gamma_shape = bwd.node(gamma_bwd).shape.clone();
1067 let dx = bwd.batch_norm_inference_backward_input(
1068 x_bwd, gamma_bwd, mean_bwd, var_bwd, upstream, *eps,
1069 );
1070 let dgamma = bwd.batch_norm_inference_backward_gamma(
1071 x_bwd,
1072 mean_bwd,
1073 var_bwd,
1074 upstream,
1075 gamma_shape.clone(),
1076 *eps,
1077 );
1078 let dbeta = bwd.batch_norm_inference_backward_beta(upstream, gamma_shape);
1079 vec![(0, dx), (1, dgamma), (2, dbeta)]
1081 }
1082
1083 Op::LayerNorm { axis, eps } => {
1084 let x_bwd = fwd_map[&node.inputs[0]];
1091 let gamma_bwd = fwd_map[&node.inputs[1]];
1092 let _beta_bwd = fwd_map[&node.inputs[2]];
1093 let gamma_shape = bwd.node(gamma_bwd).shape.clone();
1094
1095 let dx = bwd.layer_norm_backward_input(x_bwd, gamma_bwd, upstream, *axis, *eps);
1096 let dgamma =
1097 bwd.layer_norm_backward_gamma(x_bwd, upstream, gamma_shape.clone(), *axis, *eps);
1098 let dbeta = unbroadcast(upstream, &gamma_shape, bwd);
1099 vec![(0, dx), (1, dgamma), (2, dbeta)]
1100 }
1101
1102 Op::Softmax { axis } => {
1103 let y_bwd = fwd_map[&node.id];
1119 let y_shape = bwd.node(y_bwd).shape.clone();
1120 let dtype = y_shape.dtype();
1121 let rank = y_shape.rank();
1122 let axis_pos = if *axis < 0 {
1123 (rank as i32 + *axis) as usize
1124 } else {
1125 *axis as usize
1126 };
1127
1128 let yg = bwd.binary(BinaryOp::Mul, y_bwd, upstream, y_shape.clone());
1129
1130 let mut kept_dims: Vec<Dim> = y_shape.dims().to_vec();
1131 kept_dims[axis_pos] = Dim::Static(1);
1132 let kept_shape = Shape::from_dims(&kept_dims, dtype);
1133 let s = bwd.add_node(
1134 Op::Reduce {
1135 op: ReduceOp::Sum,
1136 axes: vec![axis_pos],
1137 keep_dim: true,
1138 },
1139 vec![yg],
1140 kept_shape,
1141 );
1142
1143 let target_dims: Vec<i64> = y_shape
1144 .dims()
1145 .iter()
1146 .map(|d| match d {
1147 Dim::Static(n) => *n as i64,
1148 Dim::Dynamic(_) => -1,
1149 })
1150 .collect();
1151 let s_expanded = bwd.add_node(
1152 Op::Expand {
1153 target_shape: target_dims,
1154 },
1155 vec![s],
1156 y_shape.clone(),
1157 );
1158
1159 let diff = bwd.binary(BinaryOp::Sub, upstream, s_expanded, y_shape.clone());
1160 let dx = bwd.binary(BinaryOp::Mul, y_bwd, diff, y_shape);
1161 vec![(0, dx)]
1162 }
1163
1164 Op::Transpose { perm } => {
1166 let inv: Vec<usize> = {
1169 let mut v = vec![0usize; perm.len()];
1170 for (i, &p) in perm.iter().enumerate() {
1171 v[p] = i;
1172 }
1173 v
1174 };
1175 let x_bwd = fwd_map[&node.inputs[0]];
1176 let x_shape = bwd.node(x_bwd).shape.clone();
1177 let dx = bwd.add_node(Op::Transpose { perm: inv }, vec![upstream], x_shape);
1178 vec![(0, dx)]
1179 }
1180
1181 Op::Concat { axis } => {
1182 let mut grads = Vec::with_capacity(node.inputs.len());
1185 let mut offset: usize = 0;
1186 for (i, &input_id) in node.inputs.iter().enumerate() {
1187 let x_bwd = fwd_map[&input_id];
1188 let x_shape = bwd.node(x_bwd).shape.clone();
1189 let len = match x_shape.dim(*axis) {
1190 Dim::Static(n) => n,
1191 _ => panic!("Concat VJP: dynamic concat dim"),
1192 };
1193 let dx = bwd.add_node(
1194 Op::Narrow {
1195 axis: *axis,
1196 start: offset,
1197 len,
1198 },
1199 vec![upstream],
1200 x_shape,
1201 );
1202 grads.push((i, dx));
1203 offset += len;
1204 }
1205 grads
1206 }
1207
1208 Op::Narrow { axis, start, len } => {
1209 let x_bwd = fwd_map[&node.inputs[0]];
1213 let x_shape = bwd.node(x_bwd).shape.clone();
1214 let full_n = match x_shape.dim(*axis) {
1215 Dim::Static(n) => n,
1216 _ => panic!("Narrow VJP: dynamic axis"),
1217 };
1218 let pre = *start;
1219 let post = full_n - *start - *len;
1220
1221 let zero_buf = |bwd: &mut Graph, len_axis: usize| -> NodeId {
1222 if len_axis == 0 {
1223 return upstream; }
1225 let dtype = x_shape.dtype();
1226 let mut dims: Vec<Dim> = x_shape.dims().to_vec();
1227 dims[*axis] = Dim::Static(len_axis);
1228 let s = Shape::from_dims(&dims, dtype);
1229 let n_elems = dims.iter().fold(1usize, |a, d| match d {
1230 Dim::Static(k) => a * k,
1231 _ => a,
1232 });
1233 let bytes = vec![0u8; n_elems * dtype.size_bytes()];
1237 bwd.add_node(Op::Constant { data: bytes }, vec![], s)
1238 };
1239
1240 let mut parts: Vec<NodeId> = Vec::new();
1241 if pre > 0 {
1242 parts.push(zero_buf(bwd, pre));
1243 }
1244 parts.push(upstream);
1245 if post > 0 {
1246 parts.push(zero_buf(bwd, post));
1247 }
1248
1249 let dx = if parts.len() == 1 {
1250 parts[0]
1251 } else {
1252 bwd.add_node(Op::Concat { axis: *axis }, parts, x_shape)
1253 };
1254 vec![(0, dx)]
1255 }
1256
1257 Op::Gather { axis } => {
1258 let table_bwd = fwd_map[&node.inputs[0]];
1259 let indices_bwd = fwd_map[&node.inputs[1]];
1260 let table_shape = bwd.node(table_bwd).shape.clone();
1261 if *axis == 0 {
1262 let dtable = bwd.add_node(Op::ScatterAdd, vec![upstream, indices_bwd], table_shape);
1263 vec![(0, dtable)]
1264 } else {
1265 let dtable = bwd.gather_backward(
1266 upstream,
1267 indices_bwd,
1268 table_shape,
1269 (*axis).try_into().unwrap(),
1270 );
1271 vec![(0, dtable)]
1272 }
1273 }
1274
1275 Op::Compare(_) => {
1277 vec![]
1282 }
1283
1284 Op::Where => {
1285 let cond = fwd_map[&node.inputs[0]];
1289 let a_bwd = fwd_map[&node.inputs[1]];
1290 let b_bwd = fwd_map[&node.inputs[2]];
1291 let a_shape = bwd.node(a_bwd).shape.clone();
1292 let b_shape = bwd.node(b_bwd).shape.clone();
1293 let out_shape = upstream_shape.clone();
1294
1295 let zero_a_bytes = vec![0u8; a_shape.num_elements().expect("Where VJP: dynamic a") * 4];
1296 let zero_b_bytes = vec![0u8; b_shape.num_elements().expect("Where VJP: dynamic b") * 4];
1297 let zero_a = bwd.add_node(Op::Constant { data: zero_a_bytes }, vec![], a_shape.clone());
1298 let zero_b = bwd.add_node(Op::Constant { data: zero_b_bytes }, vec![], b_shape.clone());
1299 let zero_a_bcast = unbroadcast_inverse(zero_a, &out_shape, bwd);
1302 let zero_b_bcast = unbroadcast_inverse(zero_b, &out_shape, bwd);
1303 let g_a_full = bwd.add_node(
1304 Op::Where,
1305 vec![cond, upstream, zero_a_bcast],
1306 out_shape.clone(),
1307 );
1308 let g_b_full = bwd.add_node(Op::Where, vec![cond, zero_b_bcast, upstream], out_shape);
1309 let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1310 let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1311 vec![(1, g_a), (2, g_b)]
1312 }
1313
1314 Op::Binary(BinaryOp::Div) => {
1316 let a_bwd = fwd_map[&node.inputs[0]];
1324 let b_bwd = fwd_map[&node.inputs[1]];
1325 let y_bwd = fwd_map[&node.id];
1326 let a_shape = bwd.node(a_bwd).shape.clone();
1327 let b_shape = bwd.node(b_bwd).shape.clone();
1328 let is_c64 = upstream_shape.dtype() == DType::C64;
1329
1330 let b_term = if is_c64 { bwd.conjugate(b_bwd) } else { b_bwd };
1331 let y_term = if is_c64 { bwd.conjugate(y_bwd) } else { y_bwd };
1332
1333 let g_a_full = bwd.binary(BinaryOp::Div, upstream, b_term, upstream_shape.clone());
1335 let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1336
1337 let neg_up = bwd.activation(Activation::Neg, upstream, upstream_shape.clone());
1339 let neg_up_y = bwd.binary(BinaryOp::Mul, neg_up, y_term, upstream_shape.clone());
1340 let g_b_full = bwd.binary(BinaryOp::Div, neg_up_y, b_term, upstream_shape);
1341 let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1342
1343 vec![(0, g_a), (1, g_b)]
1344 }
1345
1346 Op::Reduce {
1348 op: ReduceOp::Max,
1349 axes,
1350 keep_dim,
1351 }
1352 | Op::Reduce {
1353 op: ReduceOp::Min,
1354 axes,
1355 keep_dim,
1356 } => {
1357 let is_max = matches!(
1361 node.op,
1362 Op::Reduce {
1363 op: ReduceOp::Max,
1364 ..
1365 }
1366 );
1367 let _ = is_max;
1368 let x_bwd = fwd_map[&node.inputs[0]];
1369 let y_bwd = fwd_map[&node.id];
1370 let x_shape = bwd.node(x_bwd).shape.clone();
1371 let y_expanded = expand_to(y_bwd, &x_shape, axes, *keep_dim, bwd);
1372 let mask_bool = bwd.add_node(
1373 Op::Compare(CmpOp::Eq),
1374 vec![x_bwd, y_expanded],
1375 Shape::from_dims(x_shape.dims(), DType::F32),
1376 );
1377 let mask_f32 = bwd.add_node(
1381 Op::Cast {
1382 to: x_shape.dtype(),
1383 },
1384 vec![mask_bool],
1385 x_shape.clone(),
1386 );
1387 let upstream_expanded = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
1388 let dx = bwd.binary(BinaryOp::Mul, upstream_expanded, mask_f32, x_shape);
1389 vec![(0, dx)]
1390 }
1391
1392 Op::Rope { head_dim, n_rot } => {
1398 let cos = fwd_map[&node.inputs[1]];
1399 let sin = fwd_map[&node.inputs[2]];
1400 let dx = bwd.rope_backward(upstream, cos, sin, *head_dim, *n_rot);
1401 vec![(0, dx)]
1402 }
1403
1404 Op::RmsNorm { axis, eps } => {
1405 let x = fwd_map[&node.inputs[0]];
1406 let gamma = fwd_map[&node.inputs[1]];
1407 let beta = fwd_map[&node.inputs[2]];
1408 let dx = bwd.rms_norm_backward_input(x, gamma, beta, upstream, *axis, *eps);
1409 let dgamma = bwd.rms_norm_backward_gamma(x, gamma, beta, upstream, *axis, *eps);
1410 let dbeta = bwd.rms_norm_backward_beta(x, gamma, beta, upstream, *axis, *eps);
1411 vec![(0, dx), (1, dgamma), (2, dbeta)]
1412 }
1413
1414 Op::GroupNorm { num_groups, eps } => {
1415 let x = fwd_map[&node.inputs[0]];
1416 let gamma = fwd_map[&node.inputs[1]];
1417 let beta = fwd_map[&node.inputs[2]];
1418 let gamma_shape = bwd.node(gamma).shape.clone();
1419 let beta_shape = bwd.node(beta).shape.clone();
1420 let dx = bwd.group_norm_backward_input(x, gamma, beta, upstream, *num_groups, *eps);
1421 let dgamma = bwd.group_norm_backward_gamma(x, upstream, gamma_shape, *num_groups, *eps);
1422 let dbeta = bwd.group_norm_backward_beta(x, upstream, beta_shape, *num_groups, *eps);
1423 vec![(0, dx), (1, dgamma), (2, dbeta)]
1424 }
1425
1426 Op::Attention {
1428 num_heads,
1429 head_dim,
1430 mask_kind,
1431 score_scale: _,
1432 attn_logit_softcap: _,
1433 } => {
1434 let q = fwd_map[&node.inputs[0]];
1435 let k = fwd_map[&node.inputs[1]];
1436 let v = fwd_map[&node.inputs[2]];
1437 let mask = match mask_kind {
1438 MaskKind::Custom | MaskKind::Bias => Some(fwd_map[&node.inputs[3]]),
1439 _ => None,
1440 };
1441 let (dq, dk, dv) = bwd
1442 .attention_backward_all(q, k, v, upstream, *num_heads, *head_dim, *mask_kind, mask);
1443 vec![(0, dq), (1, dk), (2, dv)]
1444 }
1445
1446 Op::Reduce {
1453 op: ReduceOp::Prod,
1454 axes,
1455 keep_dim,
1456 } => {
1457 let x_bwd = fwd_map[&node.inputs[0]];
1458 let y_bwd = fwd_map[&node.id];
1459 let x_shape = bwd.node(x_bwd).shape.clone();
1460 let y_expanded = expand_to(y_bwd, &x_shape, axes, *keep_dim, bwd);
1461 let upstream_expanded = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
1462 let num = bwd.binary(
1464 BinaryOp::Mul,
1465 upstream_expanded,
1466 y_expanded,
1467 x_shape.clone(),
1468 );
1469 let dx = bwd.binary(BinaryOp::Div, num, x_bwd, x_shape);
1470 vec![(0, dx)]
1471 }
1472
1473 Op::Pool {
1485 kind: ReduceOp::Mean,
1486 kernel_size,
1487 stride,
1488 padding,
1489 } => {
1490 assert_eq!(kernel_size.len(), 2, "Pool(Mean) VJP: 2-D pool only");
1491 let x_bwd = fwd_map[&node.inputs[0]];
1492 let x_shape = bwd.node(x_bwd).shape.clone();
1493 let dtype = x_shape.dtype();
1494 let c = match x_shape.dim(1) {
1496 Dim::Static(n) => n,
1497 _ => panic!("Pool(Mean) VJP: dynamic channel dim"),
1498 };
1499 let kh = kernel_size[0];
1500 let kw = kernel_size[1];
1501 let inv_n = 1.0_f32 / (kh as f32 * kw as f32);
1502 let kernel_n = c * kh * kw;
1503 let mut bytes: Vec<u8> = Vec::with_capacity(kernel_n * 4);
1504 for _ in 0..kernel_n {
1505 bytes.extend_from_slice(&inv_n.to_le_bytes());
1506 }
1507 let kernel_shape = Shape::from_dims(
1508 &[
1509 Dim::Static(c),
1510 Dim::Static(1),
1511 Dim::Static(kh),
1512 Dim::Static(kw),
1513 ],
1514 dtype,
1515 );
1516 let kernel = bwd.add_node(Op::Constant { data: bytes }, vec![], kernel_shape);
1517 let dx = bwd.conv2d_backward_input(
1518 upstream,
1519 kernel,
1520 x_shape,
1521 kernel_size.clone(),
1522 stride.clone(),
1523 padding.clone(),
1524 vec![1, 1],
1525 c, );
1527 vec![(0, dx)]
1528 }
1529
1530 Op::Binary(BinaryOp::Min) | Op::Binary(BinaryOp::Max) => {
1537 let a_bwd = fwd_map[&node.inputs[0]];
1538 let b_bwd = fwd_map[&node.inputs[1]];
1539 let y_bwd = fwd_map[&node.id];
1540 let a_shape = bwd.node(a_bwd).shape.clone();
1541 let b_shape = bwd.node(b_bwd).shape.clone();
1542 let dtype = upstream_shape.dtype();
1543
1544 let bool_shape = Shape::from_dims(upstream_shape.dims(), DType::Bool);
1545 let mask_pred = bwd.add_node(Op::Compare(CmpOp::Eq), vec![a_bwd, y_bwd], bool_shape);
1546 let mask_f32 = bwd.add_node(
1547 Op::Cast { to: dtype },
1548 vec![mask_pred],
1549 upstream_shape.clone(),
1550 );
1551 let zero_bytes = vec![
1552 0u8;
1553 upstream_shape
1554 .num_elements()
1555 .expect("Min/Max VJP: dyn shape")
1556 * 4
1557 ];
1558 let zero = bwd.add_node(
1559 Op::Constant { data: zero_bytes },
1560 vec![],
1561 upstream_shape.clone(),
1562 );
1563 let g_a_full = bwd.add_node(
1564 Op::Where,
1565 vec![mask_f32, upstream, zero],
1566 upstream_shape.clone(),
1567 );
1568 let g_b_full = bwd.add_node(Op::Where, vec![mask_f32, zero, upstream], upstream_shape);
1569 let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1570 let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1571 vec![(0, g_a), (1, g_b)]
1572 }
1573
1574 Op::Binary(BinaryOp::Pow) => {
1583 let a_bwd = fwd_map[&node.inputs[0]];
1584 let b_bwd = fwd_map[&node.inputs[1]];
1585 let y_bwd = fwd_map[&node.id]; let a_shape = bwd.node(a_bwd).shape.clone();
1587 let b_shape = bwd.node(b_bwd).shape.clone();
1588
1589 let yb = bwd.binary(BinaryOp::Mul, y_bwd, b_bwd, upstream_shape.clone());
1592 let yb_over_a = bwd.binary(BinaryOp::Div, yb, a_bwd, upstream_shape.clone());
1593 let g_a_full = bwd.binary(BinaryOp::Mul, upstream, yb_over_a, upstream_shape.clone());
1594 let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1595
1596 let ln_a = bwd.activation(Activation::Log, a_bwd, a_shape);
1598 let ln_a_b = unbroadcast_inverse(ln_a, &upstream_shape, bwd);
1599 let yln = bwd.binary(BinaryOp::Mul, y_bwd, ln_a_b, upstream_shape.clone());
1600 let g_b_full = bwd.binary(BinaryOp::Mul, upstream, yln, upstream_shape);
1601 let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1602
1603 vec![(0, g_a), (1, g_b)]
1604 }
1605
1606 Op::DequantMatMul { scheme: _ } => {
1625 let x_bwd = fwd_map[&node.inputs[0]];
1626 let w_q_bwd = fwd_map[&node.inputs[1]];
1627 let scale_bwd = fwd_map[&node.inputs[2]];
1628 let zp_bwd = fwd_map[&node.inputs[3]];
1629 let x_shape = bwd.node(x_bwd).shape.clone();
1630 let w_shape = bwd.node(w_q_bwd).shape.clone();
1631 let scale_shape = bwd.node(scale_bwd).shape.clone();
1632 let zp_shape = bwd.node(zp_bwd).shape.clone();
1633
1634 let dtype = x_shape.dtype();
1638 let w_q_f32 = bwd.add_node(
1639 Op::Cast { to: dtype },
1640 vec![w_q_bwd],
1641 Shape::from_dims(w_shape.dims(), dtype),
1642 );
1643 let scale_b =
1645 unbroadcast_inverse(scale_bwd, &Shape::from_dims(w_shape.dims(), dtype), bwd);
1646 let zp_b = unbroadcast_inverse(zp_bwd, &Shape::from_dims(w_shape.dims(), dtype), bwd);
1647 let w_centered = bwd.binary(
1648 BinaryOp::Sub,
1649 w_q_f32,
1650 zp_b,
1651 Shape::from_dims(w_shape.dims(), dtype),
1652 );
1653 let w_dq = bwd.binary(
1654 BinaryOp::Mul,
1655 w_centered,
1656 scale_b,
1657 Shape::from_dims(w_shape.dims(), dtype),
1658 );
1659
1660 let w_rank = w_shape.rank();
1662 let mut perm: Vec<usize> = (0..w_rank).collect();
1663 perm.swap(w_rank - 2, w_rank - 1);
1664 let mut wdt_dims: Vec<Dim> = w_shape.dims().to_vec();
1665 wdt_dims.swap(w_rank - 2, w_rank - 1);
1666 let w_dq_t_shape = Shape::from_dims(&wdt_dims, dtype);
1667 let w_dq_t = bwd.add_node(Op::Transpose { perm }, vec![w_dq], w_dq_t_shape);
1668 let dx = bwd.matmul(upstream, w_dq_t, x_shape.clone());
1669
1670 let x_rank = x_shape.rank();
1675 let mut x_perm: Vec<usize> = (0..x_rank).collect();
1676 x_perm.swap(x_rank - 2, x_rank - 1);
1677 let mut x_t_dims: Vec<Dim> = x_shape.dims().to_vec();
1678 x_t_dims.swap(x_rank - 2, x_rank - 1);
1679 let x_t = bwd.add_node(
1680 Op::Transpose { perm: x_perm },
1681 vec![x_bwd],
1682 Shape::from_dims(&x_t_dims, dtype),
1683 );
1684 let dw_unscaled = bwd.matmul(x_t, upstream, Shape::from_dims(w_shape.dims(), dtype));
1685 let dw_q_f32 = bwd.binary(
1686 BinaryOp::Mul,
1687 dw_unscaled,
1688 scale_b,
1689 Shape::from_dims(w_shape.dims(), dtype),
1690 );
1691 let dw_q = bwd.add_node(
1693 Op::Cast {
1694 to: w_shape.dtype(),
1695 },
1696 vec![dw_q_f32],
1697 w_shape,
1698 );
1699
1700 let zero_scale_bytes =
1702 vec![0u8; scale_shape.num_elements().expect("DQMM VJP: dyn scale") * 4];
1703 let zero_zp_bytes = vec![0u8; zp_shape.num_elements().expect("DQMM VJP: dyn zp") * 4];
1704 let dscale = bwd.add_node(
1705 Op::Constant {
1706 data: zero_scale_bytes,
1707 },
1708 vec![],
1709 scale_shape,
1710 );
1711 let dzp = bwd.add_node(
1712 Op::Constant {
1713 data: zero_zp_bytes,
1714 },
1715 vec![],
1716 zp_shape,
1717 );
1718
1719 vec![(0, dx), (1, dw_q), (2, dscale), (3, dzp)]
1720 }
1721
1722 Op::ScatterAdd => {
1728 let updates_bwd = fwd_map[&node.inputs[0]];
1729 let indices_bwd = fwd_map[&node.inputs[1]];
1730 let updates_shape = bwd.node(updates_bwd).shape.clone();
1731 let dupdates = bwd.add_node(
1732 Op::Gather { axis: 0 },
1733 vec![upstream, indices_bwd],
1734 updates_shape,
1735 );
1736 vec![(0, dupdates)]
1737 }
1738
1739 Op::Cumsum { axis, exclusive } => {
1742 let x_bwd = fwd_map[&node.inputs[0]];
1743 let x_shape = bwd.node(x_bwd).shape.clone();
1744 let dx = bwd.cumsum_backward(upstream, x_shape, *axis, *exclusive);
1745 vec![(0, dx)]
1746 }
1747
1748 Op::GroupedMatMul => {
1761 let x_bwd = fwd_map[&node.inputs[0]];
1762 let w_bwd = fwd_map[&node.inputs[1]];
1763 let expert_bwd = fwd_map[&node.inputs[2]];
1764 let x_shape = bwd.node(x_bwd).shape.clone();
1765 let w_shape = bwd.node(w_bwd).shape.clone();
1766 let (dx, dw) =
1767 grouped_matmul_vjp(bwd, upstream, x_bwd, w_bwd, expert_bwd, &x_shape, &w_shape);
1768 vec![(0, dx), (1, dw)]
1769 }
1770
1771 Op::DequantGroupedMatMul { scheme } => {
1777 let x_bwd = fwd_map[&node.inputs[0]];
1778 let w_packed = fwd_map[&node.inputs[1]];
1779 let expert_bwd = fwd_map[&node.inputs[2]];
1780 let x_shape = bwd.node(x_bwd).shape.clone();
1781 let w_packed_shape = bwd.node(w_packed).shape.clone();
1782 let dtype = x_shape.dtype();
1783 let k = x_shape.dim(1);
1784 let n_out = node.shape.dim(node.shape.rank() - 1);
1785 let k_static = match k {
1786 Dim::Static(v) => v,
1787 _ => panic!("DequantGroupedMatMul VJP: K must be static"),
1788 };
1789 let n_static = match n_out {
1790 Dim::Static(v) => v,
1791 _ => panic!("DequantGroupedMatMul VJP: N must be static"),
1792 };
1793 let block_elems = scheme.gguf_block_size() as usize;
1794 let block_bytes = scheme.gguf_block_bytes() as usize;
1795 let slab_bytes = (k_static * n_static) / block_elems * block_bytes;
1796 let total_bytes = w_packed_shape
1797 .num_elements()
1798 .expect("DequantGroupedMatMul VJP: dyn packed");
1799 let e_static = total_bytes / slab_bytes.max(1);
1800 let w_shape = Shape::from_dims(
1801 &[
1802 Dim::Static(e_static),
1803 Dim::Static(k_static),
1804 Dim::Static(n_static),
1805 ],
1806 dtype,
1807 );
1808 let w_dq = bwd.add_node(
1809 Op::DequantMoEWeights { scheme: *scheme },
1810 vec![w_packed],
1811 w_shape.clone(),
1812 );
1813 let (dx, _dw) =
1814 grouped_matmul_vjp(bwd, upstream, x_bwd, w_dq, expert_bwd, &x_shape, &w_shape);
1815 vec![(0, dx)]
1816 }
1817
1818 Op::QMatMul {
1831 x_zp,
1832 w_zp,
1833 out_zp: _,
1834 mult,
1835 } => {
1836 let x_bwd = fwd_map[&node.inputs[0]];
1837 let w_bwd = fwd_map[&node.inputs[1]];
1838 let bias_bwd = fwd_map[&node.inputs[2]];
1839 let x_shape = bwd.node(x_bwd).shape.clone();
1840 let w_shape = bwd.node(w_bwd).shape.clone();
1841 let bias_shape = bwd.node(bias_bwd).shape.clone();
1842 let dtype = upstream_shape.dtype();
1843
1844 let x_f32 = bwd.add_node(
1846 Op::Cast { to: dtype },
1847 vec![x_bwd],
1848 Shape::from_dims(x_shape.dims(), dtype),
1849 );
1850 let w_f32 = bwd.add_node(
1851 Op::Cast { to: dtype },
1852 vec![w_bwd],
1853 Shape::from_dims(w_shape.dims(), dtype),
1854 );
1855 let xzp_c = scalar_const(*x_zp as f32, bwd);
1856 let xzp_b = unbroadcast_inverse(xzp_c, &Shape::from_dims(x_shape.dims(), dtype), bwd);
1857 let _ = bwd.binary(
1858 BinaryOp::Sub,
1859 x_f32,
1860 xzp_b,
1861 Shape::from_dims(x_shape.dims(), dtype),
1862 );
1863 let wzp_c = scalar_const(*w_zp as f32, bwd);
1864 let wzp_b = unbroadcast_inverse(wzp_c, &Shape::from_dims(w_shape.dims(), dtype), bwd);
1865 let w_centered = bwd.binary(
1866 BinaryOp::Sub,
1867 w_f32,
1868 wzp_b,
1869 Shape::from_dims(w_shape.dims(), dtype),
1870 );
1871
1872 let mult_c = scalar_const(*mult, bwd);
1874 let mult_b = unbroadcast_inverse(mult_c, &upstream_shape, bwd);
1875 let upstream_scaled =
1876 bwd.binary(BinaryOp::Mul, upstream, mult_b, upstream_shape.clone());
1877
1878 let w_rank = w_shape.rank();
1881 let mut perm: Vec<usize> = (0..w_rank).collect();
1882 perm.swap(w_rank - 2, w_rank - 1);
1883 let mut wt_dims: Vec<Dim> = w_shape.dims().to_vec();
1884 wt_dims.swap(w_rank - 2, w_rank - 1);
1885 let w_t = bwd.add_node(
1886 Op::Transpose { perm },
1887 vec![w_centered],
1888 Shape::from_dims(&wt_dims, dtype),
1889 );
1890 let dx_f32 = bwd.matmul(
1891 upstream_scaled,
1892 w_t,
1893 Shape::from_dims(x_shape.dims(), dtype),
1894 );
1895 let dx = bwd.add_node(
1896 Op::Cast {
1897 to: x_shape.dtype(),
1898 },
1899 vec![dx_f32],
1900 x_shape.clone(),
1901 );
1902
1903 let x_rank = x_shape.rank();
1905 let mut x_perm: Vec<usize> = (0..x_rank).collect();
1906 x_perm.swap(x_rank - 2, x_rank - 1);
1907 let mut xt_dims: Vec<Dim> = x_shape.dims().to_vec();
1908 xt_dims.swap(x_rank - 2, x_rank - 1);
1909 let x_f32_2 = bwd.add_node(
1911 Op::Cast { to: dtype },
1912 vec![x_bwd],
1913 Shape::from_dims(x_shape.dims(), dtype),
1914 );
1915 let x_centered = bwd.binary(
1916 BinaryOp::Sub,
1917 x_f32_2,
1918 xzp_b,
1919 Shape::from_dims(x_shape.dims(), dtype),
1920 );
1921 let x_t = bwd.add_node(
1922 Op::Transpose { perm: x_perm },
1923 vec![x_centered],
1924 Shape::from_dims(&xt_dims, dtype),
1925 );
1926 let dw_f32 = bwd.matmul(
1927 x_t,
1928 upstream_scaled,
1929 Shape::from_dims(w_shape.dims(), dtype),
1930 );
1931 let dw = bwd.add_node(
1932 Op::Cast {
1933 to: w_shape.dtype(),
1934 },
1935 vec![dw_f32],
1936 w_shape,
1937 );
1938
1939 let bias_rank = bias_shape.rank();
1942 let reduce_axes: Vec<usize> = (0..upstream_shape.rank())
1943 .filter(|&i| i + bias_rank < upstream_shape.rank() || i == 0)
1944 .collect();
1945 let dbias_f32 = bwd.add_node(
1946 Op::Reduce {
1947 op: ReduceOp::Sum,
1948 axes: reduce_axes,
1949 keep_dim: false,
1950 },
1951 vec![upstream_scaled],
1952 Shape::from_dims(bias_shape.dims(), dtype),
1953 );
1954 let dbias = bwd.add_node(
1955 Op::Cast {
1956 to: bias_shape.dtype(),
1957 },
1958 vec![dbias_f32],
1959 bias_shape,
1960 );
1961
1962 vec![(0, dx), (1, dw), (2, dbias)]
1963 }
1964
1965 Op::QConv2d {
1966 kernel_size,
1967 stride,
1968 padding,
1969 dilation,
1970 groups,
1971 x_zp,
1972 w_zp,
1973 out_zp: _,
1974 mult,
1975 } => {
1976 let x_bwd = fwd_map[&node.inputs[0]];
1980 let w_bwd = fwd_map[&node.inputs[1]];
1981 let bias_bwd = fwd_map[&node.inputs[2]];
1982 let x_shape = bwd.node(x_bwd).shape.clone();
1983 let w_shape = bwd.node(w_bwd).shape.clone();
1984 let bias_shape = bwd.node(bias_bwd).shape.clone();
1985 let dtype = upstream_shape.dtype();
1986
1987 let x_f32 = bwd.add_node(
1989 Op::Cast { to: dtype },
1990 vec![x_bwd],
1991 Shape::from_dims(x_shape.dims(), dtype),
1992 );
1993 let w_f32 = bwd.add_node(
1994 Op::Cast { to: dtype },
1995 vec![w_bwd],
1996 Shape::from_dims(w_shape.dims(), dtype),
1997 );
1998 let xzp_c = scalar_const(*x_zp as f32, bwd);
1999 let xzp_b = unbroadcast_inverse(xzp_c, &Shape::from_dims(x_shape.dims(), dtype), bwd);
2000 let x_centered = bwd.binary(
2001 BinaryOp::Sub,
2002 x_f32,
2003 xzp_b,
2004 Shape::from_dims(x_shape.dims(), dtype),
2005 );
2006 let wzp_c = scalar_const(*w_zp as f32, bwd);
2007 let wzp_b = unbroadcast_inverse(wzp_c, &Shape::from_dims(w_shape.dims(), dtype), bwd);
2008 let w_centered = bwd.binary(
2009 BinaryOp::Sub,
2010 w_f32,
2011 wzp_b,
2012 Shape::from_dims(w_shape.dims(), dtype),
2013 );
2014
2015 let mult_c = scalar_const(*mult, bwd);
2017 let mult_b = unbroadcast_inverse(mult_c, &upstream_shape, bwd);
2018 let upstream_scaled =
2019 bwd.binary(BinaryOp::Mul, upstream, mult_b, upstream_shape.clone());
2020
2021 let dx_f32 = bwd.conv2d_backward_input(
2023 upstream_scaled,
2024 w_centered,
2025 Shape::from_dims(x_shape.dims(), dtype),
2026 kernel_size.clone(),
2027 stride.clone(),
2028 padding.clone(),
2029 dilation.clone(),
2030 *groups,
2031 );
2032 let dx = bwd.add_node(
2033 Op::Cast {
2034 to: x_shape.dtype(),
2035 },
2036 vec![dx_f32],
2037 x_shape,
2038 );
2039 let dw_f32 = bwd.conv2d_backward_weight(
2040 x_centered,
2041 upstream_scaled,
2042 Shape::from_dims(w_shape.dims(), dtype),
2043 kernel_size.clone(),
2044 stride.clone(),
2045 padding.clone(),
2046 dilation.clone(),
2047 *groups,
2048 );
2049 let dw = bwd.add_node(
2050 Op::Cast {
2051 to: w_shape.dtype(),
2052 },
2053 vec![dw_f32],
2054 w_shape,
2055 );
2056
2057 let dbias_f32 = bwd.add_node(
2059 Op::Reduce {
2060 op: ReduceOp::Sum,
2061 axes: vec![0, 2, 3],
2062 keep_dim: false,
2063 },
2064 vec![upstream_scaled],
2065 Shape::from_dims(bias_shape.dims(), dtype),
2066 );
2067 let dbias = bwd.add_node(
2068 Op::Cast {
2069 to: bias_shape.dtype(),
2070 },
2071 vec![dbias_f32],
2072 bias_shape,
2073 );
2074
2075 vec![(0, dx), (1, dw), (2, dbias)]
2076 }
2077
2078 Op::TopK { .. } | Op::Sample { .. } => {
2080 vec![]
2084 }
2085
2086 Op::GaussianSplatRender {
2087 width,
2088 height,
2089 tile_size,
2090 radius_scale,
2091 alpha_cutoff,
2092 max_splat_steps,
2093 transmittance_threshold,
2094 max_list_entries,
2095 ..
2096 } => {
2097 use rlx_ir::ops::splat::{
2098 GaussianSplatBackwardParams, GaussianSplatInputs, GaussianSplatRenderParams,
2099 unpack_gaussian_splat_packed_grads,
2100 };
2101 let render = GaussianSplatRenderParams {
2102 width: *width,
2103 height: *height,
2104 tile_size: *tile_size,
2105 radius_scale: *radius_scale,
2106 alpha_cutoff: *alpha_cutoff,
2107 max_splat_steps: *max_splat_steps,
2108 transmittance_threshold: *transmittance_threshold,
2109 max_list_entries: *max_list_entries,
2110 };
2111 let inputs = GaussianSplatInputs {
2112 positions: fwd_map[&node.inputs[0]],
2113 scales: fwd_map[&node.inputs[1]],
2114 rotations: fwd_map[&node.inputs[2]],
2115 opacities: fwd_map[&node.inputs[3]],
2116 colors: fwd_map[&node.inputs[4]],
2117 sh_coeffs: fwd_map[&node.inputs[5]],
2118 meta: fwd_map[&node.inputs[6]],
2119 };
2120 let count = bwd.shape(inputs.positions).num_elements().unwrap_or(0) / 3;
2121 let sh_len = bwd.shape(inputs.sh_coeffs).num_elements().unwrap_or(0);
2122 let meta_shape = bwd.shape(inputs.meta).clone();
2123 let packed = bwd.gaussian_splat_render_backward(
2124 inputs,
2125 upstream,
2126 GaussianSplatBackwardParams {
2127 render,
2128 loss_grad_clip: 1.0,
2129 sh_band: 0,
2130 max_anisotropy: 10.0,
2131 },
2132 );
2133 let sh_coeff_count = if count == 0 {
2134 1
2135 } else {
2136 (sh_len / (count * 3)).max(1)
2137 };
2138 let grads = unpack_gaussian_splat_packed_grads(bwd, packed, count, sh_coeff_count);
2139 let meta_n = meta_shape.num_elements().unwrap_or(0);
2140 let zero_meta = bwd.add_node(
2141 Op::Constant {
2142 data: vec![0u8; meta_n * meta_shape.dtype().size_bytes()],
2143 },
2144 vec![],
2145 meta_shape,
2146 );
2147 vec![
2148 (0, grads.positions),
2149 (1, grads.scales),
2150 (2, grads.rotations),
2151 (3, grads.opacities),
2152 (4, grads.colors),
2153 (5, grads.sh_coeffs),
2154 (6, zero_meta),
2155 ]
2156 }
2157
2158 Op::GaussianSplatRenderBackward { .. } => {
2159 vec![]
2161 }
2162
2163 Op::GaussianSplatPrepare { .. } | Op::GaussianSplatRasterize { .. } => {
2164 panic!(
2165 "autodiff: decomposed splat ops must be fused before AD — \
2166 `prepare_graph_for_ad` rewrites Prepare→Rasterize into \
2167 `GaussianSplatRender`, or use `Op::GaussianSplatRender` directly"
2168 );
2169 }
2170
2171 Op::CustomFn {
2193 vjp_body: Some(vjp_body),
2194 num_inputs,
2195 ..
2196 } => {
2197 let mut sub_to_bwd: HashMap<NodeId, NodeId> = HashMap::new();
2199
2200 let mut primal_input_ids: Vec<NodeId> = vjp_body
2204 .nodes()
2205 .iter()
2206 .filter_map(|n| match &n.op {
2207 Op::Input { name } if name != "primal_output" && name != "d_output" => {
2208 Some(n.id)
2209 }
2210 _ => None,
2211 })
2212 .collect();
2213 primal_input_ids.sort();
2214 assert_eq!(primal_input_ids.len(), *num_inputs as usize);
2215
2216 for sub_node in vjp_body.nodes() {
2219 let new_id = match &sub_node.op {
2220 Op::Input { name } if name == "primal_output" => fwd_map[&node.id],
2221 Op::Input { name } if name == "d_output" => upstream,
2222 Op::Input { .. } => {
2223 let idx = primal_input_ids
2225 .iter()
2226 .position(|&id| id == sub_node.id)
2227 .expect(
2228 "custom_fn vjp_body: primal Input \
2229 not found in primal list",
2230 );
2231 fwd_map[&node.inputs[idx]]
2232 }
2233 _ => {
2234 let new_inputs: Vec<NodeId> =
2235 sub_node.inputs.iter().map(|i| sub_to_bwd[i]).collect();
2236 bwd.add_node(sub_node.op.clone(), new_inputs, sub_node.shape.clone())
2237 }
2238 };
2239 sub_to_bwd.insert(sub_node.id, new_id);
2240 }
2241
2242 let mut grads: Vec<(usize, NodeId)> = Vec::with_capacity(*num_inputs as usize);
2245 for (i, out_id) in vjp_body.outputs.iter().enumerate() {
2246 grads.push((i, sub_to_bwd[out_id]));
2247 }
2248 grads
2249 }
2250
2251 Op::CustomFn { vjp_body: None, .. } => {
2254 panic!(
2255 "autodiff: Op::CustomFn has no vjp_body and was not inlined. \
2256 This is an internal error in inline_custom_fn_for_autodiff."
2257 )
2258 }
2259
2260 Op::Custom { name, .. } => {
2265 let ext = rlx_ir::lookup_op(name).unwrap_or_else(|| {
2266 panic!(
2267 "autodiff: Op::Custom('{name}') is not registered \
2268 in the op registry — register it via \
2269 rlx_ir::register_op before compiling the graph"
2270 )
2271 });
2272 let mut ctx = rlx_ir::VjpContext {
2273 upstream,
2274 fwd_map,
2275 bwd,
2276 };
2277 ext.vjp(node, &mut ctx)
2278 }
2279
2280 Op::Conv2dBackwardInput {
2281 kernel_size,
2282 stride,
2283 padding,
2284 dilation,
2285 groups,
2286 } => {
2287 let dy_bwd = fwd_map[&node.inputs[0]];
2288 let w_bwd = fwd_map[&node.inputs[1]];
2289 let dy_shape = bwd.node(dy_bwd).shape.clone();
2290 let _x_shape = node.shape.clone();
2291 let d_dy = bwd.add_node(
2292 Op::Conv {
2293 kernel_size: kernel_size.clone(),
2294 stride: stride.clone(),
2295 padding: padding.clone(),
2296 dilation: dilation.clone(),
2297 groups: *groups,
2298 },
2299 vec![upstream, w_bwd],
2300 dy_shape,
2301 );
2302 vec![(0, d_dy)]
2303 }
2304
2305 Op::Conv2dBackwardWeight {
2306 kernel_size,
2307 stride,
2308 padding,
2309 dilation,
2310 groups,
2311 } => {
2312 let x_bwd = fwd_map[&node.inputs[0]];
2313 let dy_bwd = fwd_map[&node.inputs[1]];
2314 let x_shape = bwd.node(x_bwd).shape.clone();
2315 let dy_shape = bwd.node(dy_bwd).shape.clone();
2316 let d_x = bwd.conv2d_backward_input(
2317 dy_bwd,
2318 upstream,
2319 x_shape,
2320 kernel_size.clone(),
2321 stride.clone(),
2322 padding.clone(),
2323 dilation.clone(),
2324 *groups,
2325 );
2326 let d_dy = bwd.add_node(
2327 Op::Conv {
2328 kernel_size: kernel_size.clone(),
2329 stride: stride.clone(),
2330 padding: padding.clone(),
2331 dilation: dilation.clone(),
2332 groups: *groups,
2333 },
2334 vec![x_bwd, upstream],
2335 dy_shape,
2336 );
2337 vec![(0, d_x), (1, d_dy)]
2338 }
2339
2340 Op::Fft { inverse, norm } => {
2349 let n = rlx_ir::fft::fft_meta(bwd.shape(node.inputs[0])).n_complex;
2350 let s = norm.output_scale(n, *inverse) as f32;
2351 let z = if s != 1.0 {
2352 let sc = scalar_const(s, bwd);
2353 bwd.mul(upstream, sc)
2354 } else {
2355 upstream
2356 };
2357 let dx = bwd.fft(z, !*inverse);
2358 vec![(0, dx)]
2359 }
2360
2361 Op::LogMel => {
2362 let spec_bwd = fwd_map[&node.inputs[0]];
2363 let filt_bwd = fwd_map[&node.inputs[1]];
2364 let dx = bwd.log_mel_backward(spec_bwd, filt_bwd, upstream);
2365 vec![(0, dx)]
2366 }
2367
2368 other => panic!(
2372 "autodiff: no VJP rule for {other}. See the matching \
2373 entry in rlx-opt/src/autodiff.rs (catch-all panic) for \
2374 a pointer to what's needed to differentiate this op.",
2375 ),
2376 }
2377}
2378
2379fn materialize_bcasts_for_ad(g: Graph) -> Graph {
2412 use rlx_ir::op::BinaryOp;
2413
2414 let needs = g.nodes().iter().any(|n| {
2415 matches!(
2416 &n.op, Op::Scan { num_bcast, .. } if *num_bcast > 0
2417 )
2418 });
2419 if !needs {
2420 return g;
2421 }
2422
2423 let mut out = Graph::new(g.name.clone());
2424 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
2425
2426 for node in g.nodes() {
2427 let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
2428 match &node.op {
2429 Op::Scan {
2430 body,
2431 length,
2432 save_trajectory,
2433 num_bcast,
2434 num_xs,
2435 num_checkpoints,
2436 } if *num_bcast > 0 => {
2437 let bcast_base = 1;
2442 let xs_base = 1 + *num_bcast as usize;
2443
2444 let mut new_scan_inputs = vec![new_inputs[0]];
2445
2446 let mut materialised_xs: Vec<NodeId> = Vec::new();
2448 for i in 0..*num_bcast as usize {
2449 let b_id = new_inputs[bcast_base + i];
2450 let b_shape = out.node(b_id).shape.clone();
2451 let dtype = b_shape.dtype();
2452
2453 let mut ones_dims: Vec<rlx_ir::Dim> =
2457 vec![rlx_ir::Dim::Static(*length as usize)];
2458 for _ in 0..b_shape.rank() {
2459 ones_dims.push(rlx_ir::Dim::Static(1));
2460 }
2461 let ones_shape = rlx_ir::Shape::from_dims(&ones_dims, dtype);
2462 let n_elems: usize = ones_dims
2463 .iter()
2464 .map(|d| match d {
2465 rlx_ir::Dim::Static(n) => *n,
2466 rlx_ir::Dim::Dynamic(_) => 1,
2467 })
2468 .product();
2469 let elem_size = dtype.size_bytes();
2470 let mut data = Vec::with_capacity(n_elems * elem_size);
2471 match dtype {
2472 rlx_ir::DType::F64 => {
2473 for _ in 0..n_elems {
2474 data.extend_from_slice(&1.0_f64.to_le_bytes());
2475 }
2476 }
2477 rlx_ir::DType::F32 => {
2478 for _ in 0..n_elems {
2479 data.extend_from_slice(&1.0_f32.to_le_bytes());
2480 }
2481 }
2482 other => {
2483 panic!("materialize_bcasts_for_ad: unsupported bcast dtype {other:?}")
2484 }
2485 }
2486 let ones = out.add_node(Op::Constant { data }, vec![], ones_shape);
2487
2488 let mut xs_dims: Vec<rlx_ir::Dim> = vec![rlx_ir::Dim::Static(*length as usize)];
2490 for i in 0..b_shape.rank() {
2491 xs_dims.push(b_shape.dim(i));
2492 }
2493 let xs_shape = rlx_ir::Shape::from_dims(&xs_dims, dtype);
2494 let xs_id = out.add_node(Op::Binary(BinaryOp::Mul), vec![ones, b_id], xs_shape);
2495 materialised_xs.push(xs_id);
2496 }
2497
2498 new_scan_inputs.extend_from_slice(&materialised_xs);
2499 for i in 0..*num_xs as usize {
2500 new_scan_inputs.push(new_inputs[xs_base + i]);
2501 }
2502
2503 let new_id = out.add_node(
2504 Op::Scan {
2505 body: body.clone(),
2506 length: *length,
2507 save_trajectory: *save_trajectory,
2508 num_bcast: 0,
2509 num_xs: *num_bcast + *num_xs,
2510 num_checkpoints: *num_checkpoints,
2511 },
2512 new_scan_inputs,
2513 node.shape.clone(),
2514 );
2515 id_map.insert(node.id, new_id);
2516 }
2517 _ => {
2518 let new_id = out.add_node(node.op.clone(), new_inputs, node.shape.clone());
2519 id_map.insert(node.id, new_id);
2520 }
2521 }
2522 }
2523
2524 let new_outputs: Vec<NodeId> = g.outputs.iter().map(|o| id_map[o]).collect();
2525 out.set_outputs(new_outputs);
2526 out
2527}
2528
2529pub fn convert_scans_for_ad(g: Graph) -> Graph {
2530 use rlx_ir::shape::Shape as IrShape;
2531
2532 let g = materialize_bcasts_for_ad(g);
2537
2538 let needs = g.nodes().iter().any(|n| {
2541 matches!(
2542 &n.op,
2543 Op::Scan {
2544 save_trajectory: false,
2545 ..
2546 }
2547 )
2548 });
2549 if !needs {
2550 return g;
2551 }
2552
2553 let mut out = Graph::new(g.name.clone());
2554 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
2555
2556 for node in g.nodes() {
2557 let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
2558 match &node.op {
2559 Op::Scan {
2560 body,
2561 length,
2562 save_trajectory: false,
2563 num_xs,
2564 num_checkpoints,
2565 ..
2566 } => {
2567 let carry_shape = node.shape.clone();
2568 let mut traj_dims: Vec<Dim> = Vec::with_capacity(carry_shape.rank() + 1);
2589 traj_dims.push(Dim::Static(*length as usize));
2590 for i in 0..carry_shape.rank() {
2591 traj_dims.push(carry_shape.dim(i));
2592 }
2593 let traj_shape = IrShape::from_dims(&traj_dims, carry_shape.dtype());
2594 let traj = out.add_node(
2595 Op::Scan {
2596 body: body.clone(),
2597 length: *length,
2598 save_trajectory: true,
2599 num_bcast: 0,
2600 num_xs: *num_xs,
2601 num_checkpoints: *num_checkpoints,
2602 },
2603 new_inputs,
2604 traj_shape,
2605 );
2606 let mut narrow_dims: Vec<Dim> = Vec::with_capacity(carry_shape.rank() + 1);
2608 narrow_dims.push(Dim::Static(1));
2609 for i in 0..carry_shape.rank() {
2610 narrow_dims.push(carry_shape.dim(i));
2611 }
2612 let narrow_shape = IrShape::from_dims(&narrow_dims, carry_shape.dtype());
2613 let narrowed = out.add_node(
2614 Op::Narrow {
2615 axis: 0,
2616 start: (*length as usize).saturating_sub(1),
2617 len: 1,
2618 },
2619 vec![traj],
2620 narrow_shape,
2621 );
2622 let new_shape: Vec<i64> = (0..carry_shape.rank())
2624 .map(|i| match carry_shape.dim(i) {
2625 Dim::Static(n) => n as i64,
2626 Dim::Dynamic(_) => -1,
2627 })
2628 .collect();
2629 let final_id = out.add_node(Op::Reshape { new_shape }, vec![narrowed], carry_shape);
2630 id_map.insert(node.id, final_id);
2631 }
2632 _ => {
2633 let new_id = out.add_node(node.op.clone(), new_inputs, node.shape.clone());
2634 id_map.insert(node.id, new_id);
2635 }
2636 }
2637 }
2638
2639 let new_outputs: Vec<NodeId> = g.outputs.iter().map(|o| id_map[o]).collect();
2640 out.set_outputs(new_outputs);
2641 out
2642}
2643
2644pub fn inline_custom_fn_for_autodiff(g: Graph) -> Graph {
2649 use rlx_fusion::control_flow::inline_subgraph_into;
2650
2651 let mut out = Graph::new(g.name.clone());
2652 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
2653 let nodes: Vec<rlx_ir::Node> = g.nodes().to_vec();
2654
2655 for node in &nodes {
2656 let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
2657 let new_id = match &node.op {
2658 Op::CustomFn {
2659 vjp_body: None,
2660 jvp_body: None,
2661 fwd_body,
2662 num_inputs,
2663 ..
2664 } => {
2665 assert_eq!(
2666 new_inputs.len(),
2667 *num_inputs as usize,
2668 "custom_fn: outer input count mismatch"
2669 );
2670 inline_subgraph_into(fwd_body, &new_inputs, &mut out)
2671 }
2672 _ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
2673 };
2674 id_map.insert(node.id, new_id);
2675 }
2676
2677 let new_outputs: Vec<NodeId> = g.outputs.iter().map(|i| id_map[i]).collect();
2678 out.set_outputs(new_outputs);
2679 out
2680}
2681
2682pub(crate) fn unbroadcast_inverse(x: NodeId, target: &Shape, bwd: &mut Graph) -> NodeId {
2686 let target_dims: Vec<i64> = target
2687 .dims()
2688 .iter()
2689 .map(|d| match d {
2690 Dim::Static(n) => *n as i64,
2691 Dim::Dynamic(_) => -1,
2692 })
2693 .collect();
2694 bwd.add_node(
2695 Op::Expand {
2696 target_shape: target_dims,
2697 },
2698 vec![x],
2699 target.clone(),
2700 )
2701}
2702
2703fn expand_to(
2708 grad: NodeId,
2709 x_shape: &Shape,
2710 axes: &[usize],
2711 keep_dim: bool,
2712 bwd: &mut Graph,
2713) -> NodeId {
2714 let mut current = grad;
2715 if !keep_dim {
2716 let kept_dims: Vec<Dim> = (0..x_shape.rank())
2719 .map(|i| {
2720 if axes.contains(&i) {
2721 Dim::Static(1)
2722 } else {
2723 x_shape.dim(i)
2724 }
2725 })
2726 .collect();
2727 let kept = Shape::from_dims(&kept_dims, x_shape.dtype());
2728 current = reshape_to(current, &kept, bwd);
2729 }
2730 let target_shape: Vec<i64> = x_shape
2731 .dims()
2732 .iter()
2733 .map(|d| match d {
2734 Dim::Static(n) => *n as i64,
2735 Dim::Dynamic(_) => -1,
2736 })
2737 .collect();
2738 bwd.add_node(Op::Expand { target_shape }, vec![current], x_shape.clone())
2739}
2740
2741#[cfg(test)]
2742mod tests {
2743 use super::*;
2744
2745 #[test]
2746 fn grad_of_add_is_identity() {
2747 let mut g = Graph::new("test");
2748 let x = g.input("x", Shape::new(&[4], DType::F32));
2749 let y = g.input("y", Shape::new(&[4], DType::F32));
2750 let z = g.binary(BinaryOp::Add, x, y, Shape::new(&[4], DType::F32));
2751 g.set_outputs(vec![z]);
2752
2753 let bwd = grad(&g, &[x, y]);
2754 assert_eq!(bwd.outputs.len(), 2);
2756 }
2757
2758 #[test]
2759 fn grad_of_mul_uses_other_operand() {
2760 let mut g = Graph::new("test");
2761 let x = g.input("x", Shape::new(&[4], DType::F32));
2762 let y = g.input("y", Shape::new(&[4], DType::F32));
2763 let z = g.binary(BinaryOp::Mul, x, y, Shape::new(&[4], DType::F32));
2764 g.set_outputs(vec![z]);
2765
2766 let bwd = grad(&g, &[x, y]);
2767 assert!(
2769 bwd.nodes()
2770 .iter()
2771 .filter(|n| matches!(n.op, Op::Binary(BinaryOp::Mul)))
2772 .count()
2773 >= 2
2774 );
2775 }
2776
2777 #[test]
2778 fn grad_with_loss_returns_loss_first() {
2779 let mut g = Graph::new("loss");
2780 let x = g.input("x", Shape::new(&[4], DType::F32));
2781 let y = g.input("y", Shape::new(&[4], DType::F32));
2782 let z = g.binary(BinaryOp::Add, x, y, Shape::new(&[4], DType::F32));
2783 g.set_outputs(vec![z]);
2784
2785 let bwd = grad_with_loss(&g, &[x, y]);
2786 assert_eq!(bwd.outputs.len(), 3);
2788 }
2789
2790 #[test]
2791 fn grad_of_dense_solve_emits_implicit_function_rule() {
2792 let mut g = Graph::new("solve_test");
2806 let a = g.param("A", Shape::new(&[2, 2], DType::F32));
2807 let b = g.input("b", Shape::new(&[2], DType::F32));
2808 let x = g.dense_solve(a, b, Shape::new(&[2], DType::F32));
2809 let loss = g.reduce(
2810 x,
2811 ReduceOp::Sum,
2812 vec![0],
2813 false,
2814 Shape::new(&[1], DType::F32),
2815 );
2816 g.set_outputs(vec![loss]);
2817
2818 let bwd = grad_with_loss(&g, &[a, b]);
2819 assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
2820
2821 let count =
2822 |pred: fn(&Op) -> bool| -> usize { bwd.nodes().iter().filter(|n| pred(&n.op)).count() };
2823
2824 assert!(
2827 count(|o| matches!(o, Op::DenseSolve)) >= 2,
2828 "expected ≥2 DenseSolve nodes (forward mirror + reverse), got\n{bwd}"
2829 );
2830 assert!(
2831 count(|o| matches!(o, Op::Transpose { .. })) >= 1,
2832 "expected a Transpose for Aᵀ, got\n{bwd}"
2833 );
2834 assert!(
2835 count(|o| matches!(o, Op::MatMul)) >= 1,
2836 "expected a MatMul for the outer product, got\n{bwd}"
2837 );
2838 assert!(
2839 count(|o| matches!(o, Op::Activation(Activation::Neg))) >= 1,
2840 "expected a Neg for −outer, got\n{bwd}"
2841 );
2842 }
2843
2844 #[test]
2845 fn inline_if_replaces_with_where() {
2846 let s = Shape::new(&[4], DType::F32);
2853 let pred_s = Shape::new(&[1], DType::F32);
2854
2855 let mut then_g = Graph::new("then_branch");
2856 let then_in = then_g.input("captured", s.clone());
2857 let then_out = then_g.activation(Activation::Relu, then_in, s.clone());
2858 then_g.set_outputs(vec![then_out]);
2859
2860 let mut else_g = Graph::new("else_branch");
2861 let else_in = else_g.input("captured", s.clone());
2862 let else_out = else_g.activation(Activation::Sigmoid, else_in, s.clone());
2863 else_g.set_outputs(vec![else_out]);
2864
2865 let mut g = Graph::new("parent");
2866 let x = g.input("x", s.clone());
2867 let pred = g.input("pred", pred_s);
2868 let if_out = g.add_node(
2869 Op::If {
2870 then_branch: Box::new(then_g),
2871 else_branch: Box::new(else_g),
2872 },
2873 vec![pred, x],
2874 s,
2875 );
2876 g.set_outputs(vec![if_out]);
2877
2878 let inlined = rlx_fusion::control_flow::inline_if(g);
2879
2880 let has_if = inlined
2884 .nodes()
2885 .iter()
2886 .any(|n| matches!(n.op, Op::If { .. }));
2887 let has_where = inlined.nodes().iter().any(|n| matches!(n.op, Op::Where));
2888 let has_relu = inlined
2889 .nodes()
2890 .iter()
2891 .any(|n| matches!(n.op, Op::Activation(Activation::Relu)));
2892 let has_sigmoid = inlined
2893 .nodes()
2894 .iter()
2895 .any(|n| matches!(n.op, Op::Activation(Activation::Sigmoid)));
2896 assert!(!has_if, "Op::If should be inlined away");
2897 assert!(has_where, "Op::Where should replace the Op::If");
2898 assert!(has_relu, "then_branch's Activation(Relu) should be inlined");
2899 assert!(
2900 has_sigmoid,
2901 "else_branch's Activation(Sigmoid) should be inlined"
2902 );
2903 assert_eq!(inlined.outputs.len(), 1);
2904 }
2905
2906 #[test]
2907 fn grad_through_if_propagates() {
2908 let s = Shape::new(&[4], DType::F32);
2911 let pred_s = Shape::new(&[1], DType::F32);
2912
2913 let mut then_g = Graph::new("th");
2914 let ti = then_g.input("c", s.clone());
2915 let to = then_g.binary(BinaryOp::Mul, ti, ti, s.clone());
2916 then_g.set_outputs(vec![to]);
2917
2918 let mut else_g = Graph::new("el");
2919 let ei = else_g.input("c", s.clone());
2920 let eo = else_g.activation(Activation::Relu, ei, s.clone());
2921 else_g.set_outputs(vec![eo]);
2922
2923 let mut g = Graph::new("parent");
2924 let x = g.input("x", s.clone());
2925 let pred = g.input("pred", pred_s);
2926 let z = g.add_node(
2927 Op::If {
2928 then_branch: Box::new(then_g),
2929 else_branch: Box::new(else_g),
2930 },
2931 vec![pred, x],
2932 s,
2933 );
2934 g.set_outputs(vec![z]);
2935
2936 let bwd = grad_with_loss(&g, &[x]);
2937 assert_eq!(bwd.outputs.len(), 2, "expected loss + 1 grad output");
2939 }
2940
2941 #[test]
2942 fn unroll_while_replicates_body_n_times() {
2943 let s = Shape::new(&[4], DType::F32);
2949 let bool_s = Shape::new(&[1], DType::F32);
2950
2951 let mut cond_g = Graph::new("cond");
2952 let ci = cond_g.input("c", s.clone());
2953 cond_g.set_outputs(vec![ci]);
2956 let _ = bool_s;
2959
2960 let mut body_g = Graph::new("body");
2961 let bi = body_g.input("c", s.clone());
2962 let bo = body_g.activation(Activation::Relu, bi, s.clone());
2963 body_g.set_outputs(vec![bo]);
2964
2965 let mut g = Graph::new("parent");
2966 let x = g.input("x", s.clone());
2967 let w = g.add_node(
2968 Op::While {
2969 cond: Box::new(cond_g),
2970 body: Box::new(body_g),
2971 max_iterations: Some(3),
2972 },
2973 vec![x],
2974 s,
2975 );
2976 g.set_outputs(vec![w]);
2977
2978 let unrolled = rlx_fusion::control_flow::unroll_while(g);
2979
2980 let has_while = unrolled
2981 .nodes()
2982 .iter()
2983 .any(|n| matches!(n.op, Op::While { .. }));
2984 let relu_count = unrolled
2985 .nodes()
2986 .iter()
2987 .filter(|n| matches!(n.op, Op::Activation(Activation::Relu)))
2988 .count();
2989 assert!(!has_while, "Op::While should be unrolled away");
2990 assert_eq!(
2991 relu_count, 3,
2992 "body's Activation(Relu) should appear once per iteration"
2993 );
2994 assert_eq!(unrolled.outputs.len(), 1);
2995 }
2996
2997 #[test]
2998 fn grad_through_while_propagates() {
2999 let s = Shape::new(&[4], DType::F32);
3003
3004 let mut cond_g = Graph::new("cond");
3005 let ci = cond_g.input("c", s.clone());
3006 cond_g.set_outputs(vec![ci]);
3007
3008 let mut body_g = Graph::new("body");
3009 let bi = body_g.input("c", s.clone());
3010 let bo = body_g.binary(BinaryOp::Mul, bi, bi, s.clone());
3011 body_g.set_outputs(vec![bo]);
3012
3013 let mut g = Graph::new("parent");
3014 let x = g.input("x", s.clone());
3015 let w = g.add_node(
3016 Op::While {
3017 cond: Box::new(cond_g),
3018 body: Box::new(body_g),
3019 max_iterations: Some(2),
3020 },
3021 vec![x],
3022 s,
3023 );
3024 g.set_outputs(vec![w]);
3025
3026 let bwd = grad_with_loss(&g, &[x]);
3027 assert_eq!(bwd.outputs.len(), 2, "expected loss + 1 grad output");
3028 }
3029
3030 fn build_ftl_graph(has_bias: bool) -> (Graph, NodeId, Vec<NodeId>) {
3033 let mut g = Graph::new("ftl_test");
3035 let h_shape = Shape::new(&[1, 2, 4], DType::F32);
3036 let h = g.input("h", h_shape.clone());
3037 let qkv_w = g.param("qkv_w", Shape::new(&[4, 12], DType::F32));
3038 let out_w = g.param("out_w", Shape::new(&[4, 4], DType::F32));
3039 let ln1_g = g.param("ln1_g", Shape::new(&[4], DType::F32));
3040 let fc1_w = g.param("fc1_w", Shape::new(&[4, 8], DType::F32));
3041 let fc2_w = g.param("fc2_w", Shape::new(&[8, 4], DType::F32));
3042 let ln2_g = g.param("ln2_g", Shape::new(&[4], DType::F32));
3043 let mask = g.input("mask", Shape::new(&[1, 2, 2, 2], DType::F32));
3044
3045 let (inputs, params) = if has_bias {
3046 let qkv_b = g.param("qkv_b", Shape::new(&[12], DType::F32));
3047 let out_b = g.param("out_b", Shape::new(&[4], DType::F32));
3048 let ln1_b = g.param("ln1_b", Shape::new(&[4], DType::F32));
3049 let fc1_b = g.param("fc1_b", Shape::new(&[8], DType::F32));
3050 let fc2_b = g.param("fc2_b", Shape::new(&[4], DType::F32));
3051 let ln2_b = g.param("ln2_b", Shape::new(&[4], DType::F32));
3052 (
3053 vec![
3054 h, qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g,
3055 ln2_b, mask,
3056 ],
3057 vec![
3058 qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g,
3059 ln2_b,
3060 ],
3061 )
3062 } else {
3063 (
3064 vec![h, qkv_w, out_w, ln1_g, fc1_w, fc2_w, ln2_g, mask],
3065 vec![qkv_w, out_w, ln1_g, fc1_w, fc2_w, ln2_g],
3066 )
3067 };
3068 let y = g.add_node(
3069 Op::FusedTransformerLayer {
3070 num_heads: 2,
3071 head_dim: 2,
3072 intermediate_size: 8,
3073 eps1: 1e-5,
3074 eps2: 1e-5,
3075 activation: rlx_ir::op::Activation::Gelu,
3076 has_bias,
3077 },
3078 inputs,
3079 h_shape,
3080 );
3081 g.set_outputs(vec![y]);
3082 (g, h, params)
3083 }
3084
3085 #[test]
3086 fn unfuse_decomposes_fused_transformer_layer() {
3087 let (g, _h, _params) = build_ftl_graph(true);
3091 let unfused = rlx_fusion::unfuse_fused_for_autodiff(g);
3092
3093 let has_ftl = unfused
3094 .nodes()
3095 .iter()
3096 .any(|n| matches!(n.op, Op::FusedTransformerLayer { .. }));
3097 assert!(!has_ftl, "Op::FusedTransformerLayer should be unfused");
3098
3099 let count = |pred: fn(&Op) -> bool| -> usize {
3100 unfused.nodes().iter().filter(|n| pred(&n.op)).count()
3101 };
3102 assert!(
3103 count(|o| matches!(o, Op::MatMul)) >= 4,
3104 "expected >=4 MatMul after FTL unfuse"
3105 );
3106 assert_eq!(
3107 count(|o| matches!(o, Op::Attention { .. })),
3108 1,
3109 "expected exactly 1 Attention after FTL unfuse"
3110 );
3111 assert_eq!(
3112 count(|o| matches!(o, Op::LayerNorm { .. })),
3113 2,
3114 "expected exactly 2 LayerNorm after FTL unfuse"
3115 );
3116 assert!(
3117 count(|o| matches!(o, Op::Narrow { .. })) >= 3,
3118 "expected >=3 Narrow (Q/K/V split) after FTL unfuse"
3119 );
3120 assert_eq!(
3121 count(|o| matches!(o, Op::Activation(_))),
3122 1,
3123 "expected exactly 1 Activation (FFN) after FTL unfuse"
3124 );
3125 }
3126
3127 #[test]
3128 fn grad_through_fused_transformer_layer_propagates() {
3129 let (g, _h, params) = build_ftl_graph(true);
3133 let bwd = grad_with_loss(&g, ¶ms);
3134 assert_eq!(
3135 bwd.outputs.len(),
3136 1 + params.len(),
3137 "expected loss + {} param grads",
3138 params.len()
3139 );
3140 }
3141
3142 #[test]
3143 fn grad_through_fused_transformer_layer_no_bias() {
3144 let (g, _h, params) = build_ftl_graph(false);
3147 let bwd = grad_with_loss(&g, ¶ms);
3148 assert_eq!(
3149 bwd.outputs.len(),
3150 1 + params.len(),
3151 "expected loss + {} param grads (no-bias)",
3152 params.len()
3153 );
3154 }
3155
3156 fn build_ssm_graph() -> (Graph, NodeId, Vec<NodeId>) {
3159 let mut g = Graph::new("ssm_test");
3160 let bsh = Shape::new(&[1, 3, 2], DType::F32);
3161 let hn = Shape::new(&[2, 4], DType::F32);
3162 let bsn = Shape::new(&[1, 3, 4], DType::F32);
3163
3164 let x = g.input("x", bsh.clone());
3165 let delta = g.input("delta", bsh.clone());
3166 let a = g.param("a", hn);
3167 let b = g.input("b", bsn.clone());
3168 let c = g.input("c", bsn);
3169 let y = g.selective_scan(x, delta, a, b, c, 4, bsh);
3170 g.set_outputs(vec![y]);
3171 (g, x, vec![a])
3172 }
3173
3174 #[test]
3175 fn unfuse_decomposes_selective_scan() {
3176 let (g, _x, _params) = build_ssm_graph();
3181 let unfused = rlx_fusion::unfuse_fused_for_autodiff(g);
3182
3183 let has_ssm = unfused
3184 .nodes()
3185 .iter()
3186 .any(|n| matches!(n.op, Op::SelectiveScan { .. }));
3187 assert!(!has_ssm, "Op::SelectiveScan should be unfused");
3188
3189 let count = |pred: fn(&Op) -> bool| -> usize {
3190 unfused.nodes().iter().filter(|n| pred(&n.op)).count()
3191 };
3192 assert_eq!(
3193 count(|o| matches!(o, Op::Concat { .. })),
3194 1,
3195 "expected 1 Concat (over the 3 time steps)"
3196 );
3197 assert_eq!(
3198 count(|o| matches!(
3199 o,
3200 Op::Reduce {
3201 op: ReduceOp::Sum,
3202 ..
3203 }
3204 )),
3205 3,
3206 "expected one Reduce(Sum) per time step (S=3)"
3207 );
3208 assert_eq!(
3209 count(|o| matches!(o, Op::Activation(Activation::Exp))),
3210 3,
3211 "expected one exp(δA) per time step (S=3)"
3212 );
3213 assert!(
3214 count(|o| matches!(o, Op::Narrow { .. })) >= 12,
3215 "expected >=12 Narrows (4 per step × 3 steps)"
3216 );
3217 }
3218
3219 #[test]
3220 fn grad_through_selective_scan_propagates() {
3221 let (g, _x, params) = build_ssm_graph();
3227 let bwd = grad_with_loss(&g, ¶ms);
3228 assert_eq!(
3229 bwd.outputs.len(),
3230 1 + params.len(),
3231 "expected loss + {} param grads",
3232 params.len()
3233 );
3234 }
3235
3236 fn build_gdn_graph() -> (Graph, NodeId, Vec<NodeId>) {
3238 let (b, s, h, n) = (1usize, 3, 2, 4);
3239 let mut g = Graph::new("gdn_test");
3240 let bshn = Shape::new(&[b, s, h, n], DType::F32);
3241 let bsh = Shape::new(&[b, s, h], DType::F32);
3242 let q = g.input("q", bshn.clone());
3243 let k = g.input("k", bshn.clone());
3244 let v = g.input("v", bshn.clone());
3245 let g_in = g.input("g", bsh.clone());
3246 let beta = g.input("beta", bsh);
3247 let y = g.gated_delta_net(q, k, v, g_in, beta, n, bshn);
3248 g.set_outputs(vec![y]);
3249 (g, q, vec![q, k, v, g_in, beta])
3250 }
3251
3252 #[test]
3253 fn unfuse_decomposes_gated_delta_net() {
3254 let (g, _q, _params) = build_gdn_graph();
3255 let unfused = rlx_fusion::unfuse_fused_for_autodiff(g);
3256
3257 let has_gdn = unfused
3258 .nodes()
3259 .iter()
3260 .any(|n| matches!(n.op, Op::GatedDeltaNet { .. }));
3261 assert!(!has_gdn, "Op::GatedDeltaNet should be unfused");
3262
3263 let count = |pred: fn(&Op) -> bool| -> usize {
3264 unfused.nodes().iter().filter(|n| pred(&n.op)).count()
3265 };
3266 assert_eq!(
3267 count(|o| matches!(o, Op::Concat { .. })),
3268 1,
3269 "expected 1 Concat over S=3 steps"
3270 );
3271 assert!(
3272 count(|o| matches!(o, Op::MatMul)) >= 3,
3273 "expected >=3 MatMul per step (sk + out) × S=3"
3274 );
3275 assert_eq!(
3276 count(|o| matches!(o, Op::Activation(Activation::Exp))),
3277 3,
3278 "expected one exp(g) per time step"
3279 );
3280 }
3281
3282 #[test]
3283 fn grad_through_gated_delta_net_propagates() {
3284 let (g, _q, params) = build_gdn_graph();
3285 let bwd = grad_with_loss(&g, ¶ms);
3286 assert_eq!(
3287 bwd.outputs.len(),
3288 1 + params.len(),
3289 "expected loss + {} input grads",
3290 params.len()
3291 );
3292 }
3293
3294 #[test]
3295 fn custom_fn_vjp_body_is_inlined_into_bwd() {
3296 let n = 4usize;
3304 let shape = Shape::new(&[n], DType::F32);
3305
3306 let mut fwd_body = Graph::new("square_fwd");
3308 let xb = fwd_body.input("x", shape.clone());
3309 let yb = fwd_body.binary(BinaryOp::Mul, xb, xb, shape.clone());
3310 fwd_body.set_outputs(vec![yb]);
3311
3312 let mut vjp_body = Graph::new("square_vjp");
3314 let _vx = vjp_body.input("x", shape.clone());
3315 let _vp = vjp_body.input("primal_output", shape.clone());
3316 let vd = vjp_body.input("d_output", shape.clone());
3317 let dx = vjp_body.activation(Activation::Sin, vd, shape.clone());
3318 vjp_body.set_outputs(vec![dx]);
3319
3320 let mut g = Graph::new("custom_fn_test");
3321 let x = g.input("x", shape.clone());
3322 let y = g.custom_fn(vec![x], fwd_body, Some(vjp_body), None);
3323 let loss = g.reduce(
3324 y,
3325 ReduceOp::Sum,
3326 vec![0],
3327 false,
3328 Shape::new(&[1], DType::F32),
3329 );
3330 g.set_outputs(vec![loss]);
3331
3332 let bwd = grad_with_loss(&g, &[x]);
3333 assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
3334 let sin_count = bwd
3335 .nodes()
3336 .iter()
3337 .filter(|n| matches!(n.op, Op::Activation(Activation::Sin)))
3338 .count();
3339 assert!(
3340 sin_count >= 1,
3341 "expected the vjp_body's Sin to be inlined into bwd, got\n{bwd}"
3342 );
3343 }
3344
3345 #[test]
3346 fn custom_fn_without_vjp_inlines_fwd_body_for_autodiff() {
3347 let n = 4usize;
3351 let shape = Shape::new(&[n], DType::F32);
3352
3353 let mut fwd_body = Graph::new("square_fwd");
3354 let xb = fwd_body.input("x", shape.clone());
3355 let yb = fwd_body.binary(BinaryOp::Mul, xb, xb, shape.clone());
3356 fwd_body.set_outputs(vec![yb]);
3357
3358 let mut g = Graph::new("custom_fn_no_vjp");
3359 let x = g.input("x", shape.clone());
3360 let y = g.custom_fn(vec![x], fwd_body, None, None);
3361 let loss = g.reduce(
3362 y,
3363 ReduceOp::Sum,
3364 vec![0],
3365 false,
3366 Shape::new(&[1], DType::F32),
3367 );
3368 g.set_outputs(vec![loss]);
3369
3370 let bwd = grad_with_loss(&g, &[x]);
3371 assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
3372 let custom_fn_count = bwd
3373 .nodes()
3374 .iter()
3375 .filter(|n| matches!(n.op, Op::CustomFn { .. }))
3376 .count();
3377 assert_eq!(
3378 custom_fn_count, 0,
3379 "CustomFn should be inlined away before autodiff"
3380 );
3381 let mul_count = bwd
3382 .nodes()
3383 .iter()
3384 .filter(|n| matches!(n.op, Op::Binary(BinaryOp::Mul)))
3385 .count();
3386 assert!(mul_count >= 2, "expected Mul-based VJP for x², got\n{bwd}");
3387 }
3388
3389 #[test]
3390 fn convert_scans_for_ad_forces_save_trajectory_true() {
3391 let n = 2usize;
3398 let length = 3u32;
3399 let carry = Shape::new(&[n], DType::F32);
3400 let xs_shape = Shape::new(&[length as usize, n], DType::F32);
3401
3402 let mut body = Graph::new("scan_body");
3404 let bc = body.input("carry", carry.clone());
3405 let bx = body.input("x_t", carry.clone());
3406 let by = body.binary(BinaryOp::Add, bc, bx, carry.clone());
3407 body.set_outputs(vec![by]);
3408
3409 let mut g = Graph::new("scan_save_false");
3410 let init = g.input("init", carry.clone());
3411 let xs = g.input("xs", xs_shape);
3412 let scan_out = g.add_node(
3413 Op::Scan {
3414 body: Box::new(body),
3415 length,
3416 save_trajectory: false,
3417 num_bcast: 0,
3418 num_xs: 1,
3419 num_checkpoints: 0,
3420 },
3421 vec![init, xs],
3422 carry.clone(),
3423 );
3424 let loss = g.reduce(
3425 scan_out,
3426 ReduceOp::Sum,
3427 vec![0],
3428 false,
3429 Shape::new(&[1], DType::F32),
3430 );
3431 g.set_outputs(vec![loss]);
3432
3433 let bwd = grad_with_loss(&g, &[init, xs]);
3434 let saved_traj = bwd.nodes().iter().any(|n| {
3435 matches!(
3436 &n.op,
3437 Op::Scan {
3438 save_trajectory: true,
3439 ..
3440 }
3441 )
3442 });
3443 assert!(
3444 saved_traj,
3445 "convert_scans_for_ad should rewrite save_trajectory=false → \
3446 save_trajectory=true in the AD-prepared graph; got\n{bwd}"
3447 );
3448 }
3449}