Skip to main content

burn_backend/backend/ops/
qtensor.rs

1use alloc::vec::Vec;
2use burn_std::{
3    BoolDType, FloatDType, IntDType, Shape, Slice,
4    quantization::{QuantPropagation, QuantScheme},
5};
6
7use crate::{
8    Backend, ExecutionError, QTensorPrimitive, TensorData, TensorMetadata, TensorPrimitive,
9    get_device_settings,
10};
11use crate::{
12    Scalar,
13    tensor::{
14        BoolTensor, Device, FloatTensor, IntTensor, QuantizedTensor,
15        quantization::{
16            Calibration, QuantizationParametersPrimitive, compute_q_params, compute_range,
17        },
18    },
19};
20
21/// Automatically applies `dequantization -> float operation -> quantization`.
22///
23/// Used for tensor ops that should always return a quantized output.
24#[macro_export]
25macro_rules! dequant_op_quant {
26    // Binary tensor float op w/ lhs & rhs
27    (
28        float_op $float_op:expr, $t1:expr, $t2:expr
29    ) => {{
30        // Heuristic: prioritize lhs scheme
31        let scheme = $t1.scheme().clone();
32
33        let t1_f = Self::dequantize($t1);
34        let t2_f = Self::dequantize($t2);
35        #[allow(clippy::redundant_closure_call)]
36        let out_f = $float_op(t1_f, t2_f);
37
38        Self::quantize_dynamic(out_f, &scheme)
39    }};
40    // Unary tensor float op
41    (
42        float_op $float_op:expr, $tensor:expr
43    ) => {{
44        let scheme = $tensor.scheme().clone();
45        let dtype = get_device_settings::<B>(&Self::q_device(&$tensor)).float_dtype;
46
47        let tensor_f = Self::dequantize($tensor, dtype);
48        #[allow(clippy::redundant_closure_call)]
49        let out_f = $float_op(tensor_f);
50
51        Self::quantize_dynamic(out_f, &scheme)
52    }};
53}
54
55/// Automatically applies `dequantization -> float operation [-> quantization]`.
56///
57/// The output quantization step is optional.
58/// It is only performed when the input quantization scheme is propagated.
59#[macro_export]
60macro_rules! dequant_op_flow {
61    // Binary tensor float op w/ lhs & rhs
62    (
63        float_op $float_op:expr, $t1:expr, $t2:expr
64    ) => {{
65        // Heuristic: prioritize lhs scheme
66        let scheme = $t1.scheme().clone();
67        let propagation = $t1.propagation();
68        let dtype = get_device_settings::<B>(&Self::q_device(&$t1)).float_dtype;
69
70        let t1_f = Self::dequantize($t1, dtype);
71        let t2_f = Self::dequantize($t2, dtype);
72        #[allow(clippy::redundant_closure_call)]
73        let out_f = $float_op(t1_f, t2_f);
74
75        match propagation {
76            QuantPropagation::Propagate => {
77                TensorPrimitive::QFloat(Self::quantize_dynamic(out_f, &scheme))
78            }
79            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
80        }
81    }};
82    // Unary tensor float op
83    (
84        float_op $float_op:expr, $tensor:expr
85    ) => {{
86        let scheme = $tensor.scheme().clone();
87        let propagation = $tensor.propagation();
88        let dtype = get_device_settings::<B>(&Self::q_device(&$tensor)).float_dtype;
89
90        let tensor_f = Self::dequantize($tensor, dtype);
91        #[allow(clippy::redundant_closure_call)]
92        let out_f = $float_op(tensor_f);
93
94        match propagation {
95            QuantPropagation::Propagate => {
96                TensorPrimitive::QFloat(Self::quantize_dynamic(out_f, &scheme))
97            }
98            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
99        }
100    }};
101}
102
103/// Operations on quantized tensors.
104///
105/// # Return Type Semantics
106///
107/// The return type of each operation indicates how quantization is handled:
108///
109/// ## [`QuantizedTensor<B>`]
110/// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized
111/// representation. Implementations should avoid dequantizing when possible to maintain performance.
112/// For example, shape or layout changes such as expand or transpose preserve quantization.
113///
114/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is
115/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering
116/// the quantization parameters to match the new layout.*
117///
118///
119/// ## [`TensorPrimitive<B>`]
120/// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation
121/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`])
122/// returned in floating-point form ([`TensorPrimitive::Float`]).
123///
124/// This distinction allows for fine-grained control over mixed-precision flows while still operating
125/// through a unified API.
126pub trait QTensorOps<B: Backend> {
127    /// Creates a new tensor from the data structure.
128    ///
129    /// # Arguments
130    ///
131    /// * `data` - The data structure.
132    /// * `device` - The device to create the tensor on.
133    ///
134    /// # Returns
135    ///
136    /// The tensor with the given data.
137    fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
138
139    /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
140    fn quantize(
141        tensor: FloatTensor<B>,
142        scheme: &QuantScheme,
143        qparams: QuantizationParametersPrimitive<B>,
144    ) -> QuantizedTensor<B>;
145
146    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
147    fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> {
148        // Dynamically compute min/max tensor range and qparams before quantizing
149        let (min, max) = compute_range::<B>(scheme, tensor.clone(), &Calibration::MinMax);
150        let qparams = compute_q_params(scheme, min, max);
151        Self::quantize(tensor, scheme, qparams)
152    }
153
154    /// Convert the tensor back to a higher precision data type.
155    fn dequantize(tensor: QuantizedTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
156
157    /// Gets the device of the tensor.
158    ///
159    /// # Arguments
160    ///
161    /// * `tensor` - The tensor.
162    ///
163    /// # Returns
164    ///
165    /// The device of the tensor.
166    fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
167
168    /// Moves the tensor to the given device.
169    ///
170    /// # Arguments
171    ///
172    /// * `tensor` - The tensor.
173    /// * `device` - The device to move the tensor to.
174    ///
175    /// # Returns
176    ///
177    /// The tensor on the given device.
178    fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
179
180    /// Reshapes a tensor.
181    ///
182    /// # Arguments
183    ///
184    /// * `tensor` - The tensor to reshape.
185    /// * `shape` - The new shape of the tensor.
186    ///
187    /// # Returns
188    ///
189    /// The tensor with the new shape.
190    fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
191
192    /// Converts the tensor to a data structure.
193    ///
194    /// # Arguments
195    ///
196    /// * `tensor` - The tensor.
197    ///
198    /// # Returns
199    ///
200    /// The data structure with the tensor's data.
201    fn q_into_data(
202        tensor: QuantizedTensor<B>,
203    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
204
205    /// Detaches a tensor from the computation graph.
206    fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
207        // Should only be overridden by autodiff backends.
208        tensor
209    }
210
211    /// Sets the `require_grad` flag of a tensor.
212    fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
213        // Should only be overridden by autodiff backends.
214        tensor
215    }
216
217    /// Returns the `require_grad` flag of a tensor.
218    fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
219        // Should only be overridden by autodiff backends.
220        false
221    }
222
223    /// Broadcasts the `tensor` to the given `shape`.
224    fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
225
226    /// Transposes a tensor.
227    ///
228    /// # Arguments
229    ///
230    /// * `tensor` - The tensor to transpose.
231    ///
232    /// # Returns
233    ///
234    /// The transposed tensor.
235    fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
236        let ndims = tensor.shape().num_dims();
237        Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
238    }
239
240    /// Swaps two dimensions of a tensor.
241    ///
242    /// # Arguments
243    ///
244    /// * `tensor` - The tensor to swap the dimensions of.
245    /// * `dim1` - The first dimension to swap.
246    /// * `dim2` - The second dimension to swap.
247    ///
248    /// # Returns
249    ///
250    /// The tensor with the dimensions swapped.
251    fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
252
253    /// Permutes the dimensions of a tensor.
254    ///
255    /// # Arguments
256    ///
257    /// * `tensor` - The tensor to permute the dimensions of.
258    /// * `axes` - The new order of the dimensions.
259    /// # Returns
260    ///
261    /// The tensor with the dimensions permuted.
262    fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
263
264    /// Reverse the order of elements in a tensor along the given axes.
265    ///
266    /// # Arguments
267    ///
268    /// * `tensor` - The tensor to reverse.
269    /// * `axes` - The axes to reverse.
270    ///
271    /// The tensor with the elements reversed.
272    fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
273
274    /// Select tensor elements along the given dimension corresponding for the given indices.
275    ///
276    /// # Arguments
277    ///
278    /// * `tensor` - The tensor to select from.
279    /// * `dim` - The dimension to select from.
280    /// * `indices` - The indices to select.
281    ///
282    /// # Returns
283    ///
284    /// The selected elements.
285    fn q_select(
286        tensor: QuantizedTensor<B>,
287        dim: usize,
288        indices: IntTensor<B>,
289    ) -> QuantizedTensor<B>;
290
291    /// Select tensor elements corresponding to the given slices.
292    ///
293    /// # Arguments
294    ///
295    /// * `tensor` - The tensor to select from.
296    /// * `slices` - The slices specifying ranges and steps for each dimension.
297    ///
298    /// # Returns
299    ///
300    /// The selected elements in a new tensor.
301    fn q_slice(tensor: QuantizedTensor<B>, slices: &[Slice]) -> QuantizedTensor<B>;
302
303    /// Gather elements from a tensor.
304    ///
305    /// # Arguments
306    ///
307    /// * `dim` - The dimension to gather from.
308    /// * `tensor` - The tensor to gather from.
309    /// * `indices` - The indices to gather.
310    ///
311    /// # Returns
312    ///
313    /// The gathered elements.
314    fn q_gather(
315        dim: usize,
316        tensor: QuantizedTensor<B>,
317        indices: IntTensor<B>,
318    ) -> QuantizedTensor<B> {
319        // Default implementation. Backends can gather on the quantized values when supported.
320        dequant_op_quant!(
321            float_op | tensor | B::float_gather(dim, tensor, indices),
322            tensor
323        )
324    }
325
326    /// Repeat the tensor along the given dimension.
327    ///
328    /// # Arguments
329    ///
330    /// * `tensor` - The tensor.
331    /// * `dim` - The dimension to repeat.
332    /// * `times` - The number of times to repeat the dimension.
333    ///
334    /// # Returns
335    ///
336    /// The tensor with the given dimension repeated.
337    fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
338        dequant_op_quant!(
339            float_op | tensor | B::float_repeat_dim(tensor, dim, times),
340            tensor
341        )
342    }
343
344    /// Adds two tensors together.
345    ///
346    /// # Arguments
347    ///
348    /// * `lhs` - The left hand side tensor.
349    /// * `rhs` - The right hand side tensor.
350    ///
351    /// # Returns
352    ///
353    /// The result of adding the two tensors together.
354    fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
355        dequant_op_flow!(float_op | lhs, rhs | B::float_add(lhs, rhs), lhs, rhs)
356    }
357
358    /// Adds a scalar to a tensor.
359    ///
360    /// # Arguments
361    ///
362    /// * `lhs` - The left hand side tensor.
363    /// * `rhs` - The right hand side scalar.
364    ///
365    /// # Returns
366    ///
367    /// The result of adding the scalar to the tensor.
368    fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
369        dequant_op_flow!(float_op | tensor | B::float_add_scalar(tensor, rhs), lhs)
370    }
371
372    /// Clamps a tensor under a minimum value.
373    ///
374    /// # Arguments
375    ///
376    /// * `tensor` - The tensor to clamp.
377    /// * `min` - The minimum value.
378    ///
379    /// # Returns
380    ///
381    /// The clamped tensor.
382    fn q_clamp_min(tensor: QuantizedTensor<B>, min: Scalar) -> TensorPrimitive<B> {
383        dequant_op_flow!(float_op | tensor | B::float_clamp_min(tensor, min), tensor)
384    }
385
386    /// Clamps a tensor over a maximum value.
387    ///
388    /// # Arguments
389    ///
390    /// * `tensor` - The tensor to clamp.
391    /// * `max` - The maximum value.
392    ///
393    /// # Returns
394    ///
395    /// The clamped tensor.
396    fn q_clamp_max(tensor: QuantizedTensor<B>, max: Scalar) -> TensorPrimitive<B> {
397        dequant_op_flow!(float_op | tensor | B::float_clamp_max(tensor, max), tensor)
398    }
399
400    /// Clamps a tensor between a minimum and maximum value.
401    ///
402    /// # Arguments
403    ///
404    /// * `tensor` - The tensor to clamp.
405    /// * `min` - The minimum value.
406    /// * `max` - The maximum value.
407    ///
408    /// # Returns
409    ///
410    /// The clamped tensor.
411    fn q_clamp(tensor: QuantizedTensor<B>, min: Scalar, max: Scalar) -> TensorPrimitive<B> {
412        dequant_op_flow!(float_op | tensor | B::float_clamp(tensor, min, max), tensor)
413    }
414
415    /// Subtracts two tensors.
416    ///
417    /// # Arguments
418    ///
419    /// * `lhs` - The left hand side tensor.
420    /// * `rhs` - The right hand side tensor.
421    ///
422    /// # Returns
423    ///
424    /// The result of subtracting the two tensors.
425    fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
426        dequant_op_flow!(float_op | lhs, rhs | B::float_sub(lhs, rhs), lhs, rhs)
427    }
428
429    /// Subtracts a scalar from a tensor.
430    ///
431    /// # Arguments
432    ///
433    /// * `lhs` - The left hand side tensor.
434    /// * `rhs` - The right hand side scalar.
435    ///
436    /// # Returns
437    ///
438    /// The result of subtracting the scalar from the tensor.
439    fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
440        dequant_op_flow!(float_op | tensor | B::float_sub_scalar(tensor, rhs), lhs)
441    }
442
443    /// Multiplies two tensors together element-wise.
444    fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
445        dequant_op_flow!(float_op | lhs, rhs | B::float_mul(lhs, rhs), lhs, rhs)
446    }
447
448    /// Multiplies a tensor by a scalar.
449    ///
450    /// # Arguments
451    ///
452    /// * `lhs` - The left hand side tensor.
453    /// * `rhs` - The right hand side scalar.
454    ///
455    /// # Returns
456    ///
457    /// The result of multiplying the tensor by the scalar.
458    fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
459        dequant_op_flow!(float_op | tensor | B::float_mul_scalar(tensor, rhs), lhs)
460    }
461
462    /// Divides two tensors element-wise.
463    ///
464    /// # Arguments
465    ///
466    /// * `lhs` - The left hand side tensor.
467    /// * `rhs` - The right hand side tensor.
468    ///
469    /// # Returns
470    ///
471    /// The result of dividing the two tensors.
472    fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
473        dequant_op_flow!(float_op | lhs, rhs | B::float_div(lhs, rhs), lhs, rhs)
474    }
475
476    /// Divides a tensor by a scalar.
477    ///
478    /// # Arguments
479    ///
480    /// * `lhs` - The left hand side tensor.
481    /// * `rhs` - The right hand side scalar.
482    ///
483    /// # Returns
484    ///
485    /// The result of dividing the tensor by the scalar.
486    fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
487        dequant_op_flow!(float_op | tensor | B::float_div_scalar(tensor, rhs), lhs)
488    }
489
490    /// Multiplies two tensors together using matrix multiplication.
491    ///
492    /// # Arguments
493    ///
494    /// * `lhs` - The left hand side tensor.
495    /// * `rhs` - The right hand side tensor.
496    ///
497    /// # Returns
498    ///
499    /// The result of multiplying the two tensors together using matrix multiplication.
500    fn q_matmul(lhs: TensorPrimitive<B>, rhs: TensorPrimitive<B>) -> TensorPrimitive<B> {
501        let mut propagation = QuantPropagation::Inhibit;
502        let mut scheme = QuantScheme::default();
503
504        // Pick a target dtype for any dequantization. If either operand is already
505        // a Float tensor, take its dtype so a Float-QFloat (or QFloat-Float) pair
506        // ends up matching after dequantize and `float_matmul` doesn't see a
507        // dtype mismatch. Only when both operands are QFloat do we fall back to
508        // the device default.
509        let target_dtype: Option<FloatDType> = match (&lhs, &rhs) {
510            (TensorPrimitive::Float(t), _) | (_, TensorPrimitive::Float(t)) => {
511                Some(t.dtype().into())
512            }
513            _ => None,
514        };
515
516        let lhs = match lhs {
517            TensorPrimitive::Float(lhs) => lhs,
518            TensorPrimitive::QFloat(lhs) => {
519                propagation = lhs.propagation();
520                scheme = *lhs.scheme();
521                let float_dtype = target_dtype
522                    .unwrap_or_else(|| get_device_settings::<B>(&Self::q_device(&lhs)).float_dtype);
523
524                Self::dequantize(lhs, float_dtype)
525            }
526        };
527        let rhs = match rhs {
528            TensorPrimitive::Float(rhs) => rhs,
529            TensorPrimitive::QFloat(rhs) => {
530                propagation = rhs.propagation();
531                scheme = *rhs.scheme();
532                let float_dtype = target_dtype
533                    .unwrap_or_else(|| get_device_settings::<B>(&Self::q_device(&rhs)).float_dtype);
534
535                Self::dequantize(rhs, float_dtype)
536            }
537        };
538
539        let out_f = B::float_matmul(lhs, rhs);
540        match propagation {
541            QuantPropagation::Propagate => {
542                TensorPrimitive::QFloat(<Self>::quantize_dynamic(out_f, &scheme))
543            }
544            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
545        }
546    }
547
548    /// Negates a tensor element-wise.
549    fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
550        dequant_op_flow!(float_op | tensor | B::float_neg(tensor), tensor)
551    }
552
553    /// Calculates the reciprocals element-wise
554    fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
555        dequant_op_flow!(float_op | tensor | B::float_recip(tensor), tensor)
556    }
557
558    /// Sum of all elements in a tensor.
559    ///
560    /// # Arguments
561    ///
562    /// * `tensor` - The tensor to sum.
563    ///
564    /// # Returns
565    ///
566    /// A scalar tensor with the sum of all elements in `tensor`.
567    fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
568        dequant_op_flow!(float_op | tensor | B::float_sum(tensor), tensor)
569    }
570
571    /// Sum of all elements in a tensor along a dimension.
572    ///
573    /// # Arguments
574    ///
575    /// * `tensor` - The tensor to sum.
576    /// * `dim` - The dimension along which to sum.
577    ///
578    /// # Returns
579    ///
580    /// A tensor with the sum of all elements in `tensor` along `dim`.
581    fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
582        dequant_op_flow!(float_op | tensor | B::float_sum_dim(tensor, dim), tensor)
583    }
584
585    /// Product of all elements in a tensor.
586    ///
587    /// # Arguments
588    ///
589    /// * `tensor` - The tensor to product.
590    ///
591    /// # Returns
592    ///
593    /// A scalar tensor with the product of all elements in `tensor`.
594    fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
595        dequant_op_flow!(float_op | tensor | B::float_prod(tensor), tensor)
596    }
597
598    /// Product of all elements in a tensor along a dimension.
599    ///
600    /// # Arguments
601    ///
602    /// * `tensor` - The tensor to product.
603    ///
604    /// # Returns
605    ///
606    /// A tensor with the product of all elements in `tensor` along `dim`.
607    fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
608        dequant_op_flow!(float_op | tensor | B::float_prod_dim(tensor, dim), tensor)
609    }
610
611    /// Mean of all elements in a tensor.
612    ///
613    /// # Arguments
614    ///
615    /// * `tensor` - The tensor to mean.
616    ///
617    /// # Returns
618    ///
619    /// A scalar tensor with the mean of all elements in `tensor`.
620    fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
621        dequant_op_flow!(float_op | tensor | B::float_mean(tensor), tensor)
622    }
623
624    /// Mean of all elements in a tensor along a dimension.
625    ///
626    /// # Arguments
627    ///
628    /// * `tensor` - The tensor to mean.
629    /// * `dim` - The dimension along which to mean.
630    ///
631    /// # Returns
632    ///
633    /// A tensor with the mean of all elements in `tensor` along `dim`.
634    fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
635        dequant_op_flow!(float_op | tensor | B::float_mean_dim(tensor, dim), tensor)
636    }
637
638    /// Computes the cumulative sum of elements along a dimension.
639    ///
640    /// # Arguments
641    ///
642    /// * `tensor` - The tensor to compute the cumulative sum of.
643    /// * `dim` - The dimension along which to compute the cumulative sum.
644    ///
645    /// # Returns
646    ///
647    /// A tensor with the same shape where each element is the cumulative sum
648    /// of all elements up to and including that position along the dimension.
649    fn q_cumsum(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
650        dequant_op_flow!(float_op | tensor | B::float_cumsum(tensor, dim), tensor)
651    }
652
653    /// Computes the cumulative product of elements along a dimension.
654    ///
655    /// # Arguments
656    ///
657    /// * `tensor` - The tensor to compute the cumulative product of.
658    /// * `dim` - The dimension along which to compute the cumulative product.
659    ///
660    /// # Returns
661    ///
662    /// A tensor with the same shape where each element is the cumulative product
663    /// of all elements up to and including that position along the dimension.
664    fn q_cumprod(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
665        dequant_op_flow!(float_op | tensor | B::float_cumprod(tensor, dim), tensor)
666    }
667
668    /// Computes the cumulative minimum of elements along a dimension.
669    ///
670    /// # Arguments
671    ///
672    /// * `tensor` - The tensor to compute the cumulative minimum of.
673    /// * `dim` - The dimension along which to compute the cumulative minimum.
674    ///
675    /// # Returns
676    ///
677    /// A tensor with the same shape where each element is the minimum
678    /// of all elements up to and including that position along the dimension.
679    fn q_cummin(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
680        dequant_op_flow!(float_op | tensor | B::float_cummin(tensor, dim), tensor)
681    }
682
683    /// Computes the cumulative maximum of elements along a dimension.
684    ///
685    /// # Arguments
686    ///
687    /// * `tensor` - The tensor to compute the cumulative maximum of.
688    /// * `dim` - The dimension along which to compute the cumulative maximum.
689    ///
690    /// # Returns
691    ///
692    /// A tensor with the same shape where each element is the maximum
693    /// of all elements up to and including that position along the dimension.
694    fn q_cummax(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
695        dequant_op_flow!(float_op | tensor | B::float_cummax(tensor, dim), tensor)
696    }
697
698    /// Returns a new tensor with exponential values.
699    ///
700    /// # Arguments
701    ///
702    /// * `tensor` - The tensor to exponentiate.
703    ///
704    /// # Returns
705    ///
706    /// A tensor with the same shape as `tensor` with exponential values.
707    fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
708        dequant_op_flow!(float_op | tensor | B::float_exp(tensor), tensor)
709    }
710
711    /// Returns a new tensor with natural logarithm values.
712    ///
713    /// # Arguments
714    ///
715    /// * `tensor` - The tensor to take the logarithm of.
716    ///
717    /// # Returns
718    ///
719    /// A tensor with the same shape as `tensor` with natural logarithm values.
720    fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
721        dequant_op_flow!(float_op | tensor | B::float_log(tensor), tensor)
722    }
723
724    /// Returns a new tensor with logarithm values of (1 + Xi).
725    ///
726    /// # Arguments
727    ///
728    /// * `tensor` - The tensor to take the logarithm of.
729    ///
730    /// # Returns
731    ///
732    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
733    fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
734        dequant_op_flow!(float_op | tensor | B::float_log1p(tensor), tensor)
735    }
736
737    /// Element-wise power with another tensor.
738    ///
739    /// # Arguments
740    ///
741    /// * `lhs` - The left hand side tensor.
742    /// * `rhs` - The right hand side tensor.
743    ///
744    /// # Returns
745    ///
746    /// The elements of `lhs` raised to the power of the elements of `rhs`.
747    fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
748        dequant_op_flow!(float_op | lhs, rhs | B::float_powf(lhs, rhs), lhs, rhs)
749    }
750
751    /// Element-wise power with an IntTensor.
752    ///
753    /// # Arguments
754    ///
755    /// * `lhs` - The left hand side tensor.
756    /// * `rhs` - The right hand side floatTensor.
757    ///
758    /// # Returns
759    ///
760    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
761    fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> {
762        dequant_op_flow!(float_op | tensor | B::float_powi(tensor, rhs), lhs)
763    }
764
765    /// Element-wise power with an int scalar.
766    ///
767    /// # Arguments
768    ///
769    /// * `lhs` - The left hand side tensor.
770    /// * `rhs` - The right hand side scalar.
771    ///
772    /// # Returns
773    ///
774    /// The elements of `lhs` raised to the value of `rhs`.
775    fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
776        dequant_op_flow!(float_op | tensor | B::float_powi_scalar(tensor, rhs), lhs)
777    }
778
779    /// Element-wise power with a float scalar.
780    ///
781    /// # Arguments
782    ///
783    /// * `tensor` - The tensor to exponentiate.
784    /// * `value` - The exponent.
785    ///
786    /// # Returns
787    ///
788    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
789    fn q_powf_scalar(tensor: QuantizedTensor<B>, value: Scalar) -> TensorPrimitive<B> {
790        dequant_op_flow!(
791            float_op | tensor | B::float_powf_scalar(tensor, value),
792            tensor
793        )
794    }
795
796    /// Returns a new tensor with square root values.
797    ///
798    /// # Arguments
799    ///
800    /// * `tensor` - The tensor to take the square root of.
801    ///
802    /// # Returns
803    ///
804    /// A tensor with the same shape as `tensor` with square root values.
805    fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
806        dequant_op_flow!(float_op | tensor | B::float_sqrt(tensor), tensor)
807    }
808
809    /// Returns a new tensor with absolute values.
810    ///
811    /// # Arguments
812    ///
813    /// * `tensor` - The tensor to take absolute value of.
814    ///
815    /// # Returns
816    ///
817    /// A tensor with the same shape as `tensor` with absolute values.
818    fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
819        dequant_op_quant!(float_op | tensor | B::float_abs(tensor), tensor)
820    }
821
822    /// Returns a new tensor with cosine values.
823    ///
824    /// # Arguments
825    ///
826    /// * `tensor` - The tensor to take the cosine of.
827    ///
828    /// # Returns
829    ///
830    /// A tensor with the same shape as `tensor` with cosine values.
831    fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
832        dequant_op_flow!(float_op | tensor | B::float_cos(tensor), tensor)
833    }
834
835    /// Returns a new tensor with sine values.
836    ///
837    /// # Arguments
838    ///
839    /// * `tensor` - The tensor to take the sine of.
840    ///
841    /// # Returns
842    ///
843    /// A tensor with the same shape as `tensor` with sine values.
844    fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
845        dequant_op_flow!(float_op | tensor | B::float_sin(tensor), tensor)
846    }
847
848    /// Returns a new tensor with tangent values.
849    ///
850    /// # Arguments
851    ///
852    /// * `tensor` - The tensor to take the tangent of.
853    ///
854    /// # Returns
855    ///
856    /// A tensor with the same shape as `tensor` with tangent values.
857    fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
858        dequant_op_flow!(float_op | tensor | B::float_tan(tensor), tensor)
859    }
860
861    /// Returns a new tensor with hyperbolic cosine values.
862    ///
863    /// # Arguments
864    ///
865    /// * `tensor` - The tensor to take the hyperbolic cosine of.
866    ///
867    /// # Returns
868    ///
869    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
870    fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
871        dequant_op_flow!(float_op | tensor | B::float_cosh(tensor), tensor)
872    }
873
874    /// Returns a new tensor with hyperbolic sine values.
875    ///
876    /// # Arguments
877    ///
878    /// * `tensor` - The tensor to take the hyperbolic sine of.
879    ///
880    /// # Returns
881    ///
882    /// A tensor with the same shape as `tensor` with hyperbolic sine values.
883    fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
884        dequant_op_flow!(float_op | tensor | B::float_sinh(tensor), tensor)
885    }
886
887    /// Returns a new tensor with hyperbolic tangent values.
888    ///
889    /// # Arguments
890    ///
891    /// * `tensor` - The tensor to take the hyperbolic tangent of.
892    ///
893    /// # Returns
894    ///
895    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
896    fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
897        dequant_op_flow!(float_op | tensor | B::float_tanh(tensor), tensor)
898    }
899
900    /// Returns a new tensor with the error function values.
901    ///
902    /// # Arguments
903    ///
904    /// * `tensor` - The tensor to take the error function of.
905    ///
906    /// # Returns
907    ///
908    /// A tensor with the same shape as `tensor` with error function values.
909    fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
910        dequant_op_flow!(float_op | tensor | B::float_erf(tensor), tensor)
911    }
912
913    /// Concatenates tensors along a dimension.
914    ///
915    /// # Arguments
916    ///
917    /// * `tensors` - The tensors to concatenate.
918    /// * `dim` - The dimension along which to concatenate.
919    ///
920    /// # Returns
921    ///
922    /// A tensor with the concatenated tensors along `dim`.
923    fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
924        // Heuristic: prioritize first tensor scheme
925        let first = tensors.first().unwrap();
926        let scheme = *first.scheme();
927        let dtype = get_device_settings::<B>(&Self::q_device(first)).float_dtype;
928
929        let tensor_f = tensors
930            .into_iter()
931            .map(|tensor| Self::dequantize(tensor, dtype))
932            .collect();
933
934        let out_f = B::float_cat(tensor_f, dim);
935
936        Self::quantize_dynamic(out_f, &scheme)
937    }
938
939    /// Gets the indices of the maximum elements of a tensor along an axis.
940    ///
941    /// # Arguments
942    ///
943    /// * `tensor` - The tensor to get the maximum elements of.
944    /// * `dim` - The dimension along which to get the maximum elements.
945    /// * `out_dtype` - The output tensor dtype.
946    ///
947    /// # Returns
948    ///
949    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
950    fn q_argmax(tensor: QuantizedTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B> {
951        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
952        let tensor_f = Self::dequantize(tensor, dtype);
953        B::float_argmax(tensor_f, dim, out_dtype)
954    }
955
956    /// Gets the indices of the k maximum elements of a tensor along an axis.
957    /// If two elements are equals, order them by the lowest indices
958    ///
959    /// # Arguments
960    ///
961    /// * `tensor` - The tensor to get the k maximum elements of.
962    /// * `dim` - The dimension along which to get the maximum elements.
963    /// * `k` - number of k maximums to keep
964    /// * `out_dtype` - The output tensor dtype.
965    ///
966    /// # Returns
967    ///
968    /// A tensor with the indices of the `k` maximum elements of `tensor` along `dim`.
969    fn q_argtopk(
970        tensor: QuantizedTensor<B>,
971        dim: usize,
972        k: usize,
973        out_dtype: IntDType,
974    ) -> IntTensor<B> {
975        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
976        let tensor_f = Self::dequantize(tensor, dtype);
977        B::float_argtopk(tensor_f, dim, k, out_dtype)
978    }
979
980    /// Gets the values of the k maximum elements of a tensor along an axis.
981    ///
982    /// # Arguments
983    ///
984    /// * `tensor` - The tensor to get the k maximum elements of.
985    /// * `dim` - The dimension along which to get the maximum elements.
986    /// * `k` - number of k maximums to keep
987    /// * `out_dtype` - The output tensor dtype.
988    ///
989    /// # Returns
990    ///
991    /// A tensor with the values of the `k` maximum elements of `tensor` along `dim`.
992    fn q_topk(tensor: QuantizedTensor<B>, dim: usize, k: usize) -> QuantizedTensor<B> {
993        dequant_op_quant!(float_op | tensor | B::float_topk(tensor, dim, k), tensor)
994    }
995
996    /// Gets the indices of the minimum elements of a tensor along an axis.
997    ///
998    /// # Arguments
999    ///
1000    /// * `tensor` - The tensor to get the minimum elements of.
1001    /// * `dim` - The dimension along which to get the minimum elements.
1002    /// * `out_dtype` - The output tensor dtype.
1003    ///
1004    /// # Returns
1005    ///
1006    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1007    fn q_argmin(tensor: QuantizedTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B> {
1008        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1009        let tensor_f = Self::dequantize(tensor, dtype);
1010        B::float_argmin(tensor_f, dim, out_dtype)
1011    }
1012
1013    /// Gets the maximum element of a tensor.
1014    ///
1015    /// # Arguments
1016    ///
1017    /// * `tensor` - The tensor to get the maximum elements of.
1018    ///
1019    /// # Returns
1020    ///
1021    /// A tensor with the maximum element of `tensor`.
1022    fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1023        let shape = tensor.shape();
1024        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1025
1026        B::q_max_dim(tensor, 0)
1027    }
1028
1029    /// Gets the maximum elements of a tensor along an axis.
1030    ///
1031    /// # Arguments
1032    ///
1033    /// * `tensor` - The tensor to get the maximum elements of.
1034    /// * `dim` - The dimension along which to get the maximum elements.
1035    ///
1036    /// # Returns
1037    ///
1038    /// A tensor with the maximum elements of `tensor` along `dim`.
1039    fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1040        let int_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
1041        let index = B::q_argmax(tensor.clone(), dim, int_dtype);
1042
1043        B::q_gather(dim, tensor, index)
1044    }
1045
1046    /// Gets the maximum elements of a tensor along an axis and their indices.
1047    ///
1048    /// # Arguments
1049    ///
1050    /// * `tensor` - The tensor to get the maximum elements of.
1051    /// * `dim` - The dimension along which to get the maximum elements.
1052    ///
1053    /// # Returns
1054    ///
1055    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1056    fn q_max_dim_with_indices(
1057        tensor: QuantizedTensor<B>,
1058        dim: usize,
1059        out_dtype: IntDType,
1060    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1061        let index = B::q_argmax(tensor.clone(), dim, out_dtype);
1062        let values = B::q_gather(dim, tensor, index.clone());
1063
1064        (values, index)
1065    }
1066
1067    /// Gets the minimum element of a tensor.
1068    ///
1069    /// # Arguments
1070    ///
1071    /// * `tensor` - The tensor to get the minimum elements of.
1072    ///
1073    /// # Returns
1074    ///
1075    /// A tensor with the minimum element of `tensor`.
1076    fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1077        let shape = tensor.shape();
1078        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1079
1080        B::q_min_dim(tensor, 0)
1081    }
1082
1083    /// Gets the minimum elements of a tensor along an axis.
1084    ///
1085    /// # Arguments
1086    ///
1087    /// * `tensor` - The tensor to get the minimum elements of.
1088    /// * `dim` - The dimension along which to get the minimum elements.
1089    ///
1090    /// # Returns
1091    ///
1092    /// A tensor with the minimum elements of `tensor` along `dim`.
1093    fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1094        let int_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
1095        let index = B::q_argmin(tensor.clone(), dim, int_dtype);
1096
1097        B::q_gather(dim, tensor, index)
1098    }
1099
1100    /// Gets the minimum elements of a tensor along an axis and their indices.
1101    ///
1102    /// # Arguments
1103    ///
1104    /// * `tensor` - The tensor to get the minimum elements of.
1105    /// * `dim` - The dimension along which to get the minimum elements.
1106    ///
1107    /// # Returns
1108    ///
1109    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1110    fn q_min_dim_with_indices(
1111        tensor: QuantizedTensor<B>,
1112        dim: usize,
1113        out_dtype: IntDType,
1114    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1115        let index = B::q_argmin(tensor.clone(), dim, out_dtype);
1116        let values = B::q_gather(dim, tensor, index.clone());
1117
1118        (values, index)
1119    }
1120
1121    /// Gets the maximum element of a tensor.
1122    ///
1123    /// # Arguments
1124    ///
1125    /// * `tensor` - The tensor to get the maximum elements of.
1126    ///
1127    /// # Returns
1128    ///
1129    /// A tensor with the maximum element of `tensor`.
1130    fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1131        let shape = tensor.shape();
1132        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1133
1134        B::q_max_abs_dim(tensor, 0)
1135    }
1136
1137    /// Gets the maximum elements of a tensor along an axis.
1138    ///
1139    /// # Arguments
1140    ///
1141    /// * `tensor` - The tensor to get the maximum elements of.
1142    /// * `dim` - The dimension along which to get the maximum elements.
1143    ///
1144    /// # Returns
1145    ///
1146    /// A tensor with the maximum elements of `tensor` along `dim`.
1147    fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1148        let int_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
1149        let index = B::q_argmax(B::q_abs(tensor.clone()), dim, int_dtype);
1150
1151        B::q_gather(dim, tensor, index)
1152    }
1153
1154    /// Tests if any element in the `tensor` evaluates to True.
1155    ///
1156    /// # Arguments
1157    ///
1158    /// * `tensor` - The tensor to test.
1159    ///
1160    /// # Returns
1161    ///
1162    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1163    fn q_any(tensor: QuantizedTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1164        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1165        let tensor_f = Self::dequantize(tensor, dtype);
1166        B::float_any(tensor_f, out_dtype)
1167    }
1168
1169    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1170    ///
1171    /// # Arguments
1172    ///
1173    /// * `tensor` - The tensor to test.
1174    /// * `dim` - The axis along which to test.
1175    ///
1176    /// # Returns
1177    ///
1178    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1179    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1180    /// input evaluates to True, False otherwise.
1181    fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1182        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1183        let tensor_f = Self::dequantize(tensor, dtype);
1184        B::float_any_dim(tensor_f, dim, out_dtype)
1185    }
1186
1187    /// Tests if all elements in the `tensor` evaluate to True.
1188    ///
1189    /// # Arguments
1190    ///
1191    /// * `tensor` - The tensor to test.
1192    ///
1193    /// # Returns
1194    ///
1195    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1196    /// evaluate to True, False otherwise.
1197    fn q_all(tensor: QuantizedTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1198        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1199        let tensor_f = Self::dequantize(tensor, dtype);
1200        B::float_all(tensor_f, out_dtype)
1201    }
1202
1203    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1204    ///
1205    /// # Arguments
1206    ///
1207    /// * `tensor` - The tensor to test.
1208    /// * `dim` - The axis along which to test.
1209    ///
1210    /// # Returns
1211    ///
1212    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1213    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1214    /// evaluates to True, False otherwise.
1215    fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1216        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1217        let tensor_f = Self::dequantize(tensor, dtype);
1218        B::float_all_dim(tensor_f, dim, out_dtype)
1219    }
1220
1221    /// Sort the elements of the input `tensor` by value in along a given dimension.
1222    ///
1223    /// This sort is unstable (i.e., may reorder equal elements).
1224    ///
1225    /// # Arguments
1226    ///
1227    /// * `tensor` - The input tensor.
1228    /// * `dim` - The axis along which to sort.
1229    /// * `descending` - The sorting order.
1230    ///
1231    /// # Returns
1232    ///
1233    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1234    fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1235        // Default implementation. Backends can sort on the int values since qparams remain the same.
1236        dequant_op_quant!(
1237            float_op | tensor | B::float_sort(tensor, dim, descending),
1238            tensor
1239        )
1240    }
1241
1242    /// Sort the elements of the input `tensor` by value in along a given dimension.
1243    ///
1244    /// This sort is unstable (i.e., may reorder equal elements).
1245    ///
1246    /// # Arguments
1247    ///
1248    /// * `tensor` - The input tensor.
1249    /// * `dim` - The axis along which to sort.
1250    /// * `descending` - The sorting order.
1251    ///
1252    /// # Returns
1253    ///
1254    /// A tensor with the same shape as the input tensor and corresponding indices, where
1255    /// the elements are sorted by value and the indices map back to the original input tensor.
1256    fn q_sort_with_indices(
1257        tensor: QuantizedTensor<B>,
1258        dim: usize,
1259        descending: bool,
1260        out_dtype: IntDType,
1261    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1262        let scheme = *tensor.scheme();
1263        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1264
1265        let tensor_f = Self::dequantize(tensor, dtype);
1266        let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending, out_dtype);
1267
1268        (Self::quantize_dynamic(out_f, &scheme), indices)
1269    }
1270
1271    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1272    ///
1273    /// This sort is unstable (i.e., may reorder equal elements).
1274    ///
1275    /// # Arguments
1276    ///
1277    /// * `tensor` - The input tensor.
1278    /// * `dim` - The axis along which to sort.
1279    /// * `descending` - The sorting order.
1280    ///
1281    /// # Returns
1282    ///
1283    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1284    fn q_argsort(
1285        tensor: QuantizedTensor<B>,
1286        dim: usize,
1287        descending: bool,
1288        out_dtype: IntDType,
1289    ) -> IntTensor<B> {
1290        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1291        let tensor_f = Self::dequantize(tensor, dtype);
1292        B::float_argsort(tensor_f, dim, descending, out_dtype)
1293    }
1294}