burn_backend/tensor/ops/
float.rs

1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5    AutodiffBackend, Backend, Distribution, ExecutionError, TensorData, TensorPrimitive,
6    element::ElementConversion,
7    ops::TransactionPrimitive,
8    tensor::{
9        BasicAutodiffOps, BasicOps, Device, Float, IndexingUpdateOp, IntTensor, Numeric, TensorKind,
10    },
11};
12
13macro_rules! q_bin_ops {
14    ($lhs:ident, $rhs:ident, $op:ident, $q_op:ident) => {
15        match ($lhs, $rhs) {
16            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
17                TensorPrimitive::Float(B::$op(lhs, rhs))
18            }
19            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::$q_op(lhs, rhs),
20            (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
21                TensorPrimitive::Float(B::$op(B::dequantize(lhs), rhs))
22            }
23            (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
24                TensorPrimitive::Float(B::$op(lhs, B::dequantize(rhs)))
25            }
26        }
27    };
28}
29
30impl<B: Backend> BasicOps<B> for Float {
31    type Elem = B::FloatElem;
32
33    fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
34        TensorPrimitive::Float(B::float_empty(shape, device, dtype.into()))
35    }
36
37    fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
38        TensorPrimitive::Float(B::float_zeros(shape, device, dtype.into()))
39    }
40    fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
41        TensorPrimitive::Float(B::float_ones(shape, device, dtype.into()))
42    }
43
44    fn full<E: ElementConversion>(
45        shape: Shape,
46        fill_value: E,
47        device: &Device<B>,
48        dtype: DType,
49    ) -> Self::Primitive {
50        TensorPrimitive::Float(B::float_full(
51            shape,
52            fill_value.elem(),
53            device,
54            dtype.into(),
55        ))
56    }
57
58    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
59        tr.register_float(tensor);
60    }
61
62    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
63        match tensor {
64            TensorPrimitive::Float(tensor) => {
65                TensorPrimitive::Float(B::float_reshape(tensor, shape))
66            }
67            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
68        }
69    }
70
71    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
72        match tensor {
73            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
74            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
75        }
76    }
77
78    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
79        match tensor {
80            TensorPrimitive::Float(tensor) => {
81                TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
82            }
83            TensorPrimitive::QFloat(tensor) => {
84                TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
85            }
86        }
87    }
88
89    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
90        match tensor {
91            TensorPrimitive::Float(tensor) => {
92                TensorPrimitive::Float(B::float_slice(tensor, slices))
93            }
94            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, slices)),
95        }
96    }
97
98    fn slice_assign(
99        tensor: Self::Primitive,
100        slices: &[Slice],
101        value: Self::Primitive,
102    ) -> Self::Primitive {
103        TensorPrimitive::Float(B::float_slice_assign(
104            tensor.tensor(),
105            slices,
106            value.tensor(),
107        ))
108    }
109
110    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
111        match tensor {
112            TensorPrimitive::Float(tensor) => {
113                TensorPrimitive::Float(B::float_select(tensor, dim, indices))
114            }
115            TensorPrimitive::QFloat(tensor) => {
116                TensorPrimitive::QFloat(B::q_select(tensor, dim, indices))
117            }
118        }
119    }
120
121    fn select_assign(
122        tensor: Self::Primitive,
123        dim: usize,
124        indices: IntTensor<B>,
125        values: Self::Primitive,
126        update: IndexingUpdateOp,
127    ) -> Self::Primitive {
128        // Select assign is ambiguous for QFloat
129        match update {
130            IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_select_add(
131                tensor.tensor(),
132                dim,
133                indices,
134                values.tensor(),
135            )),
136        }
137    }
138
139    fn mask_where(
140        tensor: Self::Primitive,
141        mask: B::BoolTensorPrimitive,
142        source: Self::Primitive,
143    ) -> Self::Primitive {
144        TensorPrimitive::Float(B::float_mask_where(tensor.tensor(), mask, source.tensor()))
145    }
146
147    fn mask_fill(
148        tensor: Self::Primitive,
149        mask: B::BoolTensorPrimitive,
150        value: Self::Elem,
151    ) -> Self::Primitive {
152        TensorPrimitive::Float(B::float_mask_fill(tensor.tensor(), mask, value))
153    }
154
155    fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive {
156        match tensor {
157            TensorPrimitive::Float(tensor) => {
158                TensorPrimitive::Float(B::float_gather(dim, tensor, indices))
159            }
160            TensorPrimitive::QFloat(tensor) => {
161                TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))
162            }
163        }
164    }
165
166    fn scatter(
167        dim: usize,
168        tensor: Self::Primitive,
169        indices: IntTensor<B>,
170        values: Self::Primitive,
171        update: IndexingUpdateOp,
172    ) -> Self::Primitive {
173        match update {
174            IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_scatter_add(
175                dim,
176                tensor.tensor(),
177                indices,
178                values.tensor(),
179            )),
180        }
181    }
182
183    fn device(tensor: &Self::Primitive) -> Device<B> {
184        match tensor {
185            TensorPrimitive::Float(tensor) => B::float_device(tensor),
186            TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
187        }
188    }
189
190    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
191        match tensor {
192            TensorPrimitive::Float(tensor) => {
193                TensorPrimitive::Float(B::float_to_device(tensor, device))
194            }
195            TensorPrimitive::QFloat(tensor) => {
196                TensorPrimitive::QFloat(B::q_to_device(tensor, device))
197            }
198        }
199    }
200
201    async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
202        match tensor {
203            TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
204            TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
205        }
206    }
207
208    fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {
209        match &data.dtype {
210            DType::QFloat(_scheme) => TensorPrimitive::QFloat(B::q_from_data(data, device)),
211            _ => TensorPrimitive::Float(B::float_from_data(data.convert::<B::FloatElem>(), device)),
212        }
213    }
214
215    fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
216        match dtype {
217            DType::QFloat(_scheme) => {
218                TensorPrimitive::QFloat(B::q_from_data(data.convert_dtype(dtype), device))
219            }
220            _ if dtype.is_float() => {
221                TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))
222            }
223            _ => panic!("Expected float dtype, got {dtype:?}"),
224        }
225    }
226
227    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
228        match tensor {
229            TensorPrimitive::Float(tensor) => {
230                TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
231            }
232            TensorPrimitive::QFloat(tensor) => {
233                TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
234            }
235        }
236    }
237
238    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
239        match vectors.first().unwrap() {
240            TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
241                vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
242                dim,
243            )),
244            TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
245                vectors
246                    .into_iter()
247                    .map(|tensor| {
248                        if let TensorPrimitive::QFloat(t) = tensor {
249                            t
250                        } else {
251                            panic!("Concatenation only works with vector of QFloat")
252                        }
253                    })
254                    .collect(),
255                dim,
256            )),
257        }
258    }
259
260    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
261        B::float_equal(lhs.tensor(), rhs.tensor())
262    }
263
264    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
265        B::float_not_equal(lhs.tensor(), rhs.tensor())
266    }
267
268    fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
269        B::float_equal_elem(lhs.tensor(), rhs)
270    }
271
272    fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
273        B::float_not_equal_elem(lhs.tensor(), rhs)
274    }
275
276    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
277        B::float_any(tensor.tensor())
278    }
279
280    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
281        B::float_any_dim(tensor.tensor(), dim)
282    }
283
284    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
285        B::float_all(tensor.tensor())
286    }
287
288    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
289        B::float_all_dim(tensor.tensor(), dim)
290    }
291
292    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
293        match tensor {
294            TensorPrimitive::Float(tensor) => {
295                TensorPrimitive::Float(B::float_permute(tensor, axes))
296            }
297            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
298        }
299    }
300
301    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
302        TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
303    }
304
305    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
306        match tensor {
307            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
308            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
309        }
310    }
311
312    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
313        TensorPrimitive::Float(B::float_unfold(tensor.tensor(), dim, size, step))
314    }
315}
316
317impl<B: Backend> Numeric<B> for Float {
318    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
319        q_bin_ops!(lhs, rhs, float_add, q_add)
320    }
321
322    fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
323        match lhs {
324            TensorPrimitive::Float(lhs) => {
325                TensorPrimitive::Float(B::float_add_scalar(lhs, rhs.elem()))
326            }
327            TensorPrimitive::QFloat(lhs) => B::q_add_scalar(lhs, rhs.elem()),
328        }
329    }
330
331    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
332        q_bin_ops!(lhs, rhs, float_sub, q_sub)
333    }
334
335    fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
336        match lhs {
337            TensorPrimitive::Float(lhs) => {
338                TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs.elem()))
339            }
340            TensorPrimitive::QFloat(lhs) => B::q_sub_scalar(lhs, rhs.elem()),
341        }
342    }
343
344    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
345        q_bin_ops!(lhs, rhs, float_div, q_div)
346    }
347
348    fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
349        match lhs {
350            TensorPrimitive::Float(lhs) => {
351                TensorPrimitive::Float(B::float_div_scalar(lhs, rhs.elem()))
352            }
353            TensorPrimitive::QFloat(lhs) => B::q_div_scalar(lhs, rhs.elem()),
354        }
355    }
356    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
357        TensorPrimitive::Float(B::float_remainder(lhs.tensor(), rhs.tensor()))
358    }
359
360    fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
361        TensorPrimitive::Float(B::float_remainder_scalar(lhs.tensor(), rhs.elem()))
362    }
363
364    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
365        q_bin_ops!(lhs, rhs, float_mul, q_mul)
366    }
367
368    fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
369        match lhs {
370            TensorPrimitive::Float(lhs) => {
371                TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs.elem()))
372            }
373            TensorPrimitive::QFloat(lhs) => B::q_mul_scalar(lhs, rhs.elem()),
374        }
375    }
376    fn neg(tensor: Self::Primitive) -> Self::Primitive {
377        match tensor {
378            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),
379            TensorPrimitive::QFloat(tensor) => B::q_neg(tensor),
380        }
381    }
382
383    fn sum(tensor: Self::Primitive) -> Self::Primitive {
384        match tensor {
385            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),
386            TensorPrimitive::QFloat(tensor) => B::q_sum(tensor),
387        }
388    }
389
390    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
391        match tensor {
392            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),
393            TensorPrimitive::QFloat(tensor) => B::q_sum_dim(tensor, dim),
394        }
395    }
396
397    fn prod(tensor: Self::Primitive) -> Self::Primitive {
398        match tensor {
399            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),
400            TensorPrimitive::QFloat(tensor) => B::q_prod(tensor),
401        }
402    }
403
404    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
405        match tensor {
406            TensorPrimitive::Float(tensor) => {
407                TensorPrimitive::Float(B::float_prod_dim(tensor, dim))
408            }
409            TensorPrimitive::QFloat(tensor) => B::q_prod_dim(tensor, dim),
410        }
411    }
412
413    fn mean(tensor: Self::Primitive) -> Self::Primitive {
414        match tensor {
415            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),
416            TensorPrimitive::QFloat(tensor) => B::q_mean(tensor),
417        }
418    }
419
420    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
421        match tensor {
422            TensorPrimitive::Float(tensor) => {
423                TensorPrimitive::Float(B::float_mean_dim(tensor, dim))
424            }
425            TensorPrimitive::QFloat(tensor) => B::q_mean_dim(tensor, dim),
426        }
427    }
428
429    fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
430        match tensor {
431            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumsum(tensor, dim)),
432            TensorPrimitive::QFloat(tensor) => B::q_cumsum(tensor, dim),
433        }
434    }
435
436    fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
437        match tensor {
438            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumprod(tensor, dim)),
439            TensorPrimitive::QFloat(tensor) => B::q_cumprod(tensor, dim),
440        }
441    }
442
443    fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
444        match tensor {
445            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummin(tensor, dim)),
446            TensorPrimitive::QFloat(tensor) => B::q_cummin(tensor, dim),
447        }
448    }
449
450    fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
451        match tensor {
452            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummax(tensor, dim)),
453            TensorPrimitive::QFloat(tensor) => B::q_cummax(tensor, dim),
454        }
455    }
456
457    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
458        B::float_greater(lhs.tensor(), rhs.tensor())
459    }
460
461    fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
462        B::float_greater_elem(lhs.tensor(), rhs)
463    }
464
465    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
466        B::float_greater_equal(lhs.tensor(), rhs.tensor())
467    }
468
469    fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
470        B::float_greater_equal_elem(lhs.tensor(), rhs)
471    }
472
473    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
474        B::float_lower(lhs.tensor(), rhs.tensor())
475    }
476
477    fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
478        B::float_lower_elem(lhs.tensor(), rhs)
479    }
480
481    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
482        B::float_lower_equal(lhs.tensor(), rhs.tensor())
483    }
484
485    fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
486        B::float_lower_equal_elem(lhs.tensor(), rhs)
487    }
488
489    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
490        match tensor {
491            TensorPrimitive::Float(tensor) => B::float_argmax(tensor, dim),
492            TensorPrimitive::QFloat(tensor) => B::q_argmax(tensor, dim),
493        }
494    }
495
496    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
497        match tensor {
498            TensorPrimitive::Float(tensor) => B::float_argmin(tensor, dim),
499            TensorPrimitive::QFloat(tensor) => B::q_argmin(tensor, dim),
500        }
501    }
502
503    fn max(tensor: Self::Primitive) -> Self::Primitive {
504        match tensor {
505            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),
506            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),
507        }
508    }
509
510    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
511        match tensor {
512            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),
513            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),
514        }
515    }
516
517    fn max_dim_with_indices(
518        tensor: Self::Primitive,
519        dim: usize,
520    ) -> (Self::Primitive, IntTensor<B>) {
521        match tensor {
522            TensorPrimitive::Float(tensor) => {
523                let (values, indices) = B::float_max_dim_with_indices(tensor, dim);
524                (TensorPrimitive::Float(values), indices)
525            }
526            TensorPrimitive::QFloat(tensor) => {
527                let (values, indices) = B::q_max_dim_with_indices(tensor, dim);
528                (TensorPrimitive::QFloat(values), indices)
529            }
530        }
531    }
532
533    fn min(tensor: Self::Primitive) -> Self::Primitive {
534        match tensor {
535            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),
536            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),
537        }
538    }
539
540    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
541        match tensor {
542            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),
543            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),
544        }
545    }
546
547    fn min_dim_with_indices(
548        tensor: Self::Primitive,
549        dim: usize,
550    ) -> (Self::Primitive, IntTensor<B>) {
551        match tensor {
552            TensorPrimitive::Float(tensor) => {
553                let (values, indices) = B::float_min_dim_with_indices(tensor, dim);
554                (TensorPrimitive::Float(values), indices)
555            }
556            TensorPrimitive::QFloat(tensor) => {
557                let (values, indices) = B::q_min_dim_with_indices(tensor, dim);
558                (TensorPrimitive::QFloat(values), indices)
559            }
560        }
561    }
562
563    fn clamp(tensor: Self::Primitive, min: B::FloatElem, max: B::FloatElem) -> Self::Primitive {
564        match tensor {
565            TensorPrimitive::Float(tensor) => {
566                TensorPrimitive::Float(B::float_clamp(tensor, min, max))
567            }
568            TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max),
569        }
570    }
571
572    fn clamp_min(tensor: Self::Primitive, min: B::FloatElem) -> Self::Primitive {
573        match tensor {
574            TensorPrimitive::Float(tensor) => {
575                TensorPrimitive::Float(B::float_clamp_min(tensor, min))
576            }
577            TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min),
578        }
579    }
580
581    fn clamp_max(tensor: Self::Primitive, max: B::FloatElem) -> Self::Primitive {
582        match tensor {
583            TensorPrimitive::Float(tensor) => {
584                TensorPrimitive::Float(B::float_clamp_max(tensor, max))
585            }
586            TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max),
587        }
588    }
589
590    fn abs(tensor: Self::Primitive) -> Self::Primitive {
591        match tensor {
592            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),
593            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),
594        }
595    }
596
597    fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
598        q_bin_ops!(lhs, rhs, float_powf, q_powf)
599    }
600
601    fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
602        match lhs {
603            TensorPrimitive::Float(lhs) => {
604                TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem()))
605            }
606            TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs.elem()),
607        }
608    }
609
610    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
611        q_bin_ops!(lhs, rhs, float_powf, q_powf)
612    }
613
614    fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
615        match lhs {
616            TensorPrimitive::Float(lhs) => {
617                TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs.elem()))
618            }
619            TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs.elem()),
620        }
621    }
622
623    fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
624        TensorPrimitive::Float(B::float_random(shape, distribution, device))
625    }
626
627    fn sign(tensor: Self::Primitive) -> Self::Primitive {
628        TensorPrimitive::Float(B::float_sign(tensor.tensor()))
629    }
630
631    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
632        match tensor {
633            TensorPrimitive::Float(tensor) => {
634                TensorPrimitive::Float(B::float_sort(tensor, dim, descending))
635            }
636            TensorPrimitive::QFloat(tensor) => {
637                TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))
638            }
639        }
640    }
641
642    fn sort_with_indices(
643        tensor: Self::Primitive,
644        dim: usize,
645        descending: bool,
646    ) -> (Self::Primitive, IntTensor<B>) {
647        match tensor {
648            TensorPrimitive::Float(tensor) => {
649                let (values, indices) = B::float_sort_with_indices(tensor, dim, descending);
650                (TensorPrimitive::Float(values), indices)
651            }
652            TensorPrimitive::QFloat(tensor) => {
653                let (values, indices) = B::q_sort_with_indices(tensor, dim, descending);
654                (TensorPrimitive::QFloat(values), indices)
655            }
656        }
657    }
658
659    fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
660        match tensor {
661            TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending),
662            TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending),
663        }
664    }
665
666    fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
667        match tensor {
668            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)),
669            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)),
670        }
671    }
672
673    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
674        match tensor {
675            TensorPrimitive::Float(tensor) => {
676                TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim))
677            }
678            TensorPrimitive::QFloat(tensor) => {
679                TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim))
680            }
681        }
682    }
683
684    /// Applies the matrix multiplication operation.
685    ///
686    /// `C = AB`
687    ///
688    /// # Panics
689    ///
690    /// If the two tensors don't have a compatible shape.
691    fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
692        match (lhs, rhs) {
693            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
694                TensorPrimitive::Float(B::float_matmul(lhs, rhs))
695            }
696            (lhs, rhs) => B::q_matmul(lhs, rhs),
697        }
698    }
699}
700
701impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
702    type InnerKind = Float;
703
704    fn inner(
705        tensor: <Self as TensorKind<B>>::Primitive,
706    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
707        match tensor {
708            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)),
709            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)),
710        }
711    }
712
713    fn from_inner(
714        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
715    ) -> <Self as TensorKind<B>>::Primitive {
716        match inner {
717            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)),
718            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)),
719        }
720    }
721}