1use crate::{
2 Fusion, FusionBackend,
3 stream::{OperationStreams, execution::Operation},
4};
5use burn_backend::{
6 Element,
7 ops::{
8 ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions,
9 InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,
10 MaxPool2dWithIndices, ModuleOps,
11 },
12 tensor::{FloatTensor, IntTensor},
13};
14use burn_ir::*;
15use std::marker::PhantomData;
16
17macro_rules! make_ops {
18 ($name:ident, $desc:ty, $fn:expr) => {
19 #[derive(new, Debug)]
20 struct $name<B: FusionBackend> {
21 desc: $desc,
22 _b: PhantomData<B>,
23 }
24
25 impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {
26 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
27 #[allow(clippy::redundant_closure_call)]
28 $fn(&self.desc, handles)
29 }
30 }
31 };
32}
33
34impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
35 fn conv1d(
40 x: FloatTensor<Self>,
41 weight: FloatTensor<Self>,
42 bias: Option<FloatTensor<Self>>,
43 options: ConvOptions<1>,
44 ) -> FloatTensor<Self> {
45 make_ops!(Conv1dOps, Conv1dOpIr, |desc: &Conv1dOpIr,
46 handles: &mut HandleContainer<
47 B::Handle,
48 >| {
49 let x = handles.get_float_tensor::<B>(&desc.x);
50 let weight = handles.get_float_tensor::<B>(&desc.weight);
51 let bias = desc
52 .bias
53 .as_ref()
54 .map(|bias| handles.get_float_tensor::<B>(bias));
55 let output = B::conv1d(x, weight, bias, desc.options.clone().into());
56 handles.register_float_tensor::<B>(&desc.out.id, output);
57 });
58
59 let mut streams = OperationStreams::with_inputs([&x, &weight]);
60 if let Some(bias) = bias.as_ref() {
61 streams.tensor(bias)
62 }
63
64 let client = x.client.clone();
65 let desc = Conv1dOpIr::create(
66 x.into_ir(),
67 weight.into_ir(),
68 bias.map(|bias| bias.into_ir()),
69 options.into(),
70 || client.create_empty_handle(),
71 );
72
73 client
74 .register(
75 streams,
76 OperationIr::Module(ModuleOperationIr::Conv1d(desc.clone())),
77 Conv1dOps::<B>::new(desc),
78 )
79 .output()
80 }
81
82 fn conv1d_x_backward(
83 x: FloatTensor<Fusion<B>>,
84 weight: FloatTensor<Fusion<B>>,
85 output_grad: FloatTensor<Fusion<B>>,
86 options: ConvOptions<1>,
87 ) -> FloatTensor<Fusion<B>> {
88 make_ops!(
89 Conv1dXBackwardOps,
90 Conv1dXBackwardOpIr,
91 |desc: &Conv1dXBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
92 let x = handles.get_float_tensor::<B>(&desc.x);
93 let weight = handles.get_float_tensor::<B>(&desc.weight);
94 let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
95 let output =
96 B::conv1d_x_backward(x, weight, output_grad, desc.options.clone().into());
97 handles.register_float_tensor::<B>(&desc.out.id, output);
98 }
99 );
100
101 let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);
102
103 let client = x.client.clone();
104 let desc = Conv1dXBackwardOpIr::create(
105 x.into_ir(),
106 weight.into_ir(),
107 output_grad.into_ir(),
108 options.into(),
109 || client.create_empty_handle(),
110 );
111
112 client
113 .register(
114 streams,
115 OperationIr::Module(ModuleOperationIr::Conv1dXBackward(desc.clone())),
116 Conv1dXBackwardOps::<B>::new(desc),
117 )
118 .output()
119 }
120
121 fn conv1d_weight_backward(
122 x: FloatTensor<Fusion<B>>,
123 weight: FloatTensor<Fusion<B>>,
124 output_grad: FloatTensor<Fusion<B>>,
125 options: ConvOptions<1>,
126 ) -> FloatTensor<Fusion<B>> {
127 make_ops!(
128 Conv1dWeightBackwardOps,
129 Conv1dWeightBackwardOpIr,
130 |desc: &Conv1dWeightBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
131 let x = handles.get_float_tensor::<B>(&desc.x);
132 let weight = handles.get_float_tensor::<B>(&desc.weight);
133 let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
134 let output =
135 B::conv1d_weight_backward(x, weight, output_grad, desc.options.clone().into());
136 handles.register_float_tensor::<B>(&desc.out.id, output);
137 }
138 );
139
140 let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);
141
142 let client = x.client.clone();
143 let desc = Conv1dWeightBackwardOpIr::create(
144 x.into_ir(),
145 weight.into_ir(),
146 output_grad.into_ir(),
147 options.into(),
148 || client.create_empty_handle(),
149 );
150
151 client
152 .register(
153 streams,
154 OperationIr::Module(ModuleOperationIr::Conv1dWeightBackward(desc.clone())),
155 Conv1dWeightBackwardOps::<B>::new(desc),
156 )
157 .output()
158 }
159
160 fn conv1d_bias_backward(
161 x: FloatTensor<Fusion<B>>,
162 bias: FloatTensor<Fusion<B>>,
163 output_grad: FloatTensor<Fusion<B>>,
164 ) -> FloatTensor<Fusion<B>> {
165 make_ops!(
166 Conv1dBiasBackwardOps,
167 Conv1dBiasBackwardOpIr,
168 |desc: &Conv1dBiasBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
169 let x = handles.get_float_tensor::<B>(&desc.x);
170 let bias = handles.get_float_tensor::<B>(&desc.bias);
171 let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
172 let output = B::conv1d_bias_backward(x, bias, output_grad);
173 handles.register_float_tensor::<B>(&desc.out.id, output);
174 }
175 );
176
177 let streams = OperationStreams::with_inputs([&x, &bias, &output_grad]);
178
179 let client = x.client.clone();
180 let desc = Conv1dBiasBackwardOpIr::create(
181 x.into_ir(),
182 bias.into_ir(),
183 output_grad.into_ir(),
184 || client.create_empty_handle(),
185 );
186
187 client
188 .register(
189 streams,
190 OperationIr::Module(ModuleOperationIr::Conv1dBiasBackward(desc.clone())),
191 Conv1dBiasBackwardOps::<B>::new(desc),
192 )
193 .output()
194 }
195
196 fn conv2d(
197 x: FloatTensor<Self>,
198 weight: FloatTensor<Self>,
199 bias: Option<FloatTensor<Self>>,
200 options: ConvOptions<2>,
201 ) -> FloatTensor<Self> {
202 make_ops!(Conv2dOps, Conv2dOpIr, |args: &Conv2dOpIr,
203 handles: &mut HandleContainer<
204 B::Handle,
205 >| {
206 let x = handles.get_float_tensor::<B>(&args.x);
207 let weight = handles.get_float_tensor::<B>(&args.weight);
208 let bias = args
209 .bias
210 .as_ref()
211 .map(|bias| handles.get_float_tensor::<B>(bias));
212
213 let output = B::conv2d(x, weight, bias, args.options.clone().into());
214
215 handles.register_float_tensor::<B>(&args.out.id, output);
216 });
217
218 let mut streams = OperationStreams::with_inputs([&x, &weight]);
219 if let Some(bias) = bias.as_ref() {
220 streams.tensor(bias)
221 }
222
223 let client = x.client.clone();
224 let desc = Conv2dOpIr::create(
225 x.into_ir(),
226 weight.into_ir(),
227 bias.map(|bias| bias.into_ir()),
228 options.into(),
229 || client.create_empty_handle(),
230 );
231
232 client
233 .register(
234 streams,
235 OperationIr::Module(ModuleOperationIr::Conv2d(desc.clone())),
236 Conv2dOps::<B>::new(desc),
237 )
238 .output()
239 }
240
241 fn conv2d_x_backward(
242 x: FloatTensor<Fusion<B>>,
243 weight: FloatTensor<Fusion<B>>,
244 output_grad: FloatTensor<Fusion<B>>,
245 options: ConvOptions<2>,
246 ) -> FloatTensor<Fusion<B>> {
247 make_ops!(
248 Conv2dXBackwardOps,
249 Conv2dXBackwardOpIr,
250 |desc: &Conv2dXBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
251 let x = handles.get_float_tensor::<B>(&desc.x);
252 let weight = handles.get_float_tensor::<B>(&desc.weight);
253 let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
254 let output =
255 B::conv2d_x_backward(x, weight, output_grad, desc.options.clone().into());
256 handles.register_float_tensor::<B>(&desc.out.id, output);
257 }
258 );
259
260 let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);
261
262 let client = x.client.clone();
263 let desc = Conv2dXBackwardOpIr::create(
264 x.into_ir(),
265 weight.into_ir(),
266 output_grad.into_ir(),
267 options.into(),
268 || client.create_empty_handle(),
269 );
270
271 client
272 .register(
273 streams,
274 OperationIr::Module(ModuleOperationIr::Conv2dXBackward(desc.clone())),
275 Conv2dXBackwardOps::<B>::new(desc),
276 )
277 .output()
278 }
279
280 fn conv2d_weight_backward(
281 x: FloatTensor<Fusion<B>>,
282 weight: FloatTensor<Fusion<B>>,
283 output_grad: FloatTensor<Fusion<B>>,
284 options: ConvOptions<2>,
285 ) -> FloatTensor<Fusion<B>> {
286 make_ops!(
287 Conv2dWeightBackwardOps,
288 Conv2dWeightBackwardOpIr,
289 |desc: &Conv2dWeightBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
290 let x = handles.get_float_tensor::<B>(&desc.x);
291 let weight = handles.get_float_tensor::<B>(&desc.weight);
292 let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
293 let output =
294 B::conv2d_weight_backward(x, weight, output_grad, desc.options.clone().into());
295 handles.register_float_tensor::<B>(&desc.out.id, output);
296 }
297 );
298
299 let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);
300
301 let client = x.client.clone();
302 let desc = Conv2dWeightBackwardOpIr::create(
303 x.into_ir(),
304 weight.into_ir(),
305 output_grad.into_ir(),
306 options.into(),
307 || client.create_empty_handle(),
308 );
309
310 client
311 .register(
312 streams,
313 OperationIr::Module(ModuleOperationIr::Conv2dWeightBackward(desc.clone())),
314 Conv2dWeightBackwardOps::<B>::new(desc),
315 )
316 .output()
317 }
318
319 fn conv2d_bias_backward(
320 x: FloatTensor<Fusion<B>>,
321 bias: FloatTensor<Fusion<B>>,
322 output_grad: FloatTensor<Fusion<B>>,
323 ) -> FloatTensor<Fusion<B>> {
324 make_ops!(
325 Conv2dBiasBackwardOps,
326 Conv2dBiasBackwardOpIr,
327 |desc: &Conv2dBiasBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
328 let x = handles.get_float_tensor::<B>(&desc.x);
329 let bias = handles.get_float_tensor::<B>(&desc.bias);
330 let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
331 let output = B::conv2d_bias_backward(x, bias, output_grad);
332 handles.register_float_tensor::<B>(&desc.out.id, output);
333 }
334 );
335
336 let streams = OperationStreams::with_inputs([&x, &bias, &output_grad]);
337
338 let client = x.client.clone();
339 let desc = Conv2dBiasBackwardOpIr::create(
340 x.into_ir(),
341 bias.into_ir(),
342 output_grad.into_ir(),
343 || client.create_empty_handle(),
344 );
345
346 client
347 .register(
348 streams,
349 OperationIr::Module(ModuleOperationIr::Conv2dBiasBackward(desc.clone())),
350 Conv2dBiasBackwardOps::<B>::new(desc),
351 )
352 .output()
353 }
354
355 fn deform_conv2d(
356 x: FloatTensor<Self>,
357 offset: FloatTensor<Self>,
358 weight: FloatTensor<Self>,
359 mask: Option<FloatTensor<Self>>,
360 bias: Option<FloatTensor<Self>>,
361 options: DeformConvOptions<2>,
362 ) -> FloatTensor<Self> {
363 make_ops!(
364 DeformConv2dOps,
365 DeformConv2dOpIr,
366 |args: &DeformConv2dOpIr, handles: &mut HandleContainer<B::Handle>| {
367 let x = handles.get_float_tensor::<B>(&args.x);
368 let offset = handles.get_float_tensor::<B>(&args.offset);
369 let weight = handles.get_float_tensor::<B>(&args.weight);
370 let mask = args
371 .mask
372 .as_ref()
373 .map(|mask| handles.get_float_tensor::<B>(mask));
374 let bias = args
375 .bias
376 .as_ref()
377 .map(|bias| handles.get_float_tensor::<B>(bias));
378
379 let output =
380 B::deform_conv2d(x, offset, weight, mask, bias, args.options.clone().into());
381
382 handles.register_float_tensor::<B>(&args.out.id, output);
383 }
384 );
385 let mut streams = OperationStreams::with_inputs([&x, &offset, &weight]);
386 if let Some(bias) = bias.as_ref() {
387 streams.tensor(bias)
388 }
389 if let Some(mask) = mask.as_ref() {
390 streams.tensor(mask)
391 }
392
393 let client = x.client.clone();
394 let desc = DeformConv2dOpIr::create(
395 x.into_ir(),
396 offset.into_ir(),
397 weight.into_ir(),
398 mask.map(|mask| mask.into_ir()),
399 bias.map(|bias| bias.into_ir()),
400 options.into(),
401 || client.create_empty_handle(),
402 );
403
404 client
405 .register(
406 streams,
407 OperationIr::Module(ModuleOperationIr::DeformableConv2d(Box::new(desc.clone()))),
408 DeformConv2dOps::<B>::new(desc),
409 )
410 .output()
411 }
412
413 fn deform_conv2d_backward(
414 x: FloatTensor<Self>,
415 offset: FloatTensor<Self>,
416 weight: FloatTensor<Self>,
417 mask: Option<FloatTensor<Self>>,
418 bias: Option<FloatTensor<Self>>,
419 output_grad: FloatTensor<Self>,
420 options: DeformConvOptions<2>,
421 ) -> DeformConv2dBackward<Self> {
422 make_ops!(
423 DeformConv2dBackwardOps,
424 DeformConv2dBackwardOpIr,
425 |args: &DeformConv2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
426 let x = handles.get_float_tensor::<B>(&args.x);
427 let offset = handles.get_float_tensor::<B>(&args.offset);
428 let weight = handles.get_float_tensor::<B>(&args.weight);
429 let mask = args
430 .mask
431 .as_ref()
432 .map(|mask| handles.get_float_tensor::<B>(mask));
433 let bias = args
434 .bias
435 .as_ref()
436 .map(|bias| handles.get_float_tensor::<B>(bias));
437 let output_grad = handles.get_float_tensor::<B>(&args.out_grad);
438
439 let output = B::deform_conv2d_backward(
440 x,
441 offset,
442 weight,
443 mask,
444 bias,
445 output_grad,
446 args.options.clone().into(),
447 );
448
449 handles.register_float_tensor::<B>(&args.input_grad.id, output.x_grad);
450 handles.register_float_tensor::<B>(&args.offset_grad.id, output.offset_grad);
451 handles.register_float_tensor::<B>(&args.weight_grad.id, output.weight_grad);
452 if let Some((mask_grad, field)) = output.mask_grad.zip(args.mask_grad.as_ref()) {
453 handles.register_float_tensor::<B>(&field.id, mask_grad);
454 }
455 if let Some((bias_grad, field)) = output.bias_grad.zip(args.bias_grad.as_ref()) {
456 handles.register_float_tensor::<B>(&field.id, bias_grad);
457 }
458 }
459 );
460
461 let has_bias = bias.is_some();
462 let has_mask = mask.is_some();
463
464 let mut streams = OperationStreams::with_inputs([&x, &offset, &weight, &output_grad]);
465 if let Some(bias) = bias.as_ref() {
466 streams.tensor(bias);
467 }
468 if let Some(mask) = mask.as_ref() {
469 streams.tensor(mask);
470 }
471
472 let client = x.client.clone();
473 let desc = DeformConv2dBackwardOpIr::create(
474 x.into_ir(),
475 offset.into_ir(),
476 weight.into_ir(),
477 mask.map(|mask| mask.into_ir()),
478 bias.map(|bias| bias.into_ir()),
479 output_grad.into_ir(),
480 options.into(),
481 || client.create_empty_handle(),
482 );
483
484 let mut outputs = client
485 .register(
486 streams,
487 OperationIr::Module(ModuleOperationIr::DeformableConv2dBackward(Box::new(
488 desc.clone(),
489 ))),
490 DeformConv2dBackwardOps::<B>::new(desc),
491 )
492 .into_iter();
493
494 let input_grad = outputs.next().unwrap();
496 let offset_grad = outputs.next().unwrap();
497 let weight_grad = outputs.next().unwrap();
498 let mask_grad = has_mask.then(|| outputs.next().unwrap());
499 let bias_grad = has_bias.then(|| outputs.next().unwrap());
500
501 DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
502 }
503
504 fn conv3d(
505 x: FloatTensor<Self>,
506 weight: FloatTensor<Self>,
507 bias: Option<FloatTensor<Self>>,
508 options: ConvOptions<3>,
509 ) -> FloatTensor<Self> {
510 make_ops!(Conv3dOps, Conv3dOpIr, |args: &Conv3dOpIr,
511 handles: &mut HandleContainer<
512 B::Handle,
513 >| {
514 let x = handles.get_float_tensor::<B>(&args.x);
515 let weight = handles.get_float_tensor::<B>(&args.weight);
516 let bias = args
517 .bias
518 .as_ref()
519 .map(|bias| handles.get_float_tensor::<B>(bias));
520
521 let output = B::conv3d(x, weight, bias, args.options.clone().into());
522
523 handles.register_float_tensor::<B>(&args.out.id, output);
524 });
525
526 let mut streams = OperationStreams::with_inputs([&x, &weight]);
527 if let Some(bias) = bias.as_ref() {
528 streams.tensor(bias)
529 }
530
531 let client = x.client.clone();
532 let desc = Conv3dOpIr::create(
533 x.into_ir(),
534 weight.into_ir(),
535 bias.map(|bias| bias.into_ir()),
536 options.into(),
537 || client.create_empty_handle(),
538 );
539
540 client
541 .register(
542 streams,
543 OperationIr::Module(ModuleOperationIr::Conv3d(desc.clone())),
544 Conv3dOps::<B>::new(desc),
545 )
546 .output()
547 }
548
549 fn conv3d_x_backward(
550 x: FloatTensor<Fusion<B>>,
551 weight: FloatTensor<Fusion<B>>,
552 output_grad: FloatTensor<Fusion<B>>,
553 options: ConvOptions<3>,
554 ) -> FloatTensor<Fusion<B>> {
555 make_ops!(
556 Conv3dXBackwardOps,
557 Conv3dXBackwardOpIr,
558 |desc: &Conv3dXBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
559 let x = handles.get_float_tensor::<B>(&desc.x);
560 let weight = handles.get_float_tensor::<B>(&desc.weight);
561 let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
562 let output =
563 B::conv3d_x_backward(x, weight, output_grad, desc.options.clone().into());
564 handles.register_float_tensor::<B>(&desc.out.id, output);
565 }
566 );
567
568 let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);
569
570 let client = x.client.clone();
571 let desc = Conv3dXBackwardOpIr::create(
572 x.into_ir(),
573 weight.into_ir(),
574 output_grad.into_ir(),
575 options.into(),
576 || client.create_empty_handle(),
577 );
578
579 client
580 .register(
581 streams,
582 OperationIr::Module(ModuleOperationIr::Conv3dXBackward(desc.clone())),
583 Conv3dXBackwardOps::<B>::new(desc),
584 )
585 .output()
586 }
587
588 fn conv3d_weight_backward(
589 x: FloatTensor<Fusion<B>>,
590 weight: FloatTensor<Fusion<B>>,
591 output_grad: FloatTensor<Fusion<B>>,
592 options: ConvOptions<3>,
593 ) -> FloatTensor<Fusion<B>> {
594 make_ops!(
595 Conv3dWeightBackwardOps,
596 Conv3dWeightBackwardOpIr,
597 |desc: &Conv3dWeightBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
598 let x = handles.get_float_tensor::<B>(&desc.x);
599 let weight = handles.get_float_tensor::<B>(&desc.weight);
600 let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
601 let output =
602 B::conv3d_weight_backward(x, weight, output_grad, desc.options.clone().into());
603 handles.register_float_tensor::<B>(&desc.out.id, output);
604 }
605 );
606
607 let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);
608
609 let client = x.client.clone();
610 let desc = Conv3dWeightBackwardOpIr::create(
611 x.into_ir(),
612 weight.into_ir(),
613 output_grad.into_ir(),
614 options.into(),
615 || client.create_empty_handle(),
616 );
617
618 client
619 .register(
620 streams,
621 OperationIr::Module(ModuleOperationIr::Conv3dWeightBackward(desc.clone())),
622 Conv3dWeightBackwardOps::<B>::new(desc),
623 )
624 .output()
625 }
626
627 fn conv3d_bias_backward(
628 x: FloatTensor<Fusion<B>>,
629 bias: FloatTensor<Fusion<B>>,
630 output_grad: FloatTensor<Fusion<B>>,
631 ) -> FloatTensor<Fusion<B>> {
632 make_ops!(
633 Conv3dBiasBackwardOps,
634 Conv3dBiasBackwardOpIr,
635 |desc: &Conv3dBiasBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
636 let x = handles.get_float_tensor::<B>(&desc.x);
637 let bias = handles.get_float_tensor::<B>(&desc.bias);
638 let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);
639 let output = B::conv3d_bias_backward(x, bias, output_grad);
640 handles.register_float_tensor::<B>(&desc.out.id, output);
641 }
642 );
643
644 let streams = OperationStreams::with_inputs([&x, &bias, &output_grad]);
645
646 let client = x.client.clone();
647 let desc = Conv3dBiasBackwardOpIr::create(
648 x.into_ir(),
649 bias.into_ir(),
650 output_grad.into_ir(),
651 || client.create_empty_handle(),
652 );
653
654 client
655 .register(
656 streams,
657 OperationIr::Module(ModuleOperationIr::Conv3dBiasBackward(desc.clone())),
658 Conv3dBiasBackwardOps::<B>::new(desc),
659 )
660 .output()
661 }
662
663 fn conv_transpose1d(
664 x: FloatTensor<Self>,
665 weight: FloatTensor<Self>,
666 bias: Option<FloatTensor<Self>>,
667 options: ConvTransposeOptions<1>,
668 ) -> FloatTensor<Self> {
669 make_ops!(
670 ConvTranspose1dOps,
671 ConvTranspose1dOpIr,
672 |args: &ConvTranspose1dOpIr, handles: &mut HandleContainer<B::Handle>| {
673 let x = handles.get_float_tensor::<B>(&args.x);
674 let weight = handles.get_float_tensor::<B>(&args.weight);
675 let bias = args
676 .bias
677 .as_ref()
678 .map(|bias| handles.get_float_tensor::<B>(bias));
679
680 let output = B::conv_transpose1d(x, weight, bias, args.options.clone().into());
681
682 handles.register_float_tensor::<B>(&args.out.id, output);
683 }
684 );
685 let mut streams = OperationStreams::with_inputs([&x, &weight]);
686 if let Some(bias) = bias.as_ref() {
687 streams.tensor(bias)
688 }
689
690 let client = x.client.clone();
691 let desc = ConvTranspose1dOpIr::create(
692 x.into_ir(),
693 weight.into_ir(),
694 bias.map(|bias| bias.into_ir()),
695 options.into(),
696 || client.create_empty_handle(),
697 );
698
699 client
700 .register(
701 streams,
702 OperationIr::Module(ModuleOperationIr::ConvTranspose1d(desc.clone())),
703 ConvTranspose1dOps::<B>::new(desc),
704 )
705 .output()
706 }
707
708 fn conv_transpose2d(
709 x: FloatTensor<Self>,
710 weight: FloatTensor<Self>,
711 bias: Option<FloatTensor<Self>>,
712 options: ConvTransposeOptions<2>,
713 ) -> FloatTensor<Self> {
714 make_ops!(
715 ConvTranspose2dOps,
716 ConvTranspose2dOpIr,
717 |args: &ConvTranspose2dOpIr, handles: &mut HandleContainer<B::Handle>| {
718 let x = handles.get_float_tensor::<B>(&args.x);
719 let weight = handles.get_float_tensor::<B>(&args.weight);
720 let bias = args
721 .bias
722 .as_ref()
723 .map(|bias| handles.get_float_tensor::<B>(bias));
724
725 let output = B::conv_transpose2d(x, weight, bias, args.options.clone().into());
726
727 handles.register_float_tensor::<B>(&args.out.id, output);
728 }
729 );
730 let mut streams = OperationStreams::with_inputs([&x, &weight]);
731 if let Some(bias) = bias.as_ref() {
732 streams.tensor(bias)
733 }
734
735 let client = x.client.clone();
736 let desc = ConvTranspose2dOpIr::create(
737 x.into_ir(),
738 weight.into_ir(),
739 bias.map(|bias| bias.into_ir()),
740 options.into(),
741 || client.create_empty_handle(),
742 );
743
744 client
745 .register(
746 streams,
747 OperationIr::Module(ModuleOperationIr::ConvTranspose2d(desc.clone())),
748 ConvTranspose2dOps::<B>::new(desc),
749 )
750 .output()
751 }
752
753 fn conv_transpose3d(
754 x: FloatTensor<Self>,
755 weight: FloatTensor<Self>,
756 bias: Option<FloatTensor<Self>>,
757 options: ConvTransposeOptions<3>,
758 ) -> FloatTensor<Self> {
759 make_ops!(
760 ConvTranspose3dOps,
761 ConvTranspose3dOpIr,
762 |args: &ConvTranspose3dOpIr, handles: &mut HandleContainer<B::Handle>| {
763 let x = handles.get_float_tensor::<B>(&args.x);
764 let weight = handles.get_float_tensor::<B>(&args.weight);
765 let bias = args
766 .bias
767 .as_ref()
768 .map(|bias| handles.get_float_tensor::<B>(bias));
769
770 let output = B::conv_transpose3d(x, weight, bias, args.options.clone().into());
771
772 handles.register_float_tensor::<B>(&args.out.id, output);
773 }
774 );
775 let mut streams = OperationStreams::with_inputs([&x, &weight]);
776 if let Some(bias) = bias.as_ref() {
777 streams.tensor(bias)
778 }
779
780 let client = x.client.clone();
781 let desc = ConvTranspose3dOpIr::create(
782 x.into_ir(),
783 weight.into_ir(),
784 bias.map(|bias| bias.into_ir()),
785 options.into(),
786 || client.create_empty_handle(),
787 );
788
789 client
790 .register(
791 streams,
792 OperationIr::Module(ModuleOperationIr::ConvTranspose3d(desc.clone())),
793 ConvTranspose3dOps::<B>::new(desc),
794 )
795 .output()
796 }
797
798 fn avg_pool1d(
799 x: FloatTensor<Self>,
800 kernel_size: usize,
801 stride: usize,
802 padding: usize,
803 count_include_pad: bool,
804 ceil_mode: bool,
805 ) -> FloatTensor<Self> {
806 make_ops!(
807 AvgPool1dOps,
808 AvgPool1dOpIr,
809 |args: &AvgPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
810 let x = handles.get_float_tensor::<B>(&args.x);
811 let output = B::avg_pool1d(
812 x,
813 args.kernel_size,
814 args.stride,
815 args.padding,
816 args.count_include_pad,
817 args.ceil_mode,
818 );
819
820 handles.register_float_tensor::<B>(&args.out.id, output);
821 }
822 );
823 let streams = OperationStreams::with_inputs([&x]);
824
825 let client = x.client.clone();
826 let desc = AvgPool1dOpIr::create(
827 x.into_ir(),
828 kernel_size,
829 stride,
830 padding,
831 count_include_pad,
832 ceil_mode,
833 || client.create_empty_handle(),
834 );
835
836 client
837 .register(
838 streams,
839 OperationIr::Module(ModuleOperationIr::AvgPool1d(desc.clone())),
840 AvgPool1dOps::<B>::new(desc),
841 )
842 .output()
843 }
844
845 fn avg_pool2d(
846 x: FloatTensor<Self>,
847 kernel_size: [usize; 2],
848 stride: [usize; 2],
849 padding: [usize; 2],
850 count_include_pad: bool,
851 ceil_mode: bool,
852 ) -> FloatTensor<Self> {
853 make_ops!(
854 AvgPool2dOps,
855 AvgPool2dOpIr,
856 |args: &AvgPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
857 let x = handles.get_float_tensor::<B>(&args.x);
858 let output = B::avg_pool2d(
859 x,
860 args.kernel_size,
861 args.stride,
862 args.padding,
863 args.count_include_pad,
864 args.ceil_mode,
865 );
866
867 handles.register_float_tensor::<B>(&args.out.id, output);
868 }
869 );
870
871 let streams = OperationStreams::with_inputs([&x]);
872
873 let client = x.client.clone();
874 let desc = AvgPool2dOpIr::create(
875 x.into_ir(),
876 kernel_size,
877 stride,
878 padding,
879 count_include_pad,
880 ceil_mode,
881 || client.create_empty_handle(),
882 );
883
884 client
885 .register(
886 streams,
887 OperationIr::Module(ModuleOperationIr::AvgPool2d(desc.clone())),
888 AvgPool2dOps::<B>::new(desc),
889 )
890 .output()
891 }
892
893 fn avg_pool1d_backward(
894 x: FloatTensor<Self>,
895 grad: FloatTensor<Self>,
896 kernel_size: usize,
897 stride: usize,
898 padding: usize,
899 count_include_pad: bool,
900 ceil_mode: bool,
901 ) -> FloatTensor<Self> {
902 make_ops!(
903 AvgPool1dBackwardOps,
904 AvgPool1dBackwardOpIr,
905 |args: &AvgPool1dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
906 let x = handles.get_float_tensor::<B>(&args.x);
907 let grad = handles.get_float_tensor::<B>(&args.grad);
908 let output = B::avg_pool1d_backward(
909 x,
910 grad,
911 args.kernel_size,
912 args.stride,
913 args.padding,
914 args.count_include_pad,
915 args.ceil_mode,
916 );
917
918 handles.register_float_tensor::<B>(&args.out.id, output);
919 }
920 );
921
922 let streams = OperationStreams::with_inputs([&x, &grad]);
923
924 let client = x.client.clone();
925 let desc = AvgPool1dBackwardOpIr::create(
926 x.into_ir(),
927 grad.into_ir(),
928 kernel_size,
929 stride,
930 padding,
931 count_include_pad,
932 ceil_mode,
933 || client.create_empty_handle(),
934 );
935
936 client
937 .register(
938 streams,
939 OperationIr::Module(ModuleOperationIr::AvgPool1dBackward(desc.clone())),
940 AvgPool1dBackwardOps::<B>::new(desc),
941 )
942 .output()
943 }
944
945 fn avg_pool2d_backward(
946 x: FloatTensor<Self>,
947 grad: FloatTensor<Self>,
948 kernel_size: [usize; 2],
949 stride: [usize; 2],
950 padding: [usize; 2],
951 count_include_pad: bool,
952 ceil_mode: bool,
953 ) -> FloatTensor<Self> {
954 make_ops!(
955 AvgPool2dBackwardOps,
956 AvgPool2dBackwardOpIr,
957 |args: &AvgPool2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
958 let x = handles.get_float_tensor::<B>(&args.x);
959 let grad = handles.get_float_tensor::<B>(&args.grad);
960 let output = B::avg_pool2d_backward(
961 x,
962 grad,
963 args.kernel_size,
964 args.stride,
965 args.padding,
966 args.count_include_pad,
967 args.ceil_mode,
968 );
969
970 handles.register_float_tensor::<B>(&args.out.id, output);
971 }
972 );
973
974 let streams = OperationStreams::with_inputs([&x, &grad]);
975
976 let client = x.client.clone();
977 let desc = AvgPool2dBackwardOpIr::create(
978 x.into_ir(),
979 grad.into_ir(),
980 kernel_size,
981 stride,
982 padding,
983 count_include_pad,
984 ceil_mode,
985 || client.create_empty_handle(),
986 );
987
988 client
989 .register(
990 streams,
991 OperationIr::Module(ModuleOperationIr::AvgPool2dBackward(desc.clone())),
992 AvgPool2dBackwardOps::<B>::new(desc),
993 )
994 .output()
995 }
996
997 fn max_pool1d(
998 x: FloatTensor<Self>,
999 kernel_size: usize,
1000 stride: usize,
1001 padding: usize,
1002 dilation: usize,
1003 ceil_mode: bool,
1004 ) -> FloatTensor<Self> {
1005 make_ops!(
1006 MaxPool1dOps,
1007 MaxPool1dOpIr,
1008 |args: &MaxPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
1009 let x = handles.get_float_tensor::<B>(&args.x);
1010 let output = B::max_pool1d(
1011 x,
1012 args.kernel_size,
1013 args.stride,
1014 args.padding,
1015 args.dilation,
1016 args.ceil_mode,
1017 );
1018
1019 handles.register_float_tensor::<B>(&args.out.id, output);
1020 }
1021 );
1022
1023 let streams = OperationStreams::with_inputs([&x]);
1024
1025 let client = x.client.clone();
1026 let desc = MaxPool1dOpIr::create(
1027 x.into_ir(),
1028 kernel_size,
1029 stride,
1030 padding,
1031 dilation,
1032 ceil_mode,
1033 || client.create_empty_handle(),
1034 );
1035
1036 client
1037 .register(
1038 streams,
1039 OperationIr::Module(ModuleOperationIr::MaxPool1d(desc.clone())),
1040 MaxPool1dOps::<B>::new(desc),
1041 )
1042 .output()
1043 }
1044
1045 fn max_pool2d(
1046 x: FloatTensor<Self>,
1047 kernel_size: [usize; 2],
1048 stride: [usize; 2],
1049 padding: [usize; 2],
1050 dilation: [usize; 2],
1051 ceil_mode: bool,
1052 ) -> FloatTensor<Self> {
1053 make_ops!(
1054 MaxPool2dOps,
1055 MaxPool2dOpIr,
1056 |args: &MaxPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
1057 let x = handles.get_float_tensor::<B>(&args.x);
1058 let output = B::max_pool2d(
1059 x,
1060 args.kernel_size,
1061 args.stride,
1062 args.padding,
1063 args.dilation,
1064 args.ceil_mode,
1065 );
1066
1067 handles.register_float_tensor::<B>(&args.out.id, output);
1068 }
1069 );
1070
1071 let streams = OperationStreams::with_inputs([&x]);
1072
1073 let client = x.client.clone();
1074 let desc = MaxPool2dOpIr::create(
1075 x.into_ir(),
1076 kernel_size,
1077 stride,
1078 padding,
1079 dilation,
1080 ceil_mode,
1081 || client.create_empty_handle(),
1082 );
1083
1084 client
1085 .register(
1086 streams,
1087 OperationIr::Module(ModuleOperationIr::MaxPool2d(desc.clone())),
1088 MaxPool2dOps::<B>::new(desc),
1089 )
1090 .output()
1091 }
1092
1093 fn max_pool1d_with_indices(
1094 x: FloatTensor<Self>,
1095 kernel_size: usize,
1096 stride: usize,
1097 padding: usize,
1098 dilation: usize,
1099 ceil_mode: bool,
1100 ) -> MaxPool1dWithIndices<Self> {
1101 make_ops!(
1102 MaxPool1dWithIndicesOps,
1103 MaxPool1dWithIndicesOpIr,
1104 |args: &MaxPool1dWithIndicesOpIr, handles: &mut HandleContainer<B::Handle>| {
1105 let x = handles.get_float_tensor::<B>(&args.x);
1106 let output = B::max_pool1d_with_indices(
1107 x,
1108 args.kernel_size,
1109 args.stride,
1110 args.padding,
1111 args.dilation,
1112 args.ceil_mode,
1113 );
1114
1115 handles.register_float_tensor::<B>(&args.out.id, output.output);
1116 handles.register_int_tensor::<B>(&args.out_indices.id, output.indices);
1117 }
1118 );
1119
1120 let streams = OperationStreams::with_inputs([&x]);
1121
1122 let client = x.client.clone();
1123 let desc = MaxPool1dWithIndicesOpIr::create(
1124 x.into_ir(),
1125 kernel_size,
1126 stride,
1127 padding,
1128 dilation,
1129 ceil_mode,
1130 B::IntElem::dtype(),
1131 || client.create_empty_handle(),
1132 );
1133
1134 let [out, out_indices] = client
1135 .register(
1136 streams,
1137 OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndices(desc.clone())),
1138 MaxPool1dWithIndicesOps::<B>::new(desc),
1139 )
1140 .outputs();
1141
1142 MaxPool1dWithIndices::new(out, out_indices)
1143 }
1144
1145 fn max_pool2d_with_indices(
1146 x: FloatTensor<Self>,
1147 kernel_size: [usize; 2],
1148 stride: [usize; 2],
1149 padding: [usize; 2],
1150 dilation: [usize; 2],
1151 ceil_mode: bool,
1152 ) -> MaxPool2dWithIndices<Self> {
1153 make_ops!(
1154 MaxPool2dWithIndicesOps,
1155 MaxPool2dWithIndicesOpIr,
1156 |args: &MaxPool2dWithIndicesOpIr, handles: &mut HandleContainer<B::Handle>| {
1157 let x = handles.get_float_tensor::<B>(&args.x);
1158 let output = B::max_pool2d_with_indices(
1159 x,
1160 args.kernel_size,
1161 args.stride,
1162 args.padding,
1163 args.dilation,
1164 args.ceil_mode,
1165 );
1166
1167 handles.register_float_tensor::<B>(&args.out.id, output.output);
1168 handles.register_int_tensor::<B>(&args.out_indices.id, output.indices);
1169 }
1170 );
1171
1172 let streams = OperationStreams::with_inputs([&x]);
1173
1174 let client = x.client.clone();
1175 let desc = MaxPool2dWithIndicesOpIr::create(
1176 x.into_ir(),
1177 kernel_size,
1178 stride,
1179 padding,
1180 dilation,
1181 ceil_mode,
1182 B::IntElem::dtype(),
1183 || client.create_empty_handle(),
1184 );
1185
1186 let [out, out_indices] = client
1187 .register(
1188 streams,
1189 OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndices(desc.clone())),
1190 MaxPool2dWithIndicesOps::<B>::new(desc),
1191 )
1192 .outputs();
1193
1194 MaxPool2dWithIndices::new(out, out_indices)
1195 }
1196
1197 fn max_pool1d_with_indices_backward(
1198 x: FloatTensor<Self>,
1199 kernel_size: usize,
1200 stride: usize,
1201 padding: usize,
1202 dilation: usize,
1203 ceil_mode: bool,
1204 output_grad: FloatTensor<Self>,
1205 indices: IntTensor<Self>,
1206 ) -> MaxPool1dBackward<Self> {
1207 make_ops!(
1208 MaxPool1dWithIndicesBackwardOps,
1209 MaxPool1dWithIndicesBackwardOpIr,
1210 |args: &MaxPool1dWithIndicesBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1211 let x = handles.get_float_tensor::<B>(&args.x);
1212 let grad = handles.get_float_tensor::<B>(&args.grad);
1213 let indices = handles.get_int_tensor::<B>(&args.indices);
1214 let output = B::max_pool1d_with_indices_backward(
1215 x,
1216 args.kernel_size,
1217 args.stride,
1218 args.padding,
1219 args.dilation,
1220 args.ceil_mode,
1221 grad,
1222 indices,
1223 );
1224
1225 handles.register_float_tensor::<B>(&args.out.id, output.x_grad);
1226 }
1227 );
1228
1229 let streams = OperationStreams::with_inputs([&x, &output_grad, &indices]);
1230
1231 let client = x.client.clone();
1232 let desc = MaxPool1dWithIndicesBackwardOpIr::create(
1233 x.into_ir(),
1234 output_grad.into_ir(),
1235 indices.into_ir(),
1236 kernel_size,
1237 stride,
1238 padding,
1239 dilation,
1240 ceil_mode,
1241 || client.create_empty_handle(),
1242 );
1243
1244 let out = client
1245 .register(
1246 streams,
1247 OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndicesBackward(
1248 desc.clone(),
1249 )),
1250 MaxPool1dWithIndicesBackwardOps::<B>::new(desc),
1251 )
1252 .output();
1253
1254 MaxPool1dBackward::new(out)
1255 }
1256
1257 fn max_pool2d_with_indices_backward(
1258 x: FloatTensor<Self>,
1259 kernel_size: [usize; 2],
1260 stride: [usize; 2],
1261 padding: [usize; 2],
1262 dilation: [usize; 2],
1263 ceil_mode: bool,
1264 output_grad: FloatTensor<Self>,
1265 indices: IntTensor<Self>,
1266 ) -> MaxPool2dBackward<Self> {
1267 make_ops!(
1268 MaxPool2dWithIndicesBackwardOps,
1269 MaxPool2dWithIndicesBackwardOpIr,
1270 |args: &MaxPool2dWithIndicesBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1271 let x = handles.get_float_tensor::<B>(&args.x);
1272 let grad = handles.get_float_tensor::<B>(&args.grad);
1273 let indices = handles.get_int_tensor::<B>(&args.indices);
1274 let output = B::max_pool2d_with_indices_backward(
1275 x,
1276 args.kernel_size,
1277 args.stride,
1278 args.padding,
1279 args.dilation,
1280 args.ceil_mode,
1281 grad,
1282 indices,
1283 );
1284
1285 handles.register_float_tensor::<B>(&args.out.id, output.x_grad);
1286 }
1287 );
1288
1289 let streams = OperationStreams::with_inputs([&x, &output_grad, &indices]);
1290
1291 let client = x.client.clone();
1292 let desc = MaxPool2dWithIndicesBackwardOpIr::create(
1293 x.into_ir(),
1294 output_grad.into_ir(),
1295 indices.into_ir(),
1296 kernel_size,
1297 stride,
1298 padding,
1299 dilation,
1300 ceil_mode,
1301 || client.create_empty_handle(),
1302 );
1303
1304 let out = client
1305 .register(
1306 streams,
1307 OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndicesBackward(
1308 desc.clone(),
1309 )),
1310 MaxPool2dWithIndicesBackwardOps::<B>::new(desc),
1311 )
1312 .output();
1313
1314 MaxPool2dBackward::new(out)
1315 }
1316
1317 fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
1318 make_ops!(
1319 AdaptiveAvgPool1dOps,
1320 AdaptiveAvgPool1dOpIr,
1321 |args: &AdaptiveAvgPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {
1322 let x = handles.get_float_tensor::<B>(&args.x);
1323 let output = B::adaptive_avg_pool1d(x, args.output_size);
1324
1325 handles.register_float_tensor::<B>(&args.out.id, output);
1326 }
1327 );
1328
1329 let streams = OperationStreams::with_inputs([&x]);
1330
1331 let client = x.client.clone();
1332 let desc = AdaptiveAvgPool1dOpIr::create(x.into_ir(), output_size, || {
1333 client.create_empty_handle()
1334 });
1335
1336 client
1337 .register(
1338 streams,
1339 OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d(desc.clone())),
1340 AdaptiveAvgPool1dOps::<B>::new(desc),
1341 )
1342 .output()
1343 }
1344
1345 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
1346 make_ops!(
1347 AdaptiveAvgPool2dOps,
1348 AdaptiveAvgPool2dOpIr,
1349 |args: &AdaptiveAvgPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {
1350 let x = handles.get_float_tensor::<B>(&args.x);
1351 let output = B::adaptive_avg_pool2d(x, args.output_size);
1352
1353 handles.register_float_tensor::<B>(&args.out.id, output);
1354 }
1355 );
1356
1357 let streams = OperationStreams::with_inputs([&x]);
1358
1359 let client = x.client.clone();
1360 let desc = AdaptiveAvgPool2dOpIr::create(x.into_ir(), output_size, || {
1361 client.create_empty_handle()
1362 });
1363
1364 client
1365 .register(
1366 streams,
1367 OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d(desc.clone())),
1368 AdaptiveAvgPool2dOps::<B>::new(desc),
1369 )
1370 .output()
1371 }
1372
1373 fn adaptive_avg_pool1d_backward(
1374 x: FloatTensor<Self>,
1375 grad: FloatTensor<Self>,
1376 ) -> FloatTensor<Self> {
1377 make_ops!(
1378 AdaptiveAvgPool1dBackwardOps,
1379 AdaptiveAvgPool1dBackwardOpIr,
1380 |args: &AdaptiveAvgPool1dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1381 let x = handles.get_float_tensor::<B>(&args.x);
1382 let grad = handles.get_float_tensor::<B>(&args.grad);
1383 let output = B::adaptive_avg_pool1d_backward(x, grad);
1384
1385 handles.register_float_tensor::<B>(&args.out.id, output);
1386 }
1387 );
1388
1389 let streams = OperationStreams::with_inputs([&x, &grad]);
1390
1391 let client = x.client.clone();
1392 let desc = AdaptiveAvgPool1dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
1393 client.create_empty_handle()
1394 });
1395
1396 client
1397 .register(
1398 streams,
1399 OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1dBackward(desc.clone())),
1400 AdaptiveAvgPool1dBackwardOps::<B>::new(desc),
1401 )
1402 .output()
1403 }
1404
1405 fn adaptive_avg_pool2d_backward(
1406 x: FloatTensor<Self>,
1407 grad: FloatTensor<Self>,
1408 ) -> FloatTensor<Self> {
1409 make_ops!(
1410 AdaptiveAvgPool2dBackwardOps,
1411 AdaptiveAvgPool2dBackwardOpIr,
1412 |args: &AdaptiveAvgPool2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1413 let x = handles.get_float_tensor::<B>(&args.x);
1414 let grad = handles.get_float_tensor::<B>(&args.grad);
1415 let output = B::adaptive_avg_pool2d_backward(x, grad);
1416
1417 handles.register_float_tensor::<B>(&args.out.id, output);
1418 }
1419 );
1420 let streams = OperationStreams::with_inputs([&x, &grad]);
1421
1422 let client = x.client.clone();
1423 let desc = AdaptiveAvgPool2dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
1424 client.create_empty_handle()
1425 });
1426
1427 client
1428 .register(
1429 streams,
1430 OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2dBackward(desc.clone())),
1431 AdaptiveAvgPool2dBackwardOps::<B>::new(desc),
1432 )
1433 .output()
1434 }
1435
1436 fn interpolate(
1437 x: FloatTensor<Self>,
1438 output_size: [usize; 2],
1439 options: InterpolateOptions,
1440 ) -> FloatTensor<Self> {
1441 make_ops!(
1442 InterpolateOps,
1443 InterpolateOpIr,
1444 |args: &InterpolateOpIr, handles: &mut HandleContainer<B::Handle>| {
1445 let x = handles.get_float_tensor::<B>(&args.x);
1446 let output = B::interpolate(x, args.output_size, args.options.clone().into());
1447 handles.register_float_tensor::<B>(&args.out.id, output);
1448 }
1449 );
1450
1451 let streams = OperationStreams::with_inputs([&x]);
1452
1453 let client = x.client.clone();
1454 let desc = InterpolateOpIr::create(x.into_ir(), output_size, options.into(), || {
1455 client.create_empty_handle()
1456 });
1457
1458 client
1459 .register(
1460 streams,
1461 OperationIr::Module(ModuleOperationIr::Interpolate(desc.clone())),
1462 InterpolateOps::<B>::new(desc),
1463 )
1464 .output()
1465 }
1466
1467 fn interpolate_backward(
1468 x: FloatTensor<Self>,
1469 grad: FloatTensor<Self>,
1470 output_size: [usize; 2],
1471 options: InterpolateOptions,
1472 ) -> FloatTensor<Self> {
1473 make_ops!(
1474 InterpolateBackwardOps,
1475 InterpolateBackwardOpIr,
1476 |args: &InterpolateBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1477 let x = handles.get_float_tensor::<B>(&args.x);
1478 let grad = handles.get_float_tensor::<B>(&args.grad);
1479 let output =
1480 B::interpolate_backward(x, grad, args.output_size, args.options.clone().into());
1481
1482 handles.register_float_tensor::<B>(&args.out.id, output);
1483 }
1484 );
1485
1486 let streams = OperationStreams::with_inputs([&x, &grad]);
1487
1488 let client = x.client.clone();
1489 let desc = InterpolateBackwardOpIr::create(
1490 x.into_ir(),
1491 grad.into_ir(),
1492 output_size,
1493 options.into(),
1494 || client.create_empty_handle(),
1495 );
1496
1497 client
1498 .register(
1499 streams,
1500 OperationIr::Module(ModuleOperationIr::InterpolateBackward(desc.clone())),
1501 InterpolateBackwardOps::<B>::new(desc),
1502 )
1503 .output()
1504 }
1505
1506 fn attention(
1507 query: FloatTensor<Fusion<B>>,
1508 key: FloatTensor<Fusion<B>>,
1509 value: FloatTensor<Fusion<B>>,
1510 mask: Option<burn_backend::tensor::BoolTensor<Fusion<B>>>,
1511 attn_bias: Option<FloatTensor<Fusion<B>>>,
1512 options: burn_backend::ops::AttentionModuleOptions,
1513 ) -> FloatTensor<Fusion<B>> {
1514 make_ops!(
1515 AttentionOps,
1516 AttentionOpIr,
1517 |args: &AttentionOpIr, handles: &mut HandleContainer<B::Handle>| {
1518 let query = handles.get_float_tensor::<B>(&args.query);
1519 let key = handles.get_float_tensor::<B>(&args.key);
1520 let value = handles.get_float_tensor::<B>(&args.value);
1521 let mask = args.mask.as_ref().map(|m| handles.get_bool_tensor::<B>(m));
1522 let attn_bias = args
1523 .attn_bias
1524 .as_ref()
1525 .map(|ab| handles.get_float_tensor::<B>(ab));
1526
1527 let output = B::attention(
1528 query,
1529 key,
1530 value,
1531 mask,
1532 attn_bias,
1533 args.options.clone().into(),
1534 );
1535
1536 handles.register_float_tensor::<B>(&args.out.id, output);
1537 }
1538 );
1539
1540 let mut streams = OperationStreams::with_inputs([&query, &key, &value]);
1541 if let Some(mask) = &mask {
1542 streams.tensor(mask);
1543 }
1544 if let Some(attn_bias) = &attn_bias {
1545 streams.tensor(attn_bias);
1546 }
1547
1548 let client = query.client.clone();
1549 let desc = AttentionOpIr::create(
1550 query.into_ir(),
1551 key.into_ir(),
1552 value.into_ir(),
1553 mask.map(|m| m.into_ir()),
1554 attn_bias.map(|ab| ab.into_ir()),
1555 options.into(),
1556 || client.create_empty_handle(),
1557 );
1558
1559 client
1560 .register(
1561 streams,
1562 OperationIr::Module(ModuleOperationIr::Attention(desc.clone())),
1563 AttentionOps::<B>::new(desc),
1564 )
1565 .output()
1566 }
1567
1568 fn rfft(
1569 signal: FloatTensor<Fusion<B>>,
1570 dim: usize,
1571 n: Option<usize>,
1572 ) -> (FloatTensor<Fusion<B>>, FloatTensor<Fusion<B>>) {
1573 make_ops!(RfftOps, RfftOpIr, |desc: &RfftOpIr,
1574 handles: &mut HandleContainer<
1575 B::Handle,
1576 >| {
1577 let signal = handles.get_float_tensor::<B>(&desc.signal);
1578 let (re, im) = B::rfft(signal, desc.dim, desc.n);
1579
1580 handles.register_float_tensor::<B>(&desc.out_re.id, re);
1581 handles.register_float_tensor::<B>(&desc.out_im.id, im);
1582 });
1583
1584 let streams = OperationStreams::with_inputs([&signal]);
1585 let client = signal.client.clone();
1586
1587 let desc = RfftOpIr::create(signal.into_ir(), dim, n, || client.create_empty_handle());
1588
1589 let mut outputs = client
1590 .register(
1591 streams,
1592 OperationIr::Module(ModuleOperationIr::Rfft(desc.clone())),
1593 RfftOps::<B>::new(desc),
1594 )
1595 .into_iter();
1596
1597 (outputs.next().unwrap(), outputs.next().unwrap())
1598 }
1599
1600 fn irfft(
1601 spectrum_re: FloatTensor<Fusion<B>>,
1602 spectrum_im: FloatTensor<Fusion<B>>,
1603 dim: usize,
1604 n: Option<usize>,
1605 ) -> FloatTensor<Fusion<B>> {
1606 make_ops!(IRfftOps, IRfftOpIr, |desc: &IRfftOpIr,
1607 handles: &mut HandleContainer<
1608 B::Handle,
1609 >| {
1610 let input_re = handles.get_float_tensor::<B>(&desc.input_re);
1611 let input_im = handles.get_float_tensor::<B>(&desc.input_im);
1612
1613 let signal = B::irfft(input_re, input_im, desc.dim, desc.n);
1614 handles.register_float_tensor::<B>(&desc.out_signal.id, signal);
1615 });
1616
1617 let streams = OperationStreams::with_inputs([&spectrum_re, &spectrum_im]);
1618 let client = spectrum_re.client.clone();
1619
1620 let desc = IRfftOpIr::create(spectrum_re.into_ir(), spectrum_im.into_ir(), dim, n, || {
1621 client.create_empty_handle()
1622 });
1623
1624 let mut outputs = client
1625 .register(
1626 streams,
1627 OperationIr::Module(ModuleOperationIr::IRfft(desc.clone())),
1628 IRfftOps::<B>::new(desc),
1629 )
1630 .into_iter();
1631
1632 outputs.next().unwrap()
1633 }
1634
1635 fn has_ctc_loss_backward() -> bool {
1636 B::has_ctc_loss_backward()
1637 }
1638
1639 fn ctc_loss(
1640 log_probs: FloatTensor<Fusion<B>>,
1641 targets: IntTensor<Fusion<B>>,
1642 input_lengths: IntTensor<Fusion<B>>,
1643 target_lengths: IntTensor<Fusion<B>>,
1644 blank: usize,
1645 ) -> FloatTensor<Fusion<B>> {
1646 make_ops!(CtcLossOps, CtcLossOpIr, |args: &CtcLossOpIr,
1653 handles: &mut HandleContainer<
1654 B::Handle,
1655 >| {
1656 let log_probs = handles.get_float_tensor::<B>(&args.log_probs);
1657 let targets = handles.get_int_tensor::<B>(&args.targets);
1658 let input_lengths = handles.get_int_tensor::<B>(&args.input_lengths);
1659 let target_lengths = handles.get_int_tensor::<B>(&args.target_lengths);
1660 let output = B::ctc_loss(
1661 log_probs,
1662 targets,
1663 input_lengths,
1664 target_lengths,
1665 args.blank,
1666 );
1667 handles.register_float_tensor::<B>(&args.out.id, output);
1668 });
1669
1670 let streams =
1671 OperationStreams::with_inputs([&log_probs, &targets, &input_lengths, &target_lengths]);
1672 let client = log_probs.client.clone();
1673 let desc = CtcLossOpIr::create(
1674 log_probs.into_ir(),
1675 targets.into_ir(),
1676 input_lengths.into_ir(),
1677 target_lengths.into_ir(),
1678 blank,
1679 || client.create_empty_handle(),
1680 );
1681
1682 client
1683 .register(
1684 streams,
1685 OperationIr::Module(ModuleOperationIr::CtcLoss(desc.clone())),
1686 CtcLossOps::<B>::new(desc),
1687 )
1688 .output()
1689 }
1690
1691 fn ctc_loss_backward(
1692 log_probs: FloatTensor<Fusion<B>>,
1693 targets: IntTensor<Fusion<B>>,
1694 input_lengths: IntTensor<Fusion<B>>,
1695 target_lengths: IntTensor<Fusion<B>>,
1696 grad_loss: FloatTensor<Fusion<B>>,
1697 blank: usize,
1698 ) -> FloatTensor<Fusion<B>> {
1699 make_ops!(
1702 CtcLossBackwardOps,
1703 CtcLossBackwardOpIr,
1704 |args: &CtcLossBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {
1705 let log_probs = handles.get_float_tensor::<B>(&args.log_probs);
1706 let targets = handles.get_int_tensor::<B>(&args.targets);
1707 let input_lengths = handles.get_int_tensor::<B>(&args.input_lengths);
1708 let target_lengths = handles.get_int_tensor::<B>(&args.target_lengths);
1709 let grad_loss = handles.get_float_tensor::<B>(&args.grad_loss);
1710 let output = B::ctc_loss_backward(
1711 log_probs,
1712 targets,
1713 input_lengths,
1714 target_lengths,
1715 grad_loss,
1716 args.blank,
1717 );
1718 handles.register_float_tensor::<B>(&args.out.id, output);
1719 }
1720 );
1721
1722 let streams = OperationStreams::with_inputs([
1723 &log_probs,
1724 &targets,
1725 &input_lengths,
1726 &target_lengths,
1727 &grad_loss,
1728 ]);
1729 let client = log_probs.client.clone();
1730 let desc = CtcLossBackwardOpIr::create(
1731 log_probs.into_ir(),
1732 targets.into_ir(),
1733 input_lengths.into_ir(),
1734 target_lengths.into_ir(),
1735 grad_loss.into_ir(),
1736 blank,
1737 || client.create_empty_handle(),
1738 );
1739
1740 client
1741 .register(
1742 streams,
1743 OperationIr::Module(ModuleOperationIr::CtcLossBackward(desc.clone())),
1744 CtcLossBackwardOps::<B>::new(desc),
1745 )
1746 .output()
1747 }
1748}