Skip to main content

burn_backend/tensor/ops/
float.rs

1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5    AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData, TensorMetadata,
6    TensorPrimitive, get_device_settings,
7    ops::TransactionPrimitive,
8    tensor::{
9        BasicAutodiffOps, BasicOps, Device, Float, IndexingUpdateOp, IntTensor, Numeric, Ordered,
10        TensorKind, TransactionOp,
11    },
12};
13
14macro_rules! q_bin_ops {
15    ($lhs:ident, $rhs:ident, $op:ident, $q_op:ident) => {
16        match ($lhs, $rhs) {
17            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
18                TensorPrimitive::Float(B::$op(lhs, rhs))
19            }
20            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::$q_op(lhs, rhs),
21            (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
22                let dtype = rhs.dtype();
23                TensorPrimitive::Float(B::$op(B::dequantize(lhs, dtype.into()), rhs))
24            }
25            (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
26                let dtype = lhs.dtype();
27                TensorPrimitive::Float(B::$op(lhs, B::dequantize(rhs, dtype.into())))
28            }
29        }
30    };
31}
32impl<B: Backend> TransactionOp<B> for Float {
33    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
34        tr.register_float(tensor);
35    }
36}
37impl<B: Backend> BasicOps<B> for Float {
38    type Elem = B::FloatElem;
39
40    fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
41        TensorPrimitive::Float(B::float_empty(shape, device, dtype.into()))
42    }
43
44    fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
45        TensorPrimitive::Float(B::float_zeros(shape, device, dtype.into()))
46    }
47    fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
48        TensorPrimitive::Float(B::float_ones(shape, device, dtype.into()))
49    }
50
51    fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
52        TensorPrimitive::Float(B::float_full(shape, fill_value, device, dtype.into()))
53    }
54
55    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
56        match tensor {
57            TensorPrimitive::Float(tensor) => {
58                TensorPrimitive::Float(B::float_reshape(tensor, shape))
59            }
60            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
61        }
62    }
63
64    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
65        match tensor {
66            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
67            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
68        }
69    }
70
71    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
72        match tensor {
73            TensorPrimitive::Float(tensor) => {
74                TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
75            }
76            TensorPrimitive::QFloat(tensor) => {
77                TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
78            }
79        }
80    }
81
82    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
83        match tensor {
84            TensorPrimitive::Float(tensor) => {
85                TensorPrimitive::Float(B::float_slice(tensor, slices))
86            }
87            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, slices)),
88        }
89    }
90
91    fn slice_assign(
92        tensor: Self::Primitive,
93        slices: &[Slice],
94        value: Self::Primitive,
95    ) -> Self::Primitive {
96        TensorPrimitive::Float(B::float_slice_assign(
97            tensor.tensor(),
98            slices,
99            value.tensor(),
100        ))
101    }
102
103    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
104        match tensor {
105            TensorPrimitive::Float(tensor) => {
106                TensorPrimitive::Float(B::float_select(tensor, dim, indices))
107            }
108            TensorPrimitive::QFloat(tensor) => {
109                TensorPrimitive::QFloat(B::q_select(tensor, dim, indices))
110            }
111        }
112    }
113
114    fn select_assign(
115        tensor: Self::Primitive,
116        dim: usize,
117        indices: IntTensor<B>,
118        values: Self::Primitive,
119        update: IndexingUpdateOp,
120    ) -> Self::Primitive {
121        // Select assign is ambiguous for QFloat
122        match update {
123            IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_select_add(
124                tensor.tensor(),
125                dim,
126                indices,
127                values.tensor(),
128            )),
129            _ => unimplemented!(),
130        }
131    }
132
133    fn mask_where(
134        tensor: Self::Primitive,
135        mask: B::BoolTensorPrimitive,
136        source: Self::Primitive,
137    ) -> Self::Primitive {
138        TensorPrimitive::Float(B::float_mask_where(tensor.tensor(), mask, source.tensor()))
139    }
140
141    fn mask_fill(
142        tensor: Self::Primitive,
143        mask: B::BoolTensorPrimitive,
144        value: Scalar,
145    ) -> Self::Primitive {
146        TensorPrimitive::Float(B::float_mask_fill(tensor.tensor(), mask, value))
147    }
148
149    fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive {
150        match tensor {
151            TensorPrimitive::Float(tensor) => {
152                TensorPrimitive::Float(B::float_gather(dim, tensor, indices))
153            }
154            TensorPrimitive::QFloat(tensor) => {
155                TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))
156            }
157        }
158    }
159
160    fn scatter(
161        dim: usize,
162        tensor: Self::Primitive,
163        indices: IntTensor<B>,
164        values: Self::Primitive,
165        update: IndexingUpdateOp,
166    ) -> Self::Primitive {
167        match update {
168            IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_scatter_add(
169                dim,
170                tensor.tensor(),
171                indices,
172                values.tensor(),
173            )),
174            _ => unimplemented!(),
175        }
176    }
177
178    fn scatter_nd(
179        data: Self::Primitive,
180        indices: IntTensor<B>,
181        values: Self::Primitive,
182        reduction: IndexingUpdateOp,
183    ) -> Self::Primitive {
184        TensorPrimitive::Float(B::float_scatter_nd(
185            data.tensor(),
186            indices,
187            values.tensor(),
188            reduction,
189        ))
190    }
191
192    fn gather_nd(data: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive {
193        TensorPrimitive::Float(B::float_gather_nd(data.tensor(), indices))
194    }
195
196    fn device(tensor: &Self::Primitive) -> Device<B> {
197        match tensor {
198            TensorPrimitive::Float(tensor) => B::float_device(tensor),
199            TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
200        }
201    }
202
203    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
204        match tensor {
205            TensorPrimitive::Float(tensor) => {
206                TensorPrimitive::Float(B::float_to_device(tensor, device))
207            }
208            TensorPrimitive::QFloat(tensor) => {
209                TensorPrimitive::QFloat(B::q_to_device(tensor, device))
210            }
211        }
212    }
213
214    async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
215        match tensor {
216            TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
217            TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
218        }
219    }
220
221    fn from_data(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
222        if matches!(data.dtype, DType::QFloat(_)) {
223            // When the source is QFloat, there is no conversion path possible.
224            TensorPrimitive::QFloat(B::q_from_data(data, device))
225        } else if dtype.is_float() {
226            TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))
227        } else {
228            panic!("Expected float dtype, got {dtype:?}")
229        }
230    }
231
232    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
233        match tensor {
234            TensorPrimitive::Float(tensor) => {
235                TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
236            }
237            TensorPrimitive::QFloat(tensor) => {
238                TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
239            }
240        }
241    }
242
243    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
244        match vectors.first().unwrap() {
245            TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
246                vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
247                dim,
248            )),
249            TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
250                vectors
251                    .into_iter()
252                    .map(|tensor| {
253                        if let TensorPrimitive::QFloat(t) = tensor {
254                            t
255                        } else {
256                            panic!("Concatenation only works with vector of QFloat")
257                        }
258                    })
259                    .collect(),
260                dim,
261            )),
262        }
263    }
264
265    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
266        let lhs = lhs.tensor();
267        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
268        B::float_equal(lhs, rhs.tensor(), out_dtype)
269    }
270
271    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
272        let lhs = lhs.tensor();
273        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
274        B::float_not_equal(lhs, rhs.tensor(), out_dtype)
275    }
276
277    fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
278        let lhs = lhs.tensor();
279        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
280        B::float_equal_elem(lhs, rhs, out_dtype)
281    }
282
283    fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
284        let lhs = lhs.tensor();
285        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
286        B::float_not_equal_elem(lhs, rhs, out_dtype)
287    }
288
289    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
290        let tensor = tensor.tensor();
291        let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
292        B::float_any(tensor, out_dtype)
293    }
294
295    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
296        let tensor = tensor.tensor();
297        let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
298        B::float_any_dim(tensor, dim, out_dtype)
299    }
300
301    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
302        let tensor = tensor.tensor();
303        let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
304        B::float_all(tensor, out_dtype)
305    }
306
307    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
308        let tensor = tensor.tensor();
309        let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
310        B::float_all_dim(tensor, dim, out_dtype)
311    }
312
313    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
314        match tensor {
315            TensorPrimitive::Float(tensor) => {
316                TensorPrimitive::Float(B::float_permute(tensor, axes))
317            }
318            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
319        }
320    }
321
322    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
323        TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
324    }
325
326    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
327        match tensor {
328            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
329            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
330        }
331    }
332
333    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
334        TensorPrimitive::Float(B::float_unfold(tensor.tensor(), dim, size, step))
335    }
336}
337
338impl<B: Backend> Numeric<B> for Float {
339    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
340        q_bin_ops!(lhs, rhs, float_add, q_add)
341    }
342
343    fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
344        match lhs {
345            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_add_scalar(lhs, rhs)),
346            TensorPrimitive::QFloat(lhs) => B::q_add_scalar(lhs, rhs),
347        }
348    }
349
350    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
351        q_bin_ops!(lhs, rhs, float_sub, q_sub)
352    }
353
354    fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
355        match lhs {
356            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs)),
357            TensorPrimitive::QFloat(lhs) => B::q_sub_scalar(lhs, rhs),
358        }
359    }
360
361    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
362        q_bin_ops!(lhs, rhs, float_div, q_div)
363    }
364
365    fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
366        match lhs {
367            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_div_scalar(lhs, rhs)),
368            TensorPrimitive::QFloat(lhs) => B::q_div_scalar(lhs, rhs),
369        }
370    }
371    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
372        TensorPrimitive::Float(B::float_remainder(lhs.tensor(), rhs.tensor()))
373    }
374
375    fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
376        TensorPrimitive::Float(B::float_remainder_scalar(lhs.tensor(), rhs))
377    }
378
379    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
380        q_bin_ops!(lhs, rhs, float_mul, q_mul)
381    }
382
383    fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
384        match lhs {
385            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs)),
386            TensorPrimitive::QFloat(lhs) => B::q_mul_scalar(lhs, rhs),
387        }
388    }
389    fn neg(tensor: Self::Primitive) -> Self::Primitive {
390        match tensor {
391            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),
392            TensorPrimitive::QFloat(tensor) => B::q_neg(tensor),
393        }
394    }
395
396    fn sum(tensor: Self::Primitive) -> Self::Primitive {
397        match tensor {
398            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),
399            TensorPrimitive::QFloat(tensor) => B::q_sum(tensor),
400        }
401    }
402
403    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
404        match tensor {
405            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),
406            TensorPrimitive::QFloat(tensor) => B::q_sum_dim(tensor, dim),
407        }
408    }
409
410    fn prod(tensor: Self::Primitive) -> Self::Primitive {
411        match tensor {
412            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),
413            TensorPrimitive::QFloat(tensor) => B::q_prod(tensor),
414        }
415    }
416
417    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
418        match tensor {
419            TensorPrimitive::Float(tensor) => {
420                TensorPrimitive::Float(B::float_prod_dim(tensor, dim))
421            }
422            TensorPrimitive::QFloat(tensor) => B::q_prod_dim(tensor, dim),
423        }
424    }
425
426    fn mean(tensor: Self::Primitive) -> Self::Primitive {
427        match tensor {
428            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),
429            TensorPrimitive::QFloat(tensor) => B::q_mean(tensor),
430        }
431    }
432
433    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
434        match tensor {
435            TensorPrimitive::Float(tensor) => {
436                TensorPrimitive::Float(B::float_mean_dim(tensor, dim))
437            }
438            TensorPrimitive::QFloat(tensor) => B::q_mean_dim(tensor, dim),
439        }
440    }
441
442    fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
443        match tensor {
444            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumsum(tensor, dim)),
445            TensorPrimitive::QFloat(tensor) => B::q_cumsum(tensor, dim),
446        }
447    }
448
449    fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
450        match tensor {
451            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumprod(tensor, dim)),
452            TensorPrimitive::QFloat(tensor) => B::q_cumprod(tensor, dim),
453        }
454    }
455
456    fn abs(tensor: Self::Primitive) -> Self::Primitive {
457        match tensor {
458            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),
459            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),
460        }
461    }
462
463    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
464        q_bin_ops!(lhs, rhs, float_powf, q_powf)
465    }
466
467    fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
468        match lhs {
469            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs)),
470            TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs),
471        }
472    }
473
474    fn random(
475        shape: Shape,
476        distribution: Distribution,
477        device: &Device<B>,
478        dtype: DType,
479    ) -> Self::Primitive {
480        TensorPrimitive::Float(B::float_random(shape, distribution, device, dtype.into()))
481    }
482
483    fn sign(tensor: Self::Primitive) -> Self::Primitive {
484        TensorPrimitive::Float(B::float_sign(tensor.tensor()))
485    }
486
487    /// Applies the matrix multiplication operation.
488    ///
489    /// `C = AB`
490    ///
491    /// # Panics
492    ///
493    /// If the two tensors don't have a compatible shape.
494    fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
495        match (lhs, rhs) {
496            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
497                TensorPrimitive::Float(B::float_matmul(lhs, rhs))
498            }
499            (lhs, rhs) => B::q_matmul(lhs, rhs),
500        }
501    }
502}
503impl<B: Backend> Ordered<B> for Float {
504    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
505        match tensor {
506            TensorPrimitive::Float(tensor) => {
507                TensorPrimitive::Float(B::float_sort(tensor, dim, descending))
508            }
509            TensorPrimitive::QFloat(tensor) => {
510                TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))
511            }
512        }
513    }
514
515    fn sort_with_indices(
516        tensor: Self::Primitive,
517        dim: usize,
518        descending: bool,
519    ) -> (Self::Primitive, IntTensor<B>) {
520        match tensor {
521            TensorPrimitive::Float(tensor) => {
522                let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
523                let (values, indices) =
524                    B::float_sort_with_indices(tensor, dim, descending, out_dtype);
525                (TensorPrimitive::Float(values), indices)
526            }
527            TensorPrimitive::QFloat(tensor) => {
528                let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
529                let (values, indices) = B::q_sort_with_indices(tensor, dim, descending, out_dtype);
530                (TensorPrimitive::QFloat(values), indices)
531            }
532        }
533    }
534
535    fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
536        match tensor {
537            TensorPrimitive::Float(tensor) => {
538                let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
539                B::float_argsort(tensor, dim, descending, out_dtype)
540            }
541            TensorPrimitive::QFloat(tensor) => {
542                let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
543                B::q_argsort(tensor, dim, descending, out_dtype)
544            }
545        }
546    }
547
548    fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
549        match tensor {
550            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummin(tensor, dim)),
551            TensorPrimitive::QFloat(tensor) => B::q_cummin(tensor, dim),
552        }
553    }
554
555    fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
556        match tensor {
557            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummax(tensor, dim)),
558            TensorPrimitive::QFloat(tensor) => B::q_cummax(tensor, dim),
559        }
560    }
561
562    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
563        let lhs = lhs.tensor();
564        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
565        B::float_greater(lhs, rhs.tensor(), out_dtype)
566    }
567
568    fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
569        let lhs = lhs.tensor();
570        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
571        B::float_greater_elem(lhs, rhs, out_dtype)
572    }
573
574    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
575        let lhs = lhs.tensor();
576        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
577        B::float_greater_equal(lhs, rhs.tensor(), out_dtype)
578    }
579
580    fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
581        let lhs = lhs.tensor();
582        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
583        B::float_greater_equal_elem(lhs, rhs, out_dtype)
584    }
585
586    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
587        let lhs = lhs.tensor();
588        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
589        B::float_lower(lhs, rhs.tensor(), out_dtype)
590    }
591
592    fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
593        let lhs = lhs.tensor();
594        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
595        B::float_lower_elem(lhs, rhs, out_dtype)
596    }
597
598    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
599        let lhs = lhs.tensor();
600        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
601        B::float_lower_equal(lhs, rhs.tensor(), out_dtype)
602    }
603
604    fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
605        let lhs = lhs.tensor();
606        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
607        B::float_lower_equal_elem(lhs, rhs, out_dtype)
608    }
609
610    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
611        match tensor {
612            TensorPrimitive::Float(tensor) => {
613                let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
614                B::float_argmax(tensor, dim, out_dtype)
615            }
616            TensorPrimitive::QFloat(tensor) => {
617                let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
618                B::q_argmax(tensor, dim, out_dtype)
619            }
620        }
621    }
622
623    fn argtopk(tensor: Self::Primitive, dim: usize, k: usize) -> IntTensor<B> {
624        match tensor {
625            TensorPrimitive::Float(tensor) => {
626                let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
627                B::float_argtopk(tensor, dim, k, out_dtype)
628            }
629            TensorPrimitive::QFloat(tensor) => {
630                let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
631                B::q_argtopk(tensor, dim, k, out_dtype)
632            }
633        }
634    }
635
636    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
637        match tensor {
638            TensorPrimitive::Float(tensor) => {
639                let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
640                B::float_argmin(tensor, dim, out_dtype)
641            }
642            TensorPrimitive::QFloat(tensor) => {
643                let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
644                B::q_argmin(tensor, dim, out_dtype)
645            }
646        }
647    }
648
649    fn max(tensor: Self::Primitive) -> Self::Primitive {
650        match tensor {
651            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),
652            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),
653        }
654    }
655
656    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
657        match tensor {
658            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),
659            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),
660        }
661    }
662
663    fn topk(tensor: Self::Primitive, dim: usize, k: usize) -> Self::Primitive {
664        match tensor {
665            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_topk(tensor, dim, k)),
666            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_topk(tensor, dim, k)),
667        }
668    }
669
670    fn max_dim_with_indices(
671        tensor: Self::Primitive,
672        dim: usize,
673    ) -> (Self::Primitive, IntTensor<B>) {
674        match tensor {
675            TensorPrimitive::Float(tensor) => {
676                let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
677                let (values, indices) = B::float_max_dim_with_indices(tensor, dim, out_dtype);
678                (TensorPrimitive::Float(values), indices)
679            }
680            TensorPrimitive::QFloat(tensor) => {
681                let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
682                let (values, indices) = B::q_max_dim_with_indices(tensor, dim, out_dtype);
683                (TensorPrimitive::QFloat(values), indices)
684            }
685        }
686    }
687
688    fn min(tensor: Self::Primitive) -> Self::Primitive {
689        match tensor {
690            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),
691            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),
692        }
693    }
694
695    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
696        match tensor {
697            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),
698            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),
699        }
700    }
701
702    fn min_dim_with_indices(
703        tensor: Self::Primitive,
704        dim: usize,
705    ) -> (Self::Primitive, IntTensor<B>) {
706        match tensor {
707            TensorPrimitive::Float(tensor) => {
708                let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
709                let (values, indices) = B::float_min_dim_with_indices(tensor, dim, out_dtype);
710                (TensorPrimitive::Float(values), indices)
711            }
712            TensorPrimitive::QFloat(tensor) => {
713                let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
714                let (values, indices) = B::q_min_dim_with_indices(tensor, dim, out_dtype);
715                (TensorPrimitive::QFloat(values), indices)
716            }
717        }
718    }
719
720    fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {
721        match tensor {
722            TensorPrimitive::Float(tensor) => {
723                TensorPrimitive::Float(B::float_clamp(tensor, min, max))
724            }
725            TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max),
726        }
727    }
728
729    fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {
730        match tensor {
731            TensorPrimitive::Float(tensor) => {
732                TensorPrimitive::Float(B::float_clamp_min(tensor, min))
733            }
734            TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min),
735        }
736    }
737
738    fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {
739        match tensor {
740            TensorPrimitive::Float(tensor) => {
741                TensorPrimitive::Float(B::float_clamp_max(tensor, max))
742            }
743            TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max),
744        }
745    }
746
747    fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
748        match tensor {
749            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)),
750            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)),
751        }
752    }
753
754    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
755        match tensor {
756            TensorPrimitive::Float(tensor) => {
757                TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim))
758            }
759            TensorPrimitive::QFloat(tensor) => {
760                TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim))
761            }
762        }
763    }
764}
765
766/// Trait that lists some floating-point mathematical operations are common to all float-like dtypes.
767///
768/// # Warnings
769///
770/// This is an internal trait, use the public API provided by the
771#[cfg_attr(doc, doc = crate::doc_tensor!())]
772#[cfg_attr(not(doc), doc = "`Tensor`")]
773/// struct.
774pub trait FloatMathOps<B: Backend>: Numeric<B> {
775    /// Applies element wise square operation
776    ///
777    #[cfg_attr(doc, doc = "$y_i = x^{2}$")]
778    #[cfg_attr(not(doc), doc = "`y = x^2`")]
779    fn square(tensor: Self::Primitive) -> Self::Primitive;
780
781    /// Applies element wise exponential operation.
782    ///
783    #[cfg_attr(doc, doc = "$y_i = e^{x_i}$")]
784    #[cfg_attr(not(doc), doc = "`y = e^x`")]
785    fn exp(tensor: Self::Primitive) -> Self::Primitive;
786
787    /// Applies the natural logarithm of one plus the input tensor, element-wise.
788    ///
789    #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i + 1\)$"#)]
790    #[cfg_attr(not(doc), doc = "`y_i = log(x_i + 1)`")]
791    fn log1p(tensor: Self::Primitive) -> Self::Primitive;
792
793    /// Applies element wise natural log operation *ln*.
794    ///
795    #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i\)$"#)]
796    #[cfg_attr(not(doc), doc = "`y_i = log(x_i)`")]
797    fn log(tensor: Self::Primitive) -> Self::Primitive;
798
799    /// Applies element wise root square operation.
800    ///
801    #[cfg_attr(doc, doc = r#"$y_i = \sqrt{x_i}$"#)]
802    #[cfg_attr(not(doc), doc = "`y_i = sqrt(x_i)`")]
803    fn sqrt(tensor: Self::Primitive) -> Self::Primitive;
804    /// Returns a new tensor with cosine values.
805    ///
806    /// # Arguments
807    ///
808    /// * `tensor` - The input tensor.
809    ///
810    /// # Returns
811    ///
812    /// A tensor with the same shape as `tensor` with cosine values.
813    ///
814    /// # Remarks
815    ///
816    /// This is a low-level function used internally by the library to call different backend functions
817    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
818    /// or use this function directly.
819    ///
820    /// For the cosine of a tensor, users should prefer the
821    #[cfg_attr(doc, doc = crate::doc_tensor!("cos"))]
822    #[cfg_attr(not(doc), doc = "`Tensor::cos`")]
823    /// function, which is more high-level and designed for public use.
824    fn cos(tensor: Self::Primitive) -> Self::Primitive;
825
826    /// Returns a new tensor with sine values.
827    ///
828    /// # Arguments
829    ///
830    /// * `tensor` - The input tensor.
831    ///
832    /// # Returns
833    ///
834    /// A tensor with the same shape as `tensor` with sine values.
835    ///
836    /// # Remarks
837    ///
838    /// This is a low-level function used internally by the library to call different backend functions
839    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
840    /// or use this function directly.
841    ///
842    /// For the sine of a tensor, users should prefer the
843    #[cfg_attr(doc, doc = crate::doc_tensor!("sin"))]
844    #[cfg_attr(not(doc), doc = "`Tensor::sin`")]
845    /// function, which is more high-level and designed for public use.
846    fn sin(tensor: Self::Primitive) -> Self::Primitive;
847
848    /// Returns a new tensor with tangent values.
849    ///
850    /// # Arguments
851    ///
852    /// * `tensor` - The input tensor.
853    ///
854    /// # Returns
855    ///
856    /// A tensor with the same shape as `tensor` with tangent values.
857    ///
858    /// # Remarks
859    ///
860    /// This is a low-level function used internally by the library to call different backend functions
861    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
862    /// or use this function directly.
863    ///
864    /// For the tangent of a tensor, users should prefer the
865    #[cfg_attr(doc, doc = crate::doc_tensor!("tan"))]
866    #[cfg_attr(not(doc), doc = "`Tensor::tan`")]
867    /// function, which is more high-level and designed for public use.
868    fn tan(tensor: Self::Primitive) -> Self::Primitive;
869
870    /// Returns a new tensor with hyperbolic cosine values.
871    ///
872    /// # Arguments
873    ///
874    /// * `tensor` - The input tensor.
875    ///
876    /// # Returns
877    ///
878    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
879    ///
880    /// # Remarks
881    ///
882    /// This is a low-level function used internally by the library to call different backend functions
883    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
884    /// or use this function directly.
885    ///
886    /// For the hyperbolic cosine of a tensor, users should prefer the
887    #[cfg_attr(doc, doc = crate::doc_tensor!("cosh"))]
888    #[cfg_attr(not(doc), doc = "`Tensor::cosh`")]
889    /// function, which is more high-level and designed for public use.
890    fn cosh(tensor: Self::Primitive) -> Self::Primitive;
891
892    /// Returns a new tensor with hyperbolic sine values.
893    ///
894    /// # Arguments
895    ///
896    /// * `tensor` - The input tensor.
897    ///
898    /// # Returns
899    ///
900    /// A tensor with the same shape as `tensor` with hyperbolic sine values.
901    ///
902    /// # Remarks
903    ///
904    /// This is a low-level function used internally by the library to call different backend functions
905    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
906    /// or use this function directly.
907    ///
908    /// For the hyperbolic sine of a tensor, users should prefer the
909    #[cfg_attr(doc, doc = crate::doc_tensor!("sinh"))]
910    #[cfg_attr(not(doc), doc = "`Tensor::sinh`")]
911    /// function, which is more high-level and designed for public use.
912    fn sinh(tensor: Self::Primitive) -> Self::Primitive;
913
914    /// Returns a new tensor with hyperbolic tangent values.
915    ///
916    /// # Arguments
917    ///
918    /// * `tensor` - The input tensor.
919    ///
920    /// # Returns
921    ///
922    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
923    ///
924    /// # Remarks
925    ///
926    /// This is a low-level function used internally by the library to call different backend functions
927    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
928    /// or use this function directly.
929    ///
930    /// For the hyperbolic tangent of a tensor, users should prefer the
931    #[cfg_attr(doc, doc = crate::doc_tensor!("tanh"))]
932    #[cfg_attr(not(doc), doc = "`Tensor::tanh`")]
933    /// function, which is more high-level and designed for public use.
934    fn tanh(tensor: Self::Primitive) -> Self::Primitive;
935
936    /// Returns a new tensor with inverse cosine values.
937    ///
938    /// # Arguments
939    ///
940    /// * `tensor` - The input tensor.
941    ///
942    /// # Returns
943    ///
944    /// A tensor with the same shape as `tensor` with inverse cosine values.
945    ///
946    /// # Remarks
947    ///
948    /// This is a low-level function used internally by the library to call different backend functions
949    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
950    /// or use this function directly.
951    ///
952    /// For the inverse cosine of a tensor, users should prefer the
953    #[cfg_attr(doc, doc = crate::doc_tensor!("acos"))]
954    #[cfg_attr(not(doc), doc = "`Tensor::acos`")]
955    /// function, which is more high-level and designed for public use.
956    fn acos(tensor: Self::Primitive) -> Self::Primitive;
957
958    /// Returns a new tensor with inverse hyperbolic cosine values.
959    ///
960    /// # Arguments
961    ///
962    /// * `tensor` - The input tensor.
963    ///
964    /// # Returns
965    ///
966    /// A tensor with the same shape as `tensor` with inverse hyperbolic cosine values.
967    ///
968    /// # Remarks
969    ///
970    /// This is a low-level function used internally by the library to call different backend functions
971    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
972    /// or use this function directly.
973    ///
974    /// For the inverse hyperbolic cosine of a tensor, users should prefer the
975    #[cfg_attr(doc, doc = crate::doc_tensor!("acosh"))]
976    #[cfg_attr(not(doc), doc = "`Tensor::acosh`")]
977    /// function, which is more high-level and designed for public use.
978    fn acosh(tensor: Self::Primitive) -> Self::Primitive;
979
980    /// Returns a new tensor with inverse sine values.
981    ///
982    /// # Arguments
983    ///
984    /// * `tensor` - The input tensor.
985    ///
986    /// # Returns
987    ///
988    /// A tensor with the same shape as `tensor` with inverse sine values.
989    ///
990    /// # Remarks
991    ///
992    /// This is a low-level function used internally by the library to call different backend functions
993    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
994    /// or use this function directly.
995    ///
996    /// For the inverse sine of a tensor, users should prefer the
997    #[cfg_attr(doc, doc = crate::doc_tensor!("asin"))]
998    #[cfg_attr(not(doc), doc = "`Tensor::asin`")]
999    /// function, which is more high-level and designed for public use.
1000    fn asin(tensor: Self::Primitive) -> Self::Primitive;
1001
1002    /// Returns a new tensor with inverse hyperbolic sine values.
1003    ///
1004    /// # Arguments
1005    ///
1006    /// * `tensor` - The input tensor.
1007    ///
1008    /// # Returns
1009    ///
1010    /// A tensor with the same shape as `tensor` with inverse hyperbolic sine values.
1011    ///
1012    /// # Remarks
1013    ///
1014    /// This is a low-level function used internally by the library to call different backend functions
1015    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1016    /// or use this function directly.
1017    ///
1018    /// For the inverse hyperbolic sine of a tensor, users should prefer the
1019    #[cfg_attr(doc, doc = crate::doc_tensor!("asinh"))]
1020    #[cfg_attr(not(doc), doc = "`Tensor::asinh`")]
1021    /// function, which is more high-level and designed for public use.
1022    fn asinh(tensor: Self::Primitive) -> Self::Primitive;
1023
1024    /// Returns a new tensor with inverse tangent values.
1025    ///
1026    /// # Arguments
1027    ///
1028    /// * `tensor` - The input tensor.
1029    ///
1030    /// # Returns
1031    ///
1032    /// A tensor with the same shape as `tensor` with inverse tangent values.
1033    ///
1034    /// # Remarks
1035    ///
1036    /// This is a low-level function used internally by the library to call different backend functions
1037    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1038    /// or use this function directly.
1039    ///
1040    /// For the inverse tangent of a tensor, users should prefer the
1041    #[cfg_attr(doc, doc = crate::doc_tensor!("atan"))]
1042    #[cfg_attr(not(doc), doc = "`Tensor::atan`")]
1043    /// function, which is more high-level and designed for public use.
1044    fn atan(tensor: Self::Primitive) -> Self::Primitive;
1045
1046    /// Returns a new tensor with inverse hyperbolic tangent values.
1047    ///
1048    /// # Arguments
1049    ///
1050    /// * `tensor` - The input tensor.
1051    ///
1052    /// # Returns
1053    ///
1054    /// A tensor with the same shape as `tensor` with inverse hyperbolic tangent values.
1055    ///
1056    /// # Remarks
1057    ///
1058    /// This is a low-level function used internally by the library to call different backend functions
1059    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1060    /// or use this function directly.
1061    ///
1062    /// For the inverse hyperbolic tangent of a tensor, users should prefer the
1063    #[cfg_attr(doc, doc = crate::doc_tensor!("atanh"))]
1064    #[cfg_attr(not(doc), doc = "`Tensor::atanh`")]
1065    /// function, which is more high-level and designed for public use.
1066    fn atanh(tensor: Self::Primitive) -> Self::Primitive;
1067
1068    /// Returns a tensor with the four-quadrant inverse tangent values of `y` and `x`.
1069    ///
1070    /// # Arguments
1071    ///
1072    /// * `lhs` - The tensor with y coordinates.
1073    /// * `rhs` - The tensor with x coordinates.
1074    ///
1075    /// # Returns
1076    ///
1077    /// A tensor with the four-quadrant inverse tangent values.
1078    ///
1079    /// # Remarks
1080    ///
1081    /// This is a low-level function used internally by the library to call different backend functions
1082    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1083    /// or use this function directly.
1084    ///
1085    /// For the four-quadrant inverse tangent of two tensors, users should prefer the
1086    #[cfg_attr(doc, doc = crate::doc_tensor!("atan2"))]
1087    #[cfg_attr(not(doc), doc = "`Tensor::atan2`")]
1088    /// function, which is more high-level and designed for public use.
1089    fn atan2(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
1090}
1091
1092impl<B: Backend> FloatMathOps<B> for Float {
1093    fn square(tensor: Self::Primitive) -> Self::Primitive {
1094        TensorPrimitive::Float(B::float_powi_scalar(tensor.tensor(), 2.into()))
1095    }
1096    fn sqrt(tensor: Self::Primitive) -> Self::Primitive {
1097        TensorPrimitive::Float(B::float_sqrt(tensor.tensor()))
1098    }
1099    fn cos(tensor: Self::Primitive) -> Self::Primitive {
1100        TensorPrimitive::Float(B::float_cos(tensor.tensor()))
1101    }
1102
1103    fn sin(tensor: Self::Primitive) -> Self::Primitive {
1104        TensorPrimitive::Float(B::float_sin(tensor.tensor()))
1105    }
1106
1107    fn tan(tensor: Self::Primitive) -> Self::Primitive {
1108        TensorPrimitive::Float(B::float_tan(tensor.tensor()))
1109    }
1110
1111    fn cosh(tensor: Self::Primitive) -> Self::Primitive {
1112        TensorPrimitive::Float(B::float_cosh(tensor.tensor()))
1113    }
1114
1115    fn sinh(tensor: Self::Primitive) -> Self::Primitive {
1116        TensorPrimitive::Float(B::float_sinh(tensor.tensor()))
1117    }
1118
1119    fn tanh(tensor: Self::Primitive) -> Self::Primitive {
1120        TensorPrimitive::Float(B::float_tanh(tensor.tensor()))
1121    }
1122
1123    fn acos(tensor: Self::Primitive) -> Self::Primitive {
1124        TensorPrimitive::Float(B::float_acos(tensor.tensor()))
1125    }
1126
1127    fn acosh(tensor: Self::Primitive) -> Self::Primitive {
1128        TensorPrimitive::Float(B::float_acosh(tensor.tensor()))
1129    }
1130
1131    fn asin(tensor: Self::Primitive) -> Self::Primitive {
1132        TensorPrimitive::Float(B::float_asin(tensor.tensor()))
1133    }
1134
1135    fn asinh(tensor: Self::Primitive) -> Self::Primitive {
1136        TensorPrimitive::Float(B::float_asinh(tensor.tensor()))
1137    }
1138
1139    fn atan(tensor: Self::Primitive) -> Self::Primitive {
1140        TensorPrimitive::Float(B::float_atan(tensor.tensor()))
1141    }
1142
1143    fn atanh(tensor: Self::Primitive) -> Self::Primitive {
1144        TensorPrimitive::Float(B::float_atanh(tensor.tensor()))
1145    }
1146
1147    fn atan2(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
1148        TensorPrimitive::Float(B::float_atan2(lhs.tensor(), rhs.tensor()))
1149    }
1150
1151    fn exp(tensor: Self::Primitive) -> Self::Primitive {
1152        TensorPrimitive::Float(B::float_exp(tensor.tensor()))
1153    }
1154
1155    fn log(tensor: Self::Primitive) -> Self::Primitive {
1156        TensorPrimitive::Float(B::float_log(tensor.tensor()))
1157    }
1158
1159    fn log1p(tensor: Self::Primitive) -> Self::Primitive {
1160        TensorPrimitive::Float(B::float_log1p(tensor.tensor()))
1161    }
1162}
1163
1164impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
1165    type InnerKind = Float;
1166
1167    fn inner(
1168        tensor: <Self as TensorKind<B>>::Primitive,
1169    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
1170        match tensor {
1171            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)),
1172            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)),
1173        }
1174    }
1175
1176    fn from_inner(
1177        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
1178    ) -> <Self as TensorKind<B>>::Primitive {
1179        match inner {
1180            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)),
1181            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)),
1182        }
1183    }
1184}