burn_backend/backend/ops/qtensor.rs
1use alloc::vec::Vec;
2use burn_std::{
3 Shape, Slice,
4 quantization::{QuantPropagation, QuantScheme},
5};
6
7use crate::tensor::{
8 BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor,
9 quantization::{Calibration, QuantizationParametersPrimitive, compute_q_params, compute_range},
10};
11use crate::{
12 Backend, ExecutionError, QTensorPrimitive, TensorData, TensorMetadata, TensorPrimitive,
13};
14
15/// Automatically applies `dequantization -> float operation -> quantization`.
16///
17/// Used for tensor ops that should always return a quantized output.
18#[macro_export]
19macro_rules! dequant_op_quant {
20 // Binary tensor float op w/ lhs & rhs
21 (
22 ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
23 ) => {{
24 // Heuristic: prioritize lhs scheme
25 let scheme = $t1.scheme().clone();
26
27 let t1_f = <$ty>::dequantize($t1);
28 let t2_f = <$ty>::dequantize($t2);
29 #[allow(clippy::redundant_closure_call)]
30 let out_f = $float_op(t1_f, t2_f);
31
32 <$ty>::quantize_dynamic(out_f, &scheme)
33 }};
34 // Unary tensor float op
35 (
36 ty $ty:ty, float_op $float_op:expr, $tensor:expr
37 ) => {{
38 let scheme = $tensor.scheme().clone();
39
40 let tensor_f = <$ty>::dequantize($tensor);
41 #[allow(clippy::redundant_closure_call)]
42 let out_f = $float_op(tensor_f);
43
44 <$ty>::quantize_dynamic(out_f, &scheme)
45 }};
46}
47
48/// Automatically applies `dequantization -> float operation [-> quantization]`.
49///
50/// The output quantization step is optional.
51/// It is only performed when the input quantization scheme is propagated.
52#[macro_export]
53macro_rules! dequant_op_flow {
54 // Binary tensor float op w/ lhs & rhs
55 (
56 ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
57 ) => {{
58 // Heuristic: prioritize lhs scheme
59 let scheme = $t1.scheme().clone();
60 let propagation = $t1.propagation();
61
62 let t1_f = <$ty>::dequantize($t1);
63 let t2_f = <$ty>::dequantize($t2);
64 #[allow(clippy::redundant_closure_call)]
65 let out_f = $float_op(t1_f, t2_f);
66
67 match propagation {
68 QuantPropagation::Propagate => {
69 TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
70 }
71 QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
72 }
73 }};
74 // Unary tensor float op
75 (
76 ty $ty:ty, float_op $float_op:expr, $tensor:expr
77 ) => {{
78 let scheme = $tensor.scheme().clone();
79 let propagation = $tensor.propagation();
80
81 let tensor_f = <$ty>::dequantize($tensor);
82 #[allow(clippy::redundant_closure_call)]
83 let out_f = $float_op(tensor_f);
84
85 match propagation {
86 QuantPropagation::Propagate => {
87 TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
88 }
89 QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
90 }
91 }};
92}
93
94/// Operations on quantized tensors.
95///
96/// # Return Type Semantics
97///
98/// The return type of each operation indicates how quantization is handled:
99///
100/// ## [`QuantizedTensor<B>`]
101/// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized
102/// representation. Implementations should avoid dequantizing when possible to maintain performance.
103/// For example, shape or layout changes such as expand or transpose preserve quantization.
104///
105/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is
106/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering
107/// the quantization parameters to match the new layout.*
108///
109///
110/// ## [`TensorPrimitive<B>`]
111/// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation
112/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`])
113/// returned in floating-point form ([`TensorPrimitive::Float`]).
114///
115/// This distinction allows for fine-grained control over mixed-precision flows while still operating
116/// through a unified API.
117pub trait QTensorOps<B: Backend> {
118 /// Creates a new tensor from the data structure.
119 ///
120 /// # Arguments
121 ///
122 /// * `data` - The data structure.
123 /// * `device` - The device to create the tensor on.
124 ///
125 /// # Returns
126 ///
127 /// The tensor with the given data.
128 fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
129
130 /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
131 fn quantize(
132 tensor: FloatTensor<B>,
133 scheme: &QuantScheme,
134 qparams: QuantizationParametersPrimitive<B>,
135 ) -> QuantizedTensor<B>;
136
137 /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
138 fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> {
139 // Dynamically compute min/max tensor range and qparams before quantizing
140 let (min, max) = compute_range::<B>(scheme, tensor.clone(), &Calibration::MinMax);
141 let qparams = compute_q_params(scheme, min, max);
142 Self::quantize(tensor, scheme, qparams)
143 }
144
145 /// Convert the tensor back to a higher precision data type.
146 fn dequantize(tensor: QuantizedTensor<B>) -> FloatTensor<B>;
147
148 /// Gets the device of the tensor.
149 ///
150 /// # Arguments
151 ///
152 /// * `tensor` - The tensor.
153 ///
154 /// # Returns
155 ///
156 /// The device of the tensor.
157 fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
158
159 /// Moves the tensor to the given device.
160 ///
161 /// # Arguments
162 ///
163 /// * `tensor` - The tensor.
164 /// * `device` - The device to move the tensor to.
165 ///
166 /// # Returns
167 ///
168 /// The tensor on the given device.
169 fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
170
171 /// Reshapes a tensor.
172 ///
173 /// # Arguments
174 ///
175 /// * `tensor` - The tensor to reshape.
176 /// * `shape` - The new shape of the tensor.
177 ///
178 /// # Returns
179 ///
180 /// The tensor with the new shape.
181 fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
182
183 /// Converts the tensor to a data structure.
184 ///
185 /// # Arguments
186 ///
187 /// * `tensor` - The tensor.
188 ///
189 /// # Returns
190 ///
191 /// The data structure with the tensor's data.
192 fn q_into_data(
193 tensor: QuantizedTensor<B>,
194 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
195
196 /// Detaches a tensor from the computation graph.
197 fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
198 // Should only be overridden by autodiff backends.
199 tensor
200 }
201
202 /// Sets the `require_grad` flag of a tensor.
203 fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
204 // Should only be overridden by autodiff backends.
205 tensor
206 }
207
208 /// Returns the `require_grad` flag of a tensor.
209 fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
210 // Should only be overridden by autodiff backends.
211 false
212 }
213
214 /// Broadcasts the `tensor` to the given `shape`.
215 fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
216
217 /// Transposes a tensor.
218 ///
219 /// # Arguments
220 ///
221 /// * `tensor` - The tensor to transpose.
222 ///
223 /// # Returns
224 ///
225 /// The transposed tensor.
226 fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
227 let ndims = tensor.shape().num_dims();
228 Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
229 }
230
231 /// Swaps two dimensions of a tensor.
232 ///
233 /// # Arguments
234 ///
235 /// * `tensor` - The tensor to swap the dimensions of.
236 /// * `dim1` - The first dimension to swap.
237 /// * `dim2` - The second dimension to swap.
238 ///
239 /// # Returns
240 ///
241 /// The tensor with the dimensions swapped.
242 fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
243
244 /// Permutes the dimensions of a tensor.
245 ///
246 /// # Arguments
247 ///
248 /// * `tensor` - The tensor to permute the dimensions of.
249 /// * `axes` - The new order of the dimensions.
250 /// # Returns
251 ///
252 /// The tensor with the dimensions permuted.
253 fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
254
255 /// Reverse the order of elements in a tensor along the given axes.
256 ///
257 /// # Arguments
258 ///
259 /// * `tensor` - The tensor to reverse.
260 /// * `axes` - The axes to reverse.
261 ///
262 /// The tensor with the elements reversed.
263 fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
264
265 /// Select tensor elements along the given dimension corresponding for the given indices.
266 ///
267 /// # Arguments
268 ///
269 /// * `tensor` - The tensor to select from.
270 /// * `dim` - The dimension to select from.
271 /// * `indices` - The indices to select.
272 ///
273 /// # Returns
274 ///
275 /// The selected elements.
276 fn q_select(
277 tensor: QuantizedTensor<B>,
278 dim: usize,
279 indices: IntTensor<B>,
280 ) -> QuantizedTensor<B>;
281
282 /// Select tensor elements corresponding to the given slices.
283 ///
284 /// # Arguments
285 ///
286 /// * `tensor` - The tensor to select from.
287 /// * `slices` - The slices specifying ranges and steps for each dimension.
288 ///
289 /// # Returns
290 ///
291 /// The selected elements in a new tensor.
292 fn q_slice(tensor: QuantizedTensor<B>, slices: &[Slice]) -> QuantizedTensor<B>;
293
294 /// Gather elements from a tensor.
295 ///
296 /// # Arguments
297 ///
298 /// * `dim` - The dimension to gather from.
299 /// * `tensor` - The tensor to gather from.
300 /// * `indices` - The indices to gather.
301 ///
302 /// # Returns
303 ///
304 /// The gathered elements.
305 fn q_gather(
306 dim: usize,
307 tensor: QuantizedTensor<B>,
308 indices: IntTensor<B>,
309 ) -> QuantizedTensor<B> {
310 // Default implementation. Backends can gather on the quantized values when supported.
311 dequant_op_quant!(
312 ty Self,
313 float_op |tensor| B::float_gather(dim, tensor, indices),
314 tensor
315 )
316 }
317
318 /// Repeat the tensor along the given dimension.
319 ///
320 /// # Arguments
321 ///
322 /// * `tensor` - The tensor.
323 /// * `dim` - The dimension to repeat.
324 /// * `times` - The number of times to repeat the dimension.
325 ///
326 /// # Returns
327 ///
328 /// The tensor with the given dimension repeated.
329 fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
330 dequant_op_quant!(
331 ty Self,
332 float_op |tensor| B::float_repeat_dim(tensor, dim, times),
333 tensor
334 )
335 }
336
337 /// Adds two tensors together.
338 ///
339 /// # Arguments
340 ///
341 /// * `lhs` - The left hand side tensor.
342 /// * `rhs` - The right hand side tensor.
343 ///
344 /// # Returns
345 ///
346 /// The result of adding the two tensors together.
347 fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
348 dequant_op_flow!(
349 ty Self,
350 float_op |lhs, rhs| B::float_add(lhs, rhs),
351 lhs,
352 rhs
353 )
354 }
355
356 /// Adds a scalar to a tensor.
357 ///
358 /// # Arguments
359 ///
360 /// * `lhs` - The left hand side tensor.
361 /// * `rhs` - The right hand side scalar.
362 ///
363 /// # Returns
364 ///
365 /// The result of adding the scalar to the tensor.
366 fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
367 dequant_op_flow!(
368 ty Self,
369 float_op |tensor| B::float_add_scalar(tensor, rhs),
370 lhs
371 )
372 }
373
374 /// Clamps a tensor under a minimum value.
375 ///
376 /// # Arguments
377 ///
378 /// * `tensor` - The tensor to clamp.
379 /// * `min` - The minimum value.
380 ///
381 /// # Returns
382 ///
383 /// The clamped tensor.
384 fn q_clamp_min(tensor: QuantizedTensor<B>, min: FloatElem<B>) -> TensorPrimitive<B> {
385 dequant_op_flow!(
386 ty Self,
387 float_op |tensor| B::float_clamp_min(tensor, min),
388 tensor
389 )
390 }
391
392 /// Clamps a tensor over a maximum value.
393 ///
394 /// # Arguments
395 ///
396 /// * `tensor` - The tensor to clamp.
397 /// * `max` - The maximum value.
398 ///
399 /// # Returns
400 ///
401 /// The clamped tensor.
402 fn q_clamp_max(tensor: QuantizedTensor<B>, max: FloatElem<B>) -> TensorPrimitive<B> {
403 dequant_op_flow!(
404 ty Self,
405 float_op |tensor| B::float_clamp_max(tensor, max),
406 tensor
407 )
408 }
409
410 /// Clamps a tensor between a minimum and maximum value.
411 ///
412 /// # Arguments
413 ///
414 /// * `tensor` - The tensor to clamp.
415 /// * `min` - The minimum value.
416 /// * `max` - The maximum value.
417 ///
418 /// # Returns
419 ///
420 /// The clamped tensor.
421 fn q_clamp(
422 tensor: QuantizedTensor<B>,
423 min: FloatElem<B>,
424 max: FloatElem<B>,
425 ) -> TensorPrimitive<B> {
426 dequant_op_flow!(
427 ty Self,
428 float_op |tensor| B::float_clamp(tensor, min, max),
429 tensor
430 )
431 }
432
433 /// Subtracts two tensors.
434 ///
435 /// # Arguments
436 ///
437 /// * `lhs` - The left hand side tensor.
438 /// * `rhs` - The right hand side tensor.
439 ///
440 /// # Returns
441 ///
442 /// The result of subtracting the two tensors.
443 fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
444 dequant_op_flow!(
445 ty Self,
446 float_op |lhs, rhs| B::float_sub(lhs, rhs),
447 lhs,
448 rhs
449 )
450 }
451
452 /// Subtracts a scalar from a tensor.
453 ///
454 /// # Arguments
455 ///
456 /// * `lhs` - The left hand side tensor.
457 /// * `rhs` - The right hand side scalar.
458 ///
459 /// # Returns
460 ///
461 /// The result of subtracting the scalar from the tensor.
462 fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
463 dequant_op_flow!(
464 ty Self,
465 float_op |tensor| B::float_sub_scalar(tensor, rhs),
466 lhs
467 )
468 }
469
470 /// Multiplies two tensors together element-wise.
471 fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
472 dequant_op_flow!(
473 ty Self,
474 float_op |lhs, rhs| B::float_mul(lhs, rhs),
475 lhs,
476 rhs
477 )
478 }
479
480 /// Multiplies a tensor by a scalar.
481 ///
482 /// # Arguments
483 ///
484 /// * `lhs` - The left hand side tensor.
485 /// * `rhs` - The right hand side scalar.
486 ///
487 /// # Returns
488 ///
489 /// The result of multiplying the tensor by the scalar.
490 fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
491 dequant_op_flow!(
492 ty Self,
493 float_op |tensor| B::float_mul_scalar(tensor, rhs),
494 lhs
495 )
496 }
497
498 /// Divides two tensors element-wise.
499 ///
500 /// # Arguments
501 ///
502 /// * `lhs` - The left hand side tensor.
503 /// * `rhs` - The right hand side tensor.
504 ///
505 /// # Returns
506 ///
507 /// The result of dividing the two tensors.
508 fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
509 dequant_op_flow!(
510 ty Self,
511 float_op |lhs, rhs| B::float_div(lhs, rhs),
512 lhs,
513 rhs
514 )
515 }
516
517 /// Divides a tensor by a scalar.
518 ///
519 /// # Arguments
520 ///
521 /// * `lhs` - The left hand side tensor.
522 /// * `rhs` - The right hand side scalar.
523 ///
524 /// # Returns
525 ///
526 /// The result of dividing the tensor by the scalar.
527 fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
528 dequant_op_flow!(
529 ty Self,
530 float_op |tensor| B::float_div_scalar(tensor, rhs),
531 lhs
532 )
533 }
534
535 /// Multiplies two tensors together using matrix multiplication.
536 ///
537 /// # Arguments
538 ///
539 /// * `lhs` - The left hand side tensor.
540 /// * `rhs` - The right hand side tensor.
541 ///
542 /// # Returns
543 ///
544 /// The result of multiplying the two tensors together using matrix multiplication.
545 fn q_matmul(lhs: TensorPrimitive<B>, rhs: TensorPrimitive<B>) -> TensorPrimitive<B> {
546 let mut propagation = QuantPropagation::Inhibit;
547 let mut scheme = QuantScheme::default();
548 let lhs = match lhs {
549 TensorPrimitive::Float(lhs) => lhs,
550 TensorPrimitive::QFloat(lhs) => {
551 propagation = lhs.propagation();
552 scheme = *lhs.scheme();
553 Self::dequantize(lhs)
554 }
555 };
556 let rhs = match rhs {
557 TensorPrimitive::Float(rhs) => rhs,
558 TensorPrimitive::QFloat(rhs) => {
559 propagation = rhs.propagation();
560 scheme = *rhs.scheme();
561 Self::dequantize(rhs)
562 }
563 };
564
565 let out_f = B::float_matmul(lhs, rhs);
566 match propagation {
567 QuantPropagation::Propagate => {
568 TensorPrimitive::QFloat(<Self>::quantize_dynamic(out_f, &scheme))
569 }
570 QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
571 }
572 }
573
574 /// Negates a tensor element-wise.
575 fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
576 dequant_op_flow!(
577 ty Self,
578 float_op |tensor| B::float_neg(tensor),
579 tensor
580 )
581 }
582
583 /// Calculates the reciprocals element-wise
584 fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
585 dequant_op_flow!(
586 ty Self,
587 float_op |tensor| B::float_recip(tensor),
588 tensor
589 )
590 }
591
592 /// Sum of all elements in a tensor.
593 ///
594 /// # Arguments
595 ///
596 /// * `tensor` - The tensor to sum.
597 ///
598 /// # Returns
599 ///
600 /// A scalar tensor with the sum of all elements in `tensor`.
601 fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
602 dequant_op_flow!(
603 ty Self,
604 float_op |tensor| B::float_sum(tensor),
605 tensor
606 )
607 }
608
609 /// Sum of all elements in a tensor along a dimension.
610 ///
611 /// # Arguments
612 ///
613 /// * `tensor` - The tensor to sum.
614 /// * `dim` - The dimension along which to sum.
615 ///
616 /// # Returns
617 ///
618 /// A tensor with the sum of all elements in `tensor` along `dim`.
619 fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
620 dequant_op_flow!(
621 ty Self,
622 float_op |tensor| B::float_sum_dim(tensor, dim),
623 tensor
624 )
625 }
626
627 /// Product of all elements in a tensor.
628 ///
629 /// # Arguments
630 ///
631 /// * `tensor` - The tensor to product.
632 ///
633 /// # Returns
634 ///
635 /// A scalar tensor with the product of all elements in `tensor`.
636 fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
637 dequant_op_flow!(
638 ty Self,
639 float_op |tensor| B::float_prod(tensor),
640 tensor
641 )
642 }
643
644 /// Product of all elements in a tensor along a dimension.
645 ///
646 /// # Arguments
647 ///
648 /// * `tensor` - The tensor to product.
649 ///
650 /// # Returns
651 ///
652 /// A tensor with the product of all elements in `tensor` along `dim`.
653 fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
654 dequant_op_flow!(
655 ty Self,
656 float_op |tensor| B::float_prod_dim(tensor, dim),
657 tensor
658 )
659 }
660
661 /// Mean of all elements in a tensor.
662 ///
663 /// # Arguments
664 ///
665 /// * `tensor` - The tensor to mean.
666 ///
667 /// # Returns
668 ///
669 /// A scalar tensor with the mean of all elements in `tensor`.
670 fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
671 dequant_op_flow!(
672 ty Self,
673 float_op |tensor| B::float_mean(tensor),
674 tensor
675 )
676 }
677
678 /// Mean of all elements in a tensor along a dimension.
679 ///
680 /// # Arguments
681 ///
682 /// * `tensor` - The tensor to mean.
683 /// * `dim` - The dimension along which to mean.
684 ///
685 /// # Returns
686 ///
687 /// A tensor with the mean of all elements in `tensor` along `dim`.
688 fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
689 dequant_op_flow!(
690 ty Self,
691 float_op |tensor| B::float_mean_dim(tensor, dim),
692 tensor
693 )
694 }
695
696 /// Computes the cumulative sum of elements along a dimension.
697 ///
698 /// # Arguments
699 ///
700 /// * `tensor` - The tensor to compute the cumulative sum of.
701 /// * `dim` - The dimension along which to compute the cumulative sum.
702 ///
703 /// # Returns
704 ///
705 /// A tensor with the same shape where each element is the cumulative sum
706 /// of all elements up to and including that position along the dimension.
707 fn q_cumsum(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
708 dequant_op_flow!(
709 ty Self,
710 float_op |tensor| B::float_cumsum(tensor, dim),
711 tensor
712 )
713 }
714
715 /// Computes the cumulative product of elements along a dimension.
716 ///
717 /// # Arguments
718 ///
719 /// * `tensor` - The tensor to compute the cumulative product of.
720 /// * `dim` - The dimension along which to compute the cumulative product.
721 ///
722 /// # Returns
723 ///
724 /// A tensor with the same shape where each element is the cumulative product
725 /// of all elements up to and including that position along the dimension.
726 fn q_cumprod(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
727 dequant_op_flow!(
728 ty Self,
729 float_op |tensor| B::float_cumprod(tensor, dim),
730 tensor
731 )
732 }
733
734 /// Computes the cumulative minimum of elements along a dimension.
735 ///
736 /// # Arguments
737 ///
738 /// * `tensor` - The tensor to compute the cumulative minimum of.
739 /// * `dim` - The dimension along which to compute the cumulative minimum.
740 ///
741 /// # Returns
742 ///
743 /// A tensor with the same shape where each element is the minimum
744 /// of all elements up to and including that position along the dimension.
745 fn q_cummin(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
746 dequant_op_flow!(
747 ty Self,
748 float_op |tensor| B::float_cummin(tensor, dim),
749 tensor
750 )
751 }
752
753 /// Computes the cumulative maximum of elements along a dimension.
754 ///
755 /// # Arguments
756 ///
757 /// * `tensor` - The tensor to compute the cumulative maximum of.
758 /// * `dim` - The dimension along which to compute the cumulative maximum.
759 ///
760 /// # Returns
761 ///
762 /// A tensor with the same shape where each element is the maximum
763 /// of all elements up to and including that position along the dimension.
764 fn q_cummax(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
765 dequant_op_flow!(
766 ty Self,
767 float_op |tensor| B::float_cummax(tensor, dim),
768 tensor
769 )
770 }
771
772 /// Returns a new tensor with exponential values.
773 ///
774 /// # Arguments
775 ///
776 /// * `tensor` - The tensor to exponentiate.
777 ///
778 /// # Returns
779 ///
780 /// A tensor with the same shape as `tensor` with exponential values.
781 fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
782 dequant_op_flow!(
783 ty Self,
784 float_op |tensor| B::float_exp(tensor),
785 tensor
786 )
787 }
788
789 /// Returns a new tensor with natural logarithm values.
790 ///
791 /// # Arguments
792 ///
793 /// * `tensor` - The tensor to take the logarithm of.
794 ///
795 /// # Returns
796 ///
797 /// A tensor with the same shape as `tensor` with natural logarithm values.
798 fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
799 dequant_op_flow!(
800 ty Self,
801 float_op |tensor| B::float_log(tensor),
802 tensor
803 )
804 }
805
806 /// Returns a new tensor with logarithm values of (1 + Xi).
807 ///
808 /// # Arguments
809 ///
810 /// * `tensor` - The tensor to take the logarithm of.
811 ///
812 /// # Returns
813 ///
814 /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
815 fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
816 dequant_op_flow!(
817 ty Self,
818 float_op |tensor| B::float_log1p(tensor),
819 tensor
820 )
821 }
822
823 /// Element-wise power with another tensor.
824 ///
825 /// # Arguments
826 ///
827 /// * `lhs` - The left hand side tensor.
828 /// * `rhs` - The right hand side tensor.
829 ///
830 /// # Returns
831 ///
832 /// The elements of `lhs` raised to the power of the elements of `rhs`.
833 fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
834 dequant_op_flow!(
835 ty Self,
836 float_op |lhs, rhs| B::float_powf(lhs, rhs),
837 lhs,
838 rhs
839 )
840 }
841
842 /// Element-wise power with an IntTensor.
843 ///
844 /// # Arguments
845 ///
846 /// * `lhs` - The left hand side tensor.
847 /// * `rhs` - The right hand side floatTensor.
848 ///
849 /// # Returns
850 ///
851 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
852 fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> {
853 dequant_op_flow!(
854 ty Self,
855 float_op |tensor| B::float_powi(tensor, rhs),
856 lhs
857 )
858 }
859
860 /// Element-wise power with an int scalar.
861 ///
862 /// # Arguments
863 ///
864 /// * `lhs` - The left hand side tensor.
865 /// * `rhs` - The right hand side scalar.
866 ///
867 /// # Returns
868 ///
869 /// The elements of `lhs` raised to the value of `rhs`.
870 fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: IntElem<B>) -> TensorPrimitive<B> {
871 dequant_op_flow!(
872 ty Self,
873 float_op |tensor| B::float_powi_scalar(tensor, rhs),
874 lhs
875 )
876 }
877
878 /// Element-wise power with a float scalar.
879 ///
880 /// # Arguments
881 ///
882 /// * `tensor` - The tensor to exponentiate.
883 /// * `value` - The exponent.
884 ///
885 /// # Returns
886 ///
887 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
888 fn q_powf_scalar(tensor: QuantizedTensor<B>, value: f32) -> TensorPrimitive<B> {
889 dequant_op_flow!(
890 ty Self,
891 float_op |tensor| B::float_powf_scalar(tensor, value),
892 tensor
893 )
894 }
895
896 /// Returns a new tensor with square root values.
897 ///
898 /// # Arguments
899 ///
900 /// * `tensor` - The tensor to take the square root of.
901 ///
902 /// # Returns
903 ///
904 /// A tensor with the same shape as `tensor` with square root values.
905 fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
906 dequant_op_flow!(
907 ty Self,
908 float_op |tensor| B::float_sqrt(tensor),
909 tensor
910 )
911 }
912
913 /// Returns a new tensor with absolute values.
914 ///
915 /// # Arguments
916 ///
917 /// * `tensor` - The tensor to take absolute value of.
918 ///
919 /// # Returns
920 ///
921 /// A tensor with the same shape as `tensor` with absolute values.
922 fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
923 dequant_op_quant!(
924 ty Self,
925 float_op |tensor| B::float_abs(tensor),
926 tensor
927 )
928 }
929
930 /// Returns a new tensor with cosine values.
931 ///
932 /// # Arguments
933 ///
934 /// * `tensor` - The tensor to take the cosine of.
935 ///
936 /// # Returns
937 ///
938 /// A tensor with the same shape as `tensor` with cosine values.
939 fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
940 dequant_op_flow!(
941 ty Self,
942 float_op |tensor| B::float_cos(tensor),
943 tensor
944 )
945 }
946
947 /// Returns a new tensor with sine values.
948 ///
949 /// # Arguments
950 ///
951 /// * `tensor` - The tensor to take the sine of.
952 ///
953 /// # Returns
954 ///
955 /// A tensor with the same shape as `tensor` with sine values.
956 fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
957 dequant_op_flow!(
958 ty Self,
959 float_op |tensor| B::float_sin(tensor),
960 tensor
961 )
962 }
963
964 /// Returns a new tensor with tangent values.
965 ///
966 /// # Arguments
967 ///
968 /// * `tensor` - The tensor to take the tangent of.
969 ///
970 /// # Returns
971 ///
972 /// A tensor with the same shape as `tensor` with tangent values.
973 fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
974 dequant_op_flow!(
975 ty Self,
976 float_op |tensor| B::float_tan(tensor),
977 tensor
978 )
979 }
980
981 /// Returns a new tensor with hyperbolic cosine values.
982 ///
983 /// # Arguments
984 ///
985 /// * `tensor` - The tensor to take the hyperbolic cosine of.
986 ///
987 /// # Returns
988 ///
989 /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
990 fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
991 dequant_op_flow!(
992 ty Self,
993 float_op |tensor| B::float_cosh(tensor),
994 tensor
995 )
996 }
997
998 /// Returns a new tensor with hyperbolic sine values.
999 ///
1000 /// # Arguments
1001 ///
1002 /// * `tensor` - The tensor to take the hyperbolic sine of.
1003 ///
1004 /// # Returns
1005 ///
1006 /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1007 fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1008 dequant_op_flow!(
1009 ty Self,
1010 float_op |tensor| B::float_sinh(tensor),
1011 tensor
1012 )
1013 }
1014
1015 /// Returns a new tensor with hyperbolic tangent values.
1016 ///
1017 /// # Arguments
1018 ///
1019 /// * `tensor` - The tensor to take the hyperbolic tangent of.
1020 ///
1021 /// # Returns
1022 ///
1023 /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1024 fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1025 dequant_op_flow!(
1026 ty Self,
1027 float_op |tensor| B::float_tanh(tensor),
1028 tensor
1029 )
1030 }
1031
1032 /// Returns a new tensor with the error function values.
1033 ///
1034 /// # Arguments
1035 ///
1036 /// * `tensor` - The tensor to take the error function of.
1037 ///
1038 /// # Returns
1039 ///
1040 /// A tensor with the same shape as `tensor` with error function values.
1041 fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1042 dequant_op_flow!(
1043 ty Self,
1044 float_op |tensor| B::float_erf(tensor),
1045 tensor
1046 )
1047 }
1048
1049 /// Concatenates tensors along a dimension.
1050 ///
1051 /// # Arguments
1052 ///
1053 /// * `tensors` - The tensors to concatenate.
1054 /// * `dim` - The dimension along which to concatenate.
1055 ///
1056 /// # Returns
1057 ///
1058 /// A tensor with the concatenated tensors along `dim`.
1059 fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
1060 // Heuristic: prioritize first tensor scheme
1061 let scheme = *tensors.first().unwrap().scheme();
1062
1063 let tensor_f = tensors
1064 .into_iter()
1065 .map(|tensor| Self::dequantize(tensor))
1066 .collect();
1067
1068 let out_f = B::float_cat(tensor_f, dim);
1069
1070 Self::quantize_dynamic(out_f, &scheme)
1071 }
1072
1073 /// Gets the indices of the maximum elements of a tensor along an axis.
1074 ///
1075 /// # Arguments
1076 ///
1077 /// * `tensor` - The tensor to get the maximum elements of.
1078 /// * `dim` - The dimension along which to get the maximum elements.
1079 ///
1080 /// # Returns
1081 ///
1082 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1083 fn q_argmax(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1084 // Default implementation. Backends can sort on the int values since qparams remain the same.
1085 let tensor_f = Self::dequantize(tensor);
1086 B::float_argmax(tensor_f, dim)
1087 }
1088
1089 /// Gets the indices of the minimum elements of a tensor along an axis.
1090 ///
1091 /// # Arguments
1092 ///
1093 /// * `tensor` - The tensor to get the minimum elements of.
1094 /// * `dim` - The dimension along which to get the minimum elements.
1095 ///
1096 /// # Returns
1097 ///
1098 /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1099 fn q_argmin(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1100 // Default implementation. Backends can sort on the int values since qparams remain the same.
1101 let tensor_f = Self::dequantize(tensor);
1102 B::float_argmin(tensor_f, dim)
1103 }
1104
1105 /// Gets the maximum element of a tensor.
1106 ///
1107 /// # Arguments
1108 ///
1109 /// * `tensor` - The tensor to get the maximum elements of.
1110 ///
1111 /// # Returns
1112 ///
1113 /// A tensor with the maximum element of `tensor`.
1114 fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1115 let shape = tensor.shape();
1116 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1117
1118 B::q_max_dim(tensor, 0)
1119 }
1120
1121 /// Gets the maximum elements of a tensor along an axis.
1122 ///
1123 /// # Arguments
1124 ///
1125 /// * `tensor` - The tensor to get the maximum elements of.
1126 /// * `dim` - The dimension along which to get the maximum elements.
1127 ///
1128 /// # Returns
1129 ///
1130 /// A tensor with the maximum elements of `tensor` along `dim`.
1131 fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1132 let index = B::q_argmax(tensor.clone(), dim);
1133
1134 B::q_gather(dim, tensor, index)
1135 }
1136
1137 /// Gets the maximum elements of a tensor along an axis and their indices.
1138 ///
1139 /// # Arguments
1140 ///
1141 /// * `tensor` - The tensor to get the maximum elements of.
1142 /// * `dim` - The dimension along which to get the maximum elements.
1143 ///
1144 /// # Returns
1145 ///
1146 /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1147 fn q_max_dim_with_indices(
1148 tensor: QuantizedTensor<B>,
1149 dim: usize,
1150 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1151 let index = B::q_argmax(tensor.clone(), dim);
1152 let values = B::q_gather(dim, tensor, index.clone());
1153
1154 (values, index)
1155 }
1156
1157 /// Gets the minimum element of a tensor.
1158 ///
1159 /// # Arguments
1160 ///
1161 /// * `tensor` - The tensor to get the minimum elements of.
1162 ///
1163 /// # Returns
1164 ///
1165 /// A tensor with the minimum element of `tensor`.
1166 fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1167 let shape = tensor.shape();
1168 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1169
1170 B::q_min_dim(tensor, 0)
1171 }
1172
1173 /// Gets the minimum elements of a tensor along an axis.
1174 ///
1175 /// # Arguments
1176 ///
1177 /// * `tensor` - The tensor to get the minimum elements of.
1178 /// * `dim` - The dimension along which to get the minimum elements.
1179 ///
1180 /// # Returns
1181 ///
1182 /// A tensor with the minimum elements of `tensor` along `dim`.
1183 fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1184 let index = B::q_argmin(tensor.clone(), dim);
1185
1186 B::q_gather(dim, tensor, index)
1187 }
1188
1189 /// Gets the minimum elements of a tensor along an axis and their indices.
1190 ///
1191 /// # Arguments
1192 ///
1193 /// * `tensor` - The tensor to get the minimum elements of.
1194 /// * `dim` - The dimension along which to get the minimum elements.
1195 ///
1196 /// # Returns
1197 ///
1198 /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1199 fn q_min_dim_with_indices(
1200 tensor: QuantizedTensor<B>,
1201 dim: usize,
1202 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1203 let index = B::q_argmin(tensor.clone(), dim);
1204 let values = B::q_gather(dim, tensor, index.clone());
1205
1206 (values, index)
1207 }
1208
1209 /// Gets the maximum element of a tensor.
1210 ///
1211 /// # Arguments
1212 ///
1213 /// * `tensor` - The tensor to get the maximum elements of.
1214 ///
1215 /// # Returns
1216 ///
1217 /// A tensor with the maximum element of `tensor`.
1218 fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1219 let shape = tensor.shape();
1220 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1221
1222 B::q_max_abs_dim(tensor, 0)
1223 }
1224
1225 /// Gets the maximum elements of a tensor along an axis.
1226 ///
1227 /// # Arguments
1228 ///
1229 /// * `tensor` - The tensor to get the maximum elements of.
1230 /// * `dim` - The dimension along which to get the maximum elements.
1231 ///
1232 /// # Returns
1233 ///
1234 /// A tensor with the maximum elements of `tensor` along `dim`.
1235 fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1236 let index = B::q_argmax(B::q_abs(tensor.clone()), dim);
1237
1238 B::q_gather(dim, tensor, index)
1239 }
1240
1241 /// Tests if any element in the `tensor` evaluates to True.
1242 ///
1243 /// # Arguments
1244 ///
1245 /// * `tensor` - The tensor to test.
1246 ///
1247 /// # Returns
1248 ///
1249 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1250 fn q_any(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1251 let tensor_f = Self::dequantize(tensor);
1252 B::float_any(tensor_f)
1253 }
1254
1255 /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1256 ///
1257 /// # Arguments
1258 ///
1259 /// * `tensor` - The tensor to test.
1260 /// * `dim` - The axis along which to test.
1261 ///
1262 /// # Returns
1263 ///
1264 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1265 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1266 /// input evaluates to True, False otherwise.
1267 fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1268 let tensor_f = Self::dequantize(tensor);
1269 B::float_any_dim(tensor_f, dim)
1270 }
1271
1272 /// Tests if all elements in the `tensor` evaluate to True.
1273 ///
1274 /// # Arguments
1275 ///
1276 /// * `tensor` - The tensor to test.
1277 ///
1278 /// # Returns
1279 ///
1280 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1281 /// evaluate to True, False otherwise.
1282 fn q_all(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1283 let tensor_f = Self::dequantize(tensor);
1284 B::float_all(tensor_f)
1285 }
1286
1287 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1288 ///
1289 /// # Arguments
1290 ///
1291 /// * `tensor` - The tensor to test.
1292 /// * `dim` - The axis along which to test.
1293 ///
1294 /// # Returns
1295 ///
1296 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1297 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1298 /// evaluates to True, False otherwise.
1299 fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1300 let tensor_f = Self::dequantize(tensor);
1301 B::float_all_dim(tensor_f, dim)
1302 }
1303
1304 /// Sort the elements of the input `tensor` by value in along a given dimension.
1305 ///
1306 /// This sort is unstable (i.e., may reorder equal elements).
1307 ///
1308 /// # Arguments
1309 ///
1310 /// * `tensor` - The input tensor.
1311 /// * `dim` - The axis along which to sort.
1312 /// * `descending` - The sorting order.
1313 ///
1314 /// # Returns
1315 ///
1316 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1317 fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1318 // Default implementation. Backends can sort on the int values since qparams remain the same.
1319 dequant_op_quant!(
1320 ty Self,
1321 float_op |tensor| B::float_sort(tensor, dim, descending),
1322 tensor
1323 )
1324 }
1325
1326 /// Sort the elements of the input `tensor` by value in along a given dimension.
1327 ///
1328 /// This sort is unstable (i.e., may reorder equal elements).
1329 ///
1330 /// # Arguments
1331 ///
1332 /// * `tensor` - The input tensor.
1333 /// * `dim` - The axis along which to sort.
1334 /// * `descending` - The sorting order.
1335 ///
1336 /// # Returns
1337 ///
1338 /// A tensor with the same shape as the input tensor and corresponding indices, where
1339 /// the elements are sorted by value and the indices map back to the original input tensor.
1340 fn q_sort_with_indices(
1341 tensor: QuantizedTensor<B>,
1342 dim: usize,
1343 descending: bool,
1344 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1345 // Default implementation. Backends can sort on the int values since qparams remain the same.
1346 let scheme = *tensor.scheme();
1347
1348 let tensor_f = Self::dequantize(tensor);
1349 let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending);
1350
1351 (Self::quantize_dynamic(out_f, &scheme), indices)
1352 }
1353
1354 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1355 ///
1356 /// This sort is unstable (i.e., may reorder equal elements).
1357 ///
1358 /// # Arguments
1359 ///
1360 /// * `tensor` - The input tensor.
1361 /// * `dim` - The axis along which to sort.
1362 /// * `descending` - The sorting order.
1363 ///
1364 /// # Returns
1365 ///
1366 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1367 fn q_argsort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1368 // Default implementation. Backends can sort on the int values since qparams remain the same.
1369 let tensor_f = Self::dequantize(tensor);
1370 B::float_argsort(tensor_f, dim, descending)
1371 }
1372}