1use super::{expand, numeric, permute, unfold};
2use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
3use crate::kernel::{
4 self, FloatUnaryOp, FloatUnaryOpFamily, launch_unary_float, reduce, unary_basic,
5};
6use crate::kernel::{into_contiguous, unary_basic::BasicFloatUnaryKind};
7use crate::{CubeBackend, execute_with_dtype};
8use crate::{CubeRuntime, FloatElement, IntElement};
9use crate::{
10 element::BoolElement,
11 kernel::matmul::{MatmulStrategy, matmul},
12};
13use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
14use burn_tensor::{DType, ElementConversion, FloatDType};
15use burn_tensor::{Distribution, Shape, TensorData, ops::FloatTensorOps};
16use cubecl::prelude::*;
17use cubecl::reduce::ReducePrecision;
18use cubecl::reduce::instructions::ReduceFnConfig;
19use half::{bf16, f16};
20use std::ops::Range;
21
22impl<R, F, I, BT> FloatTensorOps<Self> for CubeBackend<R, F, I, BT>
23where
24 R: CubeRuntime,
25 F: FloatElement,
26 I: IntElement,
27 BT: BoolElement,
28{
29 fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
30 match data.dtype {
31 DType::F64 | DType::F32 | DType::F16 | DType::BF16 => {
32 super::from_data::<R>(data, device)
33 }
34 _ => unimplemented!("Unsupported dtype for `float_from_data`"),
35 }
36 }
37
38 fn float_random(
39 shape: Shape,
40 distribution: Distribution,
41 device: &Device<Self>,
42 ) -> FloatTensor<Self> {
43 match distribution {
44 Distribution::Default => random_uniform(shape, device, 0.elem::<F>(), 1.elem()),
45 Distribution::Uniform(low, high) => {
46 random_uniform(shape, device, low.elem::<F>(), high.elem())
47 }
48 Distribution::Bernoulli(prob) => random_bernoulli::<R, F>(shape, device, prob as f32),
49 Distribution::Normal(mean, std) => {
50 random_normal(shape, device, mean.elem::<F>(), std.elem())
51 }
52 }
53 }
54
55 async fn float_into_data(tensor: FloatTensor<Self>) -> TensorData {
56 execute_with_dtype!(
57 float(tensor.dtype),
58 E,
59 super::into_data::<R, E>(tensor).await
60 )
61 }
62
63 fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
64 tensor.device.clone()
65 }
66
67 fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
68 super::to_device(tensor, device)
69 }
70
71 fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
72 let dtype = dtype.into();
73 execute_with_dtype!(float(dtype), E, super::empty::<R, E>(shape, device))
74 }
75
76 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
77 execute_with_dtype!(
78 float(lhs.dtype, rhs.dtype),
79 E,
80 numeric::add::<R, E>(lhs, rhs)
81 )
82 }
83
84 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
85 execute_with_dtype!(
86 float(lhs.dtype),
87 E,
88 numeric::add_scalar::<R, E>(lhs, rhs.elem())
89 )
90 }
91
92 fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
93 let dtype = dtype.into();
94 execute_with_dtype!(float(dtype), E, numeric::zeros::<R, E>(shape, device))
95 }
96
97 fn float_full(
98 shape: Shape,
99 fill_value: FloatElem<Self>,
100 device: &R::Device,
101 dtype: FloatDType,
102 ) -> FloatTensor<Self> {
103 let dtype = dtype.into();
104 execute_with_dtype!(
105 float(dtype),
106 E,
107 numeric::full::<R, E>(shape, device, fill_value.elem())
108 )
109 }
110
111 fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
112 let dtype = dtype.into();
113 execute_with_dtype!(float(dtype), E, numeric::ones::<R, E>(shape, device))
114 }
115
116 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
117 execute_with_dtype!(
118 float(lhs.dtype, rhs.dtype),
119 E,
120 numeric::sub::<R, E>(lhs, rhs)
121 )
122 }
123
124 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
125 execute_with_dtype!(
126 float(lhs.dtype),
127 E,
128 numeric::sub_scalar::<R, E>(lhs, rhs.elem())
129 )
130 }
131
132 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
133 execute_with_dtype!(
134 float(lhs.dtype, rhs.dtype),
135 E,
136 numeric::mul::<R, E>(lhs, rhs)
137 )
138 }
139
140 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
141 execute_with_dtype!(
142 float(lhs.dtype),
143 E,
144 numeric::mul_scalar::<R, E>(lhs, rhs.elem())
145 )
146 }
147
148 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
149 execute_with_dtype!(
150 float(lhs.dtype, rhs.dtype),
151 E,
152 numeric::div::<R, E>(lhs, rhs)
153 )
154 }
155
156 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
157 execute_with_dtype!(
158 float(lhs.dtype),
159 E,
160 numeric::div_scalar::<R, E>(lhs, rhs.elem())
161 )
162 }
163
164 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
165 execute_with_dtype!(
166 float(lhs.dtype, rhs.dtype),
167 E,
168 numeric::remainder::<R, E>(lhs, rhs)
169 )
170 }
171
172 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
173 execute_with_dtype!(
174 float(lhs.dtype),
175 E,
176 numeric::remainder_scalar::<R, E>(lhs, rhs.elem())
177 )
178 }
179
180 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
181 execute_with_dtype!(
182 float(lhs.dtype, rhs.dtype),
183 E,
184 matmul::<R, E>(lhs, rhs, None, MatmulStrategy::default()).unwrap()
185 )
186 }
187
188 fn float_cross(
189 lhs: FloatTensor<Self>,
190 rhs: FloatTensor<Self>,
191 dim: usize,
192 ) -> FloatTensor<Self> {
193 execute_with_dtype!(
194 float(lhs.dtype, rhs.dtype),
195 E,
196 kernel::cross::<R, E>(lhs, rhs, dim)
197 )
198 }
199
200 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
201 super::swap_dims(tensor, dim1, dim2)
202 }
203
204 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
205 super::reshape(tensor, shape)
206 }
207
208 fn float_gather(
209 dim: usize,
210 tensor: FloatTensor<Self>,
211 indices: IntTensor<Self>,
212 ) -> FloatTensor<Self> {
213 execute_with_dtype!(
214 int(indices.dtype),
215 I,
216 execute_with_dtype!(
217 float(tensor.dtype),
218 E,
219 kernel::gather::<R, E, I>(dim, tensor, indices)
220 )
221 )
222 }
223
224 fn float_scatter(
225 dim: usize,
226 tensor: FloatTensor<Self>,
227 indices: IntTensor<Self>,
228 value: FloatTensor<Self>,
229 ) -> FloatTensor<Self> {
230 execute_with_dtype!(
231 int(indices.dtype),
232 I,
233 execute_with_dtype!(
234 float(tensor.dtype, value.dtype),
235 E,
236 kernel::scatter::<R, E, I>(dim, tensor, indices, value)
237 )
238 )
239 }
240
241 fn float_select(
242 tensor: FloatTensor<Self>,
243 dim: usize,
244 indices: IntTensor<Self>,
245 ) -> FloatTensor<Self> {
246 execute_with_dtype!(
247 int(indices.dtype),
248 I,
249 execute_with_dtype!(
250 float(tensor.dtype),
251 E,
252 kernel::select::<R, E, I>(tensor, dim, indices)
253 )
254 )
255 }
256
257 fn float_select_assign(
258 tensor: FloatTensor<Self>,
259 dim: usize,
260 indices: IntTensor<Self>,
261 value: FloatTensor<Self>,
262 ) -> FloatTensor<Self> {
263 execute_with_dtype!(
264 int(indices.dtype),
265 I,
266 execute_with_dtype!(
267 float(tensor.dtype, value.dtype),
268 E,
269 kernel::select_assign::<R, E, I>(tensor, dim, indices, value, false)
270 )
271 )
272 }
273
274 fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_tensor::Slice]) -> FloatTensor<Self> {
275 let all_steps_one = slices.iter().all(|info| info.step == 1);
277
278 if all_steps_one {
279 let simple_ranges: Vec<Range<usize>> = slices
281 .iter()
282 .enumerate()
283 .map(|(i, slice)| slice.to_range(tensor.shape[i]))
284 .collect();
285
286 execute_with_dtype!(
287 float(tensor.dtype),
288 E,
289 kernel::slice::<R, E>(tensor, &simple_ranges)
290 )
291 } else {
292 execute_with_dtype!(
294 float(tensor.dtype),
295 E,
296 kernel::slice_with_steps::<R, E>(tensor, slices)
297 )
298 }
299 }
300
301 fn float_slice_assign(
302 tensor: FloatTensor<Self>,
303 ranges: &[burn_tensor::Slice],
304 value: FloatTensor<Self>,
305 ) -> FloatTensor<Self> {
306 execute_with_dtype!(
307 float(tensor.dtype, value.dtype),
308 E,
309 kernel::slice_assign::<R, E>(tensor, ranges, value)
310 )
311 }
312
313 fn float_mask_where(
314 tensor: FloatTensor<Self>,
315 mask: BoolTensor<Self>,
316 value: FloatTensor<Self>,
317 ) -> FloatTensor<Self> {
318 execute_with_dtype!(
319 float(tensor.dtype, value.dtype),
320 E,
321 kernel::mask_where_auto::<R, E, BT>(tensor, mask, value)
322 )
323 }
324
325 fn float_mask_fill(
326 tensor: FloatTensor<Self>,
327 mask: BoolTensor<Self>,
328 value: FloatElem<Self>,
329 ) -> FloatTensor<Self> {
330 execute_with_dtype!(
331 float(tensor.dtype),
332 E,
333 kernel::mask_fill_auto::<R, E, BT>(tensor, mask, value.elem())
334 )
335 }
336
337 fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
338 execute_with_dtype!(
339 float(lhs.dtype, rhs.dtype),
340 E,
341 kernel::equal::<R, E, BT>(lhs, rhs)
342 )
343 }
344
345 fn float_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
346 execute_with_dtype!(
347 float(lhs.dtype),
348 E,
349 kernel::equal_elem::<R, E, BT>(lhs, rhs.elem())
350 )
351 }
352
353 fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
354 execute_with_dtype!(
355 float(lhs.dtype, rhs.dtype),
356 E,
357 kernel::greater::<R, E, BT>(lhs, rhs)
358 )
359 }
360
361 fn float_greater_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
362 execute_with_dtype!(
363 float(lhs.dtype),
364 E,
365 kernel::greater_elem::<R, E, BT>(lhs, rhs.elem())
366 )
367 }
368
369 fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
370 execute_with_dtype!(
371 float(lhs.dtype, rhs.dtype),
372 E,
373 kernel::greater_equal::<R, E, BT>(lhs, rhs)
374 )
375 }
376
377 fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
378 execute_with_dtype!(
379 float(lhs.dtype),
380 E,
381 kernel::greater_equal_elem::<R, E, BT>(lhs, rhs.elem())
382 )
383 }
384
385 fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
386 execute_with_dtype!(
387 float(lhs.dtype, rhs.dtype),
388 E,
389 kernel::lower::<R, E, BT>(lhs, rhs)
390 )
391 }
392
393 fn float_lower_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
394 execute_with_dtype!(
395 float(lhs.dtype),
396 E,
397 kernel::lower_elem::<R, E, BT>(lhs, rhs.elem())
398 )
399 }
400
401 fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
402 execute_with_dtype!(
403 float(lhs.dtype, rhs.dtype),
404 E,
405 kernel::lower_equal::<R, E, BT>(lhs, rhs)
406 )
407 }
408
409 fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
410 execute_with_dtype!(
411 float(lhs.dtype),
412 E,
413 kernel::lower_equal_elem::<R, E, BT>(lhs, rhs.elem())
414 )
415 }
416
417 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
418 let tensor = into_contiguous::<R>(tensor);
419 execute_with_dtype!(
420 float(tensor.dtype),
421 E,
422 reduce::sum_fallback::<R, E>(tensor, Default::default()).unwrap()
423 )
424 }
425
426 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
427 execute_with_dtype!(
428 float(tensor.dtype),
429 E,
430 reduce::reduce::<R, E, E, E>(tensor, Default::default(), ReduceFnConfig::Max).unwrap()
431 )
432 }
433
434 fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
435 execute_with_dtype!(
436 float(tensor.dtype),
437 E,
438 reduce::reduce_dim::<R, E, E, E>(tensor, dim, Default::default(), ReduceFnConfig::Max)
439 .unwrap()
440 )
441 }
442
443 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
444 execute_with_dtype!(
445 float(tensor.dtype),
446 E,
447 reduce::reduce::<R, E, E, E>(tensor, Default::default(), ReduceFnConfig::Min).unwrap()
448 )
449 }
450
451 fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
452 execute_with_dtype!(
453 float(tensor.dtype),
454 E,
455 reduce::reduce_dim::<R, E, E, E>(tensor, dim, Default::default(), ReduceFnConfig::Min)
456 .unwrap()
457 )
458 }
459
460 fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
461 execute_with_dtype!(
462 float(tensor.dtype),
463 E,
464 reduce::reduce::<R, E, E, E>(tensor, Default::default(), ReduceFnConfig::MaxAbs)
465 .unwrap()
466 )
467 }
468
469 fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
470 execute_with_dtype!(
471 float(tensor.dtype),
472 E,
473 reduce::reduce_dim::<R, E, E, E>(
474 tensor,
475 dim,
476 Default::default(),
477 ReduceFnConfig::MaxAbs
478 )
479 .unwrap()
480 )
481 }
482
483 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
484 execute_with_dtype!(
485 float(tensor.dtype),
486 E,
487 reduce::reduce_dim::<R, E, E, <E as ReducePrecision>::EA>(
488 tensor,
489 dim,
490 Default::default(),
491 ReduceFnConfig::Sum
492 )
493 .unwrap()
494 )
495 }
496
497 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
498 execute_with_dtype!(
499 float(tensor.dtype),
500 E,
501 reduce::reduce_dim::<R, E, E, <E as ReducePrecision>::EA>(
502 tensor,
503 dim,
504 Default::default(),
505 ReduceFnConfig::Mean
506 )
507 .unwrap()
508 )
509 }
510
511 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
512 execute_with_dtype!(float(tensor.dtype), E, numeric::cumsum::<R, E>(tensor, dim))
513 }
514
515 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
516 execute_with_dtype!(
517 float(tensor.dtype),
518 E,
519 numeric::cumprod::<R, E>(tensor, dim)
520 )
521 }
522
523 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
524 execute_with_dtype!(float(tensor.dtype), E, numeric::cummin::<R, E>(tensor, dim))
525 }
526
527 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
528 execute_with_dtype!(float(tensor.dtype), E, numeric::cummax::<R, E>(tensor, dim))
529 }
530
531 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
532 execute_with_dtype!(
533 float(tensor.dtype),
534 E,
535 reduce::reduce::<R, E, E, <E as ReducePrecision>::EA>(
536 tensor,
537 Default::default(),
538 ReduceFnConfig::Prod
539 )
540 .unwrap()
541 )
542 }
543
544 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
545 execute_with_dtype!(
546 float(tensor.dtype),
547 E,
548 reduce::reduce_dim::<R, E, E, <E as ReducePrecision>::EA>(
549 tensor,
550 dim,
551 Default::default(),
552 ReduceFnConfig::Prod
553 )
554 .unwrap()
555 )
556 }
557
558 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
559 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Exp)
560 }
561
562 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
563 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Log)
564 }
565
566 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
567 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Log1p)
568 }
569
570 fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: f32) -> FloatTensor<Self> {
571 struct Powf;
572
573 #[cube]
574 impl<F: Float> FloatUnaryOp<F> for Powf {
575 type Options = F;
576
577 fn execute(input: Line<F>, options: &Self::Options) -> Line<F> {
578 Line::powf(input, Line::new(*options))
579 }
580 }
581
582 impl FloatUnaryOpFamily for Powf {
583 type Options<F: Float> = F;
584 type Unary<F: Float> = Self;
585 }
586
587 execute_with_dtype!(
588 float(lhs.dtype),
589 F,
590 launch_unary_float::<R, F, Powf, _>(lhs, |_| ScalarArg::new(rhs.elem::<F>()))
591 )
592 }
593
594 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
595 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sqrt)
596 }
597
598 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
599 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Abs)
600 }
601
602 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
603 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Cos)
604 }
605
606 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
607 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sin)
608 }
609
610 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
611 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Tanh)
612 }
613
614 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
615 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Round)
616 }
617
618 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
619 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Floor)
620 }
621
622 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
623 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Ceil)
624 }
625
626 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
627 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Trunc)
628 }
629
630 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
631 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Erf)
632 }
633
634 fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
635 execute_with_dtype!(
636 float(tensor.dtype),
637 E,
638 reduce::reduce_dim::<R, E, I, E>(
639 tensor,
640 dim,
641 Default::default(),
642 ReduceFnConfig::ArgMax
643 )
644 .unwrap()
645 )
646 }
647
648 fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
649 execute_with_dtype!(
650 float(tensor.dtype),
651 E,
652 reduce::reduce_dim::<R, E, I, E>(
653 tensor,
654 dim,
655 Default::default(),
656 ReduceFnConfig::ArgMin
657 )
658 .unwrap()
659 )
660 }
661
662 fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
663 execute_with_dtype!(float(tensor.dtype), E, kernel::cast::<R, E, I>(tensor))
664 }
665
666 fn float_clamp(
667 tensor: FloatTensor<Self>,
668 min: FloatElem<Self>,
669 max: FloatElem<Self>,
670 ) -> FloatTensor<Self> {
671 execute_with_dtype!(
672 float(tensor.dtype),
673 E,
674 kernel::clamp::<R, E>(tensor, min.elem(), max.elem())
675 )
676 }
677
678 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
679 unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Recip)
680 }
681
682 fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
683 execute_with_dtype!(
684 float(tensor.dtype),
685 E,
686 kernel::repeat_dim::<R, E>(tensor, dim, times)
687 )
688 }
689
690 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
691 execute_with_dtype!(float(lhs.dtype), E, numeric::pow::<R, E>(lhs, rhs))
692 }
693
694 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
695 permute(tensor, axes)
696 }
697
698 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
699 expand(tensor, shape)
700 }
701
702 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
703 execute_with_dtype!(
704 float(tensor.dtype),
705 E,
706 kernel::flip::<R, E, BT>(tensor, axes)
707 )
708 }
709
710 fn float_cast(mut tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
711 match (tensor.dtype, dtype) {
712 (DType::F64, FloatDType::F64)
713 | (DType::F32, FloatDType::F32)
714 | (DType::Flex32, FloatDType::Flex32)
715 | (DType::BF16, FloatDType::BF16)
716 | (DType::F16, FloatDType::F16) => tensor,
717 (DType::F32, FloatDType::Flex32) | (DType::Flex32, FloatDType::F32) => {
718 tensor.dtype = dtype.into();
719 tensor
720 }
721 (DType::F64, FloatDType::F32) => kernel::cast::<R, f64, f32>(tensor),
722 (DType::F64, FloatDType::Flex32) => kernel::cast::<R, f64, flex32>(tensor),
723 (DType::F64, FloatDType::F16) => kernel::cast::<R, f64, f16>(tensor),
724 (DType::F64, FloatDType::BF16) => kernel::cast::<R, f64, bf16>(tensor),
725 (DType::F32, FloatDType::F64) => kernel::cast::<R, f32, f64>(tensor),
726 (DType::F32, FloatDType::F16) => kernel::cast::<R, f32, f16>(tensor),
727 (DType::F32, FloatDType::BF16) => kernel::cast::<R, f32, bf16>(tensor),
728 (DType::Flex32, FloatDType::F64) => kernel::cast::<R, flex32, f64>(tensor),
729 (DType::Flex32, FloatDType::F16) => kernel::cast::<R, flex32, f16>(tensor),
730 (DType::Flex32, FloatDType::BF16) => kernel::cast::<R, flex32, bf16>(tensor),
731 (DType::F16, FloatDType::F64) => kernel::cast::<R, f16, f64>(tensor),
732 (DType::F16, FloatDType::F32) => kernel::cast::<R, f16, f32>(tensor),
733 (DType::F16, FloatDType::Flex32) => kernel::cast::<R, f16, flex32>(tensor),
734 (DType::F16, FloatDType::BF16) => kernel::cast::<R, f16, bf16>(tensor),
735 (DType::BF16, FloatDType::F64) => kernel::cast::<R, bf16, f64>(tensor),
736 (DType::BF16, FloatDType::F32) => kernel::cast::<R, bf16, f32>(tensor),
737 (DType::BF16, FloatDType::Flex32) => kernel::cast::<R, bf16, flex32>(tensor),
738 (DType::BF16, FloatDType::F16) => kernel::cast::<R, bf16, f16>(tensor),
739 _ => unimplemented!("Unsupported floating point type cast"),
740 }
741 }
742
743 fn float_unfold(
744 tensor: FloatTensor<Self>,
745 dim: usize,
746 size: usize,
747 step: usize,
748 ) -> FloatTensor<Self> {
749 unfold(tensor, dim, size, step)
750 }
751
752 fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
753 execute_with_dtype!(float(tensor.dtype), E, kernel::is_nan::<R, E, BT>(tensor))
754 }
755
756 fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
757 execute_with_dtype!(float(tensor.dtype), E, kernel::is_inf::<R, E, BT>(tensor))
758 }
759}