Skip to main content

burn_flex/ops/
module.rs

1//! Module operations for the Flex backend.
2//!
3//! These operations power neural network modules like convolutions and pooling.
4
5use crate::ops::{conv, conv_transpose, deform_conv, interpolate, pool};
6use crate::{Flex, FlexTensor, Layout};
7use burn_backend::{
8    DType, Element, TensorMetadata,
9    ops::{
10        AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward,
11        DeformConvOptions, FloatTensorOps, IntTensorOps, InterpolateMode, InterpolateOptions,
12        MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
13    },
14    tensor::{BoolTensor, FloatTensor, IntTensor},
15};
16use burn_std::{Bytes, Shape};
17use bytemuck::Pod;
18
19/// Cast a tensor from half-precision type E to f32.
20pub(crate) fn cast_to_f32<E: Element + Pod + Copy>(
21    tensor: FlexTensor,
22    to_f32: fn(E) -> f32,
23) -> FlexTensor {
24    let tensor = tensor.to_contiguous();
25    let shape = tensor.layout().shape().clone();
26    let data: &[E] = tensor.storage();
27    let f32_data: alloc::vec::Vec<f32> = data.iter().map(|&v| to_f32(v)).collect();
28    let bytes = Bytes::from_elems(f32_data);
29    FlexTensor::new(bytes, Layout::contiguous(shape), DType::F32)
30}
31
32/// Cast a tensor from f32 back to half-precision type E.
33pub(crate) fn cast_from_f32<E: Element + Pod + Copy>(
34    tensor: FlexTensor,
35    from_f32: fn(f32) -> E,
36) -> FlexTensor {
37    let tensor = tensor.to_contiguous();
38    let shape = tensor.layout().shape().clone();
39    let data: &[f32] = tensor.storage();
40    let half_data: alloc::vec::Vec<E> = data.iter().map(|&v| from_f32(v)).collect();
41    let bytes = Bytes::from_elems(half_data);
42    FlexTensor::new(bytes, Layout::contiguous(shape), E::dtype())
43}
44
45impl ModuleOps<Flex> for Flex {
46    fn conv1d(
47        x: FloatTensor<Flex>,
48        weight: FloatTensor<Flex>,
49        bias: Option<FloatTensor<Flex>>,
50        options: ConvOptions<1>,
51    ) -> FloatTensor<Flex> {
52        match x.dtype() {
53            DType::F32 => conv::conv1d_f32(x, weight, bias, &options),
54            DType::F64 => conv::conv1d_f64(x, weight, bias, &options),
55            DType::F16 => conv::conv1d_f16(x, weight, bias, &options),
56            DType::BF16 => conv::conv1d_bf16(x, weight, bias, &options),
57            dtype => panic!("conv1d: unsupported dtype {:?}", dtype),
58        }
59    }
60
61    fn conv2d(
62        x: FloatTensor<Flex>,
63        weight: FloatTensor<Flex>,
64        bias: Option<FloatTensor<Flex>>,
65        options: ConvOptions<2>,
66    ) -> FloatTensor<Flex> {
67        match x.dtype() {
68            DType::F32 => conv::conv2d_f32(x, weight, bias, &options),
69            DType::F64 => conv::conv2d_f64(x, weight, bias, &options),
70            DType::F16 => conv::conv2d_f16(x, weight, bias, &options),
71            DType::BF16 => conv::conv2d_bf16(x, weight, bias, &options),
72            dtype => panic!("conv2d: unsupported dtype {:?}", dtype),
73        }
74    }
75
76    fn deform_conv2d(
77        x: FloatTensor<Flex>,
78        offset: FloatTensor<Flex>,
79        weight: FloatTensor<Flex>,
80        mask: Option<FloatTensor<Flex>>,
81        bias: Option<FloatTensor<Flex>>,
82        options: DeformConvOptions<2>,
83    ) -> FloatTensor<Flex> {
84        match x.dtype() {
85            DType::F32 => deform_conv::deform_conv2d_f32(
86                x,
87                offset,
88                weight,
89                mask,
90                bias,
91                options.stride,
92                options.padding,
93                options.dilation,
94                options.weight_groups,
95                options.offset_groups,
96            ),
97            DType::F64 => deform_conv::deform_conv2d_f64(
98                x,
99                offset,
100                weight,
101                mask,
102                bias,
103                options.stride,
104                options.padding,
105                options.dilation,
106                options.weight_groups,
107                options.offset_groups,
108            ),
109            DType::F16 => {
110                use burn_std::f16;
111                let result = deform_conv::deform_conv2d_f32(
112                    cast_to_f32(x, f16::to_f32),
113                    cast_to_f32(offset, f16::to_f32),
114                    cast_to_f32(weight, f16::to_f32),
115                    mask.map(|m| cast_to_f32(m, f16::to_f32)),
116                    bias.map(|b| cast_to_f32(b, f16::to_f32)),
117                    options.stride,
118                    options.padding,
119                    options.dilation,
120                    options.weight_groups,
121                    options.offset_groups,
122                );
123                cast_from_f32(result, f16::from_f32)
124            }
125            DType::BF16 => {
126                use burn_std::bf16;
127                let result = deform_conv::deform_conv2d_f32(
128                    cast_to_f32(x, bf16::to_f32),
129                    cast_to_f32(offset, bf16::to_f32),
130                    cast_to_f32(weight, bf16::to_f32),
131                    mask.map(|m| cast_to_f32(m, bf16::to_f32)),
132                    bias.map(|b| cast_to_f32(b, bf16::to_f32)),
133                    options.stride,
134                    options.padding,
135                    options.dilation,
136                    options.weight_groups,
137                    options.offset_groups,
138                );
139                cast_from_f32(result, bf16::from_f32)
140            }
141            dtype => panic!("deform_conv2d: unsupported dtype {:?}", dtype),
142        }
143    }
144
145    fn deform_conv2d_backward(
146        x: FloatTensor<Flex>,
147        offset: FloatTensor<Flex>,
148        weight: FloatTensor<Flex>,
149        mask: Option<FloatTensor<Flex>>,
150        bias: Option<FloatTensor<Flex>>,
151        output_grad: FloatTensor<Flex>,
152        options: DeformConvOptions<2>,
153    ) -> DeformConv2dBackward<Flex> {
154        let (x_grad, offset_grad, weight_grad, mask_grad, bias_grad) = match x.dtype() {
155            DType::F32 => deform_conv::deform_conv2d_backward_f32(
156                x,
157                offset,
158                weight,
159                mask,
160                bias,
161                output_grad,
162                options.stride,
163                options.padding,
164                options.dilation,
165                options.weight_groups,
166                options.offset_groups,
167            ),
168            DType::F16 => {
169                use burn_std::f16;
170                let (xg, og, wg, mg, bg) = deform_conv::deform_conv2d_backward_f32(
171                    cast_to_f32(x, f16::to_f32),
172                    cast_to_f32(offset, f16::to_f32),
173                    cast_to_f32(weight, f16::to_f32),
174                    mask.map(|m| cast_to_f32(m, f16::to_f32)),
175                    bias.map(|b| cast_to_f32(b, f16::to_f32)),
176                    cast_to_f32(output_grad, f16::to_f32),
177                    options.stride,
178                    options.padding,
179                    options.dilation,
180                    options.weight_groups,
181                    options.offset_groups,
182                );
183                (
184                    cast_from_f32(xg, f16::from_f32),
185                    cast_from_f32(og, f16::from_f32),
186                    cast_from_f32(wg, f16::from_f32),
187                    mg.map(|m| cast_from_f32(m, f16::from_f32)),
188                    bg.map(|b| cast_from_f32(b, f16::from_f32)),
189                )
190            }
191            DType::BF16 => {
192                use burn_std::bf16;
193                let (xg, og, wg, mg, bg) = deform_conv::deform_conv2d_backward_f32(
194                    cast_to_f32(x, bf16::to_f32),
195                    cast_to_f32(offset, bf16::to_f32),
196                    cast_to_f32(weight, bf16::to_f32),
197                    mask.map(|m| cast_to_f32(m, bf16::to_f32)),
198                    bias.map(|b| cast_to_f32(b, bf16::to_f32)),
199                    cast_to_f32(output_grad, bf16::to_f32),
200                    options.stride,
201                    options.padding,
202                    options.dilation,
203                    options.weight_groups,
204                    options.offset_groups,
205                );
206                (
207                    cast_from_f32(xg, bf16::from_f32),
208                    cast_from_f32(og, bf16::from_f32),
209                    cast_from_f32(wg, bf16::from_f32),
210                    mg.map(|m| cast_from_f32(m, bf16::from_f32)),
211                    bg.map(|b| cast_from_f32(b, bf16::from_f32)),
212                )
213            }
214            // f64 backward computed via f32: precision loss for large/small values.
215            // A native f64 implementation would require duplicating ~400 lines of
216            // deform_conv2d_backward. f64 deform_conv is rare in practice.
217            DType::F64 => {
218                let to = |v: f64| v as f32;
219                let from = |v: f32| v as f64;
220                let (xg, og, wg, mg, bg) = deform_conv::deform_conv2d_backward_f32(
221                    cast_to_f32(x, to),
222                    cast_to_f32(offset, to),
223                    cast_to_f32(weight, to),
224                    mask.map(|m| cast_to_f32(m, to)),
225                    bias.map(|b| cast_to_f32(b, to)),
226                    cast_to_f32(output_grad, to),
227                    options.stride,
228                    options.padding,
229                    options.dilation,
230                    options.weight_groups,
231                    options.offset_groups,
232                );
233                (
234                    cast_from_f32(xg, from),
235                    cast_from_f32(og, from),
236                    cast_from_f32(wg, from),
237                    mg.map(|m| cast_from_f32(m, from)),
238                    bg.map(|b| cast_from_f32(b, from)),
239                )
240            }
241            dtype => panic!("deform_conv2d_backward: unsupported dtype {:?}", dtype),
242        };
243        DeformConv2dBackward::new(x_grad, offset_grad, weight_grad, mask_grad, bias_grad)
244    }
245
246    fn conv3d(
247        x: FloatTensor<Flex>,
248        weight: FloatTensor<Flex>,
249        bias: Option<FloatTensor<Flex>>,
250        options: ConvOptions<3>,
251    ) -> FloatTensor<Flex> {
252        match x.dtype() {
253            DType::F32 => conv::conv3d_f32(x, weight, bias, &options),
254            DType::F64 => conv::conv3d_f64(x, weight, bias, &options),
255            DType::F16 => conv::conv3d_f16(x, weight, bias, &options),
256            DType::BF16 => conv::conv3d_bf16(x, weight, bias, &options),
257            dtype => panic!("conv3d: unsupported dtype {:?}", dtype),
258        }
259    }
260
261    fn conv_transpose1d(
262        x: FloatTensor<Flex>,
263        weight: FloatTensor<Flex>,
264        bias: Option<FloatTensor<Flex>>,
265        options: ConvTransposeOptions<1>,
266    ) -> FloatTensor<Flex> {
267        match x.dtype() {
268            DType::F32 => conv_transpose::conv_transpose1d_f32(x, weight, bias, &options),
269            DType::F64 => conv_transpose::conv_transpose1d_f64(x, weight, bias, &options),
270            DType::F16 => conv_transpose::conv_transpose1d_f16(x, weight, bias, &options),
271            DType::BF16 => conv_transpose::conv_transpose1d_bf16(x, weight, bias, &options),
272            dtype => panic!("conv_transpose1d: unsupported dtype {:?}", dtype),
273        }
274    }
275
276    fn conv_transpose2d(
277        x: FloatTensor<Flex>,
278        weight: FloatTensor<Flex>,
279        bias: Option<FloatTensor<Flex>>,
280        options: ConvTransposeOptions<2>,
281    ) -> FloatTensor<Flex> {
282        match x.dtype() {
283            DType::F32 => conv_transpose::conv_transpose2d_f32(x, weight, bias, &options),
284            DType::F64 => conv_transpose::conv_transpose2d_f64(x, weight, bias, &options),
285            DType::F16 => conv_transpose::conv_transpose2d_f16(x, weight, bias, &options),
286            DType::BF16 => conv_transpose::conv_transpose2d_bf16(x, weight, bias, &options),
287            dtype => panic!("conv_transpose2d: unsupported dtype {:?}", dtype),
288        }
289    }
290
291    fn conv_transpose3d(
292        x: FloatTensor<Flex>,
293        weight: FloatTensor<Flex>,
294        bias: Option<FloatTensor<Flex>>,
295        options: ConvTransposeOptions<3>,
296    ) -> FloatTensor<Flex> {
297        match x.dtype() {
298            DType::F32 => conv_transpose::conv_transpose3d_f32(x, weight, bias, &options),
299            DType::F64 => conv_transpose::conv_transpose3d_f64(x, weight, bias, &options),
300            DType::F16 => conv_transpose::conv_transpose3d_f16(x, weight, bias, &options),
301            DType::BF16 => conv_transpose::conv_transpose3d_bf16(x, weight, bias, &options),
302            dtype => panic!("conv_transpose3d: unsupported dtype {:?}", dtype),
303        }
304    }
305
306    fn avg_pool2d(
307        x: FloatTensor<Flex>,
308        kernel_size: [usize; 2],
309        stride: [usize; 2],
310        padding: [usize; 2],
311        count_include_pad: bool,
312        ceil_mode: bool,
313    ) -> FloatTensor<Flex> {
314        match x.dtype() {
315            DType::F32 => pool::avg_pool2d_f32(
316                x,
317                kernel_size,
318                stride,
319                padding,
320                count_include_pad,
321                ceil_mode,
322            ),
323            DType::F64 => pool::avg_pool2d_f64(
324                x,
325                kernel_size,
326                stride,
327                padding,
328                count_include_pad,
329                ceil_mode,
330            ),
331            DType::F16 => pool::avg_pool2d_f16(
332                x,
333                kernel_size,
334                stride,
335                padding,
336                count_include_pad,
337                ceil_mode,
338            ),
339            DType::BF16 => pool::avg_pool2d_bf16(
340                x,
341                kernel_size,
342                stride,
343                padding,
344                count_include_pad,
345                ceil_mode,
346            ),
347            dtype => panic!("avg_pool2d: unsupported dtype {:?}", dtype),
348        }
349    }
350
351    fn avg_pool2d_backward(
352        x: FloatTensor<Flex>,
353        grad: FloatTensor<Flex>,
354        kernel_size: [usize; 2],
355        stride: [usize; 2],
356        padding: [usize; 2],
357        count_include_pad: bool,
358        _divisor_override: bool,
359    ) -> FloatTensor<Flex> {
360        match x.dtype() {
361            DType::F32 => pool::avg_pool2d_backward_f32(
362                x,
363                grad,
364                kernel_size,
365                stride,
366                padding,
367                count_include_pad,
368            ),
369            DType::F64 => pool::avg_pool2d_backward_f64(
370                x,
371                grad,
372                kernel_size,
373                stride,
374                padding,
375                count_include_pad,
376            ),
377            DType::F16 => pool::avg_pool2d_backward_f16(
378                x,
379                grad,
380                kernel_size,
381                stride,
382                padding,
383                count_include_pad,
384            ),
385            DType::BF16 => pool::avg_pool2d_backward_bf16(
386                x,
387                grad,
388                kernel_size,
389                stride,
390                padding,
391                count_include_pad,
392            ),
393            dtype => panic!("avg_pool2d_backward: unsupported dtype {:?}", dtype),
394        }
395    }
396
397    fn adaptive_avg_pool2d(x: FloatTensor<Flex>, output_size: [usize; 2]) -> FloatTensor<Flex> {
398        match x.dtype() {
399            DType::F32 => pool::adaptive_avg_pool2d_f32(x, output_size),
400            DType::F64 => pool::adaptive_avg_pool2d_f64(x, output_size),
401            DType::F16 => pool::adaptive_avg_pool2d_f16(x, output_size),
402            DType::BF16 => pool::adaptive_avg_pool2d_bf16(x, output_size),
403            dtype => panic!("adaptive_avg_pool2d: unsupported dtype {:?}", dtype),
404        }
405    }
406
407    fn adaptive_avg_pool2d_backward(
408        x: FloatTensor<Flex>,
409        grad: FloatTensor<Flex>,
410    ) -> FloatTensor<Flex> {
411        match x.dtype() {
412            DType::F32 => pool::adaptive_avg_pool2d_backward_f32(x, grad),
413            DType::F64 => pool::adaptive_avg_pool2d_backward_f64(x, grad),
414            DType::F16 => pool::adaptive_avg_pool2d_backward_f16(x, grad),
415            DType::BF16 => pool::adaptive_avg_pool2d_backward_bf16(x, grad),
416            dtype => panic!(
417                "adaptive_avg_pool2d_backward: unsupported dtype {:?}",
418                dtype
419            ),
420        }
421    }
422
423    fn max_pool2d(
424        x: FloatTensor<Flex>,
425        kernel_size: [usize; 2],
426        stride: [usize; 2],
427        padding: [usize; 2],
428        dilation: [usize; 2],
429        ceil_mode: bool,
430    ) -> FloatTensor<Flex> {
431        match x.dtype() {
432            DType::F32 => {
433                pool::max_pool2d_f32(x, kernel_size, stride, padding, dilation, ceil_mode)
434            }
435            DType::F64 => {
436                pool::max_pool2d_f64(x, kernel_size, stride, padding, dilation, ceil_mode)
437            }
438            DType::F16 => {
439                pool::max_pool2d_f16(x, kernel_size, stride, padding, dilation, ceil_mode)
440            }
441            DType::BF16 => {
442                pool::max_pool2d_bf16(x, kernel_size, stride, padding, dilation, ceil_mode)
443            }
444            dtype => panic!("max_pool2d: unsupported dtype {:?}", dtype),
445        }
446    }
447
448    fn max_pool2d_with_indices(
449        x: FloatTensor<Flex>,
450        kernel_size: [usize; 2],
451        stride: [usize; 2],
452        padding: [usize; 2],
453        dilation: [usize; 2],
454        ceil_mode: bool,
455    ) -> MaxPool2dWithIndices<Flex> {
456        let (output, indices) = match x.dtype() {
457            DType::F32 => pool::max_pool2d_with_indices_f32(
458                x,
459                kernel_size,
460                stride,
461                padding,
462                dilation,
463                ceil_mode,
464            ),
465            DType::F64 => pool::max_pool2d_with_indices_f64(
466                x,
467                kernel_size,
468                stride,
469                padding,
470                dilation,
471                ceil_mode,
472            ),
473            DType::F16 => pool::max_pool2d_with_indices_f16(
474                x,
475                kernel_size,
476                stride,
477                padding,
478                dilation,
479                ceil_mode,
480            ),
481            DType::BF16 => pool::max_pool2d_with_indices_bf16(
482                x,
483                kernel_size,
484                stride,
485                padding,
486                dilation,
487                ceil_mode,
488            ),
489            dtype => panic!("max_pool2d_with_indices: unsupported dtype {:?}", dtype),
490        };
491        MaxPool2dWithIndices::new(output, indices)
492    }
493
494    fn max_pool2d_with_indices_backward(
495        x: FloatTensor<Flex>,
496        _kernel_size: [usize; 2],
497        _stride: [usize; 2],
498        _padding: [usize; 2],
499        _dilation: [usize; 2],
500        _ceil_mode: bool,
501        output_grad: FloatTensor<Flex>,
502        indices: IntTensor<Flex>,
503    ) -> MaxPool2dBackward<Flex> {
504        let x_grad = match x.dtype() {
505            DType::F32 => pool::max_pool2d_backward_f32(x, output_grad, indices),
506            DType::F64 => pool::max_pool2d_backward_f64(x, output_grad, indices),
507            DType::F16 => pool::max_pool2d_backward_f16(x, output_grad, indices),
508            DType::BF16 => pool::max_pool2d_backward_bf16(x, output_grad, indices),
509            dtype => panic!(
510                "max_pool2d_with_indices_backward: unsupported dtype {:?}",
511                dtype
512            ),
513        };
514        MaxPool2dBackward::new(x_grad)
515    }
516
517    fn interpolate(
518        x: FloatTensor<Flex>,
519        output_size: [usize; 2],
520        options: InterpolateOptions,
521    ) -> FloatTensor<Flex> {
522        match (options.mode, x.dtype()) {
523            (InterpolateMode::Nearest, DType::F32) => {
524                interpolate::interpolate_nearest_f32(x, output_size, options.align_corners)
525            }
526            (InterpolateMode::Nearest, DType::F64) => {
527                interpolate::interpolate_nearest_f64(x, output_size, options.align_corners)
528            }
529            (InterpolateMode::Nearest, DType::F16) => {
530                interpolate::interpolate_nearest_f16(x, output_size, options.align_corners)
531            }
532            (InterpolateMode::Nearest, DType::BF16) => {
533                interpolate::interpolate_nearest_bf16(x, output_size, options.align_corners)
534            }
535            (InterpolateMode::Bilinear, DType::F32) => {
536                interpolate::interpolate_bilinear_f32(x, output_size, options.align_corners)
537            }
538            (InterpolateMode::Bilinear, DType::F64) => {
539                interpolate::interpolate_bilinear_f64(x, output_size, options.align_corners)
540            }
541            (InterpolateMode::Bilinear, DType::F16) => {
542                interpolate::interpolate_bilinear_f16(x, output_size, options.align_corners)
543            }
544            (InterpolateMode::Bilinear, DType::BF16) => {
545                interpolate::interpolate_bilinear_bf16(x, output_size, options.align_corners)
546            }
547            (InterpolateMode::Bicubic, DType::F32) => {
548                interpolate::interpolate_bicubic_f32(x, output_size, options.align_corners)
549            }
550            (InterpolateMode::Bicubic, DType::F64) => {
551                interpolate::interpolate_bicubic_f64(x, output_size, options.align_corners)
552            }
553            (InterpolateMode::Bicubic, DType::F16) => {
554                interpolate::interpolate_bicubic_f16(x, output_size, options.align_corners)
555            }
556            (InterpolateMode::Bicubic, DType::BF16) => {
557                interpolate::interpolate_bicubic_bf16(x, output_size, options.align_corners)
558            }
559            (InterpolateMode::Lanczos3, DType::F32) => {
560                interpolate::interpolate_lanczos3_f32(x, output_size, options.align_corners)
561            }
562            (InterpolateMode::Lanczos3, DType::F64) => {
563                interpolate::interpolate_lanczos3_f64(x, output_size, options.align_corners)
564            }
565            (InterpolateMode::Lanczos3, DType::F16) => {
566                interpolate::interpolate_lanczos3_f16(x, output_size, options.align_corners)
567            }
568            (InterpolateMode::Lanczos3, DType::BF16) => {
569                interpolate::interpolate_lanczos3_bf16(x, output_size, options.align_corners)
570            }
571            (mode, dtype) => panic!(
572                "interpolate: unsupported mode {:?} / dtype {:?}",
573                mode, dtype
574            ),
575        }
576    }
577
578    fn interpolate_backward(
579        x: FloatTensor<Flex>,
580        grad: FloatTensor<Flex>,
581        output_size: [usize; 2],
582        options: InterpolateOptions,
583    ) -> FloatTensor<Flex> {
584        match (options.mode, x.dtype()) {
585            (InterpolateMode::Nearest, DType::F32) => {
586                interpolate::interpolate_nearest_backward_f32(
587                    x,
588                    grad,
589                    output_size,
590                    options.align_corners,
591                )
592            }
593            (InterpolateMode::Nearest, DType::F64) => {
594                interpolate::interpolate_nearest_backward_f64(
595                    x,
596                    grad,
597                    output_size,
598                    options.align_corners,
599                )
600            }
601            (InterpolateMode::Nearest, DType::F16) => {
602                interpolate::interpolate_nearest_backward_f16(
603                    x,
604                    grad,
605                    output_size,
606                    options.align_corners,
607                )
608            }
609            (InterpolateMode::Nearest, DType::BF16) => {
610                interpolate::interpolate_nearest_backward_bf16(
611                    x,
612                    grad,
613                    output_size,
614                    options.align_corners,
615                )
616            }
617            (InterpolateMode::Bilinear, DType::F32) => {
618                interpolate::interpolate_bilinear_backward_f32(
619                    x,
620                    grad,
621                    output_size,
622                    options.align_corners,
623                )
624            }
625            (InterpolateMode::Bilinear, DType::F64) => {
626                interpolate::interpolate_bilinear_backward_f64(
627                    x,
628                    grad,
629                    output_size,
630                    options.align_corners,
631                )
632            }
633            (InterpolateMode::Bilinear, DType::F16) => {
634                interpolate::interpolate_bilinear_backward_f16(
635                    x,
636                    grad,
637                    output_size,
638                    options.align_corners,
639                )
640            }
641            (InterpolateMode::Bilinear, DType::BF16) => {
642                interpolate::interpolate_bilinear_backward_bf16(
643                    x,
644                    grad,
645                    output_size,
646                    options.align_corners,
647                )
648            }
649            (InterpolateMode::Bicubic, DType::F32) => {
650                interpolate::interpolate_bicubic_backward_f32(
651                    x,
652                    grad,
653                    output_size,
654                    options.align_corners,
655                )
656            }
657            (InterpolateMode::Bicubic, DType::F64) => {
658                interpolate::interpolate_bicubic_backward_f64(
659                    x,
660                    grad,
661                    output_size,
662                    options.align_corners,
663                )
664            }
665            (InterpolateMode::Bicubic, DType::F16) => {
666                interpolate::interpolate_bicubic_backward_f16(
667                    x,
668                    grad,
669                    output_size,
670                    options.align_corners,
671                )
672            }
673            (InterpolateMode::Bicubic, DType::BF16) => {
674                interpolate::interpolate_bicubic_backward_bf16(
675                    x,
676                    grad,
677                    output_size,
678                    options.align_corners,
679                )
680            }
681            (mode, dtype) => {
682                panic!(
683                    "interpolate_backward: unsupported mode {:?} / dtype {:?}",
684                    mode, dtype
685                )
686            }
687        }
688    }
689
690    fn attention(
691        query: FloatTensor<Flex>,
692        key: FloatTensor<Flex>,
693        value: FloatTensor<Flex>,
694        mask: Option<BoolTensor<Flex>>,
695        attn_bias: Option<FloatTensor<Flex>>,
696        options: AttentionModuleOptions,
697    ) -> FloatTensor<Flex> {
698        crate::ops::attention::attention(query, key, value, mask, attn_bias, options)
699    }
700
701    fn rfft(
702        signal: FloatTensor<Flex>,
703        dim: usize,
704        n: Option<usize>,
705    ) -> (FloatTensor<Flex>, FloatTensor<Flex>) {
706        match signal.dtype() {
707            DType::F32 => crate::ops::fft::rfft_f32(signal, dim, n),
708            DType::F64 => crate::ops::fft::rfft_f64(signal, dim, n),
709            DType::F16 => crate::ops::fft::rfft_f16(signal, dim, n),
710            DType::BF16 => crate::ops::fft::rfft_bf16(signal, dim, n),
711            dtype => panic!("rfft: unsupported dtype {:?}", dtype),
712        }
713    }
714
715    fn irfft(
716        spectrum_re: FloatTensor<Flex>,
717        spectrum_im: FloatTensor<Flex>,
718        dim: usize,
719        n: Option<usize>,
720    ) -> FloatTensor<Flex> {
721        match spectrum_re.dtype() {
722            DType::F32 => crate::ops::fft::irfft_f32(spectrum_re, spectrum_im, dim, n),
723            DType::F64 => crate::ops::fft::irfft_f64(spectrum_re, spectrum_im, dim, n),
724            DType::F16 => crate::ops::fft::irfft_f16(spectrum_re, spectrum_im, dim, n),
725            DType::BF16 => crate::ops::fft::irfft_bf16(spectrum_re, spectrum_im, dim, n),
726            dtype => panic!("irfft: unsupported dtype {:?}", dtype),
727        }
728    }
729
730    fn embedding(weights: FloatTensor<Flex>, indices: IntTensor<Flex>) -> FloatTensor<Flex> {
731        let [batch_size, seq_length] = indices.shape().dims();
732        let [_, d_model] = weights.shape().dims();
733
734        let indices = Flex::int_reshape(indices, Shape::from(alloc::vec![batch_size * seq_length]));
735        let output = Flex::float_select(weights, 0, indices);
736        Flex::float_reshape(
737            output,
738            Shape::from(alloc::vec![batch_size, seq_length, d_model]),
739        )
740    }
741
742    fn layer_norm(
743        tensor: FloatTensor<Flex>,
744        gamma: FloatTensor<Flex>,
745        beta: Option<FloatTensor<Flex>>,
746        epsilon: f64,
747    ) -> FloatTensor<Flex> {
748        crate::ops::activation::layer_norm(tensor, gamma, beta, epsilon)
749    }
750
751    fn embedding_backward(
752        weights: FloatTensor<Flex>,
753        output_grad: FloatTensor<Flex>,
754        indices: IntTensor<Flex>,
755    ) -> FloatTensor<Flex> {
756        let [batch_size, seq_length] = indices.shape().dims();
757        let [n_embeddings, d_model] = weights.shape().dims();
758        let dtype = output_grad.dtype();
759
760        let indices = Flex::int_reshape(indices, Shape::from(alloc::vec![batch_size * seq_length]));
761        let output_grad = Flex::float_reshape(
762            output_grad,
763            Shape::from(alloc::vec![batch_size * seq_length, d_model]),
764        );
765        let grad = Flex::float_zeros(
766            Shape::from(alloc::vec![n_embeddings, d_model]),
767            &Default::default(),
768            dtype.into(),
769        );
770        Flex::float_select_add(grad, 0, indices, output_grad)
771    }
772}