burn_cubecl/ops/
tensor.rs

1use super::{expand, numeric, permute, unfold};
2use crate::CubeBackend;
3use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
4use crate::kernel::{
5    self, FloatUnaryOp, FloatUnaryOpFamily, launch_unary_float, reduce, unary_basic,
6};
7use crate::kernel::{into_contiguous, unary_basic::BasicFloatUnaryKind};
8use crate::{CubeRuntime, FloatElement, IntElement};
9use crate::{
10    element::BoolElement,
11    kernel::matmul::{MatmulStrategy, matmul},
12};
13use burn_backend::ops::GridSampleOptions;
14use burn_backend::tensor::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
15use burn_backend::{Backend, ExecutionError};
16use burn_backend::{DType, ElementConversion, FloatDType, Slice};
17use burn_backend::{Distribution, Shape, TensorData, ops::FloatTensorOps};
18use cubecl::prelude::*;
19use cubek::reduce::components::instructions::ReduceOperationConfig;
20use std::ops::Range;
21
22impl<R, F, I, BT> FloatTensorOps<Self> for CubeBackend<R, F, I, BT>
23where
24    R: CubeRuntime,
25    F: FloatElement,
26    I: IntElement,
27    BT: BoolElement,
28{
29    #[cfg_attr(feature = "tracing", tracing::instrument(
30        level="trace",
31        skip(data),
32        fields(?data.shape, ?data.dtype)
33    ))]
34    fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
35        match data.dtype {
36            DType::F64 | DType::F32 | DType::F16 | DType::BF16 => super::from_data(data, device),
37            _ => unimplemented!("Unsupported dtype for `float_from_data`"),
38        }
39    }
40
41    fn float_random(
42        shape: Shape,
43        distribution: Distribution,
44        device: &Device<Self>,
45    ) -> FloatTensor<Self> {
46        let dtype = FloatElem::<Self>::dtype();
47        match distribution {
48            Distribution::Default => random_uniform(shape, device, 0., 1., dtype),
49            Distribution::Uniform(low, high) => {
50                random_uniform(shape, device, low.elem(), high.elem(), dtype)
51            }
52            Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob as f32, dtype),
53            Distribution::Normal(mean, std) => {
54                random_normal(shape, device, mean.elem(), std.elem(), dtype)
55            }
56        }
57    }
58
59    #[cfg_attr(feature = "tracing", tracing::instrument(
60        level="trace",
61        skip(tensor),
62        fields(from = ?tensor.device, shape = ?tensor.shape, dtype = ?tensor.dtype)
63    ))]
64    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
65        super::into_data(tensor).await
66    }
67
68    fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
69        tensor.device.clone()
70    }
71
72    #[cfg_attr(feature = "tracing", tracing::instrument(
73        level="trace",
74        skip(tensor),
75        fields(from = ?tensor.device, shape = ?tensor.shape, dtype = ?tensor.dtype)
76    ))]
77    fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
78        super::to_device(tensor, device)
79    }
80
81    fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
82        let dtype = dtype.into();
83        super::empty(shape, device, dtype)
84    }
85
86    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
87        numeric::add(lhs, rhs)
88    }
89
90    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
91        let dtype = lhs.dtype;
92        numeric::add_scalar(lhs, InputScalar::new(rhs, dtype))
93    }
94
95    fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
96        let dtype = dtype.into();
97        numeric::zeros(device.clone(), shape, dtype)
98    }
99
100    fn float_full(
101        shape: Shape,
102        fill_value: FloatElem<Self>,
103        device: &R::Device,
104        dtype: FloatDType,
105    ) -> FloatTensor<Self> {
106        let dtype: DType = dtype.into();
107        let client = R::client(device);
108        numeric::full_device_dtype(
109            client,
110            shape,
111            device.clone(),
112            InputScalar::new(fill_value, dtype),
113            dtype,
114        )
115    }
116
117    fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
118        let dtype = dtype.into();
119        numeric::ones(device.clone(), shape, dtype)
120    }
121
122    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
123        numeric::sub(lhs, rhs)
124    }
125
126    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
127        let dtype = lhs.dtype;
128        numeric::sub_scalar(lhs, InputScalar::new(rhs, dtype))
129    }
130
131    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
132        numeric::mul(lhs, rhs)
133    }
134
135    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
136        let dtype = lhs.dtype;
137        numeric::mul_scalar(lhs, InputScalar::new(rhs, dtype))
138    }
139
140    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
141        numeric::div(lhs, rhs)
142    }
143
144    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
145        let dtype = lhs.dtype;
146        numeric::div_scalar(lhs, InputScalar::new(rhs, dtype))
147    }
148
149    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
150        numeric::remainder(lhs, rhs)
151    }
152
153    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
154        let dtype = lhs.dtype;
155        numeric::remainder_scalar(lhs, InputScalar::new(rhs, dtype))
156    }
157
158    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
159        let dtype = lhs.dtype;
160        matmul(lhs, rhs, None, MatmulStrategy::default(), dtype).unwrap()
161    }
162
163    fn float_cross(
164        lhs: FloatTensor<Self>,
165        rhs: FloatTensor<Self>,
166        dim: usize,
167    ) -> FloatTensor<Self> {
168        kernel::cross(lhs, rhs, dim)
169    }
170
171    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
172        super::swap_dims(tensor, dim1, dim2)
173    }
174
175    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
176        super::reshape(tensor, shape)
177    }
178
179    fn float_gather(
180        dim: usize,
181        tensor: FloatTensor<Self>,
182        indices: IntTensor<Self>,
183    ) -> FloatTensor<Self> {
184        kernel::gather(dim, tensor, indices)
185    }
186
187    fn float_scatter_add(
188        dim: usize,
189        tensor: FloatTensor<Self>,
190        indices: IntTensor<Self>,
191        value: FloatTensor<Self>,
192    ) -> FloatTensor<Self> {
193        kernel::scatter(dim, tensor, indices, value, false)
194    }
195
196    fn float_select(
197        tensor: FloatTensor<Self>,
198        dim: usize,
199        indices: IntTensor<Self>,
200    ) -> FloatTensor<Self> {
201        kernel::select(tensor, dim, indices)
202    }
203
204    fn float_select_add(
205        tensor: FloatTensor<Self>,
206        dim: usize,
207        indices: IntTensor<Self>,
208        value: FloatTensor<Self>,
209    ) -> FloatTensor<Self> {
210        kernel::select_assign(tensor, dim, indices, value, false)
211    }
212
213    fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
214        // Check if all steps are 1
215        let all_steps_one = slices.iter().all(|info| info.step == 1);
216
217        if all_steps_one {
218            // Use optimized slice for step=1
219            let simple_ranges: Vec<Range<usize>> = slices
220                .iter()
221                .enumerate()
222                .map(|(i, slice)| slice.to_range(tensor.shape[i]))
223                .collect();
224
225            kernel::slice(tensor, &simple_ranges)
226        } else {
227            // Use slice with steps kernel
228            kernel::slice_with_steps(tensor, slices)
229        }
230    }
231
232    fn float_slice_assign(
233        tensor: FloatTensor<Self>,
234        ranges: &[Slice],
235        value: FloatTensor<Self>,
236    ) -> FloatTensor<Self> {
237        kernel::slice_assign(tensor, ranges, value)
238    }
239
240    fn float_mask_where(
241        tensor: FloatTensor<Self>,
242        mask: BoolTensor<Self>,
243        value: FloatTensor<Self>,
244    ) -> FloatTensor<Self> {
245        kernel::mask_where_auto(tensor, mask, value, BT::dtype())
246    }
247
248    fn float_mask_fill(
249        tensor: FloatTensor<Self>,
250        mask: BoolTensor<Self>,
251        value: FloatElem<Self>,
252    ) -> FloatTensor<Self> {
253        let dtype = tensor.dtype;
254        kernel::mask_fill_auto(tensor, mask, InputScalar::new(value, dtype), BT::dtype())
255    }
256
257    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
258        kernel::equal(lhs, rhs, BT::dtype())
259    }
260
261    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
262        let dtype = lhs.dtype;
263        kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
264    }
265
266    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
267        kernel::greater(lhs, rhs, BT::dtype())
268    }
269
270    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
271        let dtype = lhs.dtype;
272        kernel::greater_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
273    }
274
275    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
276        kernel::greater_equal(lhs, rhs, BT::dtype())
277    }
278
279    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
280        let dtype = lhs.dtype;
281        kernel::greater_equal_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
282    }
283
284    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
285        kernel::lower(lhs, rhs, BT::dtype())
286    }
287
288    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
289        let dtype = lhs.dtype;
290        kernel::lower_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
291    }
292
293    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
294        kernel::lower_equal(lhs, rhs, BT::dtype())
295    }
296
297    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
298        let dtype = lhs.dtype;
299        kernel::lower_equal_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
300    }
301
302    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
303        let tensor = into_contiguous(tensor);
304        reduce::sum_fallback(tensor, Default::default()).unwrap()
305    }
306
307    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
308        reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Max).unwrap()
309    }
310
311    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
312        reduce::reduce_dim(
313            tensor,
314            None,
315            dim,
316            Default::default(),
317            ReduceOperationConfig::Max,
318        )
319        .unwrap()
320    }
321
322    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
323        reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Min).unwrap()
324    }
325
326    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
327        reduce::reduce_dim(
328            tensor,
329            None,
330            dim,
331            Default::default(),
332            ReduceOperationConfig::Min,
333        )
334        .unwrap()
335    }
336
337    fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
338        reduce::reduce(
339            tensor,
340            None,
341            Default::default(),
342            ReduceOperationConfig::MaxAbs,
343        )
344        .unwrap()
345    }
346
347    fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
348        reduce::reduce_dim(
349            tensor,
350            None,
351            dim,
352            Default::default(),
353            ReduceOperationConfig::MaxAbs,
354        )
355        .unwrap()
356    }
357
358    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
359        reduce::reduce_dim(
360            tensor,
361            None,
362            dim,
363            Default::default(),
364            ReduceOperationConfig::Sum,
365        )
366        .unwrap()
367    }
368
369    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
370        reduce::reduce_dim(
371            tensor,
372            None,
373            dim,
374            Default::default(),
375            ReduceOperationConfig::Mean,
376        )
377        .unwrap()
378    }
379
380    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
381        numeric::cumsum(tensor, dim)
382    }
383
384    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
385        numeric::cumprod(tensor, dim)
386    }
387
388    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
389        numeric::cummin(tensor, dim)
390    }
391
392    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
393        numeric::cummax(tensor, dim)
394    }
395
396    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
397        reduce::reduce(
398            tensor,
399            None,
400            Default::default(),
401            ReduceOperationConfig::Prod,
402        )
403        .unwrap()
404    }
405
406    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
407        reduce::reduce_dim(
408            tensor,
409            None,
410            dim,
411            Default::default(),
412            ReduceOperationConfig::Prod,
413        )
414        .unwrap()
415    }
416
417    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
418        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Exp)
419    }
420
421    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
422        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Log)
423    }
424
425    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
426        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Log1p)
427    }
428
429    fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: f32) -> FloatTensor<Self> {
430        struct Powf;
431
432        #[cube]
433        impl<F: Float> FloatUnaryOp<F> for Powf {
434            type Options = InputScalar;
435
436            fn execute(input: Line<F>, options: &Self::Options) -> Line<F> {
437                Line::powf(input, Line::new(options.get::<F>()))
438            }
439        }
440
441        impl FloatUnaryOpFamily for Powf {
442            type Options = InputScalar;
443            type Unary<F: Float> = Self;
444        }
445
446        let dtype = lhs.dtype;
447        launch_unary_float::<R, Powf, _>(lhs, |_| InputScalar::new(rhs, dtype))
448    }
449
450    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
451        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sqrt)
452    }
453
454    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
455        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Abs)
456    }
457
458    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
459        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Cos)
460    }
461
462    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
463        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sin)
464    }
465
466    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
467        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Tan)
468    }
469
470    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
471        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Cosh)
472    }
473
474    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
475        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sinh)
476    }
477
478    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
479        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Tanh)
480    }
481
482    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
483        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcCos)
484    }
485
486    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
487        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcCosh)
488    }
489
490    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
491        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcSin)
492    }
493
494    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
495        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcSinh)
496    }
497
498    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
499        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcTan)
500    }
501
502    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
503        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcTanh)
504    }
505
506    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
507        crate::kernel::atan2::<R>(lhs, rhs)
508    }
509
510    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
511        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Round)
512    }
513
514    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
515        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Floor)
516    }
517
518    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
519        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Ceil)
520    }
521
522    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
523        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Trunc)
524    }
525
526    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
527        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Erf)
528    }
529
530    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
531        reduce::reduce_dim(
532            tensor,
533            Some(<Self as Backend>::IntElem::dtype()),
534            dim,
535            Default::default(),
536            ReduceOperationConfig::ArgMax,
537        )
538        .unwrap()
539    }
540
541    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
542        reduce::reduce_dim(
543            tensor,
544            Some(<Self as Backend>::IntElem::dtype()),
545            dim,
546            Default::default(),
547            ReduceOperationConfig::ArgMin,
548        )
549        .unwrap()
550    }
551
552    fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
553        kernel::cast(tensor, I::dtype())
554    }
555
556    fn float_clamp(
557        tensor: FloatTensor<Self>,
558        min: FloatElem<Self>,
559        max: FloatElem<Self>,
560    ) -> FloatTensor<Self> {
561        let dtype = tensor.dtype;
562        kernel::clamp(
563            tensor,
564            InputScalar::new(min, dtype),
565            InputScalar::new(max, dtype),
566        )
567    }
568
569    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
570        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Recip)
571    }
572
573    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
574        kernel::repeat_dim(tensor, dim, times)
575    }
576
577    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
578        numeric::pow(lhs, rhs)
579    }
580
581    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
582        permute(tensor, axes)
583    }
584
585    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
586        expand(tensor, shape)
587    }
588
589    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
590        kernel::flip(tensor, axes, BT::dtype())
591    }
592
593    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
594        kernel::cast(tensor, dtype.into())
595    }
596
597    fn float_unfold(
598        tensor: FloatTensor<Self>,
599        dim: usize,
600        size: usize,
601        step: usize,
602    ) -> FloatTensor<Self> {
603        unfold(tensor, dim, size, step)
604    }
605
606    fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
607        kernel::is_nan(tensor, BT::dtype())
608    }
609
610    fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
611        kernel::is_inf(tensor, BT::dtype())
612    }
613
614    fn float_grid_sample_2d(
615        tensor: FloatTensor<Self>,
616        grid: FloatTensor<Self>,
617        options: GridSampleOptions,
618    ) -> FloatTensor<Self> {
619        kernel::grid_sample::grid_sample(tensor, grid, options)
620    }
621}