Skip to main content

burn_router/ops/
module.rs

1use alloc::boxed::Box;
2
3use burn_backend::Element;
4use burn_backend::ops::{
5    AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward,
6    DeformConvOptions, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices,
7    MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
8};
9use burn_backend::tensor::{BoolTensor, FloatTensor, IntElem, IntTensor};
10use burn_ir::*;
11
12use crate::{BackendRouter, RunnerChannel, RunnerClient};
13
14impl<R: RunnerChannel> ModuleOps<Self> for BackendRouter<R> {
15    fn linear(
16        x: FloatTensor<Self>,
17        weight: FloatTensor<Self>,
18        bias: Option<FloatTensor<Self>>,
19    ) -> FloatTensor<Self> {
20        let client = x.client.clone();
21        let desc = LinearOpIr::create(
22            x.into_ir(),
23            weight.into_ir(),
24            bias.map(|bias| bias.into_ir()),
25            || client.create_empty_handle(),
26        );
27
28        client
29            .register(OperationIr::Module(ModuleOperationIr::Linear(desc)))
30            .output()
31    }
32
33    fn linear_x_backward(
34        weight: FloatTensor<Self>,
35        output_grad: FloatTensor<Self>,
36    ) -> FloatTensor<Self> {
37        let client = weight.client.clone();
38        let desc = LinearXBackwardOpIr::create(weight.into_ir(), output_grad.into_ir(), || {
39            client.create_empty_handle()
40        });
41
42        client
43            .register(OperationIr::Module(ModuleOperationIr::LinearXBackward(
44                desc,
45            )))
46            .output()
47    }
48
49    fn linear_weight_backward(
50        x: FloatTensor<Self>,
51        output_grad: FloatTensor<Self>,
52    ) -> FloatTensor<Self> {
53        let client = x.client.clone();
54        let desc = LinearWeightBackwardOpIr::create(x.into_ir(), output_grad.into_ir(), || {
55            client.create_empty_handle()
56        });
57
58        client
59            .register(OperationIr::Module(
60                ModuleOperationIr::LinearWeightBackward(desc),
61            ))
62            .output()
63    }
64
65    fn linear_bias_backward(output_grad: FloatTensor<Self>) -> FloatTensor<Self> {
66        let client = output_grad.client.clone();
67        let desc =
68            LinearBiasBackwardOpIr::create(output_grad.into_ir(), || client.create_empty_handle());
69
70        client
71            .register(OperationIr::Module(ModuleOperationIr::LinearBiasBackward(
72                desc,
73            )))
74            .output()
75    }
76
77    fn conv1d(
78        x: FloatTensor<Self>,
79        weight: FloatTensor<Self>,
80        bias: Option<FloatTensor<Self>>,
81        options: ConvOptions<1>,
82    ) -> FloatTensor<Self> {
83        let client = x.client.clone();
84        let desc = Conv1dOpIr::create(
85            x.into_ir(),
86            weight.into_ir(),
87            bias.map(|bias| bias.into_ir()),
88            options.into(),
89            || client.create_empty_handle(),
90        );
91
92        client
93            .register(OperationIr::Module(ModuleOperationIr::Conv1d(desc)))
94            .output()
95    }
96
97    fn conv1d_x_backward(
98        x: FloatTensor<Self>,
99        weight: FloatTensor<Self>,
100        output_grad: FloatTensor<Self>,
101        options: ConvOptions<1>,
102    ) -> FloatTensor<Self> {
103        let client = x.client.clone();
104        let desc = Conv1dXBackwardOpIr::create(
105            x.into_ir(),
106            weight.into_ir(),
107            output_grad.into_ir(),
108            options.into(),
109            || client.create_empty_handle(),
110        );
111
112        client
113            .register(OperationIr::Module(ModuleOperationIr::Conv1dXBackward(
114                desc,
115            )))
116            .output()
117    }
118
119    fn conv1d_weight_backward(
120        x: FloatTensor<Self>,
121        weight: FloatTensor<Self>,
122        output_grad: FloatTensor<Self>,
123        options: ConvOptions<1>,
124    ) -> FloatTensor<Self> {
125        let client = x.client.clone();
126        let desc = Conv1dWeightBackwardOpIr::create(
127            x.into_ir(),
128            weight.into_ir(),
129            output_grad.into_ir(),
130            options.into(),
131            || client.create_empty_handle(),
132        );
133
134        client
135            .register(OperationIr::Module(
136                ModuleOperationIr::Conv1dWeightBackward(desc),
137            ))
138            .output()
139    }
140
141    fn conv1d_bias_backward(
142        x: FloatTensor<Self>,
143        bias: FloatTensor<Self>,
144        output_grad: FloatTensor<Self>,
145    ) -> FloatTensor<Self> {
146        let client = x.client.clone();
147        let desc = Conv1dBiasBackwardOpIr::create(
148            x.into_ir(),
149            bias.into_ir(),
150            output_grad.into_ir(),
151            || client.create_empty_handle(),
152        );
153
154        client
155            .register(OperationIr::Module(ModuleOperationIr::Conv1dBiasBackward(
156                desc,
157            )))
158            .output()
159    }
160
161    fn conv2d(
162        x: FloatTensor<Self>,
163        weight: FloatTensor<Self>,
164        bias: Option<FloatTensor<Self>>,
165        options: ConvOptions<2>,
166    ) -> FloatTensor<Self> {
167        let client = x.client.clone();
168        let desc = Conv2dOpIr::create(
169            x.into_ir(),
170            weight.into_ir(),
171            bias.map(|bias| bias.into_ir()),
172            options.into(),
173            || client.create_empty_handle(),
174        );
175
176        client
177            .register(OperationIr::Module(ModuleOperationIr::Conv2d(desc)))
178            .output()
179    }
180
181    fn conv2d_x_backward(
182        x: FloatTensor<Self>,
183        weight: FloatTensor<Self>,
184        output_grad: FloatTensor<Self>,
185        options: ConvOptions<2>,
186    ) -> FloatTensor<Self> {
187        let client = x.client.clone();
188        let desc = Conv2dXBackwardOpIr::create(
189            x.into_ir(),
190            weight.into_ir(),
191            output_grad.into_ir(),
192            options.into(),
193            || client.create_empty_handle(),
194        );
195
196        client
197            .register(OperationIr::Module(ModuleOperationIr::Conv2dXBackward(
198                desc,
199            )))
200            .output()
201    }
202
203    fn conv2d_weight_backward(
204        x: FloatTensor<Self>,
205        weight: FloatTensor<Self>,
206        output_grad: FloatTensor<Self>,
207        options: ConvOptions<2>,
208    ) -> FloatTensor<Self> {
209        let client = x.client.clone();
210        let desc = Conv2dWeightBackwardOpIr::create(
211            x.into_ir(),
212            weight.into_ir(),
213            output_grad.into_ir(),
214            options.into(),
215            || client.create_empty_handle(),
216        );
217
218        client
219            .register(OperationIr::Module(
220                ModuleOperationIr::Conv2dWeightBackward(desc),
221            ))
222            .output()
223    }
224
225    fn conv2d_bias_backward(
226        x: FloatTensor<Self>,
227        bias: FloatTensor<Self>,
228        output_grad: FloatTensor<Self>,
229    ) -> FloatTensor<Self> {
230        let client = x.client.clone();
231        let desc = Conv2dBiasBackwardOpIr::create(
232            x.into_ir(),
233            bias.into_ir(),
234            output_grad.into_ir(),
235            || client.create_empty_handle(),
236        );
237
238        client
239            .register(OperationIr::Module(ModuleOperationIr::Conv2dBiasBackward(
240                desc,
241            )))
242            .output()
243    }
244
245    fn conv3d(
246        x: FloatTensor<Self>,
247        weight: FloatTensor<Self>,
248        bias: Option<FloatTensor<Self>>,
249        options: ConvOptions<3>,
250    ) -> FloatTensor<Self> {
251        let client = x.client.clone();
252        let desc = Conv3dOpIr::create(
253            x.into_ir(),
254            weight.into_ir(),
255            bias.map(|bias| bias.into_ir()),
256            options.into(),
257            || client.create_empty_handle(),
258        );
259
260        client
261            .register(OperationIr::Module(ModuleOperationIr::Conv3d(desc)))
262            .output()
263    }
264
265    fn conv3d_x_backward(
266        x: FloatTensor<Self>,
267        weight: FloatTensor<Self>,
268        output_grad: FloatTensor<Self>,
269        options: ConvOptions<3>,
270    ) -> FloatTensor<Self> {
271        let client = x.client.clone();
272        let desc = Conv3dXBackwardOpIr::create(
273            x.into_ir(),
274            weight.into_ir(),
275            output_grad.into_ir(),
276            options.into(),
277            || client.create_empty_handle(),
278        );
279
280        client
281            .register(OperationIr::Module(ModuleOperationIr::Conv3dXBackward(
282                desc,
283            )))
284            .output()
285    }
286
287    fn conv3d_weight_backward(
288        x: FloatTensor<Self>,
289        weight: FloatTensor<Self>,
290        output_grad: FloatTensor<Self>,
291        options: ConvOptions<3>,
292    ) -> FloatTensor<Self> {
293        let client = x.client.clone();
294        let desc = Conv3dWeightBackwardOpIr::create(
295            x.into_ir(),
296            weight.into_ir(),
297            output_grad.into_ir(),
298            options.into(),
299            || client.create_empty_handle(),
300        );
301
302        client
303            .register(OperationIr::Module(
304                ModuleOperationIr::Conv3dWeightBackward(desc),
305            ))
306            .output()
307    }
308
309    fn conv3d_bias_backward(
310        x: FloatTensor<Self>,
311        bias: FloatTensor<Self>,
312        output_grad: FloatTensor<Self>,
313    ) -> FloatTensor<Self> {
314        let client = x.client.clone();
315        let desc = Conv3dBiasBackwardOpIr::create(
316            x.into_ir(),
317            bias.into_ir(),
318            output_grad.into_ir(),
319            || client.create_empty_handle(),
320        );
321
322        client
323            .register(OperationIr::Module(ModuleOperationIr::Conv3dBiasBackward(
324                desc,
325            )))
326            .output()
327    }
328
329    fn conv_transpose1d(
330        x: FloatTensor<Self>,
331        weight: FloatTensor<Self>,
332        bias: Option<FloatTensor<Self>>,
333        options: ConvTransposeOptions<1>,
334    ) -> FloatTensor<Self> {
335        let client = x.client.clone();
336        let desc = ConvTranspose1dOpIr::create(
337            x.into_ir(),
338            weight.into_ir(),
339            bias.map(|bias| bias.into_ir()),
340            options.into(),
341            || client.create_empty_handle(),
342        );
343
344        client
345            .register(OperationIr::Module(ModuleOperationIr::ConvTranspose1d(
346                desc,
347            )))
348            .output()
349    }
350
351    fn conv_transpose2d(
352        x: FloatTensor<Self>,
353        weight: FloatTensor<Self>,
354        bias: Option<FloatTensor<Self>>,
355        options: ConvTransposeOptions<2>,
356    ) -> FloatTensor<Self> {
357        let client = x.client.clone();
358        let desc = ConvTranspose2dOpIr::create(
359            x.into_ir(),
360            weight.into_ir(),
361            bias.map(|bias| bias.into_ir()),
362            options.into(),
363            || client.create_empty_handle(),
364        );
365
366        client
367            .register(OperationIr::Module(ModuleOperationIr::ConvTranspose2d(
368                desc,
369            )))
370            .output()
371    }
372
373    fn conv_transpose3d(
374        x: FloatTensor<Self>,
375        weight: FloatTensor<Self>,
376        bias: Option<FloatTensor<Self>>,
377        options: ConvTransposeOptions<3>,
378    ) -> FloatTensor<Self> {
379        let client = x.client.clone();
380        let desc = ConvTranspose3dOpIr::create(
381            x.into_ir(),
382            weight.into_ir(),
383            bias.map(|bias| bias.into_ir()),
384            options.into(),
385            || client.create_empty_handle(),
386        );
387
388        client
389            .register(OperationIr::Module(ModuleOperationIr::ConvTranspose3d(
390                desc,
391            )))
392            .output()
393    }
394
395    fn avg_pool1d(
396        x: FloatTensor<Self>,
397        kernel_size: usize,
398        stride: usize,
399        padding: usize,
400        count_include_pad: bool,
401        ceil_mode: bool,
402    ) -> FloatTensor<Self> {
403        let client = x.client.clone();
404        let desc = AvgPool1dOpIr::create(
405            x.into_ir(),
406            kernel_size,
407            stride,
408            padding,
409            count_include_pad,
410            ceil_mode,
411            || client.create_empty_handle(),
412        );
413
414        client
415            .register(OperationIr::Module(ModuleOperationIr::AvgPool1d(desc)))
416            .output()
417    }
418
419    fn avg_pool2d(
420        x: FloatTensor<Self>,
421        kernel_size: [usize; 2],
422        stride: [usize; 2],
423        padding: [usize; 2],
424        count_include_pad: bool,
425        ceil_mode: bool,
426    ) -> FloatTensor<Self> {
427        let client = x.client.clone();
428        let desc = AvgPool2dOpIr::create(
429            x.into_ir(),
430            kernel_size,
431            stride,
432            padding,
433            count_include_pad,
434            ceil_mode,
435            || client.create_empty_handle(),
436        );
437
438        client
439            .register(OperationIr::Module(ModuleOperationIr::AvgPool2d(desc)))
440            .output()
441    }
442
443    fn avg_pool1d_backward(
444        x: FloatTensor<Self>,
445        grad: FloatTensor<Self>,
446        kernel_size: usize,
447        stride: usize,
448        padding: usize,
449        count_include_pad: bool,
450        ceil_mode: bool,
451    ) -> FloatTensor<Self> {
452        let client = x.client.clone();
453        let desc = AvgPool1dBackwardOpIr::create(
454            x.into_ir(),
455            grad.into_ir(),
456            kernel_size,
457            stride,
458            padding,
459            count_include_pad,
460            ceil_mode,
461            || client.create_empty_handle(),
462        );
463
464        client
465            .register(OperationIr::Module(ModuleOperationIr::AvgPool1dBackward(
466                desc,
467            )))
468            .output()
469    }
470
471    fn avg_pool2d_backward(
472        x: FloatTensor<Self>,
473        grad: FloatTensor<Self>,
474        kernel_size: [usize; 2],
475        stride: [usize; 2],
476        padding: [usize; 2],
477        count_include_pad: bool,
478        ceil_mode: bool,
479    ) -> FloatTensor<Self> {
480        let client = x.client.clone();
481        let desc = AvgPool2dBackwardOpIr::create(
482            x.into_ir(),
483            grad.into_ir(),
484            kernel_size,
485            stride,
486            padding,
487            count_include_pad,
488            ceil_mode,
489            || client.create_empty_handle(),
490        );
491
492        client
493            .register(OperationIr::Module(ModuleOperationIr::AvgPool2dBackward(
494                desc,
495            )))
496            .output()
497    }
498
499    fn max_pool1d(
500        x: FloatTensor<Self>,
501        kernel_size: usize,
502        stride: usize,
503        padding: usize,
504        dilation: usize,
505        ceil_mode: bool,
506    ) -> FloatTensor<Self> {
507        let client = x.client.clone();
508        let desc = MaxPool1dOpIr::create(
509            x.into_ir(),
510            kernel_size,
511            stride,
512            padding,
513            dilation,
514            ceil_mode,
515            || client.create_empty_handle(),
516        );
517
518        client
519            .register(OperationIr::Module(ModuleOperationIr::MaxPool1d(desc)))
520            .output()
521    }
522
523    fn max_pool2d(
524        x: FloatTensor<Self>,
525        kernel_size: [usize; 2],
526        stride: [usize; 2],
527        padding: [usize; 2],
528        dilation: [usize; 2],
529        ceil_mode: bool,
530    ) -> FloatTensor<Self> {
531        let client = x.client.clone();
532        let desc = MaxPool2dOpIr::create(
533            x.into_ir(),
534            kernel_size,
535            stride,
536            padding,
537            dilation,
538            ceil_mode,
539            || client.create_empty_handle(),
540        );
541
542        client
543            .register(OperationIr::Module(ModuleOperationIr::MaxPool2d(desc)))
544            .output()
545    }
546
547    fn max_pool1d_with_indices(
548        x: FloatTensor<Self>,
549        kernel_size: usize,
550        stride: usize,
551        padding: usize,
552        dilation: usize,
553        ceil_mode: bool,
554    ) -> MaxPool1dWithIndices<Self> {
555        let client = x.client.clone();
556        let desc = MaxPool1dWithIndicesOpIr::create(
557            x.into_ir(),
558            kernel_size,
559            stride,
560            padding,
561            dilation,
562            ceil_mode,
563            IntElem::<Self>::dtype(),
564            || client.create_empty_handle(),
565        );
566
567        let [out, out_indices] = client
568            .register(OperationIr::Module(
569                ModuleOperationIr::MaxPool1dWithIndices(desc),
570            ))
571            .outputs();
572
573        MaxPool1dWithIndices::new(out, out_indices)
574    }
575
576    fn max_pool2d_with_indices(
577        x: FloatTensor<Self>,
578        kernel_size: [usize; 2],
579        stride: [usize; 2],
580        padding: [usize; 2],
581        dilation: [usize; 2],
582        ceil_mode: bool,
583    ) -> MaxPool2dWithIndices<Self> {
584        let client = x.client.clone();
585        let desc = MaxPool2dWithIndicesOpIr::create(
586            x.into_ir(),
587            kernel_size,
588            stride,
589            padding,
590            dilation,
591            ceil_mode,
592            IntElem::<Self>::dtype(),
593            || client.create_empty_handle(),
594        );
595
596        let [out, out_indices] = client
597            .register(OperationIr::Module(
598                ModuleOperationIr::MaxPool2dWithIndices(desc),
599            ))
600            .outputs();
601
602        MaxPool2dWithIndices::new(out, out_indices)
603    }
604
605    fn max_pool1d_with_indices_backward(
606        x: FloatTensor<Self>,
607        kernel_size: usize,
608        stride: usize,
609        padding: usize,
610        dilation: usize,
611        ceil_mode: bool,
612        output_grad: FloatTensor<Self>,
613        indices: IntTensor<Self>,
614    ) -> MaxPool1dBackward<Self> {
615        let client = x.client.clone();
616
617        let desc = MaxPool1dWithIndicesBackwardOpIr::create(
618            x.into_ir(),
619            output_grad.into_ir(),
620            indices.into_ir(),
621            kernel_size,
622            stride,
623            padding,
624            dilation,
625            ceil_mode,
626            || client.create_empty_handle(),
627        );
628
629        let out = client
630            .register(OperationIr::Module(
631                ModuleOperationIr::MaxPool1dWithIndicesBackward(desc),
632            ))
633            .output();
634
635        MaxPool1dBackward::new(out)
636    }
637
638    fn max_pool2d_with_indices_backward(
639        x: FloatTensor<Self>,
640        kernel_size: [usize; 2],
641        stride: [usize; 2],
642        padding: [usize; 2],
643        dilation: [usize; 2],
644        ceil_mode: bool,
645        output_grad: FloatTensor<Self>,
646        indices: IntTensor<Self>,
647    ) -> MaxPool2dBackward<Self> {
648        let client = x.client.clone();
649
650        let desc = MaxPool2dWithIndicesBackwardOpIr::create(
651            x.into_ir(),
652            output_grad.into_ir(),
653            indices.into_ir(),
654            kernel_size,
655            stride,
656            padding,
657            dilation,
658            ceil_mode,
659            || client.create_empty_handle(),
660        );
661
662        let out = client
663            .register(OperationIr::Module(
664                ModuleOperationIr::MaxPool2dWithIndicesBackward(desc),
665            ))
666            .output();
667
668        MaxPool2dBackward::new(out)
669    }
670
671    fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
672        let client = x.client.clone();
673
674        let desc = AdaptiveAvgPool1dOpIr::create(x.into_ir(), output_size, || {
675            client.create_empty_handle()
676        });
677
678        client
679            .register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d(
680                desc,
681            )))
682            .output()
683    }
684
685    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
686        let client = x.client.clone();
687
688        let desc = AdaptiveAvgPool2dOpIr::create(x.into_ir(), output_size, || {
689            client.create_empty_handle()
690        });
691
692        client
693            .register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d(
694                desc,
695            )))
696            .output()
697    }
698
699    fn adaptive_avg_pool1d_backward(
700        x: FloatTensor<Self>,
701        grad: FloatTensor<Self>,
702    ) -> FloatTensor<Self> {
703        let client = x.client.clone();
704
705        let desc = AdaptiveAvgPool1dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
706            client.create_empty_handle()
707        });
708
709        client
710            .register(OperationIr::Module(
711                ModuleOperationIr::AdaptiveAvgPool1dBackward(desc),
712            ))
713            .output()
714    }
715
716    fn adaptive_avg_pool2d_backward(
717        x: FloatTensor<Self>,
718        grad: FloatTensor<Self>,
719    ) -> FloatTensor<Self> {
720        let client = x.client.clone();
721
722        let desc = AdaptiveAvgPool2dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
723            client.create_empty_handle()
724        });
725
726        client
727            .register(OperationIr::Module(
728                ModuleOperationIr::AdaptiveAvgPool2dBackward(desc),
729            ))
730            .output()
731    }
732
733    fn interpolate(
734        x: FloatTensor<Self>,
735        output_size: [usize; 2],
736        options: InterpolateOptions,
737    ) -> FloatTensor<Self> {
738        let client = x.client.clone();
739        let desc = InterpolateOpIr::create(x.into_ir(), output_size, options.into(), || {
740            client.create_empty_handle()
741        });
742
743        client
744            .register(OperationIr::Module(ModuleOperationIr::Interpolate(desc)))
745            .output()
746    }
747
748    fn interpolate_backward(
749        x: FloatTensor<Self>,
750        grad: FloatTensor<Self>,
751        output_size: [usize; 2],
752        options: InterpolateOptions,
753    ) -> FloatTensor<Self> {
754        let client = x.client.clone();
755        let desc = InterpolateBackwardOpIr::create(
756            x.into_ir(),
757            grad.into_ir(),
758            output_size,
759            options.into(),
760            || client.create_empty_handle(),
761        );
762
763        client
764            .register(OperationIr::Module(ModuleOperationIr::InterpolateBackward(
765                desc,
766            )))
767            .output()
768    }
769
770    fn deform_conv2d(
771        x: FloatTensor<Self>,
772        offset: FloatTensor<Self>,
773        weight: FloatTensor<Self>,
774        mask: Option<FloatTensor<Self>>,
775        bias: Option<FloatTensor<Self>>,
776        options: DeformConvOptions<2>,
777    ) -> FloatTensor<Self> {
778        let client = x.client.clone();
779        let desc = DeformConv2dOpIr::create(
780            x.into_ir(),
781            offset.into_ir(),
782            weight.into_ir(),
783            mask.map(|mask| mask.into_ir()),
784            bias.map(|bias| bias.into_ir()),
785            options.into(),
786            || client.create_empty_handle(),
787        );
788
789        client
790            .register(OperationIr::Module(ModuleOperationIr::DeformableConv2d(
791                Box::new(desc),
792            )))
793            .output()
794    }
795
796    fn deform_conv2d_backward(
797        x: FloatTensor<Self>,
798        offset: FloatTensor<Self>,
799        weight: FloatTensor<Self>,
800        mask: Option<FloatTensor<Self>>,
801        bias: Option<FloatTensor<Self>>,
802        output_grad: FloatTensor<Self>,
803        options: DeformConvOptions<2>,
804    ) -> DeformConv2dBackward<Self> {
805        let client = x.client.clone();
806        let has_bias = bias.is_some();
807        let has_mask = mask.is_some();
808
809        let desc = DeformConv2dBackwardOpIr::create(
810            x.into_ir(),
811            offset.into_ir(),
812            weight.into_ir(),
813            mask.map(|mask| mask.into_ir()),
814            bias.map(|bias| bias.into_ir()),
815            output_grad.into_ir(),
816            options.into(),
817            || client.create_empty_handle(),
818        );
819        let mut outputs = client
820            .register(OperationIr::Module(
821                ModuleOperationIr::DeformableConv2dBackward(Box::new(desc)),
822            ))
823            .into_iter();
824
825        // When the number of outputs is variable, the order is important
826        let input_grad = outputs.next().unwrap();
827        let offset_grad = outputs.next().unwrap();
828        let weight_grad = outputs.next().unwrap();
829        let mask_grad = has_mask.then(|| outputs.next().unwrap());
830        let bias_grad = has_bias.then(|| outputs.next().unwrap());
831
832        DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
833    }
834
835    fn attention(
836        query: FloatTensor<Self>,
837        key: FloatTensor<Self>,
838        value: FloatTensor<Self>,
839        mask: Option<BoolTensor<Self>>,
840        attn_bias: Option<FloatTensor<Self>>,
841        options: AttentionModuleOptions,
842    ) -> FloatTensor<Self> {
843        let client = query.client.clone();
844        let desc = AttentionOpIr::create(
845            query.into_ir(),
846            key.into_ir(),
847            value.into_ir(),
848            mask.map(|m: BoolTensor<Self>| m.into_ir()),
849            attn_bias.map(|ab| ab.into_ir()),
850            options.into(),
851            || client.create_empty_handle(),
852        );
853
854        client
855            .register(OperationIr::Module(ModuleOperationIr::Attention(desc)))
856            .output()
857    }
858
859    fn rfft(
860        _signal: FloatTensor<Self>,
861        _dim: usize,
862        _n: Option<usize>,
863    ) -> (FloatTensor<Self>, FloatTensor<Self>) {
864        todo!("rfft is not supported for backend-router")
865    }
866
867    fn irfft(
868        _spectrum_re: FloatTensor<Self>,
869        _spectrum_im: FloatTensor<Self>,
870        _dim: usize,
871        _n: Option<usize>,
872    ) -> FloatTensor<Self> {
873        todo!("irfft is not supported for backend-router")
874    }
875}