1use crate::{
2 Fusion, FusionBackend,
3 stream::{OperationStreams, execution::Operation},
4};
5use burn_backend::{
6 Element,
7 ops::{
8 ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions,
9 InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,
10 MaxPool2dWithIndices, ModuleOps,
11 },
12 tensor::{FloatTensor, IntTensor},
13};
14use burn_ir::*;
15use std::marker::PhantomData;
16
17macro_rules! make_ops {
18 ($name:ident, $desc:ty, $fn:expr) => {
19 #[derive(new, Debug)]
20 struct $name<B: FusionBackend> {
21 desc: $desc,
22 _b: PhantomData<B>,
23 }
24
25 impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
26 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
27 #[allow(clippy::redundant_closure_call)]
28 $fn(&self.desc, handles)
29 }
30 }
31 };
32}
33
34impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
35 fn conv1d(
36 x: FloatTensor<Self>,
37 weight: FloatTensor<Self>,
38 bias: Option<FloatTensor<Self>>,
39 options: ConvOptions<1>,
40 ) -> FloatTensor<Self> {
41 make_ops!(Conv1dOps, Conv1dOpIr, |desc: &Conv1dOpIr,
42 handles: &mut HandleContainer<
43 B::Handle,
44 >| {
45 let x = handles.get_float_tensor::<B>(&desc.x);
46 let weight = handles.get_float_tensor::<B>(&desc.weight);
47 let bias = desc
48 .bias
49 .as_ref()
50 .map(|bias| handles.get_float_tensor::<B>(bias));
51 let output = B::conv1d(x, weight, bias, desc.options.clone().into());
52 handles.register_float_tensor::<B>(&desc.out.id, output);
53 });
54
55 let mut streams = OperationStreams::with_inputs([&x, &weight]);
56 if let Some(bias) = bias.as_ref() {
57 streams.tensor(bias)
58 }
59
60 let client = x.client.clone();
61 let desc = Conv1dOpIr::create(
62 x.into_ir(),
63 weight.into_ir(),
64 bias.map(|bias| bias.into_ir()),
65 options.into(),
66 || client.create_empty_handle(),
67 );
68
69 client
70 .register(
71 streams,
72 OperationIr::Module(ModuleOperationIr::Conv1d(desc.clone())),
73 Conv1dOps::<B>::new(desc),
74 )
75 .output()
76 }
77
78 fn conv2d(
79 x: FloatTensor<Self>,
80 weight: FloatTensor<Self>,
81 bias: Option<FloatTensor<Self>>,
82 options: ConvOptions<2>,
83 ) -> FloatTensor<Self> {
84 make_ops!(Conv2dOps, Conv2dOpIr, |args: &Conv2dOpIr,
85 handles: &mut HandleContainer<
86 B::Handle,
87 >| {
88 let x = handles.get_float_tensor::<B>(&args.x);
89 let weight = handles.get_float_tensor::<B>(&args.weight);
90 let bias = args
91 .bias
92 .as_ref()
93 .map(|bias| handles.get_float_tensor::<B>(bias));
94
95 let output = B::conv2d(x, weight, bias, args.options.clone().into());
96
97 handles.register_float_tensor::<B>(&args.out.id, output);
98 });
99
100 let mut streams = OperationStreams::with_inputs([&x, &weight]);
101 if let Some(bias) = bias.as_ref() {
102 streams.tensor(bias)
103 }
104
105 let client = x.client.clone();
106 let desc = Conv2dOpIr::create(
107 x.into_ir(),
108 weight.into_ir(),
109 bias.map(|bias| bias.into_ir()),
110 options.into(),
111 || client.create_empty_handle(),
112 );
113
114 client
115 .register(
116 streams,
117 OperationIr::Module(ModuleOperationIr::Conv2d(desc.clone())),
118 Conv2dOps::<B>::new(desc),
119 )
120 .output()
121 }
122
123 fn deform_conv2d(
124 x: FloatTensor<Self>,
125 offset: FloatTensor<Self>,
126 weight: FloatTensor<Self>,
127 mask: Option<FloatTensor<Self>>,
128 bias: Option<FloatTensor<Self>>,
129 options: DeformConvOptions<2>,
130 ) -> FloatTensor<Self> {
131 make_ops!(
132 DeformConv2dOps,
133 DeformConv2dOpIr,
134 |args: &DeformConv2dOpIr, handles: &mut HandleContainer<B::Handle>| {
135 let x = handles.get_float_tensor::<B>(&args.x);
136 let offset = handles.get_float_tensor::<B>(&args.offset);
137 let weight = handles.get_float_tensor::<B>(&args.weight);
138 let mask = args
139 .mask
140 .as_ref()
141 .map(|mask| handles.get_float_tensor::<B>(mask));
142 let bias = args
143 .bias
144 .as_ref()
145 .map(|bias| handles.get_float_tensor::<B>(bias));
146
147 let output =
148 B::deform_conv2d(x, offset, weight, mask, bias, args.options.clone().into());
149
150 handles.register_float_tensor::<B>(&args.out.id, output);
151 }
152 );
153 let mut streams = OperationStreams::with_inputs([&x, &offset, &weight]);
154 if let Some(bias) = bias.as_ref() {
155 streams.tensor(bias)
156 }
157 if let Some(mask) = mask.as_ref() {
158 streams.tensor(mask)
159 }
160
161 let client = x.client.clone();
162 let desc = DeformConv2dOpIr::create(
163 x.into_ir(),
164 offset.into_ir(),
165 weight.into_ir(),
166 mask.map(|mask| mask.into_ir()),
167 bias.map(|bias| bias.into_ir()),
168 options.into(),
169 || client.create_empty_handle(),
170 );
171
172 client
173 .register(
174 streams,
175 OperationIr::Module(ModuleOperationIr::DeformableConv2d(Box::new(desc.clone()))),
176 DeformConv2dOps::<B>::new(desc),
177 )
178 .output()
179 }
180
181 fn deform_conv2d_backward(
182 x: FloatTensor<Self>,
183 offset: FloatTensor<Self>,
184 weight: FloatTensor<Self>,
185 mask: Option<FloatTensor<Self>>,
186 bias: Option<FloatTensor<Self>>,
187 output_grad: FloatTensor<Self>,
188 options: DeformConvOptions<2>,
189 ) -> DeformConv2dBackward<Self> {
190 make_ops!(
191 DeformConv2dBackwardOps,
192 DeformConv2dBackwardOpIr,
193 |args: &DeformConv2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
194 let x = handles.get_float_tensor::<B>(&args.x);
195 let offset = handles.get_float_tensor::<B>(&args.offset);
196 let weight = handles.get_float_tensor::<B>(&args.weight);
197 let mask = args
198 .mask
199 .as_ref()
200 .map(|mask| handles.get_float_tensor::<B>(mask));
201 let bias = args
202 .bias
203 .as_ref()
204 .map(|bias| handles.get_float_tensor::<B>(bias));
205 let output_grad = handles.get_float_tensor::<B>(&args.out_grad);
206
207 let output = B::deform_conv2d_backward(
208 x,
209 offset,
210 weight,
211 mask,
212 bias,
213 output_grad,
214 args.options.clone().into(),
215 );
216
217 handles.register_float_tensor::<B>(&args.input_grad.id, output.x_grad);
218 handles.register_float_tensor::<B>(&args.offset_grad.id, output.offset_grad);
219 handles.register_float_tensor::<B>(&args.weight_grad.id, output.weight_grad);
220 if let Some((mask_grad, field)) = output.mask_grad.zip(args.mask_grad.as_ref()) {
221 handles.register_float_tensor::<B>(&field.id, mask_grad);
222 }
223 if let Some((bias_grad, field)) = output.bias_grad.zip(args.bias_grad.as_ref()) {
224 handles.register_float_tensor::<B>(&field.id, bias_grad);
225 }
226 }
227 );
228
229 let has_bias = bias.is_some();
230 let has_mask = mask.is_some();
231
232 let mut streams = OperationStreams::with_inputs([&x, &offset, &weight, &output_grad]);
233 if let Some(bias) = bias.as_ref() {
234 streams.tensor(bias);
235 }
236 if let Some(mask) = mask.as_ref() {
237 streams.tensor(mask);
238 }
239
240 let client = x.client.clone();
241 let desc = DeformConv2dBackwardOpIr::create(
242 x.into_ir(),
243 offset.into_ir(),
244 weight.into_ir(),
245 mask.map(|mask| mask.into_ir()),
246 bias.map(|bias| bias.into_ir()),
247 output_grad.into_ir(),
248 options.into(),
249 || client.create_empty_handle(),
250 );
251
252 let mut outputs = client
253 .register(
254 streams,
255 OperationIr::Module(ModuleOperationIr::DeformableConv2dBackward(Box::new(
256 desc.clone(),
257 ))),
258 DeformConv2dBackwardOps::<B>::new(desc),
259 )
260 .into_iter();
261
262 let input_grad = outputs.next().unwrap();
264 let offset_grad = outputs.next().unwrap();
265 let weight_grad = outputs.next().unwrap();
266 let mask_grad = has_mask.then(|| outputs.next().unwrap());
267 let bias_grad = has_bias.then(|| outputs.next().unwrap());
268
269 DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
270 }
271
272 fn conv3d(
273 x: FloatTensor<Self>,
274 weight: FloatTensor<Self>,
275 bias: Option<FloatTensor<Self>>,
276 options: ConvOptions<3>,
277 ) -> FloatTensor<Self> {
278 make_ops!(Conv3dOps, Conv3dOpIr, |args: &Conv3dOpIr,
279 handles: &mut HandleContainer<
280 B::Handle,
281 >| {
282 let x = handles.get_float_tensor::<B>(&args.x);
283 let weight = handles.get_float_tensor::<B>(&args.weight);
284 let bias = args
285 .bias
286 .as_ref()
287 .map(|bias| handles.get_float_tensor::<B>(bias));
288
289 let output = B::conv3d(x, weight, bias, args.options.clone().into());
290
291 handles.register_float_tensor::<B>(&args.out.id, output);
292 });
293
294 let mut streams = OperationStreams::with_inputs([&x, &weight]);
295 if let Some(bias) = bias.as_ref() {
296 streams.tensor(bias)
297 }
298
299 let client = x.client.clone();
300 let desc = Conv3dOpIr::create(
301 x.into_ir(),
302 weight.into_ir(),
303 bias.map(|bias| bias.into_ir()),
304 options.into(),
305 || client.create_empty_handle(),
306 );
307
308 client
309 .register(
310 streams,
311 OperationIr::Module(ModuleOperationIr::Conv3d(desc.clone())),
312 Conv3dOps::<B>::new(desc),
313 )
314 .output()
315 }
316
317 fn conv_transpose1d(
318 x: FloatTensor<Self>,
319 weight: FloatTensor<Self>,
320 bias: Option<FloatTensor<Self>>,
321 options: ConvTransposeOptions<1>,
322 ) -> FloatTensor<Self> {
323 make_ops!(
324 ConvTranspose1dOps,
325 ConvTranspose1dOpIr,
326 |args: &ConvTranspose1dOpIr, handles: &mut HandleContainer<B::Handle>| {
327 let x = handles.get_float_tensor::<B>(&args.x);
328 let weight = handles.get_float_tensor::<B>(&args.weight);
329 let bias = args
330 .bias
331 .as_ref()
332 .map(|bias| handles.get_float_tensor::<B>(bias));
333
334 let output = B::conv_transpose1d(x, weight, bias, args.options.clone().into());
335
336 handles.register_float_tensor::<B>(&args.out.id, output);
337 }
338 );
339 let mut streams = OperationStreams::with_inputs([&x, &weight]);
340 if let Some(bias) = bias.as_ref() {
341 streams.tensor(bias)
342 }
343
344 let client = x.client.clone();
345 let desc = ConvTranspose1dOpIr::create(
346 x.into_ir(),
347 weight.into_ir(),
348 bias.map(|bias| bias.into_ir()),
349 options.into(),
350 || client.create_empty_handle(),
351 );
352
353 client
354 .register(
355 streams,
356 OperationIr::Module(ModuleOperationIr::ConvTranspose1d(desc.clone())),
357 ConvTranspose1dOps::<B>::new(desc),
358 )
359 .output()
360 }
361
362 fn conv_transpose2d(
363 x: FloatTensor<Self>,
364 weight: FloatTensor<Self>,
365 bias: Option<FloatTensor<Self>>,
366 options: ConvTransposeOptions<2>,
367 ) -> FloatTensor<Self> {
368 make_ops!(
369 ConvTranspose2dOps,
370 ConvTranspose2dOpIr,
371 |args: &ConvTranspose2dOpIr, handles: &mut HandleContainer<B::Handle>| {
372 let x = handles.get_float_tensor::<B>(&args.x);
373 let weight = handles.get_float_tensor::<B>(&args.weight);
374 let bias = args
375 .bias
376 .as_ref()
377 .map(|bias| handles.get_float_tensor::<B>(bias));
378
379 let output = B::conv_transpose2d(x, weight, bias, args.options.clone().into());
380
381 handles.register_float_tensor::<B>(&args.out.id, output);
382 }
383 );
384 let mut streams = OperationStreams::with_inputs([&x, &weight]);
385 if let Some(bias) = bias.as_ref() {
386 streams.tensor(bias)
387 }
388
389 let client = x.client.clone();
390 let desc = ConvTranspose2dOpIr::create(
391 x.into_ir(),
392 weight.into_ir(),
393 bias.map(|bias| bias.into_ir()),
394 options.into(),
395 || client.create_empty_handle(),
396 );
397
398 client
399 .register(
400 streams,
401 OperationIr::Module(ModuleOperationIr::ConvTranspose2d(desc.clone())),
402 ConvTranspose2dOps::<B>::new(desc),
403 )
404 .output()
405 }
406
407 fn conv_transpose3d(
408 x: FloatTensor<Self>,
409 weight: FloatTensor<Self>,
410 bias: Option<FloatTensor<Self>>,
411 options: ConvTransposeOptions<3>,
412 ) -> FloatTensor<Self> {
413 make_ops!(
414 ConvTranspose3dOps,
415 ConvTranspose3dOpIr,
416 |args: &ConvTranspose3dOpIr, handles: &mut HandleContainer<B::Handle>| {
417 let x = handles.get_float_tensor::<B>(&args.x);
418 let weight = handles.get_float_tensor::<B>(&args.weight);
419 let bias = args
420 .bias
421 .as_ref()
422 .map(|bias| handles.get_float_tensor::<B>(bias));
423
424 let output = B::conv_transpose3d(x, weight, bias, args.options.clone().into());
425
426 handles.register_float_tensor::<B>(&args.out.id, output);
427 }
428 );
429 let mut streams = OperationStreams::with_inputs([&x, &weight]);
430 if let Some(bias) = bias.as_ref() {
431 streams.tensor(bias)
432 }
433
434 let client = x.client.clone();
435 let desc = ConvTranspose3dOpIr::create(
436 x.into_ir(),
437 weight.into_ir(),
438 bias.map(|bias| bias.into_ir()),
439 options.into(),
440 || client.create_empty_handle(),
441 );
442
443 client
444 .register(
445 streams,
446 OperationIr::Module(ModuleOperationIr::ConvTranspose3d(desc.clone())),
447 ConvTranspose3dOps::<B>::new(desc),
448 )
449 .output()
450 }
451
452 fn avg_pool1d(
453 x: FloatTensor<Self>,
454 kernel_size: usize,
455 stride: usize,
456 padding: usize,
457 count_include_pad: bool,
458 ceil_mode: bool,
459 ) -> FloatTensor<Self> {
460 make_ops!(
461 AvgPool1dOps,
462 AvgPool1dOpIr,
463 |args: &AvgPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
464 let x = handles.get_float_tensor::<B>(&args.x);
465 let output = B::avg_pool1d(
466 x,
467 args.kernel_size,
468 args.stride,
469 args.padding,
470 args.count_include_pad,
471 args.ceil_mode,
472 );
473
474 handles.register_float_tensor::<B>(&args.out.id, output);
475 }
476 );
477 let streams = OperationStreams::with_inputs([&x]);
478
479 let client = x.client.clone();
480 let desc = AvgPool1dOpIr::create(
481 x.into_ir(),
482 kernel_size,
483 stride,
484 padding,
485 count_include_pad,
486 ceil_mode,
487 || client.create_empty_handle(),
488 );
489
490 client
491 .register(
492 streams,
493 OperationIr::Module(ModuleOperationIr::AvgPool1d(desc.clone())),
494 AvgPool1dOps::<B>::new(desc),
495 )
496 .output()
497 }
498
499 fn avg_pool2d(
500 x: FloatTensor<Self>,
501 kernel_size: [usize; 2],
502 stride: [usize; 2],
503 padding: [usize; 2],
504 count_include_pad: bool,
505 ceil_mode: bool,
506 ) -> FloatTensor<Self> {
507 make_ops!(
508 AvgPool2dOps,
509 AvgPool2dOpIr,
510 |args: &AvgPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
511 let x = handles.get_float_tensor::<B>(&args.x);
512 let output = B::avg_pool2d(
513 x,
514 args.kernel_size,
515 args.stride,
516 args.padding,
517 args.count_include_pad,
518 args.ceil_mode,
519 );
520
521 handles.register_float_tensor::<B>(&args.out.id, output);
522 }
523 );
524
525 let streams = OperationStreams::with_inputs([&x]);
526
527 let client = x.client.clone();
528 let desc = AvgPool2dOpIr::create(
529 x.into_ir(),
530 kernel_size,
531 stride,
532 padding,
533 count_include_pad,
534 ceil_mode,
535 || client.create_empty_handle(),
536 );
537
538 client
539 .register(
540 streams,
541 OperationIr::Module(ModuleOperationIr::AvgPool2d(desc.clone())),
542 AvgPool2dOps::<B>::new(desc),
543 )
544 .output()
545 }
546
547 fn avg_pool1d_backward(
548 x: FloatTensor<Self>,
549 grad: FloatTensor<Self>,
550 kernel_size: usize,
551 stride: usize,
552 padding: usize,
553 count_include_pad: bool,
554 ceil_mode: bool,
555 ) -> FloatTensor<Self> {
556 make_ops!(
557 AvgPool1dBackwardOps,
558 AvgPool1dBackwardOpIr,
559 |args: &AvgPool1dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
560 let x = handles.get_float_tensor::<B>(&args.x);
561 let grad = handles.get_float_tensor::<B>(&args.grad);
562 let output = B::avg_pool1d_backward(
563 x,
564 grad,
565 args.kernel_size,
566 args.stride,
567 args.padding,
568 args.count_include_pad,
569 args.ceil_mode,
570 );
571
572 handles.register_float_tensor::<B>(&args.out.id, output);
573 }
574 );
575
576 let streams = OperationStreams::with_inputs([&x, &grad]);
577
578 let client = x.client.clone();
579 let desc = AvgPool1dBackwardOpIr::create(
580 x.into_ir(),
581 grad.into_ir(),
582 kernel_size,
583 stride,
584 padding,
585 count_include_pad,
586 ceil_mode,
587 || client.create_empty_handle(),
588 );
589
590 client
591 .register(
592 streams,
593 OperationIr::Module(ModuleOperationIr::AvgPool1dBackward(desc.clone())),
594 AvgPool1dBackwardOps::<B>::new(desc),
595 )
596 .output()
597 }
598
599 fn avg_pool2d_backward(
600 x: FloatTensor<Self>,
601 grad: FloatTensor<Self>,
602 kernel_size: [usize; 2],
603 stride: [usize; 2],
604 padding: [usize; 2],
605 count_include_pad: bool,
606 ceil_mode: bool,
607 ) -> FloatTensor<Self> {
608 make_ops!(
609 AvgPool2dBackwardOps,
610 AvgPool2dBackwardOpIr,
611 |args: &AvgPool2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
612 let x = handles.get_float_tensor::<B>(&args.x);
613 let grad = handles.get_float_tensor::<B>(&args.grad);
614 let output = B::avg_pool2d_backward(
615 x,
616 grad,
617 args.kernel_size,
618 args.stride,
619 args.padding,
620 args.count_include_pad,
621 args.ceil_mode,
622 );
623
624 handles.register_float_tensor::<B>(&args.out.id, output);
625 }
626 );
627
628 let streams = OperationStreams::with_inputs([&x, &grad]);
629
630 let client = x.client.clone();
631 let desc = AvgPool2dBackwardOpIr::create(
632 x.into_ir(),
633 grad.into_ir(),
634 kernel_size,
635 stride,
636 padding,
637 count_include_pad,
638 ceil_mode,
639 || client.create_empty_handle(),
640 );
641
642 client
643 .register(
644 streams,
645 OperationIr::Module(ModuleOperationIr::AvgPool2dBackward(desc.clone())),
646 AvgPool2dBackwardOps::<B>::new(desc),
647 )
648 .output()
649 }
650
651 fn max_pool1d(
652 x: FloatTensor<Self>,
653 kernel_size: usize,
654 stride: usize,
655 padding: usize,
656 dilation: usize,
657 ceil_mode: bool,
658 ) -> FloatTensor<Self> {
659 make_ops!(
660 MaxPool1dOps,
661 MaxPool1dOpIr,
662 |args: &MaxPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
663 let x = handles.get_float_tensor::<B>(&args.x);
664 let output = B::max_pool1d(
665 x,
666 args.kernel_size,
667 args.stride,
668 args.padding,
669 args.dilation,
670 args.ceil_mode,
671 );
672
673 handles.register_float_tensor::<B>(&args.out.id, output);
674 }
675 );
676
677 let streams = OperationStreams::with_inputs([&x]);
678
679 let client = x.client.clone();
680 let desc = MaxPool1dOpIr::create(
681 x.into_ir(),
682 kernel_size,
683 stride,
684 padding,
685 dilation,
686 ceil_mode,
687 || client.create_empty_handle(),
688 );
689
690 client
691 .register(
692 streams,
693 OperationIr::Module(ModuleOperationIr::MaxPool1d(desc.clone())),
694 MaxPool1dOps::<B>::new(desc),
695 )
696 .output()
697 }
698
699 fn max_pool2d(
700 x: FloatTensor<Self>,
701 kernel_size: [usize; 2],
702 stride: [usize; 2],
703 padding: [usize; 2],
704 dilation: [usize; 2],
705 ceil_mode: bool,
706 ) -> FloatTensor<Self> {
707 make_ops!(
708 MaxPool2dOps,
709 MaxPool2dOpIr,
710 |args: &MaxPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
711 let x = handles.get_float_tensor::<B>(&args.x);
712 let output = B::max_pool2d(
713 x,
714 args.kernel_size,
715 args.stride,
716 args.padding,
717 args.dilation,
718 args.ceil_mode,
719 );
720
721 handles.register_float_tensor::<B>(&args.out.id, output);
722 }
723 );
724
725 let streams = OperationStreams::with_inputs([&x]);
726
727 let client = x.client.clone();
728 let desc = MaxPool2dOpIr::create(
729 x.into_ir(),
730 kernel_size,
731 stride,
732 padding,
733 dilation,
734 ceil_mode,
735 || client.create_empty_handle(),
736 );
737
738 client
739 .register(
740 streams,
741 OperationIr::Module(ModuleOperationIr::MaxPool2d(desc.clone())),
742 MaxPool2dOps::<B>::new(desc),
743 )
744 .output()
745 }
746
747 fn max_pool1d_with_indices(
748 x: FloatTensor<Self>,
749 kernel_size: usize,
750 stride: usize,
751 padding: usize,
752 dilation: usize,
753 ceil_mode: bool,
754 ) -> MaxPool1dWithIndices<Self> {
755 make_ops!(
756 MaxPool1dWithIndicesOps,
757 MaxPool1dWithIndicesOpIr,
758 |args: &MaxPool1dWithIndicesOpIr, handles: &mut HandleContainer<B::Handle>| {
759 let x = handles.get_float_tensor::<B>(&args.x);
760 let output = B::max_pool1d_with_indices(
761 x,
762 args.kernel_size,
763 args.stride,
764 args.padding,
765 args.dilation,
766 args.ceil_mode,
767 );
768
769 handles.register_float_tensor::<B>(&args.out.id, output.output);
770 handles.register_int_tensor::<B>(&args.out_indices.id, output.indices);
771 }
772 );
773
774 let streams = OperationStreams::with_inputs([&x]);
775
776 let client = x.client.clone();
777 let desc = MaxPool1dWithIndicesOpIr::create(
778 x.into_ir(),
779 kernel_size,
780 stride,
781 padding,
782 dilation,
783 ceil_mode,
784 B::IntElem::dtype(),
785 || client.create_empty_handle(),
786 );
787
788 let [out, out_indices] = client
789 .register(
790 streams,
791 OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndices(desc.clone())),
792 MaxPool1dWithIndicesOps::<B>::new(desc),
793 )
794 .outputs();
795
796 MaxPool1dWithIndices::new(out, out_indices)
797 }
798
799 fn max_pool2d_with_indices(
800 x: FloatTensor<Self>,
801 kernel_size: [usize; 2],
802 stride: [usize; 2],
803 padding: [usize; 2],
804 dilation: [usize; 2],
805 ceil_mode: bool,
806 ) -> MaxPool2dWithIndices<Self> {
807 make_ops!(
808 MaxPool2dWithIndicesOps,
809 MaxPool2dWithIndicesOpIr,
810 |args: &MaxPool2dWithIndicesOpIr, handles: &mut HandleContainer<B::Handle>| {
811 let x = handles.get_float_tensor::<B>(&args.x);
812 let output = B::max_pool2d_with_indices(
813 x,
814 args.kernel_size,
815 args.stride,
816 args.padding,
817 args.dilation,
818 args.ceil_mode,
819 );
820
821 handles.register_float_tensor::<B>(&args.out.id, output.output);
822 handles.register_int_tensor::<B>(&args.out_indices.id, output.indices);
823 }
824 );
825
826 let streams = OperationStreams::with_inputs([&x]);
827
828 let client = x.client.clone();
829 let desc = MaxPool2dWithIndicesOpIr::create(
830 x.into_ir(),
831 kernel_size,
832 stride,
833 padding,
834 dilation,
835 ceil_mode,
836 B::IntElem::dtype(),
837 || client.create_empty_handle(),
838 );
839
840 let [out, out_indices] = client
841 .register(
842 streams,
843 OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndices(desc.clone())),
844 MaxPool2dWithIndicesOps::<B>::new(desc),
845 )
846 .outputs();
847
848 MaxPool2dWithIndices::new(out, out_indices)
849 }
850
851 fn max_pool1d_with_indices_backward(
852 x: FloatTensor<Self>,
853 kernel_size: usize,
854 stride: usize,
855 padding: usize,
856 dilation: usize,
857 ceil_mode: bool,
858 output_grad: FloatTensor<Self>,
859 indices: IntTensor<Self>,
860 ) -> MaxPool1dBackward<Self> {
861 make_ops!(
862 MaxPool1dWithIndicesBackwardOps,
863 MaxPool1dWithIndicesBackwardOpIr,
864 |args: &MaxPool1dWithIndicesBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
865 let x = handles.get_float_tensor::<B>(&args.x);
866 let grad = handles.get_float_tensor::<B>(&args.grad);
867 let indices = handles.get_int_tensor::<B>(&args.indices);
868 let output = B::max_pool1d_with_indices_backward(
869 x,
870 args.kernel_size,
871 args.stride,
872 args.padding,
873 args.dilation,
874 args.ceil_mode,
875 grad,
876 indices,
877 );
878
879 handles.register_float_tensor::<B>(&args.out.id, output.x_grad);
880 }
881 );
882
883 let streams = OperationStreams::with_inputs([&x, &output_grad, &indices]);
884
885 let client = x.client.clone();
886 let desc = MaxPool1dWithIndicesBackwardOpIr::create(
887 x.into_ir(),
888 output_grad.into_ir(),
889 indices.into_ir(),
890 kernel_size,
891 stride,
892 padding,
893 dilation,
894 ceil_mode,
895 || client.create_empty_handle(),
896 );
897
898 let out = client
899 .register(
900 streams,
901 OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndicesBackward(
902 desc.clone(),
903 )),
904 MaxPool1dWithIndicesBackwardOps::<B>::new(desc),
905 )
906 .output();
907
908 MaxPool1dBackward::new(out)
909 }
910
911 fn max_pool2d_with_indices_backward(
912 x: FloatTensor<Self>,
913 kernel_size: [usize; 2],
914 stride: [usize; 2],
915 padding: [usize; 2],
916 dilation: [usize; 2],
917 ceil_mode: bool,
918 output_grad: FloatTensor<Self>,
919 indices: IntTensor<Self>,
920 ) -> MaxPool2dBackward<Self> {
921 make_ops!(
922 MaxPool2dWithIndicesBackwardOps,
923 MaxPool2dWithIndicesBackwardOpIr,
924 |args: &MaxPool2dWithIndicesBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
925 let x = handles.get_float_tensor::<B>(&args.x);
926 let grad = handles.get_float_tensor::<B>(&args.grad);
927 let indices = handles.get_int_tensor::<B>(&args.indices);
928 let output = B::max_pool2d_with_indices_backward(
929 x,
930 args.kernel_size,
931 args.stride,
932 args.padding,
933 args.dilation,
934 args.ceil_mode,
935 grad,
936 indices,
937 );
938
939 handles.register_float_tensor::<B>(&args.out.id, output.x_grad);
940 }
941 );
942
943 let streams = OperationStreams::with_inputs([&x, &output_grad, &indices]);
944
945 let client = x.client.clone();
946 let desc = MaxPool2dWithIndicesBackwardOpIr::create(
947 x.into_ir(),
948 output_grad.into_ir(),
949 indices.into_ir(),
950 kernel_size,
951 stride,
952 padding,
953 dilation,
954 ceil_mode,
955 || client.create_empty_handle(),
956 );
957
958 let out = client
959 .register(
960 streams,
961 OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndicesBackward(
962 desc.clone(),
963 )),
964 MaxPool2dWithIndicesBackwardOps::<B>::new(desc),
965 )
966 .output();
967
968 MaxPool2dBackward::new(out)
969 }
970
971 fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
972 make_ops!(
973 AdaptiveAvgPool1dOps,
974 AdaptiveAvgPool1dOpIr,
975 |args: &AdaptiveAvgPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
976 let x = handles.get_float_tensor::<B>(&args.x);
977 let output = B::adaptive_avg_pool1d(x, args.output_size);
978
979 handles.register_float_tensor::<B>(&args.out.id, output);
980 }
981 );
982
983 let streams = OperationStreams::with_inputs([&x]);
984
985 let client = x.client.clone();
986 let desc = AdaptiveAvgPool1dOpIr::create(x.into_ir(), output_size, || {
987 client.create_empty_handle()
988 });
989
990 client
991 .register(
992 streams,
993 OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d(desc.clone())),
994 AdaptiveAvgPool1dOps::<B>::new(desc),
995 )
996 .output()
997 }
998
999 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
1000 make_ops!(
1001 AdaptiveAvgPool2dOps,
1002 AdaptiveAvgPool2dOpIr,
1003 |args: &AdaptiveAvgPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
1004 let x = handles.get_float_tensor::<B>(&args.x);
1005 let output = B::adaptive_avg_pool2d(x, args.output_size);
1006
1007 handles.register_float_tensor::<B>(&args.out.id, output);
1008 }
1009 );
1010
1011 let streams = OperationStreams::with_inputs([&x]);
1012
1013 let client = x.client.clone();
1014 let desc = AdaptiveAvgPool2dOpIr::create(x.into_ir(), output_size, || {
1015 client.create_empty_handle()
1016 });
1017
1018 client
1019 .register(
1020 streams,
1021 OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d(desc.clone())),
1022 AdaptiveAvgPool2dOps::<B>::new(desc),
1023 )
1024 .output()
1025 }
1026
1027 fn adaptive_avg_pool1d_backward(
1028 x: FloatTensor<Self>,
1029 grad: FloatTensor<Self>,
1030 ) -> FloatTensor<Self> {
1031 make_ops!(
1032 AdaptiveAvgPool1dBackwardOps,
1033 AdaptiveAvgPool1dBackwardOpIr,
1034 |args: &AdaptiveAvgPool1dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1035 let x = handles.get_float_tensor::<B>(&args.x);
1036 let grad = handles.get_float_tensor::<B>(&args.grad);
1037 let output = B::adaptive_avg_pool1d_backward(x, grad);
1038
1039 handles.register_float_tensor::<B>(&args.out.id, output);
1040 }
1041 );
1042
1043 let streams = OperationStreams::with_inputs([&x, &grad]);
1044
1045 let client = x.client.clone();
1046 let desc = AdaptiveAvgPool1dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
1047 client.create_empty_handle()
1048 });
1049
1050 client
1051 .register(
1052 streams,
1053 OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1dBackward(desc.clone())),
1054 AdaptiveAvgPool1dBackwardOps::<B>::new(desc),
1055 )
1056 .output()
1057 }
1058
1059 fn adaptive_avg_pool2d_backward(
1060 x: FloatTensor<Self>,
1061 grad: FloatTensor<Self>,
1062 ) -> FloatTensor<Self> {
1063 make_ops!(
1064 AdaptiveAvgPool2dBackwardOps,
1065 AdaptiveAvgPool2dBackwardOpIr,
1066 |args: &AdaptiveAvgPool2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1067 let x = handles.get_float_tensor::<B>(&args.x);
1068 let grad = handles.get_float_tensor::<B>(&args.grad);
1069 let output = B::adaptive_avg_pool2d_backward(x, grad);
1070
1071 handles.register_float_tensor::<B>(&args.out.id, output);
1072 }
1073 );
1074 let streams = OperationStreams::with_inputs([&x, &grad]);
1075
1076 let client = x.client.clone();
1077 let desc = AdaptiveAvgPool2dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
1078 client.create_empty_handle()
1079 });
1080
1081 client
1082 .register(
1083 streams,
1084 OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2dBackward(desc.clone())),
1085 AdaptiveAvgPool2dBackwardOps::<B>::new(desc),
1086 )
1087 .output()
1088 }
1089
1090 fn interpolate(
1091 x: FloatTensor<Self>,
1092 output_size: [usize; 2],
1093 options: InterpolateOptions,
1094 ) -> FloatTensor<Self> {
1095 make_ops!(
1096 InterpolateOps,
1097 InterpolateOpIr,
1098 |args: &InterpolateOpIr, handles: &mut HandleContainer<B::Handle>| {
1099 let x = handles.get_float_tensor::<B>(&args.x);
1100 let output = B::interpolate(x, args.output_size, args.options.clone().into());
1101 handles.register_float_tensor::<B>(&args.out.id, output);
1102 }
1103 );
1104
1105 let streams = OperationStreams::with_inputs([&x]);
1106
1107 let client = x.client.clone();
1108 let desc = InterpolateOpIr::create(x.into_ir(), output_size, options.into(), || {
1109 client.create_empty_handle()
1110 });
1111
1112 client
1113 .register(
1114 streams,
1115 OperationIr::Module(ModuleOperationIr::Interpolate(desc.clone())),
1116 InterpolateOps::<B>::new(desc),
1117 )
1118 .output()
1119 }
1120
1121 fn interpolate_backward(
1122 x: FloatTensor<Self>,
1123 grad: FloatTensor<Self>,
1124 output_size: [usize; 2],
1125 options: InterpolateOptions,
1126 ) -> FloatTensor<Self> {
1127 make_ops!(
1128 InterpolateBackwardOps,
1129 InterpolateBackwardOpIr,
1130 |args: &InterpolateBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1131 let x = handles.get_float_tensor::<B>(&args.x);
1132 let grad = handles.get_float_tensor::<B>(&args.grad);
1133 let output =
1134 B::interpolate_backward(x, grad, args.output_size, args.options.clone().into());
1135
1136 handles.register_float_tensor::<B>(&args.out.id, output);
1137 }
1138 );
1139
1140 let streams = OperationStreams::with_inputs([&x, &grad]);
1141
1142 let client = x.client.clone();
1143 let desc = InterpolateBackwardOpIr::create(
1144 x.into_ir(),
1145 grad.into_ir(),
1146 output_size,
1147 options.into(),
1148 || client.create_empty_handle(),
1149 );
1150
1151 client
1152 .register(
1153 streams,
1154 OperationIr::Module(ModuleOperationIr::InterpolateBackward(desc.clone())),
1155 InterpolateBackwardOps::<B>::new(desc),
1156 )
1157 .output()
1158 }
1159}