1use super::{expand, numeric, permute};
2use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
3use crate::kernel::unary_basic::BasicFloatUnaryKind;
4use crate::kernel::{
5 self, FloatUnaryOp, FloatUnaryOpFamily, launch_unary_float, reduce, unary_basic,
6};
7use crate::{CubeBackend, execute_with_dtype};
8use crate::{CubeRuntime, FloatElement, IntElement};
9use crate::{
10 element::BoolElement,
11 kernel::matmul::{MatmulStrategy, matmul},
12};
13use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
14use burn_tensor::{DType, ElementConversion, FloatDType};
15use burn_tensor::{Distribution, Shape, TensorData, ops::FloatTensorOps};
16use cubecl::prelude::*;
17use cubecl::reduce::instructions::ReduceFnConfig;
18use half::{bf16, f16};
19use std::ops::Range;
20
21impl<R, F, I, BT> FloatTensorOps<Self> for CubeBackend<R, F, I, BT>
22where
23 R: CubeRuntime,
24 F: FloatElement,
25 I: IntElement,
26 BT: BoolElement,
27{
28 fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
29 match data.dtype {
30 DType::F64 | DType::F32 | DType::F16 | DType::BF16 => {
31 super::from_data::<R>(data, device)
32 }
33 _ => unimplemented!("Unsupported dtype for `float_from_data`"),
34 }
35 }
36
37 fn float_random(
38 shape: Shape,
39 distribution: Distribution,
40 device: &Device<Self>,
41 ) -> FloatTensor<Self> {
42 match distribution {
43 Distribution::Default => random_uniform(shape, device, 0.elem::<F>(), 1.elem()),
44 Distribution::Uniform(low, high) => {
45 random_uniform(shape, device, low.elem::<F>(), high.elem())
46 }
47 Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob.elem::<F>()),
48 Distribution::Normal(mean, std) => {
49 random_normal(shape, device, mean.elem::<F>(), std.elem())
50 }
51 }
52 }
53
54 async fn float_into_data(tensor: FloatTensor<Self>) -> TensorData {
55 execute_with_dtype!(
56 float(tensor.dtype),
57 E,
58 super::into_data::<R, E>(tensor).await
59 )
60 }
61
62 fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
63 tensor.device.clone()
64 }
65
66 fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
67 super::to_device(tensor, device)
68 }
69
70 fn float_empty(shape: Shape, device: &Device<Self>) -> FloatTensor<Self> {
71 super::empty::<R, F>(shape, device)
72 }
73
74 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
75 execute_with_dtype!(
76 float(lhs.dtype, rhs.dtype),
77 E,
78 numeric::add::<R, E>(lhs, rhs)
79 )
80 }
81
82 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
83 execute_with_dtype!(
84 float(lhs.dtype),
85 E,
86 numeric::add_scalar::<R, E>(lhs, rhs.elem())
87 )
88 }
89
90 fn float_zeros(shape: Shape, device: &Device<Self>) -> FloatTensor<Self> {
91 numeric::zeros::<R, F>(shape, device)
92 }
93
94 fn float_full(
95 shape: Shape,
96 fill_value: FloatElem<Self>,
97 device: &R::Device,
98 ) -> FloatTensor<Self> {
99 numeric::full::<R, F>(shape, device, fill_value)
100 }
101
102 fn float_ones(shape: Shape, device: &Device<Self>) -> FloatTensor<Self> {
103 numeric::ones::<R, F>(shape, device)
104 }
105
106 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
107 execute_with_dtype!(
108 float(lhs.dtype, rhs.dtype),
109 E,
110 numeric::sub::<R, E>(lhs, rhs)
111 )
112 }
113
114 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
115 execute_with_dtype!(
116 float(lhs.dtype),
117 E,
118 numeric::sub_scalar::<R, E>(lhs, rhs.elem())
119 )
120 }
121
122 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
123 execute_with_dtype!(
124 float(lhs.dtype, rhs.dtype),
125 E,
126 numeric::mul::<R, E>(lhs, rhs)
127 )
128 }
129
130 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
131 execute_with_dtype!(
132 float(lhs.dtype),
133 E,
134 numeric::mul_scalar::<R, E>(lhs, rhs.elem())
135 )
136 }
137
138 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
139 execute_with_dtype!(
140 float(lhs.dtype, rhs.dtype),
141 E,
142 numeric::div::<R, E>(lhs, rhs)
143 )
144 }
145
146 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
147 execute_with_dtype!(
148 float(lhs.dtype),
149 E,
150 numeric::div_scalar::<R, E>(lhs, rhs.elem())
151 )
152 }
153
154 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
155 execute_with_dtype!(
156 float(lhs.dtype, rhs.dtype),
157 E,
158 numeric::remainder::<R, E>(lhs, rhs)
159 )
160 }
161
162 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
163 execute_with_dtype!(
164 float(lhs.dtype),
165 E,
166 numeric::remainder_scalar::<R, E>(lhs, rhs.elem())
167 )
168 }
169
170 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
171 execute_with_dtype!(
172 float(lhs.dtype, rhs.dtype),
173 E,
174 matmul::<R, E>(lhs, rhs, None, MatmulStrategy::default()).unwrap()
175 )
176 }
177
178 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
179 super::swap_dims(tensor, dim1, dim2)
180 }
181
182 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
183 super::reshape(tensor, shape)
184 }
185
186 fn float_gather(
187 dim: usize,
188 tensor: FloatTensor<Self>,
189 indices: IntTensor<Self>,
190 ) -> FloatTensor<Self> {
191 execute_with_dtype!(
192 float(tensor.dtype),
193 E,
194 kernel::gather::<R, E, I>(dim, tensor, indices)
195 )
196 }
197
198 fn float_scatter(
199 dim: usize,
200 tensor: FloatTensor<Self>,
201 indices: IntTensor<Self>,
202 value: FloatTensor<Self>,
203 ) -> FloatTensor<Self> {
204 execute_with_dtype!(
205 float(tensor.dtype, value.dtype),
206 E,
207 kernel::scatter::<R, E, I>(dim, tensor, indices, value)
208 )
209 }
210
211 fn float_select(
212 tensor: FloatTensor<Self>,
213 dim: usize,
214 indices: IntTensor<Self>,
215 ) -> FloatTensor<Self> {
216 execute_with_dtype!(
217 float(tensor.dtype),
218 E,
219 kernel::select::<R, E, I>(tensor, dim, indices)
220 )
221 }
222
223 fn float_select_assign(
224 tensor: FloatTensor<Self>,
225 dim: usize,
226 indices: IntTensor<Self>,
227 value: FloatTensor<Self>,
228 ) -> FloatTensor<Self> {
229 execute_with_dtype!(
230 float(tensor.dtype, value.dtype),
231 E,
232 kernel::select_assign::<R, E, I>(tensor, dim, indices, value)
233 )
234 }
235
236 fn float_slice(tensor: FloatTensor<Self>, ranges: &[Range<usize>]) -> FloatTensor<Self> {
237 execute_with_dtype!(
238 float(tensor.dtype),
239 E,
240 kernel::slice::<R, E>(tensor, ranges)
241 )
242 }
243
244 fn float_slice_assign(
245 tensor: FloatTensor<Self>,
246 ranges: &[Range<usize>],
247 value: FloatTensor<Self>,
248 ) -> FloatTensor<Self> {
249 execute_with_dtype!(
250 float(tensor.dtype, value.dtype),
251 E,
252 kernel::slice_assign::<R, E>(tensor, ranges, value)
253 )
254 }
255
256 fn float_mask_where(
257 tensor: FloatTensor<Self>,
258 mask: BoolTensor<Self>,
259 value: FloatTensor<Self>,
260 ) -> FloatTensor<Self> {
261 execute_with_dtype!(
262 float(tensor.dtype, value.dtype),
263 E,
264 kernel::mask_where_auto::<R, E, BT>(tensor, mask, value)
265 )
266 }
267
268 fn float_mask_fill(
269 tensor: FloatTensor<Self>,
270 mask: BoolTensor<Self>,
271 value: FloatElem<Self>,
272 ) -> FloatTensor<Self> {
273 execute_with_dtype!(
274 float(tensor.dtype),
275 E,
276 kernel::mask_fill_auto::<R, E, BT>(tensor, mask, value.elem())
277 )
278 }
279
280 fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
281 execute_with_dtype!(
282 float(lhs.dtype, rhs.dtype),
283 E,
284 kernel::equal::<R, E, BT>(lhs, rhs)
285 )
286 }
287
288 fn float_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
289 execute_with_dtype!(
290 float(lhs.dtype),
291 E,
292 kernel::equal_elem::<R, E, BT>(lhs, rhs.elem())
293 )
294 }
295
296 fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
297 execute_with_dtype!(
298 float(lhs.dtype, rhs.dtype),
299 E,
300 kernel::greater::<R, E, BT>(lhs, rhs)
301 )
302 }
303
304 fn float_greater_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
305 execute_with_dtype!(
306 float(lhs.dtype),
307 E,
308 kernel::greater_elem::<R, E, BT>(lhs, rhs.elem())
309 )
310 }
311
312 fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
313 execute_with_dtype!(
314 float(lhs.dtype, rhs.dtype),
315 E,
316 kernel::greater_equal::<R, E, BT>(lhs, rhs)
317 )
318 }
319
320 fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
321 execute_with_dtype!(
322 float(lhs.dtype),
323 E,
324 kernel::greater_equal_elem::<R, E, BT>(lhs, rhs.elem())
325 )
326 }
327
328 fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
329 execute_with_dtype!(
330 float(lhs.dtype, rhs.dtype),
331 E,
332 kernel::lower::<R, E, BT>(lhs, rhs)
333 )
334 }
335
336 fn float_lower_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
337 execute_with_dtype!(
338 float(lhs.dtype),
339 E,
340 kernel::lower_elem::<R, E, BT>(lhs, rhs.elem())
341 )
342 }
343
344 fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
345 execute_with_dtype!(
346 float(lhs.dtype, rhs.dtype),
347 E,
348 kernel::lower_equal::<R, E, BT>(lhs, rhs)
349 )
350 }
351
352 fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
353 execute_with_dtype!(
354 float(lhs.dtype),
355 E,
356 kernel::lower_equal_elem::<R, E, BT>(lhs, rhs.elem())
357 )
358 }
359
360 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
361 execute_with_dtype!(
362 float(tensor.dtype),
363 E,
364 reduce::sum::<R, E>(tensor, Default::default()).unwrap()
365 )
366 }
367
368 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
369 execute_with_dtype!(
370 float(tensor.dtype),
371 E,
372 reduce::reduce::<R, E, E>(tensor, Default::default(), ReduceFnConfig::Max).unwrap()
373 )
374 }
375
376 fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
377 execute_with_dtype!(
378 float(tensor.dtype),
379 E,
380 reduce::reduce_dim::<R, E, E>(tensor, dim, Default::default(), ReduceFnConfig::Max)
381 .unwrap()
382 )
383 }
384
385 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
386 execute_with_dtype!(
387 float(tensor.dtype),
388 E,
389 reduce::reduce::<R, E, E>(tensor, Default::default(), ReduceFnConfig::Min).unwrap()
390 )
391 }
392
393 fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
394 execute_with_dtype!(
395 float(tensor.dtype),
396 E,
397 reduce::reduce_dim::<R, E, E>(tensor, dim, Default::default(), ReduceFnConfig::Min)
398 .unwrap()
399 )
400 }
401
402 fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
403 execute_with_dtype!(
404 float(tensor.dtype),
405 E,
406 reduce::reduce::<R, E, E>(tensor, Default::default(), ReduceFnConfig::MaxAbs).unwrap()
407 )
408 }
409
410 fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
411 execute_with_dtype!(
412 float(tensor.dtype),
413 E,
414 reduce::reduce_dim::<R, E, E>(tensor, dim, Default::default(), ReduceFnConfig::MaxAbs)
415 .unwrap()
416 )
417 }
418
419 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
420 execute_with_dtype!(
421 float(tensor.dtype),
422 E,
423 reduce::reduce_dim::<R, E, E>(tensor, dim, Default::default(), ReduceFnConfig::Sum)
424 .unwrap()
425 )
426 }
427
428 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
429 execute_with_dtype!(
430 float(tensor.dtype),
431 E,
432 reduce::reduce_dim::<R, E, E>(tensor, dim, Default::default(), ReduceFnConfig::Mean)
433 .unwrap()
434 )
435 }
436
437 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
438 execute_with_dtype!(
439 float(tensor.dtype),
440 E,
441 reduce::reduce::<R, E, E>(tensor, Default::default(), ReduceFnConfig::Prod).unwrap()
442 )
443 }
444
445 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
446 execute_with_dtype!(
447 float(tensor.dtype),
448 E,
449 reduce::reduce_dim::<R, E, E>(tensor, dim, Default::default(), ReduceFnConfig::Prod)
450 .unwrap()
451 )
452 }
453
454 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
455 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Exp)
456 }
457
458 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
459 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Log)
460 }
461
462 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
463 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Log1p)
464 }
465
466 fn float_powf_scalar(lhs: FloatTensor<Self>, rhs: f32) -> FloatTensor<Self> {
467 struct Powf;
468
469 #[cube]
470 impl<F: Float> FloatUnaryOp<F> for Powf {
471 type Options = F;
472
473 fn execute(input: Line<F>, options: &Self::Options) -> Line<F> {
474 Line::powf(input, Line::new(*options))
475 }
476 }
477
478 impl FloatUnaryOpFamily for Powf {
479 type Options<F: Float> = F;
480 type Unary<F: Float> = Self;
481 }
482
483 execute_with_dtype!(
484 float(lhs.dtype),
485 F,
486 launch_unary_float::<R, F, Powf, _>(lhs, |_| ScalarArg::new(rhs.elem::<F>()))
487 )
488 }
489
490 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
491 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Sqrt)
492 }
493
494 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
495 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Abs)
496 }
497
498 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
499 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Cos)
500 }
501
502 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
503 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Sin)
504 }
505
506 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
507 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Tanh)
508 }
509
510 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
511 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Round)
512 }
513
514 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
515 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Floor)
516 }
517
518 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
519 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Ceil)
520 }
521
522 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
523 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Erf)
524 }
525
526 fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
527 execute_with_dtype!(
528 float(tensor.dtype),
529 E,
530 reduce::reduce_dim::<R, E, I>(tensor, dim, Default::default(), ReduceFnConfig::ArgMax)
531 .unwrap()
532 )
533 }
534
535 fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
536 execute_with_dtype!(
537 float(tensor.dtype),
538 E,
539 reduce::reduce_dim::<R, E, I>(tensor, dim, Default::default(), ReduceFnConfig::ArgMin)
540 .unwrap()
541 )
542 }
543
544 fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
545 execute_with_dtype!(float(tensor.dtype), E, kernel::cast::<R, E, I>(tensor))
546 }
547
548 fn float_clamp(
549 tensor: FloatTensor<Self>,
550 min: FloatElem<Self>,
551 max: FloatElem<Self>,
552 ) -> FloatTensor<Self> {
553 execute_with_dtype!(
554 float(tensor.dtype),
555 E,
556 kernel::clamp::<R, E>(tensor, min.elem(), max.elem())
557 )
558 }
559
560 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
561 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Recip)
562 }
563
564 fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
565 execute_with_dtype!(
566 float(tensor.dtype),
567 E,
568 kernel::repeat_dim::<R, E>(tensor, dim, times)
569 )
570 }
571
572 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
573 execute_with_dtype!(float(lhs.dtype), E, numeric::pow::<R, E>(lhs, rhs))
574 }
575
576 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
577 permute(tensor, axes)
578 }
579
580 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
581 expand(tensor, shape)
582 }
583
584 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
585 execute_with_dtype!(
586 float(tensor.dtype),
587 E,
588 kernel::flip::<R, E, BT>(tensor, axes)
589 )
590 }
591
592 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
593 match (tensor.dtype, dtype) {
594 (DType::F64, FloatDType::F64)
595 | (DType::F32, FloatDType::F32)
596 | (DType::BF16, FloatDType::BF16)
597 | (DType::F16, FloatDType::F16) => tensor,
598 (DType::F64, FloatDType::F32) => kernel::cast::<R, f64, f32>(tensor),
599 (DType::F64, FloatDType::F16) => kernel::cast::<R, f64, f16>(tensor),
600 (DType::F64, FloatDType::BF16) => kernel::cast::<R, f64, bf16>(tensor),
601 (DType::F32, FloatDType::F64) => kernel::cast::<R, f32, f64>(tensor),
602 (DType::F32, FloatDType::F16) => kernel::cast::<R, f32, f16>(tensor),
603 (DType::F32, FloatDType::BF16) => kernel::cast::<R, f32, bf16>(tensor),
604 (DType::F16, FloatDType::F64) => kernel::cast::<R, f16, f64>(tensor),
605 (DType::F16, FloatDType::F32) => kernel::cast::<R, f16, f32>(tensor),
606 (DType::F16, FloatDType::BF16) => kernel::cast::<R, f16, bf16>(tensor),
607 (DType::BF16, FloatDType::F64) => kernel::cast::<R, bf16, f64>(tensor),
608 (DType::BF16, FloatDType::F32) => kernel::cast::<R, bf16, f32>(tensor),
609 (DType::BF16, FloatDType::F16) => kernel::cast::<R, bf16, f16>(tensor),
610 _ => unimplemented!("Unsupported floating point type cast"),
611 }
612 }
613}