Skip to main content

burn_dispatch/ops/
module.rs

1use burn_backend::{
2    ops::{
3        DeformConv2dBackward, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,
4        MaxPool2dWithIndices, ModuleOps,
5    },
6    tensor::{FloatTensor, IntTensor},
7};
8
9use crate::Dispatch;
10use crate::backends::*;
11
12impl ModuleOps<Self> for Dispatch {
13    fn conv2d(
14        x: FloatTensor<Self>,
15        weight: FloatTensor<Self>,
16        bias: Option<FloatTensor<Self>>,
17        options: burn_backend::ops::ConvOptions<2>,
18    ) -> FloatTensor<Self> {
19        multi_op!(
20            inputs[(x, float), (weight, float)],
21            opt_inputs[(bias, float)],
22            => Float,
23            B::conv2d(x, weight, bias, options)
24        )
25    }
26
27    fn deform_conv2d(
28        x: FloatTensor<Self>,
29        offset: FloatTensor<Self>,
30        weight: FloatTensor<Self>,
31        mask: Option<FloatTensor<Self>>,
32        bias: Option<FloatTensor<Self>>,
33        options: burn_backend::ops::DeformConvOptions<2>,
34    ) -> FloatTensor<Self> {
35        multi_op!(
36            inputs[(x, float), (offset, float), (weight, float)],
37            opt_inputs[(mask, float), (bias, float)],
38            => Float,
39            B::deform_conv2d(x, offset, weight, mask, bias, options)
40        )
41    }
42
43    fn deform_conv2d_backward(
44        x: FloatTensor<Self>,
45        offset: FloatTensor<Self>,
46        weight: FloatTensor<Self>,
47        mask: Option<FloatTensor<Self>>,
48        bias: Option<FloatTensor<Self>>,
49        output_grad: FloatTensor<Self>,
50        options: burn_backend::ops::DeformConvOptions<2>,
51    ) -> DeformConv2dBackward<Self> {
52        let (x_grad, offset_grad, weight_grad, mask_grad, bias_grad) = multi_op!(
53            inputs[(x, float), (offset, float), (weight, float), (output_grad, float)],
54            opt_inputs[(mask, float), (bias, float)],
55            outputs[(x_grad, Float), (offset_grad, Float), (weight_grad, Float)],
56            opt_outputs[mask_grad, bias_grad],
57            {
58                let res = B::deform_conv2d_backward(x, offset, weight, mask, bias, output_grad, options);
59                (res.x_grad, res.offset_grad, res.weight_grad, res.mask_grad, res.bias_grad)
60            }
61        );
62        DeformConv2dBackward::new(x_grad, offset_grad, weight_grad, mask_grad, bias_grad)
63    }
64
65    fn conv3d(
66        x: FloatTensor<Self>,
67        weight: FloatTensor<Self>,
68        bias: Option<FloatTensor<Self>>,
69        options: burn_backend::ops::ConvOptions<3>,
70    ) -> FloatTensor<Self> {
71        multi_op!(
72            inputs[(x, float), (weight, float)],
73            opt_inputs[(bias, float)],
74            => Float,
75            B::conv3d(x, weight, bias, options)
76        )
77    }
78
79    fn conv_transpose2d(
80        x: FloatTensor<Self>,
81        weight: FloatTensor<Self>,
82        bias: Option<FloatTensor<Self>>,
83        options: burn_backend::ops::ConvTransposeOptions<2>,
84    ) -> FloatTensor<Self> {
85        multi_op!(
86            inputs[(x, float), (weight, float)],
87            opt_inputs[(bias, float)],
88            => Float,
89            B::conv_transpose2d(x, weight, bias, options)
90        )
91    }
92
93    fn conv_transpose3d(
94        x: FloatTensor<Self>,
95        weight: FloatTensor<Self>,
96        bias: Option<FloatTensor<Self>>,
97        options: burn_backend::ops::ConvTransposeOptions<3>,
98    ) -> FloatTensor<Self> {
99        multi_op!(
100            inputs[(x, float), (weight, float)],
101            opt_inputs[(bias, float)],
102            => Float,
103            B::conv_transpose3d(x, weight, bias, options)
104        )
105    }
106
107    fn avg_pool2d(
108        x: FloatTensor<Self>,
109        kernel_size: [usize; 2],
110        stride: [usize; 2],
111        padding: [usize; 2],
112        count_include_pad: bool,
113        ceil_mode: bool,
114    ) -> FloatTensor<Self> {
115        multi_op!(inputs[(x, float)],
116            => Float,
117            B::avg_pool2d(x, kernel_size, stride, padding, count_include_pad, ceil_mode)
118        )
119    }
120
121    fn avg_pool2d_backward(
122        x: FloatTensor<Self>,
123        grad: FloatTensor<Self>,
124        kernel_size: [usize; 2],
125        stride: [usize; 2],
126        padding: [usize; 2],
127        count_include_pad: bool,
128        ceil_mode: bool,
129    ) -> FloatTensor<Self> {
130        multi_op!(
131            inputs[(x, float), (grad, float)],
132            => Float,
133            B::avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode)
134        )
135    }
136
137    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
138        multi_op!(
139            inputs[(x, float)],
140            => Float,
141            B::adaptive_avg_pool2d(x, output_size)
142        )
143    }
144
145    fn adaptive_avg_pool2d_backward(
146        x: FloatTensor<Self>,
147        grad: FloatTensor<Self>,
148    ) -> FloatTensor<Self> {
149        multi_op!(
150            inputs[(x, float), (grad, float)],
151            => Float,
152            B::adaptive_avg_pool2d_backward(x, grad)
153        )
154    }
155
156    fn max_pool2d(
157        x: FloatTensor<Self>,
158        kernel_size: [usize; 2],
159        stride: [usize; 2],
160        padding: [usize; 2],
161        dilation: [usize; 2],
162        ceil_mode: bool,
163    ) -> FloatTensor<Self> {
164        multi_op!(
165            inputs[(x, float)],
166            => Float,
167            B::max_pool2d(x, kernel_size, stride, padding, dilation, ceil_mode)
168        )
169    }
170
171    fn max_pool2d_with_indices(
172        x: FloatTensor<Self>,
173        kernel_size: [usize; 2],
174        stride: [usize; 2],
175        padding: [usize; 2],
176        dilation: [usize; 2],
177        ceil_mode: bool,
178    ) -> MaxPool2dWithIndices<Self> {
179        let (out, indices) = multi_op!(
180            inputs[(x, float)],
181            outputs[(out, Float), (indices, Int)],
182            {
183                let res = B::max_pool2d_with_indices(x, kernel_size, stride, padding, dilation, ceil_mode);
184                (res.output, res.indices)
185            }
186        );
187        MaxPool2dWithIndices::new(out, indices)
188    }
189
190    fn max_pool2d_with_indices_backward(
191        x: FloatTensor<Self>,
192        kernel_size: [usize; 2],
193        stride: [usize; 2],
194        padding: [usize; 2],
195        dilation: [usize; 2],
196        ceil_mode: bool,
197        output_grad: FloatTensor<Self>,
198        indices: IntTensor<Self>,
199    ) -> MaxPool2dBackward<Self> {
200        let x_grad = multi_op!(
201            inputs[(x, float), (output_grad, float), (indices, int)],
202            => Float,
203            {
204                let res = B::max_pool2d_with_indices_backward(x, kernel_size, stride, padding, dilation, ceil_mode, output_grad, indices);
205                res.x_grad
206            }
207        );
208        MaxPool2dBackward::new(x_grad)
209    }
210
211    fn interpolate(
212        x: FloatTensor<Self>,
213        output_size: [usize; 2],
214        options: burn_backend::ops::InterpolateOptions,
215    ) -> FloatTensor<Self> {
216        multi_op!(
217            inputs[(x, float)],
218            => Float,
219            B::interpolate(x, output_size, options)
220        )
221    }
222
223    fn interpolate_backward(
224        x: FloatTensor<Self>,
225        grad: FloatTensor<Self>,
226        output_size: [usize; 2],
227        options: burn_backend::ops::InterpolateOptions,
228    ) -> FloatTensor<Self> {
229        multi_op!(
230            inputs[(x, float), (grad, float)],
231            => Float,
232            B::interpolate_backward(x, grad, output_size, options)
233        )
234    }
235
236    fn embedding(weights: FloatTensor<Self>, indices: IntTensor<Self>) -> FloatTensor<Self> {
237        multi_op!(
238            inputs[(weights, float), (indices, int)],
239            => Float,
240            B::embedding(weights, indices)
241        )
242    }
243
244    fn embedding_backward(
245        weights: FloatTensor<Self>,
246        output_grad: FloatTensor<Self>,
247        indices: IntTensor<Self>,
248    ) -> FloatTensor<Self> {
249        multi_op!(
250            inputs[(weights, float), (output_grad, float), (indices, int)],
251            => Float,
252            B::embedding_backward(weights, output_grad, indices)
253        )
254    }
255
256    fn conv1d(
257        x: FloatTensor<Self>,
258        weight: FloatTensor<Self>,
259        bias: Option<FloatTensor<Self>>,
260        options: burn_backend::ops::ConvOptions<1>,
261    ) -> FloatTensor<Self> {
262        multi_op!(
263            inputs[(x, float), (weight, float)],
264            opt_inputs[(bias, float)],
265            => Float,
266            B::conv1d(x, weight, bias, options)
267        )
268    }
269
270    fn conv1d_x_backward(
271        x: FloatTensor<Self>,
272        weight: FloatTensor<Self>,
273        output_grad: FloatTensor<Self>,
274        options: burn_backend::ops::ConvOptions<1>,
275    ) -> FloatTensor<Self> {
276        multi_op!(
277            inputs[(x, float), (weight, float), (output_grad, float)],
278            => Float,
279            B::conv1d_x_backward(x, weight, output_grad, options)
280        )
281    }
282
283    fn conv1d_weight_backward(
284        x: FloatTensor<Self>,
285        weight: FloatTensor<Self>,
286        output_grad: FloatTensor<Self>,
287        options: burn_backend::ops::ConvOptions<1>,
288    ) -> FloatTensor<Self> {
289        multi_op!(
290            inputs[(x, float), (weight, float), (output_grad, float)],
291            => Float,
292            B::conv1d_weight_backward(x, weight, output_grad, options)
293        )
294    }
295
296    fn conv1d_bias_backward(
297        x: FloatTensor<Self>,
298        bias: FloatTensor<Self>,
299        output_grad: FloatTensor<Self>,
300    ) -> FloatTensor<Self> {
301        multi_op!(
302            inputs[(x, float), (bias, float), (output_grad, float)],
303            => Float,
304            B::conv1d_bias_backward(x, bias, output_grad)
305        )
306    }
307
308    fn conv2d_x_backward(
309        x: FloatTensor<Self>,
310        weight: FloatTensor<Self>,
311        output_grad: FloatTensor<Self>,
312        options: burn_backend::ops::ConvOptions<2>,
313    ) -> FloatTensor<Self> {
314        multi_op!(
315            inputs[(x, float), (weight, float), (output_grad, float)],
316            => Float,
317            B::conv2d_x_backward(x, weight, output_grad, options)
318        )
319    }
320
321    fn conv2d_weight_backward(
322        x: FloatTensor<Self>,
323        weight: FloatTensor<Self>,
324        output_grad: FloatTensor<Self>,
325        options: burn_backend::ops::ConvOptions<2>,
326    ) -> FloatTensor<Self> {
327        multi_op!(
328            inputs[(x, float), (weight, float), (output_grad, float)],
329            => Float,
330            B::conv2d_weight_backward(x, weight, output_grad, options)
331        )
332    }
333
334    fn conv2d_bias_backward(
335        x: FloatTensor<Self>,
336        bias: FloatTensor<Self>,
337        output_grad: FloatTensor<Self>,
338    ) -> FloatTensor<Self> {
339        multi_op!(
340            inputs[(x, float), (bias, float), (output_grad, float)],
341            => Float,
342            B::conv2d_bias_backward(x, bias, output_grad)
343        )
344    }
345
346    fn conv3d_x_backward(
347        x: FloatTensor<Self>,
348        weight: FloatTensor<Self>,
349        output_grad: FloatTensor<Self>,
350        options: burn_backend::ops::ConvOptions<3>,
351    ) -> FloatTensor<Self> {
352        multi_op!(
353            inputs[(x, float), (weight, float), (output_grad, float)],
354            => Float,
355            B::conv3d_x_backward(x, weight, output_grad, options)
356        )
357    }
358
359    fn conv3d_weight_backward(
360        x: FloatTensor<Self>,
361        weight: FloatTensor<Self>,
362        output_grad: FloatTensor<Self>,
363        options: burn_backend::ops::ConvOptions<3>,
364    ) -> FloatTensor<Self> {
365        multi_op!(
366            inputs[(x, float), (weight, float), (output_grad, float)],
367            => Float,
368            B::conv3d_weight_backward(x, weight, output_grad, options)
369        )
370    }
371
372    fn conv3d_bias_backward(
373        x: FloatTensor<Self>,
374        bias: FloatTensor<Self>,
375        output_grad: FloatTensor<Self>,
376    ) -> FloatTensor<Self> {
377        multi_op!(
378            inputs[(x, float), (bias, float), (output_grad, float)],
379            => Float,
380            B::conv3d_bias_backward(x, bias, output_grad)
381        )
382    }
383
384    fn conv_transpose1d(
385        x: FloatTensor<Self>,
386        weight: FloatTensor<Self>,
387        bias: Option<FloatTensor<Self>>,
388        options: burn_backend::ops::ConvTransposeOptions<1>,
389    ) -> FloatTensor<Self> {
390        multi_op!(
391            inputs[(x, float), (weight, float)],
392            opt_inputs[(bias, float)],
393            => Float,
394            B::conv_transpose1d(x, weight, bias, options)
395        )
396    }
397
398    fn conv_transpose1d_x_backward(
399        weight: FloatTensor<Self>,
400        output_grad: FloatTensor<Self>,
401        options: burn_backend::ops::ConvTransposeOptions<1>,
402    ) -> FloatTensor<Self> {
403        multi_op!(
404            inputs[(weight, float), (output_grad, float)],
405            => Float,
406            B::conv_transpose1d_x_backward(weight, output_grad, options)
407        )
408    }
409
410    fn conv_transpose1d_weight_backward(
411        x: FloatTensor<Self>,
412        weight: FloatTensor<Self>,
413        output_grad: FloatTensor<Self>,
414        options: burn_backend::ops::ConvTransposeOptions<1>,
415    ) -> FloatTensor<Self> {
416        multi_op!(
417            inputs[(x, float), (weight, float), (output_grad, float)],
418            => Float,
419            B::conv_transpose1d_weight_backward(x, weight, output_grad, options)
420        )
421    }
422
423    fn conv_transpose1d_bias_backward(
424        x: FloatTensor<Self>,
425        bias: FloatTensor<Self>,
426        output_grad: FloatTensor<Self>,
427    ) -> FloatTensor<Self> {
428        multi_op!(
429            inputs[(x, float), (bias, float), (output_grad, float)],
430            => Float,
431            B::conv_transpose1d_bias_backward(x, bias, output_grad)
432        )
433    }
434
435    fn conv_transpose2d_x_backward(
436        weight: FloatTensor<Self>,
437        output_grad: FloatTensor<Self>,
438        options: burn_backend::ops::ConvTransposeOptions<2>,
439    ) -> FloatTensor<Self> {
440        multi_op!(
441            inputs[(weight, float), (output_grad, float)],
442            => Float,
443            B::conv_transpose2d_x_backward(weight, output_grad, options)
444        )
445    }
446
447    fn conv_transpose2d_weight_backward(
448        x: FloatTensor<Self>,
449        weight: FloatTensor<Self>,
450        output_grad: FloatTensor<Self>,
451        options: burn_backend::ops::ConvTransposeOptions<2>,
452    ) -> FloatTensor<Self> {
453        multi_op!(
454            inputs[(x, float), (weight, float), (output_grad, float)],
455            => Float,
456            B::conv_transpose2d_weight_backward(x, weight, output_grad, options)
457        )
458    }
459
460    fn conv_transpose2d_bias_backward(
461        x: FloatTensor<Self>,
462        bias: FloatTensor<Self>,
463        output_grad: FloatTensor<Self>,
464    ) -> FloatTensor<Self> {
465        multi_op!(
466            inputs[(x, float), (bias, float), (output_grad, float)],
467            => Float,
468            B::conv_transpose2d_bias_backward(x, bias, output_grad)
469        )
470    }
471
472    fn conv_transpose3d_x_backward(
473        weight: FloatTensor<Self>,
474        output_grad: FloatTensor<Self>,
475        options: burn_backend::ops::ConvTransposeOptions<3>,
476    ) -> FloatTensor<Self> {
477        multi_op!(
478            inputs[(weight, float), (output_grad, float)],
479            => Float,
480            B::conv_transpose3d_x_backward(weight, output_grad, options)
481        )
482    }
483
484    fn conv_transpose3d_weight_backward(
485        x: FloatTensor<Self>,
486        weight: FloatTensor<Self>,
487        output_grad: FloatTensor<Self>,
488        options: burn_backend::ops::ConvTransposeOptions<3>,
489    ) -> FloatTensor<Self> {
490        multi_op!(
491            inputs[(x, float), (weight, float), (output_grad, float)],
492            => Float,
493            B::conv_transpose3d_weight_backward(x, weight, output_grad, options)
494        )
495    }
496
497    fn conv_transpose3d_bias_backward(
498        x: FloatTensor<Self>,
499        bias: FloatTensor<Self>,
500        output_grad: FloatTensor<Self>,
501    ) -> FloatTensor<Self> {
502        multi_op!(
503            inputs[(x, float), (bias, float), (output_grad, float)],
504            => Float,
505            B::conv_transpose3d_bias_backward(x, bias, output_grad)
506        )
507    }
508
509    fn unfold4d(
510        x: FloatTensor<Self>,
511        kernel_size: [usize; 2],
512        options: burn_backend::ops::UnfoldOptions,
513    ) -> FloatTensor<Self> {
514        multi_op!(inputs[(x, float)], => Float, B::unfold4d(x, kernel_size, options))
515    }
516
517    fn avg_pool1d(
518        x: FloatTensor<Self>,
519        kernel_size: usize,
520        stride: usize,
521        padding: usize,
522        count_include_pad: bool,
523        ceil_mode: bool,
524    ) -> FloatTensor<Self> {
525        multi_op!(inputs[(x, float)], => Float,
526            B::avg_pool1d(x, kernel_size, stride, padding, count_include_pad, ceil_mode)
527        )
528    }
529
530    fn avg_pool1d_backward(
531        x: FloatTensor<Self>,
532        grad: FloatTensor<Self>,
533        kernel_size: usize,
534        stride: usize,
535        padding: usize,
536        count_include_pad: bool,
537        ceil_mode: bool,
538    ) -> FloatTensor<Self> {
539        multi_op!(
540            inputs[(x, float), (grad, float)],
541            => Float,
542            B::avg_pool1d_backward(x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode)
543        )
544    }
545
546    fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
547        multi_op!(inputs[(x, float)], => Float, B::adaptive_avg_pool1d(x, output_size))
548    }
549
550    fn adaptive_avg_pool1d_backward(
551        x: FloatTensor<Self>,
552        grad: FloatTensor<Self>,
553    ) -> FloatTensor<Self> {
554        multi_op!(
555            inputs[(x, float), (grad, float)],
556            => Float,
557            B::adaptive_avg_pool1d_backward(x, grad)
558        )
559    }
560
561    fn max_pool1d(
562        x: FloatTensor<Self>,
563        kernel_size: usize,
564        stride: usize,
565        padding: usize,
566        dilation: usize,
567        ceil_mode: bool,
568    ) -> FloatTensor<Self> {
569        multi_op!(inputs[(x, float)], => Float,
570            B::max_pool1d(x, kernel_size, stride, padding, dilation, ceil_mode))
571    }
572
573    fn max_pool1d_with_indices(
574        x: FloatTensor<Self>,
575        kernel_size: usize,
576        stride: usize,
577        padding: usize,
578        dilation: usize,
579        ceil_mode: bool,
580    ) -> MaxPool1dWithIndices<Self> {
581        let (out, indices) = multi_op!(
582            inputs[(x, float)],
583            outputs[(out, Float), (indices, Int)],
584            {
585                let res = B::max_pool1d_with_indices(x, kernel_size, stride, padding, dilation, ceil_mode);
586                (res.output, res.indices)
587            }
588        );
589        MaxPool1dWithIndices::new(out, indices)
590    }
591
592    fn max_pool1d_with_indices_backward(
593        x: FloatTensor<Self>,
594        kernel_size: usize,
595        stride: usize,
596        padding: usize,
597        dilation: usize,
598        ceil_mode: bool,
599        output_grad: FloatTensor<Self>,
600        indices: IntTensor<Self>,
601    ) -> MaxPool1dBackward<Self> {
602        let x_grad = multi_op!(
603            inputs[(x, float), (output_grad, float), (indices, int)],
604            => Float,
605            {
606                let res = B::max_pool1d_with_indices_backward(x, kernel_size, stride, padding, dilation, ceil_mode, output_grad, indices);
607                res.x_grad
608            }
609        );
610        MaxPool1dBackward::new(x_grad)
611    }
612
613    fn attention(
614        query: FloatTensor<Self>,
615        key: FloatTensor<Self>,
616        value: FloatTensor<Self>,
617        mask: Option<burn_backend::tensor::BoolTensor<Self>>,
618        attn_bias: Option<FloatTensor<Self>>,
619        options: burn_backend::ops::AttentionModuleOptions,
620    ) -> FloatTensor<Self> {
621        multi_op!(
622            inputs[(query, float), (key, float), (value, float)],
623            opt_inputs[(mask, bool), (attn_bias, float)],
624            => Float,
625            B::attention(query, key, value, mask, attn_bias, options)
626        )
627    }
628}