Skip to main content

burn_backend/tensor/ops/
float.rs

1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5    AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData, TensorPrimitive,
6    ops::TransactionPrimitive,
7    tensor::{
8        BasicAutodiffOps, BasicOps, Device, Float, IndexingUpdateOp, IntTensor, Numeric, Ordered,
9        TensorKind,
10    },
11};
12
13macro_rules! q_bin_ops {
14    ($lhs:ident, $rhs:ident, $op:ident, $q_op:ident) => {
15        match ($lhs, $rhs) {
16            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
17                TensorPrimitive::Float(B::$op(lhs, rhs))
18            }
19            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::$q_op(lhs, rhs),
20            (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
21                TensorPrimitive::Float(B::$op(B::dequantize(lhs), rhs))
22            }
23            (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
24                TensorPrimitive::Float(B::$op(lhs, B::dequantize(rhs)))
25            }
26        }
27    };
28}
29
30impl<B: Backend> BasicOps<B> for Float {
31    type Elem = B::FloatElem;
32
33    fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
34        TensorPrimitive::Float(B::float_empty(shape, device, dtype.into()))
35    }
36
37    fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
38        TensorPrimitive::Float(B::float_zeros(shape, device, dtype.into()))
39    }
40    fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
41        TensorPrimitive::Float(B::float_ones(shape, device, dtype.into()))
42    }
43
44    fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
45        TensorPrimitive::Float(B::float_full(shape, fill_value, device, dtype.into()))
46    }
47
48    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
49        tr.register_float(tensor);
50    }
51
52    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
53        match tensor {
54            TensorPrimitive::Float(tensor) => {
55                TensorPrimitive::Float(B::float_reshape(tensor, shape))
56            }
57            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
58        }
59    }
60
61    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
62        match tensor {
63            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
64            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
65        }
66    }
67
68    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
69        match tensor {
70            TensorPrimitive::Float(tensor) => {
71                TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
72            }
73            TensorPrimitive::QFloat(tensor) => {
74                TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
75            }
76        }
77    }
78
79    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
80        match tensor {
81            TensorPrimitive::Float(tensor) => {
82                TensorPrimitive::Float(B::float_slice(tensor, slices))
83            }
84            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, slices)),
85        }
86    }
87
88    fn slice_assign(
89        tensor: Self::Primitive,
90        slices: &[Slice],
91        value: Self::Primitive,
92    ) -> Self::Primitive {
93        TensorPrimitive::Float(B::float_slice_assign(
94            tensor.tensor(),
95            slices,
96            value.tensor(),
97        ))
98    }
99
100    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
101        match tensor {
102            TensorPrimitive::Float(tensor) => {
103                TensorPrimitive::Float(B::float_select(tensor, dim, indices))
104            }
105            TensorPrimitive::QFloat(tensor) => {
106                TensorPrimitive::QFloat(B::q_select(tensor, dim, indices))
107            }
108        }
109    }
110
111    fn select_assign(
112        tensor: Self::Primitive,
113        dim: usize,
114        indices: IntTensor<B>,
115        values: Self::Primitive,
116        update: IndexingUpdateOp,
117    ) -> Self::Primitive {
118        // Select assign is ambiguous for QFloat
119        match update {
120            IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_select_add(
121                tensor.tensor(),
122                dim,
123                indices,
124                values.tensor(),
125            )),
126        }
127    }
128
129    fn mask_where(
130        tensor: Self::Primitive,
131        mask: B::BoolTensorPrimitive,
132        source: Self::Primitive,
133    ) -> Self::Primitive {
134        TensorPrimitive::Float(B::float_mask_where(tensor.tensor(), mask, source.tensor()))
135    }
136
137    fn mask_fill(
138        tensor: Self::Primitive,
139        mask: B::BoolTensorPrimitive,
140        value: Scalar,
141    ) -> Self::Primitive {
142        TensorPrimitive::Float(B::float_mask_fill(tensor.tensor(), mask, value))
143    }
144
145    fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive {
146        match tensor {
147            TensorPrimitive::Float(tensor) => {
148                TensorPrimitive::Float(B::float_gather(dim, tensor, indices))
149            }
150            TensorPrimitive::QFloat(tensor) => {
151                TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))
152            }
153        }
154    }
155
156    fn scatter(
157        dim: usize,
158        tensor: Self::Primitive,
159        indices: IntTensor<B>,
160        values: Self::Primitive,
161        update: IndexingUpdateOp,
162    ) -> Self::Primitive {
163        match update {
164            IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_scatter_add(
165                dim,
166                tensor.tensor(),
167                indices,
168                values.tensor(),
169            )),
170        }
171    }
172
173    fn device(tensor: &Self::Primitive) -> Device<B> {
174        match tensor {
175            TensorPrimitive::Float(tensor) => B::float_device(tensor),
176            TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
177        }
178    }
179
180    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
181        match tensor {
182            TensorPrimitive::Float(tensor) => {
183                TensorPrimitive::Float(B::float_to_device(tensor, device))
184            }
185            TensorPrimitive::QFloat(tensor) => {
186                TensorPrimitive::QFloat(B::q_to_device(tensor, device))
187            }
188        }
189    }
190
191    async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
192        match tensor {
193            TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
194            TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
195        }
196    }
197
198    fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {
199        match &data.dtype {
200            DType::QFloat(_scheme) => TensorPrimitive::QFloat(B::q_from_data(data, device)),
201            _ => TensorPrimitive::Float(B::float_from_data(data.convert::<B::FloatElem>(), device)),
202        }
203    }
204
205    fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
206        match dtype {
207            DType::QFloat(_scheme) => {
208                TensorPrimitive::QFloat(B::q_from_data(data.convert_dtype(dtype), device))
209            }
210            _ if dtype.is_float() => {
211                TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))
212            }
213            _ => panic!("Expected float dtype, got {dtype:?}"),
214        }
215    }
216
217    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
218        match tensor {
219            TensorPrimitive::Float(tensor) => {
220                TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
221            }
222            TensorPrimitive::QFloat(tensor) => {
223                TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
224            }
225        }
226    }
227
228    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
229        match vectors.first().unwrap() {
230            TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
231                vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
232                dim,
233            )),
234            TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
235                vectors
236                    .into_iter()
237                    .map(|tensor| {
238                        if let TensorPrimitive::QFloat(t) = tensor {
239                            t
240                        } else {
241                            panic!("Concatenation only works with vector of QFloat")
242                        }
243                    })
244                    .collect(),
245                dim,
246            )),
247        }
248    }
249
250    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
251        B::float_equal(lhs.tensor(), rhs.tensor())
252    }
253
254    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
255        B::float_not_equal(lhs.tensor(), rhs.tensor())
256    }
257
258    fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
259        B::float_equal_elem(lhs.tensor(), rhs)
260    }
261
262    fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
263        B::float_not_equal_elem(lhs.tensor(), rhs)
264    }
265
266    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
267        B::float_any(tensor.tensor())
268    }
269
270    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
271        B::float_any_dim(tensor.tensor(), dim)
272    }
273
274    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
275        B::float_all(tensor.tensor())
276    }
277
278    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
279        B::float_all_dim(tensor.tensor(), dim)
280    }
281
282    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
283        match tensor {
284            TensorPrimitive::Float(tensor) => {
285                TensorPrimitive::Float(B::float_permute(tensor, axes))
286            }
287            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
288        }
289    }
290
291    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
292        TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
293    }
294
295    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
296        match tensor {
297            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
298            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
299        }
300    }
301
302    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
303        TensorPrimitive::Float(B::float_unfold(tensor.tensor(), dim, size, step))
304    }
305}
306
307impl<B: Backend> Numeric<B> for Float {
308    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
309        q_bin_ops!(lhs, rhs, float_add, q_add)
310    }
311
312    fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
313        match lhs {
314            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_add_scalar(lhs, rhs)),
315            TensorPrimitive::QFloat(lhs) => B::q_add_scalar(lhs, rhs),
316        }
317    }
318
319    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
320        q_bin_ops!(lhs, rhs, float_sub, q_sub)
321    }
322
323    fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
324        match lhs {
325            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs)),
326            TensorPrimitive::QFloat(lhs) => B::q_sub_scalar(lhs, rhs),
327        }
328    }
329
330    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
331        q_bin_ops!(lhs, rhs, float_div, q_div)
332    }
333
334    fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
335        match lhs {
336            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_div_scalar(lhs, rhs)),
337            TensorPrimitive::QFloat(lhs) => B::q_div_scalar(lhs, rhs),
338        }
339    }
340    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
341        TensorPrimitive::Float(B::float_remainder(lhs.tensor(), rhs.tensor()))
342    }
343
344    fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
345        TensorPrimitive::Float(B::float_remainder_scalar(lhs.tensor(), rhs))
346    }
347
348    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
349        q_bin_ops!(lhs, rhs, float_mul, q_mul)
350    }
351
352    fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
353        match lhs {
354            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs)),
355            TensorPrimitive::QFloat(lhs) => B::q_mul_scalar(lhs, rhs),
356        }
357    }
358    fn neg(tensor: Self::Primitive) -> Self::Primitive {
359        match tensor {
360            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),
361            TensorPrimitive::QFloat(tensor) => B::q_neg(tensor),
362        }
363    }
364
365    fn sum(tensor: Self::Primitive) -> Self::Primitive {
366        match tensor {
367            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),
368            TensorPrimitive::QFloat(tensor) => B::q_sum(tensor),
369        }
370    }
371
372    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
373        match tensor {
374            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),
375            TensorPrimitive::QFloat(tensor) => B::q_sum_dim(tensor, dim),
376        }
377    }
378
379    fn prod(tensor: Self::Primitive) -> Self::Primitive {
380        match tensor {
381            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),
382            TensorPrimitive::QFloat(tensor) => B::q_prod(tensor),
383        }
384    }
385
386    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
387        match tensor {
388            TensorPrimitive::Float(tensor) => {
389                TensorPrimitive::Float(B::float_prod_dim(tensor, dim))
390            }
391            TensorPrimitive::QFloat(tensor) => B::q_prod_dim(tensor, dim),
392        }
393    }
394
395    fn mean(tensor: Self::Primitive) -> Self::Primitive {
396        match tensor {
397            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),
398            TensorPrimitive::QFloat(tensor) => B::q_mean(tensor),
399        }
400    }
401
402    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
403        match tensor {
404            TensorPrimitive::Float(tensor) => {
405                TensorPrimitive::Float(B::float_mean_dim(tensor, dim))
406            }
407            TensorPrimitive::QFloat(tensor) => B::q_mean_dim(tensor, dim),
408        }
409    }
410
411    fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
412        match tensor {
413            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumsum(tensor, dim)),
414            TensorPrimitive::QFloat(tensor) => B::q_cumsum(tensor, dim),
415        }
416    }
417
418    fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
419        match tensor {
420            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumprod(tensor, dim)),
421            TensorPrimitive::QFloat(tensor) => B::q_cumprod(tensor, dim),
422        }
423    }
424
425    fn abs(tensor: Self::Primitive) -> Self::Primitive {
426        match tensor {
427            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),
428            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),
429        }
430    }
431
432    fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
433        q_bin_ops!(lhs, rhs, float_powf, q_powf)
434    }
435
436    fn powf_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
437        match lhs {
438            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs)),
439            TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs),
440        }
441    }
442
443    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
444        q_bin_ops!(lhs, rhs, float_powf, q_powf)
445    }
446
447    fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
448        match lhs {
449            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs)),
450            TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs),
451        }
452    }
453
454    fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
455        TensorPrimitive::Float(B::float_random(shape, distribution, device))
456    }
457
458    fn sign(tensor: Self::Primitive) -> Self::Primitive {
459        TensorPrimitive::Float(B::float_sign(tensor.tensor()))
460    }
461
462    /// Applies the matrix multiplication operation.
463    ///
464    /// `C = AB`
465    ///
466    /// # Panics
467    ///
468    /// If the two tensors don't have a compatible shape.
469    fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
470        match (lhs, rhs) {
471            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
472                TensorPrimitive::Float(B::float_matmul(lhs, rhs))
473            }
474            (lhs, rhs) => B::q_matmul(lhs, rhs),
475        }
476    }
477}
478impl<B: Backend> Ordered<B> for Float {
479    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
480        match tensor {
481            TensorPrimitive::Float(tensor) => {
482                TensorPrimitive::Float(B::float_sort(tensor, dim, descending))
483            }
484            TensorPrimitive::QFloat(tensor) => {
485                TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))
486            }
487        }
488    }
489
490    fn sort_with_indices(
491        tensor: Self::Primitive,
492        dim: usize,
493        descending: bool,
494    ) -> (Self::Primitive, IntTensor<B>) {
495        match tensor {
496            TensorPrimitive::Float(tensor) => {
497                let (values, indices) = B::float_sort_with_indices(tensor, dim, descending);
498                (TensorPrimitive::Float(values), indices)
499            }
500            TensorPrimitive::QFloat(tensor) => {
501                let (values, indices) = B::q_sort_with_indices(tensor, dim, descending);
502                (TensorPrimitive::QFloat(values), indices)
503            }
504        }
505    }
506
507    fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
508        match tensor {
509            TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending),
510            TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending),
511        }
512    }
513
514    fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
515        match tensor {
516            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummin(tensor, dim)),
517            TensorPrimitive::QFloat(tensor) => B::q_cummin(tensor, dim),
518        }
519    }
520
521    fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
522        match tensor {
523            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummax(tensor, dim)),
524            TensorPrimitive::QFloat(tensor) => B::q_cummax(tensor, dim),
525        }
526    }
527
528    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
529        B::float_greater(lhs.tensor(), rhs.tensor())
530    }
531
532    fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
533        B::float_greater_elem(lhs.tensor(), rhs)
534    }
535
536    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
537        B::float_greater_equal(lhs.tensor(), rhs.tensor())
538    }
539
540    fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
541        B::float_greater_equal_elem(lhs.tensor(), rhs)
542    }
543
544    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
545        B::float_lower(lhs.tensor(), rhs.tensor())
546    }
547
548    fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
549        B::float_lower_elem(lhs.tensor(), rhs)
550    }
551
552    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
553        B::float_lower_equal(lhs.tensor(), rhs.tensor())
554    }
555
556    fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
557        B::float_lower_equal_elem(lhs.tensor(), rhs)
558    }
559
560    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
561        match tensor {
562            TensorPrimitive::Float(tensor) => B::float_argmax(tensor, dim),
563            TensorPrimitive::QFloat(tensor) => B::q_argmax(tensor, dim),
564        }
565    }
566
567    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
568        match tensor {
569            TensorPrimitive::Float(tensor) => B::float_argmin(tensor, dim),
570            TensorPrimitive::QFloat(tensor) => B::q_argmin(tensor, dim),
571        }
572    }
573
574    fn max(tensor: Self::Primitive) -> Self::Primitive {
575        match tensor {
576            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),
577            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),
578        }
579    }
580
581    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
582        match tensor {
583            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),
584            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),
585        }
586    }
587
588    fn max_dim_with_indices(
589        tensor: Self::Primitive,
590        dim: usize,
591    ) -> (Self::Primitive, IntTensor<B>) {
592        match tensor {
593            TensorPrimitive::Float(tensor) => {
594                let (values, indices) = B::float_max_dim_with_indices(tensor, dim);
595                (TensorPrimitive::Float(values), indices)
596            }
597            TensorPrimitive::QFloat(tensor) => {
598                let (values, indices) = B::q_max_dim_with_indices(tensor, dim);
599                (TensorPrimitive::QFloat(values), indices)
600            }
601        }
602    }
603
604    fn min(tensor: Self::Primitive) -> Self::Primitive {
605        match tensor {
606            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),
607            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),
608        }
609    }
610
611    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
612        match tensor {
613            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),
614            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),
615        }
616    }
617
618    fn min_dim_with_indices(
619        tensor: Self::Primitive,
620        dim: usize,
621    ) -> (Self::Primitive, IntTensor<B>) {
622        match tensor {
623            TensorPrimitive::Float(tensor) => {
624                let (values, indices) = B::float_min_dim_with_indices(tensor, dim);
625                (TensorPrimitive::Float(values), indices)
626            }
627            TensorPrimitive::QFloat(tensor) => {
628                let (values, indices) = B::q_min_dim_with_indices(tensor, dim);
629                (TensorPrimitive::QFloat(values), indices)
630            }
631        }
632    }
633
634    fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {
635        match tensor {
636            TensorPrimitive::Float(tensor) => {
637                TensorPrimitive::Float(B::float_clamp(tensor, min, max))
638            }
639            TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max),
640        }
641    }
642
643    fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {
644        match tensor {
645            TensorPrimitive::Float(tensor) => {
646                TensorPrimitive::Float(B::float_clamp_min(tensor, min))
647            }
648            TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min),
649        }
650    }
651
652    fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {
653        match tensor {
654            TensorPrimitive::Float(tensor) => {
655                TensorPrimitive::Float(B::float_clamp_max(tensor, max))
656            }
657            TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max),
658        }
659    }
660
661    fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
662        match tensor {
663            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)),
664            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)),
665        }
666    }
667
668    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
669        match tensor {
670            TensorPrimitive::Float(tensor) => {
671                TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim))
672            }
673            TensorPrimitive::QFloat(tensor) => {
674                TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim))
675            }
676        }
677    }
678}
679
680impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
681    type InnerKind = Float;
682
683    fn inner(
684        tensor: <Self as TensorKind<B>>::Primitive,
685    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
686        match tensor {
687            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)),
688            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)),
689        }
690    }
691
692    fn from_inner(
693        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
694    ) -> <Self as TensorKind<B>>::Primitive {
695        match inner {
696            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)),
697            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)),
698        }
699    }
700}