burn_tensor/tensor/ops/
qtensor.rs

1use alloc::vec::Vec;
2use core::{future::Future, ops::Range};
3
4use crate::{
5    backend::Backend,
6    quantization::{QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme},
7    Device, Shape, TensorData, TensorMetadata,
8};
9
10use super::{BoolTensor, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor};
11
12/// Automatically applies dequantization -> float operation -> quantization.
13#[macro_export]
14macro_rules! dequant_op_quant {
15    // Binary tensor float op w/ lhs & rhs
16    (
17        ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
18    ) => {{
19        // Heuristic: prioritize lhs scheme
20        let scheme = $t1.scheme().clone();
21
22        let t1_f = <$ty>::dequantize($t1);
23        let t2_f = <$ty>::dequantize($t2);
24        #[allow(clippy::redundant_closure_call)]
25        let out_f = $float_op(t1_f, t2_f);
26
27        <$ty>::quantize_dynamic(out_f, &scheme)
28    }};
29    // Unary tensor float op
30    (
31        ty $ty:ty, float_op $float_op:expr, $tensor:expr
32    ) => {{
33        let scheme = $tensor.scheme().clone();
34
35        let tensor_f = <$ty>::dequantize($tensor);
36        #[allow(clippy::redundant_closure_call)]
37        let out_f = $float_op(tensor_f);
38
39        <$ty>::quantize_dynamic(out_f, &scheme)
40    }};
41}
42
43/// Quantized Tensor API for basic operations, see [tensor](crate::Tensor)
44/// for documentation on each function.
45pub trait QTensorOps<B: Backend> {
46    /// Creates a new tensor from the data structure.
47    ///
48    /// # Arguments
49    ///
50    /// * `data` - The data structure.
51    /// * `device` - The device to create the tensor on.
52    ///
53    /// # Returns
54    ///
55    /// The tensor with the given data.
56    fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
57
58    /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
59    fn quantize(
60        tensor: FloatTensor<B>,
61        scheme: &QuantizationScheme,
62        qparams: QuantizationParametersPrimitive<B>,
63    ) -> QuantizedTensor<B>;
64
65    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
66    fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantizationScheme) -> QuantizedTensor<B> {
67        // Dynamically compute min/max tensor range and qparams before quantizing
68        let min = B::float_min(tensor.clone());
69        let max = B::float_max(tensor.clone());
70        let qparams = scheme.compute_q_params_primitive(min, max);
71        Self::quantize(tensor, scheme, qparams)
72    }
73
74    /// Convert the tensor back to a higher precision data type.
75    fn dequantize(tensor: QuantizedTensor<B>) -> FloatTensor<B>;
76
77    /// Gets the device of the tensor.
78    ///
79    /// # Arguments
80    ///
81    /// * `tensor` - The tensor.
82    ///
83    /// # Returns
84    ///
85    /// The device of the tensor.
86    fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
87
88    /// Moves the tensor to the given device.
89    ///
90    /// # Arguments
91    ///
92    /// * `tensor` - The tensor.
93    /// * `device` - The device to move the tensor to.
94    ///
95    /// # Returns
96    ///
97    /// The tensor on the given device.
98    fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
99
100    /// Reshapes a tensor.
101    ///
102    /// # Arguments
103    ///
104    /// * `tensor` - The tensor to reshape.
105    /// * `shape` - The new shape of the tensor.
106    ///
107    /// # Returns
108    ///
109    /// The tensor with the new shape.
110    fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
111
112    /// Converts the tensor to a data structure.
113    ///
114    /// # Arguments
115    ///
116    /// * `tensor` - The tensor.
117    ///
118    /// # Returns
119    ///
120    /// The data structure with the tensor's data.
121    fn q_into_data(tensor: QuantizedTensor<B>)
122        -> impl Future<Output = TensorData> + 'static + Send;
123
124    /// Detaches a tensor from the computation graph.
125    fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
126        // Should only be overridden by autodiff backends.
127        tensor
128    }
129
130    /// Sets the `require_grad` flag of a tensor.
131    fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
132        // Should only be overridden by autodiff backends.
133        tensor
134    }
135
136    /// Returns the `require_grad` flag of a tensor.
137    fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
138        // Should only be overridden by autodiff backends.
139        false
140    }
141
142    /// Repeat the tensor along the given dimension.
143    ///
144    /// # Arguments
145    ///
146    /// * `tensor` - The tensor.
147    /// * `dim` - The dimension to repeat.
148    /// * `times` - The number of times to repeat the dimension.
149    ///
150    /// # Returns
151    ///
152    /// The tensor with the given dimension repeated.
153    fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
154        dequant_op_quant!(
155            ty Self,
156            float_op |tensor| B::float_repeat_dim(tensor, dim, times),
157            tensor
158        )
159    }
160
161    /// Adds two tensors together.
162    ///
163    /// # Arguments
164    ///
165    /// * `lhs` - The left hand side tensor.
166    /// * `rhs` - The right hand side tensor.
167    ///
168    /// # Returns
169    ///
170    /// The result of adding the two tensors together.
171    fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
172        dequant_op_quant!(
173            ty Self,
174            float_op |lhs, rhs| B::float_add(lhs, rhs),
175            lhs,
176            rhs
177        )
178    }
179
180    /// Adds a scalar to a tensor.
181    ///
182    /// # Arguments
183    ///
184    /// * `lhs` - The left hand side tensor.
185    /// * `rhs` - The right hand side scalar.
186    ///
187    /// # Returns
188    ///
189    /// The result of adding the scalar to the tensor.
190    fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
191        let scheme = *lhs.scheme();
192
193        let lhs_f = Self::dequantize(lhs);
194        let out_f = B::float_add_scalar(lhs_f, rhs);
195
196        Self::quantize_dynamic(out_f, &scheme)
197    }
198
199    /// Clamps a tensor under a minimum value.
200    ///
201    /// # Arguments
202    ///
203    /// * `tensor` - The tensor to clamp.
204    /// * `min` - The minimum value.
205    ///
206    /// # Returns
207    ///
208    /// The clamped tensor.
209    fn q_clamp_min(tensor: QuantizedTensor<B>, min: FloatElem<B>) -> QuantizedTensor<B> {
210        let scheme = *tensor.scheme();
211
212        let tensor_f = Self::dequantize(tensor);
213        let out_f = B::float_clamp_min(tensor_f, min);
214
215        Self::quantize_dynamic(out_f, &scheme)
216    }
217
218    /// Clamps a tensor over a maximum value.
219    ///
220    /// # Arguments
221    ///
222    /// * `tensor` - The tensor to clamp.
223    /// * `max` - The maximum value.
224    ///
225    /// # Returns
226    ///
227    /// The clamped tensor.
228    fn q_clamp_max(tensor: QuantizedTensor<B>, max: FloatElem<B>) -> QuantizedTensor<B> {
229        let scheme = *tensor.scheme();
230
231        let tensor_f = Self::dequantize(tensor);
232        let out_f = B::float_clamp_max(tensor_f, max);
233
234        Self::quantize_dynamic(out_f, &scheme)
235    }
236
237    /// Clamps a tensor between a minimum and maximum value.
238    ///
239    /// # Arguments
240    ///
241    /// * `tensor` - The tensor to clamp.
242    /// * `min` - The minimum value.
243    /// * `max` - The maximum value.
244    ///
245    /// # Returns
246    ///
247    /// The clamped tensor.
248    fn q_clamp(
249        tensor: QuantizedTensor<B>,
250        min: FloatElem<B>,
251        max: FloatElem<B>,
252    ) -> QuantizedTensor<B> {
253        let scheme = *tensor.scheme();
254
255        let tensor_f = Self::dequantize(tensor);
256        let out_f = B::float_clamp(tensor_f, min, max);
257
258        Self::quantize_dynamic(out_f, &scheme)
259    }
260
261    /// Subtracts two tensors.
262    ///
263    /// # Arguments
264    ///
265    /// * `lhs` - The left hand side tensor.
266    /// * `rhs` - The right hand side tensor.
267    ///
268    /// # Returns
269    ///
270    /// The result of subtracting the two tensors.
271    fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
272        dequant_op_quant!(
273            ty Self,
274            float_op |lhs, rhs| B::float_sub(lhs, rhs),
275            lhs,
276            rhs
277        )
278    }
279
280    /// Subtracts a scalar from a tensor.
281    ///
282    /// # Arguments
283    ///
284    /// * `lhs` - The left hand side tensor.
285    /// * `rhs` - The right hand side scalar.
286    ///
287    /// # Returns
288    ///
289    /// The result of subtracting the scalar from the tensor.
290    fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
291        let scheme = *lhs.scheme();
292
293        let lhs_f = Self::dequantize(lhs);
294        let out_f = B::float_sub_scalar(lhs_f, rhs);
295
296        Self::quantize_dynamic(out_f, &scheme)
297    }
298
299    /// Multiplies two tensors together element-wise.
300    fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
301        dequant_op_quant!(
302            ty Self,
303            float_op |lhs, rhs| B::float_mul(lhs, rhs),
304            lhs,
305            rhs
306        )
307    }
308
309    /// Multiplies a tensor by a scalar.
310    ///
311    /// # Arguments
312    ///
313    /// * `lhs` - The left hand side tensor.
314    /// * `rhs` - The right hand side scalar.
315    ///
316    /// # Returns
317    ///
318    /// The result of multiplying the tensor by the scalar.
319    fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
320        let scheme = *lhs.scheme();
321
322        let lhs_f = Self::dequantize(lhs);
323        let out_f = B::float_mul_scalar(lhs_f, rhs);
324
325        Self::quantize_dynamic(out_f, &scheme)
326    }
327
328    /// Divides two tensors element-wise.
329    ///
330    /// # Arguments
331    ///
332    /// * `lhs` - The left hand side tensor.
333    /// * `rhs` - The right hand side tensor.
334    ///
335    /// # Returns
336    ///
337    /// The result of dividing the two tensors.
338    fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
339        dequant_op_quant!(
340            ty Self,
341            float_op |lhs, rhs| B::float_div(lhs, rhs),
342            lhs,
343            rhs
344        )
345    }
346
347    /// Divides a tensor by a scalar.
348    ///
349    /// # Arguments
350    ///
351    /// * `lhs` - The left hand side tensor.
352    /// * `rhs` - The right hand side scalar.
353    ///
354    /// # Returns
355    ///
356    /// The result of dividing the tensor by the scalar.
357    fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
358        let scheme = *lhs.scheme();
359
360        let lhs_f = Self::dequantize(lhs);
361        let out_f = B::float_div_scalar(lhs_f, rhs);
362
363        Self::quantize_dynamic(out_f, &scheme)
364    }
365
366    /// Computes the remainder of division between two tensors element-wise.
367    ///
368    /// # Arguments
369    ///
370    /// * `lhs` - The left hand side tensor.
371    /// * `rhs` - The right hand side tensor.
372    ///
373    /// # Returns
374    ///
375    /// The element-wise remainder when dividing `lhs` by `rhs`.
376    fn q_remainder(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
377        dequant_op_quant!(
378            ty Self,
379            float_op |lhs, rhs| B::float_remainder(lhs, rhs),
380            lhs,
381            rhs
382        )
383    }
384
385    /// Computes the modulus of a tensor given a scalar.
386    ///
387    /// # Arguments
388    /// * `lhs` - The left hand side tensor.
389    /// * `rhs` - The right hand side scalar.
390    ///
391    /// # Returns
392    ///
393    /// The result of applying the modulus of the scalar to the tensor.
394    fn q_remainder_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
395        let scheme = *lhs.scheme();
396
397        let lhs_f = Self::dequantize(lhs);
398        let out_f = B::float_remainder_scalar(lhs_f, rhs);
399
400        Self::quantize_dynamic(out_f, &scheme)
401    }
402
403    /// Multiplies two tensors together using matrix multiplication.
404    ///
405    /// # Arguments
406    ///
407    /// * `lhs` - The left hand side tensor.
408    /// * `rhs` - The right hand side tensor.
409    ///
410    /// # Returns
411    ///
412    /// The result of multiplying the two tensors together using matrix multiplication.
413    fn q_matmul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
414        dequant_op_quant!(
415            ty Self,
416            float_op |lhs, rhs| B::float_matmul(lhs, rhs),
417            lhs,
418            rhs
419        )
420    }
421
422    /// Negates a tensor element-wise.
423    fn q_neg(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
424        let scheme = *tensor.scheme();
425
426        let tensor_f = Self::dequantize(tensor);
427        let out_f = B::float_neg(tensor_f);
428
429        Self::quantize_dynamic(out_f, &scheme)
430    }
431
432    /// Calculates the reciprocals element-wise
433    fn q_recip(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
434        let scheme = *tensor.scheme();
435
436        let tensor_f = Self::dequantize(tensor);
437        let out_f = B::float_recip(tensor_f);
438
439        Self::quantize_dynamic(out_f, &scheme)
440    }
441
442    /// Transposes a tensor.
443    ///
444    /// # Arguments
445    ///
446    /// * `tensor` - The tensor to transpose.
447    ///
448    /// # Returns
449    ///
450    /// The transposed tensor.
451    fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
452        let ndims = tensor.shape().num_dims();
453        Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
454    }
455
456    /// Swaps two dimensions of a tensor.
457    ///
458    /// # Arguments
459    ///
460    /// * `tensor` - The tensor to swap the dimensions of.
461    /// * `dim1` - The first dimension to swap.
462    /// * `dim2` - The second dimension to swap.
463    ///
464    /// # Returns
465    ///
466    /// The tensor with the dimensions swapped.
467    fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
468
469    /// Permutes the dimensions of a tensor.
470    ///
471    /// # Arguments
472    ///
473    /// * `tensor` - The tensor to permute the dimensions of.
474    /// * `axes` - The new order of the dimensions.
475    /// # Returns
476    ///
477    /// The tensor with the dimensions permuted.
478    fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
479
480    /// Reverse the order of elements in a tensor along the given axes.
481    ///
482    /// # Arguments
483    ///
484    /// * `tensor` - The tensor to reverse.
485    /// * `axes` - The axes to reverse.
486    ///
487    /// The tensor with the elements reversed.
488    fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
489
490    /// Gather elements from a tensor.
491    ///
492    /// # Arguments
493    ///
494    /// * `dim` - The dimension to gather from.
495    /// * `tensor` - The tensor to gather from.
496    /// * `indices` - The indices to gather.
497    ///
498    /// # Returns
499    ///
500    /// The gathered elements.
501    fn q_gather(
502        dim: usize,
503        tensor: QuantizedTensor<B>,
504        indices: IntTensor<B>,
505    ) -> QuantizedTensor<B> {
506        // Default implementation. Backends can gather on the quantized values when supported.
507        dequant_op_quant!(
508            ty Self,
509            float_op |tensor| B::float_gather(dim, tensor, indices),
510            tensor
511        )
512    }
513
514    /// Scatter elements into a tensor.
515    ///
516    /// # Arguments
517    ///
518    /// * `dim` - The dimension to scatter into.
519    /// * `tensor` - The tensor to scatter into.
520    /// * `indices` - The indices to scatter into.
521    /// * `value` - The value to scatter.
522    ///
523    /// # Returns
524    ///
525    /// The tensor with the scattered elements.
526    fn q_scatter(
527        dim: usize,
528        tensor: QuantizedTensor<B>,
529        indices: IntTensor<B>,
530        value: QuantizedTensor<B>,
531    ) -> QuantizedTensor<B> {
532        dequant_op_quant!(
533            ty Self,
534            float_op |tensor, value| B::float_scatter(dim, tensor, indices, value),
535            tensor,
536            value
537        )
538    }
539
540    /// Select tensor elements along the given dimension corresponding for the given indices.
541    ///
542    /// # Arguments
543    ///
544    /// * `tensor` - The tensor to select from.
545    /// * `dim` - The dimension to select from.
546    /// * `indices` - The indices to select.
547    ///
548    /// # Returns
549    ///
550    /// The selected elements.
551    fn q_select(
552        tensor: QuantizedTensor<B>,
553        dim: usize,
554        indices: IntTensor<B>,
555    ) -> QuantizedTensor<B>;
556
557    /// Assign the selected elements along the given dimension corresponding for the given indices
558    /// to the given value.
559    ///
560    /// # Arguments
561    ///
562    /// * `tensor` - The tensor to select from.
563    /// * `dim` - The dimension to select from.
564    /// * `indices` - The indices to select.
565    /// * `value` - The value to assign.
566    ///
567    /// # Returns
568    ///
569    /// The tensor with the selected elements assigned to the given value.
570    fn q_select_assign(
571        tensor: QuantizedTensor<B>,
572        dim: usize,
573        indices: IntTensor<B>,
574        value: QuantizedTensor<B>,
575    ) -> QuantizedTensor<B> {
576        dequant_op_quant!(
577            ty Self,
578            float_op |tensor, value| B::float_select_assign(tensor, dim, indices, value),
579            tensor,
580            value
581        )
582    }
583
584    /// Select tensor elements corresponding for the given ranges.
585    ///
586    /// # Arguments
587    ///
588    /// * `tensor` - The tensor to select from.
589    /// * `ranges` - The ranges to select.
590    ///
591    /// # Returns
592    ///
593    /// The selected elements in a new tensor.
594    fn q_slice(tensor: QuantizedTensor<B>, ranges: &[Range<usize>]) -> QuantizedTensor<B>;
595
596    /// Assign the selected elements corresponding for the given ranges to the given value.
597    ///
598    /// # Arguments
599    ///
600    /// * `tensor` - The tensor to select from.
601    /// * `ranges` - The ranges to select.
602    /// * `value` - The value to assign.
603    ///
604    /// # Returns
605    ///
606    /// The tensor with the selected elements assigned to the given value.
607    fn q_slice_assign(
608        tensor: QuantizedTensor<B>,
609        ranges: &[Range<usize>],
610        value: QuantizedTensor<B>,
611    ) -> QuantizedTensor<B> {
612        dequant_op_quant!(
613            ty Self,
614            float_op |tensor, value| B::float_slice_assign(tensor, ranges, value),
615            tensor,
616            value
617        )
618    }
619
620    /// Update the given tensor with the value tensor where the mask is true.
621    ///
622    /// # Arguments
623    ///
624    /// * `tensor` - The tensor to select from.
625    /// * `mask` - The boolean mask to select with.
626    /// * `value` - The value to assign to the selected elements from the value tensor.
627    ///
628    /// # Returns
629    ///
630    /// The tensor with the selected elements assigned to the given value.
631    fn q_mask_where(
632        tensor: QuantizedTensor<B>,
633        mask: BoolTensor<B>,
634        value: QuantizedTensor<B>,
635    ) -> QuantizedTensor<B> {
636        dequant_op_quant!(
637            ty Self,
638            float_op |tensor, value| B::float_mask_where(tensor, mask, value),
639            tensor,
640            value
641        )
642    }
643
644    /// Update the given tensor with the value where the mask is true.
645    ///
646    /// # Arguments
647    ///
648    /// * `tensor` - The tensor to select from.
649    /// * `mask` - The boolean mask to select with.
650    /// * `value` - The value to assign to the selected elements.
651    ///
652    /// # Returns
653    ///
654    /// The tensor with the selected elements assigned to the given value.
655    fn q_mask_fill(
656        tensor: QuantizedTensor<B>,
657        mask: BoolTensor<B>,
658        value: FloatElem<B>,
659    ) -> QuantizedTensor<B> {
660        dequant_op_quant!(
661            ty Self,
662            float_op |tensor| B::float_mask_fill(tensor, mask, value),
663            tensor
664        )
665    }
666
667    /// Sum of all elements in a tensor.
668    ///
669    /// # Arguments
670    ///
671    /// * `tensor` - The tensor to sum.
672    ///
673    /// # Returns
674    ///
675    /// A scalar tensor with the sum of all elements in `tensor`.
676    fn q_sum(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
677        dequant_op_quant!(
678            ty Self,
679            float_op |tensor| B::float_sum(tensor),
680            tensor
681        )
682    }
683
684    /// Sum of all elements in a tensor along a dimension.
685    ///
686    /// # Arguments
687    ///
688    /// * `tensor` - The tensor to sum.
689    /// * `dim` - The dimension along which to sum.
690    ///
691    /// # Returns
692    ///
693    /// A tensor with the sum of all elements in `tensor` along `dim`.
694    fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
695        dequant_op_quant!(
696            ty Self,
697            float_op |tensor| B::float_sum_dim(tensor, dim),
698            tensor
699        )
700    }
701
702    /// Product of all elements in a tensor.
703    ///
704    /// # Arguments
705    ///
706    /// * `tensor` - The tensor to product.
707    ///
708    /// # Returns
709    ///
710    /// A scalar tensor with the product of all elements in `tensor`.
711    fn q_prod(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
712        dequant_op_quant!(
713            ty Self,
714            float_op |tensor| B::float_prod(tensor),
715            tensor
716        )
717    }
718
719    /// Product of all elements in a tensor along a dimension.
720    ///
721    /// # Arguments
722    ///
723    /// * `tensor` - The tensor to product.
724    ///
725    /// # Returns
726    ///
727    /// A tensor with the product of all elements in `tensor` along `dim`.
728    fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
729        dequant_op_quant!(
730            ty Self,
731            float_op |tensor| B::float_prod_dim(tensor, dim),
732            tensor
733        )
734    }
735
736    /// Mean of all elements in a tensor.
737    ///
738    /// # Arguments
739    ///
740    /// * `tensor` - The tensor to mean.
741    ///
742    /// # Returns
743    ///
744    /// A scalar tensor with the mean of all elements in `tensor`.
745    fn q_mean(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
746        dequant_op_quant!(
747            ty Self,
748            float_op |tensor| B::float_mean(tensor),
749            tensor
750        )
751    }
752
753    /// Mean of all elements in a tensor along a dimension.
754    ///
755    /// # Arguments
756    ///
757    /// * `tensor` - The tensor to mean.
758    /// * `dim` - The dimension along which to mean.
759    ///
760    /// # Returns
761    ///
762    /// A tensor with the mean of all elements in `tensor` along `dim`.
763    fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
764        dequant_op_quant!(
765            ty Self,
766            float_op |tensor| B::float_mean_dim(tensor, dim),
767            tensor
768        )
769    }
770
771    /// Returns a new tensor with exponential values.
772    ///
773    /// # Arguments
774    ///
775    /// * `tensor` - The tensor to exponentiate.
776    ///
777    /// # Returns
778    ///
779    /// A tensor with the same shape as `tensor` with exponential values.
780    fn q_exp(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
781        dequant_op_quant!(
782            ty Self,
783            float_op |tensor| B::float_exp(tensor),
784            tensor
785        )
786    }
787
788    /// Returns a new tensor with natural logarithm values.
789    ///
790    /// # Arguments
791    ///
792    /// * `tensor` - The tensor to take the logarithm of.
793    ///
794    /// # Returns
795    ///
796    /// A tensor with the same shape as `tensor` with natural logarithm values.
797    fn q_log(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
798        dequant_op_quant!(
799            ty Self,
800            float_op |tensor| B::float_log(tensor),
801            tensor
802        )
803    }
804
805    /// Returns a new tensor with logarithm values of (1 + Xi).
806    ///
807    /// # Arguments
808    ///
809    /// * `tensor` - The tensor to take the logarithm of.
810    ///
811    /// # Returns
812    ///
813    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
814    fn q_log1p(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
815        dequant_op_quant!(
816            ty Self,
817            float_op |tensor| B::float_log1p(tensor),
818            tensor
819        )
820    }
821
822    /// Element-wise power with another tensor.
823    ///
824    /// # Arguments
825    ///
826    /// * `lhs` - The left hand side tensor.
827    /// * `rhs` - The right hand side tensor.
828    ///
829    /// # Returns
830    ///
831    /// The elements of `lhs` raised to the power of the elements of `rhs`.
832    fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
833        dequant_op_quant!(
834            ty Self,
835            float_op |lhs, rhs| B::float_powf(lhs, rhs),
836            lhs,
837            rhs
838        )
839    }
840
841    /// Element-wise power with an IntTensor.
842    ///
843    /// # Arguments
844    ///
845    /// * `lhs` - The left hand side tensor.
846    /// * `rhs` - The right hand side floatTensor.
847    ///
848    /// # Returns
849    ///
850    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
851    fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> QuantizedTensor<B> {
852        dequant_op_quant!(
853            ty Self,
854            float_op |tensor| B::float_powi(tensor, rhs),
855            lhs
856        )
857    }
858
859    /// Element-wise power with an int scalar.
860    ///
861    /// # Arguments
862    ///
863    /// * `lhs` - The left hand side tensor.
864    /// * `rhs` - The right hand side scalar.
865    ///
866    /// # Returns
867    ///
868    /// The elements of `lhs` raised to the value of `rhs`.
869    fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: IntElem<B>) -> QuantizedTensor<B> {
870        dequant_op_quant!(
871            ty Self,
872            float_op |tensor| B::float_powi_scalar(tensor, rhs),
873            lhs
874        )
875    }
876
877    /// Element-wise power with a float scalar.
878    ///
879    /// # Arguments
880    ///
881    /// * `tensor` - The tensor to exponentiate.
882    /// * `value` - The exponent.
883    ///
884    /// # Returns
885    ///
886    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
887    fn q_powf_scalar(tensor: QuantizedTensor<B>, value: f32) -> QuantizedTensor<B> {
888        dequant_op_quant!(
889            ty Self,
890            float_op |tensor| B::float_powf_scalar(tensor, value),
891            tensor
892        )
893    }
894
895    /// Returns a new tensor with square root values.
896    ///
897    /// # Arguments
898    ///
899    /// * `tensor` - The tensor to take the square root of.
900    ///
901    /// # Returns
902    ///
903    /// A tensor with the same shape as `tensor` with square root values.
904    fn q_sqrt(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
905        dequant_op_quant!(
906            ty Self,
907            float_op |tensor| B::float_sqrt(tensor),
908            tensor
909        )
910    }
911
912    /// Returns a new tensor with absolute values.
913    ///
914    /// # Arguments
915    ///
916    /// * `tensor` - The tensor to take absolute value of.
917    ///
918    /// # Returns
919    ///
920    /// A tensor with the same shape as `tensor` with absolute values.
921    fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
922        dequant_op_quant!(
923            ty Self,
924            float_op |tensor| B::float_abs(tensor),
925            tensor
926        )
927    }
928
929    /// Returns a new tensor with cosine values.
930    ///
931    /// # Arguments
932    ///
933    /// * `tensor` - The tensor to take the cosine of.
934    ///
935    /// # Returns
936    ///
937    /// A tensor with the same shape as `tensor` with cosine values.
938    fn q_cos(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
939        dequant_op_quant!(
940            ty Self,
941            float_op |tensor| B::float_cos(tensor),
942            tensor
943        )
944    }
945
946    /// Returns a new tensor with sine values.
947    ///
948    /// # Arguments
949    ///
950    /// * `tensor` - The tensor to take the sine of.
951    ///
952    /// # Returns
953    ///
954    /// A tensor with the same shape as `tensor` with sine values.
955    fn q_sin(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
956        dequant_op_quant!(
957            ty Self,
958            float_op |tensor| B::float_sin(tensor),
959            tensor
960        )
961    }
962
963    /// Returns a new tensor with tangent values.
964    ///
965    /// # Arguments
966    ///
967    /// * `tensor` - The tensor to take the tangent of.
968    ///
969    /// # Returns
970    ///
971    /// A tensor with the same shape as `tensor` with tangent values.
972    fn q_tanh(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
973        dequant_op_quant!(
974            ty Self,
975            float_op |tensor| B::float_tanh(tensor),
976            tensor
977        )
978    }
979
980    /// Returns a new tensor with the error function values.
981    ///
982    /// # Arguments
983    ///
984    /// * `tensor` - The tensor to take the error function of.
985    ///
986    /// # Returns
987    ///
988    /// A tensor with the same shape as `tensor` with error function values.
989    fn q_erf(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
990        dequant_op_quant!(
991            ty Self,
992            float_op |tensor| B::float_erf(tensor),
993            tensor
994        )
995    }
996
997    /// Concatenates tensors along a dimension.
998    ///
999    /// # Arguments
1000    ///
1001    /// * `tensors` - The tensors to concatenate.
1002    /// * `dim` - The dimension along which to concatenate.
1003    ///
1004    /// # Returns
1005    ///
1006    /// A tensor with the concatenated tensors along `dim`.
1007    fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
1008        // Heuristic: prioritize first tensor scheme
1009        let scheme = *tensors.first().unwrap().scheme();
1010
1011        let tensor_f = tensors
1012            .into_iter()
1013            .map(|tensor| Self::dequantize(tensor))
1014            .collect();
1015
1016        let out_f = B::float_cat(tensor_f, dim);
1017
1018        Self::quantize_dynamic(out_f, &scheme)
1019    }
1020
1021    /// Gets the indices of the maximum elements of a tensor along an axis.
1022    ///
1023    /// # Arguments
1024    ///
1025    /// * `tensor` - The tensor to get the maximum elements of.
1026    /// * `dim` - The dimension along which to get the maximum elements.
1027    ///
1028    /// # Returns
1029    ///
1030    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1031    fn q_argmax(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1032        // Default implementation. Backends can sort on the int values since qparams remain the same.
1033        let tensor_f = Self::dequantize(tensor);
1034        B::float_argmax(tensor_f, dim)
1035    }
1036
1037    /// Gets the indices of the minimum elements of a tensor along an axis.
1038    ///
1039    /// # Arguments
1040    ///
1041    /// * `tensor` - The tensor to get the minimum elements of.
1042    /// * `dim` - The dimension along which to get the minimum elements.
1043    ///
1044    /// # Returns
1045    ///
1046    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1047    fn q_argmin(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1048        // Default implementation. Backends can sort on the int values since qparams remain the same.
1049        let tensor_f = Self::dequantize(tensor);
1050        B::float_argmin(tensor_f, dim)
1051    }
1052
1053    /// Gets the maximum element of a tensor.
1054    ///
1055    /// # Arguments
1056    ///
1057    /// * `tensor` - The tensor to get the maximum elements of.
1058    ///
1059    /// # Returns
1060    ///
1061    /// A tensor with the maximum element of `tensor`.
1062    fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1063        let shape = tensor.shape();
1064        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1065
1066        B::q_max_dim(tensor, 0)
1067    }
1068
1069    /// Gets the maximum elements of a tensor along an axis.
1070    ///
1071    /// # Arguments
1072    ///
1073    /// * `tensor` - The tensor to get the maximum elements of.
1074    /// * `dim` - The dimension along which to get the maximum elements.
1075    ///
1076    /// # Returns
1077    ///
1078    /// A tensor with the maximum elements of `tensor` along `dim`.
1079    fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1080        let index = B::q_argmax(tensor.clone(), dim);
1081
1082        B::q_gather(dim, tensor, index)
1083    }
1084
1085    /// Gets the maximum elements of a tensor along an axis and their indices.
1086    ///
1087    /// # Arguments
1088    ///
1089    /// * `tensor` - The tensor to get the maximum elements of.
1090    /// * `dim` - The dimension along which to get the maximum elements.
1091    ///
1092    /// # Returns
1093    ///
1094    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1095    fn q_max_dim_with_indices(
1096        tensor: QuantizedTensor<B>,
1097        dim: usize,
1098    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1099        let index = B::q_argmax(tensor.clone(), dim);
1100        let values = B::q_gather(dim, tensor, index.clone());
1101
1102        (values, index)
1103    }
1104
1105    /// Gets the minimum element of a tensor.
1106    ///
1107    /// # Arguments
1108    ///
1109    /// * `tensor` - The tensor to get the minimum elements of.
1110    ///
1111    /// # Returns
1112    ///
1113    /// A tensor with the minimum element of `tensor`.
1114    fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1115        let shape = tensor.shape();
1116        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1117
1118        B::q_min_dim(tensor, 0)
1119    }
1120
1121    /// Gets the minimum elements of a tensor along an axis.
1122    ///
1123    /// # Arguments
1124    ///
1125    /// * `tensor` - The tensor to get the minimum elements of.
1126    /// * `dim` - The dimension along which to get the minimum elements.
1127    ///
1128    /// # Returns
1129    ///
1130    /// A tensor with the minimum elements of `tensor` along `dim`.
1131    fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1132        let index = B::q_argmin(tensor.clone(), dim);
1133
1134        B::q_gather(dim, tensor, index)
1135    }
1136
1137    /// Gets the minimum elements of a tensor along an axis and their indices.
1138    ///
1139    /// # Arguments
1140    ///
1141    /// * `tensor` - The tensor to get the minimum elements of.
1142    /// * `dim` - The dimension along which to get the minimum elements.
1143    ///
1144    /// # Returns
1145    ///
1146    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1147    fn q_min_dim_with_indices(
1148        tensor: QuantizedTensor<B>,
1149        dim: usize,
1150    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1151        let index = B::q_argmin(tensor.clone(), dim);
1152        let values = B::q_gather(dim, tensor, index.clone());
1153
1154        (values, index)
1155    }
1156
1157    /// Returns a new tensor with the given dimension narrowed to the given range.
1158    ///
1159    /// # Arguments
1160    ///
1161    /// * `dim` - The dimension along which the tensor will be narrowed.
1162    /// * `start` - The starting point of the given range.
1163    /// * `length` - The ending point of the given range.
1164    /// # Panics
1165    ///
1166    /// - If the dimension is greater than the number of dimensions of the tensor.
1167    /// - If the given range exceeds the number of elements on the given dimension.
1168    ///
1169    /// # Returns
1170    ///
1171    /// A new tensor with the given dimension narrowed to the given range.
1172    fn q_narrow(
1173        tensor: QuantizedTensor<B>,
1174        dim: usize,
1175        start: usize,
1176        length: usize,
1177    ) -> QuantizedTensor<B> {
1178        dequant_op_quant!(
1179            ty Self,
1180            float_op |tensor| B::float_narrow(tensor, dim, start, length),
1181            tensor
1182        )
1183    }
1184
1185    /// Split the tensor along the given dimension into chunks.
1186    ///
1187    /// # Arguments
1188    ///
1189    /// * `tensor` - The tensor.
1190    /// * `chunks` - The number of chunks to be produced.
1191    /// * `times` - The dimension along which the tensor will be split.
1192    ///
1193    /// # Returns
1194    ///
1195    /// A vector of tensors.
1196    fn q_chunk(tensor: QuantizedTensor<B>, chunks: usize, dim: usize) -> Vec<QuantizedTensor<B>> {
1197        let scheme = *tensor.scheme();
1198
1199        let tensor_f = Self::dequantize(tensor);
1200        let out_f = B::float_chunk(tensor_f, chunks, dim);
1201
1202        out_f
1203            .into_iter()
1204            .map(|tensor| Self::quantize_dynamic(tensor, &scheme))
1205            .collect()
1206    }
1207
1208    /// Split the tensor along the given dimension into chunks of `split_size`.
1209    ///
1210    /// # Arguments
1211    ///
1212    /// * `tensor` - The tensor.
1213    /// * `split_size` - The size of a single chunk.
1214    /// * `times` - The dimension along which the tensor will be split.
1215    ///
1216    /// # Returns
1217    ///
1218    /// A vector of tensors.
1219    fn q_split(
1220        tensor: QuantizedTensor<B>,
1221        split_size: usize,
1222        dim: usize,
1223    ) -> Vec<QuantizedTensor<B>> {
1224        let scheme = *tensor.scheme();
1225
1226        let tensor_f = Self::dequantize(tensor);
1227        let out_f = B::float_split(tensor_f, split_size, dim);
1228
1229        out_f
1230            .into_iter()
1231            .map(|tensor| Self::quantize_dynamic(tensor, &scheme))
1232            .collect()
1233    }
1234
1235    /// Split the tensor along the given dimension into chunks with sizes in
1236    /// `dim` according to `split_sizes`.
1237    ///
1238    /// # Arguments
1239    ///
1240    /// * `tensor` - The tensor.
1241    /// * `split_sizes` - Vector of sizes for each chunk.
1242    /// * `times` - The dimension along which the tensor will be split.
1243    ///
1244    /// # Returns
1245    ///
1246    /// A vector of tensors.
1247    fn q_split_with_sizes(
1248        tensor: QuantizedTensor<B>,
1249        split_sizes: Vec<usize>,
1250        dim: usize,
1251    ) -> Vec<QuantizedTensor<B>> {
1252        let scheme = *tensor.scheme();
1253
1254        let tensor_f = Self::dequantize(tensor);
1255        let out_f = B::float_split_with_sizes(tensor_f, split_sizes, dim);
1256
1257        out_f
1258            .into_iter()
1259            .map(|tensor| Self::quantize_dynamic(tensor, &scheme))
1260            .collect()
1261    }
1262
1263    /// Tests if any element in the `tensor` evaluates to True.
1264    ///
1265    /// # Arguments
1266    ///
1267    /// * `tensor` - The tensor to test.
1268    ///
1269    /// # Returns
1270    ///
1271    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1272    fn q_any(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1273        let tensor_f = Self::dequantize(tensor);
1274        B::float_any(tensor_f)
1275    }
1276
1277    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1278    ///
1279    /// # Arguments
1280    ///
1281    /// * `tensor` - The tensor to test.
1282    /// * `dim` - The axis along which to test.
1283    ///
1284    /// # Returns
1285    ///
1286    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1287    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1288    /// input evaluates to True, False otherwise.
1289    fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1290        let tensor_f = Self::dequantize(tensor);
1291        B::float_any_dim(tensor_f, dim)
1292    }
1293
1294    /// Tests if all elements in the `tensor` evaluate to True.
1295    ///
1296    /// # Arguments
1297    ///
1298    /// * `tensor` - The tensor to test.
1299    ///
1300    /// # Returns
1301    ///
1302    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1303    /// evaluate to True, False otherwise.
1304    fn q_all(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1305        let tensor_f = Self::dequantize(tensor);
1306        B::float_all(tensor_f)
1307    }
1308
1309    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1310    ///
1311    /// # Arguments
1312    ///
1313    /// * `tensor` - The tensor to test.
1314    /// * `dim` - The axis along which to test.
1315    ///
1316    /// # Returns
1317    ///
1318    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1319    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1320    /// evaluates to True, False otherwise.
1321    fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1322        let tensor_f = Self::dequantize(tensor);
1323        B::float_all_dim(tensor_f, dim)
1324    }
1325
1326    /// Broadcasts the `tensor` to the given `shape`.
1327    fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
1328
1329    /// Sort the elements of the input `tensor` by value in along a given dimension.
1330    ///
1331    /// This sort is unstable (i.e., may reorder equal elements).
1332    ///
1333    /// # Arguments
1334    ///
1335    /// * `tensor` - The input tensor.
1336    /// * `dim` - The axis along which to sort.
1337    /// * `descending` - The sorting order.
1338    ///
1339    /// # Returns
1340    ///
1341    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1342    fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1343        // Default implementation. Backends can sort on the int values since qparams remain the same.
1344        dequant_op_quant!(
1345            ty Self,
1346            float_op |tensor| B::float_sort(tensor, dim, descending),
1347            tensor
1348        )
1349    }
1350
1351    /// Sort the elements of the input `tensor` by value in along a given dimension.
1352    ///
1353    /// This sort is unstable (i.e., may reorder equal elements).
1354    ///
1355    /// # Arguments
1356    ///
1357    /// * `tensor` - The input tensor.
1358    /// * `dim` - The axis along which to sort.
1359    /// * `descending` - The sorting order.
1360    ///
1361    /// # Returns
1362    ///
1363    /// A tensor with the same shape as the input tensor and corresponding indices, where
1364    /// the elements are sorted by value and the indices map back to the original input tensor.
1365    fn q_sort_with_indices(
1366        tensor: QuantizedTensor<B>,
1367        dim: usize,
1368        descending: bool,
1369    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1370        // Default implementation. Backends can sort on the int values since qparams remain the same.
1371        let scheme = *tensor.scheme();
1372
1373        let tensor_f = Self::dequantize(tensor);
1374        let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending);
1375
1376        (Self::quantize_dynamic(out_f, &scheme), indices)
1377    }
1378
1379    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1380    ///
1381    /// This sort is unstable (i.e., may reorder equal elements).
1382    ///
1383    /// # Arguments
1384    ///
1385    /// * `tensor` - The input tensor.
1386    /// * `dim` - The axis along which to sort.
1387    /// * `descending` - The sorting order.
1388    ///
1389    /// # Returns
1390    ///
1391    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1392    fn q_argsort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1393        // Default implementation. Backends can sort on the int values since qparams remain the same.
1394        let tensor_f = Self::dequantize(tensor);
1395        B::float_argsort(tensor_f, dim, descending)
1396    }
1397}