Skip to main content

burn_cubecl/ops/
tensor.rs

1use super::{expand, numeric, permute, unfold};
2use crate::CubeBackend;
3use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
4use crate::kernel::unary_basic::BasicFloatUnaryKind;
5use crate::kernel::{
6    self, FloatUnaryOp, FloatUnaryOpFamily, launch_unary_float, reduce, unary_basic,
7};
8use crate::{CubeRuntime, FloatElement, IntElement};
9use crate::{
10    element::BoolElement,
11    kernel::matmul::{MatmulStrategy, matmul},
12};
13use burn_backend::ops::GridSampleOptions;
14use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntTensor};
15use burn_backend::{DType, ElementConversion, FloatDType, Slice};
16use burn_backend::{Distribution, Shape, TensorData, ops::FloatTensorOps};
17use burn_backend::{ExecutionError, Scalar, get_device_settings};
18use burn_std::{BoolDType, IntDType};
19use cubecl::prelude::*;
20use cubek::reduce::components::instructions::ReduceOperationConfig;
21use std::ops::Range;
22
23impl<R, F, I, BT> FloatTensorOps<Self> for CubeBackend<R, F, I, BT>
24where
25    R: CubeRuntime,
26    F: FloatElement,
27    I: IntElement,
28    BT: BoolElement,
29{
30    #[cfg_attr(feature = "tracing", tracing::instrument(
31        level="trace",
32        skip(data),
33        fields(?data.shape, ?data.dtype)
34    ))]
35    fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
36        match data.dtype {
37            DType::F64 | DType::F32 | DType::F16 | DType::BF16 => super::from_data(data, device),
38            _ => unimplemented!("Unsupported dtype for `float_from_data`"),
39        }
40    }
41
42    fn float_random(
43        shape: Shape,
44        distribution: Distribution,
45        device: &Device<Self>,
46        dtype: FloatDType,
47    ) -> FloatTensor<Self> {
48        let dtype = dtype.into();
49        match distribution {
50            Distribution::Default => random_uniform(shape, device, 0., 1., dtype),
51            Distribution::Uniform(low, high) => {
52                random_uniform(shape, device, low.elem(), high.elem(), dtype)
53            }
54            Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob as f32, dtype),
55            Distribution::Normal(mean, std) => {
56                random_normal(shape, device, mean.elem(), std.elem(), dtype)
57            }
58        }
59    }
60
61    #[cfg_attr(feature = "tracing", tracing::instrument(
62        level="trace",
63        skip(tensor),
64        fields(from = ?tensor.device, meta = ?tensor.meta, dtype = ?tensor.dtype)
65    ))]
66    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
67        super::into_data(tensor).await
68    }
69
70    fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
71        tensor.device.clone()
72    }
73
74    #[cfg_attr(feature = "tracing", tracing::instrument(
75        level="trace",
76        skip(tensor),
77        fields(from = ?tensor.device, meta = ?tensor.meta, dtype = ?tensor.dtype)
78    ))]
79    fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
80        super::to_device(tensor, device)
81    }
82
83    fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
84        let dtype = dtype.into();
85        super::empty(shape, device, dtype)
86    }
87
88    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
89        numeric::add(lhs, rhs)
90    }
91
92    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
93        let dtype = lhs.dtype;
94        numeric::add_scalar(lhs, InputScalar::new(rhs, dtype))
95    }
96
97    fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
98        let dtype = dtype.into();
99        numeric::zeros(device.clone(), shape, dtype)
100    }
101
102    fn float_full(
103        shape: Shape,
104        fill_value: Scalar,
105        device: &R::Device,
106        dtype: FloatDType,
107    ) -> FloatTensor<Self> {
108        let dtype: DType = dtype.into();
109        let client = R::client(device);
110        numeric::full_device_dtype(
111            client,
112            shape,
113            device.clone(),
114            InputScalar::new(fill_value, dtype),
115            dtype,
116        )
117    }
118
119    fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
120        let dtype = dtype.into();
121        numeric::ones(device.clone(), shape, dtype)
122    }
123
124    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
125        numeric::sub(lhs, rhs)
126    }
127
128    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
129        let dtype = lhs.dtype;
130        numeric::sub_scalar(lhs, InputScalar::new(rhs, dtype))
131    }
132
133    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
134        numeric::mul(lhs, rhs)
135    }
136
137    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
138        let dtype = lhs.dtype;
139        numeric::mul_scalar(lhs, InputScalar::new(rhs, dtype))
140    }
141
142    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
143        numeric::div(lhs, rhs)
144    }
145
146    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
147        let dtype = lhs.dtype;
148        numeric::div_scalar(lhs, InputScalar::new(rhs, dtype))
149    }
150
151    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
152        numeric::remainder(lhs, rhs)
153    }
154
155    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
156        let dtype = lhs.dtype;
157        numeric::remainder_scalar(lhs, InputScalar::new(rhs, dtype))
158    }
159
160    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
161        let dtype = lhs.dtype;
162        matmul(lhs, rhs, None, MatmulStrategy::default(), dtype).unwrap()
163    }
164
165    fn float_cross(
166        lhs: FloatTensor<Self>,
167        rhs: FloatTensor<Self>,
168        dim: usize,
169    ) -> FloatTensor<Self> {
170        kernel::cross(lhs, rhs, dim)
171    }
172
173    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
174        super::swap_dims(tensor, dim1, dim2)
175    }
176
177    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
178        super::reshape(tensor, shape)
179    }
180
181    fn float_gather(
182        dim: usize,
183        tensor: FloatTensor<Self>,
184        indices: IntTensor<Self>,
185    ) -> FloatTensor<Self> {
186        kernel::gather(dim, tensor, indices)
187    }
188
189    fn float_scatter_add(
190        dim: usize,
191        tensor: FloatTensor<Self>,
192        indices: IntTensor<Self>,
193        value: FloatTensor<Self>,
194    ) -> FloatTensor<Self> {
195        kernel::scatter(dim, tensor, indices, value, false)
196    }
197
198    fn float_scatter_nd(
199        data: FloatTensor<Self>,
200        indices: IntTensor<Self>,
201        values: FloatTensor<Self>,
202        reduction: burn_backend::tensor::IndexingUpdateOp,
203    ) -> FloatTensor<Self> {
204        kernel::scatter_nd(data, indices, values, reduction)
205    }
206
207    fn float_gather_nd(data: FloatTensor<Self>, indices: IntTensor<Self>) -> FloatTensor<Self> {
208        kernel::gather_nd(data, indices)
209    }
210
211    fn float_select(
212        tensor: FloatTensor<Self>,
213        dim: usize,
214        indices: IntTensor<Self>,
215    ) -> FloatTensor<Self> {
216        kernel::select(tensor, dim, indices)
217    }
218
219    fn float_select_add(
220        tensor: FloatTensor<Self>,
221        dim: usize,
222        indices: IntTensor<Self>,
223        value: FloatTensor<Self>,
224    ) -> FloatTensor<Self> {
225        kernel::select_assign(tensor, dim, indices, value, false)
226    }
227
228    fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
229        // Check if all steps are 1
230        let all_steps_one = slices.iter().all(|info| info.step == 1);
231
232        if all_steps_one {
233            // Use optimized slice for step=1
234            let simple_ranges: Vec<Range<usize>> = slices
235                .iter()
236                .enumerate()
237                .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i]))
238                .collect();
239
240            kernel::slice(tensor, &simple_ranges)
241        } else {
242            // Use slice with steps kernel
243            kernel::slice_with_steps(tensor, slices)
244        }
245    }
246
247    fn float_slice_assign(
248        tensor: FloatTensor<Self>,
249        ranges: &[Slice],
250        value: FloatTensor<Self>,
251    ) -> FloatTensor<Self> {
252        kernel::slice_assign(tensor, ranges, value)
253    }
254
255    fn float_mask_where(
256        tensor: FloatTensor<Self>,
257        mask: BoolTensor<Self>,
258        value: FloatTensor<Self>,
259    ) -> FloatTensor<Self> {
260        let bool_dtype = mask.dtype;
261        kernel::mask_where_auto(tensor, mask, value, bool_dtype)
262    }
263
264    fn float_mask_fill(
265        tensor: FloatTensor<Self>,
266        mask: BoolTensor<Self>,
267        value: Scalar,
268    ) -> FloatTensor<Self> {
269        let dtype = tensor.dtype;
270        let bool_dtype = mask.dtype;
271        kernel::mask_fill_auto(tensor, mask, InputScalar::new(value, dtype), bool_dtype)
272    }
273
274    fn float_equal(
275        lhs: FloatTensor<Self>,
276        rhs: FloatTensor<Self>,
277        out_dtype: BoolDType,
278    ) -> BoolTensor<Self> {
279        kernel::equal(lhs, rhs, out_dtype.into())
280    }
281
282    fn float_equal_elem(
283        lhs: FloatTensor<Self>,
284        rhs: Scalar,
285        out_dtype: BoolDType,
286    ) -> BoolTensor<Self> {
287        let dtype = lhs.dtype;
288        kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
289    }
290
291    fn float_greater(
292        lhs: FloatTensor<Self>,
293        rhs: FloatTensor<Self>,
294        out_dtype: BoolDType,
295    ) -> BoolTensor<Self> {
296        kernel::greater(lhs, rhs, out_dtype.into())
297    }
298
299    fn float_greater_elem(
300        lhs: FloatTensor<Self>,
301        rhs: Scalar,
302        out_dtype: BoolDType,
303    ) -> BoolTensor<Self> {
304        let dtype = lhs.dtype;
305        kernel::greater_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
306    }
307
308    fn float_greater_equal(
309        lhs: FloatTensor<Self>,
310        rhs: FloatTensor<Self>,
311        out_dtype: BoolDType,
312    ) -> BoolTensor<Self> {
313        kernel::greater_equal(lhs, rhs, out_dtype.into())
314    }
315
316    fn float_greater_equal_elem(
317        lhs: FloatTensor<Self>,
318        rhs: Scalar,
319        out_dtype: BoolDType,
320    ) -> BoolTensor<Self> {
321        let dtype = lhs.dtype;
322        kernel::greater_equal_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
323    }
324
325    fn float_lower(
326        lhs: FloatTensor<Self>,
327        rhs: FloatTensor<Self>,
328        out_dtype: BoolDType,
329    ) -> BoolTensor<Self> {
330        kernel::lower(lhs, rhs, out_dtype.into())
331    }
332
333    fn float_lower_elem(
334        lhs: FloatTensor<Self>,
335        rhs: Scalar,
336        out_dtype: BoolDType,
337    ) -> BoolTensor<Self> {
338        let dtype = lhs.dtype;
339        kernel::lower_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
340    }
341
342    fn float_lower_equal(
343        lhs: FloatTensor<Self>,
344        rhs: FloatTensor<Self>,
345        out_dtype: BoolDType,
346    ) -> BoolTensor<Self> {
347        kernel::lower_equal(lhs, rhs, out_dtype.into())
348    }
349
350    fn float_lower_equal_elem(
351        lhs: FloatTensor<Self>,
352        rhs: Scalar,
353        out_dtype: BoolDType,
354    ) -> BoolTensor<Self> {
355        let dtype = lhs.dtype;
356        kernel::lower_equal_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
357    }
358
359    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
360        reduce::sum_fallback(tensor, Default::default()).unwrap()
361    }
362
363    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
364        reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Max).unwrap()
365    }
366
367    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
368        reduce::reduce_dim(
369            tensor,
370            None,
371            dim,
372            Default::default(),
373            ReduceOperationConfig::Max,
374        )
375        .unwrap()
376    }
377
378    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
379        reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Min).unwrap()
380    }
381
382    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
383        reduce::reduce_dim(
384            tensor,
385            None,
386            dim,
387            Default::default(),
388            ReduceOperationConfig::Min,
389        )
390        .unwrap()
391    }
392
393    fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
394        reduce::reduce(
395            tensor,
396            None,
397            Default::default(),
398            ReduceOperationConfig::MaxAbs,
399        )
400        .unwrap()
401    }
402
403    fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
404        reduce::reduce_dim(
405            tensor,
406            None,
407            dim,
408            Default::default(),
409            ReduceOperationConfig::MaxAbs,
410        )
411        .unwrap()
412    }
413
414    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
415        reduce::reduce_dim(
416            tensor,
417            None,
418            dim,
419            Default::default(),
420            ReduceOperationConfig::Sum,
421        )
422        .unwrap()
423    }
424
425    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
426        reduce::reduce_dim(
427            tensor,
428            None,
429            dim,
430            Default::default(),
431            ReduceOperationConfig::Mean,
432        )
433        .unwrap()
434    }
435
436    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
437        reduce::reduce(
438            tensor,
439            None,
440            Default::default(),
441            ReduceOperationConfig::Mean,
442        )
443        .unwrap()
444    }
445
446    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
447        numeric::cumsum(tensor, dim)
448    }
449
450    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
451        numeric::cumprod(tensor, dim)
452    }
453
454    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
455        numeric::cummin(tensor, dim)
456    }
457
458    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
459        numeric::cummax(tensor, dim)
460    }
461
462    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
463        reduce::reduce(
464            tensor,
465            None,
466            Default::default(),
467            ReduceOperationConfig::Prod,
468        )
469        .unwrap()
470    }
471
472    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
473        reduce::reduce_dim(
474            tensor,
475            None,
476            dim,
477            Default::default(),
478            ReduceOperationConfig::Prod,
479        )
480        .unwrap()
481    }
482
483    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
484        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Exp)
485    }
486
487    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
488        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Log)
489    }
490
491    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
492        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Log1p)
493    }
494
495    fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
496        struct Powf;
497
498        #[cube]
499        impl<F: Float, N: Size> FloatUnaryOp<F, N> for Powf {
500            type Options = InputScalar;
501
502            fn execute(input: Vector<F, N>, options: &Self::Options) -> Vector<F, N> {
503                Vector::powf(input, Vector::new(options.get::<F>()))
504            }
505        }
506
507        impl FloatUnaryOpFamily for Powf {
508            type Options = InputScalar;
509            type Unary<F: Float, N: Size> = Self;
510        }
511
512        let dtype = lhs.dtype;
513        launch_unary_float::<R, Powf, _>(lhs, |_| InputScalar::new(rhs, dtype))
514    }
515
516    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
517        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sqrt)
518    }
519
520    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
521        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Abs)
522    }
523
524    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
525        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sign)
526    }
527
528    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
529        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Cos)
530    }
531
532    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
533        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sin)
534    }
535
536    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
537        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Tan)
538    }
539
540    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
541        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Cosh)
542    }
543
544    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
545        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sinh)
546    }
547
548    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
549        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Tanh)
550    }
551
552    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
553        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcCos)
554    }
555
556    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
557        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcCosh)
558    }
559
560    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
561        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcSin)
562    }
563
564    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
565        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcSinh)
566    }
567
568    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
569        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcTan)
570    }
571
572    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
573        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcTanh)
574    }
575
576    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
577        crate::kernel::atan2::<R>(lhs, rhs)
578    }
579
580    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
581        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Round)
582    }
583
584    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
585        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Floor)
586    }
587
588    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
589        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Ceil)
590    }
591
592    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
593        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Trunc)
594    }
595
596    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
597        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Erf)
598    }
599
600    fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
601        reduce::reduce_dim(
602            tensor,
603            Some(out_dtype.into()),
604            dim,
605            Default::default(),
606            ReduceOperationConfig::ArgMax,
607        )
608        .unwrap()
609    }
610
611    fn float_argtopk(
612        tensor: FloatTensor<Self>,
613        dim: usize,
614        k: usize,
615        out_dtype: IntDType,
616    ) -> IntTensor<Self> {
617        reduce::reduce_dim(
618            tensor,
619            Some(out_dtype.into()),
620            dim,
621            Default::default(),
622            ReduceOperationConfig::ArgTopK(k),
623        )
624        .unwrap()
625    }
626
627    fn float_topk(tensor: FloatTensor<Self>, dim: usize, k: usize) -> FloatTensor<Self> {
628        reduce::reduce_dim(
629            tensor,
630            None,
631            dim,
632            Default::default(),
633            ReduceOperationConfig::TopK(k),
634        )
635        .unwrap()
636    }
637
638    fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
639        reduce::reduce_dim(
640            tensor,
641            Some(out_dtype.into()),
642            dim,
643            Default::default(),
644            ReduceOperationConfig::ArgMin,
645        )
646        .unwrap()
647    }
648
649    fn float_into_int(tensor: FloatTensor<Self>, out_dtype: IntDType) -> IntTensor<Self> {
650        kernel::cast(tensor, out_dtype.into())
651    }
652
653    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
654        let dtype = tensor.dtype;
655        kernel::clamp(
656            tensor,
657            InputScalar::new(min, dtype),
658            InputScalar::new(max, dtype),
659        )
660    }
661
662    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
663        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Recip)
664    }
665
666    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
667        kernel::repeat_dim(tensor, dim, times)
668    }
669
670    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
671        numeric::pow(lhs, rhs)
672    }
673
674    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
675        permute(tensor, axes)
676    }
677
678    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
679        expand(tensor, shape)
680    }
681
682    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
683        let bool_dtype = get_device_settings::<Self>(&tensor.device).bool_dtype;
684        kernel::flip(tensor, axes, bool_dtype.into())
685    }
686
687    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
688        kernel::cast(tensor, dtype.into())
689    }
690
691    fn float_unfold(
692        tensor: FloatTensor<Self>,
693        dim: usize,
694        size: usize,
695        step: usize,
696    ) -> FloatTensor<Self> {
697        unfold(tensor, dim, size, step)
698    }
699
700    fn float_is_nan(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
701        kernel::is_nan(tensor, out_dtype.into())
702    }
703
704    fn float_is_inf(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
705        kernel::is_inf(tensor, out_dtype.into())
706    }
707
708    fn float_grid_sample_2d(
709        tensor: FloatTensor<Self>,
710        grid: FloatTensor<Self>,
711        options: GridSampleOptions,
712    ) -> FloatTensor<Self> {
713        kernel::grid_sample::grid_sample(tensor, grid, options)
714    }
715}