1use super::{expand, numeric, permute, unfold};
2use crate::CubeBackend;
3use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
4use crate::kernel::{
5 self, FloatUnaryOp, FloatUnaryOpFamily, launch_unary_float, reduce, unary_basic,
6};
7use crate::kernel::{into_contiguous, unary_basic::BasicFloatUnaryKind};
8use crate::{CubeRuntime, FloatElement, IntElement};
9use crate::{
10 element::BoolElement,
11 kernel::matmul::{MatmulStrategy, matmul},
12};
13use burn_backend::ops::GridSampleOptions;
14use burn_backend::tensor::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
15use burn_backend::{Backend, ExecutionError};
16use burn_backend::{DType, ElementConversion, FloatDType, Slice};
17use burn_backend::{Distribution, Shape, TensorData, ops::FloatTensorOps};
18use cubecl::prelude::*;
19use cubek::reduce::components::instructions::ReduceOperationConfig;
20use std::ops::Range;
21
22impl<R, F, I, BT> FloatTensorOps<Self> for CubeBackend<R, F, I, BT>
23where
24 R: CubeRuntime,
25 F: FloatElement,
26 I: IntElement,
27 BT: BoolElement,
28{
29 #[cfg_attr(feature = "tracing", tracing::instrument(
30 level="trace",
31 skip(data),
32 fields(?data.shape, ?data.dtype)
33 ))]
34 fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
35 match data.dtype {
36 DType::F64 | DType::F32 | DType::F16 | DType::BF16 => super::from_data(data, device),
37 _ => unimplemented!("Unsupported dtype for `float_from_data`"),
38 }
39 }
40
41 fn float_random(
42 shape: Shape,
43 distribution: Distribution,
44 device: &Device<Self>,
45 ) -> FloatTensor<Self> {
46 let dtype = FloatElem::<Self>::dtype();
47 match distribution {
48 Distribution::Default => random_uniform(shape, device, 0., 1., dtype),
49 Distribution::Uniform(low, high) => {
50 random_uniform(shape, device, low.elem(), high.elem(), dtype)
51 }
52 Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob as f32, dtype),
53 Distribution::Normal(mean, std) => {
54 random_normal(shape, device, mean.elem(), std.elem(), dtype)
55 }
56 }
57 }
58
59 #[cfg_attr(feature = "tracing", tracing::instrument(
60 level="trace",
61 skip(tensor),
62 fields(from = ?tensor.device, shape = ?tensor.shape, dtype = ?tensor.dtype)
63 ))]
64 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
65 super::into_data(tensor).await
66 }
67
68 fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
69 tensor.device.clone()
70 }
71
72 #[cfg_attr(feature = "tracing", tracing::instrument(
73 level="trace",
74 skip(tensor),
75 fields(from = ?tensor.device, shape = ?tensor.shape, dtype = ?tensor.dtype)
76 ))]
77 fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
78 super::to_device(tensor, device)
79 }
80
81 fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
82 let dtype = dtype.into();
83 super::empty(shape, device, dtype)
84 }
85
86 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
87 numeric::add(lhs, rhs)
88 }
89
90 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
91 let dtype = lhs.dtype;
92 numeric::add_scalar(lhs, InputScalar::new(rhs, dtype))
93 }
94
95 fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
96 let dtype = dtype.into();
97 numeric::zeros(device.clone(), shape, dtype)
98 }
99
100 fn float_full(
101 shape: Shape,
102 fill_value: FloatElem<Self>,
103 device: &R::Device,
104 dtype: FloatDType,
105 ) -> FloatTensor<Self> {
106 let dtype: DType = dtype.into();
107 let client = R::client(device);
108 numeric::full_device_dtype(
109 client,
110 shape,
111 device.clone(),
112 InputScalar::new(fill_value, dtype),
113 dtype,
114 )
115 }
116
117 fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
118 let dtype = dtype.into();
119 numeric::ones(device.clone(), shape, dtype)
120 }
121
122 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
123 numeric::sub(lhs, rhs)
124 }
125
126 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
127 let dtype = lhs.dtype;
128 numeric::sub_scalar(lhs, InputScalar::new(rhs, dtype))
129 }
130
131 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
132 numeric::mul(lhs, rhs)
133 }
134
135 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
136 let dtype = lhs.dtype;
137 numeric::mul_scalar(lhs, InputScalar::new(rhs, dtype))
138 }
139
140 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
141 numeric::div(lhs, rhs)
142 }
143
144 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
145 let dtype = lhs.dtype;
146 numeric::div_scalar(lhs, InputScalar::new(rhs, dtype))
147 }
148
149 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
150 numeric::remainder(lhs, rhs)
151 }
152
153 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
154 let dtype = lhs.dtype;
155 numeric::remainder_scalar(lhs, InputScalar::new(rhs, dtype))
156 }
157
158 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
159 let dtype = lhs.dtype;
160 matmul(lhs, rhs, None, MatmulStrategy::default(), dtype).unwrap()
161 }
162
163 fn float_cross(
164 lhs: FloatTensor<Self>,
165 rhs: FloatTensor<Self>,
166 dim: usize,
167 ) -> FloatTensor<Self> {
168 kernel::cross(lhs, rhs, dim)
169 }
170
171 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
172 super::swap_dims(tensor, dim1, dim2)
173 }
174
175 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
176 super::reshape(tensor, shape)
177 }
178
179 fn float_gather(
180 dim: usize,
181 tensor: FloatTensor<Self>,
182 indices: IntTensor<Self>,
183 ) -> FloatTensor<Self> {
184 kernel::gather(dim, tensor, indices)
185 }
186
187 fn float_scatter_add(
188 dim: usize,
189 tensor: FloatTensor<Self>,
190 indices: IntTensor<Self>,
191 value: FloatTensor<Self>,
192 ) -> FloatTensor<Self> {
193 kernel::scatter(dim, tensor, indices, value, false)
194 }
195
196 fn float_select(
197 tensor: FloatTensor<Self>,
198 dim: usize,
199 indices: IntTensor<Self>,
200 ) -> FloatTensor<Self> {
201 kernel::select(tensor, dim, indices)
202 }
203
204 fn float_select_add(
205 tensor: FloatTensor<Self>,
206 dim: usize,
207 indices: IntTensor<Self>,
208 value: FloatTensor<Self>,
209 ) -> FloatTensor<Self> {
210 kernel::select_assign(tensor, dim, indices, value, false)
211 }
212
213 fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
214 let all_steps_one = slices.iter().all(|info| info.step == 1);
216
217 if all_steps_one {
218 let simple_ranges: Vec<Range<usize>> = slices
220 .iter()
221 .enumerate()
222 .map(|(i, slice)| slice.to_range(tensor.shape[i]))
223 .collect();
224
225 kernel::slice(tensor, &simple_ranges)
226 } else {
227 kernel::slice_with_steps(tensor, slices)
229 }
230 }
231
232 fn float_slice_assign(
233 tensor: FloatTensor<Self>,
234 ranges: &[Slice],
235 value: FloatTensor<Self>,
236 ) -> FloatTensor<Self> {
237 kernel::slice_assign(tensor, ranges, value)
238 }
239
240 fn float_mask_where(
241 tensor: FloatTensor<Self>,
242 mask: BoolTensor<Self>,
243 value: FloatTensor<Self>,
244 ) -> FloatTensor<Self> {
245 kernel::mask_where_auto(tensor, mask, value, BT::dtype())
246 }
247
248 fn float_mask_fill(
249 tensor: FloatTensor<Self>,
250 mask: BoolTensor<Self>,
251 value: FloatElem<Self>,
252 ) -> FloatTensor<Self> {
253 let dtype = tensor.dtype;
254 kernel::mask_fill_auto(tensor, mask, InputScalar::new(value, dtype), BT::dtype())
255 }
256
257 fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
258 kernel::equal(lhs, rhs, BT::dtype())
259 }
260
261 fn float_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
262 let dtype = lhs.dtype;
263 kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
264 }
265
266 fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
267 kernel::greater(lhs, rhs, BT::dtype())
268 }
269
270 fn float_greater_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
271 let dtype = lhs.dtype;
272 kernel::greater_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
273 }
274
275 fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
276 kernel::greater_equal(lhs, rhs, BT::dtype())
277 }
278
279 fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
280 let dtype = lhs.dtype;
281 kernel::greater_equal_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
282 }
283
284 fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
285 kernel::lower(lhs, rhs, BT::dtype())
286 }
287
288 fn float_lower_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
289 let dtype = lhs.dtype;
290 kernel::lower_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
291 }
292
293 fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
294 kernel::lower_equal(lhs, rhs, BT::dtype())
295 }
296
297 fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
298 let dtype = lhs.dtype;
299 kernel::lower_equal_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
300 }
301
302 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
303 let tensor = into_contiguous(tensor);
304 reduce::sum_fallback(tensor, Default::default()).unwrap()
305 }
306
307 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
308 reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Max).unwrap()
309 }
310
311 fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
312 reduce::reduce_dim(
313 tensor,
314 None,
315 dim,
316 Default::default(),
317 ReduceOperationConfig::Max,
318 )
319 .unwrap()
320 }
321
322 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
323 reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Min).unwrap()
324 }
325
326 fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
327 reduce::reduce_dim(
328 tensor,
329 None,
330 dim,
331 Default::default(),
332 ReduceOperationConfig::Min,
333 )
334 .unwrap()
335 }
336
337 fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
338 reduce::reduce(
339 tensor,
340 None,
341 Default::default(),
342 ReduceOperationConfig::MaxAbs,
343 )
344 .unwrap()
345 }
346
347 fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
348 reduce::reduce_dim(
349 tensor,
350 None,
351 dim,
352 Default::default(),
353 ReduceOperationConfig::MaxAbs,
354 )
355 .unwrap()
356 }
357
358 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
359 reduce::reduce_dim(
360 tensor,
361 None,
362 dim,
363 Default::default(),
364 ReduceOperationConfig::Sum,
365 )
366 .unwrap()
367 }
368
369 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
370 reduce::reduce_dim(
371 tensor,
372 None,
373 dim,
374 Default::default(),
375 ReduceOperationConfig::Mean,
376 )
377 .unwrap()
378 }
379
380 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
381 numeric::cumsum(tensor, dim)
382 }
383
384 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
385 numeric::cumprod(tensor, dim)
386 }
387
388 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
389 numeric::cummin(tensor, dim)
390 }
391
392 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
393 numeric::cummax(tensor, dim)
394 }
395
396 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
397 reduce::reduce(
398 tensor,
399 None,
400 Default::default(),
401 ReduceOperationConfig::Prod,
402 )
403 .unwrap()
404 }
405
406 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
407 reduce::reduce_dim(
408 tensor,
409 None,
410 dim,
411 Default::default(),
412 ReduceOperationConfig::Prod,
413 )
414 .unwrap()
415 }
416
417 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
418 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Exp)
419 }
420
421 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
422 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Log)
423 }
424
425 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
426 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Log1p)
427 }
428
429 fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: f32) -> FloatTensor<Self> {
430 struct Powf;
431
432 #[cube]
433 impl<F: Float> FloatUnaryOp<F> for Powf {
434 type Options = InputScalar;
435
436 fn execute(input: Line<F>, options: &Self::Options) -> Line<F> {
437 Line::powf(input, Line::new(options.get::<F>()))
438 }
439 }
440
441 impl FloatUnaryOpFamily for Powf {
442 type Options = InputScalar;
443 type Unary<F: Float> = Self;
444 }
445
446 let dtype = lhs.dtype;
447 launch_unary_float::<R, Powf, _>(lhs, |_| InputScalar::new(rhs, dtype))
448 }
449
450 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
451 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sqrt)
452 }
453
454 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
455 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Abs)
456 }
457
458 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
459 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Cos)
460 }
461
462 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
463 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sin)
464 }
465
466 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
467 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Tan)
468 }
469
470 fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
471 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Cosh)
472 }
473
474 fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
475 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sinh)
476 }
477
478 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
479 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Tanh)
480 }
481
482 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
483 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcCos)
484 }
485
486 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
487 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcCosh)
488 }
489
490 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
491 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcSin)
492 }
493
494 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
495 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcSinh)
496 }
497
498 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
499 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcTan)
500 }
501
502 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
503 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcTanh)
504 }
505
506 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
507 crate::kernel::atan2::<R>(lhs, rhs)
508 }
509
510 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
511 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Round)
512 }
513
514 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
515 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Floor)
516 }
517
518 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
519 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Ceil)
520 }
521
522 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
523 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Trunc)
524 }
525
526 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
527 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Erf)
528 }
529
530 fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
531 reduce::reduce_dim(
532 tensor,
533 Some(<Self as Backend>::IntElem::dtype()),
534 dim,
535 Default::default(),
536 ReduceOperationConfig::ArgMax,
537 )
538 .unwrap()
539 }
540
541 fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
542 reduce::reduce_dim(
543 tensor,
544 Some(<Self as Backend>::IntElem::dtype()),
545 dim,
546 Default::default(),
547 ReduceOperationConfig::ArgMin,
548 )
549 .unwrap()
550 }
551
552 fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
553 kernel::cast(tensor, I::dtype())
554 }
555
556 fn float_clamp(
557 tensor: FloatTensor<Self>,
558 min: FloatElem<Self>,
559 max: FloatElem<Self>,
560 ) -> FloatTensor<Self> {
561 let dtype = tensor.dtype;
562 kernel::clamp(
563 tensor,
564 InputScalar::new(min, dtype),
565 InputScalar::new(max, dtype),
566 )
567 }
568
569 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
570 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Recip)
571 }
572
573 fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
574 kernel::repeat_dim(tensor, dim, times)
575 }
576
577 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
578 numeric::pow(lhs, rhs)
579 }
580
581 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
582 permute(tensor, axes)
583 }
584
585 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
586 expand(tensor, shape)
587 }
588
589 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
590 kernel::flip(tensor, axes, BT::dtype())
591 }
592
593 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
594 kernel::cast(tensor, dtype.into())
595 }
596
597 fn float_unfold(
598 tensor: FloatTensor<Self>,
599 dim: usize,
600 size: usize,
601 step: usize,
602 ) -> FloatTensor<Self> {
603 unfold(tensor, dim, size, step)
604 }
605
606 fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
607 kernel::is_nan(tensor, BT::dtype())
608 }
609
610 fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
611 kernel::is_inf(tensor, BT::dtype())
612 }
613
614 fn float_grid_sample_2d(
615 tensor: FloatTensor<Self>,
616 grid: FloatTensor<Self>,
617 options: GridSampleOptions,
618 ) -> FloatTensor<Self> {
619 kernel::grid_sample::grid_sample(tensor, grid, options)
620 }
621}