burn_fusion/ops/
module.rs

1use crate::{
2    Fusion, FusionBackend,
3    stream::{OperationStreams, execution::Operation},
4};
5use burn_ir::*;
6use burn_tensor::{
7    Element, Shape,
8    ops::{
9        ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, FloatTensor,
10        IntTensor, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,
11        MaxPool2dWithIndices, ModuleOps,
12        conv::{
13            calculate_conv_output_size, calculate_conv_transpose_output_size,
14            calculate_pool_output_size,
15        },
16    },
17};
18use std::marker::PhantomData;
19
20macro_rules! make_ops {
21    ($name:ident, $desc:ty, $fn:expr) => {
22        #[derive(new, Debug)]
23        struct $name<B: FusionBackend> {
24            desc: $desc,
25            _b: PhantomData<B>,
26        }
27
28        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
29            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
30                #[allow(clippy::redundant_closure_call)]
31                $fn(&self.desc, handles)
32            }
33        }
34    };
35}
36
37impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
38    fn conv1d(
39        x: FloatTensor<Self>,
40        weight: FloatTensor<Self>,
41        bias: Option<FloatTensor<Self>>,
42        options: ConvOptions<1>,
43    ) -> FloatTensor<Self> {
44        make_ops!(Conv1dOps, Conv1dOpIr, |desc: &Conv1dOpIr,
45                                          handles: &mut HandleContainer<
46            B::Handle,
47        >| {
48            let x = handles.get_float_tensor::<B>(&desc.x);
49            let weight = handles.get_float_tensor::<B>(&desc.weight);
50            let bias = desc
51                .bias
52                .as_ref()
53                .map(|bias| handles.get_float_tensor::<B>(bias));
54            let output = B::conv1d(x, weight, bias, desc.options.clone().into());
55            handles.register_float_tensor::<B>(&desc.out.id, output);
56        });
57
58        let size = calculate_conv_output_size(
59            weight.shape[2],
60            options.stride[0],
61            options.padding[0],
62            options.dilation[0],
63            x.shape[2],
64        );
65
66        let mut streams = OperationStreams::default();
67        streams.tensor(&x);
68        streams.tensor(&weight);
69
70        if let Some(bias) = bias.as_ref() {
71            streams.tensor(bias)
72        }
73
74        let shape = vec![x.shape[0], weight.shape[0], size];
75        let out = x
76            .client
77            .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
78
79        let description = Conv1dOpIr {
80            x: x.into_ir(),
81            weight: weight.into_ir(),
82            bias: bias.map(|bias| bias.into_ir()),
83            options: options.into(),
84            out: out.to_ir_out(),
85        };
86
87        out.client.clone().register(
88            streams,
89            OperationIr::Module(ModuleOperationIr::Conv1d(description.clone())),
90            Conv1dOps::<B>::new(description),
91        );
92
93        out
94    }
95
96    fn conv2d(
97        x: FloatTensor<Self>,
98        weight: FloatTensor<Self>,
99        bias: Option<FloatTensor<Self>>,
100        options: ConvOptions<2>,
101    ) -> FloatTensor<Self> {
102        make_ops!(Conv2dOps, Conv2dOpIr, |args: &Conv2dOpIr,
103                                          handles: &mut HandleContainer<
104            B::Handle,
105        >| {
106            let x = handles.get_float_tensor::<B>(&args.x);
107            let weight = handles.get_float_tensor::<B>(&args.weight);
108            let bias = args
109                .bias
110                .as_ref()
111                .map(|bias| handles.get_float_tensor::<B>(bias));
112
113            let output = B::conv2d(x, weight, bias, args.options.clone().into());
114
115            handles.register_float_tensor::<B>(&args.out.id, output);
116        });
117
118        let size_0 = calculate_conv_output_size(
119            weight.shape[2],
120            options.stride[0],
121            options.padding[0],
122            options.dilation[0],
123            x.shape[2],
124        );
125        let size_1 = calculate_conv_output_size(
126            weight.shape[3],
127            options.stride[1],
128            options.padding[1],
129            options.dilation[1],
130            x.shape[3],
131        );
132
133        let mut streams = OperationStreams::default();
134        streams.tensor(&x);
135        streams.tensor(&weight);
136
137        if let Some(bias) = bias.as_ref() {
138            streams.tensor(bias)
139        }
140        let shape = vec![x.shape[0], weight.shape[0], size_0, size_1];
141        let out = x
142            .client
143            .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
144
145        let desc = Conv2dOpIr {
146            x: x.into_ir(),
147            weight: weight.into_ir(),
148            bias: bias.map(|bias| bias.into_ir()),
149            options: options.into(),
150            out: out.to_ir_out(),
151        };
152
153        out.client.register(
154            streams,
155            OperationIr::Module(ModuleOperationIr::Conv2d(desc.clone())),
156            Conv2dOps::<B>::new(desc),
157        );
158
159        out
160    }
161
162    fn deform_conv2d(
163        x: FloatTensor<Self>,
164        offset: FloatTensor<Self>,
165        weight: FloatTensor<Self>,
166        mask: Option<FloatTensor<Self>>,
167        bias: Option<FloatTensor<Self>>,
168        options: DeformConvOptions<2>,
169    ) -> FloatTensor<Self> {
170        make_ops!(
171            DeformConv2dOps,
172            DeformConv2dOpIr,
173            |args: &DeformConv2dOpIr, handles: &mut HandleContainer<B::Handle>| {
174                let x = handles.get_float_tensor::<B>(&args.x);
175                let offset = handles.get_float_tensor::<B>(&args.offset);
176                let weight = handles.get_float_tensor::<B>(&args.weight);
177                let mask = args
178                    .mask
179                    .as_ref()
180                    .map(|mask| handles.get_float_tensor::<B>(mask));
181                let bias = args
182                    .bias
183                    .as_ref()
184                    .map(|bias| handles.get_float_tensor::<B>(bias));
185
186                let output =
187                    B::deform_conv2d(x, offset, weight, mask, bias, args.options.clone().into());
188
189                handles.register_float_tensor::<B>(&args.out.id, output);
190            }
191        );
192
193        let size_0 = calculate_conv_output_size(
194            weight.shape[2],
195            options.stride[0],
196            options.padding[0],
197            options.dilation[0],
198            x.shape[2],
199        );
200        let size_1 = calculate_conv_output_size(
201            weight.shape[3],
202            options.stride[1],
203            options.padding[1],
204            options.dilation[1],
205            x.shape[3],
206        );
207
208        let mut streams = OperationStreams::default();
209        streams.tensor(&x);
210        streams.tensor(&offset);
211        streams.tensor(&weight);
212
213        if let Some(bias) = bias.as_ref() {
214            streams.tensor(bias)
215        }
216        if let Some(mask) = mask.as_ref() {
217            streams.tensor(mask)
218        }
219
220        let shape = vec![x.shape[0], weight.shape[0], size_0, size_1];
221        let out = x
222            .client
223            .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
224
225        let desc = DeformConv2dOpIr {
226            x: x.into_ir(),
227            offset: offset.into_ir(),
228            weight: weight.into_ir(),
229            mask: mask.map(|mask| mask.into_ir()),
230            bias: bias.map(|bias| bias.into_ir()),
231            options: options.into(),
232            out: out.to_ir_out(),
233        };
234
235        out.client.register(
236            streams,
237            OperationIr::Module(ModuleOperationIr::DeformableConv2d(Box::new(desc.clone()))),
238            DeformConv2dOps::<B>::new(desc),
239        );
240
241        out
242    }
243
244    fn deform_conv2d_backward(
245        x: FloatTensor<Self>,
246        offset: FloatTensor<Self>,
247        weight: FloatTensor<Self>,
248        mask: Option<FloatTensor<Self>>,
249        bias: Option<FloatTensor<Self>>,
250        output_grad: FloatTensor<Self>,
251        options: DeformConvOptions<2>,
252    ) -> DeformConv2dBackward<Self> {
253        make_ops!(
254            DeformConv2dBackwardOps,
255            DeformConv2dBackwardOpIr,
256            |args: &DeformConv2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
257                let x = handles.get_float_tensor::<B>(&args.x);
258                let offset = handles.get_float_tensor::<B>(&args.offset);
259                let weight = handles.get_float_tensor::<B>(&args.weight);
260                let mask = args
261                    .mask
262                    .as_ref()
263                    .map(|mask| handles.get_float_tensor::<B>(mask));
264                let bias = args
265                    .bias
266                    .as_ref()
267                    .map(|bias| handles.get_float_tensor::<B>(bias));
268                let output_grad = handles.get_float_tensor::<B>(&args.out_grad);
269
270                let output = B::deform_conv2d_backward(
271                    x,
272                    offset,
273                    weight,
274                    mask,
275                    bias,
276                    output_grad,
277                    args.options.clone().into(),
278                );
279
280                handles.register_float_tensor::<B>(&args.input_grad.id, output.x_grad);
281                handles.register_float_tensor::<B>(&args.offset_grad.id, output.offset_grad);
282                handles.register_float_tensor::<B>(&args.weight_grad.id, output.weight_grad);
283                if let Some((mask_grad, field)) = output.mask_grad.zip(args.mask_grad.as_ref()) {
284                    handles.register_float_tensor::<B>(&field.id, mask_grad);
285                }
286                if let Some((bias_grad, field)) = output.bias_grad.zip(args.bias_grad.as_ref()) {
287                    handles.register_float_tensor::<B>(&field.id, bias_grad);
288                }
289            }
290        );
291
292        let input_grad = x
293            .client
294            .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
295        let offset_grad = offset
296            .client
297            .tensor_uninitialized(offset.shape.clone(), B::FloatElem::dtype());
298        let weight_grad = offset
299            .client
300            .tensor_uninitialized(weight.shape.clone(), B::FloatElem::dtype());
301        let mask_grad = mask.as_ref().map(|mask| {
302            offset
303                .client
304                .tensor_uninitialized(mask.shape.clone(), B::FloatElem::dtype())
305        });
306        let bias_grad = bias.as_ref().map(|bias| {
307            offset
308                .client
309                .tensor_uninitialized(bias.shape.clone(), B::FloatElem::dtype())
310        });
311
312        let mut streams = OperationStreams::default();
313        streams.tensor(&x);
314        streams.tensor(&offset);
315        streams.tensor(&weight);
316        streams.tensor(&output_grad);
317
318        if let Some(bias) = bias.as_ref() {
319            streams.tensor(bias)
320        }
321        if let Some(mask) = mask.as_ref() {
322            streams.tensor(mask)
323        }
324
325        let desc = DeformConv2dBackwardOpIr {
326            x: x.into_ir(),
327            offset: offset.into_ir(),
328            weight: weight.into_ir(),
329            mask: mask.map(|mask| mask.into_ir()),
330            bias: bias.map(|bias| bias.into_ir()),
331            options: options.into(),
332            out_grad: output_grad.into_ir(),
333            input_grad: input_grad.to_ir_out(),
334            offset_grad: offset_grad.to_ir_out(),
335            weight_grad: weight_grad.to_ir_out(),
336            mask_grad: mask_grad.as_ref().map(|mask_grad| mask_grad.to_ir_out()),
337            bias_grad: bias_grad.as_ref().map(|bias_grad| bias_grad.to_ir_out()),
338        };
339
340        input_grad.client.register(
341            streams,
342            OperationIr::Module(ModuleOperationIr::DeformableConv2dBackward(Box::new(
343                desc.clone(),
344            ))),
345            DeformConv2dBackwardOps::<B>::new(desc),
346        );
347
348        DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
349    }
350
351    fn conv3d(
352        x: FloatTensor<Self>,
353        weight: FloatTensor<Self>,
354        bias: Option<FloatTensor<Self>>,
355        options: ConvOptions<3>,
356    ) -> FloatTensor<Self> {
357        make_ops!(Conv3dOps, Conv3dOpIr, |args: &Conv3dOpIr,
358                                          handles: &mut HandleContainer<
359            B::Handle,
360        >| {
361            let x = handles.get_float_tensor::<B>(&args.x);
362            let weight = handles.get_float_tensor::<B>(&args.weight);
363            let bias = args
364                .bias
365                .as_ref()
366                .map(|bias| handles.get_float_tensor::<B>(bias));
367
368            let output = B::conv3d(x, weight, bias, args.options.clone().into());
369
370            handles.register_float_tensor::<B>(&args.out.id, output);
371        });
372
373        let size_0 = calculate_conv_output_size(
374            weight.shape[2],
375            options.stride[0],
376            options.padding[0],
377            options.dilation[0],
378            x.shape[2],
379        );
380        let size_1 = calculate_conv_output_size(
381            weight.shape[3],
382            options.stride[1],
383            options.padding[1],
384            options.dilation[1],
385            x.shape[3],
386        );
387        let size_2 = calculate_conv_output_size(
388            weight.shape[4],
389            options.stride[2],
390            options.padding[2],
391            options.dilation[2],
392            x.shape[4],
393        );
394
395        let mut streams = OperationStreams::default();
396        streams.tensor(&x);
397        streams.tensor(&weight);
398
399        if let Some(bias) = bias.as_ref() {
400            streams.tensor(bias)
401        }
402
403        let shape = vec![x.shape[0], weight.shape[0], size_0, size_1, size_2];
404        let out = x
405            .client
406            .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
407
408        let desc = Conv3dOpIr {
409            x: x.into_ir(),
410            weight: weight.into_ir(),
411            bias: bias.map(|bias| bias.into_ir()),
412            options: options.into(),
413            out: out.to_ir_out(),
414        };
415
416        out.client.register(
417            streams,
418            OperationIr::Module(ModuleOperationIr::Conv3d(desc.clone())),
419            Conv3dOps::<B>::new(desc),
420        );
421
422        out
423    }
424
425    fn conv_transpose1d(
426        x: FloatTensor<Self>,
427        weight: FloatTensor<Self>,
428        bias: Option<FloatTensor<Self>>,
429        options: ConvTransposeOptions<1>,
430    ) -> FloatTensor<Self> {
431        make_ops!(
432            ConvTranspose1dOps,
433            ConvTranspose1dOpIr,
434            |args: &ConvTranspose1dOpIr, handles: &mut HandleContainer<B::Handle>| {
435                let x = handles.get_float_tensor::<B>(&args.x);
436                let weight = handles.get_float_tensor::<B>(&args.weight);
437                let bias = args
438                    .bias
439                    .as_ref()
440                    .map(|bias| handles.get_float_tensor::<B>(bias));
441
442                let output = B::conv_transpose1d(x, weight, bias, args.options.clone().into());
443
444                handles.register_float_tensor::<B>(&args.out.id, output);
445            }
446        );
447
448        let size = calculate_conv_transpose_output_size(
449            weight.shape[2],
450            options.stride[0],
451            options.padding[0],
452            options.padding_out[0],
453            options.dilation[0],
454            x.shape[2],
455        );
456
457        let mut streams = OperationStreams::default();
458        streams.tensor(&x);
459        streams.tensor(&weight);
460
461        if let Some(bias) = bias.as_ref() {
462            streams.tensor(bias)
463        }
464
465        let shape = vec![x.shape[0], weight.shape[1] * options.groups, size];
466        let out = x
467            .client
468            .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
469
470        let desc = ConvTranspose1dOpIr {
471            x: x.into_ir(),
472            weight: weight.into_ir(),
473            bias: bias.map(|bias| bias.into_ir()),
474            options: options.into(),
475            out: out.to_ir_out(),
476        };
477
478        out.client.register(
479            streams,
480            OperationIr::Module(ModuleOperationIr::ConvTranspose1d(desc.clone())),
481            ConvTranspose1dOps::<B>::new(desc),
482        );
483
484        out
485    }
486
487    fn conv_transpose2d(
488        x: FloatTensor<Self>,
489        weight: FloatTensor<Self>,
490        bias: Option<FloatTensor<Self>>,
491        options: ConvTransposeOptions<2>,
492    ) -> FloatTensor<Self> {
493        make_ops!(
494            ConvTranspose2dOps,
495            ConvTranspose2dOpIr,
496            |args: &ConvTranspose2dOpIr, handles: &mut HandleContainer<B::Handle>| {
497                let x = handles.get_float_tensor::<B>(&args.x);
498                let weight = handles.get_float_tensor::<B>(&args.weight);
499                let bias = args
500                    .bias
501                    .as_ref()
502                    .map(|bias| handles.get_float_tensor::<B>(bias));
503
504                let output = B::conv_transpose2d(x, weight, bias, args.options.clone().into());
505
506                handles.register_float_tensor::<B>(&args.out.id, output);
507            }
508        );
509
510        let size_0 = calculate_conv_transpose_output_size(
511            weight.shape[2],
512            options.stride[0],
513            options.padding[0],
514            options.padding_out[0],
515            options.dilation[0],
516            x.shape[2],
517        );
518        let size_1 = calculate_conv_transpose_output_size(
519            weight.shape[3],
520            options.stride[1],
521            options.padding[1],
522            options.padding_out[1],
523            options.dilation[1],
524            x.shape[3],
525        );
526
527        let mut streams = OperationStreams::default();
528        streams.tensor(&x);
529        streams.tensor(&weight);
530
531        if let Some(bias) = bias.as_ref() {
532            streams.tensor(bias)
533        }
534
535        let shape = vec![x.shape[0], weight.shape[1] * options.groups, size_0, size_1];
536        let out = x
537            .client
538            .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
539
540        let desc = ConvTranspose2dOpIr {
541            x: x.into_ir(),
542            weight: weight.into_ir(),
543            bias: bias.map(|bias| bias.into_ir()),
544            options: options.into(),
545            out: out.to_ir_out(),
546        };
547
548        out.client.register(
549            streams,
550            OperationIr::Module(ModuleOperationIr::ConvTranspose2d(desc.clone())),
551            ConvTranspose2dOps::<B>::new(desc),
552        );
553
554        out
555    }
556
557    fn conv_transpose3d(
558        x: FloatTensor<Self>,
559        weight: FloatTensor<Self>,
560        bias: Option<FloatTensor<Self>>,
561        options: ConvTransposeOptions<3>,
562    ) -> FloatTensor<Self> {
563        make_ops!(
564            ConvTranspose3dOps,
565            ConvTranspose3dOpIr,
566            |args: &ConvTranspose3dOpIr, handles: &mut HandleContainer<B::Handle>| {
567                let x = handles.get_float_tensor::<B>(&args.x);
568                let weight = handles.get_float_tensor::<B>(&args.weight);
569                let bias = args
570                    .bias
571                    .as_ref()
572                    .map(|bias| handles.get_float_tensor::<B>(bias));
573
574                let output = B::conv_transpose3d(x, weight, bias, args.options.clone().into());
575
576                handles.register_float_tensor::<B>(&args.out.id, output);
577            }
578        );
579
580        let size_0 = calculate_conv_transpose_output_size(
581            weight.shape[2],
582            options.stride[0],
583            options.padding[0],
584            options.padding_out[0],
585            options.dilation[0],
586            x.shape[2],
587        );
588        let size_1 = calculate_conv_transpose_output_size(
589            weight.shape[3],
590            options.stride[1],
591            options.padding[1],
592            options.padding_out[1],
593            options.dilation[1],
594            x.shape[3],
595        );
596        let size_2 = calculate_conv_transpose_output_size(
597            weight.shape[4],
598            options.stride[2],
599            options.padding[2],
600            options.padding_out[2],
601            options.dilation[2],
602            x.shape[4],
603        );
604
605        let mut streams = OperationStreams::default();
606        streams.tensor(&x);
607        streams.tensor(&weight);
608
609        if let Some(bias) = bias.as_ref() {
610            streams.tensor(bias)
611        }
612
613        let shape = vec![
614            x.shape[0],
615            weight.shape[1] * options.groups,
616            size_0,
617            size_1,
618            size_2,
619        ];
620        let out = x
621            .client
622            .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
623
624        let desc = ConvTranspose3dOpIr {
625            x: x.into_ir(),
626            weight: weight.into_ir(),
627            bias: bias.map(|bias| bias.into_ir()),
628            options: options.into(),
629            out: out.to_ir_out(),
630        };
631
632        out.client.register(
633            streams,
634            OperationIr::Module(ModuleOperationIr::ConvTranspose3d(desc.clone())),
635            ConvTranspose3dOps::<B>::new(desc),
636        );
637
638        out
639    }
640
641    fn avg_pool1d(
642        x: FloatTensor<Self>,
643        kernel_size: usize,
644        stride: usize,
645        padding: usize,
646        count_include_pad: bool,
647    ) -> FloatTensor<Self> {
648        make_ops!(
649            AvgPool1dOps,
650            AvgPool1dOpIr,
651            |args: &AvgPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
652                let x = handles.get_float_tensor::<B>(&args.x);
653                let output = B::avg_pool1d(
654                    x,
655                    args.kernel_size,
656                    args.stride,
657                    args.padding,
658                    args.count_include_pad,
659                );
660
661                handles.register_float_tensor::<B>(&args.out.id, output);
662            }
663        );
664
665        let mut streams = OperationStreams::default();
666        streams.tensor(&x);
667
668        let size = calculate_pool_output_size(kernel_size, stride, padding, 1, x.shape[2]);
669        let shape = vec![x.shape[0], x.shape[1], size];
670        let out = x
671            .client
672            .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
673
674        let desc = AvgPool1dOpIr {
675            x: x.into_ir(),
676            kernel_size,
677            stride,
678            padding,
679            count_include_pad,
680            out: out.to_ir_out(),
681        };
682        out.client.register(
683            streams,
684            OperationIr::Module(ModuleOperationIr::AvgPool1d(desc.clone())),
685            AvgPool1dOps::<B>::new(desc),
686        );
687
688        out
689    }
690
691    fn avg_pool2d(
692        x: FloatTensor<Self>,
693        kernel_size: [usize; 2],
694        stride: [usize; 2],
695        padding: [usize; 2],
696        count_include_pad: bool,
697    ) -> FloatTensor<Self> {
698        make_ops!(
699            AvgPool2dOps,
700            AvgPool2dOpIr,
701            |args: &AvgPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
702                let x = handles.get_float_tensor::<B>(&args.x);
703                let output = B::avg_pool2d(
704                    x,
705                    args.kernel_size,
706                    args.stride,
707                    args.padding,
708                    args.count_include_pad,
709                );
710
711                handles.register_float_tensor::<B>(&args.out.id, output);
712            }
713        );
714
715        let size_0 =
716            calculate_pool_output_size(kernel_size[0], stride[0], padding[0], 1, x.shape[2]);
717        let size_1 =
718            calculate_pool_output_size(kernel_size[1], stride[1], padding[1], 1, x.shape[3]);
719
720        let mut streams = OperationStreams::default();
721        streams.tensor(&x);
722
723        let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
724        let out = x
725            .client
726            .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
727
728        let desc = AvgPool2dOpIr {
729            x: x.into_ir(),
730            kernel_size,
731            stride,
732            padding,
733            count_include_pad,
734            out: out.to_ir_out(),
735        };
736        out.client.register(
737            streams,
738            OperationIr::Module(ModuleOperationIr::AvgPool2d(desc.clone())),
739            AvgPool2dOps::<B>::new(desc),
740        );
741
742        out
743    }
744
745    fn avg_pool1d_backward(
746        x: FloatTensor<Self>,
747        grad: FloatTensor<Self>,
748        kernel_size: usize,
749        stride: usize,
750        padding: usize,
751        count_include_pad: bool,
752    ) -> FloatTensor<Self> {
753        make_ops!(
754            AvgPool1dBackwardOps,
755            AvgPool1dBackwardOpIr,
756            |args: &AvgPool1dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
757                let x = handles.get_float_tensor::<B>(&args.x);
758                let grad = handles.get_float_tensor::<B>(&args.grad);
759                let output = B::avg_pool1d_backward(
760                    x,
761                    grad,
762                    args.kernel_size,
763                    args.stride,
764                    args.padding,
765                    args.count_include_pad,
766                );
767
768                handles.register_float_tensor::<B>(&args.out.id, output);
769            }
770        );
771
772        let mut streams = OperationStreams::default();
773        streams.tensor(&x);
774        streams.tensor(&grad);
775
776        let out = x
777            .client
778            .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
779
780        let desc = AvgPool1dBackwardOpIr {
781            x: x.into_ir(),
782            grad: grad.into_ir(),
783            kernel_size,
784            stride,
785            padding,
786            count_include_pad,
787            out: out.to_ir_out(),
788        };
789        out.client.register(
790            streams,
791            OperationIr::Module(ModuleOperationIr::AvgPool1dBackward(desc.clone())),
792            AvgPool1dBackwardOps::<B>::new(desc),
793        );
794
795        out
796    }
797
798    fn avg_pool2d_backward(
799        x: FloatTensor<Self>,
800        grad: FloatTensor<Self>,
801        kernel_size: [usize; 2],
802        stride: [usize; 2],
803        padding: [usize; 2],
804        count_include_pad: bool,
805    ) -> FloatTensor<Self> {
806        make_ops!(
807            AvgPool2dBackwardOps,
808            AvgPool2dBackwardOpIr,
809            |args: &AvgPool2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
810                let x = handles.get_float_tensor::<B>(&args.x);
811                let grad = handles.get_float_tensor::<B>(&args.grad);
812                let output = B::avg_pool2d_backward(
813                    x,
814                    grad,
815                    args.kernel_size,
816                    args.stride,
817                    args.padding,
818                    args.count_include_pad,
819                );
820
821                handles.register_float_tensor::<B>(&args.out.id, output);
822            }
823        );
824
825        let mut streams = OperationStreams::default();
826        streams.tensor(&x);
827        streams.tensor(&grad);
828
829        let out = x
830            .client
831            .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
832
833        let desc = AvgPool2dBackwardOpIr {
834            x: x.into_ir(),
835            grad: grad.into_ir(),
836            kernel_size,
837            stride,
838            padding,
839            count_include_pad,
840            out: out.to_ir_out(),
841        };
842        out.client.register(
843            streams,
844            OperationIr::Module(ModuleOperationIr::AvgPool2dBackward(desc.clone())),
845            AvgPool2dBackwardOps::<B>::new(desc),
846        );
847
848        out
849    }
850
851    fn max_pool1d(
852        x: FloatTensor<Self>,
853        kernel_size: usize,
854        stride: usize,
855        padding: usize,
856        dilation: usize,
857    ) -> FloatTensor<Self> {
858        make_ops!(
859            MaxPool1dOps,
860            MaxPool1dOpIr,
861            |args: &MaxPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
862                let x = handles.get_float_tensor::<B>(&args.x);
863                let output = B::max_pool1d(
864                    x,
865                    args.kernel_size,
866                    args.stride,
867                    args.padding,
868                    args.dilation,
869                );
870
871                handles.register_float_tensor::<B>(&args.out.id, output);
872            }
873        );
874
875        let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]);
876
877        let mut streams = OperationStreams::default();
878        streams.tensor(&x);
879
880        let shape = vec![x.shape[0], x.shape[1], size];
881        let out = x
882            .client
883            .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
884
885        let desc = MaxPool1dOpIr {
886            x: x.into_ir(),
887            kernel_size,
888            stride,
889            padding,
890            dilation,
891            out: out.to_ir_out(),
892        };
893        out.client.register(
894            streams,
895            OperationIr::Module(ModuleOperationIr::MaxPool1d(desc.clone())),
896            MaxPool1dOps::<B>::new(desc),
897        );
898
899        out
900    }
901
902    fn max_pool2d(
903        x: FloatTensor<Self>,
904        kernel_size: [usize; 2],
905        stride: [usize; 2],
906        padding: [usize; 2],
907        dilation: [usize; 2],
908    ) -> FloatTensor<Self> {
909        make_ops!(
910            MaxPool2dOps,
911            MaxPool2dOpIr,
912            |args: &MaxPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
913                let x = handles.get_float_tensor::<B>(&args.x);
914                let output = B::max_pool2d(
915                    x,
916                    args.kernel_size,
917                    args.stride,
918                    args.padding,
919                    args.dilation,
920                );
921
922                handles.register_float_tensor::<B>(&args.out.id, output);
923            }
924        );
925
926        let size_0 = calculate_pool_output_size(
927            kernel_size[0],
928            stride[0],
929            padding[0],
930            dilation[0],
931            x.shape[2],
932        );
933        let size_1 = calculate_pool_output_size(
934            kernel_size[1],
935            stride[1],
936            padding[1],
937            dilation[1],
938            x.shape[3],
939        );
940
941        let mut streams = OperationStreams::default();
942        streams.tensor(&x);
943
944        let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
945        let out = x
946            .client
947            .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
948
949        let desc = MaxPool2dOpIr {
950            x: x.into_ir(),
951            kernel_size,
952            stride,
953            padding,
954            dilation,
955            out: out.to_ir_out(),
956        };
957        out.client.register(
958            streams,
959            OperationIr::Module(ModuleOperationIr::MaxPool2d(desc.clone())),
960            MaxPool2dOps::<B>::new(desc),
961        );
962
963        out
964    }
965
966    fn max_pool1d_with_indices(
967        x: FloatTensor<Self>,
968        kernel_size: usize,
969        stride: usize,
970        padding: usize,
971        dilation: usize,
972    ) -> MaxPool1dWithIndices<Self> {
973        make_ops!(
974            MaxPool1dWithIndicesOps,
975            MaxPool1dWithIndicesOpIr,
976            |args: &MaxPool1dWithIndicesOpIr, handles: &mut HandleContainer<B::Handle>| {
977                let x = handles.get_float_tensor::<B>(&args.x);
978                let output = B::max_pool1d_with_indices(
979                    x,
980                    args.kernel_size,
981                    args.stride,
982                    args.padding,
983                    args.dilation,
984                );
985
986                handles.register_float_tensor::<B>(&args.out.id, output.output);
987                handles.register_int_tensor::<B>(&args.out_indices.id, output.indices);
988            }
989        );
990
991        let mut streams = OperationStreams::default();
992        streams.tensor(&x);
993
994        let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]);
995        let shape = vec![x.shape[0], x.shape[1], size];
996        let out = x
997            .client
998            .tensor_uninitialized(Shape::from(shape.clone()), B::FloatElem::dtype());
999        let out_indices = x
1000            .client
1001            .tensor_uninitialized(Shape::from(shape), B::IntElem::dtype());
1002
1003        let desc = MaxPool1dWithIndicesOpIr {
1004            x: x.into_ir(),
1005            kernel_size,
1006            stride,
1007            padding,
1008            dilation,
1009            out: out.to_ir_out(),
1010            out_indices: out_indices.to_ir_out(),
1011        };
1012        out.client.register(
1013            streams,
1014            OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndices(desc.clone())),
1015            MaxPool1dWithIndicesOps::<B>::new(desc),
1016        );
1017
1018        MaxPool1dWithIndices::new(out, out_indices)
1019    }
1020
1021    fn max_pool2d_with_indices(
1022        x: FloatTensor<Self>,
1023        kernel_size: [usize; 2],
1024        stride: [usize; 2],
1025        padding: [usize; 2],
1026        dilation: [usize; 2],
1027    ) -> MaxPool2dWithIndices<Self> {
1028        make_ops!(
1029            MaxPool2dWithIndicesOps,
1030            MaxPool2dWithIndicesOpIr,
1031            |args: &MaxPool2dWithIndicesOpIr, handles: &mut HandleContainer<B::Handle>| {
1032                let x = handles.get_float_tensor::<B>(&args.x);
1033                let output = B::max_pool2d_with_indices(
1034                    x,
1035                    args.kernel_size,
1036                    args.stride,
1037                    args.padding,
1038                    args.dilation,
1039                );
1040
1041                handles.register_float_tensor::<B>(&args.out.id, output.output);
1042                handles.register_int_tensor::<B>(&args.out_indices.id, output.indices);
1043            }
1044        );
1045
1046        let size_0 = calculate_pool_output_size(
1047            kernel_size[0],
1048            stride[0],
1049            padding[0],
1050            dilation[0],
1051            x.shape[2],
1052        );
1053        let size_1 = calculate_pool_output_size(
1054            kernel_size[1],
1055            stride[1],
1056            padding[1],
1057            dilation[1],
1058            x.shape[3],
1059        );
1060
1061        let mut streams = OperationStreams::default();
1062        streams.tensor(&x);
1063
1064        let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
1065        let out = x
1066            .client
1067            .tensor_uninitialized(Shape::from(shape.clone()), B::FloatElem::dtype());
1068        let out_indices = x
1069            .client
1070            .tensor_uninitialized(Shape::from(shape), B::IntElem::dtype());
1071
1072        let desc = MaxPool2dWithIndicesOpIr {
1073            x: x.into_ir(),
1074            kernel_size,
1075            stride,
1076            padding,
1077            dilation,
1078            out: out.to_ir_out(),
1079            out_indices: out_indices.to_ir_out(),
1080        };
1081        out.client.register(
1082            streams,
1083            OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndices(desc.clone())),
1084            MaxPool2dWithIndicesOps::<B>::new(desc),
1085        );
1086
1087        MaxPool2dWithIndices::new(out, out_indices)
1088    }
1089
1090    fn max_pool1d_with_indices_backward(
1091        x: FloatTensor<Self>,
1092        kernel_size: usize,
1093        stride: usize,
1094        padding: usize,
1095        dilation: usize,
1096        output_grad: FloatTensor<Self>,
1097        indices: IntTensor<Self>,
1098    ) -> MaxPool1dBackward<Self> {
1099        make_ops!(
1100            MaxPool1dWithIndicesBackwardOps,
1101            MaxPool1dWithIndicesBackwardOpIr,
1102            |args: &MaxPool1dWithIndicesBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1103                let x = handles.get_float_tensor::<B>(&args.x);
1104                let grad = handles.get_float_tensor::<B>(&args.grad);
1105                let indices = handles.get_int_tensor::<B>(&args.indices);
1106                let output = B::max_pool1d_with_indices_backward(
1107                    x,
1108                    args.kernel_size,
1109                    args.stride,
1110                    args.padding,
1111                    args.dilation,
1112                    grad,
1113                    indices,
1114                );
1115
1116                handles.register_float_tensor::<B>(&args.out.id, output.x_grad);
1117            }
1118        );
1119
1120        let mut streams = OperationStreams::default();
1121        streams.tensor(&x);
1122        streams.tensor(&output_grad);
1123        streams.tensor(&indices);
1124
1125        let out = x
1126            .client
1127            .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
1128
1129        let desc = MaxPool1dWithIndicesBackwardOpIr {
1130            x: x.into_ir(),
1131            grad: output_grad.into_ir(),
1132            indices: indices.into_ir(),
1133            kernel_size,
1134            stride,
1135            padding,
1136            dilation,
1137            out: out.to_ir_out(),
1138        };
1139        out.client.register(
1140            streams,
1141            OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndicesBackward(
1142                desc.clone(),
1143            )),
1144            MaxPool1dWithIndicesBackwardOps::<B>::new(desc),
1145        );
1146
1147        MaxPool1dBackward::new(out)
1148    }
1149
1150    fn max_pool2d_with_indices_backward(
1151        x: FloatTensor<Self>,
1152        kernel_size: [usize; 2],
1153        stride: [usize; 2],
1154        padding: [usize; 2],
1155        dilation: [usize; 2],
1156        output_grad: FloatTensor<Self>,
1157        indices: IntTensor<Self>,
1158    ) -> MaxPool2dBackward<Self> {
1159        make_ops!(
1160            MaxPool2dWithIndicesBackwardOps,
1161            MaxPool2dWithIndicesBackwardOpIr,
1162            |args: &MaxPool2dWithIndicesBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1163                let x = handles.get_float_tensor::<B>(&args.x);
1164                let grad = handles.get_float_tensor::<B>(&args.grad);
1165                let indices = handles.get_int_tensor::<B>(&args.indices);
1166                let output = B::max_pool2d_with_indices_backward(
1167                    x,
1168                    args.kernel_size,
1169                    args.stride,
1170                    args.padding,
1171                    args.dilation,
1172                    grad,
1173                    indices,
1174                );
1175
1176                handles.register_float_tensor::<B>(&args.out.id, output.x_grad);
1177            }
1178        );
1179
1180        let mut streams = OperationStreams::default();
1181        streams.tensor(&x);
1182        streams.tensor(&output_grad);
1183        streams.tensor(&indices);
1184
1185        let out = x
1186            .client
1187            .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
1188
1189        let desc = MaxPool2dWithIndicesBackwardOpIr {
1190            x: x.into_ir(),
1191            grad: output_grad.into_ir(),
1192            indices: indices.into_ir(),
1193            kernel_size,
1194            stride,
1195            padding,
1196            dilation,
1197            out: out.to_ir_out(),
1198        };
1199        out.client.register(
1200            streams,
1201            OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndicesBackward(
1202                desc.clone(),
1203            )),
1204            MaxPool2dWithIndicesBackwardOps::<B>::new(desc),
1205        );
1206
1207        MaxPool2dBackward::new(out)
1208    }
1209
1210    fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
1211        make_ops!(
1212            AdaptiveAvgPool1dOps,
1213            AdaptiveAvgPool1dOpIr,
1214            |args: &AdaptiveAvgPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
1215                let x = handles.get_float_tensor::<B>(&args.x);
1216                let output = B::adaptive_avg_pool1d(x, args.output_size);
1217
1218                handles.register_float_tensor::<B>(&args.out.id, output);
1219            }
1220        );
1221
1222        let mut streams = OperationStreams::default();
1223        streams.tensor(&x);
1224
1225        let shape = vec![x.shape[0], x.shape[1], output_size];
1226        let out = x
1227            .client
1228            .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
1229
1230        let desc = AdaptiveAvgPool1dOpIr {
1231            x: x.into_ir(),
1232            output_size,
1233            out: out.to_ir_out(),
1234        };
1235        out.client.register(
1236            streams,
1237            OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d(desc.clone())),
1238            AdaptiveAvgPool1dOps::<B>::new(desc),
1239        );
1240
1241        out
1242    }
1243
1244    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
1245        make_ops!(
1246            AdaptiveAvgPool2dOps,
1247            AdaptiveAvgPool2dOpIr,
1248            |args: &AdaptiveAvgPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
1249                let x = handles.get_float_tensor::<B>(&args.x);
1250                let output = B::adaptive_avg_pool2d(x, args.output_size);
1251
1252                handles.register_float_tensor::<B>(&args.out.id, output);
1253            }
1254        );
1255
1256        let mut streams = OperationStreams::default();
1257        streams.tensor(&x);
1258
1259        let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]];
1260        let out = x
1261            .client
1262            .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
1263
1264        let desc = AdaptiveAvgPool2dOpIr {
1265            x: x.into_ir(),
1266            output_size,
1267            out: out.to_ir_out(),
1268        };
1269        out.client.register(
1270            streams,
1271            OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d(desc.clone())),
1272            AdaptiveAvgPool2dOps::<B>::new(desc),
1273        );
1274
1275        out
1276    }
1277
1278    fn adaptive_avg_pool1d_backward(
1279        x: FloatTensor<Self>,
1280        grad: FloatTensor<Self>,
1281    ) -> FloatTensor<Self> {
1282        make_ops!(
1283            AdaptiveAvgPool1dBackwardOps,
1284            AdaptiveAvgPool1dBackwardOpIr,
1285            |args: &AdaptiveAvgPool1dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1286                let x = handles.get_float_tensor::<B>(&args.x);
1287                let grad = handles.get_float_tensor::<B>(&args.grad);
1288                let output = B::adaptive_avg_pool1d_backward(x, grad);
1289
1290                handles.register_float_tensor::<B>(&args.out.id, output);
1291            }
1292        );
1293
1294        let mut streams = OperationStreams::default();
1295        streams.tensor(&x);
1296        streams.tensor(&grad);
1297
1298        let out = x
1299            .client
1300            .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
1301        let desc = AdaptiveAvgPool1dBackwardOpIr {
1302            x: x.into_ir(),
1303            grad: grad.into_ir(),
1304            out: out.to_ir_out(),
1305        };
1306
1307        out.client.register(
1308            streams,
1309            OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1dBackward(desc.clone())),
1310            AdaptiveAvgPool1dBackwardOps::<B>::new(desc),
1311        );
1312
1313        out
1314    }
1315
1316    fn adaptive_avg_pool2d_backward(
1317        x: FloatTensor<Self>,
1318        grad: FloatTensor<Self>,
1319    ) -> FloatTensor<Self> {
1320        make_ops!(
1321            AdaptiveAvgPool2dBackwardOps,
1322            AdaptiveAvgPool2dBackwardOpIr,
1323            |args: &AdaptiveAvgPool2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1324                let x = handles.get_float_tensor::<B>(&args.x);
1325                let grad = handles.get_float_tensor::<B>(&args.grad);
1326                let output = B::adaptive_avg_pool2d_backward(x, grad);
1327
1328                handles.register_float_tensor::<B>(&args.out.id, output);
1329            }
1330        );
1331
1332        let mut streams = OperationStreams::default();
1333        streams.tensor(&x);
1334        streams.tensor(&grad);
1335
1336        let out = x
1337            .client
1338            .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
1339
1340        let desc = AdaptiveAvgPool2dBackwardOpIr {
1341            x: x.into_ir(),
1342            grad: grad.into_ir(),
1343            out: out.to_ir_out(),
1344        };
1345        out.client.register(
1346            streams,
1347            OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2dBackward(desc.clone())),
1348            AdaptiveAvgPool2dBackwardOps::<B>::new(desc),
1349        );
1350
1351        out
1352    }
1353
1354    fn interpolate(
1355        x: FloatTensor<Self>,
1356        output_size: [usize; 2],
1357        options: InterpolateOptions,
1358    ) -> FloatTensor<Self> {
1359        make_ops!(
1360            InterpolateOps,
1361            InterpolateOpIr,
1362            |args: &InterpolateOpIr, handles: &mut HandleContainer<B::Handle>| {
1363                let x = handles.get_float_tensor::<B>(&args.x);
1364                let output = B::interpolate(x, args.output_size, args.options.clone().into());
1365                handles.register_float_tensor::<B>(&args.out.id, output);
1366            }
1367        );
1368
1369        let mut streams = OperationStreams::default();
1370        streams.tensor(&x);
1371
1372        let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]];
1373        let out = x
1374            .client
1375            .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
1376
1377        let desc = InterpolateOpIr {
1378            x: x.into_ir(),
1379            output_size,
1380            options: options.into(),
1381            out: out.to_ir_out(),
1382        };
1383
1384        out.client.register(
1385            streams,
1386            OperationIr::Module(ModuleOperationIr::Interpolate(desc.clone())),
1387            InterpolateOps::<B>::new(desc),
1388        );
1389
1390        out
1391    }
1392
1393    fn interpolate_backward(
1394        x: FloatTensor<Self>,
1395        grad: FloatTensor<Self>,
1396        output_size: [usize; 2],
1397        options: InterpolateOptions,
1398    ) -> FloatTensor<Self> {
1399        make_ops!(
1400            InterpolateBackwardOps,
1401            InterpolateBackwardOpIr,
1402            |args: &InterpolateBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1403                let x = handles.get_float_tensor::<B>(&args.x);
1404                let grad = handles.get_float_tensor::<B>(&args.grad);
1405                let output =
1406                    B::interpolate_backward(x, grad, args.output_size, args.options.clone().into());
1407
1408                handles.register_float_tensor::<B>(&args.out.id, output);
1409            }
1410        );
1411
1412        let mut streams = OperationStreams::default();
1413        streams.tensor(&x);
1414        streams.tensor(&grad);
1415
1416        let out = x
1417            .client
1418            .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
1419
1420        let desc = InterpolateBackwardOpIr {
1421            x: x.into_ir(),
1422            grad: grad.into_ir(),
1423            output_size,
1424            options: options.into(),
1425            out: out.to_ir_out(),
1426        };
1427        out.client.register(
1428            streams,
1429            OperationIr::Module(ModuleOperationIr::InterpolateBackward(desc.clone())),
1430            InterpolateBackwardOps::<B>::new(desc),
1431        );
1432        out
1433    }
1434}