1use burn_backend::{
2 ops::{
3 DeformConv2dBackward, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,
4 MaxPool2dWithIndices, ModuleOps,
5 },
6 tensor::{FloatTensor, IntTensor},
7};
8
9use crate::Dispatch;
10use crate::backends::*;
11
12impl ModuleOps<Self> for Dispatch {
13 fn conv2d(
14 x: FloatTensor<Self>,
15 weight: FloatTensor<Self>,
16 bias: Option<FloatTensor<Self>>,
17 options: burn_backend::ops::ConvOptions<2>,
18 ) -> FloatTensor<Self> {
19 multi_op!(
20 inputs[(x, float), (weight, float)],
21 opt_inputs[(bias, float)],
22 => Float,
23 B::conv2d(x, weight, bias, options)
24 )
25 }
26
27 fn deform_conv2d(
28 x: FloatTensor<Self>,
29 offset: FloatTensor<Self>,
30 weight: FloatTensor<Self>,
31 mask: Option<FloatTensor<Self>>,
32 bias: Option<FloatTensor<Self>>,
33 options: burn_backend::ops::DeformConvOptions<2>,
34 ) -> FloatTensor<Self> {
35 multi_op!(
36 inputs[(x, float), (offset, float), (weight, float)],
37 opt_inputs[(mask, float), (bias, float)],
38 => Float,
39 B::deform_conv2d(x, offset, weight, mask, bias, options)
40 )
41 }
42
43 fn deform_conv2d_backward(
44 x: FloatTensor<Self>,
45 offset: FloatTensor<Self>,
46 weight: FloatTensor<Self>,
47 mask: Option<FloatTensor<Self>>,
48 bias: Option<FloatTensor<Self>>,
49 output_grad: FloatTensor<Self>,
50 options: burn_backend::ops::DeformConvOptions<2>,
51 ) -> DeformConv2dBackward<Self> {
52 let (x_grad, offset_grad, weight_grad, mask_grad, bias_grad) = multi_op!(
53 inputs[(x, float), (offset, float), (weight, float), (output_grad, float)],
54 opt_inputs[(mask, float), (bias, float)],
55 outputs[(x_grad, Float), (offset_grad, Float), (weight_grad, Float)],
56 opt_outputs[mask_grad, bias_grad],
57 {
58 let res = B::deform_conv2d_backward(x, offset, weight, mask, bias, output_grad, options);
59 (res.x_grad, res.offset_grad, res.weight_grad, res.mask_grad, res.bias_grad)
60 }
61 );
62 DeformConv2dBackward::new(x_grad, offset_grad, weight_grad, mask_grad, bias_grad)
63 }
64
65 fn conv3d(
66 x: FloatTensor<Self>,
67 weight: FloatTensor<Self>,
68 bias: Option<FloatTensor<Self>>,
69 options: burn_backend::ops::ConvOptions<3>,
70 ) -> FloatTensor<Self> {
71 multi_op!(
72 inputs[(x, float), (weight, float)],
73 opt_inputs[(bias, float)],
74 => Float,
75 B::conv3d(x, weight, bias, options)
76 )
77 }
78
79 fn conv_transpose2d(
80 x: FloatTensor<Self>,
81 weight: FloatTensor<Self>,
82 bias: Option<FloatTensor<Self>>,
83 options: burn_backend::ops::ConvTransposeOptions<2>,
84 ) -> FloatTensor<Self> {
85 multi_op!(
86 inputs[(x, float), (weight, float)],
87 opt_inputs[(bias, float)],
88 => Float,
89 B::conv_transpose2d(x, weight, bias, options)
90 )
91 }
92
93 fn conv_transpose3d(
94 x: FloatTensor<Self>,
95 weight: FloatTensor<Self>,
96 bias: Option<FloatTensor<Self>>,
97 options: burn_backend::ops::ConvTransposeOptions<3>,
98 ) -> FloatTensor<Self> {
99 multi_op!(
100 inputs[(x, float), (weight, float)],
101 opt_inputs[(bias, float)],
102 => Float,
103 B::conv_transpose3d(x, weight, bias, options)
104 )
105 }
106
107 fn avg_pool2d(
108 x: FloatTensor<Self>,
109 kernel_size: [usize; 2],
110 stride: [usize; 2],
111 padding: [usize; 2],
112 count_include_pad: bool,
113 ceil_mode: bool,
114 ) -> FloatTensor<Self> {
115 multi_op!(inputs[(x, float)],
116 => Float,
117 B::avg_pool2d(x, kernel_size, stride, padding, count_include_pad, ceil_mode)
118 )
119 }
120
121 fn avg_pool2d_backward(
122 x: FloatTensor<Self>,
123 grad: FloatTensor<Self>,
124 kernel_size: [usize; 2],
125 stride: [usize; 2],
126 padding: [usize; 2],
127 count_include_pad: bool,
128 ceil_mode: bool,
129 ) -> FloatTensor<Self> {
130 multi_op!(
131 inputs[(x, float), (grad, float)],
132 => Float,
133 B::avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode)
134 )
135 }
136
137 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
138 multi_op!(
139 inputs[(x, float)],
140 => Float,
141 B::adaptive_avg_pool2d(x, output_size)
142 )
143 }
144
145 fn adaptive_avg_pool2d_backward(
146 x: FloatTensor<Self>,
147 grad: FloatTensor<Self>,
148 ) -> FloatTensor<Self> {
149 multi_op!(
150 inputs[(x, float), (grad, float)],
151 => Float,
152 B::adaptive_avg_pool2d_backward(x, grad)
153 )
154 }
155
156 fn max_pool2d(
157 x: FloatTensor<Self>,
158 kernel_size: [usize; 2],
159 stride: [usize; 2],
160 padding: [usize; 2],
161 dilation: [usize; 2],
162 ceil_mode: bool,
163 ) -> FloatTensor<Self> {
164 multi_op!(
165 inputs[(x, float)],
166 => Float,
167 B::max_pool2d(x, kernel_size, stride, padding, dilation, ceil_mode)
168 )
169 }
170
171 fn max_pool2d_with_indices(
172 x: FloatTensor<Self>,
173 kernel_size: [usize; 2],
174 stride: [usize; 2],
175 padding: [usize; 2],
176 dilation: [usize; 2],
177 ceil_mode: bool,
178 ) -> MaxPool2dWithIndices<Self> {
179 let (out, indices) = multi_op!(
180 inputs[(x, float)],
181 outputs[(out, Float), (indices, Int)],
182 {
183 let res = B::max_pool2d_with_indices(x, kernel_size, stride, padding, dilation, ceil_mode);
184 (res.output, res.indices)
185 }
186 );
187 MaxPool2dWithIndices::new(out, indices)
188 }
189
190 fn max_pool2d_with_indices_backward(
191 x: FloatTensor<Self>,
192 kernel_size: [usize; 2],
193 stride: [usize; 2],
194 padding: [usize; 2],
195 dilation: [usize; 2],
196 ceil_mode: bool,
197 output_grad: FloatTensor<Self>,
198 indices: IntTensor<Self>,
199 ) -> MaxPool2dBackward<Self> {
200 let x_grad = multi_op!(
201 inputs[(x, float), (output_grad, float), (indices, int)],
202 => Float,
203 {
204 let res = B::max_pool2d_with_indices_backward(x, kernel_size, stride, padding, dilation, ceil_mode, output_grad, indices);
205 res.x_grad
206 }
207 );
208 MaxPool2dBackward::new(x_grad)
209 }
210
211 fn interpolate(
212 x: FloatTensor<Self>,
213 output_size: [usize; 2],
214 options: burn_backend::ops::InterpolateOptions,
215 ) -> FloatTensor<Self> {
216 multi_op!(
217 inputs[(x, float)],
218 => Float,
219 B::interpolate(x, output_size, options)
220 )
221 }
222
223 fn interpolate_backward(
224 x: FloatTensor<Self>,
225 grad: FloatTensor<Self>,
226 output_size: [usize; 2],
227 options: burn_backend::ops::InterpolateOptions,
228 ) -> FloatTensor<Self> {
229 multi_op!(
230 inputs[(x, float), (grad, float)],
231 => Float,
232 B::interpolate_backward(x, grad, output_size, options)
233 )
234 }
235
236 fn embedding(weights: FloatTensor<Self>, indices: IntTensor<Self>) -> FloatTensor<Self> {
237 multi_op!(
238 inputs[(weights, float), (indices, int)],
239 => Float,
240 B::embedding(weights, indices)
241 )
242 }
243
244 fn embedding_backward(
245 weights: FloatTensor<Self>,
246 output_grad: FloatTensor<Self>,
247 indices: IntTensor<Self>,
248 ) -> FloatTensor<Self> {
249 multi_op!(
250 inputs[(weights, float), (output_grad, float), (indices, int)],
251 => Float,
252 B::embedding_backward(weights, output_grad, indices)
253 )
254 }
255
256 fn conv1d(
257 x: FloatTensor<Self>,
258 weight: FloatTensor<Self>,
259 bias: Option<FloatTensor<Self>>,
260 options: burn_backend::ops::ConvOptions<1>,
261 ) -> FloatTensor<Self> {
262 multi_op!(
263 inputs[(x, float), (weight, float)],
264 opt_inputs[(bias, float)],
265 => Float,
266 B::conv1d(x, weight, bias, options)
267 )
268 }
269
270 fn conv1d_x_backward(
271 x: FloatTensor<Self>,
272 weight: FloatTensor<Self>,
273 output_grad: FloatTensor<Self>,
274 options: burn_backend::ops::ConvOptions<1>,
275 ) -> FloatTensor<Self> {
276 multi_op!(
277 inputs[(x, float), (weight, float), (output_grad, float)],
278 => Float,
279 B::conv1d_x_backward(x, weight, output_grad, options)
280 )
281 }
282
283 fn conv1d_weight_backward(
284 x: FloatTensor<Self>,
285 weight: FloatTensor<Self>,
286 output_grad: FloatTensor<Self>,
287 options: burn_backend::ops::ConvOptions<1>,
288 ) -> FloatTensor<Self> {
289 multi_op!(
290 inputs[(x, float), (weight, float), (output_grad, float)],
291 => Float,
292 B::conv1d_weight_backward(x, weight, output_grad, options)
293 )
294 }
295
296 fn conv1d_bias_backward(
297 x: FloatTensor<Self>,
298 bias: FloatTensor<Self>,
299 output_grad: FloatTensor<Self>,
300 ) -> FloatTensor<Self> {
301 multi_op!(
302 inputs[(x, float), (bias, float), (output_grad, float)],
303 => Float,
304 B::conv1d_bias_backward(x, bias, output_grad)
305 )
306 }
307
308 fn conv2d_x_backward(
309 x: FloatTensor<Self>,
310 weight: FloatTensor<Self>,
311 output_grad: FloatTensor<Self>,
312 options: burn_backend::ops::ConvOptions<2>,
313 ) -> FloatTensor<Self> {
314 multi_op!(
315 inputs[(x, float), (weight, float), (output_grad, float)],
316 => Float,
317 B::conv2d_x_backward(x, weight, output_grad, options)
318 )
319 }
320
321 fn conv2d_weight_backward(
322 x: FloatTensor<Self>,
323 weight: FloatTensor<Self>,
324 output_grad: FloatTensor<Self>,
325 options: burn_backend::ops::ConvOptions<2>,
326 ) -> FloatTensor<Self> {
327 multi_op!(
328 inputs[(x, float), (weight, float), (output_grad, float)],
329 => Float,
330 B::conv2d_weight_backward(x, weight, output_grad, options)
331 )
332 }
333
334 fn conv2d_bias_backward(
335 x: FloatTensor<Self>,
336 bias: FloatTensor<Self>,
337 output_grad: FloatTensor<Self>,
338 ) -> FloatTensor<Self> {
339 multi_op!(
340 inputs[(x, float), (bias, float), (output_grad, float)],
341 => Float,
342 B::conv2d_bias_backward(x, bias, output_grad)
343 )
344 }
345
346 fn conv3d_x_backward(
347 x: FloatTensor<Self>,
348 weight: FloatTensor<Self>,
349 output_grad: FloatTensor<Self>,
350 options: burn_backend::ops::ConvOptions<3>,
351 ) -> FloatTensor<Self> {
352 multi_op!(
353 inputs[(x, float), (weight, float), (output_grad, float)],
354 => Float,
355 B::conv3d_x_backward(x, weight, output_grad, options)
356 )
357 }
358
359 fn conv3d_weight_backward(
360 x: FloatTensor<Self>,
361 weight: FloatTensor<Self>,
362 output_grad: FloatTensor<Self>,
363 options: burn_backend::ops::ConvOptions<3>,
364 ) -> FloatTensor<Self> {
365 multi_op!(
366 inputs[(x, float), (weight, float), (output_grad, float)],
367 => Float,
368 B::conv3d_weight_backward(x, weight, output_grad, options)
369 )
370 }
371
372 fn conv3d_bias_backward(
373 x: FloatTensor<Self>,
374 bias: FloatTensor<Self>,
375 output_grad: FloatTensor<Self>,
376 ) -> FloatTensor<Self> {
377 multi_op!(
378 inputs[(x, float), (bias, float), (output_grad, float)],
379 => Float,
380 B::conv3d_bias_backward(x, bias, output_grad)
381 )
382 }
383
384 fn conv_transpose1d(
385 x: FloatTensor<Self>,
386 weight: FloatTensor<Self>,
387 bias: Option<FloatTensor<Self>>,
388 options: burn_backend::ops::ConvTransposeOptions<1>,
389 ) -> FloatTensor<Self> {
390 multi_op!(
391 inputs[(x, float), (weight, float)],
392 opt_inputs[(bias, float)],
393 => Float,
394 B::conv_transpose1d(x, weight, bias, options)
395 )
396 }
397
398 fn conv_transpose1d_x_backward(
399 weight: FloatTensor<Self>,
400 output_grad: FloatTensor<Self>,
401 options: burn_backend::ops::ConvTransposeOptions<1>,
402 ) -> FloatTensor<Self> {
403 multi_op!(
404 inputs[(weight, float), (output_grad, float)],
405 => Float,
406 B::conv_transpose1d_x_backward(weight, output_grad, options)
407 )
408 }
409
410 fn conv_transpose1d_weight_backward(
411 x: FloatTensor<Self>,
412 weight: FloatTensor<Self>,
413 output_grad: FloatTensor<Self>,
414 options: burn_backend::ops::ConvTransposeOptions<1>,
415 ) -> FloatTensor<Self> {
416 multi_op!(
417 inputs[(x, float), (weight, float), (output_grad, float)],
418 => Float,
419 B::conv_transpose1d_weight_backward(x, weight, output_grad, options)
420 )
421 }
422
423 fn conv_transpose1d_bias_backward(
424 x: FloatTensor<Self>,
425 bias: FloatTensor<Self>,
426 output_grad: FloatTensor<Self>,
427 ) -> FloatTensor<Self> {
428 multi_op!(
429 inputs[(x, float), (bias, float), (output_grad, float)],
430 => Float,
431 B::conv_transpose1d_bias_backward(x, bias, output_grad)
432 )
433 }
434
435 fn conv_transpose2d_x_backward(
436 weight: FloatTensor<Self>,
437 output_grad: FloatTensor<Self>,
438 options: burn_backend::ops::ConvTransposeOptions<2>,
439 ) -> FloatTensor<Self> {
440 multi_op!(
441 inputs[(weight, float), (output_grad, float)],
442 => Float,
443 B::conv_transpose2d_x_backward(weight, output_grad, options)
444 )
445 }
446
447 fn conv_transpose2d_weight_backward(
448 x: FloatTensor<Self>,
449 weight: FloatTensor<Self>,
450 output_grad: FloatTensor<Self>,
451 options: burn_backend::ops::ConvTransposeOptions<2>,
452 ) -> FloatTensor<Self> {
453 multi_op!(
454 inputs[(x, float), (weight, float), (output_grad, float)],
455 => Float,
456 B::conv_transpose2d_weight_backward(x, weight, output_grad, options)
457 )
458 }
459
460 fn conv_transpose2d_bias_backward(
461 x: FloatTensor<Self>,
462 bias: FloatTensor<Self>,
463 output_grad: FloatTensor<Self>,
464 ) -> FloatTensor<Self> {
465 multi_op!(
466 inputs[(x, float), (bias, float), (output_grad, float)],
467 => Float,
468 B::conv_transpose2d_bias_backward(x, bias, output_grad)
469 )
470 }
471
472 fn conv_transpose3d_x_backward(
473 weight: FloatTensor<Self>,
474 output_grad: FloatTensor<Self>,
475 options: burn_backend::ops::ConvTransposeOptions<3>,
476 ) -> FloatTensor<Self> {
477 multi_op!(
478 inputs[(weight, float), (output_grad, float)],
479 => Float,
480 B::conv_transpose3d_x_backward(weight, output_grad, options)
481 )
482 }
483
484 fn conv_transpose3d_weight_backward(
485 x: FloatTensor<Self>,
486 weight: FloatTensor<Self>,
487 output_grad: FloatTensor<Self>,
488 options: burn_backend::ops::ConvTransposeOptions<3>,
489 ) -> FloatTensor<Self> {
490 multi_op!(
491 inputs[(x, float), (weight, float), (output_grad, float)],
492 => Float,
493 B::conv_transpose3d_weight_backward(x, weight, output_grad, options)
494 )
495 }
496
497 fn conv_transpose3d_bias_backward(
498 x: FloatTensor<Self>,
499 bias: FloatTensor<Self>,
500 output_grad: FloatTensor<Self>,
501 ) -> FloatTensor<Self> {
502 multi_op!(
503 inputs[(x, float), (bias, float), (output_grad, float)],
504 => Float,
505 B::conv_transpose3d_bias_backward(x, bias, output_grad)
506 )
507 }
508
509 fn unfold4d(
510 x: FloatTensor<Self>,
511 kernel_size: [usize; 2],
512 options: burn_backend::ops::UnfoldOptions,
513 ) -> FloatTensor<Self> {
514 multi_op!(inputs[(x, float)], => Float, B::unfold4d(x, kernel_size, options))
515 }
516
517 fn avg_pool1d(
518 x: FloatTensor<Self>,
519 kernel_size: usize,
520 stride: usize,
521 padding: usize,
522 count_include_pad: bool,
523 ceil_mode: bool,
524 ) -> FloatTensor<Self> {
525 multi_op!(inputs[(x, float)], => Float,
526 B::avg_pool1d(x, kernel_size, stride, padding, count_include_pad, ceil_mode)
527 )
528 }
529
530 fn avg_pool1d_backward(
531 x: FloatTensor<Self>,
532 grad: FloatTensor<Self>,
533 kernel_size: usize,
534 stride: usize,
535 padding: usize,
536 count_include_pad: bool,
537 ceil_mode: bool,
538 ) -> FloatTensor<Self> {
539 multi_op!(
540 inputs[(x, float), (grad, float)],
541 => Float,
542 B::avg_pool1d_backward(x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode)
543 )
544 }
545
546 fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
547 multi_op!(inputs[(x, float)], => Float, B::adaptive_avg_pool1d(x, output_size))
548 }
549
550 fn adaptive_avg_pool1d_backward(
551 x: FloatTensor<Self>,
552 grad: FloatTensor<Self>,
553 ) -> FloatTensor<Self> {
554 multi_op!(
555 inputs[(x, float), (grad, float)],
556 => Float,
557 B::adaptive_avg_pool1d_backward(x, grad)
558 )
559 }
560
561 fn max_pool1d(
562 x: FloatTensor<Self>,
563 kernel_size: usize,
564 stride: usize,
565 padding: usize,
566 dilation: usize,
567 ceil_mode: bool,
568 ) -> FloatTensor<Self> {
569 multi_op!(inputs[(x, float)], => Float,
570 B::max_pool1d(x, kernel_size, stride, padding, dilation, ceil_mode))
571 }
572
573 fn max_pool1d_with_indices(
574 x: FloatTensor<Self>,
575 kernel_size: usize,
576 stride: usize,
577 padding: usize,
578 dilation: usize,
579 ceil_mode: bool,
580 ) -> MaxPool1dWithIndices<Self> {
581 let (out, indices) = multi_op!(
582 inputs[(x, float)],
583 outputs[(out, Float), (indices, Int)],
584 {
585 let res = B::max_pool1d_with_indices(x, kernel_size, stride, padding, dilation, ceil_mode);
586 (res.output, res.indices)
587 }
588 );
589 MaxPool1dWithIndices::new(out, indices)
590 }
591
592 fn max_pool1d_with_indices_backward(
593 x: FloatTensor<Self>,
594 kernel_size: usize,
595 stride: usize,
596 padding: usize,
597 dilation: usize,
598 ceil_mode: bool,
599 output_grad: FloatTensor<Self>,
600 indices: IntTensor<Self>,
601 ) -> MaxPool1dBackward<Self> {
602 let x_grad = multi_op!(
603 inputs[(x, float), (output_grad, float), (indices, int)],
604 => Float,
605 {
606 let res = B::max_pool1d_with_indices_backward(x, kernel_size, stride, padding, dilation, ceil_mode, output_grad, indices);
607 res.x_grad
608 }
609 );
610 MaxPool1dBackward::new(x_grad)
611 }
612
613 fn attention(
614 query: FloatTensor<Self>,
615 key: FloatTensor<Self>,
616 value: FloatTensor<Self>,
617 mask: Option<burn_backend::tensor::BoolTensor<Self>>,
618 attn_bias: Option<FloatTensor<Self>>,
619 options: burn_backend::ops::AttentionModuleOptions,
620 ) -> FloatTensor<Self> {
621 multi_op!(
622 inputs[(query, float), (key, float), (value, float)],
623 opt_inputs[(mask, bool), (attn_bias, float)],
624 => Float,
625 B::attention(query, key, value, mask, attn_bias, options)
626 )
627 }
628}