burn_backend/backend/ops/
qtensor.rs

1use alloc::vec::Vec;
2use burn_std::{
3    Shape, Slice,
4    quantization::{QuantPropagation, QuantScheme},
5};
6
7use crate::{
8    Backend, ExecutionError, QTensorPrimitive, TensorData, TensorMetadata, TensorPrimitive,
9    tensor::{Calibration, QuantizationParametersPrimitive, compute_q_params, compute_range},
10};
11
12use crate::tensor::{
13    BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor,
14};
15
16/// Automatically applies `dequantization -> float operation -> quantization`.
17///
18/// Used for tensor ops that should always return a quantized output.
19#[macro_export]
20macro_rules! dequant_op_quant {
21    // Binary tensor float op w/ lhs & rhs
22    (
23        ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
24    ) => {{
25        // Heuristic: prioritize lhs scheme
26        let scheme = $t1.scheme().clone();
27
28        let t1_f = <$ty>::dequantize($t1);
29        let t2_f = <$ty>::dequantize($t2);
30        #[allow(clippy::redundant_closure_call)]
31        let out_f = $float_op(t1_f, t2_f);
32
33        <$ty>::quantize_dynamic(out_f, &scheme)
34    }};
35    // Unary tensor float op
36    (
37        ty $ty:ty, float_op $float_op:expr, $tensor:expr
38    ) => {{
39        let scheme = $tensor.scheme().clone();
40
41        let tensor_f = <$ty>::dequantize($tensor);
42        #[allow(clippy::redundant_closure_call)]
43        let out_f = $float_op(tensor_f);
44
45        <$ty>::quantize_dynamic(out_f, &scheme)
46    }};
47}
48
49/// Automatically applies `dequantization -> float operation [-> quantization]`.
50///
51/// The output quantization step is optional.
52/// It is only performed when the input quantization scheme is propagated.
53#[macro_export]
54macro_rules! dequant_op_flow {
55    // Binary tensor float op w/ lhs & rhs
56    (
57        ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
58    ) => {{
59        // Heuristic: prioritize lhs scheme
60        let scheme = $t1.scheme().clone();
61        let propagation = $t1.propagation();
62
63        let t1_f = <$ty>::dequantize($t1);
64        let t2_f = <$ty>::dequantize($t2);
65        #[allow(clippy::redundant_closure_call)]
66        let out_f = $float_op(t1_f, t2_f);
67
68        match propagation {
69            QuantPropagation::Propagate => {
70                TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
71            }
72            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
73        }
74    }};
75    // Unary tensor float op
76    (
77        ty $ty:ty, float_op $float_op:expr, $tensor:expr
78    ) => {{
79        let scheme = $tensor.scheme().clone();
80        let propagation = $tensor.propagation();
81
82        let tensor_f = <$ty>::dequantize($tensor);
83        #[allow(clippy::redundant_closure_call)]
84        let out_f = $float_op(tensor_f);
85
86        match propagation {
87            QuantPropagation::Propagate => {
88                TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
89            }
90            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
91        }
92    }};
93}
94
95/// Operations on quantized tensors.
96///
97/// # Return Type Semantics
98///
99/// The return type of each operation indicates how quantization is handled:
100///
101/// ## [`QuantizedTensor<B>`]
102/// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized
103/// representation. Implementations should avoid dequantizing when possible to maintain performance.
104/// For example, shape or layout changes such as expand or transpose preserve quantization.
105///
106/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is
107/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering
108/// the quantization parameters to match the new layout.*
109///
110///
111/// ## [`TensorPrimitive<B>`]
112/// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation
113/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`])
114/// returned in floating-point form ([`TensorPrimitive::Float`]).
115///
116/// This distinction allows for fine-grained control over mixed-precision flows while still operating
117/// through a unified API.
118pub trait QTensorOps<B: Backend> {
119    /// Creates a new tensor from the data structure.
120    ///
121    /// # Arguments
122    ///
123    /// * `data` - The data structure.
124    /// * `device` - The device to create the tensor on.
125    ///
126    /// # Returns
127    ///
128    /// The tensor with the given data.
129    fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
130
131    /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
132    fn quantize(
133        tensor: FloatTensor<B>,
134        scheme: &QuantScheme,
135        qparams: QuantizationParametersPrimitive<B>,
136    ) -> QuantizedTensor<B>;
137
138    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
139    fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> {
140        // Dynamically compute min/max tensor range and qparams before quantizing
141        let (min, max) = compute_range::<B>(scheme, tensor.clone(), &Calibration::MinMax);
142        let qparams = compute_q_params(scheme, min, max);
143        Self::quantize(tensor, scheme, qparams)
144    }
145
146    /// Convert the tensor back to a higher precision data type.
147    fn dequantize(tensor: QuantizedTensor<B>) -> FloatTensor<B>;
148
149    /// Gets the device of the tensor.
150    ///
151    /// # Arguments
152    ///
153    /// * `tensor` - The tensor.
154    ///
155    /// # Returns
156    ///
157    /// The device of the tensor.
158    fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
159
160    /// Moves the tensor to the given device.
161    ///
162    /// # Arguments
163    ///
164    /// * `tensor` - The tensor.
165    /// * `device` - The device to move the tensor to.
166    ///
167    /// # Returns
168    ///
169    /// The tensor on the given device.
170    fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
171
172    /// Reshapes a tensor.
173    ///
174    /// # Arguments
175    ///
176    /// * `tensor` - The tensor to reshape.
177    /// * `shape` - The new shape of the tensor.
178    ///
179    /// # Returns
180    ///
181    /// The tensor with the new shape.
182    fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
183
184    /// Converts the tensor to a data structure.
185    ///
186    /// # Arguments
187    ///
188    /// * `tensor` - The tensor.
189    ///
190    /// # Returns
191    ///
192    /// The data structure with the tensor's data.
193    fn q_into_data(
194        tensor: QuantizedTensor<B>,
195    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
196
197    /// Detaches a tensor from the computation graph.
198    fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
199        // Should only be overridden by autodiff backends.
200        tensor
201    }
202
203    /// Sets the `require_grad` flag of a tensor.
204    fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
205        // Should only be overridden by autodiff backends.
206        tensor
207    }
208
209    /// Returns the `require_grad` flag of a tensor.
210    fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
211        // Should only be overridden by autodiff backends.
212        false
213    }
214
215    /// Broadcasts the `tensor` to the given `shape`.
216    fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
217
218    /// Transposes a tensor.
219    ///
220    /// # Arguments
221    ///
222    /// * `tensor` - The tensor to transpose.
223    ///
224    /// # Returns
225    ///
226    /// The transposed tensor.
227    fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
228        let ndims = tensor.shape().num_dims();
229        Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
230    }
231
232    /// Swaps two dimensions of a tensor.
233    ///
234    /// # Arguments
235    ///
236    /// * `tensor` - The tensor to swap the dimensions of.
237    /// * `dim1` - The first dimension to swap.
238    /// * `dim2` - The second dimension to swap.
239    ///
240    /// # Returns
241    ///
242    /// The tensor with the dimensions swapped.
243    fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
244
245    /// Permutes the dimensions of a tensor.
246    ///
247    /// # Arguments
248    ///
249    /// * `tensor` - The tensor to permute the dimensions of.
250    /// * `axes` - The new order of the dimensions.
251    /// # Returns
252    ///
253    /// The tensor with the dimensions permuted.
254    fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
255
256    /// Reverse the order of elements in a tensor along the given axes.
257    ///
258    /// # Arguments
259    ///
260    /// * `tensor` - The tensor to reverse.
261    /// * `axes` - The axes to reverse.
262    ///
263    /// The tensor with the elements reversed.
264    fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
265
266    /// Select tensor elements along the given dimension corresponding for the given indices.
267    ///
268    /// # Arguments
269    ///
270    /// * `tensor` - The tensor to select from.
271    /// * `dim` - The dimension to select from.
272    /// * `indices` - The indices to select.
273    ///
274    /// # Returns
275    ///
276    /// The selected elements.
277    fn q_select(
278        tensor: QuantizedTensor<B>,
279        dim: usize,
280        indices: IntTensor<B>,
281    ) -> QuantizedTensor<B>;
282
283    /// Select tensor elements corresponding to the given slices.
284    ///
285    /// # Arguments
286    ///
287    /// * `tensor` - The tensor to select from.
288    /// * `slices` - The slices specifying ranges and steps for each dimension.
289    ///
290    /// # Returns
291    ///
292    /// The selected elements in a new tensor.
293    fn q_slice(tensor: QuantizedTensor<B>, slices: &[Slice]) -> QuantizedTensor<B>;
294
295    /// Gather elements from a tensor.
296    ///
297    /// # Arguments
298    ///
299    /// * `dim` - The dimension to gather from.
300    /// * `tensor` - The tensor to gather from.
301    /// * `indices` - The indices to gather.
302    ///
303    /// # Returns
304    ///
305    /// The gathered elements.
306    fn q_gather(
307        dim: usize,
308        tensor: QuantizedTensor<B>,
309        indices: IntTensor<B>,
310    ) -> QuantizedTensor<B> {
311        // Default implementation. Backends can gather on the quantized values when supported.
312        dequant_op_quant!(
313            ty Self,
314            float_op |tensor| B::float_gather(dim, tensor, indices),
315            tensor
316        )
317    }
318
319    /// Repeat the tensor along the given dimension.
320    ///
321    /// # Arguments
322    ///
323    /// * `tensor` - The tensor.
324    /// * `dim` - The dimension to repeat.
325    /// * `times` - The number of times to repeat the dimension.
326    ///
327    /// # Returns
328    ///
329    /// The tensor with the given dimension repeated.
330    fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
331        dequant_op_quant!(
332            ty Self,
333            float_op |tensor| B::float_repeat_dim(tensor, dim, times),
334            tensor
335        )
336    }
337
338    /// Adds two tensors together.
339    ///
340    /// # Arguments
341    ///
342    /// * `lhs` - The left hand side tensor.
343    /// * `rhs` - The right hand side tensor.
344    ///
345    /// # Returns
346    ///
347    /// The result of adding the two tensors together.
348    fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
349        dequant_op_flow!(
350            ty Self,
351            float_op |lhs, rhs| B::float_add(lhs, rhs),
352            lhs,
353            rhs
354        )
355    }
356
357    /// Adds a scalar to a tensor.
358    ///
359    /// # Arguments
360    ///
361    /// * `lhs` - The left hand side tensor.
362    /// * `rhs` - The right hand side scalar.
363    ///
364    /// # Returns
365    ///
366    /// The result of adding the scalar to the tensor.
367    fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
368        dequant_op_flow!(
369            ty Self,
370            float_op |tensor| B::float_add_scalar(tensor, rhs),
371            lhs
372        )
373    }
374
375    /// Clamps a tensor under a minimum value.
376    ///
377    /// # Arguments
378    ///
379    /// * `tensor` - The tensor to clamp.
380    /// * `min` - The minimum value.
381    ///
382    /// # Returns
383    ///
384    /// The clamped tensor.
385    fn q_clamp_min(tensor: QuantizedTensor<B>, min: FloatElem<B>) -> TensorPrimitive<B> {
386        dequant_op_flow!(
387            ty Self,
388            float_op |tensor| B::float_clamp_min(tensor, min),
389            tensor
390        )
391    }
392
393    /// Clamps a tensor over a maximum value.
394    ///
395    /// # Arguments
396    ///
397    /// * `tensor` - The tensor to clamp.
398    /// * `max` - The maximum value.
399    ///
400    /// # Returns
401    ///
402    /// The clamped tensor.
403    fn q_clamp_max(tensor: QuantizedTensor<B>, max: FloatElem<B>) -> TensorPrimitive<B> {
404        dequant_op_flow!(
405            ty Self,
406            float_op |tensor| B::float_clamp_max(tensor, max),
407            tensor
408        )
409    }
410
411    /// Clamps a tensor between a minimum and maximum value.
412    ///
413    /// # Arguments
414    ///
415    /// * `tensor` - The tensor to clamp.
416    /// * `min` - The minimum value.
417    /// * `max` - The maximum value.
418    ///
419    /// # Returns
420    ///
421    /// The clamped tensor.
422    fn q_clamp(
423        tensor: QuantizedTensor<B>,
424        min: FloatElem<B>,
425        max: FloatElem<B>,
426    ) -> TensorPrimitive<B> {
427        dequant_op_flow!(
428            ty Self,
429            float_op |tensor| B::float_clamp(tensor, min, max),
430            tensor
431        )
432    }
433
434    /// Subtracts two tensors.
435    ///
436    /// # Arguments
437    ///
438    /// * `lhs` - The left hand side tensor.
439    /// * `rhs` - The right hand side tensor.
440    ///
441    /// # Returns
442    ///
443    /// The result of subtracting the two tensors.
444    fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
445        dequant_op_flow!(
446            ty Self,
447            float_op |lhs, rhs| B::float_sub(lhs, rhs),
448            lhs,
449            rhs
450        )
451    }
452
453    /// Subtracts a scalar from a tensor.
454    ///
455    /// # Arguments
456    ///
457    /// * `lhs` - The left hand side tensor.
458    /// * `rhs` - The right hand side scalar.
459    ///
460    /// # Returns
461    ///
462    /// The result of subtracting the scalar from the tensor.
463    fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
464        dequant_op_flow!(
465            ty Self,
466            float_op |tensor| B::float_sub_scalar(tensor, rhs),
467            lhs
468        )
469    }
470
471    /// Multiplies two tensors together element-wise.
472    fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
473        dequant_op_flow!(
474            ty Self,
475            float_op |lhs, rhs| B::float_mul(lhs, rhs),
476            lhs,
477            rhs
478        )
479    }
480
481    /// Multiplies a tensor by a scalar.
482    ///
483    /// # Arguments
484    ///
485    /// * `lhs` - The left hand side tensor.
486    /// * `rhs` - The right hand side scalar.
487    ///
488    /// # Returns
489    ///
490    /// The result of multiplying the tensor by the scalar.
491    fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
492        dequant_op_flow!(
493            ty Self,
494            float_op |tensor| B::float_mul_scalar(tensor, rhs),
495            lhs
496        )
497    }
498
499    /// Divides two tensors element-wise.
500    ///
501    /// # Arguments
502    ///
503    /// * `lhs` - The left hand side tensor.
504    /// * `rhs` - The right hand side tensor.
505    ///
506    /// # Returns
507    ///
508    /// The result of dividing the two tensors.
509    fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
510        dequant_op_flow!(
511            ty Self,
512            float_op |lhs, rhs| B::float_div(lhs, rhs),
513            lhs,
514            rhs
515        )
516    }
517
518    /// Divides a tensor by a scalar.
519    ///
520    /// # Arguments
521    ///
522    /// * `lhs` - The left hand side tensor.
523    /// * `rhs` - The right hand side scalar.
524    ///
525    /// # Returns
526    ///
527    /// The result of dividing the tensor by the scalar.
528    fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
529        dequant_op_flow!(
530            ty Self,
531            float_op |tensor| B::float_div_scalar(tensor, rhs),
532            lhs
533        )
534    }
535
536    /// Multiplies two tensors together using matrix multiplication.
537    ///
538    /// # Arguments
539    ///
540    /// * `lhs` - The left hand side tensor.
541    /// * `rhs` - The right hand side tensor.
542    ///
543    /// # Returns
544    ///
545    /// The result of multiplying the two tensors together using matrix multiplication.
546    fn q_matmul(lhs: TensorPrimitive<B>, rhs: TensorPrimitive<B>) -> TensorPrimitive<B> {
547        let mut propagation = QuantPropagation::Inhibit;
548        let mut scheme = QuantScheme::default();
549        let lhs = match lhs {
550            TensorPrimitive::Float(lhs) => lhs,
551            TensorPrimitive::QFloat(lhs) => {
552                propagation = lhs.propagation();
553                scheme = *lhs.scheme();
554                Self::dequantize(lhs)
555            }
556        };
557        let rhs = match rhs {
558            TensorPrimitive::Float(rhs) => rhs,
559            TensorPrimitive::QFloat(rhs) => {
560                propagation = rhs.propagation();
561                scheme = *rhs.scheme();
562                Self::dequantize(rhs)
563            }
564        };
565
566        let out_f = B::float_matmul(lhs, rhs);
567        match propagation {
568            QuantPropagation::Propagate => {
569                TensorPrimitive::QFloat(<Self>::quantize_dynamic(out_f, &scheme))
570            }
571            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
572        }
573    }
574
575    /// Negates a tensor element-wise.
576    fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
577        dequant_op_flow!(
578            ty Self,
579            float_op |tensor| B::float_neg(tensor),
580            tensor
581        )
582    }
583
584    /// Calculates the reciprocals element-wise
585    fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
586        dequant_op_flow!(
587            ty Self,
588            float_op |tensor| B::float_recip(tensor),
589            tensor
590        )
591    }
592
593    /// Sum of all elements in a tensor.
594    ///
595    /// # Arguments
596    ///
597    /// * `tensor` - The tensor to sum.
598    ///
599    /// # Returns
600    ///
601    /// A scalar tensor with the sum of all elements in `tensor`.
602    fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
603        dequant_op_flow!(
604            ty Self,
605            float_op |tensor| B::float_sum(tensor),
606            tensor
607        )
608    }
609
610    /// Sum of all elements in a tensor along a dimension.
611    ///
612    /// # Arguments
613    ///
614    /// * `tensor` - The tensor to sum.
615    /// * `dim` - The dimension along which to sum.
616    ///
617    /// # Returns
618    ///
619    /// A tensor with the sum of all elements in `tensor` along `dim`.
620    fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
621        dequant_op_flow!(
622            ty Self,
623            float_op |tensor| B::float_sum_dim(tensor, dim),
624            tensor
625        )
626    }
627
628    /// Product of all elements in a tensor.
629    ///
630    /// # Arguments
631    ///
632    /// * `tensor` - The tensor to product.
633    ///
634    /// # Returns
635    ///
636    /// A scalar tensor with the product of all elements in `tensor`.
637    fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
638        dequant_op_flow!(
639            ty Self,
640            float_op |tensor| B::float_prod(tensor),
641            tensor
642        )
643    }
644
645    /// Product of all elements in a tensor along a dimension.
646    ///
647    /// # Arguments
648    ///
649    /// * `tensor` - The tensor to product.
650    ///
651    /// # Returns
652    ///
653    /// A tensor with the product of all elements in `tensor` along `dim`.
654    fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
655        dequant_op_flow!(
656            ty Self,
657            float_op |tensor| B::float_prod_dim(tensor, dim),
658            tensor
659        )
660    }
661
662    /// Mean of all elements in a tensor.
663    ///
664    /// # Arguments
665    ///
666    /// * `tensor` - The tensor to mean.
667    ///
668    /// # Returns
669    ///
670    /// A scalar tensor with the mean of all elements in `tensor`.
671    fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
672        dequant_op_flow!(
673            ty Self,
674            float_op |tensor| B::float_mean(tensor),
675            tensor
676        )
677    }
678
679    /// Mean of all elements in a tensor along a dimension.
680    ///
681    /// # Arguments
682    ///
683    /// * `tensor` - The tensor to mean.
684    /// * `dim` - The dimension along which to mean.
685    ///
686    /// # Returns
687    ///
688    /// A tensor with the mean of all elements in `tensor` along `dim`.
689    fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
690        dequant_op_flow!(
691            ty Self,
692            float_op |tensor| B::float_mean_dim(tensor, dim),
693            tensor
694        )
695    }
696
697    /// Computes the cumulative sum of elements along a dimension.
698    ///
699    /// # Arguments
700    ///
701    /// * `tensor` - The tensor to compute the cumulative sum of.
702    /// * `dim` - The dimension along which to compute the cumulative sum.
703    ///
704    /// # Returns
705    ///
706    /// A tensor with the same shape where each element is the cumulative sum
707    /// of all elements up to and including that position along the dimension.
708    fn q_cumsum(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
709        dequant_op_flow!(
710            ty Self,
711            float_op |tensor| B::float_cumsum(tensor, dim),
712            tensor
713        )
714    }
715
716    /// Computes the cumulative product of elements along a dimension.
717    ///
718    /// # Arguments
719    ///
720    /// * `tensor` - The tensor to compute the cumulative product of.
721    /// * `dim` - The dimension along which to compute the cumulative product.
722    ///
723    /// # Returns
724    ///
725    /// A tensor with the same shape where each element is the cumulative product
726    /// of all elements up to and including that position along the dimension.
727    fn q_cumprod(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
728        dequant_op_flow!(
729            ty Self,
730            float_op |tensor| B::float_cumprod(tensor, dim),
731            tensor
732        )
733    }
734
735    /// Computes the cumulative minimum of elements along a dimension.
736    ///
737    /// # Arguments
738    ///
739    /// * `tensor` - The tensor to compute the cumulative minimum of.
740    /// * `dim` - The dimension along which to compute the cumulative minimum.
741    ///
742    /// # Returns
743    ///
744    /// A tensor with the same shape where each element is the minimum
745    /// of all elements up to and including that position along the dimension.
746    fn q_cummin(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
747        dequant_op_flow!(
748            ty Self,
749            float_op |tensor| B::float_cummin(tensor, dim),
750            tensor
751        )
752    }
753
754    /// Computes the cumulative maximum of elements along a dimension.
755    ///
756    /// # Arguments
757    ///
758    /// * `tensor` - The tensor to compute the cumulative maximum of.
759    /// * `dim` - The dimension along which to compute the cumulative maximum.
760    ///
761    /// # Returns
762    ///
763    /// A tensor with the same shape where each element is the maximum
764    /// of all elements up to and including that position along the dimension.
765    fn q_cummax(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
766        dequant_op_flow!(
767            ty Self,
768            float_op |tensor| B::float_cummax(tensor, dim),
769            tensor
770        )
771    }
772
773    /// Returns a new tensor with exponential values.
774    ///
775    /// # Arguments
776    ///
777    /// * `tensor` - The tensor to exponentiate.
778    ///
779    /// # Returns
780    ///
781    /// A tensor with the same shape as `tensor` with exponential values.
782    fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
783        dequant_op_flow!(
784            ty Self,
785            float_op |tensor| B::float_exp(tensor),
786            tensor
787        )
788    }
789
790    /// Returns a new tensor with natural logarithm values.
791    ///
792    /// # Arguments
793    ///
794    /// * `tensor` - The tensor to take the logarithm of.
795    ///
796    /// # Returns
797    ///
798    /// A tensor with the same shape as `tensor` with natural logarithm values.
799    fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
800        dequant_op_flow!(
801            ty Self,
802            float_op |tensor| B::float_log(tensor),
803            tensor
804        )
805    }
806
807    /// Returns a new tensor with logarithm values of (1 + Xi).
808    ///
809    /// # Arguments
810    ///
811    /// * `tensor` - The tensor to take the logarithm of.
812    ///
813    /// # Returns
814    ///
815    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
816    fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
817        dequant_op_flow!(
818            ty Self,
819            float_op |tensor| B::float_log1p(tensor),
820            tensor
821        )
822    }
823
824    /// Element-wise power with another tensor.
825    ///
826    /// # Arguments
827    ///
828    /// * `lhs` - The left hand side tensor.
829    /// * `rhs` - The right hand side tensor.
830    ///
831    /// # Returns
832    ///
833    /// The elements of `lhs` raised to the power of the elements of `rhs`.
834    fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
835        dequant_op_flow!(
836            ty Self,
837            float_op |lhs, rhs| B::float_powf(lhs, rhs),
838            lhs,
839            rhs
840        )
841    }
842
843    /// Element-wise power with an IntTensor.
844    ///
845    /// # Arguments
846    ///
847    /// * `lhs` - The left hand side tensor.
848    /// * `rhs` - The right hand side floatTensor.
849    ///
850    /// # Returns
851    ///
852    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
853    fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> {
854        dequant_op_flow!(
855            ty Self,
856            float_op |tensor| B::float_powi(tensor, rhs),
857            lhs
858        )
859    }
860
861    /// Element-wise power with an int scalar.
862    ///
863    /// # Arguments
864    ///
865    /// * `lhs` - The left hand side tensor.
866    /// * `rhs` - The right hand side scalar.
867    ///
868    /// # Returns
869    ///
870    /// The elements of `lhs` raised to the value of `rhs`.
871    fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: IntElem<B>) -> TensorPrimitive<B> {
872        dequant_op_flow!(
873            ty Self,
874            float_op |tensor| B::float_powi_scalar(tensor, rhs),
875            lhs
876        )
877    }
878
879    /// Element-wise power with a float scalar.
880    ///
881    /// # Arguments
882    ///
883    /// * `tensor` - The tensor to exponentiate.
884    /// * `value` - The exponent.
885    ///
886    /// # Returns
887    ///
888    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
889    fn q_powf_scalar(tensor: QuantizedTensor<B>, value: f32) -> TensorPrimitive<B> {
890        dequant_op_flow!(
891            ty Self,
892            float_op |tensor| B::float_powf_scalar(tensor, value),
893            tensor
894        )
895    }
896
897    /// Returns a new tensor with square root values.
898    ///
899    /// # Arguments
900    ///
901    /// * `tensor` - The tensor to take the square root of.
902    ///
903    /// # Returns
904    ///
905    /// A tensor with the same shape as `tensor` with square root values.
906    fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
907        dequant_op_flow!(
908            ty Self,
909            float_op |tensor| B::float_sqrt(tensor),
910            tensor
911        )
912    }
913
914    /// Returns a new tensor with absolute values.
915    ///
916    /// # Arguments
917    ///
918    /// * `tensor` - The tensor to take absolute value of.
919    ///
920    /// # Returns
921    ///
922    /// A tensor with the same shape as `tensor` with absolute values.
923    fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
924        dequant_op_quant!(
925            ty Self,
926            float_op |tensor| B::float_abs(tensor),
927            tensor
928        )
929    }
930
931    /// Returns a new tensor with cosine values.
932    ///
933    /// # Arguments
934    ///
935    /// * `tensor` - The tensor to take the cosine of.
936    ///
937    /// # Returns
938    ///
939    /// A tensor with the same shape as `tensor` with cosine values.
940    fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
941        dequant_op_flow!(
942            ty Self,
943            float_op |tensor| B::float_cos(tensor),
944            tensor
945        )
946    }
947
948    /// Returns a new tensor with sine values.
949    ///
950    /// # Arguments
951    ///
952    /// * `tensor` - The tensor to take the sine of.
953    ///
954    /// # Returns
955    ///
956    /// A tensor with the same shape as `tensor` with sine values.
957    fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
958        dequant_op_flow!(
959            ty Self,
960            float_op |tensor| B::float_sin(tensor),
961            tensor
962        )
963    }
964
965    /// Returns a new tensor with tangent values.
966    ///
967    /// # Arguments
968    ///
969    /// * `tensor` - The tensor to take the tangent of.
970    ///
971    /// # Returns
972    ///
973    /// A tensor with the same shape as `tensor` with tangent values.
974    fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
975        dequant_op_flow!(
976            ty Self,
977            float_op |tensor| B::float_tan(tensor),
978            tensor
979        )
980    }
981
982    /// Returns a new tensor with hyperbolic cosine values.
983    ///
984    /// # Arguments
985    ///
986    /// * `tensor` - The tensor to take the hyperbolic cosine of.
987    ///
988    /// # Returns
989    ///
990    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
991    fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
992        dequant_op_flow!(
993            ty Self,
994            float_op |tensor| B::float_cosh(tensor),
995            tensor
996        )
997    }
998
999    /// Returns a new tensor with hyperbolic sine values.
1000    ///
1001    /// # Arguments
1002    ///
1003    /// * `tensor` - The tensor to take the hyperbolic sine of.
1004    ///
1005    /// # Returns
1006    ///
1007    /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1008    fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1009        dequant_op_flow!(
1010            ty Self,
1011            float_op |tensor| B::float_sinh(tensor),
1012            tensor
1013        )
1014    }
1015
1016    /// Returns a new tensor with hyperbolic tangent values.
1017    ///
1018    /// # Arguments
1019    ///
1020    /// * `tensor` - The tensor to take the hyperbolic tangent of.
1021    ///
1022    /// # Returns
1023    ///
1024    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1025    fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1026        dequant_op_flow!(
1027            ty Self,
1028            float_op |tensor| B::float_tanh(tensor),
1029            tensor
1030        )
1031    }
1032
1033    /// Returns a new tensor with the error function values.
1034    ///
1035    /// # Arguments
1036    ///
1037    /// * `tensor` - The tensor to take the error function of.
1038    ///
1039    /// # Returns
1040    ///
1041    /// A tensor with the same shape as `tensor` with error function values.
1042    fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1043        dequant_op_flow!(
1044            ty Self,
1045            float_op |tensor| B::float_erf(tensor),
1046            tensor
1047        )
1048    }
1049
1050    /// Concatenates tensors along a dimension.
1051    ///
1052    /// # Arguments
1053    ///
1054    /// * `tensors` - The tensors to concatenate.
1055    /// * `dim` - The dimension along which to concatenate.
1056    ///
1057    /// # Returns
1058    ///
1059    /// A tensor with the concatenated tensors along `dim`.
1060    fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
1061        // Heuristic: prioritize first tensor scheme
1062        let scheme = *tensors.first().unwrap().scheme();
1063
1064        let tensor_f = tensors
1065            .into_iter()
1066            .map(|tensor| Self::dequantize(tensor))
1067            .collect();
1068
1069        let out_f = B::float_cat(tensor_f, dim);
1070
1071        Self::quantize_dynamic(out_f, &scheme)
1072    }
1073
1074    /// Gets the indices of the maximum elements of a tensor along an axis.
1075    ///
1076    /// # Arguments
1077    ///
1078    /// * `tensor` - The tensor to get the maximum elements of.
1079    /// * `dim` - The dimension along which to get the maximum elements.
1080    ///
1081    /// # Returns
1082    ///
1083    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1084    fn q_argmax(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1085        // Default implementation. Backends can sort on the int values since qparams remain the same.
1086        let tensor_f = Self::dequantize(tensor);
1087        B::float_argmax(tensor_f, dim)
1088    }
1089
1090    /// Gets the indices of the minimum elements of a tensor along an axis.
1091    ///
1092    /// # Arguments
1093    ///
1094    /// * `tensor` - The tensor to get the minimum elements of.
1095    /// * `dim` - The dimension along which to get the minimum elements.
1096    ///
1097    /// # Returns
1098    ///
1099    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1100    fn q_argmin(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1101        // Default implementation. Backends can sort on the int values since qparams remain the same.
1102        let tensor_f = Self::dequantize(tensor);
1103        B::float_argmin(tensor_f, dim)
1104    }
1105
1106    /// Gets the maximum element of a tensor.
1107    ///
1108    /// # Arguments
1109    ///
1110    /// * `tensor` - The tensor to get the maximum elements of.
1111    ///
1112    /// # Returns
1113    ///
1114    /// A tensor with the maximum element of `tensor`.
1115    fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1116        let shape = tensor.shape();
1117        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1118
1119        B::q_max_dim(tensor, 0)
1120    }
1121
1122    /// Gets the maximum elements of a tensor along an axis.
1123    ///
1124    /// # Arguments
1125    ///
1126    /// * `tensor` - The tensor to get the maximum elements of.
1127    /// * `dim` - The dimension along which to get the maximum elements.
1128    ///
1129    /// # Returns
1130    ///
1131    /// A tensor with the maximum elements of `tensor` along `dim`.
1132    fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1133        let index = B::q_argmax(tensor.clone(), dim);
1134
1135        B::q_gather(dim, tensor, index)
1136    }
1137
1138    /// Gets the maximum elements of a tensor along an axis and their indices.
1139    ///
1140    /// # Arguments
1141    ///
1142    /// * `tensor` - The tensor to get the maximum elements of.
1143    /// * `dim` - The dimension along which to get the maximum elements.
1144    ///
1145    /// # Returns
1146    ///
1147    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1148    fn q_max_dim_with_indices(
1149        tensor: QuantizedTensor<B>,
1150        dim: usize,
1151    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1152        let index = B::q_argmax(tensor.clone(), dim);
1153        let values = B::q_gather(dim, tensor, index.clone());
1154
1155        (values, index)
1156    }
1157
1158    /// Gets the minimum element of a tensor.
1159    ///
1160    /// # Arguments
1161    ///
1162    /// * `tensor` - The tensor to get the minimum elements of.
1163    ///
1164    /// # Returns
1165    ///
1166    /// A tensor with the minimum element of `tensor`.
1167    fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1168        let shape = tensor.shape();
1169        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1170
1171        B::q_min_dim(tensor, 0)
1172    }
1173
1174    /// Gets the minimum elements of a tensor along an axis.
1175    ///
1176    /// # Arguments
1177    ///
1178    /// * `tensor` - The tensor to get the minimum elements of.
1179    /// * `dim` - The dimension along which to get the minimum elements.
1180    ///
1181    /// # Returns
1182    ///
1183    /// A tensor with the minimum elements of `tensor` along `dim`.
1184    fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1185        let index = B::q_argmin(tensor.clone(), dim);
1186
1187        B::q_gather(dim, tensor, index)
1188    }
1189
1190    /// Gets the minimum elements of a tensor along an axis and their indices.
1191    ///
1192    /// # Arguments
1193    ///
1194    /// * `tensor` - The tensor to get the minimum elements of.
1195    /// * `dim` - The dimension along which to get the minimum elements.
1196    ///
1197    /// # Returns
1198    ///
1199    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1200    fn q_min_dim_with_indices(
1201        tensor: QuantizedTensor<B>,
1202        dim: usize,
1203    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1204        let index = B::q_argmin(tensor.clone(), dim);
1205        let values = B::q_gather(dim, tensor, index.clone());
1206
1207        (values, index)
1208    }
1209
1210    /// Gets the maximum element of a tensor.
1211    ///
1212    /// # Arguments
1213    ///
1214    /// * `tensor` - The tensor to get the maximum elements of.
1215    ///
1216    /// # Returns
1217    ///
1218    /// A tensor with the maximum element of `tensor`.
1219    fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1220        let shape = tensor.shape();
1221        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1222
1223        B::q_max_abs_dim(tensor, 0)
1224    }
1225
1226    /// Gets the maximum elements of a tensor along an axis.
1227    ///
1228    /// # Arguments
1229    ///
1230    /// * `tensor` - The tensor to get the maximum elements of.
1231    /// * `dim` - The dimension along which to get the maximum elements.
1232    ///
1233    /// # Returns
1234    ///
1235    /// A tensor with the maximum elements of `tensor` along `dim`.
1236    fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1237        let index = B::q_argmax(B::q_abs(tensor.clone()), dim);
1238
1239        B::q_gather(dim, tensor, index)
1240    }
1241
1242    /// Tests if any element in the `tensor` evaluates to True.
1243    ///
1244    /// # Arguments
1245    ///
1246    /// * `tensor` - The tensor to test.
1247    ///
1248    /// # Returns
1249    ///
1250    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1251    fn q_any(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1252        let tensor_f = Self::dequantize(tensor);
1253        B::float_any(tensor_f)
1254    }
1255
1256    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1257    ///
1258    /// # Arguments
1259    ///
1260    /// * `tensor` - The tensor to test.
1261    /// * `dim` - The axis along which to test.
1262    ///
1263    /// # Returns
1264    ///
1265    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1266    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1267    /// input evaluates to True, False otherwise.
1268    fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1269        let tensor_f = Self::dequantize(tensor);
1270        B::float_any_dim(tensor_f, dim)
1271    }
1272
1273    /// Tests if all elements in the `tensor` evaluate to True.
1274    ///
1275    /// # Arguments
1276    ///
1277    /// * `tensor` - The tensor to test.
1278    ///
1279    /// # Returns
1280    ///
1281    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1282    /// evaluate to True, False otherwise.
1283    fn q_all(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1284        let tensor_f = Self::dequantize(tensor);
1285        B::float_all(tensor_f)
1286    }
1287
1288    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1289    ///
1290    /// # Arguments
1291    ///
1292    /// * `tensor` - The tensor to test.
1293    /// * `dim` - The axis along which to test.
1294    ///
1295    /// # Returns
1296    ///
1297    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1298    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1299    /// evaluates to True, False otherwise.
1300    fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1301        let tensor_f = Self::dequantize(tensor);
1302        B::float_all_dim(tensor_f, dim)
1303    }
1304
1305    /// Sort the elements of the input `tensor` by value in along a given dimension.
1306    ///
1307    /// This sort is unstable (i.e., may reorder equal elements).
1308    ///
1309    /// # Arguments
1310    ///
1311    /// * `tensor` - The input tensor.
1312    /// * `dim` - The axis along which to sort.
1313    /// * `descending` - The sorting order.
1314    ///
1315    /// # Returns
1316    ///
1317    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1318    fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1319        // Default implementation. Backends can sort on the int values since qparams remain the same.
1320        dequant_op_quant!(
1321            ty Self,
1322            float_op |tensor| B::float_sort(tensor, dim, descending),
1323            tensor
1324        )
1325    }
1326
1327    /// Sort the elements of the input `tensor` by value in along a given dimension.
1328    ///
1329    /// This sort is unstable (i.e., may reorder equal elements).
1330    ///
1331    /// # Arguments
1332    ///
1333    /// * `tensor` - The input tensor.
1334    /// * `dim` - The axis along which to sort.
1335    /// * `descending` - The sorting order.
1336    ///
1337    /// # Returns
1338    ///
1339    /// A tensor with the same shape as the input tensor and corresponding indices, where
1340    /// the elements are sorted by value and the indices map back to the original input tensor.
1341    fn q_sort_with_indices(
1342        tensor: QuantizedTensor<B>,
1343        dim: usize,
1344        descending: bool,
1345    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1346        // Default implementation. Backends can sort on the int values since qparams remain the same.
1347        let scheme = *tensor.scheme();
1348
1349        let tensor_f = Self::dequantize(tensor);
1350        let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending);
1351
1352        (Self::quantize_dynamic(out_f, &scheme), indices)
1353    }
1354
1355    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1356    ///
1357    /// This sort is unstable (i.e., may reorder equal elements).
1358    ///
1359    /// # Arguments
1360    ///
1361    /// * `tensor` - The input tensor.
1362    /// * `dim` - The axis along which to sort.
1363    /// * `descending` - The sorting order.
1364    ///
1365    /// # Returns
1366    ///
1367    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1368    fn q_argsort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1369        // Default implementation. Backends can sort on the int values since qparams remain the same.
1370        let tensor_f = Self::dequantize(tensor);
1371        B::float_argsort(tensor_f, dim, descending)
1372    }
1373}