burn_jit/ops/
float_ops.rs

1use super::{expand, numeric, permute};
2use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
3use crate::kernel::unary_basic::BasicFloatUnaryKind;
4use crate::kernel::{
5    self, launch_unary_float, reduce, unary_basic, FloatUnaryOp, FloatUnaryOpFamily,
6};
7use crate::{
8    element::BoolElement,
9    kernel::matmul::{matmul, MatmulStrategy},
10};
11use crate::{execute_with_dtype, JitBackend};
12use crate::{FloatElement, IntElement, JitRuntime};
13use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
14use burn_tensor::{ops::FloatTensorOps, Distribution, Shape, TensorData};
15use burn_tensor::{DType, ElementConversion, FloatDType};
16use cubecl::prelude::*;
17use half::{bf16, f16};
18use std::ops::Range;
19
20impl<R, F, I, BT> FloatTensorOps<Self> for JitBackend<R, F, I, BT>
21where
22    R: JitRuntime,
23    F: FloatElement,
24    I: IntElement,
25    BT: BoolElement,
26{
27    fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
28        super::from_data::<R, F>(data, device)
29    }
30
31    fn float_random(
32        shape: Shape,
33        distribution: Distribution,
34        device: &Device<Self>,
35    ) -> FloatTensor<Self> {
36        match distribution {
37            Distribution::Default => random_uniform(shape, device, 0.elem::<F>(), 1.elem()),
38            Distribution::Uniform(low, high) => {
39                random_uniform(shape, device, low.elem::<F>(), high.elem())
40            }
41            Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob.elem::<F>()),
42            Distribution::Normal(mean, std) => {
43                random_normal(shape, device, mean.elem::<F>(), std.elem())
44            }
45        }
46    }
47
48    async fn float_into_data(tensor: FloatTensor<Self>) -> TensorData {
49        execute_with_dtype!(
50            float(tensor.dtype),
51            E,
52            super::into_data::<R, E>(tensor).await
53        )
54    }
55
56    fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
57        tensor.device.clone()
58    }
59
60    fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
61        super::to_device(tensor, device)
62    }
63
64    fn float_empty(shape: Shape, device: &Device<Self>) -> FloatTensor<Self> {
65        super::empty::<R, F>(shape, device)
66    }
67
68    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
69        execute_with_dtype!(
70            float(lhs.dtype, rhs.dtype),
71            E,
72            numeric::add::<R, E>(lhs, rhs)
73        )
74    }
75
76    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
77        execute_with_dtype!(
78            float(lhs.dtype),
79            E,
80            numeric::add_scalar::<R, E>(lhs, rhs.elem())
81        )
82    }
83
84    fn float_zeros(shape: Shape, device: &Device<Self>) -> FloatTensor<Self> {
85        numeric::zeros::<R, F>(shape, device)
86    }
87
88    fn float_full(
89        shape: Shape,
90        fill_value: FloatElem<Self>,
91        device: &R::Device,
92    ) -> FloatTensor<Self> {
93        numeric::full::<R, F>(shape, device, fill_value)
94    }
95
96    fn float_ones(shape: Shape, device: &Device<Self>) -> FloatTensor<Self> {
97        numeric::ones::<R, F>(shape, device)
98    }
99
100    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
101        execute_with_dtype!(
102            float(lhs.dtype, rhs.dtype),
103            E,
104            numeric::sub::<R, E>(lhs, rhs)
105        )
106    }
107
108    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
109        execute_with_dtype!(
110            float(lhs.dtype),
111            E,
112            numeric::sub_scalar::<R, E>(lhs, rhs.elem())
113        )
114    }
115
116    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
117        execute_with_dtype!(
118            float(lhs.dtype, rhs.dtype),
119            E,
120            numeric::mul::<R, E>(lhs, rhs)
121        )
122    }
123
124    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
125        execute_with_dtype!(
126            float(lhs.dtype),
127            E,
128            numeric::mul_scalar::<R, E>(lhs, rhs.elem())
129        )
130    }
131
132    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
133        execute_with_dtype!(
134            float(lhs.dtype, rhs.dtype),
135            E,
136            numeric::div::<R, E>(lhs, rhs)
137        )
138    }
139
140    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
141        execute_with_dtype!(
142            float(lhs.dtype),
143            E,
144            numeric::div_scalar::<R, E>(lhs, rhs.elem())
145        )
146    }
147
148    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
149        execute_with_dtype!(
150            float(lhs.dtype, rhs.dtype),
151            E,
152            numeric::remainder::<R, E>(lhs, rhs)
153        )
154    }
155
156    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
157        execute_with_dtype!(
158            float(lhs.dtype),
159            E,
160            numeric::remainder_scalar::<R, E>(lhs, rhs.elem())
161        )
162    }
163
164    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
165        execute_with_dtype!(
166            float(lhs.dtype, rhs.dtype),
167            E,
168            matmul::<R, E>(lhs, rhs, None, MatmulStrategy::default()).unwrap()
169        )
170    }
171
172    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
173        super::swap_dims(tensor, dim1, dim2)
174    }
175
176    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
177        super::reshape(tensor, shape)
178    }
179
180    fn float_gather(
181        dim: usize,
182        tensor: FloatTensor<Self>,
183        indices: IntTensor<Self>,
184    ) -> FloatTensor<Self> {
185        execute_with_dtype!(
186            float(tensor.dtype),
187            E,
188            kernel::gather::<R, E, I>(dim, tensor, indices)
189        )
190    }
191
192    fn float_scatter(
193        dim: usize,
194        tensor: FloatTensor<Self>,
195        indices: IntTensor<Self>,
196        value: FloatTensor<Self>,
197    ) -> FloatTensor<Self> {
198        execute_with_dtype!(
199            float(tensor.dtype, value.dtype),
200            E,
201            kernel::scatter::<R, E, I>(dim, tensor, indices, value)
202        )
203    }
204
205    fn float_select(
206        tensor: FloatTensor<Self>,
207        dim: usize,
208        indices: IntTensor<Self>,
209    ) -> FloatTensor<Self> {
210        execute_with_dtype!(
211            float(tensor.dtype),
212            E,
213            kernel::select::<R, E, I>(tensor, dim, indices)
214        )
215    }
216
217    fn float_select_assign(
218        tensor: FloatTensor<Self>,
219        dim: usize,
220        indices: IntTensor<Self>,
221        value: FloatTensor<Self>,
222    ) -> FloatTensor<Self> {
223        execute_with_dtype!(
224            float(tensor.dtype, value.dtype),
225            E,
226            kernel::select_assign::<R, E, I>(tensor, dim, indices, value)
227        )
228    }
229
230    fn float_slice(tensor: FloatTensor<Self>, ranges: &[Range<usize>]) -> FloatTensor<Self> {
231        execute_with_dtype!(
232            float(tensor.dtype),
233            E,
234            kernel::slice::<R, E>(tensor, ranges)
235        )
236    }
237
238    fn float_slice_assign(
239        tensor: FloatTensor<Self>,
240        ranges: &[Range<usize>],
241        value: FloatTensor<Self>,
242    ) -> FloatTensor<Self> {
243        execute_with_dtype!(
244            float(tensor.dtype, value.dtype),
245            E,
246            kernel::slice_assign::<R, E>(tensor, ranges, value)
247        )
248    }
249
250    fn float_mask_where(
251        tensor: FloatTensor<Self>,
252        mask: BoolTensor<Self>,
253        value: FloatTensor<Self>,
254    ) -> FloatTensor<Self> {
255        execute_with_dtype!(
256            float(tensor.dtype, value.dtype),
257            E,
258            kernel::mask_where_auto::<R, E, BT>(tensor, mask, value)
259        )
260    }
261
262    fn float_mask_fill(
263        tensor: FloatTensor<Self>,
264        mask: BoolTensor<Self>,
265        value: FloatElem<Self>,
266    ) -> FloatTensor<Self> {
267        execute_with_dtype!(
268            float(tensor.dtype),
269            E,
270            kernel::mask_fill_auto::<R, E, BT>(tensor, mask, value.elem())
271        )
272    }
273
274    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
275        execute_with_dtype!(
276            float(lhs.dtype, rhs.dtype),
277            E,
278            kernel::equal::<R, E, BT>(lhs, rhs)
279        )
280    }
281
282    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
283        execute_with_dtype!(
284            float(lhs.dtype),
285            E,
286            kernel::equal_elem::<R, E, BT>(lhs, rhs.elem())
287        )
288    }
289
290    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
291        execute_with_dtype!(
292            float(lhs.dtype, rhs.dtype),
293            E,
294            kernel::greater::<R, E, BT>(lhs, rhs)
295        )
296    }
297
298    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
299        execute_with_dtype!(
300            float(lhs.dtype),
301            E,
302            kernel::greater_elem::<R, E, BT>(lhs, rhs.elem())
303        )
304    }
305
306    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
307        execute_with_dtype!(
308            float(lhs.dtype, rhs.dtype),
309            E,
310            kernel::greater_equal::<R, E, BT>(lhs, rhs)
311        )
312    }
313
314    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
315        execute_with_dtype!(
316            float(lhs.dtype),
317            E,
318            kernel::greater_equal_elem::<R, E, BT>(lhs, rhs.elem())
319        )
320    }
321
322    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
323        execute_with_dtype!(
324            float(lhs.dtype, rhs.dtype),
325            E,
326            kernel::lower::<R, E, BT>(lhs, rhs)
327        )
328    }
329
330    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
331        execute_with_dtype!(
332            float(lhs.dtype),
333            E,
334            kernel::lower_elem::<R, E, BT>(lhs, rhs.elem())
335        )
336    }
337
338    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
339        execute_with_dtype!(
340            float(lhs.dtype, rhs.dtype),
341            E,
342            kernel::lower_equal::<R, E, BT>(lhs, rhs)
343        )
344    }
345
346    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
347        execute_with_dtype!(
348            float(lhs.dtype),
349            E,
350            kernel::lower_equal_elem::<R, E, BT>(lhs, rhs.elem())
351        )
352    }
353
354    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
355        execute_with_dtype!(
356            float(tensor.dtype),
357            E,
358            reduce::reduce::<R, E, E, reduce::Sum>(tensor, Default::default()).unwrap()
359        )
360    }
361
362    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
363        execute_with_dtype!(
364            float(tensor.dtype),
365            E,
366            reduce::reduce_dim::<R, E, E, reduce::Sum>(tensor, dim, Default::default()).unwrap()
367        )
368    }
369
370    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
371        execute_with_dtype!(
372            float(tensor.dtype),
373            E,
374            reduce::reduce_dim::<R, E, E, reduce::Mean>(tensor, dim, Default::default()).unwrap()
375        )
376    }
377
378    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
379        execute_with_dtype!(
380            float(tensor.dtype),
381            E,
382            reduce::reduce::<R, E, E, reduce::Prod>(tensor, Default::default()).unwrap()
383        )
384    }
385
386    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
387        execute_with_dtype!(
388            float(tensor.dtype),
389            E,
390            reduce::reduce_dim::<R, E, E, reduce::Prod>(tensor, dim, Default::default()).unwrap()
391        )
392    }
393
394    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
395        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Exp)
396    }
397
398    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
399        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Log)
400    }
401
402    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
403        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Log1p)
404    }
405
406    fn float_powf_scalar(lhs: FloatTensor<Self>, rhs: f32) -> FloatTensor<Self> {
407        struct Powf;
408
409        #[cube]
410        impl<F: Float> FloatUnaryOp<F> for Powf {
411            type Options = F;
412
413            fn execute(input: Line<F>, options: &Self::Options) -> Line<F> {
414                Line::powf(input, Line::new(*options))
415            }
416        }
417
418        impl FloatUnaryOpFamily for Powf {
419            type Options<F: Float> = F;
420            type Unary<F: Float> = Self;
421        }
422
423        execute_with_dtype!(
424            float(lhs.dtype),
425            F,
426            launch_unary_float::<R, F, Powf, _>(lhs, |_| ScalarArg::new(rhs.elem::<F>()))
427        )
428    }
429
430    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
431        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Sqrt)
432    }
433
434    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
435        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Abs)
436    }
437
438    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
439        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Cos)
440    }
441
442    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
443        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Sin)
444    }
445
446    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
447        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Tanh)
448    }
449
450    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
451        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Round)
452    }
453
454    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
455        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Floor)
456    }
457
458    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
459        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Ceil)
460    }
461
462    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
463        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Erf)
464    }
465
466    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
467        execute_with_dtype!(
468            float(tensor.dtype),
469            E,
470            reduce::reduce_dim::<R, E, I, reduce::ArgMax>(tensor, dim, Default::default()).unwrap()
471        )
472    }
473
474    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
475        execute_with_dtype!(
476            float(tensor.dtype),
477            E,
478            reduce::reduce_dim::<R, E, I, reduce::ArgMin>(tensor, dim, Default::default()).unwrap()
479        )
480    }
481
482    fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
483        execute_with_dtype!(float(tensor.dtype), E, kernel::cast::<R, E, I>(tensor))
484    }
485
486    fn float_clamp(
487        tensor: FloatTensor<Self>,
488        min: FloatElem<Self>,
489        max: FloatElem<Self>,
490    ) -> FloatTensor<Self> {
491        execute_with_dtype!(
492            float(tensor.dtype),
493            E,
494            kernel::clamp::<R, E>(tensor, min.elem(), max.elem())
495        )
496    }
497
498    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
499        unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Recip)
500    }
501
502    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
503        execute_with_dtype!(
504            float(tensor.dtype),
505            E,
506            kernel::repeat_dim::<R, E>(tensor, dim, times)
507        )
508    }
509
510    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
511        execute_with_dtype!(float(lhs.dtype), E, numeric::pow::<R, E>(lhs, rhs))
512    }
513
514    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
515        permute(tensor, axes)
516    }
517
518    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
519        expand(tensor, shape)
520    }
521
522    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
523        execute_with_dtype!(
524            float(tensor.dtype),
525            E,
526            kernel::flip::<R, E, BT>(tensor, axes)
527        )
528    }
529
530    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
531        match (tensor.dtype, dtype) {
532            (DType::F64, FloatDType::F64)
533            | (DType::F32, FloatDType::F32)
534            | (DType::BF16, FloatDType::BF16)
535            | (DType::F16, FloatDType::F16) => tensor,
536            (DType::F64, FloatDType::F32) => kernel::cast::<R, f64, f32>(tensor),
537            (DType::F64, FloatDType::F16) => kernel::cast::<R, f64, f16>(tensor),
538            (DType::F64, FloatDType::BF16) => kernel::cast::<R, f64, bf16>(tensor),
539            (DType::F32, FloatDType::F64) => kernel::cast::<R, f32, f64>(tensor),
540            (DType::F32, FloatDType::F16) => kernel::cast::<R, f32, f16>(tensor),
541            (DType::F32, FloatDType::BF16) => kernel::cast::<R, f32, bf16>(tensor),
542            (DType::F16, FloatDType::F64) => kernel::cast::<R, f16, f64>(tensor),
543            (DType::F16, FloatDType::F32) => kernel::cast::<R, f16, f32>(tensor),
544            (DType::F16, FloatDType::BF16) => kernel::cast::<R, f16, bf16>(tensor),
545            (DType::BF16, FloatDType::F64) => kernel::cast::<R, bf16, f64>(tensor),
546            (DType::BF16, FloatDType::F32) => kernel::cast::<R, bf16, f32>(tensor),
547            (DType::BF16, FloatDType::F16) => kernel::cast::<R, bf16, f16>(tensor),
548            _ => unimplemented!("Unsupported floating point type cast"),
549        }
550    }
551}