1use super::{expand, numeric, permute};
2use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
3use crate::kernel::unary_basic::BasicFloatUnaryKind;
4use crate::kernel::{
5 self, launch_unary_float, reduce, unary_basic, FloatUnaryOp, FloatUnaryOpFamily,
6};
7use crate::{
8 element::BoolElement,
9 kernel::matmul::{matmul, MatmulStrategy},
10};
11use crate::{execute_with_dtype, JitBackend};
12use crate::{FloatElement, IntElement, JitRuntime};
13use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
14use burn_tensor::{ops::FloatTensorOps, Distribution, Shape, TensorData};
15use burn_tensor::{DType, ElementConversion, FloatDType};
16use cubecl::prelude::*;
17use half::{bf16, f16};
18use std::ops::Range;
19
20impl<R, F, I, BT> FloatTensorOps<Self> for JitBackend<R, F, I, BT>
21where
22 R: JitRuntime,
23 F: FloatElement,
24 I: IntElement,
25 BT: BoolElement,
26{
27 fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
28 super::from_data::<R, F>(data, device)
29 }
30
31 fn float_random(
32 shape: Shape,
33 distribution: Distribution,
34 device: &Device<Self>,
35 ) -> FloatTensor<Self> {
36 match distribution {
37 Distribution::Default => random_uniform(shape, device, 0.elem::<F>(), 1.elem()),
38 Distribution::Uniform(low, high) => {
39 random_uniform(shape, device, low.elem::<F>(), high.elem())
40 }
41 Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob.elem::<F>()),
42 Distribution::Normal(mean, std) => {
43 random_normal(shape, device, mean.elem::<F>(), std.elem())
44 }
45 }
46 }
47
48 async fn float_into_data(tensor: FloatTensor<Self>) -> TensorData {
49 execute_with_dtype!(
50 float(tensor.dtype),
51 E,
52 super::into_data::<R, E>(tensor).await
53 )
54 }
55
56 fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
57 tensor.device.clone()
58 }
59
60 fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
61 super::to_device(tensor, device)
62 }
63
64 fn float_empty(shape: Shape, device: &Device<Self>) -> FloatTensor<Self> {
65 super::empty::<R, F>(shape, device)
66 }
67
68 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
69 execute_with_dtype!(
70 float(lhs.dtype, rhs.dtype),
71 E,
72 numeric::add::<R, E>(lhs, rhs)
73 )
74 }
75
76 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
77 execute_with_dtype!(
78 float(lhs.dtype),
79 E,
80 numeric::add_scalar::<R, E>(lhs, rhs.elem())
81 )
82 }
83
84 fn float_zeros(shape: Shape, device: &Device<Self>) -> FloatTensor<Self> {
85 numeric::zeros::<R, F>(shape, device)
86 }
87
88 fn float_full(
89 shape: Shape,
90 fill_value: FloatElem<Self>,
91 device: &R::Device,
92 ) -> FloatTensor<Self> {
93 numeric::full::<R, F>(shape, device, fill_value)
94 }
95
96 fn float_ones(shape: Shape, device: &Device<Self>) -> FloatTensor<Self> {
97 numeric::ones::<R, F>(shape, device)
98 }
99
100 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
101 execute_with_dtype!(
102 float(lhs.dtype, rhs.dtype),
103 E,
104 numeric::sub::<R, E>(lhs, rhs)
105 )
106 }
107
108 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
109 execute_with_dtype!(
110 float(lhs.dtype),
111 E,
112 numeric::sub_scalar::<R, E>(lhs, rhs.elem())
113 )
114 }
115
116 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
117 execute_with_dtype!(
118 float(lhs.dtype, rhs.dtype),
119 E,
120 numeric::mul::<R, E>(lhs, rhs)
121 )
122 }
123
124 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
125 execute_with_dtype!(
126 float(lhs.dtype),
127 E,
128 numeric::mul_scalar::<R, E>(lhs, rhs.elem())
129 )
130 }
131
132 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
133 execute_with_dtype!(
134 float(lhs.dtype, rhs.dtype),
135 E,
136 numeric::div::<R, E>(lhs, rhs)
137 )
138 }
139
140 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
141 execute_with_dtype!(
142 float(lhs.dtype),
143 E,
144 numeric::div_scalar::<R, E>(lhs, rhs.elem())
145 )
146 }
147
148 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
149 execute_with_dtype!(
150 float(lhs.dtype, rhs.dtype),
151 E,
152 numeric::remainder::<R, E>(lhs, rhs)
153 )
154 }
155
156 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
157 execute_with_dtype!(
158 float(lhs.dtype),
159 E,
160 numeric::remainder_scalar::<R, E>(lhs, rhs.elem())
161 )
162 }
163
164 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
165 execute_with_dtype!(
166 float(lhs.dtype, rhs.dtype),
167 E,
168 matmul::<R, E>(lhs, rhs, None, MatmulStrategy::default()).unwrap()
169 )
170 }
171
172 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
173 super::swap_dims(tensor, dim1, dim2)
174 }
175
176 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
177 super::reshape(tensor, shape)
178 }
179
180 fn float_gather(
181 dim: usize,
182 tensor: FloatTensor<Self>,
183 indices: IntTensor<Self>,
184 ) -> FloatTensor<Self> {
185 execute_with_dtype!(
186 float(tensor.dtype),
187 E,
188 kernel::gather::<R, E, I>(dim, tensor, indices)
189 )
190 }
191
192 fn float_scatter(
193 dim: usize,
194 tensor: FloatTensor<Self>,
195 indices: IntTensor<Self>,
196 value: FloatTensor<Self>,
197 ) -> FloatTensor<Self> {
198 execute_with_dtype!(
199 float(tensor.dtype, value.dtype),
200 E,
201 kernel::scatter::<R, E, I>(dim, tensor, indices, value)
202 )
203 }
204
205 fn float_select(
206 tensor: FloatTensor<Self>,
207 dim: usize,
208 indices: IntTensor<Self>,
209 ) -> FloatTensor<Self> {
210 execute_with_dtype!(
211 float(tensor.dtype),
212 E,
213 kernel::select::<R, E, I>(tensor, dim, indices)
214 )
215 }
216
217 fn float_select_assign(
218 tensor: FloatTensor<Self>,
219 dim: usize,
220 indices: IntTensor<Self>,
221 value: FloatTensor<Self>,
222 ) -> FloatTensor<Self> {
223 execute_with_dtype!(
224 float(tensor.dtype, value.dtype),
225 E,
226 kernel::select_assign::<R, E, I>(tensor, dim, indices, value)
227 )
228 }
229
230 fn float_slice(tensor: FloatTensor<Self>, ranges: &[Range<usize>]) -> FloatTensor<Self> {
231 execute_with_dtype!(
232 float(tensor.dtype),
233 E,
234 kernel::slice::<R, E>(tensor, ranges)
235 )
236 }
237
238 fn float_slice_assign(
239 tensor: FloatTensor<Self>,
240 ranges: &[Range<usize>],
241 value: FloatTensor<Self>,
242 ) -> FloatTensor<Self> {
243 execute_with_dtype!(
244 float(tensor.dtype, value.dtype),
245 E,
246 kernel::slice_assign::<R, E>(tensor, ranges, value)
247 )
248 }
249
250 fn float_mask_where(
251 tensor: FloatTensor<Self>,
252 mask: BoolTensor<Self>,
253 value: FloatTensor<Self>,
254 ) -> FloatTensor<Self> {
255 execute_with_dtype!(
256 float(tensor.dtype, value.dtype),
257 E,
258 kernel::mask_where_auto::<R, E, BT>(tensor, mask, value)
259 )
260 }
261
262 fn float_mask_fill(
263 tensor: FloatTensor<Self>,
264 mask: BoolTensor<Self>,
265 value: FloatElem<Self>,
266 ) -> FloatTensor<Self> {
267 execute_with_dtype!(
268 float(tensor.dtype),
269 E,
270 kernel::mask_fill_auto::<R, E, BT>(tensor, mask, value.elem())
271 )
272 }
273
274 fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
275 execute_with_dtype!(
276 float(lhs.dtype, rhs.dtype),
277 E,
278 kernel::equal::<R, E, BT>(lhs, rhs)
279 )
280 }
281
282 fn float_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
283 execute_with_dtype!(
284 float(lhs.dtype),
285 E,
286 kernel::equal_elem::<R, E, BT>(lhs, rhs.elem())
287 )
288 }
289
290 fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
291 execute_with_dtype!(
292 float(lhs.dtype, rhs.dtype),
293 E,
294 kernel::greater::<R, E, BT>(lhs, rhs)
295 )
296 }
297
298 fn float_greater_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
299 execute_with_dtype!(
300 float(lhs.dtype),
301 E,
302 kernel::greater_elem::<R, E, BT>(lhs, rhs.elem())
303 )
304 }
305
306 fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
307 execute_with_dtype!(
308 float(lhs.dtype, rhs.dtype),
309 E,
310 kernel::greater_equal::<R, E, BT>(lhs, rhs)
311 )
312 }
313
314 fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
315 execute_with_dtype!(
316 float(lhs.dtype),
317 E,
318 kernel::greater_equal_elem::<R, E, BT>(lhs, rhs.elem())
319 )
320 }
321
322 fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
323 execute_with_dtype!(
324 float(lhs.dtype, rhs.dtype),
325 E,
326 kernel::lower::<R, E, BT>(lhs, rhs)
327 )
328 }
329
330 fn float_lower_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
331 execute_with_dtype!(
332 float(lhs.dtype),
333 E,
334 kernel::lower_elem::<R, E, BT>(lhs, rhs.elem())
335 )
336 }
337
338 fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
339 execute_with_dtype!(
340 float(lhs.dtype, rhs.dtype),
341 E,
342 kernel::lower_equal::<R, E, BT>(lhs, rhs)
343 )
344 }
345
346 fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
347 execute_with_dtype!(
348 float(lhs.dtype),
349 E,
350 kernel::lower_equal_elem::<R, E, BT>(lhs, rhs.elem())
351 )
352 }
353
354 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
355 execute_with_dtype!(
356 float(tensor.dtype),
357 E,
358 reduce::reduce::<R, E, E, reduce::Sum>(tensor, Default::default()).unwrap()
359 )
360 }
361
362 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
363 execute_with_dtype!(
364 float(tensor.dtype),
365 E,
366 reduce::reduce_dim::<R, E, E, reduce::Sum>(tensor, dim, Default::default()).unwrap()
367 )
368 }
369
370 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
371 execute_with_dtype!(
372 float(tensor.dtype),
373 E,
374 reduce::reduce_dim::<R, E, E, reduce::Mean>(tensor, dim, Default::default()).unwrap()
375 )
376 }
377
378 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
379 execute_with_dtype!(
380 float(tensor.dtype),
381 E,
382 reduce::reduce::<R, E, E, reduce::Prod>(tensor, Default::default()).unwrap()
383 )
384 }
385
386 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
387 execute_with_dtype!(
388 float(tensor.dtype),
389 E,
390 reduce::reduce_dim::<R, E, E, reduce::Prod>(tensor, dim, Default::default()).unwrap()
391 )
392 }
393
394 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
395 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Exp)
396 }
397
398 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
399 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Log)
400 }
401
402 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
403 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Log1p)
404 }
405
406 fn float_powf_scalar(lhs: FloatTensor<Self>, rhs: f32) -> FloatTensor<Self> {
407 struct Powf;
408
409 #[cube]
410 impl<F: Float> FloatUnaryOp<F> for Powf {
411 type Options = F;
412
413 fn execute(input: Line<F>, options: &Self::Options) -> Line<F> {
414 Line::powf(input, Line::new(*options))
415 }
416 }
417
418 impl FloatUnaryOpFamily for Powf {
419 type Options<F: Float> = F;
420 type Unary<F: Float> = Self;
421 }
422
423 execute_with_dtype!(
424 float(lhs.dtype),
425 F,
426 launch_unary_float::<R, F, Powf, _>(lhs, |_| ScalarArg::new(rhs.elem::<F>()))
427 )
428 }
429
430 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
431 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Sqrt)
432 }
433
434 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
435 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Abs)
436 }
437
438 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
439 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Cos)
440 }
441
442 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
443 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Sin)
444 }
445
446 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
447 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Tanh)
448 }
449
450 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
451 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Round)
452 }
453
454 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
455 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Floor)
456 }
457
458 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
459 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Ceil)
460 }
461
462 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
463 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Erf)
464 }
465
466 fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
467 execute_with_dtype!(
468 float(tensor.dtype),
469 E,
470 reduce::reduce_dim::<R, E, I, reduce::ArgMax>(tensor, dim, Default::default()).unwrap()
471 )
472 }
473
474 fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
475 execute_with_dtype!(
476 float(tensor.dtype),
477 E,
478 reduce::reduce_dim::<R, E, I, reduce::ArgMin>(tensor, dim, Default::default()).unwrap()
479 )
480 }
481
482 fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
483 execute_with_dtype!(float(tensor.dtype), E, kernel::cast::<R, E, I>(tensor))
484 }
485
486 fn float_clamp(
487 tensor: FloatTensor<Self>,
488 min: FloatElem<Self>,
489 max: FloatElem<Self>,
490 ) -> FloatTensor<Self> {
491 execute_with_dtype!(
492 float(tensor.dtype),
493 E,
494 kernel::clamp::<R, E>(tensor, min.elem(), max.elem())
495 )
496 }
497
498 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
499 unary_basic::launch::<R, _>(tensor, |_| &BasicFloatUnaryKind::Recip)
500 }
501
502 fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
503 execute_with_dtype!(
504 float(tensor.dtype),
505 E,
506 kernel::repeat_dim::<R, E>(tensor, dim, times)
507 )
508 }
509
510 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
511 execute_with_dtype!(float(lhs.dtype), E, numeric::pow::<R, E>(lhs, rhs))
512 }
513
514 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
515 permute(tensor, axes)
516 }
517
518 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
519 expand(tensor, shape)
520 }
521
522 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
523 execute_with_dtype!(
524 float(tensor.dtype),
525 E,
526 kernel::flip::<R, E, BT>(tensor, axes)
527 )
528 }
529
530 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
531 match (tensor.dtype, dtype) {
532 (DType::F64, FloatDType::F64)
533 | (DType::F32, FloatDType::F32)
534 | (DType::BF16, FloatDType::BF16)
535 | (DType::F16, FloatDType::F16) => tensor,
536 (DType::F64, FloatDType::F32) => kernel::cast::<R, f64, f32>(tensor),
537 (DType::F64, FloatDType::F16) => kernel::cast::<R, f64, f16>(tensor),
538 (DType::F64, FloatDType::BF16) => kernel::cast::<R, f64, bf16>(tensor),
539 (DType::F32, FloatDType::F64) => kernel::cast::<R, f32, f64>(tensor),
540 (DType::F32, FloatDType::F16) => kernel::cast::<R, f32, f16>(tensor),
541 (DType::F32, FloatDType::BF16) => kernel::cast::<R, f32, bf16>(tensor),
542 (DType::F16, FloatDType::F64) => kernel::cast::<R, f16, f64>(tensor),
543 (DType::F16, FloatDType::F32) => kernel::cast::<R, f16, f32>(tensor),
544 (DType::F16, FloatDType::BF16) => kernel::cast::<R, f16, bf16>(tensor),
545 (DType::BF16, FloatDType::F64) => kernel::cast::<R, bf16, f64>(tensor),
546 (DType::BF16, FloatDType::F32) => kernel::cast::<R, bf16, f32>(tensor),
547 (DType::BF16, FloatDType::F16) => kernel::cast::<R, bf16, f16>(tensor),
548 _ => unimplemented!("Unsupported floating point type cast"),
549 }
550 }
551}