burn_backend/backend/ops/qtensor.rs
1use alloc::vec::Vec;
2use burn_std::{
3 Shape, Slice,
4 quantization::{QuantPropagation, QuantScheme},
5};
6
7use crate::{
8 Backend, ExecutionError, QTensorPrimitive, TensorData, TensorMetadata, TensorPrimitive,
9 tensor::{Calibration, QuantizationParametersPrimitive, compute_q_params, compute_range},
10};
11
12use crate::tensor::{
13 BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor,
14};
15
16/// Automatically applies `dequantization -> float operation -> quantization`.
17///
18/// Used for tensor ops that should always return a quantized output.
19#[macro_export]
20macro_rules! dequant_op_quant {
21 // Binary tensor float op w/ lhs & rhs
22 (
23 ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
24 ) => {{
25 // Heuristic: prioritize lhs scheme
26 let scheme = $t1.scheme().clone();
27
28 let t1_f = <$ty>::dequantize($t1);
29 let t2_f = <$ty>::dequantize($t2);
30 #[allow(clippy::redundant_closure_call)]
31 let out_f = $float_op(t1_f, t2_f);
32
33 <$ty>::quantize_dynamic(out_f, &scheme)
34 }};
35 // Unary tensor float op
36 (
37 ty $ty:ty, float_op $float_op:expr, $tensor:expr
38 ) => {{
39 let scheme = $tensor.scheme().clone();
40
41 let tensor_f = <$ty>::dequantize($tensor);
42 #[allow(clippy::redundant_closure_call)]
43 let out_f = $float_op(tensor_f);
44
45 <$ty>::quantize_dynamic(out_f, &scheme)
46 }};
47}
48
49/// Automatically applies `dequantization -> float operation [-> quantization]`.
50///
51/// The output quantization step is optional.
52/// It is only performed when the input quantization scheme is propagated.
53#[macro_export]
54macro_rules! dequant_op_flow {
55 // Binary tensor float op w/ lhs & rhs
56 (
57 ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
58 ) => {{
59 // Heuristic: prioritize lhs scheme
60 let scheme = $t1.scheme().clone();
61 let propagation = $t1.propagation();
62
63 let t1_f = <$ty>::dequantize($t1);
64 let t2_f = <$ty>::dequantize($t2);
65 #[allow(clippy::redundant_closure_call)]
66 let out_f = $float_op(t1_f, t2_f);
67
68 match propagation {
69 QuantPropagation::Propagate => {
70 TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
71 }
72 QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
73 }
74 }};
75 // Unary tensor float op
76 (
77 ty $ty:ty, float_op $float_op:expr, $tensor:expr
78 ) => {{
79 let scheme = $tensor.scheme().clone();
80 let propagation = $tensor.propagation();
81
82 let tensor_f = <$ty>::dequantize($tensor);
83 #[allow(clippy::redundant_closure_call)]
84 let out_f = $float_op(tensor_f);
85
86 match propagation {
87 QuantPropagation::Propagate => {
88 TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
89 }
90 QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
91 }
92 }};
93}
94
95/// Operations on quantized tensors.
96///
97/// # Return Type Semantics
98///
99/// The return type of each operation indicates how quantization is handled:
100///
101/// ## [`QuantizedTensor<B>`]
102/// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized
103/// representation. Implementations should avoid dequantizing when possible to maintain performance.
104/// For example, shape or layout changes such as expand or transpose preserve quantization.
105///
106/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is
107/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering
108/// the quantization parameters to match the new layout.*
109///
110///
111/// ## [`TensorPrimitive<B>`]
112/// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation
113/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`])
114/// returned in floating-point form ([`TensorPrimitive::Float`]).
115///
116/// This distinction allows for fine-grained control over mixed-precision flows while still operating
117/// through a unified API.
118pub trait QTensorOps<B: Backend> {
119 /// Creates a new tensor from the data structure.
120 ///
121 /// # Arguments
122 ///
123 /// * `data` - The data structure.
124 /// * `device` - The device to create the tensor on.
125 ///
126 /// # Returns
127 ///
128 /// The tensor with the given data.
129 fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
130
131 /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
132 fn quantize(
133 tensor: FloatTensor<B>,
134 scheme: &QuantScheme,
135 qparams: QuantizationParametersPrimitive<B>,
136 ) -> QuantizedTensor<B>;
137
138 /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
139 fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> {
140 // Dynamically compute min/max tensor range and qparams before quantizing
141 let (min, max) = compute_range::<B>(scheme, tensor.clone(), &Calibration::MinMax);
142 let qparams = compute_q_params(scheme, min, max);
143 Self::quantize(tensor, scheme, qparams)
144 }
145
146 /// Convert the tensor back to a higher precision data type.
147 fn dequantize(tensor: QuantizedTensor<B>) -> FloatTensor<B>;
148
149 /// Gets the device of the tensor.
150 ///
151 /// # Arguments
152 ///
153 /// * `tensor` - The tensor.
154 ///
155 /// # Returns
156 ///
157 /// The device of the tensor.
158 fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
159
160 /// Moves the tensor to the given device.
161 ///
162 /// # Arguments
163 ///
164 /// * `tensor` - The tensor.
165 /// * `device` - The device to move the tensor to.
166 ///
167 /// # Returns
168 ///
169 /// The tensor on the given device.
170 fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
171
172 /// Reshapes a tensor.
173 ///
174 /// # Arguments
175 ///
176 /// * `tensor` - The tensor to reshape.
177 /// * `shape` - The new shape of the tensor.
178 ///
179 /// # Returns
180 ///
181 /// The tensor with the new shape.
182 fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
183
184 /// Converts the tensor to a data structure.
185 ///
186 /// # Arguments
187 ///
188 /// * `tensor` - The tensor.
189 ///
190 /// # Returns
191 ///
192 /// The data structure with the tensor's data.
193 fn q_into_data(
194 tensor: QuantizedTensor<B>,
195 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
196
197 /// Detaches a tensor from the computation graph.
198 fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
199 // Should only be overridden by autodiff backends.
200 tensor
201 }
202
203 /// Sets the `require_grad` flag of a tensor.
204 fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
205 // Should only be overridden by autodiff backends.
206 tensor
207 }
208
209 /// Returns the `require_grad` flag of a tensor.
210 fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
211 // Should only be overridden by autodiff backends.
212 false
213 }
214
215 /// Broadcasts the `tensor` to the given `shape`.
216 fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
217
218 /// Transposes a tensor.
219 ///
220 /// # Arguments
221 ///
222 /// * `tensor` - The tensor to transpose.
223 ///
224 /// # Returns
225 ///
226 /// The transposed tensor.
227 fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
228 let ndims = tensor.shape().num_dims();
229 Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
230 }
231
232 /// Swaps two dimensions of a tensor.
233 ///
234 /// # Arguments
235 ///
236 /// * `tensor` - The tensor to swap the dimensions of.
237 /// * `dim1` - The first dimension to swap.
238 /// * `dim2` - The second dimension to swap.
239 ///
240 /// # Returns
241 ///
242 /// The tensor with the dimensions swapped.
243 fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
244
245 /// Permutes the dimensions of a tensor.
246 ///
247 /// # Arguments
248 ///
249 /// * `tensor` - The tensor to permute the dimensions of.
250 /// * `axes` - The new order of the dimensions.
251 /// # Returns
252 ///
253 /// The tensor with the dimensions permuted.
254 fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
255
256 /// Reverse the order of elements in a tensor along the given axes.
257 ///
258 /// # Arguments
259 ///
260 /// * `tensor` - The tensor to reverse.
261 /// * `axes` - The axes to reverse.
262 ///
263 /// The tensor with the elements reversed.
264 fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
265
266 /// Select tensor elements along the given dimension corresponding for the given indices.
267 ///
268 /// # Arguments
269 ///
270 /// * `tensor` - The tensor to select from.
271 /// * `dim` - The dimension to select from.
272 /// * `indices` - The indices to select.
273 ///
274 /// # Returns
275 ///
276 /// The selected elements.
277 fn q_select(
278 tensor: QuantizedTensor<B>,
279 dim: usize,
280 indices: IntTensor<B>,
281 ) -> QuantizedTensor<B>;
282
283 /// Select tensor elements corresponding to the given slices.
284 ///
285 /// # Arguments
286 ///
287 /// * `tensor` - The tensor to select from.
288 /// * `slices` - The slices specifying ranges and steps for each dimension.
289 ///
290 /// # Returns
291 ///
292 /// The selected elements in a new tensor.
293 fn q_slice(tensor: QuantizedTensor<B>, slices: &[Slice]) -> QuantizedTensor<B>;
294
295 /// Gather elements from a tensor.
296 ///
297 /// # Arguments
298 ///
299 /// * `dim` - The dimension to gather from.
300 /// * `tensor` - The tensor to gather from.
301 /// * `indices` - The indices to gather.
302 ///
303 /// # Returns
304 ///
305 /// The gathered elements.
306 fn q_gather(
307 dim: usize,
308 tensor: QuantizedTensor<B>,
309 indices: IntTensor<B>,
310 ) -> QuantizedTensor<B> {
311 // Default implementation. Backends can gather on the quantized values when supported.
312 dequant_op_quant!(
313 ty Self,
314 float_op |tensor| B::float_gather(dim, tensor, indices),
315 tensor
316 )
317 }
318
319 /// Repeat the tensor along the given dimension.
320 ///
321 /// # Arguments
322 ///
323 /// * `tensor` - The tensor.
324 /// * `dim` - The dimension to repeat.
325 /// * `times` - The number of times to repeat the dimension.
326 ///
327 /// # Returns
328 ///
329 /// The tensor with the given dimension repeated.
330 fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
331 dequant_op_quant!(
332 ty Self,
333 float_op |tensor| B::float_repeat_dim(tensor, dim, times),
334 tensor
335 )
336 }
337
338 /// Adds two tensors together.
339 ///
340 /// # Arguments
341 ///
342 /// * `lhs` - The left hand side tensor.
343 /// * `rhs` - The right hand side tensor.
344 ///
345 /// # Returns
346 ///
347 /// The result of adding the two tensors together.
348 fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
349 dequant_op_flow!(
350 ty Self,
351 float_op |lhs, rhs| B::float_add(lhs, rhs),
352 lhs,
353 rhs
354 )
355 }
356
357 /// Adds a scalar to a tensor.
358 ///
359 /// # Arguments
360 ///
361 /// * `lhs` - The left hand side tensor.
362 /// * `rhs` - The right hand side scalar.
363 ///
364 /// # Returns
365 ///
366 /// The result of adding the scalar to the tensor.
367 fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
368 dequant_op_flow!(
369 ty Self,
370 float_op |tensor| B::float_add_scalar(tensor, rhs),
371 lhs
372 )
373 }
374
375 /// Clamps a tensor under a minimum value.
376 ///
377 /// # Arguments
378 ///
379 /// * `tensor` - The tensor to clamp.
380 /// * `min` - The minimum value.
381 ///
382 /// # Returns
383 ///
384 /// The clamped tensor.
385 fn q_clamp_min(tensor: QuantizedTensor<B>, min: FloatElem<B>) -> TensorPrimitive<B> {
386 dequant_op_flow!(
387 ty Self,
388 float_op |tensor| B::float_clamp_min(tensor, min),
389 tensor
390 )
391 }
392
393 /// Clamps a tensor over a maximum value.
394 ///
395 /// # Arguments
396 ///
397 /// * `tensor` - The tensor to clamp.
398 /// * `max` - The maximum value.
399 ///
400 /// # Returns
401 ///
402 /// The clamped tensor.
403 fn q_clamp_max(tensor: QuantizedTensor<B>, max: FloatElem<B>) -> TensorPrimitive<B> {
404 dequant_op_flow!(
405 ty Self,
406 float_op |tensor| B::float_clamp_max(tensor, max),
407 tensor
408 )
409 }
410
411 /// Clamps a tensor between a minimum and maximum value.
412 ///
413 /// # Arguments
414 ///
415 /// * `tensor` - The tensor to clamp.
416 /// * `min` - The minimum value.
417 /// * `max` - The maximum value.
418 ///
419 /// # Returns
420 ///
421 /// The clamped tensor.
422 fn q_clamp(
423 tensor: QuantizedTensor<B>,
424 min: FloatElem<B>,
425 max: FloatElem<B>,
426 ) -> TensorPrimitive<B> {
427 dequant_op_flow!(
428 ty Self,
429 float_op |tensor| B::float_clamp(tensor, min, max),
430 tensor
431 )
432 }
433
434 /// Subtracts two tensors.
435 ///
436 /// # Arguments
437 ///
438 /// * `lhs` - The left hand side tensor.
439 /// * `rhs` - The right hand side tensor.
440 ///
441 /// # Returns
442 ///
443 /// The result of subtracting the two tensors.
444 fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
445 dequant_op_flow!(
446 ty Self,
447 float_op |lhs, rhs| B::float_sub(lhs, rhs),
448 lhs,
449 rhs
450 )
451 }
452
453 /// Subtracts a scalar from a tensor.
454 ///
455 /// # Arguments
456 ///
457 /// * `lhs` - The left hand side tensor.
458 /// * `rhs` - The right hand side scalar.
459 ///
460 /// # Returns
461 ///
462 /// The result of subtracting the scalar from the tensor.
463 fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
464 dequant_op_flow!(
465 ty Self,
466 float_op |tensor| B::float_sub_scalar(tensor, rhs),
467 lhs
468 )
469 }
470
471 /// Multiplies two tensors together element-wise.
472 fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
473 dequant_op_flow!(
474 ty Self,
475 float_op |lhs, rhs| B::float_mul(lhs, rhs),
476 lhs,
477 rhs
478 )
479 }
480
481 /// Multiplies a tensor by a scalar.
482 ///
483 /// # Arguments
484 ///
485 /// * `lhs` - The left hand side tensor.
486 /// * `rhs` - The right hand side scalar.
487 ///
488 /// # Returns
489 ///
490 /// The result of multiplying the tensor by the scalar.
491 fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
492 dequant_op_flow!(
493 ty Self,
494 float_op |tensor| B::float_mul_scalar(tensor, rhs),
495 lhs
496 )
497 }
498
499 /// Divides two tensors element-wise.
500 ///
501 /// # Arguments
502 ///
503 /// * `lhs` - The left hand side tensor.
504 /// * `rhs` - The right hand side tensor.
505 ///
506 /// # Returns
507 ///
508 /// The result of dividing the two tensors.
509 fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
510 dequant_op_flow!(
511 ty Self,
512 float_op |lhs, rhs| B::float_div(lhs, rhs),
513 lhs,
514 rhs
515 )
516 }
517
518 /// Divides a tensor by a scalar.
519 ///
520 /// # Arguments
521 ///
522 /// * `lhs` - The left hand side tensor.
523 /// * `rhs` - The right hand side scalar.
524 ///
525 /// # Returns
526 ///
527 /// The result of dividing the tensor by the scalar.
528 fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
529 dequant_op_flow!(
530 ty Self,
531 float_op |tensor| B::float_div_scalar(tensor, rhs),
532 lhs
533 )
534 }
535
536 /// Multiplies two tensors together using matrix multiplication.
537 ///
538 /// # Arguments
539 ///
540 /// * `lhs` - The left hand side tensor.
541 /// * `rhs` - The right hand side tensor.
542 ///
543 /// # Returns
544 ///
545 /// The result of multiplying the two tensors together using matrix multiplication.
546 fn q_matmul(lhs: TensorPrimitive<B>, rhs: TensorPrimitive<B>) -> TensorPrimitive<B> {
547 let mut propagation = QuantPropagation::Inhibit;
548 let mut scheme = QuantScheme::default();
549 let lhs = match lhs {
550 TensorPrimitive::Float(lhs) => lhs,
551 TensorPrimitive::QFloat(lhs) => {
552 propagation = lhs.propagation();
553 scheme = *lhs.scheme();
554 Self::dequantize(lhs)
555 }
556 };
557 let rhs = match rhs {
558 TensorPrimitive::Float(rhs) => rhs,
559 TensorPrimitive::QFloat(rhs) => {
560 propagation = rhs.propagation();
561 scheme = *rhs.scheme();
562 Self::dequantize(rhs)
563 }
564 };
565
566 let out_f = B::float_matmul(lhs, rhs);
567 match propagation {
568 QuantPropagation::Propagate => {
569 TensorPrimitive::QFloat(<Self>::quantize_dynamic(out_f, &scheme))
570 }
571 QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
572 }
573 }
574
575 /// Negates a tensor element-wise.
576 fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
577 dequant_op_flow!(
578 ty Self,
579 float_op |tensor| B::float_neg(tensor),
580 tensor
581 )
582 }
583
584 /// Calculates the reciprocals element-wise
585 fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
586 dequant_op_flow!(
587 ty Self,
588 float_op |tensor| B::float_recip(tensor),
589 tensor
590 )
591 }
592
593 /// Sum of all elements in a tensor.
594 ///
595 /// # Arguments
596 ///
597 /// * `tensor` - The tensor to sum.
598 ///
599 /// # Returns
600 ///
601 /// A scalar tensor with the sum of all elements in `tensor`.
602 fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
603 dequant_op_flow!(
604 ty Self,
605 float_op |tensor| B::float_sum(tensor),
606 tensor
607 )
608 }
609
610 /// Sum of all elements in a tensor along a dimension.
611 ///
612 /// # Arguments
613 ///
614 /// * `tensor` - The tensor to sum.
615 /// * `dim` - The dimension along which to sum.
616 ///
617 /// # Returns
618 ///
619 /// A tensor with the sum of all elements in `tensor` along `dim`.
620 fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
621 dequant_op_flow!(
622 ty Self,
623 float_op |tensor| B::float_sum_dim(tensor, dim),
624 tensor
625 )
626 }
627
628 /// Product of all elements in a tensor.
629 ///
630 /// # Arguments
631 ///
632 /// * `tensor` - The tensor to product.
633 ///
634 /// # Returns
635 ///
636 /// A scalar tensor with the product of all elements in `tensor`.
637 fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
638 dequant_op_flow!(
639 ty Self,
640 float_op |tensor| B::float_prod(tensor),
641 tensor
642 )
643 }
644
645 /// Product of all elements in a tensor along a dimension.
646 ///
647 /// # Arguments
648 ///
649 /// * `tensor` - The tensor to product.
650 ///
651 /// # Returns
652 ///
653 /// A tensor with the product of all elements in `tensor` along `dim`.
654 fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
655 dequant_op_flow!(
656 ty Self,
657 float_op |tensor| B::float_prod_dim(tensor, dim),
658 tensor
659 )
660 }
661
662 /// Mean of all elements in a tensor.
663 ///
664 /// # Arguments
665 ///
666 /// * `tensor` - The tensor to mean.
667 ///
668 /// # Returns
669 ///
670 /// A scalar tensor with the mean of all elements in `tensor`.
671 fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
672 dequant_op_flow!(
673 ty Self,
674 float_op |tensor| B::float_mean(tensor),
675 tensor
676 )
677 }
678
679 /// Mean of all elements in a tensor along a dimension.
680 ///
681 /// # Arguments
682 ///
683 /// * `tensor` - The tensor to mean.
684 /// * `dim` - The dimension along which to mean.
685 ///
686 /// # Returns
687 ///
688 /// A tensor with the mean of all elements in `tensor` along `dim`.
689 fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
690 dequant_op_flow!(
691 ty Self,
692 float_op |tensor| B::float_mean_dim(tensor, dim),
693 tensor
694 )
695 }
696
697 /// Computes the cumulative sum of elements along a dimension.
698 ///
699 /// # Arguments
700 ///
701 /// * `tensor` - The tensor to compute the cumulative sum of.
702 /// * `dim` - The dimension along which to compute the cumulative sum.
703 ///
704 /// # Returns
705 ///
706 /// A tensor with the same shape where each element is the cumulative sum
707 /// of all elements up to and including that position along the dimension.
708 fn q_cumsum(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
709 dequant_op_flow!(
710 ty Self,
711 float_op |tensor| B::float_cumsum(tensor, dim),
712 tensor
713 )
714 }
715
716 /// Computes the cumulative product of elements along a dimension.
717 ///
718 /// # Arguments
719 ///
720 /// * `tensor` - The tensor to compute the cumulative product of.
721 /// * `dim` - The dimension along which to compute the cumulative product.
722 ///
723 /// # Returns
724 ///
725 /// A tensor with the same shape where each element is the cumulative product
726 /// of all elements up to and including that position along the dimension.
727 fn q_cumprod(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
728 dequant_op_flow!(
729 ty Self,
730 float_op |tensor| B::float_cumprod(tensor, dim),
731 tensor
732 )
733 }
734
735 /// Computes the cumulative minimum of elements along a dimension.
736 ///
737 /// # Arguments
738 ///
739 /// * `tensor` - The tensor to compute the cumulative minimum of.
740 /// * `dim` - The dimension along which to compute the cumulative minimum.
741 ///
742 /// # Returns
743 ///
744 /// A tensor with the same shape where each element is the minimum
745 /// of all elements up to and including that position along the dimension.
746 fn q_cummin(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
747 dequant_op_flow!(
748 ty Self,
749 float_op |tensor| B::float_cummin(tensor, dim),
750 tensor
751 )
752 }
753
754 /// Computes the cumulative maximum of elements along a dimension.
755 ///
756 /// # Arguments
757 ///
758 /// * `tensor` - The tensor to compute the cumulative maximum of.
759 /// * `dim` - The dimension along which to compute the cumulative maximum.
760 ///
761 /// # Returns
762 ///
763 /// A tensor with the same shape where each element is the maximum
764 /// of all elements up to and including that position along the dimension.
765 fn q_cummax(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
766 dequant_op_flow!(
767 ty Self,
768 float_op |tensor| B::float_cummax(tensor, dim),
769 tensor
770 )
771 }
772
773 /// Returns a new tensor with exponential values.
774 ///
775 /// # Arguments
776 ///
777 /// * `tensor` - The tensor to exponentiate.
778 ///
779 /// # Returns
780 ///
781 /// A tensor with the same shape as `tensor` with exponential values.
782 fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
783 dequant_op_flow!(
784 ty Self,
785 float_op |tensor| B::float_exp(tensor),
786 tensor
787 )
788 }
789
790 /// Returns a new tensor with natural logarithm values.
791 ///
792 /// # Arguments
793 ///
794 /// * `tensor` - The tensor to take the logarithm of.
795 ///
796 /// # Returns
797 ///
798 /// A tensor with the same shape as `tensor` with natural logarithm values.
799 fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
800 dequant_op_flow!(
801 ty Self,
802 float_op |tensor| B::float_log(tensor),
803 tensor
804 )
805 }
806
807 /// Returns a new tensor with logarithm values of (1 + Xi).
808 ///
809 /// # Arguments
810 ///
811 /// * `tensor` - The tensor to take the logarithm of.
812 ///
813 /// # Returns
814 ///
815 /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
816 fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
817 dequant_op_flow!(
818 ty Self,
819 float_op |tensor| B::float_log1p(tensor),
820 tensor
821 )
822 }
823
824 /// Element-wise power with another tensor.
825 ///
826 /// # Arguments
827 ///
828 /// * `lhs` - The left hand side tensor.
829 /// * `rhs` - The right hand side tensor.
830 ///
831 /// # Returns
832 ///
833 /// The elements of `lhs` raised to the power of the elements of `rhs`.
834 fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
835 dequant_op_flow!(
836 ty Self,
837 float_op |lhs, rhs| B::float_powf(lhs, rhs),
838 lhs,
839 rhs
840 )
841 }
842
843 /// Element-wise power with an IntTensor.
844 ///
845 /// # Arguments
846 ///
847 /// * `lhs` - The left hand side tensor.
848 /// * `rhs` - The right hand side floatTensor.
849 ///
850 /// # Returns
851 ///
852 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
853 fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> {
854 dequant_op_flow!(
855 ty Self,
856 float_op |tensor| B::float_powi(tensor, rhs),
857 lhs
858 )
859 }
860
861 /// Element-wise power with an int scalar.
862 ///
863 /// # Arguments
864 ///
865 /// * `lhs` - The left hand side tensor.
866 /// * `rhs` - The right hand side scalar.
867 ///
868 /// # Returns
869 ///
870 /// The elements of `lhs` raised to the value of `rhs`.
871 fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: IntElem<B>) -> TensorPrimitive<B> {
872 dequant_op_flow!(
873 ty Self,
874 float_op |tensor| B::float_powi_scalar(tensor, rhs),
875 lhs
876 )
877 }
878
879 /// Element-wise power with a float scalar.
880 ///
881 /// # Arguments
882 ///
883 /// * `tensor` - The tensor to exponentiate.
884 /// * `value` - The exponent.
885 ///
886 /// # Returns
887 ///
888 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
889 fn q_powf_scalar(tensor: QuantizedTensor<B>, value: f32) -> TensorPrimitive<B> {
890 dequant_op_flow!(
891 ty Self,
892 float_op |tensor| B::float_powf_scalar(tensor, value),
893 tensor
894 )
895 }
896
897 /// Returns a new tensor with square root values.
898 ///
899 /// # Arguments
900 ///
901 /// * `tensor` - The tensor to take the square root of.
902 ///
903 /// # Returns
904 ///
905 /// A tensor with the same shape as `tensor` with square root values.
906 fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
907 dequant_op_flow!(
908 ty Self,
909 float_op |tensor| B::float_sqrt(tensor),
910 tensor
911 )
912 }
913
914 /// Returns a new tensor with absolute values.
915 ///
916 /// # Arguments
917 ///
918 /// * `tensor` - The tensor to take absolute value of.
919 ///
920 /// # Returns
921 ///
922 /// A tensor with the same shape as `tensor` with absolute values.
923 fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
924 dequant_op_quant!(
925 ty Self,
926 float_op |tensor| B::float_abs(tensor),
927 tensor
928 )
929 }
930
931 /// Returns a new tensor with cosine values.
932 ///
933 /// # Arguments
934 ///
935 /// * `tensor` - The tensor to take the cosine of.
936 ///
937 /// # Returns
938 ///
939 /// A tensor with the same shape as `tensor` with cosine values.
940 fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
941 dequant_op_flow!(
942 ty Self,
943 float_op |tensor| B::float_cos(tensor),
944 tensor
945 )
946 }
947
948 /// Returns a new tensor with sine values.
949 ///
950 /// # Arguments
951 ///
952 /// * `tensor` - The tensor to take the sine of.
953 ///
954 /// # Returns
955 ///
956 /// A tensor with the same shape as `tensor` with sine values.
957 fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
958 dequant_op_flow!(
959 ty Self,
960 float_op |tensor| B::float_sin(tensor),
961 tensor
962 )
963 }
964
965 /// Returns a new tensor with tangent values.
966 ///
967 /// # Arguments
968 ///
969 /// * `tensor` - The tensor to take the tangent of.
970 ///
971 /// # Returns
972 ///
973 /// A tensor with the same shape as `tensor` with tangent values.
974 fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
975 dequant_op_flow!(
976 ty Self,
977 float_op |tensor| B::float_tan(tensor),
978 tensor
979 )
980 }
981
982 /// Returns a new tensor with hyperbolic cosine values.
983 ///
984 /// # Arguments
985 ///
986 /// * `tensor` - The tensor to take the hyperbolic cosine of.
987 ///
988 /// # Returns
989 ///
990 /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
991 fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
992 dequant_op_flow!(
993 ty Self,
994 float_op |tensor| B::float_cosh(tensor),
995 tensor
996 )
997 }
998
999 /// Returns a new tensor with hyperbolic sine values.
1000 ///
1001 /// # Arguments
1002 ///
1003 /// * `tensor` - The tensor to take the hyperbolic sine of.
1004 ///
1005 /// # Returns
1006 ///
1007 /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1008 fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1009 dequant_op_flow!(
1010 ty Self,
1011 float_op |tensor| B::float_sinh(tensor),
1012 tensor
1013 )
1014 }
1015
1016 /// Returns a new tensor with hyperbolic tangent values.
1017 ///
1018 /// # Arguments
1019 ///
1020 /// * `tensor` - The tensor to take the hyperbolic tangent of.
1021 ///
1022 /// # Returns
1023 ///
1024 /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1025 fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1026 dequant_op_flow!(
1027 ty Self,
1028 float_op |tensor| B::float_tanh(tensor),
1029 tensor
1030 )
1031 }
1032
1033 /// Returns a new tensor with the error function values.
1034 ///
1035 /// # Arguments
1036 ///
1037 /// * `tensor` - The tensor to take the error function of.
1038 ///
1039 /// # Returns
1040 ///
1041 /// A tensor with the same shape as `tensor` with error function values.
1042 fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1043 dequant_op_flow!(
1044 ty Self,
1045 float_op |tensor| B::float_erf(tensor),
1046 tensor
1047 )
1048 }
1049
1050 /// Concatenates tensors along a dimension.
1051 ///
1052 /// # Arguments
1053 ///
1054 /// * `tensors` - The tensors to concatenate.
1055 /// * `dim` - The dimension along which to concatenate.
1056 ///
1057 /// # Returns
1058 ///
1059 /// A tensor with the concatenated tensors along `dim`.
1060 fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
1061 // Heuristic: prioritize first tensor scheme
1062 let scheme = *tensors.first().unwrap().scheme();
1063
1064 let tensor_f = tensors
1065 .into_iter()
1066 .map(|tensor| Self::dequantize(tensor))
1067 .collect();
1068
1069 let out_f = B::float_cat(tensor_f, dim);
1070
1071 Self::quantize_dynamic(out_f, &scheme)
1072 }
1073
1074 /// Gets the indices of the maximum elements of a tensor along an axis.
1075 ///
1076 /// # Arguments
1077 ///
1078 /// * `tensor` - The tensor to get the maximum elements of.
1079 /// * `dim` - The dimension along which to get the maximum elements.
1080 ///
1081 /// # Returns
1082 ///
1083 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1084 fn q_argmax(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1085 // Default implementation. Backends can sort on the int values since qparams remain the same.
1086 let tensor_f = Self::dequantize(tensor);
1087 B::float_argmax(tensor_f, dim)
1088 }
1089
1090 /// Gets the indices of the minimum elements of a tensor along an axis.
1091 ///
1092 /// # Arguments
1093 ///
1094 /// * `tensor` - The tensor to get the minimum elements of.
1095 /// * `dim` - The dimension along which to get the minimum elements.
1096 ///
1097 /// # Returns
1098 ///
1099 /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1100 fn q_argmin(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1101 // Default implementation. Backends can sort on the int values since qparams remain the same.
1102 let tensor_f = Self::dequantize(tensor);
1103 B::float_argmin(tensor_f, dim)
1104 }
1105
1106 /// Gets the maximum element of a tensor.
1107 ///
1108 /// # Arguments
1109 ///
1110 /// * `tensor` - The tensor to get the maximum elements of.
1111 ///
1112 /// # Returns
1113 ///
1114 /// A tensor with the maximum element of `tensor`.
1115 fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1116 let shape = tensor.shape();
1117 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1118
1119 B::q_max_dim(tensor, 0)
1120 }
1121
1122 /// Gets the maximum elements of a tensor along an axis.
1123 ///
1124 /// # Arguments
1125 ///
1126 /// * `tensor` - The tensor to get the maximum elements of.
1127 /// * `dim` - The dimension along which to get the maximum elements.
1128 ///
1129 /// # Returns
1130 ///
1131 /// A tensor with the maximum elements of `tensor` along `dim`.
1132 fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1133 let index = B::q_argmax(tensor.clone(), dim);
1134
1135 B::q_gather(dim, tensor, index)
1136 }
1137
1138 /// Gets the maximum elements of a tensor along an axis and their indices.
1139 ///
1140 /// # Arguments
1141 ///
1142 /// * `tensor` - The tensor to get the maximum elements of.
1143 /// * `dim` - The dimension along which to get the maximum elements.
1144 ///
1145 /// # Returns
1146 ///
1147 /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1148 fn q_max_dim_with_indices(
1149 tensor: QuantizedTensor<B>,
1150 dim: usize,
1151 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1152 let index = B::q_argmax(tensor.clone(), dim);
1153 let values = B::q_gather(dim, tensor, index.clone());
1154
1155 (values, index)
1156 }
1157
1158 /// Gets the minimum element of a tensor.
1159 ///
1160 /// # Arguments
1161 ///
1162 /// * `tensor` - The tensor to get the minimum elements of.
1163 ///
1164 /// # Returns
1165 ///
1166 /// A tensor with the minimum element of `tensor`.
1167 fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1168 let shape = tensor.shape();
1169 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1170
1171 B::q_min_dim(tensor, 0)
1172 }
1173
1174 /// Gets the minimum elements of a tensor along an axis.
1175 ///
1176 /// # Arguments
1177 ///
1178 /// * `tensor` - The tensor to get the minimum elements of.
1179 /// * `dim` - The dimension along which to get the minimum elements.
1180 ///
1181 /// # Returns
1182 ///
1183 /// A tensor with the minimum elements of `tensor` along `dim`.
1184 fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1185 let index = B::q_argmin(tensor.clone(), dim);
1186
1187 B::q_gather(dim, tensor, index)
1188 }
1189
1190 /// Gets the minimum elements of a tensor along an axis and their indices.
1191 ///
1192 /// # Arguments
1193 ///
1194 /// * `tensor` - The tensor to get the minimum elements of.
1195 /// * `dim` - The dimension along which to get the minimum elements.
1196 ///
1197 /// # Returns
1198 ///
1199 /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1200 fn q_min_dim_with_indices(
1201 tensor: QuantizedTensor<B>,
1202 dim: usize,
1203 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1204 let index = B::q_argmin(tensor.clone(), dim);
1205 let values = B::q_gather(dim, tensor, index.clone());
1206
1207 (values, index)
1208 }
1209
1210 /// Gets the maximum element of a tensor.
1211 ///
1212 /// # Arguments
1213 ///
1214 /// * `tensor` - The tensor to get the maximum elements of.
1215 ///
1216 /// # Returns
1217 ///
1218 /// A tensor with the maximum element of `tensor`.
1219 fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1220 let shape = tensor.shape();
1221 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1222
1223 B::q_max_abs_dim(tensor, 0)
1224 }
1225
1226 /// Gets the maximum elements of a tensor along an axis.
1227 ///
1228 /// # Arguments
1229 ///
1230 /// * `tensor` - The tensor to get the maximum elements of.
1231 /// * `dim` - The dimension along which to get the maximum elements.
1232 ///
1233 /// # Returns
1234 ///
1235 /// A tensor with the maximum elements of `tensor` along `dim`.
1236 fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1237 let index = B::q_argmax(B::q_abs(tensor.clone()), dim);
1238
1239 B::q_gather(dim, tensor, index)
1240 }
1241
1242 /// Tests if any element in the `tensor` evaluates to True.
1243 ///
1244 /// # Arguments
1245 ///
1246 /// * `tensor` - The tensor to test.
1247 ///
1248 /// # Returns
1249 ///
1250 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1251 fn q_any(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1252 let tensor_f = Self::dequantize(tensor);
1253 B::float_any(tensor_f)
1254 }
1255
1256 /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1257 ///
1258 /// # Arguments
1259 ///
1260 /// * `tensor` - The tensor to test.
1261 /// * `dim` - The axis along which to test.
1262 ///
1263 /// # Returns
1264 ///
1265 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1266 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1267 /// input evaluates to True, False otherwise.
1268 fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1269 let tensor_f = Self::dequantize(tensor);
1270 B::float_any_dim(tensor_f, dim)
1271 }
1272
1273 /// Tests if all elements in the `tensor` evaluate to True.
1274 ///
1275 /// # Arguments
1276 ///
1277 /// * `tensor` - The tensor to test.
1278 ///
1279 /// # Returns
1280 ///
1281 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1282 /// evaluate to True, False otherwise.
1283 fn q_all(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1284 let tensor_f = Self::dequantize(tensor);
1285 B::float_all(tensor_f)
1286 }
1287
1288 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1289 ///
1290 /// # Arguments
1291 ///
1292 /// * `tensor` - The tensor to test.
1293 /// * `dim` - The axis along which to test.
1294 ///
1295 /// # Returns
1296 ///
1297 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1298 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1299 /// evaluates to True, False otherwise.
1300 fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1301 let tensor_f = Self::dequantize(tensor);
1302 B::float_all_dim(tensor_f, dim)
1303 }
1304
1305 /// Sort the elements of the input `tensor` by value in along a given dimension.
1306 ///
1307 /// This sort is unstable (i.e., may reorder equal elements).
1308 ///
1309 /// # Arguments
1310 ///
1311 /// * `tensor` - The input tensor.
1312 /// * `dim` - The axis along which to sort.
1313 /// * `descending` - The sorting order.
1314 ///
1315 /// # Returns
1316 ///
1317 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1318 fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1319 // Default implementation. Backends can sort on the int values since qparams remain the same.
1320 dequant_op_quant!(
1321 ty Self,
1322 float_op |tensor| B::float_sort(tensor, dim, descending),
1323 tensor
1324 )
1325 }
1326
1327 /// Sort the elements of the input `tensor` by value in along a given dimension.
1328 ///
1329 /// This sort is unstable (i.e., may reorder equal elements).
1330 ///
1331 /// # Arguments
1332 ///
1333 /// * `tensor` - The input tensor.
1334 /// * `dim` - The axis along which to sort.
1335 /// * `descending` - The sorting order.
1336 ///
1337 /// # Returns
1338 ///
1339 /// A tensor with the same shape as the input tensor and corresponding indices, where
1340 /// the elements are sorted by value and the indices map back to the original input tensor.
1341 fn q_sort_with_indices(
1342 tensor: QuantizedTensor<B>,
1343 dim: usize,
1344 descending: bool,
1345 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1346 // Default implementation. Backends can sort on the int values since qparams remain the same.
1347 let scheme = *tensor.scheme();
1348
1349 let tensor_f = Self::dequantize(tensor);
1350 let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending);
1351
1352 (Self::quantize_dynamic(out_f, &scheme), indices)
1353 }
1354
1355 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1356 ///
1357 /// This sort is unstable (i.e., may reorder equal elements).
1358 ///
1359 /// # Arguments
1360 ///
1361 /// * `tensor` - The input tensor.
1362 /// * `dim` - The axis along which to sort.
1363 /// * `descending` - The sorting order.
1364 ///
1365 /// # Returns
1366 ///
1367 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1368 fn q_argsort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1369 // Default implementation. Backends can sort on the int values since qparams remain the same.
1370 let tensor_f = Self::dequantize(tensor);
1371 B::float_argsort(tensor_f, dim, descending)
1372 }
1373}