Skip to main content

burn_backend/backend/ops/
qtensor.rs

1use alloc::vec::Vec;
2use burn_std::{
3    BoolDType, FloatDType, IntDType, Shape, Slice,
4    quantization::{QuantPropagation, QuantScheme},
5};
6
7use crate::{
8    Backend, ExecutionError, QTensorPrimitive, TensorData, TensorMetadata, TensorPrimitive,
9    get_device_settings,
10};
11use crate::{
12    Scalar,
13    tensor::{
14        BoolTensor, Device, FloatTensor, IntTensor, QuantizedTensor,
15        quantization::{
16            Calibration, QuantizationParametersPrimitive, compute_q_params, compute_range,
17        },
18    },
19};
20
21/// Automatically applies `dequantization -> float operation -> quantization`.
22///
23/// Used for tensor ops that should always return a quantized output.
24#[macro_export]
25macro_rules! dequant_op_quant {
26    // Binary tensor float op w/ lhs & rhs
27    (
28        float_op $float_op:expr, $t1:expr, $t2:expr
29    ) => {{
30        // Heuristic: prioritize lhs scheme
31        let scheme = $t1.scheme().clone();
32
33        let t1_f = Self::dequantize($t1);
34        let t2_f = Self::dequantize($t2);
35        #[allow(clippy::redundant_closure_call)]
36        let out_f = $float_op(t1_f, t2_f);
37
38        Self::quantize_dynamic(out_f, &scheme)
39    }};
40    // Unary tensor float op
41    (
42        float_op $float_op:expr, $tensor:expr
43    ) => {{
44        let scheme = $tensor.scheme().clone();
45        let dtype = get_device_settings::<B>(&Self::q_device(&$tensor)).float_dtype;
46
47        let tensor_f = Self::dequantize($tensor, dtype);
48        #[allow(clippy::redundant_closure_call)]
49        let out_f = $float_op(tensor_f);
50
51        Self::quantize_dynamic(out_f, &scheme)
52    }};
53}
54
55/// Automatically applies `dequantization -> float operation [-> quantization]`.
56///
57/// The output quantization step is optional.
58/// It is only performed when the input quantization scheme is propagated.
59#[macro_export]
60macro_rules! dequant_op_flow {
61    // Binary tensor float op w/ lhs & rhs
62    (
63        float_op $float_op:expr, $t1:expr, $t2:expr
64    ) => {{
65        // Heuristic: prioritize lhs scheme
66        let scheme = $t1.scheme().clone();
67        let propagation = $t1.propagation();
68        let dtype = get_device_settings::<B>(&Self::q_device(&$t1)).float_dtype;
69
70        let t1_f = Self::dequantize($t1, dtype);
71        let t2_f = Self::dequantize($t2, dtype);
72        #[allow(clippy::redundant_closure_call)]
73        let out_f = $float_op(t1_f, t2_f);
74
75        match propagation {
76            QuantPropagation::Propagate => {
77                TensorPrimitive::QFloat(Self::quantize_dynamic(out_f, &scheme))
78            }
79            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
80        }
81    }};
82    // Unary tensor float op
83    (
84        float_op $float_op:expr, $tensor:expr
85    ) => {{
86        let scheme = $tensor.scheme().clone();
87        let propagation = $tensor.propagation();
88        let dtype = get_device_settings::<B>(&Self::q_device(&$tensor)).float_dtype;
89
90        let tensor_f = Self::dequantize($tensor, dtype);
91        #[allow(clippy::redundant_closure_call)]
92        let out_f = $float_op(tensor_f);
93
94        match propagation {
95            QuantPropagation::Propagate => {
96                TensorPrimitive::QFloat(Self::quantize_dynamic(out_f, &scheme))
97            }
98            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
99        }
100    }};
101}
102
103/// Operations on quantized tensors.
104///
105/// # Return Type Semantics
106///
107/// The return type of each operation indicates how quantization is handled:
108///
109/// ## [`QuantizedTensor<B>`]
110/// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized
111/// representation. Implementations should avoid dequantizing when possible to maintain performance.
112/// For example, shape or layout changes such as expand or transpose preserve quantization.
113///
114/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is
115/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering
116/// the quantization parameters to match the new layout.*
117///
118///
119/// ## [`TensorPrimitive<B>`]
120/// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation
121/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`])
122/// returned in floating-point form ([`TensorPrimitive::Float`]).
123///
124/// This distinction allows for fine-grained control over mixed-precision flows while still operating
125/// through a unified API.
126pub trait QTensorOps<B: Backend> {
127    /// Creates a new tensor from the data structure.
128    ///
129    /// # Arguments
130    ///
131    /// * `data` - The data structure.
132    /// * `device` - The device to create the tensor on.
133    ///
134    /// # Returns
135    ///
136    /// The tensor with the given data.
137    fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
138
139    /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
140    fn quantize(
141        tensor: FloatTensor<B>,
142        scheme: &QuantScheme,
143        qparams: QuantizationParametersPrimitive<B>,
144    ) -> QuantizedTensor<B>;
145
146    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
147    fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> {
148        // Dynamically compute min/max tensor range and qparams before quantizing
149        let (min, max) = compute_range::<B>(scheme, tensor.clone(), &Calibration::MinMax);
150        let qparams = compute_q_params(scheme, min, max);
151        Self::quantize(tensor, scheme, qparams)
152    }
153
154    /// Convert the tensor back to a higher precision data type.
155    fn dequantize(tensor: QuantizedTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
156
157    /// Gets the device of the tensor.
158    ///
159    /// # Arguments
160    ///
161    /// * `tensor` - The tensor.
162    ///
163    /// # Returns
164    ///
165    /// The device of the tensor.
166    fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
167
168    /// Moves the tensor to the given device.
169    ///
170    /// # Arguments
171    ///
172    /// * `tensor` - The tensor.
173    /// * `device` - The device to move the tensor to.
174    ///
175    /// # Returns
176    ///
177    /// The tensor on the given device.
178    fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
179
180    /// Reshapes a tensor.
181    ///
182    /// # Arguments
183    ///
184    /// * `tensor` - The tensor to reshape.
185    /// * `shape` - The new shape of the tensor.
186    ///
187    /// # Returns
188    ///
189    /// The tensor with the new shape.
190    fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
191
192    /// Converts the tensor to a data structure.
193    ///
194    /// # Arguments
195    ///
196    /// * `tensor` - The tensor.
197    ///
198    /// # Returns
199    ///
200    /// The data structure with the tensor's data.
201    fn q_into_data(
202        tensor: QuantizedTensor<B>,
203    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
204
205    /// Detaches a tensor from the computation graph.
206    fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
207        // Should only be overridden by autodiff backends.
208        tensor
209    }
210
211    /// Sets the `require_grad` flag of a tensor.
212    fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
213        // Should only be overridden by autodiff backends.
214        tensor
215    }
216
217    /// Returns the `require_grad` flag of a tensor.
218    fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
219        // Should only be overridden by autodiff backends.
220        false
221    }
222
223    /// Broadcasts the `tensor` to the given `shape`.
224    fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
225
226    /// Transposes a tensor.
227    ///
228    /// # Arguments
229    ///
230    /// * `tensor` - The tensor to transpose.
231    ///
232    /// # Returns
233    ///
234    /// The transposed tensor.
235    fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
236        let ndims = tensor.shape().num_dims();
237        Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
238    }
239
240    /// Swaps two dimensions of a tensor.
241    ///
242    /// # Arguments
243    ///
244    /// * `tensor` - The tensor to swap the dimensions of.
245    /// * `dim1` - The first dimension to swap.
246    /// * `dim2` - The second dimension to swap.
247    ///
248    /// # Returns
249    ///
250    /// The tensor with the dimensions swapped.
251    fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
252
253    /// Permutes the dimensions of a tensor.
254    ///
255    /// # Arguments
256    ///
257    /// * `tensor` - The tensor to permute the dimensions of.
258    /// * `axes` - The new order of the dimensions.
259    /// # Returns
260    ///
261    /// The tensor with the dimensions permuted.
262    fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
263
264    /// Reverse the order of elements in a tensor along the given axes.
265    ///
266    /// # Arguments
267    ///
268    /// * `tensor` - The tensor to reverse.
269    /// * `axes` - The axes to reverse.
270    ///
271    /// The tensor with the elements reversed.
272    fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
273
274    /// Select tensor elements along the given dimension corresponding for the given indices.
275    ///
276    /// # Arguments
277    ///
278    /// * `tensor` - The tensor to select from.
279    /// * `dim` - The dimension to select from.
280    /// * `indices` - The indices to select.
281    ///
282    /// # Returns
283    ///
284    /// The selected elements.
285    fn q_select(
286        tensor: QuantizedTensor<B>,
287        dim: usize,
288        indices: IntTensor<B>,
289    ) -> QuantizedTensor<B>;
290
291    /// Select tensor elements corresponding to the given slices.
292    ///
293    /// # Arguments
294    ///
295    /// * `tensor` - The tensor to select from.
296    /// * `slices` - The slices specifying ranges and steps for each dimension.
297    ///
298    /// # Returns
299    ///
300    /// The selected elements in a new tensor.
301    fn q_slice(tensor: QuantizedTensor<B>, slices: &[Slice]) -> QuantizedTensor<B>;
302
303    /// Gather elements from a tensor.
304    ///
305    /// # Arguments
306    ///
307    /// * `dim` - The dimension to gather from.
308    /// * `tensor` - The tensor to gather from.
309    /// * `indices` - The indices to gather.
310    ///
311    /// # Returns
312    ///
313    /// The gathered elements.
314    fn q_gather(
315        dim: usize,
316        tensor: QuantizedTensor<B>,
317        indices: IntTensor<B>,
318    ) -> QuantizedTensor<B> {
319        // Default implementation. Backends can gather on the quantized values when supported.
320        dequant_op_quant!(
321            float_op | tensor | B::float_gather(dim, tensor, indices),
322            tensor
323        )
324    }
325
326    /// Repeat the tensor along the given dimension.
327    ///
328    /// # Arguments
329    ///
330    /// * `tensor` - The tensor.
331    /// * `dim` - The dimension to repeat.
332    /// * `times` - The number of times to repeat the dimension.
333    ///
334    /// # Returns
335    ///
336    /// The tensor with the given dimension repeated.
337    fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
338        dequant_op_quant!(
339            float_op | tensor | B::float_repeat_dim(tensor, dim, times),
340            tensor
341        )
342    }
343
344    /// Adds two tensors together.
345    ///
346    /// # Arguments
347    ///
348    /// * `lhs` - The left hand side tensor.
349    /// * `rhs` - The right hand side tensor.
350    ///
351    /// # Returns
352    ///
353    /// The result of adding the two tensors together.
354    fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
355        dequant_op_flow!(float_op | lhs, rhs | B::float_add(lhs, rhs), lhs, rhs)
356    }
357
358    /// Adds a scalar to a tensor.
359    ///
360    /// # Arguments
361    ///
362    /// * `lhs` - The left hand side tensor.
363    /// * `rhs` - The right hand side scalar.
364    ///
365    /// # Returns
366    ///
367    /// The result of adding the scalar to the tensor.
368    fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
369        dequant_op_flow!(float_op | tensor | B::float_add_scalar(tensor, rhs), lhs)
370    }
371
372    /// Clamps a tensor under a minimum value.
373    ///
374    /// # Arguments
375    ///
376    /// * `tensor` - The tensor to clamp.
377    /// * `min` - The minimum value.
378    ///
379    /// # Returns
380    ///
381    /// The clamped tensor.
382    fn q_clamp_min(tensor: QuantizedTensor<B>, min: Scalar) -> TensorPrimitive<B> {
383        dequant_op_flow!(float_op | tensor | B::float_clamp_min(tensor, min), tensor)
384    }
385
386    /// Clamps a tensor over a maximum value.
387    ///
388    /// # Arguments
389    ///
390    /// * `tensor` - The tensor to clamp.
391    /// * `max` - The maximum value.
392    ///
393    /// # Returns
394    ///
395    /// The clamped tensor.
396    fn q_clamp_max(tensor: QuantizedTensor<B>, max: Scalar) -> TensorPrimitive<B> {
397        dequant_op_flow!(float_op | tensor | B::float_clamp_max(tensor, max), tensor)
398    }
399
400    /// Clamps a tensor between a minimum and maximum value.
401    ///
402    /// # Arguments
403    ///
404    /// * `tensor` - The tensor to clamp.
405    /// * `min` - The minimum value.
406    /// * `max` - The maximum value.
407    ///
408    /// # Returns
409    ///
410    /// The clamped tensor.
411    fn q_clamp(tensor: QuantizedTensor<B>, min: Scalar, max: Scalar) -> TensorPrimitive<B> {
412        dequant_op_flow!(float_op | tensor | B::float_clamp(tensor, min, max), tensor)
413    }
414
415    /// Subtracts two tensors.
416    ///
417    /// # Arguments
418    ///
419    /// * `lhs` - The left hand side tensor.
420    /// * `rhs` - The right hand side tensor.
421    ///
422    /// # Returns
423    ///
424    /// The result of subtracting the two tensors.
425    fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
426        dequant_op_flow!(float_op | lhs, rhs | B::float_sub(lhs, rhs), lhs, rhs)
427    }
428
429    /// Subtracts a scalar from a tensor.
430    ///
431    /// # Arguments
432    ///
433    /// * `lhs` - The left hand side tensor.
434    /// * `rhs` - The right hand side scalar.
435    ///
436    /// # Returns
437    ///
438    /// The result of subtracting the scalar from the tensor.
439    fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
440        dequant_op_flow!(float_op | tensor | B::float_sub_scalar(tensor, rhs), lhs)
441    }
442
443    /// Multiplies two tensors together element-wise.
444    fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
445        dequant_op_flow!(float_op | lhs, rhs | B::float_mul(lhs, rhs), lhs, rhs)
446    }
447
448    /// Multiplies a tensor by a scalar.
449    ///
450    /// # Arguments
451    ///
452    /// * `lhs` - The left hand side tensor.
453    /// * `rhs` - The right hand side scalar.
454    ///
455    /// # Returns
456    ///
457    /// The result of multiplying the tensor by the scalar.
458    fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
459        dequant_op_flow!(float_op | tensor | B::float_mul_scalar(tensor, rhs), lhs)
460    }
461
462    /// Divides two tensors element-wise.
463    ///
464    /// # Arguments
465    ///
466    /// * `lhs` - The left hand side tensor.
467    /// * `rhs` - The right hand side tensor.
468    ///
469    /// # Returns
470    ///
471    /// The result of dividing the two tensors.
472    fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
473        dequant_op_flow!(float_op | lhs, rhs | B::float_div(lhs, rhs), lhs, rhs)
474    }
475
476    /// Divides a tensor by a scalar.
477    ///
478    /// # Arguments
479    ///
480    /// * `lhs` - The left hand side tensor.
481    /// * `rhs` - The right hand side scalar.
482    ///
483    /// # Returns
484    ///
485    /// The result of dividing the tensor by the scalar.
486    fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
487        dequant_op_flow!(float_op | tensor | B::float_div_scalar(tensor, rhs), lhs)
488    }
489
490    /// Multiplies two tensors together using matrix multiplication.
491    ///
492    /// # Arguments
493    ///
494    /// * `lhs` - The left hand side tensor.
495    /// * `rhs` - The right hand side tensor.
496    ///
497    /// # Returns
498    ///
499    /// The result of multiplying the two tensors together using matrix multiplication.
500    fn q_matmul(lhs: TensorPrimitive<B>, rhs: TensorPrimitive<B>) -> TensorPrimitive<B> {
501        let mut propagation = QuantPropagation::Inhibit;
502        let mut scheme = QuantScheme::default();
503        let mut dtype = None;
504
505        let lhs = match lhs {
506            TensorPrimitive::Float(lhs) => lhs,
507            TensorPrimitive::QFloat(lhs) => {
508                propagation = lhs.propagation();
509                scheme = *lhs.scheme();
510                let float_dtype = get_device_settings::<B>(&Self::q_device(&lhs)).float_dtype;
511                dtype = Some(float_dtype);
512
513                Self::dequantize(lhs, float_dtype)
514            }
515        };
516        let rhs = match rhs {
517            TensorPrimitive::Float(rhs) => rhs,
518            TensorPrimitive::QFloat(rhs) => {
519                propagation = rhs.propagation();
520                scheme = *rhs.scheme();
521                let float_dtype = dtype
522                    .unwrap_or_else(|| get_device_settings::<B>(&Self::q_device(&rhs)).float_dtype);
523
524                Self::dequantize(rhs, float_dtype)
525            }
526        };
527
528        let out_f = B::float_matmul(lhs, rhs);
529        match propagation {
530            QuantPropagation::Propagate => {
531                TensorPrimitive::QFloat(<Self>::quantize_dynamic(out_f, &scheme))
532            }
533            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
534        }
535    }
536
537    /// Negates a tensor element-wise.
538    fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
539        dequant_op_flow!(float_op | tensor | B::float_neg(tensor), tensor)
540    }
541
542    /// Calculates the reciprocals element-wise
543    fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
544        dequant_op_flow!(float_op | tensor | B::float_recip(tensor), tensor)
545    }
546
547    /// Sum of all elements in a tensor.
548    ///
549    /// # Arguments
550    ///
551    /// * `tensor` - The tensor to sum.
552    ///
553    /// # Returns
554    ///
555    /// A scalar tensor with the sum of all elements in `tensor`.
556    fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
557        dequant_op_flow!(float_op | tensor | B::float_sum(tensor), tensor)
558    }
559
560    /// Sum of all elements in a tensor along a dimension.
561    ///
562    /// # Arguments
563    ///
564    /// * `tensor` - The tensor to sum.
565    /// * `dim` - The dimension along which to sum.
566    ///
567    /// # Returns
568    ///
569    /// A tensor with the sum of all elements in `tensor` along `dim`.
570    fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
571        dequant_op_flow!(float_op | tensor | B::float_sum_dim(tensor, dim), tensor)
572    }
573
574    /// Product of all elements in a tensor.
575    ///
576    /// # Arguments
577    ///
578    /// * `tensor` - The tensor to product.
579    ///
580    /// # Returns
581    ///
582    /// A scalar tensor with the product of all elements in `tensor`.
583    fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
584        dequant_op_flow!(float_op | tensor | B::float_prod(tensor), tensor)
585    }
586
587    /// Product of all elements in a tensor along a dimension.
588    ///
589    /// # Arguments
590    ///
591    /// * `tensor` - The tensor to product.
592    ///
593    /// # Returns
594    ///
595    /// A tensor with the product of all elements in `tensor` along `dim`.
596    fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
597        dequant_op_flow!(float_op | tensor | B::float_prod_dim(tensor, dim), tensor)
598    }
599
600    /// Mean of all elements in a tensor.
601    ///
602    /// # Arguments
603    ///
604    /// * `tensor` - The tensor to mean.
605    ///
606    /// # Returns
607    ///
608    /// A scalar tensor with the mean of all elements in `tensor`.
609    fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
610        dequant_op_flow!(float_op | tensor | B::float_mean(tensor), tensor)
611    }
612
613    /// Mean of all elements in a tensor along a dimension.
614    ///
615    /// # Arguments
616    ///
617    /// * `tensor` - The tensor to mean.
618    /// * `dim` - The dimension along which to mean.
619    ///
620    /// # Returns
621    ///
622    /// A tensor with the mean of all elements in `tensor` along `dim`.
623    fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
624        dequant_op_flow!(float_op | tensor | B::float_mean_dim(tensor, dim), tensor)
625    }
626
627    /// Computes the cumulative sum of elements along a dimension.
628    ///
629    /// # Arguments
630    ///
631    /// * `tensor` - The tensor to compute the cumulative sum of.
632    /// * `dim` - The dimension along which to compute the cumulative sum.
633    ///
634    /// # Returns
635    ///
636    /// A tensor with the same shape where each element is the cumulative sum
637    /// of all elements up to and including that position along the dimension.
638    fn q_cumsum(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
639        dequant_op_flow!(float_op | tensor | B::float_cumsum(tensor, dim), tensor)
640    }
641
642    /// Computes the cumulative product of elements along a dimension.
643    ///
644    /// # Arguments
645    ///
646    /// * `tensor` - The tensor to compute the cumulative product of.
647    /// * `dim` - The dimension along which to compute the cumulative product.
648    ///
649    /// # Returns
650    ///
651    /// A tensor with the same shape where each element is the cumulative product
652    /// of all elements up to and including that position along the dimension.
653    fn q_cumprod(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
654        dequant_op_flow!(float_op | tensor | B::float_cumprod(tensor, dim), tensor)
655    }
656
657    /// Computes the cumulative minimum of elements along a dimension.
658    ///
659    /// # Arguments
660    ///
661    /// * `tensor` - The tensor to compute the cumulative minimum of.
662    /// * `dim` - The dimension along which to compute the cumulative minimum.
663    ///
664    /// # Returns
665    ///
666    /// A tensor with the same shape where each element is the minimum
667    /// of all elements up to and including that position along the dimension.
668    fn q_cummin(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
669        dequant_op_flow!(float_op | tensor | B::float_cummin(tensor, dim), tensor)
670    }
671
672    /// Computes the cumulative maximum of elements along a dimension.
673    ///
674    /// # Arguments
675    ///
676    /// * `tensor` - The tensor to compute the cumulative maximum of.
677    /// * `dim` - The dimension along which to compute the cumulative maximum.
678    ///
679    /// # Returns
680    ///
681    /// A tensor with the same shape where each element is the maximum
682    /// of all elements up to and including that position along the dimension.
683    fn q_cummax(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
684        dequant_op_flow!(float_op | tensor | B::float_cummax(tensor, dim), tensor)
685    }
686
687    /// Returns a new tensor with exponential values.
688    ///
689    /// # Arguments
690    ///
691    /// * `tensor` - The tensor to exponentiate.
692    ///
693    /// # Returns
694    ///
695    /// A tensor with the same shape as `tensor` with exponential values.
696    fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
697        dequant_op_flow!(float_op | tensor | B::float_exp(tensor), tensor)
698    }
699
700    /// Returns a new tensor with natural logarithm values.
701    ///
702    /// # Arguments
703    ///
704    /// * `tensor` - The tensor to take the logarithm of.
705    ///
706    /// # Returns
707    ///
708    /// A tensor with the same shape as `tensor` with natural logarithm values.
709    fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
710        dequant_op_flow!(float_op | tensor | B::float_log(tensor), tensor)
711    }
712
713    /// Returns a new tensor with logarithm values of (1 + Xi).
714    ///
715    /// # Arguments
716    ///
717    /// * `tensor` - The tensor to take the logarithm of.
718    ///
719    /// # Returns
720    ///
721    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
722    fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
723        dequant_op_flow!(float_op | tensor | B::float_log1p(tensor), tensor)
724    }
725
726    /// Element-wise power with another tensor.
727    ///
728    /// # Arguments
729    ///
730    /// * `lhs` - The left hand side tensor.
731    /// * `rhs` - The right hand side tensor.
732    ///
733    /// # Returns
734    ///
735    /// The elements of `lhs` raised to the power of the elements of `rhs`.
736    fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
737        dequant_op_flow!(float_op | lhs, rhs | B::float_powf(lhs, rhs), lhs, rhs)
738    }
739
740    /// Element-wise power with an IntTensor.
741    ///
742    /// # Arguments
743    ///
744    /// * `lhs` - The left hand side tensor.
745    /// * `rhs` - The right hand side floatTensor.
746    ///
747    /// # Returns
748    ///
749    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
750    fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> {
751        dequant_op_flow!(float_op | tensor | B::float_powi(tensor, rhs), lhs)
752    }
753
754    /// Element-wise power with an int scalar.
755    ///
756    /// # Arguments
757    ///
758    /// * `lhs` - The left hand side tensor.
759    /// * `rhs` - The right hand side scalar.
760    ///
761    /// # Returns
762    ///
763    /// The elements of `lhs` raised to the value of `rhs`.
764    fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
765        dequant_op_flow!(float_op | tensor | B::float_powi_scalar(tensor, rhs), lhs)
766    }
767
768    /// Element-wise power with a float scalar.
769    ///
770    /// # Arguments
771    ///
772    /// * `tensor` - The tensor to exponentiate.
773    /// * `value` - The exponent.
774    ///
775    /// # Returns
776    ///
777    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
778    fn q_powf_scalar(tensor: QuantizedTensor<B>, value: Scalar) -> TensorPrimitive<B> {
779        dequant_op_flow!(
780            float_op | tensor | B::float_powf_scalar(tensor, value),
781            tensor
782        )
783    }
784
785    /// Returns a new tensor with square root values.
786    ///
787    /// # Arguments
788    ///
789    /// * `tensor` - The tensor to take the square root of.
790    ///
791    /// # Returns
792    ///
793    /// A tensor with the same shape as `tensor` with square root values.
794    fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
795        dequant_op_flow!(float_op | tensor | B::float_sqrt(tensor), tensor)
796    }
797
798    /// Returns a new tensor with absolute values.
799    ///
800    /// # Arguments
801    ///
802    /// * `tensor` - The tensor to take absolute value of.
803    ///
804    /// # Returns
805    ///
806    /// A tensor with the same shape as `tensor` with absolute values.
807    fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
808        dequant_op_quant!(float_op | tensor | B::float_abs(tensor), tensor)
809    }
810
811    /// Returns a new tensor with cosine values.
812    ///
813    /// # Arguments
814    ///
815    /// * `tensor` - The tensor to take the cosine of.
816    ///
817    /// # Returns
818    ///
819    /// A tensor with the same shape as `tensor` with cosine values.
820    fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
821        dequant_op_flow!(float_op | tensor | B::float_cos(tensor), tensor)
822    }
823
824    /// Returns a new tensor with sine values.
825    ///
826    /// # Arguments
827    ///
828    /// * `tensor` - The tensor to take the sine of.
829    ///
830    /// # Returns
831    ///
832    /// A tensor with the same shape as `tensor` with sine values.
833    fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
834        dequant_op_flow!(float_op | tensor | B::float_sin(tensor), tensor)
835    }
836
837    /// Returns a new tensor with tangent values.
838    ///
839    /// # Arguments
840    ///
841    /// * `tensor` - The tensor to take the tangent of.
842    ///
843    /// # Returns
844    ///
845    /// A tensor with the same shape as `tensor` with tangent values.
846    fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
847        dequant_op_flow!(float_op | tensor | B::float_tan(tensor), tensor)
848    }
849
850    /// Returns a new tensor with hyperbolic cosine values.
851    ///
852    /// # Arguments
853    ///
854    /// * `tensor` - The tensor to take the hyperbolic cosine of.
855    ///
856    /// # Returns
857    ///
858    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
859    fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
860        dequant_op_flow!(float_op | tensor | B::float_cosh(tensor), tensor)
861    }
862
863    /// Returns a new tensor with hyperbolic sine values.
864    ///
865    /// # Arguments
866    ///
867    /// * `tensor` - The tensor to take the hyperbolic sine of.
868    ///
869    /// # Returns
870    ///
871    /// A tensor with the same shape as `tensor` with hyperbolic sine values.
872    fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
873        dequant_op_flow!(float_op | tensor | B::float_sinh(tensor), tensor)
874    }
875
876    /// Returns a new tensor with hyperbolic tangent values.
877    ///
878    /// # Arguments
879    ///
880    /// * `tensor` - The tensor to take the hyperbolic tangent of.
881    ///
882    /// # Returns
883    ///
884    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
885    fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
886        dequant_op_flow!(float_op | tensor | B::float_tanh(tensor), tensor)
887    }
888
889    /// Returns a new tensor with the error function values.
890    ///
891    /// # Arguments
892    ///
893    /// * `tensor` - The tensor to take the error function of.
894    ///
895    /// # Returns
896    ///
897    /// A tensor with the same shape as `tensor` with error function values.
898    fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
899        dequant_op_flow!(float_op | tensor | B::float_erf(tensor), tensor)
900    }
901
902    /// Concatenates tensors along a dimension.
903    ///
904    /// # Arguments
905    ///
906    /// * `tensors` - The tensors to concatenate.
907    /// * `dim` - The dimension along which to concatenate.
908    ///
909    /// # Returns
910    ///
911    /// A tensor with the concatenated tensors along `dim`.
912    fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
913        // Heuristic: prioritize first tensor scheme
914        let first = tensors.first().unwrap();
915        let scheme = *first.scheme();
916        let dtype = get_device_settings::<B>(&Self::q_device(first)).float_dtype;
917
918        let tensor_f = tensors
919            .into_iter()
920            .map(|tensor| Self::dequantize(tensor, dtype))
921            .collect();
922
923        let out_f = B::float_cat(tensor_f, dim);
924
925        Self::quantize_dynamic(out_f, &scheme)
926    }
927
928    /// Gets the indices of the maximum elements of a tensor along an axis.
929    ///
930    /// # Arguments
931    ///
932    /// * `tensor` - The tensor to get the maximum elements of.
933    /// * `dim` - The dimension along which to get the maximum elements.
934    /// * `out_dtype` - The output tensor dtype.
935    ///
936    /// # Returns
937    ///
938    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
939    fn q_argmax(tensor: QuantizedTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B> {
940        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
941        let tensor_f = Self::dequantize(tensor, dtype);
942        B::float_argmax(tensor_f, dim, out_dtype)
943    }
944
945    /// Gets the indices of the minimum elements of a tensor along an axis.
946    ///
947    /// # Arguments
948    ///
949    /// * `tensor` - The tensor to get the minimum elements of.
950    /// * `dim` - The dimension along which to get the minimum elements.
951    /// * `out_dtype` - The output tensor dtype.
952    ///
953    /// # Returns
954    ///
955    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
956    fn q_argmin(tensor: QuantizedTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B> {
957        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
958        let tensor_f = Self::dequantize(tensor, dtype);
959        B::float_argmin(tensor_f, dim, out_dtype)
960    }
961
962    /// Gets the maximum element of a tensor.
963    ///
964    /// # Arguments
965    ///
966    /// * `tensor` - The tensor to get the maximum elements of.
967    ///
968    /// # Returns
969    ///
970    /// A tensor with the maximum element of `tensor`.
971    fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
972        let shape = tensor.shape();
973        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
974
975        B::q_max_dim(tensor, 0)
976    }
977
978    /// Gets the maximum elements of a tensor along an axis.
979    ///
980    /// # Arguments
981    ///
982    /// * `tensor` - The tensor to get the maximum elements of.
983    /// * `dim` - The dimension along which to get the maximum elements.
984    ///
985    /// # Returns
986    ///
987    /// A tensor with the maximum elements of `tensor` along `dim`.
988    fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
989        let int_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
990        let index = B::q_argmax(tensor.clone(), dim, int_dtype);
991
992        B::q_gather(dim, tensor, index)
993    }
994
995    /// Gets the maximum elements of a tensor along an axis and their indices.
996    ///
997    /// # Arguments
998    ///
999    /// * `tensor` - The tensor to get the maximum elements of.
1000    /// * `dim` - The dimension along which to get the maximum elements.
1001    ///
1002    /// # Returns
1003    ///
1004    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1005    fn q_max_dim_with_indices(
1006        tensor: QuantizedTensor<B>,
1007        dim: usize,
1008        out_dtype: IntDType,
1009    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1010        let index = B::q_argmax(tensor.clone(), dim, out_dtype);
1011        let values = B::q_gather(dim, tensor, index.clone());
1012
1013        (values, index)
1014    }
1015
1016    /// Gets the minimum element of a tensor.
1017    ///
1018    /// # Arguments
1019    ///
1020    /// * `tensor` - The tensor to get the minimum elements of.
1021    ///
1022    /// # Returns
1023    ///
1024    /// A tensor with the minimum element of `tensor`.
1025    fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1026        let shape = tensor.shape();
1027        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1028
1029        B::q_min_dim(tensor, 0)
1030    }
1031
1032    /// Gets the minimum elements of a tensor along an axis.
1033    ///
1034    /// # Arguments
1035    ///
1036    /// * `tensor` - The tensor to get the minimum elements of.
1037    /// * `dim` - The dimension along which to get the minimum elements.
1038    ///
1039    /// # Returns
1040    ///
1041    /// A tensor with the minimum elements of `tensor` along `dim`.
1042    fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1043        let int_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
1044        let index = B::q_argmin(tensor.clone(), dim, int_dtype);
1045
1046        B::q_gather(dim, tensor, index)
1047    }
1048
1049    /// Gets the minimum elements of a tensor along an axis and their indices.
1050    ///
1051    /// # Arguments
1052    ///
1053    /// * `tensor` - The tensor to get the minimum elements of.
1054    /// * `dim` - The dimension along which to get the minimum elements.
1055    ///
1056    /// # Returns
1057    ///
1058    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1059    fn q_min_dim_with_indices(
1060        tensor: QuantizedTensor<B>,
1061        dim: usize,
1062        out_dtype: IntDType,
1063    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1064        let index = B::q_argmin(tensor.clone(), dim, out_dtype);
1065        let values = B::q_gather(dim, tensor, index.clone());
1066
1067        (values, index)
1068    }
1069
1070    /// Gets the maximum element of a tensor.
1071    ///
1072    /// # Arguments
1073    ///
1074    /// * `tensor` - The tensor to get the maximum elements of.
1075    ///
1076    /// # Returns
1077    ///
1078    /// A tensor with the maximum element of `tensor`.
1079    fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1080        let shape = tensor.shape();
1081        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1082
1083        B::q_max_abs_dim(tensor, 0)
1084    }
1085
1086    /// Gets the maximum elements of a tensor along an axis.
1087    ///
1088    /// # Arguments
1089    ///
1090    /// * `tensor` - The tensor to get the maximum elements of.
1091    /// * `dim` - The dimension along which to get the maximum elements.
1092    ///
1093    /// # Returns
1094    ///
1095    /// A tensor with the maximum elements of `tensor` along `dim`.
1096    fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1097        let int_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
1098        let index = B::q_argmax(B::q_abs(tensor.clone()), dim, int_dtype);
1099
1100        B::q_gather(dim, tensor, index)
1101    }
1102
1103    /// Tests if any element in the `tensor` evaluates to True.
1104    ///
1105    /// # Arguments
1106    ///
1107    /// * `tensor` - The tensor to test.
1108    ///
1109    /// # Returns
1110    ///
1111    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1112    fn q_any(tensor: QuantizedTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1113        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1114        let tensor_f = Self::dequantize(tensor, dtype);
1115        B::float_any(tensor_f, out_dtype)
1116    }
1117
1118    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1119    ///
1120    /// # Arguments
1121    ///
1122    /// * `tensor` - The tensor to test.
1123    /// * `dim` - The axis along which to test.
1124    ///
1125    /// # Returns
1126    ///
1127    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1128    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1129    /// input evaluates to True, False otherwise.
1130    fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1131        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1132        let tensor_f = Self::dequantize(tensor, dtype);
1133        B::float_any_dim(tensor_f, dim, out_dtype)
1134    }
1135
1136    /// Tests if all elements in the `tensor` evaluate to True.
1137    ///
1138    /// # Arguments
1139    ///
1140    /// * `tensor` - The tensor to test.
1141    ///
1142    /// # Returns
1143    ///
1144    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1145    /// evaluate to True, False otherwise.
1146    fn q_all(tensor: QuantizedTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1147        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1148        let tensor_f = Self::dequantize(tensor, dtype);
1149        B::float_all(tensor_f, out_dtype)
1150    }
1151
1152    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1153    ///
1154    /// # Arguments
1155    ///
1156    /// * `tensor` - The tensor to test.
1157    /// * `dim` - The axis along which to test.
1158    ///
1159    /// # Returns
1160    ///
1161    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1162    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1163    /// evaluates to True, False otherwise.
1164    fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1165        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1166        let tensor_f = Self::dequantize(tensor, dtype);
1167        B::float_all_dim(tensor_f, dim, out_dtype)
1168    }
1169
1170    /// Sort the elements of the input `tensor` by value in along a given dimension.
1171    ///
1172    /// This sort is unstable (i.e., may reorder equal elements).
1173    ///
1174    /// # Arguments
1175    ///
1176    /// * `tensor` - The input tensor.
1177    /// * `dim` - The axis along which to sort.
1178    /// * `descending` - The sorting order.
1179    ///
1180    /// # Returns
1181    ///
1182    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1183    fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1184        // Default implementation. Backends can sort on the int values since qparams remain the same.
1185        dequant_op_quant!(
1186            float_op | tensor | B::float_sort(tensor, dim, descending),
1187            tensor
1188        )
1189    }
1190
1191    /// Sort the elements of the input `tensor` by value in along a given dimension.
1192    ///
1193    /// This sort is unstable (i.e., may reorder equal elements).
1194    ///
1195    /// # Arguments
1196    ///
1197    /// * `tensor` - The input tensor.
1198    /// * `dim` - The axis along which to sort.
1199    /// * `descending` - The sorting order.
1200    ///
1201    /// # Returns
1202    ///
1203    /// A tensor with the same shape as the input tensor and corresponding indices, where
1204    /// the elements are sorted by value and the indices map back to the original input tensor.
1205    fn q_sort_with_indices(
1206        tensor: QuantizedTensor<B>,
1207        dim: usize,
1208        descending: bool,
1209        out_dtype: IntDType,
1210    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1211        let scheme = *tensor.scheme();
1212        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1213
1214        let tensor_f = Self::dequantize(tensor, dtype);
1215        let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending, out_dtype);
1216
1217        (Self::quantize_dynamic(out_f, &scheme), indices)
1218    }
1219
1220    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1221    ///
1222    /// This sort is unstable (i.e., may reorder equal elements).
1223    ///
1224    /// # Arguments
1225    ///
1226    /// * `tensor` - The input tensor.
1227    /// * `dim` - The axis along which to sort.
1228    /// * `descending` - The sorting order.
1229    ///
1230    /// # Returns
1231    ///
1232    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1233    fn q_argsort(
1234        tensor: QuantizedTensor<B>,
1235        dim: usize,
1236        descending: bool,
1237        out_dtype: IntDType,
1238    ) -> IntTensor<B> {
1239        let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1240        let tensor_f = Self::dequantize(tensor, dtype);
1241        B::float_argsort(tensor_f, dim, descending, out_dtype)
1242    }
1243}