burn_backend/backend/ops/
qtensor.rs

1use alloc::vec::Vec;
2use burn_std::{
3    Shape, Slice,
4    quantization::{QuantPropagation, QuantScheme},
5};
6
7use crate::tensor::{
8    BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor,
9    quantization::{Calibration, QuantizationParametersPrimitive, compute_q_params, compute_range},
10};
11use crate::{
12    Backend, ExecutionError, QTensorPrimitive, TensorData, TensorMetadata, TensorPrimitive,
13};
14
15/// Automatically applies `dequantization -> float operation -> quantization`.
16///
17/// Used for tensor ops that should always return a quantized output.
18#[macro_export]
19macro_rules! dequant_op_quant {
20    // Binary tensor float op w/ lhs & rhs
21    (
22        ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
23    ) => {{
24        // Heuristic: prioritize lhs scheme
25        let scheme = $t1.scheme().clone();
26
27        let t1_f = <$ty>::dequantize($t1);
28        let t2_f = <$ty>::dequantize($t2);
29        #[allow(clippy::redundant_closure_call)]
30        let out_f = $float_op(t1_f, t2_f);
31
32        <$ty>::quantize_dynamic(out_f, &scheme)
33    }};
34    // Unary tensor float op
35    (
36        ty $ty:ty, float_op $float_op:expr, $tensor:expr
37    ) => {{
38        let scheme = $tensor.scheme().clone();
39
40        let tensor_f = <$ty>::dequantize($tensor);
41        #[allow(clippy::redundant_closure_call)]
42        let out_f = $float_op(tensor_f);
43
44        <$ty>::quantize_dynamic(out_f, &scheme)
45    }};
46}
47
48/// Automatically applies `dequantization -> float operation [-> quantization]`.
49///
50/// The output quantization step is optional.
51/// It is only performed when the input quantization scheme is propagated.
52#[macro_export]
53macro_rules! dequant_op_flow {
54    // Binary tensor float op w/ lhs & rhs
55    (
56        ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
57    ) => {{
58        // Heuristic: prioritize lhs scheme
59        let scheme = $t1.scheme().clone();
60        let propagation = $t1.propagation();
61
62        let t1_f = <$ty>::dequantize($t1);
63        let t2_f = <$ty>::dequantize($t2);
64        #[allow(clippy::redundant_closure_call)]
65        let out_f = $float_op(t1_f, t2_f);
66
67        match propagation {
68            QuantPropagation::Propagate => {
69                TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
70            }
71            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
72        }
73    }};
74    // Unary tensor float op
75    (
76        ty $ty:ty, float_op $float_op:expr, $tensor:expr
77    ) => {{
78        let scheme = $tensor.scheme().clone();
79        let propagation = $tensor.propagation();
80
81        let tensor_f = <$ty>::dequantize($tensor);
82        #[allow(clippy::redundant_closure_call)]
83        let out_f = $float_op(tensor_f);
84
85        match propagation {
86            QuantPropagation::Propagate => {
87                TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
88            }
89            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
90        }
91    }};
92}
93
94/// Operations on quantized tensors.
95///
96/// # Return Type Semantics
97///
98/// The return type of each operation indicates how quantization is handled:
99///
100/// ## [`QuantizedTensor<B>`]
101/// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized
102/// representation. Implementations should avoid dequantizing when possible to maintain performance.
103/// For example, shape or layout changes such as expand or transpose preserve quantization.
104///
105/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is
106/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering
107/// the quantization parameters to match the new layout.*
108///
109///
110/// ## [`TensorPrimitive<B>`]
111/// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation
112/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`])
113/// returned in floating-point form ([`TensorPrimitive::Float`]).
114///
115/// This distinction allows for fine-grained control over mixed-precision flows while still operating
116/// through a unified API.
117pub trait QTensorOps<B: Backend> {
118    /// Creates a new tensor from the data structure.
119    ///
120    /// # Arguments
121    ///
122    /// * `data` - The data structure.
123    /// * `device` - The device to create the tensor on.
124    ///
125    /// # Returns
126    ///
127    /// The tensor with the given data.
128    fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
129
130    /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
131    fn quantize(
132        tensor: FloatTensor<B>,
133        scheme: &QuantScheme,
134        qparams: QuantizationParametersPrimitive<B>,
135    ) -> QuantizedTensor<B>;
136
137    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
138    fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> {
139        // Dynamically compute min/max tensor range and qparams before quantizing
140        let (min, max) = compute_range::<B>(scheme, tensor.clone(), &Calibration::MinMax);
141        let qparams = compute_q_params(scheme, min, max);
142        Self::quantize(tensor, scheme, qparams)
143    }
144
145    /// Convert the tensor back to a higher precision data type.
146    fn dequantize(tensor: QuantizedTensor<B>) -> FloatTensor<B>;
147
148    /// Gets the device of the tensor.
149    ///
150    /// # Arguments
151    ///
152    /// * `tensor` - The tensor.
153    ///
154    /// # Returns
155    ///
156    /// The device of the tensor.
157    fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
158
159    /// Moves the tensor to the given device.
160    ///
161    /// # Arguments
162    ///
163    /// * `tensor` - The tensor.
164    /// * `device` - The device to move the tensor to.
165    ///
166    /// # Returns
167    ///
168    /// The tensor on the given device.
169    fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
170
171    /// Reshapes a tensor.
172    ///
173    /// # Arguments
174    ///
175    /// * `tensor` - The tensor to reshape.
176    /// * `shape` - The new shape of the tensor.
177    ///
178    /// # Returns
179    ///
180    /// The tensor with the new shape.
181    fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
182
183    /// Converts the tensor to a data structure.
184    ///
185    /// # Arguments
186    ///
187    /// * `tensor` - The tensor.
188    ///
189    /// # Returns
190    ///
191    /// The data structure with the tensor's data.
192    fn q_into_data(
193        tensor: QuantizedTensor<B>,
194    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
195
196    /// Detaches a tensor from the computation graph.
197    fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
198        // Should only be overridden by autodiff backends.
199        tensor
200    }
201
202    /// Sets the `require_grad` flag of a tensor.
203    fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
204        // Should only be overridden by autodiff backends.
205        tensor
206    }
207
208    /// Returns the `require_grad` flag of a tensor.
209    fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
210        // Should only be overridden by autodiff backends.
211        false
212    }
213
214    /// Broadcasts the `tensor` to the given `shape`.
215    fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
216
217    /// Transposes a tensor.
218    ///
219    /// # Arguments
220    ///
221    /// * `tensor` - The tensor to transpose.
222    ///
223    /// # Returns
224    ///
225    /// The transposed tensor.
226    fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
227        let ndims = tensor.shape().num_dims();
228        Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
229    }
230
231    /// Swaps two dimensions of a tensor.
232    ///
233    /// # Arguments
234    ///
235    /// * `tensor` - The tensor to swap the dimensions of.
236    /// * `dim1` - The first dimension to swap.
237    /// * `dim2` - The second dimension to swap.
238    ///
239    /// # Returns
240    ///
241    /// The tensor with the dimensions swapped.
242    fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
243
244    /// Permutes the dimensions of a tensor.
245    ///
246    /// # Arguments
247    ///
248    /// * `tensor` - The tensor to permute the dimensions of.
249    /// * `axes` - The new order of the dimensions.
250    /// # Returns
251    ///
252    /// The tensor with the dimensions permuted.
253    fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
254
255    /// Reverse the order of elements in a tensor along the given axes.
256    ///
257    /// # Arguments
258    ///
259    /// * `tensor` - The tensor to reverse.
260    /// * `axes` - The axes to reverse.
261    ///
262    /// The tensor with the elements reversed.
263    fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
264
265    /// Select tensor elements along the given dimension corresponding for the given indices.
266    ///
267    /// # Arguments
268    ///
269    /// * `tensor` - The tensor to select from.
270    /// * `dim` - The dimension to select from.
271    /// * `indices` - The indices to select.
272    ///
273    /// # Returns
274    ///
275    /// The selected elements.
276    fn q_select(
277        tensor: QuantizedTensor<B>,
278        dim: usize,
279        indices: IntTensor<B>,
280    ) -> QuantizedTensor<B>;
281
282    /// Select tensor elements corresponding to the given slices.
283    ///
284    /// # Arguments
285    ///
286    /// * `tensor` - The tensor to select from.
287    /// * `slices` - The slices specifying ranges and steps for each dimension.
288    ///
289    /// # Returns
290    ///
291    /// The selected elements in a new tensor.
292    fn q_slice(tensor: QuantizedTensor<B>, slices: &[Slice]) -> QuantizedTensor<B>;
293
294    /// Gather elements from a tensor.
295    ///
296    /// # Arguments
297    ///
298    /// * `dim` - The dimension to gather from.
299    /// * `tensor` - The tensor to gather from.
300    /// * `indices` - The indices to gather.
301    ///
302    /// # Returns
303    ///
304    /// The gathered elements.
305    fn q_gather(
306        dim: usize,
307        tensor: QuantizedTensor<B>,
308        indices: IntTensor<B>,
309    ) -> QuantizedTensor<B> {
310        // Default implementation. Backends can gather on the quantized values when supported.
311        dequant_op_quant!(
312            ty Self,
313            float_op |tensor| B::float_gather(dim, tensor, indices),
314            tensor
315        )
316    }
317
318    /// Repeat the tensor along the given dimension.
319    ///
320    /// # Arguments
321    ///
322    /// * `tensor` - The tensor.
323    /// * `dim` - The dimension to repeat.
324    /// * `times` - The number of times to repeat the dimension.
325    ///
326    /// # Returns
327    ///
328    /// The tensor with the given dimension repeated.
329    fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
330        dequant_op_quant!(
331            ty Self,
332            float_op |tensor| B::float_repeat_dim(tensor, dim, times),
333            tensor
334        )
335    }
336
337    /// Adds two tensors together.
338    ///
339    /// # Arguments
340    ///
341    /// * `lhs` - The left hand side tensor.
342    /// * `rhs` - The right hand side tensor.
343    ///
344    /// # Returns
345    ///
346    /// The result of adding the two tensors together.
347    fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
348        dequant_op_flow!(
349            ty Self,
350            float_op |lhs, rhs| B::float_add(lhs, rhs),
351            lhs,
352            rhs
353        )
354    }
355
356    /// Adds a scalar to a tensor.
357    ///
358    /// # Arguments
359    ///
360    /// * `lhs` - The left hand side tensor.
361    /// * `rhs` - The right hand side scalar.
362    ///
363    /// # Returns
364    ///
365    /// The result of adding the scalar to the tensor.
366    fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
367        dequant_op_flow!(
368            ty Self,
369            float_op |tensor| B::float_add_scalar(tensor, rhs),
370            lhs
371        )
372    }
373
374    /// Clamps a tensor under a minimum value.
375    ///
376    /// # Arguments
377    ///
378    /// * `tensor` - The tensor to clamp.
379    /// * `min` - The minimum value.
380    ///
381    /// # Returns
382    ///
383    /// The clamped tensor.
384    fn q_clamp_min(tensor: QuantizedTensor<B>, min: FloatElem<B>) -> TensorPrimitive<B> {
385        dequant_op_flow!(
386            ty Self,
387            float_op |tensor| B::float_clamp_min(tensor, min),
388            tensor
389        )
390    }
391
392    /// Clamps a tensor over a maximum value.
393    ///
394    /// # Arguments
395    ///
396    /// * `tensor` - The tensor to clamp.
397    /// * `max` - The maximum value.
398    ///
399    /// # Returns
400    ///
401    /// The clamped tensor.
402    fn q_clamp_max(tensor: QuantizedTensor<B>, max: FloatElem<B>) -> TensorPrimitive<B> {
403        dequant_op_flow!(
404            ty Self,
405            float_op |tensor| B::float_clamp_max(tensor, max),
406            tensor
407        )
408    }
409
410    /// Clamps a tensor between a minimum and maximum value.
411    ///
412    /// # Arguments
413    ///
414    /// * `tensor` - The tensor to clamp.
415    /// * `min` - The minimum value.
416    /// * `max` - The maximum value.
417    ///
418    /// # Returns
419    ///
420    /// The clamped tensor.
421    fn q_clamp(
422        tensor: QuantizedTensor<B>,
423        min: FloatElem<B>,
424        max: FloatElem<B>,
425    ) -> TensorPrimitive<B> {
426        dequant_op_flow!(
427            ty Self,
428            float_op |tensor| B::float_clamp(tensor, min, max),
429            tensor
430        )
431    }
432
433    /// Subtracts two tensors.
434    ///
435    /// # Arguments
436    ///
437    /// * `lhs` - The left hand side tensor.
438    /// * `rhs` - The right hand side tensor.
439    ///
440    /// # Returns
441    ///
442    /// The result of subtracting the two tensors.
443    fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
444        dequant_op_flow!(
445            ty Self,
446            float_op |lhs, rhs| B::float_sub(lhs, rhs),
447            lhs,
448            rhs
449        )
450    }
451
452    /// Subtracts a scalar from a tensor.
453    ///
454    /// # Arguments
455    ///
456    /// * `lhs` - The left hand side tensor.
457    /// * `rhs` - The right hand side scalar.
458    ///
459    /// # Returns
460    ///
461    /// The result of subtracting the scalar from the tensor.
462    fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
463        dequant_op_flow!(
464            ty Self,
465            float_op |tensor| B::float_sub_scalar(tensor, rhs),
466            lhs
467        )
468    }
469
470    /// Multiplies two tensors together element-wise.
471    fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
472        dequant_op_flow!(
473            ty Self,
474            float_op |lhs, rhs| B::float_mul(lhs, rhs),
475            lhs,
476            rhs
477        )
478    }
479
480    /// Multiplies a tensor by a scalar.
481    ///
482    /// # Arguments
483    ///
484    /// * `lhs` - The left hand side tensor.
485    /// * `rhs` - The right hand side scalar.
486    ///
487    /// # Returns
488    ///
489    /// The result of multiplying the tensor by the scalar.
490    fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
491        dequant_op_flow!(
492            ty Self,
493            float_op |tensor| B::float_mul_scalar(tensor, rhs),
494            lhs
495        )
496    }
497
498    /// Divides two tensors element-wise.
499    ///
500    /// # Arguments
501    ///
502    /// * `lhs` - The left hand side tensor.
503    /// * `rhs` - The right hand side tensor.
504    ///
505    /// # Returns
506    ///
507    /// The result of dividing the two tensors.
508    fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
509        dequant_op_flow!(
510            ty Self,
511            float_op |lhs, rhs| B::float_div(lhs, rhs),
512            lhs,
513            rhs
514        )
515    }
516
517    /// Divides a tensor by a scalar.
518    ///
519    /// # Arguments
520    ///
521    /// * `lhs` - The left hand side tensor.
522    /// * `rhs` - The right hand side scalar.
523    ///
524    /// # Returns
525    ///
526    /// The result of dividing the tensor by the scalar.
527    fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
528        dequant_op_flow!(
529            ty Self,
530            float_op |tensor| B::float_div_scalar(tensor, rhs),
531            lhs
532        )
533    }
534
535    /// Multiplies two tensors together using matrix multiplication.
536    ///
537    /// # Arguments
538    ///
539    /// * `lhs` - The left hand side tensor.
540    /// * `rhs` - The right hand side tensor.
541    ///
542    /// # Returns
543    ///
544    /// The result of multiplying the two tensors together using matrix multiplication.
545    fn q_matmul(lhs: TensorPrimitive<B>, rhs: TensorPrimitive<B>) -> TensorPrimitive<B> {
546        let mut propagation = QuantPropagation::Inhibit;
547        let mut scheme = QuantScheme::default();
548        let lhs = match lhs {
549            TensorPrimitive::Float(lhs) => lhs,
550            TensorPrimitive::QFloat(lhs) => {
551                propagation = lhs.propagation();
552                scheme = *lhs.scheme();
553                Self::dequantize(lhs)
554            }
555        };
556        let rhs = match rhs {
557            TensorPrimitive::Float(rhs) => rhs,
558            TensorPrimitive::QFloat(rhs) => {
559                propagation = rhs.propagation();
560                scheme = *rhs.scheme();
561                Self::dequantize(rhs)
562            }
563        };
564
565        let out_f = B::float_matmul(lhs, rhs);
566        match propagation {
567            QuantPropagation::Propagate => {
568                TensorPrimitive::QFloat(<Self>::quantize_dynamic(out_f, &scheme))
569            }
570            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
571        }
572    }
573
574    /// Negates a tensor element-wise.
575    fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
576        dequant_op_flow!(
577            ty Self,
578            float_op |tensor| B::float_neg(tensor),
579            tensor
580        )
581    }
582
583    /// Calculates the reciprocals element-wise
584    fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
585        dequant_op_flow!(
586            ty Self,
587            float_op |tensor| B::float_recip(tensor),
588            tensor
589        )
590    }
591
592    /// Sum of all elements in a tensor.
593    ///
594    /// # Arguments
595    ///
596    /// * `tensor` - The tensor to sum.
597    ///
598    /// # Returns
599    ///
600    /// A scalar tensor with the sum of all elements in `tensor`.
601    fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
602        dequant_op_flow!(
603            ty Self,
604            float_op |tensor| B::float_sum(tensor),
605            tensor
606        )
607    }
608
609    /// Sum of all elements in a tensor along a dimension.
610    ///
611    /// # Arguments
612    ///
613    /// * `tensor` - The tensor to sum.
614    /// * `dim` - The dimension along which to sum.
615    ///
616    /// # Returns
617    ///
618    /// A tensor with the sum of all elements in `tensor` along `dim`.
619    fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
620        dequant_op_flow!(
621            ty Self,
622            float_op |tensor| B::float_sum_dim(tensor, dim),
623            tensor
624        )
625    }
626
627    /// Product of all elements in a tensor.
628    ///
629    /// # Arguments
630    ///
631    /// * `tensor` - The tensor to product.
632    ///
633    /// # Returns
634    ///
635    /// A scalar tensor with the product of all elements in `tensor`.
636    fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
637        dequant_op_flow!(
638            ty Self,
639            float_op |tensor| B::float_prod(tensor),
640            tensor
641        )
642    }
643
644    /// Product of all elements in a tensor along a dimension.
645    ///
646    /// # Arguments
647    ///
648    /// * `tensor` - The tensor to product.
649    ///
650    /// # Returns
651    ///
652    /// A tensor with the product of all elements in `tensor` along `dim`.
653    fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
654        dequant_op_flow!(
655            ty Self,
656            float_op |tensor| B::float_prod_dim(tensor, dim),
657            tensor
658        )
659    }
660
661    /// Mean of all elements in a tensor.
662    ///
663    /// # Arguments
664    ///
665    /// * `tensor` - The tensor to mean.
666    ///
667    /// # Returns
668    ///
669    /// A scalar tensor with the mean of all elements in `tensor`.
670    fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
671        dequant_op_flow!(
672            ty Self,
673            float_op |tensor| B::float_mean(tensor),
674            tensor
675        )
676    }
677
678    /// Mean of all elements in a tensor along a dimension.
679    ///
680    /// # Arguments
681    ///
682    /// * `tensor` - The tensor to mean.
683    /// * `dim` - The dimension along which to mean.
684    ///
685    /// # Returns
686    ///
687    /// A tensor with the mean of all elements in `tensor` along `dim`.
688    fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
689        dequant_op_flow!(
690            ty Self,
691            float_op |tensor| B::float_mean_dim(tensor, dim),
692            tensor
693        )
694    }
695
696    /// Computes the cumulative sum of elements along a dimension.
697    ///
698    /// # Arguments
699    ///
700    /// * `tensor` - The tensor to compute the cumulative sum of.
701    /// * `dim` - The dimension along which to compute the cumulative sum.
702    ///
703    /// # Returns
704    ///
705    /// A tensor with the same shape where each element is the cumulative sum
706    /// of all elements up to and including that position along the dimension.
707    fn q_cumsum(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
708        dequant_op_flow!(
709            ty Self,
710            float_op |tensor| B::float_cumsum(tensor, dim),
711            tensor
712        )
713    }
714
715    /// Computes the cumulative product of elements along a dimension.
716    ///
717    /// # Arguments
718    ///
719    /// * `tensor` - The tensor to compute the cumulative product of.
720    /// * `dim` - The dimension along which to compute the cumulative product.
721    ///
722    /// # Returns
723    ///
724    /// A tensor with the same shape where each element is the cumulative product
725    /// of all elements up to and including that position along the dimension.
726    fn q_cumprod(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
727        dequant_op_flow!(
728            ty Self,
729            float_op |tensor| B::float_cumprod(tensor, dim),
730            tensor
731        )
732    }
733
734    /// Computes the cumulative minimum of elements along a dimension.
735    ///
736    /// # Arguments
737    ///
738    /// * `tensor` - The tensor to compute the cumulative minimum of.
739    /// * `dim` - The dimension along which to compute the cumulative minimum.
740    ///
741    /// # Returns
742    ///
743    /// A tensor with the same shape where each element is the minimum
744    /// of all elements up to and including that position along the dimension.
745    fn q_cummin(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
746        dequant_op_flow!(
747            ty Self,
748            float_op |tensor| B::float_cummin(tensor, dim),
749            tensor
750        )
751    }
752
753    /// Computes the cumulative maximum of elements along a dimension.
754    ///
755    /// # Arguments
756    ///
757    /// * `tensor` - The tensor to compute the cumulative maximum of.
758    /// * `dim` - The dimension along which to compute the cumulative maximum.
759    ///
760    /// # Returns
761    ///
762    /// A tensor with the same shape where each element is the maximum
763    /// of all elements up to and including that position along the dimension.
764    fn q_cummax(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
765        dequant_op_flow!(
766            ty Self,
767            float_op |tensor| B::float_cummax(tensor, dim),
768            tensor
769        )
770    }
771
772    /// Returns a new tensor with exponential values.
773    ///
774    /// # Arguments
775    ///
776    /// * `tensor` - The tensor to exponentiate.
777    ///
778    /// # Returns
779    ///
780    /// A tensor with the same shape as `tensor` with exponential values.
781    fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
782        dequant_op_flow!(
783            ty Self,
784            float_op |tensor| B::float_exp(tensor),
785            tensor
786        )
787    }
788
789    /// Returns a new tensor with natural logarithm values.
790    ///
791    /// # Arguments
792    ///
793    /// * `tensor` - The tensor to take the logarithm of.
794    ///
795    /// # Returns
796    ///
797    /// A tensor with the same shape as `tensor` with natural logarithm values.
798    fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
799        dequant_op_flow!(
800            ty Self,
801            float_op |tensor| B::float_log(tensor),
802            tensor
803        )
804    }
805
806    /// Returns a new tensor with logarithm values of (1 + Xi).
807    ///
808    /// # Arguments
809    ///
810    /// * `tensor` - The tensor to take the logarithm of.
811    ///
812    /// # Returns
813    ///
814    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
815    fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
816        dequant_op_flow!(
817            ty Self,
818            float_op |tensor| B::float_log1p(tensor),
819            tensor
820        )
821    }
822
823    /// Element-wise power with another tensor.
824    ///
825    /// # Arguments
826    ///
827    /// * `lhs` - The left hand side tensor.
828    /// * `rhs` - The right hand side tensor.
829    ///
830    /// # Returns
831    ///
832    /// The elements of `lhs` raised to the power of the elements of `rhs`.
833    fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
834        dequant_op_flow!(
835            ty Self,
836            float_op |lhs, rhs| B::float_powf(lhs, rhs),
837            lhs,
838            rhs
839        )
840    }
841
842    /// Element-wise power with an IntTensor.
843    ///
844    /// # Arguments
845    ///
846    /// * `lhs` - The left hand side tensor.
847    /// * `rhs` - The right hand side floatTensor.
848    ///
849    /// # Returns
850    ///
851    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
852    fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> {
853        dequant_op_flow!(
854            ty Self,
855            float_op |tensor| B::float_powi(tensor, rhs),
856            lhs
857        )
858    }
859
860    /// Element-wise power with an int scalar.
861    ///
862    /// # Arguments
863    ///
864    /// * `lhs` - The left hand side tensor.
865    /// * `rhs` - The right hand side scalar.
866    ///
867    /// # Returns
868    ///
869    /// The elements of `lhs` raised to the value of `rhs`.
870    fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: IntElem<B>) -> TensorPrimitive<B> {
871        dequant_op_flow!(
872            ty Self,
873            float_op |tensor| B::float_powi_scalar(tensor, rhs),
874            lhs
875        )
876    }
877
878    /// Element-wise power with a float scalar.
879    ///
880    /// # Arguments
881    ///
882    /// * `tensor` - The tensor to exponentiate.
883    /// * `value` - The exponent.
884    ///
885    /// # Returns
886    ///
887    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
888    fn q_powf_scalar(tensor: QuantizedTensor<B>, value: f32) -> TensorPrimitive<B> {
889        dequant_op_flow!(
890            ty Self,
891            float_op |tensor| B::float_powf_scalar(tensor, value),
892            tensor
893        )
894    }
895
896    /// Returns a new tensor with square root values.
897    ///
898    /// # Arguments
899    ///
900    /// * `tensor` - The tensor to take the square root of.
901    ///
902    /// # Returns
903    ///
904    /// A tensor with the same shape as `tensor` with square root values.
905    fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
906        dequant_op_flow!(
907            ty Self,
908            float_op |tensor| B::float_sqrt(tensor),
909            tensor
910        )
911    }
912
913    /// Returns a new tensor with absolute values.
914    ///
915    /// # Arguments
916    ///
917    /// * `tensor` - The tensor to take absolute value of.
918    ///
919    /// # Returns
920    ///
921    /// A tensor with the same shape as `tensor` with absolute values.
922    fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
923        dequant_op_quant!(
924            ty Self,
925            float_op |tensor| B::float_abs(tensor),
926            tensor
927        )
928    }
929
930    /// Returns a new tensor with cosine values.
931    ///
932    /// # Arguments
933    ///
934    /// * `tensor` - The tensor to take the cosine of.
935    ///
936    /// # Returns
937    ///
938    /// A tensor with the same shape as `tensor` with cosine values.
939    fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
940        dequant_op_flow!(
941            ty Self,
942            float_op |tensor| B::float_cos(tensor),
943            tensor
944        )
945    }
946
947    /// Returns a new tensor with sine values.
948    ///
949    /// # Arguments
950    ///
951    /// * `tensor` - The tensor to take the sine of.
952    ///
953    /// # Returns
954    ///
955    /// A tensor with the same shape as `tensor` with sine values.
956    fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
957        dequant_op_flow!(
958            ty Self,
959            float_op |tensor| B::float_sin(tensor),
960            tensor
961        )
962    }
963
964    /// Returns a new tensor with tangent values.
965    ///
966    /// # Arguments
967    ///
968    /// * `tensor` - The tensor to take the tangent of.
969    ///
970    /// # Returns
971    ///
972    /// A tensor with the same shape as `tensor` with tangent values.
973    fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
974        dequant_op_flow!(
975            ty Self,
976            float_op |tensor| B::float_tan(tensor),
977            tensor
978        )
979    }
980
981    /// Returns a new tensor with hyperbolic cosine values.
982    ///
983    /// # Arguments
984    ///
985    /// * `tensor` - The tensor to take the hyperbolic cosine of.
986    ///
987    /// # Returns
988    ///
989    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
990    fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
991        dequant_op_flow!(
992            ty Self,
993            float_op |tensor| B::float_cosh(tensor),
994            tensor
995        )
996    }
997
998    /// Returns a new tensor with hyperbolic sine values.
999    ///
1000    /// # Arguments
1001    ///
1002    /// * `tensor` - The tensor to take the hyperbolic sine of.
1003    ///
1004    /// # Returns
1005    ///
1006    /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1007    fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1008        dequant_op_flow!(
1009            ty Self,
1010            float_op |tensor| B::float_sinh(tensor),
1011            tensor
1012        )
1013    }
1014
1015    /// Returns a new tensor with hyperbolic tangent values.
1016    ///
1017    /// # Arguments
1018    ///
1019    /// * `tensor` - The tensor to take the hyperbolic tangent of.
1020    ///
1021    /// # Returns
1022    ///
1023    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1024    fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1025        dequant_op_flow!(
1026            ty Self,
1027            float_op |tensor| B::float_tanh(tensor),
1028            tensor
1029        )
1030    }
1031
1032    /// Returns a new tensor with the error function values.
1033    ///
1034    /// # Arguments
1035    ///
1036    /// * `tensor` - The tensor to take the error function of.
1037    ///
1038    /// # Returns
1039    ///
1040    /// A tensor with the same shape as `tensor` with error function values.
1041    fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1042        dequant_op_flow!(
1043            ty Self,
1044            float_op |tensor| B::float_erf(tensor),
1045            tensor
1046        )
1047    }
1048
1049    /// Concatenates tensors along a dimension.
1050    ///
1051    /// # Arguments
1052    ///
1053    /// * `tensors` - The tensors to concatenate.
1054    /// * `dim` - The dimension along which to concatenate.
1055    ///
1056    /// # Returns
1057    ///
1058    /// A tensor with the concatenated tensors along `dim`.
1059    fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
1060        // Heuristic: prioritize first tensor scheme
1061        let scheme = *tensors.first().unwrap().scheme();
1062
1063        let tensor_f = tensors
1064            .into_iter()
1065            .map(|tensor| Self::dequantize(tensor))
1066            .collect();
1067
1068        let out_f = B::float_cat(tensor_f, dim);
1069
1070        Self::quantize_dynamic(out_f, &scheme)
1071    }
1072
1073    /// Gets the indices of the maximum elements of a tensor along an axis.
1074    ///
1075    /// # Arguments
1076    ///
1077    /// * `tensor` - The tensor to get the maximum elements of.
1078    /// * `dim` - The dimension along which to get the maximum elements.
1079    ///
1080    /// # Returns
1081    ///
1082    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1083    fn q_argmax(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1084        // Default implementation. Backends can sort on the int values since qparams remain the same.
1085        let tensor_f = Self::dequantize(tensor);
1086        B::float_argmax(tensor_f, dim)
1087    }
1088
1089    /// Gets the indices of the minimum elements of a tensor along an axis.
1090    ///
1091    /// # Arguments
1092    ///
1093    /// * `tensor` - The tensor to get the minimum elements of.
1094    /// * `dim` - The dimension along which to get the minimum elements.
1095    ///
1096    /// # Returns
1097    ///
1098    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1099    fn q_argmin(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1100        // Default implementation. Backends can sort on the int values since qparams remain the same.
1101        let tensor_f = Self::dequantize(tensor);
1102        B::float_argmin(tensor_f, dim)
1103    }
1104
1105    /// Gets the maximum element of a tensor.
1106    ///
1107    /// # Arguments
1108    ///
1109    /// * `tensor` - The tensor to get the maximum elements of.
1110    ///
1111    /// # Returns
1112    ///
1113    /// A tensor with the maximum element of `tensor`.
1114    fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1115        let shape = tensor.shape();
1116        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1117
1118        B::q_max_dim(tensor, 0)
1119    }
1120
1121    /// Gets the maximum elements of a tensor along an axis.
1122    ///
1123    /// # Arguments
1124    ///
1125    /// * `tensor` - The tensor to get the maximum elements of.
1126    /// * `dim` - The dimension along which to get the maximum elements.
1127    ///
1128    /// # Returns
1129    ///
1130    /// A tensor with the maximum elements of `tensor` along `dim`.
1131    fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1132        let index = B::q_argmax(tensor.clone(), dim);
1133
1134        B::q_gather(dim, tensor, index)
1135    }
1136
1137    /// Gets the maximum elements of a tensor along an axis and their indices.
1138    ///
1139    /// # Arguments
1140    ///
1141    /// * `tensor` - The tensor to get the maximum elements of.
1142    /// * `dim` - The dimension along which to get the maximum elements.
1143    ///
1144    /// # Returns
1145    ///
1146    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1147    fn q_max_dim_with_indices(
1148        tensor: QuantizedTensor<B>,
1149        dim: usize,
1150    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1151        let index = B::q_argmax(tensor.clone(), dim);
1152        let values = B::q_gather(dim, tensor, index.clone());
1153
1154        (values, index)
1155    }
1156
1157    /// Gets the minimum element of a tensor.
1158    ///
1159    /// # Arguments
1160    ///
1161    /// * `tensor` - The tensor to get the minimum elements of.
1162    ///
1163    /// # Returns
1164    ///
1165    /// A tensor with the minimum element of `tensor`.
1166    fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1167        let shape = tensor.shape();
1168        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1169
1170        B::q_min_dim(tensor, 0)
1171    }
1172
1173    /// Gets the minimum elements of a tensor along an axis.
1174    ///
1175    /// # Arguments
1176    ///
1177    /// * `tensor` - The tensor to get the minimum elements of.
1178    /// * `dim` - The dimension along which to get the minimum elements.
1179    ///
1180    /// # Returns
1181    ///
1182    /// A tensor with the minimum elements of `tensor` along `dim`.
1183    fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1184        let index = B::q_argmin(tensor.clone(), dim);
1185
1186        B::q_gather(dim, tensor, index)
1187    }
1188
1189    /// Gets the minimum elements of a tensor along an axis and their indices.
1190    ///
1191    /// # Arguments
1192    ///
1193    /// * `tensor` - The tensor to get the minimum elements of.
1194    /// * `dim` - The dimension along which to get the minimum elements.
1195    ///
1196    /// # Returns
1197    ///
1198    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1199    fn q_min_dim_with_indices(
1200        tensor: QuantizedTensor<B>,
1201        dim: usize,
1202    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1203        let index = B::q_argmin(tensor.clone(), dim);
1204        let values = B::q_gather(dim, tensor, index.clone());
1205
1206        (values, index)
1207    }
1208
1209    /// Gets the maximum element of a tensor.
1210    ///
1211    /// # Arguments
1212    ///
1213    /// * `tensor` - The tensor to get the maximum elements of.
1214    ///
1215    /// # Returns
1216    ///
1217    /// A tensor with the maximum element of `tensor`.
1218    fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1219        let shape = tensor.shape();
1220        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1221
1222        B::q_max_abs_dim(tensor, 0)
1223    }
1224
1225    /// Gets the maximum elements of a tensor along an axis.
1226    ///
1227    /// # Arguments
1228    ///
1229    /// * `tensor` - The tensor to get the maximum elements of.
1230    /// * `dim` - The dimension along which to get the maximum elements.
1231    ///
1232    /// # Returns
1233    ///
1234    /// A tensor with the maximum elements of `tensor` along `dim`.
1235    fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1236        let index = B::q_argmax(B::q_abs(tensor.clone()), dim);
1237
1238        B::q_gather(dim, tensor, index)
1239    }
1240
1241    /// Tests if any element in the `tensor` evaluates to True.
1242    ///
1243    /// # Arguments
1244    ///
1245    /// * `tensor` - The tensor to test.
1246    ///
1247    /// # Returns
1248    ///
1249    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1250    fn q_any(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1251        let tensor_f = Self::dequantize(tensor);
1252        B::float_any(tensor_f)
1253    }
1254
1255    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1256    ///
1257    /// # Arguments
1258    ///
1259    /// * `tensor` - The tensor to test.
1260    /// * `dim` - The axis along which to test.
1261    ///
1262    /// # Returns
1263    ///
1264    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1265    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1266    /// input evaluates to True, False otherwise.
1267    fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1268        let tensor_f = Self::dequantize(tensor);
1269        B::float_any_dim(tensor_f, dim)
1270    }
1271
1272    /// Tests if all elements in the `tensor` evaluate to True.
1273    ///
1274    /// # Arguments
1275    ///
1276    /// * `tensor` - The tensor to test.
1277    ///
1278    /// # Returns
1279    ///
1280    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1281    /// evaluate to True, False otherwise.
1282    fn q_all(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1283        let tensor_f = Self::dequantize(tensor);
1284        B::float_all(tensor_f)
1285    }
1286
1287    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1288    ///
1289    /// # Arguments
1290    ///
1291    /// * `tensor` - The tensor to test.
1292    /// * `dim` - The axis along which to test.
1293    ///
1294    /// # Returns
1295    ///
1296    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1297    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1298    /// evaluates to True, False otherwise.
1299    fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1300        let tensor_f = Self::dequantize(tensor);
1301        B::float_all_dim(tensor_f, dim)
1302    }
1303
1304    /// Sort the elements of the input `tensor` by value in along a given dimension.
1305    ///
1306    /// This sort is unstable (i.e., may reorder equal elements).
1307    ///
1308    /// # Arguments
1309    ///
1310    /// * `tensor` - The input tensor.
1311    /// * `dim` - The axis along which to sort.
1312    /// * `descending` - The sorting order.
1313    ///
1314    /// # Returns
1315    ///
1316    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1317    fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1318        // Default implementation. Backends can sort on the int values since qparams remain the same.
1319        dequant_op_quant!(
1320            ty Self,
1321            float_op |tensor| B::float_sort(tensor, dim, descending),
1322            tensor
1323        )
1324    }
1325
1326    /// Sort the elements of the input `tensor` by value in along a given dimension.
1327    ///
1328    /// This sort is unstable (i.e., may reorder equal elements).
1329    ///
1330    /// # Arguments
1331    ///
1332    /// * `tensor` - The input tensor.
1333    /// * `dim` - The axis along which to sort.
1334    /// * `descending` - The sorting order.
1335    ///
1336    /// # Returns
1337    ///
1338    /// A tensor with the same shape as the input tensor and corresponding indices, where
1339    /// the elements are sorted by value and the indices map back to the original input tensor.
1340    fn q_sort_with_indices(
1341        tensor: QuantizedTensor<B>,
1342        dim: usize,
1343        descending: bool,
1344    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1345        // Default implementation. Backends can sort on the int values since qparams remain the same.
1346        let scheme = *tensor.scheme();
1347
1348        let tensor_f = Self::dequantize(tensor);
1349        let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending);
1350
1351        (Self::quantize_dynamic(out_f, &scheme), indices)
1352    }
1353
1354    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1355    ///
1356    /// This sort is unstable (i.e., may reorder equal elements).
1357    ///
1358    /// # Arguments
1359    ///
1360    /// * `tensor` - The input tensor.
1361    /// * `dim` - The axis along which to sort.
1362    /// * `descending` - The sorting order.
1363    ///
1364    /// # Returns
1365    ///
1366    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1367    fn q_argsort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1368        // Default implementation. Backends can sort on the int values since qparams remain the same.
1369        let tensor_f = Self::dequantize(tensor);
1370        B::float_argsort(tensor_f, dim, descending)
1371    }
1372}