burn_tensor/tensor/ops/qtensor.rs
1use alloc::vec::Vec;
2use cubecl_quant::scheme::QuantScheme;
3
4use crate::{
5 Device, Shape, TensorData, TensorMetadata, TensorPrimitive,
6 backend::Backend,
7 quantization::{
8 Calibration, QTensorPrimitive, QuantPropagation, QuantizationParametersPrimitive,
9 compute_q_params_primitive, compute_range_primitive,
10 },
11};
12
13use super::{BoolTensor, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor};
14
15/// Automatically applies `dequantization -> float operation -> quantization`.
16///
17/// Used for tensor ops that should always return a quantized output.
18#[macro_export]
19macro_rules! dequant_op_quant {
20 // Binary tensor float op w/ lhs & rhs
21 (
22 ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
23 ) => {{
24 // Heuristic: prioritize lhs scheme
25 let scheme = $t1.scheme().clone();
26
27 let t1_f = <$ty>::dequantize($t1);
28 let t2_f = <$ty>::dequantize($t2);
29 #[allow(clippy::redundant_closure_call)]
30 let out_f = $float_op(t1_f, t2_f);
31
32 <$ty>::quantize_dynamic(out_f, &scheme)
33 }};
34 // Unary tensor float op
35 (
36 ty $ty:ty, float_op $float_op:expr, $tensor:expr
37 ) => {{
38 let scheme = $tensor.scheme().clone();
39
40 let tensor_f = <$ty>::dequantize($tensor);
41 #[allow(clippy::redundant_closure_call)]
42 let out_f = $float_op(tensor_f);
43
44 <$ty>::quantize_dynamic(out_f, &scheme)
45 }};
46}
47
48/// Automatically applies `dequantization -> float operation [-> quantization]`.
49///
50/// The output quantization step is optional.
51/// It is only performed when the input quantization scheme is propagated.
52#[macro_export]
53macro_rules! dequant_op_flow {
54 // Binary tensor float op w/ lhs & rhs
55 (
56 ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
57 ) => {{
58 // Heuristic: prioritize lhs scheme
59 let scheme = $t1.scheme().clone();
60 let propagation = $t1.propagation();
61
62 let t1_f = <$ty>::dequantize($t1);
63 let t2_f = <$ty>::dequantize($t2);
64 #[allow(clippy::redundant_closure_call)]
65 let out_f = $float_op(t1_f, t2_f);
66
67 match propagation {
68 QuantPropagation::Propagate => {
69 TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
70 }
71 QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
72 }
73 }};
74 // Unary tensor float op
75 (
76 ty $ty:ty, float_op $float_op:expr, $tensor:expr
77 ) => {{
78 let scheme = $tensor.scheme().clone();
79 let propagation = $tensor.propagation();
80
81 let tensor_f = <$ty>::dequantize($tensor);
82 #[allow(clippy::redundant_closure_call)]
83 let out_f = $float_op(tensor_f);
84
85 match propagation {
86 QuantPropagation::Propagate => {
87 TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))
88 }
89 QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
90 }
91 }};
92}
93
94/// Operations on quantized tensors.
95///
96/// # Return Type Semantics
97///
98/// The return type of each operation indicates how quantization is handled:
99///
100/// ## [`QuantizedTensor<B>`]
101/// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized
102/// representation. Implementations should avoid dequantizing when possible to maintain performance.
103/// For example, shape or layout changes such as expand or transpose preserve quantization.
104///
105/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is
106/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering
107/// the quantization parameters to match the new layout.*
108///
109///
110/// ## [`TensorPrimitive<B>`]
111/// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation
112/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`])
113/// returned in floating-point form ([`TensorPrimitive::Float`]).
114///
115/// This distinction allows for fine-grained control over mixed-precision flows while still operating
116/// through a unified API.
117pub trait QTensorOps<B: Backend> {
118 /// Creates a new tensor from the data structure.
119 ///
120 /// # Arguments
121 ///
122 /// * `data` - The data structure.
123 /// * `device` - The device to create the tensor on.
124 ///
125 /// # Returns
126 ///
127 /// The tensor with the given data.
128 fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
129
130 /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
131 fn quantize(
132 tensor: FloatTensor<B>,
133 scheme: &QuantScheme,
134 qparams: QuantizationParametersPrimitive<B>,
135 ) -> QuantizedTensor<B>;
136
137 /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
138 fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> {
139 // Dynamically compute min/max tensor range and qparams before quantizing
140 let (min, max) = compute_range_primitive::<B>(scheme, tensor.clone(), &Calibration::MinMax);
141 let qparams = compute_q_params_primitive(scheme, min, max);
142 Self::quantize(tensor, scheme, qparams)
143 }
144
145 /// Convert the tensor back to a higher precision data type.
146 fn dequantize(tensor: QuantizedTensor<B>) -> FloatTensor<B>;
147
148 /// Gets the device of the tensor.
149 ///
150 /// # Arguments
151 ///
152 /// * `tensor` - The tensor.
153 ///
154 /// # Returns
155 ///
156 /// The device of the tensor.
157 fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
158
159 /// Moves the tensor to the given device.
160 ///
161 /// # Arguments
162 ///
163 /// * `tensor` - The tensor.
164 /// * `device` - The device to move the tensor to.
165 ///
166 /// # Returns
167 ///
168 /// The tensor on the given device.
169 fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
170
171 /// Reshapes a tensor.
172 ///
173 /// # Arguments
174 ///
175 /// * `tensor` - The tensor to reshape.
176 /// * `shape` - The new shape of the tensor.
177 ///
178 /// # Returns
179 ///
180 /// The tensor with the new shape.
181 fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
182
183 /// Converts the tensor to a data structure.
184 ///
185 /// # Arguments
186 ///
187 /// * `tensor` - The tensor.
188 ///
189 /// # Returns
190 ///
191 /// The data structure with the tensor's data.
192 fn q_into_data(tensor: QuantizedTensor<B>) -> impl Future<Output = TensorData> + Send;
193
194 /// Detaches a tensor from the computation graph.
195 fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
196 // Should only be overridden by autodiff backends.
197 tensor
198 }
199
200 /// Sets the `require_grad` flag of a tensor.
201 fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
202 // Should only be overridden by autodiff backends.
203 tensor
204 }
205
206 /// Returns the `require_grad` flag of a tensor.
207 fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
208 // Should only be overridden by autodiff backends.
209 false
210 }
211
212 /// Broadcasts the `tensor` to the given `shape`.
213 fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
214
215 /// Transposes a tensor.
216 ///
217 /// # Arguments
218 ///
219 /// * `tensor` - The tensor to transpose.
220 ///
221 /// # Returns
222 ///
223 /// The transposed tensor.
224 fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
225 let ndims = tensor.shape().num_dims();
226 Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
227 }
228
229 /// Swaps two dimensions of a tensor.
230 ///
231 /// # Arguments
232 ///
233 /// * `tensor` - The tensor to swap the dimensions of.
234 /// * `dim1` - The first dimension to swap.
235 /// * `dim2` - The second dimension to swap.
236 ///
237 /// # Returns
238 ///
239 /// The tensor with the dimensions swapped.
240 fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
241
242 /// Permutes the dimensions of a tensor.
243 ///
244 /// # Arguments
245 ///
246 /// * `tensor` - The tensor to permute the dimensions of.
247 /// * `axes` - The new order of the dimensions.
248 /// # Returns
249 ///
250 /// The tensor with the dimensions permuted.
251 fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
252
253 /// Reverse the order of elements in a tensor along the given axes.
254 ///
255 /// # Arguments
256 ///
257 /// * `tensor` - The tensor to reverse.
258 /// * `axes` - The axes to reverse.
259 ///
260 /// The tensor with the elements reversed.
261 fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
262
263 /// Select tensor elements along the given dimension corresponding for the given indices.
264 ///
265 /// # Arguments
266 ///
267 /// * `tensor` - The tensor to select from.
268 /// * `dim` - The dimension to select from.
269 /// * `indices` - The indices to select.
270 ///
271 /// # Returns
272 ///
273 /// The selected elements.
274 fn q_select(
275 tensor: QuantizedTensor<B>,
276 dim: usize,
277 indices: IntTensor<B>,
278 ) -> QuantizedTensor<B>;
279
280 /// Select tensor elements corresponding to the given slices.
281 ///
282 /// # Arguments
283 ///
284 /// * `tensor` - The tensor to select from.
285 /// * `slices` - The slices specifying ranges and steps for each dimension.
286 ///
287 /// # Returns
288 ///
289 /// The selected elements in a new tensor.
290 fn q_slice(tensor: QuantizedTensor<B>, slices: &[crate::Slice]) -> QuantizedTensor<B>;
291
292 /// Gather elements from a tensor.
293 ///
294 /// # Arguments
295 ///
296 /// * `dim` - The dimension to gather from.
297 /// * `tensor` - The tensor to gather from.
298 /// * `indices` - The indices to gather.
299 ///
300 /// # Returns
301 ///
302 /// The gathered elements.
303 fn q_gather(
304 dim: usize,
305 tensor: QuantizedTensor<B>,
306 indices: IntTensor<B>,
307 ) -> QuantizedTensor<B> {
308 // Default implementation. Backends can gather on the quantized values when supported.
309 dequant_op_quant!(
310 ty Self,
311 float_op |tensor| B::float_gather(dim, tensor, indices),
312 tensor
313 )
314 }
315
316 /// Repeat the tensor along the given dimension.
317 ///
318 /// # Arguments
319 ///
320 /// * `tensor` - The tensor.
321 /// * `dim` - The dimension to repeat.
322 /// * `times` - The number of times to repeat the dimension.
323 ///
324 /// # Returns
325 ///
326 /// The tensor with the given dimension repeated.
327 fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
328 dequant_op_quant!(
329 ty Self,
330 float_op |tensor| B::float_repeat_dim(tensor, dim, times),
331 tensor
332 )
333 }
334
335 /// Adds two tensors together.
336 ///
337 /// # Arguments
338 ///
339 /// * `lhs` - The left hand side tensor.
340 /// * `rhs` - The right hand side tensor.
341 ///
342 /// # Returns
343 ///
344 /// The result of adding the two tensors together.
345 fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
346 dequant_op_flow!(
347 ty Self,
348 float_op |lhs, rhs| B::float_add(lhs, rhs),
349 lhs,
350 rhs
351 )
352 }
353
354 /// Adds a scalar to a tensor.
355 ///
356 /// # Arguments
357 ///
358 /// * `lhs` - The left hand side tensor.
359 /// * `rhs` - The right hand side scalar.
360 ///
361 /// # Returns
362 ///
363 /// The result of adding the scalar to the tensor.
364 fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
365 dequant_op_flow!(
366 ty Self,
367 float_op |tensor| B::float_add_scalar(tensor, rhs),
368 lhs
369 )
370 }
371
372 /// Clamps a tensor under a minimum value.
373 ///
374 /// # Arguments
375 ///
376 /// * `tensor` - The tensor to clamp.
377 /// * `min` - The minimum value.
378 ///
379 /// # Returns
380 ///
381 /// The clamped tensor.
382 fn q_clamp_min(tensor: QuantizedTensor<B>, min: FloatElem<B>) -> TensorPrimitive<B> {
383 dequant_op_flow!(
384 ty Self,
385 float_op |tensor| B::float_clamp_min(tensor, min),
386 tensor
387 )
388 }
389
390 /// Clamps a tensor over a maximum value.
391 ///
392 /// # Arguments
393 ///
394 /// * `tensor` - The tensor to clamp.
395 /// * `max` - The maximum value.
396 ///
397 /// # Returns
398 ///
399 /// The clamped tensor.
400 fn q_clamp_max(tensor: QuantizedTensor<B>, max: FloatElem<B>) -> TensorPrimitive<B> {
401 dequant_op_flow!(
402 ty Self,
403 float_op |tensor| B::float_clamp_max(tensor, max),
404 tensor
405 )
406 }
407
408 /// Clamps a tensor between a minimum and maximum value.
409 ///
410 /// # Arguments
411 ///
412 /// * `tensor` - The tensor to clamp.
413 /// * `min` - The minimum value.
414 /// * `max` - The maximum value.
415 ///
416 /// # Returns
417 ///
418 /// The clamped tensor.
419 fn q_clamp(
420 tensor: QuantizedTensor<B>,
421 min: FloatElem<B>,
422 max: FloatElem<B>,
423 ) -> TensorPrimitive<B> {
424 dequant_op_flow!(
425 ty Self,
426 float_op |tensor| B::float_clamp(tensor, min, max),
427 tensor
428 )
429 }
430
431 /// Subtracts two tensors.
432 ///
433 /// # Arguments
434 ///
435 /// * `lhs` - The left hand side tensor.
436 /// * `rhs` - The right hand side tensor.
437 ///
438 /// # Returns
439 ///
440 /// The result of subtracting the two tensors.
441 fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
442 dequant_op_flow!(
443 ty Self,
444 float_op |lhs, rhs| B::float_sub(lhs, rhs),
445 lhs,
446 rhs
447 )
448 }
449
450 /// Subtracts a scalar from a tensor.
451 ///
452 /// # Arguments
453 ///
454 /// * `lhs` - The left hand side tensor.
455 /// * `rhs` - The right hand side scalar.
456 ///
457 /// # Returns
458 ///
459 /// The result of subtracting the scalar from the tensor.
460 fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
461 dequant_op_flow!(
462 ty Self,
463 float_op |tensor| B::float_sub_scalar(tensor, rhs),
464 lhs
465 )
466 }
467
468 /// Multiplies two tensors together element-wise.
469 fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
470 dequant_op_flow!(
471 ty Self,
472 float_op |lhs, rhs| B::float_mul(lhs, rhs),
473 lhs,
474 rhs
475 )
476 }
477
478 /// Multiplies a tensor by a scalar.
479 ///
480 /// # Arguments
481 ///
482 /// * `lhs` - The left hand side tensor.
483 /// * `rhs` - The right hand side scalar.
484 ///
485 /// # Returns
486 ///
487 /// The result of multiplying the tensor by the scalar.
488 fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
489 dequant_op_flow!(
490 ty Self,
491 float_op |tensor| B::float_mul_scalar(tensor, rhs),
492 lhs
493 )
494 }
495
496 /// Divides two tensors element-wise.
497 ///
498 /// # Arguments
499 ///
500 /// * `lhs` - The left hand side tensor.
501 /// * `rhs` - The right hand side tensor.
502 ///
503 /// # Returns
504 ///
505 /// The result of dividing the two tensors.
506 fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
507 dequant_op_flow!(
508 ty Self,
509 float_op |lhs, rhs| B::float_div(lhs, rhs),
510 lhs,
511 rhs
512 )
513 }
514
515 /// Divides a tensor by a scalar.
516 ///
517 /// # Arguments
518 ///
519 /// * `lhs` - The left hand side tensor.
520 /// * `rhs` - The right hand side scalar.
521 ///
522 /// # Returns
523 ///
524 /// The result of dividing the tensor by the scalar.
525 fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> {
526 dequant_op_flow!(
527 ty Self,
528 float_op |tensor| B::float_div_scalar(tensor, rhs),
529 lhs
530 )
531 }
532
533 /// Multiplies two tensors together using matrix multiplication.
534 ///
535 /// # Arguments
536 ///
537 /// * `lhs` - The left hand side tensor.
538 /// * `rhs` - The right hand side tensor.
539 ///
540 /// # Returns
541 ///
542 /// The result of multiplying the two tensors together using matrix multiplication.
543 fn q_matmul(lhs: TensorPrimitive<B>, rhs: TensorPrimitive<B>) -> TensorPrimitive<B> {
544 let mut propagation = QuantPropagation::Inhibit;
545 let mut scheme = QuantScheme::default();
546 let lhs = match lhs {
547 TensorPrimitive::Float(lhs) => lhs,
548 TensorPrimitive::QFloat(lhs) => {
549 propagation = lhs.propagation();
550 scheme = *lhs.scheme();
551 Self::dequantize(lhs)
552 }
553 };
554 let rhs = match rhs {
555 TensorPrimitive::Float(rhs) => rhs,
556 TensorPrimitive::QFloat(rhs) => {
557 propagation = rhs.propagation();
558 scheme = *rhs.scheme();
559 Self::dequantize(rhs)
560 }
561 };
562
563 let out_f = B::float_matmul(lhs, rhs);
564 match propagation {
565 QuantPropagation::Propagate => {
566 TensorPrimitive::QFloat(<Self>::quantize_dynamic(out_f, &scheme))
567 }
568 QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
569 }
570 }
571
572 /// Negates a tensor element-wise.
573 fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
574 dequant_op_flow!(
575 ty Self,
576 float_op |tensor| B::float_neg(tensor),
577 tensor
578 )
579 }
580
581 /// Calculates the reciprocals element-wise
582 fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
583 dequant_op_flow!(
584 ty Self,
585 float_op |tensor| B::float_recip(tensor),
586 tensor
587 )
588 }
589
590 /// Sum of all elements in a tensor.
591 ///
592 /// # Arguments
593 ///
594 /// * `tensor` - The tensor to sum.
595 ///
596 /// # Returns
597 ///
598 /// A scalar tensor with the sum of all elements in `tensor`.
599 fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
600 dequant_op_flow!(
601 ty Self,
602 float_op |tensor| B::float_sum(tensor),
603 tensor
604 )
605 }
606
607 /// Sum of all elements in a tensor along a dimension.
608 ///
609 /// # Arguments
610 ///
611 /// * `tensor` - The tensor to sum.
612 /// * `dim` - The dimension along which to sum.
613 ///
614 /// # Returns
615 ///
616 /// A tensor with the sum of all elements in `tensor` along `dim`.
617 fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
618 dequant_op_flow!(
619 ty Self,
620 float_op |tensor| B::float_sum_dim(tensor, dim),
621 tensor
622 )
623 }
624
625 /// Product of all elements in a tensor.
626 ///
627 /// # Arguments
628 ///
629 /// * `tensor` - The tensor to product.
630 ///
631 /// # Returns
632 ///
633 /// A scalar tensor with the product of all elements in `tensor`.
634 fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
635 dequant_op_flow!(
636 ty Self,
637 float_op |tensor| B::float_prod(tensor),
638 tensor
639 )
640 }
641
642 /// Product of all elements in a tensor along a dimension.
643 ///
644 /// # Arguments
645 ///
646 /// * `tensor` - The tensor to product.
647 ///
648 /// # Returns
649 ///
650 /// A tensor with the product of all elements in `tensor` along `dim`.
651 fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
652 dequant_op_flow!(
653 ty Self,
654 float_op |tensor| B::float_prod_dim(tensor, dim),
655 tensor
656 )
657 }
658
659 /// Mean of all elements in a tensor.
660 ///
661 /// # Arguments
662 ///
663 /// * `tensor` - The tensor to mean.
664 ///
665 /// # Returns
666 ///
667 /// A scalar tensor with the mean of all elements in `tensor`.
668 fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
669 dequant_op_flow!(
670 ty Self,
671 float_op |tensor| B::float_mean(tensor),
672 tensor
673 )
674 }
675
676 /// Mean of all elements in a tensor along a dimension.
677 ///
678 /// # Arguments
679 ///
680 /// * `tensor` - The tensor to mean.
681 /// * `dim` - The dimension along which to mean.
682 ///
683 /// # Returns
684 ///
685 /// A tensor with the mean of all elements in `tensor` along `dim`.
686 fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
687 dequant_op_flow!(
688 ty Self,
689 float_op |tensor| B::float_mean_dim(tensor, dim),
690 tensor
691 )
692 }
693
694 /// Computes the cumulative sum of elements along a dimension.
695 ///
696 /// # Arguments
697 ///
698 /// * `tensor` - The tensor to compute the cumulative sum of.
699 /// * `dim` - The dimension along which to compute the cumulative sum.
700 ///
701 /// # Returns
702 ///
703 /// A tensor with the same shape where each element is the cumulative sum
704 /// of all elements up to and including that position along the dimension.
705 fn q_cumsum(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
706 dequant_op_flow!(
707 ty Self,
708 float_op |tensor| B::float_cumsum(tensor, dim),
709 tensor
710 )
711 }
712
713 /// Computes the cumulative product of elements along a dimension.
714 ///
715 /// # Arguments
716 ///
717 /// * `tensor` - The tensor to compute the cumulative product of.
718 /// * `dim` - The dimension along which to compute the cumulative product.
719 ///
720 /// # Returns
721 ///
722 /// A tensor with the same shape where each element is the cumulative product
723 /// of all elements up to and including that position along the dimension.
724 fn q_cumprod(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
725 dequant_op_flow!(
726 ty Self,
727 float_op |tensor| B::float_cumprod(tensor, dim),
728 tensor
729 )
730 }
731
732 /// Computes the cumulative minimum of elements along a dimension.
733 ///
734 /// # Arguments
735 ///
736 /// * `tensor` - The tensor to compute the cumulative minimum of.
737 /// * `dim` - The dimension along which to compute the cumulative minimum.
738 ///
739 /// # Returns
740 ///
741 /// A tensor with the same shape where each element is the minimum
742 /// of all elements up to and including that position along the dimension.
743 fn q_cummin(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
744 dequant_op_flow!(
745 ty Self,
746 float_op |tensor| B::float_cummin(tensor, dim),
747 tensor
748 )
749 }
750
751 /// Computes the cumulative maximum of elements along a dimension.
752 ///
753 /// # Arguments
754 ///
755 /// * `tensor` - The tensor to compute the cumulative maximum of.
756 /// * `dim` - The dimension along which to compute the cumulative maximum.
757 ///
758 /// # Returns
759 ///
760 /// A tensor with the same shape where each element is the maximum
761 /// of all elements up to and including that position along the dimension.
762 fn q_cummax(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
763 dequant_op_flow!(
764 ty Self,
765 float_op |tensor| B::float_cummax(tensor, dim),
766 tensor
767 )
768 }
769
770 /// Returns a new tensor with exponential values.
771 ///
772 /// # Arguments
773 ///
774 /// * `tensor` - The tensor to exponentiate.
775 ///
776 /// # Returns
777 ///
778 /// A tensor with the same shape as `tensor` with exponential values.
779 fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
780 dequant_op_flow!(
781 ty Self,
782 float_op |tensor| B::float_exp(tensor),
783 tensor
784 )
785 }
786
787 /// Returns a new tensor with natural logarithm values.
788 ///
789 /// # Arguments
790 ///
791 /// * `tensor` - The tensor to take the logarithm of.
792 ///
793 /// # Returns
794 ///
795 /// A tensor with the same shape as `tensor` with natural logarithm values.
796 fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
797 dequant_op_flow!(
798 ty Self,
799 float_op |tensor| B::float_log(tensor),
800 tensor
801 )
802 }
803
804 /// Returns a new tensor with logarithm values of (1 + Xi).
805 ///
806 /// # Arguments
807 ///
808 /// * `tensor` - The tensor to take the logarithm of.
809 ///
810 /// # Returns
811 ///
812 /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
813 fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
814 dequant_op_flow!(
815 ty Self,
816 float_op |tensor| B::float_log1p(tensor),
817 tensor
818 )
819 }
820
821 /// Element-wise power with another tensor.
822 ///
823 /// # Arguments
824 ///
825 /// * `lhs` - The left hand side tensor.
826 /// * `rhs` - The right hand side tensor.
827 ///
828 /// # Returns
829 ///
830 /// The elements of `lhs` raised to the power of the elements of `rhs`.
831 fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
832 dequant_op_flow!(
833 ty Self,
834 float_op |lhs, rhs| B::float_powf(lhs, rhs),
835 lhs,
836 rhs
837 )
838 }
839
840 /// Element-wise power with an IntTensor.
841 ///
842 /// # Arguments
843 ///
844 /// * `lhs` - The left hand side tensor.
845 /// * `rhs` - The right hand side floatTensor.
846 ///
847 /// # Returns
848 ///
849 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
850 fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> {
851 dequant_op_flow!(
852 ty Self,
853 float_op |tensor| B::float_powi(tensor, rhs),
854 lhs
855 )
856 }
857
858 /// Element-wise power with an int scalar.
859 ///
860 /// # Arguments
861 ///
862 /// * `lhs` - The left hand side tensor.
863 /// * `rhs` - The right hand side scalar.
864 ///
865 /// # Returns
866 ///
867 /// The elements of `lhs` raised to the value of `rhs`.
868 fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: IntElem<B>) -> TensorPrimitive<B> {
869 dequant_op_flow!(
870 ty Self,
871 float_op |tensor| B::float_powi_scalar(tensor, rhs),
872 lhs
873 )
874 }
875
876 /// Element-wise power with a float scalar.
877 ///
878 /// # Arguments
879 ///
880 /// * `tensor` - The tensor to exponentiate.
881 /// * `value` - The exponent.
882 ///
883 /// # Returns
884 ///
885 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
886 fn q_powf_scalar(tensor: QuantizedTensor<B>, value: f32) -> TensorPrimitive<B> {
887 dequant_op_flow!(
888 ty Self,
889 float_op |tensor| B::float_powf_scalar(tensor, value),
890 tensor
891 )
892 }
893
894 /// Returns a new tensor with square root values.
895 ///
896 /// # Arguments
897 ///
898 /// * `tensor` - The tensor to take the square root of.
899 ///
900 /// # Returns
901 ///
902 /// A tensor with the same shape as `tensor` with square root values.
903 fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
904 dequant_op_flow!(
905 ty Self,
906 float_op |tensor| B::float_sqrt(tensor),
907 tensor
908 )
909 }
910
911 /// Returns a new tensor with absolute values.
912 ///
913 /// # Arguments
914 ///
915 /// * `tensor` - The tensor to take absolute value of.
916 ///
917 /// # Returns
918 ///
919 /// A tensor with the same shape as `tensor` with absolute values.
920 fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
921 dequant_op_quant!(
922 ty Self,
923 float_op |tensor| B::float_abs(tensor),
924 tensor
925 )
926 }
927
928 /// Returns a new tensor with cosine values.
929 ///
930 /// # Arguments
931 ///
932 /// * `tensor` - The tensor to take the cosine of.
933 ///
934 /// # Returns
935 ///
936 /// A tensor with the same shape as `tensor` with cosine values.
937 fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
938 dequant_op_flow!(
939 ty Self,
940 float_op |tensor| B::float_cos(tensor),
941 tensor
942 )
943 }
944
945 /// Returns a new tensor with sine values.
946 ///
947 /// # Arguments
948 ///
949 /// * `tensor` - The tensor to take the sine of.
950 ///
951 /// # Returns
952 ///
953 /// A tensor with the same shape as `tensor` with sine values.
954 fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
955 dequant_op_flow!(
956 ty Self,
957 float_op |tensor| B::float_sin(tensor),
958 tensor
959 )
960 }
961
962 /// Returns a new tensor with tangent values.
963 ///
964 /// # Arguments
965 ///
966 /// * `tensor` - The tensor to take the tangent of.
967 ///
968 /// # Returns
969 ///
970 /// A tensor with the same shape as `tensor` with tangent values.
971 fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
972 dequant_op_flow!(
973 ty Self,
974 float_op |tensor| B::float_tan(tensor),
975 tensor
976 )
977 }
978
979 /// Returns a new tensor with hyperbolic cosine values.
980 ///
981 /// # Arguments
982 ///
983 /// * `tensor` - The tensor to take the hyperbolic cosine of.
984 ///
985 /// # Returns
986 ///
987 /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
988 fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
989 dequant_op_flow!(
990 ty Self,
991 float_op |tensor| B::float_cosh(tensor),
992 tensor
993 )
994 }
995
996 /// Returns a new tensor with hyperbolic sine values.
997 ///
998 /// # Arguments
999 ///
1000 /// * `tensor` - The tensor to take the hyperbolic sine of.
1001 ///
1002 /// # Returns
1003 ///
1004 /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1005 fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1006 dequant_op_flow!(
1007 ty Self,
1008 float_op |tensor| B::float_sinh(tensor),
1009 tensor
1010 )
1011 }
1012
1013 /// Returns a new tensor with hyperbolic tangent values.
1014 ///
1015 /// # Arguments
1016 ///
1017 /// * `tensor` - The tensor to take the hyperbolic tangent of.
1018 ///
1019 /// # Returns
1020 ///
1021 /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1022 fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1023 dequant_op_flow!(
1024 ty Self,
1025 float_op |tensor| B::float_tanh(tensor),
1026 tensor
1027 )
1028 }
1029
1030 /// Returns a new tensor with the error function values.
1031 ///
1032 /// # Arguments
1033 ///
1034 /// * `tensor` - The tensor to take the error function of.
1035 ///
1036 /// # Returns
1037 ///
1038 /// A tensor with the same shape as `tensor` with error function values.
1039 fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
1040 dequant_op_flow!(
1041 ty Self,
1042 float_op |tensor| B::float_erf(tensor),
1043 tensor
1044 )
1045 }
1046
1047 /// Concatenates tensors along a dimension.
1048 ///
1049 /// # Arguments
1050 ///
1051 /// * `tensors` - The tensors to concatenate.
1052 /// * `dim` - The dimension along which to concatenate.
1053 ///
1054 /// # Returns
1055 ///
1056 /// A tensor with the concatenated tensors along `dim`.
1057 fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
1058 // Heuristic: prioritize first tensor scheme
1059 let scheme = *tensors.first().unwrap().scheme();
1060
1061 let tensor_f = tensors
1062 .into_iter()
1063 .map(|tensor| Self::dequantize(tensor))
1064 .collect();
1065
1066 let out_f = B::float_cat(tensor_f, dim);
1067
1068 Self::quantize_dynamic(out_f, &scheme)
1069 }
1070
1071 /// Gets the indices of the maximum elements of a tensor along an axis.
1072 ///
1073 /// # Arguments
1074 ///
1075 /// * `tensor` - The tensor to get the maximum elements of.
1076 /// * `dim` - The dimension along which to get the maximum elements.
1077 ///
1078 /// # Returns
1079 ///
1080 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1081 fn q_argmax(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1082 // Default implementation. Backends can sort on the int values since qparams remain the same.
1083 let tensor_f = Self::dequantize(tensor);
1084 B::float_argmax(tensor_f, dim)
1085 }
1086
1087 /// Gets the indices of the minimum elements of a tensor along an axis.
1088 ///
1089 /// # Arguments
1090 ///
1091 /// * `tensor` - The tensor to get the minimum elements of.
1092 /// * `dim` - The dimension along which to get the minimum elements.
1093 ///
1094 /// # Returns
1095 ///
1096 /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1097 fn q_argmin(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1098 // Default implementation. Backends can sort on the int values since qparams remain the same.
1099 let tensor_f = Self::dequantize(tensor);
1100 B::float_argmin(tensor_f, dim)
1101 }
1102
1103 /// Gets the maximum element of a tensor.
1104 ///
1105 /// # Arguments
1106 ///
1107 /// * `tensor` - The tensor to get the maximum elements of.
1108 ///
1109 /// # Returns
1110 ///
1111 /// A tensor with the maximum element of `tensor`.
1112 fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1113 let shape = tensor.shape();
1114 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1115
1116 B::q_max_dim(tensor, 0)
1117 }
1118
1119 /// Gets the maximum elements of a tensor along an axis.
1120 ///
1121 /// # Arguments
1122 ///
1123 /// * `tensor` - The tensor to get the maximum elements of.
1124 /// * `dim` - The dimension along which to get the maximum elements.
1125 ///
1126 /// # Returns
1127 ///
1128 /// A tensor with the maximum elements of `tensor` along `dim`.
1129 fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1130 let index = B::q_argmax(tensor.clone(), dim);
1131
1132 B::q_gather(dim, tensor, index)
1133 }
1134
1135 /// Gets the maximum elements of a tensor along an axis and their indices.
1136 ///
1137 /// # Arguments
1138 ///
1139 /// * `tensor` - The tensor to get the maximum elements of.
1140 /// * `dim` - The dimension along which to get the maximum elements.
1141 ///
1142 /// # Returns
1143 ///
1144 /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1145 fn q_max_dim_with_indices(
1146 tensor: QuantizedTensor<B>,
1147 dim: usize,
1148 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1149 let index = B::q_argmax(tensor.clone(), dim);
1150 let values = B::q_gather(dim, tensor, index.clone());
1151
1152 (values, index)
1153 }
1154
1155 /// Gets the minimum element of a tensor.
1156 ///
1157 /// # Arguments
1158 ///
1159 /// * `tensor` - The tensor to get the minimum elements of.
1160 ///
1161 /// # Returns
1162 ///
1163 /// A tensor with the minimum element of `tensor`.
1164 fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1165 let shape = tensor.shape();
1166 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1167
1168 B::q_min_dim(tensor, 0)
1169 }
1170
1171 /// Gets the minimum elements of a tensor along an axis.
1172 ///
1173 /// # Arguments
1174 ///
1175 /// * `tensor` - The tensor to get the minimum elements of.
1176 /// * `dim` - The dimension along which to get the minimum elements.
1177 ///
1178 /// # Returns
1179 ///
1180 /// A tensor with the minimum elements of `tensor` along `dim`.
1181 fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1182 let index = B::q_argmin(tensor.clone(), dim);
1183
1184 B::q_gather(dim, tensor, index)
1185 }
1186
1187 /// Gets the minimum elements of a tensor along an axis and their indices.
1188 ///
1189 /// # Arguments
1190 ///
1191 /// * `tensor` - The tensor to get the minimum elements of.
1192 /// * `dim` - The dimension along which to get the minimum elements.
1193 ///
1194 /// # Returns
1195 ///
1196 /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1197 fn q_min_dim_with_indices(
1198 tensor: QuantizedTensor<B>,
1199 dim: usize,
1200 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1201 let index = B::q_argmin(tensor.clone(), dim);
1202 let values = B::q_gather(dim, tensor, index.clone());
1203
1204 (values, index)
1205 }
1206
1207 /// Gets the maximum element of a tensor.
1208 ///
1209 /// # Arguments
1210 ///
1211 /// * `tensor` - The tensor to get the maximum elements of.
1212 ///
1213 /// # Returns
1214 ///
1215 /// A tensor with the maximum element of `tensor`.
1216 fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1217 let shape = tensor.shape();
1218 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1219
1220 B::q_max_abs_dim(tensor, 0)
1221 }
1222
1223 /// Gets the maximum elements of a tensor along an axis.
1224 ///
1225 /// # Arguments
1226 ///
1227 /// * `tensor` - The tensor to get the maximum elements of.
1228 /// * `dim` - The dimension along which to get the maximum elements.
1229 ///
1230 /// # Returns
1231 ///
1232 /// A tensor with the maximum elements of `tensor` along `dim`.
1233 fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1234 let index = B::q_argmax(B::q_abs(tensor.clone()), dim);
1235
1236 B::q_gather(dim, tensor, index)
1237 }
1238
1239 /// Tests if any element in the `tensor` evaluates to True.
1240 ///
1241 /// # Arguments
1242 ///
1243 /// * `tensor` - The tensor to test.
1244 ///
1245 /// # Returns
1246 ///
1247 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1248 fn q_any(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1249 let tensor_f = Self::dequantize(tensor);
1250 B::float_any(tensor_f)
1251 }
1252
1253 /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1254 ///
1255 /// # Arguments
1256 ///
1257 /// * `tensor` - The tensor to test.
1258 /// * `dim` - The axis along which to test.
1259 ///
1260 /// # Returns
1261 ///
1262 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1263 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1264 /// input evaluates to True, False otherwise.
1265 fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1266 let tensor_f = Self::dequantize(tensor);
1267 B::float_any_dim(tensor_f, dim)
1268 }
1269
1270 /// Tests if all elements in the `tensor` evaluate to True.
1271 ///
1272 /// # Arguments
1273 ///
1274 /// * `tensor` - The tensor to test.
1275 ///
1276 /// # Returns
1277 ///
1278 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1279 /// evaluate to True, False otherwise.
1280 fn q_all(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1281 let tensor_f = Self::dequantize(tensor);
1282 B::float_all(tensor_f)
1283 }
1284
1285 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1286 ///
1287 /// # Arguments
1288 ///
1289 /// * `tensor` - The tensor to test.
1290 /// * `dim` - The axis along which to test.
1291 ///
1292 /// # Returns
1293 ///
1294 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1295 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1296 /// evaluates to True, False otherwise.
1297 fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1298 let tensor_f = Self::dequantize(tensor);
1299 B::float_all_dim(tensor_f, dim)
1300 }
1301
1302 /// Sort the elements of the input `tensor` by value in along a given dimension.
1303 ///
1304 /// This sort is unstable (i.e., may reorder equal elements).
1305 ///
1306 /// # Arguments
1307 ///
1308 /// * `tensor` - The input tensor.
1309 /// * `dim` - The axis along which to sort.
1310 /// * `descending` - The sorting order.
1311 ///
1312 /// # Returns
1313 ///
1314 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1315 fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1316 // Default implementation. Backends can sort on the int values since qparams remain the same.
1317 dequant_op_quant!(
1318 ty Self,
1319 float_op |tensor| B::float_sort(tensor, dim, descending),
1320 tensor
1321 )
1322 }
1323
1324 /// Sort the elements of the input `tensor` by value in along a given dimension.
1325 ///
1326 /// This sort is unstable (i.e., may reorder equal elements).
1327 ///
1328 /// # Arguments
1329 ///
1330 /// * `tensor` - The input tensor.
1331 /// * `dim` - The axis along which to sort.
1332 /// * `descending` - The sorting order.
1333 ///
1334 /// # Returns
1335 ///
1336 /// A tensor with the same shape as the input tensor and corresponding indices, where
1337 /// the elements are sorted by value and the indices map back to the original input tensor.
1338 fn q_sort_with_indices(
1339 tensor: QuantizedTensor<B>,
1340 dim: usize,
1341 descending: bool,
1342 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1343 // Default implementation. Backends can sort on the int values since qparams remain the same.
1344 let scheme = *tensor.scheme();
1345
1346 let tensor_f = Self::dequantize(tensor);
1347 let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending);
1348
1349 (Self::quantize_dynamic(out_f, &scheme), indices)
1350 }
1351
1352 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1353 ///
1354 /// This sort is unstable (i.e., may reorder equal elements).
1355 ///
1356 /// # Arguments
1357 ///
1358 /// * `tensor` - The input tensor.
1359 /// * `dim` - The axis along which to sort.
1360 /// * `descending` - The sorting order.
1361 ///
1362 /// # Returns
1363 ///
1364 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1365 fn q_argsort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1366 // Default implementation. Backends can sort on the int values since qparams remain the same.
1367 let tensor_f = Self::dequantize(tensor);
1368 B::float_argsort(tensor_f, dim, descending)
1369 }
1370}