1use alloc::boxed::Box;
2
3use burn_backend::Element;
4use burn_backend::ops::{
5 ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions,
6 MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
7};
8use burn_backend::tensor::{FloatTensor, IntElem, IntTensor};
9use burn_ir::{
10 AdaptiveAvgPool1dBackwardOpIr, AdaptiveAvgPool1dOpIr, AdaptiveAvgPool2dBackwardOpIr,
11 AdaptiveAvgPool2dOpIr, AvgPool1dBackwardOpIr, AvgPool1dOpIr, AvgPool2dBackwardOpIr,
12 AvgPool2dOpIr, Conv1dOpIr, Conv2dOpIr, Conv3dOpIr, ConvTranspose1dOpIr, ConvTranspose2dOpIr,
13 ConvTranspose3dOpIr, DeformConv2dBackwardOpIr, DeformConv2dOpIr, InterpolateBackwardOpIr,
14 InterpolateOpIr, MaxPool1dOpIr, MaxPool1dWithIndicesBackwardOpIr, MaxPool1dWithIndicesOpIr,
15 MaxPool2dOpIr, MaxPool2dWithIndicesBackwardOpIr, MaxPool2dWithIndicesOpIr, ModuleOperationIr,
16 OperationIr, OperationOutput,
17};
18
19use crate::{BackendRouter, RunnerChannel, RunnerClient};
20
21impl<R: RunnerChannel> ModuleOps<Self> for BackendRouter<R> {
22 fn conv1d(
23 x: FloatTensor<Self>,
24 weight: FloatTensor<Self>,
25 bias: Option<FloatTensor<Self>>,
26 options: ConvOptions<1>,
27 ) -> FloatTensor<Self> {
28 let client = x.client.clone();
29 let desc = Conv1dOpIr::create(
30 x.into_ir(),
31 weight.into_ir(),
32 bias.map(|bias| bias.into_ir()),
33 options.into(),
34 || client.create_empty_handle(),
35 );
36
37 client
38 .register(OperationIr::Module(ModuleOperationIr::Conv1d(desc)))
39 .output()
40 }
41
42 fn conv2d(
43 x: FloatTensor<Self>,
44 weight: FloatTensor<Self>,
45 bias: Option<FloatTensor<Self>>,
46 options: ConvOptions<2>,
47 ) -> FloatTensor<Self> {
48 let client = x.client.clone();
49 let desc = Conv2dOpIr::create(
50 x.into_ir(),
51 weight.into_ir(),
52 bias.map(|bias| bias.into_ir()),
53 options.into(),
54 || client.create_empty_handle(),
55 );
56
57 client
58 .register(OperationIr::Module(ModuleOperationIr::Conv2d(desc)))
59 .output()
60 }
61
62 fn conv3d(
63 x: FloatTensor<Self>,
64 weight: FloatTensor<Self>,
65 bias: Option<FloatTensor<Self>>,
66 options: ConvOptions<3>,
67 ) -> FloatTensor<Self> {
68 let client = x.client.clone();
69 let desc = Conv3dOpIr::create(
70 x.into_ir(),
71 weight.into_ir(),
72 bias.map(|bias| bias.into_ir()),
73 options.into(),
74 || client.create_empty_handle(),
75 );
76
77 client
78 .register(OperationIr::Module(ModuleOperationIr::Conv3d(desc)))
79 .output()
80 }
81
82 fn conv_transpose1d(
83 x: FloatTensor<Self>,
84 weight: FloatTensor<Self>,
85 bias: Option<FloatTensor<Self>>,
86 options: ConvTransposeOptions<1>,
87 ) -> FloatTensor<Self> {
88 let client = x.client.clone();
89 let desc = ConvTranspose1dOpIr::create(
90 x.into_ir(),
91 weight.into_ir(),
92 bias.map(|bias| bias.into_ir()),
93 options.into(),
94 || client.create_empty_handle(),
95 );
96
97 client
98 .register(OperationIr::Module(ModuleOperationIr::ConvTranspose1d(
99 desc,
100 )))
101 .output()
102 }
103
104 fn conv_transpose2d(
105 x: FloatTensor<Self>,
106 weight: FloatTensor<Self>,
107 bias: Option<FloatTensor<Self>>,
108 options: ConvTransposeOptions<2>,
109 ) -> FloatTensor<Self> {
110 let client = x.client.clone();
111 let desc = ConvTranspose2dOpIr::create(
112 x.into_ir(),
113 weight.into_ir(),
114 bias.map(|bias| bias.into_ir()),
115 options.into(),
116 || client.create_empty_handle(),
117 );
118
119 client
120 .register(OperationIr::Module(ModuleOperationIr::ConvTranspose2d(
121 desc,
122 )))
123 .output()
124 }
125
126 fn conv_transpose3d(
127 x: FloatTensor<Self>,
128 weight: FloatTensor<Self>,
129 bias: Option<FloatTensor<Self>>,
130 options: ConvTransposeOptions<3>,
131 ) -> FloatTensor<Self> {
132 let client = x.client.clone();
133 let desc = ConvTranspose3dOpIr::create(
134 x.into_ir(),
135 weight.into_ir(),
136 bias.map(|bias| bias.into_ir()),
137 options.into(),
138 || client.create_empty_handle(),
139 );
140
141 client
142 .register(OperationIr::Module(ModuleOperationIr::ConvTranspose3d(
143 desc,
144 )))
145 .output()
146 }
147
148 fn avg_pool1d(
149 x: FloatTensor<Self>,
150 kernel_size: usize,
151 stride: usize,
152 padding: usize,
153 count_include_pad: bool,
154 ceil_mode: bool,
155 ) -> FloatTensor<Self> {
156 let client = x.client.clone();
157 let desc = AvgPool1dOpIr::create(
158 x.into_ir(),
159 kernel_size,
160 stride,
161 padding,
162 count_include_pad,
163 ceil_mode,
164 || client.create_empty_handle(),
165 );
166
167 client
168 .register(OperationIr::Module(ModuleOperationIr::AvgPool1d(desc)))
169 .output()
170 }
171
172 fn avg_pool2d(
173 x: FloatTensor<Self>,
174 kernel_size: [usize; 2],
175 stride: [usize; 2],
176 padding: [usize; 2],
177 count_include_pad: bool,
178 ceil_mode: bool,
179 ) -> FloatTensor<Self> {
180 let client = x.client.clone();
181 let desc = AvgPool2dOpIr::create(
182 x.into_ir(),
183 kernel_size,
184 stride,
185 padding,
186 count_include_pad,
187 ceil_mode,
188 || client.create_empty_handle(),
189 );
190
191 client
192 .register(OperationIr::Module(ModuleOperationIr::AvgPool2d(desc)))
193 .output()
194 }
195
196 fn avg_pool1d_backward(
197 x: FloatTensor<Self>,
198 grad: FloatTensor<Self>,
199 kernel_size: usize,
200 stride: usize,
201 padding: usize,
202 count_include_pad: bool,
203 ceil_mode: bool,
204 ) -> FloatTensor<Self> {
205 let client = x.client.clone();
206 let desc = AvgPool1dBackwardOpIr::create(
207 x.into_ir(),
208 grad.into_ir(),
209 kernel_size,
210 stride,
211 padding,
212 count_include_pad,
213 ceil_mode,
214 || client.create_empty_handle(),
215 );
216
217 client
218 .register(OperationIr::Module(ModuleOperationIr::AvgPool1dBackward(
219 desc,
220 )))
221 .output()
222 }
223
224 fn avg_pool2d_backward(
225 x: FloatTensor<Self>,
226 grad: FloatTensor<Self>,
227 kernel_size: [usize; 2],
228 stride: [usize; 2],
229 padding: [usize; 2],
230 count_include_pad: bool,
231 ceil_mode: bool,
232 ) -> FloatTensor<Self> {
233 let client = x.client.clone();
234 let desc = AvgPool2dBackwardOpIr::create(
235 x.into_ir(),
236 grad.into_ir(),
237 kernel_size,
238 stride,
239 padding,
240 count_include_pad,
241 ceil_mode,
242 || client.create_empty_handle(),
243 );
244
245 client
246 .register(OperationIr::Module(ModuleOperationIr::AvgPool2dBackward(
247 desc,
248 )))
249 .output()
250 }
251
252 fn max_pool1d(
253 x: FloatTensor<Self>,
254 kernel_size: usize,
255 stride: usize,
256 padding: usize,
257 dilation: usize,
258 ceil_mode: bool,
259 ) -> FloatTensor<Self> {
260 let client = x.client.clone();
261 let desc = MaxPool1dOpIr::create(
262 x.into_ir(),
263 kernel_size,
264 stride,
265 padding,
266 dilation,
267 ceil_mode,
268 || client.create_empty_handle(),
269 );
270
271 client
272 .register(OperationIr::Module(ModuleOperationIr::MaxPool1d(desc)))
273 .output()
274 }
275
276 fn max_pool2d(
277 x: FloatTensor<Self>,
278 kernel_size: [usize; 2],
279 stride: [usize; 2],
280 padding: [usize; 2],
281 dilation: [usize; 2],
282 ceil_mode: bool,
283 ) -> FloatTensor<Self> {
284 let client = x.client.clone();
285 let desc = MaxPool2dOpIr::create(
286 x.into_ir(),
287 kernel_size,
288 stride,
289 padding,
290 dilation,
291 ceil_mode,
292 || client.create_empty_handle(),
293 );
294
295 client
296 .register(OperationIr::Module(ModuleOperationIr::MaxPool2d(desc)))
297 .output()
298 }
299
300 fn max_pool1d_with_indices(
301 x: FloatTensor<Self>,
302 kernel_size: usize,
303 stride: usize,
304 padding: usize,
305 dilation: usize,
306 ceil_mode: bool,
307 ) -> MaxPool1dWithIndices<Self> {
308 let client = x.client.clone();
309 let desc = MaxPool1dWithIndicesOpIr::create(
310 x.into_ir(),
311 kernel_size,
312 stride,
313 padding,
314 dilation,
315 ceil_mode,
316 IntElem::<Self>::dtype(),
317 || client.create_empty_handle(),
318 );
319
320 let [out, out_indices] = client
321 .register(OperationIr::Module(
322 ModuleOperationIr::MaxPool1dWithIndices(desc),
323 ))
324 .outputs();
325
326 MaxPool1dWithIndices::new(out, out_indices)
327 }
328
329 fn max_pool2d_with_indices(
330 x: FloatTensor<Self>,
331 kernel_size: [usize; 2],
332 stride: [usize; 2],
333 padding: [usize; 2],
334 dilation: [usize; 2],
335 ceil_mode: bool,
336 ) -> MaxPool2dWithIndices<Self> {
337 let client = x.client.clone();
338 let desc = MaxPool2dWithIndicesOpIr::create(
339 x.into_ir(),
340 kernel_size,
341 stride,
342 padding,
343 dilation,
344 ceil_mode,
345 IntElem::<Self>::dtype(),
346 || client.create_empty_handle(),
347 );
348
349 let [out, out_indices] = client
350 .register(OperationIr::Module(
351 ModuleOperationIr::MaxPool2dWithIndices(desc),
352 ))
353 .outputs();
354
355 MaxPool2dWithIndices::new(out, out_indices)
356 }
357
358 fn max_pool1d_with_indices_backward(
359 x: FloatTensor<Self>,
360 kernel_size: usize,
361 stride: usize,
362 padding: usize,
363 dilation: usize,
364 ceil_mode: bool,
365 output_grad: FloatTensor<Self>,
366 indices: IntTensor<Self>,
367 ) -> MaxPool1dBackward<Self> {
368 let client = x.client.clone();
369
370 let desc = MaxPool1dWithIndicesBackwardOpIr::create(
371 x.into_ir(),
372 output_grad.into_ir(),
373 indices.into_ir(),
374 kernel_size,
375 stride,
376 padding,
377 dilation,
378 ceil_mode,
379 || client.create_empty_handle(),
380 );
381
382 let out = client
383 .register(OperationIr::Module(
384 ModuleOperationIr::MaxPool1dWithIndicesBackward(desc),
385 ))
386 .output();
387
388 MaxPool1dBackward::new(out)
389 }
390
391 fn max_pool2d_with_indices_backward(
392 x: FloatTensor<Self>,
393 kernel_size: [usize; 2],
394 stride: [usize; 2],
395 padding: [usize; 2],
396 dilation: [usize; 2],
397 ceil_mode: bool,
398 output_grad: FloatTensor<Self>,
399 indices: IntTensor<Self>,
400 ) -> MaxPool2dBackward<Self> {
401 let client = x.client.clone();
402
403 let desc = MaxPool2dWithIndicesBackwardOpIr::create(
404 x.into_ir(),
405 output_grad.into_ir(),
406 indices.into_ir(),
407 kernel_size,
408 stride,
409 padding,
410 dilation,
411 ceil_mode,
412 || client.create_empty_handle(),
413 );
414
415 let out = client
416 .register(OperationIr::Module(
417 ModuleOperationIr::MaxPool2dWithIndicesBackward(desc),
418 ))
419 .output();
420
421 MaxPool2dBackward::new(out)
422 }
423
424 fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
425 let client = x.client.clone();
426
427 let desc = AdaptiveAvgPool1dOpIr::create(x.into_ir(), output_size, || {
428 client.create_empty_handle()
429 });
430
431 client
432 .register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d(
433 desc,
434 )))
435 .output()
436 }
437
438 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
439 let client = x.client.clone();
440
441 let desc = AdaptiveAvgPool2dOpIr::create(x.into_ir(), output_size, || {
442 client.create_empty_handle()
443 });
444
445 client
446 .register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d(
447 desc,
448 )))
449 .output()
450 }
451
452 fn adaptive_avg_pool1d_backward(
453 x: FloatTensor<Self>,
454 grad: FloatTensor<Self>,
455 ) -> FloatTensor<Self> {
456 let client = x.client.clone();
457
458 let desc = AdaptiveAvgPool1dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
459 client.create_empty_handle()
460 });
461
462 client
463 .register(OperationIr::Module(
464 ModuleOperationIr::AdaptiveAvgPool1dBackward(desc),
465 ))
466 .output()
467 }
468
469 fn adaptive_avg_pool2d_backward(
470 x: FloatTensor<Self>,
471 grad: FloatTensor<Self>,
472 ) -> FloatTensor<Self> {
473 let client = x.client.clone();
474
475 let desc = AdaptiveAvgPool2dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {
476 client.create_empty_handle()
477 });
478
479 client
480 .register(OperationIr::Module(
481 ModuleOperationIr::AdaptiveAvgPool2dBackward(desc),
482 ))
483 .output()
484 }
485
486 fn interpolate(
487 x: FloatTensor<Self>,
488 output_size: [usize; 2],
489 options: InterpolateOptions,
490 ) -> FloatTensor<Self> {
491 let client = x.client.clone();
492 let desc = InterpolateOpIr::create(x.into_ir(), output_size, options.into(), || {
493 client.create_empty_handle()
494 });
495
496 client
497 .register(OperationIr::Module(ModuleOperationIr::Interpolate(desc)))
498 .output()
499 }
500
501 fn interpolate_backward(
502 x: FloatTensor<Self>,
503 grad: FloatTensor<Self>,
504 output_size: [usize; 2],
505 options: InterpolateOptions,
506 ) -> FloatTensor<Self> {
507 let client = x.client.clone();
508 let desc = InterpolateBackwardOpIr::create(
509 x.into_ir(),
510 grad.into_ir(),
511 output_size,
512 options.into(),
513 || client.create_empty_handle(),
514 );
515
516 client
517 .register(OperationIr::Module(ModuleOperationIr::InterpolateBackward(
518 desc,
519 )))
520 .output()
521 }
522
523 fn deform_conv2d(
524 x: FloatTensor<Self>,
525 offset: FloatTensor<Self>,
526 weight: FloatTensor<Self>,
527 mask: Option<FloatTensor<Self>>,
528 bias: Option<FloatTensor<Self>>,
529 options: DeformConvOptions<2>,
530 ) -> FloatTensor<Self> {
531 let client = x.client.clone();
532 let desc = DeformConv2dOpIr::create(
533 x.into_ir(),
534 offset.into_ir(),
535 weight.into_ir(),
536 mask.map(|mask| mask.into_ir()),
537 bias.map(|bias| bias.into_ir()),
538 options.into(),
539 || client.create_empty_handle(),
540 );
541
542 client
543 .register(OperationIr::Module(ModuleOperationIr::DeformableConv2d(
544 Box::new(desc),
545 )))
546 .output()
547 }
548
549 fn deform_conv2d_backward(
550 x: FloatTensor<Self>,
551 offset: FloatTensor<Self>,
552 weight: FloatTensor<Self>,
553 mask: Option<FloatTensor<Self>>,
554 bias: Option<FloatTensor<Self>>,
555 output_grad: FloatTensor<Self>,
556 options: DeformConvOptions<2>,
557 ) -> DeformConv2dBackward<Self> {
558 let client = x.client.clone();
559 let has_bias = bias.is_some();
560 let has_mask = mask.is_some();
561
562 let desc = DeformConv2dBackwardOpIr::create(
563 x.into_ir(),
564 offset.into_ir(),
565 weight.into_ir(),
566 mask.map(|mask| mask.into_ir()),
567 bias.map(|bias| bias.into_ir()),
568 output_grad.into_ir(),
569 options.into(),
570 || client.create_empty_handle(),
571 );
572 let mut outputs = client
573 .register(OperationIr::Module(
574 ModuleOperationIr::DeformableConv2dBackward(Box::new(desc)),
575 ))
576 .into_iter();
577
578 let input_grad = outputs.next().unwrap();
580 let offset_grad = outputs.next().unwrap();
581 let weight_grad = outputs.next().unwrap();
582 let mask_grad = has_mask.then(|| outputs.next().unwrap());
583 let bias_grad = has_bias.then(|| outputs.next().unwrap());
584
585 DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
586 }
587}