Skip to main content

burn_router/ops/
module.rs

1use alloc::boxed::Box;
2
3use burn_backend::Element;
4use burn_backend::ops::{
5    ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions,
6    MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
7};
8use burn_backend::tensor::{FloatTensor, IntElem, IntTensor};
9use burn_ir::{
10    AdaptiveAvgPool1dBackwardOpIr, AdaptiveAvgPool1dOpIr, AdaptiveAvgPool2dBackwardOpIr,
11    AdaptiveAvgPool2dOpIr, AvgPool1dBackwardOpIr, AvgPool1dOpIr, AvgPool2dBackwardOpIr,
12    AvgPool2dOpIr, Conv1dOpIr, Conv2dOpIr, Conv3dOpIr, ConvTranspose1dOpIr, ConvTranspose2dOpIr,
13    ConvTranspose3dOpIr, DeformConv2dBackwardOpIr, DeformConv2dOpIr, InterpolateBackwardOpIr,
14    InterpolateOpIr, MaxPool1dOpIr, MaxPool1dWithIndicesBackwardOpIr, MaxPool1dWithIndicesOpIr,
15    MaxPool2dOpIr, MaxPool2dWithIndicesBackwardOpIr, MaxPool2dWithIndicesOpIr, ModuleOperationIr,
16    OperationIr, OperationOutput,
17};
18
19use crate::{BackendRouter, RunnerChannel, RunnerClient};
20
21impl<R: RunnerChannel> ModuleOps<Self> for BackendRouter<R> {
22    fn conv1d(
23        x: FloatTensor<Self>,
24        weight: FloatTensor<Self>,
25        bias: Option<FloatTensor<Self>>,
26        options: ConvOptions<1>,
27    ) -> FloatTensor<Self> {
28        let client = x.client.clone();
29        let desc = Conv1dOpIr::create(
30            x.into_ir(),
31            weight.into_ir(),
32            bias.map(|bias| bias.into_ir()),
33            options.into(),
34            || client.create_empty_handle(),
35        );
36
37        client
38            .register(OperationIr::Module(ModuleOperationIr::Conv1d(desc)))
39            .output()
40    }
41
42    fn conv2d(
43        x: FloatTensor<Self>,
44        weight: FloatTensor<Self>,
45        bias: Option<FloatTensor<Self>>,
46        options: ConvOptions<2>,
47    ) -> FloatTensor<Self> {
48        let client = x.client.clone();
49        let desc = Conv2dOpIr::create(
50            x.into_ir(),
51            weight.into_ir(),
52            bias.map(|bias| bias.into_ir()),
53            options.into(),
54            || client.create_empty_handle(),
55        );
56
57        client
58            .register(OperationIr::Module(ModuleOperationIr::Conv2d(desc)))
59            .output()
60    }
61
62    fn conv3d(
63        x: FloatTensor<Self>,
64        weight: FloatTensor<Self>,
65        bias: Option<FloatTensor<Self>>,
66        options: ConvOptions<3>,
67    ) -> FloatTensor<Self> {
68        let client = x.client.clone();
69        let desc = Conv3dOpIr::create(
70            x.into_ir(),
71            weight.into_ir(),
72            bias.map(|bias| bias.into_ir()),
73            options.into(),
74            || client.create_empty_handle(),
75        );
76
77        client
78            .register(OperationIr::Module(ModuleOperationIr::Conv3d(desc)))
79            .output()
80    }
81
82    fn conv_transpose1d(
83        x: FloatTensor<Self>,
84        weight: FloatTensor<Self>,
85        bias: Option<FloatTensor<Self>>,
86        options: ConvTransposeOptions<1>,
87    ) -> FloatTensor<Self> {
88        let client = x.client.clone();
89        let desc = ConvTranspose1dOpIr::create(
90            x.into_ir(),
91            weight.into_ir(),
92            bias.map(|bias| bias.into_ir()),
93            options.into(),
94            || client.create_empty_handle(),
95        );
96
97        client
98            .register(OperationIr::Module(ModuleOperationIr::ConvTranspose1d(
99                desc,
100            )))
101            .output()
102    }
103
104    fn conv_transpose2d(
105        x: FloatTensor<Self>,
106        weight: FloatTensor<Self>,
107        bias: Option<FloatTensor<Self>>,
108        options: ConvTransposeOptions<2>,
109    ) -> FloatTensor<Self> {
110        let client = x.client.clone();
111        let desc = ConvTranspose2dOpIr::create(
112            x.into_ir(),
113            weight.into_ir(),
114            bias.map(|bias| bias.into_ir()),
115            options.into(),
116            || client.create_empty_handle(),
117        );
118
119        client
120            .register(OperationIr::Module(ModuleOperationIr::ConvTranspose2d(
121                desc,
122            )))
123            .output()
124    }
125
126    fn conv_transpose3d(
127        x: FloatTensor<Self>,
128        weight: FloatTensor<Self>,
129        bias: Option<FloatTensor<Self>>,
130        options: ConvTransposeOptions<3>,
131    ) -> FloatTensor<Self> {
132        let client = x.client.clone();
133        let desc = ConvTranspose3dOpIr::create(
134            x.into_ir(),
135            weight.into_ir(),
136            bias.map(|bias| bias.into_ir()),
137            options.into(),
138            || client.create_empty_handle(),
139        );
140
141        client
142            .register(OperationIr::Module(ModuleOperationIr::ConvTranspose3d(
143                desc,
144            )))
145            .output()
146    }
147
148    fn avg_pool1d(
149        x: FloatTensor<Self>,
150        kernel_size: usize,
151        stride: usize,
152        padding: usize,
153        count_include_pad: bool,
154        ceil_mode: bool,
155    ) -> FloatTensor<Self> {
156        let client = x.client.clone();
157        let desc = AvgPool1dOpIr::create(
158            x.into_ir(),
159            kernel_size,
160            stride,
161            padding,
162            count_include_pad,
163            ceil_mode,
164            || client.create_empty_handle(),
165        );
166
167        client
168            .register(OperationIr::Module(ModuleOperationIr::AvgPool1d(desc)))
169            .output()
170    }
171
172    fn avg_pool2d(
173        x: FloatTensor<Self>,
174        kernel_size: [usize; 2],
175        stride: [usize; 2],
176        padding: [usize; 2],
177        count_include_pad: bool,
178        ceil_mode: bool,
179    ) -> FloatTensor<Self> {
180        let client = x.client.clone();
181        let desc = AvgPool2dOpIr::create(
182            x.into_ir(),
183            kernel_size,
184            stride,
185            padding,
186            count_include_pad,
187            ceil_mode,
188            || client.create_empty_handle(),
189        );
190
191        client
192            .register(OperationIr::Module(ModuleOperationIr::AvgPool2d(desc)))
193            .output()
194    }
195
196    fn avg_pool1d_backward(
197        x: FloatTensor<Self>,
198        grad: FloatTensor<Self>,
199        kernel_size: usize,
200        stride: usize,
201        padding: usize,
202        count_include_pad: bool,
203        ceil_mode: bool,
204    ) -> FloatTensor<Self> {
205        let client = x.client.clone();
206        let desc = AvgPool1dBackwardOpIr::create(
207            x.into_ir(),
208            grad.into_ir(),
209            kernel_size,
210            stride,
211            padding,
212            count_include_pad,
213            ceil_mode,
214            || client.create_empty_handle(),
215        );
216
217        client
218            .register(OperationIr::Module(ModuleOperationIr::AvgPool1dBackward(
219                desc,
220            )))
221            .output()
222    }
223
224    fn avg_pool2d_backward(
225        x: FloatTensor<Self>,
226        grad: FloatTensor<Self>,
227        kernel_size: [usize; 2],
228        stride: [usize; 2],
229        padding: [usize; 2],
230        count_include_pad: bool,
231        ceil_mode: bool,
232    ) -> FloatTensor<Self> {
233        let client = x.client.clone();
234        let desc = AvgPool2dBackwardOpIr::create(
235            x.into_ir(),
236            grad.into_ir(),
237            kernel_size,
238            stride,
239            padding,
240            count_include_pad,
241            ceil_mode,
242            || client.create_empty_handle(),
243        );
244
245        client
246            .register(OperationIr::Module(ModuleOperationIr::AvgPool2dBackward(
247                desc,
248            )))
249            .output()
250    }
251
252    fn max_pool1d(
253        x: FloatTensor<Self>,
254        kernel_size: usize,
255        stride: usize,
256        padding: usize,
257        dilation: usize,
258        ceil_mode: bool,
259    ) -> FloatTensor<Self> {
260        let client = x.client.clone();
261        let desc = MaxPool1dOpIr::create(
262            x.into_ir(),
263            kernel_size,
264            stride,
265            padding,
266            dilation,
267            ceil_mode,
268            || client.create_empty_handle(),
269        );
270
271        client
272            .register(OperationIr::Module(ModuleOperationIr::MaxPool1d(desc)))
273            .output()
274    }
275
276    fn max_pool2d(
277        x: FloatTensor<Self>,
278        kernel_size: [usize; 2],
279        stride: [usize; 2],
280        padding: [usize; 2],
281        dilation: [usize; 2],
282        ceil_mode: bool,
283    ) -> FloatTensor<Self> {
284        let client = x.client.clone();
285        let desc = MaxPool2dOpIr::create(
286            x.into_ir(),
287            kernel_size,
288            stride,
289            padding,
290            dilation,
291            ceil_mode,
292            || client.create_empty_handle(),
293        );
294
295        client
296            .register(OperationIr::Module(ModuleOperationIr::MaxPool2d(desc)))
297            .output()
298    }
299
300    fn max_pool1d_with_indices(
301        x: FloatTensor<Self>,
302        kernel_size: usize,
303        stride: usize,
304        padding: usize,
305        dilation: usize,
306        ceil_mode: bool,
307    ) -> MaxPool1dWithIndices<Self> {
308        let client = x.client.clone();
309        let desc = MaxPool1dWithIndicesOpIr::create(
310            x.into_ir(),
311            kernel_size,
312            stride,
313            padding,
314            dilation,
315            ceil_mode,
316            IntElem::<Self>::dtype(),
317            || client.create_empty_handle(),
318        );
319
320        let [out, out_indices] = client
321            .register(OperationIr::Module(
322                ModuleOperationIr::MaxPool1dWithIndices(desc),
323            ))
324            .outputs();
325
326        MaxPool1dWithIndices::new(out, out_indices)
327    }
328
329    fn max_pool2d_with_indices(
330        x: FloatTensor<Self>,
331        kernel_size: [usize; 2],
332        stride: [usize; 2],
333        padding: [usize; 2],
334        dilation: [usize; 2],
335        ceil_mode: bool,
336    ) -> MaxPool2dWithIndices<Self> {
337        let client = x.client.clone();
338        let desc = MaxPool2dWithIndicesOpIr::create(
339            x.into_ir(),
340            kernel_size,
341            stride,
342            padding,
343            dilation,
344            ceil_mode,
345            IntElem::<Self>::dtype(),
346            || client.create_empty_handle(),
347        );
348
349        let [out, out_indices] = client
350            .register(OperationIr::Module(
351                ModuleOperationIr::MaxPool2dWithIndices(desc),
352            ))
353            .outputs();
354
355        MaxPool2dWithIndices::new(out, out_indices)
356    }
357
358    fn max_pool1d_with_indices_backward(
359        x: FloatTensor<Self>,
360        kernel_size: usize,
361        stride: usize,
362        padding: usize,
363        dilation: usize,
364        ceil_mode: bool,
365        output_grad: FloatTensor<Self>,
366        indices: IntTensor<Self>,
367    ) -> MaxPool1dBackward<Self> {
368        let client = x.client.clone();
369
370        let desc = MaxPool1dWithIndicesBackwardOpIr::create(
371            x.into_ir(),
372            output_grad.into_ir(),
373            indices.into_ir(),
374            kernel_size,
375            stride,
376            padding,
377            dilation,
378            ceil_mode,
379            || client.create_empty_handle(),
380        );
381
382        let out = client
383            .register(OperationIr::Module(
384                ModuleOperationIr::MaxPool1dWithIndicesBackward(desc),
385            ))
386            .output();
387
388        MaxPool1dBackward::new(out)
389    }
390
391    fn max_pool2d_with_indices_backward(
392        x: FloatTensor<Self>,
393        kernel_size: [usize; 2],
394        stride: [usize; 2],
395        padding: [usize; 2],
396        dilation: [usize; 2],
397        ceil_mode: bool,
398        output_grad: FloatTensor<Self>,
399        indices: IntTensor<Self>,
400    ) -> MaxPool2dBackward<Self> {
401        let client = x.client.clone();
402
403        let desc = MaxPool2dWithIndicesBackwardOpIr::create(
404            x.into_ir(),
405            output_grad.into_ir(),
406            indices.into_ir(),
407            kernel_size,
408            stride,
409            padding,
410            dilation,
411            ceil_mode,
412            || client.create_empty_handle(),
413        );
414
415        let out = client
416            .register(OperationIr::Module(
417                ModuleOperationIr::MaxPool2dWithIndicesBackward(desc),
418            ))
419            .output();
420
421        MaxPool2dBackward::new(out)
422    }
423
424    fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
425        let client = x.client.clone();
426
427        let desc = AdaptiveAvgPool1dOpIr::create(x.into_ir(), output_size, || {
428            client.create_empty_handle()
429        });
430
431        client
432            .register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d(
433                desc,
434            )))
435            .output()
436    }
437
438    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
439        let client = x.client.clone();
440
441        let desc = AdaptiveAvgPool2dOpIr::create(x.into_ir(), output_size, || {
442            client.create_empty_handle()
443        });
444
445        client
446            .register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d(
447                desc,
448            )))
449            .output()
450    }
451
452    fn adaptive_avg_pool1d_backward(
453        x: FloatTensor<Self>,
454        grad: FloatTensor<Self>,
455    ) -> FloatTensor<Self> {
456        let client = x.client.clone();
457
458        let desc = AdaptiveAvgPool1dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
459            client.create_empty_handle()
460        });
461
462        client
463            .register(OperationIr::Module(
464                ModuleOperationIr::AdaptiveAvgPool1dBackward(desc),
465            ))
466            .output()
467    }
468
469    fn adaptive_avg_pool2d_backward(
470        x: FloatTensor<Self>,
471        grad: FloatTensor<Self>,
472    ) -> FloatTensor<Self> {
473        let client = x.client.clone();
474
475        let desc = AdaptiveAvgPool2dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
476            client.create_empty_handle()
477        });
478
479        client
480            .register(OperationIr::Module(
481                ModuleOperationIr::AdaptiveAvgPool2dBackward(desc),
482            ))
483            .output()
484    }
485
486    fn interpolate(
487        x: FloatTensor<Self>,
488        output_size: [usize; 2],
489        options: InterpolateOptions,
490    ) -> FloatTensor<Self> {
491        let client = x.client.clone();
492        let desc = InterpolateOpIr::create(x.into_ir(), output_size, options.into(), || {
493            client.create_empty_handle()
494        });
495
496        client
497            .register(OperationIr::Module(ModuleOperationIr::Interpolate(desc)))
498            .output()
499    }
500
501    fn interpolate_backward(
502        x: FloatTensor<Self>,
503        grad: FloatTensor<Self>,
504        output_size: [usize; 2],
505        options: InterpolateOptions,
506    ) -> FloatTensor<Self> {
507        let client = x.client.clone();
508        let desc = InterpolateBackwardOpIr::create(
509            x.into_ir(),
510            grad.into_ir(),
511            output_size,
512            options.into(),
513            || client.create_empty_handle(),
514        );
515
516        client
517            .register(OperationIr::Module(ModuleOperationIr::InterpolateBackward(
518                desc,
519            )))
520            .output()
521    }
522
523    fn deform_conv2d(
524        x: FloatTensor<Self>,
525        offset: FloatTensor<Self>,
526        weight: FloatTensor<Self>,
527        mask: Option<FloatTensor<Self>>,
528        bias: Option<FloatTensor<Self>>,
529        options: DeformConvOptions<2>,
530    ) -> FloatTensor<Self> {
531        let client = x.client.clone();
532        let desc = DeformConv2dOpIr::create(
533            x.into_ir(),
534            offset.into_ir(),
535            weight.into_ir(),
536            mask.map(|mask| mask.into_ir()),
537            bias.map(|bias| bias.into_ir()),
538            options.into(),
539            || client.create_empty_handle(),
540        );
541
542        client
543            .register(OperationIr::Module(ModuleOperationIr::DeformableConv2d(
544                Box::new(desc),
545            )))
546            .output()
547    }
548
549    fn deform_conv2d_backward(
550        x: FloatTensor<Self>,
551        offset: FloatTensor<Self>,
552        weight: FloatTensor<Self>,
553        mask: Option<FloatTensor<Self>>,
554        bias: Option<FloatTensor<Self>>,
555        output_grad: FloatTensor<Self>,
556        options: DeformConvOptions<2>,
557    ) -> DeformConv2dBackward<Self> {
558        let client = x.client.clone();
559        let has_bias = bias.is_some();
560        let has_mask = mask.is_some();
561
562        let desc = DeformConv2dBackwardOpIr::create(
563            x.into_ir(),
564            offset.into_ir(),
565            weight.into_ir(),
566            mask.map(|mask| mask.into_ir()),
567            bias.map(|bias| bias.into_ir()),
568            output_grad.into_ir(),
569            options.into(),
570            || client.create_empty_handle(),
571        );
572        let mut outputs = client
573            .register(OperationIr::Module(
574                ModuleOperationIr::DeformableConv2dBackward(Box::new(desc)),
575            ))
576            .into_iter();
577
578        // When the number of outputs is variable, the order is important
579        let input_grad = outputs.next().unwrap();
580        let offset_grad = outputs.next().unwrap();
581        let weight_grad = outputs.next().unwrap();
582        let mask_grad = has_mask.then(|| outputs.next().unwrap());
583        let bias_grad = has_bias.then(|| outputs.next().unwrap());
584
585        DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
586    }
587}