Skip to main content

burn_fusion/ops/
module.rs

1use crate::{
2    Fusion, FusionBackend,
3    stream::{OperationStreams, execution::Operation},
4};
5use burn_backend::{
6    Element,
7    ops::{
8        ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions,
9        InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,
10        MaxPool2dWithIndices, ModuleOps,
11    },
12    tensor::{FloatTensor, IntTensor},
13};
14use burn_ir::*;
15use std::marker::PhantomData;
16
17macro_rules! make_ops {
18    ($name:ident, $desc:ty, $fn:expr) => {
19        #[derive(new, Debug)]
20        struct $name<B: FusionBackend> {
21            desc: $desc,
22            _b: PhantomData<B>,
23        }
24
25        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
26            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
27                #[allow(clippy::redundant_closure_call)]
28                $fn(&self.desc, handles)
29            }
30        }
31    };
32}
33
34impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
35    // linear and its backward ops fall back to the default ModuleOps impl,
36    // which decomposes into matmul + add / matmul + sum. This preserves
37    // downstream fusion in burn-cubecl-fusion, which matches on those
38    // primitive IR nodes.
39    fn conv1d(
40        x: FloatTensor<Self>,
41        weight: FloatTensor<Self>,
42        bias: Option<FloatTensor<Self>>,
43        options: ConvOptions<1>,
44    ) -> FloatTensor<Self> {
45        make_ops!(Conv1dOps, Conv1dOpIr, |desc: &Conv1dOpIr,
46                                          handles: &mut HandleContainer<
47            B::Handle,
48        >| {
49            let x = handles.get_float_tensor::<B>(&desc.x);
50            let weight = handles.get_float_tensor::<B>(&desc.weight);
51            let bias = desc
52                .bias
53                .as_ref()
54                .map(|bias| handles.get_float_tensor::<B>(bias));
55            let output = B::conv1d(x, weight, bias, desc.options.clone().into());
56            handles.register_float_tensor::<B>(&desc.out.id, output);
57        });
58
59        let mut streams = OperationStreams::with_inputs([&x, &weight]);
60        if let Some(bias) = bias.as_ref() {
61            streams.tensor(bias)
62        }
63
64        let client = x.client.clone();
65        let desc = Conv1dOpIr::create(
66            x.into_ir(),
67            weight.into_ir(),
68            bias.map(|bias| bias.into_ir()),
69            options.into(),
70            || client.create_empty_handle(),
71        );
72
73        client
74            .register(
75                streams,
76                OperationIr::Module(ModuleOperationIr::Conv1d(desc.clone())),
77                Conv1dOps::<B>::new(desc),
78            )
79            .output()
80    }
81
82    fn conv1d_x_backward(
83        x: FloatTensor<Fusion<B>>,
84        weight: FloatTensor<Fusion<B>>,
85        output_grad: FloatTensor<Fusion<B>>,
86        options: ConvOptions<1>,
87    ) -> FloatTensor<Fusion<B>> {
88        make_ops!(
89            Conv1dXBackwardOps,
90            Conv1dXBackwardOpIr,
91            |desc: &Conv1dXBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
92                let x = handles.get_float_tensor::<B>(&desc.x);
93                let weight = handles.get_float_tensor::<B>(&desc.weight);
94                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
95                let output =
96                    B::conv1d_x_backward(x, weight, output_grad, desc.options.clone().into());
97                handles.register_float_tensor::<B>(&desc.out.id, output);
98            }
99        );
100
101        let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);
102
103        let client = x.client.clone();
104        let desc = Conv1dXBackwardOpIr::create(
105            x.into_ir(),
106            weight.into_ir(),
107            output_grad.into_ir(),
108            options.into(),
109            || client.create_empty_handle(),
110        );
111
112        client
113            .register(
114                streams,
115                OperationIr::Module(ModuleOperationIr::Conv1dXBackward(desc.clone())),
116                Conv1dXBackwardOps::<B>::new(desc),
117            )
118            .output()
119    }
120
121    fn conv1d_weight_backward(
122        x: FloatTensor<Fusion<B>>,
123        weight: FloatTensor<Fusion<B>>,
124        output_grad: FloatTensor<Fusion<B>>,
125        options: ConvOptions<1>,
126    ) -> FloatTensor<Fusion<B>> {
127        make_ops!(
128            Conv1dWeightBackwardOps,
129            Conv1dWeightBackwardOpIr,
130            |desc: &Conv1dWeightBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
131                let x = handles.get_float_tensor::<B>(&desc.x);
132                let weight = handles.get_float_tensor::<B>(&desc.weight);
133                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
134                let output =
135                    B::conv1d_weight_backward(x, weight, output_grad, desc.options.clone().into());
136                handles.register_float_tensor::<B>(&desc.out.id, output);
137            }
138        );
139
140        let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);
141
142        let client = x.client.clone();
143        let desc = Conv1dWeightBackwardOpIr::create(
144            x.into_ir(),
145            weight.into_ir(),
146            output_grad.into_ir(),
147            options.into(),
148            || client.create_empty_handle(),
149        );
150
151        client
152            .register(
153                streams,
154                OperationIr::Module(ModuleOperationIr::Conv1dWeightBackward(desc.clone())),
155                Conv1dWeightBackwardOps::<B>::new(desc),
156            )
157            .output()
158    }
159
160    fn conv1d_bias_backward(
161        x: FloatTensor<Fusion<B>>,
162        bias: FloatTensor<Fusion<B>>,
163        output_grad: FloatTensor<Fusion<B>>,
164    ) -> FloatTensor<Fusion<B>> {
165        make_ops!(
166            Conv1dBiasBackwardOps,
167            Conv1dBiasBackwardOpIr,
168            |desc: &Conv1dBiasBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
169                let x = handles.get_float_tensor::<B>(&desc.x);
170                let bias = handles.get_float_tensor::<B>(&desc.bias);
171                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
172                let output = B::conv1d_bias_backward(x, bias, output_grad);
173                handles.register_float_tensor::<B>(&desc.out.id, output);
174            }
175        );
176
177        let streams = OperationStreams::with_inputs([&x, &bias, &output_grad]);
178
179        let client = x.client.clone();
180        let desc = Conv1dBiasBackwardOpIr::create(
181            x.into_ir(),
182            bias.into_ir(),
183            output_grad.into_ir(),
184            || client.create_empty_handle(),
185        );
186
187        client
188            .register(
189                streams,
190                OperationIr::Module(ModuleOperationIr::Conv1dBiasBackward(desc.clone())),
191                Conv1dBiasBackwardOps::<B>::new(desc),
192            )
193            .output()
194    }
195
196    fn conv2d(
197        x: FloatTensor<Self>,
198        weight: FloatTensor<Self>,
199        bias: Option<FloatTensor<Self>>,
200        options: ConvOptions<2>,
201    ) -> FloatTensor<Self> {
202        make_ops!(Conv2dOps, Conv2dOpIr, |args: &Conv2dOpIr,
203                                          handles: &mut HandleContainer<
204            B::Handle,
205        >| {
206            let x = handles.get_float_tensor::<B>(&args.x);
207            let weight = handles.get_float_tensor::<B>(&args.weight);
208            let bias = args
209                .bias
210                .as_ref()
211                .map(|bias| handles.get_float_tensor::<B>(bias));
212
213            let output = B::conv2d(x, weight, bias, args.options.clone().into());
214
215            handles.register_float_tensor::<B>(&args.out.id, output);
216        });
217
218        let mut streams = OperationStreams::with_inputs([&x, &weight]);
219        if let Some(bias) = bias.as_ref() {
220            streams.tensor(bias)
221        }
222
223        let client = x.client.clone();
224        let desc = Conv2dOpIr::create(
225            x.into_ir(),
226            weight.into_ir(),
227            bias.map(|bias| bias.into_ir()),
228            options.into(),
229            || client.create_empty_handle(),
230        );
231
232        client
233            .register(
234                streams,
235                OperationIr::Module(ModuleOperationIr::Conv2d(desc.clone())),
236                Conv2dOps::<B>::new(desc),
237            )
238            .output()
239    }
240
241    fn conv2d_x_backward(
242        x: FloatTensor<Fusion<B>>,
243        weight: FloatTensor<Fusion<B>>,
244        output_grad: FloatTensor<Fusion<B>>,
245        options: ConvOptions<2>,
246    ) -> FloatTensor<Fusion<B>> {
247        make_ops!(
248            Conv2dXBackwardOps,
249            Conv2dXBackwardOpIr,
250            |desc: &Conv2dXBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
251                let x = handles.get_float_tensor::<B>(&desc.x);
252                let weight = handles.get_float_tensor::<B>(&desc.weight);
253                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
254                let output =
255                    B::conv2d_x_backward(x, weight, output_grad, desc.options.clone().into());
256                handles.register_float_tensor::<B>(&desc.out.id, output);
257            }
258        );
259
260        let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);
261
262        let client = x.client.clone();
263        let desc = Conv2dXBackwardOpIr::create(
264            x.into_ir(),
265            weight.into_ir(),
266            output_grad.into_ir(),
267            options.into(),
268            || client.create_empty_handle(),
269        );
270
271        client
272            .register(
273                streams,
274                OperationIr::Module(ModuleOperationIr::Conv2dXBackward(desc.clone())),
275                Conv2dXBackwardOps::<B>::new(desc),
276            )
277            .output()
278    }
279
280    fn conv2d_weight_backward(
281        x: FloatTensor<Fusion<B>>,
282        weight: FloatTensor<Fusion<B>>,
283        output_grad: FloatTensor<Fusion<B>>,
284        options: ConvOptions<2>,
285    ) -> FloatTensor<Fusion<B>> {
286        make_ops!(
287            Conv2dWeightBackwardOps,
288            Conv2dWeightBackwardOpIr,
289            |desc: &Conv2dWeightBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
290                let x = handles.get_float_tensor::<B>(&desc.x);
291                let weight = handles.get_float_tensor::<B>(&desc.weight);
292                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
293                let output =
294                    B::conv2d_weight_backward(x, weight, output_grad, desc.options.clone().into());
295                handles.register_float_tensor::<B>(&desc.out.id, output);
296            }
297        );
298
299        let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);
300
301        let client = x.client.clone();
302        let desc = Conv2dWeightBackwardOpIr::create(
303            x.into_ir(),
304            weight.into_ir(),
305            output_grad.into_ir(),
306            options.into(),
307            || client.create_empty_handle(),
308        );
309
310        client
311            .register(
312                streams,
313                OperationIr::Module(ModuleOperationIr::Conv2dWeightBackward(desc.clone())),
314                Conv2dWeightBackwardOps::<B>::new(desc),
315            )
316            .output()
317    }
318
319    fn conv2d_bias_backward(
320        x: FloatTensor<Fusion<B>>,
321        bias: FloatTensor<Fusion<B>>,
322        output_grad: FloatTensor<Fusion<B>>,
323    ) -> FloatTensor<Fusion<B>> {
324        make_ops!(
325            Conv2dBiasBackwardOps,
326            Conv2dBiasBackwardOpIr,
327            |desc: &Conv2dBiasBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
328                let x = handles.get_float_tensor::<B>(&desc.x);
329                let bias = handles.get_float_tensor::<B>(&desc.bias);
330                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
331                let output = B::conv2d_bias_backward(x, bias, output_grad);
332                handles.register_float_tensor::<B>(&desc.out.id, output);
333            }
334        );
335
336        let streams = OperationStreams::with_inputs([&x, &bias, &output_grad]);
337
338        let client = x.client.clone();
339        let desc = Conv2dBiasBackwardOpIr::create(
340            x.into_ir(),
341            bias.into_ir(),
342            output_grad.into_ir(),
343            || client.create_empty_handle(),
344        );
345
346        client
347            .register(
348                streams,
349                OperationIr::Module(ModuleOperationIr::Conv2dBiasBackward(desc.clone())),
350                Conv2dBiasBackwardOps::<B>::new(desc),
351            )
352            .output()
353    }
354
355    fn deform_conv2d(
356        x: FloatTensor<Self>,
357        offset: FloatTensor<Self>,
358        weight: FloatTensor<Self>,
359        mask: Option<FloatTensor<Self>>,
360        bias: Option<FloatTensor<Self>>,
361        options: DeformConvOptions<2>,
362    ) -> FloatTensor<Self> {
363        make_ops!(
364            DeformConv2dOps,
365            DeformConv2dOpIr,
366            |args: &DeformConv2dOpIr, handles: &mut HandleContainer<B::Handle>| {
367                let x = handles.get_float_tensor::<B>(&args.x);
368                let offset = handles.get_float_tensor::<B>(&args.offset);
369                let weight = handles.get_float_tensor::<B>(&args.weight);
370                let mask = args
371                    .mask
372                    .as_ref()
373                    .map(|mask| handles.get_float_tensor::<B>(mask));
374                let bias = args
375                    .bias
376                    .as_ref()
377                    .map(|bias| handles.get_float_tensor::<B>(bias));
378
379                let output =
380                    B::deform_conv2d(x, offset, weight, mask, bias, args.options.clone().into());
381
382                handles.register_float_tensor::<B>(&args.out.id, output);
383            }
384        );
385        let mut streams = OperationStreams::with_inputs([&x, &offset, &weight]);
386        if let Some(bias) = bias.as_ref() {
387            streams.tensor(bias)
388        }
389        if let Some(mask) = mask.as_ref() {
390            streams.tensor(mask)
391        }
392
393        let client = x.client.clone();
394        let desc = DeformConv2dOpIr::create(
395            x.into_ir(),
396            offset.into_ir(),
397            weight.into_ir(),
398            mask.map(|mask| mask.into_ir()),
399            bias.map(|bias| bias.into_ir()),
400            options.into(),
401            || client.create_empty_handle(),
402        );
403
404        client
405            .register(
406                streams,
407                OperationIr::Module(ModuleOperationIr::DeformableConv2d(Box::new(desc.clone()))),
408                DeformConv2dOps::<B>::new(desc),
409            )
410            .output()
411    }
412
413    fn deform_conv2d_backward(
414        x: FloatTensor<Self>,
415        offset: FloatTensor<Self>,
416        weight: FloatTensor<Self>,
417        mask: Option<FloatTensor<Self>>,
418        bias: Option<FloatTensor<Self>>,
419        output_grad: FloatTensor<Self>,
420        options: DeformConvOptions<2>,
421    ) -> DeformConv2dBackward<Self> {
422        make_ops!(
423            DeformConv2dBackwardOps,
424            DeformConv2dBackwardOpIr,
425            |args: &DeformConv2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
426                let x = handles.get_float_tensor::<B>(&args.x);
427                let offset = handles.get_float_tensor::<B>(&args.offset);
428                let weight = handles.get_float_tensor::<B>(&args.weight);
429                let mask = args
430                    .mask
431                    .as_ref()
432                    .map(|mask| handles.get_float_tensor::<B>(mask));
433                let bias = args
434                    .bias
435                    .as_ref()
436                    .map(|bias| handles.get_float_tensor::<B>(bias));
437                let output_grad = handles.get_float_tensor::<B>(&args.out_grad);
438
439                let output = B::deform_conv2d_backward(
440                    x,
441                    offset,
442                    weight,
443                    mask,
444                    bias,
445                    output_grad,
446                    args.options.clone().into(),
447                );
448
449                handles.register_float_tensor::<B>(&args.input_grad.id, output.x_grad);
450                handles.register_float_tensor::<B>(&args.offset_grad.id, output.offset_grad);
451                handles.register_float_tensor::<B>(&args.weight_grad.id, output.weight_grad);
452                if let Some((mask_grad, field)) = output.mask_grad.zip(args.mask_grad.as_ref()) {
453                    handles.register_float_tensor::<B>(&field.id, mask_grad);
454                }
455                if let Some((bias_grad, field)) = output.bias_grad.zip(args.bias_grad.as_ref()) {
456                    handles.register_float_tensor::<B>(&field.id, bias_grad);
457                }
458            }
459        );
460
461        let has_bias = bias.is_some();
462        let has_mask = mask.is_some();
463
464        let mut streams = OperationStreams::with_inputs([&x, &offset, &weight, &output_grad]);
465        if let Some(bias) = bias.as_ref() {
466            streams.tensor(bias);
467        }
468        if let Some(mask) = mask.as_ref() {
469            streams.tensor(mask);
470        }
471
472        let client = x.client.clone();
473        let desc = DeformConv2dBackwardOpIr::create(
474            x.into_ir(),
475            offset.into_ir(),
476            weight.into_ir(),
477            mask.map(|mask| mask.into_ir()),
478            bias.map(|bias| bias.into_ir()),
479            output_grad.into_ir(),
480            options.into(),
481            || client.create_empty_handle(),
482        );
483
484        let mut outputs = client
485            .register(
486                streams,
487                OperationIr::Module(ModuleOperationIr::DeformableConv2dBackward(Box::new(
488                    desc.clone(),
489                ))),
490                DeformConv2dBackwardOps::<B>::new(desc),
491            )
492            .into_iter();
493
494        // When the number of outputs is variable, the order is important
495        let input_grad = outputs.next().unwrap();
496        let offset_grad = outputs.next().unwrap();
497        let weight_grad = outputs.next().unwrap();
498        let mask_grad = has_mask.then(|| outputs.next().unwrap());
499        let bias_grad = has_bias.then(|| outputs.next().unwrap());
500
501        DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
502    }
503
504    fn conv3d(
505        x: FloatTensor<Self>,
506        weight: FloatTensor<Self>,
507        bias: Option<FloatTensor<Self>>,
508        options: ConvOptions<3>,
509    ) -> FloatTensor<Self> {
510        make_ops!(Conv3dOps, Conv3dOpIr, |args: &Conv3dOpIr,
511                                          handles: &mut HandleContainer<
512            B::Handle,
513        >| {
514            let x = handles.get_float_tensor::<B>(&args.x);
515            let weight = handles.get_float_tensor::<B>(&args.weight);
516            let bias = args
517                .bias
518                .as_ref()
519                .map(|bias| handles.get_float_tensor::<B>(bias));
520
521            let output = B::conv3d(x, weight, bias, args.options.clone().into());
522
523            handles.register_float_tensor::<B>(&args.out.id, output);
524        });
525
526        let mut streams = OperationStreams::with_inputs([&x, &weight]);
527        if let Some(bias) = bias.as_ref() {
528            streams.tensor(bias)
529        }
530
531        let client = x.client.clone();
532        let desc = Conv3dOpIr::create(
533            x.into_ir(),
534            weight.into_ir(),
535            bias.map(|bias| bias.into_ir()),
536            options.into(),
537            || client.create_empty_handle(),
538        );
539
540        client
541            .register(
542                streams,
543                OperationIr::Module(ModuleOperationIr::Conv3d(desc.clone())),
544                Conv3dOps::<B>::new(desc),
545            )
546            .output()
547    }
548
549    fn conv3d_x_backward(
550        x: FloatTensor<Fusion<B>>,
551        weight: FloatTensor<Fusion<B>>,
552        output_grad: FloatTensor<Fusion<B>>,
553        options: ConvOptions<3>,
554    ) -> FloatTensor<Fusion<B>> {
555        make_ops!(
556            Conv3dXBackwardOps,
557            Conv3dXBackwardOpIr,
558            |desc: &Conv3dXBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
559                let x = handles.get_float_tensor::<B>(&desc.x);
560                let weight = handles.get_float_tensor::<B>(&desc.weight);
561                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
562                let output =
563                    B::conv3d_x_backward(x, weight, output_grad, desc.options.clone().into());
564                handles.register_float_tensor::<B>(&desc.out.id, output);
565            }
566        );
567
568        let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);
569
570        let client = x.client.clone();
571        let desc = Conv3dXBackwardOpIr::create(
572            x.into_ir(),
573            weight.into_ir(),
574            output_grad.into_ir(),
575            options.into(),
576            || client.create_empty_handle(),
577        );
578
579        client
580            .register(
581                streams,
582                OperationIr::Module(ModuleOperationIr::Conv3dXBackward(desc.clone())),
583                Conv3dXBackwardOps::<B>::new(desc),
584            )
585            .output()
586    }
587
588    fn conv3d_weight_backward(
589        x: FloatTensor<Fusion<B>>,
590        weight: FloatTensor<Fusion<B>>,
591        output_grad: FloatTensor<Fusion<B>>,
592        options: ConvOptions<3>,
593    ) -> FloatTensor<Fusion<B>> {
594        make_ops!(
595            Conv3dWeightBackwardOps,
596            Conv3dWeightBackwardOpIr,
597            |desc: &Conv3dWeightBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
598                let x = handles.get_float_tensor::<B>(&desc.x);
599                let weight = handles.get_float_tensor::<B>(&desc.weight);
600                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
601                let output =
602                    B::conv3d_weight_backward(x, weight, output_grad, desc.options.clone().into());
603                handles.register_float_tensor::<B>(&desc.out.id, output);
604            }
605        );
606
607        let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);
608
609        let client = x.client.clone();
610        let desc = Conv3dWeightBackwardOpIr::create(
611            x.into_ir(),
612            weight.into_ir(),
613            output_grad.into_ir(),
614            options.into(),
615            || client.create_empty_handle(),
616        );
617
618        client
619            .register(
620                streams,
621                OperationIr::Module(ModuleOperationIr::Conv3dWeightBackward(desc.clone())),
622                Conv3dWeightBackwardOps::<B>::new(desc),
623            )
624            .output()
625    }
626
627    fn conv3d_bias_backward(
628        x: FloatTensor<Fusion<B>>,
629        bias: FloatTensor<Fusion<B>>,
630        output_grad: FloatTensor<Fusion<B>>,
631    ) -> FloatTensor<Fusion<B>> {
632        make_ops!(
633            Conv3dBiasBackwardOps,
634            Conv3dBiasBackwardOpIr,
635            |desc: &Conv3dBiasBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
636                let x = handles.get_float_tensor::<B>(&desc.x);
637                let bias = handles.get_float_tensor::<B>(&desc.bias);
638                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
639                let output = B::conv3d_bias_backward(x, bias, output_grad);
640                handles.register_float_tensor::<B>(&desc.out.id, output);
641            }
642        );
643
644        let streams = OperationStreams::with_inputs([&x, &bias, &output_grad]);
645
646        let client = x.client.clone();
647        let desc = Conv3dBiasBackwardOpIr::create(
648            x.into_ir(),
649            bias.into_ir(),
650            output_grad.into_ir(),
651            || client.create_empty_handle(),
652        );
653
654        client
655            .register(
656                streams,
657                OperationIr::Module(ModuleOperationIr::Conv3dBiasBackward(desc.clone())),
658                Conv3dBiasBackwardOps::<B>::new(desc),
659            )
660            .output()
661    }
662
663    fn conv_transpose1d(
664        x: FloatTensor<Self>,
665        weight: FloatTensor<Self>,
666        bias: Option<FloatTensor<Self>>,
667        options: ConvTransposeOptions<1>,
668    ) -> FloatTensor<Self> {
669        make_ops!(
670            ConvTranspose1dOps,
671            ConvTranspose1dOpIr,
672            |args: &ConvTranspose1dOpIr, handles: &mut HandleContainer<B::Handle>| {
673                let x = handles.get_float_tensor::<B>(&args.x);
674                let weight = handles.get_float_tensor::<B>(&args.weight);
675                let bias = args
676                    .bias
677                    .as_ref()
678                    .map(|bias| handles.get_float_tensor::<B>(bias));
679
680                let output = B::conv_transpose1d(x, weight, bias, args.options.clone().into());
681
682                handles.register_float_tensor::<B>(&args.out.id, output);
683            }
684        );
685        let mut streams = OperationStreams::with_inputs([&x, &weight]);
686        if let Some(bias) = bias.as_ref() {
687            streams.tensor(bias)
688        }
689
690        let client = x.client.clone();
691        let desc = ConvTranspose1dOpIr::create(
692            x.into_ir(),
693            weight.into_ir(),
694            bias.map(|bias| bias.into_ir()),
695            options.into(),
696            || client.create_empty_handle(),
697        );
698
699        client
700            .register(
701                streams,
702                OperationIr::Module(ModuleOperationIr::ConvTranspose1d(desc.clone())),
703                ConvTranspose1dOps::<B>::new(desc),
704            )
705            .output()
706    }
707
708    fn conv_transpose2d(
709        x: FloatTensor<Self>,
710        weight: FloatTensor<Self>,
711        bias: Option<FloatTensor<Self>>,
712        options: ConvTransposeOptions<2>,
713    ) -> FloatTensor<Self> {
714        make_ops!(
715            ConvTranspose2dOps,
716            ConvTranspose2dOpIr,
717            |args: &ConvTranspose2dOpIr, handles: &mut HandleContainer<B::Handle>| {
718                let x = handles.get_float_tensor::<B>(&args.x);
719                let weight = handles.get_float_tensor::<B>(&args.weight);
720                let bias = args
721                    .bias
722                    .as_ref()
723                    .map(|bias| handles.get_float_tensor::<B>(bias));
724
725                let output = B::conv_transpose2d(x, weight, bias, args.options.clone().into());
726
727                handles.register_float_tensor::<B>(&args.out.id, output);
728            }
729        );
730        let mut streams = OperationStreams::with_inputs([&x, &weight]);
731        if let Some(bias) = bias.as_ref() {
732            streams.tensor(bias)
733        }
734
735        let client = x.client.clone();
736        let desc = ConvTranspose2dOpIr::create(
737            x.into_ir(),
738            weight.into_ir(),
739            bias.map(|bias| bias.into_ir()),
740            options.into(),
741            || client.create_empty_handle(),
742        );
743
744        client
745            .register(
746                streams,
747                OperationIr::Module(ModuleOperationIr::ConvTranspose2d(desc.clone())),
748                ConvTranspose2dOps::<B>::new(desc),
749            )
750            .output()
751    }
752
753    fn conv_transpose3d(
754        x: FloatTensor<Self>,
755        weight: FloatTensor<Self>,
756        bias: Option<FloatTensor<Self>>,
757        options: ConvTransposeOptions<3>,
758    ) -> FloatTensor<Self> {
759        make_ops!(
760            ConvTranspose3dOps,
761            ConvTranspose3dOpIr,
762            |args: &ConvTranspose3dOpIr, handles: &mut HandleContainer<B::Handle>| {
763                let x = handles.get_float_tensor::<B>(&args.x);
764                let weight = handles.get_float_tensor::<B>(&args.weight);
765                let bias = args
766                    .bias
767                    .as_ref()
768                    .map(|bias| handles.get_float_tensor::<B>(bias));
769
770                let output = B::conv_transpose3d(x, weight, bias, args.options.clone().into());
771
772                handles.register_float_tensor::<B>(&args.out.id, output);
773            }
774        );
775        let mut streams = OperationStreams::with_inputs([&x, &weight]);
776        if let Some(bias) = bias.as_ref() {
777            streams.tensor(bias)
778        }
779
780        let client = x.client.clone();
781        let desc = ConvTranspose3dOpIr::create(
782            x.into_ir(),
783            weight.into_ir(),
784            bias.map(|bias| bias.into_ir()),
785            options.into(),
786            || client.create_empty_handle(),
787        );
788
789        client
790            .register(
791                streams,
792                OperationIr::Module(ModuleOperationIr::ConvTranspose3d(desc.clone())),
793                ConvTranspose3dOps::<B>::new(desc),
794            )
795            .output()
796    }
797
798    fn avg_pool1d(
799        x: FloatTensor<Self>,
800        kernel_size: usize,
801        stride: usize,
802        padding: usize,
803        count_include_pad: bool,
804        ceil_mode: bool,
805    ) -> FloatTensor<Self> {
806        make_ops!(
807            AvgPool1dOps,
808            AvgPool1dOpIr,
809            |args: &AvgPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
810                let x = handles.get_float_tensor::<B>(&args.x);
811                let output = B::avg_pool1d(
812                    x,
813                    args.kernel_size,
814                    args.stride,
815                    args.padding,
816                    args.count_include_pad,
817                    args.ceil_mode,
818                );
819
820                handles.register_float_tensor::<B>(&args.out.id, output);
821            }
822        );
823        let streams = OperationStreams::with_inputs([&x]);
824
825        let client = x.client.clone();
826        let desc = AvgPool1dOpIr::create(
827            x.into_ir(),
828            kernel_size,
829            stride,
830            padding,
831            count_include_pad,
832            ceil_mode,
833            || client.create_empty_handle(),
834        );
835
836        client
837            .register(
838                streams,
839                OperationIr::Module(ModuleOperationIr::AvgPool1d(desc.clone())),
840                AvgPool1dOps::<B>::new(desc),
841            )
842            .output()
843    }
844
845    fn avg_pool2d(
846        x: FloatTensor<Self>,
847        kernel_size: [usize; 2],
848        stride: [usize; 2],
849        padding: [usize; 2],
850        count_include_pad: bool,
851        ceil_mode: bool,
852    ) -> FloatTensor<Self> {
853        make_ops!(
854            AvgPool2dOps,
855            AvgPool2dOpIr,
856            |args: &AvgPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
857                let x = handles.get_float_tensor::<B>(&args.x);
858                let output = B::avg_pool2d(
859                    x,
860                    args.kernel_size,
861                    args.stride,
862                    args.padding,
863                    args.count_include_pad,
864                    args.ceil_mode,
865                );
866
867                handles.register_float_tensor::<B>(&args.out.id, output);
868            }
869        );
870
871        let streams = OperationStreams::with_inputs([&x]);
872
873        let client = x.client.clone();
874        let desc = AvgPool2dOpIr::create(
875            x.into_ir(),
876            kernel_size,
877            stride,
878            padding,
879            count_include_pad,
880            ceil_mode,
881            || client.create_empty_handle(),
882        );
883
884        client
885            .register(
886                streams,
887                OperationIr::Module(ModuleOperationIr::AvgPool2d(desc.clone())),
888                AvgPool2dOps::<B>::new(desc),
889            )
890            .output()
891    }
892
893    fn avg_pool1d_backward(
894        x: FloatTensor<Self>,
895        grad: FloatTensor<Self>,
896        kernel_size: usize,
897        stride: usize,
898        padding: usize,
899        count_include_pad: bool,
900        ceil_mode: bool,
901    ) -> FloatTensor<Self> {
902        make_ops!(
903            AvgPool1dBackwardOps,
904            AvgPool1dBackwardOpIr,
905            |args: &AvgPool1dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
906                let x = handles.get_float_tensor::<B>(&args.x);
907                let grad = handles.get_float_tensor::<B>(&args.grad);
908                let output = B::avg_pool1d_backward(
909                    x,
910                    grad,
911                    args.kernel_size,
912                    args.stride,
913                    args.padding,
914                    args.count_include_pad,
915                    args.ceil_mode,
916                );
917
918                handles.register_float_tensor::<B>(&args.out.id, output);
919            }
920        );
921
922        let streams = OperationStreams::with_inputs([&x, &grad]);
923
924        let client = x.client.clone();
925        let desc = AvgPool1dBackwardOpIr::create(
926            x.into_ir(),
927            grad.into_ir(),
928            kernel_size,
929            stride,
930            padding,
931            count_include_pad,
932            ceil_mode,
933            || client.create_empty_handle(),
934        );
935
936        client
937            .register(
938                streams,
939                OperationIr::Module(ModuleOperationIr::AvgPool1dBackward(desc.clone())),
940                AvgPool1dBackwardOps::<B>::new(desc),
941            )
942            .output()
943    }
944
945    fn avg_pool2d_backward(
946        x: FloatTensor<Self>,
947        grad: FloatTensor<Self>,
948        kernel_size: [usize; 2],
949        stride: [usize; 2],
950        padding: [usize; 2],
951        count_include_pad: bool,
952        ceil_mode: bool,
953    ) -> FloatTensor<Self> {
954        make_ops!(
955            AvgPool2dBackwardOps,
956            AvgPool2dBackwardOpIr,
957            |args: &AvgPool2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
958                let x = handles.get_float_tensor::<B>(&args.x);
959                let grad = handles.get_float_tensor::<B>(&args.grad);
960                let output = B::avg_pool2d_backward(
961                    x,
962                    grad,
963                    args.kernel_size,
964                    args.stride,
965                    args.padding,
966                    args.count_include_pad,
967                    args.ceil_mode,
968                );
969
970                handles.register_float_tensor::<B>(&args.out.id, output);
971            }
972        );
973
974        let streams = OperationStreams::with_inputs([&x, &grad]);
975
976        let client = x.client.clone();
977        let desc = AvgPool2dBackwardOpIr::create(
978            x.into_ir(),
979            grad.into_ir(),
980            kernel_size,
981            stride,
982            padding,
983            count_include_pad,
984            ceil_mode,
985            || client.create_empty_handle(),
986        );
987
988        client
989            .register(
990                streams,
991                OperationIr::Module(ModuleOperationIr::AvgPool2dBackward(desc.clone())),
992                AvgPool2dBackwardOps::<B>::new(desc),
993            )
994            .output()
995    }
996
997    fn max_pool1d(
998        x: FloatTensor<Self>,
999        kernel_size: usize,
1000        stride: usize,
1001        padding: usize,
1002        dilation: usize,
1003        ceil_mode: bool,
1004    ) -> FloatTensor<Self> {
1005        make_ops!(
1006            MaxPool1dOps,
1007            MaxPool1dOpIr,
1008            |args: &MaxPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
1009                let x = handles.get_float_tensor::<B>(&args.x);
1010                let output = B::max_pool1d(
1011                    x,
1012                    args.kernel_size,
1013                    args.stride,
1014                    args.padding,
1015                    args.dilation,
1016                    args.ceil_mode,
1017                );
1018
1019                handles.register_float_tensor::<B>(&args.out.id, output);
1020            }
1021        );
1022
1023        let streams = OperationStreams::with_inputs([&x]);
1024
1025        let client = x.client.clone();
1026        let desc = MaxPool1dOpIr::create(
1027            x.into_ir(),
1028            kernel_size,
1029            stride,
1030            padding,
1031            dilation,
1032            ceil_mode,
1033            || client.create_empty_handle(),
1034        );
1035
1036        client
1037            .register(
1038                streams,
1039                OperationIr::Module(ModuleOperationIr::MaxPool1d(desc.clone())),
1040                MaxPool1dOps::<B>::new(desc),
1041            )
1042            .output()
1043    }
1044
1045    fn max_pool2d(
1046        x: FloatTensor<Self>,
1047        kernel_size: [usize; 2],
1048        stride: [usize; 2],
1049        padding: [usize; 2],
1050        dilation: [usize; 2],
1051        ceil_mode: bool,
1052    ) -> FloatTensor<Self> {
1053        make_ops!(
1054            MaxPool2dOps,
1055            MaxPool2dOpIr,
1056            |args: &MaxPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
1057                let x = handles.get_float_tensor::<B>(&args.x);
1058                let output = B::max_pool2d(
1059                    x,
1060                    args.kernel_size,
1061                    args.stride,
1062                    args.padding,
1063                    args.dilation,
1064                    args.ceil_mode,
1065                );
1066
1067                handles.register_float_tensor::<B>(&args.out.id, output);
1068            }
1069        );
1070
1071        let streams = OperationStreams::with_inputs([&x]);
1072
1073        let client = x.client.clone();
1074        let desc = MaxPool2dOpIr::create(
1075            x.into_ir(),
1076            kernel_size,
1077            stride,
1078            padding,
1079            dilation,
1080            ceil_mode,
1081            || client.create_empty_handle(),
1082        );
1083
1084        client
1085            .register(
1086                streams,
1087                OperationIr::Module(ModuleOperationIr::MaxPool2d(desc.clone())),
1088                MaxPool2dOps::<B>::new(desc),
1089            )
1090            .output()
1091    }
1092
1093    fn max_pool1d_with_indices(
1094        x: FloatTensor<Self>,
1095        kernel_size: usize,
1096        stride: usize,
1097        padding: usize,
1098        dilation: usize,
1099        ceil_mode: bool,
1100    ) -> MaxPool1dWithIndices<Self> {
1101        make_ops!(
1102            MaxPool1dWithIndicesOps,
1103            MaxPool1dWithIndicesOpIr,
1104            |args: &MaxPool1dWithIndicesOpIr, handles: &mut HandleContainer<B::Handle>| {
1105                let x = handles.get_float_tensor::<B>(&args.x);
1106                let output = B::max_pool1d_with_indices(
1107                    x,
1108                    args.kernel_size,
1109                    args.stride,
1110                    args.padding,
1111                    args.dilation,
1112                    args.ceil_mode,
1113                );
1114
1115                handles.register_float_tensor::<B>(&args.out.id, output.output);
1116                handles.register_int_tensor::<B>(&args.out_indices.id, output.indices);
1117            }
1118        );
1119
1120        let streams = OperationStreams::with_inputs([&x]);
1121
1122        let client = x.client.clone();
1123        let desc = MaxPool1dWithIndicesOpIr::create(
1124            x.into_ir(),
1125            kernel_size,
1126            stride,
1127            padding,
1128            dilation,
1129            ceil_mode,
1130            B::IntElem::dtype(),
1131            || client.create_empty_handle(),
1132        );
1133
1134        let [out, out_indices] = client
1135            .register(
1136                streams,
1137                OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndices(desc.clone())),
1138                MaxPool1dWithIndicesOps::<B>::new(desc),
1139            )
1140            .outputs();
1141
1142        MaxPool1dWithIndices::new(out, out_indices)
1143    }
1144
1145    fn max_pool2d_with_indices(
1146        x: FloatTensor<Self>,
1147        kernel_size: [usize; 2],
1148        stride: [usize; 2],
1149        padding: [usize; 2],
1150        dilation: [usize; 2],
1151        ceil_mode: bool,
1152    ) -> MaxPool2dWithIndices<Self> {
1153        make_ops!(
1154            MaxPool2dWithIndicesOps,
1155            MaxPool2dWithIndicesOpIr,
1156            |args: &MaxPool2dWithIndicesOpIr, handles: &mut HandleContainer<B::Handle>| {
1157                let x = handles.get_float_tensor::<B>(&args.x);
1158                let output = B::max_pool2d_with_indices(
1159                    x,
1160                    args.kernel_size,
1161                    args.stride,
1162                    args.padding,
1163                    args.dilation,
1164                    args.ceil_mode,
1165                );
1166
1167                handles.register_float_tensor::<B>(&args.out.id, output.output);
1168                handles.register_int_tensor::<B>(&args.out_indices.id, output.indices);
1169            }
1170        );
1171
1172        let streams = OperationStreams::with_inputs([&x]);
1173
1174        let client = x.client.clone();
1175        let desc = MaxPool2dWithIndicesOpIr::create(
1176            x.into_ir(),
1177            kernel_size,
1178            stride,
1179            padding,
1180            dilation,
1181            ceil_mode,
1182            B::IntElem::dtype(),
1183            || client.create_empty_handle(),
1184        );
1185
1186        let [out, out_indices] = client
1187            .register(
1188                streams,
1189                OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndices(desc.clone())),
1190                MaxPool2dWithIndicesOps::<B>::new(desc),
1191            )
1192            .outputs();
1193
1194        MaxPool2dWithIndices::new(out, out_indices)
1195    }
1196
1197    fn max_pool1d_with_indices_backward(
1198        x: FloatTensor<Self>,
1199        kernel_size: usize,
1200        stride: usize,
1201        padding: usize,
1202        dilation: usize,
1203        ceil_mode: bool,
1204        output_grad: FloatTensor<Self>,
1205        indices: IntTensor<Self>,
1206    ) -> MaxPool1dBackward<Self> {
1207        make_ops!(
1208            MaxPool1dWithIndicesBackwardOps,
1209            MaxPool1dWithIndicesBackwardOpIr,
1210            |args: &MaxPool1dWithIndicesBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1211                let x = handles.get_float_tensor::<B>(&args.x);
1212                let grad = handles.get_float_tensor::<B>(&args.grad);
1213                let indices = handles.get_int_tensor::<B>(&args.indices);
1214                let output = B::max_pool1d_with_indices_backward(
1215                    x,
1216                    args.kernel_size,
1217                    args.stride,
1218                    args.padding,
1219                    args.dilation,
1220                    args.ceil_mode,
1221                    grad,
1222                    indices,
1223                );
1224
1225                handles.register_float_tensor::<B>(&args.out.id, output.x_grad);
1226            }
1227        );
1228
1229        let streams = OperationStreams::with_inputs([&x, &output_grad, &indices]);
1230
1231        let client = x.client.clone();
1232        let desc = MaxPool1dWithIndicesBackwardOpIr::create(
1233            x.into_ir(),
1234            output_grad.into_ir(),
1235            indices.into_ir(),
1236            kernel_size,
1237            stride,
1238            padding,
1239            dilation,
1240            ceil_mode,
1241            || client.create_empty_handle(),
1242        );
1243
1244        let out = client
1245            .register(
1246                streams,
1247                OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndicesBackward(
1248                    desc.clone(),
1249                )),
1250                MaxPool1dWithIndicesBackwardOps::<B>::new(desc),
1251            )
1252            .output();
1253
1254        MaxPool1dBackward::new(out)
1255    }
1256
1257    fn max_pool2d_with_indices_backward(
1258        x: FloatTensor<Self>,
1259        kernel_size: [usize; 2],
1260        stride: [usize; 2],
1261        padding: [usize; 2],
1262        dilation: [usize; 2],
1263        ceil_mode: bool,
1264        output_grad: FloatTensor<Self>,
1265        indices: IntTensor<Self>,
1266    ) -> MaxPool2dBackward<Self> {
1267        make_ops!(
1268            MaxPool2dWithIndicesBackwardOps,
1269            MaxPool2dWithIndicesBackwardOpIr,
1270            |args: &MaxPool2dWithIndicesBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1271                let x = handles.get_float_tensor::<B>(&args.x);
1272                let grad = handles.get_float_tensor::<B>(&args.grad);
1273                let indices = handles.get_int_tensor::<B>(&args.indices);
1274                let output = B::max_pool2d_with_indices_backward(
1275                    x,
1276                    args.kernel_size,
1277                    args.stride,
1278                    args.padding,
1279                    args.dilation,
1280                    args.ceil_mode,
1281                    grad,
1282                    indices,
1283                );
1284
1285                handles.register_float_tensor::<B>(&args.out.id, output.x_grad);
1286            }
1287        );
1288
1289        let streams = OperationStreams::with_inputs([&x, &output_grad, &indices]);
1290
1291        let client = x.client.clone();
1292        let desc = MaxPool2dWithIndicesBackwardOpIr::create(
1293            x.into_ir(),
1294            output_grad.into_ir(),
1295            indices.into_ir(),
1296            kernel_size,
1297            stride,
1298            padding,
1299            dilation,
1300            ceil_mode,
1301            || client.create_empty_handle(),
1302        );
1303
1304        let out = client
1305            .register(
1306                streams,
1307                OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndicesBackward(
1308                    desc.clone(),
1309                )),
1310                MaxPool2dWithIndicesBackwardOps::<B>::new(desc),
1311            )
1312            .output();
1313
1314        MaxPool2dBackward::new(out)
1315    }
1316
1317    fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
1318        make_ops!(
1319            AdaptiveAvgPool1dOps,
1320            AdaptiveAvgPool1dOpIr,
1321            |args: &AdaptiveAvgPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
1322                let x = handles.get_float_tensor::<B>(&args.x);
1323                let output = B::adaptive_avg_pool1d(x, args.output_size);
1324
1325                handles.register_float_tensor::<B>(&args.out.id, output);
1326            }
1327        );
1328
1329        let streams = OperationStreams::with_inputs([&x]);
1330
1331        let client = x.client.clone();
1332        let desc = AdaptiveAvgPool1dOpIr::create(x.into_ir(), output_size, || {
1333            client.create_empty_handle()
1334        });
1335
1336        client
1337            .register(
1338                streams,
1339                OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d(desc.clone())),
1340                AdaptiveAvgPool1dOps::<B>::new(desc),
1341            )
1342            .output()
1343    }
1344
1345    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
1346        make_ops!(
1347            AdaptiveAvgPool2dOps,
1348            AdaptiveAvgPool2dOpIr,
1349            |args: &AdaptiveAvgPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
1350                let x = handles.get_float_tensor::<B>(&args.x);
1351                let output = B::adaptive_avg_pool2d(x, args.output_size);
1352
1353                handles.register_float_tensor::<B>(&args.out.id, output);
1354            }
1355        );
1356
1357        let streams = OperationStreams::with_inputs([&x]);
1358
1359        let client = x.client.clone();
1360        let desc = AdaptiveAvgPool2dOpIr::create(x.into_ir(), output_size, || {
1361            client.create_empty_handle()
1362        });
1363
1364        client
1365            .register(
1366                streams,
1367                OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d(desc.clone())),
1368                AdaptiveAvgPool2dOps::<B>::new(desc),
1369            )
1370            .output()
1371    }
1372
1373    fn adaptive_avg_pool1d_backward(
1374        x: FloatTensor<Self>,
1375        grad: FloatTensor<Self>,
1376    ) -> FloatTensor<Self> {
1377        make_ops!(
1378            AdaptiveAvgPool1dBackwardOps,
1379            AdaptiveAvgPool1dBackwardOpIr,
1380            |args: &AdaptiveAvgPool1dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1381                let x = handles.get_float_tensor::<B>(&args.x);
1382                let grad = handles.get_float_tensor::<B>(&args.grad);
1383                let output = B::adaptive_avg_pool1d_backward(x, grad);
1384
1385                handles.register_float_tensor::<B>(&args.out.id, output);
1386            }
1387        );
1388
1389        let streams = OperationStreams::with_inputs([&x, &grad]);
1390
1391        let client = x.client.clone();
1392        let desc = AdaptiveAvgPool1dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
1393            client.create_empty_handle()
1394        });
1395
1396        client
1397            .register(
1398                streams,
1399                OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1dBackward(desc.clone())),
1400                AdaptiveAvgPool1dBackwardOps::<B>::new(desc),
1401            )
1402            .output()
1403    }
1404
1405    fn adaptive_avg_pool2d_backward(
1406        x: FloatTensor<Self>,
1407        grad: FloatTensor<Self>,
1408    ) -> FloatTensor<Self> {
1409        make_ops!(
1410            AdaptiveAvgPool2dBackwardOps,
1411            AdaptiveAvgPool2dBackwardOpIr,
1412            |args: &AdaptiveAvgPool2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1413                let x = handles.get_float_tensor::<B>(&args.x);
1414                let grad = handles.get_float_tensor::<B>(&args.grad);
1415                let output = B::adaptive_avg_pool2d_backward(x, grad);
1416
1417                handles.register_float_tensor::<B>(&args.out.id, output);
1418            }
1419        );
1420        let streams = OperationStreams::with_inputs([&x, &grad]);
1421
1422        let client = x.client.clone();
1423        let desc = AdaptiveAvgPool2dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
1424            client.create_empty_handle()
1425        });
1426
1427        client
1428            .register(
1429                streams,
1430                OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2dBackward(desc.clone())),
1431                AdaptiveAvgPool2dBackwardOps::<B>::new(desc),
1432            )
1433            .output()
1434    }
1435
1436    fn interpolate(
1437        x: FloatTensor<Self>,
1438        output_size: [usize; 2],
1439        options: InterpolateOptions,
1440    ) -> FloatTensor<Self> {
1441        make_ops!(
1442            InterpolateOps,
1443            InterpolateOpIr,
1444            |args: &InterpolateOpIr, handles: &mut HandleContainer<B::Handle>| {
1445                let x = handles.get_float_tensor::<B>(&args.x);
1446                let output = B::interpolate(x, args.output_size, args.options.clone().into());
1447                handles.register_float_tensor::<B>(&args.out.id, output);
1448            }
1449        );
1450
1451        let streams = OperationStreams::with_inputs([&x]);
1452
1453        let client = x.client.clone();
1454        let desc = InterpolateOpIr::create(x.into_ir(), output_size, options.into(), || {
1455            client.create_empty_handle()
1456        });
1457
1458        client
1459            .register(
1460                streams,
1461                OperationIr::Module(ModuleOperationIr::Interpolate(desc.clone())),
1462                InterpolateOps::<B>::new(desc),
1463            )
1464            .output()
1465    }
1466
1467    fn interpolate_backward(
1468        x: FloatTensor<Self>,
1469        grad: FloatTensor<Self>,
1470        output_size: [usize; 2],
1471        options: InterpolateOptions,
1472    ) -> FloatTensor<Self> {
1473        make_ops!(
1474            InterpolateBackwardOps,
1475            InterpolateBackwardOpIr,
1476            |args: &InterpolateBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1477                let x = handles.get_float_tensor::<B>(&args.x);
1478                let grad = handles.get_float_tensor::<B>(&args.grad);
1479                let output =
1480                    B::interpolate_backward(x, grad, args.output_size, args.options.clone().into());
1481
1482                handles.register_float_tensor::<B>(&args.out.id, output);
1483            }
1484        );
1485
1486        let streams = OperationStreams::with_inputs([&x, &grad]);
1487
1488        let client = x.client.clone();
1489        let desc = InterpolateBackwardOpIr::create(
1490            x.into_ir(),
1491            grad.into_ir(),
1492            output_size,
1493            options.into(),
1494            || client.create_empty_handle(),
1495        );
1496
1497        client
1498            .register(
1499                streams,
1500                OperationIr::Module(ModuleOperationIr::InterpolateBackward(desc.clone())),
1501                InterpolateBackwardOps::<B>::new(desc),
1502            )
1503            .output()
1504    }
1505
1506    fn attention(
1507        query: FloatTensor<Fusion<B>>,
1508        key: FloatTensor<Fusion<B>>,
1509        value: FloatTensor<Fusion<B>>,
1510        mask: Option<burn_backend::tensor::BoolTensor<Fusion<B>>>,
1511        attn_bias: Option<FloatTensor<Fusion<B>>>,
1512        options: burn_backend::ops::AttentionModuleOptions,
1513    ) -> FloatTensor<Fusion<B>> {
1514        make_ops!(
1515            AttentionOps,
1516            AttentionOpIr,
1517            |args: &AttentionOpIr, handles: &mut HandleContainer<B::Handle>| {
1518                let query = handles.get_float_tensor::<B>(&args.query);
1519                let key = handles.get_float_tensor::<B>(&args.key);
1520                let value = handles.get_float_tensor::<B>(&args.value);
1521                let mask = args.mask.as_ref().map(|m| handles.get_bool_tensor::<B>(m));
1522                let attn_bias = args
1523                    .attn_bias
1524                    .as_ref()
1525                    .map(|ab| handles.get_float_tensor::<B>(ab));
1526
1527                let output = B::attention(
1528                    query,
1529                    key,
1530                    value,
1531                    mask,
1532                    attn_bias,
1533                    args.options.clone().into(),
1534                );
1535
1536                handles.register_float_tensor::<B>(&args.out.id, output);
1537            }
1538        );
1539
1540        let mut streams = OperationStreams::with_inputs([&query, &key, &value]);
1541        if let Some(mask) = &mask {
1542            streams.tensor(mask);
1543        }
1544        if let Some(attn_bias) = &attn_bias {
1545            streams.tensor(attn_bias);
1546        }
1547
1548        let client = query.client.clone();
1549        let desc = AttentionOpIr::create(
1550            query.into_ir(),
1551            key.into_ir(),
1552            value.into_ir(),
1553            mask.map(|m| m.into_ir()),
1554            attn_bias.map(|ab| ab.into_ir()),
1555            options.into(),
1556            || client.create_empty_handle(),
1557        );
1558
1559        client
1560            .register(
1561                streams,
1562                OperationIr::Module(ModuleOperationIr::Attention(desc.clone())),
1563                AttentionOps::<B>::new(desc),
1564            )
1565            .output()
1566    }
1567
1568    fn rfft(
1569        signal: FloatTensor<Fusion<B>>,
1570        dim: usize,
1571        n: Option<usize>,
1572    ) -> (FloatTensor<Fusion<B>>, FloatTensor<Fusion<B>>) {
1573        make_ops!(RfftOps, RfftOpIr, |desc: &RfftOpIr,
1574                                      handles: &mut HandleContainer<
1575            B::Handle,
1576        >| {
1577            let signal = handles.get_float_tensor::<B>(&desc.signal);
1578            let (re, im) = B::rfft(signal, desc.dim, desc.n);
1579
1580            handles.register_float_tensor::<B>(&desc.out_re.id, re);
1581            handles.register_float_tensor::<B>(&desc.out_im.id, im);
1582        });
1583
1584        let streams = OperationStreams::with_inputs([&signal]);
1585        let client = signal.client.clone();
1586
1587        let desc = RfftOpIr::create(signal.into_ir(), dim, n, || client.create_empty_handle());
1588
1589        let mut outputs = client
1590            .register(
1591                streams,
1592                OperationIr::Module(ModuleOperationIr::Rfft(desc.clone())),
1593                RfftOps::<B>::new(desc),
1594            )
1595            .into_iter();
1596
1597        (outputs.next().unwrap(), outputs.next().unwrap())
1598    }
1599
1600    fn irfft(
1601        spectrum_re: FloatTensor<Fusion<B>>,
1602        spectrum_im: FloatTensor<Fusion<B>>,
1603        dim: usize,
1604        n: Option<usize>,
1605    ) -> FloatTensor<Fusion<B>> {
1606        make_ops!(IRfftOps, IRfftOpIr, |desc: &IRfftOpIr,
1607                                        handles: &mut HandleContainer<
1608            B::Handle,
1609        >| {
1610            let input_re = handles.get_float_tensor::<B>(&desc.input_re);
1611            let input_im = handles.get_float_tensor::<B>(&desc.input_im);
1612
1613            let signal = B::irfft(input_re, input_im, desc.dim, desc.n);
1614            handles.register_float_tensor::<B>(&desc.out_signal.id, signal);
1615        });
1616
1617        let streams = OperationStreams::with_inputs([&spectrum_re, &spectrum_im]);
1618        let client = spectrum_re.client.clone();
1619
1620        let desc = IRfftOpIr::create(spectrum_re.into_ir(), spectrum_im.into_ir(), dim, n, || {
1621            client.create_empty_handle()
1622        });
1623
1624        let mut outputs = client
1625            .register(
1626                streams,
1627                OperationIr::Module(ModuleOperationIr::IRfft(desc.clone())),
1628                IRfftOps::<B>::new(desc),
1629            )
1630            .into_iter();
1631
1632        outputs.next().unwrap()
1633    }
1634
1635    fn has_ctc_loss_backward() -> bool {
1636        B::has_ctc_loss_backward()
1637    }
1638
1639    fn ctc_loss(
1640        log_probs: FloatTensor<Fusion<B>>,
1641        targets: IntTensor<Fusion<B>>,
1642        input_lengths: IntTensor<Fusion<B>>,
1643        target_lengths: IntTensor<Fusion<B>>,
1644        blank: usize,
1645    ) -> FloatTensor<Fusion<B>> {
1646        // CTC is treated as its own non-fuseable IR node, the same way
1647        // `attention` and `conv2d` are. The execute callback drains the input
1648        // handles and dispatches to `B::ctc_loss` on the inner backend, which
1649        // either runs a native kernel (cubecl, libtorch) or the decomposed
1650        // default - either way it executes on raw inner-backend tensors,
1651        // never re-entering the fusion stream.
1652        make_ops!(CtcLossOps, CtcLossOpIr, |args: &CtcLossOpIr,
1653                                            handles: &mut HandleContainer<
1654            B::Handle,
1655        >| {
1656            let log_probs = handles.get_float_tensor::<B>(&args.log_probs);
1657            let targets = handles.get_int_tensor::<B>(&args.targets);
1658            let input_lengths = handles.get_int_tensor::<B>(&args.input_lengths);
1659            let target_lengths = handles.get_int_tensor::<B>(&args.target_lengths);
1660            let output = B::ctc_loss(
1661                log_probs,
1662                targets,
1663                input_lengths,
1664                target_lengths,
1665                args.blank,
1666            );
1667            handles.register_float_tensor::<B>(&args.out.id, output);
1668        });
1669
1670        let streams =
1671            OperationStreams::with_inputs([&log_probs, &targets, &input_lengths, &target_lengths]);
1672        let client = log_probs.client.clone();
1673        let desc = CtcLossOpIr::create(
1674            log_probs.into_ir(),
1675            targets.into_ir(),
1676            input_lengths.into_ir(),
1677            target_lengths.into_ir(),
1678            blank,
1679            || client.create_empty_handle(),
1680        );
1681
1682        client
1683            .register(
1684                streams,
1685                OperationIr::Module(ModuleOperationIr::CtcLoss(desc.clone())),
1686                CtcLossOps::<B>::new(desc),
1687            )
1688            .output()
1689    }
1690
1691    fn ctc_loss_backward(
1692        log_probs: FloatTensor<Fusion<B>>,
1693        targets: IntTensor<Fusion<B>>,
1694        input_lengths: IntTensor<Fusion<B>>,
1695        target_lengths: IntTensor<Fusion<B>>,
1696        grad_loss: FloatTensor<Fusion<B>>,
1697        blank: usize,
1698    ) -> FloatTensor<Fusion<B>> {
1699        // Mirrors `ctc_loss`: a typed IR node that dispatches to the inner
1700        // backend's native backward kernel.
1701        make_ops!(
1702            CtcLossBackwardOps,
1703            CtcLossBackwardOpIr,
1704            |args: &CtcLossBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1705                let log_probs = handles.get_float_tensor::<B>(&args.log_probs);
1706                let targets = handles.get_int_tensor::<B>(&args.targets);
1707                let input_lengths = handles.get_int_tensor::<B>(&args.input_lengths);
1708                let target_lengths = handles.get_int_tensor::<B>(&args.target_lengths);
1709                let grad_loss = handles.get_float_tensor::<B>(&args.grad_loss);
1710                let output = B::ctc_loss_backward(
1711                    log_probs,
1712                    targets,
1713                    input_lengths,
1714                    target_lengths,
1715                    grad_loss,
1716                    args.blank,
1717                );
1718                handles.register_float_tensor::<B>(&args.out.id, output);
1719            }
1720        );
1721
1722        let streams = OperationStreams::with_inputs([
1723            &log_probs,
1724            &targets,
1725            &input_lengths,
1726            &target_lengths,
1727            &grad_loss,
1728        ]);
1729        let client = log_probs.client.clone();
1730        let desc = CtcLossBackwardOpIr::create(
1731            log_probs.into_ir(),
1732            targets.into_ir(),
1733            input_lengths.into_ir(),
1734            target_lengths.into_ir(),
1735            grad_loss.into_ir(),
1736            blank,
1737            || client.create_empty_handle(),
1738        );
1739
1740        client
1741            .register(
1742                streams,
1743                OperationIr::Module(ModuleOperationIr::CtcLossBackward(desc.clone())),
1744                CtcLossBackwardOps::<B>::new(desc),
1745            )
1746            .output()
1747    }
1748}