Skip to main content

burn_backend/backend/ops/
qtensor.rs

1use alloc::vec::Vec;
2use burn_std::{
3    BoolDType, FloatDType, IntDType, Shape, Slice,
4    quantization::{QuantPropagation, QuantScheme},
5};
6
7use crate::{
8    Backend, ExecutionError, QTensorPrimitive, TensorData, TensorMetadata, TensorPrimitive,
9    get_device_settings,
10};
11use crate::{
12    Scalar,
13    tensor::{
14        BoolTensor, Device, FloatTensor, IntTensor, QuantizedTensor,
15        quantization::{
16            Calibration, QuantizationParametersPrimitive, compute_q_params, compute_range,
17        },
18    },
19};
20
21/// Automatically applies `dequantization -> float operation -> quantization`.
22///
23/// Used for tensor ops that should always return a quantized output.
24#[macro_export]
25macro_rules! dequant_op_quant {
26    // Binary tensor float op w/ lhs & rhs
27    (
28        float_op $float_op:expr, $t1:expr, $t2:expr
29    ) => {{
30        // Heuristic: prioritize lhs scheme
31        let scheme = $t1.scheme().clone();
32
33        let t1_f = Self::dequantize($t1);
34        let t2_f = Self::dequantize($t2);
35        #[allow(clippy::redundant_closure_call)]
36        let out_f = $float_op(t1_f, t2_f);
37
38        Self::quantize_dynamic(out_f, &scheme)
39    }};
40    // Unary tensor float op
41    (
42        float_op $float_op:expr, $tensor:expr
43    ) => {{
44        let scheme = $tensor.scheme().clone();
45        let dtype = get_device_settings::<B>(&Self::q_device(&$tensor)).float_dtype;
46
47        let tensor_f = Self::dequantize($tensor, dtype);
48        #[allow(clippy::redundant_closure_call)]
49        let out_f = $float_op(tensor_f);
50
51        Self::quantize_dynamic(out_f, &scheme)
52    }};
53}
54
55/// Automatically applies `dequantization -> float operation [-> quantization]`.
56///
57/// The output quantization step is optional.
58/// It is only performed when the input quantization scheme is propagated.
59#[macro_export]
60macro_rules! dequant_op_flow {
61    // Binary tensor float op w/ lhs & rhs
62    (
63        float_op $float_op:expr, $t1:expr, $t2:expr
64    ) => {{
65        // Heuristic: prioritize lhs scheme
66        let scheme = $t1.scheme().clone();
67        let propagation = $t1.propagation();
68        let dtype = get_device_settings::<B>(&Self::q_device(&$t1)).float_dtype;
69
70        let t1_f = Self::dequantize($t1, dtype);
71        let t2_f = Self::dequantize($t2, dtype);
72        #[allow(clippy::redundant_closure_call)]
73        let out_f = $float_op(t1_f, t2_f);
74
75        match propagation {
76            QuantPropagation::Propagate => {
77                TensorPrimitive::QFloat(Self::quantize_dynamic(out_f, &scheme))
78            }
79            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
80        }
81    }};
82    // Unary tensor float op
83    (
84        float_op $float_op:expr, $tensor:expr
85    ) => {{
86        let scheme = $tensor.scheme().clone();
87        let propagation = $tensor.propagation();
88        let dtype = get_device_settings::<B>(&Self::q_device(&$tensor)).float_dtype;
89
90        let tensor_f = Self::dequantize($tensor, dtype);
91        #[allow(clippy::redundant_closure_call)]
92        let out_f = $float_op(tensor_f);
93
94        match propagation {
95            QuantPropagation::Propagate => {
96                TensorPrimitive::QFloat(Self::quantize_dynamic(out_f, &scheme))
97            }
98            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
99        }
100    }};
101}
102
103/// Operations on quantized tensors.
104///
105/// # Return Type Semantics
106///
107/// The return type of each operation indicates how quantization is handled:
108///
109/// ## [`QuantizedTensor<B>`]
110/// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized
111/// representation. Implementations should avoid dequantizing when possible to maintain performance.
112/// For example, shape or layout changes such as expand or transpose preserve quantization.
113///
114/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is
115/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering
116/// the quantization parameters to match the new layout.*
117///
118///
119/// ## [`TensorPrimitive<B>`]
120/// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation
121/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`])
122/// returned in floating-point form ([`TensorPrimitive::Float`]).
123///
124/// This distinction allows for fine-grained control over mixed-precision flows while still operating
125/// through a unified API.
126pub trait QTensorOps<B: Backend> {
127    /// Creates a new tensor from the data structure.
128    ///
129    /// # Arguments
130    ///
131    /// * `data` - The data structure.
132    /// * `device` - The device to create the tensor on.
133    ///
134    /// # Returns
135    ///
136    /// The tensor with the given data.
137    fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
138
139    /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
140    fn quantize(
141        tensor: FloatTensor<B>,
142        scheme: &QuantScheme,
143        qparams: QuantizationParametersPrimitive<B>,
144    ) -> QuantizedTensor<B>;
145
146    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
147    fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> {
148        // Dynamically compute min/max tensor range and qparams before quantizing
149        let (min, max) = compute_range::<B>(scheme, tensor.clone(), &Calibration::MinMax);
150        let qparams = compute_q_params(scheme, min, max);
151        Self::quantize(tensor, scheme, qparams)
152    }
153
154    /// Convert the tensor back to a higher precision data type.
155    fn dequantize(tensor: QuantizedTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
156
157    /// Gets the device of the tensor.
158    ///
159    /// # Arguments
160    ///
161    /// * `tensor` - The tensor.
162    ///
163    /// # Returns
164    ///
165    /// The device of the tensor.
166    fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
167
168    /// Moves the tensor to the given device.
169    ///
170    /// # Arguments
171    ///
172    /// * `tensor` - The tensor.
173    /// * `device` - The device to move the tensor to.
174    ///
175    /// # Returns
176    ///
177    /// The tensor on the given device.
178    fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
179
180    /// Reshapes a tensor.
181    ///
182    /// # Arguments
183    ///
184    /// * `tensor` - The tensor to reshape.
185    /// * `shape` - The new shape of the tensor.
186    ///
187    /// # Returns
188    ///
189    /// The tensor with the new shape.
190    fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
191
192    /// Converts the tensor to a data structure.
193    ///
194    /// # Arguments
195    ///
196    /// * `tensor` - The tensor.
197    ///
198    /// # Returns
199    ///
200    /// The data structure with the tensor's data.
201    fn q_into_data(
202        tensor: QuantizedTensor<B>,
203    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
204
205    /// Detaches a tensor from the computation graph.
206    fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
207        // Should only be overridden by autodiff backends.
208        tensor
209    }
210
211    /// Sets the `require_grad` flag of a tensor.
212    fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
213        // Should only be overridden by autodiff backends.
214        tensor
215    }
216
217    /// Returns the `require_grad` flag of a tensor.
218    fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
219        // Should only be overridden by autodiff backends.
220        false
221    }
222
223    /// Broadcasts the `tensor` to the given `shape`.
224    fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
225
226    /// Transposes a tensor.
227    ///
228    /// # Arguments
229    ///
230    /// * `tensor` - The tensor to transpose.
231    ///
232    /// # Returns
233    ///
234    /// The transposed tensor.
235    fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
236        let ndims = tensor.shape().num_dims();
237        Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
238    }
239
240    /// Swaps two dimensions of a tensor.
241    ///
242    /// # Arguments
243    ///
244    /// * `tensor` - The tensor to swap the dimensions of.
245    /// * `dim1` - The first dimension to swap.
246    /// * `dim2` - The second dimension to swap.
247    ///
248    /// # Returns
249    ///
250    /// The tensor with the dimensions swapped.
251    fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
252
253    /// Permutes the dimensions of a tensor.
254    ///
255    /// # Arguments
256    ///
257    /// * `tensor` - The tensor to permute the dimensions of.
258    /// * `axes` - The new order of the dimensions.
259    /// # Returns
260    ///
261    /// The tensor with the dimensions permuted.
262    fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
263
264    /// Reverse the order of elements in a tensor along the given axes.
265    ///
266    /// # Arguments
267    ///
268    /// * `tensor` - The tensor to reverse.
269    /// * `axes` - The axes to reverse.
270    ///
271    /// The tensor with the elements reversed.
272    fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
273
274    /// Select tensor elements along the given dimension corresponding for the given indices.
275    ///
276    /// # Arguments
277    ///
278    /// * `tensor` - The tensor to select from.
279    /// * `dim` - The dimension to select from.
280    /// * `indices` - The indices to select.
281    ///
282    /// # Returns
283    ///
284    /// The selected elements.
285    fn q_select(
286        tensor: QuantizedTensor<B>,
287        dim: usize,
288        indices: IntTensor<B>,
289    ) -> QuantizedTensor<B>;
290
291    /// Select tensor elements corresponding to the given slices.
292    ///
293    /// # Arguments
294    ///
295    /// * `tensor` - The tensor to select from.
296    /// * `slices` - The slices specifying ranges and steps for each dimension.
297    ///
298    /// # Returns
299    ///
300    /// The selected elements in a new tensor.
301    fn q_slice(tensor: QuantizedTensor<B>, slices: &[Slice]) -> QuantizedTensor<B>;
302
303    /// Gather elements from a tensor.
304    ///
305    /// # Arguments
306    ///
307    /// * `dim` - The dimension to gather from.
308    /// * `tensor` - The tensor to gather from.
309    /// * `indices` - The indices to gather.
310    ///
311    /// # Returns
312    ///
313    /// The gathered elements.
314    fn q_gather(
315        dim: usize,
316        tensor: QuantizedTensor<B>,
317        indices: IntTensor<B>,
318    ) -> QuantizedTensor<B> {
319        // Default implementation. Backends can gather on the quantized values when supported.
320        dequant_op_quant!(
321            float_op | tensor | B::float_gather(dim, tensor, indices),
322            tensor
323        )
324    }
325
326    /// Repeat the tensor along the given dimension.
327    ///
328    /// # Arguments
329    ///
330    /// * `tensor` - The tensor.
331    /// * `dim` - The dimension to repeat.
332    /// * `times` - The number of times to repeat the dimension.
333    ///
334    /// # Returns
335    ///
336    /// The tensor with the given dimension repeated.
337    fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
338        dequant_op_quant!(
339            float_op | tensor | B::float_repeat_dim(tensor, dim, times),
340            tensor
341        )
342    }
343
344    /// Adds two tensors together.
345    ///
346    /// # Arguments
347    ///
348    /// * `lhs` - The left hand side tensor.
349    /// * `rhs` - The right hand side tensor.
350    ///
351    /// # Returns
352    ///
353    /// The result of adding the two tensors together.
354    fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
355        dequant_op_flow!(float_op | lhs, rhs | B::float_add(lhs, rhs), lhs, rhs)
356    }
357
358    /// Adds a scalar to a tensor.
359    ///
360    /// # Arguments
361    ///
362    /// * `lhs` - The left hand side tensor.
363    /// * `rhs` - The right hand side scalar.
364    ///
365    /// # Returns
366    ///
367    /// The result of adding the scalar to the tensor.
368    fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
369        dequant_op_flow!(float_op | tensor | B::float_add_scalar(tensor, rhs), lhs)
370    }
371
372    /// Clamps a tensor under a minimum value.
373    ///
374    /// # Arguments
375    ///
376    /// * `tensor` - The tensor to clamp.
377    /// * `min` - The minimum value.
378    ///
379    /// # Returns
380    ///
381    /// The clamped tensor.
382    fn q_clamp_min(tensor: QuantizedTensor<B>, min: Scalar) -> TensorPrimitive<B> {
383        dequant_op_flow!(float_op | tensor | B::float_clamp_min(tensor, min), tensor)
384    }
385
386    /// Clamps a tensor over a maximum value.
387    ///
388    /// # Arguments
389    ///
390    /// * `tensor` - The tensor to clamp.
391    /// * `max` - The maximum value.
392    ///
393    /// # Returns
394    ///
395    /// The clamped tensor.
396    fn q_clamp_max(tensor: QuantizedTensor<B>, max: Scalar) -> TensorPrimitive<B> {
397        dequant_op_flow!(float_op | tensor | B::float_clamp_max(tensor, max), tensor)
398    }
399
400    /// Clamps a tensor between a minimum and maximum value.
401    ///
402    /// # Arguments
403    ///
404    /// * `tensor` - The tensor to clamp.
405    /// * `min` - The minimum value.
406    /// * `max` - The maximum value.
407    ///
408    /// # Returns
409    ///
410    /// The clamped tensor.
411    fn q_clamp(tensor: QuantizedTensor<B>, min: Scalar, max: Scalar) -> TensorPrimitive<B> {
412        dequant_op_flow!(float_op | tensor | B::float_clamp(tensor, min, max), tensor)
413    }
414
415    /// Subtracts two tensors.
416    ///
417    /// # Arguments
418    ///
419    /// * `lhs` - The left hand side tensor.
420    /// * `rhs` - The right hand side tensor.
421    ///
422    /// # Returns
423    ///
424    /// The result of subtracting the two tensors.
425    fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
426        dequant_op_flow!(float_op | lhs, rhs | B::float_sub(lhs, rhs), lhs, rhs)
427    }
428
429    /// Subtracts a scalar from a tensor.
430    ///
431    /// # Arguments
432    ///
433    /// * `lhs` - The left hand side tensor.
434    /// * `rhs` - The right hand side scalar.
435    ///
436    /// # Returns
437    ///
438    /// The result of subtracting the scalar from the tensor.
439    fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
440        dequant_op_flow!(float_op | tensor | B::float_sub_scalar(tensor, rhs), lhs)
441    }
442
443    /// Multiplies two tensors together element-wise.
444    fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
445        dequant_op_flow!(float_op | lhs, rhs | B::float_mul(lhs, rhs), lhs, rhs)
446    }
447
448    /// Multiplies a tensor by a scalar.
449    ///
450    /// # Arguments
451    ///
452    /// * `lhs` - The left hand side tensor.
453    /// * `rhs` - The right hand side scalar.
454    ///
455    /// # Returns
456    ///
457    /// The result of multiplying the tensor by the scalar.
458    fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
459        dequant_op_flow!(float_op | tensor | B::float_mul_scalar(tensor, rhs), lhs)
460    }
461
462    /// Divides two tensors element-wise.
463    ///
464    /// # Arguments
465    ///
466    /// * `lhs` - The left hand side tensor.
467    /// * `rhs` - The right hand side tensor.
468    ///
469    /// # Returns
470    ///
471    /// The result of dividing the two tensors.
472    fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
473        dequant_op_flow!(float_op | lhs, rhs | B::float_div(lhs, rhs), lhs, rhs)
474    }
475
476    /// Divides a tensor by a scalar.
477    ///
478    /// # Arguments
479    ///
480    /// * `lhs` - The left hand side tensor.
481    /// * `rhs` - The right hand side scalar.
482    ///
483    /// # Returns
484    ///
485    /// The result of dividing the tensor by the scalar.
486    fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
487        dequant_op_flow!(float_op | tensor | B::float_div_scalar(tensor, rhs), lhs)
488    }
489
490    /// Multiplies two tensors together using matrix multiplication.
491    ///
492    /// # Arguments
493    ///
494    /// * `lhs` - The left hand side tensor.
495    /// * `rhs` - The right hand side tensor.
496    ///
497    /// # Returns
498    ///
499    /// The result of multiplying the two tensors together using matrix multiplication.
500    fn q_matmul(lhs: TensorPrimitive<B>, rhs: TensorPrimitive<B>) -> TensorPrimitive<B> {
501        let mut propagation = QuantPropagation::Inhibit;
502        let mut scheme = QuantScheme::default();
503
504        // Pick a target dtype for any dequantization. If either operand is already
505        // a Float tensor, take its dtype so a Float-QFloat (or QFloat-Float) pair
506        // ends up matching after dequantize and `float_matmul` doesn't see a
507        // dtype mismatch. Only when both operands are QFloat do we fall back to
508        // the device default.
509        let target_dtype: Option<FloatDType> = match (&lhs, &rhs) {
510            (TensorPrimitive::Float(t), _) | (_, TensorPrimitive::Float(t)) => {
511                Some(t.dtype().into())
512            }
513            _ => None,
514        };
515
516        let lhs = match lhs {
517            TensorPrimitive::Float(lhs) => lhs,
518            TensorPrimitive::QFloat(lhs) => {
519                propagation = lhs.propagation();
520                scheme = *lhs.scheme();
521                let float_dtype = target_dtype
522                    .unwrap_or_else(|| get_device_settings::<B>(&Self::q_device(&lhs)).float_dtype);
523
524                Self::dequantize(lhs, float_dtype)
525            }
526        };
527        let rhs = match rhs {
528            TensorPrimitive::Float(rhs) => rhs,
529            TensorPrimitive::QFloat(rhs) => {
530                propagation = rhs.propagation();
531                scheme = *rhs.scheme();
532                let float_dtype = target_dtype
533                    .unwrap_or_else(|| get_device_settings::<B>(&Self::q_device(&rhs)).float_dtype);
534
535                Self::dequantize(rhs, float_dtype)
536            }
537        };
538
539        let out_f = B::float_matmul(lhs, rhs);
540        match propagation {
541            QuantPropagation::Propagate => {
542                TensorPrimitive::QFloat(<Self>::quantize_dynamic(out_f, &scheme))
543            }
544            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
545        }
546    }
547
548    /// Negates a tensor element-wise.
549    fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
550        dequant_op_flow!(float_op | tensor | B::float_neg(tensor), tensor)
551    }
552
553    /// Calculates the reciprocals element-wise
554    fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
555        dequant_op_flow!(float_op | tensor | B::float_recip(tensor), tensor)
556    }
557
558    /// Sum of all elements in a tensor.
559    ///
560    /// # Arguments
561    ///
562    /// * `tensor` - The tensor to sum.
563    ///
564    /// # Returns
565    ///
566    /// A scalar tensor with the sum of all elements in `tensor`.
567    fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
568        dequant_op_flow!(float_op | tensor | B::float_sum(tensor), tensor)
569    }
570
571    /// Sum of all elements in a tensor along a dimension.
572    ///
573    /// # Arguments
574    ///
575    /// * `tensor` - The tensor to sum.
576    /// * `dim` - The dimension along which to sum.
577    ///
578    /// # Returns
579    ///
580    /// A tensor with the sum of all elements in `tensor` along `dim`.
581    fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
582        dequant_op_flow!(float_op | tensor | B::float_sum_dim(tensor, dim), tensor)
583    }
584
585    /// Product of all elements in a tensor.
586    ///
587    /// # Arguments
588    ///
589    /// * `tensor` - The tensor to product.
590    ///
591    /// # Returns
592    ///
593    /// A scalar tensor with the product of all elements in `tensor`.
594    fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
595        dequant_op_flow!(float_op | tensor | B::float_prod(tensor), tensor)
596    }
597
598    /// Product of all elements in a tensor along a dimension.
599    ///
600    /// # Arguments
601    ///
602    /// * `tensor` - The tensor to product.
603    ///
604    /// # Returns
605    ///
606    /// A tensor with the product of all elements in `tensor` along `dim`.
607    fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
608        dequant_op_flow!(float_op | tensor | B::float_prod_dim(tensor, dim), tensor)
609    }
610
611    /// Mean of all elements in a tensor.
612    ///
613    /// # Arguments
614    ///
615    /// * `tensor` - The tensor to mean.
616    ///
617    /// # Returns
618    ///
619    /// A scalar tensor with the mean of all elements in `tensor`.
620    fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
621        dequant_op_flow!(float_op | tensor | B::float_mean(tensor), tensor)
622    }
623
624    /// Mean of all elements in a tensor along a dimension.
625    ///
626    /// # Arguments
627    ///
628    /// * `tensor` - The tensor to mean.
629    /// * `dim` - The dimension along which to mean.
630    ///
631    /// # Returns
632    ///
633    /// A tensor with the mean of all elements in `tensor` along `dim`.
634    fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
635        dequant_op_flow!(float_op | tensor | B::float_mean_dim(tensor, dim), tensor)
636    }
637
638    /// Computes the cumulative sum of elements along a dimension.
639    ///
640    /// # Arguments
641    ///
642    /// * `tensor` - The tensor to compute the cumulative sum of.
643    /// * `dim` - The dimension along which to compute the cumulative sum.
644    ///
645    /// # Returns
646    ///
647    /// A tensor with the same shape where each element is the cumulative sum
648    /// of all elements up to and including that position along the dimension.
649    fn q_cumsum(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
650        dequant_op_flow!(float_op | tensor | B::float_cumsum(tensor, dim), tensor)
651    }
652
653    /// Computes the cumulative product of elements along a dimension.
654    ///
655    /// # Arguments
656    ///
657    /// * `tensor` - The tensor to compute the cumulative product of.
658    /// * `dim` - The dimension along which to compute the cumulative product.
659    ///
660    /// # Returns
661    ///
662    /// A tensor with the same shape where each element is the cumulative product
663    /// of all elements up to and including that position along the dimension.
664    fn q_cumprod(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
665        dequant_op_flow!(float_op | tensor | B::float_cumprod(tensor, dim), tensor)
666    }
667
668    /// Computes the cumulative minimum of elements along a dimension.
669    ///
670    /// # Arguments
671    ///
672    /// * `tensor` - The tensor to compute the cumulative minimum of.
673    /// * `dim` - The dimension along which to compute the cumulative minimum.
674    ///
675    /// # Returns
676    ///
677    /// A tensor with the same shape where each element is the minimum
678    /// of all elements up to and including that position along the dimension.
679    fn q_cummin(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
680        dequant_op_flow!(float_op | tensor | B::float_cummin(tensor, dim), tensor)
681    }
682
683    /// Computes the cumulative maximum of elements along a dimension.
684    ///
685    /// # Arguments
686    ///
687    /// * `tensor` - The tensor to compute the cumulative maximum of.
688    /// * `dim` - The dimension along which to compute the cumulative maximum.
689    ///
690    /// # Returns
691    ///
692    /// A tensor with the same shape where each element is the maximum
693    /// of all elements up to and including that position along the dimension.
694    fn q_cummax(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
695        dequant_op_flow!(float_op | tensor | B::float_cummax(tensor, dim), tensor)
696    }
697
698    /// Returns a new tensor with exponential values.
699    ///
700    /// # Arguments
701    ///
702    /// * `tensor` - The tensor to exponentiate.
703    ///
704    /// # Returns
705    ///
706    /// A tensor with the same shape as `tensor` with exponential values.
707    fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
708        dequant_op_flow!(float_op | tensor | B::float_exp(tensor), tensor)
709    }
710
711    /// Returns a new tensor with natural logarithm values.
712    ///
713    /// # Arguments
714    ///
715    /// * `tensor` - The tensor to take the logarithm of.
716    ///
717    /// # Returns
718    ///
719    /// A tensor with the same shape as `tensor` with natural logarithm values.
720    fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
721        dequant_op_flow!(float_op | tensor | B::float_log(tensor), tensor)
722    }
723
724    /// Returns a new tensor with logarithm values of (1 + Xi).
725    ///
726    /// # Arguments
727    ///
728    /// * `tensor` - The tensor to take the logarithm of.
729    ///
730    /// # Returns
731    ///
732    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
733    fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
734        dequant_op_flow!(float_op | tensor | B::float_log1p(tensor), tensor)
735    }
736
737    /// Element-wise power with another tensor.
738    ///
739    /// # Arguments
740    ///
741    /// * `lhs` - The left hand side tensor.
742    /// * `rhs` - The right hand side tensor.
743    ///
744    /// # Returns
745    ///
746    /// The elements of `lhs` raised to the power of the elements of `rhs`.
747    fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
748        dequant_op_flow!(float_op | lhs, rhs | B::float_powf(lhs, rhs), lhs, rhs)
749    }
750
751    /// Element-wise power with an IntTensor.
752    ///
753    /// # Arguments
754    ///
755    /// * `lhs` - The left hand side tensor.
756    /// * `rhs` - The right hand side floatTensor.
757    ///
758    /// # Returns
759    ///
760    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
761    fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> {
762        dequant_op_flow!(float_op | tensor | B::float_powi(tensor, rhs), lhs)
763    }
764
765    /// Element-wise power with an int scalar.
766    ///
767    /// # Arguments
768    ///
769    /// * `lhs` - The left hand side tensor.
770    /// * `rhs` - The right hand side scalar.
771    ///
772    /// # Returns
773    ///
774    /// The elements of `lhs` raised to the value of `rhs`.
775    fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
776        dequant_op_flow!(float_op | tensor | B::float_powi_scalar(tensor, rhs), lhs)
777    }
778
779    /// Element-wise power with a float scalar.
780    ///
781    /// # Arguments
782    ///
783    /// * `tensor` - The tensor to exponentiate.
784    /// * `value` - The exponent.
785    ///
786    /// # Returns
787    ///
788    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
789    fn q_powf_scalar(tensor: QuantizedTensor<B>, value: Scalar) -> TensorPrimitive<B> {
790        dequant_op_flow!(
791            float_op | tensor | B::float_powf_scalar(tensor, value),
792            tensor
793        )
794    }
795
796    /// Returns a new tensor with square root values.
797    ///
798    /// # Arguments
799    ///
800    /// * `tensor` - The tensor to take the square root of.
801    ///
802    /// # Returns
803    ///
804    /// A tensor with the same shape as `tensor` with square root values.
805    fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
806        dequant_op_flow!(float_op | tensor | B::float_sqrt(tensor), tensor)
807    }
808
809    /// Returns a new tensor with absolute values.
810    ///
811    /// # Arguments
812    ///
813    /// * `tensor` - The tensor to take absolute value of.
814    ///
815    /// # Returns
816    ///
817    /// A tensor with the same shape as `tensor` with absolute values.
818    fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
819        dequant_op_quant!(float_op | tensor | B::float_abs(tensor), tensor)
820    }
821
822    /// Returns a new tensor with cosine values.
823    ///
824    /// # Arguments
825    ///
826    /// * `tensor` - The tensor to take the cosine of.
827    ///
828    /// # Returns
829    ///
830    /// A tensor with the same shape as `tensor` with cosine values.
831    fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
832        dequant_op_flow!(float_op | tensor | B::float_cos(tensor), tensor)
833    }
834
835    /// Returns a new tensor with sine values.
836    ///
837    /// # Arguments
838    ///
839    /// * `tensor` - The tensor to take the sine of.
840    ///
841    /// # Returns
842    ///
843    /// A tensor with the same shape as `tensor` with sine values.
844    fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
845        dequant_op_flow!(float_op | tensor | B::float_sin(tensor), tensor)
846    }
847
848    /// Returns a new tensor with tangent values.
849    ///
850    /// # Arguments
851    ///
852    /// * `tensor` - The tensor to take the tangent of.
853    ///
854    /// # Returns
855    ///
856    /// A tensor with the same shape as `tensor` with tangent values.
857    fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
858        dequant_op_flow!(float_op | tensor | B::float_tan(tensor), tensor)
859    }
860
861    /// Returns a new tensor with hyperbolic cosine values.
862    ///
863    /// # Arguments
864    ///
865    /// * `tensor` - The tensor to take the hyperbolic cosine of.
866    ///
867    /// # Returns
868    ///
869    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
870    fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
871        dequant_op_flow!(float_op | tensor | B::float_cosh(tensor), tensor)
872    }
873
874    /// Returns a new tensor with hyperbolic sine values.
875    ///
876    /// # Arguments
877    ///
878    /// * `tensor` - The tensor to take the hyperbolic sine of.
879    ///
880    /// # Returns
881    ///
882    /// A tensor with the same shape as `tensor` with hyperbolic sine values.
883    fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
884        dequant_op_flow!(float_op | tensor | B::float_sinh(tensor), tensor)
885    }
886
887    /// Returns a new tensor with hyperbolic tangent values.
888    ///
889    /// # Arguments
890    ///
891    /// * `tensor` - The tensor to take the hyperbolic tangent of.
892    ///
893    /// # Returns
894    ///
895    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
896    fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
897        dequant_op_flow!(float_op | tensor | B::float_tanh(tensor), tensor)
898    }
899
900    /// Returns a new tensor with the error function values.
901    ///
902    /// # Arguments
903    ///
904    /// * `tensor` - The tensor to take the error function of.
905    ///
906    /// # Returns
907    ///
908    /// A tensor with the same shape as `tensor` with error function values.
909    fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
910        dequant_op_flow!(float_op | tensor | B::float_erf(tensor), tensor)
911    }
912
913    /// Concatenates tensors along a dimension.
914    ///
915    /// # Arguments
916    ///
917    /// * `tensors` - The tensors to concatenate.
918    /// * `dim` - The dimension along which to concatenate.
919    ///
920    /// # Returns
921    ///
922    /// A tensor with the concatenated tensors along `dim`.
923    fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
924        // Heuristic: prioritize first tensor scheme
925        let first = tensors.first().unwrap();
926        let scheme = *first.scheme();
927        let dtype = get_device_settings::<B>(&Self::q_device(first)).float_dtype;
928
929        let tensor_f = tensors
930            .into_iter()
931            .map(|tensor| Self::dequantize(tensor, dtype))
932            .collect();
933
934        let out_f = B::float_cat(tensor_f, dim);
935
936        Self::quantize_dynamic(out_f, &scheme)
937    }
938
939    /// Gets the indices of the maximum elements of a tensor along an axis.
940    ///
941    /// # Arguments
942    ///
943    /// * `tensor` - The tensor to get the maximum elements of.
944    /// * `dim` - The dimension along which to get the maximum elements.
945    /// * `out_dtype` - The output tensor dtype.
946    ///
947    /// # Returns
948    ///
949    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
950    fn q_argmax(tensor: QuantizedTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B> {
951        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
952        let tensor_f = Self::dequantize(tensor, dtype);
953        B::float_argmax(tensor_f, dim, out_dtype)
954    }
955
956    /// Gets the indices of the minimum elements of a tensor along an axis.
957    ///
958    /// # Arguments
959    ///
960    /// * `tensor` - The tensor to get the minimum elements of.
961    /// * `dim` - The dimension along which to get the minimum elements.
962    /// * `out_dtype` - The output tensor dtype.
963    ///
964    /// # Returns
965    ///
966    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
967    fn q_argmin(tensor: QuantizedTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B> {
968        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
969        let tensor_f = Self::dequantize(tensor, dtype);
970        B::float_argmin(tensor_f, dim, out_dtype)
971    }
972
973    /// Gets the maximum element of a tensor.
974    ///
975    /// # Arguments
976    ///
977    /// * `tensor` - The tensor to get the maximum elements of.
978    ///
979    /// # Returns
980    ///
981    /// A tensor with the maximum element of `tensor`.
982    fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
983        let shape = tensor.shape();
984        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
985
986        B::q_max_dim(tensor, 0)
987    }
988
989    /// Gets the maximum elements of a tensor along an axis.
990    ///
991    /// # Arguments
992    ///
993    /// * `tensor` - The tensor to get the maximum elements of.
994    /// * `dim` - The dimension along which to get the maximum elements.
995    ///
996    /// # Returns
997    ///
998    /// A tensor with the maximum elements of `tensor` along `dim`.
999    fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1000        let int_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
1001        let index = B::q_argmax(tensor.clone(), dim, int_dtype);
1002
1003        B::q_gather(dim, tensor, index)
1004    }
1005
1006    /// Gets the maximum elements of a tensor along an axis and their indices.
1007    ///
1008    /// # Arguments
1009    ///
1010    /// * `tensor` - The tensor to get the maximum elements of.
1011    /// * `dim` - The dimension along which to get the maximum elements.
1012    ///
1013    /// # Returns
1014    ///
1015    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1016    fn q_max_dim_with_indices(
1017        tensor: QuantizedTensor<B>,
1018        dim: usize,
1019        out_dtype: IntDType,
1020    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1021        let index = B::q_argmax(tensor.clone(), dim, out_dtype);
1022        let values = B::q_gather(dim, tensor, index.clone());
1023
1024        (values, index)
1025    }
1026
1027    /// Gets the minimum element of a tensor.
1028    ///
1029    /// # Arguments
1030    ///
1031    /// * `tensor` - The tensor to get the minimum elements of.
1032    ///
1033    /// # Returns
1034    ///
1035    /// A tensor with the minimum element of `tensor`.
1036    fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1037        let shape = tensor.shape();
1038        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1039
1040        B::q_min_dim(tensor, 0)
1041    }
1042
1043    /// Gets the minimum elements of a tensor along an axis.
1044    ///
1045    /// # Arguments
1046    ///
1047    /// * `tensor` - The tensor to get the minimum elements of.
1048    /// * `dim` - The dimension along which to get the minimum elements.
1049    ///
1050    /// # Returns
1051    ///
1052    /// A tensor with the minimum elements of `tensor` along `dim`.
1053    fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1054        let int_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
1055        let index = B::q_argmin(tensor.clone(), dim, int_dtype);
1056
1057        B::q_gather(dim, tensor, index)
1058    }
1059
1060    /// Gets the minimum elements of a tensor along an axis and their indices.
1061    ///
1062    /// # Arguments
1063    ///
1064    /// * `tensor` - The tensor to get the minimum elements of.
1065    /// * `dim` - The dimension along which to get the minimum elements.
1066    ///
1067    /// # Returns
1068    ///
1069    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1070    fn q_min_dim_with_indices(
1071        tensor: QuantizedTensor<B>,
1072        dim: usize,
1073        out_dtype: IntDType,
1074    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1075        let index = B::q_argmin(tensor.clone(), dim, out_dtype);
1076        let values = B::q_gather(dim, tensor, index.clone());
1077
1078        (values, index)
1079    }
1080
1081    /// Gets the maximum element of a tensor.
1082    ///
1083    /// # Arguments
1084    ///
1085    /// * `tensor` - The tensor to get the maximum elements of.
1086    ///
1087    /// # Returns
1088    ///
1089    /// A tensor with the maximum element of `tensor`.
1090    fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1091        let shape = tensor.shape();
1092        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1093
1094        B::q_max_abs_dim(tensor, 0)
1095    }
1096
1097    /// Gets the maximum elements of a tensor along an axis.
1098    ///
1099    /// # Arguments
1100    ///
1101    /// * `tensor` - The tensor to get the maximum elements of.
1102    /// * `dim` - The dimension along which to get the maximum elements.
1103    ///
1104    /// # Returns
1105    ///
1106    /// A tensor with the maximum elements of `tensor` along `dim`.
1107    fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1108        let int_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
1109        let index = B::q_argmax(B::q_abs(tensor.clone()), dim, int_dtype);
1110
1111        B::q_gather(dim, tensor, index)
1112    }
1113
1114    /// Tests if any element in the `tensor` evaluates to True.
1115    ///
1116    /// # Arguments
1117    ///
1118    /// * `tensor` - The tensor to test.
1119    ///
1120    /// # Returns
1121    ///
1122    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1123    fn q_any(tensor: QuantizedTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1124        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1125        let tensor_f = Self::dequantize(tensor, dtype);
1126        B::float_any(tensor_f, out_dtype)
1127    }
1128
1129    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1130    ///
1131    /// # Arguments
1132    ///
1133    /// * `tensor` - The tensor to test.
1134    /// * `dim` - The axis along which to test.
1135    ///
1136    /// # Returns
1137    ///
1138    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1139    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1140    /// input evaluates to True, False otherwise.
1141    fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1142        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1143        let tensor_f = Self::dequantize(tensor, dtype);
1144        B::float_any_dim(tensor_f, dim, out_dtype)
1145    }
1146
1147    /// Tests if all elements in the `tensor` evaluate to True.
1148    ///
1149    /// # Arguments
1150    ///
1151    /// * `tensor` - The tensor to test.
1152    ///
1153    /// # Returns
1154    ///
1155    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1156    /// evaluate to True, False otherwise.
1157    fn q_all(tensor: QuantizedTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1158        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1159        let tensor_f = Self::dequantize(tensor, dtype);
1160        B::float_all(tensor_f, out_dtype)
1161    }
1162
1163    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1164    ///
1165    /// # Arguments
1166    ///
1167    /// * `tensor` - The tensor to test.
1168    /// * `dim` - The axis along which to test.
1169    ///
1170    /// # Returns
1171    ///
1172    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1173    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1174    /// evaluates to True, False otherwise.
1175    fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1176        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1177        let tensor_f = Self::dequantize(tensor, dtype);
1178        B::float_all_dim(tensor_f, dim, out_dtype)
1179    }
1180
1181    /// Sort the elements of the input `tensor` by value in along a given dimension.
1182    ///
1183    /// This sort is unstable (i.e., may reorder equal elements).
1184    ///
1185    /// # Arguments
1186    ///
1187    /// * `tensor` - The input tensor.
1188    /// * `dim` - The axis along which to sort.
1189    /// * `descending` - The sorting order.
1190    ///
1191    /// # Returns
1192    ///
1193    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1194    fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1195        // Default implementation. Backends can sort on the int values since qparams remain the same.
1196        dequant_op_quant!(
1197            float_op | tensor | B::float_sort(tensor, dim, descending),
1198            tensor
1199        )
1200    }
1201
1202    /// Sort the elements of the input `tensor` by value in along a given dimension.
1203    ///
1204    /// This sort is unstable (i.e., may reorder equal elements).
1205    ///
1206    /// # Arguments
1207    ///
1208    /// * `tensor` - The input tensor.
1209    /// * `dim` - The axis along which to sort.
1210    /// * `descending` - The sorting order.
1211    ///
1212    /// # Returns
1213    ///
1214    /// A tensor with the same shape as the input tensor and corresponding indices, where
1215    /// the elements are sorted by value and the indices map back to the original input tensor.
1216    fn q_sort_with_indices(
1217        tensor: QuantizedTensor<B>,
1218        dim: usize,
1219        descending: bool,
1220        out_dtype: IntDType,
1221    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1222        let scheme = *tensor.scheme();
1223        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1224
1225        let tensor_f = Self::dequantize(tensor, dtype);
1226        let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending, out_dtype);
1227
1228        (Self::quantize_dynamic(out_f, &scheme), indices)
1229    }
1230
1231    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1232    ///
1233    /// This sort is unstable (i.e., may reorder equal elements).
1234    ///
1235    /// # Arguments
1236    ///
1237    /// * `tensor` - The input tensor.
1238    /// * `dim` - The axis along which to sort.
1239    /// * `descending` - The sorting order.
1240    ///
1241    /// # Returns
1242    ///
1243    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1244    fn q_argsort(
1245        tensor: QuantizedTensor<B>,
1246        dim: usize,
1247        descending: bool,
1248        out_dtype: IntDType,
1249    ) -> IntTensor<B> {
1250        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1251        let tensor_f = Self::dequantize(tensor, dtype);
1252        B::float_argsort(tensor_f, dim, descending, out_dtype)
1253    }
1254}