1use crate::op::{AttentionBwdWrt, MaskKind};
30use crate::{DType, Graph, NodeId, Op, Shape};
31
32impl Graph {
33 pub fn relu_backward(&mut self, x: NodeId, dy: NodeId) -> NodeId {
35 let x_shape = self.shape(x).clone();
36 debug_assert_eq!(
37 self.shape(x),
38 self.shape(dy),
39 "relu_backward: x and dy must have identical shapes"
40 );
41 self.push(Op::ReluBackward, vec![x, dy], x_shape, None)
42 }
43
44 pub fn activation_backward(
48 &mut self,
49 kind: crate::op::Activation,
50 x: NodeId,
51 dy: NodeId,
52 ) -> NodeId {
53 let x_shape = self.shape(x).clone();
54 debug_assert_eq!(
55 self.shape(x),
56 self.shape(dy),
57 "activation_backward: x and dy must have identical shapes"
58 );
59 self.push(Op::ActivationBackward { kind }, vec![x, dy], x_shape, None)
60 }
61
62 pub fn layer_norm_backward_input(
65 &mut self,
66 x: NodeId,
67 gamma: NodeId,
68 dy: NodeId,
69 axis: i32,
70 eps: f32,
71 ) -> NodeId {
72 let x_shape = self.shape(x).clone();
73 debug_assert_eq!(
74 self.shape(x),
75 self.shape(dy),
76 "layer_norm_backward_input: x and dy must match"
77 );
78 self.push(
79 Op::LayerNormBackwardInput { axis, eps },
80 vec![x, gamma, dy],
81 x_shape,
82 None,
83 )
84 }
85
86 pub fn rms_norm_backward_input(
88 &mut self,
89 x: NodeId,
90 gamma: NodeId,
91 beta: NodeId,
92 dy: NodeId,
93 axis: i32,
94 eps: f32,
95 ) -> NodeId {
96 let x_shape = self.shape(x).clone();
97 self.push(
98 Op::RmsNormBackwardInput { axis, eps },
99 vec![x, gamma, beta, dy],
100 x_shape,
101 None,
102 )
103 }
104
105 pub fn rms_norm_backward_gamma(
106 &mut self,
107 x: NodeId,
108 gamma: NodeId,
109 beta: NodeId,
110 dy: NodeId,
111 axis: i32,
112 eps: f32,
113 ) -> NodeId {
114 self.push(
115 Op::RmsNormBackwardGamma { axis, eps },
116 vec![x, gamma, beta, dy],
117 self.shape(gamma).clone(),
118 None,
119 )
120 }
121
122 pub fn rms_norm_backward_beta(
123 &mut self,
124 x: NodeId,
125 gamma: NodeId,
126 beta: NodeId,
127 dy: NodeId,
128 axis: i32,
129 eps: f32,
130 ) -> NodeId {
131 self.push(
132 Op::RmsNormBackwardBeta { axis, eps },
133 vec![x, gamma, beta, dy],
134 self.shape(beta).clone(),
135 None,
136 )
137 }
138
139 pub fn rope_backward(
140 &mut self,
141 dy: NodeId,
142 cos: NodeId,
143 sin: NodeId,
144 head_dim: usize,
145 n_rot: usize,
146 ) -> NodeId {
147 let out_shape = self.shape(dy).clone();
148 self.push(
149 Op::RopeBackward { head_dim, n_rot },
150 vec![dy, cos, sin],
151 out_shape,
152 None,
153 )
154 }
155
156 pub fn cumsum_backward(
157 &mut self,
158 dy: NodeId,
159 out_shape: Shape,
160 axis: i32,
161 exclusive: bool,
162 ) -> NodeId {
163 self.push(
164 Op::CumsumBackward { axis, exclusive },
165 vec![dy],
166 out_shape,
167 None,
168 )
169 }
170
171 pub fn gather_backward(
172 &mut self,
173 dy: NodeId,
174 indices: NodeId,
175 table_shape: Shape,
176 axis: i32,
177 ) -> NodeId {
178 self.push(
179 Op::GatherBackward { axis },
180 vec![dy, indices],
181 table_shape,
182 None,
183 )
184 }
185
186 pub fn group_norm_backward_input(
188 &mut self,
189 x: NodeId,
190 gamma: NodeId,
191 beta: NodeId,
192 dy: NodeId,
193 num_groups: usize,
194 eps: f32,
195 ) -> NodeId {
196 let x_shape = self.shape(x).clone();
197 self.push(
198 Op::GroupNormBackwardInput { num_groups, eps },
199 vec![x, gamma, beta, dy],
200 x_shape,
201 None,
202 )
203 }
204
205 pub fn group_norm_backward_gamma(
207 &mut self,
208 x: NodeId,
209 dy: NodeId,
210 gamma_shape: Shape,
211 num_groups: usize,
212 eps: f32,
213 ) -> NodeId {
214 self.push(
215 Op::GroupNormBackwardGamma { num_groups, eps },
216 vec![x, dy],
217 gamma_shape,
218 None,
219 )
220 }
221
222 pub fn group_norm_backward_beta(
224 &mut self,
225 x: NodeId,
226 dy: NodeId,
227 beta_shape: Shape,
228 num_groups: usize,
229 eps: f32,
230 ) -> NodeId {
231 self.push(
232 Op::GroupNormBackwardBeta { num_groups, eps },
233 vec![x, dy],
234 beta_shape,
235 None,
236 )
237 }
238
239 pub fn layer_norm_backward_gamma(
243 &mut self,
244 x: NodeId,
245 dy: NodeId,
246 gamma_shape: Shape,
247 axis: i32,
248 eps: f32,
249 ) -> NodeId {
250 debug_assert_eq!(
251 self.shape(x),
252 self.shape(dy),
253 "layer_norm_backward_gamma: x and dy must match"
254 );
255 self.push(
256 Op::LayerNormBackwardGamma { axis, eps },
257 vec![x, dy],
258 gamma_shape,
259 None,
260 )
261 }
262
263 pub fn maxpool2d_backward(
267 &mut self,
268 x: NodeId,
269 dy: NodeId,
270 kernel_size: Vec<usize>,
271 stride: Vec<usize>,
272 padding: Vec<usize>,
273 ) -> NodeId {
274 let x_shape = self.shape(x).clone();
275 debug_assert_eq!(kernel_size.len(), 2, "maxpool2d_backward: 2-D only");
276 debug_assert_eq!(stride.len(), 2);
277 debug_assert_eq!(padding.len(), 2);
278 self.push(
279 Op::MaxPool2dBackward {
280 kernel_size,
281 stride,
282 padding,
283 },
284 vec![x, dy],
285 x_shape,
286 None,
287 )
288 }
289
290 pub fn conv2d_backward_input(
296 &mut self,
297 dy: NodeId,
298 w: NodeId,
299 x_shape: Shape,
300 kernel_size: Vec<usize>,
301 stride: Vec<usize>,
302 padding: Vec<usize>,
303 dilation: Vec<usize>,
304 groups: usize,
305 ) -> NodeId {
306 debug_assert_eq!(kernel_size.len(), 2);
307 debug_assert_eq!(stride.len(), 2);
308 debug_assert_eq!(padding.len(), 2);
309 debug_assert_eq!(dilation.len(), 2);
310 self.push(
311 Op::Conv2dBackwardInput {
312 kernel_size,
313 stride,
314 padding,
315 dilation,
316 groups,
317 },
318 vec![dy, w],
319 x_shape,
320 None,
321 )
322 }
323
324 pub fn conv2d_backward_weight(
327 &mut self,
328 x: NodeId,
329 dy: NodeId,
330 w_shape: Shape,
331 kernel_size: Vec<usize>,
332 stride: Vec<usize>,
333 padding: Vec<usize>,
334 dilation: Vec<usize>,
335 groups: usize,
336 ) -> NodeId {
337 debug_assert_eq!(kernel_size.len(), 2);
338 debug_assert_eq!(stride.len(), 2);
339 debug_assert_eq!(padding.len(), 2);
340 debug_assert_eq!(dilation.len(), 2);
341 self.push(
342 Op::Conv2dBackwardWeight {
343 kernel_size,
344 stride,
345 padding,
346 dilation,
347 groups,
348 },
349 vec![x, dy],
350 w_shape,
351 None,
352 )
353 }
354
355 pub fn softmax_cross_entropy_with_logits(&mut self, logits: NodeId, labels: NodeId) -> NodeId {
358 let logits_shape = self.shape(logits);
359 debug_assert_eq!(
360 logits_shape.rank(),
361 2,
362 "sce_with_logits: logits must be 2-D [N, C]"
363 );
364 let n = logits_shape.dim(0);
365 let dtype = logits_shape.dtype();
366 let out_shape = Shape::from_dims(&[n], dtype);
367 self.push(
368 Op::SoftmaxCrossEntropyWithLogits,
369 vec![logits, labels],
370 out_shape,
371 None,
372 )
373 }
374
375 pub fn softmax_cross_entropy_backward(
378 &mut self,
379 logits: NodeId,
380 labels: NodeId,
381 d_loss: NodeId,
382 ) -> NodeId {
383 let logits_shape = self.shape(logits).clone();
384 debug_assert_eq!(
385 logits_shape.rank(),
386 2,
387 "sce_backward: logits must be 2-D [N, C]"
388 );
389 self.push(
390 Op::SoftmaxCrossEntropyBackward,
391 vec![logits, labels, d_loss],
392 logits_shape,
393 None,
394 )
395 }
396
397 pub fn complex_norm_sq(&mut self, z: NodeId) -> NodeId {
402 let z_shape = self.shape(z).clone();
403 debug_assert_eq!(
404 z_shape.dtype(),
405 DType::C64,
406 "complex_norm_sq: input must be C64, got {:?}",
407 z_shape.dtype()
408 );
409 let out_shape = Shape::from_dims(z_shape.dims(), DType::F32);
410 self.push(Op::ComplexNormSq, vec![z], out_shape, None)
411 }
412
413 pub fn attention_backward(
417 &mut self,
418 wrt: AttentionBwdWrt,
419 q: NodeId,
420 k: NodeId,
421 v: NodeId,
422 dy: NodeId,
423 num_heads: usize,
424 head_dim: usize,
425 mask_kind: MaskKind,
426 mask: Option<NodeId>,
427 ) -> NodeId {
428 let out_shape = match wrt {
429 AttentionBwdWrt::Query => self.shape(q).clone(),
430 AttentionBwdWrt::Key => self.shape(k).clone(),
431 AttentionBwdWrt::Value => self.shape(v).clone(),
432 };
433 let mut inputs = vec![q, k, v, dy];
434 if matches!(mask_kind, MaskKind::Custom | MaskKind::Bias) {
435 inputs.push(mask.expect("attention_backward: mask required for Custom/Bias"));
436 }
437 self.push(
438 Op::AttentionBackward {
439 num_heads,
440 head_dim,
441 mask_kind,
442 wrt,
443 },
444 inputs,
445 out_shape,
446 None,
447 )
448 }
449
450 pub fn attention_backward_all(
452 &mut self,
453 q: NodeId,
454 k: NodeId,
455 v: NodeId,
456 dy: NodeId,
457 num_heads: usize,
458 head_dim: usize,
459 mask_kind: MaskKind,
460 mask: Option<NodeId>,
461 ) -> (NodeId, NodeId, NodeId) {
462 let dq = self.attention_backward(
463 AttentionBwdWrt::Query,
464 q,
465 k,
466 v,
467 dy,
468 num_heads,
469 head_dim,
470 mask_kind,
471 mask,
472 );
473 let dk = self.attention_backward(
474 AttentionBwdWrt::Key,
475 q,
476 k,
477 v,
478 dy,
479 num_heads,
480 head_dim,
481 mask_kind,
482 mask,
483 );
484 let dv = self.attention_backward(
485 AttentionBwdWrt::Value,
486 q,
487 k,
488 v,
489 dy,
490 num_heads,
491 head_dim,
492 mask_kind,
493 mask,
494 );
495 (dq, dk, dv)
496 }
497
498 pub fn complex_norm_sq_backward(&mut self, z: NodeId, g: NodeId) -> NodeId {
502 let z_shape = self.shape(z).clone();
503 debug_assert_eq!(z_shape.dtype(), DType::C64);
504 debug_assert_eq!(self.shape(g).dtype(), DType::F32);
505 debug_assert_eq!(
506 z_shape.dims(),
507 self.shape(g).dims(),
508 "complex_norm_sq_backward: z and g must share logical shape"
509 );
510 self.push(Op::ComplexNormSqBackward, vec![z, g], z_shape, None)
511 }
512
513 pub fn conjugate(&mut self, z: NodeId) -> NodeId {
517 let z_shape = self.shape(z).clone();
518 debug_assert_eq!(
519 z_shape.dtype(),
520 DType::C64,
521 "conjugate: input must be C64, got {:?}",
522 z_shape.dtype()
523 );
524 self.push(Op::Conjugate, vec![z], z_shape, None)
525 }
526}