use rlx_ir::op::*;
use rlx_ir::shape::Dim;
use rlx_ir::*;
use std::collections::HashMap;
pub use crate::prepare_ad::{
AutodiffError, PrepareForAutodiff, grad_with_loss_module, jvp_module, prepare_graph_for_ad,
prepare_mir_for_ad, prepare_module_for_ad,
};
pub fn grad_with_loss(forward: &Graph, wrt: &[NodeId]) -> Graph {
assert_eq!(
forward.outputs.len(),
1,
"grad_with_loss: forward must have exactly one output"
);
let forward_owned = crate::prepare_ad::prepare_graph_for_ad(forward.clone());
let forward = &forward_owned;
let mut bwd = Graph::new(format!("{}_grad", forward.name));
let mut fwd_to_bwd: HashMap<NodeId, NodeId> = HashMap::new();
for node in forward.nodes() {
let inputs: Vec<NodeId> = node.inputs.iter().map(|i| fwd_to_bwd[i]).collect();
let new_id = bwd.add_node(node.op.clone(), inputs, node.shape.clone());
fwd_to_bwd.insert(node.id, new_id);
}
let loss_fwd = forward.outputs[0];
let loss_bwd = fwd_to_bwd[&loss_fwd];
let loss_shape = forward.node(loss_fwd).shape.clone();
let d_output = bwd.input("d_output", loss_shape);
let mut grads: HashMap<NodeId, NodeId> = HashMap::new();
grads.insert(loss_bwd, d_output);
for fwd_node in forward.nodes().iter().rev() {
let bwd_id = fwd_to_bwd[&fwd_node.id];
let upstream = match grads.get(&bwd_id) {
Some(g) => *g,
None => continue,
};
let input_grads = vjp(fwd_node, upstream, &fwd_to_bwd, &mut bwd);
for (idx, grad_id) in input_grads {
let target = fwd_node.inputs[idx];
let bwd_target = fwd_to_bwd[&target];
let new_grad = if let Some(&prev) = grads.get(&bwd_target) {
let shape = bwd.node(prev).shape.clone();
bwd.binary(BinaryOp::Add, prev, grad_id, shape)
} else {
grad_id
};
grads.insert(bwd_target, new_grad);
}
}
let mut outputs = Vec::with_capacity(1 + wrt.len());
outputs.push(loss_bwd);
for &id in wrt {
let g = grads.get(&fwd_to_bwd[&id]).copied().unwrap_or_else(|| {
panic!(
"no gradient flowed to {id} — \
either the forward graph doesn't depend on it, or one \
of its consumer ops has no VJP rule"
)
});
outputs.push(g);
}
bwd.set_outputs(outputs);
bwd
}
pub fn grad(forward: &Graph, wrt: &[NodeId]) -> Graph {
let g = grad_with_loss(forward, wrt);
let mut g = g;
let outs = g.outputs.iter().skip(1).copied().collect();
g.set_outputs(outs);
g
}
pub fn quantized_weight_bits(forward: &Graph, node_id: NodeId) -> Option<u8> {
match &forward.node(node_id).op {
Op::FakeQuantize { bits, .. } => Some(*bits),
Op::FakeQuantizeLSQ { bits, .. } => Some(*bits),
_ => None,
}
}
fn unbroadcast(grad: NodeId, target_shape: &Shape, bwd: &mut Graph) -> NodeId {
let grad_shape = bwd.node(grad).shape.clone();
if grad_shape == *target_shape {
return grad;
}
let g_rank = grad_shape.rank();
let t_rank = target_shape.rank();
let extra = g_rank.saturating_sub(t_rank);
let mut axes: Vec<usize> = (0..extra).collect();
for i in 0..t_rank {
let g_dim = grad_shape.dim(extra + i);
let t_dim = target_shape.dim(i);
if matches!(t_dim, Dim::Static(1)) && !matches!(g_dim, Dim::Static(1)) {
axes.push(extra + i);
}
}
let mut current = grad;
if !axes.is_empty() {
let mut running_dims: Vec<Dim> = (0..g_rank).map(|i| grad_shape.dim(i)).collect();
for &ax in &axes {
running_dims[ax] = Dim::Static(1);
let step_shape = Shape::from_dims(&running_dims, target_shape.dtype());
current = bwd.add_node(
Op::Reduce {
op: ReduceOp::Sum,
axes: vec![ax],
keep_dim: true,
},
vec![current],
step_shape,
);
}
}
if bwd.node(current).shape.rank() != t_rank {
let new_shape: Vec<i64> = target_shape
.dims()
.iter()
.map(|d| match d {
Dim::Static(n) => *n as i64,
Dim::Dynamic(_) => -1,
})
.collect();
current = bwd.add_node(
Op::Reshape { new_shape },
vec![current],
target_shape.clone(),
);
}
current
}
fn reshape_to(grad: NodeId, target_shape: &Shape, bwd: &mut Graph) -> NodeId {
if bwd.node(grad).shape == *target_shape {
return grad;
}
let new_shape: Vec<i64> = target_shape
.dims()
.iter()
.map(|d| match d {
Dim::Static(n) => *n as i64,
Dim::Dynamic(_) => -1,
})
.collect();
bwd.add_node(Op::Reshape { new_shape }, vec![grad], target_shape.clone())
}
fn grouped_matmul_vjp(
bwd: &mut Graph,
upstream: NodeId,
x_bwd: NodeId,
w_bwd: NodeId,
expert_bwd: NodeId,
x_shape: &Shape,
w_shape: &Shape,
) -> (NodeId, NodeId) {
let dtype = x_shape.dtype();
let m = x_shape.dim(0);
let k = x_shape.dim(1);
let e = w_shape.dim(0);
let n_out = w_shape.dim(2);
let m_static = match m {
Dim::Static(v) => v,
_ => panic!("GroupedMatMul VJP: M must be static"),
};
let k_static = match k {
Dim::Static(v) => v,
_ => panic!("GroupedMatMul VJP: K must be static"),
};
let n_static = match n_out {
Dim::Static(v) => v,
_ => panic!("GroupedMatMul VJP: N must be static"),
};
let w_per = bwd.add_node(
Op::Gather { axis: 0 },
vec![w_bwd, expert_bwd],
Shape::from_dims(&[m, k, n_out], dtype),
);
let up_3d_shape = Shape::from_dims(&[m, Dim::Static(1), n_out], dtype);
let up_3d = bwd.reshape(
upstream,
vec![m_static as i64, 1, n_static as i64],
up_3d_shape,
);
let w_per_t = bwd.add_node(
Op::Transpose {
perm: vec![0, 2, 1],
},
vec![w_per],
Shape::from_dims(&[m, n_out, k], dtype),
);
let dx_3d_shape = Shape::from_dims(&[m, Dim::Static(1), k], dtype);
let dx_3d = bwd.matmul(up_3d, w_per_t, dx_3d_shape);
let dx = bwd.reshape(
dx_3d,
vec![m_static as i64, k_static as i64],
x_shape.clone(),
);
let x_3d = bwd.reshape(
x_bwd,
vec![m_static as i64, k_static as i64, 1],
Shape::from_dims(&[m, k, Dim::Static(1)], dtype),
);
let up_for_outer = bwd.reshape(
upstream,
vec![m_static as i64, 1, n_static as i64],
Shape::from_dims(&[m, Dim::Static(1), n_out], dtype),
);
let dw_per = bwd.matmul(x_3d, up_for_outer, Shape::from_dims(&[m, k, n_out], dtype));
let dw = bwd.add_node(
Op::ScatterAdd,
vec![dw_per, expert_bwd],
Shape::from_dims(&[e, k, n_out], dtype),
);
(dx, dw)
}
fn scalar_const(value: f32, bwd: &mut Graph) -> NodeId {
let bytes = value.to_le_bytes().to_vec();
let shape = Shape::from_dims(&[Dim::Static(1)], DType::F32);
bwd.add_node(Op::Constant { data: bytes }, vec![], shape)
}
fn vjp(
node: &Node,
upstream: NodeId,
fwd_map: &HashMap<NodeId, NodeId>,
bwd: &mut Graph,
) -> Vec<(usize, NodeId)> {
let upstream_shape = bwd.node(upstream).shape.clone();
match &node.op {
Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => vec![],
Op::Binary(BinaryOp::Add) => {
let a_bwd = fwd_map[&node.inputs[0]];
let b_bwd = fwd_map[&node.inputs[1]];
let a_shape = bwd.node(a_bwd).shape.clone();
let b_shape = bwd.node(b_bwd).shape.clone();
let g_a = unbroadcast(upstream, &a_shape, bwd);
let g_b = unbroadcast(upstream, &b_shape, bwd);
vec![(0, g_a), (1, g_b)]
}
Op::Binary(BinaryOp::Sub) => {
let a_bwd = fwd_map[&node.inputs[0]];
let b_bwd = fwd_map[&node.inputs[1]];
let a_shape = bwd.node(a_bwd).shape.clone();
let b_shape = bwd.node(b_bwd).shape.clone();
let neg = bwd.activation(Activation::Neg, upstream, upstream_shape.clone());
let g_a = unbroadcast(upstream, &a_shape, bwd);
let g_b = unbroadcast(neg, &b_shape, bwd);
vec![(0, g_a), (1, g_b)]
}
Op::Binary(BinaryOp::Mul) => {
let a_bwd = fwd_map[&node.inputs[0]];
let b_bwd = fwd_map[&node.inputs[1]];
let a_shape = bwd.node(a_bwd).shape.clone();
let b_shape = bwd.node(b_bwd).shape.clone();
let is_c64 = upstream_shape.dtype() == DType::C64;
let b_for_a = if is_c64 { bwd.conjugate(b_bwd) } else { b_bwd };
let a_for_b = if is_c64 { bwd.conjugate(a_bwd) } else { a_bwd };
let g_a_full = bwd.binary(BinaryOp::Mul, upstream, b_for_a, upstream_shape.clone());
let g_b_full = bwd.binary(BinaryOp::Mul, upstream, a_for_b, upstream_shape);
let g_a = unbroadcast(g_a_full, &a_shape, bwd);
let g_b = unbroadcast(g_b_full, &b_shape, bwd);
vec![(0, g_a), (1, g_b)]
}
Op::Activation(kind) => {
let x_bwd = fwd_map[&node.inputs[0]];
let dx = match kind {
Activation::Relu => bwd.relu_backward(x_bwd, upstream),
_ => bwd.activation_backward(*kind, x_bwd, upstream),
};
vec![(0, dx)]
}
Op::MatMul => {
let a_bwd = fwd_map[&node.inputs[0]];
let b_bwd = fwd_map[&node.inputs[1]];
let a_shape = bwd.node(a_bwd).shape.clone();
let b_shape = bwd.node(b_bwd).shape.clone();
assert!(
a_shape.rank() >= 2 && b_shape.rank() >= 2,
"MatMul VJP: rank must be ≥ 2, got {} and {}",
a_shape.rank(),
b_shape.rank()
);
let dtype = upstream_shape.dtype();
let trans_last_two = |bwd: &mut Graph, x: NodeId| -> NodeId {
let s = bwd.node(x).shape.clone();
let r = s.rank();
let mut perm: Vec<usize> = (0..r).collect();
perm.swap(r - 2, r - 1);
let mut dims: Vec<Dim> = s.dims().to_vec();
dims.swap(r - 2, r - 1);
let new_shape = Shape::from_dims(&dims, s.dtype());
bwd.add_node(Op::Transpose { perm }, vec![x], new_shape)
};
let upstream_dims: Vec<Dim> = upstream_shape.dims().to_vec();
let r_up = upstream_dims.len();
let b_t = trans_last_two(bwd, b_bwd);
let mut ga_dims = upstream_dims.clone();
ga_dims[r_up - 1] = a_shape.dim(a_shape.rank() - 1); let ga_shape = Shape::from_dims(&ga_dims, dtype);
let g_a_full = bwd.matmul(upstream, b_t, ga_shape);
let g_a = unbroadcast(g_a_full, &a_shape, bwd);
let a_t = trans_last_two(bwd, a_bwd);
let mut gb_dims = upstream_dims.clone();
gb_dims[r_up - 2] = a_shape.dim(a_shape.rank() - 1); let gb_shape = Shape::from_dims(&gb_dims, dtype);
let g_b_full = bwd.matmul(a_t, upstream, gb_shape);
let g_b = unbroadcast(g_b_full, &b_shape, bwd);
vec![(0, g_a), (1, g_b)]
}
Op::DenseSolve => {
let a_bwd = fwd_map[&node.inputs[0]];
let x_bwd = fwd_map[&node.id];
let a_shape = bwd.node(a_bwd).shape.clone();
let x_shape = bwd.node(x_bwd).shape.clone();
assert_eq!(a_shape.rank(), 2, "DenseSolve VJP: A must be 2-D");
let n = match a_shape.dim(0) {
Dim::Static(n) => n,
Dim::Dynamic(_) => panic!("DenseSolve VJP: dynamic N not supported"),
};
let dtype = a_shape.dtype();
let mut a_t_dims: Vec<Dim> = a_shape.dims().to_vec();
a_t_dims.swap(0, 1);
let a_t_shape = Shape::from_dims(&a_t_dims, dtype);
let a_t = bwd.add_node(Op::Transpose { perm: vec![1, 0] }, vec![a_bwd], a_t_shape);
let d_b = bwd.dense_solve(a_t, upstream, x_shape.clone());
let neg_outer = match x_shape.rank() {
1 => {
let col_shape = Shape::from_dims(&[Dim::Static(n), Dim::Static(1)], dtype);
let row_shape = Shape::from_dims(&[Dim::Static(1), Dim::Static(n)], dtype);
let db_col = bwd.add_node(
Op::Reshape {
new_shape: vec![n as i64, 1],
},
vec![d_b],
col_shape,
);
let x_row = bwd.add_node(
Op::Reshape {
new_shape: vec![1, n as i64],
},
vec![x_bwd],
row_shape,
);
let outer = bwd.matmul(db_col, x_row, a_shape.clone());
bwd.activation(Activation::Neg, outer, a_shape)
}
2 => {
let k = match x_shape.dim(1) {
Dim::Static(k) => k,
Dim::Dynamic(_) => panic!("DenseSolve VJP: dynamic K not supported"),
};
let xt_dims = vec![Dim::Static(k), Dim::Static(n)];
let xt_shape = Shape::from_dims(&xt_dims, dtype);
let x_t =
bwd.add_node(Op::Transpose { perm: vec![1, 0] }, vec![x_bwd], xt_shape);
let outer = bwd.matmul(d_b, x_t, a_shape.clone());
bwd.activation(Activation::Neg, outer, a_shape)
}
r => panic!("DenseSolve VJP: B must be rank 1 or 2, got rank {r}"),
};
vec![(0, neg_outer), (1, d_b)]
}
Op::BatchedDenseSolve => {
let a_bwd = fwd_map[&node.inputs[0]];
let x_bwd = fwd_map[&node.id];
let a_shape = bwd.node(a_bwd).shape.clone();
let x_shape = bwd.node(x_bwd).shape.clone();
assert_eq!(
a_shape.rank(),
3,
"BatchedDenseSolve VJP: A must be rank-3 [B, N, N]"
);
let b_dim = match a_shape.dim(0) {
Dim::Static(b) => b,
Dim::Dynamic(_) => panic!("BatchedDenseSolve VJP: dynamic B not supported"),
};
let n = match a_shape.dim(1) {
Dim::Static(n) => n,
Dim::Dynamic(_) => panic!("BatchedDenseSolve VJP: dynamic N not supported"),
};
let dtype = a_shape.dtype();
let a_t = bwd.add_node(
Op::Transpose {
perm: vec![0, 2, 1],
},
vec![a_bwd],
a_shape.clone(),
);
let d_b = bwd.batched_dense_solve(a_t, upstream, x_shape.clone());
let neg_outer = match x_shape.rank() {
2 => {
let col_shape = Shape::from_dims(
&[Dim::Static(b_dim), Dim::Static(n), Dim::Static(1)],
dtype,
);
let row_shape = Shape::from_dims(
&[Dim::Static(b_dim), Dim::Static(1), Dim::Static(n)],
dtype,
);
let db_col = bwd.add_node(
Op::Reshape {
new_shape: vec![b_dim as i64, n as i64, 1],
},
vec![d_b],
col_shape,
);
let x_row = bwd.add_node(
Op::Reshape {
new_shape: vec![b_dim as i64, 1, n as i64],
},
vec![x_bwd],
row_shape,
);
let outer = bwd.matmul(db_col, x_row, a_shape.clone());
bwd.activation(Activation::Neg, outer, a_shape)
}
3 => {
let k = match x_shape.dim(2) {
Dim::Static(k) => k,
Dim::Dynamic(_) => panic!("BatchedDenseSolve VJP: dynamic K not supported"),
};
let xt_shape = Shape::from_dims(
&[Dim::Static(b_dim), Dim::Static(k), Dim::Static(n)],
dtype,
);
let x_t = bwd.add_node(
Op::Transpose {
perm: vec![0, 2, 1],
},
vec![x_bwd],
xt_shape,
);
let outer = bwd.matmul(d_b, x_t, a_shape.clone());
bwd.activation(Activation::Neg, outer, a_shape)
}
r => panic!("BatchedDenseSolve VJP: b must be rank 2 or 3, got rank {r}"),
};
vec![(0, neg_outer), (1, d_b)]
}
Op::Scan {
body,
length,
save_trajectory,
num_bcast: _,
num_xs,
num_checkpoints,
} => {
let init_bwd = fwd_map[&node.inputs[0]];
let traj_bwd = fwd_map[&node.id];
let init_shape = bwd.node(init_bwd).shape.clone();
let mut body_input_ids: Vec<NodeId> = body
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::Input { .. }))
.map(|n| n.id)
.collect();
body_input_ids.sort();
let body_vjp = grad(body, &body_input_ids);
let xs_bwd: Vec<NodeId> = (0..*num_xs as usize)
.map(|i| fwd_map[&node.inputs[1 + i]])
.collect();
let is_checkpointed = *num_checkpoints != 0 && *num_checkpoints != *length;
let forward_body_for_bwd = if is_checkpointed {
Some((**body).clone())
} else {
None
};
let dinit = bwd.scan_backward_with_checkpoints(
init_bwd,
traj_bwd,
upstream,
&xs_bwd,
body_vjp.clone(),
*length,
*save_trajectory,
*num_checkpoints,
forward_body_for_bwd.clone(),
init_shape,
);
let mut grads: Vec<(usize, NodeId)> = vec![(0, dinit)];
for i in 0..*num_xs as usize {
let outer_xs_id = node.inputs[1 + i];
let xs_shape = bwd.node(fwd_map[&outer_xs_id]).shape.clone();
let dxs_i = bwd.scan_backward_xs_with_checkpoints(
init_bwd,
traj_bwd,
upstream,
&xs_bwd,
body_vjp.clone(),
*length,
*save_trajectory,
i as u32,
*num_checkpoints,
forward_body_for_bwd.clone(),
xs_shape,
);
grads.push((1 + i, dxs_i));
}
grads
}
Op::Conv {
kernel_size,
stride,
padding,
dilation,
groups,
} => {
let x_bwd = fwd_map[&node.inputs[0]];
let w_bwd = fwd_map[&node.inputs[1]];
let x_shape = bwd.node(x_bwd).shape.clone();
let w_shape = bwd.node(w_bwd).shape.clone();
let dx = bwd.conv2d_backward_input(
upstream,
w_bwd,
x_shape,
kernel_size.clone(),
stride.clone(),
padding.clone(),
dilation.clone(),
*groups,
);
let _qat_bits: Option<u8> = None;
let dw = bwd.conv2d_backward_weight(
x_bwd,
upstream,
w_shape,
kernel_size.clone(),
stride.clone(),
padding.clone(),
dilation.clone(),
*groups,
);
vec![(0, dx), (1, dw)]
}
Op::Pool {
kind: ReduceOp::Max,
kernel_size,
stride,
padding,
} => {
let x_bwd = fwd_map[&node.inputs[0]];
let dx = bwd.maxpool2d_backward(
x_bwd,
upstream,
kernel_size.clone(),
stride.clone(),
padding.clone(),
);
vec![(0, dx)]
}
Op::SoftmaxCrossEntropyWithLogits => {
let logits_bwd = fwd_map[&node.inputs[0]];
let labels_bwd = fwd_map[&node.inputs[1]];
let dlogits = bwd.softmax_cross_entropy_backward(logits_bwd, labels_bwd, upstream);
vec![(0, dlogits)]
}
Op::Reduce {
op: ReduceOp::Sum,
axes,
keep_dim,
} => {
let x_bwd = fwd_map[&node.inputs[0]];
let x_shape = bwd.node(x_bwd).shape.clone();
let g = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
vec![(0, g)]
}
Op::Reduce {
op: ReduceOp::Mean,
axes,
keep_dim,
} => {
let x_bwd = fwd_map[&node.inputs[0]];
let x_shape = bwd.node(x_bwd).shape.clone();
let count: usize = axes
.iter()
.map(|&a| match x_shape.dim(a) {
Dim::Static(n) => n,
_ => panic!("Reduce::Mean VJP requires static reduced dims"),
})
.product();
let expanded = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
let inv_count = scalar_const(1.0 / count as f32, bwd);
let g = bwd.binary(BinaryOp::Mul, expanded, inv_count, x_shape);
vec![(0, g)]
}
Op::Reshape { .. } => {
let x_bwd = fwd_map[&node.inputs[0]];
let x_shape = bwd.node(x_bwd).shape.clone();
let dx = reshape_to(upstream, &x_shape, bwd);
vec![(0, dx)]
}
Op::ComplexNormSq => {
let z_bwd = fwd_map[&node.inputs[0]];
let dz = bwd.complex_norm_sq_backward(z_bwd, upstream);
vec![(0, dz)]
}
Op::Conjugate => {
let dz = bwd.conjugate(upstream);
vec![(0, dz)]
}
Op::Cast { .. } => {
let x_bwd = fwd_map[&node.inputs[0]];
let x_shape = bwd.node(x_bwd).shape.clone();
let dx = bwd.add_node(
Op::Cast {
to: x_shape.dtype(),
},
vec![upstream],
x_shape,
);
vec![(0, dx)]
}
Op::Quantize { .. } | Op::Dequantize { .. } => {
vec![(0, upstream)]
}
Op::FakeQuantizeLSQ { bits, axis } => {
let x_bwd = fwd_map[&node.inputs[0]];
let scale_bwd = fwd_map[&node.inputs[1]];
let x_shape = bwd.node(x_bwd).shape.clone();
let scale_shape = bwd.node(scale_bwd).shape.clone();
let dx = bwd.add_node(
Op::FakeQuantizeLSQBackwardX {
bits: *bits,
axis: *axis,
},
vec![x_bwd, scale_bwd, upstream],
x_shape,
);
let dscale = bwd.add_node(
Op::FakeQuantizeLSQBackwardScale {
bits: *bits,
axis: *axis,
},
vec![x_bwd, scale_bwd, upstream],
scale_shape,
);
vec![(0, dx), (1, dscale)]
}
Op::FakeQuantize {
bits, axis, ste, ..
} => {
use rlx_ir::op::SteKind;
match ste {
SteKind::Identity => vec![(0, upstream)],
_ => {
let x_bwd = fwd_map[&node.inputs[0]];
let x_shape = bwd.node(x_bwd).shape.clone();
let dx = bwd.add_node(
Op::FakeQuantizeBackward {
bits: *bits,
axis: *axis,
ste: *ste,
},
vec![x_bwd, upstream],
x_shape,
);
vec![(0, dx)]
}
}
}
Op::Expand { .. } => {
let x_bwd = fwd_map[&node.inputs[0]];
let x_shape = bwd.node(x_bwd).shape.clone();
let dx = unbroadcast(upstream, &x_shape, bwd);
vec![(0, dx)]
}
Op::LayerNorm { axis, eps } => {
let x_bwd = fwd_map[&node.inputs[0]];
let gamma_bwd = fwd_map[&node.inputs[1]];
let _beta_bwd = fwd_map[&node.inputs[2]];
let gamma_shape = bwd.node(gamma_bwd).shape.clone();
let dx = bwd.layer_norm_backward_input(x_bwd, gamma_bwd, upstream, *axis, *eps);
let dgamma =
bwd.layer_norm_backward_gamma(x_bwd, upstream, gamma_shape.clone(), *axis, *eps);
let dbeta = unbroadcast(upstream, &gamma_shape, bwd);
vec![(0, dx), (1, dgamma), (2, dbeta)]
}
Op::Softmax { axis } => {
let y_bwd = fwd_map[&node.id];
let y_shape = bwd.node(y_bwd).shape.clone();
let dtype = y_shape.dtype();
let rank = y_shape.rank();
let axis_pos = if *axis < 0 {
(rank as i32 + *axis) as usize
} else {
*axis as usize
};
let yg = bwd.binary(BinaryOp::Mul, y_bwd, upstream, y_shape.clone());
let mut kept_dims: Vec<Dim> = y_shape.dims().to_vec();
kept_dims[axis_pos] = Dim::Static(1);
let kept_shape = Shape::from_dims(&kept_dims, dtype);
let s = bwd.add_node(
Op::Reduce {
op: ReduceOp::Sum,
axes: vec![axis_pos],
keep_dim: true,
},
vec![yg],
kept_shape,
);
let target_dims: Vec<i64> = y_shape
.dims()
.iter()
.map(|d| match d {
Dim::Static(n) => *n as i64,
Dim::Dynamic(_) => -1,
})
.collect();
let s_expanded = bwd.add_node(
Op::Expand {
target_shape: target_dims,
},
vec![s],
y_shape.clone(),
);
let diff = bwd.binary(BinaryOp::Sub, upstream, s_expanded, y_shape.clone());
let dx = bwd.binary(BinaryOp::Mul, y_bwd, diff, y_shape);
vec![(0, dx)]
}
Op::Transpose { perm } => {
let inv: Vec<usize> = {
let mut v = vec![0usize; perm.len()];
for (i, &p) in perm.iter().enumerate() {
v[p] = i;
}
v
};
let x_bwd = fwd_map[&node.inputs[0]];
let x_shape = bwd.node(x_bwd).shape.clone();
let dx = bwd.add_node(Op::Transpose { perm: inv }, vec![upstream], x_shape);
vec![(0, dx)]
}
Op::Concat { axis } => {
let mut grads = Vec::with_capacity(node.inputs.len());
let mut offset: usize = 0;
for (i, &input_id) in node.inputs.iter().enumerate() {
let x_bwd = fwd_map[&input_id];
let x_shape = bwd.node(x_bwd).shape.clone();
let len = match x_shape.dim(*axis) {
Dim::Static(n) => n,
_ => panic!("Concat VJP: dynamic concat dim"),
};
let dx = bwd.add_node(
Op::Narrow {
axis: *axis,
start: offset,
len,
},
vec![upstream],
x_shape,
);
grads.push((i, dx));
offset += len;
}
grads
}
Op::Narrow { axis, start, len } => {
let x_bwd = fwd_map[&node.inputs[0]];
let x_shape = bwd.node(x_bwd).shape.clone();
let full_n = match x_shape.dim(*axis) {
Dim::Static(n) => n,
_ => panic!("Narrow VJP: dynamic axis"),
};
let pre = *start;
let post = full_n - *start - *len;
let zero_buf = |bwd: &mut Graph, len_axis: usize| -> NodeId {
if len_axis == 0 {
return upstream; }
let dtype = x_shape.dtype();
let mut dims: Vec<Dim> = x_shape.dims().to_vec();
dims[*axis] = Dim::Static(len_axis);
let s = Shape::from_dims(&dims, dtype);
let n_elems = dims.iter().fold(1usize, |a, d| match d {
Dim::Static(k) => a * k,
_ => a,
});
let bytes = vec![0u8; n_elems * dtype.size_bytes()];
bwd.add_node(Op::Constant { data: bytes }, vec![], s)
};
let mut parts: Vec<NodeId> = Vec::new();
if pre > 0 {
parts.push(zero_buf(bwd, pre));
}
parts.push(upstream);
if post > 0 {
parts.push(zero_buf(bwd, post));
}
let dx = if parts.len() == 1 {
parts[0]
} else {
bwd.add_node(Op::Concat { axis: *axis }, parts, x_shape)
};
vec![(0, dx)]
}
Op::Gather { axis } => {
let table_bwd = fwd_map[&node.inputs[0]];
let indices_bwd = fwd_map[&node.inputs[1]];
let table_shape = bwd.node(table_bwd).shape.clone();
if *axis == 0 {
let dtable = bwd.add_node(Op::ScatterAdd, vec![upstream, indices_bwd], table_shape);
vec![(0, dtable)]
} else {
let dtable = bwd.gather_backward(
upstream,
indices_bwd,
table_shape,
(*axis).try_into().unwrap(),
);
vec![(0, dtable)]
}
}
Op::Compare(_) => {
vec![]
}
Op::Where => {
let cond = fwd_map[&node.inputs[0]];
let a_bwd = fwd_map[&node.inputs[1]];
let b_bwd = fwd_map[&node.inputs[2]];
let a_shape = bwd.node(a_bwd).shape.clone();
let b_shape = bwd.node(b_bwd).shape.clone();
let out_shape = upstream_shape.clone();
let zero_a_bytes = vec![0u8; a_shape.num_elements().expect("Where VJP: dynamic a") * 4];
let zero_b_bytes = vec![0u8; b_shape.num_elements().expect("Where VJP: dynamic b") * 4];
let zero_a = bwd.add_node(Op::Constant { data: zero_a_bytes }, vec![], a_shape.clone());
let zero_b = bwd.add_node(Op::Constant { data: zero_b_bytes }, vec![], b_shape.clone());
let zero_a_bcast = unbroadcast_inverse(zero_a, &out_shape, bwd);
let zero_b_bcast = unbroadcast_inverse(zero_b, &out_shape, bwd);
let g_a_full = bwd.add_node(
Op::Where,
vec![cond, upstream, zero_a_bcast],
out_shape.clone(),
);
let g_b_full = bwd.add_node(Op::Where, vec![cond, zero_b_bcast, upstream], out_shape);
let g_a = unbroadcast(g_a_full, &a_shape, bwd);
let g_b = unbroadcast(g_b_full, &b_shape, bwd);
vec![(1, g_a), (2, g_b)]
}
Op::Binary(BinaryOp::Div) => {
let a_bwd = fwd_map[&node.inputs[0]];
let b_bwd = fwd_map[&node.inputs[1]];
let y_bwd = fwd_map[&node.id];
let a_shape = bwd.node(a_bwd).shape.clone();
let b_shape = bwd.node(b_bwd).shape.clone();
let is_c64 = upstream_shape.dtype() == DType::C64;
let b_term = if is_c64 { bwd.conjugate(b_bwd) } else { b_bwd };
let y_term = if is_c64 { bwd.conjugate(y_bwd) } else { y_bwd };
let g_a_full = bwd.binary(BinaryOp::Div, upstream, b_term, upstream_shape.clone());
let g_a = unbroadcast(g_a_full, &a_shape, bwd);
let neg_up = bwd.activation(Activation::Neg, upstream, upstream_shape.clone());
let neg_up_y = bwd.binary(BinaryOp::Mul, neg_up, y_term, upstream_shape.clone());
let g_b_full = bwd.binary(BinaryOp::Div, neg_up_y, b_term, upstream_shape);
let g_b = unbroadcast(g_b_full, &b_shape, bwd);
vec![(0, g_a), (1, g_b)]
}
Op::Reduce {
op: ReduceOp::Max,
axes,
keep_dim,
}
| Op::Reduce {
op: ReduceOp::Min,
axes,
keep_dim,
} => {
let is_max = matches!(
node.op,
Op::Reduce {
op: ReduceOp::Max,
..
}
);
let _ = is_max;
let x_bwd = fwd_map[&node.inputs[0]];
let y_bwd = fwd_map[&node.id];
let x_shape = bwd.node(x_bwd).shape.clone();
let y_expanded = expand_to(y_bwd, &x_shape, axes, *keep_dim, bwd);
let mask_bool = bwd.add_node(
Op::Compare(CmpOp::Eq),
vec![x_bwd, y_expanded],
Shape::from_dims(x_shape.dims(), DType::F32),
);
let mask_f32 = bwd.add_node(
Op::Cast {
to: x_shape.dtype(),
},
vec![mask_bool],
x_shape.clone(),
);
let upstream_expanded = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
let dx = bwd.binary(BinaryOp::Mul, upstream_expanded, mask_f32, x_shape);
vec![(0, dx)]
}
Op::Rope { head_dim, n_rot } => {
let cos = fwd_map[&node.inputs[1]];
let sin = fwd_map[&node.inputs[2]];
let dx = bwd.rope_backward(upstream, cos, sin, *head_dim, *n_rot);
vec![(0, dx)]
}
Op::RmsNorm { axis, eps } => {
let x = fwd_map[&node.inputs[0]];
let gamma = fwd_map[&node.inputs[1]];
let beta = fwd_map[&node.inputs[2]];
let dx = bwd.rms_norm_backward_input(x, gamma, beta, upstream, *axis, *eps);
let dgamma = bwd.rms_norm_backward_gamma(x, gamma, beta, upstream, *axis, *eps);
let dbeta = bwd.rms_norm_backward_beta(x, gamma, beta, upstream, *axis, *eps);
vec![(0, dx), (1, dgamma), (2, dbeta)]
}
Op::GroupNorm { num_groups, eps } => {
let x = fwd_map[&node.inputs[0]];
let gamma = fwd_map[&node.inputs[1]];
let beta = fwd_map[&node.inputs[2]];
let gamma_shape = bwd.node(gamma).shape.clone();
let beta_shape = bwd.node(beta).shape.clone();
let dx = bwd.group_norm_backward_input(x, gamma, beta, upstream, *num_groups, *eps);
let dgamma = bwd.group_norm_backward_gamma(x, upstream, gamma_shape, *num_groups, *eps);
let dbeta = bwd.group_norm_backward_beta(x, upstream, beta_shape, *num_groups, *eps);
vec![(0, dx), (1, dgamma), (2, dbeta)]
}
Op::Attention {
num_heads,
head_dim,
mask_kind,
score_scale: _,
attn_logit_softcap: _,
} => {
let q = fwd_map[&node.inputs[0]];
let k = fwd_map[&node.inputs[1]];
let v = fwd_map[&node.inputs[2]];
let mask = match mask_kind {
MaskKind::Custom | MaskKind::Bias => Some(fwd_map[&node.inputs[3]]),
_ => None,
};
let (dq, dk, dv) = bwd
.attention_backward_all(q, k, v, upstream, *num_heads, *head_dim, *mask_kind, mask);
vec![(0, dq), (1, dk), (2, dv)]
}
Op::Reduce {
op: ReduceOp::Prod,
axes,
keep_dim,
} => {
let x_bwd = fwd_map[&node.inputs[0]];
let y_bwd = fwd_map[&node.id];
let x_shape = bwd.node(x_bwd).shape.clone();
let y_expanded = expand_to(y_bwd, &x_shape, axes, *keep_dim, bwd);
let upstream_expanded = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
let num = bwd.binary(
BinaryOp::Mul,
upstream_expanded,
y_expanded,
x_shape.clone(),
);
let dx = bwd.binary(BinaryOp::Div, num, x_bwd, x_shape);
vec![(0, dx)]
}
Op::Pool {
kind: ReduceOp::Mean,
kernel_size,
stride,
padding,
} => {
assert_eq!(kernel_size.len(), 2, "Pool(Mean) VJP: 2-D pool only");
let x_bwd = fwd_map[&node.inputs[0]];
let x_shape = bwd.node(x_bwd).shape.clone();
let dtype = x_shape.dtype();
let c = match x_shape.dim(1) {
Dim::Static(n) => n,
_ => panic!("Pool(Mean) VJP: dynamic channel dim"),
};
let kh = kernel_size[0];
let kw = kernel_size[1];
let inv_n = 1.0_f32 / (kh as f32 * kw as f32);
let kernel_n = c * kh * kw;
let mut bytes: Vec<u8> = Vec::with_capacity(kernel_n * 4);
for _ in 0..kernel_n {
bytes.extend_from_slice(&inv_n.to_le_bytes());
}
let kernel_shape = Shape::from_dims(
&[
Dim::Static(c),
Dim::Static(1),
Dim::Static(kh),
Dim::Static(kw),
],
dtype,
);
let kernel = bwd.add_node(Op::Constant { data: bytes }, vec![], kernel_shape);
let dx = bwd.conv2d_backward_input(
upstream,
kernel,
x_shape,
kernel_size.clone(),
stride.clone(),
padding.clone(),
vec![1, 1],
c, );
vec![(0, dx)]
}
Op::Binary(BinaryOp::Min) | Op::Binary(BinaryOp::Max) => {
let a_bwd = fwd_map[&node.inputs[0]];
let b_bwd = fwd_map[&node.inputs[1]];
let y_bwd = fwd_map[&node.id];
let a_shape = bwd.node(a_bwd).shape.clone();
let b_shape = bwd.node(b_bwd).shape.clone();
let dtype = upstream_shape.dtype();
let bool_shape = Shape::from_dims(upstream_shape.dims(), DType::Bool);
let mask_pred = bwd.add_node(Op::Compare(CmpOp::Eq), vec![a_bwd, y_bwd], bool_shape);
let mask_f32 = bwd.add_node(
Op::Cast { to: dtype },
vec![mask_pred],
upstream_shape.clone(),
);
let zero_bytes = vec![
0u8;
upstream_shape
.num_elements()
.expect("Min/Max VJP: dyn shape")
* 4
];
let zero = bwd.add_node(
Op::Constant { data: zero_bytes },
vec![],
upstream_shape.clone(),
);
let g_a_full = bwd.add_node(
Op::Where,
vec![mask_f32, upstream, zero],
upstream_shape.clone(),
);
let g_b_full = bwd.add_node(Op::Where, vec![mask_f32, zero, upstream], upstream_shape);
let g_a = unbroadcast(g_a_full, &a_shape, bwd);
let g_b = unbroadcast(g_b_full, &b_shape, bwd);
vec![(0, g_a), (1, g_b)]
}
Op::Binary(BinaryOp::Pow) => {
let a_bwd = fwd_map[&node.inputs[0]];
let b_bwd = fwd_map[&node.inputs[1]];
let y_bwd = fwd_map[&node.id]; let a_shape = bwd.node(a_bwd).shape.clone();
let b_shape = bwd.node(b_bwd).shape.clone();
let yb = bwd.binary(BinaryOp::Mul, y_bwd, b_bwd, upstream_shape.clone());
let yb_over_a = bwd.binary(BinaryOp::Div, yb, a_bwd, upstream_shape.clone());
let g_a_full = bwd.binary(BinaryOp::Mul, upstream, yb_over_a, upstream_shape.clone());
let g_a = unbroadcast(g_a_full, &a_shape, bwd);
let ln_a = bwd.activation(Activation::Log, a_bwd, a_shape);
let ln_a_b = unbroadcast_inverse(ln_a, &upstream_shape, bwd);
let yln = bwd.binary(BinaryOp::Mul, y_bwd, ln_a_b, upstream_shape.clone());
let g_b_full = bwd.binary(BinaryOp::Mul, upstream, yln, upstream_shape);
let g_b = unbroadcast(g_b_full, &b_shape, bwd);
vec![(0, g_a), (1, g_b)]
}
Op::DequantMatMul { scheme: _ } => {
let x_bwd = fwd_map[&node.inputs[0]];
let w_q_bwd = fwd_map[&node.inputs[1]];
let scale_bwd = fwd_map[&node.inputs[2]];
let zp_bwd = fwd_map[&node.inputs[3]];
let x_shape = bwd.node(x_bwd).shape.clone();
let w_shape = bwd.node(w_q_bwd).shape.clone();
let scale_shape = bwd.node(scale_bwd).shape.clone();
let zp_shape = bwd.node(zp_bwd).shape.clone();
let dtype = x_shape.dtype();
let w_q_f32 = bwd.add_node(
Op::Cast { to: dtype },
vec![w_q_bwd],
Shape::from_dims(w_shape.dims(), dtype),
);
let scale_b =
unbroadcast_inverse(scale_bwd, &Shape::from_dims(w_shape.dims(), dtype), bwd);
let zp_b = unbroadcast_inverse(zp_bwd, &Shape::from_dims(w_shape.dims(), dtype), bwd);
let w_centered = bwd.binary(
BinaryOp::Sub,
w_q_f32,
zp_b,
Shape::from_dims(w_shape.dims(), dtype),
);
let w_dq = bwd.binary(
BinaryOp::Mul,
w_centered,
scale_b,
Shape::from_dims(w_shape.dims(), dtype),
);
let w_rank = w_shape.rank();
let mut perm: Vec<usize> = (0..w_rank).collect();
perm.swap(w_rank - 2, w_rank - 1);
let mut wdt_dims: Vec<Dim> = w_shape.dims().to_vec();
wdt_dims.swap(w_rank - 2, w_rank - 1);
let w_dq_t_shape = Shape::from_dims(&wdt_dims, dtype);
let w_dq_t = bwd.add_node(Op::Transpose { perm }, vec![w_dq], w_dq_t_shape);
let dx = bwd.matmul(upstream, w_dq_t, x_shape.clone());
let x_rank = x_shape.rank();
let mut x_perm: Vec<usize> = (0..x_rank).collect();
x_perm.swap(x_rank - 2, x_rank - 1);
let mut x_t_dims: Vec<Dim> = x_shape.dims().to_vec();
x_t_dims.swap(x_rank - 2, x_rank - 1);
let x_t = bwd.add_node(
Op::Transpose { perm: x_perm },
vec![x_bwd],
Shape::from_dims(&x_t_dims, dtype),
);
let dw_unscaled = bwd.matmul(x_t, upstream, Shape::from_dims(w_shape.dims(), dtype));
let dw_q_f32 = bwd.binary(
BinaryOp::Mul,
dw_unscaled,
scale_b,
Shape::from_dims(w_shape.dims(), dtype),
);
let dw_q = bwd.add_node(
Op::Cast {
to: w_shape.dtype(),
},
vec![dw_q_f32],
w_shape,
);
let zero_scale_bytes =
vec![0u8; scale_shape.num_elements().expect("DQMM VJP: dyn scale") * 4];
let zero_zp_bytes = vec![0u8; zp_shape.num_elements().expect("DQMM VJP: dyn zp") * 4];
let dscale = bwd.add_node(
Op::Constant {
data: zero_scale_bytes,
},
vec![],
scale_shape,
);
let dzp = bwd.add_node(
Op::Constant {
data: zero_zp_bytes,
},
vec![],
zp_shape,
);
vec![(0, dx), (1, dw_q), (2, dscale), (3, dzp)]
}
Op::ScatterAdd => {
let updates_bwd = fwd_map[&node.inputs[0]];
let indices_bwd = fwd_map[&node.inputs[1]];
let updates_shape = bwd.node(updates_bwd).shape.clone();
let dupdates = bwd.add_node(
Op::Gather { axis: 0 },
vec![upstream, indices_bwd],
updates_shape,
);
vec![(0, dupdates)]
}
Op::Cumsum { axis, exclusive } => {
let x_bwd = fwd_map[&node.inputs[0]];
let x_shape = bwd.node(x_bwd).shape.clone();
let dx = bwd.cumsum_backward(upstream, x_shape, *axis, *exclusive);
vec![(0, dx)]
}
Op::GroupedMatMul => {
let x_bwd = fwd_map[&node.inputs[0]];
let w_bwd = fwd_map[&node.inputs[1]];
let expert_bwd = fwd_map[&node.inputs[2]];
let x_shape = bwd.node(x_bwd).shape.clone();
let w_shape = bwd.node(w_bwd).shape.clone();
let (dx, dw) =
grouped_matmul_vjp(bwd, upstream, x_bwd, w_bwd, expert_bwd, &x_shape, &w_shape);
vec![(0, dx), (1, dw)]
}
Op::DequantGroupedMatMul { scheme } => {
let x_bwd = fwd_map[&node.inputs[0]];
let w_packed = fwd_map[&node.inputs[1]];
let expert_bwd = fwd_map[&node.inputs[2]];
let x_shape = bwd.node(x_bwd).shape.clone();
let w_packed_shape = bwd.node(w_packed).shape.clone();
let dtype = x_shape.dtype();
let k = x_shape.dim(1);
let n_out = node.shape.dim(node.shape.rank() - 1);
let k_static = match k {
Dim::Static(v) => v,
_ => panic!("DequantGroupedMatMul VJP: K must be static"),
};
let n_static = match n_out {
Dim::Static(v) => v,
_ => panic!("DequantGroupedMatMul VJP: N must be static"),
};
let block_elems = scheme.gguf_block_size() as usize;
let block_bytes = scheme.gguf_block_bytes() as usize;
let slab_bytes = (k_static * n_static) / block_elems * block_bytes;
let total_bytes = w_packed_shape
.num_elements()
.expect("DequantGroupedMatMul VJP: dyn packed");
let e_static = total_bytes / slab_bytes.max(1);
let w_shape = Shape::from_dims(
&[
Dim::Static(e_static),
Dim::Static(k_static),
Dim::Static(n_static),
],
dtype,
);
let w_dq = bwd.add_node(
Op::DequantMoEWeights { scheme: *scheme },
vec![w_packed],
w_shape.clone(),
);
let (dx, _dw) =
grouped_matmul_vjp(bwd, upstream, x_bwd, w_dq, expert_bwd, &x_shape, &w_shape);
vec![(0, dx)]
}
Op::QMatMul {
x_zp,
w_zp,
out_zp: _,
mult,
} => {
let x_bwd = fwd_map[&node.inputs[0]];
let w_bwd = fwd_map[&node.inputs[1]];
let bias_bwd = fwd_map[&node.inputs[2]];
let x_shape = bwd.node(x_bwd).shape.clone();
let w_shape = bwd.node(w_bwd).shape.clone();
let bias_shape = bwd.node(bias_bwd).shape.clone();
let dtype = upstream_shape.dtype();
let x_f32 = bwd.add_node(
Op::Cast { to: dtype },
vec![x_bwd],
Shape::from_dims(x_shape.dims(), dtype),
);
let w_f32 = bwd.add_node(
Op::Cast { to: dtype },
vec![w_bwd],
Shape::from_dims(w_shape.dims(), dtype),
);
let xzp_c = scalar_const(*x_zp as f32, bwd);
let xzp_b = unbroadcast_inverse(xzp_c, &Shape::from_dims(x_shape.dims(), dtype), bwd);
let _ = bwd.binary(
BinaryOp::Sub,
x_f32,
xzp_b,
Shape::from_dims(x_shape.dims(), dtype),
);
let wzp_c = scalar_const(*w_zp as f32, bwd);
let wzp_b = unbroadcast_inverse(wzp_c, &Shape::from_dims(w_shape.dims(), dtype), bwd);
let w_centered = bwd.binary(
BinaryOp::Sub,
w_f32,
wzp_b,
Shape::from_dims(w_shape.dims(), dtype),
);
let mult_c = scalar_const(*mult, bwd);
let mult_b = unbroadcast_inverse(mult_c, &upstream_shape, bwd);
let upstream_scaled =
bwd.binary(BinaryOp::Mul, upstream, mult_b, upstream_shape.clone());
let w_rank = w_shape.rank();
let mut perm: Vec<usize> = (0..w_rank).collect();
perm.swap(w_rank - 2, w_rank - 1);
let mut wt_dims: Vec<Dim> = w_shape.dims().to_vec();
wt_dims.swap(w_rank - 2, w_rank - 1);
let w_t = bwd.add_node(
Op::Transpose { perm },
vec![w_centered],
Shape::from_dims(&wt_dims, dtype),
);
let dx_f32 = bwd.matmul(
upstream_scaled,
w_t,
Shape::from_dims(x_shape.dims(), dtype),
);
let dx = bwd.add_node(
Op::Cast {
to: x_shape.dtype(),
},
vec![dx_f32],
x_shape.clone(),
);
let x_rank = x_shape.rank();
let mut x_perm: Vec<usize> = (0..x_rank).collect();
x_perm.swap(x_rank - 2, x_rank - 1);
let mut xt_dims: Vec<Dim> = x_shape.dims().to_vec();
xt_dims.swap(x_rank - 2, x_rank - 1);
let x_f32_2 = bwd.add_node(
Op::Cast { to: dtype },
vec![x_bwd],
Shape::from_dims(x_shape.dims(), dtype),
);
let x_centered = bwd.binary(
BinaryOp::Sub,
x_f32_2,
xzp_b,
Shape::from_dims(x_shape.dims(), dtype),
);
let x_t = bwd.add_node(
Op::Transpose { perm: x_perm },
vec![x_centered],
Shape::from_dims(&xt_dims, dtype),
);
let dw_f32 = bwd.matmul(
x_t,
upstream_scaled,
Shape::from_dims(w_shape.dims(), dtype),
);
let dw = bwd.add_node(
Op::Cast {
to: w_shape.dtype(),
},
vec![dw_f32],
w_shape,
);
let bias_rank = bias_shape.rank();
let reduce_axes: Vec<usize> = (0..upstream_shape.rank())
.filter(|&i| i + bias_rank < upstream_shape.rank() || i == 0)
.collect();
let dbias_f32 = bwd.add_node(
Op::Reduce {
op: ReduceOp::Sum,
axes: reduce_axes,
keep_dim: false,
},
vec![upstream_scaled],
Shape::from_dims(bias_shape.dims(), dtype),
);
let dbias = bwd.add_node(
Op::Cast {
to: bias_shape.dtype(),
},
vec![dbias_f32],
bias_shape,
);
vec![(0, dx), (1, dw), (2, dbias)]
}
Op::QConv2d {
kernel_size,
stride,
padding,
dilation,
groups,
x_zp,
w_zp,
out_zp: _,
mult,
} => {
let x_bwd = fwd_map[&node.inputs[0]];
let w_bwd = fwd_map[&node.inputs[1]];
let bias_bwd = fwd_map[&node.inputs[2]];
let x_shape = bwd.node(x_bwd).shape.clone();
let w_shape = bwd.node(w_bwd).shape.clone();
let bias_shape = bwd.node(bias_bwd).shape.clone();
let dtype = upstream_shape.dtype();
let x_f32 = bwd.add_node(
Op::Cast { to: dtype },
vec![x_bwd],
Shape::from_dims(x_shape.dims(), dtype),
);
let w_f32 = bwd.add_node(
Op::Cast { to: dtype },
vec![w_bwd],
Shape::from_dims(w_shape.dims(), dtype),
);
let xzp_c = scalar_const(*x_zp as f32, bwd);
let xzp_b = unbroadcast_inverse(xzp_c, &Shape::from_dims(x_shape.dims(), dtype), bwd);
let x_centered = bwd.binary(
BinaryOp::Sub,
x_f32,
xzp_b,
Shape::from_dims(x_shape.dims(), dtype),
);
let wzp_c = scalar_const(*w_zp as f32, bwd);
let wzp_b = unbroadcast_inverse(wzp_c, &Shape::from_dims(w_shape.dims(), dtype), bwd);
let w_centered = bwd.binary(
BinaryOp::Sub,
w_f32,
wzp_b,
Shape::from_dims(w_shape.dims(), dtype),
);
let mult_c = scalar_const(*mult, bwd);
let mult_b = unbroadcast_inverse(mult_c, &upstream_shape, bwd);
let upstream_scaled =
bwd.binary(BinaryOp::Mul, upstream, mult_b, upstream_shape.clone());
let dx_f32 = bwd.conv2d_backward_input(
upstream_scaled,
w_centered,
Shape::from_dims(x_shape.dims(), dtype),
kernel_size.clone(),
stride.clone(),
padding.clone(),
dilation.clone(),
*groups,
);
let dx = bwd.add_node(
Op::Cast {
to: x_shape.dtype(),
},
vec![dx_f32],
x_shape,
);
let dw_f32 = bwd.conv2d_backward_weight(
x_centered,
upstream_scaled,
Shape::from_dims(w_shape.dims(), dtype),
kernel_size.clone(),
stride.clone(),
padding.clone(),
dilation.clone(),
*groups,
);
let dw = bwd.add_node(
Op::Cast {
to: w_shape.dtype(),
},
vec![dw_f32],
w_shape,
);
let dbias_f32 = bwd.add_node(
Op::Reduce {
op: ReduceOp::Sum,
axes: vec![0, 2, 3],
keep_dim: false,
},
vec![upstream_scaled],
Shape::from_dims(bias_shape.dims(), dtype),
);
let dbias = bwd.add_node(
Op::Cast {
to: bias_shape.dtype(),
},
vec![dbias_f32],
bias_shape,
);
vec![(0, dx), (1, dw), (2, dbias)]
}
Op::TopK { .. } | Op::Sample { .. } => {
vec![]
}
Op::GaussianSplatRender {
width,
height,
tile_size,
radius_scale,
alpha_cutoff,
max_splat_steps,
transmittance_threshold,
max_list_entries,
..
} => {
use rlx_ir::ops::splat::{
GaussianSplatBackwardParams, GaussianSplatInputs, GaussianSplatRenderParams,
unpack_gaussian_splat_packed_grads,
};
let render = GaussianSplatRenderParams {
width: *width,
height: *height,
tile_size: *tile_size,
radius_scale: *radius_scale,
alpha_cutoff: *alpha_cutoff,
max_splat_steps: *max_splat_steps,
transmittance_threshold: *transmittance_threshold,
max_list_entries: *max_list_entries,
};
let inputs = GaussianSplatInputs {
positions: fwd_map[&node.inputs[0]],
scales: fwd_map[&node.inputs[1]],
rotations: fwd_map[&node.inputs[2]],
opacities: fwd_map[&node.inputs[3]],
colors: fwd_map[&node.inputs[4]],
sh_coeffs: fwd_map[&node.inputs[5]],
meta: fwd_map[&node.inputs[6]],
};
let count = bwd.shape(inputs.positions).num_elements().unwrap_or(0) / 3;
let sh_len = bwd.shape(inputs.sh_coeffs).num_elements().unwrap_or(0);
let meta_shape = bwd.shape(inputs.meta).clone();
let packed = bwd.gaussian_splat_render_backward(
inputs,
upstream,
GaussianSplatBackwardParams {
render,
loss_grad_clip: 1.0,
sh_band: 0,
max_anisotropy: 10.0,
},
);
let sh_coeff_count = if count == 0 {
1
} else {
(sh_len / (count * 3)).max(1)
};
let grads = unpack_gaussian_splat_packed_grads(bwd, packed, count, sh_coeff_count);
let meta_n = meta_shape.num_elements().unwrap_or(0);
let zero_meta = bwd.add_node(
Op::Constant {
data: vec![0u8; meta_n * meta_shape.dtype().size_bytes()],
},
vec![],
meta_shape,
);
vec![
(0, grads.positions),
(1, grads.scales),
(2, grads.rotations),
(3, grads.opacities),
(4, grads.colors),
(5, grads.sh_coeffs),
(6, zero_meta),
]
}
Op::GaussianSplatRenderBackward { .. } => {
vec![]
}
Op::GaussianSplatPrepare { .. } | Op::GaussianSplatRasterize { .. } => {
panic!(
"autodiff: decomposed splat ops must be fused before AD — \
`prepare_graph_for_ad` rewrites Prepare→Rasterize into \
`GaussianSplatRender`, or use `Op::GaussianSplatRender` directly"
);
}
Op::CustomFn {
vjp_body: Some(vjp_body),
num_inputs,
..
} => {
let mut sub_to_bwd: HashMap<NodeId, NodeId> = HashMap::new();
let mut primal_input_ids: Vec<NodeId> = vjp_body
.nodes()
.iter()
.filter_map(|n| match &n.op {
Op::Input { name } if name != "primal_output" && name != "d_output" => {
Some(n.id)
}
_ => None,
})
.collect();
primal_input_ids.sort();
assert_eq!(primal_input_ids.len(), *num_inputs as usize);
for sub_node in vjp_body.nodes() {
let new_id = match &sub_node.op {
Op::Input { name } if name == "primal_output" => fwd_map[&node.id],
Op::Input { name } if name == "d_output" => upstream,
Op::Input { .. } => {
let idx = primal_input_ids
.iter()
.position(|&id| id == sub_node.id)
.expect(
"custom_fn vjp_body: primal Input \
not found in primal list",
);
fwd_map[&node.inputs[idx]]
}
_ => {
let new_inputs: Vec<NodeId> =
sub_node.inputs.iter().map(|i| sub_to_bwd[i]).collect();
bwd.add_node(sub_node.op.clone(), new_inputs, sub_node.shape.clone())
}
};
sub_to_bwd.insert(sub_node.id, new_id);
}
let mut grads: Vec<(usize, NodeId)> = Vec::with_capacity(*num_inputs as usize);
for (i, out_id) in vjp_body.outputs.iter().enumerate() {
grads.push((i, sub_to_bwd[out_id]));
}
grads
}
Op::CustomFn { vjp_body: None, .. } => {
panic!(
"autodiff: Op::CustomFn has no vjp_body and was not inlined. \
This is an internal error in inline_custom_fn_for_autodiff."
)
}
Op::Custom { name, .. } => {
let ext = rlx_ir::lookup_op(name).unwrap_or_else(|| {
panic!(
"autodiff: Op::Custom('{name}') is not registered \
in the op registry — register it via \
rlx_ir::register_op before compiling the graph"
)
});
let mut ctx = rlx_ir::VjpContext {
upstream,
fwd_map,
bwd,
};
ext.vjp(node, &mut ctx)
}
Op::Fft { inverse, norm } => {
let n = rlx_ir::fft::fft_meta(bwd.shape(node.inputs[0])).n_complex;
let s = norm.output_scale(n, *inverse) as f32;
let z = if s != 1.0 {
let sc = scalar_const(s, bwd);
bwd.mul(upstream, sc)
} else {
upstream
};
let dx = bwd.fft(z, !*inverse);
vec![(0, dx)]
}
other => panic!(
"autodiff: no VJP rule for {other}. See the matching \
entry in rlx-opt/src/autodiff.rs (catch-all panic) for \
a pointer to what's needed to differentiate this op.",
),
}
}
fn materialize_bcasts_for_ad(g: Graph) -> Graph {
use rlx_ir::op::BinaryOp;
let needs = g.nodes().iter().any(|n| {
matches!(
&n.op, Op::Scan { num_bcast, .. } if *num_bcast > 0
)
});
if !needs {
return g;
}
let mut out = Graph::new(g.name.clone());
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
for node in g.nodes() {
let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
match &node.op {
Op::Scan {
body,
length,
save_trajectory,
num_bcast,
num_xs,
num_checkpoints,
} if *num_bcast > 0 => {
let bcast_base = 1;
let xs_base = 1 + *num_bcast as usize;
let mut new_scan_inputs = vec![new_inputs[0]];
let mut materialised_xs: Vec<NodeId> = Vec::new();
for i in 0..*num_bcast as usize {
let b_id = new_inputs[bcast_base + i];
let b_shape = out.node(b_id).shape.clone();
let dtype = b_shape.dtype();
let mut ones_dims: Vec<rlx_ir::Dim> =
vec![rlx_ir::Dim::Static(*length as usize)];
for _ in 0..b_shape.rank() {
ones_dims.push(rlx_ir::Dim::Static(1));
}
let ones_shape = rlx_ir::Shape::from_dims(&ones_dims, dtype);
let n_elems: usize = ones_dims
.iter()
.map(|d| match d {
rlx_ir::Dim::Static(n) => *n,
rlx_ir::Dim::Dynamic(_) => 1,
})
.product();
let elem_size = dtype.size_bytes();
let mut data = Vec::with_capacity(n_elems * elem_size);
match dtype {
rlx_ir::DType::F64 => {
for _ in 0..n_elems {
data.extend_from_slice(&1.0_f64.to_le_bytes());
}
}
rlx_ir::DType::F32 => {
for _ in 0..n_elems {
data.extend_from_slice(&1.0_f32.to_le_bytes());
}
}
other => {
panic!("materialize_bcasts_for_ad: unsupported bcast dtype {other:?}")
}
}
let ones = out.add_node(Op::Constant { data }, vec![], ones_shape);
let mut xs_dims: Vec<rlx_ir::Dim> = vec![rlx_ir::Dim::Static(*length as usize)];
for i in 0..b_shape.rank() {
xs_dims.push(b_shape.dim(i));
}
let xs_shape = rlx_ir::Shape::from_dims(&xs_dims, dtype);
let xs_id = out.add_node(Op::Binary(BinaryOp::Mul), vec![ones, b_id], xs_shape);
materialised_xs.push(xs_id);
}
new_scan_inputs.extend_from_slice(&materialised_xs);
for i in 0..*num_xs as usize {
new_scan_inputs.push(new_inputs[xs_base + i]);
}
let new_id = out.add_node(
Op::Scan {
body: body.clone(),
length: *length,
save_trajectory: *save_trajectory,
num_bcast: 0,
num_xs: *num_bcast + *num_xs,
num_checkpoints: *num_checkpoints,
},
new_scan_inputs,
node.shape.clone(),
);
id_map.insert(node.id, new_id);
}
_ => {
let new_id = out.add_node(node.op.clone(), new_inputs, node.shape.clone());
id_map.insert(node.id, new_id);
}
}
}
let new_outputs: Vec<NodeId> = g.outputs.iter().map(|o| id_map[o]).collect();
out.set_outputs(new_outputs);
out
}
pub fn convert_scans_for_ad(g: Graph) -> Graph {
use rlx_ir::shape::Shape as IrShape;
let g = materialize_bcasts_for_ad(g);
let needs = g.nodes().iter().any(|n| {
matches!(
&n.op,
Op::Scan {
save_trajectory: false,
..
}
)
});
if !needs {
return g;
}
let mut out = Graph::new(g.name.clone());
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
for node in g.nodes() {
let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
match &node.op {
Op::Scan {
body,
length,
save_trajectory: false,
num_xs,
num_checkpoints,
..
} => {
let carry_shape = node.shape.clone();
let mut traj_dims: Vec<Dim> = Vec::with_capacity(carry_shape.rank() + 1);
traj_dims.push(Dim::Static(*length as usize));
for i in 0..carry_shape.rank() {
traj_dims.push(carry_shape.dim(i));
}
let traj_shape = IrShape::from_dims(&traj_dims, carry_shape.dtype());
let traj = out.add_node(
Op::Scan {
body: body.clone(),
length: *length,
save_trajectory: true,
num_bcast: 0,
num_xs: *num_xs,
num_checkpoints: *num_checkpoints,
},
new_inputs,
traj_shape,
);
let mut narrow_dims: Vec<Dim> = Vec::with_capacity(carry_shape.rank() + 1);
narrow_dims.push(Dim::Static(1));
for i in 0..carry_shape.rank() {
narrow_dims.push(carry_shape.dim(i));
}
let narrow_shape = IrShape::from_dims(&narrow_dims, carry_shape.dtype());
let narrowed = out.add_node(
Op::Narrow {
axis: 0,
start: (*length as usize).saturating_sub(1),
len: 1,
},
vec![traj],
narrow_shape,
);
let new_shape: Vec<i64> = (0..carry_shape.rank())
.map(|i| match carry_shape.dim(i) {
Dim::Static(n) => n as i64,
Dim::Dynamic(_) => -1,
})
.collect();
let final_id = out.add_node(Op::Reshape { new_shape }, vec![narrowed], carry_shape);
id_map.insert(node.id, final_id);
}
_ => {
let new_id = out.add_node(node.op.clone(), new_inputs, node.shape.clone());
id_map.insert(node.id, new_id);
}
}
}
let new_outputs: Vec<NodeId> = g.outputs.iter().map(|o| id_map[o]).collect();
out.set_outputs(new_outputs);
out
}
pub fn inline_custom_fn_for_autodiff(g: Graph) -> Graph {
use rlx_fusion::control_flow::inline_subgraph_into;
let mut out = Graph::new(g.name.clone());
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
let nodes: Vec<rlx_ir::Node> = g.nodes().to_vec();
for node in &nodes {
let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
let new_id = match &node.op {
Op::CustomFn {
vjp_body: None,
jvp_body: None,
fwd_body,
num_inputs,
..
} => {
assert_eq!(
new_inputs.len(),
*num_inputs as usize,
"custom_fn: outer input count mismatch"
);
inline_subgraph_into(fwd_body, &new_inputs, &mut out)
}
_ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
};
id_map.insert(node.id, new_id);
}
let new_outputs: Vec<NodeId> = g.outputs.iter().map(|i| id_map[i]).collect();
out.set_outputs(new_outputs);
out
}
pub(crate) fn unbroadcast_inverse(x: NodeId, target: &Shape, bwd: &mut Graph) -> NodeId {
let target_dims: Vec<i64> = target
.dims()
.iter()
.map(|d| match d {
Dim::Static(n) => *n as i64,
Dim::Dynamic(_) => -1,
})
.collect();
bwd.add_node(
Op::Expand {
target_shape: target_dims,
},
vec![x],
target.clone(),
)
}
fn expand_to(
grad: NodeId,
x_shape: &Shape,
axes: &[usize],
keep_dim: bool,
bwd: &mut Graph,
) -> NodeId {
let mut current = grad;
if !keep_dim {
let kept_dims: Vec<Dim> = (0..x_shape.rank())
.map(|i| {
if axes.contains(&i) {
Dim::Static(1)
} else {
x_shape.dim(i)
}
})
.collect();
let kept = Shape::from_dims(&kept_dims, x_shape.dtype());
current = reshape_to(current, &kept, bwd);
}
let target_shape: Vec<i64> = x_shape
.dims()
.iter()
.map(|d| match d {
Dim::Static(n) => *n as i64,
Dim::Dynamic(_) => -1,
})
.collect();
bwd.add_node(Op::Expand { target_shape }, vec![current], x_shape.clone())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn grad_of_add_is_identity() {
let mut g = Graph::new("test");
let x = g.input("x", Shape::new(&[4], DType::F32));
let y = g.input("y", Shape::new(&[4], DType::F32));
let z = g.binary(BinaryOp::Add, x, y, Shape::new(&[4], DType::F32));
g.set_outputs(vec![z]);
let bwd = grad(&g, &[x, y]);
assert_eq!(bwd.outputs.len(), 2);
}
#[test]
fn grad_of_mul_uses_other_operand() {
let mut g = Graph::new("test");
let x = g.input("x", Shape::new(&[4], DType::F32));
let y = g.input("y", Shape::new(&[4], DType::F32));
let z = g.binary(BinaryOp::Mul, x, y, Shape::new(&[4], DType::F32));
g.set_outputs(vec![z]);
let bwd = grad(&g, &[x, y]);
assert!(
bwd.nodes()
.iter()
.filter(|n| matches!(n.op, Op::Binary(BinaryOp::Mul)))
.count()
>= 2
);
}
#[test]
fn grad_with_loss_returns_loss_first() {
let mut g = Graph::new("loss");
let x = g.input("x", Shape::new(&[4], DType::F32));
let y = g.input("y", Shape::new(&[4], DType::F32));
let z = g.binary(BinaryOp::Add, x, y, Shape::new(&[4], DType::F32));
g.set_outputs(vec![z]);
let bwd = grad_with_loss(&g, &[x, y]);
assert_eq!(bwd.outputs.len(), 3);
}
#[test]
fn grad_of_dense_solve_emits_implicit_function_rule() {
let mut g = Graph::new("solve_test");
let a = g.param("A", Shape::new(&[2, 2], DType::F32));
let b = g.input("b", Shape::new(&[2], DType::F32));
let x = g.dense_solve(a, b, Shape::new(&[2], DType::F32));
let loss = g.reduce(
x,
ReduceOp::Sum,
vec![0],
false,
Shape::new(&[1], DType::F32),
);
g.set_outputs(vec![loss]);
let bwd = grad_with_loss(&g, &[a, b]);
assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
let count =
|pred: fn(&Op) -> bool| -> usize { bwd.nodes().iter().filter(|n| pred(&n.op)).count() };
assert!(
count(|o| matches!(o, Op::DenseSolve)) >= 2,
"expected ≥2 DenseSolve nodes (forward mirror + reverse), got\n{bwd}"
);
assert!(
count(|o| matches!(o, Op::Transpose { .. })) >= 1,
"expected a Transpose for Aᵀ, got\n{bwd}"
);
assert!(
count(|o| matches!(o, Op::MatMul)) >= 1,
"expected a MatMul for the outer product, got\n{bwd}"
);
assert!(
count(|o| matches!(o, Op::Activation(Activation::Neg))) >= 1,
"expected a Neg for −outer, got\n{bwd}"
);
}
#[test]
fn inline_if_replaces_with_where() {
let s = Shape::new(&[4], DType::F32);
let pred_s = Shape::new(&[1], DType::F32);
let mut then_g = Graph::new("then_branch");
let then_in = then_g.input("captured", s.clone());
let then_out = then_g.activation(Activation::Relu, then_in, s.clone());
then_g.set_outputs(vec![then_out]);
let mut else_g = Graph::new("else_branch");
let else_in = else_g.input("captured", s.clone());
let else_out = else_g.activation(Activation::Sigmoid, else_in, s.clone());
else_g.set_outputs(vec![else_out]);
let mut g = Graph::new("parent");
let x = g.input("x", s.clone());
let pred = g.input("pred", pred_s);
let if_out = g.add_node(
Op::If {
then_branch: Box::new(then_g),
else_branch: Box::new(else_g),
},
vec![pred, x],
s,
);
g.set_outputs(vec![if_out]);
let inlined = rlx_fusion::control_flow::inline_if(g);
let has_if = inlined
.nodes()
.iter()
.any(|n| matches!(n.op, Op::If { .. }));
let has_where = inlined.nodes().iter().any(|n| matches!(n.op, Op::Where));
let has_relu = inlined
.nodes()
.iter()
.any(|n| matches!(n.op, Op::Activation(Activation::Relu)));
let has_sigmoid = inlined
.nodes()
.iter()
.any(|n| matches!(n.op, Op::Activation(Activation::Sigmoid)));
assert!(!has_if, "Op::If should be inlined away");
assert!(has_where, "Op::Where should replace the Op::If");
assert!(has_relu, "then_branch's Activation(Relu) should be inlined");
assert!(
has_sigmoid,
"else_branch's Activation(Sigmoid) should be inlined"
);
assert_eq!(inlined.outputs.len(), 1);
}
#[test]
fn grad_through_if_propagates() {
let s = Shape::new(&[4], DType::F32);
let pred_s = Shape::new(&[1], DType::F32);
let mut then_g = Graph::new("th");
let ti = then_g.input("c", s.clone());
let to = then_g.binary(BinaryOp::Mul, ti, ti, s.clone());
then_g.set_outputs(vec![to]);
let mut else_g = Graph::new("el");
let ei = else_g.input("c", s.clone());
let eo = else_g.activation(Activation::Relu, ei, s.clone());
else_g.set_outputs(vec![eo]);
let mut g = Graph::new("parent");
let x = g.input("x", s.clone());
let pred = g.input("pred", pred_s);
let z = g.add_node(
Op::If {
then_branch: Box::new(then_g),
else_branch: Box::new(else_g),
},
vec![pred, x],
s,
);
g.set_outputs(vec![z]);
let bwd = grad_with_loss(&g, &[x]);
assert_eq!(bwd.outputs.len(), 2, "expected loss + 1 grad output");
}
#[test]
fn unroll_while_replicates_body_n_times() {
let s = Shape::new(&[4], DType::F32);
let bool_s = Shape::new(&[1], DType::F32);
let mut cond_g = Graph::new("cond");
let ci = cond_g.input("c", s.clone());
cond_g.set_outputs(vec![ci]);
let _ = bool_s;
let mut body_g = Graph::new("body");
let bi = body_g.input("c", s.clone());
let bo = body_g.activation(Activation::Relu, bi, s.clone());
body_g.set_outputs(vec![bo]);
let mut g = Graph::new("parent");
let x = g.input("x", s.clone());
let w = g.add_node(
Op::While {
cond: Box::new(cond_g),
body: Box::new(body_g),
max_iterations: Some(3),
},
vec![x],
s,
);
g.set_outputs(vec![w]);
let unrolled = rlx_fusion::control_flow::unroll_while(g);
let has_while = unrolled
.nodes()
.iter()
.any(|n| matches!(n.op, Op::While { .. }));
let relu_count = unrolled
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::Activation(Activation::Relu)))
.count();
assert!(!has_while, "Op::While should be unrolled away");
assert_eq!(
relu_count, 3,
"body's Activation(Relu) should appear once per iteration"
);
assert_eq!(unrolled.outputs.len(), 1);
}
#[test]
fn grad_through_while_propagates() {
let s = Shape::new(&[4], DType::F32);
let mut cond_g = Graph::new("cond");
let ci = cond_g.input("c", s.clone());
cond_g.set_outputs(vec![ci]);
let mut body_g = Graph::new("body");
let bi = body_g.input("c", s.clone());
let bo = body_g.binary(BinaryOp::Mul, bi, bi, s.clone());
body_g.set_outputs(vec![bo]);
let mut g = Graph::new("parent");
let x = g.input("x", s.clone());
let w = g.add_node(
Op::While {
cond: Box::new(cond_g),
body: Box::new(body_g),
max_iterations: Some(2),
},
vec![x],
s,
);
g.set_outputs(vec![w]);
let bwd = grad_with_loss(&g, &[x]);
assert_eq!(bwd.outputs.len(), 2, "expected loss + 1 grad output");
}
fn build_ftl_graph(has_bias: bool) -> (Graph, NodeId, Vec<NodeId>) {
let mut g = Graph::new("ftl_test");
let h_shape = Shape::new(&[1, 2, 4], DType::F32);
let h = g.input("h", h_shape.clone());
let qkv_w = g.param("qkv_w", Shape::new(&[4, 12], DType::F32));
let out_w = g.param("out_w", Shape::new(&[4, 4], DType::F32));
let ln1_g = g.param("ln1_g", Shape::new(&[4], DType::F32));
let fc1_w = g.param("fc1_w", Shape::new(&[4, 8], DType::F32));
let fc2_w = g.param("fc2_w", Shape::new(&[8, 4], DType::F32));
let ln2_g = g.param("ln2_g", Shape::new(&[4], DType::F32));
let mask = g.input("mask", Shape::new(&[1, 2, 2, 2], DType::F32));
let (inputs, params) = if has_bias {
let qkv_b = g.param("qkv_b", Shape::new(&[12], DType::F32));
let out_b = g.param("out_b", Shape::new(&[4], DType::F32));
let ln1_b = g.param("ln1_b", Shape::new(&[4], DType::F32));
let fc1_b = g.param("fc1_b", Shape::new(&[8], DType::F32));
let fc2_b = g.param("fc2_b", Shape::new(&[4], DType::F32));
let ln2_b = g.param("ln2_b", Shape::new(&[4], DType::F32));
(
vec![
h, qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g,
ln2_b, mask,
],
vec![
qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g,
ln2_b,
],
)
} else {
(
vec![h, qkv_w, out_w, ln1_g, fc1_w, fc2_w, ln2_g, mask],
vec![qkv_w, out_w, ln1_g, fc1_w, fc2_w, ln2_g],
)
};
let y = g.add_node(
Op::FusedTransformerLayer {
num_heads: 2,
head_dim: 2,
intermediate_size: 8,
eps1: 1e-5,
eps2: 1e-5,
activation: rlx_ir::op::Activation::Gelu,
has_bias,
},
inputs,
h_shape,
);
g.set_outputs(vec![y]);
(g, h, params)
}
#[test]
fn unfuse_decomposes_fused_transformer_layer() {
let (g, _h, _params) = build_ftl_graph(true);
let unfused = rlx_fusion::unfuse_fused_for_autodiff(g);
let has_ftl = unfused
.nodes()
.iter()
.any(|n| matches!(n.op, Op::FusedTransformerLayer { .. }));
assert!(!has_ftl, "Op::FusedTransformerLayer should be unfused");
let count = |pred: fn(&Op) -> bool| -> usize {
unfused.nodes().iter().filter(|n| pred(&n.op)).count()
};
assert!(
count(|o| matches!(o, Op::MatMul)) >= 4,
"expected >=4 MatMul after FTL unfuse"
);
assert_eq!(
count(|o| matches!(o, Op::Attention { .. })),
1,
"expected exactly 1 Attention after FTL unfuse"
);
assert_eq!(
count(|o| matches!(o, Op::LayerNorm { .. })),
2,
"expected exactly 2 LayerNorm after FTL unfuse"
);
assert!(
count(|o| matches!(o, Op::Narrow { .. })) >= 3,
"expected >=3 Narrow (Q/K/V split) after FTL unfuse"
);
assert_eq!(
count(|o| matches!(o, Op::Activation(_))),
1,
"expected exactly 1 Activation (FFN) after FTL unfuse"
);
}
#[test]
fn grad_through_fused_transformer_layer_propagates() {
let (g, _h, params) = build_ftl_graph(true);
let bwd = grad_with_loss(&g, ¶ms);
assert_eq!(
bwd.outputs.len(),
1 + params.len(),
"expected loss + {} param grads",
params.len()
);
}
#[test]
fn grad_through_fused_transformer_layer_no_bias() {
let (g, _h, params) = build_ftl_graph(false);
let bwd = grad_with_loss(&g, ¶ms);
assert_eq!(
bwd.outputs.len(),
1 + params.len(),
"expected loss + {} param grads (no-bias)",
params.len()
);
}
fn build_ssm_graph() -> (Graph, NodeId, Vec<NodeId>) {
let mut g = Graph::new("ssm_test");
let bsh = Shape::new(&[1, 3, 2], DType::F32);
let hn = Shape::new(&[2, 4], DType::F32);
let bsn = Shape::new(&[1, 3, 4], DType::F32);
let x = g.input("x", bsh.clone());
let delta = g.input("delta", bsh.clone());
let a = g.param("a", hn);
let b = g.input("b", bsn.clone());
let c = g.input("c", bsn);
let y = g.selective_scan(x, delta, a, b, c, 4, bsh);
g.set_outputs(vec![y]);
(g, x, vec![a])
}
#[test]
fn unfuse_decomposes_selective_scan() {
let (g, _x, _params) = build_ssm_graph();
let unfused = rlx_fusion::unfuse_fused_for_autodiff(g);
let has_ssm = unfused
.nodes()
.iter()
.any(|n| matches!(n.op, Op::SelectiveScan { .. }));
assert!(!has_ssm, "Op::SelectiveScan should be unfused");
let count = |pred: fn(&Op) -> bool| -> usize {
unfused.nodes().iter().filter(|n| pred(&n.op)).count()
};
assert_eq!(
count(|o| matches!(o, Op::Concat { .. })),
1,
"expected 1 Concat (over the 3 time steps)"
);
assert_eq!(
count(|o| matches!(
o,
Op::Reduce {
op: ReduceOp::Sum,
..
}
)),
3,
"expected one Reduce(Sum) per time step (S=3)"
);
assert_eq!(
count(|o| matches!(o, Op::Activation(Activation::Exp))),
3,
"expected one exp(δA) per time step (S=3)"
);
assert!(
count(|o| matches!(o, Op::Narrow { .. })) >= 12,
"expected >=12 Narrows (4 per step × 3 steps)"
);
}
#[test]
fn grad_through_selective_scan_propagates() {
let (g, _x, params) = build_ssm_graph();
let bwd = grad_with_loss(&g, ¶ms);
assert_eq!(
bwd.outputs.len(),
1 + params.len(),
"expected loss + {} param grads",
params.len()
);
}
fn build_gdn_graph() -> (Graph, NodeId, Vec<NodeId>) {
let (b, s, h, n) = (1usize, 3, 2, 4);
let mut g = Graph::new("gdn_test");
let bshn = Shape::new(&[b, s, h, n], DType::F32);
let bsh = Shape::new(&[b, s, h], DType::F32);
let q = g.input("q", bshn.clone());
let k = g.input("k", bshn.clone());
let v = g.input("v", bshn.clone());
let g_in = g.input("g", bsh.clone());
let beta = g.input("beta", bsh);
let y = g.gated_delta_net(q, k, v, g_in, beta, n, bshn);
g.set_outputs(vec![y]);
(g, q, vec![q, k, v, g_in, beta])
}
#[test]
fn unfuse_decomposes_gated_delta_net() {
let (g, _q, _params) = build_gdn_graph();
let unfused = rlx_fusion::unfuse_fused_for_autodiff(g);
let has_gdn = unfused
.nodes()
.iter()
.any(|n| matches!(n.op, Op::GatedDeltaNet { .. }));
assert!(!has_gdn, "Op::GatedDeltaNet should be unfused");
let count = |pred: fn(&Op) -> bool| -> usize {
unfused.nodes().iter().filter(|n| pred(&n.op)).count()
};
assert_eq!(
count(|o| matches!(o, Op::Concat { .. })),
1,
"expected 1 Concat over S=3 steps"
);
assert!(
count(|o| matches!(o, Op::MatMul)) >= 3,
"expected >=3 MatMul per step (sk + out) × S=3"
);
assert_eq!(
count(|o| matches!(o, Op::Activation(Activation::Exp))),
3,
"expected one exp(g) per time step"
);
}
#[test]
fn grad_through_gated_delta_net_propagates() {
let (g, _q, params) = build_gdn_graph();
let bwd = grad_with_loss(&g, ¶ms);
assert_eq!(
bwd.outputs.len(),
1 + params.len(),
"expected loss + {} input grads",
params.len()
);
}
#[test]
fn custom_fn_vjp_body_is_inlined_into_bwd() {
let n = 4usize;
let shape = Shape::new(&[n], DType::F32);
let mut fwd_body = Graph::new("square_fwd");
let xb = fwd_body.input("x", shape.clone());
let yb = fwd_body.binary(BinaryOp::Mul, xb, xb, shape.clone());
fwd_body.set_outputs(vec![yb]);
let mut vjp_body = Graph::new("square_vjp");
let _vx = vjp_body.input("x", shape.clone());
let _vp = vjp_body.input("primal_output", shape.clone());
let vd = vjp_body.input("d_output", shape.clone());
let dx = vjp_body.activation(Activation::Sin, vd, shape.clone());
vjp_body.set_outputs(vec![dx]);
let mut g = Graph::new("custom_fn_test");
let x = g.input("x", shape.clone());
let y = g.custom_fn(vec![x], fwd_body, Some(vjp_body), None);
let loss = g.reduce(
y,
ReduceOp::Sum,
vec![0],
false,
Shape::new(&[1], DType::F32),
);
g.set_outputs(vec![loss]);
let bwd = grad_with_loss(&g, &[x]);
assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
let sin_count = bwd
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::Activation(Activation::Sin)))
.count();
assert!(
sin_count >= 1,
"expected the vjp_body's Sin to be inlined into bwd, got\n{bwd}"
);
}
#[test]
fn custom_fn_without_vjp_inlines_fwd_body_for_autodiff() {
let n = 4usize;
let shape = Shape::new(&[n], DType::F32);
let mut fwd_body = Graph::new("square_fwd");
let xb = fwd_body.input("x", shape.clone());
let yb = fwd_body.binary(BinaryOp::Mul, xb, xb, shape.clone());
fwd_body.set_outputs(vec![yb]);
let mut g = Graph::new("custom_fn_no_vjp");
let x = g.input("x", shape.clone());
let y = g.custom_fn(vec![x], fwd_body, None, None);
let loss = g.reduce(
y,
ReduceOp::Sum,
vec![0],
false,
Shape::new(&[1], DType::F32),
);
g.set_outputs(vec![loss]);
let bwd = grad_with_loss(&g, &[x]);
assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
let custom_fn_count = bwd
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::CustomFn { .. }))
.count();
assert_eq!(
custom_fn_count, 0,
"CustomFn should be inlined away before autodiff"
);
let mul_count = bwd
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::Binary(BinaryOp::Mul)))
.count();
assert!(mul_count >= 2, "expected Mul-based VJP for x², got\n{bwd}");
}
#[test]
fn convert_scans_for_ad_forces_save_trajectory_true() {
let n = 2usize;
let length = 3u32;
let carry = Shape::new(&[n], DType::F32);
let xs_shape = Shape::new(&[length as usize, n], DType::F32);
let mut body = Graph::new("scan_body");
let bc = body.input("carry", carry.clone());
let bx = body.input("x_t", carry.clone());
let by = body.binary(BinaryOp::Add, bc, bx, carry.clone());
body.set_outputs(vec![by]);
let mut g = Graph::new("scan_save_false");
let init = g.input("init", carry.clone());
let xs = g.input("xs", xs_shape);
let scan_out = g.add_node(
Op::Scan {
body: Box::new(body),
length,
save_trajectory: false,
num_bcast: 0,
num_xs: 1,
num_checkpoints: 0,
},
vec![init, xs],
carry.clone(),
);
let loss = g.reduce(
scan_out,
ReduceOp::Sum,
vec![0],
false,
Shape::new(&[1], DType::F32),
);
g.set_outputs(vec![loss]);
let bwd = grad_with_loss(&g, &[init, xs]);
let saved_traj = bwd.nodes().iter().any(|n| {
matches!(
&n.op,
Op::Scan {
save_trajectory: true,
..
}
)
});
assert!(
saved_traj,
"convert_scans_for_ad should rewrite save_trajectory=false → \
save_trajectory=true in the AD-prepared graph; got\n{bwd}"
);
}
}