burn_cubecl/ops/
float_ops.rs

1use super::{expand, numeric, permute};
2use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
3use crate::kernel::unary_basic::BasicFloatUnaryKind;
4use crate::kernel::{
5    self, FloatUnaryOp, FloatUnaryOpFamily, launch_unary_float, reduce, unary_basic,
6};
7use crate::{CubeBackend, execute_with_dtype};
8use crate::{CubeRuntime, FloatElement, IntElement};
9use crate::{
10    element::BoolElement,
11    kernel::matmul::{MatmulStrategy, matmul},
12};
13use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
14use burn_tensor::{DType, ElementConversion, FloatDType};
15use burn_tensor::{Distribution, Shape, TensorData, ops::FloatTensorOps};
16use cubecl::prelude::*;
17use cubecl::reduce::instructions::ReduceFnConfig;
18use half::{bf16, f16};
19use std::ops::Range;
20
21impl<R, F, I, BT> FloatTensorOps<Self> for CubeBackend<R, F, I, BT>
22where
23    R: CubeRuntime,
24    F: FloatElement,
25    I: IntElement,
26    BT: BoolElement,
27{
28    fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
29        match data.dtype {
30            DType::F64 | DType::F32 | DType::F16 | DType::BF16 => {
31                super::from_data::<R>(data, device)
32            }
33            _ => unimplemented!("Unsupported dtype for `float_from_data`"),
34        }
35    }
36
37    fn float_random(
38        shape: Shape,
39        distribution: Distribution,
40        device: &Device<Self>,
41    ) -> FloatTensor<Self> {
42        match distribution {
43            Distribution::Default => random_uniform(shape, device, 0.elem::<F>(), 1.elem()),
44            Distribution::Uniform(low, high) => {
45                random_uniform(shape, device, low.elem::<F>(), high.elem())
46            }
47            Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob.elem::<F>()),
48            Distribution::Normal(mean, std) => {
49                random_normal(shape, device, mean.elem::<F>(), std.elem())
50            }
51        }
52    }
53
54    async fn float_into_data(tensor: FloatTensor<Self>) -> TensorData {
55        execute_with_dtype!(
56            float(tensor.dtype),
57            E,
58            super::into_data::<R, E>(tensor).await
59        )
60    }
61
62    fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
63        tensor.device.clone()
64    }
65
66    fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
67        super::to_device(tensor, device)
68    }
69
70    fn float_empty(shape: Shape, device: &Device<Self>) -> FloatTensor<Self> {
71        super::empty::<R, F>(shape, device)
72    }
73
74    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
75        execute_with_dtype!(
76            float(lhs.dtype, rhs.dtype),
77            E,
78            numeric::add::<R, E>(lhs, rhs)
79        )
80    }
81
82    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
83        execute_with_dtype!(
84            float(lhs.dtype),
85            E,
86            numeric::add_scalar::<R, E>(lhs, rhs.elem())
87        )
88    }
89
90    fn float_zeros(shape: Shape, device: &Device<Self>) -> FloatTensor<Self> {
91        numeric::zeros::<R, F>(shape, device)
92    }
93
94    fn float_full(
95        shape: Shape,
96        fill_value: FloatElem<Self>,
97        device: &R::Device,
98    ) -> FloatTensor<Self> {
99        numeric::full::<R, F>(shape, device, fill_value)
100    }
101
102    fn float_ones(shape: Shape, device: &Device<Self>) -> FloatTensor<Self> {
103        numeric::ones::<R, F>(shape, device)
104    }
105
106    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
107        execute_with_dtype!(
108            float(lhs.dtype, rhs.dtype),
109            E,
110            numeric::sub::<R, E>(lhs, rhs)
111        )
112    }
113
114    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
115        execute_with_dtype!(
116            float(lhs.dtype),
117            E,
118            numeric::sub_scalar::<R, E>(lhs, rhs.elem())
119        )
120    }
121
122    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
123        execute_with_dtype!(
124            float(lhs.dtype, rhs.dtype),
125            E,
126            numeric::mul::<R, E>(lhs, rhs)
127        )
128    }
129
130    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
131        execute_with_dtype!(
132            float(lhs.dtype),
133            E,
134            numeric::mul_scalar::<R, E>(lhs, rhs.elem())
135        )
136    }
137
138    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
139        execute_with_dtype!(
140            float(lhs.dtype, rhs.dtype),
141            E,
142            numeric::div::<R, E>(lhs, rhs)
143        )
144    }
145
146    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
147        execute_with_dtype!(
148            float(lhs.dtype),
149            E,
150            numeric::div_scalar::<R, E>(lhs, rhs.elem())
151        )
152    }
153
154    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
155        execute_with_dtype!(
156            float(lhs.dtype, rhs.dtype),
157            E,
158            numeric::remainder::<R, E>(lhs, rhs)
159        )
160    }
161
162    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
163        execute_with_dtype!(
164            float(lhs.dtype),
165            E,
166            numeric::remainder_scalar::<R, E>(lhs, rhs.elem())
167        )
168    }
169
170    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
171        execute_with_dtype!(
172            float(lhs.dtype, rhs.dtype),
173            E,
174            matmul::<R, E>(lhs, rhs, None, MatmulStrategy::default()).unwrap()
175        )
176    }
177
178    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
179        super::swap_dims(tensor, dim1, dim2)
180    }
181
182    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
183        super::reshape(tensor, shape)
184    }
185
186    fn float_gather(
187        dim: usize,
188        tensor: FloatTensor<Self>,
189        indices: IntTensor<Self>,
190    ) -> FloatTensor<Self> {
191        execute_with_dtype!(
192            float(tensor.dtype),
193            E,
194            kernel::gather::<R, E, I>(dim, tensor, indices)
195        )
196    }
197
198    fn float_scatter(
199        dim: usize,
200        tensor: FloatTensor<Self>,
201        indices: IntTensor<Self>,
202        value: FloatTensor<Self>,
203    ) -> FloatTensor<Self> {
204        execute_with_dtype!(
205            float(tensor.dtype, value.dtype),
206            E,
207            kernel::scatter::<R, E, I>(dim, tensor, indices, value)
208        )
209    }
210
211    fn float_select(
212        tensor: FloatTensor<Self>,
213        dim: usize,
214        indices: IntTensor<Self>,
215    ) -> FloatTensor<Self> {
216        execute_with_dtype!(
217            float(tensor.dtype),
218            E,
219            kernel::select::<R, E, I>(tensor, dim, indices)
220        )
221    }
222
223    fn float_select_assign(
224        tensor: FloatTensor<Self>,
225        dim: usize,
226        indices: IntTensor<Self>,
227        value: FloatTensor<Self>,
228    ) -> FloatTensor<Self> {
229        execute_with_dtype!(
230            float(tensor.dtype, value.dtype),
231            E,
232            kernel::select_assign::<R, E, I>(tensor, dim, indices, value)
233        )
234    }
235
236    fn float_slice(tensor: FloatTensor<Self>, ranges: &[Range<usize>]) -> FloatTensor<Self> {
237        execute_with_dtype!(
238            float(tensor.dtype),
239            E,
240            kernel::slice::<R, E>(tensor, ranges)
241        )
242    }
243
244    fn float_slice_assign(
245        tensor: FloatTensor<Self>,
246        ranges: &[Range<usize>],
247        value: FloatTensor<Self>,
248    ) -> FloatTensor<Self> {
249        execute_with_dtype!(
250            float(tensor.dtype, value.dtype),
251            E,
252            kernel::slice_assign::<R, E>(tensor, ranges, value)
253        )
254    }
255
256    fn float_mask_where(
257        tensor: FloatTensor<Self>,
258        mask: BoolTensor<Self>,
259        value: FloatTensor<Self>,
260    ) -> FloatTensor<Self> {
261        execute_with_dtype!(
262            float(tensor.dtype, value.dtype),
263            E,
264            kernel::mask_where_auto::<R, E, BT>(tensor, mask, value)
265        )
266    }
267
268    fn float_mask_fill(
269        tensor: FloatTensor<Self>,
270        mask: BoolTensor<Self>,
271        value: FloatElem<Self>,
272    ) -> FloatTensor<Self> {
273        execute_with_dtype!(
274            float(tensor.dtype),
275            E,
276            kernel::mask_fill_auto::<R, E, BT>(tensor, mask, value.elem())
277        )
278    }
279
280    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
281        execute_with_dtype!(
282            float(lhs.dtype, rhs.dtype),
283            E,
284            kernel::equal::<R, E, BT>(lhs, rhs)
285        )
286    }
287
288    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
289        execute_with_dtype!(
290            float(lhs.dtype),
291            E,
292            kernel::equal_elem::<R, E, BT>(lhs, rhs.elem())
293        )
294    }
295
296    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
297        execute_with_dtype!(
298            float(lhs.dtype, rhs.dtype),
299            E,
300            kernel::greater::<R, E, BT>(lhs, rhs)
301        )
302    }
303
304    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
305        execute_with_dtype!(
306            float(lhs.dtype),
307            E,
308            kernel::greater_elem::<R, E, BT>(lhs, rhs.elem())
309        )
310    }
311
312    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
313        execute_with_dtype!(
314            float(lhs.dtype, rhs.dtype),
315            E,
316            kernel::greater_equal::<R, E, BT>(lhs, rhs)
317        )
318    }
319
320    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
321        execute_with_dtype!(
322            float(lhs.dtype),
323            E,
324            kernel::greater_equal_elem::<R, E, BT>(lhs, rhs.elem())
325        )
326    }
327
328    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
329        execute_with_dtype!(
330            float(lhs.dtype, rhs.dtype),
331            E,
332            kernel::lower::<R, E, BT>(lhs, rhs)
333        )
334    }
335
336    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
337        execute_with_dtype!(
338            float(lhs.dtype),
339            E,
340            kernel::lower_elem::<R, E, BT>(lhs, rhs.elem())
341        )
342    }
343
344    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
345        execute_with_dtype!(
346            float(lhs.dtype, rhs.dtype),
347            E,
348            kernel::lower_equal::<R, E, BT>(lhs, rhs)
349        )
350    }
351
352    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
353        execute_with_dtype!(
354            float(lhs.dtype),
355            E,
356            kernel::lower_equal_elem::<R, E, BT>(lhs, rhs.elem())
357        )
358    }
359
360    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
361        execute_with_dtype!(
362            float(tensor.dtype),
363            E,
364            reduce::sum::<R, E>(tensor, Default::default()).unwrap()
365        )
366    }
367
368    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
369        execute_with_dtype!(
370            float(tensor.dtype),
371            E,
372            reduce::reduce::<R, E, E>(tensor, Default::default(), ReduceFnConfig::Max).unwrap()
373        )
374    }
375
376    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
377        execute_with_dtype!(
378            float(tensor.dtype),
379            E,
380            reduce::reduce_dim::<R, E, E>(tensor, dim, Default::default(), ReduceFnConfig::Max)
381                .unwrap()
382        )
383    }
384
385    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
386        execute_with_dtype!(
387            float(tensor.dtype),
388            E,
389            reduce::reduce::<R, E, E>(tensor, Default::default(), ReduceFnConfig::Min).unwrap()
390        )
391    }
392
393    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
394        execute_with_dtype!(
395            float(tensor.dtype),
396            E,
397            reduce::reduce_dim::<R, E, E>(tensor, dim, Default::default(), ReduceFnConfig::Min)
398                .unwrap()
399        )
400    }
401
402    fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
403        execute_with_dtype!(
404            float(tensor.dtype),
405            E,
406            reduce::reduce::<R, E, E>(tensor, Default::default(), ReduceFnConfig::MaxAbs).unwrap()
407        )
408    }
409
410    fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
411        execute_with_dtype!(
412            float(tensor.dtype),
413            E,
414            reduce::reduce_dim::<R, E, E>(tensor, dim, Default::default(), ReduceFnConfig::MaxAbs)
415                .unwrap()
416        )
417    }
418
419    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
420        execute_with_dtype!(
421            float(tensor.dtype),
422            E,
423            reduce::reduce_dim::<R, E, E>(tensor, dim, Default::default(), ReduceFnConfig::Sum)
424                .unwrap()
425        )
426    }
427
428    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
429        execute_with_dtype!(
430            float(tensor.dtype),
431            E,
432            reduce::reduce_dim::<R, E, E>(tensor, dim, Default::default(), ReduceFnConfig::Mean)
433                .unwrap()
434        )
435    }
436
437    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
438        execute_with_dtype!(
439            float(tensor.dtype),
440            E,
441            reduce::reduce::<R, E, E>(tensor, Default::default(), ReduceFnConfig::Prod).unwrap()
442        )
443    }
444
445    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
446        execute_with_dtype!(
447            float(tensor.dtype),
448            E,
449            reduce::reduce_dim::<R, E, E>(tensor, dim, Default::default(), ReduceFnConfig::Prod)
450                .unwrap()
451        )
452    }
453
454    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
455        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Exp)
456    }
457
458    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
459        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Log)
460    }
461
462    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
463        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Log1p)
464    }
465
466    fn float_powf_scalar(lhs: FloatTensor<Self>, rhs: f32) -> FloatTensor<Self> {
467        struct Powf;
468
469        #[cube]
470        impl<F: Float> FloatUnaryOp<F> for Powf {
471            type Options = F;
472
473            fn execute(input: Line<F>, options: &Self::Options) -> Line<F> {
474                Line::powf(input, Line::new(*options))
475            }
476        }
477
478        impl FloatUnaryOpFamily for Powf {
479            type Options<F: Float> = F;
480            type Unary<F: Float> = Self;
481        }
482
483        execute_with_dtype!(
484            float(lhs.dtype),
485            F,
486            launch_unary_float::<R, F, Powf, _>(lhs, |_| ScalarArg::new(rhs.elem::<F>()))
487        )
488    }
489
490    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
491        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Sqrt)
492    }
493
494    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
495        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Abs)
496    }
497
498    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
499        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Cos)
500    }
501
502    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
503        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Sin)
504    }
505
506    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
507        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Tanh)
508    }
509
510    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
511        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Round)
512    }
513
514    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
515        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Floor)
516    }
517
518    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
519        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Ceil)
520    }
521
522    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
523        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Erf)
524    }
525
526    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
527        execute_with_dtype!(
528            float(tensor.dtype),
529            E,
530            reduce::reduce_dim::<R, E, I>(tensor, dim, Default::default(), ReduceFnConfig::ArgMax)
531                .unwrap()
532        )
533    }
534
535    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
536        execute_with_dtype!(
537            float(tensor.dtype),
538            E,
539            reduce::reduce_dim::<R, E, I>(tensor, dim, Default::default(), ReduceFnConfig::ArgMin)
540                .unwrap()
541        )
542    }
543
544    fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
545        execute_with_dtype!(float(tensor.dtype), E, kernel::cast::<R, E, I>(tensor))
546    }
547
548    fn float_clamp(
549        tensor: FloatTensor<Self>,
550        min: FloatElem<Self>,
551        max: FloatElem<Self>,
552    ) -> FloatTensor<Self> {
553        execute_with_dtype!(
554            float(tensor.dtype),
555            E,
556            kernel::clamp::<R, E>(tensor, min.elem(), max.elem())
557        )
558    }
559
560    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
561        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Recip)
562    }
563
564    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
565        execute_with_dtype!(
566            float(tensor.dtype),
567            E,
568            kernel::repeat_dim::<R, E>(tensor, dim, times)
569        )
570    }
571
572    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
573        execute_with_dtype!(float(lhs.dtype), E, numeric::pow::<R, E>(lhs, rhs))
574    }
575
576    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
577        permute(tensor, axes)
578    }
579
580    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
581        expand(tensor, shape)
582    }
583
584    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
585        execute_with_dtype!(
586            float(tensor.dtype),
587            E,
588            kernel::flip::<R, E, BT>(tensor, axes)
589        )
590    }
591
592    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
593        match (tensor.dtype, dtype) {
594            (DType::F64, FloatDType::F64)
595            | (DType::F32, FloatDType::F32)
596            | (DType::BF16, FloatDType::BF16)
597            | (DType::F16, FloatDType::F16) => tensor,
598            (DType::F64, FloatDType::F32) => kernel::cast::<R, f64, f32>(tensor),
599            (DType::F64, FloatDType::F16) => kernel::cast::<R, f64, f16>(tensor),
600            (DType::F64, FloatDType::BF16) => kernel::cast::<R, f64, bf16>(tensor),
601            (DType::F32, FloatDType::F64) => kernel::cast::<R, f32, f64>(tensor),
602            (DType::F32, FloatDType::F16) => kernel::cast::<R, f32, f16>(tensor),
603            (DType::F32, FloatDType::BF16) => kernel::cast::<R, f32, bf16>(tensor),
604            (DType::F16, FloatDType::F64) => kernel::cast::<R, f16, f64>(tensor),
605            (DType::F16, FloatDType::F32) => kernel::cast::<R, f16, f32>(tensor),
606            (DType::F16, FloatDType::BF16) => kernel::cast::<R, f16, bf16>(tensor),
607            (DType::BF16, FloatDType::F64) => kernel::cast::<R, bf16, f64>(tensor),
608            (DType::BF16, FloatDType::F32) => kernel::cast::<R, bf16, f32>(tensor),
609            (DType::BF16, FloatDType::F16) => kernel::cast::<R, bf16, f16>(tensor),
610            _ => unimplemented!("Unsupported floating point type cast"),
611        }
612    }
613}