use crate::op::{AttentionBwdWrt, MaskKind};
use crate::{DType, Graph, NodeId, Op, Shape};
impl Graph {
pub fn relu_backward(&mut self, x: NodeId, dy: NodeId) -> NodeId {
let x_shape = self.shape(x).clone();
debug_assert_eq!(
self.shape(x),
self.shape(dy),
"relu_backward: x and dy must have identical shapes"
);
self.push(Op::ReluBackward, vec![x, dy], x_shape, None)
}
pub fn activation_backward(
&mut self,
kind: crate::op::Activation,
x: NodeId,
dy: NodeId,
) -> NodeId {
let x_shape = self.shape(x).clone();
debug_assert_eq!(
self.shape(x),
self.shape(dy),
"activation_backward: x and dy must have identical shapes"
);
self.push(Op::ActivationBackward { kind }, vec![x, dy], x_shape, None)
}
pub fn layer_norm_backward_input(
&mut self,
x: NodeId,
gamma: NodeId,
dy: NodeId,
axis: i32,
eps: f32,
) -> NodeId {
let x_shape = self.shape(x).clone();
debug_assert_eq!(
self.shape(x),
self.shape(dy),
"layer_norm_backward_input: x and dy must match"
);
self.push(
Op::LayerNormBackwardInput { axis, eps },
vec![x, gamma, dy],
x_shape,
None,
)
}
pub fn rms_norm_backward_input(
&mut self,
x: NodeId,
gamma: NodeId,
beta: NodeId,
dy: NodeId,
axis: i32,
eps: f32,
) -> NodeId {
let x_shape = self.shape(x).clone();
self.push(
Op::RmsNormBackwardInput { axis, eps },
vec![x, gamma, beta, dy],
x_shape,
None,
)
}
pub fn rms_norm_backward_gamma(
&mut self,
x: NodeId,
gamma: NodeId,
beta: NodeId,
dy: NodeId,
axis: i32,
eps: f32,
) -> NodeId {
self.push(
Op::RmsNormBackwardGamma { axis, eps },
vec![x, gamma, beta, dy],
self.shape(gamma).clone(),
None,
)
}
pub fn rms_norm_backward_beta(
&mut self,
x: NodeId,
gamma: NodeId,
beta: NodeId,
dy: NodeId,
axis: i32,
eps: f32,
) -> NodeId {
self.push(
Op::RmsNormBackwardBeta { axis, eps },
vec![x, gamma, beta, dy],
self.shape(beta).clone(),
None,
)
}
pub fn rope_backward(
&mut self,
dy: NodeId,
cos: NodeId,
sin: NodeId,
head_dim: usize,
n_rot: usize,
) -> NodeId {
let out_shape = self.shape(dy).clone();
self.push(
Op::RopeBackward { head_dim, n_rot },
vec![dy, cos, sin],
out_shape,
None,
)
}
pub fn cumsum_backward(
&mut self,
dy: NodeId,
out_shape: Shape,
axis: i32,
exclusive: bool,
) -> NodeId {
self.push(
Op::CumsumBackward { axis, exclusive },
vec![dy],
out_shape,
None,
)
}
pub fn gather_backward(
&mut self,
dy: NodeId,
indices: NodeId,
table_shape: Shape,
axis: i32,
) -> NodeId {
self.push(
Op::GatherBackward { axis },
vec![dy, indices],
table_shape,
None,
)
}
pub fn group_norm_backward_input(
&mut self,
x: NodeId,
gamma: NodeId,
beta: NodeId,
dy: NodeId,
num_groups: usize,
eps: f32,
) -> NodeId {
let x_shape = self.shape(x).clone();
self.push(
Op::GroupNormBackwardInput { num_groups, eps },
vec![x, gamma, beta, dy],
x_shape,
None,
)
}
pub fn group_norm_backward_gamma(
&mut self,
x: NodeId,
dy: NodeId,
gamma_shape: Shape,
num_groups: usize,
eps: f32,
) -> NodeId {
self.push(
Op::GroupNormBackwardGamma { num_groups, eps },
vec![x, dy],
gamma_shape,
None,
)
}
pub fn group_norm_backward_beta(
&mut self,
x: NodeId,
dy: NodeId,
beta_shape: Shape,
num_groups: usize,
eps: f32,
) -> NodeId {
self.push(
Op::GroupNormBackwardBeta { num_groups, eps },
vec![x, dy],
beta_shape,
None,
)
}
pub fn layer_norm_backward_gamma(
&mut self,
x: NodeId,
dy: NodeId,
gamma_shape: Shape,
axis: i32,
eps: f32,
) -> NodeId {
debug_assert_eq!(
self.shape(x),
self.shape(dy),
"layer_norm_backward_gamma: x and dy must match"
);
self.push(
Op::LayerNormBackwardGamma { axis, eps },
vec![x, dy],
gamma_shape,
None,
)
}
pub fn maxpool2d_backward(
&mut self,
x: NodeId,
dy: NodeId,
kernel_size: Vec<usize>,
stride: Vec<usize>,
padding: Vec<usize>,
) -> NodeId {
let x_shape = self.shape(x).clone();
debug_assert_eq!(kernel_size.len(), 2, "maxpool2d_backward: 2-D only");
debug_assert_eq!(stride.len(), 2);
debug_assert_eq!(padding.len(), 2);
self.push(
Op::MaxPool2dBackward {
kernel_size,
stride,
padding,
},
vec![x, dy],
x_shape,
None,
)
}
pub fn conv2d_backward_input(
&mut self,
dy: NodeId,
w: NodeId,
x_shape: Shape,
kernel_size: Vec<usize>,
stride: Vec<usize>,
padding: Vec<usize>,
dilation: Vec<usize>,
groups: usize,
) -> NodeId {
debug_assert_eq!(kernel_size.len(), 2);
debug_assert_eq!(stride.len(), 2);
debug_assert_eq!(padding.len(), 2);
debug_assert_eq!(dilation.len(), 2);
self.push(
Op::Conv2dBackwardInput {
kernel_size,
stride,
padding,
dilation,
groups,
},
vec![dy, w],
x_shape,
None,
)
}
pub fn conv2d_backward_weight(
&mut self,
x: NodeId,
dy: NodeId,
w_shape: Shape,
kernel_size: Vec<usize>,
stride: Vec<usize>,
padding: Vec<usize>,
dilation: Vec<usize>,
groups: usize,
) -> NodeId {
debug_assert_eq!(kernel_size.len(), 2);
debug_assert_eq!(stride.len(), 2);
debug_assert_eq!(padding.len(), 2);
debug_assert_eq!(dilation.len(), 2);
self.push(
Op::Conv2dBackwardWeight {
kernel_size,
stride,
padding,
dilation,
groups,
},
vec![x, dy],
w_shape,
None,
)
}
pub fn softmax_cross_entropy_with_logits(&mut self, logits: NodeId, labels: NodeId) -> NodeId {
let logits_shape = self.shape(logits);
debug_assert_eq!(
logits_shape.rank(),
2,
"sce_with_logits: logits must be 2-D [N, C]"
);
let n = logits_shape.dim(0);
let dtype = logits_shape.dtype();
let out_shape = Shape::from_dims(&[n], dtype);
self.push(
Op::SoftmaxCrossEntropyWithLogits,
vec![logits, labels],
out_shape,
None,
)
}
pub fn softmax_cross_entropy_backward(
&mut self,
logits: NodeId,
labels: NodeId,
d_loss: NodeId,
) -> NodeId {
let logits_shape = self.shape(logits).clone();
debug_assert_eq!(
logits_shape.rank(),
2,
"sce_backward: logits must be 2-D [N, C]"
);
self.push(
Op::SoftmaxCrossEntropyBackward,
vec![logits, labels, d_loss],
logits_shape,
None,
)
}
pub fn complex_norm_sq(&mut self, z: NodeId) -> NodeId {
let z_shape = self.shape(z).clone();
debug_assert_eq!(
z_shape.dtype(),
DType::C64,
"complex_norm_sq: input must be C64, got {:?}",
z_shape.dtype()
);
let out_shape = Shape::from_dims(z_shape.dims(), DType::F32);
self.push(Op::ComplexNormSq, vec![z], out_shape, None)
}
pub fn attention_backward(
&mut self,
wrt: AttentionBwdWrt,
q: NodeId,
k: NodeId,
v: NodeId,
dy: NodeId,
num_heads: usize,
head_dim: usize,
mask_kind: MaskKind,
mask: Option<NodeId>,
) -> NodeId {
let out_shape = match wrt {
AttentionBwdWrt::Query => self.shape(q).clone(),
AttentionBwdWrt::Key => self.shape(k).clone(),
AttentionBwdWrt::Value => self.shape(v).clone(),
};
let mut inputs = vec![q, k, v, dy];
if matches!(mask_kind, MaskKind::Custom | MaskKind::Bias) {
inputs.push(mask.expect("attention_backward: mask required for Custom/Bias"));
}
self.push(
Op::AttentionBackward {
num_heads,
head_dim,
mask_kind,
wrt,
},
inputs,
out_shape,
None,
)
}
pub fn attention_backward_all(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
dy: NodeId,
num_heads: usize,
head_dim: usize,
mask_kind: MaskKind,
mask: Option<NodeId>,
) -> (NodeId, NodeId, NodeId) {
let dq = self.attention_backward(
AttentionBwdWrt::Query,
q,
k,
v,
dy,
num_heads,
head_dim,
mask_kind,
mask,
);
let dk = self.attention_backward(
AttentionBwdWrt::Key,
q,
k,
v,
dy,
num_heads,
head_dim,
mask_kind,
mask,
);
let dv = self.attention_backward(
AttentionBwdWrt::Value,
q,
k,
v,
dy,
num_heads,
head_dim,
mask_kind,
mask,
);
(dq, dk, dv)
}
pub fn complex_norm_sq_backward(&mut self, z: NodeId, g: NodeId) -> NodeId {
let z_shape = self.shape(z).clone();
debug_assert_eq!(z_shape.dtype(), DType::C64);
debug_assert_eq!(self.shape(g).dtype(), DType::F32);
debug_assert_eq!(
z_shape.dims(),
self.shape(g).dims(),
"complex_norm_sq_backward: z and g must share logical shape"
);
self.push(Op::ComplexNormSqBackward, vec![z, g], z_shape, None)
}
pub fn conjugate(&mut self, z: NodeId) -> NodeId {
let z_shape = self.shape(z).clone();
debug_assert_eq!(
z_shape.dtype(),
DType::C64,
"conjugate: input must be C64, got {:?}",
z_shape.dtype()
);
self.push(Op::Conjugate, vec![z], z_shape, None)
}
}