Skip to main content

burn_fusion/ops/
module.rs

1use crate::{
2    Fusion, FusionBackend,
3    stream::{OperationStreams, execution::Operation},
4};
5use burn_backend::{
6    Element,
7    ops::{
8        ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions,
9        InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,
10        MaxPool2dWithIndices, ModuleOps,
11    },
12    tensor::{FloatTensor, IntTensor},
13};
14use burn_ir::*;
15use std::marker::PhantomData;
16
17macro_rules! make_ops {
18    ($name:ident, $desc:ty, $fn:expr) => {
19        #[derive(new, Debug)]
20        struct $name<B: FusionBackend> {
21            desc: $desc,
22            _b: PhantomData<B>,
23        }
24
25        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
26            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
27                #[allow(clippy::redundant_closure_call)]
28                $fn(&self.desc, handles)
29            }
30        }
31    };
32}
33
34impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
35    fn conv1d(
36        x: FloatTensor<Self>,
37        weight: FloatTensor<Self>,
38        bias: Option<FloatTensor<Self>>,
39        options: ConvOptions<1>,
40    ) -> FloatTensor<Self> {
41        make_ops!(Conv1dOps, Conv1dOpIr, |desc: &Conv1dOpIr,
42                                          handles: &mut HandleContainer<
43            B::Handle,
44        >| {
45            let x = handles.get_float_tensor::<B>(&desc.x);
46            let weight = handles.get_float_tensor::<B>(&desc.weight);
47            let bias = desc
48                .bias
49                .as_ref()
50                .map(|bias| handles.get_float_tensor::<B>(bias));
51            let output = B::conv1d(x, weight, bias, desc.options.clone().into());
52            handles.register_float_tensor::<B>(&desc.out.id, output);
53        });
54
55        let mut streams = OperationStreams::with_inputs([&x, &weight]);
56        if let Some(bias) = bias.as_ref() {
57            streams.tensor(bias)
58        }
59
60        let client = x.client.clone();
61        let desc = Conv1dOpIr::create(
62            x.into_ir(),
63            weight.into_ir(),
64            bias.map(|bias| bias.into_ir()),
65            options.into(),
66            || client.create_empty_handle(),
67        );
68
69        client
70            .register(
71                streams,
72                OperationIr::Module(ModuleOperationIr::Conv1d(desc.clone())),
73                Conv1dOps::<B>::new(desc),
74            )
75            .output()
76    }
77
78    fn conv2d(
79        x: FloatTensor<Self>,
80        weight: FloatTensor<Self>,
81        bias: Option<FloatTensor<Self>>,
82        options: ConvOptions<2>,
83    ) -> FloatTensor<Self> {
84        make_ops!(Conv2dOps, Conv2dOpIr, |args: &Conv2dOpIr,
85                                          handles: &mut HandleContainer<
86            B::Handle,
87        >| {
88            let x = handles.get_float_tensor::<B>(&args.x);
89            let weight = handles.get_float_tensor::<B>(&args.weight);
90            let bias = args
91                .bias
92                .as_ref()
93                .map(|bias| handles.get_float_tensor::<B>(bias));
94
95            let output = B::conv2d(x, weight, bias, args.options.clone().into());
96
97            handles.register_float_tensor::<B>(&args.out.id, output);
98        });
99
100        let mut streams = OperationStreams::with_inputs([&x, &weight]);
101        if let Some(bias) = bias.as_ref() {
102            streams.tensor(bias)
103        }
104
105        let client = x.client.clone();
106        let desc = Conv2dOpIr::create(
107            x.into_ir(),
108            weight.into_ir(),
109            bias.map(|bias| bias.into_ir()),
110            options.into(),
111            || client.create_empty_handle(),
112        );
113
114        client
115            .register(
116                streams,
117                OperationIr::Module(ModuleOperationIr::Conv2d(desc.clone())),
118                Conv2dOps::<B>::new(desc),
119            )
120            .output()
121    }
122
123    fn deform_conv2d(
124        x: FloatTensor<Self>,
125        offset: FloatTensor<Self>,
126        weight: FloatTensor<Self>,
127        mask: Option<FloatTensor<Self>>,
128        bias: Option<FloatTensor<Self>>,
129        options: DeformConvOptions<2>,
130    ) -> FloatTensor<Self> {
131        make_ops!(
132            DeformConv2dOps,
133            DeformConv2dOpIr,
134            |args: &DeformConv2dOpIr, handles: &mut HandleContainer<B::Handle>| {
135                let x = handles.get_float_tensor::<B>(&args.x);
136                let offset = handles.get_float_tensor::<B>(&args.offset);
137                let weight = handles.get_float_tensor::<B>(&args.weight);
138                let mask = args
139                    .mask
140                    .as_ref()
141                    .map(|mask| handles.get_float_tensor::<B>(mask));
142                let bias = args
143                    .bias
144                    .as_ref()
145                    .map(|bias| handles.get_float_tensor::<B>(bias));
146
147                let output =
148                    B::deform_conv2d(x, offset, weight, mask, bias, args.options.clone().into());
149
150                handles.register_float_tensor::<B>(&args.out.id, output);
151            }
152        );
153        let mut streams = OperationStreams::with_inputs([&x, &offset, &weight]);
154        if let Some(bias) = bias.as_ref() {
155            streams.tensor(bias)
156        }
157        if let Some(mask) = mask.as_ref() {
158            streams.tensor(mask)
159        }
160
161        let client = x.client.clone();
162        let desc = DeformConv2dOpIr::create(
163            x.into_ir(),
164            offset.into_ir(),
165            weight.into_ir(),
166            mask.map(|mask| mask.into_ir()),
167            bias.map(|bias| bias.into_ir()),
168            options.into(),
169            || client.create_empty_handle(),
170        );
171
172        client
173            .register(
174                streams,
175                OperationIr::Module(ModuleOperationIr::DeformableConv2d(Box::new(desc.clone()))),
176                DeformConv2dOps::<B>::new(desc),
177            )
178            .output()
179    }
180
181    fn deform_conv2d_backward(
182        x: FloatTensor<Self>,
183        offset: FloatTensor<Self>,
184        weight: FloatTensor<Self>,
185        mask: Option<FloatTensor<Self>>,
186        bias: Option<FloatTensor<Self>>,
187        output_grad: FloatTensor<Self>,
188        options: DeformConvOptions<2>,
189    ) -> DeformConv2dBackward<Self> {
190        make_ops!(
191            DeformConv2dBackwardOps,
192            DeformConv2dBackwardOpIr,
193            |args: &DeformConv2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
194                let x = handles.get_float_tensor::<B>(&args.x);
195                let offset = handles.get_float_tensor::<B>(&args.offset);
196                let weight = handles.get_float_tensor::<B>(&args.weight);
197                let mask = args
198                    .mask
199                    .as_ref()
200                    .map(|mask| handles.get_float_tensor::<B>(mask));
201                let bias = args
202                    .bias
203                    .as_ref()
204                    .map(|bias| handles.get_float_tensor::<B>(bias));
205                let output_grad = handles.get_float_tensor::<B>(&args.out_grad);
206
207                let output = B::deform_conv2d_backward(
208                    x,
209                    offset,
210                    weight,
211                    mask,
212                    bias,
213                    output_grad,
214                    args.options.clone().into(),
215                );
216
217                handles.register_float_tensor::<B>(&args.input_grad.id, output.x_grad);
218                handles.register_float_tensor::<B>(&args.offset_grad.id, output.offset_grad);
219                handles.register_float_tensor::<B>(&args.weight_grad.id, output.weight_grad);
220                if let Some((mask_grad, field)) = output.mask_grad.zip(args.mask_grad.as_ref()) {
221                    handles.register_float_tensor::<B>(&field.id, mask_grad);
222                }
223                if let Some((bias_grad, field)) = output.bias_grad.zip(args.bias_grad.as_ref()) {
224                    handles.register_float_tensor::<B>(&field.id, bias_grad);
225                }
226            }
227        );
228
229        let has_bias = bias.is_some();
230        let has_mask = mask.is_some();
231
232        let mut streams = OperationStreams::with_inputs([&x, &offset, &weight, &output_grad]);
233        if let Some(bias) = bias.as_ref() {
234            streams.tensor(bias);
235        }
236        if let Some(mask) = mask.as_ref() {
237            streams.tensor(mask);
238        }
239
240        let client = x.client.clone();
241        let desc = DeformConv2dBackwardOpIr::create(
242            x.into_ir(),
243            offset.into_ir(),
244            weight.into_ir(),
245            mask.map(|mask| mask.into_ir()),
246            bias.map(|bias| bias.into_ir()),
247            output_grad.into_ir(),
248            options.into(),
249            || client.create_empty_handle(),
250        );
251
252        let mut outputs = client
253            .register(
254                streams,
255                OperationIr::Module(ModuleOperationIr::DeformableConv2dBackward(Box::new(
256                    desc.clone(),
257                ))),
258                DeformConv2dBackwardOps::<B>::new(desc),
259            )
260            .into_iter();
261
262        // When the number of outputs is variable, the order is important
263        let input_grad = outputs.next().unwrap();
264        let offset_grad = outputs.next().unwrap();
265        let weight_grad = outputs.next().unwrap();
266        let mask_grad = has_mask.then(|| outputs.next().unwrap());
267        let bias_grad = has_bias.then(|| outputs.next().unwrap());
268
269        DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
270    }
271
272    fn conv3d(
273        x: FloatTensor<Self>,
274        weight: FloatTensor<Self>,
275        bias: Option<FloatTensor<Self>>,
276        options: ConvOptions<3>,
277    ) -> FloatTensor<Self> {
278        make_ops!(Conv3dOps, Conv3dOpIr, |args: &Conv3dOpIr,
279                                          handles: &mut HandleContainer<
280            B::Handle,
281        >| {
282            let x = handles.get_float_tensor::<B>(&args.x);
283            let weight = handles.get_float_tensor::<B>(&args.weight);
284            let bias = args
285                .bias
286                .as_ref()
287                .map(|bias| handles.get_float_tensor::<B>(bias));
288
289            let output = B::conv3d(x, weight, bias, args.options.clone().into());
290
291            handles.register_float_tensor::<B>(&args.out.id, output);
292        });
293
294        let mut streams = OperationStreams::with_inputs([&x, &weight]);
295        if let Some(bias) = bias.as_ref() {
296            streams.tensor(bias)
297        }
298
299        let client = x.client.clone();
300        let desc = Conv3dOpIr::create(
301            x.into_ir(),
302            weight.into_ir(),
303            bias.map(|bias| bias.into_ir()),
304            options.into(),
305            || client.create_empty_handle(),
306        );
307
308        client
309            .register(
310                streams,
311                OperationIr::Module(ModuleOperationIr::Conv3d(desc.clone())),
312                Conv3dOps::<B>::new(desc),
313            )
314            .output()
315    }
316
317    fn conv_transpose1d(
318        x: FloatTensor<Self>,
319        weight: FloatTensor<Self>,
320        bias: Option<FloatTensor<Self>>,
321        options: ConvTransposeOptions<1>,
322    ) -> FloatTensor<Self> {
323        make_ops!(
324            ConvTranspose1dOps,
325            ConvTranspose1dOpIr,
326            |args: &ConvTranspose1dOpIr, handles: &mut HandleContainer<B::Handle>| {
327                let x = handles.get_float_tensor::<B>(&args.x);
328                let weight = handles.get_float_tensor::<B>(&args.weight);
329                let bias = args
330                    .bias
331                    .as_ref()
332                    .map(|bias| handles.get_float_tensor::<B>(bias));
333
334                let output = B::conv_transpose1d(x, weight, bias, args.options.clone().into());
335
336                handles.register_float_tensor::<B>(&args.out.id, output);
337            }
338        );
339        let mut streams = OperationStreams::with_inputs([&x, &weight]);
340        if let Some(bias) = bias.as_ref() {
341            streams.tensor(bias)
342        }
343
344        let client = x.client.clone();
345        let desc = ConvTranspose1dOpIr::create(
346            x.into_ir(),
347            weight.into_ir(),
348            bias.map(|bias| bias.into_ir()),
349            options.into(),
350            || client.create_empty_handle(),
351        );
352
353        client
354            .register(
355                streams,
356                OperationIr::Module(ModuleOperationIr::ConvTranspose1d(desc.clone())),
357                ConvTranspose1dOps::<B>::new(desc),
358            )
359            .output()
360    }
361
362    fn conv_transpose2d(
363        x: FloatTensor<Self>,
364        weight: FloatTensor<Self>,
365        bias: Option<FloatTensor<Self>>,
366        options: ConvTransposeOptions<2>,
367    ) -> FloatTensor<Self> {
368        make_ops!(
369            ConvTranspose2dOps,
370            ConvTranspose2dOpIr,
371            |args: &ConvTranspose2dOpIr, handles: &mut HandleContainer<B::Handle>| {
372                let x = handles.get_float_tensor::<B>(&args.x);
373                let weight = handles.get_float_tensor::<B>(&args.weight);
374                let bias = args
375                    .bias
376                    .as_ref()
377                    .map(|bias| handles.get_float_tensor::<B>(bias));
378
379                let output = B::conv_transpose2d(x, weight, bias, args.options.clone().into());
380
381                handles.register_float_tensor::<B>(&args.out.id, output);
382            }
383        );
384        let mut streams = OperationStreams::with_inputs([&x, &weight]);
385        if let Some(bias) = bias.as_ref() {
386            streams.tensor(bias)
387        }
388
389        let client = x.client.clone();
390        let desc = ConvTranspose2dOpIr::create(
391            x.into_ir(),
392            weight.into_ir(),
393            bias.map(|bias| bias.into_ir()),
394            options.into(),
395            || client.create_empty_handle(),
396        );
397
398        client
399            .register(
400                streams,
401                OperationIr::Module(ModuleOperationIr::ConvTranspose2d(desc.clone())),
402                ConvTranspose2dOps::<B>::new(desc),
403            )
404            .output()
405    }
406
407    fn conv_transpose3d(
408        x: FloatTensor<Self>,
409        weight: FloatTensor<Self>,
410        bias: Option<FloatTensor<Self>>,
411        options: ConvTransposeOptions<3>,
412    ) -> FloatTensor<Self> {
413        make_ops!(
414            ConvTranspose3dOps,
415            ConvTranspose3dOpIr,
416            |args: &ConvTranspose3dOpIr, handles: &mut HandleContainer<B::Handle>| {
417                let x = handles.get_float_tensor::<B>(&args.x);
418                let weight = handles.get_float_tensor::<B>(&args.weight);
419                let bias = args
420                    .bias
421                    .as_ref()
422                    .map(|bias| handles.get_float_tensor::<B>(bias));
423
424                let output = B::conv_transpose3d(x, weight, bias, args.options.clone().into());
425
426                handles.register_float_tensor::<B>(&args.out.id, output);
427            }
428        );
429        let mut streams = OperationStreams::with_inputs([&x, &weight]);
430        if let Some(bias) = bias.as_ref() {
431            streams.tensor(bias)
432        }
433
434        let client = x.client.clone();
435        let desc = ConvTranspose3dOpIr::create(
436            x.into_ir(),
437            weight.into_ir(),
438            bias.map(|bias| bias.into_ir()),
439            options.into(),
440            || client.create_empty_handle(),
441        );
442
443        client
444            .register(
445                streams,
446                OperationIr::Module(ModuleOperationIr::ConvTranspose3d(desc.clone())),
447                ConvTranspose3dOps::<B>::new(desc),
448            )
449            .output()
450    }
451
452    fn avg_pool1d(
453        x: FloatTensor<Self>,
454        kernel_size: usize,
455        stride: usize,
456        padding: usize,
457        count_include_pad: bool,
458        ceil_mode: bool,
459    ) -> FloatTensor<Self> {
460        make_ops!(
461            AvgPool1dOps,
462            AvgPool1dOpIr,
463            |args: &AvgPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
464                let x = handles.get_float_tensor::<B>(&args.x);
465                let output = B::avg_pool1d(
466                    x,
467                    args.kernel_size,
468                    args.stride,
469                    args.padding,
470                    args.count_include_pad,
471                    args.ceil_mode,
472                );
473
474                handles.register_float_tensor::<B>(&args.out.id, output);
475            }
476        );
477        let streams = OperationStreams::with_inputs([&x]);
478
479        let client = x.client.clone();
480        let desc = AvgPool1dOpIr::create(
481            x.into_ir(),
482            kernel_size,
483            stride,
484            padding,
485            count_include_pad,
486            ceil_mode,
487            || client.create_empty_handle(),
488        );
489
490        client
491            .register(
492                streams,
493                OperationIr::Module(ModuleOperationIr::AvgPool1d(desc.clone())),
494                AvgPool1dOps::<B>::new(desc),
495            )
496            .output()
497    }
498
499    fn avg_pool2d(
500        x: FloatTensor<Self>,
501        kernel_size: [usize; 2],
502        stride: [usize; 2],
503        padding: [usize; 2],
504        count_include_pad: bool,
505        ceil_mode: bool,
506    ) -> FloatTensor<Self> {
507        make_ops!(
508            AvgPool2dOps,
509            AvgPool2dOpIr,
510            |args: &AvgPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
511                let x = handles.get_float_tensor::<B>(&args.x);
512                let output = B::avg_pool2d(
513                    x,
514                    args.kernel_size,
515                    args.stride,
516                    args.padding,
517                    args.count_include_pad,
518                    args.ceil_mode,
519                );
520
521                handles.register_float_tensor::<B>(&args.out.id, output);
522            }
523        );
524
525        let streams = OperationStreams::with_inputs([&x]);
526
527        let client = x.client.clone();
528        let desc = AvgPool2dOpIr::create(
529            x.into_ir(),
530            kernel_size,
531            stride,
532            padding,
533            count_include_pad,
534            ceil_mode,
535            || client.create_empty_handle(),
536        );
537
538        client
539            .register(
540                streams,
541                OperationIr::Module(ModuleOperationIr::AvgPool2d(desc.clone())),
542                AvgPool2dOps::<B>::new(desc),
543            )
544            .output()
545    }
546
547    fn avg_pool1d_backward(
548        x: FloatTensor<Self>,
549        grad: FloatTensor<Self>,
550        kernel_size: usize,
551        stride: usize,
552        padding: usize,
553        count_include_pad: bool,
554        ceil_mode: bool,
555    ) -> FloatTensor<Self> {
556        make_ops!(
557            AvgPool1dBackwardOps,
558            AvgPool1dBackwardOpIr,
559            |args: &AvgPool1dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
560                let x = handles.get_float_tensor::<B>(&args.x);
561                let grad = handles.get_float_tensor::<B>(&args.grad);
562                let output = B::avg_pool1d_backward(
563                    x,
564                    grad,
565                    args.kernel_size,
566                    args.stride,
567                    args.padding,
568                    args.count_include_pad,
569                    args.ceil_mode,
570                );
571
572                handles.register_float_tensor::<B>(&args.out.id, output);
573            }
574        );
575
576        let streams = OperationStreams::with_inputs([&x, &grad]);
577
578        let client = x.client.clone();
579        let desc = AvgPool1dBackwardOpIr::create(
580            x.into_ir(),
581            grad.into_ir(),
582            kernel_size,
583            stride,
584            padding,
585            count_include_pad,
586            ceil_mode,
587            || client.create_empty_handle(),
588        );
589
590        client
591            .register(
592                streams,
593                OperationIr::Module(ModuleOperationIr::AvgPool1dBackward(desc.clone())),
594                AvgPool1dBackwardOps::<B>::new(desc),
595            )
596            .output()
597    }
598
599    fn avg_pool2d_backward(
600        x: FloatTensor<Self>,
601        grad: FloatTensor<Self>,
602        kernel_size: [usize; 2],
603        stride: [usize; 2],
604        padding: [usize; 2],
605        count_include_pad: bool,
606        ceil_mode: bool,
607    ) -> FloatTensor<Self> {
608        make_ops!(
609            AvgPool2dBackwardOps,
610            AvgPool2dBackwardOpIr,
611            |args: &AvgPool2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
612                let x = handles.get_float_tensor::<B>(&args.x);
613                let grad = handles.get_float_tensor::<B>(&args.grad);
614                let output = B::avg_pool2d_backward(
615                    x,
616                    grad,
617                    args.kernel_size,
618                    args.stride,
619                    args.padding,
620                    args.count_include_pad,
621                    args.ceil_mode,
622                );
623
624                handles.register_float_tensor::<B>(&args.out.id, output);
625            }
626        );
627
628        let streams = OperationStreams::with_inputs([&x, &grad]);
629
630        let client = x.client.clone();
631        let desc = AvgPool2dBackwardOpIr::create(
632            x.into_ir(),
633            grad.into_ir(),
634            kernel_size,
635            stride,
636            padding,
637            count_include_pad,
638            ceil_mode,
639            || client.create_empty_handle(),
640        );
641
642        client
643            .register(
644                streams,
645                OperationIr::Module(ModuleOperationIr::AvgPool2dBackward(desc.clone())),
646                AvgPool2dBackwardOps::<B>::new(desc),
647            )
648            .output()
649    }
650
651    fn max_pool1d(
652        x: FloatTensor<Self>,
653        kernel_size: usize,
654        stride: usize,
655        padding: usize,
656        dilation: usize,
657        ceil_mode: bool,
658    ) -> FloatTensor<Self> {
659        make_ops!(
660            MaxPool1dOps,
661            MaxPool1dOpIr,
662            |args: &MaxPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
663                let x = handles.get_float_tensor::<B>(&args.x);
664                let output = B::max_pool1d(
665                    x,
666                    args.kernel_size,
667                    args.stride,
668                    args.padding,
669                    args.dilation,
670                    args.ceil_mode,
671                );
672
673                handles.register_float_tensor::<B>(&args.out.id, output);
674            }
675        );
676
677        let streams = OperationStreams::with_inputs([&x]);
678
679        let client = x.client.clone();
680        let desc = MaxPool1dOpIr::create(
681            x.into_ir(),
682            kernel_size,
683            stride,
684            padding,
685            dilation,
686            ceil_mode,
687            || client.create_empty_handle(),
688        );
689
690        client
691            .register(
692                streams,
693                OperationIr::Module(ModuleOperationIr::MaxPool1d(desc.clone())),
694                MaxPool1dOps::<B>::new(desc),
695            )
696            .output()
697    }
698
699    fn max_pool2d(
700        x: FloatTensor<Self>,
701        kernel_size: [usize; 2],
702        stride: [usize; 2],
703        padding: [usize; 2],
704        dilation: [usize; 2],
705        ceil_mode: bool,
706    ) -> FloatTensor<Self> {
707        make_ops!(
708            MaxPool2dOps,
709            MaxPool2dOpIr,
710            |args: &MaxPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
711                let x = handles.get_float_tensor::<B>(&args.x);
712                let output = B::max_pool2d(
713                    x,
714                    args.kernel_size,
715                    args.stride,
716                    args.padding,
717                    args.dilation,
718                    args.ceil_mode,
719                );
720
721                handles.register_float_tensor::<B>(&args.out.id, output);
722            }
723        );
724
725        let streams = OperationStreams::with_inputs([&x]);
726
727        let client = x.client.clone();
728        let desc = MaxPool2dOpIr::create(
729            x.into_ir(),
730            kernel_size,
731            stride,
732            padding,
733            dilation,
734            ceil_mode,
735            || client.create_empty_handle(),
736        );
737
738        client
739            .register(
740                streams,
741                OperationIr::Module(ModuleOperationIr::MaxPool2d(desc.clone())),
742                MaxPool2dOps::<B>::new(desc),
743            )
744            .output()
745    }
746
747    fn max_pool1d_with_indices(
748        x: FloatTensor<Self>,
749        kernel_size: usize,
750        stride: usize,
751        padding: usize,
752        dilation: usize,
753        ceil_mode: bool,
754    ) -> MaxPool1dWithIndices<Self> {
755        make_ops!(
756            MaxPool1dWithIndicesOps,
757            MaxPool1dWithIndicesOpIr,
758            |args: &MaxPool1dWithIndicesOpIr, handles: &mut HandleContainer<B::Handle>| {
759                let x = handles.get_float_tensor::<B>(&args.x);
760                let output = B::max_pool1d_with_indices(
761                    x,
762                    args.kernel_size,
763                    args.stride,
764                    args.padding,
765                    args.dilation,
766                    args.ceil_mode,
767                );
768
769                handles.register_float_tensor::<B>(&args.out.id, output.output);
770                handles.register_int_tensor::<B>(&args.out_indices.id, output.indices);
771            }
772        );
773
774        let streams = OperationStreams::with_inputs([&x]);
775
776        let client = x.client.clone();
777        let desc = MaxPool1dWithIndicesOpIr::create(
778            x.into_ir(),
779            kernel_size,
780            stride,
781            padding,
782            dilation,
783            ceil_mode,
784            B::IntElem::dtype(),
785            || client.create_empty_handle(),
786        );
787
788        let [out, out_indices] = client
789            .register(
790                streams,
791                OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndices(desc.clone())),
792                MaxPool1dWithIndicesOps::<B>::new(desc),
793            )
794            .outputs();
795
796        MaxPool1dWithIndices::new(out, out_indices)
797    }
798
799    fn max_pool2d_with_indices(
800        x: FloatTensor<Self>,
801        kernel_size: [usize; 2],
802        stride: [usize; 2],
803        padding: [usize; 2],
804        dilation: [usize; 2],
805        ceil_mode: bool,
806    ) -> MaxPool2dWithIndices<Self> {
807        make_ops!(
808            MaxPool2dWithIndicesOps,
809            MaxPool2dWithIndicesOpIr,
810            |args: &MaxPool2dWithIndicesOpIr, handles: &mut HandleContainer<B::Handle>| {
811                let x = handles.get_float_tensor::<B>(&args.x);
812                let output = B::max_pool2d_with_indices(
813                    x,
814                    args.kernel_size,
815                    args.stride,
816                    args.padding,
817                    args.dilation,
818                    args.ceil_mode,
819                );
820
821                handles.register_float_tensor::<B>(&args.out.id, output.output);
822                handles.register_int_tensor::<B>(&args.out_indices.id, output.indices);
823            }
824        );
825
826        let streams = OperationStreams::with_inputs([&x]);
827
828        let client = x.client.clone();
829        let desc = MaxPool2dWithIndicesOpIr::create(
830            x.into_ir(),
831            kernel_size,
832            stride,
833            padding,
834            dilation,
835            ceil_mode,
836            B::IntElem::dtype(),
837            || client.create_empty_handle(),
838        );
839
840        let [out, out_indices] = client
841            .register(
842                streams,
843                OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndices(desc.clone())),
844                MaxPool2dWithIndicesOps::<B>::new(desc),
845            )
846            .outputs();
847
848        MaxPool2dWithIndices::new(out, out_indices)
849    }
850
851    fn max_pool1d_with_indices_backward(
852        x: FloatTensor<Self>,
853        kernel_size: usize,
854        stride: usize,
855        padding: usize,
856        dilation: usize,
857        ceil_mode: bool,
858        output_grad: FloatTensor<Self>,
859        indices: IntTensor<Self>,
860    ) -> MaxPool1dBackward<Self> {
861        make_ops!(
862            MaxPool1dWithIndicesBackwardOps,
863            MaxPool1dWithIndicesBackwardOpIr,
864            |args: &MaxPool1dWithIndicesBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
865                let x = handles.get_float_tensor::<B>(&args.x);
866                let grad = handles.get_float_tensor::<B>(&args.grad);
867                let indices = handles.get_int_tensor::<B>(&args.indices);
868                let output = B::max_pool1d_with_indices_backward(
869                    x,
870                    args.kernel_size,
871                    args.stride,
872                    args.padding,
873                    args.dilation,
874                    args.ceil_mode,
875                    grad,
876                    indices,
877                );
878
879                handles.register_float_tensor::<B>(&args.out.id, output.x_grad);
880            }
881        );
882
883        let streams = OperationStreams::with_inputs([&x, &output_grad, &indices]);
884
885        let client = x.client.clone();
886        let desc = MaxPool1dWithIndicesBackwardOpIr::create(
887            x.into_ir(),
888            output_grad.into_ir(),
889            indices.into_ir(),
890            kernel_size,
891            stride,
892            padding,
893            dilation,
894            ceil_mode,
895            || client.create_empty_handle(),
896        );
897
898        let out = client
899            .register(
900                streams,
901                OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndicesBackward(
902                    desc.clone(),
903                )),
904                MaxPool1dWithIndicesBackwardOps::<B>::new(desc),
905            )
906            .output();
907
908        MaxPool1dBackward::new(out)
909    }
910
911    fn max_pool2d_with_indices_backward(
912        x: FloatTensor<Self>,
913        kernel_size: [usize; 2],
914        stride: [usize; 2],
915        padding: [usize; 2],
916        dilation: [usize; 2],
917        ceil_mode: bool,
918        output_grad: FloatTensor<Self>,
919        indices: IntTensor<Self>,
920    ) -> MaxPool2dBackward<Self> {
921        make_ops!(
922            MaxPool2dWithIndicesBackwardOps,
923            MaxPool2dWithIndicesBackwardOpIr,
924            |args: &MaxPool2dWithIndicesBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
925                let x = handles.get_float_tensor::<B>(&args.x);
926                let grad = handles.get_float_tensor::<B>(&args.grad);
927                let indices = handles.get_int_tensor::<B>(&args.indices);
928                let output = B::max_pool2d_with_indices_backward(
929                    x,
930                    args.kernel_size,
931                    args.stride,
932                    args.padding,
933                    args.dilation,
934                    args.ceil_mode,
935                    grad,
936                    indices,
937                );
938
939                handles.register_float_tensor::<B>(&args.out.id, output.x_grad);
940            }
941        );
942
943        let streams = OperationStreams::with_inputs([&x, &output_grad, &indices]);
944
945        let client = x.client.clone();
946        let desc = MaxPool2dWithIndicesBackwardOpIr::create(
947            x.into_ir(),
948            output_grad.into_ir(),
949            indices.into_ir(),
950            kernel_size,
951            stride,
952            padding,
953            dilation,
954            ceil_mode,
955            || client.create_empty_handle(),
956        );
957
958        let out = client
959            .register(
960                streams,
961                OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndicesBackward(
962                    desc.clone(),
963                )),
964                MaxPool2dWithIndicesBackwardOps::<B>::new(desc),
965            )
966            .output();
967
968        MaxPool2dBackward::new(out)
969    }
970
971    fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
972        make_ops!(
973            AdaptiveAvgPool1dOps,
974            AdaptiveAvgPool1dOpIr,
975            |args: &AdaptiveAvgPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
976                let x = handles.get_float_tensor::<B>(&args.x);
977                let output = B::adaptive_avg_pool1d(x, args.output_size);
978
979                handles.register_float_tensor::<B>(&args.out.id, output);
980            }
981        );
982
983        let streams = OperationStreams::with_inputs([&x]);
984
985        let client = x.client.clone();
986        let desc = AdaptiveAvgPool1dOpIr::create(x.into_ir(), output_size, || {
987            client.create_empty_handle()
988        });
989
990        client
991            .register(
992                streams,
993                OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d(desc.clone())),
994                AdaptiveAvgPool1dOps::<B>::new(desc),
995            )
996            .output()
997    }
998
999    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
1000        make_ops!(
1001            AdaptiveAvgPool2dOps,
1002            AdaptiveAvgPool2dOpIr,
1003            |args: &AdaptiveAvgPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
1004                let x = handles.get_float_tensor::<B>(&args.x);
1005                let output = B::adaptive_avg_pool2d(x, args.output_size);
1006
1007                handles.register_float_tensor::<B>(&args.out.id, output);
1008            }
1009        );
1010
1011        let streams = OperationStreams::with_inputs([&x]);
1012
1013        let client = x.client.clone();
1014        let desc = AdaptiveAvgPool2dOpIr::create(x.into_ir(), output_size, || {
1015            client.create_empty_handle()
1016        });
1017
1018        client
1019            .register(
1020                streams,
1021                OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d(desc.clone())),
1022                AdaptiveAvgPool2dOps::<B>::new(desc),
1023            )
1024            .output()
1025    }
1026
1027    fn adaptive_avg_pool1d_backward(
1028        x: FloatTensor<Self>,
1029        grad: FloatTensor<Self>,
1030    ) -> FloatTensor<Self> {
1031        make_ops!(
1032            AdaptiveAvgPool1dBackwardOps,
1033            AdaptiveAvgPool1dBackwardOpIr,
1034            |args: &AdaptiveAvgPool1dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1035                let x = handles.get_float_tensor::<B>(&args.x);
1036                let grad = handles.get_float_tensor::<B>(&args.grad);
1037                let output = B::adaptive_avg_pool1d_backward(x, grad);
1038
1039                handles.register_float_tensor::<B>(&args.out.id, output);
1040            }
1041        );
1042
1043        let streams = OperationStreams::with_inputs([&x, &grad]);
1044
1045        let client = x.client.clone();
1046        let desc = AdaptiveAvgPool1dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
1047            client.create_empty_handle()
1048        });
1049
1050        client
1051            .register(
1052                streams,
1053                OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1dBackward(desc.clone())),
1054                AdaptiveAvgPool1dBackwardOps::<B>::new(desc),
1055            )
1056            .output()
1057    }
1058
1059    fn adaptive_avg_pool2d_backward(
1060        x: FloatTensor<Self>,
1061        grad: FloatTensor<Self>,
1062    ) -> FloatTensor<Self> {
1063        make_ops!(
1064            AdaptiveAvgPool2dBackwardOps,
1065            AdaptiveAvgPool2dBackwardOpIr,
1066            |args: &AdaptiveAvgPool2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1067                let x = handles.get_float_tensor::<B>(&args.x);
1068                let grad = handles.get_float_tensor::<B>(&args.grad);
1069                let output = B::adaptive_avg_pool2d_backward(x, grad);
1070
1071                handles.register_float_tensor::<B>(&args.out.id, output);
1072            }
1073        );
1074        let streams = OperationStreams::with_inputs([&x, &grad]);
1075
1076        let client = x.client.clone();
1077        let desc = AdaptiveAvgPool2dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
1078            client.create_empty_handle()
1079        });
1080
1081        client
1082            .register(
1083                streams,
1084                OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2dBackward(desc.clone())),
1085                AdaptiveAvgPool2dBackwardOps::<B>::new(desc),
1086            )
1087            .output()
1088    }
1089
1090    fn interpolate(
1091        x: FloatTensor<Self>,
1092        output_size: [usize; 2],
1093        options: InterpolateOptions,
1094    ) -> FloatTensor<Self> {
1095        make_ops!(
1096            InterpolateOps,
1097            InterpolateOpIr,
1098            |args: &InterpolateOpIr, handles: &mut HandleContainer<B::Handle>| {
1099                let x = handles.get_float_tensor::<B>(&args.x);
1100                let output = B::interpolate(x, args.output_size, args.options.clone().into());
1101                handles.register_float_tensor::<B>(&args.out.id, output);
1102            }
1103        );
1104
1105        let streams = OperationStreams::with_inputs([&x]);
1106
1107        let client = x.client.clone();
1108        let desc = InterpolateOpIr::create(x.into_ir(), output_size, options.into(), || {
1109            client.create_empty_handle()
1110        });
1111
1112        client
1113            .register(
1114                streams,
1115                OperationIr::Module(ModuleOperationIr::Interpolate(desc.clone())),
1116                InterpolateOps::<B>::new(desc),
1117            )
1118            .output()
1119    }
1120
1121    fn interpolate_backward(
1122        x: FloatTensor<Self>,
1123        grad: FloatTensor<Self>,
1124        output_size: [usize; 2],
1125        options: InterpolateOptions,
1126    ) -> FloatTensor<Self> {
1127        make_ops!(
1128            InterpolateBackwardOps,
1129            InterpolateBackwardOpIr,
1130            |args: &InterpolateBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1131                let x = handles.get_float_tensor::<B>(&args.x);
1132                let grad = handles.get_float_tensor::<B>(&args.grad);
1133                let output =
1134                    B::interpolate_backward(x, grad, args.output_size, args.options.clone().into());
1135
1136                handles.register_float_tensor::<B>(&args.out.id, output);
1137            }
1138        );
1139
1140        let streams = OperationStreams::with_inputs([&x, &grad]);
1141
1142        let client = x.client.clone();
1143        let desc = InterpolateBackwardOpIr::create(
1144            x.into_ir(),
1145            grad.into_ir(),
1146            output_size,
1147            options.into(),
1148            || client.create_empty_handle(),
1149        );
1150
1151        client
1152            .register(
1153                streams,
1154                OperationIr::Module(ModuleOperationIr::InterpolateBackward(desc.clone())),
1155                InterpolateBackwardOps::<B>::new(desc),
1156            )
1157            .output()
1158    }
1159}