burn_tensor/tensor/ops/
qtensor.rs

1use alloc::vec::Vec;
2use core::{future::Future, ops::Range};
3
4use crate::{
5    Device, Shape, TensorData, TensorMetadata,
6    backend::Backend,
7    quantization::{
8        Calibration, QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme,
9    },
10};
11
12use super::{BoolTensor, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor};
13
14/// Automatically applies dequantization -> float operation -> quantization.
15#[macro_export]
16macro_rules! dequant_op_quant {
17    // Binary tensor float op w/ lhs & rhs
18    (
19        ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
20    ) => {{
21        // Heuristic: prioritize lhs scheme
22        let scheme = $t1.scheme().clone();
23
24        let t1_f = <$ty>::dequantize($t1);
25        let t2_f = <$ty>::dequantize($t2);
26        #[allow(clippy::redundant_closure_call)]
27        let out_f = $float_op(t1_f, t2_f);
28
29        <$ty>::quantize_dynamic(out_f, &scheme)
30    }};
31    // Unary tensor float op
32    (
33        ty $ty:ty, float_op $float_op:expr, $tensor:expr
34    ) => {{
35        let scheme = $tensor.scheme().clone();
36
37        let tensor_f = <$ty>::dequantize($tensor);
38        #[allow(clippy::redundant_closure_call)]
39        let out_f = $float_op(tensor_f);
40
41        <$ty>::quantize_dynamic(out_f, &scheme)
42    }};
43}
44
45/// Quantized Tensor API for basic operations, see [tensor](crate::Tensor)
46/// for documentation on each function.
47pub trait QTensorOps<B: Backend> {
48    /// Creates a new tensor from the data structure.
49    ///
50    /// # Arguments
51    ///
52    /// * `data` - The data structure.
53    /// * `device` - The device to create the tensor on.
54    ///
55    /// # Returns
56    ///
57    /// The tensor with the given data.
58    fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
59
60    /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
61    fn quantize(
62        tensor: FloatTensor<B>,
63        scheme: &QuantizationScheme,
64        qparams: QuantizationParametersPrimitive<B>,
65    ) -> QuantizedTensor<B>;
66
67    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
68    fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantizationScheme) -> QuantizedTensor<B> {
69        // Dynamically compute min/max tensor range and qparams before quantizing
70        let (min, max) = scheme.compute_range_primitive::<B>(tensor.clone(), &Calibration::MinMax);
71        let qparams = scheme.compute_q_params_primitive(min, max);
72        Self::quantize(tensor, scheme, qparams)
73    }
74
75    /// Convert the tensor back to a higher precision data type.
76    fn dequantize(tensor: QuantizedTensor<B>) -> FloatTensor<B>;
77
78    /// Gets the device of the tensor.
79    ///
80    /// # Arguments
81    ///
82    /// * `tensor` - The tensor.
83    ///
84    /// # Returns
85    ///
86    /// The device of the tensor.
87    fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
88
89    /// Moves the tensor to the given device.
90    ///
91    /// # Arguments
92    ///
93    /// * `tensor` - The tensor.
94    /// * `device` - The device to move the tensor to.
95    ///
96    /// # Returns
97    ///
98    /// The tensor on the given device.
99    fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
100
101    /// Reshapes a tensor.
102    ///
103    /// # Arguments
104    ///
105    /// * `tensor` - The tensor to reshape.
106    /// * `shape` - The new shape of the tensor.
107    ///
108    /// # Returns
109    ///
110    /// The tensor with the new shape.
111    fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
112
113    /// Converts the tensor to a data structure.
114    ///
115    /// # Arguments
116    ///
117    /// * `tensor` - The tensor.
118    ///
119    /// # Returns
120    ///
121    /// The data structure with the tensor's data.
122    fn q_into_data(tensor: QuantizedTensor<B>)
123    -> impl Future<Output = TensorData> + 'static + Send;
124
125    /// Detaches a tensor from the computation graph.
126    fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
127        // Should only be overridden by autodiff backends.
128        tensor
129    }
130
131    /// Sets the `require_grad` flag of a tensor.
132    fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
133        // Should only be overridden by autodiff backends.
134        tensor
135    }
136
137    /// Returns the `require_grad` flag of a tensor.
138    fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
139        // Should only be overridden by autodiff backends.
140        false
141    }
142
143    /// Repeat the tensor along the given dimension.
144    ///
145    /// # Arguments
146    ///
147    /// * `tensor` - The tensor.
148    /// * `dim` - The dimension to repeat.
149    /// * `times` - The number of times to repeat the dimension.
150    ///
151    /// # Returns
152    ///
153    /// The tensor with the given dimension repeated.
154    fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
155        dequant_op_quant!(
156            ty Self,
157            float_op |tensor| B::float_repeat_dim(tensor, dim, times),
158            tensor
159        )
160    }
161
162    /// Adds two tensors together.
163    ///
164    /// # Arguments
165    ///
166    /// * `lhs` - The left hand side tensor.
167    /// * `rhs` - The right hand side tensor.
168    ///
169    /// # Returns
170    ///
171    /// The result of adding the two tensors together.
172    fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
173        dequant_op_quant!(
174            ty Self,
175            float_op |lhs, rhs| B::float_add(lhs, rhs),
176            lhs,
177            rhs
178        )
179    }
180
181    /// Adds a scalar to a tensor.
182    ///
183    /// # Arguments
184    ///
185    /// * `lhs` - The left hand side tensor.
186    /// * `rhs` - The right hand side scalar.
187    ///
188    /// # Returns
189    ///
190    /// The result of adding the scalar to the tensor.
191    fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
192        let scheme = *lhs.scheme();
193
194        let lhs_f = Self::dequantize(lhs);
195        let out_f = B::float_add_scalar(lhs_f, rhs);
196
197        Self::quantize_dynamic(out_f, &scheme)
198    }
199
200    /// Clamps a tensor under a minimum value.
201    ///
202    /// # Arguments
203    ///
204    /// * `tensor` - The tensor to clamp.
205    /// * `min` - The minimum value.
206    ///
207    /// # Returns
208    ///
209    /// The clamped tensor.
210    fn q_clamp_min(tensor: QuantizedTensor<B>, min: FloatElem<B>) -> QuantizedTensor<B> {
211        let scheme = *tensor.scheme();
212
213        let tensor_f = Self::dequantize(tensor);
214        let out_f = B::float_clamp_min(tensor_f, min);
215
216        Self::quantize_dynamic(out_f, &scheme)
217    }
218
219    /// Clamps a tensor over a maximum value.
220    ///
221    /// # Arguments
222    ///
223    /// * `tensor` - The tensor to clamp.
224    /// * `max` - The maximum value.
225    ///
226    /// # Returns
227    ///
228    /// The clamped tensor.
229    fn q_clamp_max(tensor: QuantizedTensor<B>, max: FloatElem<B>) -> QuantizedTensor<B> {
230        let scheme = *tensor.scheme();
231
232        let tensor_f = Self::dequantize(tensor);
233        let out_f = B::float_clamp_max(tensor_f, max);
234
235        Self::quantize_dynamic(out_f, &scheme)
236    }
237
238    /// Clamps a tensor between a minimum and maximum value.
239    ///
240    /// # Arguments
241    ///
242    /// * `tensor` - The tensor to clamp.
243    /// * `min` - The minimum value.
244    /// * `max` - The maximum value.
245    ///
246    /// # Returns
247    ///
248    /// The clamped tensor.
249    fn q_clamp(
250        tensor: QuantizedTensor<B>,
251        min: FloatElem<B>,
252        max: FloatElem<B>,
253    ) -> QuantizedTensor<B> {
254        let scheme = *tensor.scheme();
255
256        let tensor_f = Self::dequantize(tensor);
257        let out_f = B::float_clamp(tensor_f, min, max);
258
259        Self::quantize_dynamic(out_f, &scheme)
260    }
261
262    /// Subtracts two tensors.
263    ///
264    /// # Arguments
265    ///
266    /// * `lhs` - The left hand side tensor.
267    /// * `rhs` - The right hand side tensor.
268    ///
269    /// # Returns
270    ///
271    /// The result of subtracting the two tensors.
272    fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
273        dequant_op_quant!(
274            ty Self,
275            float_op |lhs, rhs| B::float_sub(lhs, rhs),
276            lhs,
277            rhs
278        )
279    }
280
281    /// Subtracts a scalar from a tensor.
282    ///
283    /// # Arguments
284    ///
285    /// * `lhs` - The left hand side tensor.
286    /// * `rhs` - The right hand side scalar.
287    ///
288    /// # Returns
289    ///
290    /// The result of subtracting the scalar from the tensor.
291    fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
292        let scheme = *lhs.scheme();
293
294        let lhs_f = Self::dequantize(lhs);
295        let out_f = B::float_sub_scalar(lhs_f, rhs);
296
297        Self::quantize_dynamic(out_f, &scheme)
298    }
299
300    /// Multiplies two tensors together element-wise.
301    fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
302        dequant_op_quant!(
303            ty Self,
304            float_op |lhs, rhs| B::float_mul(lhs, rhs),
305            lhs,
306            rhs
307        )
308    }
309
310    /// Multiplies a tensor by a scalar.
311    ///
312    /// # Arguments
313    ///
314    /// * `lhs` - The left hand side tensor.
315    /// * `rhs` - The right hand side scalar.
316    ///
317    /// # Returns
318    ///
319    /// The result of multiplying the tensor by the scalar.
320    fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
321        let scheme = *lhs.scheme();
322
323        let lhs_f = Self::dequantize(lhs);
324        let out_f = B::float_mul_scalar(lhs_f, rhs);
325
326        Self::quantize_dynamic(out_f, &scheme)
327    }
328
329    /// Divides two tensors element-wise.
330    ///
331    /// # Arguments
332    ///
333    /// * `lhs` - The left hand side tensor.
334    /// * `rhs` - The right hand side tensor.
335    ///
336    /// # Returns
337    ///
338    /// The result of dividing the two tensors.
339    fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
340        dequant_op_quant!(
341            ty Self,
342            float_op |lhs, rhs| B::float_div(lhs, rhs),
343            lhs,
344            rhs
345        )
346    }
347
348    /// Divides a tensor by a scalar.
349    ///
350    /// # Arguments
351    ///
352    /// * `lhs` - The left hand side tensor.
353    /// * `rhs` - The right hand side scalar.
354    ///
355    /// # Returns
356    ///
357    /// The result of dividing the tensor by the scalar.
358    fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
359        let scheme = *lhs.scheme();
360
361        let lhs_f = Self::dequantize(lhs);
362        let out_f = B::float_div_scalar(lhs_f, rhs);
363
364        Self::quantize_dynamic(out_f, &scheme)
365    }
366
367    /// Computes the remainder of division between two tensors element-wise.
368    ///
369    /// # Arguments
370    ///
371    /// * `lhs` - The left hand side tensor.
372    /// * `rhs` - The right hand side tensor.
373    ///
374    /// # Returns
375    ///
376    /// The element-wise remainder when dividing `lhs` by `rhs`.
377    fn q_remainder(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
378        dequant_op_quant!(
379            ty Self,
380            float_op |lhs, rhs| B::float_remainder(lhs, rhs),
381            lhs,
382            rhs
383        )
384    }
385
386    /// Computes the modulus of a tensor given a scalar.
387    ///
388    /// # Arguments
389    /// * `lhs` - The left hand side tensor.
390    /// * `rhs` - The right hand side scalar.
391    ///
392    /// # Returns
393    ///
394    /// The result of applying the modulus of the scalar to the tensor.
395    fn q_remainder_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
396        let scheme = *lhs.scheme();
397
398        let lhs_f = Self::dequantize(lhs);
399        let out_f = B::float_remainder_scalar(lhs_f, rhs);
400
401        Self::quantize_dynamic(out_f, &scheme)
402    }
403
404    /// Multiplies two tensors together using matrix multiplication.
405    ///
406    /// # Arguments
407    ///
408    /// * `lhs` - The left hand side tensor.
409    /// * `rhs` - The right hand side tensor.
410    ///
411    /// # Returns
412    ///
413    /// The result of multiplying the two tensors together using matrix multiplication.
414    fn q_matmul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
415        dequant_op_quant!(
416            ty Self,
417            float_op |lhs, rhs| B::float_matmul(lhs, rhs),
418            lhs,
419            rhs
420        )
421    }
422
423    /// Negates a tensor element-wise.
424    fn q_neg(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
425        let scheme = *tensor.scheme();
426
427        let tensor_f = Self::dequantize(tensor);
428        let out_f = B::float_neg(tensor_f);
429
430        Self::quantize_dynamic(out_f, &scheme)
431    }
432
433    /// Calculates the reciprocals element-wise
434    fn q_recip(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
435        let scheme = *tensor.scheme();
436
437        let tensor_f = Self::dequantize(tensor);
438        let out_f = B::float_recip(tensor_f);
439
440        Self::quantize_dynamic(out_f, &scheme)
441    }
442
443    /// Transposes a tensor.
444    ///
445    /// # Arguments
446    ///
447    /// * `tensor` - The tensor to transpose.
448    ///
449    /// # Returns
450    ///
451    /// The transposed tensor.
452    fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
453        let ndims = tensor.shape().num_dims();
454        Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
455    }
456
457    /// Swaps two dimensions of a tensor.
458    ///
459    /// # Arguments
460    ///
461    /// * `tensor` - The tensor to swap the dimensions of.
462    /// * `dim1` - The first dimension to swap.
463    /// * `dim2` - The second dimension to swap.
464    ///
465    /// # Returns
466    ///
467    /// The tensor with the dimensions swapped.
468    fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
469
470    /// Permutes the dimensions of a tensor.
471    ///
472    /// # Arguments
473    ///
474    /// * `tensor` - The tensor to permute the dimensions of.
475    /// * `axes` - The new order of the dimensions.
476    /// # Returns
477    ///
478    /// The tensor with the dimensions permuted.
479    fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
480
481    /// Reverse the order of elements in a tensor along the given axes.
482    ///
483    /// # Arguments
484    ///
485    /// * `tensor` - The tensor to reverse.
486    /// * `axes` - The axes to reverse.
487    ///
488    /// The tensor with the elements reversed.
489    fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
490
491    /// Gather elements from a tensor.
492    ///
493    /// # Arguments
494    ///
495    /// * `dim` - The dimension to gather from.
496    /// * `tensor` - The tensor to gather from.
497    /// * `indices` - The indices to gather.
498    ///
499    /// # Returns
500    ///
501    /// The gathered elements.
502    fn q_gather(
503        dim: usize,
504        tensor: QuantizedTensor<B>,
505        indices: IntTensor<B>,
506    ) -> QuantizedTensor<B> {
507        // Default implementation. Backends can gather on the quantized values when supported.
508        dequant_op_quant!(
509            ty Self,
510            float_op |tensor| B::float_gather(dim, tensor, indices),
511            tensor
512        )
513    }
514
515    /// Scatter elements into a tensor.
516    ///
517    /// # Arguments
518    ///
519    /// * `dim` - The dimension to scatter into.
520    /// * `tensor` - The tensor to scatter into.
521    /// * `indices` - The indices to scatter into.
522    /// * `value` - The value to scatter.
523    ///
524    /// # Returns
525    ///
526    /// The tensor with the scattered elements.
527    fn q_scatter(
528        dim: usize,
529        tensor: QuantizedTensor<B>,
530        indices: IntTensor<B>,
531        value: QuantizedTensor<B>,
532    ) -> QuantizedTensor<B> {
533        dequant_op_quant!(
534            ty Self,
535            float_op |tensor, value| B::float_scatter(dim, tensor, indices, value),
536            tensor,
537            value
538        )
539    }
540
541    /// Select tensor elements along the given dimension corresponding for the given indices.
542    ///
543    /// # Arguments
544    ///
545    /// * `tensor` - The tensor to select from.
546    /// * `dim` - The dimension to select from.
547    /// * `indices` - The indices to select.
548    ///
549    /// # Returns
550    ///
551    /// The selected elements.
552    fn q_select(
553        tensor: QuantizedTensor<B>,
554        dim: usize,
555        indices: IntTensor<B>,
556    ) -> QuantizedTensor<B>;
557
558    /// Assign the selected elements along the given dimension corresponding for the given indices
559    /// to the given value.
560    ///
561    /// # Arguments
562    ///
563    /// * `tensor` - The tensor to select from.
564    /// * `dim` - The dimension to select from.
565    /// * `indices` - The indices to select.
566    /// * `value` - The value to assign.
567    ///
568    /// # Returns
569    ///
570    /// The tensor with the selected elements assigned to the given value.
571    fn q_select_assign(
572        tensor: QuantizedTensor<B>,
573        dim: usize,
574        indices: IntTensor<B>,
575        value: QuantizedTensor<B>,
576    ) -> QuantizedTensor<B> {
577        dequant_op_quant!(
578            ty Self,
579            float_op |tensor, value| B::float_select_assign(tensor, dim, indices, value),
580            tensor,
581            value
582        )
583    }
584
585    /// Select tensor elements corresponding for the given ranges.
586    ///
587    /// # Arguments
588    ///
589    /// * `tensor` - The tensor to select from.
590    /// * `ranges` - The ranges to select.
591    ///
592    /// # Returns
593    ///
594    /// The selected elements in a new tensor.
595    fn q_slice(tensor: QuantizedTensor<B>, ranges: &[Range<usize>]) -> QuantizedTensor<B>;
596
597    /// Assign the selected elements corresponding for the given ranges to the given value.
598    ///
599    /// # Arguments
600    ///
601    /// * `tensor` - The tensor to select from.
602    /// * `ranges` - The ranges to select.
603    /// * `value` - The value to assign.
604    ///
605    /// # Returns
606    ///
607    /// The tensor with the selected elements assigned to the given value.
608    fn q_slice_assign(
609        tensor: QuantizedTensor<B>,
610        ranges: &[Range<usize>],
611        value: QuantizedTensor<B>,
612    ) -> QuantizedTensor<B> {
613        dequant_op_quant!(
614            ty Self,
615            float_op |tensor, value| B::float_slice_assign(tensor, ranges, value),
616            tensor,
617            value
618        )
619    }
620
621    /// Update the given tensor with the value tensor where the mask is true.
622    ///
623    /// # Arguments
624    ///
625    /// * `tensor` - The tensor to select from.
626    /// * `mask` - The boolean mask to select with.
627    /// * `value` - The value to assign to the selected elements from the value tensor.
628    ///
629    /// # Returns
630    ///
631    /// The tensor with the selected elements assigned to the given value.
632    fn q_mask_where(
633        tensor: QuantizedTensor<B>,
634        mask: BoolTensor<B>,
635        value: QuantizedTensor<B>,
636    ) -> QuantizedTensor<B> {
637        dequant_op_quant!(
638            ty Self,
639            float_op |tensor, value| B::float_mask_where(tensor, mask, value),
640            tensor,
641            value
642        )
643    }
644
645    /// Update the given tensor with the value where the mask is true.
646    ///
647    /// # Arguments
648    ///
649    /// * `tensor` - The tensor to select from.
650    /// * `mask` - The boolean mask to select with.
651    /// * `value` - The value to assign to the selected elements.
652    ///
653    /// # Returns
654    ///
655    /// The tensor with the selected elements assigned to the given value.
656    fn q_mask_fill(
657        tensor: QuantizedTensor<B>,
658        mask: BoolTensor<B>,
659        value: FloatElem<B>,
660    ) -> QuantizedTensor<B> {
661        dequant_op_quant!(
662            ty Self,
663            float_op |tensor| B::float_mask_fill(tensor, mask, value),
664            tensor
665        )
666    }
667
668    /// Sum of all elements in a tensor.
669    ///
670    /// # Arguments
671    ///
672    /// * `tensor` - The tensor to sum.
673    ///
674    /// # Returns
675    ///
676    /// A scalar tensor with the sum of all elements in `tensor`.
677    fn q_sum(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
678        dequant_op_quant!(
679            ty Self,
680            float_op |tensor| B::float_sum(tensor),
681            tensor
682        )
683    }
684
685    /// Sum of all elements in a tensor along a dimension.
686    ///
687    /// # Arguments
688    ///
689    /// * `tensor` - The tensor to sum.
690    /// * `dim` - The dimension along which to sum.
691    ///
692    /// # Returns
693    ///
694    /// A tensor with the sum of all elements in `tensor` along `dim`.
695    fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
696        dequant_op_quant!(
697            ty Self,
698            float_op |tensor| B::float_sum_dim(tensor, dim),
699            tensor
700        )
701    }
702
703    /// Product of all elements in a tensor.
704    ///
705    /// # Arguments
706    ///
707    /// * `tensor` - The tensor to product.
708    ///
709    /// # Returns
710    ///
711    /// A scalar tensor with the product of all elements in `tensor`.
712    fn q_prod(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
713        dequant_op_quant!(
714            ty Self,
715            float_op |tensor| B::float_prod(tensor),
716            tensor
717        )
718    }
719
720    /// Product of all elements in a tensor along a dimension.
721    ///
722    /// # Arguments
723    ///
724    /// * `tensor` - The tensor to product.
725    ///
726    /// # Returns
727    ///
728    /// A tensor with the product of all elements in `tensor` along `dim`.
729    fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
730        dequant_op_quant!(
731            ty Self,
732            float_op |tensor| B::float_prod_dim(tensor, dim),
733            tensor
734        )
735    }
736
737    /// Mean of all elements in a tensor.
738    ///
739    /// # Arguments
740    ///
741    /// * `tensor` - The tensor to mean.
742    ///
743    /// # Returns
744    ///
745    /// A scalar tensor with the mean of all elements in `tensor`.
746    fn q_mean(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
747        dequant_op_quant!(
748            ty Self,
749            float_op |tensor| B::float_mean(tensor),
750            tensor
751        )
752    }
753
754    /// Mean of all elements in a tensor along a dimension.
755    ///
756    /// # Arguments
757    ///
758    /// * `tensor` - The tensor to mean.
759    /// * `dim` - The dimension along which to mean.
760    ///
761    /// # Returns
762    ///
763    /// A tensor with the mean of all elements in `tensor` along `dim`.
764    fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
765        dequant_op_quant!(
766            ty Self,
767            float_op |tensor| B::float_mean_dim(tensor, dim),
768            tensor
769        )
770    }
771
772    /// Returns a new tensor with exponential values.
773    ///
774    /// # Arguments
775    ///
776    /// * `tensor` - The tensor to exponentiate.
777    ///
778    /// # Returns
779    ///
780    /// A tensor with the same shape as `tensor` with exponential values.
781    fn q_exp(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
782        dequant_op_quant!(
783            ty Self,
784            float_op |tensor| B::float_exp(tensor),
785            tensor
786        )
787    }
788
789    /// Returns a new tensor with natural logarithm values.
790    ///
791    /// # Arguments
792    ///
793    /// * `tensor` - The tensor to take the logarithm of.
794    ///
795    /// # Returns
796    ///
797    /// A tensor with the same shape as `tensor` with natural logarithm values.
798    fn q_log(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
799        dequant_op_quant!(
800            ty Self,
801            float_op |tensor| B::float_log(tensor),
802            tensor
803        )
804    }
805
806    /// Returns a new tensor with logarithm values of (1 + Xi).
807    ///
808    /// # Arguments
809    ///
810    /// * `tensor` - The tensor to take the logarithm of.
811    ///
812    /// # Returns
813    ///
814    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
815    fn q_log1p(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
816        dequant_op_quant!(
817            ty Self,
818            float_op |tensor| B::float_log1p(tensor),
819            tensor
820        )
821    }
822
823    /// Element-wise power with another tensor.
824    ///
825    /// # Arguments
826    ///
827    /// * `lhs` - The left hand side tensor.
828    /// * `rhs` - The right hand side tensor.
829    ///
830    /// # Returns
831    ///
832    /// The elements of `lhs` raised to the power of the elements of `rhs`.
833    fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
834        dequant_op_quant!(
835            ty Self,
836            float_op |lhs, rhs| B::float_powf(lhs, rhs),
837            lhs,
838            rhs
839        )
840    }
841
842    /// Element-wise power with an IntTensor.
843    ///
844    /// # Arguments
845    ///
846    /// * `lhs` - The left hand side tensor.
847    /// * `rhs` - The right hand side floatTensor.
848    ///
849    /// # Returns
850    ///
851    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
852    fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> QuantizedTensor<B> {
853        dequant_op_quant!(
854            ty Self,
855            float_op |tensor| B::float_powi(tensor, rhs),
856            lhs
857        )
858    }
859
860    /// Element-wise power with an int scalar.
861    ///
862    /// # Arguments
863    ///
864    /// * `lhs` - The left hand side tensor.
865    /// * `rhs` - The right hand side scalar.
866    ///
867    /// # Returns
868    ///
869    /// The elements of `lhs` raised to the value of `rhs`.
870    fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: IntElem<B>) -> QuantizedTensor<B> {
871        dequant_op_quant!(
872            ty Self,
873            float_op |tensor| B::float_powi_scalar(tensor, rhs),
874            lhs
875        )
876    }
877
878    /// Element-wise power with a float scalar.
879    ///
880    /// # Arguments
881    ///
882    /// * `tensor` - The tensor to exponentiate.
883    /// * `value` - The exponent.
884    ///
885    /// # Returns
886    ///
887    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
888    fn q_powf_scalar(tensor: QuantizedTensor<B>, value: f32) -> QuantizedTensor<B> {
889        dequant_op_quant!(
890            ty Self,
891            float_op |tensor| B::float_powf_scalar(tensor, value),
892            tensor
893        )
894    }
895
896    /// Returns a new tensor with square root values.
897    ///
898    /// # Arguments
899    ///
900    /// * `tensor` - The tensor to take the square root of.
901    ///
902    /// # Returns
903    ///
904    /// A tensor with the same shape as `tensor` with square root values.
905    fn q_sqrt(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
906        dequant_op_quant!(
907            ty Self,
908            float_op |tensor| B::float_sqrt(tensor),
909            tensor
910        )
911    }
912
913    /// Returns a new tensor with absolute values.
914    ///
915    /// # Arguments
916    ///
917    /// * `tensor` - The tensor to take absolute value of.
918    ///
919    /// # Returns
920    ///
921    /// A tensor with the same shape as `tensor` with absolute values.
922    fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
923        dequant_op_quant!(
924            ty Self,
925            float_op |tensor| B::float_abs(tensor),
926            tensor
927        )
928    }
929
930    /// Returns a new tensor with cosine values.
931    ///
932    /// # Arguments
933    ///
934    /// * `tensor` - The tensor to take the cosine of.
935    ///
936    /// # Returns
937    ///
938    /// A tensor with the same shape as `tensor` with cosine values.
939    fn q_cos(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
940        dequant_op_quant!(
941            ty Self,
942            float_op |tensor| B::float_cos(tensor),
943            tensor
944        )
945    }
946
947    /// Returns a new tensor with sine values.
948    ///
949    /// # Arguments
950    ///
951    /// * `tensor` - The tensor to take the sine of.
952    ///
953    /// # Returns
954    ///
955    /// A tensor with the same shape as `tensor` with sine values.
956    fn q_sin(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
957        dequant_op_quant!(
958            ty Self,
959            float_op |tensor| B::float_sin(tensor),
960            tensor
961        )
962    }
963
964    /// Returns a new tensor with tangent values.
965    ///
966    /// # Arguments
967    ///
968    /// * `tensor` - The tensor to take the tangent of.
969    ///
970    /// # Returns
971    ///
972    /// A tensor with the same shape as `tensor` with tangent values.
973    fn q_tan(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
974        dequant_op_quant!(
975            ty Self,
976            float_op |tensor| B::float_tan(tensor),
977            tensor
978        )
979    }
980
981    /// Returns a new tensor with hyperbolic cosine values.
982    ///
983    /// # Arguments
984    ///
985    /// * `tensor` - The tensor to take the hyperbolic cosine of.
986    ///
987    /// # Returns
988    ///
989    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
990    fn q_cosh(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
991        dequant_op_quant!(
992            ty Self,
993            float_op |tensor| B::float_cosh(tensor),
994            tensor
995        )
996    }
997
998    /// Returns a new tensor with hyperbolic sine values.
999    ///
1000    /// # Arguments
1001    ///
1002    /// * `tensor` - The tensor to take the hyperbolic sine of.
1003    ///
1004    /// # Returns
1005    ///
1006    /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1007    fn q_sinh(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1008        dequant_op_quant!(
1009            ty Self,
1010            float_op |tensor| B::float_sinh(tensor),
1011            tensor
1012        )
1013    }
1014
1015    /// Returns a new tensor with hyperbolic tangent values.
1016    ///
1017    /// # Arguments
1018    ///
1019    /// * `tensor` - The tensor to take the hyperbolic tangent of.
1020    ///
1021    /// # Returns
1022    ///
1023    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1024    fn q_tanh(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1025        dequant_op_quant!(
1026            ty Self,
1027            float_op |tensor| B::float_tanh(tensor),
1028            tensor
1029        )
1030    }
1031
1032    /// Returns a new tensor with the error function values.
1033    ///
1034    /// # Arguments
1035    ///
1036    /// * `tensor` - The tensor to take the error function of.
1037    ///
1038    /// # Returns
1039    ///
1040    /// A tensor with the same shape as `tensor` with error function values.
1041    fn q_erf(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1042        dequant_op_quant!(
1043            ty Self,
1044            float_op |tensor| B::float_erf(tensor),
1045            tensor
1046        )
1047    }
1048
1049    /// Concatenates tensors along a dimension.
1050    ///
1051    /// # Arguments
1052    ///
1053    /// * `tensors` - The tensors to concatenate.
1054    /// * `dim` - The dimension along which to concatenate.
1055    ///
1056    /// # Returns
1057    ///
1058    /// A tensor with the concatenated tensors along `dim`.
1059    fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
1060        // Heuristic: prioritize first tensor scheme
1061        let scheme = *tensors.first().unwrap().scheme();
1062
1063        let tensor_f = tensors
1064            .into_iter()
1065            .map(|tensor| Self::dequantize(tensor))
1066            .collect();
1067
1068        let out_f = B::float_cat(tensor_f, dim);
1069
1070        Self::quantize_dynamic(out_f, &scheme)
1071    }
1072
1073    /// Gets the indices of the maximum elements of a tensor along an axis.
1074    ///
1075    /// # Arguments
1076    ///
1077    /// * `tensor` - The tensor to get the maximum elements of.
1078    /// * `dim` - The dimension along which to get the maximum elements.
1079    ///
1080    /// # Returns
1081    ///
1082    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1083    fn q_argmax(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1084        // Default implementation. Backends can sort on the int values since qparams remain the same.
1085        let tensor_f = Self::dequantize(tensor);
1086        B::float_argmax(tensor_f, dim)
1087    }
1088
1089    /// Gets the indices of the minimum elements of a tensor along an axis.
1090    ///
1091    /// # Arguments
1092    ///
1093    /// * `tensor` - The tensor to get the minimum elements of.
1094    /// * `dim` - The dimension along which to get the minimum elements.
1095    ///
1096    /// # Returns
1097    ///
1098    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1099    fn q_argmin(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1100        // Default implementation. Backends can sort on the int values since qparams remain the same.
1101        let tensor_f = Self::dequantize(tensor);
1102        B::float_argmin(tensor_f, dim)
1103    }
1104
1105    /// Gets the maximum element of a tensor.
1106    ///
1107    /// # Arguments
1108    ///
1109    /// * `tensor` - The tensor to get the maximum elements of.
1110    ///
1111    /// # Returns
1112    ///
1113    /// A tensor with the maximum element of `tensor`.
1114    fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1115        let shape = tensor.shape();
1116        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1117
1118        B::q_max_dim(tensor, 0)
1119    }
1120
1121    /// Gets the maximum elements of a tensor along an axis.
1122    ///
1123    /// # Arguments
1124    ///
1125    /// * `tensor` - The tensor to get the maximum elements of.
1126    /// * `dim` - The dimension along which to get the maximum elements.
1127    ///
1128    /// # Returns
1129    ///
1130    /// A tensor with the maximum elements of `tensor` along `dim`.
1131    fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1132        let index = B::q_argmax(tensor.clone(), dim);
1133
1134        B::q_gather(dim, tensor, index)
1135    }
1136
1137    /// Gets the maximum elements of a tensor along an axis and their indices.
1138    ///
1139    /// # Arguments
1140    ///
1141    /// * `tensor` - The tensor to get the maximum elements of.
1142    /// * `dim` - The dimension along which to get the maximum elements.
1143    ///
1144    /// # Returns
1145    ///
1146    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1147    fn q_max_dim_with_indices(
1148        tensor: QuantizedTensor<B>,
1149        dim: usize,
1150    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1151        let index = B::q_argmax(tensor.clone(), dim);
1152        let values = B::q_gather(dim, tensor, index.clone());
1153
1154        (values, index)
1155    }
1156
1157    /// Gets the minimum element of a tensor.
1158    ///
1159    /// # Arguments
1160    ///
1161    /// * `tensor` - The tensor to get the minimum elements of.
1162    ///
1163    /// # Returns
1164    ///
1165    /// A tensor with the minimum element of `tensor`.
1166    fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1167        let shape = tensor.shape();
1168        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1169
1170        B::q_min_dim(tensor, 0)
1171    }
1172
1173    /// Gets the minimum elements of a tensor along an axis.
1174    ///
1175    /// # Arguments
1176    ///
1177    /// * `tensor` - The tensor to get the minimum elements of.
1178    /// * `dim` - The dimension along which to get the minimum elements.
1179    ///
1180    /// # Returns
1181    ///
1182    /// A tensor with the minimum elements of `tensor` along `dim`.
1183    fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1184        let index = B::q_argmin(tensor.clone(), dim);
1185
1186        B::q_gather(dim, tensor, index)
1187    }
1188
1189    /// Gets the minimum elements of a tensor along an axis and their indices.
1190    ///
1191    /// # Arguments
1192    ///
1193    /// * `tensor` - The tensor to get the minimum elements of.
1194    /// * `dim` - The dimension along which to get the minimum elements.
1195    ///
1196    /// # Returns
1197    ///
1198    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1199    fn q_min_dim_with_indices(
1200        tensor: QuantizedTensor<B>,
1201        dim: usize,
1202    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1203        let index = B::q_argmin(tensor.clone(), dim);
1204        let values = B::q_gather(dim, tensor, index.clone());
1205
1206        (values, index)
1207    }
1208
1209    /// Gets the maximum element of a tensor.
1210    ///
1211    /// # Arguments
1212    ///
1213    /// * `tensor` - The tensor to get the maximum elements of.
1214    ///
1215    /// # Returns
1216    ///
1217    /// A tensor with the maximum element of `tensor`.
1218    fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1219        let shape = tensor.shape();
1220        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1221
1222        B::q_max_abs_dim(tensor, 0)
1223    }
1224
1225    /// Gets the maximum elements of a tensor along an axis.
1226    ///
1227    /// # Arguments
1228    ///
1229    /// * `tensor` - The tensor to get the maximum elements of.
1230    /// * `dim` - The dimension along which to get the maximum elements.
1231    ///
1232    /// # Returns
1233    ///
1234    /// A tensor with the maximum elements of `tensor` along `dim`.
1235    fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1236        let index = B::q_argmax(B::q_abs(tensor.clone()), dim);
1237
1238        B::q_gather(dim, tensor, index)
1239    }
1240
1241    /// Returns a new tensor with the given dimension narrowed to the given range.
1242    ///
1243    /// # Arguments
1244    ///
1245    /// * `dim` - The dimension along which the tensor will be narrowed.
1246    /// * `start` - The starting point of the given range.
1247    /// * `length` - The ending point of the given range.
1248    /// # Panics
1249    ///
1250    /// - If the dimension is greater than the number of dimensions of the tensor.
1251    /// - If the given range exceeds the number of elements on the given dimension.
1252    ///
1253    /// # Returns
1254    ///
1255    /// A new tensor with the given dimension narrowed to the given range.
1256    fn q_narrow(
1257        tensor: QuantizedTensor<B>,
1258        dim: usize,
1259        start: usize,
1260        length: usize,
1261    ) -> QuantizedTensor<B> {
1262        dequant_op_quant!(
1263            ty Self,
1264            float_op |tensor| B::float_narrow(tensor, dim, start, length),
1265            tensor
1266        )
1267    }
1268
1269    /// Split the tensor along the given dimension into chunks.
1270    ///
1271    /// # Arguments
1272    ///
1273    /// * `tensor` - The tensor.
1274    /// * `chunks` - The number of chunks to be produced.
1275    /// * `times` - The dimension along which the tensor will be split.
1276    ///
1277    /// # Returns
1278    ///
1279    /// A vector of tensors.
1280    fn q_chunk(tensor: QuantizedTensor<B>, chunks: usize, dim: usize) -> Vec<QuantizedTensor<B>> {
1281        let scheme = *tensor.scheme();
1282
1283        let tensor_f = Self::dequantize(tensor);
1284        let out_f = B::float_chunk(tensor_f, chunks, dim);
1285
1286        out_f
1287            .into_iter()
1288            .map(|tensor| Self::quantize_dynamic(tensor, &scheme))
1289            .collect()
1290    }
1291
1292    /// Split the tensor along the given dimension into chunks of `split_size`.
1293    ///
1294    /// # Arguments
1295    ///
1296    /// * `tensor` - The tensor.
1297    /// * `split_size` - The size of a single chunk.
1298    /// * `times` - The dimension along which the tensor will be split.
1299    ///
1300    /// # Returns
1301    ///
1302    /// A vector of tensors.
1303    fn q_split(
1304        tensor: QuantizedTensor<B>,
1305        split_size: usize,
1306        dim: usize,
1307    ) -> Vec<QuantizedTensor<B>> {
1308        let scheme = *tensor.scheme();
1309
1310        let tensor_f = Self::dequantize(tensor);
1311        let out_f = B::float_split(tensor_f, split_size, dim);
1312
1313        out_f
1314            .into_iter()
1315            .map(|tensor| Self::quantize_dynamic(tensor, &scheme))
1316            .collect()
1317    }
1318
1319    /// Split the tensor along the given dimension into chunks with sizes in
1320    /// `dim` according to `split_sizes`.
1321    ///
1322    /// # Arguments
1323    ///
1324    /// * `tensor` - The tensor.
1325    /// * `split_sizes` - Vector of sizes for each chunk.
1326    /// * `times` - The dimension along which the tensor will be split.
1327    ///
1328    /// # Returns
1329    ///
1330    /// A vector of tensors.
1331    fn q_split_with_sizes(
1332        tensor: QuantizedTensor<B>,
1333        split_sizes: Vec<usize>,
1334        dim: usize,
1335    ) -> Vec<QuantizedTensor<B>> {
1336        let scheme = *tensor.scheme();
1337
1338        let tensor_f = Self::dequantize(tensor);
1339        let out_f = B::float_split_with_sizes(tensor_f, split_sizes, dim);
1340
1341        out_f
1342            .into_iter()
1343            .map(|tensor| Self::quantize_dynamic(tensor, &scheme))
1344            .collect()
1345    }
1346
1347    /// Tests if any element in the `tensor` evaluates to True.
1348    ///
1349    /// # Arguments
1350    ///
1351    /// * `tensor` - The tensor to test.
1352    ///
1353    /// # Returns
1354    ///
1355    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1356    fn q_any(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1357        let tensor_f = Self::dequantize(tensor);
1358        B::float_any(tensor_f)
1359    }
1360
1361    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1362    ///
1363    /// # Arguments
1364    ///
1365    /// * `tensor` - The tensor to test.
1366    /// * `dim` - The axis along which to test.
1367    ///
1368    /// # Returns
1369    ///
1370    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1371    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1372    /// input evaluates to True, False otherwise.
1373    fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1374        let tensor_f = Self::dequantize(tensor);
1375        B::float_any_dim(tensor_f, dim)
1376    }
1377
1378    /// Tests if all elements in the `tensor` evaluate to True.
1379    ///
1380    /// # Arguments
1381    ///
1382    /// * `tensor` - The tensor to test.
1383    ///
1384    /// # Returns
1385    ///
1386    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1387    /// evaluate to True, False otherwise.
1388    fn q_all(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1389        let tensor_f = Self::dequantize(tensor);
1390        B::float_all(tensor_f)
1391    }
1392
1393    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1394    ///
1395    /// # Arguments
1396    ///
1397    /// * `tensor` - The tensor to test.
1398    /// * `dim` - The axis along which to test.
1399    ///
1400    /// # Returns
1401    ///
1402    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1403    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1404    /// evaluates to True, False otherwise.
1405    fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1406        let tensor_f = Self::dequantize(tensor);
1407        B::float_all_dim(tensor_f, dim)
1408    }
1409
1410    /// Broadcasts the `tensor` to the given `shape`.
1411    fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
1412
1413    /// Sort the elements of the input `tensor` by value in along a given dimension.
1414    ///
1415    /// This sort is unstable (i.e., may reorder equal elements).
1416    ///
1417    /// # Arguments
1418    ///
1419    /// * `tensor` - The input tensor.
1420    /// * `dim` - The axis along which to sort.
1421    /// * `descending` - The sorting order.
1422    ///
1423    /// # Returns
1424    ///
1425    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1426    fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1427        // Default implementation. Backends can sort on the int values since qparams remain the same.
1428        dequant_op_quant!(
1429            ty Self,
1430            float_op |tensor| B::float_sort(tensor, dim, descending),
1431            tensor
1432        )
1433    }
1434
1435    /// Sort the elements of the input `tensor` by value in along a given dimension.
1436    ///
1437    /// This sort is unstable (i.e., may reorder equal elements).
1438    ///
1439    /// # Arguments
1440    ///
1441    /// * `tensor` - The input tensor.
1442    /// * `dim` - The axis along which to sort.
1443    /// * `descending` - The sorting order.
1444    ///
1445    /// # Returns
1446    ///
1447    /// A tensor with the same shape as the input tensor and corresponding indices, where
1448    /// the elements are sorted by value and the indices map back to the original input tensor.
1449    fn q_sort_with_indices(
1450        tensor: QuantizedTensor<B>,
1451        dim: usize,
1452        descending: bool,
1453    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1454        // Default implementation. Backends can sort on the int values since qparams remain the same.
1455        let scheme = *tensor.scheme();
1456
1457        let tensor_f = Self::dequantize(tensor);
1458        let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending);
1459
1460        (Self::quantize_dynamic(out_f, &scheme), indices)
1461    }
1462
1463    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1464    ///
1465    /// This sort is unstable (i.e., may reorder equal elements).
1466    ///
1467    /// # Arguments
1468    ///
1469    /// * `tensor` - The input tensor.
1470    /// * `dim` - The axis along which to sort.
1471    /// * `descending` - The sorting order.
1472    ///
1473    /// # Returns
1474    ///
1475    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1476    fn q_argsort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1477        // Default implementation. Backends can sort on the int values since qparams remain the same.
1478        let tensor_f = Self::dequantize(tensor);
1479        B::float_argsort(tensor_f, dim, descending)
1480    }
1481}