1use crate::Autodiff;
2use crate::checkpoint::base::Checkpointer;
3use crate::checkpoint::strategy::CheckpointStrategy;
4use crate::grads::Gradients;
5use crate::graph::NodeId;
6use crate::ops::{Backward, Ops, unary};
7use crate::tensor::AutodiffTensor;
8
9use burn_backend::Backend;
10use burn_backend::ops::*;
11use burn_backend::tensor::{FloatTensor, IntTensor};
12
13use super::OpsKind;
14
15impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B, C> {
16 fn embedding(weights: AutodiffTensor<B>, indices: IntTensor<B>) -> AutodiffTensor<B> {
17 #[derive(Debug)]
18 struct Embedding;
19
20 impl<B: Backend> Backward<B, 1> for Embedding {
21 type State = (B::FloatTensorPrimitive, IntTensor<B>);
22
23 fn backward(
24 self,
25 ops: Ops<Self::State, 1>,
26 grads: &mut Gradients,
27 _checkpointer: &mut Checkpointer,
28 ) {
29 let (weights, indices) = ops.state;
30
31 unary::<B, _>(ops.parents, ops.node, grads, |grad| {
32 B::embedding_backward(weights, grad, indices)
33 });
34 }
35 }
36
37 match Embedding
38 .prepare::<C>([weights.node])
39 .compute_bound()
40 .stateful()
41 {
42 OpsKind::Tracked(prep) => prep.finish(
43 (weights.primitive.clone(), indices.clone()),
44 B::embedding(weights.primitive, indices),
45 ),
46 OpsKind::UnTracked(prep) => prep.finish(B::embedding(weights.primitive, indices)),
47 }
48 }
49
50 fn embedding_backward(
51 _weights: AutodiffTensor<B>,
52 _output: AutodiffTensor<B>,
53 _indices: IntTensor<B>,
54 ) -> AutodiffTensor<B> {
55 panic!("Can't differentiate embedding backward.");
56 }
57
58 fn conv1d(
59 x: AutodiffTensor<B>,
60 weight: AutodiffTensor<B>,
61 bias: Option<AutodiffTensor<B>>,
62 options: ConvOptions<1>,
63 ) -> AutodiffTensor<B> {
64 #[derive(Debug)]
65 struct Conv1DWithBias;
66 #[derive(Debug)]
67 struct Conv1DNoBias;
68
69 impl<B: Backend> Backward<B, 3> for Conv1DWithBias {
70 type State = (NodeId, NodeId, NodeId, ConvOptions<1>);
71
72 fn backward(
73 self,
74 ops: Ops<Self::State, 3>,
75 grads: &mut Gradients,
76 checkpointer: &mut Checkpointer,
77 ) {
78 let [node_x, node_weight, node_bias] = ops.parents;
79 let grad = grads.consume::<B>(&ops.node);
80
81 let (x_state, weight_state, bias_state, options) = ops.state;
82 let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
83 let weight =
84 checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
85 let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);
86
87 if let Some(node) = node_x {
88 let grad = B::conv1d_x_backward(
89 x.clone(),
90 weight.clone(),
91 grad.clone(),
92 options.clone(),
93 );
94 grads.register::<B>(node.id, grad)
95 }
96 if let Some(node) = node_weight {
97 let grad = B::conv1d_weight_backward(x.clone(), weight, grad.clone(), options);
98 grads.register::<B>(node.id, grad)
99 }
100 if let Some(node) = node_bias {
101 let grad = B::conv1d_bias_backward(x, bias, grad);
102 grads.register::<B>(node.id, grad)
103 }
104 }
105 }
106
107 impl<B: Backend> Backward<B, 2> for Conv1DNoBias {
108 type State = (NodeId, NodeId, ConvOptions<1>);
109
110 fn backward(
111 self,
112 ops: Ops<Self::State, 2>,
113 grads: &mut Gradients,
114 checkpointer: &mut Checkpointer,
115 ) {
116 let [node_x, node_weight] = ops.parents;
117 let grad = grads.consume::<B>(&ops.node);
118
119 let (x_state, weight_state, options) = ops.state;
120 let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
121 let weight =
122 checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
123
124 if let Some(node) = node_x {
125 let grad = B::conv1d_x_backward(
126 x.clone(),
127 weight.clone(),
128 grad.clone(),
129 options.clone(),
130 );
131 grads.register::<B>(node.id, grad)
132 }
133 if let Some(node) = node_weight {
134 let grad = B::conv1d_weight_backward(x, weight, grad, options);
135 grads.register::<B>(node.id, grad)
136 }
137 }
138 }
139 match bias {
140 Some(bias) => match Conv1DWithBias
141 .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
142 .compute_bound()
143 .stateful()
144 {
145 OpsKind::Tracked(mut prep) => {
146 let x_state = prep.checkpoint(&x);
147 let weight_state = prep.checkpoint(&weight);
148 let bias_state = prep.checkpoint(&bias);
149 prep.finish(
150 (x_state, weight_state, bias_state, options.clone()),
151 B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options),
152 )
153 }
154 OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
155 x.primitive,
156 weight.primitive,
157 Some(bias.primitive),
158 options,
159 )),
160 },
161 None => match Conv1DNoBias
162 .prepare::<C>([x.node.clone(), weight.node.clone()])
163 .compute_bound()
164 .stateful()
165 {
166 OpsKind::Tracked(mut prep) => {
167 let x_state = prep.checkpoint(&x);
168 let weight_state = prep.checkpoint(&weight);
169 prep.finish(
170 (x_state, weight_state, options.clone()),
171 B::conv1d(x.primitive, weight.primitive, None, options),
172 )
173 }
174 OpsKind::UnTracked(prep) => {
175 prep.finish(B::conv1d(x.primitive, weight.primitive, None, options))
176 }
177 },
178 }
179 }
180
181 fn conv_transpose1d(
182 x: AutodiffTensor<B>,
183 weight: AutodiffTensor<B>,
184 bias: Option<AutodiffTensor<B>>,
185 options: ConvTransposeOptions<1>,
186 ) -> AutodiffTensor<B> {
187 #[derive(Debug)]
188 struct ConvTranspose1DWithBias;
189 #[derive(Debug)]
190 struct ConvTranspose1DNoBias;
191
192 impl<B: Backend> Backward<B, 3> for ConvTranspose1DWithBias {
193 type State = (NodeId, NodeId, NodeId, ConvTransposeOptions<1>);
194
195 fn backward(
196 self,
197 ops: Ops<Self::State, 3>,
198 grads: &mut Gradients,
199 checkpointer: &mut Checkpointer,
200 ) {
201 let [node_x, node_weight, node_bias] = ops.parents;
202 let grad = grads.consume::<B>(&ops.node);
203
204 let (x_state, weight_state, bias_state, options) = ops.state;
205 let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
206 let weight =
207 checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
208 let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);
209
210 if let Some(node) = node_x {
211 let grad = B::conv_transpose1d_x_backward(
212 weight.clone(),
213 grad.clone(),
214 options.clone(),
215 );
216 grads.register::<B>(node.id, grad)
217 }
218 if let Some(node) = node_weight {
219 let grad = B::conv_transpose1d_weight_backward(
220 x.clone(),
221 weight,
222 grad.clone(),
223 options,
224 );
225 grads.register::<B>(node.id, grad)
226 }
227 if let Some(node) = node_bias {
228 let grad = B::conv_transpose1d_bias_backward(x, bias, grad);
229 grads.register::<B>(node.id, grad)
230 }
231 }
232 }
233
234 impl<B: Backend> Backward<B, 2> for ConvTranspose1DNoBias {
235 type State = (NodeId, NodeId, ConvTransposeOptions<1>);
236
237 fn backward(
238 self,
239 ops: Ops<Self::State, 2>,
240 grads: &mut Gradients,
241 checkpointer: &mut Checkpointer,
242 ) {
243 let [node_x, node_weight] = ops.parents;
244 let grad = grads.consume::<B>(&ops.node);
245
246 let (x_state, weight_state, options) = ops.state;
247 let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
248 let weight =
249 checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
250
251 if let Some(node) = node_x {
252 let grad = B::conv_transpose1d_x_backward(
253 weight.clone(),
254 grad.clone(),
255 options.clone(),
256 );
257 grads.register::<B>(node.id, grad)
258 }
259 if let Some(node) = node_weight {
260 let grad = B::conv_transpose1d_weight_backward(x, weight, grad, options);
261 grads.register::<B>(node.id, grad)
262 }
263 }
264 }
265
266 match bias {
267 Some(bias) => match ConvTranspose1DWithBias
268 .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
269 .compute_bound()
270 .stateful()
271 {
272 OpsKind::Tracked(mut prep) => {
273 let x_state = prep.checkpoint(&x);
274 let weight_state = prep.checkpoint(&weight);
275 let bias_state = prep.checkpoint(&bias);
276 prep.finish(
277 (x_state, weight_state, bias_state, options.clone()),
278 B::conv_transpose1d(
279 x.primitive,
280 weight.primitive,
281 Some(bias.primitive),
282 options,
283 ),
284 )
285 }
286 OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
287 x.primitive,
288 weight.primitive,
289 Some(bias.primitive),
290 options,
291 )),
292 },
293 None => match ConvTranspose1DNoBias
294 .prepare::<C>([x.node.clone(), weight.node.clone()])
295 .compute_bound()
296 .stateful()
297 {
298 OpsKind::Tracked(mut prep) => {
299 let x_state = prep.checkpoint(&x);
300 let weight_state = prep.checkpoint(&weight);
301 prep.finish(
302 (x_state, weight_state, options.clone()),
303 B::conv_transpose1d(x.primitive, weight.primitive, None, options),
304 )
305 }
306 OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
307 x.primitive,
308 weight.primitive,
309 None,
310 options,
311 )),
312 },
313 }
314 }
315
316 fn conv2d(
317 x: AutodiffTensor<B>,
318 weight: AutodiffTensor<B>,
319 bias: Option<AutodiffTensor<B>>,
320 options: ConvOptions<2>,
321 ) -> AutodiffTensor<B> {
322 #[derive(Debug)]
323 struct Conv2DWithBias;
324 #[derive(Debug)]
325 struct Conv2DNoBias;
326
327 impl<B: Backend> Backward<B, 3> for Conv2DWithBias {
328 type State = (NodeId, NodeId, NodeId, ConvOptions<2>);
329
330 fn backward(
331 self,
332 ops: Ops<Self::State, 3>,
333 grads: &mut Gradients,
334 checkpointer: &mut Checkpointer,
335 ) {
336 let [node_x, node_weight, node_bias] = ops.parents;
337 let grad = grads.consume::<B>(&ops.node);
338
339 let (x_state, weight_state, bias_state, options) = ops.state;
340 let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
341 let weight =
342 checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
343 let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);
344
345 if let Some(node) = node_x {
346 let grad = B::conv2d_x_backward(
347 x.clone(),
348 weight.clone(),
349 grad.clone(),
350 options.clone(),
351 );
352 grads.register::<B>(node.id, grad)
353 }
354 if let Some(node) = node_weight {
355 let grad =
356 B::conv2d_weight_backward(x.clone(), weight.clone(), grad.clone(), options);
357 grads.register::<B>(node.id, grad)
358 }
359 if let Some(node) = node_bias {
360 let grad = B::conv2d_bias_backward(x, weight, bias, grad);
361 grads.register::<B>(node.id, grad)
362 }
363 }
364 }
365
366 impl<B: Backend> Backward<B, 2> for Conv2DNoBias {
367 type State = (NodeId, NodeId, ConvOptions<2>);
368
369 fn backward(
370 self,
371 ops: Ops<Self::State, 2>,
372 grads: &mut Gradients,
373 checkpointer: &mut Checkpointer,
374 ) {
375 let [node_x, node_weight] = ops.parents;
376 let grad = grads.consume::<B>(&ops.node);
377
378 let (x_state, weight_state, options) = ops.state;
379 let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
380 let weight =
381 checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
382
383 if let Some(node) = node_x {
384 let grad = B::conv2d_x_backward(
385 x.clone(),
386 weight.clone(),
387 grad.clone(),
388 options.clone(),
389 );
390 grads.register::<B>(node.id, grad)
391 }
392 if let Some(node) = node_weight {
393 let grad = B::conv2d_weight_backward(x, weight, grad, options);
394 grads.register::<B>(node.id, grad)
395 }
396 }
397 }
398
399 match bias {
400 Some(bias) => match Conv2DWithBias
401 .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
402 .compute_bound()
403 .stateful()
404 {
405 OpsKind::Tracked(mut prep) => {
406 let x_state = prep.checkpoint(&x);
407 let weight_state = prep.checkpoint(&weight);
408 let bias_state = prep.checkpoint(&bias);
409 prep.finish(
410 (x_state, weight_state, bias_state, options.clone()),
411 B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options),
412 )
413 }
414 OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
415 x.primitive,
416 weight.primitive,
417 Some(bias.primitive),
418 options,
419 )),
420 },
421 None => match Conv2DNoBias
422 .prepare::<C>([x.node.clone(), weight.node.clone()])
423 .compute_bound()
424 .stateful()
425 {
426 OpsKind::Tracked(mut prep) => {
427 let x_state = prep.checkpoint(&x);
428 let weight_state = prep.checkpoint(&weight);
429 prep.finish(
430 (x_state, weight_state, options.clone()),
431 B::conv2d(x.primitive, weight.primitive, None, options),
432 )
433 }
434
435 OpsKind::UnTracked(prep) => {
436 prep.finish(B::conv2d(x.primitive, weight.primitive, None, options))
437 }
438 },
439 }
440 }
441
442 fn deform_conv2d(
443 x: AutodiffTensor<B>,
444 offset: AutodiffTensor<B>,
445 weight: AutodiffTensor<B>,
446 mask: Option<AutodiffTensor<B>>,
447 bias: Option<AutodiffTensor<B>>,
448 options: DeformConvOptions<2>,
449 ) -> AutodiffTensor<B> {
450 #[derive(Debug)]
451 struct DeformConv2DWithMaskWithBias;
452 #[derive(Debug)]
453 struct DeformConv2DWithMaskNoBias;
454 #[derive(Debug)]
455 struct DeformConv2DNoMaskWithBias;
456 #[derive(Debug)]
457 struct DeformConv2DNoMaskNoBias;
458
459 impl<B: Backend> Backward<B, 5> for DeformConv2DWithMaskWithBias {
460 type State = (NodeId, NodeId, NodeId, NodeId, NodeId, DeformConvOptions<2>);
461
462 fn backward(
463 self,
464 ops: Ops<Self::State, 5>,
465 grads: &mut Gradients,
466 checkpointer: &mut Checkpointer,
467 ) {
468 let [node_x, node_offset, node_weight, node_mask, node_bias] = ops.parents;
469 let grad = grads.consume::<B>(&ops.node);
470
471 let (x_state, offset_state, weight_state, mask_state, bias_state, options) =
472 ops.state;
473 let x = checkpointer.retrieve_node_output(x_state);
474 let offset = checkpointer.retrieve_node_output(offset_state);
475 let weight = checkpointer.retrieve_node_output(weight_state);
476 let mask = Some(checkpointer.retrieve_node_output(mask_state));
477 let bias = Some(checkpointer.retrieve_node_output(bias_state));
478
479 let backward =
480 B::deform_conv2d_backward(x, offset, weight, mask, bias, grad, options);
481
482 if let Some(node) = node_x {
483 grads.register::<B>(node.id, backward.x_grad)
484 }
485 if let Some(node) = node_offset {
486 grads.register::<B>(node.id, backward.offset_grad)
487 }
488 if let Some(node) = node_weight {
489 grads.register::<B>(node.id, backward.weight_grad)
490 }
491 if let Some(node) = node_mask {
492 grads.register::<B>(node.id, backward.mask_grad.unwrap())
493 }
494 if let Some(node) = node_bias {
495 grads.register::<B>(node.id, backward.bias_grad.unwrap())
496 }
497 }
498 }
499
500 impl<B: Backend> Backward<B, 4> for DeformConv2DWithMaskNoBias {
501 type State = (NodeId, NodeId, NodeId, NodeId, DeformConvOptions<2>);
502
503 fn backward(
504 self,
505 ops: Ops<Self::State, 4>,
506 grads: &mut Gradients,
507 checkpointer: &mut Checkpointer,
508 ) {
509 let [node_x, node_offset, node_weight, node_mask] = ops.parents;
510 let grad = grads.consume::<B>(&ops.node);
511
512 let (x_state, offset_state, weight_state, mask_state, options) = ops.state;
513 let x = checkpointer.retrieve_node_output(x_state);
514 let offset = checkpointer.retrieve_node_output(offset_state);
515 let weight = checkpointer.retrieve_node_output(weight_state);
516 let mask = Some(checkpointer.retrieve_node_output(mask_state));
517
518 let backward =
519 B::deform_conv2d_backward(x, offset, weight, mask, None, grad, options);
520
521 if let Some(node) = node_x {
522 grads.register::<B>(node.id, backward.x_grad)
523 }
524 if let Some(node) = node_offset {
525 grads.register::<B>(node.id, backward.offset_grad)
526 }
527 if let Some(node) = node_weight {
528 grads.register::<B>(node.id, backward.weight_grad)
529 }
530 if let Some(node) = node_mask {
531 grads.register::<B>(node.id, backward.mask_grad.unwrap())
532 }
533 }
534 }
535
536 impl<B: Backend> Backward<B, 4> for DeformConv2DNoMaskWithBias {
537 type State = (NodeId, NodeId, NodeId, NodeId, DeformConvOptions<2>);
538
539 fn backward(
540 self,
541 ops: Ops<Self::State, 4>,
542 grads: &mut Gradients,
543 checkpointer: &mut Checkpointer,
544 ) {
545 let [node_x, node_offset, node_weight, node_bias] = ops.parents;
546 let grad = grads.consume::<B>(&ops.node);
547
548 let (x_state, offset_state, weight_state, bias_state, options) = ops.state;
549 let x = checkpointer.retrieve_node_output(x_state);
550 let offset = checkpointer.retrieve_node_output(offset_state);
551 let weight = checkpointer.retrieve_node_output(weight_state);
552 let bias = Some(checkpointer.retrieve_node_output(bias_state));
553
554 let backward =
555 B::deform_conv2d_backward(x, offset, weight, None, bias, grad, options);
556
557 if let Some(node) = node_x {
558 grads.register::<B>(node.id, backward.x_grad)
559 }
560 if let Some(node) = node_offset {
561 grads.register::<B>(node.id, backward.offset_grad)
562 }
563 if let Some(node) = node_weight {
564 grads.register::<B>(node.id, backward.weight_grad)
565 }
566 if let Some(node) = node_bias {
567 grads.register::<B>(node.id, backward.bias_grad.unwrap())
568 }
569 }
570 }
571
572 impl<B: Backend> Backward<B, 3> for DeformConv2DNoMaskNoBias {
573 type State = (NodeId, NodeId, NodeId, DeformConvOptions<2>);
574
575 fn backward(
576 self,
577 ops: Ops<Self::State, 3>,
578 grads: &mut Gradients,
579 checkpointer: &mut Checkpointer,
580 ) {
581 let [node_x, node_offset, node_weight] = ops.parents;
582 let grad = grads.consume::<B>(&ops.node);
583
584 let (x_state, offset_state, weight_state, options) = ops.state;
585 let x = checkpointer.retrieve_node_output(x_state);
586 let offset = checkpointer.retrieve_node_output(offset_state);
587 let weight = checkpointer.retrieve_node_output(weight_state);
588
589 let backward =
590 B::deform_conv2d_backward(x, offset, weight, None, None, grad, options);
591
592 if let Some(node) = node_x {
593 grads.register::<B>(node.id, backward.x_grad)
594 }
595 if let Some(node) = node_offset {
596 grads.register::<B>(node.id, backward.offset_grad)
597 }
598 if let Some(node) = node_weight {
599 grads.register::<B>(node.id, backward.weight_grad)
600 }
601 }
602 }
603
604 match (mask, bias) {
605 (Some(mask), Some(bias)) => match DeformConv2DWithMaskWithBias
606 .prepare::<C>([
607 x.node.clone(),
608 offset.node.clone(),
609 weight.node.clone(),
610 mask.node.clone(),
611 bias.node.clone(),
612 ])
613 .compute_bound()
614 .stateful()
615 {
616 OpsKind::Tracked(mut prep) => {
617 let x_state = prep.checkpoint(&x);
618 let offset_state = prep.checkpoint(&offset);
619 let weight_state = prep.checkpoint(&weight);
620 let mask_state = prep.checkpoint(&mask);
621 let bias_state = prep.checkpoint(&bias);
622 prep.finish(
623 (
624 x_state,
625 offset_state,
626 weight_state,
627 mask_state,
628 bias_state,
629 options.clone(),
630 ),
631 B::deform_conv2d(
632 x.primitive,
633 offset.primitive,
634 weight.primitive,
635 Some(mask.primitive),
636 Some(bias.primitive),
637 options,
638 ),
639 )
640 }
641 OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
642 x.primitive,
643 offset.primitive,
644 weight.primitive,
645 Some(mask.primitive),
646 Some(bias.primitive),
647 options,
648 )),
649 },
650 (Some(mask), None) => match DeformConv2DWithMaskNoBias
651 .prepare::<C>([
652 x.node.clone(),
653 offset.node.clone(),
654 weight.node.clone(),
655 mask.node.clone(),
656 ])
657 .compute_bound()
658 .stateful()
659 {
660 OpsKind::Tracked(mut prep) => {
661 let x_state = prep.checkpoint(&x);
662 let offset_state = prep.checkpoint(&offset);
663 let weight_state = prep.checkpoint(&weight);
664 let mask_state = prep.checkpoint(&mask);
665 prep.finish(
666 (
667 x_state,
668 offset_state,
669 weight_state,
670 mask_state,
671 options.clone(),
672 ),
673 B::deform_conv2d(
674 x.primitive,
675 offset.primitive,
676 weight.primitive,
677 Some(mask.primitive),
678 None,
679 options,
680 ),
681 )
682 }
683 OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
684 x.primitive,
685 offset.primitive,
686 weight.primitive,
687 Some(mask.primitive),
688 None,
689 options,
690 )),
691 },
692 (None, Some(bias)) => match DeformConv2DNoMaskWithBias
693 .prepare::<C>([
694 x.node.clone(),
695 offset.node.clone(),
696 weight.node.clone(),
697 bias.node.clone(),
698 ])
699 .compute_bound()
700 .stateful()
701 {
702 OpsKind::Tracked(mut prep) => {
703 let x_state = prep.checkpoint(&x);
704 let offset_state = prep.checkpoint(&offset);
705 let weight_state = prep.checkpoint(&weight);
706 let bias_state = prep.checkpoint(&bias);
707 prep.finish(
708 (
709 x_state,
710 offset_state,
711 weight_state,
712 bias_state,
713 options.clone(),
714 ),
715 B::deform_conv2d(
716 x.primitive,
717 offset.primitive,
718 weight.primitive,
719 None,
720 Some(bias.primitive),
721 options,
722 ),
723 )
724 }
725 OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
726 x.primitive,
727 offset.primitive,
728 weight.primitive,
729 None,
730 Some(bias.primitive),
731 options,
732 )),
733 },
734 (None, None) => match DeformConv2DNoMaskNoBias
735 .prepare::<C>([x.node.clone(), offset.node.clone(), weight.node.clone()])
736 .compute_bound()
737 .stateful()
738 {
739 OpsKind::Tracked(mut prep) => {
740 let x_state = prep.checkpoint(&x);
741 let offset_state = prep.checkpoint(&offset);
742 let weight_state = prep.checkpoint(&weight);
743 prep.finish(
744 (x_state, offset_state, weight_state, options.clone()),
745 B::deform_conv2d(
746 x.primitive,
747 offset.primitive,
748 weight.primitive,
749 None,
750 None,
751 options,
752 ),
753 )
754 }
755 OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
756 x.primitive,
757 offset.primitive,
758 weight.primitive,
759 None,
760 None,
761 options,
762 )),
763 },
764 }
765 }
766
767 fn deform_conv2d_backward(
768 _x: AutodiffTensor<B>,
769 _offset: AutodiffTensor<B>,
770 _weight: AutodiffTensor<B>,
771 _mask: Option<AutodiffTensor<B>>,
772 _bias: Option<AutodiffTensor<B>>,
773 _output_grad: AutodiffTensor<B>,
774 _options: DeformConvOptions<2>,
775 ) -> DeformConv2dBackward<Self> {
776 panic!("Can't differentiate deform conv 2d backward.");
777 }
778
779 fn conv_transpose2d(
780 x: AutodiffTensor<B>,
781 weight: AutodiffTensor<B>,
782 bias: Option<AutodiffTensor<B>>,
783 options: ConvTransposeOptions<2>,
784 ) -> AutodiffTensor<B> {
785 #[derive(Debug)]
786 struct ConvTranspose2DWithBias;
787 #[derive(Debug)]
788 struct ConvTranspose2DNoBias;
789
790 impl<B: Backend> Backward<B, 3> for ConvTranspose2DWithBias {
791 type State = (NodeId, NodeId, NodeId, ConvTransposeOptions<2>);
792
793 fn backward(
794 self,
795 ops: Ops<Self::State, 3>,
796 grads: &mut Gradients,
797 checkpointer: &mut Checkpointer,
798 ) {
799 let [node_x, node_weight, node_bias] = ops.parents;
800 let grad = grads.consume::<B>(&ops.node);
801
802 let (x_state, weight_state, bias_state, options) = ops.state;
803 let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
804 let weight =
805 checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
806 let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);
807
808 if let Some(node) = node_x {
809 let grad = B::conv_transpose2d_x_backward(
810 weight.clone(),
811 grad.clone(),
812 options.clone(),
813 );
814 grads.register::<B>(node.id, grad)
815 }
816 if let Some(node) = node_weight {
817 let grad = B::conv_transpose2d_weight_backward(
818 x.clone(),
819 weight,
820 grad.clone(),
821 options,
822 );
823 grads.register::<B>(node.id, grad)
824 }
825 if let Some(node) = node_bias {
826 let grad = B::conv_transpose2d_bias_backward(x, bias, grad);
827 grads.register::<B>(node.id, grad)
828 }
829 }
830 }
831
832 impl<B: Backend> Backward<B, 2> for ConvTranspose2DNoBias {
833 type State = (NodeId, NodeId, ConvTransposeOptions<2>);
834
835 fn backward(
836 self,
837 ops: Ops<Self::State, 2>,
838 grads: &mut Gradients,
839 checkpointer: &mut Checkpointer,
840 ) {
841 let [node_x, node_weight] = ops.parents;
842 let grad = grads.consume::<B>(&ops.node);
843
844 let (x_state, weight_state, options) = ops.state;
845 let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
846 let weight =
847 checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
848
849 if let Some(node) = node_x {
850 let grad = B::conv_transpose2d_x_backward(
851 weight.clone(),
852 grad.clone(),
853 options.clone(),
854 );
855 grads.register::<B>(node.id, grad)
856 }
857 if let Some(node) = node_weight {
858 let grad = B::conv_transpose2d_weight_backward(x, weight, grad, options);
859 grads.register::<B>(node.id, grad)
860 }
861 }
862 }
863
864 match bias {
865 Some(bias) => match ConvTranspose2DWithBias
866 .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
867 .compute_bound()
868 .stateful()
869 {
870 OpsKind::Tracked(mut prep) => {
871 let x_state = prep.checkpoint(&x);
872 let weight_state = prep.checkpoint(&weight);
873 let bias_state = prep.checkpoint(&bias);
874
875 prep.finish(
876 (x_state, weight_state, bias_state, options.clone()),
877 B::conv_transpose2d(
878 x.primitive,
879 weight.primitive,
880 Some(bias.primitive),
881 options,
882 ),
883 )
884 }
885 OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
886 x.primitive,
887 weight.primitive,
888 Some(bias.primitive),
889 options,
890 )),
891 },
892 None => match ConvTranspose2DNoBias
893 .prepare::<C>([x.node.clone(), weight.node.clone()])
894 .compute_bound()
895 .stateful()
896 {
897 OpsKind::Tracked(mut prep) => {
898 let x_state = prep.checkpoint(&x);
899 let weight_state = prep.checkpoint(&weight);
900
901 prep.finish(
902 (x_state, weight_state, options.clone()),
903 B::conv_transpose2d(x.primitive, weight.primitive, None, options),
904 )
905 }
906 OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
907 x.primitive,
908 weight.primitive,
909 None,
910 options,
911 )),
912 },
913 }
914 }
915
916 fn conv3d(
917 x: AutodiffTensor<B>,
918 weight: AutodiffTensor<B>,
919 bias: Option<AutodiffTensor<B>>,
920 options: ConvOptions<3>,
921 ) -> AutodiffTensor<B> {
922 #[derive(Debug)]
923 struct Conv3DWithBias;
924 #[derive(Debug)]
925 struct Conv3DNoBias;
926
927 impl<B: Backend> Backward<B, 3> for Conv3DWithBias {
928 type State = (NodeId, NodeId, NodeId, ConvOptions<3>);
929
930 fn backward(
931 self,
932 ops: Ops<Self::State, 3>,
933 grads: &mut Gradients,
934 checkpointer: &mut Checkpointer,
935 ) {
936 let [node_x, node_weight, node_bias] = ops.parents;
937 let grad = grads.consume::<B>(&ops.node);
938
939 let (x_state, weight_state, bias_state, options) = ops.state;
940 let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
941 let weight =
942 checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
943 let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);
944
945 if let Some(node) = node_x {
946 let grad = B::conv3d_x_backward(
947 x.clone(),
948 weight.clone(),
949 grad.clone(),
950 options.clone(),
951 );
952 grads.register::<B>(node.id, grad)
953 }
954 if let Some(node) = node_weight {
955 let grad =
956 B::conv3d_weight_backward(x.clone(), weight.clone(), grad.clone(), options);
957 grads.register::<B>(node.id, grad)
958 }
959 if let Some(node) = node_bias {
960 let grad = B::conv3d_bias_backward(x, weight, bias, grad);
961 grads.register::<B>(node.id, grad)
962 }
963 }
964 }
965
966 impl<B: Backend> Backward<B, 2> for Conv3DNoBias {
967 type State = (NodeId, NodeId, ConvOptions<3>);
968
969 fn backward(
970 self,
971 ops: Ops<Self::State, 2>,
972 grads: &mut Gradients,
973 checkpointer: &mut Checkpointer,
974 ) {
975 let [node_x, node_weight] = ops.parents;
976 let grad = grads.consume::<B>(&ops.node);
977
978 let (x_state, weight_state, options) = ops.state;
979 let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
980 let weight =
981 checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
982
983 if let Some(node) = node_x {
984 let grad = B::conv3d_x_backward(
985 x.clone(),
986 weight.clone(),
987 grad.clone(),
988 options.clone(),
989 );
990 grads.register::<B>(node.id, grad)
991 }
992 if let Some(node) = node_weight {
993 let grad = B::conv3d_weight_backward(x, weight, grad, options);
994 grads.register::<B>(node.id, grad)
995 }
996 }
997 }
998
999 match bias {
1000 Some(bias) => match Conv3DWithBias
1001 .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
1002 .compute_bound()
1003 .stateful()
1004 {
1005 OpsKind::Tracked(mut prep) => {
1006 let x_state = prep.checkpoint(&x);
1007 let weight_state = prep.checkpoint(&weight);
1008 let bias_state = prep.checkpoint(&bias);
1009 prep.finish(
1010 (x_state, weight_state, bias_state, options.clone()),
1011 B::conv3d(x.primitive, weight.primitive, Some(bias.primitive), options),
1012 )
1013 }
1014 OpsKind::UnTracked(prep) => prep.finish(B::conv3d(
1015 x.primitive,
1016 weight.primitive,
1017 Some(bias.primitive),
1018 options,
1019 )),
1020 },
1021 None => match Conv3DNoBias
1022 .prepare::<C>([x.node.clone(), weight.node.clone()])
1023 .compute_bound()
1024 .stateful()
1025 {
1026 OpsKind::Tracked(mut prep) => {
1027 let x_state = prep.checkpoint(&x);
1028 let weight_state = prep.checkpoint(&weight);
1029 prep.finish(
1030 (x_state, weight_state, options.clone()),
1031 B::conv3d(x.primitive, weight.primitive, None, options),
1032 )
1033 }
1034
1035 OpsKind::UnTracked(prep) => {
1036 prep.finish(B::conv3d(x.primitive, weight.primitive, None, options))
1037 }
1038 },
1039 }
1040 }
1041
1042 fn conv_transpose3d(
1043 x: AutodiffTensor<B>,
1044 weight: AutodiffTensor<B>,
1045 bias: Option<AutodiffTensor<B>>,
1046 options: ConvTransposeOptions<3>,
1047 ) -> AutodiffTensor<B> {
1048 #[derive(Debug)]
1049 struct ConvTranspose3DWithBias;
1050 #[derive(Debug)]
1051 struct ConvTranspose3DNoBias;
1052
1053 impl<B: Backend> Backward<B, 3> for ConvTranspose3DWithBias {
1054 type State = (NodeId, NodeId, NodeId, ConvTransposeOptions<3>);
1055
1056 fn backward(
1057 self,
1058 ops: Ops<Self::State, 3>,
1059 grads: &mut Gradients,
1060 checkpointer: &mut Checkpointer,
1061 ) {
1062 let [node_x, node_weight, node_bias] = ops.parents;
1063 let grad = grads.consume::<B>(&ops.node);
1064
1065 let (x_state, weight_state, bias_state, options) = ops.state;
1066 let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
1067 let weight =
1068 checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
1069 let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);
1070
1071 if let Some(node) = node_x {
1072 let grad = B::conv_transpose3d_x_backward(
1073 weight.clone(),
1074 grad.clone(),
1075 options.clone(),
1076 );
1077 grads.register::<B>(node.id, grad)
1078 }
1079 if let Some(node) = node_weight {
1080 let grad = B::conv_transpose3d_weight_backward(
1081 x.clone(),
1082 weight,
1083 grad.clone(),
1084 options,
1085 );
1086 grads.register::<B>(node.id, grad)
1087 }
1088 if let Some(node) = node_bias {
1089 let grad = B::conv_transpose3d_bias_backward(x, bias, grad);
1090 grads.register::<B>(node.id, grad)
1091 }
1092 }
1093 }
1094
1095 impl<B: Backend> Backward<B, 2> for ConvTranspose3DNoBias {
1096 type State = (NodeId, NodeId, ConvTransposeOptions<3>);
1097
1098 fn backward(
1099 self,
1100 ops: Ops<Self::State, 2>,
1101 grads: &mut Gradients,
1102 checkpointer: &mut Checkpointer,
1103 ) {
1104 let [node_x, node_weight] = ops.parents;
1105 let grad = grads.consume::<B>(&ops.node);
1106
1107 let (x_state, weight_state, options) = ops.state;
1108 let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
1109 let weight =
1110 checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
1111
1112 if let Some(node) = node_x {
1113 let grad = B::conv_transpose3d_x_backward(
1114 weight.clone(),
1115 grad.clone(),
1116 options.clone(),
1117 );
1118 grads.register::<B>(node.id, grad)
1119 }
1120 if let Some(node) = node_weight {
1121 let grad = B::conv_transpose3d_weight_backward(x, weight, grad, options);
1122 grads.register::<B>(node.id, grad)
1123 }
1124 }
1125 }
1126
1127 match bias {
1128 Some(bias) => match ConvTranspose3DWithBias
1129 .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
1130 .compute_bound()
1131 .stateful()
1132 {
1133 OpsKind::Tracked(mut prep) => {
1134 let x_state = prep.checkpoint(&x);
1135 let weight_state = prep.checkpoint(&weight);
1136 let bias_state = prep.checkpoint(&bias);
1137
1138 prep.finish(
1139 (x_state, weight_state, bias_state, options.clone()),
1140 B::conv_transpose3d(
1141 x.primitive,
1142 weight.primitive,
1143 Some(bias.primitive),
1144 options,
1145 ),
1146 )
1147 }
1148 OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose3d(
1149 x.primitive,
1150 weight.primitive,
1151 Some(bias.primitive),
1152 options,
1153 )),
1154 },
1155 None => match ConvTranspose3DNoBias
1156 .prepare::<C>([x.node.clone(), weight.node.clone()])
1157 .compute_bound()
1158 .stateful()
1159 {
1160 OpsKind::Tracked(mut prep) => {
1161 let x_state = prep.checkpoint(&x);
1162 let weight_state = prep.checkpoint(&weight);
1163
1164 prep.finish(
1165 (x_state, weight_state, options.clone()),
1166 B::conv_transpose3d(x.primitive, weight.primitive, None, options),
1167 )
1168 }
1169 OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose3d(
1170 x.primitive,
1171 weight.primitive,
1172 None,
1173 options,
1174 )),
1175 },
1176 }
1177 }
1178
1179 fn avg_pool1d(
1195 x: AutodiffTensor<B>,
1196 kernel_size: usize,
1197 stride: usize,
1198 padding: usize,
1199 count_include_pad: bool,
1200 ceil_mode: bool,
1201 ) -> AutodiffTensor<B> {
1202 #[derive(Debug)]
1203 struct AvgPool1D;
1204
1205 impl<B: Backend> Backward<B, 1> for AvgPool1D {
1206 type State = (NodeId, usize, usize, usize, bool, bool);
1207
1208 fn backward(
1209 self,
1210 ops: Ops<Self::State, 1>,
1211 grads: &mut Gradients,
1212 checkpointer: &mut Checkpointer,
1213 ) {
1214 let [node_parent] = ops.parents;
1215 let grad = grads.consume::<B>(&ops.node);
1216 let (x_state, kernel_size, stride, padding, count_include_pad, ceil_mode) =
1217 ops.state;
1218 let x = checkpointer.retrieve_node_output(x_state);
1219
1220 if let Some(node) = node_parent {
1221 let grad = B::avg_pool1d_backward(
1222 x,
1223 grad,
1224 kernel_size,
1225 stride,
1226 padding,
1227 count_include_pad,
1228 ceil_mode,
1229 );
1230 grads.register::<B>(node.id, grad);
1231 }
1232 }
1233 }
1234
1235 match AvgPool1D
1236 .prepare::<C>([x.node.clone()])
1237 .compute_bound()
1238 .stateful()
1239 {
1240 OpsKind::Tracked(mut prep) => {
1241 let x_state = prep.checkpoint(&x);
1242 prep.finish(
1243 (
1244 x_state,
1245 kernel_size,
1246 stride,
1247 padding,
1248 count_include_pad,
1249 ceil_mode,
1250 ),
1251 B::avg_pool1d(
1252 x.primitive.clone(),
1253 kernel_size,
1254 stride,
1255 padding,
1256 count_include_pad,
1257 ceil_mode,
1258 ),
1259 )
1260 }
1261 OpsKind::UnTracked(prep) => prep.finish(B::avg_pool1d(
1262 x.primitive,
1263 kernel_size,
1264 stride,
1265 padding,
1266 count_include_pad,
1267 ceil_mode,
1268 )),
1269 }
1270 }
1271
1272 fn avg_pool2d(
1273 x: AutodiffTensor<B>,
1274 kernel_size: [usize; 2],
1275 stride: [usize; 2],
1276 padding: [usize; 2],
1277 count_include_pad: bool,
1278 ceil_mode: bool,
1279 ) -> AutodiffTensor<B> {
1280 #[derive(Debug)]
1281 struct AvgPool2D;
1282
1283 impl<B: Backend> Backward<B, 1> for AvgPool2D {
1284 type State = (NodeId, [usize; 2], [usize; 2], [usize; 2], bool, bool);
1285
1286 fn backward(
1287 self,
1288 ops: Ops<Self::State, 1>,
1289 grads: &mut Gradients,
1290 checkpointer: &mut Checkpointer,
1291 ) {
1292 let [node_parent] = ops.parents;
1293 let grad = grads.consume::<B>(&ops.node);
1294 let (x_state, kernel_size, stride, padding, count_include_pad, ceil_mode) =
1295 ops.state;
1296 let x = checkpointer.retrieve_node_output(x_state);
1297
1298 if let Some(node) = node_parent {
1299 let grad = B::avg_pool2d_backward(
1300 x,
1301 grad,
1302 kernel_size,
1303 stride,
1304 padding,
1305 count_include_pad,
1306 ceil_mode,
1307 );
1308 grads.register::<B>(node.id, grad);
1309 }
1310 }
1311 }
1312
1313 match AvgPool2D
1314 .prepare::<C>([x.node.clone()])
1315 .compute_bound()
1316 .stateful()
1317 {
1318 OpsKind::Tracked(mut prep) => {
1319 let x_state = prep.checkpoint(&x);
1320 prep.finish(
1321 (
1322 x_state,
1323 kernel_size,
1324 stride,
1325 padding,
1326 count_include_pad,
1327 ceil_mode,
1328 ),
1329 B::avg_pool2d(
1330 x.primitive.clone(),
1331 kernel_size,
1332 stride,
1333 padding,
1334 count_include_pad,
1335 ceil_mode,
1336 ),
1337 )
1338 }
1339 OpsKind::UnTracked(prep) => prep.finish(B::avg_pool2d(
1340 x.primitive,
1341 kernel_size,
1342 stride,
1343 padding,
1344 count_include_pad,
1345 ceil_mode,
1346 )),
1347 }
1348 }
1349
1350 fn avg_pool2d_backward(
1351 _x: AutodiffTensor<B>,
1352 _grad: AutodiffTensor<B>,
1353 _kernel_size: [usize; 2],
1354 _stride: [usize; 2],
1355 _padding: [usize; 2],
1356 _count_include_pad: bool,
1357 _ceil_mode: bool,
1358 ) -> AutodiffTensor<B> {
1359 panic!("Can't differentiate avg pool 2d backward.");
1360 }
1361
1362 fn max_pool1d(
1363 x: AutodiffTensor<B>,
1364 kernel_size: usize,
1365 stride: usize,
1366 padding: usize,
1367 dilation: usize,
1368 ceil_mode: bool,
1369 ) -> AutodiffTensor<B> {
1370 match MaxPool1D
1371 .prepare::<C>([x.node.clone()])
1372 .compute_bound()
1373 .stateful()
1374 {
1375 OpsKind::Tracked(mut prep) => {
1376 let x_state = prep.checkpoint(&x);
1377 let output = B::max_pool1d_with_indices(
1378 x.primitive,
1379 kernel_size,
1380 stride,
1381 padding,
1382 dilation,
1383 ceil_mode,
1384 );
1385 prep.finish(
1386 (
1387 x_state,
1388 output.indices,
1389 kernel_size,
1390 stride,
1391 padding,
1392 dilation,
1393 ceil_mode,
1394 ),
1395 output.output,
1396 )
1397 }
1398 OpsKind::UnTracked(prep) => prep.finish(B::max_pool1d(
1399 x.primitive,
1400 kernel_size,
1401 stride,
1402 padding,
1403 dilation,
1404 ceil_mode,
1405 )),
1406 }
1407 }
1408
1409 fn max_pool1d_with_indices(
1410 x: AutodiffTensor<B>,
1411 kernel_size: usize,
1412 stride: usize,
1413 padding: usize,
1414 dilation: usize,
1415 ceil_mode: bool,
1416 ) -> MaxPool1dWithIndices<Self> {
1417 match MaxPool1D
1418 .prepare::<C>([x.node.clone()])
1419 .compute_bound()
1420 .stateful()
1421 {
1422 OpsKind::Tracked(mut prep) => {
1423 let x_state = prep.checkpoint(&x);
1424 let output = B::max_pool1d_with_indices(
1425 x.primitive,
1426 kernel_size,
1427 stride,
1428 padding,
1429 dilation,
1430 ceil_mode,
1431 );
1432
1433 let output_tensor = prep.finish(
1434 (
1435 x_state,
1436 output.indices.clone(),
1437 kernel_size,
1438 stride,
1439 padding,
1440 dilation,
1441 ceil_mode,
1442 ),
1443 output.output,
1444 );
1445
1446 MaxPool1dWithIndices::new(output_tensor, output.indices)
1447 }
1448 OpsKind::UnTracked(prep) => {
1449 let output = B::max_pool1d_with_indices(
1450 x.primitive,
1451 kernel_size,
1452 stride,
1453 padding,
1454 dilation,
1455 ceil_mode,
1456 );
1457 let output_tensor = prep.finish(output.output);
1458
1459 MaxPool1dWithIndices::new(output_tensor, output.indices)
1460 }
1461 }
1462 }
1463
1464 fn max_pool1d_with_indices_backward(
1465 x: AutodiffTensor<B>,
1466 kernel_size: usize,
1467 stride: usize,
1468 padding: usize,
1469 dilation: usize,
1470 ceil_mode: bool,
1471 output_grad: AutodiffTensor<B>,
1472 indices: IntTensor<B>,
1473 ) -> MaxPool1dBackward<Self> {
1474 let output = B::max_pool1d_with_indices_backward(
1475 x.primitive,
1476 kernel_size,
1477 stride,
1478 padding,
1479 dilation,
1480 ceil_mode,
1481 output_grad.primitive,
1482 indices,
1483 );
1484 MaxPool1dBackward::new(AutodiffTensor::new(output.x_grad))
1485 }
1486
1487 fn max_pool2d(
1488 x: AutodiffTensor<B>,
1489 kernel_size: [usize; 2],
1490 stride: [usize; 2],
1491 padding: [usize; 2],
1492 dilation: [usize; 2],
1493 ceil_mode: bool,
1494 ) -> AutodiffTensor<B> {
1495 match MaxPool2D
1496 .prepare::<C>([x.node.clone()])
1497 .compute_bound()
1498 .stateful()
1499 {
1500 OpsKind::Tracked(mut prep) => {
1501 let x_state = prep.checkpoint(&x);
1502 let output = B::max_pool2d_with_indices(
1503 x.primitive,
1504 kernel_size,
1505 stride,
1506 padding,
1507 dilation,
1508 ceil_mode,
1509 );
1510 prep.finish(
1511 (
1512 x_state,
1513 output.indices,
1514 kernel_size,
1515 stride,
1516 padding,
1517 dilation,
1518 ceil_mode,
1519 ),
1520 output.output,
1521 )
1522 }
1523 OpsKind::UnTracked(prep) => prep.finish(B::max_pool2d(
1524 x.primitive,
1525 kernel_size,
1526 stride,
1527 padding,
1528 dilation,
1529 ceil_mode,
1530 )),
1531 }
1532 }
1533
1534 fn max_pool2d_with_indices(
1535 x: AutodiffTensor<B>,
1536 kernel_size: [usize; 2],
1537 stride: [usize; 2],
1538 padding: [usize; 2],
1539 dilation: [usize; 2],
1540 ceil_mode: bool,
1541 ) -> MaxPool2dWithIndices<Self> {
1542 match MaxPool2D
1543 .prepare::<C>([x.node.clone()])
1544 .compute_bound()
1545 .stateful()
1546 {
1547 OpsKind::Tracked(mut prep) => {
1548 let x_state = prep.checkpoint(&x);
1549
1550 let output = B::max_pool2d_with_indices(
1551 x.primitive,
1552 kernel_size,
1553 stride,
1554 padding,
1555 dilation,
1556 ceil_mode,
1557 );
1558
1559 let output_tensor = prep.finish(
1560 (
1561 x_state,
1562 output.indices.clone(),
1563 kernel_size,
1564 stride,
1565 padding,
1566 dilation,
1567 ceil_mode,
1568 ),
1569 output.output,
1570 );
1571
1572 MaxPool2dWithIndices::new(output_tensor, output.indices)
1573 }
1574 OpsKind::UnTracked(prep) => {
1575 let output = B::max_pool2d_with_indices(
1576 x.primitive,
1577 kernel_size,
1578 stride,
1579 padding,
1580 dilation,
1581 ceil_mode,
1582 );
1583 let output_tensor = prep.finish(output.output);
1584
1585 MaxPool2dWithIndices::new(output_tensor, output.indices)
1586 }
1587 }
1588 }
1589
1590 fn max_pool2d_with_indices_backward(
1591 _x: AutodiffTensor<B>,
1592 _kernel_size: [usize; 2],
1593 _stride: [usize; 2],
1594 _padding: [usize; 2],
1595 _dilation: [usize; 2],
1596 _ceil_mode: bool,
1597 _output_grad: AutodiffTensor<B>,
1598 _indices: IntTensor<B>,
1599 ) -> MaxPool2dBackward<Self> {
1600 panic!("Can't differentiate max pool2d with indices backward.");
1601 }
1602 fn adaptive_avg_pool1d(x: AutodiffTensor<B>, output_size: usize) -> AutodiffTensor<B> {
1603 #[derive(Debug)]
1604 struct AdaptiveAvgPool1D;
1605
1606 impl<B: Backend> Backward<B, 1> for AdaptiveAvgPool1D {
1607 type State = NodeId;
1608
1609 fn backward(
1610 self,
1611 ops: Ops<Self::State, 1>,
1612 grads: &mut Gradients,
1613 checkpointer: &mut Checkpointer,
1614 ) {
1615 let [node_parent] = ops.parents;
1616 let grad = grads.consume::<B>(&ops.node);
1617 let state = checkpointer.retrieve_node_output(ops.state);
1618
1619 if let Some(node) = node_parent {
1620 let grad = B::adaptive_avg_pool1d_backward(state, grad);
1621 grads.register::<B>(node.id, grad);
1622 }
1623 }
1624 }
1625
1626 match AdaptiveAvgPool1D
1627 .prepare::<C>([x.node.clone()])
1628 .compute_bound()
1629 .stateful()
1630 {
1631 OpsKind::Tracked(mut prep) => {
1632 let x_state = prep.checkpoint(&x);
1633 prep.finish(x_state, B::adaptive_avg_pool1d(x.primitive, output_size))
1634 }
1635 OpsKind::UnTracked(prep) => {
1636 prep.finish(B::adaptive_avg_pool1d(x.primitive, output_size))
1637 }
1638 }
1639 }
1640
1641 fn adaptive_avg_pool2d(x: AutodiffTensor<B>, output_size: [usize; 2]) -> AutodiffTensor<B> {
1642 #[derive(Debug)]
1643 struct AdaptiveAvgPool2D;
1644
1645 impl<B: Backend> Backward<B, 1> for AdaptiveAvgPool2D {
1646 type State = NodeId;
1647
1648 fn backward(
1649 self,
1650 ops: Ops<Self::State, 1>,
1651 grads: &mut Gradients,
1652 checkpointer: &mut Checkpointer,
1653 ) {
1654 let [node_parent] = ops.parents;
1655 let grad = grads.consume::<B>(&ops.node);
1656 let state = checkpointer.retrieve_node_output(ops.state);
1657
1658 if let Some(node) = node_parent {
1659 let grad = B::adaptive_avg_pool2d_backward(state, grad);
1660 grads.register::<B>(node.id, grad);
1661 }
1662 }
1663 }
1664
1665 match AdaptiveAvgPool2D
1666 .prepare::<C>([x.node.clone()])
1667 .compute_bound()
1668 .stateful()
1669 {
1670 OpsKind::Tracked(mut prep) => {
1671 let x_state = prep.checkpoint(&x);
1672 prep.finish(x_state, B::adaptive_avg_pool2d(x.primitive, output_size))
1673 }
1674 OpsKind::UnTracked(prep) => {
1675 prep.finish(B::adaptive_avg_pool2d(x.primitive, output_size))
1676 }
1677 }
1678 }
1679
1680 fn adaptive_avg_pool2d_backward(
1681 _x: AutodiffTensor<B>,
1682 _grad: AutodiffTensor<B>,
1683 ) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
1684 panic!("Can't differentiate adaptive avg pool2d backward.");
1685 }
1686
1687 fn interpolate(
1688 x: AutodiffTensor<B>,
1689 output_size: [usize; 2],
1690 options: InterpolateOptions,
1691 ) -> AutodiffTensor<B> {
1692 #[derive(Debug)]
1693 struct Interpolate;
1694 impl<B: Backend> Backward<B, 1> for Interpolate {
1695 type State = (NodeId, [usize; 2], InterpolateOptions);
1696
1697 fn backward(
1698 self,
1699 ops: Ops<Self::State, 1>,
1700 grads: &mut Gradients,
1701 checkpointer: &mut Checkpointer,
1702 ) {
1703 let [node_parent] = ops.parents;
1704 let grad = grads.consume::<B>(&ops.node);
1705
1706 let (x_state, output_size, options) = ops.state;
1707 let state = checkpointer.retrieve_node_output(x_state);
1708
1709 if let Some(node) = node_parent {
1710 let grad = B::interpolate_backward(state, grad, output_size, options);
1711 grads.register::<B>(node.id, grad);
1712 }
1713 }
1714 }
1715
1716 match Interpolate
1717 .prepare::<C>([x.node.clone()])
1718 .compute_bound()
1719 .stateful()
1720 {
1721 OpsKind::Tracked(mut prep) => {
1722 let x_state = prep.checkpoint(&x);
1723 let output = B::interpolate(x.primitive.clone(), output_size, options.clone());
1724 prep.finish((x_state, output_size, options), output)
1725 }
1726 OpsKind::UnTracked(prep) => {
1727 prep.finish(B::interpolate(x.primitive, output_size, options))
1728 }
1729 }
1730 }
1731
1732 fn interpolate_backward(
1733 _x: FloatTensor<Autodiff<B, C>>,
1734 _grad: FloatTensor<Autodiff<B, C>>,
1735 _output_size: [usize; 2],
1736 _options: InterpolateOptions,
1737 ) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
1738 panic!("Can't differentiate interpolate backward.");
1739 }
1740}
1741
1742#[derive(Debug)]
1743struct MaxPool1D;
1744
1745impl<B: Backend> Backward<B, 1> for MaxPool1D {
1746 type State = (NodeId, IntTensor<B>, usize, usize, usize, usize, bool);
1747
1748 fn backward(
1749 self,
1750 ops: Ops<Self::State, 1>,
1751 grads: &mut Gradients,
1752 checkpointer: &mut Checkpointer,
1753 ) {
1754 let [node_parent] = ops.parents;
1755 let grad = grads.consume::<B>(&ops.node);
1756 let (x_state, indices, kernel_size, stride, padding, dilation, ceil_mode) = ops.state;
1757 let x = checkpointer.retrieve_node_output(x_state);
1758
1759 if let Some(node) = node_parent {
1760 let grad = B::max_pool1d_with_indices_backward(
1761 x,
1762 kernel_size,
1763 stride,
1764 padding,
1765 dilation,
1766 ceil_mode,
1767 grad,
1768 indices,
1769 );
1770
1771 grads.register::<B>(node.id, grad.x_grad);
1772 }
1773 }
1774}
1775
1776#[derive(Debug)]
1777struct MaxPool2D;
1778
1779impl<B: Backend> Backward<B, 1> for MaxPool2D {
1780 type State = (
1781 NodeId,
1782 IntTensor<B>,
1783 [usize; 2],
1784 [usize; 2],
1785 [usize; 2],
1786 [usize; 2],
1787 bool,
1788 );
1789
1790 fn backward(
1791 self,
1792 ops: Ops<Self::State, 1>,
1793 grads: &mut Gradients,
1794 checkpointer: &mut Checkpointer,
1795 ) {
1796 let [node_parent] = ops.parents;
1797 let grad = grads.consume::<B>(&ops.node);
1798 let (x_state, indices, kernel_size, stride, padding, dilation, ceil_mode) = ops.state;
1799 let x = checkpointer.retrieve_node_output(x_state);
1800
1801 if let Some(node) = node_parent {
1802 let grad = B::max_pool2d_with_indices_backward(
1803 x,
1804 kernel_size,
1805 stride,
1806 padding,
1807 dilation,
1808 ceil_mode,
1809 grad,
1810 indices,
1811 );
1812
1813 grads.register::<B>(node.id, grad.x_grad);
1814 }
1815 }
1816}