1use crate::{
2 Fusion, FusionBackend,
3 stream::{OperationStreams, execution::Operation},
4};
5use burn_ir::*;
6use burn_tensor::{
7 Element, Shape,
8 ops::{
9 ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, FloatTensor,
10 IntTensor, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,
11 MaxPool2dWithIndices, ModuleOps,
12 conv::{
13 calculate_conv_output_size, calculate_conv_transpose_output_size,
14 calculate_pool_output_size,
15 },
16 },
17};
18use std::marker::PhantomData;
19
20macro_rules! make_ops {
21 ($name:ident, $desc:ty, $fn:expr) => {
22 #[derive(new, Debug)]
23 struct $name<B: FusionBackend> {
24 desc: $desc,
25 _b: PhantomData<B>,
26 }
27
28 impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
29 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
30 #[allow(clippy::redundant_closure_call)]
31 $fn(&self.desc, handles)
32 }
33 }
34 };
35}
36
37impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
38 fn conv1d(
39 x: FloatTensor<Self>,
40 weight: FloatTensor<Self>,
41 bias: Option<FloatTensor<Self>>,
42 options: ConvOptions<1>,
43 ) -> FloatTensor<Self> {
44 make_ops!(Conv1dOps, Conv1dOpIr, |desc: &Conv1dOpIr,
45 handles: &mut HandleContainer<
46 B::Handle,
47 >| {
48 let x = handles.get_float_tensor::<B>(&desc.x);
49 let weight = handles.get_float_tensor::<B>(&desc.weight);
50 let bias = desc
51 .bias
52 .as_ref()
53 .map(|bias| handles.get_float_tensor::<B>(bias));
54 let output = B::conv1d(x, weight, bias, desc.options.clone().into());
55 handles.register_float_tensor::<B>(&desc.out.id, output);
56 });
57
58 let size = calculate_conv_output_size(
59 weight.shape[2],
60 options.stride[0],
61 options.padding[0],
62 options.dilation[0],
63 x.shape[2],
64 );
65
66 let mut streams = OperationStreams::default();
67 streams.tensor(&x);
68 streams.tensor(&weight);
69
70 if let Some(bias) = bias.as_ref() {
71 streams.tensor(bias)
72 }
73
74 let shape = vec![x.shape[0], weight.shape[0], size];
75 let out = x
76 .client
77 .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
78
79 let description = Conv1dOpIr {
80 x: x.into_ir(),
81 weight: weight.into_ir(),
82 bias: bias.map(|bias| bias.into_ir()),
83 options: options.into(),
84 out: out.to_ir_out(),
85 };
86
87 out.client.clone().register(
88 streams,
89 OperationIr::Module(ModuleOperationIr::Conv1d(description.clone())),
90 Conv1dOps::<B>::new(description),
91 );
92
93 out
94 }
95
96 fn conv2d(
97 x: FloatTensor<Self>,
98 weight: FloatTensor<Self>,
99 bias: Option<FloatTensor<Self>>,
100 options: ConvOptions<2>,
101 ) -> FloatTensor<Self> {
102 make_ops!(Conv2dOps, Conv2dOpIr, |args: &Conv2dOpIr,
103 handles: &mut HandleContainer<
104 B::Handle,
105 >| {
106 let x = handles.get_float_tensor::<B>(&args.x);
107 let weight = handles.get_float_tensor::<B>(&args.weight);
108 let bias = args
109 .bias
110 .as_ref()
111 .map(|bias| handles.get_float_tensor::<B>(bias));
112
113 let output = B::conv2d(x, weight, bias, args.options.clone().into());
114
115 handles.register_float_tensor::<B>(&args.out.id, output);
116 });
117
118 let size_0 = calculate_conv_output_size(
119 weight.shape[2],
120 options.stride[0],
121 options.padding[0],
122 options.dilation[0],
123 x.shape[2],
124 );
125 let size_1 = calculate_conv_output_size(
126 weight.shape[3],
127 options.stride[1],
128 options.padding[1],
129 options.dilation[1],
130 x.shape[3],
131 );
132
133 let mut streams = OperationStreams::default();
134 streams.tensor(&x);
135 streams.tensor(&weight);
136
137 if let Some(bias) = bias.as_ref() {
138 streams.tensor(bias)
139 }
140 let shape = vec![x.shape[0], weight.shape[0], size_0, size_1];
141 let out = x
142 .client
143 .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
144
145 let desc = Conv2dOpIr {
146 x: x.into_ir(),
147 weight: weight.into_ir(),
148 bias: bias.map(|bias| bias.into_ir()),
149 options: options.into(),
150 out: out.to_ir_out(),
151 };
152
153 out.client.register(
154 streams,
155 OperationIr::Module(ModuleOperationIr::Conv2d(desc.clone())),
156 Conv2dOps::<B>::new(desc),
157 );
158
159 out
160 }
161
162 fn deform_conv2d(
163 x: FloatTensor<Self>,
164 offset: FloatTensor<Self>,
165 weight: FloatTensor<Self>,
166 mask: Option<FloatTensor<Self>>,
167 bias: Option<FloatTensor<Self>>,
168 options: DeformConvOptions<2>,
169 ) -> FloatTensor<Self> {
170 make_ops!(
171 DeformConv2dOps,
172 DeformConv2dOpIr,
173 |args: &DeformConv2dOpIr, handles: &mut HandleContainer<B::Handle>| {
174 let x = handles.get_float_tensor::<B>(&args.x);
175 let offset = handles.get_float_tensor::<B>(&args.offset);
176 let weight = handles.get_float_tensor::<B>(&args.weight);
177 let mask = args
178 .mask
179 .as_ref()
180 .map(|mask| handles.get_float_tensor::<B>(mask));
181 let bias = args
182 .bias
183 .as_ref()
184 .map(|bias| handles.get_float_tensor::<B>(bias));
185
186 let output =
187 B::deform_conv2d(x, offset, weight, mask, bias, args.options.clone().into());
188
189 handles.register_float_tensor::<B>(&args.out.id, output);
190 }
191 );
192
193 let size_0 = calculate_conv_output_size(
194 weight.shape[2],
195 options.stride[0],
196 options.padding[0],
197 options.dilation[0],
198 x.shape[2],
199 );
200 let size_1 = calculate_conv_output_size(
201 weight.shape[3],
202 options.stride[1],
203 options.padding[1],
204 options.dilation[1],
205 x.shape[3],
206 );
207
208 let mut streams = OperationStreams::default();
209 streams.tensor(&x);
210 streams.tensor(&offset);
211 streams.tensor(&weight);
212
213 if let Some(bias) = bias.as_ref() {
214 streams.tensor(bias)
215 }
216 if let Some(mask) = mask.as_ref() {
217 streams.tensor(mask)
218 }
219
220 let shape = vec![x.shape[0], weight.shape[0], size_0, size_1];
221 let out = x
222 .client
223 .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
224
225 let desc = DeformConv2dOpIr {
226 x: x.into_ir(),
227 offset: offset.into_ir(),
228 weight: weight.into_ir(),
229 mask: mask.map(|mask| mask.into_ir()),
230 bias: bias.map(|bias| bias.into_ir()),
231 options: options.into(),
232 out: out.to_ir_out(),
233 };
234
235 out.client.register(
236 streams,
237 OperationIr::Module(ModuleOperationIr::DeformableConv2d(Box::new(desc.clone()))),
238 DeformConv2dOps::<B>::new(desc),
239 );
240
241 out
242 }
243
244 fn deform_conv2d_backward(
245 x: FloatTensor<Self>,
246 offset: FloatTensor<Self>,
247 weight: FloatTensor<Self>,
248 mask: Option<FloatTensor<Self>>,
249 bias: Option<FloatTensor<Self>>,
250 output_grad: FloatTensor<Self>,
251 options: DeformConvOptions<2>,
252 ) -> DeformConv2dBackward<Self> {
253 make_ops!(
254 DeformConv2dBackwardOps,
255 DeformConv2dBackwardOpIr,
256 |args: &DeformConv2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
257 let x = handles.get_float_tensor::<B>(&args.x);
258 let offset = handles.get_float_tensor::<B>(&args.offset);
259 let weight = handles.get_float_tensor::<B>(&args.weight);
260 let mask = args
261 .mask
262 .as_ref()
263 .map(|mask| handles.get_float_tensor::<B>(mask));
264 let bias = args
265 .bias
266 .as_ref()
267 .map(|bias| handles.get_float_tensor::<B>(bias));
268 let output_grad = handles.get_float_tensor::<B>(&args.out_grad);
269
270 let output = B::deform_conv2d_backward(
271 x,
272 offset,
273 weight,
274 mask,
275 bias,
276 output_grad,
277 args.options.clone().into(),
278 );
279
280 handles.register_float_tensor::<B>(&args.input_grad.id, output.x_grad);
281 handles.register_float_tensor::<B>(&args.offset_grad.id, output.offset_grad);
282 handles.register_float_tensor::<B>(&args.weight_grad.id, output.weight_grad);
283 if let Some((mask_grad, field)) = output.mask_grad.zip(args.mask_grad.as_ref()) {
284 handles.register_float_tensor::<B>(&field.id, mask_grad);
285 }
286 if let Some((bias_grad, field)) = output.bias_grad.zip(args.bias_grad.as_ref()) {
287 handles.register_float_tensor::<B>(&field.id, bias_grad);
288 }
289 }
290 );
291
292 let input_grad = x
293 .client
294 .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
295 let offset_grad = offset
296 .client
297 .tensor_uninitialized(offset.shape.clone(), B::FloatElem::dtype());
298 let weight_grad = offset
299 .client
300 .tensor_uninitialized(weight.shape.clone(), B::FloatElem::dtype());
301 let mask_grad = mask.as_ref().map(|mask| {
302 offset
303 .client
304 .tensor_uninitialized(mask.shape.clone(), B::FloatElem::dtype())
305 });
306 let bias_grad = bias.as_ref().map(|bias| {
307 offset
308 .client
309 .tensor_uninitialized(bias.shape.clone(), B::FloatElem::dtype())
310 });
311
312 let mut streams = OperationStreams::default();
313 streams.tensor(&x);
314 streams.tensor(&offset);
315 streams.tensor(&weight);
316 streams.tensor(&output_grad);
317
318 if let Some(bias) = bias.as_ref() {
319 streams.tensor(bias)
320 }
321 if let Some(mask) = mask.as_ref() {
322 streams.tensor(mask)
323 }
324
325 let desc = DeformConv2dBackwardOpIr {
326 x: x.into_ir(),
327 offset: offset.into_ir(),
328 weight: weight.into_ir(),
329 mask: mask.map(|mask| mask.into_ir()),
330 bias: bias.map(|bias| bias.into_ir()),
331 options: options.into(),
332 out_grad: output_grad.into_ir(),
333 input_grad: input_grad.to_ir_out(),
334 offset_grad: offset_grad.to_ir_out(),
335 weight_grad: weight_grad.to_ir_out(),
336 mask_grad: mask_grad.as_ref().map(|mask_grad| mask_grad.to_ir_out()),
337 bias_grad: bias_grad.as_ref().map(|bias_grad| bias_grad.to_ir_out()),
338 };
339
340 input_grad.client.register(
341 streams,
342 OperationIr::Module(ModuleOperationIr::DeformableConv2dBackward(Box::new(
343 desc.clone(),
344 ))),
345 DeformConv2dBackwardOps::<B>::new(desc),
346 );
347
348 DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
349 }
350
351 fn conv3d(
352 x: FloatTensor<Self>,
353 weight: FloatTensor<Self>,
354 bias: Option<FloatTensor<Self>>,
355 options: ConvOptions<3>,
356 ) -> FloatTensor<Self> {
357 make_ops!(Conv3dOps, Conv3dOpIr, |args: &Conv3dOpIr,
358 handles: &mut HandleContainer<
359 B::Handle,
360 >| {
361 let x = handles.get_float_tensor::<B>(&args.x);
362 let weight = handles.get_float_tensor::<B>(&args.weight);
363 let bias = args
364 .bias
365 .as_ref()
366 .map(|bias| handles.get_float_tensor::<B>(bias));
367
368 let output = B::conv3d(x, weight, bias, args.options.clone().into());
369
370 handles.register_float_tensor::<B>(&args.out.id, output);
371 });
372
373 let size_0 = calculate_conv_output_size(
374 weight.shape[2],
375 options.stride[0],
376 options.padding[0],
377 options.dilation[0],
378 x.shape[2],
379 );
380 let size_1 = calculate_conv_output_size(
381 weight.shape[3],
382 options.stride[1],
383 options.padding[1],
384 options.dilation[1],
385 x.shape[3],
386 );
387 let size_2 = calculate_conv_output_size(
388 weight.shape[4],
389 options.stride[2],
390 options.padding[2],
391 options.dilation[2],
392 x.shape[4],
393 );
394
395 let mut streams = OperationStreams::default();
396 streams.tensor(&x);
397 streams.tensor(&weight);
398
399 if let Some(bias) = bias.as_ref() {
400 streams.tensor(bias)
401 }
402
403 let shape = vec![x.shape[0], weight.shape[0], size_0, size_1, size_2];
404 let out = x
405 .client
406 .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
407
408 let desc = Conv3dOpIr {
409 x: x.into_ir(),
410 weight: weight.into_ir(),
411 bias: bias.map(|bias| bias.into_ir()),
412 options: options.into(),
413 out: out.to_ir_out(),
414 };
415
416 out.client.register(
417 streams,
418 OperationIr::Module(ModuleOperationIr::Conv3d(desc.clone())),
419 Conv3dOps::<B>::new(desc),
420 );
421
422 out
423 }
424
425 fn conv_transpose1d(
426 x: FloatTensor<Self>,
427 weight: FloatTensor<Self>,
428 bias: Option<FloatTensor<Self>>,
429 options: ConvTransposeOptions<1>,
430 ) -> FloatTensor<Self> {
431 make_ops!(
432 ConvTranspose1dOps,
433 ConvTranspose1dOpIr,
434 |args: &ConvTranspose1dOpIr, handles: &mut HandleContainer<B::Handle>| {
435 let x = handles.get_float_tensor::<B>(&args.x);
436 let weight = handles.get_float_tensor::<B>(&args.weight);
437 let bias = args
438 .bias
439 .as_ref()
440 .map(|bias| handles.get_float_tensor::<B>(bias));
441
442 let output = B::conv_transpose1d(x, weight, bias, args.options.clone().into());
443
444 handles.register_float_tensor::<B>(&args.out.id, output);
445 }
446 );
447
448 let size = calculate_conv_transpose_output_size(
449 weight.shape[2],
450 options.stride[0],
451 options.padding[0],
452 options.padding_out[0],
453 options.dilation[0],
454 x.shape[2],
455 );
456
457 let mut streams = OperationStreams::default();
458 streams.tensor(&x);
459 streams.tensor(&weight);
460
461 if let Some(bias) = bias.as_ref() {
462 streams.tensor(bias)
463 }
464
465 let shape = vec![x.shape[0], weight.shape[1] * options.groups, size];
466 let out = x
467 .client
468 .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
469
470 let desc = ConvTranspose1dOpIr {
471 x: x.into_ir(),
472 weight: weight.into_ir(),
473 bias: bias.map(|bias| bias.into_ir()),
474 options: options.into(),
475 out: out.to_ir_out(),
476 };
477
478 out.client.register(
479 streams,
480 OperationIr::Module(ModuleOperationIr::ConvTranspose1d(desc.clone())),
481 ConvTranspose1dOps::<B>::new(desc),
482 );
483
484 out
485 }
486
487 fn conv_transpose2d(
488 x: FloatTensor<Self>,
489 weight: FloatTensor<Self>,
490 bias: Option<FloatTensor<Self>>,
491 options: ConvTransposeOptions<2>,
492 ) -> FloatTensor<Self> {
493 make_ops!(
494 ConvTranspose2dOps,
495 ConvTranspose2dOpIr,
496 |args: &ConvTranspose2dOpIr, handles: &mut HandleContainer<B::Handle>| {
497 let x = handles.get_float_tensor::<B>(&args.x);
498 let weight = handles.get_float_tensor::<B>(&args.weight);
499 let bias = args
500 .bias
501 .as_ref()
502 .map(|bias| handles.get_float_tensor::<B>(bias));
503
504 let output = B::conv_transpose2d(x, weight, bias, args.options.clone().into());
505
506 handles.register_float_tensor::<B>(&args.out.id, output);
507 }
508 );
509
510 let size_0 = calculate_conv_transpose_output_size(
511 weight.shape[2],
512 options.stride[0],
513 options.padding[0],
514 options.padding_out[0],
515 options.dilation[0],
516 x.shape[2],
517 );
518 let size_1 = calculate_conv_transpose_output_size(
519 weight.shape[3],
520 options.stride[1],
521 options.padding[1],
522 options.padding_out[1],
523 options.dilation[1],
524 x.shape[3],
525 );
526
527 let mut streams = OperationStreams::default();
528 streams.tensor(&x);
529 streams.tensor(&weight);
530
531 if let Some(bias) = bias.as_ref() {
532 streams.tensor(bias)
533 }
534
535 let shape = vec![x.shape[0], weight.shape[1] * options.groups, size_0, size_1];
536 let out = x
537 .client
538 .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
539
540 let desc = ConvTranspose2dOpIr {
541 x: x.into_ir(),
542 weight: weight.into_ir(),
543 bias: bias.map(|bias| bias.into_ir()),
544 options: options.into(),
545 out: out.to_ir_out(),
546 };
547
548 out.client.register(
549 streams,
550 OperationIr::Module(ModuleOperationIr::ConvTranspose2d(desc.clone())),
551 ConvTranspose2dOps::<B>::new(desc),
552 );
553
554 out
555 }
556
557 fn conv_transpose3d(
558 x: FloatTensor<Self>,
559 weight: FloatTensor<Self>,
560 bias: Option<FloatTensor<Self>>,
561 options: ConvTransposeOptions<3>,
562 ) -> FloatTensor<Self> {
563 make_ops!(
564 ConvTranspose3dOps,
565 ConvTranspose3dOpIr,
566 |args: &ConvTranspose3dOpIr, handles: &mut HandleContainer<B::Handle>| {
567 let x = handles.get_float_tensor::<B>(&args.x);
568 let weight = handles.get_float_tensor::<B>(&args.weight);
569 let bias = args
570 .bias
571 .as_ref()
572 .map(|bias| handles.get_float_tensor::<B>(bias));
573
574 let output = B::conv_transpose3d(x, weight, bias, args.options.clone().into());
575
576 handles.register_float_tensor::<B>(&args.out.id, output);
577 }
578 );
579
580 let size_0 = calculate_conv_transpose_output_size(
581 weight.shape[2],
582 options.stride[0],
583 options.padding[0],
584 options.padding_out[0],
585 options.dilation[0],
586 x.shape[2],
587 );
588 let size_1 = calculate_conv_transpose_output_size(
589 weight.shape[3],
590 options.stride[1],
591 options.padding[1],
592 options.padding_out[1],
593 options.dilation[1],
594 x.shape[3],
595 );
596 let size_2 = calculate_conv_transpose_output_size(
597 weight.shape[4],
598 options.stride[2],
599 options.padding[2],
600 options.padding_out[2],
601 options.dilation[2],
602 x.shape[4],
603 );
604
605 let mut streams = OperationStreams::default();
606 streams.tensor(&x);
607 streams.tensor(&weight);
608
609 if let Some(bias) = bias.as_ref() {
610 streams.tensor(bias)
611 }
612
613 let shape = vec![
614 x.shape[0],
615 weight.shape[1] * options.groups,
616 size_0,
617 size_1,
618 size_2,
619 ];
620 let out = x
621 .client
622 .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
623
624 let desc = ConvTranspose3dOpIr {
625 x: x.into_ir(),
626 weight: weight.into_ir(),
627 bias: bias.map(|bias| bias.into_ir()),
628 options: options.into(),
629 out: out.to_ir_out(),
630 };
631
632 out.client.register(
633 streams,
634 OperationIr::Module(ModuleOperationIr::ConvTranspose3d(desc.clone())),
635 ConvTranspose3dOps::<B>::new(desc),
636 );
637
638 out
639 }
640
641 fn avg_pool1d(
642 x: FloatTensor<Self>,
643 kernel_size: usize,
644 stride: usize,
645 padding: usize,
646 count_include_pad: bool,
647 ) -> FloatTensor<Self> {
648 make_ops!(
649 AvgPool1dOps,
650 AvgPool1dOpIr,
651 |args: &AvgPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
652 let x = handles.get_float_tensor::<B>(&args.x);
653 let output = B::avg_pool1d(
654 x,
655 args.kernel_size,
656 args.stride,
657 args.padding,
658 args.count_include_pad,
659 );
660
661 handles.register_float_tensor::<B>(&args.out.id, output);
662 }
663 );
664
665 let mut streams = OperationStreams::default();
666 streams.tensor(&x);
667
668 let size = calculate_pool_output_size(kernel_size, stride, padding, 1, x.shape[2]);
669 let shape = vec![x.shape[0], x.shape[1], size];
670 let out = x
671 .client
672 .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
673
674 let desc = AvgPool1dOpIr {
675 x: x.into_ir(),
676 kernel_size,
677 stride,
678 padding,
679 count_include_pad,
680 out: out.to_ir_out(),
681 };
682 out.client.register(
683 streams,
684 OperationIr::Module(ModuleOperationIr::AvgPool1d(desc.clone())),
685 AvgPool1dOps::<B>::new(desc),
686 );
687
688 out
689 }
690
691 fn avg_pool2d(
692 x: FloatTensor<Self>,
693 kernel_size: [usize; 2],
694 stride: [usize; 2],
695 padding: [usize; 2],
696 count_include_pad: bool,
697 ) -> FloatTensor<Self> {
698 make_ops!(
699 AvgPool2dOps,
700 AvgPool2dOpIr,
701 |args: &AvgPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
702 let x = handles.get_float_tensor::<B>(&args.x);
703 let output = B::avg_pool2d(
704 x,
705 args.kernel_size,
706 args.stride,
707 args.padding,
708 args.count_include_pad,
709 );
710
711 handles.register_float_tensor::<B>(&args.out.id, output);
712 }
713 );
714
715 let size_0 =
716 calculate_pool_output_size(kernel_size[0], stride[0], padding[0], 1, x.shape[2]);
717 let size_1 =
718 calculate_pool_output_size(kernel_size[1], stride[1], padding[1], 1, x.shape[3]);
719
720 let mut streams = OperationStreams::default();
721 streams.tensor(&x);
722
723 let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
724 let out = x
725 .client
726 .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
727
728 let desc = AvgPool2dOpIr {
729 x: x.into_ir(),
730 kernel_size,
731 stride,
732 padding,
733 count_include_pad,
734 out: out.to_ir_out(),
735 };
736 out.client.register(
737 streams,
738 OperationIr::Module(ModuleOperationIr::AvgPool2d(desc.clone())),
739 AvgPool2dOps::<B>::new(desc),
740 );
741
742 out
743 }
744
745 fn avg_pool1d_backward(
746 x: FloatTensor<Self>,
747 grad: FloatTensor<Self>,
748 kernel_size: usize,
749 stride: usize,
750 padding: usize,
751 count_include_pad: bool,
752 ) -> FloatTensor<Self> {
753 make_ops!(
754 AvgPool1dBackwardOps,
755 AvgPool1dBackwardOpIr,
756 |args: &AvgPool1dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
757 let x = handles.get_float_tensor::<B>(&args.x);
758 let grad = handles.get_float_tensor::<B>(&args.grad);
759 let output = B::avg_pool1d_backward(
760 x,
761 grad,
762 args.kernel_size,
763 args.stride,
764 args.padding,
765 args.count_include_pad,
766 );
767
768 handles.register_float_tensor::<B>(&args.out.id, output);
769 }
770 );
771
772 let mut streams = OperationStreams::default();
773 streams.tensor(&x);
774 streams.tensor(&grad);
775
776 let out = x
777 .client
778 .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
779
780 let desc = AvgPool1dBackwardOpIr {
781 x: x.into_ir(),
782 grad: grad.into_ir(),
783 kernel_size,
784 stride,
785 padding,
786 count_include_pad,
787 out: out.to_ir_out(),
788 };
789 out.client.register(
790 streams,
791 OperationIr::Module(ModuleOperationIr::AvgPool1dBackward(desc.clone())),
792 AvgPool1dBackwardOps::<B>::new(desc),
793 );
794
795 out
796 }
797
798 fn avg_pool2d_backward(
799 x: FloatTensor<Self>,
800 grad: FloatTensor<Self>,
801 kernel_size: [usize; 2],
802 stride: [usize; 2],
803 padding: [usize; 2],
804 count_include_pad: bool,
805 ) -> FloatTensor<Self> {
806 make_ops!(
807 AvgPool2dBackwardOps,
808 AvgPool2dBackwardOpIr,
809 |args: &AvgPool2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
810 let x = handles.get_float_tensor::<B>(&args.x);
811 let grad = handles.get_float_tensor::<B>(&args.grad);
812 let output = B::avg_pool2d_backward(
813 x,
814 grad,
815 args.kernel_size,
816 args.stride,
817 args.padding,
818 args.count_include_pad,
819 );
820
821 handles.register_float_tensor::<B>(&args.out.id, output);
822 }
823 );
824
825 let mut streams = OperationStreams::default();
826 streams.tensor(&x);
827 streams.tensor(&grad);
828
829 let out = x
830 .client
831 .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
832
833 let desc = AvgPool2dBackwardOpIr {
834 x: x.into_ir(),
835 grad: grad.into_ir(),
836 kernel_size,
837 stride,
838 padding,
839 count_include_pad,
840 out: out.to_ir_out(),
841 };
842 out.client.register(
843 streams,
844 OperationIr::Module(ModuleOperationIr::AvgPool2dBackward(desc.clone())),
845 AvgPool2dBackwardOps::<B>::new(desc),
846 );
847
848 out
849 }
850
851 fn max_pool1d(
852 x: FloatTensor<Self>,
853 kernel_size: usize,
854 stride: usize,
855 padding: usize,
856 dilation: usize,
857 ) -> FloatTensor<Self> {
858 make_ops!(
859 MaxPool1dOps,
860 MaxPool1dOpIr,
861 |args: &MaxPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
862 let x = handles.get_float_tensor::<B>(&args.x);
863 let output = B::max_pool1d(
864 x,
865 args.kernel_size,
866 args.stride,
867 args.padding,
868 args.dilation,
869 );
870
871 handles.register_float_tensor::<B>(&args.out.id, output);
872 }
873 );
874
875 let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]);
876
877 let mut streams = OperationStreams::default();
878 streams.tensor(&x);
879
880 let shape = vec![x.shape[0], x.shape[1], size];
881 let out = x
882 .client
883 .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
884
885 let desc = MaxPool1dOpIr {
886 x: x.into_ir(),
887 kernel_size,
888 stride,
889 padding,
890 dilation,
891 out: out.to_ir_out(),
892 };
893 out.client.register(
894 streams,
895 OperationIr::Module(ModuleOperationIr::MaxPool1d(desc.clone())),
896 MaxPool1dOps::<B>::new(desc),
897 );
898
899 out
900 }
901
902 fn max_pool2d(
903 x: FloatTensor<Self>,
904 kernel_size: [usize; 2],
905 stride: [usize; 2],
906 padding: [usize; 2],
907 dilation: [usize; 2],
908 ) -> FloatTensor<Self> {
909 make_ops!(
910 MaxPool2dOps,
911 MaxPool2dOpIr,
912 |args: &MaxPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
913 let x = handles.get_float_tensor::<B>(&args.x);
914 let output = B::max_pool2d(
915 x,
916 args.kernel_size,
917 args.stride,
918 args.padding,
919 args.dilation,
920 );
921
922 handles.register_float_tensor::<B>(&args.out.id, output);
923 }
924 );
925
926 let size_0 = calculate_pool_output_size(
927 kernel_size[0],
928 stride[0],
929 padding[0],
930 dilation[0],
931 x.shape[2],
932 );
933 let size_1 = calculate_pool_output_size(
934 kernel_size[1],
935 stride[1],
936 padding[1],
937 dilation[1],
938 x.shape[3],
939 );
940
941 let mut streams = OperationStreams::default();
942 streams.tensor(&x);
943
944 let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
945 let out = x
946 .client
947 .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
948
949 let desc = MaxPool2dOpIr {
950 x: x.into_ir(),
951 kernel_size,
952 stride,
953 padding,
954 dilation,
955 out: out.to_ir_out(),
956 };
957 out.client.register(
958 streams,
959 OperationIr::Module(ModuleOperationIr::MaxPool2d(desc.clone())),
960 MaxPool2dOps::<B>::new(desc),
961 );
962
963 out
964 }
965
966 fn max_pool1d_with_indices(
967 x: FloatTensor<Self>,
968 kernel_size: usize,
969 stride: usize,
970 padding: usize,
971 dilation: usize,
972 ) -> MaxPool1dWithIndices<Self> {
973 make_ops!(
974 MaxPool1dWithIndicesOps,
975 MaxPool1dWithIndicesOpIr,
976 |args: &MaxPool1dWithIndicesOpIr, handles: &mut HandleContainer<B::Handle>| {
977 let x = handles.get_float_tensor::<B>(&args.x);
978 let output = B::max_pool1d_with_indices(
979 x,
980 args.kernel_size,
981 args.stride,
982 args.padding,
983 args.dilation,
984 );
985
986 handles.register_float_tensor::<B>(&args.out.id, output.output);
987 handles.register_int_tensor::<B>(&args.out_indices.id, output.indices);
988 }
989 );
990
991 let mut streams = OperationStreams::default();
992 streams.tensor(&x);
993
994 let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]);
995 let shape = vec![x.shape[0], x.shape[1], size];
996 let out = x
997 .client
998 .tensor_uninitialized(Shape::from(shape.clone()), B::FloatElem::dtype());
999 let out_indices = x
1000 .client
1001 .tensor_uninitialized(Shape::from(shape), B::IntElem::dtype());
1002
1003 let desc = MaxPool1dWithIndicesOpIr {
1004 x: x.into_ir(),
1005 kernel_size,
1006 stride,
1007 padding,
1008 dilation,
1009 out: out.to_ir_out(),
1010 out_indices: out_indices.to_ir_out(),
1011 };
1012 out.client.register(
1013 streams,
1014 OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndices(desc.clone())),
1015 MaxPool1dWithIndicesOps::<B>::new(desc),
1016 );
1017
1018 MaxPool1dWithIndices::new(out, out_indices)
1019 }
1020
1021 fn max_pool2d_with_indices(
1022 x: FloatTensor<Self>,
1023 kernel_size: [usize; 2],
1024 stride: [usize; 2],
1025 padding: [usize; 2],
1026 dilation: [usize; 2],
1027 ) -> MaxPool2dWithIndices<Self> {
1028 make_ops!(
1029 MaxPool2dWithIndicesOps,
1030 MaxPool2dWithIndicesOpIr,
1031 |args: &MaxPool2dWithIndicesOpIr, handles: &mut HandleContainer<B::Handle>| {
1032 let x = handles.get_float_tensor::<B>(&args.x);
1033 let output = B::max_pool2d_with_indices(
1034 x,
1035 args.kernel_size,
1036 args.stride,
1037 args.padding,
1038 args.dilation,
1039 );
1040
1041 handles.register_float_tensor::<B>(&args.out.id, output.output);
1042 handles.register_int_tensor::<B>(&args.out_indices.id, output.indices);
1043 }
1044 );
1045
1046 let size_0 = calculate_pool_output_size(
1047 kernel_size[0],
1048 stride[0],
1049 padding[0],
1050 dilation[0],
1051 x.shape[2],
1052 );
1053 let size_1 = calculate_pool_output_size(
1054 kernel_size[1],
1055 stride[1],
1056 padding[1],
1057 dilation[1],
1058 x.shape[3],
1059 );
1060
1061 let mut streams = OperationStreams::default();
1062 streams.tensor(&x);
1063
1064 let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
1065 let out = x
1066 .client
1067 .tensor_uninitialized(Shape::from(shape.clone()), B::FloatElem::dtype());
1068 let out_indices = x
1069 .client
1070 .tensor_uninitialized(Shape::from(shape), B::IntElem::dtype());
1071
1072 let desc = MaxPool2dWithIndicesOpIr {
1073 x: x.into_ir(),
1074 kernel_size,
1075 stride,
1076 padding,
1077 dilation,
1078 out: out.to_ir_out(),
1079 out_indices: out_indices.to_ir_out(),
1080 };
1081 out.client.register(
1082 streams,
1083 OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndices(desc.clone())),
1084 MaxPool2dWithIndicesOps::<B>::new(desc),
1085 );
1086
1087 MaxPool2dWithIndices::new(out, out_indices)
1088 }
1089
1090 fn max_pool1d_with_indices_backward(
1091 x: FloatTensor<Self>,
1092 kernel_size: usize,
1093 stride: usize,
1094 padding: usize,
1095 dilation: usize,
1096 output_grad: FloatTensor<Self>,
1097 indices: IntTensor<Self>,
1098 ) -> MaxPool1dBackward<Self> {
1099 make_ops!(
1100 MaxPool1dWithIndicesBackwardOps,
1101 MaxPool1dWithIndicesBackwardOpIr,
1102 |args: &MaxPool1dWithIndicesBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1103 let x = handles.get_float_tensor::<B>(&args.x);
1104 let grad = handles.get_float_tensor::<B>(&args.grad);
1105 let indices = handles.get_int_tensor::<B>(&args.indices);
1106 let output = B::max_pool1d_with_indices_backward(
1107 x,
1108 args.kernel_size,
1109 args.stride,
1110 args.padding,
1111 args.dilation,
1112 grad,
1113 indices,
1114 );
1115
1116 handles.register_float_tensor::<B>(&args.out.id, output.x_grad);
1117 }
1118 );
1119
1120 let mut streams = OperationStreams::default();
1121 streams.tensor(&x);
1122 streams.tensor(&output_grad);
1123 streams.tensor(&indices);
1124
1125 let out = x
1126 .client
1127 .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
1128
1129 let desc = MaxPool1dWithIndicesBackwardOpIr {
1130 x: x.into_ir(),
1131 grad: output_grad.into_ir(),
1132 indices: indices.into_ir(),
1133 kernel_size,
1134 stride,
1135 padding,
1136 dilation,
1137 out: out.to_ir_out(),
1138 };
1139 out.client.register(
1140 streams,
1141 OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndicesBackward(
1142 desc.clone(),
1143 )),
1144 MaxPool1dWithIndicesBackwardOps::<B>::new(desc),
1145 );
1146
1147 MaxPool1dBackward::new(out)
1148 }
1149
1150 fn max_pool2d_with_indices_backward(
1151 x: FloatTensor<Self>,
1152 kernel_size: [usize; 2],
1153 stride: [usize; 2],
1154 padding: [usize; 2],
1155 dilation: [usize; 2],
1156 output_grad: FloatTensor<Self>,
1157 indices: IntTensor<Self>,
1158 ) -> MaxPool2dBackward<Self> {
1159 make_ops!(
1160 MaxPool2dWithIndicesBackwardOps,
1161 MaxPool2dWithIndicesBackwardOpIr,
1162 |args: &MaxPool2dWithIndicesBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1163 let x = handles.get_float_tensor::<B>(&args.x);
1164 let grad = handles.get_float_tensor::<B>(&args.grad);
1165 let indices = handles.get_int_tensor::<B>(&args.indices);
1166 let output = B::max_pool2d_with_indices_backward(
1167 x,
1168 args.kernel_size,
1169 args.stride,
1170 args.padding,
1171 args.dilation,
1172 grad,
1173 indices,
1174 );
1175
1176 handles.register_float_tensor::<B>(&args.out.id, output.x_grad);
1177 }
1178 );
1179
1180 let mut streams = OperationStreams::default();
1181 streams.tensor(&x);
1182 streams.tensor(&output_grad);
1183 streams.tensor(&indices);
1184
1185 let out = x
1186 .client
1187 .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
1188
1189 let desc = MaxPool2dWithIndicesBackwardOpIr {
1190 x: x.into_ir(),
1191 grad: output_grad.into_ir(),
1192 indices: indices.into_ir(),
1193 kernel_size,
1194 stride,
1195 padding,
1196 dilation,
1197 out: out.to_ir_out(),
1198 };
1199 out.client.register(
1200 streams,
1201 OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndicesBackward(
1202 desc.clone(),
1203 )),
1204 MaxPool2dWithIndicesBackwardOps::<B>::new(desc),
1205 );
1206
1207 MaxPool2dBackward::new(out)
1208 }
1209
1210 fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
1211 make_ops!(
1212 AdaptiveAvgPool1dOps,
1213 AdaptiveAvgPool1dOpIr,
1214 |args: &AdaptiveAvgPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
1215 let x = handles.get_float_tensor::<B>(&args.x);
1216 let output = B::adaptive_avg_pool1d(x, args.output_size);
1217
1218 handles.register_float_tensor::<B>(&args.out.id, output);
1219 }
1220 );
1221
1222 let mut streams = OperationStreams::default();
1223 streams.tensor(&x);
1224
1225 let shape = vec![x.shape[0], x.shape[1], output_size];
1226 let out = x
1227 .client
1228 .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
1229
1230 let desc = AdaptiveAvgPool1dOpIr {
1231 x: x.into_ir(),
1232 output_size,
1233 out: out.to_ir_out(),
1234 };
1235 out.client.register(
1236 streams,
1237 OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d(desc.clone())),
1238 AdaptiveAvgPool1dOps::<B>::new(desc),
1239 );
1240
1241 out
1242 }
1243
1244 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
1245 make_ops!(
1246 AdaptiveAvgPool2dOps,
1247 AdaptiveAvgPool2dOpIr,
1248 |args: &AdaptiveAvgPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
1249 let x = handles.get_float_tensor::<B>(&args.x);
1250 let output = B::adaptive_avg_pool2d(x, args.output_size);
1251
1252 handles.register_float_tensor::<B>(&args.out.id, output);
1253 }
1254 );
1255
1256 let mut streams = OperationStreams::default();
1257 streams.tensor(&x);
1258
1259 let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]];
1260 let out = x
1261 .client
1262 .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
1263
1264 let desc = AdaptiveAvgPool2dOpIr {
1265 x: x.into_ir(),
1266 output_size,
1267 out: out.to_ir_out(),
1268 };
1269 out.client.register(
1270 streams,
1271 OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d(desc.clone())),
1272 AdaptiveAvgPool2dOps::<B>::new(desc),
1273 );
1274
1275 out
1276 }
1277
1278 fn adaptive_avg_pool1d_backward(
1279 x: FloatTensor<Self>,
1280 grad: FloatTensor<Self>,
1281 ) -> FloatTensor<Self> {
1282 make_ops!(
1283 AdaptiveAvgPool1dBackwardOps,
1284 AdaptiveAvgPool1dBackwardOpIr,
1285 |args: &AdaptiveAvgPool1dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1286 let x = handles.get_float_tensor::<B>(&args.x);
1287 let grad = handles.get_float_tensor::<B>(&args.grad);
1288 let output = B::adaptive_avg_pool1d_backward(x, grad);
1289
1290 handles.register_float_tensor::<B>(&args.out.id, output);
1291 }
1292 );
1293
1294 let mut streams = OperationStreams::default();
1295 streams.tensor(&x);
1296 streams.tensor(&grad);
1297
1298 let out = x
1299 .client
1300 .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
1301 let desc = AdaptiveAvgPool1dBackwardOpIr {
1302 x: x.into_ir(),
1303 grad: grad.into_ir(),
1304 out: out.to_ir_out(),
1305 };
1306
1307 out.client.register(
1308 streams,
1309 OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1dBackward(desc.clone())),
1310 AdaptiveAvgPool1dBackwardOps::<B>::new(desc),
1311 );
1312
1313 out
1314 }
1315
1316 fn adaptive_avg_pool2d_backward(
1317 x: FloatTensor<Self>,
1318 grad: FloatTensor<Self>,
1319 ) -> FloatTensor<Self> {
1320 make_ops!(
1321 AdaptiveAvgPool2dBackwardOps,
1322 AdaptiveAvgPool2dBackwardOpIr,
1323 |args: &AdaptiveAvgPool2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1324 let x = handles.get_float_tensor::<B>(&args.x);
1325 let grad = handles.get_float_tensor::<B>(&args.grad);
1326 let output = B::adaptive_avg_pool2d_backward(x, grad);
1327
1328 handles.register_float_tensor::<B>(&args.out.id, output);
1329 }
1330 );
1331
1332 let mut streams = OperationStreams::default();
1333 streams.tensor(&x);
1334 streams.tensor(&grad);
1335
1336 let out = x
1337 .client
1338 .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
1339
1340 let desc = AdaptiveAvgPool2dBackwardOpIr {
1341 x: x.into_ir(),
1342 grad: grad.into_ir(),
1343 out: out.to_ir_out(),
1344 };
1345 out.client.register(
1346 streams,
1347 OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2dBackward(desc.clone())),
1348 AdaptiveAvgPool2dBackwardOps::<B>::new(desc),
1349 );
1350
1351 out
1352 }
1353
1354 fn interpolate(
1355 x: FloatTensor<Self>,
1356 output_size: [usize; 2],
1357 options: InterpolateOptions,
1358 ) -> FloatTensor<Self> {
1359 make_ops!(
1360 InterpolateOps,
1361 InterpolateOpIr,
1362 |args: &InterpolateOpIr, handles: &mut HandleContainer<B::Handle>| {
1363 let x = handles.get_float_tensor::<B>(&args.x);
1364 let output = B::interpolate(x, args.output_size, args.options.clone().into());
1365 handles.register_float_tensor::<B>(&args.out.id, output);
1366 }
1367 );
1368
1369 let mut streams = OperationStreams::default();
1370 streams.tensor(&x);
1371
1372 let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]];
1373 let out = x
1374 .client
1375 .tensor_uninitialized(Shape::from(shape), B::FloatElem::dtype());
1376
1377 let desc = InterpolateOpIr {
1378 x: x.into_ir(),
1379 output_size,
1380 options: options.into(),
1381 out: out.to_ir_out(),
1382 };
1383
1384 out.client.register(
1385 streams,
1386 OperationIr::Module(ModuleOperationIr::Interpolate(desc.clone())),
1387 InterpolateOps::<B>::new(desc),
1388 );
1389
1390 out
1391 }
1392
1393 fn interpolate_backward(
1394 x: FloatTensor<Self>,
1395 grad: FloatTensor<Self>,
1396 output_size: [usize; 2],
1397 options: InterpolateOptions,
1398 ) -> FloatTensor<Self> {
1399 make_ops!(
1400 InterpolateBackwardOps,
1401 InterpolateBackwardOpIr,
1402 |args: &InterpolateBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1403 let x = handles.get_float_tensor::<B>(&args.x);
1404 let grad = handles.get_float_tensor::<B>(&args.grad);
1405 let output =
1406 B::interpolate_backward(x, grad, args.output_size, args.options.clone().into());
1407
1408 handles.register_float_tensor::<B>(&args.out.id, output);
1409 }
1410 );
1411
1412 let mut streams = OperationStreams::default();
1413 streams.tensor(&x);
1414 streams.tensor(&grad);
1415
1416 let out = x
1417 .client
1418 .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
1419
1420 let desc = InterpolateBackwardOpIr {
1421 x: x.into_ir(),
1422 grad: grad.into_ir(),
1423 output_size,
1424 options: options.into(),
1425 out: out.to_ir_out(),
1426 };
1427 out.client.register(
1428 streams,
1429 OperationIr::Module(ModuleOperationIr::InterpolateBackward(desc.clone())),
1430 InterpolateBackwardOps::<B>::new(desc),
1431 );
1432 out
1433 }
1434}