use rlx_ir::infer::GraphExt;
use rlx_ir::op::*;
use rlx_ir::shape::Dim;
use rlx_ir::*;
use std::collections::HashMap;
use crate::autodiff::unbroadcast_inverse;
pub fn jvp(forward: &Graph, tangent_for: &[NodeId]) -> Graph {
let forward_owned = crate::prepare_ad::prepare_graph_for_ad(forward.clone());
let forward = &forward_owned;
let mut bwd = Graph::new(format!("{}_jvp", forward.name));
let mut fwd_to_bwd: HashMap<NodeId, NodeId> = HashMap::new();
for node in forward.nodes() {
let inputs: Vec<NodeId> = node.inputs.iter().map(|i| fwd_to_bwd[i]).collect();
let new_id = bwd.add_node(node.op.clone(), inputs, node.shape.clone());
fwd_to_bwd.insert(node.id, new_id);
}
let mut tangents: HashMap<NodeId, NodeId> = HashMap::new();
for &id in tangent_for {
let original = forward.node(id);
let name = match &original.op {
Op::Input { name } | Op::Param { name } => name.clone(),
other => panic!("jvp: tangent_for[{id}] must be Input/Param, got {other:?}"),
};
let tangent = bwd.input(format!("tangent_{name}"), original.shape.clone());
tangents.insert(id, tangent);
}
for fwd_node in forward.nodes() {
if tangents.contains_key(&fwd_node.id) {
continue;
}
let in_tangents: Vec<Option<NodeId>> = fwd_node
.inputs
.iter()
.map(|id| tangents.get(id).copied())
.collect();
if in_tangents.iter().all(Option::is_none) {
continue;
}
if let Some(t_out) = jvp_rule(fwd_node, &in_tangents, &fwd_to_bwd, &mut bwd) {
tangents.insert(fwd_node.id, t_out);
}
}
let mut outs = Vec::with_capacity(2 * forward.outputs.len());
for &out in &forward.outputs {
outs.push(fwd_to_bwd[&out]);
}
for &out in &forward.outputs {
let t = match tangents.get(&out) {
Some(&t) => t,
None => zero_like(fwd_to_bwd[&out], &mut bwd),
};
outs.push(t);
}
bwd.set_outputs(outs);
bwd
}
pub fn hvp(forward: &Graph, wrt: &[NodeId]) -> Graph {
let bwd = crate::autodiff::grad_with_loss(forward, wrt);
let names: Vec<String> = wrt
.iter()
.map(|&id| match &forward.node(id).op {
Op::Input { name } | Op::Param { name } => name.clone(),
other => panic!("hvp: wrt[{id}] must be Input/Param, got {other:?}"),
})
.collect();
let bwd_ids: Vec<NodeId> = names
.iter()
.map(|name| {
bwd.nodes()
.iter()
.find(|n| match &n.op {
Op::Input { name: n_name } | Op::Param { name: n_name } => n_name == name,
_ => false,
})
.map(|n| n.id)
.unwrap_or_else(|| panic!("hvp: input '{name}' missing in backward graph"))
})
.collect();
jvp(&bwd, &bwd_ids)
}
fn zero_like(like: NodeId, bwd: &mut Graph) -> NodeId {
let shape = bwd.node(like).shape.clone();
let n_bytes = shape.size_bytes().unwrap_or(0);
let data = vec![0u8; n_bytes];
bwd.add_node(Op::Constant { data }, vec![], shape)
}
fn jvp_rule(
node: &Node,
t_inputs: &[Option<NodeId>],
fwd_map: &HashMap<NodeId, NodeId>,
bwd: &mut Graph,
) -> Option<NodeId> {
match &node.op {
Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => None,
Op::Binary(op) => {
let a_p = fwd_map[&node.inputs[0]];
let b_p = fwd_map[&node.inputs[1]];
let out_shape = node.shape.clone();
match op {
BinaryOp::Add => {
match (t_inputs[0], t_inputs[1]) {
(Some(ta), Some(tb)) => Some(bwd.binary(BinaryOp::Add, ta, tb, out_shape)),
(Some(ta), None) => Some(ta),
(None, Some(tb)) => Some(tb),
(None, None) => None,
}
}
BinaryOp::Sub => {
match (t_inputs[0], t_inputs[1]) {
(Some(ta), Some(tb)) => Some(bwd.binary(BinaryOp::Sub, ta, tb, out_shape)),
(Some(ta), None) => Some(ta),
(None, Some(tb)) => {
let s = bwd.node(tb).shape.clone();
Some(bwd.activation(Activation::Neg, tb, s))
}
(None, None) => None,
}
}
BinaryOp::Mul => {
let ta_b =
t_inputs[0].map(|ta| bwd.binary(BinaryOp::Mul, ta, b_p, out_shape.clone()));
let a_tb =
t_inputs[1].map(|tb| bwd.binary(BinaryOp::Mul, a_p, tb, out_shape.clone()));
match (ta_b, a_tb) {
(Some(x), Some(y)) => Some(bwd.binary(BinaryOp::Add, x, y, out_shape)),
(Some(x), None) | (None, Some(x)) => Some(x),
(None, None) => None,
}
}
BinaryOp::Div => {
let ta_b =
t_inputs[0].map(|ta| bwd.binary(BinaryOp::Mul, ta, b_p, out_shape.clone()));
let a_tb =
t_inputs[1].map(|tb| bwd.binary(BinaryOp::Mul, a_p, tb, out_shape.clone()));
let numer = match (ta_b, a_tb) {
(Some(x), Some(y)) => {
Some(bwd.binary(BinaryOp::Sub, x, y, out_shape.clone()))
}
(Some(x), None) => Some(x),
(None, Some(y)) => {
let s = bwd.node(y).shape.clone();
Some(bwd.activation(Activation::Neg, y, s))
}
(None, None) => None,
};
numer.map(|n| {
let bb = bwd.binary(BinaryOp::Mul, b_p, b_p, out_shape.clone());
bwd.binary(BinaryOp::Div, n, bb, out_shape)
})
}
BinaryOp::Min => {
let zero = scalar_const(0.0, &out_shape, bwd);
let bool_shape = Shape::from_dims(out_shape.dims(), DType::Bool);
let cond = bwd.add_node(Op::Compare(CmpOp::Lt), vec![a_p, b_p], bool_shape);
let cond_f = bwd.add_node(
Op::Cast {
to: out_shape.dtype(),
},
vec![cond],
out_shape.clone(),
);
let ta = t_inputs[0].unwrap_or(zero);
let tb = t_inputs[1].unwrap_or(zero);
Some(bwd.add_node(Op::Where, vec![cond_f, ta, tb], out_shape))
}
BinaryOp::Max => {
let zero = scalar_const(0.0, &out_shape, bwd);
let bool_shape = Shape::from_dims(out_shape.dims(), DType::Bool);
let cond = bwd.add_node(Op::Compare(CmpOp::Lt), vec![b_p, a_p], bool_shape);
let cond_f = bwd.add_node(
Op::Cast {
to: out_shape.dtype(),
},
vec![cond],
out_shape.clone(),
);
let ta = t_inputs[0].unwrap_or(zero);
let tb = t_inputs[1].unwrap_or(zero);
Some(bwd.add_node(Op::Where, vec![cond_f, ta, tb], out_shape))
}
BinaryOp::Pow => {
let ta = t_inputs[0];
let tb = t_inputs[1];
let mut terms = Vec::new();
if let Some(t_a) = ta {
let one = scalar_const(1.0, &out_shape, bwd);
let exp = bwd.binary(BinaryOp::Sub, b_p, one, out_shape.clone());
let apow = bwd.binary(BinaryOp::Pow, a_p, exp, out_shape.clone());
let apow_ta = bwd.binary(BinaryOp::Mul, apow, t_a, out_shape.clone());
let term = bwd.binary(BinaryOp::Mul, b_p, apow_ta, out_shape.clone());
terms.push(term);
}
if let Some(t_b) = tb {
let ln_a = bwd.activation(Activation::Log, a_p, out_shape.clone());
let apow = bwd.binary(BinaryOp::Pow, a_p, b_p, out_shape.clone());
let ln_t_b = bwd.binary(BinaryOp::Mul, ln_a, t_b, out_shape.clone());
let term = bwd.binary(BinaryOp::Mul, apow, ln_t_b, out_shape.clone());
terms.push(term);
}
terms
.into_iter()
.reduce(|a, b| bwd.binary(BinaryOp::Add, a, b, out_shape.clone()))
}
}
}
Op::Activation(kind) => {
let t_x = t_inputs[0]?;
let x = fwd_map[&node.inputs[0]];
let s = node.shape.clone();
let deriv = match kind {
Activation::Neg => {
return Some(bwd.activation(Activation::Neg, t_x, s));
}
Activation::Exp => {
fwd_map[&node.id]
}
Activation::Log => {
let one = scalar_const(1.0, &s, bwd);
bwd.binary(BinaryOp::Div, one, x, s.clone())
}
Activation::Sqrt => {
let half = scalar_const(0.5, &s, bwd);
let y = fwd_map[&node.id];
bwd.binary(BinaryOp::Div, half, y, s.clone())
}
Activation::Rsqrt => {
let y = fwd_map[&node.id];
let y2 = bwd.binary(BinaryOp::Mul, y, y, s.clone());
let y3 = bwd.binary(BinaryOp::Mul, y2, y, s.clone());
let neg_half = scalar_const(-0.5, &s, bwd);
bwd.binary(BinaryOp::Mul, neg_half, y3, s.clone())
}
Activation::Tanh => {
let y = fwd_map[&node.id];
let y2 = bwd.binary(BinaryOp::Mul, y, y, s.clone());
let one = scalar_const(1.0, &s, bwd);
bwd.binary(BinaryOp::Sub, one, y2, s.clone())
}
Activation::Sigmoid => {
let y = fwd_map[&node.id];
let one = scalar_const(1.0, &s, bwd);
let one_minus_y = bwd.binary(BinaryOp::Sub, one, y, s.clone());
bwd.binary(BinaryOp::Mul, y, one_minus_y, s.clone())
}
Activation::Relu => {
let zero = scalar_const(0.0, &s, bwd);
let mask = bwd.add_node(
Op::Compare(CmpOp::Gt),
vec![x, zero],
Shape::from_dims(s.dims(), DType::Bool),
);
let zero2 = scalar_const(0.0, &s, bwd);
return Some(bwd.add_node(Op::Where, vec![mask, t_x, zero2], s));
}
Activation::Sin => {
bwd.activation(Activation::Cos, x, s.clone())
}
Activation::Cos => {
let sx = bwd.activation(Activation::Sin, x, s.clone());
bwd.activation(Activation::Neg, sx, s.clone())
}
Activation::Tan => {
let y = fwd_map[&node.id];
let y2 = bwd.binary(BinaryOp::Mul, y, y, s.clone());
let one = scalar_const(1.0, &s, bwd);
bwd.binary(BinaryOp::Add, one, y2, s.clone())
}
Activation::Atan => {
let x2 = bwd.binary(BinaryOp::Mul, x, x, s.clone());
let one = scalar_const(1.0, &s, bwd);
let denom = bwd.binary(BinaryOp::Add, one, x2, s.clone());
let one2 = scalar_const(1.0, &s, bwd);
bwd.binary(BinaryOp::Div, one2, denom, s.clone())
}
Activation::Abs => {
let zero = scalar_const(0.0, &s, bwd);
let mask = bwd.add_node(
Op::Compare(CmpOp::Gt),
vec![x, zero],
Shape::from_dims(s.dims(), DType::Bool),
);
let neg_tx = bwd.activation(Activation::Neg, t_x, s.clone());
return Some(bwd.add_node(Op::Where, vec![mask, t_x, neg_tx], s));
}
Activation::Round => return None,
Activation::Gelu => {
let c = scalar_const(0.7978845608 * 0.5, &s, bwd); let x2 = bwd.binary(BinaryOp::Mul, x, x, s.clone());
let x3 = bwd.binary(BinaryOp::Mul, x, x2, s.clone());
let c_x3 = bwd.binary(BinaryOp::Mul, c, x3, s.clone());
let inner = bwd.binary(BinaryOp::Add, x, c_x3, s.clone());
let t = bwd.activation(Activation::Tanh, inner, s.clone());
let one = scalar_const(1.0, &s, bwd);
let t2 = bwd.binary(BinaryOp::Mul, t, t, s.clone());
let sech2 = bwd.binary(BinaryOp::Sub, one, t2, s.clone());
let one_half = scalar_const(1.5, &s, bwd);
let one_half_x2 = bwd.binary(BinaryOp::Mul, one_half, x2, s.clone());
let inner_deriv = bwd.binary(BinaryOp::Add, c, one_half_x2, s.clone());
bwd.binary(BinaryOp::Mul, sech2, inner_deriv, s.clone())
}
Activation::GeluApprox | Activation::Silu => {
let sig = bwd.activation(Activation::Sigmoid, x, s.clone());
let one = scalar_const(1.0, &s, bwd);
let one_minus = bwd.binary(BinaryOp::Sub, one, sig, s.clone());
let sig_om = bwd.binary(BinaryOp::Mul, sig, one_minus, s.clone());
let x_sig_om = bwd.binary(BinaryOp::Mul, x, sig_om, s.clone());
bwd.binary(BinaryOp::Add, sig, x_sig_om, s.clone())
}
};
Some(bwd.binary(BinaryOp::Mul, deriv, t_x, node.shape.clone()))
}
Op::MatMul => {
let a_p = fwd_map[&node.inputs[0]];
let b_p = fwd_map[&node.inputs[1]];
let out_shape = node.shape.clone();
let ta_b = t_inputs[0].map(|ta| bwd.matmul(ta, b_p, out_shape.clone()));
let a_tb = t_inputs[1].map(|tb| bwd.matmul(a_p, tb, out_shape.clone()));
match (ta_b, a_tb) {
(Some(x), Some(y)) => Some(bwd.binary(BinaryOp::Add, x, y, out_shape)),
(Some(x), None) | (None, Some(x)) => Some(x),
(None, None) => None,
}
}
Op::DenseSolve => {
let a_p = fwd_map[&node.inputs[0]];
let x = fwd_map[&node.id];
let x_shape = node.shape.clone();
let dtype = x_shape.dtype();
let make_ta_x = |t_a: NodeId, bwd: &mut Graph| -> NodeId {
match x_shape.rank() {
1 => {
let n = match x_shape.dim(0) {
Dim::Static(n) => n,
Dim::Dynamic(_) => panic!("jvp: DenseSolve dynamic N not supported"),
};
let x_col_shape =
Shape::from_dims(&[Dim::Static(n), Dim::Static(1)], dtype);
let x_col = bwd.add_node(
Op::Reshape {
new_shape: vec![n as i64, 1],
},
vec![x],
x_col_shape.clone(),
);
let prod_col = bwd.matmul(t_a, x_col, x_col_shape);
bwd.add_node(
Op::Reshape {
new_shape: vec![n as i64],
},
vec![prod_col],
x_shape.clone(),
)
}
2 => {
bwd.matmul(t_a, x, x_shape.clone())
}
r => panic!("jvp: DenseSolve B must be rank 1 or 2, got rank {r}"),
}
};
let rhs = match (t_inputs[0], t_inputs[1]) {
(Some(t_a), Some(t_b)) => {
let prod = make_ta_x(t_a, bwd);
bwd.binary(BinaryOp::Sub, t_b, prod, x_shape.clone())
}
(Some(t_a), None) => {
let prod = make_ta_x(t_a, bwd);
bwd.activation(Activation::Neg, prod, x_shape.clone())
}
(None, Some(t_b)) => t_b,
(None, None) => return None,
};
Some(bwd.dense_solve(a_p, rhs, x_shape))
}
Op::Reshape { new_shape } => {
let t_x = t_inputs[0]?;
Some(bwd.add_node(
Op::Reshape {
new_shape: new_shape.clone(),
},
vec![t_x],
node.shape.clone(),
))
}
Op::Transpose { perm } => {
let t_x = t_inputs[0]?;
Some(bwd.add_node(
Op::Transpose { perm: perm.clone() },
vec![t_x],
node.shape.clone(),
))
}
Op::Expand { target_shape } => {
let t_x = t_inputs[0]?;
Some(bwd.add_node(
Op::Expand {
target_shape: target_shape.clone(),
},
vec![t_x],
node.shape.clone(),
))
}
Op::Narrow { axis, start, len } => {
let t_x = t_inputs[0]?;
Some(bwd.add_node(
Op::Narrow {
axis: *axis,
start: *start,
len: *len,
},
vec![t_x],
node.shape.clone(),
))
}
Op::Fft { inverse, norm } => {
let t_x = t_inputs[0]?;
let n = rlx_ir::fft::fft_meta(bwd.shape(node.inputs[0])).n_complex;
let s = norm.output_scale(n, *inverse) as f32;
let t_y = bwd.fft(t_x, *inverse);
if s != 1.0 {
let sc = scalar_const(s as f64, &node.shape, bwd);
Some(bwd.mul(t_y, sc))
} else {
Some(t_y)
}
}
Op::Conjugate => {
let t_x = t_inputs[0]?;
Some(bwd.conjugate(t_x))
}
Op::Concat { axis } => {
if t_inputs.iter().all(Option::is_none) {
return None;
}
let mut t_ins: Vec<NodeId> = Vec::with_capacity(t_inputs.len());
for (i, t) in t_inputs.iter().enumerate() {
match t {
Some(node_id) => t_ins.push(*node_id),
None => {
let primal_in = fwd_map[&node.inputs[i]];
t_ins.push(zero_like(primal_in, bwd));
}
}
}
Some(bwd.add_node(Op::Concat { axis: *axis }, t_ins, node.shape.clone()))
}
Op::Reduce { op, axes, keep_dim } => {
let t_x = t_inputs[0]?;
match op {
ReduceOp::Sum | ReduceOp::Mean => {
Some(bwd.reduce(t_x, *op, axes.clone(), *keep_dim, node.shape.clone()))
}
ReduceOp::Min | ReduceOp::Max => None,
ReduceOp::Prod => {
let x_p = fwd_map[&node.inputs[0]];
let y = fwd_map[&node.id];
let x_shape = bwd.node(x_p).shape.clone();
let eps = scalar_const(1e-12, &x_shape, bwd);
let denom = bwd.binary(BinaryOp::Add, x_p, eps, x_shape.clone());
let factor = bwd.binary(BinaryOp::Div, y, denom, x_shape);
Some(bwd.binary(BinaryOp::Mul, factor, t_x, node.shape.clone()))
}
}
}
Op::Where => {
let cond_p = fwd_map[&node.inputs[0]];
let s = node.shape.clone();
match (t_inputs[1], t_inputs[2]) {
(Some(ta), Some(tb)) => Some(bwd.add_node(Op::Where, vec![cond_p, ta, tb], s)),
(Some(ta), None) => {
let zero = zero_like(ta, bwd);
Some(bwd.add_node(Op::Where, vec![cond_p, ta, zero], s))
}
(None, Some(tb)) => {
let zero = zero_like(tb, bwd);
Some(bwd.add_node(Op::Where, vec![cond_p, zero, tb], s))
}
(None, None) => None,
}
}
Op::Compare(_) => None, Op::Cast { to } => {
let t_x = t_inputs[0]?;
Some(bwd.add_node(Op::Cast { to: *to }, vec![t_x], node.shape.clone()))
}
Op::CustomFn {
jvp_body: Some(jvp_body),
num_inputs,
..
} => {
let mut sub_to_bwd: HashMap<NodeId, NodeId> = HashMap::new();
for sub_node in jvp_body.nodes() {
let new_id = match &sub_node.op {
Op::Input { name } if name == "primal_output" => fwd_map[&node.id],
Op::Input { name } if name.starts_with("tangent_") => {
let idx: usize = name["tangent_".len()..].parse().expect(
"custom_fn jvp_body: tangent name must be \
'tangent_<i>' where i is a usize",
);
assert!(idx < *num_inputs as usize);
match t_inputs[idx] {
Some(t) => t,
None => {
scalar_const(0.0, &sub_node.shape, bwd)
}
}
}
Op::Input { .. } => {
let mut primal_input_ids: Vec<NodeId> = jvp_body
.nodes()
.iter()
.filter_map(|n| match &n.op {
Op::Input { name }
if !name.starts_with("tangent_") && name != "primal_output" =>
{
Some(n.id)
}
_ => None,
})
.collect();
primal_input_ids.sort();
let idx = primal_input_ids
.iter()
.position(|&id| id == sub_node.id)
.expect("custom_fn jvp_body: primal Input not found");
fwd_map[&node.inputs[idx]]
}
_ => {
let new_inputs: Vec<NodeId> =
sub_node.inputs.iter().map(|i| sub_to_bwd[i]).collect();
bwd.add_node(sub_node.op.clone(), new_inputs, sub_node.shape.clone())
}
};
sub_to_bwd.insert(sub_node.id, new_id);
}
Some(sub_to_bwd[&jvp_body.outputs[0]])
}
Op::CustomFn { jvp_body: None, .. } => {
panic!(
"jvp: Op::CustomFn has no jvp_body. Either supply \
one to Graph::custom_fn(...), or inline the forward \
body into the parent graph before differentiating."
)
}
Op::Custom { name, .. } => {
let ext = rlx_ir::lookup_op(name)
.unwrap_or_else(|| panic!("jvp: Op::Custom('{name}') not registered"));
let mut ctx = rlx_ir::JvpContext {
tangents: t_inputs,
fwd_map,
bwd,
};
ext.jvp(node, &mut ctx)
}
Op::Rope { head_dim, n_rot } => {
let t_x = t_inputs[0]?;
let cos = fwd_map[&node.inputs[1]];
let sin = fwd_map[&node.inputs[2]];
let sin_shape = bwd.node(sin).shape.clone();
let neg_sin = bwd.activation(Activation::Neg, sin, sin_shape);
Some(bwd.add_node(
Op::Rope {
head_dim: *head_dim,
n_rot: *n_rot,
},
vec![t_x, cos, neg_sin],
node.shape.clone(),
))
}
Op::Cumsum { axis, exclusive } => {
let t_x = t_inputs[0]?;
Some(bwd.add_node(
Op::Cumsum {
axis: *axis,
exclusive: *exclusive,
},
vec![t_x],
node.shape.clone(),
))
}
Op::Gather { axis } => {
let t_table = t_inputs[0]?;
let indices = fwd_map[&node.inputs[1]];
Some(bwd.add_node(
Op::Gather { axis: *axis },
vec![t_table, indices],
node.shape.clone(),
))
}
Op::LayerNorm { axis, eps } => {
let t_x = t_inputs[0]?;
let gamma = fwd_map[&node.inputs[1]];
let x = fwd_map[&node.inputs[0]];
let x_shape = node.shape.clone();
let rank = x_shape.rank();
let axis_pos = if *axis < 0 {
(rank as i32 + *axis) as usize
} else {
*axis as usize
};
let mut keep_dims: Vec<Dim> = x_shape.dims().to_vec();
keep_dims[axis_pos] = Dim::Static(1);
let keep_shape = Shape::from_dims(&keep_dims, x_shape.dtype());
let mean = bwd.add_node(
Op::Reduce {
op: ReduceOp::Mean,
axes: vec![axis_pos],
keep_dim: true,
},
vec![x],
keep_shape.clone(),
);
let diff = bwd.binary(BinaryOp::Sub, x, mean, x_shape.clone());
let diff_sq = bwd.binary(BinaryOp::Mul, diff, diff, x_shape.clone());
let var = bwd.add_node(
Op::Reduce {
op: ReduceOp::Mean,
axes: vec![axis_pos],
keep_dim: true,
},
vec![diff_sq],
keep_shape.clone(),
);
let eps_c = scalar_const(*eps as f64, &keep_shape, bwd);
let var_eps = bwd.binary(BinaryOp::Add, var, eps_c, keep_shape.clone());
let inv_std = bwd.activation(Activation::Rsqrt, var_eps, keep_shape.clone());
let inv_std_b = unbroadcast_inverse(inv_std, &x_shape, bwd);
let gamma_b = unbroadcast_inverse(gamma, &x_shape, bwd);
let inner = bwd.binary(BinaryOp::Mul, t_x, gamma_b, x_shape.clone());
Some(bwd.binary(BinaryOp::Mul, inner, inv_std_b, x_shape))
}
Op::RmsNorm { axis, eps } => {
let t_x = t_inputs[0]?;
let gamma = fwd_map[&node.inputs[1]];
let x = fwd_map[&node.inputs[0]];
let x_shape = node.shape.clone();
let rank = x_shape.rank();
let axis_pos = if *axis < 0 {
(rank as i32 + *axis) as usize
} else {
*axis as usize
};
let mut keep_dims: Vec<Dim> = x_shape.dims().to_vec();
keep_dims[axis_pos] = Dim::Static(1);
let keep_shape = Shape::from_dims(&keep_dims, x_shape.dtype());
let xsq = bwd.binary(BinaryOp::Mul, x, x, x_shape.clone());
let r_sq = bwd.add_node(
Op::Reduce {
op: ReduceOp::Mean,
axes: vec![axis_pos],
keep_dim: true,
},
vec![xsq],
keep_shape.clone(),
);
let eps_c = scalar_const(*eps as f64, &keep_shape, bwd);
let r_sq_eps = bwd.binary(BinaryOp::Add, r_sq, eps_c, keep_shape.clone());
let inv_r = bwd.activation(Activation::Rsqrt, r_sq_eps, keep_shape.clone());
let inv_r_b = unbroadcast_inverse(inv_r, &x_shape, bwd);
let gamma_b = unbroadcast_inverse(gamma, &x_shape, bwd);
let inner = bwd.binary(BinaryOp::Mul, t_x, gamma_b, x_shape.clone());
Some(bwd.binary(BinaryOp::Mul, inner, inv_r_b, x_shape))
}
Op::GroupNorm { num_groups, eps } => {
let t_x = t_inputs[0]?;
let x = fwd_map[&node.inputs[0]];
let gamma = fwd_map[&node.inputs[1]];
let dtype = node.shape.dtype();
let dims: Vec<usize> = node
.shape
.dims()
.iter()
.map(|d| d.unwrap_static())
.collect();
let (n, c, h, w) = (dims[0], dims[1], dims[2], dims[3]);
let cpg = c / num_groups;
let inner = (cpg * h * w) as i64;
let x5 = bwd.reshape_(
x,
vec![n as i64, *num_groups as i64, cpg as i64, h as i64, w as i64],
);
let t_x5 = bwd.reshape_(
t_x,
vec![n as i64, *num_groups as i64, cpg as i64, h as i64, w as i64],
);
let inner_u = inner as usize;
let x3 = bwd.reshape_(x5, vec![n as i64, *num_groups as i64, inner]);
let _t_x3 = bwd.reshape_(t_x5, vec![n as i64, *num_groups as i64, inner]);
let keep_shape = Shape::new(&[n, *num_groups, 1], dtype);
let xsq = bwd.binary(
BinaryOp::Mul,
x3,
x3,
Shape::new(&[n, *num_groups, inner_u], dtype),
);
let r_sq = bwd.mean(xsq, vec![2], true);
let eps_c = scalar_const(*eps as f64, &keep_shape, bwd);
let r_sq_eps = bwd.binary(BinaryOp::Add, r_sq, eps_c, keep_shape.clone());
let inv_r = bwd.activation(Activation::Rsqrt, r_sq_eps, keep_shape);
let inv_r5 = bwd.reshape_(inv_r, vec![n as i64, *num_groups as i64, 1, 1, 1]);
let gamma5 = bwd.reshape_(gamma, vec![1, *num_groups as i64, cpg as i64, 1, 1]);
let t_scaled = bwd.binary(
BinaryOp::Mul,
t_x5,
gamma5,
Shape::new(&[n, *num_groups, cpg, h, w], dtype),
);
let t_out5 = bwd.binary(
BinaryOp::Mul,
t_scaled,
inv_r5,
Shape::new(&[n, *num_groups, cpg, h, w], dtype),
);
Some(bwd.reshape_(t_out5, vec![n as i64, c as i64, h as i64, w as i64]))
}
Op::Attention {
num_heads,
head_dim,
mask_kind,
score_scale,
attn_logit_softcap,
} => {
if t_inputs[1].is_some() || t_inputs[2].is_some() {
panic!("jvp: Attention tangent only supported for Query");
}
let t_q = t_inputs[0]?;
let k = fwd_map[&node.inputs[1]];
let v = fwd_map[&node.inputs[2]];
let inputs = match mask_kind {
MaskKind::Custom | MaskKind::Bias => {
let mask = fwd_map[&node.inputs[3]];
vec![t_q, k, v, mask]
}
_ => vec![t_q, k, v],
};
Some(bwd.add_node(
Op::Attention {
num_heads: *num_heads,
head_dim: *head_dim,
mask_kind: *mask_kind,
score_scale: *score_scale,
attn_logit_softcap: *attn_logit_softcap,
},
inputs,
node.shape.clone(),
))
}
other => panic!("jvp: no rule for op {other:?}"),
}
}
fn scalar_const(value: f64, shape: &Shape, bwd: &mut Graph) -> NodeId {
let bytes = match shape.dtype() {
DType::F32 => (value as f32).to_le_bytes().to_vec(),
DType::F64 => value.to_le_bytes().to_vec(),
other => panic!("scalar_const: dtype {other:?} not supported"),
};
let scalar_shape = Shape::from_dims(&[Dim::Static(1)], shape.dtype());
bwd.add_node(Op::Constant { data: bytes }, vec![], scalar_shape)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn jvp_dense_solve_b_only() {
let mut g = Graph::new("jvp_db");
let a = g.input("A", Shape::new(&[2, 2], DType::F64));
let b = g.input("b", Shape::new(&[2], DType::F64));
let x = g.dense_solve(a, b, Shape::new(&[2], DType::F64));
g.set_outputs(vec![x]);
let jg = jvp(&g, &[b]);
assert_eq!(jg.outputs.len(), 2);
let n_solves = jg
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::DenseSolve))
.count();
assert!(
n_solves >= 2,
"tangent path should add a DenseSolve, got\n{jg}"
);
}
#[test]
fn jvp_dense_solve_a_only() {
let mut g = Graph::new("jvp_da");
let a = g.input("A", Shape::new(&[2, 2], DType::F64));
let b = g.input("b", Shape::new(&[2], DType::F64));
let x = g.dense_solve(a, b, Shape::new(&[2], DType::F64));
g.set_outputs(vec![x]);
let jg = jvp(&g, &[a]);
let n_solves = jg
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::DenseSolve))
.count();
let n_neg = jg
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::Activation(Activation::Neg)))
.count();
let n_mm = jg
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::MatMul))
.count();
assert!(n_solves >= 2, "expected ≥2 DenseSolve, got\n{jg}");
assert!(n_neg >= 1, "expected a Neg for −t_A·x, got\n{jg}");
assert!(n_mm >= 1, "expected a MatMul for t_A·x, got\n{jg}");
}
#[test]
fn jvp_with_no_seeded_tangents_produces_zero_output() {
let mut g = Graph::new("jvp_no_seed");
let a = g.input("A", Shape::new(&[2, 2], DType::F64));
let b = g.input("b", Shape::new(&[2], DType::F64));
let x = g.dense_solve(a, b, Shape::new(&[2], DType::F64));
g.set_outputs(vec![x]);
let jg = jvp(&g, &[]); assert_eq!(jg.outputs.len(), 2);
let tangent_out = jg.node(jg.outputs[1]);
assert!(
matches!(tangent_out.op, Op::Constant { .. }),
"expected zero Constant for tangent, got {:?}",
tangent_out.op
);
}
}