Skip to main content

burn_autodiff/ops/
module.rs

1use crate::Autodiff;
2use crate::checkpoint::base::Checkpointer;
3use crate::checkpoint::strategy::CheckpointStrategy;
4use crate::grads::Gradients;
5use crate::graph::NodeId;
6use crate::ops::{Backward, Ops, unary};
7use crate::tensor::AutodiffTensor;
8
9use burn_backend::Backend;
10use burn_backend::ops::*;
11use burn_backend::tensor::{FloatTensor, IntTensor};
12
13use super::OpsKind;
14
15impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B, C> {
16    fn embedding(weights: AutodiffTensor<B>, indices: IntTensor<B>) -> AutodiffTensor<B> {
17        #[derive(Debug)]
18        struct Embedding;
19
20        impl<B: Backend> Backward<B, 1> for Embedding {
21            type State = (B::FloatTensorPrimitive, IntTensor<B>);
22
23            fn backward(
24                self,
25                ops: Ops<Self::State, 1>,
26                grads: &mut Gradients,
27                _checkpointer: &mut Checkpointer,
28            ) {
29                let (weights, indices) = ops.state;
30
31                unary::<B, _>(ops.parents, ops.node, grads, |grad| {
32                    B::embedding_backward(weights, grad, indices)
33                });
34            }
35        }
36
37        match Embedding
38            .prepare::<C>([weights.node])
39            .compute_bound()
40            .stateful()
41        {
42            OpsKind::Tracked(prep) => prep.finish(
43                (weights.primitive.clone(), indices.clone()),
44                B::embedding(weights.primitive, indices),
45            ),
46            OpsKind::UnTracked(prep) => prep.finish(B::embedding(weights.primitive, indices)),
47        }
48    }
49
50    fn embedding_backward(
51        _weights: AutodiffTensor<B>,
52        _output: AutodiffTensor<B>,
53        _indices: IntTensor<B>,
54    ) -> AutodiffTensor<B> {
55        panic!("Can't differentiate embedding backward.");
56    }
57
58    fn conv1d(
59        x: AutodiffTensor<B>,
60        weight: AutodiffTensor<B>,
61        bias: Option<AutodiffTensor<B>>,
62        options: ConvOptions<1>,
63    ) -> AutodiffTensor<B> {
64        #[derive(Debug)]
65        struct Conv1DWithBias;
66        #[derive(Debug)]
67        struct Conv1DNoBias;
68
69        impl<B: Backend> Backward<B, 3> for Conv1DWithBias {
70            type State = (NodeId, NodeId, NodeId, ConvOptions<1>);
71
72            fn backward(
73                self,
74                ops: Ops<Self::State, 3>,
75                grads: &mut Gradients,
76                checkpointer: &mut Checkpointer,
77            ) {
78                let [node_x, node_weight, node_bias] = ops.parents;
79                let grad = grads.consume::<B>(&ops.node);
80
81                let (x_state, weight_state, bias_state, options) = ops.state;
82                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
83                let weight =
84                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
85                let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);
86
87                if let Some(node) = node_x {
88                    let grad = B::conv1d_x_backward(
89                        x.clone(),
90                        weight.clone(),
91                        grad.clone(),
92                        options.clone(),
93                    );
94                    grads.register::<B>(node.id, grad)
95                }
96                if let Some(node) = node_weight {
97                    let grad = B::conv1d_weight_backward(x.clone(), weight, grad.clone(), options);
98                    grads.register::<B>(node.id, grad)
99                }
100                if let Some(node) = node_bias {
101                    let grad = B::conv1d_bias_backward(x, bias, grad);
102                    grads.register::<B>(node.id, grad)
103                }
104            }
105        }
106
107        impl<B: Backend> Backward<B, 2> for Conv1DNoBias {
108            type State = (NodeId, NodeId, ConvOptions<1>);
109
110            fn backward(
111                self,
112                ops: Ops<Self::State, 2>,
113                grads: &mut Gradients,
114                checkpointer: &mut Checkpointer,
115            ) {
116                let [node_x, node_weight] = ops.parents;
117                let grad = grads.consume::<B>(&ops.node);
118
119                let (x_state, weight_state, options) = ops.state;
120                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
121                let weight =
122                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
123
124                if let Some(node) = node_x {
125                    let grad = B::conv1d_x_backward(
126                        x.clone(),
127                        weight.clone(),
128                        grad.clone(),
129                        options.clone(),
130                    );
131                    grads.register::<B>(node.id, grad)
132                }
133                if let Some(node) = node_weight {
134                    let grad = B::conv1d_weight_backward(x, weight, grad, options);
135                    grads.register::<B>(node.id, grad)
136                }
137            }
138        }
139        match bias {
140            Some(bias) => match Conv1DWithBias
141                .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
142                .compute_bound()
143                .stateful()
144            {
145                OpsKind::Tracked(mut prep) => {
146                    let x_state = prep.checkpoint(&x);
147                    let weight_state = prep.checkpoint(&weight);
148                    let bias_state = prep.checkpoint(&bias);
149                    prep.finish(
150                        (x_state, weight_state, bias_state, options.clone()),
151                        B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options),
152                    )
153                }
154                OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
155                    x.primitive,
156                    weight.primitive,
157                    Some(bias.primitive),
158                    options,
159                )),
160            },
161            None => match Conv1DNoBias
162                .prepare::<C>([x.node.clone(), weight.node.clone()])
163                .compute_bound()
164                .stateful()
165            {
166                OpsKind::Tracked(mut prep) => {
167                    let x_state = prep.checkpoint(&x);
168                    let weight_state = prep.checkpoint(&weight);
169                    prep.finish(
170                        (x_state, weight_state, options.clone()),
171                        B::conv1d(x.primitive, weight.primitive, None, options),
172                    )
173                }
174                OpsKind::UnTracked(prep) => {
175                    prep.finish(B::conv1d(x.primitive, weight.primitive, None, options))
176                }
177            },
178        }
179    }
180
181    fn conv_transpose1d(
182        x: AutodiffTensor<B>,
183        weight: AutodiffTensor<B>,
184        bias: Option<AutodiffTensor<B>>,
185        options: ConvTransposeOptions<1>,
186    ) -> AutodiffTensor<B> {
187        #[derive(Debug)]
188        struct ConvTranspose1DWithBias;
189        #[derive(Debug)]
190        struct ConvTranspose1DNoBias;
191
192        impl<B: Backend> Backward<B, 3> for ConvTranspose1DWithBias {
193            type State = (NodeId, NodeId, NodeId, ConvTransposeOptions<1>);
194
195            fn backward(
196                self,
197                ops: Ops<Self::State, 3>,
198                grads: &mut Gradients,
199                checkpointer: &mut Checkpointer,
200            ) {
201                let [node_x, node_weight, node_bias] = ops.parents;
202                let grad = grads.consume::<B>(&ops.node);
203
204                let (x_state, weight_state, bias_state, options) = ops.state;
205                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
206                let weight =
207                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
208                let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);
209
210                if let Some(node) = node_x {
211                    let grad = B::conv_transpose1d_x_backward(
212                        weight.clone(),
213                        grad.clone(),
214                        options.clone(),
215                    );
216                    grads.register::<B>(node.id, grad)
217                }
218                if let Some(node) = node_weight {
219                    let grad = B::conv_transpose1d_weight_backward(
220                        x.clone(),
221                        weight,
222                        grad.clone(),
223                        options,
224                    );
225                    grads.register::<B>(node.id, grad)
226                }
227                if let Some(node) = node_bias {
228                    let grad = B::conv_transpose1d_bias_backward(x, bias, grad);
229                    grads.register::<B>(node.id, grad)
230                }
231            }
232        }
233
234        impl<B: Backend> Backward<B, 2> for ConvTranspose1DNoBias {
235            type State = (NodeId, NodeId, ConvTransposeOptions<1>);
236
237            fn backward(
238                self,
239                ops: Ops<Self::State, 2>,
240                grads: &mut Gradients,
241                checkpointer: &mut Checkpointer,
242            ) {
243                let [node_x, node_weight] = ops.parents;
244                let grad = grads.consume::<B>(&ops.node);
245
246                let (x_state, weight_state, options) = ops.state;
247                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
248                let weight =
249                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
250
251                if let Some(node) = node_x {
252                    let grad = B::conv_transpose1d_x_backward(
253                        weight.clone(),
254                        grad.clone(),
255                        options.clone(),
256                    );
257                    grads.register::<B>(node.id, grad)
258                }
259                if let Some(node) = node_weight {
260                    let grad = B::conv_transpose1d_weight_backward(x, weight, grad, options);
261                    grads.register::<B>(node.id, grad)
262                }
263            }
264        }
265
266        match bias {
267            Some(bias) => match ConvTranspose1DWithBias
268                .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
269                .compute_bound()
270                .stateful()
271            {
272                OpsKind::Tracked(mut prep) => {
273                    let x_state = prep.checkpoint(&x);
274                    let weight_state = prep.checkpoint(&weight);
275                    let bias_state = prep.checkpoint(&bias);
276                    prep.finish(
277                        (x_state, weight_state, bias_state, options.clone()),
278                        B::conv_transpose1d(
279                            x.primitive,
280                            weight.primitive,
281                            Some(bias.primitive),
282                            options,
283                        ),
284                    )
285                }
286                OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
287                    x.primitive,
288                    weight.primitive,
289                    Some(bias.primitive),
290                    options,
291                )),
292            },
293            None => match ConvTranspose1DNoBias
294                .prepare::<C>([x.node.clone(), weight.node.clone()])
295                .compute_bound()
296                .stateful()
297            {
298                OpsKind::Tracked(mut prep) => {
299                    let x_state = prep.checkpoint(&x);
300                    let weight_state = prep.checkpoint(&weight);
301                    prep.finish(
302                        (x_state, weight_state, options.clone()),
303                        B::conv_transpose1d(x.primitive, weight.primitive, None, options),
304                    )
305                }
306                OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
307                    x.primitive,
308                    weight.primitive,
309                    None,
310                    options,
311                )),
312            },
313        }
314    }
315
316    fn conv2d(
317        x: AutodiffTensor<B>,
318        weight: AutodiffTensor<B>,
319        bias: Option<AutodiffTensor<B>>,
320        options: ConvOptions<2>,
321    ) -> AutodiffTensor<B> {
322        #[derive(Debug)]
323        struct Conv2DWithBias;
324        #[derive(Debug)]
325        struct Conv2DNoBias;
326
327        impl<B: Backend> Backward<B, 3> for Conv2DWithBias {
328            type State = (NodeId, NodeId, NodeId, ConvOptions<2>);
329
330            fn backward(
331                self,
332                ops: Ops<Self::State, 3>,
333                grads: &mut Gradients,
334                checkpointer: &mut Checkpointer,
335            ) {
336                let [node_x, node_weight, node_bias] = ops.parents;
337                let grad = grads.consume::<B>(&ops.node);
338
339                let (x_state, weight_state, bias_state, options) = ops.state;
340                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
341                let weight =
342                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
343                let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);
344
345                if let Some(node) = node_x {
346                    let grad = B::conv2d_x_backward(
347                        x.clone(),
348                        weight.clone(),
349                        grad.clone(),
350                        options.clone(),
351                    );
352                    grads.register::<B>(node.id, grad)
353                }
354                if let Some(node) = node_weight {
355                    let grad =
356                        B::conv2d_weight_backward(x.clone(), weight.clone(), grad.clone(), options);
357                    grads.register::<B>(node.id, grad)
358                }
359                if let Some(node) = node_bias {
360                    let grad = B::conv2d_bias_backward(x, weight, bias, grad);
361                    grads.register::<B>(node.id, grad)
362                }
363            }
364        }
365
366        impl<B: Backend> Backward<B, 2> for Conv2DNoBias {
367            type State = (NodeId, NodeId, ConvOptions<2>);
368
369            fn backward(
370                self,
371                ops: Ops<Self::State, 2>,
372                grads: &mut Gradients,
373                checkpointer: &mut Checkpointer,
374            ) {
375                let [node_x, node_weight] = ops.parents;
376                let grad = grads.consume::<B>(&ops.node);
377
378                let (x_state, weight_state, options) = ops.state;
379                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
380                let weight =
381                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
382
383                if let Some(node) = node_x {
384                    let grad = B::conv2d_x_backward(
385                        x.clone(),
386                        weight.clone(),
387                        grad.clone(),
388                        options.clone(),
389                    );
390                    grads.register::<B>(node.id, grad)
391                }
392                if let Some(node) = node_weight {
393                    let grad = B::conv2d_weight_backward(x, weight, grad, options);
394                    grads.register::<B>(node.id, grad)
395                }
396            }
397        }
398
399        match bias {
400            Some(bias) => match Conv2DWithBias
401                .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
402                .compute_bound()
403                .stateful()
404            {
405                OpsKind::Tracked(mut prep) => {
406                    let x_state = prep.checkpoint(&x);
407                    let weight_state = prep.checkpoint(&weight);
408                    let bias_state = prep.checkpoint(&bias);
409                    prep.finish(
410                        (x_state, weight_state, bias_state, options.clone()),
411                        B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options),
412                    )
413                }
414                OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
415                    x.primitive,
416                    weight.primitive,
417                    Some(bias.primitive),
418                    options,
419                )),
420            },
421            None => match Conv2DNoBias
422                .prepare::<C>([x.node.clone(), weight.node.clone()])
423                .compute_bound()
424                .stateful()
425            {
426                OpsKind::Tracked(mut prep) => {
427                    let x_state = prep.checkpoint(&x);
428                    let weight_state = prep.checkpoint(&weight);
429                    prep.finish(
430                        (x_state, weight_state, options.clone()),
431                        B::conv2d(x.primitive, weight.primitive, None, options),
432                    )
433                }
434
435                OpsKind::UnTracked(prep) => {
436                    prep.finish(B::conv2d(x.primitive, weight.primitive, None, options))
437                }
438            },
439        }
440    }
441
442    fn deform_conv2d(
443        x: AutodiffTensor<B>,
444        offset: AutodiffTensor<B>,
445        weight: AutodiffTensor<B>,
446        mask: Option<AutodiffTensor<B>>,
447        bias: Option<AutodiffTensor<B>>,
448        options: DeformConvOptions<2>,
449    ) -> AutodiffTensor<B> {
450        #[derive(Debug)]
451        struct DeformConv2DWithMaskWithBias;
452        #[derive(Debug)]
453        struct DeformConv2DWithMaskNoBias;
454        #[derive(Debug)]
455        struct DeformConv2DNoMaskWithBias;
456        #[derive(Debug)]
457        struct DeformConv2DNoMaskNoBias;
458
459        impl<B: Backend> Backward<B, 5> for DeformConv2DWithMaskWithBias {
460            type State = (NodeId, NodeId, NodeId, NodeId, NodeId, DeformConvOptions<2>);
461
462            fn backward(
463                self,
464                ops: Ops<Self::State, 5>,
465                grads: &mut Gradients,
466                checkpointer: &mut Checkpointer,
467            ) {
468                let [node_x, node_offset, node_weight, node_mask, node_bias] = ops.parents;
469                let grad = grads.consume::<B>(&ops.node);
470
471                let (x_state, offset_state, weight_state, mask_state, bias_state, options) =
472                    ops.state;
473                let x = checkpointer.retrieve_node_output(x_state);
474                let offset = checkpointer.retrieve_node_output(offset_state);
475                let weight = checkpointer.retrieve_node_output(weight_state);
476                let mask = Some(checkpointer.retrieve_node_output(mask_state));
477                let bias = Some(checkpointer.retrieve_node_output(bias_state));
478
479                let backward =
480                    B::deform_conv2d_backward(x, offset, weight, mask, bias, grad, options);
481
482                if let Some(node) = node_x {
483                    grads.register::<B>(node.id, backward.x_grad)
484                }
485                if let Some(node) = node_offset {
486                    grads.register::<B>(node.id, backward.offset_grad)
487                }
488                if let Some(node) = node_weight {
489                    grads.register::<B>(node.id, backward.weight_grad)
490                }
491                if let Some(node) = node_mask {
492                    grads.register::<B>(node.id, backward.mask_grad.unwrap())
493                }
494                if let Some(node) = node_bias {
495                    grads.register::<B>(node.id, backward.bias_grad.unwrap())
496                }
497            }
498        }
499
500        impl<B: Backend> Backward<B, 4> for DeformConv2DWithMaskNoBias {
501            type State = (NodeId, NodeId, NodeId, NodeId, DeformConvOptions<2>);
502
503            fn backward(
504                self,
505                ops: Ops<Self::State, 4>,
506                grads: &mut Gradients,
507                checkpointer: &mut Checkpointer,
508            ) {
509                let [node_x, node_offset, node_weight, node_mask] = ops.parents;
510                let grad = grads.consume::<B>(&ops.node);
511
512                let (x_state, offset_state, weight_state, mask_state, options) = ops.state;
513                let x = checkpointer.retrieve_node_output(x_state);
514                let offset = checkpointer.retrieve_node_output(offset_state);
515                let weight = checkpointer.retrieve_node_output(weight_state);
516                let mask = Some(checkpointer.retrieve_node_output(mask_state));
517
518                let backward =
519                    B::deform_conv2d_backward(x, offset, weight, mask, None, grad, options);
520
521                if let Some(node) = node_x {
522                    grads.register::<B>(node.id, backward.x_grad)
523                }
524                if let Some(node) = node_offset {
525                    grads.register::<B>(node.id, backward.offset_grad)
526                }
527                if let Some(node) = node_weight {
528                    grads.register::<B>(node.id, backward.weight_grad)
529                }
530                if let Some(node) = node_mask {
531                    grads.register::<B>(node.id, backward.mask_grad.unwrap())
532                }
533            }
534        }
535
536        impl<B: Backend> Backward<B, 4> for DeformConv2DNoMaskWithBias {
537            type State = (NodeId, NodeId, NodeId, NodeId, DeformConvOptions<2>);
538
539            fn backward(
540                self,
541                ops: Ops<Self::State, 4>,
542                grads: &mut Gradients,
543                checkpointer: &mut Checkpointer,
544            ) {
545                let [node_x, node_offset, node_weight, node_bias] = ops.parents;
546                let grad = grads.consume::<B>(&ops.node);
547
548                let (x_state, offset_state, weight_state, bias_state, options) = ops.state;
549                let x = checkpointer.retrieve_node_output(x_state);
550                let offset = checkpointer.retrieve_node_output(offset_state);
551                let weight = checkpointer.retrieve_node_output(weight_state);
552                let bias = Some(checkpointer.retrieve_node_output(bias_state));
553
554                let backward =
555                    B::deform_conv2d_backward(x, offset, weight, None, bias, grad, options);
556
557                if let Some(node) = node_x {
558                    grads.register::<B>(node.id, backward.x_grad)
559                }
560                if let Some(node) = node_offset {
561                    grads.register::<B>(node.id, backward.offset_grad)
562                }
563                if let Some(node) = node_weight {
564                    grads.register::<B>(node.id, backward.weight_grad)
565                }
566                if let Some(node) = node_bias {
567                    grads.register::<B>(node.id, backward.bias_grad.unwrap())
568                }
569            }
570        }
571
572        impl<B: Backend> Backward<B, 3> for DeformConv2DNoMaskNoBias {
573            type State = (NodeId, NodeId, NodeId, DeformConvOptions<2>);
574
575            fn backward(
576                self,
577                ops: Ops<Self::State, 3>,
578                grads: &mut Gradients,
579                checkpointer: &mut Checkpointer,
580            ) {
581                let [node_x, node_offset, node_weight] = ops.parents;
582                let grad = grads.consume::<B>(&ops.node);
583
584                let (x_state, offset_state, weight_state, options) = ops.state;
585                let x = checkpointer.retrieve_node_output(x_state);
586                let offset = checkpointer.retrieve_node_output(offset_state);
587                let weight = checkpointer.retrieve_node_output(weight_state);
588
589                let backward =
590                    B::deform_conv2d_backward(x, offset, weight, None, None, grad, options);
591
592                if let Some(node) = node_x {
593                    grads.register::<B>(node.id, backward.x_grad)
594                }
595                if let Some(node) = node_offset {
596                    grads.register::<B>(node.id, backward.offset_grad)
597                }
598                if let Some(node) = node_weight {
599                    grads.register::<B>(node.id, backward.weight_grad)
600                }
601            }
602        }
603
604        match (mask, bias) {
605            (Some(mask), Some(bias)) => match DeformConv2DWithMaskWithBias
606                .prepare::<C>([
607                    x.node.clone(),
608                    offset.node.clone(),
609                    weight.node.clone(),
610                    mask.node.clone(),
611                    bias.node.clone(),
612                ])
613                .compute_bound()
614                .stateful()
615            {
616                OpsKind::Tracked(mut prep) => {
617                    let x_state = prep.checkpoint(&x);
618                    let offset_state = prep.checkpoint(&offset);
619                    let weight_state = prep.checkpoint(&weight);
620                    let mask_state = prep.checkpoint(&mask);
621                    let bias_state = prep.checkpoint(&bias);
622                    prep.finish(
623                        (
624                            x_state,
625                            offset_state,
626                            weight_state,
627                            mask_state,
628                            bias_state,
629                            options.clone(),
630                        ),
631                        B::deform_conv2d(
632                            x.primitive,
633                            offset.primitive,
634                            weight.primitive,
635                            Some(mask.primitive),
636                            Some(bias.primitive),
637                            options,
638                        ),
639                    )
640                }
641                OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
642                    x.primitive,
643                    offset.primitive,
644                    weight.primitive,
645                    Some(mask.primitive),
646                    Some(bias.primitive),
647                    options,
648                )),
649            },
650            (Some(mask), None) => match DeformConv2DWithMaskNoBias
651                .prepare::<C>([
652                    x.node.clone(),
653                    offset.node.clone(),
654                    weight.node.clone(),
655                    mask.node.clone(),
656                ])
657                .compute_bound()
658                .stateful()
659            {
660                OpsKind::Tracked(mut prep) => {
661                    let x_state = prep.checkpoint(&x);
662                    let offset_state = prep.checkpoint(&offset);
663                    let weight_state = prep.checkpoint(&weight);
664                    let mask_state = prep.checkpoint(&mask);
665                    prep.finish(
666                        (
667                            x_state,
668                            offset_state,
669                            weight_state,
670                            mask_state,
671                            options.clone(),
672                        ),
673                        B::deform_conv2d(
674                            x.primitive,
675                            offset.primitive,
676                            weight.primitive,
677                            Some(mask.primitive),
678                            None,
679                            options,
680                        ),
681                    )
682                }
683                OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
684                    x.primitive,
685                    offset.primitive,
686                    weight.primitive,
687                    Some(mask.primitive),
688                    None,
689                    options,
690                )),
691            },
692            (None, Some(bias)) => match DeformConv2DNoMaskWithBias
693                .prepare::<C>([
694                    x.node.clone(),
695                    offset.node.clone(),
696                    weight.node.clone(),
697                    bias.node.clone(),
698                ])
699                .compute_bound()
700                .stateful()
701            {
702                OpsKind::Tracked(mut prep) => {
703                    let x_state = prep.checkpoint(&x);
704                    let offset_state = prep.checkpoint(&offset);
705                    let weight_state = prep.checkpoint(&weight);
706                    let bias_state = prep.checkpoint(&bias);
707                    prep.finish(
708                        (
709                            x_state,
710                            offset_state,
711                            weight_state,
712                            bias_state,
713                            options.clone(),
714                        ),
715                        B::deform_conv2d(
716                            x.primitive,
717                            offset.primitive,
718                            weight.primitive,
719                            None,
720                            Some(bias.primitive),
721                            options,
722                        ),
723                    )
724                }
725                OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
726                    x.primitive,
727                    offset.primitive,
728                    weight.primitive,
729                    None,
730                    Some(bias.primitive),
731                    options,
732                )),
733            },
734            (None, None) => match DeformConv2DNoMaskNoBias
735                .prepare::<C>([x.node.clone(), offset.node.clone(), weight.node.clone()])
736                .compute_bound()
737                .stateful()
738            {
739                OpsKind::Tracked(mut prep) => {
740                    let x_state = prep.checkpoint(&x);
741                    let offset_state = prep.checkpoint(&offset);
742                    let weight_state = prep.checkpoint(&weight);
743                    prep.finish(
744                        (x_state, offset_state, weight_state, options.clone()),
745                        B::deform_conv2d(
746                            x.primitive,
747                            offset.primitive,
748                            weight.primitive,
749                            None,
750                            None,
751                            options,
752                        ),
753                    )
754                }
755                OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
756                    x.primitive,
757                    offset.primitive,
758                    weight.primitive,
759                    None,
760                    None,
761                    options,
762                )),
763            },
764        }
765    }
766
767    fn deform_conv2d_backward(
768        _x: AutodiffTensor<B>,
769        _offset: AutodiffTensor<B>,
770        _weight: AutodiffTensor<B>,
771        _mask: Option<AutodiffTensor<B>>,
772        _bias: Option<AutodiffTensor<B>>,
773        _output_grad: AutodiffTensor<B>,
774        _options: DeformConvOptions<2>,
775    ) -> DeformConv2dBackward<Self> {
776        panic!("Can't differentiate deform conv 2d backward.");
777    }
778
779    fn conv_transpose2d(
780        x: AutodiffTensor<B>,
781        weight: AutodiffTensor<B>,
782        bias: Option<AutodiffTensor<B>>,
783        options: ConvTransposeOptions<2>,
784    ) -> AutodiffTensor<B> {
785        #[derive(Debug)]
786        struct ConvTranspose2DWithBias;
787        #[derive(Debug)]
788        struct ConvTranspose2DNoBias;
789
790        impl<B: Backend> Backward<B, 3> for ConvTranspose2DWithBias {
791            type State = (NodeId, NodeId, NodeId, ConvTransposeOptions<2>);
792
793            fn backward(
794                self,
795                ops: Ops<Self::State, 3>,
796                grads: &mut Gradients,
797                checkpointer: &mut Checkpointer,
798            ) {
799                let [node_x, node_weight, node_bias] = ops.parents;
800                let grad = grads.consume::<B>(&ops.node);
801
802                let (x_state, weight_state, bias_state, options) = ops.state;
803                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
804                let weight =
805                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
806                let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);
807
808                if let Some(node) = node_x {
809                    let grad = B::conv_transpose2d_x_backward(
810                        weight.clone(),
811                        grad.clone(),
812                        options.clone(),
813                    );
814                    grads.register::<B>(node.id, grad)
815                }
816                if let Some(node) = node_weight {
817                    let grad = B::conv_transpose2d_weight_backward(
818                        x.clone(),
819                        weight,
820                        grad.clone(),
821                        options,
822                    );
823                    grads.register::<B>(node.id, grad)
824                }
825                if let Some(node) = node_bias {
826                    let grad = B::conv_transpose2d_bias_backward(x, bias, grad);
827                    grads.register::<B>(node.id, grad)
828                }
829            }
830        }
831
832        impl<B: Backend> Backward<B, 2> for ConvTranspose2DNoBias {
833            type State = (NodeId, NodeId, ConvTransposeOptions<2>);
834
835            fn backward(
836                self,
837                ops: Ops<Self::State, 2>,
838                grads: &mut Gradients,
839                checkpointer: &mut Checkpointer,
840            ) {
841                let [node_x, node_weight] = ops.parents;
842                let grad = grads.consume::<B>(&ops.node);
843
844                let (x_state, weight_state, options) = ops.state;
845                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
846                let weight =
847                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
848
849                if let Some(node) = node_x {
850                    let grad = B::conv_transpose2d_x_backward(
851                        weight.clone(),
852                        grad.clone(),
853                        options.clone(),
854                    );
855                    grads.register::<B>(node.id, grad)
856                }
857                if let Some(node) = node_weight {
858                    let grad = B::conv_transpose2d_weight_backward(x, weight, grad, options);
859                    grads.register::<B>(node.id, grad)
860                }
861            }
862        }
863
864        match bias {
865            Some(bias) => match ConvTranspose2DWithBias
866                .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
867                .compute_bound()
868                .stateful()
869            {
870                OpsKind::Tracked(mut prep) => {
871                    let x_state = prep.checkpoint(&x);
872                    let weight_state = prep.checkpoint(&weight);
873                    let bias_state = prep.checkpoint(&bias);
874
875                    prep.finish(
876                        (x_state, weight_state, bias_state, options.clone()),
877                        B::conv_transpose2d(
878                            x.primitive,
879                            weight.primitive,
880                            Some(bias.primitive),
881                            options,
882                        ),
883                    )
884                }
885                OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
886                    x.primitive,
887                    weight.primitive,
888                    Some(bias.primitive),
889                    options,
890                )),
891            },
892            None => match ConvTranspose2DNoBias
893                .prepare::<C>([x.node.clone(), weight.node.clone()])
894                .compute_bound()
895                .stateful()
896            {
897                OpsKind::Tracked(mut prep) => {
898                    let x_state = prep.checkpoint(&x);
899                    let weight_state = prep.checkpoint(&weight);
900
901                    prep.finish(
902                        (x_state, weight_state, options.clone()),
903                        B::conv_transpose2d(x.primitive, weight.primitive, None, options),
904                    )
905                }
906                OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
907                    x.primitive,
908                    weight.primitive,
909                    None,
910                    options,
911                )),
912            },
913        }
914    }
915
916    fn conv3d(
917        x: AutodiffTensor<B>,
918        weight: AutodiffTensor<B>,
919        bias: Option<AutodiffTensor<B>>,
920        options: ConvOptions<3>,
921    ) -> AutodiffTensor<B> {
922        #[derive(Debug)]
923        struct Conv3DWithBias;
924        #[derive(Debug)]
925        struct Conv3DNoBias;
926
927        impl<B: Backend> Backward<B, 3> for Conv3DWithBias {
928            type State = (NodeId, NodeId, NodeId, ConvOptions<3>);
929
930            fn backward(
931                self,
932                ops: Ops<Self::State, 3>,
933                grads: &mut Gradients,
934                checkpointer: &mut Checkpointer,
935            ) {
936                let [node_x, node_weight, node_bias] = ops.parents;
937                let grad = grads.consume::<B>(&ops.node);
938
939                let (x_state, weight_state, bias_state, options) = ops.state;
940                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
941                let weight =
942                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
943                let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);
944
945                if let Some(node) = node_x {
946                    let grad = B::conv3d_x_backward(
947                        x.clone(),
948                        weight.clone(),
949                        grad.clone(),
950                        options.clone(),
951                    );
952                    grads.register::<B>(node.id, grad)
953                }
954                if let Some(node) = node_weight {
955                    let grad =
956                        B::conv3d_weight_backward(x.clone(), weight.clone(), grad.clone(), options);
957                    grads.register::<B>(node.id, grad)
958                }
959                if let Some(node) = node_bias {
960                    let grad = B::conv3d_bias_backward(x, weight, bias, grad);
961                    grads.register::<B>(node.id, grad)
962                }
963            }
964        }
965
966        impl<B: Backend> Backward<B, 2> for Conv3DNoBias {
967            type State = (NodeId, NodeId, ConvOptions<3>);
968
969            fn backward(
970                self,
971                ops: Ops<Self::State, 2>,
972                grads: &mut Gradients,
973                checkpointer: &mut Checkpointer,
974            ) {
975                let [node_x, node_weight] = ops.parents;
976                let grad = grads.consume::<B>(&ops.node);
977
978                let (x_state, weight_state, options) = ops.state;
979                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
980                let weight =
981                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
982
983                if let Some(node) = node_x {
984                    let grad = B::conv3d_x_backward(
985                        x.clone(),
986                        weight.clone(),
987                        grad.clone(),
988                        options.clone(),
989                    );
990                    grads.register::<B>(node.id, grad)
991                }
992                if let Some(node) = node_weight {
993                    let grad = B::conv3d_weight_backward(x, weight, grad, options);
994                    grads.register::<B>(node.id, grad)
995                }
996            }
997        }
998
999        match bias {
1000            Some(bias) => match Conv3DWithBias
1001                .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
1002                .compute_bound()
1003                .stateful()
1004            {
1005                OpsKind::Tracked(mut prep) => {
1006                    let x_state = prep.checkpoint(&x);
1007                    let weight_state = prep.checkpoint(&weight);
1008                    let bias_state = prep.checkpoint(&bias);
1009                    prep.finish(
1010                        (x_state, weight_state, bias_state, options.clone()),
1011                        B::conv3d(x.primitive, weight.primitive, Some(bias.primitive), options),
1012                    )
1013                }
1014                OpsKind::UnTracked(prep) => prep.finish(B::conv3d(
1015                    x.primitive,
1016                    weight.primitive,
1017                    Some(bias.primitive),
1018                    options,
1019                )),
1020            },
1021            None => match Conv3DNoBias
1022                .prepare::<C>([x.node.clone(), weight.node.clone()])
1023                .compute_bound()
1024                .stateful()
1025            {
1026                OpsKind::Tracked(mut prep) => {
1027                    let x_state = prep.checkpoint(&x);
1028                    let weight_state = prep.checkpoint(&weight);
1029                    prep.finish(
1030                        (x_state, weight_state, options.clone()),
1031                        B::conv3d(x.primitive, weight.primitive, None, options),
1032                    )
1033                }
1034
1035                OpsKind::UnTracked(prep) => {
1036                    prep.finish(B::conv3d(x.primitive, weight.primitive, None, options))
1037                }
1038            },
1039        }
1040    }
1041
1042    fn conv_transpose3d(
1043        x: AutodiffTensor<B>,
1044        weight: AutodiffTensor<B>,
1045        bias: Option<AutodiffTensor<B>>,
1046        options: ConvTransposeOptions<3>,
1047    ) -> AutodiffTensor<B> {
1048        #[derive(Debug)]
1049        struct ConvTranspose3DWithBias;
1050        #[derive(Debug)]
1051        struct ConvTranspose3DNoBias;
1052
1053        impl<B: Backend> Backward<B, 3> for ConvTranspose3DWithBias {
1054            type State = (NodeId, NodeId, NodeId, ConvTransposeOptions<3>);
1055
1056            fn backward(
1057                self,
1058                ops: Ops<Self::State, 3>,
1059                grads: &mut Gradients,
1060                checkpointer: &mut Checkpointer,
1061            ) {
1062                let [node_x, node_weight, node_bias] = ops.parents;
1063                let grad = grads.consume::<B>(&ops.node);
1064
1065                let (x_state, weight_state, bias_state, options) = ops.state;
1066                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
1067                let weight =
1068                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
1069                let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);
1070
1071                if let Some(node) = node_x {
1072                    let grad = B::conv_transpose3d_x_backward(
1073                        weight.clone(),
1074                        grad.clone(),
1075                        options.clone(),
1076                    );
1077                    grads.register::<B>(node.id, grad)
1078                }
1079                if let Some(node) = node_weight {
1080                    let grad = B::conv_transpose3d_weight_backward(
1081                        x.clone(),
1082                        weight,
1083                        grad.clone(),
1084                        options,
1085                    );
1086                    grads.register::<B>(node.id, grad)
1087                }
1088                if let Some(node) = node_bias {
1089                    let grad = B::conv_transpose3d_bias_backward(x, bias, grad);
1090                    grads.register::<B>(node.id, grad)
1091                }
1092            }
1093        }
1094
1095        impl<B: Backend> Backward<B, 2> for ConvTranspose3DNoBias {
1096            type State = (NodeId, NodeId, ConvTransposeOptions<3>);
1097
1098            fn backward(
1099                self,
1100                ops: Ops<Self::State, 2>,
1101                grads: &mut Gradients,
1102                checkpointer: &mut Checkpointer,
1103            ) {
1104                let [node_x, node_weight] = ops.parents;
1105                let grad = grads.consume::<B>(&ops.node);
1106
1107                let (x_state, weight_state, options) = ops.state;
1108                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);
1109                let weight =
1110                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);
1111
1112                if let Some(node) = node_x {
1113                    let grad = B::conv_transpose3d_x_backward(
1114                        weight.clone(),
1115                        grad.clone(),
1116                        options.clone(),
1117                    );
1118                    grads.register::<B>(node.id, grad)
1119                }
1120                if let Some(node) = node_weight {
1121                    let grad = B::conv_transpose3d_weight_backward(x, weight, grad, options);
1122                    grads.register::<B>(node.id, grad)
1123                }
1124            }
1125        }
1126
1127        match bias {
1128            Some(bias) => match ConvTranspose3DWithBias
1129                .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])
1130                .compute_bound()
1131                .stateful()
1132            {
1133                OpsKind::Tracked(mut prep) => {
1134                    let x_state = prep.checkpoint(&x);
1135                    let weight_state = prep.checkpoint(&weight);
1136                    let bias_state = prep.checkpoint(&bias);
1137
1138                    prep.finish(
1139                        (x_state, weight_state, bias_state, options.clone()),
1140                        B::conv_transpose3d(
1141                            x.primitive,
1142                            weight.primitive,
1143                            Some(bias.primitive),
1144                            options,
1145                        ),
1146                    )
1147                }
1148                OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose3d(
1149                    x.primitive,
1150                    weight.primitive,
1151                    Some(bias.primitive),
1152                    options,
1153                )),
1154            },
1155            None => match ConvTranspose3DNoBias
1156                .prepare::<C>([x.node.clone(), weight.node.clone()])
1157                .compute_bound()
1158                .stateful()
1159            {
1160                OpsKind::Tracked(mut prep) => {
1161                    let x_state = prep.checkpoint(&x);
1162                    let weight_state = prep.checkpoint(&weight);
1163
1164                    prep.finish(
1165                        (x_state, weight_state, options.clone()),
1166                        B::conv_transpose3d(x.primitive, weight.primitive, None, options),
1167                    )
1168                }
1169                OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose3d(
1170                    x.primitive,
1171                    weight.primitive,
1172                    None,
1173                    options,
1174                )),
1175            },
1176        }
1177    }
1178
1179    // TODO: Support a custom unfold4d operation by overriding the default implementation.
1180    //
1181    // We don't override it now because the fold operation isn't available for the backward pass.
1182    // This implies that when autodiff is enabled, custom unfold operations defined by backends
1183    // won't be used. Instead, the conv2d operation with custom weights matrix will be used.
1184    // Therefore, the conv2d backward pass will be used for the unfold4d backward pass.
1185    //
1186    // fn unfold4d(
1187    //     x:AutodiffTensor<B>,
1188    //     kernel_size: [usize; 2],
1189    //     options: UnfoldOptions,
1190    // ) -> AutodiffTensor<B> {
1191    //     todo!()
1192    // }
1193
1194    fn avg_pool1d(
1195        x: AutodiffTensor<B>,
1196        kernel_size: usize,
1197        stride: usize,
1198        padding: usize,
1199        count_include_pad: bool,
1200        ceil_mode: bool,
1201    ) -> AutodiffTensor<B> {
1202        #[derive(Debug)]
1203        struct AvgPool1D;
1204
1205        impl<B: Backend> Backward<B, 1> for AvgPool1D {
1206            type State = (NodeId, usize, usize, usize, bool, bool);
1207
1208            fn backward(
1209                self,
1210                ops: Ops<Self::State, 1>,
1211                grads: &mut Gradients,
1212                checkpointer: &mut Checkpointer,
1213            ) {
1214                let [node_parent] = ops.parents;
1215                let grad = grads.consume::<B>(&ops.node);
1216                let (x_state, kernel_size, stride, padding, count_include_pad, ceil_mode) =
1217                    ops.state;
1218                let x = checkpointer.retrieve_node_output(x_state);
1219
1220                if let Some(node) = node_parent {
1221                    let grad = B::avg_pool1d_backward(
1222                        x,
1223                        grad,
1224                        kernel_size,
1225                        stride,
1226                        padding,
1227                        count_include_pad,
1228                        ceil_mode,
1229                    );
1230                    grads.register::<B>(node.id, grad);
1231                }
1232            }
1233        }
1234
1235        match AvgPool1D
1236            .prepare::<C>([x.node.clone()])
1237            .compute_bound()
1238            .stateful()
1239        {
1240            OpsKind::Tracked(mut prep) => {
1241                let x_state = prep.checkpoint(&x);
1242                prep.finish(
1243                    (
1244                        x_state,
1245                        kernel_size,
1246                        stride,
1247                        padding,
1248                        count_include_pad,
1249                        ceil_mode,
1250                    ),
1251                    B::avg_pool1d(
1252                        x.primitive.clone(),
1253                        kernel_size,
1254                        stride,
1255                        padding,
1256                        count_include_pad,
1257                        ceil_mode,
1258                    ),
1259                )
1260            }
1261            OpsKind::UnTracked(prep) => prep.finish(B::avg_pool1d(
1262                x.primitive,
1263                kernel_size,
1264                stride,
1265                padding,
1266                count_include_pad,
1267                ceil_mode,
1268            )),
1269        }
1270    }
1271
1272    fn avg_pool2d(
1273        x: AutodiffTensor<B>,
1274        kernel_size: [usize; 2],
1275        stride: [usize; 2],
1276        padding: [usize; 2],
1277        count_include_pad: bool,
1278        ceil_mode: bool,
1279    ) -> AutodiffTensor<B> {
1280        #[derive(Debug)]
1281        struct AvgPool2D;
1282
1283        impl<B: Backend> Backward<B, 1> for AvgPool2D {
1284            type State = (NodeId, [usize; 2], [usize; 2], [usize; 2], bool, bool);
1285
1286            fn backward(
1287                self,
1288                ops: Ops<Self::State, 1>,
1289                grads: &mut Gradients,
1290                checkpointer: &mut Checkpointer,
1291            ) {
1292                let [node_parent] = ops.parents;
1293                let grad = grads.consume::<B>(&ops.node);
1294                let (x_state, kernel_size, stride, padding, count_include_pad, ceil_mode) =
1295                    ops.state;
1296                let x = checkpointer.retrieve_node_output(x_state);
1297
1298                if let Some(node) = node_parent {
1299                    let grad = B::avg_pool2d_backward(
1300                        x,
1301                        grad,
1302                        kernel_size,
1303                        stride,
1304                        padding,
1305                        count_include_pad,
1306                        ceil_mode,
1307                    );
1308                    grads.register::<B>(node.id, grad);
1309                }
1310            }
1311        }
1312
1313        match AvgPool2D
1314            .prepare::<C>([x.node.clone()])
1315            .compute_bound()
1316            .stateful()
1317        {
1318            OpsKind::Tracked(mut prep) => {
1319                let x_state = prep.checkpoint(&x);
1320                prep.finish(
1321                    (
1322                        x_state,
1323                        kernel_size,
1324                        stride,
1325                        padding,
1326                        count_include_pad,
1327                        ceil_mode,
1328                    ),
1329                    B::avg_pool2d(
1330                        x.primitive.clone(),
1331                        kernel_size,
1332                        stride,
1333                        padding,
1334                        count_include_pad,
1335                        ceil_mode,
1336                    ),
1337                )
1338            }
1339            OpsKind::UnTracked(prep) => prep.finish(B::avg_pool2d(
1340                x.primitive,
1341                kernel_size,
1342                stride,
1343                padding,
1344                count_include_pad,
1345                ceil_mode,
1346            )),
1347        }
1348    }
1349
1350    fn avg_pool2d_backward(
1351        _x: AutodiffTensor<B>,
1352        _grad: AutodiffTensor<B>,
1353        _kernel_size: [usize; 2],
1354        _stride: [usize; 2],
1355        _padding: [usize; 2],
1356        _count_include_pad: bool,
1357        _ceil_mode: bool,
1358    ) -> AutodiffTensor<B> {
1359        panic!("Can't differentiate avg pool 2d backward.");
1360    }
1361
1362    fn max_pool1d(
1363        x: AutodiffTensor<B>,
1364        kernel_size: usize,
1365        stride: usize,
1366        padding: usize,
1367        dilation: usize,
1368        ceil_mode: bool,
1369    ) -> AutodiffTensor<B> {
1370        match MaxPool1D
1371            .prepare::<C>([x.node.clone()])
1372            .compute_bound()
1373            .stateful()
1374        {
1375            OpsKind::Tracked(mut prep) => {
1376                let x_state = prep.checkpoint(&x);
1377                let output = B::max_pool1d_with_indices(
1378                    x.primitive,
1379                    kernel_size,
1380                    stride,
1381                    padding,
1382                    dilation,
1383                    ceil_mode,
1384                );
1385                prep.finish(
1386                    (
1387                        x_state,
1388                        output.indices,
1389                        kernel_size,
1390                        stride,
1391                        padding,
1392                        dilation,
1393                        ceil_mode,
1394                    ),
1395                    output.output,
1396                )
1397            }
1398            OpsKind::UnTracked(prep) => prep.finish(B::max_pool1d(
1399                x.primitive,
1400                kernel_size,
1401                stride,
1402                padding,
1403                dilation,
1404                ceil_mode,
1405            )),
1406        }
1407    }
1408
1409    fn max_pool1d_with_indices(
1410        x: AutodiffTensor<B>,
1411        kernel_size: usize,
1412        stride: usize,
1413        padding: usize,
1414        dilation: usize,
1415        ceil_mode: bool,
1416    ) -> MaxPool1dWithIndices<Self> {
1417        match MaxPool1D
1418            .prepare::<C>([x.node.clone()])
1419            .compute_bound()
1420            .stateful()
1421        {
1422            OpsKind::Tracked(mut prep) => {
1423                let x_state = prep.checkpoint(&x);
1424                let output = B::max_pool1d_with_indices(
1425                    x.primitive,
1426                    kernel_size,
1427                    stride,
1428                    padding,
1429                    dilation,
1430                    ceil_mode,
1431                );
1432
1433                let output_tensor = prep.finish(
1434                    (
1435                        x_state,
1436                        output.indices.clone(),
1437                        kernel_size,
1438                        stride,
1439                        padding,
1440                        dilation,
1441                        ceil_mode,
1442                    ),
1443                    output.output,
1444                );
1445
1446                MaxPool1dWithIndices::new(output_tensor, output.indices)
1447            }
1448            OpsKind::UnTracked(prep) => {
1449                let output = B::max_pool1d_with_indices(
1450                    x.primitive,
1451                    kernel_size,
1452                    stride,
1453                    padding,
1454                    dilation,
1455                    ceil_mode,
1456                );
1457                let output_tensor = prep.finish(output.output);
1458
1459                MaxPool1dWithIndices::new(output_tensor, output.indices)
1460            }
1461        }
1462    }
1463
1464    fn max_pool1d_with_indices_backward(
1465        x: AutodiffTensor<B>,
1466        kernel_size: usize,
1467        stride: usize,
1468        padding: usize,
1469        dilation: usize,
1470        ceil_mode: bool,
1471        output_grad: AutodiffTensor<B>,
1472        indices: IntTensor<B>,
1473    ) -> MaxPool1dBackward<Self> {
1474        let output = B::max_pool1d_with_indices_backward(
1475            x.primitive,
1476            kernel_size,
1477            stride,
1478            padding,
1479            dilation,
1480            ceil_mode,
1481            output_grad.primitive,
1482            indices,
1483        );
1484        MaxPool1dBackward::new(AutodiffTensor::new(output.x_grad))
1485    }
1486
1487    fn max_pool2d(
1488        x: AutodiffTensor<B>,
1489        kernel_size: [usize; 2],
1490        stride: [usize; 2],
1491        padding: [usize; 2],
1492        dilation: [usize; 2],
1493        ceil_mode: bool,
1494    ) -> AutodiffTensor<B> {
1495        match MaxPool2D
1496            .prepare::<C>([x.node.clone()])
1497            .compute_bound()
1498            .stateful()
1499        {
1500            OpsKind::Tracked(mut prep) => {
1501                let x_state = prep.checkpoint(&x);
1502                let output = B::max_pool2d_with_indices(
1503                    x.primitive,
1504                    kernel_size,
1505                    stride,
1506                    padding,
1507                    dilation,
1508                    ceil_mode,
1509                );
1510                prep.finish(
1511                    (
1512                        x_state,
1513                        output.indices,
1514                        kernel_size,
1515                        stride,
1516                        padding,
1517                        dilation,
1518                        ceil_mode,
1519                    ),
1520                    output.output,
1521                )
1522            }
1523            OpsKind::UnTracked(prep) => prep.finish(B::max_pool2d(
1524                x.primitive,
1525                kernel_size,
1526                stride,
1527                padding,
1528                dilation,
1529                ceil_mode,
1530            )),
1531        }
1532    }
1533
1534    fn max_pool2d_with_indices(
1535        x: AutodiffTensor<B>,
1536        kernel_size: [usize; 2],
1537        stride: [usize; 2],
1538        padding: [usize; 2],
1539        dilation: [usize; 2],
1540        ceil_mode: bool,
1541    ) -> MaxPool2dWithIndices<Self> {
1542        match MaxPool2D
1543            .prepare::<C>([x.node.clone()])
1544            .compute_bound()
1545            .stateful()
1546        {
1547            OpsKind::Tracked(mut prep) => {
1548                let x_state = prep.checkpoint(&x);
1549
1550                let output = B::max_pool2d_with_indices(
1551                    x.primitive,
1552                    kernel_size,
1553                    stride,
1554                    padding,
1555                    dilation,
1556                    ceil_mode,
1557                );
1558
1559                let output_tensor = prep.finish(
1560                    (
1561                        x_state,
1562                        output.indices.clone(),
1563                        kernel_size,
1564                        stride,
1565                        padding,
1566                        dilation,
1567                        ceil_mode,
1568                    ),
1569                    output.output,
1570                );
1571
1572                MaxPool2dWithIndices::new(output_tensor, output.indices)
1573            }
1574            OpsKind::UnTracked(prep) => {
1575                let output = B::max_pool2d_with_indices(
1576                    x.primitive,
1577                    kernel_size,
1578                    stride,
1579                    padding,
1580                    dilation,
1581                    ceil_mode,
1582                );
1583                let output_tensor = prep.finish(output.output);
1584
1585                MaxPool2dWithIndices::new(output_tensor, output.indices)
1586            }
1587        }
1588    }
1589
1590    fn max_pool2d_with_indices_backward(
1591        _x: AutodiffTensor<B>,
1592        _kernel_size: [usize; 2],
1593        _stride: [usize; 2],
1594        _padding: [usize; 2],
1595        _dilation: [usize; 2],
1596        _ceil_mode: bool,
1597        _output_grad: AutodiffTensor<B>,
1598        _indices: IntTensor<B>,
1599    ) -> MaxPool2dBackward<Self> {
1600        panic!("Can't differentiate max pool2d with indices backward.");
1601    }
1602    fn adaptive_avg_pool1d(x: AutodiffTensor<B>, output_size: usize) -> AutodiffTensor<B> {
1603        #[derive(Debug)]
1604        struct AdaptiveAvgPool1D;
1605
1606        impl<B: Backend> Backward<B, 1> for AdaptiveAvgPool1D {
1607            type State = NodeId;
1608
1609            fn backward(
1610                self,
1611                ops: Ops<Self::State, 1>,
1612                grads: &mut Gradients,
1613                checkpointer: &mut Checkpointer,
1614            ) {
1615                let [node_parent] = ops.parents;
1616                let grad = grads.consume::<B>(&ops.node);
1617                let state = checkpointer.retrieve_node_output(ops.state);
1618
1619                if let Some(node) = node_parent {
1620                    let grad = B::adaptive_avg_pool1d_backward(state, grad);
1621                    grads.register::<B>(node.id, grad);
1622                }
1623            }
1624        }
1625
1626        match AdaptiveAvgPool1D
1627            .prepare::<C>([x.node.clone()])
1628            .compute_bound()
1629            .stateful()
1630        {
1631            OpsKind::Tracked(mut prep) => {
1632                let x_state = prep.checkpoint(&x);
1633                prep.finish(x_state, B::adaptive_avg_pool1d(x.primitive, output_size))
1634            }
1635            OpsKind::UnTracked(prep) => {
1636                prep.finish(B::adaptive_avg_pool1d(x.primitive, output_size))
1637            }
1638        }
1639    }
1640
1641    fn adaptive_avg_pool2d(x: AutodiffTensor<B>, output_size: [usize; 2]) -> AutodiffTensor<B> {
1642        #[derive(Debug)]
1643        struct AdaptiveAvgPool2D;
1644
1645        impl<B: Backend> Backward<B, 1> for AdaptiveAvgPool2D {
1646            type State = NodeId;
1647
1648            fn backward(
1649                self,
1650                ops: Ops<Self::State, 1>,
1651                grads: &mut Gradients,
1652                checkpointer: &mut Checkpointer,
1653            ) {
1654                let [node_parent] = ops.parents;
1655                let grad = grads.consume::<B>(&ops.node);
1656                let state = checkpointer.retrieve_node_output(ops.state);
1657
1658                if let Some(node) = node_parent {
1659                    let grad = B::adaptive_avg_pool2d_backward(state, grad);
1660                    grads.register::<B>(node.id, grad);
1661                }
1662            }
1663        }
1664
1665        match AdaptiveAvgPool2D
1666            .prepare::<C>([x.node.clone()])
1667            .compute_bound()
1668            .stateful()
1669        {
1670            OpsKind::Tracked(mut prep) => {
1671                let x_state = prep.checkpoint(&x);
1672                prep.finish(x_state, B::adaptive_avg_pool2d(x.primitive, output_size))
1673            }
1674            OpsKind::UnTracked(prep) => {
1675                prep.finish(B::adaptive_avg_pool2d(x.primitive, output_size))
1676            }
1677        }
1678    }
1679
1680    fn adaptive_avg_pool2d_backward(
1681        _x: AutodiffTensor<B>,
1682        _grad: AutodiffTensor<B>,
1683    ) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
1684        panic!("Can't differentiate adaptive avg pool2d backward.");
1685    }
1686
1687    fn interpolate(
1688        x: AutodiffTensor<B>,
1689        output_size: [usize; 2],
1690        options: InterpolateOptions,
1691    ) -> AutodiffTensor<B> {
1692        #[derive(Debug)]
1693        struct Interpolate;
1694        impl<B: Backend> Backward<B, 1> for Interpolate {
1695            type State = (NodeId, [usize; 2], InterpolateOptions);
1696
1697            fn backward(
1698                self,
1699                ops: Ops<Self::State, 1>,
1700                grads: &mut Gradients,
1701                checkpointer: &mut Checkpointer,
1702            ) {
1703                let [node_parent] = ops.parents;
1704                let grad = grads.consume::<B>(&ops.node);
1705
1706                let (x_state, output_size, options) = ops.state;
1707                let state = checkpointer.retrieve_node_output(x_state);
1708
1709                if let Some(node) = node_parent {
1710                    let grad = B::interpolate_backward(state, grad, output_size, options);
1711                    grads.register::<B>(node.id, grad);
1712                }
1713            }
1714        }
1715
1716        match Interpolate
1717            .prepare::<C>([x.node.clone()])
1718            .compute_bound()
1719            .stateful()
1720        {
1721            OpsKind::Tracked(mut prep) => {
1722                let x_state = prep.checkpoint(&x);
1723                let output = B::interpolate(x.primitive.clone(), output_size, options.clone());
1724                prep.finish((x_state, output_size, options), output)
1725            }
1726            OpsKind::UnTracked(prep) => {
1727                prep.finish(B::interpolate(x.primitive, output_size, options))
1728            }
1729        }
1730    }
1731
1732    fn interpolate_backward(
1733        _x: FloatTensor<Autodiff<B, C>>,
1734        _grad: FloatTensor<Autodiff<B, C>>,
1735        _output_size: [usize; 2],
1736        _options: InterpolateOptions,
1737    ) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
1738        panic!("Can't differentiate interpolate backward.");
1739    }
1740}
1741
1742#[derive(Debug)]
1743struct MaxPool1D;
1744
1745impl<B: Backend> Backward<B, 1> for MaxPool1D {
1746    type State = (NodeId, IntTensor<B>, usize, usize, usize, usize, bool);
1747
1748    fn backward(
1749        self,
1750        ops: Ops<Self::State, 1>,
1751        grads: &mut Gradients,
1752        checkpointer: &mut Checkpointer,
1753    ) {
1754        let [node_parent] = ops.parents;
1755        let grad = grads.consume::<B>(&ops.node);
1756        let (x_state, indices, kernel_size, stride, padding, dilation, ceil_mode) = ops.state;
1757        let x = checkpointer.retrieve_node_output(x_state);
1758
1759        if let Some(node) = node_parent {
1760            let grad = B::max_pool1d_with_indices_backward(
1761                x,
1762                kernel_size,
1763                stride,
1764                padding,
1765                dilation,
1766                ceil_mode,
1767                grad,
1768                indices,
1769            );
1770
1771            grads.register::<B>(node.id, grad.x_grad);
1772        }
1773    }
1774}
1775
1776#[derive(Debug)]
1777struct MaxPool2D;
1778
1779impl<B: Backend> Backward<B, 1> for MaxPool2D {
1780    type State = (
1781        NodeId,
1782        IntTensor<B>,
1783        [usize; 2],
1784        [usize; 2],
1785        [usize; 2],
1786        [usize; 2],
1787        bool,
1788    );
1789
1790    fn backward(
1791        self,
1792        ops: Ops<Self::State, 1>,
1793        grads: &mut Gradients,
1794        checkpointer: &mut Checkpointer,
1795    ) {
1796        let [node_parent] = ops.parents;
1797        let grad = grads.consume::<B>(&ops.node);
1798        let (x_state, indices, kernel_size, stride, padding, dilation, ceil_mode) = ops.state;
1799        let x = checkpointer.retrieve_node_output(x_state);
1800
1801        if let Some(node) = node_parent {
1802            let grad = B::max_pool2d_with_indices_backward(
1803                x,
1804                kernel_size,
1805                stride,
1806                padding,
1807                dilation,
1808                ceil_mode,
1809                grad,
1810                indices,
1811            );
1812
1813            grads.register::<B>(node.id, grad.x_grad);
1814        }
1815    }
1816}