1use crate::op::{AttentionBwdWrt, MaskKind};
30use crate::{DType, Graph, NodeId, Op, Shape};
31
32impl Graph {
33 pub fn relu_backward(&mut self, x: NodeId, dy: NodeId) -> NodeId {
35 let x_shape = self.shape(x).clone();
36 debug_assert_eq!(
37 self.shape(x),
38 self.shape(dy),
39 "relu_backward: x and dy must have identical shapes"
40 );
41 self.push(Op::ReluBackward, vec![x, dy], x_shape, None)
42 }
43
44 pub fn activation_backward(
48 &mut self,
49 kind: crate::op::Activation,
50 x: NodeId,
51 dy: NodeId,
52 ) -> NodeId {
53 let x_shape = self.shape(x).clone();
54 debug_assert_eq!(
55 self.shape(x),
56 self.shape(dy),
57 "activation_backward: x and dy must have identical shapes"
58 );
59 self.push(Op::ActivationBackward { kind }, vec![x, dy], x_shape, None)
60 }
61
62 pub fn layer_norm_backward_input(
65 &mut self,
66 x: NodeId,
67 gamma: NodeId,
68 dy: NodeId,
69 axis: i32,
70 eps: f32,
71 ) -> NodeId {
72 let x_shape = self.shape(x).clone();
73 debug_assert_eq!(
74 self.shape(x),
75 self.shape(dy),
76 "layer_norm_backward_input: x and dy must match"
77 );
78 self.push(
79 Op::LayerNormBackwardInput { axis, eps },
80 vec![x, gamma, dy],
81 x_shape,
82 None,
83 )
84 }
85
86 pub fn rms_norm_backward_input(
88 &mut self,
89 x: NodeId,
90 gamma: NodeId,
91 beta: NodeId,
92 dy: NodeId,
93 axis: i32,
94 eps: f32,
95 ) -> NodeId {
96 let x_shape = self.shape(x).clone();
97 self.push(
98 Op::RmsNormBackwardInput { axis, eps },
99 vec![x, gamma, beta, dy],
100 x_shape,
101 None,
102 )
103 }
104
105 pub fn rms_norm_backward_gamma(
106 &mut self,
107 x: NodeId,
108 gamma: NodeId,
109 beta: NodeId,
110 dy: NodeId,
111 axis: i32,
112 eps: f32,
113 ) -> NodeId {
114 self.push(
115 Op::RmsNormBackwardGamma { axis, eps },
116 vec![x, gamma, beta, dy],
117 self.shape(gamma).clone(),
118 None,
119 )
120 }
121
122 pub fn rms_norm_backward_beta(
123 &mut self,
124 x: NodeId,
125 gamma: NodeId,
126 beta: NodeId,
127 dy: NodeId,
128 axis: i32,
129 eps: f32,
130 ) -> NodeId {
131 self.push(
132 Op::RmsNormBackwardBeta { axis, eps },
133 vec![x, gamma, beta, dy],
134 self.shape(beta).clone(),
135 None,
136 )
137 }
138
139 pub fn rope_backward(
140 &mut self,
141 dy: NodeId,
142 cos: NodeId,
143 sin: NodeId,
144 head_dim: usize,
145 n_rot: usize,
146 ) -> NodeId {
147 let out_shape = self.shape(dy).clone();
148 self.push(
149 Op::RopeBackward { head_dim, n_rot },
150 vec![dy, cos, sin],
151 out_shape,
152 None,
153 )
154 }
155
156 pub fn cumsum_backward(
157 &mut self,
158 dy: NodeId,
159 out_shape: Shape,
160 axis: i32,
161 exclusive: bool,
162 ) -> NodeId {
163 self.push(
164 Op::CumsumBackward { axis, exclusive },
165 vec![dy],
166 out_shape,
167 None,
168 )
169 }
170
171 pub fn gather_backward(
172 &mut self,
173 dy: NodeId,
174 indices: NodeId,
175 table_shape: Shape,
176 axis: i32,
177 ) -> NodeId {
178 self.push(
179 Op::GatherBackward { axis },
180 vec![dy, indices],
181 table_shape,
182 None,
183 )
184 }
185
186 pub fn group_norm_backward_input(
188 &mut self,
189 x: NodeId,
190 gamma: NodeId,
191 beta: NodeId,
192 dy: NodeId,
193 num_groups: usize,
194 eps: f32,
195 ) -> NodeId {
196 let x_shape = self.shape(x).clone();
197 self.push(
198 Op::GroupNormBackwardInput { num_groups, eps },
199 vec![x, gamma, beta, dy],
200 x_shape,
201 None,
202 )
203 }
204
205 pub fn group_norm_backward_gamma(
207 &mut self,
208 x: NodeId,
209 dy: NodeId,
210 gamma_shape: Shape,
211 num_groups: usize,
212 eps: f32,
213 ) -> NodeId {
214 self.push(
215 Op::GroupNormBackwardGamma { num_groups, eps },
216 vec![x, dy],
217 gamma_shape,
218 None,
219 )
220 }
221
222 pub fn group_norm_backward_beta(
224 &mut self,
225 x: NodeId,
226 dy: NodeId,
227 beta_shape: Shape,
228 num_groups: usize,
229 eps: f32,
230 ) -> NodeId {
231 self.push(
232 Op::GroupNormBackwardBeta { num_groups, eps },
233 vec![x, dy],
234 beta_shape,
235 None,
236 )
237 }
238
239 pub fn batch_norm_inference_backward_input(
241 &mut self,
242 x: NodeId,
243 gamma: NodeId,
244 mean: NodeId,
245 var: NodeId,
246 dy: NodeId,
247 eps: f32,
248 ) -> NodeId {
249 let x_shape = self.shape(x).clone();
250 debug_assert_eq!(self.shape(x), self.shape(dy));
251 self.push(
252 Op::BatchNormInferenceBackwardInput { eps },
253 vec![x, gamma, mean, var, dy],
254 x_shape,
255 None,
256 )
257 }
258
259 pub fn batch_norm_inference_backward_gamma(
261 &mut self,
262 x: NodeId,
263 mean: NodeId,
264 var: NodeId,
265 dy: NodeId,
266 gamma_shape: Shape,
267 eps: f32,
268 ) -> NodeId {
269 self.push(
270 Op::BatchNormInferenceBackwardGamma { eps },
271 vec![x, mean, var, dy],
272 gamma_shape,
273 None,
274 )
275 }
276
277 pub fn batch_norm_inference_backward_beta(&mut self, dy: NodeId, beta_shape: Shape) -> NodeId {
279 self.push(
280 Op::BatchNormInferenceBackwardBeta,
281 vec![dy],
282 beta_shape,
283 None,
284 )
285 }
286
287 pub fn layer_norm_backward_gamma(
291 &mut self,
292 x: NodeId,
293 dy: NodeId,
294 gamma_shape: Shape,
295 axis: i32,
296 eps: f32,
297 ) -> NodeId {
298 debug_assert_eq!(
299 self.shape(x),
300 self.shape(dy),
301 "layer_norm_backward_gamma: x and dy must match"
302 );
303 self.push(
304 Op::LayerNormBackwardGamma { axis, eps },
305 vec![x, dy],
306 gamma_shape,
307 None,
308 )
309 }
310
311 pub fn maxpool2d_backward(
315 &mut self,
316 x: NodeId,
317 dy: NodeId,
318 kernel_size: Vec<usize>,
319 stride: Vec<usize>,
320 padding: Vec<usize>,
321 ) -> NodeId {
322 let x_shape = self.shape(x).clone();
323 debug_assert_eq!(kernel_size.len(), 2, "maxpool2d_backward: 2-D only");
324 debug_assert_eq!(stride.len(), 2);
325 debug_assert_eq!(padding.len(), 2);
326 self.push(
327 Op::MaxPool2dBackward {
328 kernel_size,
329 stride,
330 padding,
331 },
332 vec![x, dy],
333 x_shape,
334 None,
335 )
336 }
337
338 pub fn conv2d_backward_input(
344 &mut self,
345 dy: NodeId,
346 w: NodeId,
347 x_shape: Shape,
348 kernel_size: Vec<usize>,
349 stride: Vec<usize>,
350 padding: Vec<usize>,
351 dilation: Vec<usize>,
352 groups: usize,
353 ) -> NodeId {
354 debug_assert_eq!(kernel_size.len(), 2);
355 debug_assert_eq!(stride.len(), 2);
356 debug_assert_eq!(padding.len(), 2);
357 debug_assert_eq!(dilation.len(), 2);
358 self.push(
359 Op::Conv2dBackwardInput {
360 kernel_size,
361 stride,
362 padding,
363 dilation,
364 groups,
365 },
366 vec![dy, w],
367 x_shape,
368 None,
369 )
370 }
371
372 pub fn conv2d_backward_weight(
375 &mut self,
376 x: NodeId,
377 dy: NodeId,
378 w_shape: Shape,
379 kernel_size: Vec<usize>,
380 stride: Vec<usize>,
381 padding: Vec<usize>,
382 dilation: Vec<usize>,
383 groups: usize,
384 ) -> NodeId {
385 debug_assert_eq!(kernel_size.len(), 2);
386 debug_assert_eq!(stride.len(), 2);
387 debug_assert_eq!(padding.len(), 2);
388 debug_assert_eq!(dilation.len(), 2);
389 self.push(
390 Op::Conv2dBackwardWeight {
391 kernel_size,
392 stride,
393 padding,
394 dilation,
395 groups,
396 },
397 vec![x, dy],
398 w_shape,
399 None,
400 )
401 }
402
403 pub fn softmax_cross_entropy_with_logits(&mut self, logits: NodeId, labels: NodeId) -> NodeId {
406 let logits_shape = self.shape(logits);
407 debug_assert_eq!(
408 logits_shape.rank(),
409 2,
410 "sce_with_logits: logits must be 2-D [N, C]"
411 );
412 let n = logits_shape.dim(0);
413 let dtype = logits_shape.dtype();
414 let out_shape = Shape::from_dims(&[n], dtype);
415 self.push(
416 Op::SoftmaxCrossEntropyWithLogits,
417 vec![logits, labels],
418 out_shape,
419 None,
420 )
421 }
422
423 pub fn softmax_cross_entropy_backward(
426 &mut self,
427 logits: NodeId,
428 labels: NodeId,
429 d_loss: NodeId,
430 ) -> NodeId {
431 let logits_shape = self.shape(logits).clone();
432 debug_assert_eq!(
433 logits_shape.rank(),
434 2,
435 "sce_backward: logits must be 2-D [N, C]"
436 );
437 self.push(
438 Op::SoftmaxCrossEntropyBackward,
439 vec![logits, labels, d_loss],
440 logits_shape,
441 None,
442 )
443 }
444
445 pub fn complex_norm_sq(&mut self, z: NodeId) -> NodeId {
450 let z_shape = self.shape(z).clone();
451 debug_assert_eq!(
452 z_shape.dtype(),
453 DType::C64,
454 "complex_norm_sq: input must be C64, got {:?}",
455 z_shape.dtype()
456 );
457 let out_shape = Shape::from_dims(z_shape.dims(), DType::F32);
458 self.push(Op::ComplexNormSq, vec![z], out_shape, None)
459 }
460
461 pub fn attention_backward(
465 &mut self,
466 wrt: AttentionBwdWrt,
467 q: NodeId,
468 k: NodeId,
469 v: NodeId,
470 dy: NodeId,
471 num_heads: usize,
472 head_dim: usize,
473 mask_kind: MaskKind,
474 mask: Option<NodeId>,
475 ) -> NodeId {
476 let out_shape = match wrt {
477 AttentionBwdWrt::Query => self.shape(q).clone(),
478 AttentionBwdWrt::Key => self.shape(k).clone(),
479 AttentionBwdWrt::Value => self.shape(v).clone(),
480 };
481 let mut inputs = vec![q, k, v, dy];
482 if matches!(mask_kind, MaskKind::Custom | MaskKind::Bias) {
483 inputs.push(mask.expect("attention_backward: mask required for Custom/Bias"));
484 }
485 self.push(
486 Op::AttentionBackward {
487 num_heads,
488 head_dim,
489 mask_kind,
490 wrt,
491 },
492 inputs,
493 out_shape,
494 None,
495 )
496 }
497
498 pub fn attention_backward_all(
500 &mut self,
501 q: NodeId,
502 k: NodeId,
503 v: NodeId,
504 dy: NodeId,
505 num_heads: usize,
506 head_dim: usize,
507 mask_kind: MaskKind,
508 mask: Option<NodeId>,
509 ) -> (NodeId, NodeId, NodeId) {
510 let dq = self.attention_backward(
511 AttentionBwdWrt::Query,
512 q,
513 k,
514 v,
515 dy,
516 num_heads,
517 head_dim,
518 mask_kind,
519 mask,
520 );
521 let dk = self.attention_backward(
522 AttentionBwdWrt::Key,
523 q,
524 k,
525 v,
526 dy,
527 num_heads,
528 head_dim,
529 mask_kind,
530 mask,
531 );
532 let dv = self.attention_backward(
533 AttentionBwdWrt::Value,
534 q,
535 k,
536 v,
537 dy,
538 num_heads,
539 head_dim,
540 mask_kind,
541 mask,
542 );
543 (dq, dk, dv)
544 }
545
546 pub fn complex_norm_sq_backward(&mut self, z: NodeId, g: NodeId) -> NodeId {
550 let z_shape = self.shape(z).clone();
551 debug_assert_eq!(z_shape.dtype(), DType::C64);
552 debug_assert_eq!(self.shape(g).dtype(), DType::F32);
553 debug_assert_eq!(
554 z_shape.dims(),
555 self.shape(g).dims(),
556 "complex_norm_sq_backward: z and g must share logical shape"
557 );
558 self.push(Op::ComplexNormSqBackward, vec![z, g], z_shape, None)
559 }
560
561 pub fn conjugate(&mut self, z: NodeId) -> NodeId {
565 let z_shape = self.shape(z).clone();
566 debug_assert_eq!(
567 z_shape.dtype(),
568 DType::C64,
569 "conjugate: input must be C64, got {:?}",
570 z_shape.dtype()
571 );
572 self.push(Op::Conjugate, vec![z], z_shape, None)
573 }
574}