burn_tensor/tensor/ops/
qtensor.rs

1use alloc::vec::Vec;
2use cubecl_quant::scheme::QuantScheme;
3
4use crate::{
5    Device, Shape, TensorData, TensorMetadata, TensorPrimitive,
6    backend::Backend,
7    quantization::{
8        Calibration, QTensorPrimitive, QuantPropagation, QuantizationParametersPrimitive,
9        compute_q_params_primitive, compute_range_primitive,
10    },
11};
12
13use super::{BoolTensor, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor};
14
15/// Automatically applies `dequantization -> float operation -> quantization`.
16///
17/// Used for tensor ops that should always return a quantized output.
18#[macro_export]
19macro_rules! dequant_op_quant {
20    // Binary tensor float op w/ lhs & rhs
21    (
22        ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
23    ) => {{
24        // Heuristic: prioritize lhs scheme
25        let scheme = $t1.scheme().clone();
26
27        let t1_f = <$ty>::dequantize($t1);
28        let t2_f = <$ty>::dequantize($t2);
29        #[allow(clippy::redundant_closure_call)]
30        let out_f = $float_op(t1_f, t2_f);
31
32        <$ty>::quantize_dynamic(out_f, &scheme)
33    }};
34    // Unary tensor float op
35    (
36        ty $ty:ty, float_op $float_op:expr, $tensor:expr
37    ) => {{
38        let scheme = $tensor.scheme().clone();
39
40        let tensor_f = <$ty>::dequantize($tensor);
41        #[allow(clippy::redundant_closure_call)]
42        let out_f = $float_op(tensor_f);
43
44        <$ty>::quantize_dynamic(out_f, &scheme)
45    }};
46}
47
48/// Automatically applies `dequantization -> float operation [-> quantization]`.
49///
50/// The output quantization step is optional.
51/// It is only performed when the input quantization scheme is propagated.
52#[macro_export]
53macro_rules! dequant_op_flow {
54    // Binary tensor float op w/ lhs & rhs
55    (
56        ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
57    ) => {{
58        // Heuristic: prioritize lhs scheme
59        let scheme = $t1.scheme().clone();
60        let propagation = $t1.propagation();
61
62        let t1_f = <$ty>::dequantize($t1);
63        let t2_f = <$ty>::dequantize($t2);
64        #[allow(clippy::redundant_closure_call)]
65        let out_f = $float_op(t1_f, t2_f);
66
67        match propagation {
68            QuantPropagation::Propagate => {
69                TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
70            }
71            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
72        }
73    }};
74    // Unary tensor float op
75    (
76        ty $ty:ty, float_op $float_op:expr, $tensor:expr
77    ) => {{
78        let scheme = $tensor.scheme().clone();
79        let propagation = $tensor.propagation();
80
81        let tensor_f = <$ty>::dequantize($tensor);
82        #[allow(clippy::redundant_closure_call)]
83        let out_f = $float_op(tensor_f);
84
85        match propagation {
86            QuantPropagation::Propagate => {
87                TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
88            }
89            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
90        }
91    }};
92}
93
94/// Operations on quantized tensors.
95///
96/// # Return Type Semantics
97///
98/// The return type of each operation indicates how quantization is handled:
99///
100/// ## [`QuantizedTensor<B>`]
101/// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized
102/// representation. Implementations should avoid dequantizing when possible to maintain performance.
103/// For example, shape or layout changes such as expand or transpose preserve quantization.
104///
105/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is
106/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering
107/// the quantization parameters to match the new layout.*
108///
109///
110/// ## [`TensorPrimitive<B>`]
111/// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation
112/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`])
113/// returned in floating-point form ([`TensorPrimitive::Float`]).
114///
115/// This distinction allows for fine-grained control over mixed-precision flows while still operating
116/// through a unified API.
117pub trait QTensorOps<B: Backend> {
118    /// Creates a new tensor from the data structure.
119    ///
120    /// # Arguments
121    ///
122    /// * `data` - The data structure.
123    /// * `device` - The device to create the tensor on.
124    ///
125    /// # Returns
126    ///
127    /// The tensor with the given data.
128    fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
129
130    /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
131    fn quantize(
132        tensor: FloatTensor<B>,
133        scheme: &QuantScheme,
134        qparams: QuantizationParametersPrimitive<B>,
135    ) -> QuantizedTensor<B>;
136
137    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
138    fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> {
139        // Dynamically compute min/max tensor range and qparams before quantizing
140        let (min, max) = compute_range_primitive::<B>(scheme, tensor.clone(), &Calibration::MinMax);
141        let qparams = compute_q_params_primitive(scheme, min, max);
142        Self::quantize(tensor, scheme, qparams)
143    }
144
145    /// Convert the tensor back to a higher precision data type.
146    fn dequantize(tensor: QuantizedTensor<B>) -> FloatTensor<B>;
147
148    /// Gets the device of the tensor.
149    ///
150    /// # Arguments
151    ///
152    /// * `tensor` - The tensor.
153    ///
154    /// # Returns
155    ///
156    /// The device of the tensor.
157    fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
158
159    /// Moves the tensor to the given device.
160    ///
161    /// # Arguments
162    ///
163    /// * `tensor` - The tensor.
164    /// * `device` - The device to move the tensor to.
165    ///
166    /// # Returns
167    ///
168    /// The tensor on the given device.
169    fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
170
171    /// Reshapes a tensor.
172    ///
173    /// # Arguments
174    ///
175    /// * `tensor` - The tensor to reshape.
176    /// * `shape` - The new shape of the tensor.
177    ///
178    /// # Returns
179    ///
180    /// The tensor with the new shape.
181    fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
182
183    /// Converts the tensor to a data structure.
184    ///
185    /// # Arguments
186    ///
187    /// * `tensor` - The tensor.
188    ///
189    /// # Returns
190    ///
191    /// The data structure with the tensor's data.
192    fn q_into_data(tensor: QuantizedTensor<B>) -> impl Future<Output = TensorData> + Send;
193
194    /// Detaches a tensor from the computation graph.
195    fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
196        // Should only be overridden by autodiff backends.
197        tensor
198    }
199
200    /// Sets the `require_grad` flag of a tensor.
201    fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
202        // Should only be overridden by autodiff backends.
203        tensor
204    }
205
206    /// Returns the `require_grad` flag of a tensor.
207    fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
208        // Should only be overridden by autodiff backends.
209        false
210    }
211
212    /// Broadcasts the `tensor` to the given `shape`.
213    fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
214
215    /// Transposes a tensor.
216    ///
217    /// # Arguments
218    ///
219    /// * `tensor` - The tensor to transpose.
220    ///
221    /// # Returns
222    ///
223    /// The transposed tensor.
224    fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
225        let ndims = tensor.shape().num_dims();
226        Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
227    }
228
229    /// Swaps two dimensions of a tensor.
230    ///
231    /// # Arguments
232    ///
233    /// * `tensor` - The tensor to swap the dimensions of.
234    /// * `dim1` - The first dimension to swap.
235    /// * `dim2` - The second dimension to swap.
236    ///
237    /// # Returns
238    ///
239    /// The tensor with the dimensions swapped.
240    fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
241
242    /// Permutes the dimensions of a tensor.
243    ///
244    /// # Arguments
245    ///
246    /// * `tensor` - The tensor to permute the dimensions of.
247    /// * `axes` - The new order of the dimensions.
248    /// # Returns
249    ///
250    /// The tensor with the dimensions permuted.
251    fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
252
253    /// Reverse the order of elements in a tensor along the given axes.
254    ///
255    /// # Arguments
256    ///
257    /// * `tensor` - The tensor to reverse.
258    /// * `axes` - The axes to reverse.
259    ///
260    /// The tensor with the elements reversed.
261    fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
262
263    /// Select tensor elements along the given dimension corresponding for the given indices.
264    ///
265    /// # Arguments
266    ///
267    /// * `tensor` - The tensor to select from.
268    /// * `dim` - The dimension to select from.
269    /// * `indices` - The indices to select.
270    ///
271    /// # Returns
272    ///
273    /// The selected elements.
274    fn q_select(
275        tensor: QuantizedTensor<B>,
276        dim: usize,
277        indices: IntTensor<B>,
278    ) -> QuantizedTensor<B>;
279
280    /// Select tensor elements corresponding to the given slices.
281    ///
282    /// # Arguments
283    ///
284    /// * `tensor` - The tensor to select from.
285    /// * `slices` - The slices specifying ranges and steps for each dimension.
286    ///
287    /// # Returns
288    ///
289    /// The selected elements in a new tensor.
290    fn q_slice(tensor: QuantizedTensor<B>, slices: &[crate::Slice]) -> QuantizedTensor<B>;
291
292    /// Gather elements from a tensor.
293    ///
294    /// # Arguments
295    ///
296    /// * `dim` - The dimension to gather from.
297    /// * `tensor` - The tensor to gather from.
298    /// * `indices` - The indices to gather.
299    ///
300    /// # Returns
301    ///
302    /// The gathered elements.
303    fn q_gather(
304        dim: usize,
305        tensor: QuantizedTensor<B>,
306        indices: IntTensor<B>,
307    ) -> QuantizedTensor<B> {
308        // Default implementation. Backends can gather on the quantized values when supported.
309        dequant_op_quant!(
310            ty Self,
311            float_op |tensor| B::float_gather(dim, tensor, indices),
312            tensor
313        )
314    }
315
316    /// Repeat the tensor along the given dimension.
317    ///
318    /// # Arguments
319    ///
320    /// * `tensor` - The tensor.
321    /// * `dim` - The dimension to repeat.
322    /// * `times` - The number of times to repeat the dimension.
323    ///
324    /// # Returns
325    ///
326    /// The tensor with the given dimension repeated.
327    fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
328        dequant_op_quant!(
329            ty Self,
330            float_op |tensor| B::float_repeat_dim(tensor, dim, times),
331            tensor
332        )
333    }
334
335    /// Adds two tensors together.
336    ///
337    /// # Arguments
338    ///
339    /// * `lhs` - The left hand side tensor.
340    /// * `rhs` - The right hand side tensor.
341    ///
342    /// # Returns
343    ///
344    /// The result of adding the two tensors together.
345    fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
346        dequant_op_flow!(
347            ty Self,
348            float_op |lhs, rhs| B::float_add(lhs, rhs),
349            lhs,
350            rhs
351        )
352    }
353
354    /// Adds a scalar to a tensor.
355    ///
356    /// # Arguments
357    ///
358    /// * `lhs` - The left hand side tensor.
359    /// * `rhs` - The right hand side scalar.
360    ///
361    /// # Returns
362    ///
363    /// The result of adding the scalar to the tensor.
364    fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
365        dequant_op_flow!(
366            ty Self,
367            float_op |tensor| B::float_add_scalar(tensor, rhs),
368            lhs
369        )
370    }
371
372    /// Clamps a tensor under a minimum value.
373    ///
374    /// # Arguments
375    ///
376    /// * `tensor` - The tensor to clamp.
377    /// * `min` - The minimum value.
378    ///
379    /// # Returns
380    ///
381    /// The clamped tensor.
382    fn q_clamp_min(tensor: QuantizedTensor<B>, min: FloatElem<B>) -> TensorPrimitive<B> {
383        dequant_op_flow!(
384            ty Self,
385            float_op |tensor| B::float_clamp_min(tensor, min),
386            tensor
387        )
388    }
389
390    /// Clamps a tensor over a maximum value.
391    ///
392    /// # Arguments
393    ///
394    /// * `tensor` - The tensor to clamp.
395    /// * `max` - The maximum value.
396    ///
397    /// # Returns
398    ///
399    /// The clamped tensor.
400    fn q_clamp_max(tensor: QuantizedTensor<B>, max: FloatElem<B>) -> TensorPrimitive<B> {
401        dequant_op_flow!(
402            ty Self,
403            float_op |tensor| B::float_clamp_max(tensor, max),
404            tensor
405        )
406    }
407
408    /// Clamps a tensor between a minimum and maximum value.
409    ///
410    /// # Arguments
411    ///
412    /// * `tensor` - The tensor to clamp.
413    /// * `min` - The minimum value.
414    /// * `max` - The maximum value.
415    ///
416    /// # Returns
417    ///
418    /// The clamped tensor.
419    fn q_clamp(
420        tensor: QuantizedTensor<B>,
421        min: FloatElem<B>,
422        max: FloatElem<B>,
423    ) -> TensorPrimitive<B> {
424        dequant_op_flow!(
425            ty Self,
426            float_op |tensor| B::float_clamp(tensor, min, max),
427            tensor
428        )
429    }
430
431    /// Subtracts two tensors.
432    ///
433    /// # Arguments
434    ///
435    /// * `lhs` - The left hand side tensor.
436    /// * `rhs` - The right hand side tensor.
437    ///
438    /// # Returns
439    ///
440    /// The result of subtracting the two tensors.
441    fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
442        dequant_op_flow!(
443            ty Self,
444            float_op |lhs, rhs| B::float_sub(lhs, rhs),
445            lhs,
446            rhs
447        )
448    }
449
450    /// Subtracts a scalar from a tensor.
451    ///
452    /// # Arguments
453    ///
454    /// * `lhs` - The left hand side tensor.
455    /// * `rhs` - The right hand side scalar.
456    ///
457    /// # Returns
458    ///
459    /// The result of subtracting the scalar from the tensor.
460    fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
461        dequant_op_flow!(
462            ty Self,
463            float_op |tensor| B::float_sub_scalar(tensor, rhs),
464            lhs
465        )
466    }
467
468    /// Multiplies two tensors together element-wise.
469    fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
470        dequant_op_flow!(
471            ty Self,
472            float_op |lhs, rhs| B::float_mul(lhs, rhs),
473            lhs,
474            rhs
475        )
476    }
477
478    /// Multiplies a tensor by a scalar.
479    ///
480    /// # Arguments
481    ///
482    /// * `lhs` - The left hand side tensor.
483    /// * `rhs` - The right hand side scalar.
484    ///
485    /// # Returns
486    ///
487    /// The result of multiplying the tensor by the scalar.
488    fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
489        dequant_op_flow!(
490            ty Self,
491            float_op |tensor| B::float_mul_scalar(tensor, rhs),
492            lhs
493        )
494    }
495
496    /// Divides two tensors element-wise.
497    ///
498    /// # Arguments
499    ///
500    /// * `lhs` - The left hand side tensor.
501    /// * `rhs` - The right hand side tensor.
502    ///
503    /// # Returns
504    ///
505    /// The result of dividing the two tensors.
506    fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
507        dequant_op_flow!(
508            ty Self,
509            float_op |lhs, rhs| B::float_div(lhs, rhs),
510            lhs,
511            rhs
512        )
513    }
514
515    /// Divides a tensor by a scalar.
516    ///
517    /// # Arguments
518    ///
519    /// * `lhs` - The left hand side tensor.
520    /// * `rhs` - The right hand side scalar.
521    ///
522    /// # Returns
523    ///
524    /// The result of dividing the tensor by the scalar.
525    fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
526        dequant_op_flow!(
527            ty Self,
528            float_op |tensor| B::float_div_scalar(tensor, rhs),
529            lhs
530        )
531    }
532
533    /// Multiplies two tensors together using matrix multiplication.
534    ///
535    /// # Arguments
536    ///
537    /// * `lhs` - The left hand side tensor.
538    /// * `rhs` - The right hand side tensor.
539    ///
540    /// # Returns
541    ///
542    /// The result of multiplying the two tensors together using matrix multiplication.
543    fn q_matmul(lhs: TensorPrimitive<B>, rhs: TensorPrimitive<B>) -> TensorPrimitive<B> {
544        let mut propagation = QuantPropagation::Inhibit;
545        let mut scheme = QuantScheme::default();
546        let lhs = match lhs {
547            TensorPrimitive::Float(lhs) => lhs,
548            TensorPrimitive::QFloat(lhs) => {
549                propagation = lhs.propagation();
550                scheme = *lhs.scheme();
551                Self::dequantize(lhs)
552            }
553        };
554        let rhs = match rhs {
555            TensorPrimitive::Float(rhs) => rhs,
556            TensorPrimitive::QFloat(rhs) => {
557                propagation = rhs.propagation();
558                scheme = *rhs.scheme();
559                Self::dequantize(rhs)
560            }
561        };
562
563        let out_f = B::float_matmul(lhs, rhs);
564        match propagation {
565            QuantPropagation::Propagate => {
566                TensorPrimitive::QFloat(<Self>::quantize_dynamic(out_f, &scheme))
567            }
568            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
569        }
570    }
571
572    /// Negates a tensor element-wise.
573    fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
574        dequant_op_flow!(
575            ty Self,
576            float_op |tensor| B::float_neg(tensor),
577            tensor
578        )
579    }
580
581    /// Calculates the reciprocals element-wise
582    fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
583        dequant_op_flow!(
584            ty Self,
585            float_op |tensor| B::float_recip(tensor),
586            tensor
587        )
588    }
589
590    /// Sum of all elements in a tensor.
591    ///
592    /// # Arguments
593    ///
594    /// * `tensor` - The tensor to sum.
595    ///
596    /// # Returns
597    ///
598    /// A scalar tensor with the sum of all elements in `tensor`.
599    fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
600        dequant_op_flow!(
601            ty Self,
602            float_op |tensor| B::float_sum(tensor),
603            tensor
604        )
605    }
606
607    /// Sum of all elements in a tensor along a dimension.
608    ///
609    /// # Arguments
610    ///
611    /// * `tensor` - The tensor to sum.
612    /// * `dim` - The dimension along which to sum.
613    ///
614    /// # Returns
615    ///
616    /// A tensor with the sum of all elements in `tensor` along `dim`.
617    fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
618        dequant_op_flow!(
619            ty Self,
620            float_op |tensor| B::float_sum_dim(tensor, dim),
621            tensor
622        )
623    }
624
625    /// Product of all elements in a tensor.
626    ///
627    /// # Arguments
628    ///
629    /// * `tensor` - The tensor to product.
630    ///
631    /// # Returns
632    ///
633    /// A scalar tensor with the product of all elements in `tensor`.
634    fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
635        dequant_op_flow!(
636            ty Self,
637            float_op |tensor| B::float_prod(tensor),
638            tensor
639        )
640    }
641
642    /// Product of all elements in a tensor along a dimension.
643    ///
644    /// # Arguments
645    ///
646    /// * `tensor` - The tensor to product.
647    ///
648    /// # Returns
649    ///
650    /// A tensor with the product of all elements in `tensor` along `dim`.
651    fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
652        dequant_op_flow!(
653            ty Self,
654            float_op |tensor| B::float_prod_dim(tensor, dim),
655            tensor
656        )
657    }
658
659    /// Mean of all elements in a tensor.
660    ///
661    /// # Arguments
662    ///
663    /// * `tensor` - The tensor to mean.
664    ///
665    /// # Returns
666    ///
667    /// A scalar tensor with the mean of all elements in `tensor`.
668    fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
669        dequant_op_flow!(
670            ty Self,
671            float_op |tensor| B::float_mean(tensor),
672            tensor
673        )
674    }
675
676    /// Mean of all elements in a tensor along a dimension.
677    ///
678    /// # Arguments
679    ///
680    /// * `tensor` - The tensor to mean.
681    /// * `dim` - The dimension along which to mean.
682    ///
683    /// # Returns
684    ///
685    /// A tensor with the mean of all elements in `tensor` along `dim`.
686    fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
687        dequant_op_flow!(
688            ty Self,
689            float_op |tensor| B::float_mean_dim(tensor, dim),
690            tensor
691        )
692    }
693
694    /// Computes the cumulative sum of elements along a dimension.
695    ///
696    /// # Arguments
697    ///
698    /// * `tensor` - The tensor to compute the cumulative sum of.
699    /// * `dim` - The dimension along which to compute the cumulative sum.
700    ///
701    /// # Returns
702    ///
703    /// A tensor with the same shape where each element is the cumulative sum
704    /// of all elements up to and including that position along the dimension.
705    fn q_cumsum(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
706        dequant_op_flow!(
707            ty Self,
708            float_op |tensor| B::float_cumsum(tensor, dim),
709            tensor
710        )
711    }
712
713    /// Computes the cumulative product of elements along a dimension.
714    ///
715    /// # Arguments
716    ///
717    /// * `tensor` - The tensor to compute the cumulative product of.
718    /// * `dim` - The dimension along which to compute the cumulative product.
719    ///
720    /// # Returns
721    ///
722    /// A tensor with the same shape where each element is the cumulative product
723    /// of all elements up to and including that position along the dimension.
724    fn q_cumprod(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
725        dequant_op_flow!(
726            ty Self,
727            float_op |tensor| B::float_cumprod(tensor, dim),
728            tensor
729        )
730    }
731
732    /// Computes the cumulative minimum of elements along a dimension.
733    ///
734    /// # Arguments
735    ///
736    /// * `tensor` - The tensor to compute the cumulative minimum of.
737    /// * `dim` - The dimension along which to compute the cumulative minimum.
738    ///
739    /// # Returns
740    ///
741    /// A tensor with the same shape where each element is the minimum
742    /// of all elements up to and including that position along the dimension.
743    fn q_cummin(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
744        dequant_op_flow!(
745            ty Self,
746            float_op |tensor| B::float_cummin(tensor, dim),
747            tensor
748        )
749    }
750
751    /// Computes the cumulative maximum of elements along a dimension.
752    ///
753    /// # Arguments
754    ///
755    /// * `tensor` - The tensor to compute the cumulative maximum of.
756    /// * `dim` - The dimension along which to compute the cumulative maximum.
757    ///
758    /// # Returns
759    ///
760    /// A tensor with the same shape where each element is the maximum
761    /// of all elements up to and including that position along the dimension.
762    fn q_cummax(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
763        dequant_op_flow!(
764            ty Self,
765            float_op |tensor| B::float_cummax(tensor, dim),
766            tensor
767        )
768    }
769
770    /// Returns a new tensor with exponential values.
771    ///
772    /// # Arguments
773    ///
774    /// * `tensor` - The tensor to exponentiate.
775    ///
776    /// # Returns
777    ///
778    /// A tensor with the same shape as `tensor` with exponential values.
779    fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
780        dequant_op_flow!(
781            ty Self,
782            float_op |tensor| B::float_exp(tensor),
783            tensor
784        )
785    }
786
787    /// Returns a new tensor with natural logarithm values.
788    ///
789    /// # Arguments
790    ///
791    /// * `tensor` - The tensor to take the logarithm of.
792    ///
793    /// # Returns
794    ///
795    /// A tensor with the same shape as `tensor` with natural logarithm values.
796    fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
797        dequant_op_flow!(
798            ty Self,
799            float_op |tensor| B::float_log(tensor),
800            tensor
801        )
802    }
803
804    /// Returns a new tensor with logarithm values of (1 + Xi).
805    ///
806    /// # Arguments
807    ///
808    /// * `tensor` - The tensor to take the logarithm of.
809    ///
810    /// # Returns
811    ///
812    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
813    fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
814        dequant_op_flow!(
815            ty Self,
816            float_op |tensor| B::float_log1p(tensor),
817            tensor
818        )
819    }
820
821    /// Element-wise power with another tensor.
822    ///
823    /// # Arguments
824    ///
825    /// * `lhs` - The left hand side tensor.
826    /// * `rhs` - The right hand side tensor.
827    ///
828    /// # Returns
829    ///
830    /// The elements of `lhs` raised to the power of the elements of `rhs`.
831    fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
832        dequant_op_flow!(
833            ty Self,
834            float_op |lhs, rhs| B::float_powf(lhs, rhs),
835            lhs,
836            rhs
837        )
838    }
839
840    /// Element-wise power with an IntTensor.
841    ///
842    /// # Arguments
843    ///
844    /// * `lhs` - The left hand side tensor.
845    /// * `rhs` - The right hand side floatTensor.
846    ///
847    /// # Returns
848    ///
849    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
850    fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> {
851        dequant_op_flow!(
852            ty Self,
853            float_op |tensor| B::float_powi(tensor, rhs),
854            lhs
855        )
856    }
857
858    /// Element-wise power with an int scalar.
859    ///
860    /// # Arguments
861    ///
862    /// * `lhs` - The left hand side tensor.
863    /// * `rhs` - The right hand side scalar.
864    ///
865    /// # Returns
866    ///
867    /// The elements of `lhs` raised to the value of `rhs`.
868    fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: IntElem<B>) -> TensorPrimitive<B> {
869        dequant_op_flow!(
870            ty Self,
871            float_op |tensor| B::float_powi_scalar(tensor, rhs),
872            lhs
873        )
874    }
875
876    /// Element-wise power with a float scalar.
877    ///
878    /// # Arguments
879    ///
880    /// * `tensor` - The tensor to exponentiate.
881    /// * `value` - The exponent.
882    ///
883    /// # Returns
884    ///
885    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
886    fn q_powf_scalar(tensor: QuantizedTensor<B>, value: f32) -> TensorPrimitive<B> {
887        dequant_op_flow!(
888            ty Self,
889            float_op |tensor| B::float_powf_scalar(tensor, value),
890            tensor
891        )
892    }
893
894    /// Returns a new tensor with square root values.
895    ///
896    /// # Arguments
897    ///
898    /// * `tensor` - The tensor to take the square root of.
899    ///
900    /// # Returns
901    ///
902    /// A tensor with the same shape as `tensor` with square root values.
903    fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
904        dequant_op_flow!(
905            ty Self,
906            float_op |tensor| B::float_sqrt(tensor),
907            tensor
908        )
909    }
910
911    /// Returns a new tensor with absolute values.
912    ///
913    /// # Arguments
914    ///
915    /// * `tensor` - The tensor to take absolute value of.
916    ///
917    /// # Returns
918    ///
919    /// A tensor with the same shape as `tensor` with absolute values.
920    fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
921        dequant_op_quant!(
922            ty Self,
923            float_op |tensor| B::float_abs(tensor),
924            tensor
925        )
926    }
927
928    /// Returns a new tensor with cosine values.
929    ///
930    /// # Arguments
931    ///
932    /// * `tensor` - The tensor to take the cosine of.
933    ///
934    /// # Returns
935    ///
936    /// A tensor with the same shape as `tensor` with cosine values.
937    fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
938        dequant_op_flow!(
939            ty Self,
940            float_op |tensor| B::float_cos(tensor),
941            tensor
942        )
943    }
944
945    /// Returns a new tensor with sine values.
946    ///
947    /// # Arguments
948    ///
949    /// * `tensor` - The tensor to take the sine of.
950    ///
951    /// # Returns
952    ///
953    /// A tensor with the same shape as `tensor` with sine values.
954    fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
955        dequant_op_flow!(
956            ty Self,
957            float_op |tensor| B::float_sin(tensor),
958            tensor
959        )
960    }
961
962    /// Returns a new tensor with tangent values.
963    ///
964    /// # Arguments
965    ///
966    /// * `tensor` - The tensor to take the tangent of.
967    ///
968    /// # Returns
969    ///
970    /// A tensor with the same shape as `tensor` with tangent values.
971    fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
972        dequant_op_flow!(
973            ty Self,
974            float_op |tensor| B::float_tan(tensor),
975            tensor
976        )
977    }
978
979    /// Returns a new tensor with hyperbolic cosine values.
980    ///
981    /// # Arguments
982    ///
983    /// * `tensor` - The tensor to take the hyperbolic cosine of.
984    ///
985    /// # Returns
986    ///
987    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
988    fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
989        dequant_op_flow!(
990            ty Self,
991            float_op |tensor| B::float_cosh(tensor),
992            tensor
993        )
994    }
995
996    /// Returns a new tensor with hyperbolic sine values.
997    ///
998    /// # Arguments
999    ///
1000    /// * `tensor` - The tensor to take the hyperbolic sine of.
1001    ///
1002    /// # Returns
1003    ///
1004    /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1005    fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1006        dequant_op_flow!(
1007            ty Self,
1008            float_op |tensor| B::float_sinh(tensor),
1009            tensor
1010        )
1011    }
1012
1013    /// Returns a new tensor with hyperbolic tangent values.
1014    ///
1015    /// # Arguments
1016    ///
1017    /// * `tensor` - The tensor to take the hyperbolic tangent of.
1018    ///
1019    /// # Returns
1020    ///
1021    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1022    fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1023        dequant_op_flow!(
1024            ty Self,
1025            float_op |tensor| B::float_tanh(tensor),
1026            tensor
1027        )
1028    }
1029
1030    /// Returns a new tensor with the error function values.
1031    ///
1032    /// # Arguments
1033    ///
1034    /// * `tensor` - The tensor to take the error function of.
1035    ///
1036    /// # Returns
1037    ///
1038    /// A tensor with the same shape as `tensor` with error function values.
1039    fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1040        dequant_op_flow!(
1041            ty Self,
1042            float_op |tensor| B::float_erf(tensor),
1043            tensor
1044        )
1045    }
1046
1047    /// Concatenates tensors along a dimension.
1048    ///
1049    /// # Arguments
1050    ///
1051    /// * `tensors` - The tensors to concatenate.
1052    /// * `dim` - The dimension along which to concatenate.
1053    ///
1054    /// # Returns
1055    ///
1056    /// A tensor with the concatenated tensors along `dim`.
1057    fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
1058        // Heuristic: prioritize first tensor scheme
1059        let scheme = *tensors.first().unwrap().scheme();
1060
1061        let tensor_f = tensors
1062            .into_iter()
1063            .map(|tensor| Self::dequantize(tensor))
1064            .collect();
1065
1066        let out_f = B::float_cat(tensor_f, dim);
1067
1068        Self::quantize_dynamic(out_f, &scheme)
1069    }
1070
1071    /// Gets the indices of the maximum elements of a tensor along an axis.
1072    ///
1073    /// # Arguments
1074    ///
1075    /// * `tensor` - The tensor to get the maximum elements of.
1076    /// * `dim` - The dimension along which to get the maximum elements.
1077    ///
1078    /// # Returns
1079    ///
1080    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1081    fn q_argmax(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1082        // Default implementation. Backends can sort on the int values since qparams remain the same.
1083        let tensor_f = Self::dequantize(tensor);
1084        B::float_argmax(tensor_f, dim)
1085    }
1086
1087    /// Gets the indices of the minimum elements of a tensor along an axis.
1088    ///
1089    /// # Arguments
1090    ///
1091    /// * `tensor` - The tensor to get the minimum elements of.
1092    /// * `dim` - The dimension along which to get the minimum elements.
1093    ///
1094    /// # Returns
1095    ///
1096    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1097    fn q_argmin(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1098        // Default implementation. Backends can sort on the int values since qparams remain the same.
1099        let tensor_f = Self::dequantize(tensor);
1100        B::float_argmin(tensor_f, dim)
1101    }
1102
1103    /// Gets the maximum element of a tensor.
1104    ///
1105    /// # Arguments
1106    ///
1107    /// * `tensor` - The tensor to get the maximum elements of.
1108    ///
1109    /// # Returns
1110    ///
1111    /// A tensor with the maximum element of `tensor`.
1112    fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1113        let shape = tensor.shape();
1114        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1115
1116        B::q_max_dim(tensor, 0)
1117    }
1118
1119    /// Gets the maximum elements of a tensor along an axis.
1120    ///
1121    /// # Arguments
1122    ///
1123    /// * `tensor` - The tensor to get the maximum elements of.
1124    /// * `dim` - The dimension along which to get the maximum elements.
1125    ///
1126    /// # Returns
1127    ///
1128    /// A tensor with the maximum elements of `tensor` along `dim`.
1129    fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1130        let index = B::q_argmax(tensor.clone(), dim);
1131
1132        B::q_gather(dim, tensor, index)
1133    }
1134
1135    /// Gets the maximum elements of a tensor along an axis and their indices.
1136    ///
1137    /// # Arguments
1138    ///
1139    /// * `tensor` - The tensor to get the maximum elements of.
1140    /// * `dim` - The dimension along which to get the maximum elements.
1141    ///
1142    /// # Returns
1143    ///
1144    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1145    fn q_max_dim_with_indices(
1146        tensor: QuantizedTensor<B>,
1147        dim: usize,
1148    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1149        let index = B::q_argmax(tensor.clone(), dim);
1150        let values = B::q_gather(dim, tensor, index.clone());
1151
1152        (values, index)
1153    }
1154
1155    /// Gets the minimum element of a tensor.
1156    ///
1157    /// # Arguments
1158    ///
1159    /// * `tensor` - The tensor to get the minimum elements of.
1160    ///
1161    /// # Returns
1162    ///
1163    /// A tensor with the minimum element of `tensor`.
1164    fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1165        let shape = tensor.shape();
1166        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1167
1168        B::q_min_dim(tensor, 0)
1169    }
1170
1171    /// Gets the minimum elements of a tensor along an axis.
1172    ///
1173    /// # Arguments
1174    ///
1175    /// * `tensor` - The tensor to get the minimum elements of.
1176    /// * `dim` - The dimension along which to get the minimum elements.
1177    ///
1178    /// # Returns
1179    ///
1180    /// A tensor with the minimum elements of `tensor` along `dim`.
1181    fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1182        let index = B::q_argmin(tensor.clone(), dim);
1183
1184        B::q_gather(dim, tensor, index)
1185    }
1186
1187    /// Gets the minimum elements of a tensor along an axis and their indices.
1188    ///
1189    /// # Arguments
1190    ///
1191    /// * `tensor` - The tensor to get the minimum elements of.
1192    /// * `dim` - The dimension along which to get the minimum elements.
1193    ///
1194    /// # Returns
1195    ///
1196    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1197    fn q_min_dim_with_indices(
1198        tensor: QuantizedTensor<B>,
1199        dim: usize,
1200    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1201        let index = B::q_argmin(tensor.clone(), dim);
1202        let values = B::q_gather(dim, tensor, index.clone());
1203
1204        (values, index)
1205    }
1206
1207    /// Gets the maximum element of a tensor.
1208    ///
1209    /// # Arguments
1210    ///
1211    /// * `tensor` - The tensor to get the maximum elements of.
1212    ///
1213    /// # Returns
1214    ///
1215    /// A tensor with the maximum element of `tensor`.
1216    fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1217        let shape = tensor.shape();
1218        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1219
1220        B::q_max_abs_dim(tensor, 0)
1221    }
1222
1223    /// Gets the maximum elements of a tensor along an axis.
1224    ///
1225    /// # Arguments
1226    ///
1227    /// * `tensor` - The tensor to get the maximum elements of.
1228    /// * `dim` - The dimension along which to get the maximum elements.
1229    ///
1230    /// # Returns
1231    ///
1232    /// A tensor with the maximum elements of `tensor` along `dim`.
1233    fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1234        let index = B::q_argmax(B::q_abs(tensor.clone()), dim);
1235
1236        B::q_gather(dim, tensor, index)
1237    }
1238
1239    /// Tests if any element in the `tensor` evaluates to True.
1240    ///
1241    /// # Arguments
1242    ///
1243    /// * `tensor` - The tensor to test.
1244    ///
1245    /// # Returns
1246    ///
1247    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1248    fn q_any(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1249        let tensor_f = Self::dequantize(tensor);
1250        B::float_any(tensor_f)
1251    }
1252
1253    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1254    ///
1255    /// # Arguments
1256    ///
1257    /// * `tensor` - The tensor to test.
1258    /// * `dim` - The axis along which to test.
1259    ///
1260    /// # Returns
1261    ///
1262    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1263    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1264    /// input evaluates to True, False otherwise.
1265    fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1266        let tensor_f = Self::dequantize(tensor);
1267        B::float_any_dim(tensor_f, dim)
1268    }
1269
1270    /// Tests if all elements in the `tensor` evaluate to True.
1271    ///
1272    /// # Arguments
1273    ///
1274    /// * `tensor` - The tensor to test.
1275    ///
1276    /// # Returns
1277    ///
1278    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1279    /// evaluate to True, False otherwise.
1280    fn q_all(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1281        let tensor_f = Self::dequantize(tensor);
1282        B::float_all(tensor_f)
1283    }
1284
1285    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1286    ///
1287    /// # Arguments
1288    ///
1289    /// * `tensor` - The tensor to test.
1290    /// * `dim` - The axis along which to test.
1291    ///
1292    /// # Returns
1293    ///
1294    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1295    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1296    /// evaluates to True, False otherwise.
1297    fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1298        let tensor_f = Self::dequantize(tensor);
1299        B::float_all_dim(tensor_f, dim)
1300    }
1301
1302    /// Sort the elements of the input `tensor` by value in along a given dimension.
1303    ///
1304    /// This sort is unstable (i.e., may reorder equal elements).
1305    ///
1306    /// # Arguments
1307    ///
1308    /// * `tensor` - The input tensor.
1309    /// * `dim` - The axis along which to sort.
1310    /// * `descending` - The sorting order.
1311    ///
1312    /// # Returns
1313    ///
1314    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1315    fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1316        // Default implementation. Backends can sort on the int values since qparams remain the same.
1317        dequant_op_quant!(
1318            ty Self,
1319            float_op |tensor| B::float_sort(tensor, dim, descending),
1320            tensor
1321        )
1322    }
1323
1324    /// Sort the elements of the input `tensor` by value in along a given dimension.
1325    ///
1326    /// This sort is unstable (i.e., may reorder equal elements).
1327    ///
1328    /// # Arguments
1329    ///
1330    /// * `tensor` - The input tensor.
1331    /// * `dim` - The axis along which to sort.
1332    /// * `descending` - The sorting order.
1333    ///
1334    /// # Returns
1335    ///
1336    /// A tensor with the same shape as the input tensor and corresponding indices, where
1337    /// the elements are sorted by value and the indices map back to the original input tensor.
1338    fn q_sort_with_indices(
1339        tensor: QuantizedTensor<B>,
1340        dim: usize,
1341        descending: bool,
1342    ) -> (QuantizedTensor<B>, IntTensor<B>) {
1343        // Default implementation. Backends can sort on the int values since qparams remain the same.
1344        let scheme = *tensor.scheme();
1345
1346        let tensor_f = Self::dequantize(tensor);
1347        let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending);
1348
1349        (Self::quantize_dynamic(out_f, &scheme), indices)
1350    }
1351
1352    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1353    ///
1354    /// This sort is unstable (i.e., may reorder equal elements).
1355    ///
1356    /// # Arguments
1357    ///
1358    /// * `tensor` - The input tensor.
1359    /// * `dim` - The axis along which to sort.
1360    /// * `descending` - The sorting order.
1361    ///
1362    /// # Returns
1363    ///
1364    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1365    fn q_argsort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1366        // Default implementation. Backends can sort on the int values since qparams remain the same.
1367        let tensor_f = Self::dequantize(tensor);
1368        B::float_argsort(tensor_f, dim, descending)
1369    }
1370}