Skip to main content

burn_backend/tensor/ops/
float.rs

1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5    AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData, TensorMetadata,
6    TensorPrimitive, get_device_settings,
7    ops::TransactionPrimitive,
8    tensor::{
9        BasicAutodiffOps, BasicOps, Device, Float, IndexingUpdateOp, IntTensor, Numeric, Ordered,
10        TensorKind, TransactionOp,
11    },
12};
13
14macro_rules! q_bin_ops {
15    ($lhs:ident, $rhs:ident, $op:ident, $q_op:ident) => {
16        match ($lhs, $rhs) {
17            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
18                TensorPrimitive::Float(B::$op(lhs, rhs))
19            }
20            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::$q_op(lhs, rhs),
21            (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
22                let dtype = rhs.dtype();
23                TensorPrimitive::Float(B::$op(B::dequantize(lhs, dtype.into()), rhs))
24            }
25            (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
26                let dtype = lhs.dtype();
27                TensorPrimitive::Float(B::$op(lhs, B::dequantize(rhs, dtype.into())))
28            }
29        }
30    };
31}
32impl<B: Backend> TransactionOp<B> for Float {
33    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
34        tr.register_float(tensor);
35    }
36}
37impl<B: Backend> BasicOps<B> for Float {
38    type Elem = B::FloatElem;
39
40    fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
41        TensorPrimitive::Float(B::float_empty(shape, device, dtype.into()))
42    }
43
44    fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
45        TensorPrimitive::Float(B::float_zeros(shape, device, dtype.into()))
46    }
47    fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
48        TensorPrimitive::Float(B::float_ones(shape, device, dtype.into()))
49    }
50
51    fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
52        TensorPrimitive::Float(B::float_full(shape, fill_value, device, dtype.into()))
53    }
54
55    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
56        match tensor {
57            TensorPrimitive::Float(tensor) => {
58                TensorPrimitive::Float(B::float_reshape(tensor, shape))
59            }
60            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
61        }
62    }
63
64    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
65        match tensor {
66            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
67            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
68        }
69    }
70
71    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
72        match tensor {
73            TensorPrimitive::Float(tensor) => {
74                TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
75            }
76            TensorPrimitive::QFloat(tensor) => {
77                TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
78            }
79        }
80    }
81
82    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
83        match tensor {
84            TensorPrimitive::Float(tensor) => {
85                TensorPrimitive::Float(B::float_slice(tensor, slices))
86            }
87            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, slices)),
88        }
89    }
90
91    fn slice_assign(
92        tensor: Self::Primitive,
93        slices: &[Slice],
94        value: Self::Primitive,
95    ) -> Self::Primitive {
96        TensorPrimitive::Float(B::float_slice_assign(
97            tensor.tensor(),
98            slices,
99            value.tensor(),
100        ))
101    }
102
103    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
104        match tensor {
105            TensorPrimitive::Float(tensor) => {
106                TensorPrimitive::Float(B::float_select(tensor, dim, indices))
107            }
108            TensorPrimitive::QFloat(tensor) => {
109                TensorPrimitive::QFloat(B::q_select(tensor, dim, indices))
110            }
111        }
112    }
113
114    fn select_assign(
115        tensor: Self::Primitive,
116        dim: usize,
117        indices: IntTensor<B>,
118        values: Self::Primitive,
119        update: IndexingUpdateOp,
120    ) -> Self::Primitive {
121        // Select assign is ambiguous for QFloat
122        match update {
123            IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_select_add(
124                tensor.tensor(),
125                dim,
126                indices,
127                values.tensor(),
128            )),
129        }
130    }
131
132    fn mask_where(
133        tensor: Self::Primitive,
134        mask: B::BoolTensorPrimitive,
135        source: Self::Primitive,
136    ) -> Self::Primitive {
137        TensorPrimitive::Float(B::float_mask_where(tensor.tensor(), mask, source.tensor()))
138    }
139
140    fn mask_fill(
141        tensor: Self::Primitive,
142        mask: B::BoolTensorPrimitive,
143        value: Scalar,
144    ) -> Self::Primitive {
145        TensorPrimitive::Float(B::float_mask_fill(tensor.tensor(), mask, value))
146    }
147
148    fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive {
149        match tensor {
150            TensorPrimitive::Float(tensor) => {
151                TensorPrimitive::Float(B::float_gather(dim, tensor, indices))
152            }
153            TensorPrimitive::QFloat(tensor) => {
154                TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))
155            }
156        }
157    }
158
159    fn scatter(
160        dim: usize,
161        tensor: Self::Primitive,
162        indices: IntTensor<B>,
163        values: Self::Primitive,
164        update: IndexingUpdateOp,
165    ) -> Self::Primitive {
166        match update {
167            IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_scatter_add(
168                dim,
169                tensor.tensor(),
170                indices,
171                values.tensor(),
172            )),
173        }
174    }
175
176    fn device(tensor: &Self::Primitive) -> Device<B> {
177        match tensor {
178            TensorPrimitive::Float(tensor) => B::float_device(tensor),
179            TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
180        }
181    }
182
183    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
184        match tensor {
185            TensorPrimitive::Float(tensor) => {
186                TensorPrimitive::Float(B::float_to_device(tensor, device))
187            }
188            TensorPrimitive::QFloat(tensor) => {
189                TensorPrimitive::QFloat(B::q_to_device(tensor, device))
190            }
191        }
192    }
193
194    async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
195        match tensor {
196            TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
197            TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
198        }
199    }
200
201    fn from_data(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
202        if matches!(data.dtype, DType::QFloat(_)) {
203            // When the source is QFloat, there is no conversion path possible.
204            TensorPrimitive::QFloat(B::q_from_data(data, device))
205        } else if dtype.is_float() {
206            TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))
207        } else {
208            panic!("Expected float dtype, got {dtype:?}")
209        }
210    }
211
212    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
213        match tensor {
214            TensorPrimitive::Float(tensor) => {
215                TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
216            }
217            TensorPrimitive::QFloat(tensor) => {
218                TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
219            }
220        }
221    }
222
223    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
224        match vectors.first().unwrap() {
225            TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
226                vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
227                dim,
228            )),
229            TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
230                vectors
231                    .into_iter()
232                    .map(|tensor| {
233                        if let TensorPrimitive::QFloat(t) = tensor {
234                            t
235                        } else {
236                            panic!("Concatenation only works with vector of QFloat")
237                        }
238                    })
239                    .collect(),
240                dim,
241            )),
242        }
243    }
244
245    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
246        let lhs = lhs.tensor();
247        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
248        B::float_equal(lhs, rhs.tensor(), out_dtype)
249    }
250
251    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
252        let lhs = lhs.tensor();
253        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
254        B::float_not_equal(lhs, rhs.tensor(), out_dtype)
255    }
256
257    fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
258        let lhs = lhs.tensor();
259        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
260        B::float_equal_elem(lhs, rhs, out_dtype)
261    }
262
263    fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
264        let lhs = lhs.tensor();
265        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
266        B::float_not_equal_elem(lhs, rhs, out_dtype)
267    }
268
269    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
270        let tensor = tensor.tensor();
271        let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
272        B::float_any(tensor, out_dtype)
273    }
274
275    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
276        let tensor = tensor.tensor();
277        let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
278        B::float_any_dim(tensor, dim, out_dtype)
279    }
280
281    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
282        let tensor = tensor.tensor();
283        let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
284        B::float_all(tensor, out_dtype)
285    }
286
287    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
288        let tensor = tensor.tensor();
289        let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
290        B::float_all_dim(tensor, dim, out_dtype)
291    }
292
293    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
294        match tensor {
295            TensorPrimitive::Float(tensor) => {
296                TensorPrimitive::Float(B::float_permute(tensor, axes))
297            }
298            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
299        }
300    }
301
302    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
303        TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
304    }
305
306    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
307        match tensor {
308            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
309            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
310        }
311    }
312
313    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
314        TensorPrimitive::Float(B::float_unfold(tensor.tensor(), dim, size, step))
315    }
316}
317
318impl<B: Backend> Numeric<B> for Float {
319    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
320        q_bin_ops!(lhs, rhs, float_add, q_add)
321    }
322
323    fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
324        match lhs {
325            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_add_scalar(lhs, rhs)),
326            TensorPrimitive::QFloat(lhs) => B::q_add_scalar(lhs, rhs),
327        }
328    }
329
330    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
331        q_bin_ops!(lhs, rhs, float_sub, q_sub)
332    }
333
334    fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
335        match lhs {
336            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs)),
337            TensorPrimitive::QFloat(lhs) => B::q_sub_scalar(lhs, rhs),
338        }
339    }
340
341    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
342        q_bin_ops!(lhs, rhs, float_div, q_div)
343    }
344
345    fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
346        match lhs {
347            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_div_scalar(lhs, rhs)),
348            TensorPrimitive::QFloat(lhs) => B::q_div_scalar(lhs, rhs),
349        }
350    }
351    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
352        TensorPrimitive::Float(B::float_remainder(lhs.tensor(), rhs.tensor()))
353    }
354
355    fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
356        TensorPrimitive::Float(B::float_remainder_scalar(lhs.tensor(), rhs))
357    }
358
359    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
360        q_bin_ops!(lhs, rhs, float_mul, q_mul)
361    }
362
363    fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
364        match lhs {
365            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs)),
366            TensorPrimitive::QFloat(lhs) => B::q_mul_scalar(lhs, rhs),
367        }
368    }
369    fn neg(tensor: Self::Primitive) -> Self::Primitive {
370        match tensor {
371            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),
372            TensorPrimitive::QFloat(tensor) => B::q_neg(tensor),
373        }
374    }
375
376    fn sum(tensor: Self::Primitive) -> Self::Primitive {
377        match tensor {
378            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),
379            TensorPrimitive::QFloat(tensor) => B::q_sum(tensor),
380        }
381    }
382
383    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
384        match tensor {
385            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),
386            TensorPrimitive::QFloat(tensor) => B::q_sum_dim(tensor, dim),
387        }
388    }
389
390    fn prod(tensor: Self::Primitive) -> Self::Primitive {
391        match tensor {
392            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),
393            TensorPrimitive::QFloat(tensor) => B::q_prod(tensor),
394        }
395    }
396
397    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
398        match tensor {
399            TensorPrimitive::Float(tensor) => {
400                TensorPrimitive::Float(B::float_prod_dim(tensor, dim))
401            }
402            TensorPrimitive::QFloat(tensor) => B::q_prod_dim(tensor, dim),
403        }
404    }
405
406    fn mean(tensor: Self::Primitive) -> Self::Primitive {
407        match tensor {
408            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),
409            TensorPrimitive::QFloat(tensor) => B::q_mean(tensor),
410        }
411    }
412
413    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
414        match tensor {
415            TensorPrimitive::Float(tensor) => {
416                TensorPrimitive::Float(B::float_mean_dim(tensor, dim))
417            }
418            TensorPrimitive::QFloat(tensor) => B::q_mean_dim(tensor, dim),
419        }
420    }
421
422    fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
423        match tensor {
424            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumsum(tensor, dim)),
425            TensorPrimitive::QFloat(tensor) => B::q_cumsum(tensor, dim),
426        }
427    }
428
429    fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
430        match tensor {
431            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumprod(tensor, dim)),
432            TensorPrimitive::QFloat(tensor) => B::q_cumprod(tensor, dim),
433        }
434    }
435
436    fn abs(tensor: Self::Primitive) -> Self::Primitive {
437        match tensor {
438            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),
439            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),
440        }
441    }
442
443    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
444        q_bin_ops!(lhs, rhs, float_powf, q_powf)
445    }
446
447    fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
448        match lhs {
449            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs)),
450            TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs),
451        }
452    }
453
454    fn random(
455        shape: Shape,
456        distribution: Distribution,
457        device: &Device<B>,
458        dtype: DType,
459    ) -> Self::Primitive {
460        TensorPrimitive::Float(B::float_random(shape, distribution, device, dtype.into()))
461    }
462
463    fn sign(tensor: Self::Primitive) -> Self::Primitive {
464        TensorPrimitive::Float(B::float_sign(tensor.tensor()))
465    }
466
467    /// Applies the matrix multiplication operation.
468    ///
469    /// `C = AB`
470    ///
471    /// # Panics
472    ///
473    /// If the two tensors don't have a compatible shape.
474    fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
475        match (lhs, rhs) {
476            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
477                TensorPrimitive::Float(B::float_matmul(lhs, rhs))
478            }
479            (lhs, rhs) => B::q_matmul(lhs, rhs),
480        }
481    }
482}
483impl<B: Backend> Ordered<B> for Float {
484    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
485        match tensor {
486            TensorPrimitive::Float(tensor) => {
487                TensorPrimitive::Float(B::float_sort(tensor, dim, descending))
488            }
489            TensorPrimitive::QFloat(tensor) => {
490                TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))
491            }
492        }
493    }
494
495    fn sort_with_indices(
496        tensor: Self::Primitive,
497        dim: usize,
498        descending: bool,
499    ) -> (Self::Primitive, IntTensor<B>) {
500        match tensor {
501            TensorPrimitive::Float(tensor) => {
502                let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
503                let (values, indices) =
504                    B::float_sort_with_indices(tensor, dim, descending, out_dtype);
505                (TensorPrimitive::Float(values), indices)
506            }
507            TensorPrimitive::QFloat(tensor) => {
508                let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
509                let (values, indices) = B::q_sort_with_indices(tensor, dim, descending, out_dtype);
510                (TensorPrimitive::QFloat(values), indices)
511            }
512        }
513    }
514
515    fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
516        match tensor {
517            TensorPrimitive::Float(tensor) => {
518                let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
519                B::float_argsort(tensor, dim, descending, out_dtype)
520            }
521            TensorPrimitive::QFloat(tensor) => {
522                let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
523                B::q_argsort(tensor, dim, descending, out_dtype)
524            }
525        }
526    }
527
528    fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
529        match tensor {
530            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummin(tensor, dim)),
531            TensorPrimitive::QFloat(tensor) => B::q_cummin(tensor, dim),
532        }
533    }
534
535    fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
536        match tensor {
537            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummax(tensor, dim)),
538            TensorPrimitive::QFloat(tensor) => B::q_cummax(tensor, dim),
539        }
540    }
541
542    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
543        let lhs = lhs.tensor();
544        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
545        B::float_greater(lhs, rhs.tensor(), out_dtype)
546    }
547
548    fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
549        let lhs = lhs.tensor();
550        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
551        B::float_greater_elem(lhs, rhs, out_dtype)
552    }
553
554    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
555        let lhs = lhs.tensor();
556        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
557        B::float_greater_equal(lhs, rhs.tensor(), out_dtype)
558    }
559
560    fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
561        let lhs = lhs.tensor();
562        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
563        B::float_greater_equal_elem(lhs, rhs, out_dtype)
564    }
565
566    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
567        let lhs = lhs.tensor();
568        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
569        B::float_lower(lhs, rhs.tensor(), out_dtype)
570    }
571
572    fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
573        let lhs = lhs.tensor();
574        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
575        B::float_lower_elem(lhs, rhs, out_dtype)
576    }
577
578    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
579        let lhs = lhs.tensor();
580        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
581        B::float_lower_equal(lhs, rhs.tensor(), out_dtype)
582    }
583
584    fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
585        let lhs = lhs.tensor();
586        let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
587        B::float_lower_equal_elem(lhs, rhs, out_dtype)
588    }
589
590    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
591        match tensor {
592            TensorPrimitive::Float(tensor) => {
593                let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
594                B::float_argmax(tensor, dim, out_dtype)
595            }
596            TensorPrimitive::QFloat(tensor) => {
597                let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
598                B::q_argmax(tensor, dim, out_dtype)
599            }
600        }
601    }
602
603    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
604        match tensor {
605            TensorPrimitive::Float(tensor) => {
606                let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
607                B::float_argmin(tensor, dim, out_dtype)
608            }
609            TensorPrimitive::QFloat(tensor) => {
610                let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
611                B::q_argmin(tensor, dim, out_dtype)
612            }
613        }
614    }
615
616    fn max(tensor: Self::Primitive) -> Self::Primitive {
617        match tensor {
618            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),
619            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),
620        }
621    }
622
623    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
624        match tensor {
625            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),
626            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),
627        }
628    }
629
630    fn max_dim_with_indices(
631        tensor: Self::Primitive,
632        dim: usize,
633    ) -> (Self::Primitive, IntTensor<B>) {
634        match tensor {
635            TensorPrimitive::Float(tensor) => {
636                let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
637                let (values, indices) = B::float_max_dim_with_indices(tensor, dim, out_dtype);
638                (TensorPrimitive::Float(values), indices)
639            }
640            TensorPrimitive::QFloat(tensor) => {
641                let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
642                let (values, indices) = B::q_max_dim_with_indices(tensor, dim, out_dtype);
643                (TensorPrimitive::QFloat(values), indices)
644            }
645        }
646    }
647
648    fn min(tensor: Self::Primitive) -> Self::Primitive {
649        match tensor {
650            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),
651            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),
652        }
653    }
654
655    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
656        match tensor {
657            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),
658            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),
659        }
660    }
661
662    fn min_dim_with_indices(
663        tensor: Self::Primitive,
664        dim: usize,
665    ) -> (Self::Primitive, IntTensor<B>) {
666        match tensor {
667            TensorPrimitive::Float(tensor) => {
668                let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
669                let (values, indices) = B::float_min_dim_with_indices(tensor, dim, out_dtype);
670                (TensorPrimitive::Float(values), indices)
671            }
672            TensorPrimitive::QFloat(tensor) => {
673                let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
674                let (values, indices) = B::q_min_dim_with_indices(tensor, dim, out_dtype);
675                (TensorPrimitive::QFloat(values), indices)
676            }
677        }
678    }
679
680    fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {
681        match tensor {
682            TensorPrimitive::Float(tensor) => {
683                TensorPrimitive::Float(B::float_clamp(tensor, min, max))
684            }
685            TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max),
686        }
687    }
688
689    fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {
690        match tensor {
691            TensorPrimitive::Float(tensor) => {
692                TensorPrimitive::Float(B::float_clamp_min(tensor, min))
693            }
694            TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min),
695        }
696    }
697
698    fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {
699        match tensor {
700            TensorPrimitive::Float(tensor) => {
701                TensorPrimitive::Float(B::float_clamp_max(tensor, max))
702            }
703            TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max),
704        }
705    }
706
707    fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
708        match tensor {
709            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)),
710            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)),
711        }
712    }
713
714    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
715        match tensor {
716            TensorPrimitive::Float(tensor) => {
717                TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim))
718            }
719            TensorPrimitive::QFloat(tensor) => {
720                TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim))
721            }
722        }
723    }
724}
725
726impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
727    type InnerKind = Float;
728
729    fn inner(
730        tensor: <Self as TensorKind<B>>::Primitive,
731    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
732        match tensor {
733            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)),
734            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)),
735        }
736    }
737
738    fn from_inner(
739        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
740    ) -> <Self as TensorKind<B>>::Primitive {
741        match inner {
742            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)),
743            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)),
744        }
745    }
746}