burn_cubecl/ops/
float_ops.rs

1use super::{expand, numeric, permute, unfold};
2use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
3use crate::kernel::{
4    self, FloatUnaryOp, FloatUnaryOpFamily, launch_unary_float, reduce, unary_basic,
5};
6use crate::kernel::{into_contiguous, unary_basic::BasicFloatUnaryKind};
7use crate::{CubeBackend, execute_with_dtype};
8use crate::{CubeRuntime, FloatElement, IntElement};
9use crate::{
10    element::BoolElement,
11    kernel::matmul::{MatmulStrategy, matmul},
12};
13use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
14use burn_tensor::{DType, ElementConversion, FloatDType};
15use burn_tensor::{Distribution, Shape, TensorData, ops::FloatTensorOps};
16use cubecl::prelude::*;
17use cubecl::reduce::ReducePrecision;
18use cubecl::reduce::instructions::ReduceFnConfig;
19use half::{bf16, f16};
20use std::ops::Range;
21
22impl<R, F, I, BT> FloatTensorOps<Self> for CubeBackend<R, F, I, BT>
23where
24    R: CubeRuntime,
25    F: FloatElement,
26    I: IntElement,
27    BT: BoolElement,
28{
29    fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
30        match data.dtype {
31            DType::F64 | DType::F32 | DType::F16 | DType::BF16 => {
32                super::from_data::<R>(data, device)
33            }
34            _ => unimplemented!("Unsupported dtype for `float_from_data`"),
35        }
36    }
37
38    fn float_random(
39        shape: Shape,
40        distribution: Distribution,
41        device: &Device<Self>,
42    ) -> FloatTensor<Self> {
43        match distribution {
44            Distribution::Default => random_uniform(shape, device, 0.elem::<F>(), 1.elem()),
45            Distribution::Uniform(low, high) => {
46                random_uniform(shape, device, low.elem::<F>(), high.elem())
47            }
48            Distribution::Bernoulli(prob) => random_bernoulli::<R, F>(shape, device, prob as f32),
49            Distribution::Normal(mean, std) => {
50                random_normal(shape, device, mean.elem::<F>(), std.elem())
51            }
52        }
53    }
54
55    async fn float_into_data(tensor: FloatTensor<Self>) -> TensorData {
56        execute_with_dtype!(
57            float(tensor.dtype),
58            E,
59            super::into_data::<R, E>(tensor).await
60        )
61    }
62
63    fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
64        tensor.device.clone()
65    }
66
67    fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
68        super::to_device(tensor, device)
69    }
70
71    fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
72        let dtype = dtype.into();
73        execute_with_dtype!(float(dtype), E, super::empty::<R, E>(shape, device))
74    }
75
76    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
77        execute_with_dtype!(
78            float(lhs.dtype, rhs.dtype),
79            E,
80            numeric::add::<R, E>(lhs, rhs)
81        )
82    }
83
84    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
85        execute_with_dtype!(
86            float(lhs.dtype),
87            E,
88            numeric::add_scalar::<R, E>(lhs, rhs.elem())
89        )
90    }
91
92    fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
93        let dtype = dtype.into();
94        execute_with_dtype!(float(dtype), E, numeric::zeros::<R, E>(shape, device))
95    }
96
97    fn float_full(
98        shape: Shape,
99        fill_value: FloatElem<Self>,
100        device: &R::Device,
101        dtype: FloatDType,
102    ) -> FloatTensor<Self> {
103        let dtype = dtype.into();
104        execute_with_dtype!(
105            float(dtype),
106            E,
107            numeric::full::<R, E>(shape, device, fill_value.elem())
108        )
109    }
110
111    fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
112        let dtype = dtype.into();
113        execute_with_dtype!(float(dtype), E, numeric::ones::<R, E>(shape, device))
114    }
115
116    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
117        execute_with_dtype!(
118            float(lhs.dtype, rhs.dtype),
119            E,
120            numeric::sub::<R, E>(lhs, rhs)
121        )
122    }
123
124    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
125        execute_with_dtype!(
126            float(lhs.dtype),
127            E,
128            numeric::sub_scalar::<R, E>(lhs, rhs.elem())
129        )
130    }
131
132    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
133        execute_with_dtype!(
134            float(lhs.dtype, rhs.dtype),
135            E,
136            numeric::mul::<R, E>(lhs, rhs)
137        )
138    }
139
140    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
141        execute_with_dtype!(
142            float(lhs.dtype),
143            E,
144            numeric::mul_scalar::<R, E>(lhs, rhs.elem())
145        )
146    }
147
148    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
149        execute_with_dtype!(
150            float(lhs.dtype, rhs.dtype),
151            E,
152            numeric::div::<R, E>(lhs, rhs)
153        )
154    }
155
156    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
157        execute_with_dtype!(
158            float(lhs.dtype),
159            E,
160            numeric::div_scalar::<R, E>(lhs, rhs.elem())
161        )
162    }
163
164    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
165        execute_with_dtype!(
166            float(lhs.dtype, rhs.dtype),
167            E,
168            numeric::remainder::<R, E>(lhs, rhs)
169        )
170    }
171
172    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
173        execute_with_dtype!(
174            float(lhs.dtype),
175            E,
176            numeric::remainder_scalar::<R, E>(lhs, rhs.elem())
177        )
178    }
179
180    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
181        execute_with_dtype!(
182            float(lhs.dtype, rhs.dtype),
183            E,
184            matmul::<R, E>(lhs, rhs, None, MatmulStrategy::default()).unwrap()
185        )
186    }
187
188    fn float_cross(
189        lhs: FloatTensor<Self>,
190        rhs: FloatTensor<Self>,
191        dim: usize,
192    ) -> FloatTensor<Self> {
193        execute_with_dtype!(
194            float(lhs.dtype, rhs.dtype),
195            E,
196            kernel::cross::<R, E>(lhs, rhs, dim)
197        )
198    }
199
200    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
201        super::swap_dims(tensor, dim1, dim2)
202    }
203
204    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
205        super::reshape(tensor, shape)
206    }
207
208    fn float_gather(
209        dim: usize,
210        tensor: FloatTensor<Self>,
211        indices: IntTensor<Self>,
212    ) -> FloatTensor<Self> {
213        execute_with_dtype!(
214            int(indices.dtype),
215            I,
216            execute_with_dtype!(
217                float(tensor.dtype),
218                E,
219                kernel::gather::<R, E, I>(dim, tensor, indices)
220            )
221        )
222    }
223
224    fn float_scatter(
225        dim: usize,
226        tensor: FloatTensor<Self>,
227        indices: IntTensor<Self>,
228        value: FloatTensor<Self>,
229    ) -> FloatTensor<Self> {
230        execute_with_dtype!(
231            int(indices.dtype),
232            I,
233            execute_with_dtype!(
234                float(tensor.dtype, value.dtype),
235                E,
236                kernel::scatter::<R, E, I>(dim, tensor, indices, value)
237            )
238        )
239    }
240
241    fn float_select(
242        tensor: FloatTensor<Self>,
243        dim: usize,
244        indices: IntTensor<Self>,
245    ) -> FloatTensor<Self> {
246        execute_with_dtype!(
247            int(indices.dtype),
248            I,
249            execute_with_dtype!(
250                float(tensor.dtype),
251                E,
252                kernel::select::<R, E, I>(tensor, dim, indices)
253            )
254        )
255    }
256
257    fn float_select_assign(
258        tensor: FloatTensor<Self>,
259        dim: usize,
260        indices: IntTensor<Self>,
261        value: FloatTensor<Self>,
262    ) -> FloatTensor<Self> {
263        execute_with_dtype!(
264            int(indices.dtype),
265            I,
266            execute_with_dtype!(
267                float(tensor.dtype, value.dtype),
268                E,
269                kernel::select_assign::<R, E, I>(tensor, dim, indices, value, false)
270            )
271        )
272    }
273
274    fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_tensor::Slice]) -> FloatTensor<Self> {
275        // Check if all steps are 1
276        let all_steps_one = slices.iter().all(|info| info.step == 1);
277
278        if all_steps_one {
279            // Use optimized slice for step=1
280            let simple_ranges: Vec<Range<usize>> = slices
281                .iter()
282                .enumerate()
283                .map(|(i, slice)| slice.to_range(tensor.shape[i]))
284                .collect();
285
286            execute_with_dtype!(
287                float(tensor.dtype),
288                E,
289                kernel::slice::<R, E>(tensor, &simple_ranges)
290            )
291        } else {
292            // Use slice with steps kernel
293            execute_with_dtype!(
294                float(tensor.dtype),
295                E,
296                kernel::slice_with_steps::<R, E>(tensor, slices)
297            )
298        }
299    }
300
301    fn float_slice_assign(
302        tensor: FloatTensor<Self>,
303        ranges: &[burn_tensor::Slice],
304        value: FloatTensor<Self>,
305    ) -> FloatTensor<Self> {
306        execute_with_dtype!(
307            float(tensor.dtype, value.dtype),
308            E,
309            kernel::slice_assign::<R, E>(tensor, ranges, value)
310        )
311    }
312
313    fn float_mask_where(
314        tensor: FloatTensor<Self>,
315        mask: BoolTensor<Self>,
316        value: FloatTensor<Self>,
317    ) -> FloatTensor<Self> {
318        execute_with_dtype!(
319            float(tensor.dtype, value.dtype),
320            E,
321            kernel::mask_where_auto::<R, E, BT>(tensor, mask, value)
322        )
323    }
324
325    fn float_mask_fill(
326        tensor: FloatTensor<Self>,
327        mask: BoolTensor<Self>,
328        value: FloatElem<Self>,
329    ) -> FloatTensor<Self> {
330        execute_with_dtype!(
331            float(tensor.dtype),
332            E,
333            kernel::mask_fill_auto::<R, E, BT>(tensor, mask, value.elem())
334        )
335    }
336
337    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
338        execute_with_dtype!(
339            float(lhs.dtype, rhs.dtype),
340            E,
341            kernel::equal::<R, E, BT>(lhs, rhs)
342        )
343    }
344
345    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
346        execute_with_dtype!(
347            float(lhs.dtype),
348            E,
349            kernel::equal_elem::<R, E, BT>(lhs, rhs.elem())
350        )
351    }
352
353    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
354        execute_with_dtype!(
355            float(lhs.dtype, rhs.dtype),
356            E,
357            kernel::greater::<R, E, BT>(lhs, rhs)
358        )
359    }
360
361    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
362        execute_with_dtype!(
363            float(lhs.dtype),
364            E,
365            kernel::greater_elem::<R, E, BT>(lhs, rhs.elem())
366        )
367    }
368
369    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
370        execute_with_dtype!(
371            float(lhs.dtype, rhs.dtype),
372            E,
373            kernel::greater_equal::<R, E, BT>(lhs, rhs)
374        )
375    }
376
377    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
378        execute_with_dtype!(
379            float(lhs.dtype),
380            E,
381            kernel::greater_equal_elem::<R, E, BT>(lhs, rhs.elem())
382        )
383    }
384
385    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
386        execute_with_dtype!(
387            float(lhs.dtype, rhs.dtype),
388            E,
389            kernel::lower::<R, E, BT>(lhs, rhs)
390        )
391    }
392
393    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
394        execute_with_dtype!(
395            float(lhs.dtype),
396            E,
397            kernel::lower_elem::<R, E, BT>(lhs, rhs.elem())
398        )
399    }
400
401    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
402        execute_with_dtype!(
403            float(lhs.dtype, rhs.dtype),
404            E,
405            kernel::lower_equal::<R, E, BT>(lhs, rhs)
406        )
407    }
408
409    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
410        execute_with_dtype!(
411            float(lhs.dtype),
412            E,
413            kernel::lower_equal_elem::<R, E, BT>(lhs, rhs.elem())
414        )
415    }
416
417    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
418        let tensor = into_contiguous::<R>(tensor);
419        execute_with_dtype!(
420            float(tensor.dtype),
421            E,
422            reduce::sum_fallback::<R, E>(tensor, Default::default()).unwrap()
423        )
424    }
425
426    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
427        execute_with_dtype!(
428            float(tensor.dtype),
429            E,
430            reduce::reduce::<R, E, E, E>(tensor, Default::default(), ReduceFnConfig::Max).unwrap()
431        )
432    }
433
434    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
435        execute_with_dtype!(
436            float(tensor.dtype),
437            E,
438            reduce::reduce_dim::<R, E, E, E>(tensor, dim, Default::default(), ReduceFnConfig::Max)
439                .unwrap()
440        )
441    }
442
443    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
444        execute_with_dtype!(
445            float(tensor.dtype),
446            E,
447            reduce::reduce::<R, E, E, E>(tensor, Default::default(), ReduceFnConfig::Min).unwrap()
448        )
449    }
450
451    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
452        execute_with_dtype!(
453            float(tensor.dtype),
454            E,
455            reduce::reduce_dim::<R, E, E, E>(tensor, dim, Default::default(), ReduceFnConfig::Min)
456                .unwrap()
457        )
458    }
459
460    fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
461        execute_with_dtype!(
462            float(tensor.dtype),
463            E,
464            reduce::reduce::<R, E, E, E>(tensor, Default::default(), ReduceFnConfig::MaxAbs)
465                .unwrap()
466        )
467    }
468
469    fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
470        execute_with_dtype!(
471            float(tensor.dtype),
472            E,
473            reduce::reduce_dim::<R, E, E, E>(
474                tensor,
475                dim,
476                Default::default(),
477                ReduceFnConfig::MaxAbs
478            )
479            .unwrap()
480        )
481    }
482
483    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
484        execute_with_dtype!(
485            float(tensor.dtype),
486            E,
487            reduce::reduce_dim::<R, E, E, <E as ReducePrecision>::EA>(
488                tensor,
489                dim,
490                Default::default(),
491                ReduceFnConfig::Sum
492            )
493            .unwrap()
494        )
495    }
496
497    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
498        execute_with_dtype!(
499            float(tensor.dtype),
500            E,
501            reduce::reduce_dim::<R, E, E, <E as ReducePrecision>::EA>(
502                tensor,
503                dim,
504                Default::default(),
505                ReduceFnConfig::Mean
506            )
507            .unwrap()
508        )
509    }
510
511    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
512        execute_with_dtype!(float(tensor.dtype), E, numeric::cumsum::<R, E>(tensor, dim))
513    }
514
515    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
516        execute_with_dtype!(
517            float(tensor.dtype),
518            E,
519            numeric::cumprod::<R, E>(tensor, dim)
520        )
521    }
522
523    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
524        execute_with_dtype!(float(tensor.dtype), E, numeric::cummin::<R, E>(tensor, dim))
525    }
526
527    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
528        execute_with_dtype!(float(tensor.dtype), E, numeric::cummax::<R, E>(tensor, dim))
529    }
530
531    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
532        execute_with_dtype!(
533            float(tensor.dtype),
534            E,
535            reduce::reduce::<R, E, E, <E as ReducePrecision>::EA>(
536                tensor,
537                Default::default(),
538                ReduceFnConfig::Prod
539            )
540            .unwrap()
541        )
542    }
543
544    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
545        execute_with_dtype!(
546            float(tensor.dtype),
547            E,
548            reduce::reduce_dim::<R, E, E, <E as ReducePrecision>::EA>(
549                tensor,
550                dim,
551                Default::default(),
552                ReduceFnConfig::Prod
553            )
554            .unwrap()
555        )
556    }
557
558    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
559        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Exp)
560    }
561
562    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
563        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Log)
564    }
565
566    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
567        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Log1p)
568    }
569
570    fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: f32) -> FloatTensor<Self> {
571        struct Powf;
572
573        #[cube]
574        impl<F: Float> FloatUnaryOp<F> for Powf {
575            type Options = F;
576
577            fn execute(input: Line<F>, options: &Self::Options) -> Line<F> {
578                Line::powf(input, Line::new(*options))
579            }
580        }
581
582        impl FloatUnaryOpFamily for Powf {
583            type Options<F: Float> = F;
584            type Unary<F: Float> = Self;
585        }
586
587        execute_with_dtype!(
588            float(lhs.dtype),
589            F,
590            launch_unary_float::<R, F, Powf, _>(lhs, |_| ScalarArg::new(rhs.elem::<F>()))
591        )
592    }
593
594    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
595        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sqrt)
596    }
597
598    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
599        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Abs)
600    }
601
602    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
603        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Cos)
604    }
605
606    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
607        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sin)
608    }
609
610    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
611        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Tanh)
612    }
613
614    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
615        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Round)
616    }
617
618    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
619        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Floor)
620    }
621
622    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
623        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Ceil)
624    }
625
626    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
627        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Trunc)
628    }
629
630    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
631        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Erf)
632    }
633
634    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
635        execute_with_dtype!(
636            float(tensor.dtype),
637            E,
638            reduce::reduce_dim::<R, E, I, E>(
639                tensor,
640                dim,
641                Default::default(),
642                ReduceFnConfig::ArgMax
643            )
644            .unwrap()
645        )
646    }
647
648    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
649        execute_with_dtype!(
650            float(tensor.dtype),
651            E,
652            reduce::reduce_dim::<R, E, I, E>(
653                tensor,
654                dim,
655                Default::default(),
656                ReduceFnConfig::ArgMin
657            )
658            .unwrap()
659        )
660    }
661
662    fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
663        execute_with_dtype!(float(tensor.dtype), E, kernel::cast::<R, E, I>(tensor))
664    }
665
666    fn float_clamp(
667        tensor: FloatTensor<Self>,
668        min: FloatElem<Self>,
669        max: FloatElem<Self>,
670    ) -> FloatTensor<Self> {
671        execute_with_dtype!(
672            float(tensor.dtype),
673            E,
674            kernel::clamp::<R, E>(tensor, min.elem(), max.elem())
675        )
676    }
677
678    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
679        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Recip)
680    }
681
682    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
683        execute_with_dtype!(
684            float(tensor.dtype),
685            E,
686            kernel::repeat_dim::<R, E>(tensor, dim, times)
687        )
688    }
689
690    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
691        execute_with_dtype!(float(lhs.dtype), E, numeric::pow::<R, E>(lhs, rhs))
692    }
693
694    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
695        permute(tensor, axes)
696    }
697
698    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
699        expand(tensor, shape)
700    }
701
702    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
703        execute_with_dtype!(
704            float(tensor.dtype),
705            E,
706            kernel::flip::<R, E, BT>(tensor, axes)
707        )
708    }
709
710    fn float_cast(mut tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
711        match (tensor.dtype, dtype) {
712            (DType::F64, FloatDType::F64)
713            | (DType::F32, FloatDType::F32)
714            | (DType::Flex32, FloatDType::Flex32)
715            | (DType::BF16, FloatDType::BF16)
716            | (DType::F16, FloatDType::F16) => tensor,
717            (DType::F32, FloatDType::Flex32) | (DType::Flex32, FloatDType::F32) => {
718                tensor.dtype = dtype.into();
719                tensor
720            }
721            (DType::F64, FloatDType::F32) => kernel::cast::<R, f64, f32>(tensor),
722            (DType::F64, FloatDType::Flex32) => kernel::cast::<R, f64, flex32>(tensor),
723            (DType::F64, FloatDType::F16) => kernel::cast::<R, f64, f16>(tensor),
724            (DType::F64, FloatDType::BF16) => kernel::cast::<R, f64, bf16>(tensor),
725            (DType::F32, FloatDType::F64) => kernel::cast::<R, f32, f64>(tensor),
726            (DType::F32, FloatDType::F16) => kernel::cast::<R, f32, f16>(tensor),
727            (DType::F32, FloatDType::BF16) => kernel::cast::<R, f32, bf16>(tensor),
728            (DType::Flex32, FloatDType::F64) => kernel::cast::<R, flex32, f64>(tensor),
729            (DType::Flex32, FloatDType::F16) => kernel::cast::<R, flex32, f16>(tensor),
730            (DType::Flex32, FloatDType::BF16) => kernel::cast::<R, flex32, bf16>(tensor),
731            (DType::F16, FloatDType::F64) => kernel::cast::<R, f16, f64>(tensor),
732            (DType::F16, FloatDType::F32) => kernel::cast::<R, f16, f32>(tensor),
733            (DType::F16, FloatDType::Flex32) => kernel::cast::<R, f16, flex32>(tensor),
734            (DType::F16, FloatDType::BF16) => kernel::cast::<R, f16, bf16>(tensor),
735            (DType::BF16, FloatDType::F64) => kernel::cast::<R, bf16, f64>(tensor),
736            (DType::BF16, FloatDType::F32) => kernel::cast::<R, bf16, f32>(tensor),
737            (DType::BF16, FloatDType::Flex32) => kernel::cast::<R, bf16, flex32>(tensor),
738            (DType::BF16, FloatDType::F16) => kernel::cast::<R, bf16, f16>(tensor),
739            _ => unimplemented!("Unsupported floating point type cast"),
740        }
741    }
742
743    fn float_unfold(
744        tensor: FloatTensor<Self>,
745        dim: usize,
746        size: usize,
747        step: usize,
748    ) -> FloatTensor<Self> {
749        unfold(tensor, dim, size, step)
750    }
751
752    fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
753        execute_with_dtype!(float(tensor.dtype), E, kernel::is_nan::<R, E, BT>(tensor))
754    }
755
756    fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
757        execute_with_dtype!(float(tensor.dtype), E, kernel::is_inf::<R, E, BT>(tensor))
758    }
759}