1use alloc::boxed::Box;
2
3use burn_backend::Element;
4use burn_backend::ops::{
5 AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward,
6 DeformConvOptions, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices,
7 MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
8};
9use burn_backend::tensor::{BoolTensor, FloatTensor, IntElem, IntTensor};
10use burn_ir::*;
11
12use crate::{BackendRouter, RunnerChannel, RunnerClient};
13
14impl<R: RunnerChannel> ModuleOps<Self> for BackendRouter<R> {
15 fn linear(
16 x: FloatTensor<Self>,
17 weight: FloatTensor<Self>,
18 bias: Option<FloatTensor<Self>>,
19 ) -> FloatTensor<Self> {
20 let client = x.client.clone();
21 let desc = LinearOpIr::create(
22 x.into_ir(),
23 weight.into_ir(),
24 bias.map(|bias| bias.into_ir()),
25 || client.create_empty_handle(),
26 );
27
28 client
29 .register(OperationIr::Module(ModuleOperationIr::Linear(desc)))
30 .output()
31 }
32
33 fn linear_x_backward(
34 weight: FloatTensor<Self>,
35 output_grad: FloatTensor<Self>,
36 ) -> FloatTensor<Self> {
37 let client = weight.client.clone();
38 let desc = LinearXBackwardOpIr::create(weight.into_ir(), output_grad.into_ir(), || {
39 client.create_empty_handle()
40 });
41
42 client
43 .register(OperationIr::Module(ModuleOperationIr::LinearXBackward(
44 desc,
45 )))
46 .output()
47 }
48
49 fn linear_weight_backward(
50 x: FloatTensor<Self>,
51 output_grad: FloatTensor<Self>,
52 ) -> FloatTensor<Self> {
53 let client = x.client.clone();
54 let desc = LinearWeightBackwardOpIr::create(x.into_ir(), output_grad.into_ir(), || {
55 client.create_empty_handle()
56 });
57
58 client
59 .register(OperationIr::Module(
60 ModuleOperationIr::LinearWeightBackward(desc),
61 ))
62 .output()
63 }
64
65 fn linear_bias_backward(output_grad: FloatTensor<Self>) -> FloatTensor<Self> {
66 let client = output_grad.client.clone();
67 let desc =
68 LinearBiasBackwardOpIr::create(output_grad.into_ir(), || client.create_empty_handle());
69
70 client
71 .register(OperationIr::Module(ModuleOperationIr::LinearBiasBackward(
72 desc,
73 )))
74 .output()
75 }
76
77 fn conv1d(
78 x: FloatTensor<Self>,
79 weight: FloatTensor<Self>,
80 bias: Option<FloatTensor<Self>>,
81 options: ConvOptions<1>,
82 ) -> FloatTensor<Self> {
83 let client = x.client.clone();
84 let desc = Conv1dOpIr::create(
85 x.into_ir(),
86 weight.into_ir(),
87 bias.map(|bias| bias.into_ir()),
88 options.into(),
89 || client.create_empty_handle(),
90 );
91
92 client
93 .register(OperationIr::Module(ModuleOperationIr::Conv1d(desc)))
94 .output()
95 }
96
97 fn conv1d_x_backward(
98 x: FloatTensor<Self>,
99 weight: FloatTensor<Self>,
100 output_grad: FloatTensor<Self>,
101 options: ConvOptions<1>,
102 ) -> FloatTensor<Self> {
103 let client = x.client.clone();
104 let desc = Conv1dXBackwardOpIr::create(
105 x.into_ir(),
106 weight.into_ir(),
107 output_grad.into_ir(),
108 options.into(),
109 || client.create_empty_handle(),
110 );
111
112 client
113 .register(OperationIr::Module(ModuleOperationIr::Conv1dXBackward(
114 desc,
115 )))
116 .output()
117 }
118
119 fn conv1d_weight_backward(
120 x: FloatTensor<Self>,
121 weight: FloatTensor<Self>,
122 output_grad: FloatTensor<Self>,
123 options: ConvOptions<1>,
124 ) -> FloatTensor<Self> {
125 let client = x.client.clone();
126 let desc = Conv1dWeightBackwardOpIr::create(
127 x.into_ir(),
128 weight.into_ir(),
129 output_grad.into_ir(),
130 options.into(),
131 || client.create_empty_handle(),
132 );
133
134 client
135 .register(OperationIr::Module(
136 ModuleOperationIr::Conv1dWeightBackward(desc),
137 ))
138 .output()
139 }
140
141 fn conv1d_bias_backward(
142 x: FloatTensor<Self>,
143 bias: FloatTensor<Self>,
144 output_grad: FloatTensor<Self>,
145 ) -> FloatTensor<Self> {
146 let client = x.client.clone();
147 let desc = Conv1dBiasBackwardOpIr::create(
148 x.into_ir(),
149 bias.into_ir(),
150 output_grad.into_ir(),
151 || client.create_empty_handle(),
152 );
153
154 client
155 .register(OperationIr::Module(ModuleOperationIr::Conv1dBiasBackward(
156 desc,
157 )))
158 .output()
159 }
160
161 fn conv2d(
162 x: FloatTensor<Self>,
163 weight: FloatTensor<Self>,
164 bias: Option<FloatTensor<Self>>,
165 options: ConvOptions<2>,
166 ) -> FloatTensor<Self> {
167 let client = x.client.clone();
168 let desc = Conv2dOpIr::create(
169 x.into_ir(),
170 weight.into_ir(),
171 bias.map(|bias| bias.into_ir()),
172 options.into(),
173 || client.create_empty_handle(),
174 );
175
176 client
177 .register(OperationIr::Module(ModuleOperationIr::Conv2d(desc)))
178 .output()
179 }
180
181 fn conv2d_x_backward(
182 x: FloatTensor<Self>,
183 weight: FloatTensor<Self>,
184 output_grad: FloatTensor<Self>,
185 options: ConvOptions<2>,
186 ) -> FloatTensor<Self> {
187 let client = x.client.clone();
188 let desc = Conv2dXBackwardOpIr::create(
189 x.into_ir(),
190 weight.into_ir(),
191 output_grad.into_ir(),
192 options.into(),
193 || client.create_empty_handle(),
194 );
195
196 client
197 .register(OperationIr::Module(ModuleOperationIr::Conv2dXBackward(
198 desc,
199 )))
200 .output()
201 }
202
203 fn conv2d_weight_backward(
204 x: FloatTensor<Self>,
205 weight: FloatTensor<Self>,
206 output_grad: FloatTensor<Self>,
207 options: ConvOptions<2>,
208 ) -> FloatTensor<Self> {
209 let client = x.client.clone();
210 let desc = Conv2dWeightBackwardOpIr::create(
211 x.into_ir(),
212 weight.into_ir(),
213 output_grad.into_ir(),
214 options.into(),
215 || client.create_empty_handle(),
216 );
217
218 client
219 .register(OperationIr::Module(
220 ModuleOperationIr::Conv2dWeightBackward(desc),
221 ))
222 .output()
223 }
224
225 fn conv2d_bias_backward(
226 x: FloatTensor<Self>,
227 bias: FloatTensor<Self>,
228 output_grad: FloatTensor<Self>,
229 ) -> FloatTensor<Self> {
230 let client = x.client.clone();
231 let desc = Conv2dBiasBackwardOpIr::create(
232 x.into_ir(),
233 bias.into_ir(),
234 output_grad.into_ir(),
235 || client.create_empty_handle(),
236 );
237
238 client
239 .register(OperationIr::Module(ModuleOperationIr::Conv2dBiasBackward(
240 desc,
241 )))
242 .output()
243 }
244
245 fn conv3d(
246 x: FloatTensor<Self>,
247 weight: FloatTensor<Self>,
248 bias: Option<FloatTensor<Self>>,
249 options: ConvOptions<3>,
250 ) -> FloatTensor<Self> {
251 let client = x.client.clone();
252 let desc = Conv3dOpIr::create(
253 x.into_ir(),
254 weight.into_ir(),
255 bias.map(|bias| bias.into_ir()),
256 options.into(),
257 || client.create_empty_handle(),
258 );
259
260 client
261 .register(OperationIr::Module(ModuleOperationIr::Conv3d(desc)))
262 .output()
263 }
264
265 fn conv3d_x_backward(
266 x: FloatTensor<Self>,
267 weight: FloatTensor<Self>,
268 output_grad: FloatTensor<Self>,
269 options: ConvOptions<3>,
270 ) -> FloatTensor<Self> {
271 let client = x.client.clone();
272 let desc = Conv3dXBackwardOpIr::create(
273 x.into_ir(),
274 weight.into_ir(),
275 output_grad.into_ir(),
276 options.into(),
277 || client.create_empty_handle(),
278 );
279
280 client
281 .register(OperationIr::Module(ModuleOperationIr::Conv3dXBackward(
282 desc,
283 )))
284 .output()
285 }
286
287 fn conv3d_weight_backward(
288 x: FloatTensor<Self>,
289 weight: FloatTensor<Self>,
290 output_grad: FloatTensor<Self>,
291 options: ConvOptions<3>,
292 ) -> FloatTensor<Self> {
293 let client = x.client.clone();
294 let desc = Conv3dWeightBackwardOpIr::create(
295 x.into_ir(),
296 weight.into_ir(),
297 output_grad.into_ir(),
298 options.into(),
299 || client.create_empty_handle(),
300 );
301
302 client
303 .register(OperationIr::Module(
304 ModuleOperationIr::Conv3dWeightBackward(desc),
305 ))
306 .output()
307 }
308
309 fn conv3d_bias_backward(
310 x: FloatTensor<Self>,
311 bias: FloatTensor<Self>,
312 output_grad: FloatTensor<Self>,
313 ) -> FloatTensor<Self> {
314 let client = x.client.clone();
315 let desc = Conv3dBiasBackwardOpIr::create(
316 x.into_ir(),
317 bias.into_ir(),
318 output_grad.into_ir(),
319 || client.create_empty_handle(),
320 );
321
322 client
323 .register(OperationIr::Module(ModuleOperationIr::Conv3dBiasBackward(
324 desc,
325 )))
326 .output()
327 }
328
329 fn conv_transpose1d(
330 x: FloatTensor<Self>,
331 weight: FloatTensor<Self>,
332 bias: Option<FloatTensor<Self>>,
333 options: ConvTransposeOptions<1>,
334 ) -> FloatTensor<Self> {
335 let client = x.client.clone();
336 let desc = ConvTranspose1dOpIr::create(
337 x.into_ir(),
338 weight.into_ir(),
339 bias.map(|bias| bias.into_ir()),
340 options.into(),
341 || client.create_empty_handle(),
342 );
343
344 client
345 .register(OperationIr::Module(ModuleOperationIr::ConvTranspose1d(
346 desc,
347 )))
348 .output()
349 }
350
351 fn conv_transpose2d(
352 x: FloatTensor<Self>,
353 weight: FloatTensor<Self>,
354 bias: Option<FloatTensor<Self>>,
355 options: ConvTransposeOptions<2>,
356 ) -> FloatTensor<Self> {
357 let client = x.client.clone();
358 let desc = ConvTranspose2dOpIr::create(
359 x.into_ir(),
360 weight.into_ir(),
361 bias.map(|bias| bias.into_ir()),
362 options.into(),
363 || client.create_empty_handle(),
364 );
365
366 client
367 .register(OperationIr::Module(ModuleOperationIr::ConvTranspose2d(
368 desc,
369 )))
370 .output()
371 }
372
373 fn conv_transpose3d(
374 x: FloatTensor<Self>,
375 weight: FloatTensor<Self>,
376 bias: Option<FloatTensor<Self>>,
377 options: ConvTransposeOptions<3>,
378 ) -> FloatTensor<Self> {
379 let client = x.client.clone();
380 let desc = ConvTranspose3dOpIr::create(
381 x.into_ir(),
382 weight.into_ir(),
383 bias.map(|bias| bias.into_ir()),
384 options.into(),
385 || client.create_empty_handle(),
386 );
387
388 client
389 .register(OperationIr::Module(ModuleOperationIr::ConvTranspose3d(
390 desc,
391 )))
392 .output()
393 }
394
395 fn avg_pool1d(
396 x: FloatTensor<Self>,
397 kernel_size: usize,
398 stride: usize,
399 padding: usize,
400 count_include_pad: bool,
401 ceil_mode: bool,
402 ) -> FloatTensor<Self> {
403 let client = x.client.clone();
404 let desc = AvgPool1dOpIr::create(
405 x.into_ir(),
406 kernel_size,
407 stride,
408 padding,
409 count_include_pad,
410 ceil_mode,
411 || client.create_empty_handle(),
412 );
413
414 client
415 .register(OperationIr::Module(ModuleOperationIr::AvgPool1d(desc)))
416 .output()
417 }
418
419 fn avg_pool2d(
420 x: FloatTensor<Self>,
421 kernel_size: [usize; 2],
422 stride: [usize; 2],
423 padding: [usize; 2],
424 count_include_pad: bool,
425 ceil_mode: bool,
426 ) -> FloatTensor<Self> {
427 let client = x.client.clone();
428 let desc = AvgPool2dOpIr::create(
429 x.into_ir(),
430 kernel_size,
431 stride,
432 padding,
433 count_include_pad,
434 ceil_mode,
435 || client.create_empty_handle(),
436 );
437
438 client
439 .register(OperationIr::Module(ModuleOperationIr::AvgPool2d(desc)))
440 .output()
441 }
442
443 fn avg_pool1d_backward(
444 x: FloatTensor<Self>,
445 grad: FloatTensor<Self>,
446 kernel_size: usize,
447 stride: usize,
448 padding: usize,
449 count_include_pad: bool,
450 ceil_mode: bool,
451 ) -> FloatTensor<Self> {
452 let client = x.client.clone();
453 let desc = AvgPool1dBackwardOpIr::create(
454 x.into_ir(),
455 grad.into_ir(),
456 kernel_size,
457 stride,
458 padding,
459 count_include_pad,
460 ceil_mode,
461 || client.create_empty_handle(),
462 );
463
464 client
465 .register(OperationIr::Module(ModuleOperationIr::AvgPool1dBackward(
466 desc,
467 )))
468 .output()
469 }
470
471 fn avg_pool2d_backward(
472 x: FloatTensor<Self>,
473 grad: FloatTensor<Self>,
474 kernel_size: [usize; 2],
475 stride: [usize; 2],
476 padding: [usize; 2],
477 count_include_pad: bool,
478 ceil_mode: bool,
479 ) -> FloatTensor<Self> {
480 let client = x.client.clone();
481 let desc = AvgPool2dBackwardOpIr::create(
482 x.into_ir(),
483 grad.into_ir(),
484 kernel_size,
485 stride,
486 padding,
487 count_include_pad,
488 ceil_mode,
489 || client.create_empty_handle(),
490 );
491
492 client
493 .register(OperationIr::Module(ModuleOperationIr::AvgPool2dBackward(
494 desc,
495 )))
496 .output()
497 }
498
499 fn max_pool1d(
500 x: FloatTensor<Self>,
501 kernel_size: usize,
502 stride: usize,
503 padding: usize,
504 dilation: usize,
505 ceil_mode: bool,
506 ) -> FloatTensor<Self> {
507 let client = x.client.clone();
508 let desc = MaxPool1dOpIr::create(
509 x.into_ir(),
510 kernel_size,
511 stride,
512 padding,
513 dilation,
514 ceil_mode,
515 || client.create_empty_handle(),
516 );
517
518 client
519 .register(OperationIr::Module(ModuleOperationIr::MaxPool1d(desc)))
520 .output()
521 }
522
523 fn max_pool2d(
524 x: FloatTensor<Self>,
525 kernel_size: [usize; 2],
526 stride: [usize; 2],
527 padding: [usize; 2],
528 dilation: [usize; 2],
529 ceil_mode: bool,
530 ) -> FloatTensor<Self> {
531 let client = x.client.clone();
532 let desc = MaxPool2dOpIr::create(
533 x.into_ir(),
534 kernel_size,
535 stride,
536 padding,
537 dilation,
538 ceil_mode,
539 || client.create_empty_handle(),
540 );
541
542 client
543 .register(OperationIr::Module(ModuleOperationIr::MaxPool2d(desc)))
544 .output()
545 }
546
547 fn max_pool1d_with_indices(
548 x: FloatTensor<Self>,
549 kernel_size: usize,
550 stride: usize,
551 padding: usize,
552 dilation: usize,
553 ceil_mode: bool,
554 ) -> MaxPool1dWithIndices<Self> {
555 let client = x.client.clone();
556 let desc = MaxPool1dWithIndicesOpIr::create(
557 x.into_ir(),
558 kernel_size,
559 stride,
560 padding,
561 dilation,
562 ceil_mode,
563 IntElem::<Self>::dtype(),
564 || client.create_empty_handle(),
565 );
566
567 let [out, out_indices] = client
568 .register(OperationIr::Module(
569 ModuleOperationIr::MaxPool1dWithIndices(desc),
570 ))
571 .outputs();
572
573 MaxPool1dWithIndices::new(out, out_indices)
574 }
575
576 fn max_pool2d_with_indices(
577 x: FloatTensor<Self>,
578 kernel_size: [usize; 2],
579 stride: [usize; 2],
580 padding: [usize; 2],
581 dilation: [usize; 2],
582 ceil_mode: bool,
583 ) -> MaxPool2dWithIndices<Self> {
584 let client = x.client.clone();
585 let desc = MaxPool2dWithIndicesOpIr::create(
586 x.into_ir(),
587 kernel_size,
588 stride,
589 padding,
590 dilation,
591 ceil_mode,
592 IntElem::<Self>::dtype(),
593 || client.create_empty_handle(),
594 );
595
596 let [out, out_indices] = client
597 .register(OperationIr::Module(
598 ModuleOperationIr::MaxPool2dWithIndices(desc),
599 ))
600 .outputs();
601
602 MaxPool2dWithIndices::new(out, out_indices)
603 }
604
605 fn max_pool1d_with_indices_backward(
606 x: FloatTensor<Self>,
607 kernel_size: usize,
608 stride: usize,
609 padding: usize,
610 dilation: usize,
611 ceil_mode: bool,
612 output_grad: FloatTensor<Self>,
613 indices: IntTensor<Self>,
614 ) -> MaxPool1dBackward<Self> {
615 let client = x.client.clone();
616
617 let desc = MaxPool1dWithIndicesBackwardOpIr::create(
618 x.into_ir(),
619 output_grad.into_ir(),
620 indices.into_ir(),
621 kernel_size,
622 stride,
623 padding,
624 dilation,
625 ceil_mode,
626 || client.create_empty_handle(),
627 );
628
629 let out = client
630 .register(OperationIr::Module(
631 ModuleOperationIr::MaxPool1dWithIndicesBackward(desc),
632 ))
633 .output();
634
635 MaxPool1dBackward::new(out)
636 }
637
638 fn max_pool2d_with_indices_backward(
639 x: FloatTensor<Self>,
640 kernel_size: [usize; 2],
641 stride: [usize; 2],
642 padding: [usize; 2],
643 dilation: [usize; 2],
644 ceil_mode: bool,
645 output_grad: FloatTensor<Self>,
646 indices: IntTensor<Self>,
647 ) -> MaxPool2dBackward<Self> {
648 let client = x.client.clone();
649
650 let desc = MaxPool2dWithIndicesBackwardOpIr::create(
651 x.into_ir(),
652 output_grad.into_ir(),
653 indices.into_ir(),
654 kernel_size,
655 stride,
656 padding,
657 dilation,
658 ceil_mode,
659 || client.create_empty_handle(),
660 );
661
662 let out = client
663 .register(OperationIr::Module(
664 ModuleOperationIr::MaxPool2dWithIndicesBackward(desc),
665 ))
666 .output();
667
668 MaxPool2dBackward::new(out)
669 }
670
671 fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
672 let client = x.client.clone();
673
674 let desc = AdaptiveAvgPool1dOpIr::create(x.into_ir(), output_size, || {
675 client.create_empty_handle()
676 });
677
678 client
679 .register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d(
680 desc,
681 )))
682 .output()
683 }
684
685 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
686 let client = x.client.clone();
687
688 let desc = AdaptiveAvgPool2dOpIr::create(x.into_ir(), output_size, || {
689 client.create_empty_handle()
690 });
691
692 client
693 .register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d(
694 desc,
695 )))
696 .output()
697 }
698
699 fn adaptive_avg_pool1d_backward(
700 x: FloatTensor<Self>,
701 grad: FloatTensor<Self>,
702 ) -> FloatTensor<Self> {
703 let client = x.client.clone();
704
705 let desc = AdaptiveAvgPool1dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
706 client.create_empty_handle()
707 });
708
709 client
710 .register(OperationIr::Module(
711 ModuleOperationIr::AdaptiveAvgPool1dBackward(desc),
712 ))
713 .output()
714 }
715
716 fn adaptive_avg_pool2d_backward(
717 x: FloatTensor<Self>,
718 grad: FloatTensor<Self>,
719 ) -> FloatTensor<Self> {
720 let client = x.client.clone();
721
722 let desc = AdaptiveAvgPool2dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
723 client.create_empty_handle()
724 });
725
726 client
727 .register(OperationIr::Module(
728 ModuleOperationIr::AdaptiveAvgPool2dBackward(desc),
729 ))
730 .output()
731 }
732
733 fn interpolate(
734 x: FloatTensor<Self>,
735 output_size: [usize; 2],
736 options: InterpolateOptions,
737 ) -> FloatTensor<Self> {
738 let client = x.client.clone();
739 let desc = InterpolateOpIr::create(x.into_ir(), output_size, options.into(), || {
740 client.create_empty_handle()
741 });
742
743 client
744 .register(OperationIr::Module(ModuleOperationIr::Interpolate(desc)))
745 .output()
746 }
747
748 fn interpolate_backward(
749 x: FloatTensor<Self>,
750 grad: FloatTensor<Self>,
751 output_size: [usize; 2],
752 options: InterpolateOptions,
753 ) -> FloatTensor<Self> {
754 let client = x.client.clone();
755 let desc = InterpolateBackwardOpIr::create(
756 x.into_ir(),
757 grad.into_ir(),
758 output_size,
759 options.into(),
760 || client.create_empty_handle(),
761 );
762
763 client
764 .register(OperationIr::Module(ModuleOperationIr::InterpolateBackward(
765 desc,
766 )))
767 .output()
768 }
769
770 fn deform_conv2d(
771 x: FloatTensor<Self>,
772 offset: FloatTensor<Self>,
773 weight: FloatTensor<Self>,
774 mask: Option<FloatTensor<Self>>,
775 bias: Option<FloatTensor<Self>>,
776 options: DeformConvOptions<2>,
777 ) -> FloatTensor<Self> {
778 let client = x.client.clone();
779 let desc = DeformConv2dOpIr::create(
780 x.into_ir(),
781 offset.into_ir(),
782 weight.into_ir(),
783 mask.map(|mask| mask.into_ir()),
784 bias.map(|bias| bias.into_ir()),
785 options.into(),
786 || client.create_empty_handle(),
787 );
788
789 client
790 .register(OperationIr::Module(ModuleOperationIr::DeformableConv2d(
791 Box::new(desc),
792 )))
793 .output()
794 }
795
796 fn deform_conv2d_backward(
797 x: FloatTensor<Self>,
798 offset: FloatTensor<Self>,
799 weight: FloatTensor<Self>,
800 mask: Option<FloatTensor<Self>>,
801 bias: Option<FloatTensor<Self>>,
802 output_grad: FloatTensor<Self>,
803 options: DeformConvOptions<2>,
804 ) -> DeformConv2dBackward<Self> {
805 let client = x.client.clone();
806 let has_bias = bias.is_some();
807 let has_mask = mask.is_some();
808
809 let desc = DeformConv2dBackwardOpIr::create(
810 x.into_ir(),
811 offset.into_ir(),
812 weight.into_ir(),
813 mask.map(|mask| mask.into_ir()),
814 bias.map(|bias| bias.into_ir()),
815 output_grad.into_ir(),
816 options.into(),
817 || client.create_empty_handle(),
818 );
819 let mut outputs = client
820 .register(OperationIr::Module(
821 ModuleOperationIr::DeformableConv2dBackward(Box::new(desc)),
822 ))
823 .into_iter();
824
825 let input_grad = outputs.next().unwrap();
827 let offset_grad = outputs.next().unwrap();
828 let weight_grad = outputs.next().unwrap();
829 let mask_grad = has_mask.then(|| outputs.next().unwrap());
830 let bias_grad = has_bias.then(|| outputs.next().unwrap());
831
832 DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
833 }
834
835 fn attention(
836 query: FloatTensor<Self>,
837 key: FloatTensor<Self>,
838 value: FloatTensor<Self>,
839 mask: Option<BoolTensor<Self>>,
840 attn_bias: Option<FloatTensor<Self>>,
841 options: AttentionModuleOptions,
842 ) -> FloatTensor<Self> {
843 let client = query.client.clone();
844 let desc = AttentionOpIr::create(
845 query.into_ir(),
846 key.into_ir(),
847 value.into_ir(),
848 mask.map(|m: BoolTensor<Self>| m.into_ir()),
849 attn_bias.map(|ab| ab.into_ir()),
850 options.into(),
851 || client.create_empty_handle(),
852 );
853
854 client
855 .register(OperationIr::Module(ModuleOperationIr::Attention(desc)))
856 .output()
857 }
858
859 fn rfft(
860 _signal: FloatTensor<Self>,
861 _dim: usize,
862 _n: Option<usize>,
863 ) -> (FloatTensor<Self>, FloatTensor<Self>) {
864 todo!("rfft is not supported for backend-router")
865 }
866
867 fn irfft(
868 _spectrum_re: FloatTensor<Self>,
869 _spectrum_im: FloatTensor<Self>,
870 _dim: usize,
871 _n: Option<usize>,
872 ) -> FloatTensor<Self> {
873 todo!("irfft is not supported for backend-router")
874 }
875}