1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData, TensorMetadata,
6 TensorPrimitive, get_device_settings,
7 ops::TransactionPrimitive,
8 tensor::{
9 BasicAutodiffOps, BasicOps, Device, Float, IndexingUpdateOp, IntTensor, Numeric, Ordered,
10 TensorKind, TransactionOp,
11 },
12};
13
14macro_rules! q_bin_ops {
15 ($lhs:ident, $rhs:ident, $op:ident, $q_op:ident) => {
16 match ($lhs, $rhs) {
17 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
18 TensorPrimitive::Float(B::$op(lhs, rhs))
19 }
20 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::$q_op(lhs, rhs),
21 (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
22 let dtype = rhs.dtype();
23 TensorPrimitive::Float(B::$op(B::dequantize(lhs, dtype.into()), rhs))
24 }
25 (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
26 let dtype = lhs.dtype();
27 TensorPrimitive::Float(B::$op(lhs, B::dequantize(rhs, dtype.into())))
28 }
29 }
30 };
31}
32impl<B: Backend> TransactionOp<B> for Float {
33 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
34 tr.register_float(tensor);
35 }
36}
37impl<B: Backend> BasicOps<B> for Float {
38 type Elem = B::FloatElem;
39
40 fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
41 TensorPrimitive::Float(B::float_empty(shape, device, dtype.into()))
42 }
43
44 fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
45 TensorPrimitive::Float(B::float_zeros(shape, device, dtype.into()))
46 }
47 fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
48 TensorPrimitive::Float(B::float_ones(shape, device, dtype.into()))
49 }
50
51 fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
52 TensorPrimitive::Float(B::float_full(shape, fill_value, device, dtype.into()))
53 }
54
55 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
56 match tensor {
57 TensorPrimitive::Float(tensor) => {
58 TensorPrimitive::Float(B::float_reshape(tensor, shape))
59 }
60 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
61 }
62 }
63
64 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
65 match tensor {
66 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
67 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
68 }
69 }
70
71 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
72 match tensor {
73 TensorPrimitive::Float(tensor) => {
74 TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
75 }
76 TensorPrimitive::QFloat(tensor) => {
77 TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
78 }
79 }
80 }
81
82 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
83 match tensor {
84 TensorPrimitive::Float(tensor) => {
85 TensorPrimitive::Float(B::float_slice(tensor, slices))
86 }
87 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, slices)),
88 }
89 }
90
91 fn slice_assign(
92 tensor: Self::Primitive,
93 slices: &[Slice],
94 value: Self::Primitive,
95 ) -> Self::Primitive {
96 TensorPrimitive::Float(B::float_slice_assign(
97 tensor.tensor(),
98 slices,
99 value.tensor(),
100 ))
101 }
102
103 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
104 match tensor {
105 TensorPrimitive::Float(tensor) => {
106 TensorPrimitive::Float(B::float_select(tensor, dim, indices))
107 }
108 TensorPrimitive::QFloat(tensor) => {
109 TensorPrimitive::QFloat(B::q_select(tensor, dim, indices))
110 }
111 }
112 }
113
114 fn select_assign(
115 tensor: Self::Primitive,
116 dim: usize,
117 indices: IntTensor<B>,
118 values: Self::Primitive,
119 update: IndexingUpdateOp,
120 ) -> Self::Primitive {
121 match update {
123 IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_select_add(
124 tensor.tensor(),
125 dim,
126 indices,
127 values.tensor(),
128 )),
129 _ => unimplemented!(),
130 }
131 }
132
133 fn mask_where(
134 tensor: Self::Primitive,
135 mask: B::BoolTensorPrimitive,
136 source: Self::Primitive,
137 ) -> Self::Primitive {
138 TensorPrimitive::Float(B::float_mask_where(tensor.tensor(), mask, source.tensor()))
139 }
140
141 fn mask_fill(
142 tensor: Self::Primitive,
143 mask: B::BoolTensorPrimitive,
144 value: Scalar,
145 ) -> Self::Primitive {
146 TensorPrimitive::Float(B::float_mask_fill(tensor.tensor(), mask, value))
147 }
148
149 fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive {
150 match tensor {
151 TensorPrimitive::Float(tensor) => {
152 TensorPrimitive::Float(B::float_gather(dim, tensor, indices))
153 }
154 TensorPrimitive::QFloat(tensor) => {
155 TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))
156 }
157 }
158 }
159
160 fn scatter(
161 dim: usize,
162 tensor: Self::Primitive,
163 indices: IntTensor<B>,
164 values: Self::Primitive,
165 update: IndexingUpdateOp,
166 ) -> Self::Primitive {
167 match update {
168 IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_scatter_add(
169 dim,
170 tensor.tensor(),
171 indices,
172 values.tensor(),
173 )),
174 _ => unimplemented!(),
175 }
176 }
177
178 fn scatter_nd(
179 data: Self::Primitive,
180 indices: IntTensor<B>,
181 values: Self::Primitive,
182 reduction: IndexingUpdateOp,
183 ) -> Self::Primitive {
184 TensorPrimitive::Float(B::float_scatter_nd(
185 data.tensor(),
186 indices,
187 values.tensor(),
188 reduction,
189 ))
190 }
191
192 fn gather_nd(data: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive {
193 TensorPrimitive::Float(B::float_gather_nd(data.tensor(), indices))
194 }
195
196 fn device(tensor: &Self::Primitive) -> Device<B> {
197 match tensor {
198 TensorPrimitive::Float(tensor) => B::float_device(tensor),
199 TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
200 }
201 }
202
203 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
204 match tensor {
205 TensorPrimitive::Float(tensor) => {
206 TensorPrimitive::Float(B::float_to_device(tensor, device))
207 }
208 TensorPrimitive::QFloat(tensor) => {
209 TensorPrimitive::QFloat(B::q_to_device(tensor, device))
210 }
211 }
212 }
213
214 async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
215 match tensor {
216 TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
217 TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
218 }
219 }
220
221 fn from_data(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
222 if matches!(data.dtype, DType::QFloat(_)) {
223 TensorPrimitive::QFloat(B::q_from_data(data, device))
225 } else if dtype.is_float() {
226 TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))
227 } else {
228 panic!("Expected float dtype, got {dtype:?}")
229 }
230 }
231
232 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
233 match tensor {
234 TensorPrimitive::Float(tensor) => {
235 TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
236 }
237 TensorPrimitive::QFloat(tensor) => {
238 TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
239 }
240 }
241 }
242
243 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
244 match vectors.first().unwrap() {
245 TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
246 vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
247 dim,
248 )),
249 TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
250 vectors
251 .into_iter()
252 .map(|tensor| {
253 if let TensorPrimitive::QFloat(t) = tensor {
254 t
255 } else {
256 panic!("Concatenation only works with vector of QFloat")
257 }
258 })
259 .collect(),
260 dim,
261 )),
262 }
263 }
264
265 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
266 let lhs = lhs.tensor();
267 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
268 B::float_equal(lhs, rhs.tensor(), out_dtype)
269 }
270
271 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
272 let lhs = lhs.tensor();
273 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
274 B::float_not_equal(lhs, rhs.tensor(), out_dtype)
275 }
276
277 fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
278 let lhs = lhs.tensor();
279 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
280 B::float_equal_elem(lhs, rhs, out_dtype)
281 }
282
283 fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
284 let lhs = lhs.tensor();
285 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
286 B::float_not_equal_elem(lhs, rhs, out_dtype)
287 }
288
289 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
290 let tensor = tensor.tensor();
291 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
292 B::float_any(tensor, out_dtype)
293 }
294
295 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
296 let tensor = tensor.tensor();
297 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
298 B::float_any_dim(tensor, dim, out_dtype)
299 }
300
301 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
302 let tensor = tensor.tensor();
303 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
304 B::float_all(tensor, out_dtype)
305 }
306
307 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
308 let tensor = tensor.tensor();
309 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
310 B::float_all_dim(tensor, dim, out_dtype)
311 }
312
313 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
314 match tensor {
315 TensorPrimitive::Float(tensor) => {
316 TensorPrimitive::Float(B::float_permute(tensor, axes))
317 }
318 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
319 }
320 }
321
322 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
323 TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
324 }
325
326 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
327 match tensor {
328 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
329 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
330 }
331 }
332
333 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
334 TensorPrimitive::Float(B::float_unfold(tensor.tensor(), dim, size, step))
335 }
336}
337
338impl<B: Backend> Numeric<B> for Float {
339 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
340 q_bin_ops!(lhs, rhs, float_add, q_add)
341 }
342
343 fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
344 match lhs {
345 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_add_scalar(lhs, rhs)),
346 TensorPrimitive::QFloat(lhs) => B::q_add_scalar(lhs, rhs),
347 }
348 }
349
350 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
351 q_bin_ops!(lhs, rhs, float_sub, q_sub)
352 }
353
354 fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
355 match lhs {
356 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs)),
357 TensorPrimitive::QFloat(lhs) => B::q_sub_scalar(lhs, rhs),
358 }
359 }
360
361 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
362 q_bin_ops!(lhs, rhs, float_div, q_div)
363 }
364
365 fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
366 match lhs {
367 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_div_scalar(lhs, rhs)),
368 TensorPrimitive::QFloat(lhs) => B::q_div_scalar(lhs, rhs),
369 }
370 }
371 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
372 TensorPrimitive::Float(B::float_remainder(lhs.tensor(), rhs.tensor()))
373 }
374
375 fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
376 TensorPrimitive::Float(B::float_remainder_scalar(lhs.tensor(), rhs))
377 }
378
379 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
380 q_bin_ops!(lhs, rhs, float_mul, q_mul)
381 }
382
383 fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
384 match lhs {
385 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs)),
386 TensorPrimitive::QFloat(lhs) => B::q_mul_scalar(lhs, rhs),
387 }
388 }
389 fn neg(tensor: Self::Primitive) -> Self::Primitive {
390 match tensor {
391 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),
392 TensorPrimitive::QFloat(tensor) => B::q_neg(tensor),
393 }
394 }
395
396 fn sum(tensor: Self::Primitive) -> Self::Primitive {
397 match tensor {
398 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),
399 TensorPrimitive::QFloat(tensor) => B::q_sum(tensor),
400 }
401 }
402
403 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
404 match tensor {
405 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),
406 TensorPrimitive::QFloat(tensor) => B::q_sum_dim(tensor, dim),
407 }
408 }
409
410 fn prod(tensor: Self::Primitive) -> Self::Primitive {
411 match tensor {
412 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),
413 TensorPrimitive::QFloat(tensor) => B::q_prod(tensor),
414 }
415 }
416
417 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
418 match tensor {
419 TensorPrimitive::Float(tensor) => {
420 TensorPrimitive::Float(B::float_prod_dim(tensor, dim))
421 }
422 TensorPrimitive::QFloat(tensor) => B::q_prod_dim(tensor, dim),
423 }
424 }
425
426 fn mean(tensor: Self::Primitive) -> Self::Primitive {
427 match tensor {
428 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),
429 TensorPrimitive::QFloat(tensor) => B::q_mean(tensor),
430 }
431 }
432
433 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
434 match tensor {
435 TensorPrimitive::Float(tensor) => {
436 TensorPrimitive::Float(B::float_mean_dim(tensor, dim))
437 }
438 TensorPrimitive::QFloat(tensor) => B::q_mean_dim(tensor, dim),
439 }
440 }
441
442 fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
443 match tensor {
444 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumsum(tensor, dim)),
445 TensorPrimitive::QFloat(tensor) => B::q_cumsum(tensor, dim),
446 }
447 }
448
449 fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
450 match tensor {
451 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumprod(tensor, dim)),
452 TensorPrimitive::QFloat(tensor) => B::q_cumprod(tensor, dim),
453 }
454 }
455
456 fn abs(tensor: Self::Primitive) -> Self::Primitive {
457 match tensor {
458 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),
459 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),
460 }
461 }
462
463 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
464 q_bin_ops!(lhs, rhs, float_powf, q_powf)
465 }
466
467 fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
468 match lhs {
469 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs)),
470 TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs),
471 }
472 }
473
474 fn random(
475 shape: Shape,
476 distribution: Distribution,
477 device: &Device<B>,
478 dtype: DType,
479 ) -> Self::Primitive {
480 TensorPrimitive::Float(B::float_random(shape, distribution, device, dtype.into()))
481 }
482
483 fn sign(tensor: Self::Primitive) -> Self::Primitive {
484 TensorPrimitive::Float(B::float_sign(tensor.tensor()))
485 }
486
487 fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
495 match (lhs, rhs) {
496 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
497 TensorPrimitive::Float(B::float_matmul(lhs, rhs))
498 }
499 (lhs, rhs) => B::q_matmul(lhs, rhs),
500 }
501 }
502}
503impl<B: Backend> Ordered<B> for Float {
504 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
505 match tensor {
506 TensorPrimitive::Float(tensor) => {
507 TensorPrimitive::Float(B::float_sort(tensor, dim, descending))
508 }
509 TensorPrimitive::QFloat(tensor) => {
510 TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))
511 }
512 }
513 }
514
515 fn sort_with_indices(
516 tensor: Self::Primitive,
517 dim: usize,
518 descending: bool,
519 ) -> (Self::Primitive, IntTensor<B>) {
520 match tensor {
521 TensorPrimitive::Float(tensor) => {
522 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
523 let (values, indices) =
524 B::float_sort_with_indices(tensor, dim, descending, out_dtype);
525 (TensorPrimitive::Float(values), indices)
526 }
527 TensorPrimitive::QFloat(tensor) => {
528 let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
529 let (values, indices) = B::q_sort_with_indices(tensor, dim, descending, out_dtype);
530 (TensorPrimitive::QFloat(values), indices)
531 }
532 }
533 }
534
535 fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
536 match tensor {
537 TensorPrimitive::Float(tensor) => {
538 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
539 B::float_argsort(tensor, dim, descending, out_dtype)
540 }
541 TensorPrimitive::QFloat(tensor) => {
542 let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
543 B::q_argsort(tensor, dim, descending, out_dtype)
544 }
545 }
546 }
547
548 fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
549 match tensor {
550 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummin(tensor, dim)),
551 TensorPrimitive::QFloat(tensor) => B::q_cummin(tensor, dim),
552 }
553 }
554
555 fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
556 match tensor {
557 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummax(tensor, dim)),
558 TensorPrimitive::QFloat(tensor) => B::q_cummax(tensor, dim),
559 }
560 }
561
562 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
563 let lhs = lhs.tensor();
564 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
565 B::float_greater(lhs, rhs.tensor(), out_dtype)
566 }
567
568 fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
569 let lhs = lhs.tensor();
570 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
571 B::float_greater_elem(lhs, rhs, out_dtype)
572 }
573
574 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
575 let lhs = lhs.tensor();
576 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
577 B::float_greater_equal(lhs, rhs.tensor(), out_dtype)
578 }
579
580 fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
581 let lhs = lhs.tensor();
582 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
583 B::float_greater_equal_elem(lhs, rhs, out_dtype)
584 }
585
586 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
587 let lhs = lhs.tensor();
588 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
589 B::float_lower(lhs, rhs.tensor(), out_dtype)
590 }
591
592 fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
593 let lhs = lhs.tensor();
594 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
595 B::float_lower_elem(lhs, rhs, out_dtype)
596 }
597
598 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
599 let lhs = lhs.tensor();
600 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
601 B::float_lower_equal(lhs, rhs.tensor(), out_dtype)
602 }
603
604 fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
605 let lhs = lhs.tensor();
606 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
607 B::float_lower_equal_elem(lhs, rhs, out_dtype)
608 }
609
610 fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
611 match tensor {
612 TensorPrimitive::Float(tensor) => {
613 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
614 B::float_argmax(tensor, dim, out_dtype)
615 }
616 TensorPrimitive::QFloat(tensor) => {
617 let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
618 B::q_argmax(tensor, dim, out_dtype)
619 }
620 }
621 }
622
623 fn argtopk(tensor: Self::Primitive, dim: usize, k: usize) -> IntTensor<B> {
624 match tensor {
625 TensorPrimitive::Float(tensor) => {
626 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
627 B::float_argtopk(tensor, dim, k, out_dtype)
628 }
629 TensorPrimitive::QFloat(tensor) => {
630 let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
631 B::q_argtopk(tensor, dim, k, out_dtype)
632 }
633 }
634 }
635
636 fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
637 match tensor {
638 TensorPrimitive::Float(tensor) => {
639 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
640 B::float_argmin(tensor, dim, out_dtype)
641 }
642 TensorPrimitive::QFloat(tensor) => {
643 let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
644 B::q_argmin(tensor, dim, out_dtype)
645 }
646 }
647 }
648
649 fn max(tensor: Self::Primitive) -> Self::Primitive {
650 match tensor {
651 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),
652 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),
653 }
654 }
655
656 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
657 match tensor {
658 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),
659 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),
660 }
661 }
662
663 fn topk(tensor: Self::Primitive, dim: usize, k: usize) -> Self::Primitive {
664 match tensor {
665 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_topk(tensor, dim, k)),
666 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_topk(tensor, dim, k)),
667 }
668 }
669
670 fn max_dim_with_indices(
671 tensor: Self::Primitive,
672 dim: usize,
673 ) -> (Self::Primitive, IntTensor<B>) {
674 match tensor {
675 TensorPrimitive::Float(tensor) => {
676 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
677 let (values, indices) = B::float_max_dim_with_indices(tensor, dim, out_dtype);
678 (TensorPrimitive::Float(values), indices)
679 }
680 TensorPrimitive::QFloat(tensor) => {
681 let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
682 let (values, indices) = B::q_max_dim_with_indices(tensor, dim, out_dtype);
683 (TensorPrimitive::QFloat(values), indices)
684 }
685 }
686 }
687
688 fn min(tensor: Self::Primitive) -> Self::Primitive {
689 match tensor {
690 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),
691 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),
692 }
693 }
694
695 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
696 match tensor {
697 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),
698 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),
699 }
700 }
701
702 fn min_dim_with_indices(
703 tensor: Self::Primitive,
704 dim: usize,
705 ) -> (Self::Primitive, IntTensor<B>) {
706 match tensor {
707 TensorPrimitive::Float(tensor) => {
708 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
709 let (values, indices) = B::float_min_dim_with_indices(tensor, dim, out_dtype);
710 (TensorPrimitive::Float(values), indices)
711 }
712 TensorPrimitive::QFloat(tensor) => {
713 let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
714 let (values, indices) = B::q_min_dim_with_indices(tensor, dim, out_dtype);
715 (TensorPrimitive::QFloat(values), indices)
716 }
717 }
718 }
719
720 fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {
721 match tensor {
722 TensorPrimitive::Float(tensor) => {
723 TensorPrimitive::Float(B::float_clamp(tensor, min, max))
724 }
725 TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max),
726 }
727 }
728
729 fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {
730 match tensor {
731 TensorPrimitive::Float(tensor) => {
732 TensorPrimitive::Float(B::float_clamp_min(tensor, min))
733 }
734 TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min),
735 }
736 }
737
738 fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {
739 match tensor {
740 TensorPrimitive::Float(tensor) => {
741 TensorPrimitive::Float(B::float_clamp_max(tensor, max))
742 }
743 TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max),
744 }
745 }
746
747 fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
748 match tensor {
749 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)),
750 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)),
751 }
752 }
753
754 fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
755 match tensor {
756 TensorPrimitive::Float(tensor) => {
757 TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim))
758 }
759 TensorPrimitive::QFloat(tensor) => {
760 TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim))
761 }
762 }
763 }
764}
765
766#[cfg_attr(doc, doc = crate::doc_tensor!())]
772#[cfg_attr(not(doc), doc = "`Tensor`")]
773pub trait FloatMathOps<B: Backend>: Numeric<B> {
775 #[cfg_attr(doc, doc = "$y_i = x^{2}$")]
778 #[cfg_attr(not(doc), doc = "`y = x^2`")]
779 fn square(tensor: Self::Primitive) -> Self::Primitive;
780
781 #[cfg_attr(doc, doc = "$y_i = e^{x_i}$")]
784 #[cfg_attr(not(doc), doc = "`y = e^x`")]
785 fn exp(tensor: Self::Primitive) -> Self::Primitive;
786
787 #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i + 1\)$"#)]
790 #[cfg_attr(not(doc), doc = "`y_i = log(x_i + 1)`")]
791 fn log1p(tensor: Self::Primitive) -> Self::Primitive;
792
793 #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i\)$"#)]
796 #[cfg_attr(not(doc), doc = "`y_i = log(x_i)`")]
797 fn log(tensor: Self::Primitive) -> Self::Primitive;
798
799 #[cfg_attr(doc, doc = r#"$y_i = \sqrt{x_i}$"#)]
802 #[cfg_attr(not(doc), doc = "`y_i = sqrt(x_i)`")]
803 fn sqrt(tensor: Self::Primitive) -> Self::Primitive;
804 #[cfg_attr(doc, doc = crate::doc_tensor!("cos"))]
822 #[cfg_attr(not(doc), doc = "`Tensor::cos`")]
823 fn cos(tensor: Self::Primitive) -> Self::Primitive;
825
826 #[cfg_attr(doc, doc = crate::doc_tensor!("sin"))]
844 #[cfg_attr(not(doc), doc = "`Tensor::sin`")]
845 fn sin(tensor: Self::Primitive) -> Self::Primitive;
847
848 #[cfg_attr(doc, doc = crate::doc_tensor!("tan"))]
866 #[cfg_attr(not(doc), doc = "`Tensor::tan`")]
867 fn tan(tensor: Self::Primitive) -> Self::Primitive;
869
870 #[cfg_attr(doc, doc = crate::doc_tensor!("cosh"))]
888 #[cfg_attr(not(doc), doc = "`Tensor::cosh`")]
889 fn cosh(tensor: Self::Primitive) -> Self::Primitive;
891
892 #[cfg_attr(doc, doc = crate::doc_tensor!("sinh"))]
910 #[cfg_attr(not(doc), doc = "`Tensor::sinh`")]
911 fn sinh(tensor: Self::Primitive) -> Self::Primitive;
913
914 #[cfg_attr(doc, doc = crate::doc_tensor!("tanh"))]
932 #[cfg_attr(not(doc), doc = "`Tensor::tanh`")]
933 fn tanh(tensor: Self::Primitive) -> Self::Primitive;
935
936 #[cfg_attr(doc, doc = crate::doc_tensor!("acos"))]
954 #[cfg_attr(not(doc), doc = "`Tensor::acos`")]
955 fn acos(tensor: Self::Primitive) -> Self::Primitive;
957
958 #[cfg_attr(doc, doc = crate::doc_tensor!("acosh"))]
976 #[cfg_attr(not(doc), doc = "`Tensor::acosh`")]
977 fn acosh(tensor: Self::Primitive) -> Self::Primitive;
979
980 #[cfg_attr(doc, doc = crate::doc_tensor!("asin"))]
998 #[cfg_attr(not(doc), doc = "`Tensor::asin`")]
999 fn asin(tensor: Self::Primitive) -> Self::Primitive;
1001
1002 #[cfg_attr(doc, doc = crate::doc_tensor!("asinh"))]
1020 #[cfg_attr(not(doc), doc = "`Tensor::asinh`")]
1021 fn asinh(tensor: Self::Primitive) -> Self::Primitive;
1023
1024 #[cfg_attr(doc, doc = crate::doc_tensor!("atan"))]
1042 #[cfg_attr(not(doc), doc = "`Tensor::atan`")]
1043 fn atan(tensor: Self::Primitive) -> Self::Primitive;
1045
1046 #[cfg_attr(doc, doc = crate::doc_tensor!("atanh"))]
1064 #[cfg_attr(not(doc), doc = "`Tensor::atanh`")]
1065 fn atanh(tensor: Self::Primitive) -> Self::Primitive;
1067
1068 #[cfg_attr(doc, doc = crate::doc_tensor!("atan2"))]
1087 #[cfg_attr(not(doc), doc = "`Tensor::atan2`")]
1088 fn atan2(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
1090}
1091
1092impl<B: Backend> FloatMathOps<B> for Float {
1093 fn square(tensor: Self::Primitive) -> Self::Primitive {
1094 TensorPrimitive::Float(B::float_powi_scalar(tensor.tensor(), 2.into()))
1095 }
1096 fn sqrt(tensor: Self::Primitive) -> Self::Primitive {
1097 TensorPrimitive::Float(B::float_sqrt(tensor.tensor()))
1098 }
1099 fn cos(tensor: Self::Primitive) -> Self::Primitive {
1100 TensorPrimitive::Float(B::float_cos(tensor.tensor()))
1101 }
1102
1103 fn sin(tensor: Self::Primitive) -> Self::Primitive {
1104 TensorPrimitive::Float(B::float_sin(tensor.tensor()))
1105 }
1106
1107 fn tan(tensor: Self::Primitive) -> Self::Primitive {
1108 TensorPrimitive::Float(B::float_tan(tensor.tensor()))
1109 }
1110
1111 fn cosh(tensor: Self::Primitive) -> Self::Primitive {
1112 TensorPrimitive::Float(B::float_cosh(tensor.tensor()))
1113 }
1114
1115 fn sinh(tensor: Self::Primitive) -> Self::Primitive {
1116 TensorPrimitive::Float(B::float_sinh(tensor.tensor()))
1117 }
1118
1119 fn tanh(tensor: Self::Primitive) -> Self::Primitive {
1120 TensorPrimitive::Float(B::float_tanh(tensor.tensor()))
1121 }
1122
1123 fn acos(tensor: Self::Primitive) -> Self::Primitive {
1124 TensorPrimitive::Float(B::float_acos(tensor.tensor()))
1125 }
1126
1127 fn acosh(tensor: Self::Primitive) -> Self::Primitive {
1128 TensorPrimitive::Float(B::float_acosh(tensor.tensor()))
1129 }
1130
1131 fn asin(tensor: Self::Primitive) -> Self::Primitive {
1132 TensorPrimitive::Float(B::float_asin(tensor.tensor()))
1133 }
1134
1135 fn asinh(tensor: Self::Primitive) -> Self::Primitive {
1136 TensorPrimitive::Float(B::float_asinh(tensor.tensor()))
1137 }
1138
1139 fn atan(tensor: Self::Primitive) -> Self::Primitive {
1140 TensorPrimitive::Float(B::float_atan(tensor.tensor()))
1141 }
1142
1143 fn atanh(tensor: Self::Primitive) -> Self::Primitive {
1144 TensorPrimitive::Float(B::float_atanh(tensor.tensor()))
1145 }
1146
1147 fn atan2(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
1148 TensorPrimitive::Float(B::float_atan2(lhs.tensor(), rhs.tensor()))
1149 }
1150
1151 fn exp(tensor: Self::Primitive) -> Self::Primitive {
1152 TensorPrimitive::Float(B::float_exp(tensor.tensor()))
1153 }
1154
1155 fn log(tensor: Self::Primitive) -> Self::Primitive {
1156 TensorPrimitive::Float(B::float_log(tensor.tensor()))
1157 }
1158
1159 fn log1p(tensor: Self::Primitive) -> Self::Primitive {
1160 TensorPrimitive::Float(B::float_log1p(tensor.tensor()))
1161 }
1162}
1163
1164impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
1165 type InnerKind = Float;
1166
1167 fn inner(
1168 tensor: <Self as TensorKind<B>>::Primitive,
1169 ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
1170 match tensor {
1171 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)),
1172 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)),
1173 }
1174 }
1175
1176 fn from_inner(
1177 inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
1178 ) -> <Self as TensorKind<B>>::Primitive {
1179 match inner {
1180 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)),
1181 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)),
1182 }
1183 }
1184}