Skip to main content

burn_backend/backend/ops/
qtensor.rs

1use alloc::vec::Vec;
2use burn_std::{
3    Shape, Slice,
4    quantization::{QuantPropagation, QuantScheme},
5};
6
7use crate::{
8    Backend, ExecutionError, QTensorPrimitive, TensorData, TensorMetadata, TensorPrimitive,
9};
10use crate::{
11    Scalar,
12    tensor::{
13        BoolTensor, Device, FloatTensor, IntTensor, QuantizedTensor,
14        quantization::{
15            Calibration, QuantizationParametersPrimitive, compute_q_params, compute_range,
16        },
17    },
18};
19
20/// Automatically applies `dequantization -> float operation -> quantization`.
21///
22/// Used for tensor ops that should always return a quantized output.
23#[macro_export]
24macro_rules! dequant_op_quant {
25    // Binary tensor float op w/ lhs & rhs
26    (
27        ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
28    ) => {{
29        // Heuristic: prioritize lhs scheme
30        let scheme = $t1.scheme().clone();
31
32        let t1_f = <$ty>::dequantize($t1);
33        let t2_f = <$ty>::dequantize($t2);
34        #[allow(clippy::redundant_closure_call)]
35        let out_f = $float_op(t1_f, t2_f);
36
37        <$ty>::quantize_dynamic(out_f, &scheme)
38    }};
39    // Unary tensor float op
40    (
41        ty $ty:ty, float_op $float_op:expr, $tensor:expr
42    ) => {{
43        let scheme = $tensor.scheme().clone();
44
45        let tensor_f = <$ty>::dequantize($tensor);
46        #[allow(clippy::redundant_closure_call)]
47        let out_f = $float_op(tensor_f);
48
49        <$ty>::quantize_dynamic(out_f, &scheme)
50    }};
51}
52
53/// Automatically applies `dequantization -> float operation [-> quantization]`.
54///
55/// The output quantization step is optional.
56/// It is only performed when the input quantization scheme is propagated.
57#[macro_export]
58macro_rules! dequant_op_flow {
59    // Binary tensor float op w/ lhs & rhs
60    (
61        ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
62    ) => {{
63        // Heuristic: prioritize lhs scheme
64        let scheme = $t1.scheme().clone();
65        let propagation = $t1.propagation();
66
67        let t1_f = <$ty>::dequantize($t1);
68        let t2_f = <$ty>::dequantize($t2);
69        #[allow(clippy::redundant_closure_call)]
70        let out_f = $float_op(t1_f, t2_f);
71
72        match propagation {
73            QuantPropagation::Propagate => {
74                TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
75            }
76            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
77        }
78    }};
79    // Unary tensor float op
80    (
81        ty $ty:ty, float_op $float_op:expr, $tensor:expr
82    ) => {{
83        let scheme = $tensor.scheme().clone();
84        let propagation = $tensor.propagation();
85
86        let tensor_f = <$ty>::dequantize($tensor);
87        #[allow(clippy::redundant_closure_call)]
88        let out_f = $float_op(tensor_f);
89
90        match propagation {
91            QuantPropagation::Propagate => {
92                TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
93            }
94            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
95        }
96    }};
97}
98
99/// Operations on quantized tensors.
100///
101/// # Return Type Semantics
102///
103/// The return type of each operation indicates how quantization is handled:
104///
105/// ## [`QuantizedTensor<B>`]
106/// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized
107/// representation. Implementations should avoid dequantizing when possible to maintain performance.
108/// For example, shape or layout changes such as expand or transpose preserve quantization.
109///
110/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is
111/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering
112/// the quantization parameters to match the new layout.*
113///
114///
115/// ## [`TensorPrimitive<B>`]
116/// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation
117/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`])
118/// returned in floating-point form ([`TensorPrimitive::Float`]).
119///
120/// This distinction allows for fine-grained control over mixed-precision flows while still operating
121/// through a unified API.
122pub trait QTensorOps<B: Backend> {
123    /// Creates a new tensor from the data structure.
124    ///
125    /// # Arguments
126    ///
127    /// * `data` - The data structure.
128    /// * `device` - The device to create the tensor on.
129    ///
130    /// # Returns
131    ///
132    /// The tensor with the given data.
133    fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
134
135    /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
136    fn quantize(
137        tensor: FloatTensor<B>,
138        scheme: &QuantScheme,
139        qparams: QuantizationParametersPrimitive<B>,
140    ) -> QuantizedTensor<B>;
141
142    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
143    fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> {
144        // Dynamically compute min/max tensor range and qparams before quantizing
145        let (min, max) = compute_range::<B>(scheme, tensor.clone(), &Calibration::MinMax);
146        let qparams = compute_q_params(scheme, min, max);
147        Self::quantize(tensor, scheme, qparams)
148    }
149
150    /// Convert the tensor back to a higher precision data type.
151    fn dequantize(tensor: QuantizedTensor<B>) -> FloatTensor<B>;
152
153    /// Gets the device of the tensor.
154    ///
155    /// # Arguments
156    ///
157    /// * `tensor` - The tensor.
158    ///
159    /// # Returns
160    ///
161    /// The device of the tensor.
162    fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
163
164    /// Moves the tensor to the given device.
165    ///
166    /// # Arguments
167    ///
168    /// * `tensor` - The tensor.
169    /// * `device` - The device to move the tensor to.
170    ///
171    /// # Returns
172    ///
173    /// The tensor on the given device.
174    fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
175
176    /// Reshapes a tensor.
177    ///
178    /// # Arguments
179    ///
180    /// * `tensor` - The tensor to reshape.
181    /// * `shape` - The new shape of the tensor.
182    ///
183    /// # Returns
184    ///
185    /// The tensor with the new shape.
186    fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
187
188    /// Converts the tensor to a data structure.
189    ///
190    /// # Arguments
191    ///
192    /// * `tensor` - The tensor.
193    ///
194    /// # Returns
195    ///
196    /// The data structure with the tensor's data.
197    fn q_into_data(
198        tensor: QuantizedTensor<B>,
199    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
200
201    /// Detaches a tensor from the computation graph.
202    fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
203        // Should only be overridden by autodiff backends.
204        tensor
205    }
206
207    /// Sets the `require_grad` flag of a tensor.
208    fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
209        // Should only be overridden by autodiff backends.
210        tensor
211    }
212
213    /// Returns the `require_grad` flag of a tensor.
214    fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
215        // Should only be overridden by autodiff backends.
216        false
217    }
218
219    /// Broadcasts the `tensor` to the given `shape`.
220    fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
221
222    /// Transposes a tensor.
223    ///
224    /// # Arguments
225    ///
226    /// * `tensor` - The tensor to transpose.
227    ///
228    /// # Returns
229    ///
230    /// The transposed tensor.
231    fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
232        let ndims = tensor.shape().num_dims();
233        Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
234    }
235
236    /// Swaps two dimensions of a tensor.
237    ///
238    /// # Arguments
239    ///
240    /// * `tensor` - The tensor to swap the dimensions of.
241    /// * `dim1` - The first dimension to swap.
242    /// * `dim2` - The second dimension to swap.
243    ///
244    /// # Returns
245    ///
246    /// The tensor with the dimensions swapped.
247    fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
248
249    /// Permutes the dimensions of a tensor.
250    ///
251    /// # Arguments
252    ///
253    /// * `tensor` - The tensor to permute the dimensions of.
254    /// * `axes` - The new order of the dimensions.
255    /// # Returns
256    ///
257    /// The tensor with the dimensions permuted.
258    fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
259
260    /// Reverse the order of elements in a tensor along the given axes.
261    ///
262    /// # Arguments
263    ///
264    /// * `tensor` - The tensor to reverse.
265    /// * `axes` - The axes to reverse.
266    ///
267    /// The tensor with the elements reversed.
268    fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
269
270    /// Select tensor elements along the given dimension corresponding for the given indices.
271    ///
272    /// # Arguments
273    ///
274    /// * `tensor` - The tensor to select from.
275    /// * `dim` - The dimension to select from.
276    /// * `indices` - The indices to select.
277    ///
278    /// # Returns
279    ///
280    /// The selected elements.
281    fn q_select(
282        tensor: QuantizedTensor<B>,
283        dim: usize,
284        indices: IntTensor<B>,
285    ) -> QuantizedTensor<B>;
286
287    /// Select tensor elements corresponding to the given slices.
288    ///
289    /// # Arguments
290    ///
291    /// * `tensor` - The tensor to select from.
292    /// * `slices` - The slices specifying ranges and steps for each dimension.
293    ///
294    /// # Returns
295    ///
296    /// The selected elements in a new tensor.
297    fn q_slice(tensor: QuantizedTensor<B>, slices: &[Slice]) -> QuantizedTensor<B>;
298
299    /// Gather elements from a tensor.
300    ///
301    /// # Arguments
302    ///
303    /// * `dim` - The dimension to gather from.
304    /// * `tensor` - The tensor to gather from.
305    /// * `indices` - The indices to gather.
306    ///
307    /// # Returns
308    ///
309    /// The gathered elements.
310    fn q_gather(
311        dim: usize,
312        tensor: QuantizedTensor<B>,
313        indices: IntTensor<B>,
314    ) -> QuantizedTensor<B> {
315        // Default implementation. Backends can gather on the quantized values when supported.
316        dequant_op_quant!(
317            ty Self,
318            float_op |tensor| B::float_gather(dim, tensor, indices),
319            tensor
320        )
321    }
322
323    /// Repeat the tensor along the given dimension.
324    ///
325    /// # Arguments
326    ///
327    /// * `tensor` - The tensor.
328    /// * `dim` - The dimension to repeat.
329    /// * `times` - The number of times to repeat the dimension.
330    ///
331    /// # Returns
332    ///
333    /// The tensor with the given dimension repeated.
334    fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
335        dequant_op_quant!(
336            ty Self,
337            float_op |tensor| B::float_repeat_dim(tensor, dim, times),
338            tensor
339        )
340    }
341
342    /// Adds two tensors together.
343    ///
344    /// # Arguments
345    ///
346    /// * `lhs` - The left hand side tensor.
347    /// * `rhs` - The right hand side tensor.
348    ///
349    /// # Returns
350    ///
351    /// The result of adding the two tensors together.
352    fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
353        dequant_op_flow!(
354            ty Self,
355            float_op |lhs, rhs| B::float_add(lhs, rhs),
356            lhs,
357            rhs
358        )
359    }
360
361    /// Adds a scalar to a tensor.
362    ///
363    /// # Arguments
364    ///
365    /// * `lhs` - The left hand side tensor.
366    /// * `rhs` - The right hand side scalar.
367    ///
368    /// # Returns
369    ///
370    /// The result of adding the scalar to the tensor.
371    fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
372        dequant_op_flow!(
373            ty Self,
374            float_op |tensor| B::float_add_scalar(tensor, rhs),
375            lhs
376        )
377    }
378
379    /// Clamps a tensor under a minimum value.
380    ///
381    /// # Arguments
382    ///
383    /// * `tensor` - The tensor to clamp.
384    /// * `min` - The minimum value.
385    ///
386    /// # Returns
387    ///
388    /// The clamped tensor.
389    fn q_clamp_min(tensor: QuantizedTensor<B>, min: Scalar) -> TensorPrimitive<B> {
390        dequant_op_flow!(
391            ty Self,
392            float_op |tensor| B::float_clamp_min(tensor, min),
393            tensor
394        )
395    }
396
397    /// Clamps a tensor over a maximum value.
398    ///
399    /// # Arguments
400    ///
401    /// * `tensor` - The tensor to clamp.
402    /// * `max` - The maximum value.
403    ///
404    /// # Returns
405    ///
406    /// The clamped tensor.
407    fn q_clamp_max(tensor: QuantizedTensor<B>, max: Scalar) -> TensorPrimitive<B> {
408        dequant_op_flow!(
409            ty Self,
410            float_op |tensor| B::float_clamp_max(tensor, max),
411            tensor
412        )
413    }
414
415    /// Clamps a tensor between a minimum and maximum value.
416    ///
417    /// # Arguments
418    ///
419    /// * `tensor` - The tensor to clamp.
420    /// * `min` - The minimum value.
421    /// * `max` - The maximum value.
422    ///
423    /// # Returns
424    ///
425    /// The clamped tensor.
426    fn q_clamp(tensor: QuantizedTensor<B>, min: Scalar, max: Scalar) -> TensorPrimitive<B> {
427        dequant_op_flow!(
428            ty Self,
429            float_op |tensor| B::float_clamp(tensor, min, max),
430            tensor
431        )
432    }
433
434    /// Subtracts two tensors.
435    ///
436    /// # Arguments
437    ///
438    /// * `lhs` - The left hand side tensor.
439    /// * `rhs` - The right hand side tensor.
440    ///
441    /// # Returns
442    ///
443    /// The result of subtracting the two tensors.
444    fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
445        dequant_op_flow!(
446            ty Self,
447            float_op |lhs, rhs| B::float_sub(lhs, rhs),
448            lhs,
449            rhs
450        )
451    }
452
453    /// Subtracts a scalar from a tensor.
454    ///
455    /// # Arguments
456    ///
457    /// * `lhs` - The left hand side tensor.
458    /// * `rhs` - The right hand side scalar.
459    ///
460    /// # Returns
461    ///
462    /// The result of subtracting the scalar from the tensor.
463    fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
464        dequant_op_flow!(
465            ty Self,
466            float_op |tensor| B::float_sub_scalar(tensor, rhs),
467            lhs
468        )
469    }
470
471    /// Multiplies two tensors together element-wise.
472    fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
473        dequant_op_flow!(
474            ty Self,
475            float_op |lhs, rhs| B::float_mul(lhs, rhs),
476            lhs,
477            rhs
478        )
479    }
480
481    /// Multiplies a tensor by a scalar.
482    ///
483    /// # Arguments
484    ///
485    /// * `lhs` - The left hand side tensor.
486    /// * `rhs` - The right hand side scalar.
487    ///
488    /// # Returns
489    ///
490    /// The result of multiplying the tensor by the scalar.
491    fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
492        dequant_op_flow!(
493            ty Self,
494            float_op |tensor| B::float_mul_scalar(tensor, rhs),
495            lhs
496        )
497    }
498
499    /// Divides two tensors element-wise.
500    ///
501    /// # Arguments
502    ///
503    /// * `lhs` - The left hand side tensor.
504    /// * `rhs` - The right hand side tensor.
505    ///
506    /// # Returns
507    ///
508    /// The result of dividing the two tensors.
509    fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
510        dequant_op_flow!(
511            ty Self,
512            float_op |lhs, rhs| B::float_div(lhs, rhs),
513            lhs,
514            rhs
515        )
516    }
517
518    /// Divides a tensor by a scalar.
519    ///
520    /// # Arguments
521    ///
522    /// * `lhs` - The left hand side tensor.
523    /// * `rhs` - The right hand side scalar.
524    ///
525    /// # Returns
526    ///
527    /// The result of dividing the tensor by the scalar.
528    fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
529        dequant_op_flow!(
530            ty Self,
531            float_op |tensor| B::float_div_scalar(tensor, rhs),
532            lhs
533        )
534    }
535
536    /// Multiplies two tensors together using matrix multiplication.
537    ///
538    /// # Arguments
539    ///
540    /// * `lhs` - The left hand side tensor.
541    /// * `rhs` - The right hand side tensor.
542    ///
543    /// # Returns
544    ///
545    /// The result of multiplying the two tensors together using matrix multiplication.
546    fn q_matmul(lhs: TensorPrimitive<B>, rhs: TensorPrimitive<B>) -> TensorPrimitive<B> {
547        let mut propagation = QuantPropagation::Inhibit;
548        let mut scheme = QuantScheme::default();
549        let lhs = match lhs {
550            TensorPrimitive::Float(lhs) => lhs,
551            TensorPrimitive::QFloat(lhs) => {
552                propagation = lhs.propagation();
553                scheme = *lhs.scheme();
554                Self::dequantize(lhs)
555            }
556        };
557        let rhs = match rhs {
558            TensorPrimitive::Float(rhs) => rhs,
559            TensorPrimitive::QFloat(rhs) => {
560                propagation = rhs.propagation();
561                scheme = *rhs.scheme();
562                Self::dequantize(rhs)
563            }
564        };
565
566        let out_f = B::float_matmul(lhs, rhs);
567        match propagation {
568            QuantPropagation::Propagate => {
569                TensorPrimitive::QFloat(<Self>::quantize_dynamic(out_f, &scheme))
570            }
571            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
572        }
573    }
574
575    /// Negates a tensor element-wise.
576    fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
577        dequant_op_flow!(
578            ty Self,
579            float_op |tensor| B::float_neg(tensor),
580            tensor
581        )
582    }
583
584    /// Calculates the reciprocals element-wise
585    fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
586        dequant_op_flow!(
587            ty Self,
588            float_op |tensor| B::float_recip(tensor),
589            tensor
590        )
591    }
592
593    /// Sum of all elements in a tensor.
594    ///
595    /// # Arguments
596    ///
597    /// * `tensor` - The tensor to sum.
598    ///
599    /// # Returns
600    ///
601    /// A scalar tensor with the sum of all elements in `tensor`.
602    fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
603        dequant_op_flow!(
604            ty Self,
605            float_op |tensor| B::float_sum(tensor),
606            tensor
607        )
608    }
609
610    /// Sum of all elements in a tensor along a dimension.
611    ///
612    /// # Arguments
613    ///
614    /// * `tensor` - The tensor to sum.
615    /// * `dim` - The dimension along which to sum.
616    ///
617    /// # Returns
618    ///
619    /// A tensor with the sum of all elements in `tensor` along `dim`.
620    fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
621        dequant_op_flow!(
622            ty Self,
623            float_op |tensor| B::float_sum_dim(tensor, dim),
624            tensor
625        )
626    }
627
628    /// Product of all elements in a tensor.
629    ///
630    /// # Arguments
631    ///
632    /// * `tensor` - The tensor to product.
633    ///
634    /// # Returns
635    ///
636    /// A scalar tensor with the product of all elements in `tensor`.
637    fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
638        dequant_op_flow!(
639            ty Self,
640            float_op |tensor| B::float_prod(tensor),
641            tensor
642        )
643    }
644
645    /// Product of all elements in a tensor along a dimension.
646    ///
647    /// # Arguments
648    ///
649    /// * `tensor` - The tensor to product.
650    ///
651    /// # Returns
652    ///
653    /// A tensor with the product of all elements in `tensor` along `dim`.
654    fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
655        dequant_op_flow!(
656            ty Self,
657            float_op |tensor| B::float_prod_dim(tensor, dim),
658            tensor
659        )
660    }
661
662    /// Mean of all elements in a tensor.
663    ///
664    /// # Arguments
665    ///
666    /// * `tensor` - The tensor to mean.
667    ///
668    /// # Returns
669    ///
670    /// A scalar tensor with the mean of all elements in `tensor`.
671    fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
672        dequant_op_flow!(
673            ty Self,
674            float_op |tensor| B::float_mean(tensor),
675            tensor
676        )
677    }
678
679    /// Mean of all elements in a tensor along a dimension.
680    ///
681    /// # Arguments
682    ///
683    /// * `tensor` - The tensor to mean.
684    /// * `dim` - The dimension along which to mean.
685    ///
686    /// # Returns
687    ///
688    /// A tensor with the mean of all elements in `tensor` along `dim`.
689    fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
690        dequant_op_flow!(
691            ty Self,
692            float_op |tensor| B::float_mean_dim(tensor, dim),
693            tensor
694        )
695    }
696
697    /// Computes the cumulative sum of elements along a dimension.
698    ///
699    /// # Arguments
700    ///
701    /// * `tensor` - The tensor to compute the cumulative sum of.
702    /// * `dim` - The dimension along which to compute the cumulative sum.
703    ///
704    /// # Returns
705    ///
706    /// A tensor with the same shape where each element is the cumulative sum
707    /// of all elements up to and including that position along the dimension.
708    fn q_cumsum(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
709        dequant_op_flow!(
710            ty Self,
711            float_op |tensor| B::float_cumsum(tensor, dim),
712            tensor
713        )
714    }
715
716    /// Computes the cumulative product of elements along a dimension.
717    ///
718    /// # Arguments
719    ///
720    /// * `tensor` - The tensor to compute the cumulative product of.
721    /// * `dim` - The dimension along which to compute the cumulative product.
722    ///
723    /// # Returns
724    ///
725    /// A tensor with the same shape where each element is the cumulative product
726    /// of all elements up to and including that position along the dimension.
727    fn q_cumprod(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
728        dequant_op_flow!(
729            ty Self,
730            float_op |tensor| B::float_cumprod(tensor, dim),
731            tensor
732        )
733    }
734
735    /// Computes the cumulative minimum of elements along a dimension.
736    ///
737    /// # Arguments
738    ///
739    /// * `tensor` - The tensor to compute the cumulative minimum of.
740    /// * `dim` - The dimension along which to compute the cumulative minimum.
741    ///
742    /// # Returns
743    ///
744    /// A tensor with the same shape where each element is the minimum
745    /// of all elements up to and including that position along the dimension.
746    fn q_cummin(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
747        dequant_op_flow!(
748            ty Self,
749            float_op |tensor| B::float_cummin(tensor, dim),
750            tensor
751        )
752    }
753
754    /// Computes the cumulative maximum of elements along a dimension.
755    ///
756    /// # Arguments
757    ///
758    /// * `tensor` - The tensor to compute the cumulative maximum of.
759    /// * `dim` - The dimension along which to compute the cumulative maximum.
760    ///
761    /// # Returns
762    ///
763    /// A tensor with the same shape where each element is the maximum
764    /// of all elements up to and including that position along the dimension.
765    fn q_cummax(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
766        dequant_op_flow!(
767            ty Self,
768            float_op |tensor| B::float_cummax(tensor, dim),
769            tensor
770        )
771    }
772
773    /// Returns a new tensor with exponential values.
774    ///
775    /// # Arguments
776    ///
777    /// * `tensor` - The tensor to exponentiate.
778    ///
779    /// # Returns
780    ///
781    /// A tensor with the same shape as `tensor` with exponential values.
782    fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
783        dequant_op_flow!(
784            ty Self,
785            float_op |tensor| B::float_exp(tensor),
786            tensor
787        )
788    }
789
790    /// Returns a new tensor with natural logarithm values.
791    ///
792    /// # Arguments
793    ///
794    /// * `tensor` - The tensor to take the logarithm of.
795    ///
796    /// # Returns
797    ///
798    /// A tensor with the same shape as `tensor` with natural logarithm values.
799    fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
800        dequant_op_flow!(
801            ty Self,
802            float_op |tensor| B::float_log(tensor),
803            tensor
804        )
805    }
806
807    /// Returns a new tensor with logarithm values of (1 + Xi).
808    ///
809    /// # Arguments
810    ///
811    /// * `tensor` - The tensor to take the logarithm of.
812    ///
813    /// # Returns
814    ///
815    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
816    fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
817        dequant_op_flow!(
818            ty Self,
819            float_op |tensor| B::float_log1p(tensor),
820            tensor
821        )
822    }
823
824    /// Element-wise power with another tensor.
825    ///
826    /// # Arguments
827    ///
828    /// * `lhs` - The left hand side tensor.
829    /// * `rhs` - The right hand side tensor.
830    ///
831    /// # Returns
832    ///
833    /// The elements of `lhs` raised to the power of the elements of `rhs`.
834    fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
835        dequant_op_flow!(
836            ty Self,
837            float_op |lhs, rhs| B::float_powf(lhs, rhs),
838            lhs,
839            rhs
840        )
841    }
842
843    /// Element-wise power with an IntTensor.
844    ///
845    /// # Arguments
846    ///
847    /// * `lhs` - The left hand side tensor.
848    /// * `rhs` - The right hand side floatTensor.
849    ///
850    /// # Returns
851    ///
852    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
853    fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> {
854        dequant_op_flow!(
855            ty Self,
856            float_op |tensor| B::float_powi(tensor, rhs),
857            lhs
858        )
859    }
860
861    /// Element-wise power with an int scalar.
862    ///
863    /// # Arguments
864    ///
865    /// * `lhs` - The left hand side tensor.
866    /// * `rhs` - The right hand side scalar.
867    ///
868    /// # Returns
869    ///
870    /// The elements of `lhs` raised to the value of `rhs`.
871    fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
872        dequant_op_flow!(
873            ty Self,
874            float_op |tensor| B::float_powi_scalar(tensor, rhs),
875            lhs
876        )
877    }
878
879    /// Element-wise power with a float scalar.
880    ///
881    /// # Arguments
882    ///
883    /// * `tensor` - The tensor to exponentiate.
884    /// * `value` - The exponent.
885    ///
886    /// # Returns
887    ///
888    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
889    fn q_powf_scalar(tensor: QuantizedTensor<B>, value: Scalar) -> TensorPrimitive<B> {
890        dequant_op_flow!(
891            ty Self,
892            float_op |tensor| B::float_powf_scalar(tensor, value),
893            tensor
894        )
895    }
896
897    /// Returns a new tensor with square root values.
898    ///
899    /// # Arguments
900    ///
901    /// * `tensor` - The tensor to take the square root of.
902    ///
903    /// # Returns
904    ///
905    /// A tensor with the same shape as `tensor` with square root values.
906    fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
907        dequant_op_flow!(
908            ty Self,
909            float_op |tensor| B::float_sqrt(tensor),
910            tensor
911        )
912    }
913
914    /// Returns a new tensor with absolute values.
915    ///
916    /// # Arguments
917    ///
918    /// * `tensor` - The tensor to take absolute value of.
919    ///
920    /// # Returns
921    ///
922    /// A tensor with the same shape as `tensor` with absolute values.
923    fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
924        dequant_op_quant!(
925            ty Self,
926            float_op |tensor| B::float_abs(tensor),
927            tensor
928        )
929    }
930
931    /// Returns a new tensor with cosine values.
932    ///
933    /// # Arguments
934    ///
935    /// * `tensor` - The tensor to take the cosine of.
936    ///
937    /// # Returns
938    ///
939    /// A tensor with the same shape as `tensor` with cosine values.
940    fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
941        dequant_op_flow!(
942            ty Self,
943            float_op |tensor| B::float_cos(tensor),
944            tensor
945        )
946    }
947
948    /// Returns a new tensor with sine values.
949    ///
950    /// # Arguments
951    ///
952    /// * `tensor` - The tensor to take the sine of.
953    ///
954    /// # Returns
955    ///
956    /// A tensor with the same shape as `tensor` with sine values.
957    fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
958        dequant_op_flow!(
959            ty Self,
960            float_op |tensor| B::float_sin(tensor),
961            tensor
962        )
963    }
964
965    /// Returns a new tensor with tangent values.
966    ///
967    /// # Arguments
968    ///
969    /// * `tensor` - The tensor to take the tangent of.
970    ///
971    /// # Returns
972    ///
973    /// A tensor with the same shape as `tensor` with tangent values.
974    fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
975        dequant_op_flow!(
976            ty Self,
977            float_op |tensor| B::float_tan(tensor),
978            tensor
979        )
980    }
981
982    /// Returns a new tensor with hyperbolic cosine values.
983    ///
984    /// # Arguments
985    ///
986    /// * `tensor` - The tensor to take the hyperbolic cosine of.
987    ///
988    /// # Returns
989    ///
990    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
991    fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
992        dequant_op_flow!(
993            ty Self,
994            float_op |tensor| B::float_cosh(tensor),
995            tensor
996        )
997    }
998
999    /// Returns a new tensor with hyperbolic sine values.
1000    ///
1001    /// # Arguments
1002    ///
1003    /// * `tensor` - The tensor to take the hyperbolic sine of.
1004    ///
1005    /// # Returns
1006    ///
1007    /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1008    fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1009        dequant_op_flow!(
1010            ty Self,
1011            float_op |tensor| B::float_sinh(tensor),
1012            tensor
1013        )
1014    }
1015
1016    /// Returns a new tensor with hyperbolic tangent values.
1017    ///
1018    /// # Arguments
1019    ///
1020    /// * `tensor` - The tensor to take the hyperbolic tangent of.
1021    ///
1022    /// # Returns
1023    ///
1024    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1025    fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1026        dequant_op_flow!(
1027            ty Self,
1028            float_op |tensor| B::float_tanh(tensor),
1029            tensor
1030        )
1031    }
1032
1033    /// Returns a new tensor with the error function values.
1034    ///
1035    /// # Arguments
1036    ///
1037    /// * `tensor` - The tensor to take the error function of.
1038    ///
1039    /// # Returns
1040    ///
1041    /// A tensor with the same shape as `tensor` with error function values.
1042    fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1043        dequant_op_flow!(
1044            ty Self,
1045            float_op |tensor| B::float_erf(tensor),
1046            tensor
1047        )
1048    }
1049
1050    /// Concatenates tensors along a dimension.
1051    ///
1052    /// # Arguments
1053    ///
1054    /// * `tensors` - The tensors to concatenate.
1055    /// * `dim` - The dimension along which to concatenate.
1056    ///
1057    /// # Returns
1058    ///
1059    /// A tensor with the concatenated tensors along `dim`.
1060    fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
1061        // Heuristic: prioritize first tensor scheme
1062        let scheme = *tensors.first().unwrap().scheme();
1063
1064        let tensor_f = tensors
1065            .into_iter()
1066            .map(|tensor| Self::dequantize(tensor))
1067            .collect();
1068
1069        let out_f = B::float_cat(tensor_f, dim);
1070
1071        Self::quantize_dynamic(out_f, &scheme)
1072    }
1073
1074    /// Gets the indices of the maximum elements of a tensor along an axis.
1075    ///
1076    /// # Arguments
1077    ///
1078    /// * `tensor` - The tensor to get the maximum elements of.
1079    /// * `dim` - The dimension along which to get the maximum elements.
1080    ///
1081    /// # Returns
1082    ///
1083    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1084    fn q_argmax(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1085        // Default implementation. Backends can sort on the int values since qparams remain the same.
1086        let tensor_f = Self::dequantize(tensor);
1087        B::float_argmax(tensor_f, dim)
1088    }
1089
1090    /// Gets the indices of the minimum elements of a tensor along an axis.
1091    ///
1092    /// # Arguments
1093    ///
1094    /// * `tensor` - The tensor to get the minimum elements of.
1095    /// * `dim` - The dimension along which to get the minimum elements.
1096    ///
1097    /// # Returns
1098    ///
1099    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1100    fn q_argmin(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1101        // Default implementation. Backends can sort on the int values since qparams remain the same.
1102        let tensor_f = Self::dequantize(tensor);
1103        B::float_argmin(tensor_f, dim)
1104    }
1105
1106    /// Gets the maximum element of a tensor.
1107    ///
1108    /// # Arguments
1109    ///
1110    /// * `tensor` - The tensor to get the maximum elements of.
1111    ///
1112    /// # Returns
1113    ///
1114    /// A tensor with the maximum element of `tensor`.
1115    fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1116        let shape = tensor.shape();
1117        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1118
1119        B::q_max_dim(tensor, 0)
1120    }
1121
1122    /// Gets the maximum elements of a tensor along an axis.
1123    ///
1124    /// # Arguments
1125    ///
1126    /// * `tensor` - The tensor to get the maximum elements of.
1127    /// * `dim` - The dimension along which to get the maximum elements.
1128    ///
1129    /// # Returns
1130    ///
1131    /// A tensor with the maximum elements of `tensor` along `dim`.
1132    fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1133        let index = B::q_argmax(tensor.clone(), dim);
1134
1135        B::q_gather(dim, tensor, index)
1136    }
1137
1138    /// Gets the maximum elements of a tensor along an axis and their indices.
1139    ///
1140    /// # Arguments
1141    ///
1142    /// * `tensor` - The tensor to get the maximum elements of.
1143    /// * `dim` - The dimension along which to get the maximum elements.
1144    ///
1145    /// # Returns
1146    ///
1147    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1148    fn q_max_dim_with_indices(
1149        tensor: QuantizedTensor<B>,
1150        dim: usize,
1151    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1152        let index = B::q_argmax(tensor.clone(), dim);
1153        let values = B::q_gather(dim, tensor, index.clone());
1154
1155        (values, index)
1156    }
1157
1158    /// Gets the minimum element of a tensor.
1159    ///
1160    /// # Arguments
1161    ///
1162    /// * `tensor` - The tensor to get the minimum elements of.
1163    ///
1164    /// # Returns
1165    ///
1166    /// A tensor with the minimum element of `tensor`.
1167    fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1168        let shape = tensor.shape();
1169        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1170
1171        B::q_min_dim(tensor, 0)
1172    }
1173
1174    /// Gets the minimum elements of a tensor along an axis.
1175    ///
1176    /// # Arguments
1177    ///
1178    /// * `tensor` - The tensor to get the minimum elements of.
1179    /// * `dim` - The dimension along which to get the minimum elements.
1180    ///
1181    /// # Returns
1182    ///
1183    /// A tensor with the minimum elements of `tensor` along `dim`.
1184    fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1185        let index = B::q_argmin(tensor.clone(), dim);
1186
1187        B::q_gather(dim, tensor, index)
1188    }
1189
1190    /// Gets the minimum elements of a tensor along an axis and their indices.
1191    ///
1192    /// # Arguments
1193    ///
1194    /// * `tensor` - The tensor to get the minimum elements of.
1195    /// * `dim` - The dimension along which to get the minimum elements.
1196    ///
1197    /// # Returns
1198    ///
1199    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1200    fn q_min_dim_with_indices(
1201        tensor: QuantizedTensor<B>,
1202        dim: usize,
1203    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1204        let index = B::q_argmin(tensor.clone(), dim);
1205        let values = B::q_gather(dim, tensor, index.clone());
1206
1207        (values, index)
1208    }
1209
1210    /// Gets the maximum element of a tensor.
1211    ///
1212    /// # Arguments
1213    ///
1214    /// * `tensor` - The tensor to get the maximum elements of.
1215    ///
1216    /// # Returns
1217    ///
1218    /// A tensor with the maximum element of `tensor`.
1219    fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1220        let shape = tensor.shape();
1221        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1222
1223        B::q_max_abs_dim(tensor, 0)
1224    }
1225
1226    /// Gets the maximum elements of a tensor along an axis.
1227    ///
1228    /// # Arguments
1229    ///
1230    /// * `tensor` - The tensor to get the maximum elements of.
1231    /// * `dim` - The dimension along which to get the maximum elements.
1232    ///
1233    /// # Returns
1234    ///
1235    /// A tensor with the maximum elements of `tensor` along `dim`.
1236    fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1237        let index = B::q_argmax(B::q_abs(tensor.clone()), dim);
1238
1239        B::q_gather(dim, tensor, index)
1240    }
1241
1242    /// Tests if any element in the `tensor` evaluates to True.
1243    ///
1244    /// # Arguments
1245    ///
1246    /// * `tensor` - The tensor to test.
1247    ///
1248    /// # Returns
1249    ///
1250    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1251    fn q_any(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1252        let tensor_f = Self::dequantize(tensor);
1253        B::float_any(tensor_f)
1254    }
1255
1256    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1257    ///
1258    /// # Arguments
1259    ///
1260    /// * `tensor` - The tensor to test.
1261    /// * `dim` - The axis along which to test.
1262    ///
1263    /// # Returns
1264    ///
1265    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1266    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1267    /// input evaluates to True, False otherwise.
1268    fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1269        let tensor_f = Self::dequantize(tensor);
1270        B::float_any_dim(tensor_f, dim)
1271    }
1272
1273    /// Tests if all elements in the `tensor` evaluate to True.
1274    ///
1275    /// # Arguments
1276    ///
1277    /// * `tensor` - The tensor to test.
1278    ///
1279    /// # Returns
1280    ///
1281    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1282    /// evaluate to True, False otherwise.
1283    fn q_all(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1284        let tensor_f = Self::dequantize(tensor);
1285        B::float_all(tensor_f)
1286    }
1287
1288    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1289    ///
1290    /// # Arguments
1291    ///
1292    /// * `tensor` - The tensor to test.
1293    /// * `dim` - The axis along which to test.
1294    ///
1295    /// # Returns
1296    ///
1297    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1298    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1299    /// evaluates to True, False otherwise.
1300    fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1301        let tensor_f = Self::dequantize(tensor);
1302        B::float_all_dim(tensor_f, dim)
1303    }
1304
1305    /// Sort the elements of the input `tensor` by value in along a given dimension.
1306    ///
1307    /// This sort is unstable (i.e., may reorder equal elements).
1308    ///
1309    /// # Arguments
1310    ///
1311    /// * `tensor` - The input tensor.
1312    /// * `dim` - The axis along which to sort.
1313    /// * `descending` - The sorting order.
1314    ///
1315    /// # Returns
1316    ///
1317    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1318    fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1319        // Default implementation. Backends can sort on the int values since qparams remain the same.
1320        dequant_op_quant!(
1321            ty Self,
1322            float_op |tensor| B::float_sort(tensor, dim, descending),
1323            tensor
1324        )
1325    }
1326
1327    /// Sort the elements of the input `tensor` by value in along a given dimension.
1328    ///
1329    /// This sort is unstable (i.e., may reorder equal elements).
1330    ///
1331    /// # Arguments
1332    ///
1333    /// * `tensor` - The input tensor.
1334    /// * `dim` - The axis along which to sort.
1335    /// * `descending` - The sorting order.
1336    ///
1337    /// # Returns
1338    ///
1339    /// A tensor with the same shape as the input tensor and corresponding indices, where
1340    /// the elements are sorted by value and the indices map back to the original input tensor.
1341    fn q_sort_with_indices(
1342        tensor: QuantizedTensor<B>,
1343        dim: usize,
1344        descending: bool,
1345    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1346        // Default implementation. Backends can sort on the int values since qparams remain the same.
1347        let scheme = *tensor.scheme();
1348
1349        let tensor_f = Self::dequantize(tensor);
1350        let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending);
1351
1352        (Self::quantize_dynamic(out_f, &scheme), indices)
1353    }
1354
1355    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1356    ///
1357    /// This sort is unstable (i.e., may reorder equal elements).
1358    ///
1359    /// # Arguments
1360    ///
1361    /// * `tensor` - The input tensor.
1362    /// * `dim` - The axis along which to sort.
1363    /// * `descending` - The sorting order.
1364    ///
1365    /// # Returns
1366    ///
1367    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1368    fn q_argsort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1369        // Default implementation. Backends can sort on the int values since qparams remain the same.
1370        let tensor_f = Self::dequantize(tensor);
1371        B::float_argsort(tensor_f, dim, descending)
1372    }
1373}