1use super::{expand, numeric, permute, unfold};
2use crate::CubeBackend;
3use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
4use crate::kernel::unary_basic::BasicFloatUnaryKind;
5use crate::kernel::{
6 self, FloatUnaryOp, FloatUnaryOpFamily, launch_unary_float, reduce, unary_basic,
7};
8use crate::{CubeRuntime, FloatElement, IntElement};
9use crate::{
10 element::BoolElement,
11 kernel::matmul::{MatmulStrategy, matmul},
12};
13use burn_backend::ops::GridSampleOptions;
14use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntTensor};
15use burn_backend::{DType, ElementConversion, FloatDType, Slice};
16use burn_backend::{Distribution, Shape, TensorData, ops::FloatTensorOps};
17use burn_backend::{ExecutionError, Scalar, get_device_settings};
18use burn_std::{BoolDType, IntDType};
19use cubecl::prelude::*;
20use cubek::reduce::components::instructions::ReduceOperationConfig;
21use std::ops::Range;
22
23impl<R, F, I, BT> FloatTensorOps<Self> for CubeBackend<R, F, I, BT>
24where
25 R: CubeRuntime,
26 F: FloatElement,
27 I: IntElement,
28 BT: BoolElement,
29{
30 #[cfg_attr(feature = "tracing", tracing::instrument(
31 level="trace",
32 skip(data),
33 fields(?data.shape, ?data.dtype)
34 ))]
35 fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
36 match data.dtype {
37 DType::F64 | DType::F32 | DType::F16 | DType::BF16 => super::from_data(data, device),
38 _ => unimplemented!("Unsupported dtype for `float_from_data`"),
39 }
40 }
41
42 fn float_random(
43 shape: Shape,
44 distribution: Distribution,
45 device: &Device<Self>,
46 dtype: FloatDType,
47 ) -> FloatTensor<Self> {
48 let dtype = dtype.into();
49 match distribution {
50 Distribution::Default => random_uniform(shape, device, 0., 1., dtype),
51 Distribution::Uniform(low, high) => {
52 random_uniform(shape, device, low.elem(), high.elem(), dtype)
53 }
54 Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob as f32, dtype),
55 Distribution::Normal(mean, std) => {
56 random_normal(shape, device, mean.elem(), std.elem(), dtype)
57 }
58 }
59 }
60
61 #[cfg_attr(feature = "tracing", tracing::instrument(
62 level="trace",
63 skip(tensor),
64 fields(from = ?tensor.device, meta = ?tensor.meta, dtype = ?tensor.dtype)
65 ))]
66 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
67 super::into_data(tensor).await
68 }
69
70 fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
71 tensor.device.clone()
72 }
73
74 #[cfg_attr(feature = "tracing", tracing::instrument(
75 level="trace",
76 skip(tensor),
77 fields(from = ?tensor.device, meta = ?tensor.meta, dtype = ?tensor.dtype)
78 ))]
79 fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
80 super::to_device(tensor, device)
81 }
82
83 fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
84 let dtype = dtype.into();
85 super::empty(shape, device, dtype)
86 }
87
88 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
89 numeric::add(lhs, rhs)
90 }
91
92 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
93 let dtype = lhs.dtype;
94 numeric::add_scalar(lhs, InputScalar::new(rhs, dtype))
95 }
96
97 fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
98 let dtype = dtype.into();
99 numeric::zeros(device.clone(), shape, dtype)
100 }
101
102 fn float_full(
103 shape: Shape,
104 fill_value: Scalar,
105 device: &R::Device,
106 dtype: FloatDType,
107 ) -> FloatTensor<Self> {
108 let dtype: DType = dtype.into();
109 let client = R::client(device);
110 numeric::full_device_dtype(
111 client,
112 shape,
113 device.clone(),
114 InputScalar::new(fill_value, dtype),
115 dtype,
116 )
117 }
118
119 fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
120 let dtype = dtype.into();
121 numeric::ones(device.clone(), shape, dtype)
122 }
123
124 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
125 numeric::sub(lhs, rhs)
126 }
127
128 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
129 let dtype = lhs.dtype;
130 numeric::sub_scalar(lhs, InputScalar::new(rhs, dtype))
131 }
132
133 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
134 numeric::mul(lhs, rhs)
135 }
136
137 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
138 let dtype = lhs.dtype;
139 numeric::mul_scalar(lhs, InputScalar::new(rhs, dtype))
140 }
141
142 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
143 numeric::div(lhs, rhs)
144 }
145
146 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
147 let dtype = lhs.dtype;
148 numeric::div_scalar(lhs, InputScalar::new(rhs, dtype))
149 }
150
151 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
152 numeric::remainder(lhs, rhs)
153 }
154
155 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
156 let dtype = lhs.dtype;
157 numeric::remainder_scalar(lhs, InputScalar::new(rhs, dtype))
158 }
159
160 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
161 let dtype = lhs.dtype;
162 matmul(lhs, rhs, None, MatmulStrategy::default(), dtype).unwrap()
163 }
164
165 fn float_cross(
166 lhs: FloatTensor<Self>,
167 rhs: FloatTensor<Self>,
168 dim: usize,
169 ) -> FloatTensor<Self> {
170 kernel::cross(lhs, rhs, dim)
171 }
172
173 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
174 super::swap_dims(tensor, dim1, dim2)
175 }
176
177 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
178 super::reshape(tensor, shape)
179 }
180
181 fn float_gather(
182 dim: usize,
183 tensor: FloatTensor<Self>,
184 indices: IntTensor<Self>,
185 ) -> FloatTensor<Self> {
186 kernel::gather(dim, tensor, indices)
187 }
188
189 fn float_scatter_add(
190 dim: usize,
191 tensor: FloatTensor<Self>,
192 indices: IntTensor<Self>,
193 value: FloatTensor<Self>,
194 ) -> FloatTensor<Self> {
195 kernel::scatter(dim, tensor, indices, value, false)
196 }
197
198 fn float_scatter_nd(
199 data: FloatTensor<Self>,
200 indices: IntTensor<Self>,
201 values: FloatTensor<Self>,
202 reduction: burn_backend::tensor::IndexingUpdateOp,
203 ) -> FloatTensor<Self> {
204 kernel::scatter_nd(data, indices, values, reduction)
205 }
206
207 fn float_gather_nd(data: FloatTensor<Self>, indices: IntTensor<Self>) -> FloatTensor<Self> {
208 kernel::gather_nd(data, indices)
209 }
210
211 fn float_select(
212 tensor: FloatTensor<Self>,
213 dim: usize,
214 indices: IntTensor<Self>,
215 ) -> FloatTensor<Self> {
216 kernel::select(tensor, dim, indices)
217 }
218
219 fn float_select_add(
220 tensor: FloatTensor<Self>,
221 dim: usize,
222 indices: IntTensor<Self>,
223 value: FloatTensor<Self>,
224 ) -> FloatTensor<Self> {
225 kernel::select_assign(tensor, dim, indices, value, false)
226 }
227
228 fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
229 let all_steps_one = slices.iter().all(|info| info.step == 1);
231
232 if all_steps_one {
233 let simple_ranges: Vec<Range<usize>> = slices
235 .iter()
236 .enumerate()
237 .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i]))
238 .collect();
239
240 kernel::slice(tensor, &simple_ranges)
241 } else {
242 kernel::slice_with_steps(tensor, slices)
244 }
245 }
246
247 fn float_slice_assign(
248 tensor: FloatTensor<Self>,
249 ranges: &[Slice],
250 value: FloatTensor<Self>,
251 ) -> FloatTensor<Self> {
252 kernel::slice_assign(tensor, ranges, value)
253 }
254
255 fn float_mask_where(
256 tensor: FloatTensor<Self>,
257 mask: BoolTensor<Self>,
258 value: FloatTensor<Self>,
259 ) -> FloatTensor<Self> {
260 let bool_dtype = mask.dtype;
261 kernel::mask_where_auto(tensor, mask, value, bool_dtype)
262 }
263
264 fn float_mask_fill(
265 tensor: FloatTensor<Self>,
266 mask: BoolTensor<Self>,
267 value: Scalar,
268 ) -> FloatTensor<Self> {
269 let dtype = tensor.dtype;
270 let bool_dtype = mask.dtype;
271 kernel::mask_fill_auto(tensor, mask, InputScalar::new(value, dtype), bool_dtype)
272 }
273
274 fn float_equal(
275 lhs: FloatTensor<Self>,
276 rhs: FloatTensor<Self>,
277 out_dtype: BoolDType,
278 ) -> BoolTensor<Self> {
279 kernel::equal(lhs, rhs, out_dtype.into())
280 }
281
282 fn float_equal_elem(
283 lhs: FloatTensor<Self>,
284 rhs: Scalar,
285 out_dtype: BoolDType,
286 ) -> BoolTensor<Self> {
287 let dtype = lhs.dtype;
288 kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
289 }
290
291 fn float_greater(
292 lhs: FloatTensor<Self>,
293 rhs: FloatTensor<Self>,
294 out_dtype: BoolDType,
295 ) -> BoolTensor<Self> {
296 kernel::greater(lhs, rhs, out_dtype.into())
297 }
298
299 fn float_greater_elem(
300 lhs: FloatTensor<Self>,
301 rhs: Scalar,
302 out_dtype: BoolDType,
303 ) -> BoolTensor<Self> {
304 let dtype = lhs.dtype;
305 kernel::greater_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
306 }
307
308 fn float_greater_equal(
309 lhs: FloatTensor<Self>,
310 rhs: FloatTensor<Self>,
311 out_dtype: BoolDType,
312 ) -> BoolTensor<Self> {
313 kernel::greater_equal(lhs, rhs, out_dtype.into())
314 }
315
316 fn float_greater_equal_elem(
317 lhs: FloatTensor<Self>,
318 rhs: Scalar,
319 out_dtype: BoolDType,
320 ) -> BoolTensor<Self> {
321 let dtype = lhs.dtype;
322 kernel::greater_equal_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
323 }
324
325 fn float_lower(
326 lhs: FloatTensor<Self>,
327 rhs: FloatTensor<Self>,
328 out_dtype: BoolDType,
329 ) -> BoolTensor<Self> {
330 kernel::lower(lhs, rhs, out_dtype.into())
331 }
332
333 fn float_lower_elem(
334 lhs: FloatTensor<Self>,
335 rhs: Scalar,
336 out_dtype: BoolDType,
337 ) -> BoolTensor<Self> {
338 let dtype = lhs.dtype;
339 kernel::lower_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
340 }
341
342 fn float_lower_equal(
343 lhs: FloatTensor<Self>,
344 rhs: FloatTensor<Self>,
345 out_dtype: BoolDType,
346 ) -> BoolTensor<Self> {
347 kernel::lower_equal(lhs, rhs, out_dtype.into())
348 }
349
350 fn float_lower_equal_elem(
351 lhs: FloatTensor<Self>,
352 rhs: Scalar,
353 out_dtype: BoolDType,
354 ) -> BoolTensor<Self> {
355 let dtype = lhs.dtype;
356 kernel::lower_equal_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
357 }
358
359 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
360 reduce::sum_fallback(tensor, Default::default()).unwrap()
361 }
362
363 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
364 reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Max).unwrap()
365 }
366
367 fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
368 reduce::reduce_dim(
369 tensor,
370 None,
371 dim,
372 Default::default(),
373 ReduceOperationConfig::Max,
374 )
375 .unwrap()
376 }
377
378 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
379 reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Min).unwrap()
380 }
381
382 fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
383 reduce::reduce_dim(
384 tensor,
385 None,
386 dim,
387 Default::default(),
388 ReduceOperationConfig::Min,
389 )
390 .unwrap()
391 }
392
393 fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
394 reduce::reduce(
395 tensor,
396 None,
397 Default::default(),
398 ReduceOperationConfig::MaxAbs,
399 )
400 .unwrap()
401 }
402
403 fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
404 reduce::reduce_dim(
405 tensor,
406 None,
407 dim,
408 Default::default(),
409 ReduceOperationConfig::MaxAbs,
410 )
411 .unwrap()
412 }
413
414 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
415 reduce::reduce_dim(
416 tensor,
417 None,
418 dim,
419 Default::default(),
420 ReduceOperationConfig::Sum,
421 )
422 .unwrap()
423 }
424
425 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
426 reduce::reduce_dim(
427 tensor,
428 None,
429 dim,
430 Default::default(),
431 ReduceOperationConfig::Mean,
432 )
433 .unwrap()
434 }
435
436 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
437 reduce::reduce(
438 tensor,
439 None,
440 Default::default(),
441 ReduceOperationConfig::Mean,
442 )
443 .unwrap()
444 }
445
446 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
447 numeric::cumsum(tensor, dim)
448 }
449
450 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
451 numeric::cumprod(tensor, dim)
452 }
453
454 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
455 numeric::cummin(tensor, dim)
456 }
457
458 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
459 numeric::cummax(tensor, dim)
460 }
461
462 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
463 reduce::reduce(
464 tensor,
465 None,
466 Default::default(),
467 ReduceOperationConfig::Prod,
468 )
469 .unwrap()
470 }
471
472 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
473 reduce::reduce_dim(
474 tensor,
475 None,
476 dim,
477 Default::default(),
478 ReduceOperationConfig::Prod,
479 )
480 .unwrap()
481 }
482
483 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
484 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Exp)
485 }
486
487 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
488 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Log)
489 }
490
491 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
492 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Log1p)
493 }
494
495 fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
496 struct Powf;
497
498 #[cube]
499 impl<F: Float, N: Size> FloatUnaryOp<F, N> for Powf {
500 type Options = InputScalar;
501
502 fn execute(input: Vector<F, N>, options: &Self::Options) -> Vector<F, N> {
503 Vector::powf(input, Vector::new(options.get::<F>()))
504 }
505 }
506
507 impl FloatUnaryOpFamily for Powf {
508 type Options = InputScalar;
509 type Unary<F: Float, N: Size> = Self;
510 }
511
512 let dtype = lhs.dtype;
513 launch_unary_float::<R, Powf, _>(lhs, |_| InputScalar::new(rhs, dtype))
514 }
515
516 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
517 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sqrt)
518 }
519
520 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
521 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Abs)
522 }
523
524 fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
525 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sign)
526 }
527
528 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
529 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Cos)
530 }
531
532 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
533 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sin)
534 }
535
536 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
537 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Tan)
538 }
539
540 fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
541 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Cosh)
542 }
543
544 fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
545 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sinh)
546 }
547
548 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
549 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Tanh)
550 }
551
552 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
553 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcCos)
554 }
555
556 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
557 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcCosh)
558 }
559
560 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
561 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcSin)
562 }
563
564 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
565 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcSinh)
566 }
567
568 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
569 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcTan)
570 }
571
572 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
573 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcTanh)
574 }
575
576 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
577 crate::kernel::atan2::<R>(lhs, rhs)
578 }
579
580 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
581 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Round)
582 }
583
584 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
585 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Floor)
586 }
587
588 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
589 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Ceil)
590 }
591
592 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
593 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Trunc)
594 }
595
596 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
597 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Erf)
598 }
599
600 fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
601 reduce::reduce_dim(
602 tensor,
603 Some(out_dtype.into()),
604 dim,
605 Default::default(),
606 ReduceOperationConfig::ArgMax,
607 )
608 .unwrap()
609 }
610
611 fn float_argtopk(
612 tensor: FloatTensor<Self>,
613 dim: usize,
614 k: usize,
615 out_dtype: IntDType,
616 ) -> IntTensor<Self> {
617 reduce::reduce_dim(
618 tensor,
619 Some(out_dtype.into()),
620 dim,
621 Default::default(),
622 ReduceOperationConfig::ArgTopK(k),
623 )
624 .unwrap()
625 }
626
627 fn float_topk(tensor: FloatTensor<Self>, dim: usize, k: usize) -> FloatTensor<Self> {
628 reduce::reduce_dim(
629 tensor,
630 None,
631 dim,
632 Default::default(),
633 ReduceOperationConfig::TopK(k),
634 )
635 .unwrap()
636 }
637
638 fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
639 reduce::reduce_dim(
640 tensor,
641 Some(out_dtype.into()),
642 dim,
643 Default::default(),
644 ReduceOperationConfig::ArgMin,
645 )
646 .unwrap()
647 }
648
649 fn float_into_int(tensor: FloatTensor<Self>, out_dtype: IntDType) -> IntTensor<Self> {
650 kernel::cast(tensor, out_dtype.into())
651 }
652
653 fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
654 let dtype = tensor.dtype;
655 kernel::clamp(
656 tensor,
657 InputScalar::new(min, dtype),
658 InputScalar::new(max, dtype),
659 )
660 }
661
662 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
663 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Recip)
664 }
665
666 fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
667 kernel::repeat_dim(tensor, dim, times)
668 }
669
670 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
671 numeric::pow(lhs, rhs)
672 }
673
674 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
675 permute(tensor, axes)
676 }
677
678 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
679 expand(tensor, shape)
680 }
681
682 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
683 let bool_dtype = get_device_settings::<Self>(&tensor.device).bool_dtype;
684 kernel::flip(tensor, axes, bool_dtype.into())
685 }
686
687 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
688 kernel::cast(tensor, dtype.into())
689 }
690
691 fn float_unfold(
692 tensor: FloatTensor<Self>,
693 dim: usize,
694 size: usize,
695 step: usize,
696 ) -> FloatTensor<Self> {
697 unfold(tensor, dim, size, step)
698 }
699
700 fn float_is_nan(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
701 kernel::is_nan(tensor, out_dtype.into())
702 }
703
704 fn float_is_inf(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
705 kernel::is_inf(tensor, out_dtype.into())
706 }
707
708 fn float_grid_sample_2d(
709 tensor: FloatTensor<Self>,
710 grid: FloatTensor<Self>,
711 options: GridSampleOptions,
712 ) -> FloatTensor<Self> {
713 kernel::grid_sample::grid_sample(tensor, grid, options)
714 }
715}