1use crate::ops::{conv, conv_transpose, deform_conv, interpolate, pool};
6use crate::{Flex, FlexTensor, Layout};
7use burn_backend::{
8 DType, Element, TensorMetadata,
9 ops::{
10 AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward,
11 DeformConvOptions, FloatTensorOps, IntTensorOps, InterpolateMode, InterpolateOptions,
12 MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
13 },
14 tensor::{BoolTensor, FloatTensor, IntTensor},
15};
16use burn_std::{Bytes, Shape};
17use bytemuck::Pod;
18
19pub(crate) fn cast_to_f32<E: Element + Pod + Copy>(
21 tensor: FlexTensor,
22 to_f32: fn(E) -> f32,
23) -> FlexTensor {
24 let tensor = tensor.to_contiguous();
25 let shape = tensor.layout().shape().clone();
26 let data: &[E] = tensor.storage();
27 let f32_data: alloc::vec::Vec<f32> = data.iter().map(|&v| to_f32(v)).collect();
28 let bytes = Bytes::from_elems(f32_data);
29 FlexTensor::new(bytes, Layout::contiguous(shape), DType::F32)
30}
31
32pub(crate) fn cast_from_f32<E: Element + Pod + Copy>(
34 tensor: FlexTensor,
35 from_f32: fn(f32) -> E,
36) -> FlexTensor {
37 let tensor = tensor.to_contiguous();
38 let shape = tensor.layout().shape().clone();
39 let data: &[f32] = tensor.storage();
40 let half_data: alloc::vec::Vec<E> = data.iter().map(|&v| from_f32(v)).collect();
41 let bytes = Bytes::from_elems(half_data);
42 FlexTensor::new(bytes, Layout::contiguous(shape), E::dtype())
43}
44
45impl ModuleOps<Flex> for Flex {
46 fn conv1d(
47 x: FloatTensor<Flex>,
48 weight: FloatTensor<Flex>,
49 bias: Option<FloatTensor<Flex>>,
50 options: ConvOptions<1>,
51 ) -> FloatTensor<Flex> {
52 match x.dtype() {
53 DType::F32 => conv::conv1d_f32(x, weight, bias, &options),
54 DType::F64 => conv::conv1d_f64(x, weight, bias, &options),
55 DType::F16 => conv::conv1d_f16(x, weight, bias, &options),
56 DType::BF16 => conv::conv1d_bf16(x, weight, bias, &options),
57 dtype => panic!("conv1d: unsupported dtype {:?}", dtype),
58 }
59 }
60
61 fn conv2d(
62 x: FloatTensor<Flex>,
63 weight: FloatTensor<Flex>,
64 bias: Option<FloatTensor<Flex>>,
65 options: ConvOptions<2>,
66 ) -> FloatTensor<Flex> {
67 match x.dtype() {
68 DType::F32 => conv::conv2d_f32(x, weight, bias, &options),
69 DType::F64 => conv::conv2d_f64(x, weight, bias, &options),
70 DType::F16 => conv::conv2d_f16(x, weight, bias, &options),
71 DType::BF16 => conv::conv2d_bf16(x, weight, bias, &options),
72 dtype => panic!("conv2d: unsupported dtype {:?}", dtype),
73 }
74 }
75
76 fn deform_conv2d(
77 x: FloatTensor<Flex>,
78 offset: FloatTensor<Flex>,
79 weight: FloatTensor<Flex>,
80 mask: Option<FloatTensor<Flex>>,
81 bias: Option<FloatTensor<Flex>>,
82 options: DeformConvOptions<2>,
83 ) -> FloatTensor<Flex> {
84 match x.dtype() {
85 DType::F32 => deform_conv::deform_conv2d_f32(
86 x,
87 offset,
88 weight,
89 mask,
90 bias,
91 options.stride,
92 options.padding,
93 options.dilation,
94 options.weight_groups,
95 options.offset_groups,
96 ),
97 DType::F64 => deform_conv::deform_conv2d_f64(
98 x,
99 offset,
100 weight,
101 mask,
102 bias,
103 options.stride,
104 options.padding,
105 options.dilation,
106 options.weight_groups,
107 options.offset_groups,
108 ),
109 DType::F16 => {
110 use burn_std::f16;
111 let result = deform_conv::deform_conv2d_f32(
112 cast_to_f32(x, f16::to_f32),
113 cast_to_f32(offset, f16::to_f32),
114 cast_to_f32(weight, f16::to_f32),
115 mask.map(|m| cast_to_f32(m, f16::to_f32)),
116 bias.map(|b| cast_to_f32(b, f16::to_f32)),
117 options.stride,
118 options.padding,
119 options.dilation,
120 options.weight_groups,
121 options.offset_groups,
122 );
123 cast_from_f32(result, f16::from_f32)
124 }
125 DType::BF16 => {
126 use burn_std::bf16;
127 let result = deform_conv::deform_conv2d_f32(
128 cast_to_f32(x, bf16::to_f32),
129 cast_to_f32(offset, bf16::to_f32),
130 cast_to_f32(weight, bf16::to_f32),
131 mask.map(|m| cast_to_f32(m, bf16::to_f32)),
132 bias.map(|b| cast_to_f32(b, bf16::to_f32)),
133 options.stride,
134 options.padding,
135 options.dilation,
136 options.weight_groups,
137 options.offset_groups,
138 );
139 cast_from_f32(result, bf16::from_f32)
140 }
141 dtype => panic!("deform_conv2d: unsupported dtype {:?}", dtype),
142 }
143 }
144
145 fn deform_conv2d_backward(
146 x: FloatTensor<Flex>,
147 offset: FloatTensor<Flex>,
148 weight: FloatTensor<Flex>,
149 mask: Option<FloatTensor<Flex>>,
150 bias: Option<FloatTensor<Flex>>,
151 output_grad: FloatTensor<Flex>,
152 options: DeformConvOptions<2>,
153 ) -> DeformConv2dBackward<Flex> {
154 let (x_grad, offset_grad, weight_grad, mask_grad, bias_grad) = match x.dtype() {
155 DType::F32 => deform_conv::deform_conv2d_backward_f32(
156 x,
157 offset,
158 weight,
159 mask,
160 bias,
161 output_grad,
162 options.stride,
163 options.padding,
164 options.dilation,
165 options.weight_groups,
166 options.offset_groups,
167 ),
168 DType::F16 => {
169 use burn_std::f16;
170 let (xg, og, wg, mg, bg) = deform_conv::deform_conv2d_backward_f32(
171 cast_to_f32(x, f16::to_f32),
172 cast_to_f32(offset, f16::to_f32),
173 cast_to_f32(weight, f16::to_f32),
174 mask.map(|m| cast_to_f32(m, f16::to_f32)),
175 bias.map(|b| cast_to_f32(b, f16::to_f32)),
176 cast_to_f32(output_grad, f16::to_f32),
177 options.stride,
178 options.padding,
179 options.dilation,
180 options.weight_groups,
181 options.offset_groups,
182 );
183 (
184 cast_from_f32(xg, f16::from_f32),
185 cast_from_f32(og, f16::from_f32),
186 cast_from_f32(wg, f16::from_f32),
187 mg.map(|m| cast_from_f32(m, f16::from_f32)),
188 bg.map(|b| cast_from_f32(b, f16::from_f32)),
189 )
190 }
191 DType::BF16 => {
192 use burn_std::bf16;
193 let (xg, og, wg, mg, bg) = deform_conv::deform_conv2d_backward_f32(
194 cast_to_f32(x, bf16::to_f32),
195 cast_to_f32(offset, bf16::to_f32),
196 cast_to_f32(weight, bf16::to_f32),
197 mask.map(|m| cast_to_f32(m, bf16::to_f32)),
198 bias.map(|b| cast_to_f32(b, bf16::to_f32)),
199 cast_to_f32(output_grad, bf16::to_f32),
200 options.stride,
201 options.padding,
202 options.dilation,
203 options.weight_groups,
204 options.offset_groups,
205 );
206 (
207 cast_from_f32(xg, bf16::from_f32),
208 cast_from_f32(og, bf16::from_f32),
209 cast_from_f32(wg, bf16::from_f32),
210 mg.map(|m| cast_from_f32(m, bf16::from_f32)),
211 bg.map(|b| cast_from_f32(b, bf16::from_f32)),
212 )
213 }
214 DType::F64 => {
218 let to = |v: f64| v as f32;
219 let from = |v: f32| v as f64;
220 let (xg, og, wg, mg, bg) = deform_conv::deform_conv2d_backward_f32(
221 cast_to_f32(x, to),
222 cast_to_f32(offset, to),
223 cast_to_f32(weight, to),
224 mask.map(|m| cast_to_f32(m, to)),
225 bias.map(|b| cast_to_f32(b, to)),
226 cast_to_f32(output_grad, to),
227 options.stride,
228 options.padding,
229 options.dilation,
230 options.weight_groups,
231 options.offset_groups,
232 );
233 (
234 cast_from_f32(xg, from),
235 cast_from_f32(og, from),
236 cast_from_f32(wg, from),
237 mg.map(|m| cast_from_f32(m, from)),
238 bg.map(|b| cast_from_f32(b, from)),
239 )
240 }
241 dtype => panic!("deform_conv2d_backward: unsupported dtype {:?}", dtype),
242 };
243 DeformConv2dBackward::new(x_grad, offset_grad, weight_grad, mask_grad, bias_grad)
244 }
245
246 fn conv3d(
247 x: FloatTensor<Flex>,
248 weight: FloatTensor<Flex>,
249 bias: Option<FloatTensor<Flex>>,
250 options: ConvOptions<3>,
251 ) -> FloatTensor<Flex> {
252 match x.dtype() {
253 DType::F32 => conv::conv3d_f32(x, weight, bias, &options),
254 DType::F64 => conv::conv3d_f64(x, weight, bias, &options),
255 DType::F16 => conv::conv3d_f16(x, weight, bias, &options),
256 DType::BF16 => conv::conv3d_bf16(x, weight, bias, &options),
257 dtype => panic!("conv3d: unsupported dtype {:?}", dtype),
258 }
259 }
260
261 fn conv_transpose1d(
262 x: FloatTensor<Flex>,
263 weight: FloatTensor<Flex>,
264 bias: Option<FloatTensor<Flex>>,
265 options: ConvTransposeOptions<1>,
266 ) -> FloatTensor<Flex> {
267 match x.dtype() {
268 DType::F32 => conv_transpose::conv_transpose1d_f32(x, weight, bias, &options),
269 DType::F64 => conv_transpose::conv_transpose1d_f64(x, weight, bias, &options),
270 DType::F16 => conv_transpose::conv_transpose1d_f16(x, weight, bias, &options),
271 DType::BF16 => conv_transpose::conv_transpose1d_bf16(x, weight, bias, &options),
272 dtype => panic!("conv_transpose1d: unsupported dtype {:?}", dtype),
273 }
274 }
275
276 fn conv_transpose2d(
277 x: FloatTensor<Flex>,
278 weight: FloatTensor<Flex>,
279 bias: Option<FloatTensor<Flex>>,
280 options: ConvTransposeOptions<2>,
281 ) -> FloatTensor<Flex> {
282 match x.dtype() {
283 DType::F32 => conv_transpose::conv_transpose2d_f32(x, weight, bias, &options),
284 DType::F64 => conv_transpose::conv_transpose2d_f64(x, weight, bias, &options),
285 DType::F16 => conv_transpose::conv_transpose2d_f16(x, weight, bias, &options),
286 DType::BF16 => conv_transpose::conv_transpose2d_bf16(x, weight, bias, &options),
287 dtype => panic!("conv_transpose2d: unsupported dtype {:?}", dtype),
288 }
289 }
290
291 fn conv_transpose3d(
292 x: FloatTensor<Flex>,
293 weight: FloatTensor<Flex>,
294 bias: Option<FloatTensor<Flex>>,
295 options: ConvTransposeOptions<3>,
296 ) -> FloatTensor<Flex> {
297 match x.dtype() {
298 DType::F32 => conv_transpose::conv_transpose3d_f32(x, weight, bias, &options),
299 DType::F64 => conv_transpose::conv_transpose3d_f64(x, weight, bias, &options),
300 DType::F16 => conv_transpose::conv_transpose3d_f16(x, weight, bias, &options),
301 DType::BF16 => conv_transpose::conv_transpose3d_bf16(x, weight, bias, &options),
302 dtype => panic!("conv_transpose3d: unsupported dtype {:?}", dtype),
303 }
304 }
305
306 fn avg_pool2d(
307 x: FloatTensor<Flex>,
308 kernel_size: [usize; 2],
309 stride: [usize; 2],
310 padding: [usize; 2],
311 count_include_pad: bool,
312 ceil_mode: bool,
313 ) -> FloatTensor<Flex> {
314 match x.dtype() {
315 DType::F32 => pool::avg_pool2d_f32(
316 x,
317 kernel_size,
318 stride,
319 padding,
320 count_include_pad,
321 ceil_mode,
322 ),
323 DType::F64 => pool::avg_pool2d_f64(
324 x,
325 kernel_size,
326 stride,
327 padding,
328 count_include_pad,
329 ceil_mode,
330 ),
331 DType::F16 => pool::avg_pool2d_f16(
332 x,
333 kernel_size,
334 stride,
335 padding,
336 count_include_pad,
337 ceil_mode,
338 ),
339 DType::BF16 => pool::avg_pool2d_bf16(
340 x,
341 kernel_size,
342 stride,
343 padding,
344 count_include_pad,
345 ceil_mode,
346 ),
347 dtype => panic!("avg_pool2d: unsupported dtype {:?}", dtype),
348 }
349 }
350
351 fn avg_pool2d_backward(
352 x: FloatTensor<Flex>,
353 grad: FloatTensor<Flex>,
354 kernel_size: [usize; 2],
355 stride: [usize; 2],
356 padding: [usize; 2],
357 count_include_pad: bool,
358 _divisor_override: bool,
359 ) -> FloatTensor<Flex> {
360 match x.dtype() {
361 DType::F32 => pool::avg_pool2d_backward_f32(
362 x,
363 grad,
364 kernel_size,
365 stride,
366 padding,
367 count_include_pad,
368 ),
369 DType::F64 => pool::avg_pool2d_backward_f64(
370 x,
371 grad,
372 kernel_size,
373 stride,
374 padding,
375 count_include_pad,
376 ),
377 DType::F16 => pool::avg_pool2d_backward_f16(
378 x,
379 grad,
380 kernel_size,
381 stride,
382 padding,
383 count_include_pad,
384 ),
385 DType::BF16 => pool::avg_pool2d_backward_bf16(
386 x,
387 grad,
388 kernel_size,
389 stride,
390 padding,
391 count_include_pad,
392 ),
393 dtype => panic!("avg_pool2d_backward: unsupported dtype {:?}", dtype),
394 }
395 }
396
397 fn adaptive_avg_pool2d(x: FloatTensor<Flex>, output_size: [usize; 2]) -> FloatTensor<Flex> {
398 match x.dtype() {
399 DType::F32 => pool::adaptive_avg_pool2d_f32(x, output_size),
400 DType::F64 => pool::adaptive_avg_pool2d_f64(x, output_size),
401 DType::F16 => pool::adaptive_avg_pool2d_f16(x, output_size),
402 DType::BF16 => pool::adaptive_avg_pool2d_bf16(x, output_size),
403 dtype => panic!("adaptive_avg_pool2d: unsupported dtype {:?}", dtype),
404 }
405 }
406
407 fn adaptive_avg_pool2d_backward(
408 x: FloatTensor<Flex>,
409 grad: FloatTensor<Flex>,
410 ) -> FloatTensor<Flex> {
411 match x.dtype() {
412 DType::F32 => pool::adaptive_avg_pool2d_backward_f32(x, grad),
413 DType::F64 => pool::adaptive_avg_pool2d_backward_f64(x, grad),
414 DType::F16 => pool::adaptive_avg_pool2d_backward_f16(x, grad),
415 DType::BF16 => pool::adaptive_avg_pool2d_backward_bf16(x, grad),
416 dtype => panic!(
417 "adaptive_avg_pool2d_backward: unsupported dtype {:?}",
418 dtype
419 ),
420 }
421 }
422
423 fn max_pool2d(
424 x: FloatTensor<Flex>,
425 kernel_size: [usize; 2],
426 stride: [usize; 2],
427 padding: [usize; 2],
428 dilation: [usize; 2],
429 ceil_mode: bool,
430 ) -> FloatTensor<Flex> {
431 match x.dtype() {
432 DType::F32 => {
433 pool::max_pool2d_f32(x, kernel_size, stride, padding, dilation, ceil_mode)
434 }
435 DType::F64 => {
436 pool::max_pool2d_f64(x, kernel_size, stride, padding, dilation, ceil_mode)
437 }
438 DType::F16 => {
439 pool::max_pool2d_f16(x, kernel_size, stride, padding, dilation, ceil_mode)
440 }
441 DType::BF16 => {
442 pool::max_pool2d_bf16(x, kernel_size, stride, padding, dilation, ceil_mode)
443 }
444 dtype => panic!("max_pool2d: unsupported dtype {:?}", dtype),
445 }
446 }
447
448 fn max_pool2d_with_indices(
449 x: FloatTensor<Flex>,
450 kernel_size: [usize; 2],
451 stride: [usize; 2],
452 padding: [usize; 2],
453 dilation: [usize; 2],
454 ceil_mode: bool,
455 ) -> MaxPool2dWithIndices<Flex> {
456 let (output, indices) = match x.dtype() {
457 DType::F32 => pool::max_pool2d_with_indices_f32(
458 x,
459 kernel_size,
460 stride,
461 padding,
462 dilation,
463 ceil_mode,
464 ),
465 DType::F64 => pool::max_pool2d_with_indices_f64(
466 x,
467 kernel_size,
468 stride,
469 padding,
470 dilation,
471 ceil_mode,
472 ),
473 DType::F16 => pool::max_pool2d_with_indices_f16(
474 x,
475 kernel_size,
476 stride,
477 padding,
478 dilation,
479 ceil_mode,
480 ),
481 DType::BF16 => pool::max_pool2d_with_indices_bf16(
482 x,
483 kernel_size,
484 stride,
485 padding,
486 dilation,
487 ceil_mode,
488 ),
489 dtype => panic!("max_pool2d_with_indices: unsupported dtype {:?}", dtype),
490 };
491 MaxPool2dWithIndices::new(output, indices)
492 }
493
494 fn max_pool2d_with_indices_backward(
495 x: FloatTensor<Flex>,
496 _kernel_size: [usize; 2],
497 _stride: [usize; 2],
498 _padding: [usize; 2],
499 _dilation: [usize; 2],
500 _ceil_mode: bool,
501 output_grad: FloatTensor<Flex>,
502 indices: IntTensor<Flex>,
503 ) -> MaxPool2dBackward<Flex> {
504 let x_grad = match x.dtype() {
505 DType::F32 => pool::max_pool2d_backward_f32(x, output_grad, indices),
506 DType::F64 => pool::max_pool2d_backward_f64(x, output_grad, indices),
507 DType::F16 => pool::max_pool2d_backward_f16(x, output_grad, indices),
508 DType::BF16 => pool::max_pool2d_backward_bf16(x, output_grad, indices),
509 dtype => panic!(
510 "max_pool2d_with_indices_backward: unsupported dtype {:?}",
511 dtype
512 ),
513 };
514 MaxPool2dBackward::new(x_grad)
515 }
516
517 fn interpolate(
518 x: FloatTensor<Flex>,
519 output_size: [usize; 2],
520 options: InterpolateOptions,
521 ) -> FloatTensor<Flex> {
522 match (options.mode, x.dtype()) {
523 (InterpolateMode::Nearest, DType::F32) => {
524 interpolate::interpolate_nearest_f32(x, output_size, options.align_corners)
525 }
526 (InterpolateMode::Nearest, DType::F64) => {
527 interpolate::interpolate_nearest_f64(x, output_size, options.align_corners)
528 }
529 (InterpolateMode::Nearest, DType::F16) => {
530 interpolate::interpolate_nearest_f16(x, output_size, options.align_corners)
531 }
532 (InterpolateMode::Nearest, DType::BF16) => {
533 interpolate::interpolate_nearest_bf16(x, output_size, options.align_corners)
534 }
535 (InterpolateMode::Bilinear, DType::F32) => {
536 interpolate::interpolate_bilinear_f32(x, output_size, options.align_corners)
537 }
538 (InterpolateMode::Bilinear, DType::F64) => {
539 interpolate::interpolate_bilinear_f64(x, output_size, options.align_corners)
540 }
541 (InterpolateMode::Bilinear, DType::F16) => {
542 interpolate::interpolate_bilinear_f16(x, output_size, options.align_corners)
543 }
544 (InterpolateMode::Bilinear, DType::BF16) => {
545 interpolate::interpolate_bilinear_bf16(x, output_size, options.align_corners)
546 }
547 (InterpolateMode::Bicubic, DType::F32) => {
548 interpolate::interpolate_bicubic_f32(x, output_size, options.align_corners)
549 }
550 (InterpolateMode::Bicubic, DType::F64) => {
551 interpolate::interpolate_bicubic_f64(x, output_size, options.align_corners)
552 }
553 (InterpolateMode::Bicubic, DType::F16) => {
554 interpolate::interpolate_bicubic_f16(x, output_size, options.align_corners)
555 }
556 (InterpolateMode::Bicubic, DType::BF16) => {
557 interpolate::interpolate_bicubic_bf16(x, output_size, options.align_corners)
558 }
559 (InterpolateMode::Lanczos3, DType::F32) => {
560 interpolate::interpolate_lanczos3_f32(x, output_size, options.align_corners)
561 }
562 (InterpolateMode::Lanczos3, DType::F64) => {
563 interpolate::interpolate_lanczos3_f64(x, output_size, options.align_corners)
564 }
565 (InterpolateMode::Lanczos3, DType::F16) => {
566 interpolate::interpolate_lanczos3_f16(x, output_size, options.align_corners)
567 }
568 (InterpolateMode::Lanczos3, DType::BF16) => {
569 interpolate::interpolate_lanczos3_bf16(x, output_size, options.align_corners)
570 }
571 (mode, dtype) => panic!(
572 "interpolate: unsupported mode {:?} / dtype {:?}",
573 mode, dtype
574 ),
575 }
576 }
577
578 fn interpolate_backward(
579 x: FloatTensor<Flex>,
580 grad: FloatTensor<Flex>,
581 output_size: [usize; 2],
582 options: InterpolateOptions,
583 ) -> FloatTensor<Flex> {
584 match (options.mode, x.dtype()) {
585 (InterpolateMode::Nearest, DType::F32) => {
586 interpolate::interpolate_nearest_backward_f32(
587 x,
588 grad,
589 output_size,
590 options.align_corners,
591 )
592 }
593 (InterpolateMode::Nearest, DType::F64) => {
594 interpolate::interpolate_nearest_backward_f64(
595 x,
596 grad,
597 output_size,
598 options.align_corners,
599 )
600 }
601 (InterpolateMode::Nearest, DType::F16) => {
602 interpolate::interpolate_nearest_backward_f16(
603 x,
604 grad,
605 output_size,
606 options.align_corners,
607 )
608 }
609 (InterpolateMode::Nearest, DType::BF16) => {
610 interpolate::interpolate_nearest_backward_bf16(
611 x,
612 grad,
613 output_size,
614 options.align_corners,
615 )
616 }
617 (InterpolateMode::Bilinear, DType::F32) => {
618 interpolate::interpolate_bilinear_backward_f32(
619 x,
620 grad,
621 output_size,
622 options.align_corners,
623 )
624 }
625 (InterpolateMode::Bilinear, DType::F64) => {
626 interpolate::interpolate_bilinear_backward_f64(
627 x,
628 grad,
629 output_size,
630 options.align_corners,
631 )
632 }
633 (InterpolateMode::Bilinear, DType::F16) => {
634 interpolate::interpolate_bilinear_backward_f16(
635 x,
636 grad,
637 output_size,
638 options.align_corners,
639 )
640 }
641 (InterpolateMode::Bilinear, DType::BF16) => {
642 interpolate::interpolate_bilinear_backward_bf16(
643 x,
644 grad,
645 output_size,
646 options.align_corners,
647 )
648 }
649 (InterpolateMode::Bicubic, DType::F32) => {
650 interpolate::interpolate_bicubic_backward_f32(
651 x,
652 grad,
653 output_size,
654 options.align_corners,
655 )
656 }
657 (InterpolateMode::Bicubic, DType::F64) => {
658 interpolate::interpolate_bicubic_backward_f64(
659 x,
660 grad,
661 output_size,
662 options.align_corners,
663 )
664 }
665 (InterpolateMode::Bicubic, DType::F16) => {
666 interpolate::interpolate_bicubic_backward_f16(
667 x,
668 grad,
669 output_size,
670 options.align_corners,
671 )
672 }
673 (InterpolateMode::Bicubic, DType::BF16) => {
674 interpolate::interpolate_bicubic_backward_bf16(
675 x,
676 grad,
677 output_size,
678 options.align_corners,
679 )
680 }
681 (mode, dtype) => {
682 panic!(
683 "interpolate_backward: unsupported mode {:?} / dtype {:?}",
684 mode, dtype
685 )
686 }
687 }
688 }
689
690 fn attention(
691 query: FloatTensor<Flex>,
692 key: FloatTensor<Flex>,
693 value: FloatTensor<Flex>,
694 mask: Option<BoolTensor<Flex>>,
695 attn_bias: Option<FloatTensor<Flex>>,
696 options: AttentionModuleOptions,
697 ) -> FloatTensor<Flex> {
698 crate::ops::attention::attention(query, key, value, mask, attn_bias, options)
699 }
700
701 fn rfft(
702 signal: FloatTensor<Flex>,
703 dim: usize,
704 n: Option<usize>,
705 ) -> (FloatTensor<Flex>, FloatTensor<Flex>) {
706 match signal.dtype() {
707 DType::F32 => crate::ops::fft::rfft_f32(signal, dim, n),
708 DType::F64 => crate::ops::fft::rfft_f64(signal, dim, n),
709 DType::F16 => crate::ops::fft::rfft_f16(signal, dim, n),
710 DType::BF16 => crate::ops::fft::rfft_bf16(signal, dim, n),
711 dtype => panic!("rfft: unsupported dtype {:?}", dtype),
712 }
713 }
714
715 fn irfft(
716 spectrum_re: FloatTensor<Flex>,
717 spectrum_im: FloatTensor<Flex>,
718 dim: usize,
719 n: Option<usize>,
720 ) -> FloatTensor<Flex> {
721 match spectrum_re.dtype() {
722 DType::F32 => crate::ops::fft::irfft_f32(spectrum_re, spectrum_im, dim, n),
723 DType::F64 => crate::ops::fft::irfft_f64(spectrum_re, spectrum_im, dim, n),
724 DType::F16 => crate::ops::fft::irfft_f16(spectrum_re, spectrum_im, dim, n),
725 DType::BF16 => crate::ops::fft::irfft_bf16(spectrum_re, spectrum_im, dim, n),
726 dtype => panic!("irfft: unsupported dtype {:?}", dtype),
727 }
728 }
729
730 fn embedding(weights: FloatTensor<Flex>, indices: IntTensor<Flex>) -> FloatTensor<Flex> {
731 let [batch_size, seq_length] = indices.shape().dims();
732 let [_, d_model] = weights.shape().dims();
733
734 let indices = Flex::int_reshape(indices, Shape::from(alloc::vec![batch_size * seq_length]));
735 let output = Flex::float_select(weights, 0, indices);
736 Flex::float_reshape(
737 output,
738 Shape::from(alloc::vec![batch_size, seq_length, d_model]),
739 )
740 }
741
742 fn layer_norm(
743 tensor: FloatTensor<Flex>,
744 gamma: FloatTensor<Flex>,
745 beta: Option<FloatTensor<Flex>>,
746 epsilon: f64,
747 ) -> FloatTensor<Flex> {
748 crate::ops::activation::layer_norm(tensor, gamma, beta, epsilon)
749 }
750
751 fn embedding_backward(
752 weights: FloatTensor<Flex>,
753 output_grad: FloatTensor<Flex>,
754 indices: IntTensor<Flex>,
755 ) -> FloatTensor<Flex> {
756 let [batch_size, seq_length] = indices.shape().dims();
757 let [n_embeddings, d_model] = weights.shape().dims();
758 let dtype = output_grad.dtype();
759
760 let indices = Flex::int_reshape(indices, Shape::from(alloc::vec![batch_size * seq_length]));
761 let output_grad = Flex::float_reshape(
762 output_grad,
763 Shape::from(alloc::vec![batch_size * seq_length, d_model]),
764 );
765 let grad = Flex::float_zeros(
766 Shape::from(alloc::vec![n_embeddings, d_model]),
767 &Default::default(),
768 dtype.into(),
769 );
770 Flex::float_select_add(grad, 0, indices, output_grad)
771 }
772}