1use rlx_ir::op::*;
114use rlx_ir::shape::Dim;
115use rlx_ir::*;
116use std::collections::HashMap;
117
118pub use crate::prepare_ad::{
119 AutodiffError, PrepareForAutodiff, grad_with_loss_module, jvp_module, prepare_graph_for_ad,
120 prepare_mir_for_ad, prepare_module_for_ad,
121};
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub struct GradWithLossOptions {
150 pub zero_missing_wrt: bool,
153}
154
155impl GradWithLossOptions {
156 pub const STRICT: Self = Self {
157 zero_missing_wrt: false,
158 };
159 pub const TRAINING: Self = Self {
160 zero_missing_wrt: true,
161 };
162}
163
164pub fn grad_with_loss(forward: &Graph, wrt: &[NodeId]) -> Graph {
170 grad_with_loss_opts(forward, wrt, GradWithLossOptions::STRICT)
171}
172
173pub fn grad_with_loss_opts(forward: &Graph, wrt: &[NodeId], opts: GradWithLossOptions) -> Graph {
175 assert!(
176 !forward.outputs.is_empty(),
177 "grad_with_loss: forward must have at least one output (the loss)"
178 );
179
180 let forward_owned = crate::prepare_ad::prepare_graph_for_ad(forward.clone());
193 let forward = &forward_owned;
194
195 let mut bwd = Graph::new(format!("{}_grad", forward.name));
196
197 let mut fwd_to_bwd: HashMap<NodeId, NodeId> = HashMap::new();
202 for node in forward.nodes() {
203 let inputs: Vec<NodeId> = node.inputs.iter().map(|i| fwd_to_bwd[i]).collect();
204 let new_id = bwd.add_node(node.op.clone(), inputs, node.shape.clone());
205 fwd_to_bwd.insert(node.id, new_id);
206 }
207
208 let loss_fwd = forward.outputs[0];
211 let loss_bwd = fwd_to_bwd[&loss_fwd];
212 let loss_shape = forward.node(loss_fwd).shape.clone();
213 let d_output = bwd.input("d_output", loss_shape);
214
215 let mut grads: HashMap<NodeId, NodeId> = HashMap::new();
216 grads.insert(loss_bwd, d_output);
217
218 for fwd_node in forward.nodes().iter().rev() {
219 let bwd_id = fwd_to_bwd[&fwd_node.id];
220 let upstream = match grads.get(&bwd_id) {
221 Some(g) => *g,
222 None => continue,
223 };
224 let input_grads = vjp(fwd_node, upstream, &fwd_to_bwd, &mut bwd);
225 for (idx, grad_id) in input_grads {
226 let target = fwd_node.inputs[idx];
227 let bwd_target = fwd_to_bwd[&target];
228 let new_grad = if let Some(&prev) = grads.get(&bwd_target) {
230 let shape = bwd.node(prev).shape.clone();
231 bwd.binary(BinaryOp::Add, prev, grad_id, shape)
232 } else {
233 grad_id
234 };
235 grads.insert(bwd_target, new_grad);
236 }
237 }
238
239 let n_aux = forward.outputs.len().saturating_sub(1);
240 let mut outputs = Vec::with_capacity(1 + n_aux + wrt.len());
241 outputs.push(loss_bwd);
242 for &aux in &forward.outputs[1..] {
245 outputs.push(fwd_to_bwd[&aux]);
246 }
247 for &id in wrt {
248 let g = match grads.get(&fwd_to_bwd[&id]).copied() {
249 Some(g) => g,
250 None if opts.zero_missing_wrt => {
251 let shape = forward.node(id).shape.clone();
252 let n = shape.num_elements().unwrap_or(0);
253 let data: Vec<u8> = (0..n).flat_map(|_| 0.0f32.to_le_bytes()).collect();
254 bwd.add_node(Op::Constant { data }, vec![], shape)
255 }
256 None => {
257 panic!(
258 "no gradient flowed to {id} — \
259 either the forward graph doesn't depend on it, or one \
260 of its consumer ops has no VJP rule"
261 )
262 }
263 };
264 outputs.push(g);
265 }
266 bwd.set_outputs(outputs);
267 bwd
268}
269
270pub fn grad(forward: &Graph, wrt: &[NodeId]) -> Graph {
274 let g = grad_with_loss(forward, wrt);
275 let mut g = g;
276 let outs = g.outputs.iter().skip(1).copied().collect();
278 g.set_outputs(outs);
279 g
280}
281
282pub fn quantized_weight_bits(forward: &Graph, node_id: NodeId) -> Option<u8> {
295 match &forward.node(node_id).op {
296 Op::FakeQuantize { bits, .. } => Some(*bits),
297 Op::FakeQuantizeLSQ { bits, .. } => Some(*bits),
298 _ => None,
299 }
300}
301
302fn unbroadcast(grad: NodeId, target_shape: &Shape, bwd: &mut Graph) -> NodeId {
303 let grad_shape = bwd.node(grad).shape.clone();
304 if grad_shape == *target_shape {
305 return grad;
306 }
307 let g_rank = grad_shape.rank();
308 let t_rank = target_shape.rank();
309 let extra = g_rank.saturating_sub(t_rank);
310
311 let mut axes: Vec<usize> = (0..extra).collect();
313 for i in 0..t_rank {
314 let g_dim = grad_shape.dim(extra + i);
315 let t_dim = target_shape.dim(i);
316 if matches!(t_dim, Dim::Static(1)) && !matches!(g_dim, Dim::Static(1)) {
317 axes.push(extra + i);
318 }
319 }
320
321 let mut current = grad;
322 if !axes.is_empty() {
323 let mut running_dims: Vec<Dim> = (0..g_rank).map(|i| grad_shape.dim(i)).collect();
330 for &ax in &axes {
331 running_dims[ax] = Dim::Static(1);
332 let step_shape = Shape::from_dims(&running_dims, target_shape.dtype());
333 current = bwd.add_node(
334 Op::Reduce {
335 op: ReduceOp::Sum,
336 axes: vec![ax],
337 keep_dim: true,
338 },
339 vec![current],
340 step_shape,
341 );
342 }
343 }
344
345 if bwd.node(current).shape.rank() != t_rank {
347 let new_shape: Vec<i64> = target_shape
348 .dims()
349 .iter()
350 .map(|d| match d {
351 Dim::Static(n) => *n as i64,
352 Dim::Dynamic(_) => -1,
353 })
354 .collect();
355 current = bwd.add_node(
356 Op::Reshape { new_shape },
357 vec![current],
358 target_shape.clone(),
359 );
360 }
361 current
362}
363
364fn reshape_to(grad: NodeId, target_shape: &Shape, bwd: &mut Graph) -> NodeId {
366 if bwd.node(grad).shape == *target_shape {
367 return grad;
368 }
369 let new_shape: Vec<i64> = target_shape
370 .dims()
371 .iter()
372 .map(|d| match d {
373 Dim::Static(n) => *n as i64,
374 Dim::Dynamic(_) => -1,
375 })
376 .collect();
377 bwd.add_node(Op::Reshape { new_shape }, vec![grad], target_shape.clone())
378}
379
380fn grouped_matmul_vjp(
382 bwd: &mut Graph,
383 upstream: NodeId,
384 x_bwd: NodeId,
385 w_bwd: NodeId,
386 expert_bwd: NodeId,
387 x_shape: &Shape,
388 w_shape: &Shape,
389) -> (NodeId, NodeId) {
390 let dtype = x_shape.dtype();
391 let m = x_shape.dim(0);
392 let k = x_shape.dim(1);
393 let e = w_shape.dim(0);
394 let n_out = w_shape.dim(2);
395 let m_static = match m {
396 Dim::Static(v) => v,
397 _ => panic!("GroupedMatMul VJP: M must be static"),
398 };
399 let k_static = match k {
400 Dim::Static(v) => v,
401 _ => panic!("GroupedMatMul VJP: K must be static"),
402 };
403 let n_static = match n_out {
404 Dim::Static(v) => v,
405 _ => panic!("GroupedMatMul VJP: N must be static"),
406 };
407
408 let w_per = bwd.add_node(
409 Op::Gather { axis: 0 },
410 vec![w_bwd, expert_bwd],
411 Shape::from_dims(&[m, k, n_out], dtype),
412 );
413
414 let up_3d_shape = Shape::from_dims(&[m, Dim::Static(1), n_out], dtype);
415 let up_3d = bwd.reshape(
416 upstream,
417 vec![m_static as i64, 1, n_static as i64],
418 up_3d_shape,
419 );
420 let w_per_t = bwd.add_node(
421 Op::Transpose {
422 perm: vec![0, 2, 1],
423 },
424 vec![w_per],
425 Shape::from_dims(&[m, n_out, k], dtype),
426 );
427 let dx_3d_shape = Shape::from_dims(&[m, Dim::Static(1), k], dtype);
428 let dx_3d = bwd.matmul(up_3d, w_per_t, dx_3d_shape);
429 let dx = bwd.reshape(
430 dx_3d,
431 vec![m_static as i64, k_static as i64],
432 x_shape.clone(),
433 );
434
435 let x_3d = bwd.reshape(
436 x_bwd,
437 vec![m_static as i64, k_static as i64, 1],
438 Shape::from_dims(&[m, k, Dim::Static(1)], dtype),
439 );
440 let up_for_outer = bwd.reshape(
441 upstream,
442 vec![m_static as i64, 1, n_static as i64],
443 Shape::from_dims(&[m, Dim::Static(1), n_out], dtype),
444 );
445 let dw_per = bwd.matmul(x_3d, up_for_outer, Shape::from_dims(&[m, k, n_out], dtype));
446 let dw = bwd.add_node(
447 Op::ScatterAdd,
448 vec![dw_per, expert_bwd],
449 Shape::from_dims(&[e, k, n_out], dtype),
450 );
451 (dx, dw)
452}
453
454fn scalar_const(value: f32, bwd: &mut Graph) -> NodeId {
456 let bytes = value.to_le_bytes().to_vec();
457 let shape = Shape::from_dims(&[Dim::Static(1)], DType::F32);
458 bwd.add_node(Op::Constant { data: bytes }, vec![], shape)
459}
460
461fn vjp(
465 node: &Node,
466 upstream: NodeId,
467 fwd_map: &HashMap<NodeId, NodeId>,
468 bwd: &mut Graph,
469) -> Vec<(usize, NodeId)> {
470 let upstream_shape = bwd.node(upstream).shape.clone();
471 match &node.op {
472 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => vec![],
474
475 Op::Binary(BinaryOp::Add) => {
476 let a_bwd = fwd_map[&node.inputs[0]];
477 let b_bwd = fwd_map[&node.inputs[1]];
478 let a_shape = bwd.node(a_bwd).shape.clone();
479 let b_shape = bwd.node(b_bwd).shape.clone();
480 let g_a = unbroadcast(upstream, &a_shape, bwd);
481 let g_b = unbroadcast(upstream, &b_shape, bwd);
482 vec![(0, g_a), (1, g_b)]
483 }
484
485 Op::Binary(BinaryOp::Sub) => {
486 let a_bwd = fwd_map[&node.inputs[0]];
487 let b_bwd = fwd_map[&node.inputs[1]];
488 let a_shape = bwd.node(a_bwd).shape.clone();
489 let b_shape = bwd.node(b_bwd).shape.clone();
490 let neg = bwd.activation(Activation::Neg, upstream, upstream_shape.clone());
491 let g_a = unbroadcast(upstream, &a_shape, bwd);
492 let g_b = unbroadcast(neg, &b_shape, bwd);
493 vec![(0, g_a), (1, g_b)]
494 }
495
496 Op::Binary(BinaryOp::Mul) => {
497 let a_bwd = fwd_map[&node.inputs[0]];
498 let b_bwd = fwd_map[&node.inputs[1]];
499 let a_shape = bwd.node(a_bwd).shape.clone();
500 let b_shape = bwd.node(b_bwd).shape.clone();
501 let is_c64 = upstream_shape.dtype() == DType::C64;
507 let b_for_a = if is_c64 { bwd.conjugate(b_bwd) } else { b_bwd };
508 let a_for_b = if is_c64 { bwd.conjugate(a_bwd) } else { a_bwd };
509 let g_a_full = bwd.binary(BinaryOp::Mul, upstream, b_for_a, upstream_shape.clone());
510 let g_b_full = bwd.binary(BinaryOp::Mul, upstream, a_for_b, upstream_shape);
511 let g_a = unbroadcast(g_a_full, &a_shape, bwd);
512 let g_b = unbroadcast(g_b_full, &b_shape, bwd);
513 vec![(0, g_a), (1, g_b)]
514 }
515
516 Op::Activation(kind) => {
517 let x_bwd = fwd_map[&node.inputs[0]];
518 let dx = match kind {
523 Activation::Relu => bwd.relu_backward(x_bwd, upstream),
524 _ => bwd.activation_backward(*kind, x_bwd, upstream),
525 };
526 vec![(0, dx)]
527 }
528
529 Op::MatMul => {
530 let a_bwd = fwd_map[&node.inputs[0]];
541 let b_bwd = fwd_map[&node.inputs[1]];
542 let a_shape = bwd.node(a_bwd).shape.clone();
543 let b_shape = bwd.node(b_bwd).shape.clone();
544 assert!(
545 a_shape.rank() >= 2 && b_shape.rank() >= 2,
546 "MatMul VJP: rank must be ≥ 2, got {} and {}",
547 a_shape.rank(),
548 b_shape.rank()
549 );
550 let dtype = upstream_shape.dtype();
551
552 let trans_last_two = |bwd: &mut Graph, x: NodeId| -> NodeId {
554 let s = bwd.node(x).shape.clone();
555 let r = s.rank();
556 let mut perm: Vec<usize> = (0..r).collect();
557 perm.swap(r - 2, r - 1);
558 let mut dims: Vec<Dim> = s.dims().to_vec();
559 dims.swap(r - 2, r - 1);
560 let new_shape = Shape::from_dims(&dims, s.dtype());
561 bwd.add_node(Op::Transpose { perm }, vec![x], new_shape)
562 };
563
564 let upstream_dims: Vec<Dim> = upstream_shape.dims().to_vec();
567 let r_up = upstream_dims.len();
568
569 let is_c64 = dtype == DType::C64;
575
576 let b_t = trans_last_two(bwd, b_bwd);
578 let b_t = if is_c64 { bwd.conjugate(b_t) } else { b_t };
579 let mut ga_dims = upstream_dims.clone();
580 ga_dims[r_up - 1] = a_shape.dim(a_shape.rank() - 1); let ga_shape = Shape::from_dims(&ga_dims, dtype);
582 let g_a_full = bwd.matmul(upstream, b_t, ga_shape);
583 let g_a = unbroadcast(g_a_full, &a_shape, bwd);
584
585 let a_t = trans_last_two(bwd, a_bwd);
587 let a_t = if is_c64 { bwd.conjugate(a_t) } else { a_t };
588 let mut gb_dims = upstream_dims.clone();
589 gb_dims[r_up - 2] = a_shape.dim(a_shape.rank() - 1); let gb_shape = Shape::from_dims(&gb_dims, dtype);
591 let g_b_full = bwd.matmul(a_t, upstream, gb_shape);
592 let g_b = unbroadcast(g_b_full, &b_shape, bwd);
593
594 vec![(0, g_a), (1, g_b)]
595 }
596
597 Op::DenseSolve => {
598 let a_bwd = fwd_map[&node.inputs[0]];
607 let x_bwd = fwd_map[&node.id];
608 let a_shape = bwd.node(a_bwd).shape.clone();
609 let x_shape = bwd.node(x_bwd).shape.clone();
610 assert_eq!(a_shape.rank(), 2, "DenseSolve VJP: A must be 2-D");
611 let n = match a_shape.dim(0) {
612 Dim::Static(n) => n,
613 Dim::Dynamic(_) => panic!("DenseSolve VJP: dynamic N not supported"),
614 };
615 let dtype = a_shape.dtype();
616
617 let mut a_t_dims: Vec<Dim> = a_shape.dims().to_vec();
619 a_t_dims.swap(0, 1);
620 let a_t_shape = Shape::from_dims(&a_t_dims, dtype);
621 let a_t = bwd.add_node(Op::Transpose { perm: vec![1, 0] }, vec![a_bwd], a_t_shape);
622
623 let d_b = bwd.dense_solve(a_t, upstream, x_shape.clone());
625
626 let neg_outer = match x_shape.rank() {
628 1 => {
629 let col_shape = Shape::from_dims(&[Dim::Static(n), Dim::Static(1)], dtype);
631 let row_shape = Shape::from_dims(&[Dim::Static(1), Dim::Static(n)], dtype);
632 let db_col = bwd.add_node(
633 Op::Reshape {
634 new_shape: vec![n as i64, 1],
635 },
636 vec![d_b],
637 col_shape,
638 );
639 let x_row = bwd.add_node(
640 Op::Reshape {
641 new_shape: vec![1, n as i64],
642 },
643 vec![x_bwd],
644 row_shape,
645 );
646 let outer = bwd.matmul(db_col, x_row, a_shape.clone());
647 bwd.activation(Activation::Neg, outer, a_shape)
648 }
649 2 => {
650 let k = match x_shape.dim(1) {
652 Dim::Static(k) => k,
653 Dim::Dynamic(_) => panic!("DenseSolve VJP: dynamic K not supported"),
654 };
655 let xt_dims = vec![Dim::Static(k), Dim::Static(n)];
656 let xt_shape = Shape::from_dims(&xt_dims, dtype);
657 let x_t =
658 bwd.add_node(Op::Transpose { perm: vec![1, 0] }, vec![x_bwd], xt_shape);
659 let outer = bwd.matmul(d_b, x_t, a_shape.clone());
660 bwd.activation(Activation::Neg, outer, a_shape)
661 }
662 r => panic!("DenseSolve VJP: B must be rank 1 or 2, got rank {r}"),
663 };
664
665 vec![(0, neg_outer), (1, d_b)]
666 }
667
668 Op::BatchedDenseSolve => {
669 let a_bwd = fwd_map[&node.inputs[0]];
675 let x_bwd = fwd_map[&node.id];
676 let a_shape = bwd.node(a_bwd).shape.clone();
677 let x_shape = bwd.node(x_bwd).shape.clone();
678 assert_eq!(
679 a_shape.rank(),
680 3,
681 "BatchedDenseSolve VJP: A must be rank-3 [B, N, N]"
682 );
683 let b_dim = match a_shape.dim(0) {
684 Dim::Static(b) => b,
685 Dim::Dynamic(_) => panic!("BatchedDenseSolve VJP: dynamic B not supported"),
686 };
687 let n = match a_shape.dim(1) {
688 Dim::Static(n) => n,
689 Dim::Dynamic(_) => panic!("BatchedDenseSolve VJP: dynamic N not supported"),
690 };
691 let dtype = a_shape.dtype();
692
693 let a_t = bwd.add_node(
696 Op::Transpose {
697 perm: vec![0, 2, 1],
698 },
699 vec![a_bwd],
700 a_shape.clone(),
701 );
702
703 let d_b = bwd.batched_dense_solve(a_t, upstream, x_shape.clone());
705
706 let neg_outer = match x_shape.rank() {
708 2 => {
709 let col_shape = Shape::from_dims(
712 &[Dim::Static(b_dim), Dim::Static(n), Dim::Static(1)],
713 dtype,
714 );
715 let row_shape = Shape::from_dims(
716 &[Dim::Static(b_dim), Dim::Static(1), Dim::Static(n)],
717 dtype,
718 );
719 let db_col = bwd.add_node(
720 Op::Reshape {
721 new_shape: vec![b_dim as i64, n as i64, 1],
722 },
723 vec![d_b],
724 col_shape,
725 );
726 let x_row = bwd.add_node(
727 Op::Reshape {
728 new_shape: vec![b_dim as i64, 1, n as i64],
729 },
730 vec![x_bwd],
731 row_shape,
732 );
733 let outer = bwd.matmul(db_col, x_row, a_shape.clone());
734 bwd.activation(Activation::Neg, outer, a_shape)
735 }
736 3 => {
737 let k = match x_shape.dim(2) {
740 Dim::Static(k) => k,
741 Dim::Dynamic(_) => panic!("BatchedDenseSolve VJP: dynamic K not supported"),
742 };
743 let xt_shape = Shape::from_dims(
744 &[Dim::Static(b_dim), Dim::Static(k), Dim::Static(n)],
745 dtype,
746 );
747 let x_t = bwd.add_node(
748 Op::Transpose {
749 perm: vec![0, 2, 1],
750 },
751 vec![x_bwd],
752 xt_shape,
753 );
754 let outer = bwd.matmul(d_b, x_t, a_shape.clone());
755 bwd.activation(Activation::Neg, outer, a_shape)
756 }
757 r => panic!("BatchedDenseSolve VJP: b must be rank 2 or 3, got rank {r}"),
758 };
759
760 vec![(0, neg_outer), (1, d_b)]
761 }
762
763 Op::Scan {
764 body,
765 length,
766 save_trajectory,
767 num_bcast: _,
768 num_xs,
769 num_checkpoints,
770 } => {
771 let init_bwd = fwd_map[&node.inputs[0]];
779 let traj_bwd = fwd_map[&node.id];
780 let init_shape = bwd.node(init_bwd).shape.clone();
781
782 let mut body_input_ids: Vec<NodeId> = body
784 .nodes()
785 .iter()
786 .filter(|n| matches!(n.op, Op::Input { .. }))
787 .map(|n| n.id)
788 .collect();
789 body_input_ids.sort();
790
791 let body_vjp = grad(body, &body_input_ids);
792
793 let xs_bwd: Vec<NodeId> = (0..*num_xs as usize)
794 .map(|i| fwd_map[&node.inputs[1 + i]])
795 .collect();
796
797 let is_checkpointed = *num_checkpoints != 0 && *num_checkpoints != *length;
803 let forward_body_for_bwd = if is_checkpointed {
804 Some((**body).clone())
805 } else {
806 None
807 };
808
809 let dinit = bwd.scan_backward_with_checkpoints(
810 init_bwd,
811 traj_bwd,
812 upstream,
813 &xs_bwd,
814 body_vjp.clone(),
815 *length,
816 *save_trajectory,
817 *num_checkpoints,
818 forward_body_for_bwd.clone(),
819 init_shape,
820 );
821
822 let mut grads: Vec<(usize, NodeId)> = vec![(0, dinit)];
823 for i in 0..*num_xs as usize {
824 let outer_xs_id = node.inputs[1 + i];
825 let xs_shape = bwd.node(fwd_map[&outer_xs_id]).shape.clone();
826 let dxs_i = bwd.scan_backward_xs_with_checkpoints(
827 init_bwd,
828 traj_bwd,
829 upstream,
830 &xs_bwd,
831 body_vjp.clone(),
832 *length,
833 *save_trajectory,
834 i as u32,
835 *num_checkpoints,
836 forward_body_for_bwd.clone(),
837 xs_shape,
838 );
839 grads.push((1 + i, dxs_i));
840 }
841 grads
842 }
843
844 Op::Conv {
845 kernel_size,
846 stride,
847 padding,
848 dilation,
849 groups,
850 } => {
851 let x_bwd = fwd_map[&node.inputs[0]];
852 let w_bwd = fwd_map[&node.inputs[1]];
853 let x_shape = bwd.node(x_bwd).shape.clone();
854 let w_shape = bwd.node(w_bwd).shape.clone();
855 let dx = bwd.conv2d_backward_input(
856 upstream,
857 w_bwd,
858 x_shape,
859 kernel_size.clone(),
860 stride.clone(),
861 padding.clone(),
862 dilation.clone(),
863 *groups,
864 );
865 let _qat_bits: Option<u8> = None;
875 let dw = bwd.conv2d_backward_weight(
876 x_bwd,
877 upstream,
878 w_shape,
879 kernel_size.clone(),
880 stride.clone(),
881 padding.clone(),
882 dilation.clone(),
883 *groups,
884 );
885 vec![(0, dx), (1, dw)]
886 }
887
888 Op::Pool {
889 kind: ReduceOp::Max,
890 kernel_size,
891 stride,
892 padding,
893 } => {
894 let x_bwd = fwd_map[&node.inputs[0]];
895 let dx = bwd.maxpool2d_backward(
896 x_bwd,
897 upstream,
898 kernel_size.clone(),
899 stride.clone(),
900 padding.clone(),
901 );
902 vec![(0, dx)]
903 }
904
905 Op::SoftmaxCrossEntropyWithLogits => {
906 let logits_bwd = fwd_map[&node.inputs[0]];
907 let labels_bwd = fwd_map[&node.inputs[1]];
908 let dlogits = bwd.softmax_cross_entropy_backward(logits_bwd, labels_bwd, upstream);
909 vec![(0, dlogits)]
911 }
912
913 Op::Reduce {
914 op: ReduceOp::Sum,
915 axes,
916 keep_dim,
917 } => {
918 let x_bwd = fwd_map[&node.inputs[0]];
919 let x_shape = bwd.node(x_bwd).shape.clone();
920 let g = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
921 vec![(0, g)]
922 }
923
924 Op::Reduce {
925 op: ReduceOp::Mean,
926 axes,
927 keep_dim,
928 } => {
929 let x_bwd = fwd_map[&node.inputs[0]];
935 let x_shape = bwd.node(x_bwd).shape.clone();
936 let count: usize = axes
937 .iter()
938 .map(|&a| match x_shape.dim(a) {
939 Dim::Static(n) => n,
940 _ => panic!("Reduce::Mean VJP requires static reduced dims"),
941 })
942 .product();
943 let expanded = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
944 let inv_count = scalar_const(1.0 / count as f32, bwd);
945 let g = bwd.binary(BinaryOp::Mul, expanded, inv_count, x_shape);
946 vec![(0, g)]
947 }
948
949 Op::Reshape { .. } => {
950 let x_bwd = fwd_map[&node.inputs[0]];
951 let x_shape = bwd.node(x_bwd).shape.clone();
952 let dx = reshape_to(upstream, &x_shape, bwd);
953 vec![(0, dx)]
954 }
955
956 Op::ComplexNormSq => {
957 let z_bwd = fwd_map[&node.inputs[0]];
960 let dz = bwd.complex_norm_sq_backward(z_bwd, upstream);
961 vec![(0, dz)]
962 }
963
964 Op::Conjugate => {
965 let dz = bwd.conjugate(upstream);
971 vec![(0, dz)]
972 }
973
974 Op::Cast { .. } => {
975 let x_bwd = fwd_map[&node.inputs[0]];
976 let x_shape = bwd.node(x_bwd).shape.clone();
977 let dx = bwd.add_node(
978 Op::Cast {
979 to: x_shape.dtype(),
980 },
981 vec![upstream],
982 x_shape,
983 );
984 vec![(0, dx)]
985 }
986
987 Op::StopGradient => vec![],
992
993 Op::Quantize { .. } | Op::Dequantize { .. } => {
1005 vec![(0, upstream)]
1006 }
1007
1008 Op::FakeQuantizeLSQ { bits, axis } => {
1009 let x_bwd = fwd_map[&node.inputs[0]];
1013 let scale_bwd = fwd_map[&node.inputs[1]];
1014 let x_shape = bwd.node(x_bwd).shape.clone();
1015 let scale_shape = bwd.node(scale_bwd).shape.clone();
1016 let dx = bwd.add_node(
1017 Op::FakeQuantizeLSQBackwardX {
1018 bits: *bits,
1019 axis: *axis,
1020 },
1021 vec![x_bwd, scale_bwd, upstream],
1022 x_shape,
1023 );
1024 let dscale = bwd.add_node(
1025 Op::FakeQuantizeLSQBackwardScale {
1026 bits: *bits,
1027 axis: *axis,
1028 },
1029 vec![x_bwd, scale_bwd, upstream],
1030 scale_shape,
1031 );
1032 vec![(0, dx), (1, dscale)]
1033 }
1034
1035 Op::FakeQuantize {
1040 bits, axis, ste, ..
1041 } => {
1042 use rlx_ir::op::SteKind;
1043 match ste {
1044 SteKind::Identity => vec![(0, upstream)],
1045 _ => {
1046 let x_bwd = fwd_map[&node.inputs[0]];
1047 let x_shape = bwd.node(x_bwd).shape.clone();
1048 let dx = bwd.add_node(
1049 Op::FakeQuantizeBackward {
1050 bits: *bits,
1051 axis: *axis,
1052 ste: *ste,
1053 },
1054 vec![x_bwd, upstream],
1055 x_shape,
1056 );
1057 vec![(0, dx)]
1058 }
1059 }
1060 }
1061
1062 Op::Expand { .. } => {
1063 let x_bwd = fwd_map[&node.inputs[0]];
1064 let x_shape = bwd.node(x_bwd).shape.clone();
1065 let dx = unbroadcast(upstream, &x_shape, bwd);
1066 vec![(0, dx)]
1067 }
1068
1069 Op::BatchNormInference { eps } => {
1070 let x_bwd = fwd_map[&node.inputs[0]];
1071 let gamma_bwd = fwd_map[&node.inputs[1]];
1072 let _beta_bwd = fwd_map[&node.inputs[2]];
1073 let mean_bwd = fwd_map[&node.inputs[3]];
1074 let var_bwd = fwd_map[&node.inputs[4]];
1075 let gamma_shape = bwd.node(gamma_bwd).shape.clone();
1076 let dx = bwd.batch_norm_inference_backward_input(
1077 x_bwd, gamma_bwd, mean_bwd, var_bwd, upstream, *eps,
1078 );
1079 let dgamma = bwd.batch_norm_inference_backward_gamma(
1080 x_bwd,
1081 mean_bwd,
1082 var_bwd,
1083 upstream,
1084 gamma_shape.clone(),
1085 *eps,
1086 );
1087 let dbeta = bwd.batch_norm_inference_backward_beta(upstream, gamma_shape);
1088 vec![(0, dx), (1, dgamma), (2, dbeta)]
1090 }
1091
1092 Op::LayerNorm { axis, eps } => {
1093 let x_bwd = fwd_map[&node.inputs[0]];
1100 let gamma_bwd = fwd_map[&node.inputs[1]];
1101 let _beta_bwd = fwd_map[&node.inputs[2]];
1102 let gamma_shape = bwd.node(gamma_bwd).shape.clone();
1103
1104 let dx = bwd.layer_norm_backward_input(x_bwd, gamma_bwd, upstream, *axis, *eps);
1105 let dgamma =
1106 bwd.layer_norm_backward_gamma(x_bwd, upstream, gamma_shape.clone(), *axis, *eps);
1107 let dbeta = unbroadcast(upstream, &gamma_shape, bwd);
1108 vec![(0, dx), (1, dgamma), (2, dbeta)]
1109 }
1110
1111 Op::Softmax { axis } => {
1112 let y_bwd = fwd_map[&node.id];
1128 let y_shape = bwd.node(y_bwd).shape.clone();
1129 let dtype = y_shape.dtype();
1130 let rank = y_shape.rank();
1131 let axis_pos = if *axis < 0 {
1132 (rank as i32 + *axis) as usize
1133 } else {
1134 *axis as usize
1135 };
1136
1137 let yg = bwd.binary(BinaryOp::Mul, y_bwd, upstream, y_shape.clone());
1138
1139 let mut kept_dims: Vec<Dim> = y_shape.dims().to_vec();
1140 kept_dims[axis_pos] = Dim::Static(1);
1141 let kept_shape = Shape::from_dims(&kept_dims, dtype);
1142 let s = bwd.add_node(
1143 Op::Reduce {
1144 op: ReduceOp::Sum,
1145 axes: vec![axis_pos],
1146 keep_dim: true,
1147 },
1148 vec![yg],
1149 kept_shape,
1150 );
1151
1152 let target_dims: Vec<i64> = y_shape
1153 .dims()
1154 .iter()
1155 .map(|d| match d {
1156 Dim::Static(n) => *n as i64,
1157 Dim::Dynamic(_) => -1,
1158 })
1159 .collect();
1160 let s_expanded = bwd.add_node(
1161 Op::Expand {
1162 target_shape: target_dims,
1163 },
1164 vec![s],
1165 y_shape.clone(),
1166 );
1167
1168 let diff = bwd.binary(BinaryOp::Sub, upstream, s_expanded, y_shape.clone());
1169 let dx = bwd.binary(BinaryOp::Mul, y_bwd, diff, y_shape);
1170 vec![(0, dx)]
1171 }
1172
1173 Op::Transpose { perm } => {
1175 let inv: Vec<usize> = {
1178 let mut v = vec![0usize; perm.len()];
1179 for (i, &p) in perm.iter().enumerate() {
1180 v[p] = i;
1181 }
1182 v
1183 };
1184 let x_bwd = fwd_map[&node.inputs[0]];
1185 let x_shape = bwd.node(x_bwd).shape.clone();
1186 let dx = bwd.add_node(Op::Transpose { perm: inv }, vec![upstream], x_shape);
1187 vec![(0, dx)]
1188 }
1189
1190 Op::Concat { axis } => {
1191 let mut grads = Vec::with_capacity(node.inputs.len());
1194 let mut offset: usize = 0;
1195 for (i, &input_id) in node.inputs.iter().enumerate() {
1196 let x_bwd = fwd_map[&input_id];
1197 let x_shape = bwd.node(x_bwd).shape.clone();
1198 let len = match x_shape.dim(*axis) {
1199 Dim::Static(n) => n,
1200 _ => panic!("Concat VJP: dynamic concat dim"),
1201 };
1202 let dx = bwd.add_node(
1203 Op::Narrow {
1204 axis: *axis,
1205 start: offset,
1206 len,
1207 },
1208 vec![upstream],
1209 x_shape,
1210 );
1211 grads.push((i, dx));
1212 offset += len;
1213 }
1214 grads
1215 }
1216
1217 Op::Narrow { axis, start, len } => {
1218 let x_bwd = fwd_map[&node.inputs[0]];
1222 let x_shape = bwd.node(x_bwd).shape.clone();
1223 let full_n = match x_shape.dim(*axis) {
1224 Dim::Static(n) => n,
1225 _ => panic!("Narrow VJP: dynamic axis"),
1226 };
1227 let pre = *start;
1228 let post = full_n - *start - *len;
1229
1230 let zero_buf = |bwd: &mut Graph, len_axis: usize| -> NodeId {
1231 if len_axis == 0 {
1232 return upstream; }
1234 let dtype = x_shape.dtype();
1235 let mut dims: Vec<Dim> = x_shape.dims().to_vec();
1236 dims[*axis] = Dim::Static(len_axis);
1237 let s = Shape::from_dims(&dims, dtype);
1238 let n_elems = dims.iter().fold(1usize, |a, d| match d {
1239 Dim::Static(k) => a * k,
1240 _ => a,
1241 });
1242 let bytes = vec![0u8; n_elems * dtype.size_bytes()];
1246 bwd.add_node(Op::Constant { data: bytes }, vec![], s)
1247 };
1248
1249 let mut parts: Vec<NodeId> = Vec::new();
1250 if pre > 0 {
1251 parts.push(zero_buf(bwd, pre));
1252 }
1253 parts.push(upstream);
1254 if post > 0 {
1255 parts.push(zero_buf(bwd, post));
1256 }
1257
1258 let dx = if parts.len() == 1 {
1259 parts[0]
1260 } else {
1261 bwd.add_node(Op::Concat { axis: *axis }, parts, x_shape)
1262 };
1263 vec![(0, dx)]
1264 }
1265
1266 Op::Gather { axis } => {
1267 let table_bwd = fwd_map[&node.inputs[0]];
1268 let indices_bwd = fwd_map[&node.inputs[1]];
1269 let table_shape = bwd.node(table_bwd).shape.clone();
1270 if *axis == 0 {
1271 let dtable = bwd.add_node(Op::ScatterAdd, vec![upstream, indices_bwd], table_shape);
1272 vec![(0, dtable)]
1273 } else {
1274 let dtable = bwd.gather_backward(
1275 upstream,
1276 indices_bwd,
1277 table_shape,
1278 (*axis).try_into().unwrap(),
1279 );
1280 vec![(0, dtable)]
1281 }
1282 }
1283
1284 Op::Compare(_) => {
1286 vec![]
1291 }
1292
1293 Op::Where => {
1294 let cond = fwd_map[&node.inputs[0]];
1298 let a_bwd = fwd_map[&node.inputs[1]];
1299 let b_bwd = fwd_map[&node.inputs[2]];
1300 let a_shape = bwd.node(a_bwd).shape.clone();
1301 let b_shape = bwd.node(b_bwd).shape.clone();
1302 let out_shape = upstream_shape.clone();
1303
1304 let zero_a_bytes = vec![0u8; a_shape.num_elements().expect("Where VJP: dynamic a") * 4];
1305 let zero_b_bytes = vec![0u8; b_shape.num_elements().expect("Where VJP: dynamic b") * 4];
1306 let zero_a = bwd.add_node(Op::Constant { data: zero_a_bytes }, vec![], a_shape.clone());
1307 let zero_b = bwd.add_node(Op::Constant { data: zero_b_bytes }, vec![], b_shape.clone());
1308 let zero_a_bcast = unbroadcast_inverse(zero_a, &out_shape, bwd);
1311 let zero_b_bcast = unbroadcast_inverse(zero_b, &out_shape, bwd);
1312 let g_a_full = bwd.add_node(
1313 Op::Where,
1314 vec![cond, upstream, zero_a_bcast],
1315 out_shape.clone(),
1316 );
1317 let g_b_full = bwd.add_node(Op::Where, vec![cond, zero_b_bcast, upstream], out_shape);
1318 let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1319 let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1320 vec![(1, g_a), (2, g_b)]
1321 }
1322
1323 Op::Binary(BinaryOp::Div) => {
1325 let a_bwd = fwd_map[&node.inputs[0]];
1333 let b_bwd = fwd_map[&node.inputs[1]];
1334 let y_bwd = fwd_map[&node.id];
1335 let a_shape = bwd.node(a_bwd).shape.clone();
1336 let b_shape = bwd.node(b_bwd).shape.clone();
1337 let is_c64 = upstream_shape.dtype() == DType::C64;
1338
1339 let b_term = if is_c64 { bwd.conjugate(b_bwd) } else { b_bwd };
1340 let y_term = if is_c64 { bwd.conjugate(y_bwd) } else { y_bwd };
1341
1342 let g_a_full = bwd.binary(BinaryOp::Div, upstream, b_term, upstream_shape.clone());
1344 let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1345
1346 let neg_up = bwd.activation(Activation::Neg, upstream, upstream_shape.clone());
1348 let neg_up_y = bwd.binary(BinaryOp::Mul, neg_up, y_term, upstream_shape.clone());
1349 let g_b_full = bwd.binary(BinaryOp::Div, neg_up_y, b_term, upstream_shape);
1350 let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1351
1352 vec![(0, g_a), (1, g_b)]
1353 }
1354
1355 Op::Reduce {
1357 op: ReduceOp::Max,
1358 axes,
1359 keep_dim,
1360 }
1361 | Op::Reduce {
1362 op: ReduceOp::Min,
1363 axes,
1364 keep_dim,
1365 } => {
1366 let is_max = matches!(
1370 node.op,
1371 Op::Reduce {
1372 op: ReduceOp::Max,
1373 ..
1374 }
1375 );
1376 let _ = is_max;
1377 let x_bwd = fwd_map[&node.inputs[0]];
1378 let y_bwd = fwd_map[&node.id];
1379 let x_shape = bwd.node(x_bwd).shape.clone();
1380 let y_expanded = expand_to(y_bwd, &x_shape, axes, *keep_dim, bwd);
1381 let mask_bool = bwd.add_node(
1382 Op::Compare(CmpOp::Eq),
1383 vec![x_bwd, y_expanded],
1384 Shape::from_dims(x_shape.dims(), DType::F32),
1385 );
1386 let mask_f32 = bwd.add_node(
1390 Op::Cast {
1391 to: x_shape.dtype(),
1392 },
1393 vec![mask_bool],
1394 x_shape.clone(),
1395 );
1396 let upstream_expanded = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
1397 let dx = bwd.binary(BinaryOp::Mul, upstream_expanded, mask_f32, x_shape);
1398 vec![(0, dx)]
1399 }
1400
1401 Op::Rope { head_dim, n_rot } => {
1407 let cos = fwd_map[&node.inputs[1]];
1408 let sin = fwd_map[&node.inputs[2]];
1409 let dx = bwd.rope_backward(upstream, cos, sin, *head_dim, *n_rot);
1410 vec![(0, dx)]
1411 }
1412
1413 Op::RmsNorm { axis, eps } => {
1414 let x = fwd_map[&node.inputs[0]];
1415 let gamma = fwd_map[&node.inputs[1]];
1416 let beta = fwd_map[&node.inputs[2]];
1417 let dx = bwd.rms_norm_backward_input(x, gamma, beta, upstream, *axis, *eps);
1418 let dgamma = bwd.rms_norm_backward_gamma(x, gamma, beta, upstream, *axis, *eps);
1419 let dbeta = bwd.rms_norm_backward_beta(x, gamma, beta, upstream, *axis, *eps);
1420 vec![(0, dx), (1, dgamma), (2, dbeta)]
1421 }
1422
1423 Op::GroupNorm { num_groups, eps } => {
1424 let x = fwd_map[&node.inputs[0]];
1425 let gamma = fwd_map[&node.inputs[1]];
1426 let beta = fwd_map[&node.inputs[2]];
1427 let gamma_shape = bwd.node(gamma).shape.clone();
1428 let beta_shape = bwd.node(beta).shape.clone();
1429 let dx = bwd.group_norm_backward_input(x, gamma, beta, upstream, *num_groups, *eps);
1430 let dgamma = bwd.group_norm_backward_gamma(x, upstream, gamma_shape, *num_groups, *eps);
1431 let dbeta = bwd.group_norm_backward_beta(x, upstream, beta_shape, *num_groups, *eps);
1432 vec![(0, dx), (1, dgamma), (2, dbeta)]
1433 }
1434
1435 Op::Attention {
1437 num_heads,
1438 head_dim,
1439 mask_kind,
1440 score_scale: _,
1441 attn_logit_softcap: _,
1442 } => {
1443 let q = fwd_map[&node.inputs[0]];
1444 let k = fwd_map[&node.inputs[1]];
1445 let v = fwd_map[&node.inputs[2]];
1446 let mask = match mask_kind {
1447 MaskKind::Custom | MaskKind::Bias => Some(fwd_map[&node.inputs[3]]),
1448 _ => None,
1449 };
1450 let (dq, dk, dv) = bwd
1451 .attention_backward_all(q, k, v, upstream, *num_heads, *head_dim, *mask_kind, mask);
1452 vec![(0, dq), (1, dk), (2, dv)]
1453 }
1454
1455 Op::Reduce {
1462 op: ReduceOp::Prod,
1463 axes,
1464 keep_dim,
1465 } => {
1466 let x_bwd = fwd_map[&node.inputs[0]];
1467 let y_bwd = fwd_map[&node.id];
1468 let x_shape = bwd.node(x_bwd).shape.clone();
1469 let y_expanded = expand_to(y_bwd, &x_shape, axes, *keep_dim, bwd);
1470 let upstream_expanded = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
1471 let num = bwd.binary(
1473 BinaryOp::Mul,
1474 upstream_expanded,
1475 y_expanded,
1476 x_shape.clone(),
1477 );
1478 let dx = bwd.binary(BinaryOp::Div, num, x_bwd, x_shape);
1479 vec![(0, dx)]
1480 }
1481
1482 Op::Pool {
1494 kind: ReduceOp::Mean,
1495 kernel_size,
1496 stride,
1497 padding,
1498 } => {
1499 assert_eq!(kernel_size.len(), 2, "Pool(Mean) VJP: 2-D pool only");
1500 let x_bwd = fwd_map[&node.inputs[0]];
1501 let x_shape = bwd.node(x_bwd).shape.clone();
1502 let dtype = x_shape.dtype();
1503 let c = match x_shape.dim(1) {
1505 Dim::Static(n) => n,
1506 _ => panic!("Pool(Mean) VJP: dynamic channel dim"),
1507 };
1508 let kh = kernel_size[0];
1509 let kw = kernel_size[1];
1510 let inv_n = 1.0_f32 / (kh as f32 * kw as f32);
1511 let kernel_n = c * kh * kw;
1512 let mut bytes: Vec<u8> = Vec::with_capacity(kernel_n * 4);
1513 for _ in 0..kernel_n {
1514 bytes.extend_from_slice(&inv_n.to_le_bytes());
1515 }
1516 let kernel_shape = Shape::from_dims(
1517 &[
1518 Dim::Static(c),
1519 Dim::Static(1),
1520 Dim::Static(kh),
1521 Dim::Static(kw),
1522 ],
1523 dtype,
1524 );
1525 let kernel = bwd.add_node(Op::Constant { data: bytes }, vec![], kernel_shape);
1526 let dx = bwd.conv2d_backward_input(
1527 upstream,
1528 kernel,
1529 x_shape,
1530 kernel_size.clone(),
1531 stride.clone(),
1532 padding.clone(),
1533 vec![1, 1],
1534 c, );
1536 vec![(0, dx)]
1537 }
1538
1539 Op::Binary(BinaryOp::Min) | Op::Binary(BinaryOp::Max) => {
1546 let a_bwd = fwd_map[&node.inputs[0]];
1547 let b_bwd = fwd_map[&node.inputs[1]];
1548 let y_bwd = fwd_map[&node.id];
1549 let a_shape = bwd.node(a_bwd).shape.clone();
1550 let b_shape = bwd.node(b_bwd).shape.clone();
1551 let dtype = upstream_shape.dtype();
1552
1553 let bool_shape = Shape::from_dims(upstream_shape.dims(), DType::Bool);
1554 let mask_pred = bwd.add_node(Op::Compare(CmpOp::Eq), vec![a_bwd, y_bwd], bool_shape);
1555 let mask_f32 = bwd.add_node(
1556 Op::Cast { to: dtype },
1557 vec![mask_pred],
1558 upstream_shape.clone(),
1559 );
1560 let zero_bytes = vec![
1561 0u8;
1562 upstream_shape
1563 .num_elements()
1564 .expect("Min/Max VJP: dyn shape")
1565 * 4
1566 ];
1567 let zero = bwd.add_node(
1568 Op::Constant { data: zero_bytes },
1569 vec![],
1570 upstream_shape.clone(),
1571 );
1572 let g_a_full = bwd.add_node(
1573 Op::Where,
1574 vec![mask_f32, upstream, zero],
1575 upstream_shape.clone(),
1576 );
1577 let g_b_full = bwd.add_node(Op::Where, vec![mask_f32, zero, upstream], upstream_shape);
1578 let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1579 let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1580 vec![(0, g_a), (1, g_b)]
1581 }
1582
1583 Op::Binary(BinaryOp::Pow) => {
1592 let a_bwd = fwd_map[&node.inputs[0]];
1593 let b_bwd = fwd_map[&node.inputs[1]];
1594 let y_bwd = fwd_map[&node.id]; let a_shape = bwd.node(a_bwd).shape.clone();
1596 let b_shape = bwd.node(b_bwd).shape.clone();
1597
1598 let yb = bwd.binary(BinaryOp::Mul, y_bwd, b_bwd, upstream_shape.clone());
1601 let yb_over_a = bwd.binary(BinaryOp::Div, yb, a_bwd, upstream_shape.clone());
1602 let g_a_full = bwd.binary(BinaryOp::Mul, upstream, yb_over_a, upstream_shape.clone());
1603 let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1604
1605 let ln_a = bwd.activation(Activation::Log, a_bwd, a_shape);
1607 let ln_a_b = unbroadcast_inverse(ln_a, &upstream_shape, bwd);
1608 let yln = bwd.binary(BinaryOp::Mul, y_bwd, ln_a_b, upstream_shape.clone());
1609 let g_b_full = bwd.binary(BinaryOp::Mul, upstream, yln, upstream_shape);
1610 let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1611
1612 vec![(0, g_a), (1, g_b)]
1613 }
1614
1615 Op::DequantMatMul { scheme: _ } => {
1634 let x_bwd = fwd_map[&node.inputs[0]];
1635 let w_q_bwd = fwd_map[&node.inputs[1]];
1636 let scale_bwd = fwd_map[&node.inputs[2]];
1637 let zp_bwd = fwd_map[&node.inputs[3]];
1638 let x_shape = bwd.node(x_bwd).shape.clone();
1639 let w_shape = bwd.node(w_q_bwd).shape.clone();
1640 let scale_shape = bwd.node(scale_bwd).shape.clone();
1641 let zp_shape = bwd.node(zp_bwd).shape.clone();
1642
1643 let dtype = x_shape.dtype();
1647 let w_q_f32 = bwd.add_node(
1648 Op::Cast { to: dtype },
1649 vec![w_q_bwd],
1650 Shape::from_dims(w_shape.dims(), dtype),
1651 );
1652 let scale_b =
1654 unbroadcast_inverse(scale_bwd, &Shape::from_dims(w_shape.dims(), dtype), bwd);
1655 let zp_b = unbroadcast_inverse(zp_bwd, &Shape::from_dims(w_shape.dims(), dtype), bwd);
1656 let w_centered = bwd.binary(
1657 BinaryOp::Sub,
1658 w_q_f32,
1659 zp_b,
1660 Shape::from_dims(w_shape.dims(), dtype),
1661 );
1662 let w_dq = bwd.binary(
1663 BinaryOp::Mul,
1664 w_centered,
1665 scale_b,
1666 Shape::from_dims(w_shape.dims(), dtype),
1667 );
1668
1669 let w_rank = w_shape.rank();
1671 let mut perm: Vec<usize> = (0..w_rank).collect();
1672 perm.swap(w_rank - 2, w_rank - 1);
1673 let mut wdt_dims: Vec<Dim> = w_shape.dims().to_vec();
1674 wdt_dims.swap(w_rank - 2, w_rank - 1);
1675 let w_dq_t_shape = Shape::from_dims(&wdt_dims, dtype);
1676 let w_dq_t = bwd.add_node(Op::Transpose { perm }, vec![w_dq], w_dq_t_shape);
1677 let dx = bwd.matmul(upstream, w_dq_t, x_shape.clone());
1678
1679 let x_rank = x_shape.rank();
1684 let mut x_perm: Vec<usize> = (0..x_rank).collect();
1685 x_perm.swap(x_rank - 2, x_rank - 1);
1686 let mut x_t_dims: Vec<Dim> = x_shape.dims().to_vec();
1687 x_t_dims.swap(x_rank - 2, x_rank - 1);
1688 let x_t = bwd.add_node(
1689 Op::Transpose { perm: x_perm },
1690 vec![x_bwd],
1691 Shape::from_dims(&x_t_dims, dtype),
1692 );
1693 let dw_unscaled = bwd.matmul(x_t, upstream, Shape::from_dims(w_shape.dims(), dtype));
1694 let dw_q_f32 = bwd.binary(
1695 BinaryOp::Mul,
1696 dw_unscaled,
1697 scale_b,
1698 Shape::from_dims(w_shape.dims(), dtype),
1699 );
1700 let dw_q = bwd.add_node(
1702 Op::Cast {
1703 to: w_shape.dtype(),
1704 },
1705 vec![dw_q_f32],
1706 w_shape,
1707 );
1708
1709 let zero_scale_bytes =
1711 vec![0u8; scale_shape.num_elements().expect("DQMM VJP: dyn scale") * 4];
1712 let zero_zp_bytes = vec![0u8; zp_shape.num_elements().expect("DQMM VJP: dyn zp") * 4];
1713 let dscale = bwd.add_node(
1714 Op::Constant {
1715 data: zero_scale_bytes,
1716 },
1717 vec![],
1718 scale_shape,
1719 );
1720 let dzp = bwd.add_node(
1721 Op::Constant {
1722 data: zero_zp_bytes,
1723 },
1724 vec![],
1725 zp_shape,
1726 );
1727
1728 vec![(0, dx), (1, dw_q), (2, dscale), (3, dzp)]
1729 }
1730
1731 Op::ScatterAdd => {
1737 let updates_bwd = fwd_map[&node.inputs[0]];
1738 let indices_bwd = fwd_map[&node.inputs[1]];
1739 let updates_shape = bwd.node(updates_bwd).shape.clone();
1740 let dupdates = bwd.add_node(
1741 Op::Gather { axis: 0 },
1742 vec![upstream, indices_bwd],
1743 updates_shape,
1744 );
1745 vec![(0, dupdates)]
1746 }
1747
1748 Op::Cumsum { axis, exclusive } => {
1751 let x_bwd = fwd_map[&node.inputs[0]];
1752 let x_shape = bwd.node(x_bwd).shape.clone();
1753 let dx = bwd.cumsum_backward(upstream, x_shape, *axis, *exclusive);
1754 vec![(0, dx)]
1755 }
1756
1757 Op::GroupedMatMul => {
1770 let x_bwd = fwd_map[&node.inputs[0]];
1771 let w_bwd = fwd_map[&node.inputs[1]];
1772 let expert_bwd = fwd_map[&node.inputs[2]];
1773 let x_shape = bwd.node(x_bwd).shape.clone();
1774 let w_shape = bwd.node(w_bwd).shape.clone();
1775 let (dx, dw) =
1776 grouped_matmul_vjp(bwd, upstream, x_bwd, w_bwd, expert_bwd, &x_shape, &w_shape);
1777 vec![(0, dx), (1, dw)]
1778 }
1779
1780 Op::DequantGroupedMatMul { scheme } => {
1786 let x_bwd = fwd_map[&node.inputs[0]];
1787 let w_packed = fwd_map[&node.inputs[1]];
1788 let expert_bwd = fwd_map[&node.inputs[2]];
1789 let x_shape = bwd.node(x_bwd).shape.clone();
1790 let w_packed_shape = bwd.node(w_packed).shape.clone();
1791 let dtype = x_shape.dtype();
1792 let k = x_shape.dim(1);
1793 let n_out = node.shape.dim(node.shape.rank() - 1);
1794 let k_static = match k {
1795 Dim::Static(v) => v,
1796 _ => panic!("DequantGroupedMatMul VJP: K must be static"),
1797 };
1798 let n_static = match n_out {
1799 Dim::Static(v) => v,
1800 _ => panic!("DequantGroupedMatMul VJP: N must be static"),
1801 };
1802 let block_elems = scheme.gguf_block_size() as usize;
1803 let block_bytes = scheme.gguf_block_bytes() as usize;
1804 let slab_bytes = (k_static * n_static) / block_elems * block_bytes;
1805 let total_bytes = w_packed_shape
1806 .num_elements()
1807 .expect("DequantGroupedMatMul VJP: dyn packed");
1808 let e_static = total_bytes / slab_bytes.max(1);
1809 let w_shape = Shape::from_dims(
1810 &[
1811 Dim::Static(e_static),
1812 Dim::Static(k_static),
1813 Dim::Static(n_static),
1814 ],
1815 dtype,
1816 );
1817 let w_dq = bwd.add_node(
1818 Op::DequantMoEWeights { scheme: *scheme },
1819 vec![w_packed],
1820 w_shape.clone(),
1821 );
1822 let (dx, _dw) =
1823 grouped_matmul_vjp(bwd, upstream, x_bwd, w_dq, expert_bwd, &x_shape, &w_shape);
1824 vec![(0, dx)]
1825 }
1826
1827 Op::QMatMul {
1840 x_zp,
1841 w_zp,
1842 out_zp: _,
1843 mult,
1844 } => {
1845 let x_bwd = fwd_map[&node.inputs[0]];
1846 let w_bwd = fwd_map[&node.inputs[1]];
1847 let bias_bwd = fwd_map[&node.inputs[2]];
1848 let x_shape = bwd.node(x_bwd).shape.clone();
1849 let w_shape = bwd.node(w_bwd).shape.clone();
1850 let bias_shape = bwd.node(bias_bwd).shape.clone();
1851 let dtype = upstream_shape.dtype();
1852
1853 let x_f32 = bwd.add_node(
1855 Op::Cast { to: dtype },
1856 vec![x_bwd],
1857 Shape::from_dims(x_shape.dims(), dtype),
1858 );
1859 let w_f32 = bwd.add_node(
1860 Op::Cast { to: dtype },
1861 vec![w_bwd],
1862 Shape::from_dims(w_shape.dims(), dtype),
1863 );
1864 let xzp_c = scalar_const(*x_zp as f32, bwd);
1865 let xzp_b = unbroadcast_inverse(xzp_c, &Shape::from_dims(x_shape.dims(), dtype), bwd);
1866 let _ = bwd.binary(
1867 BinaryOp::Sub,
1868 x_f32,
1869 xzp_b,
1870 Shape::from_dims(x_shape.dims(), dtype),
1871 );
1872 let wzp_c = scalar_const(*w_zp as f32, bwd);
1873 let wzp_b = unbroadcast_inverse(wzp_c, &Shape::from_dims(w_shape.dims(), dtype), bwd);
1874 let w_centered = bwd.binary(
1875 BinaryOp::Sub,
1876 w_f32,
1877 wzp_b,
1878 Shape::from_dims(w_shape.dims(), dtype),
1879 );
1880
1881 let mult_c = scalar_const(*mult, bwd);
1883 let mult_b = unbroadcast_inverse(mult_c, &upstream_shape, bwd);
1884 let upstream_scaled =
1885 bwd.binary(BinaryOp::Mul, upstream, mult_b, upstream_shape.clone());
1886
1887 let w_rank = w_shape.rank();
1890 let mut perm: Vec<usize> = (0..w_rank).collect();
1891 perm.swap(w_rank - 2, w_rank - 1);
1892 let mut wt_dims: Vec<Dim> = w_shape.dims().to_vec();
1893 wt_dims.swap(w_rank - 2, w_rank - 1);
1894 let w_t = bwd.add_node(
1895 Op::Transpose { perm },
1896 vec![w_centered],
1897 Shape::from_dims(&wt_dims, dtype),
1898 );
1899 let dx_f32 = bwd.matmul(
1900 upstream_scaled,
1901 w_t,
1902 Shape::from_dims(x_shape.dims(), dtype),
1903 );
1904 let dx = bwd.add_node(
1905 Op::Cast {
1906 to: x_shape.dtype(),
1907 },
1908 vec![dx_f32],
1909 x_shape.clone(),
1910 );
1911
1912 let x_rank = x_shape.rank();
1914 let mut x_perm: Vec<usize> = (0..x_rank).collect();
1915 x_perm.swap(x_rank - 2, x_rank - 1);
1916 let mut xt_dims: Vec<Dim> = x_shape.dims().to_vec();
1917 xt_dims.swap(x_rank - 2, x_rank - 1);
1918 let x_f32_2 = bwd.add_node(
1920 Op::Cast { to: dtype },
1921 vec![x_bwd],
1922 Shape::from_dims(x_shape.dims(), dtype),
1923 );
1924 let x_centered = bwd.binary(
1925 BinaryOp::Sub,
1926 x_f32_2,
1927 xzp_b,
1928 Shape::from_dims(x_shape.dims(), dtype),
1929 );
1930 let x_t = bwd.add_node(
1931 Op::Transpose { perm: x_perm },
1932 vec![x_centered],
1933 Shape::from_dims(&xt_dims, dtype),
1934 );
1935 let dw_f32 = bwd.matmul(
1936 x_t,
1937 upstream_scaled,
1938 Shape::from_dims(w_shape.dims(), dtype),
1939 );
1940 let dw = bwd.add_node(
1941 Op::Cast {
1942 to: w_shape.dtype(),
1943 },
1944 vec![dw_f32],
1945 w_shape,
1946 );
1947
1948 let bias_rank = bias_shape.rank();
1951 let reduce_axes: Vec<usize> = (0..upstream_shape.rank())
1952 .filter(|&i| i + bias_rank < upstream_shape.rank() || i == 0)
1953 .collect();
1954 let dbias_f32 = bwd.add_node(
1955 Op::Reduce {
1956 op: ReduceOp::Sum,
1957 axes: reduce_axes,
1958 keep_dim: false,
1959 },
1960 vec![upstream_scaled],
1961 Shape::from_dims(bias_shape.dims(), dtype),
1962 );
1963 let dbias = bwd.add_node(
1964 Op::Cast {
1965 to: bias_shape.dtype(),
1966 },
1967 vec![dbias_f32],
1968 bias_shape,
1969 );
1970
1971 vec![(0, dx), (1, dw), (2, dbias)]
1972 }
1973
1974 Op::QConv2d {
1975 kernel_size,
1976 stride,
1977 padding,
1978 dilation,
1979 groups,
1980 x_zp,
1981 w_zp,
1982 out_zp: _,
1983 mult,
1984 } => {
1985 let x_bwd = fwd_map[&node.inputs[0]];
1989 let w_bwd = fwd_map[&node.inputs[1]];
1990 let bias_bwd = fwd_map[&node.inputs[2]];
1991 let x_shape = bwd.node(x_bwd).shape.clone();
1992 let w_shape = bwd.node(w_bwd).shape.clone();
1993 let bias_shape = bwd.node(bias_bwd).shape.clone();
1994 let dtype = upstream_shape.dtype();
1995
1996 let x_f32 = bwd.add_node(
1998 Op::Cast { to: dtype },
1999 vec![x_bwd],
2000 Shape::from_dims(x_shape.dims(), dtype),
2001 );
2002 let w_f32 = bwd.add_node(
2003 Op::Cast { to: dtype },
2004 vec![w_bwd],
2005 Shape::from_dims(w_shape.dims(), dtype),
2006 );
2007 let xzp_c = scalar_const(*x_zp as f32, bwd);
2008 let xzp_b = unbroadcast_inverse(xzp_c, &Shape::from_dims(x_shape.dims(), dtype), bwd);
2009 let x_centered = bwd.binary(
2010 BinaryOp::Sub,
2011 x_f32,
2012 xzp_b,
2013 Shape::from_dims(x_shape.dims(), dtype),
2014 );
2015 let wzp_c = scalar_const(*w_zp as f32, bwd);
2016 let wzp_b = unbroadcast_inverse(wzp_c, &Shape::from_dims(w_shape.dims(), dtype), bwd);
2017 let w_centered = bwd.binary(
2018 BinaryOp::Sub,
2019 w_f32,
2020 wzp_b,
2021 Shape::from_dims(w_shape.dims(), dtype),
2022 );
2023
2024 let mult_c = scalar_const(*mult, bwd);
2026 let mult_b = unbroadcast_inverse(mult_c, &upstream_shape, bwd);
2027 let upstream_scaled =
2028 bwd.binary(BinaryOp::Mul, upstream, mult_b, upstream_shape.clone());
2029
2030 let dx_f32 = bwd.conv2d_backward_input(
2032 upstream_scaled,
2033 w_centered,
2034 Shape::from_dims(x_shape.dims(), dtype),
2035 kernel_size.clone(),
2036 stride.clone(),
2037 padding.clone(),
2038 dilation.clone(),
2039 *groups,
2040 );
2041 let dx = bwd.add_node(
2042 Op::Cast {
2043 to: x_shape.dtype(),
2044 },
2045 vec![dx_f32],
2046 x_shape,
2047 );
2048 let dw_f32 = bwd.conv2d_backward_weight(
2049 x_centered,
2050 upstream_scaled,
2051 Shape::from_dims(w_shape.dims(), dtype),
2052 kernel_size.clone(),
2053 stride.clone(),
2054 padding.clone(),
2055 dilation.clone(),
2056 *groups,
2057 );
2058 let dw = bwd.add_node(
2059 Op::Cast {
2060 to: w_shape.dtype(),
2061 },
2062 vec![dw_f32],
2063 w_shape,
2064 );
2065
2066 let dbias_f32 = bwd.add_node(
2068 Op::Reduce {
2069 op: ReduceOp::Sum,
2070 axes: vec![0, 2, 3],
2071 keep_dim: false,
2072 },
2073 vec![upstream_scaled],
2074 Shape::from_dims(bias_shape.dims(), dtype),
2075 );
2076 let dbias = bwd.add_node(
2077 Op::Cast {
2078 to: bias_shape.dtype(),
2079 },
2080 vec![dbias_f32],
2081 bias_shape,
2082 );
2083
2084 vec![(0, dx), (1, dw), (2, dbias)]
2085 }
2086
2087 Op::TopK { .. } | Op::Sample { .. } | Op::RngNormal { .. } | Op::RngUniform { .. } => {
2089 vec![]
2093 }
2094
2095 Op::GaussianSplatRender {
2096 width,
2097 height,
2098 tile_size,
2099 radius_scale,
2100 alpha_cutoff,
2101 max_splat_steps,
2102 transmittance_threshold,
2103 max_list_entries,
2104 ..
2105 } => {
2106 use rlx_ir::ops::splat::{
2107 GaussianSplatBackwardParams, GaussianSplatInputs, GaussianSplatRenderParams,
2108 unpack_gaussian_splat_packed_grads,
2109 };
2110 let render = GaussianSplatRenderParams {
2111 width: *width,
2112 height: *height,
2113 tile_size: *tile_size,
2114 radius_scale: *radius_scale,
2115 alpha_cutoff: *alpha_cutoff,
2116 max_splat_steps: *max_splat_steps,
2117 transmittance_threshold: *transmittance_threshold,
2118 max_list_entries: *max_list_entries,
2119 };
2120 let inputs = GaussianSplatInputs {
2121 positions: fwd_map[&node.inputs[0]],
2122 scales: fwd_map[&node.inputs[1]],
2123 rotations: fwd_map[&node.inputs[2]],
2124 opacities: fwd_map[&node.inputs[3]],
2125 colors: fwd_map[&node.inputs[4]],
2126 sh_coeffs: fwd_map[&node.inputs[5]],
2127 meta: fwd_map[&node.inputs[6]],
2128 };
2129 let count = bwd.shape(inputs.positions).num_elements().unwrap_or(0) / 3;
2130 let sh_len = bwd.shape(inputs.sh_coeffs).num_elements().unwrap_or(0);
2131 let meta_shape = bwd.shape(inputs.meta).clone();
2132 let packed = bwd.gaussian_splat_render_backward(
2133 inputs,
2134 upstream,
2135 GaussianSplatBackwardParams {
2136 render,
2137 loss_grad_clip: 1.0,
2138 sh_band: 0,
2139 max_anisotropy: 10.0,
2140 },
2141 );
2142 let sh_coeff_count = if count == 0 {
2143 1
2144 } else {
2145 (sh_len / (count * 3)).max(1)
2146 };
2147 let grads = unpack_gaussian_splat_packed_grads(bwd, packed, count, sh_coeff_count);
2148 let meta_n = meta_shape.num_elements().unwrap_or(0);
2149 let zero_meta = bwd.add_node(
2150 Op::Constant {
2151 data: vec![0u8; meta_n * meta_shape.dtype().size_bytes()],
2152 },
2153 vec![],
2154 meta_shape,
2155 );
2156 vec![
2157 (0, grads.positions),
2158 (1, grads.scales),
2159 (2, grads.rotations),
2160 (3, grads.opacities),
2161 (4, grads.colors),
2162 (5, grads.sh_coeffs),
2163 (6, zero_meta),
2164 ]
2165 }
2166
2167 Op::GaussianSplatRenderBackward { .. } => {
2168 vec![]
2170 }
2171
2172 Op::GaussianSplatPrepare { .. } | Op::GaussianSplatRasterize { .. } => {
2173 panic!(
2174 "autodiff: decomposed splat ops must be fused before AD — \
2175 `prepare_graph_for_ad` rewrites Prepare→Rasterize into \
2176 `GaussianSplatRender`, or use `Op::GaussianSplatRender` directly"
2177 );
2178 }
2179
2180 Op::CustomFn {
2202 vjp_body: Some(vjp_body),
2203 num_inputs,
2204 ..
2205 } => {
2206 let mut sub_to_bwd: HashMap<NodeId, NodeId> = HashMap::new();
2208
2209 let mut primal_input_ids: Vec<NodeId> = vjp_body
2213 .nodes()
2214 .iter()
2215 .filter_map(|n| match &n.op {
2216 Op::Input { name } if name != "primal_output" && name != "d_output" => {
2217 Some(n.id)
2218 }
2219 _ => None,
2220 })
2221 .collect();
2222 primal_input_ids.sort();
2223 assert_eq!(primal_input_ids.len(), *num_inputs as usize);
2224
2225 for sub_node in vjp_body.nodes() {
2228 let new_id = match &sub_node.op {
2229 Op::Input { name } if name == "primal_output" => fwd_map[&node.id],
2230 Op::Input { name } if name == "d_output" => upstream,
2231 Op::Input { .. } => {
2232 let idx = primal_input_ids
2234 .iter()
2235 .position(|&id| id == sub_node.id)
2236 .expect(
2237 "custom_fn vjp_body: primal Input \
2238 not found in primal list",
2239 );
2240 fwd_map[&node.inputs[idx]]
2241 }
2242 _ => {
2243 let new_inputs: Vec<NodeId> =
2244 sub_node.inputs.iter().map(|i| sub_to_bwd[i]).collect();
2245 bwd.add_node(sub_node.op.clone(), new_inputs, sub_node.shape.clone())
2246 }
2247 };
2248 sub_to_bwd.insert(sub_node.id, new_id);
2249 }
2250
2251 let mut grads: Vec<(usize, NodeId)> = Vec::with_capacity(*num_inputs as usize);
2254 for (i, out_id) in vjp_body.outputs.iter().enumerate() {
2255 grads.push((i, sub_to_bwd[out_id]));
2256 }
2257 grads
2258 }
2259
2260 Op::CustomFn { vjp_body: None, .. } => {
2263 panic!(
2264 "autodiff: Op::CustomFn has no vjp_body and was not inlined. \
2265 This is an internal error in inline_custom_fn_for_autodiff."
2266 )
2267 }
2268
2269 Op::Custom { name, .. } => {
2274 let ext = rlx_ir::lookup_op(name).unwrap_or_else(|| {
2275 panic!(
2276 "autodiff: Op::Custom('{name}') is not registered \
2277 in the op registry — register it via \
2278 rlx_ir::register_op before compiling the graph"
2279 )
2280 });
2281 let mut ctx = rlx_ir::VjpContext {
2282 upstream,
2283 fwd_map,
2284 bwd,
2285 };
2286 ext.vjp(node, &mut ctx)
2287 }
2288
2289 Op::Conv2dBackwardInput {
2290 kernel_size,
2291 stride,
2292 padding,
2293 dilation,
2294 groups,
2295 } => {
2296 let dy_bwd = fwd_map[&node.inputs[0]];
2297 let w_bwd = fwd_map[&node.inputs[1]];
2298 let dy_shape = bwd.node(dy_bwd).shape.clone();
2299 let _x_shape = node.shape.clone();
2300 let d_dy = bwd.add_node(
2301 Op::Conv {
2302 kernel_size: kernel_size.clone(),
2303 stride: stride.clone(),
2304 padding: padding.clone(),
2305 dilation: dilation.clone(),
2306 groups: *groups,
2307 },
2308 vec![upstream, w_bwd],
2309 dy_shape,
2310 );
2311 vec![(0, d_dy)]
2312 }
2313
2314 Op::Conv2dBackwardWeight {
2315 kernel_size,
2316 stride,
2317 padding,
2318 dilation,
2319 groups,
2320 } => {
2321 let x_bwd = fwd_map[&node.inputs[0]];
2322 let dy_bwd = fwd_map[&node.inputs[1]];
2323 let x_shape = bwd.node(x_bwd).shape.clone();
2324 let dy_shape = bwd.node(dy_bwd).shape.clone();
2325 let d_x = bwd.conv2d_backward_input(
2326 dy_bwd,
2327 upstream,
2328 x_shape,
2329 kernel_size.clone(),
2330 stride.clone(),
2331 padding.clone(),
2332 dilation.clone(),
2333 *groups,
2334 );
2335 let d_dy = bwd.add_node(
2336 Op::Conv {
2337 kernel_size: kernel_size.clone(),
2338 stride: stride.clone(),
2339 padding: padding.clone(),
2340 dilation: dilation.clone(),
2341 groups: *groups,
2342 },
2343 vec![x_bwd, upstream],
2344 dy_shape,
2345 );
2346 vec![(0, d_x), (1, d_dy)]
2347 }
2348
2349 Op::Fft { inverse, norm } => {
2358 let n = rlx_ir::fft::fft_meta(bwd.shape(node.inputs[0])).n_complex;
2359 let s = norm.output_scale(n, *inverse) as f32;
2360 let z = if s != 1.0 {
2361 let sc = scalar_const(s, bwd);
2362 bwd.mul(upstream, sc)
2363 } else {
2364 upstream
2365 };
2366 let dx = bwd.fft(z, !*inverse);
2367 vec![(0, dx)]
2368 }
2369
2370 Op::LogMel => {
2371 let spec_bwd = fwd_map[&node.inputs[0]];
2372 let filt_bwd = fwd_map[&node.inputs[1]];
2373 let dx = bwd.log_mel_backward(spec_bwd, filt_bwd, upstream);
2374 vec![(0, dx)]
2375 }
2376
2377 other => panic!(
2381 "autodiff: no VJP rule for {other}. See the matching \
2382 entry in rlx-opt/src/autodiff.rs (catch-all panic) for \
2383 a pointer to what's needed to differentiate this op.",
2384 ),
2385 }
2386}
2387
2388fn materialize_bcasts_for_ad(g: Graph) -> Graph {
2421 use rlx_ir::op::BinaryOp;
2422
2423 let needs = g.nodes().iter().any(|n| {
2424 matches!(
2425 &n.op, Op::Scan { num_bcast, .. } if *num_bcast > 0
2426 )
2427 });
2428 if !needs {
2429 return g;
2430 }
2431
2432 let mut out = Graph::new(g.name.clone());
2433 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
2434
2435 for node in g.nodes() {
2436 let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
2437 match &node.op {
2438 Op::Scan {
2439 body,
2440 length,
2441 save_trajectory,
2442 num_bcast,
2443 num_xs,
2444 num_checkpoints,
2445 } if *num_bcast > 0 => {
2446 let bcast_base = 1;
2451 let xs_base = 1 + *num_bcast as usize;
2452
2453 let mut new_scan_inputs = vec![new_inputs[0]];
2454
2455 let mut materialised_xs: Vec<NodeId> = Vec::new();
2457 for i in 0..*num_bcast as usize {
2458 let b_id = new_inputs[bcast_base + i];
2459 let b_shape = out.node(b_id).shape.clone();
2460 let dtype = b_shape.dtype();
2461
2462 let mut ones_dims: Vec<rlx_ir::Dim> =
2466 vec![rlx_ir::Dim::Static(*length as usize)];
2467 for _ in 0..b_shape.rank() {
2468 ones_dims.push(rlx_ir::Dim::Static(1));
2469 }
2470 let ones_shape = rlx_ir::Shape::from_dims(&ones_dims, dtype);
2471 let n_elems: usize = ones_dims
2472 .iter()
2473 .map(|d| match d {
2474 rlx_ir::Dim::Static(n) => *n,
2475 rlx_ir::Dim::Dynamic(_) => 1,
2476 })
2477 .product();
2478 let elem_size = dtype.size_bytes();
2479 let mut data = Vec::with_capacity(n_elems * elem_size);
2480 match dtype {
2481 rlx_ir::DType::F64 => {
2482 for _ in 0..n_elems {
2483 data.extend_from_slice(&1.0_f64.to_le_bytes());
2484 }
2485 }
2486 rlx_ir::DType::F32 => {
2487 for _ in 0..n_elems {
2488 data.extend_from_slice(&1.0_f32.to_le_bytes());
2489 }
2490 }
2491 other => {
2492 panic!("materialize_bcasts_for_ad: unsupported bcast dtype {other:?}")
2493 }
2494 }
2495 let ones = out.add_node(Op::Constant { data }, vec![], ones_shape);
2496
2497 let mut xs_dims: Vec<rlx_ir::Dim> = vec![rlx_ir::Dim::Static(*length as usize)];
2499 for i in 0..b_shape.rank() {
2500 xs_dims.push(b_shape.dim(i));
2501 }
2502 let xs_shape = rlx_ir::Shape::from_dims(&xs_dims, dtype);
2503 let xs_id = out.add_node(Op::Binary(BinaryOp::Mul), vec![ones, b_id], xs_shape);
2504 materialised_xs.push(xs_id);
2505 }
2506
2507 new_scan_inputs.extend_from_slice(&materialised_xs);
2508 for i in 0..*num_xs as usize {
2509 new_scan_inputs.push(new_inputs[xs_base + i]);
2510 }
2511
2512 let new_id = out.add_node(
2513 Op::Scan {
2514 body: body.clone(),
2515 length: *length,
2516 save_trajectory: *save_trajectory,
2517 num_bcast: 0,
2518 num_xs: *num_bcast + *num_xs,
2519 num_checkpoints: *num_checkpoints,
2520 },
2521 new_scan_inputs,
2522 node.shape.clone(),
2523 );
2524 id_map.insert(node.id, new_id);
2525 }
2526 _ => {
2527 let new_id = out.add_node(node.op.clone(), new_inputs, node.shape.clone());
2528 id_map.insert(node.id, new_id);
2529 }
2530 }
2531 }
2532
2533 let new_outputs: Vec<NodeId> = g.outputs.iter().map(|o| id_map[o]).collect();
2534 out.set_outputs(new_outputs);
2535 out
2536}
2537
2538pub fn convert_scans_for_ad(g: Graph) -> Graph {
2539 use rlx_ir::shape::Shape as IrShape;
2540
2541 let g = materialize_bcasts_for_ad(g);
2546
2547 let needs = g.nodes().iter().any(|n| {
2550 matches!(
2551 &n.op,
2552 Op::Scan {
2553 save_trajectory: false,
2554 ..
2555 }
2556 )
2557 });
2558 if !needs {
2559 return g;
2560 }
2561
2562 let mut out = Graph::new(g.name.clone());
2563 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
2564
2565 for node in g.nodes() {
2566 let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
2567 match &node.op {
2568 Op::Scan {
2569 body,
2570 length,
2571 save_trajectory: false,
2572 num_xs,
2573 num_checkpoints,
2574 ..
2575 } => {
2576 let carry_shape = node.shape.clone();
2577 let mut traj_dims: Vec<Dim> = Vec::with_capacity(carry_shape.rank() + 1);
2598 traj_dims.push(Dim::Static(*length as usize));
2599 for i in 0..carry_shape.rank() {
2600 traj_dims.push(carry_shape.dim(i));
2601 }
2602 let traj_shape = IrShape::from_dims(&traj_dims, carry_shape.dtype());
2603 let traj = out.add_node(
2604 Op::Scan {
2605 body: body.clone(),
2606 length: *length,
2607 save_trajectory: true,
2608 num_bcast: 0,
2609 num_xs: *num_xs,
2610 num_checkpoints: *num_checkpoints,
2611 },
2612 new_inputs,
2613 traj_shape,
2614 );
2615 let mut narrow_dims: Vec<Dim> = Vec::with_capacity(carry_shape.rank() + 1);
2617 narrow_dims.push(Dim::Static(1));
2618 for i in 0..carry_shape.rank() {
2619 narrow_dims.push(carry_shape.dim(i));
2620 }
2621 let narrow_shape = IrShape::from_dims(&narrow_dims, carry_shape.dtype());
2622 let narrowed = out.add_node(
2623 Op::Narrow {
2624 axis: 0,
2625 start: (*length as usize).saturating_sub(1),
2626 len: 1,
2627 },
2628 vec![traj],
2629 narrow_shape,
2630 );
2631 let new_shape: Vec<i64> = (0..carry_shape.rank())
2633 .map(|i| match carry_shape.dim(i) {
2634 Dim::Static(n) => n as i64,
2635 Dim::Dynamic(_) => -1,
2636 })
2637 .collect();
2638 let final_id = out.add_node(Op::Reshape { new_shape }, vec![narrowed], carry_shape);
2639 id_map.insert(node.id, final_id);
2640 }
2641 _ => {
2642 let new_id = out.add_node(node.op.clone(), new_inputs, node.shape.clone());
2643 id_map.insert(node.id, new_id);
2644 }
2645 }
2646 }
2647
2648 let new_outputs: Vec<NodeId> = g.outputs.iter().map(|o| id_map[o]).collect();
2649 out.set_outputs(new_outputs);
2650 out
2651}
2652
2653pub fn inline_custom_fn_for_autodiff(g: Graph) -> Graph {
2658 use rlx_fusion::control_flow::inline_subgraph_into;
2659
2660 let mut out = Graph::new(g.name.clone());
2661 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
2662 let nodes: Vec<rlx_ir::Node> = g.nodes().to_vec();
2663
2664 for node in &nodes {
2665 let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
2666 let new_id = match &node.op {
2667 Op::CustomFn {
2668 vjp_body: None,
2669 jvp_body: None,
2670 fwd_body,
2671 num_inputs,
2672 ..
2673 } => {
2674 assert_eq!(
2675 new_inputs.len(),
2676 *num_inputs as usize,
2677 "custom_fn: outer input count mismatch"
2678 );
2679 inline_subgraph_into(fwd_body, &new_inputs, &mut out)
2680 }
2681 _ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
2682 };
2683 id_map.insert(node.id, new_id);
2684 }
2685
2686 let new_outputs: Vec<NodeId> = g.outputs.iter().map(|i| id_map[i]).collect();
2687 out.set_outputs(new_outputs);
2688 out
2689}
2690
2691pub(crate) fn unbroadcast_inverse(x: NodeId, target: &Shape, bwd: &mut Graph) -> NodeId {
2695 let target_dims: Vec<i64> = target
2696 .dims()
2697 .iter()
2698 .map(|d| match d {
2699 Dim::Static(n) => *n as i64,
2700 Dim::Dynamic(_) => -1,
2701 })
2702 .collect();
2703 bwd.add_node(
2704 Op::Expand {
2705 target_shape: target_dims,
2706 },
2707 vec![x],
2708 target.clone(),
2709 )
2710}
2711
2712fn expand_to(
2717 grad: NodeId,
2718 x_shape: &Shape,
2719 axes: &[usize],
2720 keep_dim: bool,
2721 bwd: &mut Graph,
2722) -> NodeId {
2723 let mut current = grad;
2724 if !keep_dim {
2725 let kept_dims: Vec<Dim> = (0..x_shape.rank())
2728 .map(|i| {
2729 if axes.contains(&i) {
2730 Dim::Static(1)
2731 } else {
2732 x_shape.dim(i)
2733 }
2734 })
2735 .collect();
2736 let kept = Shape::from_dims(&kept_dims, x_shape.dtype());
2737 current = reshape_to(current, &kept, bwd);
2738 }
2739 let target_shape: Vec<i64> = x_shape
2740 .dims()
2741 .iter()
2742 .map(|d| match d {
2743 Dim::Static(n) => *n as i64,
2744 Dim::Dynamic(_) => -1,
2745 })
2746 .collect();
2747 bwd.add_node(Op::Expand { target_shape }, vec![current], x_shape.clone())
2748}
2749
2750#[cfg(test)]
2751mod tests {
2752 use super::*;
2753
2754 #[test]
2755 fn grad_of_add_is_identity() {
2756 let mut g = Graph::new("test");
2757 let x = g.input("x", Shape::new(&[4], DType::F32));
2758 let y = g.input("y", Shape::new(&[4], DType::F32));
2759 let z = g.binary(BinaryOp::Add, x, y, Shape::new(&[4], DType::F32));
2760 g.set_outputs(vec![z]);
2761
2762 let bwd = grad(&g, &[x, y]);
2763 assert_eq!(bwd.outputs.len(), 2);
2765 }
2766
2767 #[test]
2768 fn grad_of_mul_uses_other_operand() {
2769 let mut g = Graph::new("test");
2770 let x = g.input("x", Shape::new(&[4], DType::F32));
2771 let y = g.input("y", Shape::new(&[4], DType::F32));
2772 let z = g.binary(BinaryOp::Mul, x, y, Shape::new(&[4], DType::F32));
2773 g.set_outputs(vec![z]);
2774
2775 let bwd = grad(&g, &[x, y]);
2776 assert!(
2778 bwd.nodes()
2779 .iter()
2780 .filter(|n| matches!(n.op, Op::Binary(BinaryOp::Mul)))
2781 .count()
2782 >= 2
2783 );
2784 }
2785
2786 #[test]
2787 fn grad_with_loss_returns_loss_first() {
2788 let mut g = Graph::new("loss");
2789 let x = g.input("x", Shape::new(&[4], DType::F32));
2790 let y = g.input("y", Shape::new(&[4], DType::F32));
2791 let z = g.binary(BinaryOp::Add, x, y, Shape::new(&[4], DType::F32));
2792 g.set_outputs(vec![z]);
2793
2794 let bwd = grad_with_loss(&g, &[x, y]);
2795 assert_eq!(bwd.outputs.len(), 3);
2797 }
2798
2799 #[test]
2800 fn grad_of_dense_solve_emits_implicit_function_rule() {
2801 let mut g = Graph::new("solve_test");
2815 let a = g.param("A", Shape::new(&[2, 2], DType::F32));
2816 let b = g.input("b", Shape::new(&[2], DType::F32));
2817 let x = g.dense_solve(a, b, Shape::new(&[2], DType::F32));
2818 let loss = g.reduce(
2819 x,
2820 ReduceOp::Sum,
2821 vec![0],
2822 false,
2823 Shape::new(&[1], DType::F32),
2824 );
2825 g.set_outputs(vec![loss]);
2826
2827 let bwd = grad_with_loss(&g, &[a, b]);
2828 assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
2829
2830 let count =
2831 |pred: fn(&Op) -> bool| -> usize { bwd.nodes().iter().filter(|n| pred(&n.op)).count() };
2832
2833 assert!(
2836 count(|o| matches!(o, Op::DenseSolve)) >= 2,
2837 "expected ≥2 DenseSolve nodes (forward mirror + reverse), got\n{bwd}"
2838 );
2839 assert!(
2840 count(|o| matches!(o, Op::Transpose { .. })) >= 1,
2841 "expected a Transpose for Aᵀ, got\n{bwd}"
2842 );
2843 assert!(
2844 count(|o| matches!(o, Op::MatMul)) >= 1,
2845 "expected a MatMul for the outer product, got\n{bwd}"
2846 );
2847 assert!(
2848 count(|o| matches!(o, Op::Activation(Activation::Neg))) >= 1,
2849 "expected a Neg for −outer, got\n{bwd}"
2850 );
2851 }
2852
2853 #[test]
2854 fn inline_if_replaces_with_where() {
2855 let s = Shape::new(&[4], DType::F32);
2862 let pred_s = Shape::new(&[1], DType::F32);
2863
2864 let mut then_g = Graph::new("then_branch");
2865 let then_in = then_g.input("captured", s.clone());
2866 let then_out = then_g.activation(Activation::Relu, then_in, s.clone());
2867 then_g.set_outputs(vec![then_out]);
2868
2869 let mut else_g = Graph::new("else_branch");
2870 let else_in = else_g.input("captured", s.clone());
2871 let else_out = else_g.activation(Activation::Sigmoid, else_in, s.clone());
2872 else_g.set_outputs(vec![else_out]);
2873
2874 let mut g = Graph::new("parent");
2875 let x = g.input("x", s.clone());
2876 let pred = g.input("pred", pred_s);
2877 let if_out = g.add_node(
2878 Op::If {
2879 then_branch: Box::new(then_g),
2880 else_branch: Box::new(else_g),
2881 },
2882 vec![pred, x],
2883 s,
2884 );
2885 g.set_outputs(vec![if_out]);
2886
2887 let inlined = rlx_fusion::control_flow::inline_if(g);
2888
2889 let has_if = inlined
2893 .nodes()
2894 .iter()
2895 .any(|n| matches!(n.op, Op::If { .. }));
2896 let has_where = inlined.nodes().iter().any(|n| matches!(n.op, Op::Where));
2897 let has_relu = inlined
2898 .nodes()
2899 .iter()
2900 .any(|n| matches!(n.op, Op::Activation(Activation::Relu)));
2901 let has_sigmoid = inlined
2902 .nodes()
2903 .iter()
2904 .any(|n| matches!(n.op, Op::Activation(Activation::Sigmoid)));
2905 assert!(!has_if, "Op::If should be inlined away");
2906 assert!(has_where, "Op::Where should replace the Op::If");
2907 assert!(has_relu, "then_branch's Activation(Relu) should be inlined");
2908 assert!(
2909 has_sigmoid,
2910 "else_branch's Activation(Sigmoid) should be inlined"
2911 );
2912 assert_eq!(inlined.outputs.len(), 1);
2913 }
2914
2915 #[test]
2916 fn grad_through_if_propagates() {
2917 let s = Shape::new(&[4], DType::F32);
2920 let pred_s = Shape::new(&[1], DType::F32);
2921
2922 let mut then_g = Graph::new("th");
2923 let ti = then_g.input("c", s.clone());
2924 let to = then_g.binary(BinaryOp::Mul, ti, ti, s.clone());
2925 then_g.set_outputs(vec![to]);
2926
2927 let mut else_g = Graph::new("el");
2928 let ei = else_g.input("c", s.clone());
2929 let eo = else_g.activation(Activation::Relu, ei, s.clone());
2930 else_g.set_outputs(vec![eo]);
2931
2932 let mut g = Graph::new("parent");
2933 let x = g.input("x", s.clone());
2934 let pred = g.input("pred", pred_s);
2935 let z = g.add_node(
2936 Op::If {
2937 then_branch: Box::new(then_g),
2938 else_branch: Box::new(else_g),
2939 },
2940 vec![pred, x],
2941 s,
2942 );
2943 g.set_outputs(vec![z]);
2944
2945 let bwd = grad_with_loss(&g, &[x]);
2946 assert_eq!(bwd.outputs.len(), 2, "expected loss + 1 grad output");
2948 }
2949
2950 #[test]
2951 fn unroll_while_replicates_body_n_times() {
2952 let s = Shape::new(&[4], DType::F32);
2958 let bool_s = Shape::new(&[1], DType::F32);
2959
2960 let mut cond_g = Graph::new("cond");
2961 let ci = cond_g.input("c", s.clone());
2962 cond_g.set_outputs(vec![ci]);
2965 let _ = bool_s;
2968
2969 let mut body_g = Graph::new("body");
2970 let bi = body_g.input("c", s.clone());
2971 let bo = body_g.activation(Activation::Relu, bi, s.clone());
2972 body_g.set_outputs(vec![bo]);
2973
2974 let mut g = Graph::new("parent");
2975 let x = g.input("x", s.clone());
2976 let w = g.add_node(
2977 Op::While {
2978 cond: Box::new(cond_g),
2979 body: Box::new(body_g),
2980 max_iterations: Some(3),
2981 },
2982 vec![x],
2983 s,
2984 );
2985 g.set_outputs(vec![w]);
2986
2987 let unrolled = rlx_fusion::control_flow::unroll_while(g);
2988
2989 let has_while = unrolled
2990 .nodes()
2991 .iter()
2992 .any(|n| matches!(n.op, Op::While { .. }));
2993 let relu_count = unrolled
2994 .nodes()
2995 .iter()
2996 .filter(|n| matches!(n.op, Op::Activation(Activation::Relu)))
2997 .count();
2998 assert!(!has_while, "Op::While should be unrolled away");
2999 assert_eq!(
3000 relu_count, 3,
3001 "body's Activation(Relu) should appear once per iteration"
3002 );
3003 assert_eq!(unrolled.outputs.len(), 1);
3004 }
3005
3006 #[test]
3007 fn grad_through_while_propagates() {
3008 let s = Shape::new(&[4], DType::F32);
3012
3013 let mut cond_g = Graph::new("cond");
3014 let ci = cond_g.input("c", s.clone());
3015 cond_g.set_outputs(vec![ci]);
3016
3017 let mut body_g = Graph::new("body");
3018 let bi = body_g.input("c", s.clone());
3019 let bo = body_g.binary(BinaryOp::Mul, bi, bi, s.clone());
3020 body_g.set_outputs(vec![bo]);
3021
3022 let mut g = Graph::new("parent");
3023 let x = g.input("x", s.clone());
3024 let w = g.add_node(
3025 Op::While {
3026 cond: Box::new(cond_g),
3027 body: Box::new(body_g),
3028 max_iterations: Some(2),
3029 },
3030 vec![x],
3031 s,
3032 );
3033 g.set_outputs(vec![w]);
3034
3035 let bwd = grad_with_loss(&g, &[x]);
3036 assert_eq!(bwd.outputs.len(), 2, "expected loss + 1 grad output");
3037 }
3038
3039 fn build_ftl_graph(has_bias: bool) -> (Graph, NodeId, Vec<NodeId>) {
3042 let mut g = Graph::new("ftl_test");
3044 let h_shape = Shape::new(&[1, 2, 4], DType::F32);
3045 let h = g.input("h", h_shape.clone());
3046 let qkv_w = g.param("qkv_w", Shape::new(&[4, 12], DType::F32));
3047 let out_w = g.param("out_w", Shape::new(&[4, 4], DType::F32));
3048 let ln1_g = g.param("ln1_g", Shape::new(&[4], DType::F32));
3049 let fc1_w = g.param("fc1_w", Shape::new(&[4, 8], DType::F32));
3050 let fc2_w = g.param("fc2_w", Shape::new(&[8, 4], DType::F32));
3051 let ln2_g = g.param("ln2_g", Shape::new(&[4], DType::F32));
3052 let mask = g.input("mask", Shape::new(&[1, 2, 2, 2], DType::F32));
3053
3054 let (inputs, params) = if has_bias {
3055 let qkv_b = g.param("qkv_b", Shape::new(&[12], DType::F32));
3056 let out_b = g.param("out_b", Shape::new(&[4], DType::F32));
3057 let ln1_b = g.param("ln1_b", Shape::new(&[4], DType::F32));
3058 let fc1_b = g.param("fc1_b", Shape::new(&[8], DType::F32));
3059 let fc2_b = g.param("fc2_b", Shape::new(&[4], DType::F32));
3060 let ln2_b = g.param("ln2_b", Shape::new(&[4], DType::F32));
3061 (
3062 vec![
3063 h, qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g,
3064 ln2_b, mask,
3065 ],
3066 vec![
3067 qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g,
3068 ln2_b,
3069 ],
3070 )
3071 } else {
3072 (
3073 vec![h, qkv_w, out_w, ln1_g, fc1_w, fc2_w, ln2_g, mask],
3074 vec![qkv_w, out_w, ln1_g, fc1_w, fc2_w, ln2_g],
3075 )
3076 };
3077 let y = g.add_node(
3078 Op::FusedTransformerLayer {
3079 num_heads: 2,
3080 head_dim: 2,
3081 intermediate_size: 8,
3082 eps1: 1e-5,
3083 eps2: 1e-5,
3084 activation: rlx_ir::op::Activation::Gelu,
3085 has_bias,
3086 },
3087 inputs,
3088 h_shape,
3089 );
3090 g.set_outputs(vec![y]);
3091 (g, h, params)
3092 }
3093
3094 #[test]
3095 fn unfuse_decomposes_fused_transformer_layer() {
3096 let (g, _h, _params) = build_ftl_graph(true);
3100 let unfused = rlx_fusion::unfuse_fused_for_autodiff(g);
3101
3102 let has_ftl = unfused
3103 .nodes()
3104 .iter()
3105 .any(|n| matches!(n.op, Op::FusedTransformerLayer { .. }));
3106 assert!(!has_ftl, "Op::FusedTransformerLayer should be unfused");
3107
3108 let count = |pred: fn(&Op) -> bool| -> usize {
3109 unfused.nodes().iter().filter(|n| pred(&n.op)).count()
3110 };
3111 assert!(
3112 count(|o| matches!(o, Op::MatMul)) >= 4,
3113 "expected >=4 MatMul after FTL unfuse"
3114 );
3115 assert_eq!(
3116 count(|o| matches!(o, Op::Attention { .. })),
3117 1,
3118 "expected exactly 1 Attention after FTL unfuse"
3119 );
3120 assert_eq!(
3121 count(|o| matches!(o, Op::LayerNorm { .. })),
3122 2,
3123 "expected exactly 2 LayerNorm after FTL unfuse"
3124 );
3125 assert!(
3126 count(|o| matches!(o, Op::Narrow { .. })) >= 3,
3127 "expected >=3 Narrow (Q/K/V split) after FTL unfuse"
3128 );
3129 assert_eq!(
3130 count(|o| matches!(o, Op::Activation(_))),
3131 1,
3132 "expected exactly 1 Activation (FFN) after FTL unfuse"
3133 );
3134 }
3135
3136 #[test]
3137 fn grad_through_fused_transformer_layer_propagates() {
3138 let (g, _h, params) = build_ftl_graph(true);
3142 let bwd = grad_with_loss(&g, ¶ms);
3143 assert_eq!(
3144 bwd.outputs.len(),
3145 1 + params.len(),
3146 "expected loss + {} param grads",
3147 params.len()
3148 );
3149 }
3150
3151 #[test]
3152 fn grad_through_fused_transformer_layer_no_bias() {
3153 let (g, _h, params) = build_ftl_graph(false);
3156 let bwd = grad_with_loss(&g, ¶ms);
3157 assert_eq!(
3158 bwd.outputs.len(),
3159 1 + params.len(),
3160 "expected loss + {} param grads (no-bias)",
3161 params.len()
3162 );
3163 }
3164
3165 fn build_ssm_graph() -> (Graph, NodeId, Vec<NodeId>) {
3168 let mut g = Graph::new("ssm_test");
3169 let bsh = Shape::new(&[1, 3, 2], DType::F32);
3170 let hn = Shape::new(&[2, 4], DType::F32);
3171 let bsn = Shape::new(&[1, 3, 4], DType::F32);
3172
3173 let x = g.input("x", bsh.clone());
3174 let delta = g.input("delta", bsh.clone());
3175 let a = g.param("a", hn);
3176 let b = g.input("b", bsn.clone());
3177 let c = g.input("c", bsn);
3178 let y = g.selective_scan(x, delta, a, b, c, 4, bsh);
3179 g.set_outputs(vec![y]);
3180 (g, x, vec![a])
3181 }
3182
3183 #[test]
3184 fn unfuse_decomposes_selective_scan() {
3185 let (g, _x, _params) = build_ssm_graph();
3190 let unfused = rlx_fusion::unfuse_fused_for_autodiff(g);
3191
3192 let has_ssm = unfused
3193 .nodes()
3194 .iter()
3195 .any(|n| matches!(n.op, Op::SelectiveScan { .. }));
3196 assert!(!has_ssm, "Op::SelectiveScan should be unfused");
3197
3198 let count = |pred: fn(&Op) -> bool| -> usize {
3199 unfused.nodes().iter().filter(|n| pred(&n.op)).count()
3200 };
3201 assert_eq!(
3202 count(|o| matches!(o, Op::Concat { .. })),
3203 1,
3204 "expected 1 Concat (over the 3 time steps)"
3205 );
3206 assert_eq!(
3207 count(|o| matches!(
3208 o,
3209 Op::Reduce {
3210 op: ReduceOp::Sum,
3211 ..
3212 }
3213 )),
3214 3,
3215 "expected one Reduce(Sum) per time step (S=3)"
3216 );
3217 assert_eq!(
3218 count(|o| matches!(o, Op::Activation(Activation::Exp))),
3219 3,
3220 "expected one exp(δA) per time step (S=3)"
3221 );
3222 assert!(
3223 count(|o| matches!(o, Op::Narrow { .. })) >= 12,
3224 "expected >=12 Narrows (4 per step × 3 steps)"
3225 );
3226 }
3227
3228 #[test]
3229 fn grad_through_selective_scan_propagates() {
3230 let (g, _x, params) = build_ssm_graph();
3236 let bwd = grad_with_loss(&g, ¶ms);
3237 assert_eq!(
3238 bwd.outputs.len(),
3239 1 + params.len(),
3240 "expected loss + {} param grads",
3241 params.len()
3242 );
3243 }
3244
3245 fn build_gdn_graph() -> (Graph, NodeId, Vec<NodeId>) {
3247 let (b, s, h, n) = (1usize, 3, 2, 4);
3248 let mut g = Graph::new("gdn_test");
3249 let bshn = Shape::new(&[b, s, h, n], DType::F32);
3250 let bsh = Shape::new(&[b, s, h], DType::F32);
3251 let q = g.input("q", bshn.clone());
3252 let k = g.input("k", bshn.clone());
3253 let v = g.input("v", bshn.clone());
3254 let g_in = g.input("g", bsh.clone());
3255 let beta = g.input("beta", bsh);
3256 let y = g.gated_delta_net(q, k, v, g_in, beta, n, bshn);
3257 g.set_outputs(vec![y]);
3258 (g, q, vec![q, k, v, g_in, beta])
3259 }
3260
3261 #[test]
3262 fn unfuse_decomposes_gated_delta_net() {
3263 let (g, _q, _params) = build_gdn_graph();
3264 let unfused = rlx_fusion::unfuse_fused_for_autodiff(g);
3265
3266 let has_gdn = unfused
3267 .nodes()
3268 .iter()
3269 .any(|n| matches!(n.op, Op::GatedDeltaNet { .. }));
3270 assert!(!has_gdn, "Op::GatedDeltaNet should be unfused");
3271
3272 let count = |pred: fn(&Op) -> bool| -> usize {
3273 unfused.nodes().iter().filter(|n| pred(&n.op)).count()
3274 };
3275 assert_eq!(
3276 count(|o| matches!(o, Op::Concat { .. })),
3277 1,
3278 "expected 1 Concat over S=3 steps"
3279 );
3280 assert!(
3281 count(|o| matches!(o, Op::MatMul)) >= 3,
3282 "expected >=3 MatMul per step (sk + out) × S=3"
3283 );
3284 assert_eq!(
3285 count(|o| matches!(o, Op::Activation(Activation::Exp))),
3286 3,
3287 "expected one exp(g) per time step"
3288 );
3289 }
3290
3291 #[test]
3292 fn grad_through_gated_delta_net_propagates() {
3293 let (g, _q, params) = build_gdn_graph();
3294 let bwd = grad_with_loss(&g, ¶ms);
3295 assert_eq!(
3296 bwd.outputs.len(),
3297 1 + params.len(),
3298 "expected loss + {} input grads",
3299 params.len()
3300 );
3301 }
3302
3303 #[test]
3304 fn custom_fn_vjp_body_is_inlined_into_bwd() {
3305 let n = 4usize;
3313 let shape = Shape::new(&[n], DType::F32);
3314
3315 let mut fwd_body = Graph::new("square_fwd");
3317 let xb = fwd_body.input("x", shape.clone());
3318 let yb = fwd_body.binary(BinaryOp::Mul, xb, xb, shape.clone());
3319 fwd_body.set_outputs(vec![yb]);
3320
3321 let mut vjp_body = Graph::new("square_vjp");
3323 let _vx = vjp_body.input("x", shape.clone());
3324 let _vp = vjp_body.input("primal_output", shape.clone());
3325 let vd = vjp_body.input("d_output", shape.clone());
3326 let dx = vjp_body.activation(Activation::Sin, vd, shape.clone());
3327 vjp_body.set_outputs(vec![dx]);
3328
3329 let mut g = Graph::new("custom_fn_test");
3330 let x = g.input("x", shape.clone());
3331 let y = g.custom_fn(vec![x], fwd_body, Some(vjp_body), None);
3332 let loss = g.reduce(
3333 y,
3334 ReduceOp::Sum,
3335 vec![0],
3336 false,
3337 Shape::new(&[1], DType::F32),
3338 );
3339 g.set_outputs(vec![loss]);
3340
3341 let bwd = grad_with_loss(&g, &[x]);
3342 assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
3343 let sin_count = bwd
3344 .nodes()
3345 .iter()
3346 .filter(|n| matches!(n.op, Op::Activation(Activation::Sin)))
3347 .count();
3348 assert!(
3349 sin_count >= 1,
3350 "expected the vjp_body's Sin to be inlined into bwd, got\n{bwd}"
3351 );
3352 }
3353
3354 #[test]
3355 fn custom_fn_without_vjp_inlines_fwd_body_for_autodiff() {
3356 let n = 4usize;
3360 let shape = Shape::new(&[n], DType::F32);
3361
3362 let mut fwd_body = Graph::new("square_fwd");
3363 let xb = fwd_body.input("x", shape.clone());
3364 let yb = fwd_body.binary(BinaryOp::Mul, xb, xb, shape.clone());
3365 fwd_body.set_outputs(vec![yb]);
3366
3367 let mut g = Graph::new("custom_fn_no_vjp");
3368 let x = g.input("x", shape.clone());
3369 let y = g.custom_fn(vec![x], fwd_body, None, None);
3370 let loss = g.reduce(
3371 y,
3372 ReduceOp::Sum,
3373 vec![0],
3374 false,
3375 Shape::new(&[1], DType::F32),
3376 );
3377 g.set_outputs(vec![loss]);
3378
3379 let bwd = grad_with_loss(&g, &[x]);
3380 assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
3381 let custom_fn_count = bwd
3382 .nodes()
3383 .iter()
3384 .filter(|n| matches!(n.op, Op::CustomFn { .. }))
3385 .count();
3386 assert_eq!(
3387 custom_fn_count, 0,
3388 "CustomFn should be inlined away before autodiff"
3389 );
3390 let mul_count = bwd
3391 .nodes()
3392 .iter()
3393 .filter(|n| matches!(n.op, Op::Binary(BinaryOp::Mul)))
3394 .count();
3395 assert!(mul_count >= 2, "expected Mul-based VJP for x², got\n{bwd}");
3396 }
3397
3398 #[test]
3399 fn convert_scans_for_ad_forces_save_trajectory_true() {
3400 let n = 2usize;
3407 let length = 3u32;
3408 let carry = Shape::new(&[n], DType::F32);
3409 let xs_shape = Shape::new(&[length as usize, n], DType::F32);
3410
3411 let mut body = Graph::new("scan_body");
3413 let bc = body.input("carry", carry.clone());
3414 let bx = body.input("x_t", carry.clone());
3415 let by = body.binary(BinaryOp::Add, bc, bx, carry.clone());
3416 body.set_outputs(vec![by]);
3417
3418 let mut g = Graph::new("scan_save_false");
3419 let init = g.input("init", carry.clone());
3420 let xs = g.input("xs", xs_shape);
3421 let scan_out = g.add_node(
3422 Op::Scan {
3423 body: Box::new(body),
3424 length,
3425 save_trajectory: false,
3426 num_bcast: 0,
3427 num_xs: 1,
3428 num_checkpoints: 0,
3429 },
3430 vec![init, xs],
3431 carry.clone(),
3432 );
3433 let loss = g.reduce(
3434 scan_out,
3435 ReduceOp::Sum,
3436 vec![0],
3437 false,
3438 Shape::new(&[1], DType::F32),
3439 );
3440 g.set_outputs(vec![loss]);
3441
3442 let bwd = grad_with_loss(&g, &[init, xs]);
3443 let saved_traj = bwd.nodes().iter().any(|n| {
3444 matches!(
3445 &n.op,
3446 Op::Scan {
3447 save_trajectory: true,
3448 ..
3449 }
3450 )
3451 });
3452 assert!(
3453 saved_traj,
3454 "convert_scans_for_ad should rewrite save_trajectory=false → \
3455 save_trajectory=true in the AD-prepared graph; got\n{bwd}"
3456 );
3457 }
3458}