burn_backend/backend/ops/qtensor.rs
1use alloc::vec::Vec;
2use burn_std::{
3 BoolDType, FloatDType, IntDType, Shape, Slice,
4 quantization::{QuantPropagation, QuantScheme},
5};
6
7use crate::{
8 Backend, ExecutionError, QTensorPrimitive, TensorData, TensorMetadata, TensorPrimitive,
9 get_device_settings,
10};
11use crate::{
12 Scalar,
13 tensor::{
14 BoolTensor, Device, FloatTensor, IntTensor, QuantizedTensor,
15 quantization::{
16 Calibration, QuantizationParametersPrimitive, compute_q_params, compute_range,
17 },
18 },
19};
20
21/// Automatically applies `dequantization -> float operation -> quantization`.
22///
23/// Used for tensor ops that should always return a quantized output.
24#[macro_export]
25macro_rules! dequant_op_quant {
26 // Binary tensor float op w/ lhs & rhs
27 (
28 float_op $float_op:expr, $t1:expr, $t2:expr
29 ) => {{
30 // Heuristic: prioritize lhs scheme
31 let scheme = $t1.scheme().clone();
32
33 let t1_f = Self::dequantize($t1);
34 let t2_f = Self::dequantize($t2);
35 #[allow(clippy::redundant_closure_call)]
36 let out_f = $float_op(t1_f, t2_f);
37
38 Self::quantize_dynamic(out_f, &scheme)
39 }};
40 // Unary tensor float op
41 (
42 float_op $float_op:expr, $tensor:expr
43 ) => {{
44 let scheme = $tensor.scheme().clone();
45 let dtype = get_device_settings::<B>(&Self::q_device(&$tensor)).float_dtype;
46
47 let tensor_f = Self::dequantize($tensor, dtype);
48 #[allow(clippy::redundant_closure_call)]
49 let out_f = $float_op(tensor_f);
50
51 Self::quantize_dynamic(out_f, &scheme)
52 }};
53}
54
55/// Automatically applies `dequantization -> float operation [-> quantization]`.
56///
57/// The output quantization step is optional.
58/// It is only performed when the input quantization scheme is propagated.
59#[macro_export]
60macro_rules! dequant_op_flow {
61 // Binary tensor float op w/ lhs & rhs
62 (
63 float_op $float_op:expr, $t1:expr, $t2:expr
64 ) => {{
65 // Heuristic: prioritize lhs scheme
66 let scheme = $t1.scheme().clone();
67 let propagation = $t1.propagation();
68 let dtype = get_device_settings::<B>(&Self::q_device(&$t1)).float_dtype;
69
70 let t1_f = Self::dequantize($t1, dtype);
71 let t2_f = Self::dequantize($t2, dtype);
72 #[allow(clippy::redundant_closure_call)]
73 let out_f = $float_op(t1_f, t2_f);
74
75 match propagation {
76 QuantPropagation::Propagate => {
77 TensorPrimitive::QFloat(Self::quantize_dynamic(out_f, &scheme))
78 }
79 QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
80 }
81 }};
82 // Unary tensor float op
83 (
84 float_op $float_op:expr, $tensor:expr
85 ) => {{
86 let scheme = $tensor.scheme().clone();
87 let propagation = $tensor.propagation();
88 let dtype = get_device_settings::<B>(&Self::q_device(&$tensor)).float_dtype;
89
90 let tensor_f = Self::dequantize($tensor, dtype);
91 #[allow(clippy::redundant_closure_call)]
92 let out_f = $float_op(tensor_f);
93
94 match propagation {
95 QuantPropagation::Propagate => {
96 TensorPrimitive::QFloat(Self::quantize_dynamic(out_f, &scheme))
97 }
98 QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
99 }
100 }};
101}
102
103/// Operations on quantized tensors.
104///
105/// # Return Type Semantics
106///
107/// The return type of each operation indicates how quantization is handled:
108///
109/// ## [`QuantizedTensor<B>`]
110/// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized
111/// representation. Implementations should avoid dequantizing when possible to maintain performance.
112/// For example, shape or layout changes such as expand or transpose preserve quantization.
113///
114/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is
115/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering
116/// the quantization parameters to match the new layout.*
117///
118///
119/// ## [`TensorPrimitive<B>`]
120/// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation
121/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`])
122/// returned in floating-point form ([`TensorPrimitive::Float`]).
123///
124/// This distinction allows for fine-grained control over mixed-precision flows while still operating
125/// through a unified API.
126pub trait QTensorOps<B: Backend> {
127 /// Creates a new tensor from the data structure.
128 ///
129 /// # Arguments
130 ///
131 /// * `data` - The data structure.
132 /// * `device` - The device to create the tensor on.
133 ///
134 /// # Returns
135 ///
136 /// The tensor with the given data.
137 fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
138
139 /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
140 fn quantize(
141 tensor: FloatTensor<B>,
142 scheme: &QuantScheme,
143 qparams: QuantizationParametersPrimitive<B>,
144 ) -> QuantizedTensor<B>;
145
146 /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
147 fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> {
148 // Dynamically compute min/max tensor range and qparams before quantizing
149 let (min, max) = compute_range::<B>(scheme, tensor.clone(), &Calibration::MinMax);
150 let qparams = compute_q_params(scheme, min, max);
151 Self::quantize(tensor, scheme, qparams)
152 }
153
154 /// Convert the tensor back to a higher precision data type.
155 fn dequantize(tensor: QuantizedTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
156
157 /// Gets the device of the tensor.
158 ///
159 /// # Arguments
160 ///
161 /// * `tensor` - The tensor.
162 ///
163 /// # Returns
164 ///
165 /// The device of the tensor.
166 fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
167
168 /// Moves the tensor to the given device.
169 ///
170 /// # Arguments
171 ///
172 /// * `tensor` - The tensor.
173 /// * `device` - The device to move the tensor to.
174 ///
175 /// # Returns
176 ///
177 /// The tensor on the given device.
178 fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
179
180 /// Reshapes a tensor.
181 ///
182 /// # Arguments
183 ///
184 /// * `tensor` - The tensor to reshape.
185 /// * `shape` - The new shape of the tensor.
186 ///
187 /// # Returns
188 ///
189 /// The tensor with the new shape.
190 fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
191
192 /// Converts the tensor to a data structure.
193 ///
194 /// # Arguments
195 ///
196 /// * `tensor` - The tensor.
197 ///
198 /// # Returns
199 ///
200 /// The data structure with the tensor's data.
201 fn q_into_data(
202 tensor: QuantizedTensor<B>,
203 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
204
205 /// Detaches a tensor from the computation graph.
206 fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
207 // Should only be overridden by autodiff backends.
208 tensor
209 }
210
211 /// Sets the `require_grad` flag of a tensor.
212 fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
213 // Should only be overridden by autodiff backends.
214 tensor
215 }
216
217 /// Returns the `require_grad` flag of a tensor.
218 fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
219 // Should only be overridden by autodiff backends.
220 false
221 }
222
223 /// Broadcasts the `tensor` to the given `shape`.
224 fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
225
226 /// Transposes a tensor.
227 ///
228 /// # Arguments
229 ///
230 /// * `tensor` - The tensor to transpose.
231 ///
232 /// # Returns
233 ///
234 /// The transposed tensor.
235 fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
236 let ndims = tensor.shape().num_dims();
237 Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
238 }
239
240 /// Swaps two dimensions of a tensor.
241 ///
242 /// # Arguments
243 ///
244 /// * `tensor` - The tensor to swap the dimensions of.
245 /// * `dim1` - The first dimension to swap.
246 /// * `dim2` - The second dimension to swap.
247 ///
248 /// # Returns
249 ///
250 /// The tensor with the dimensions swapped.
251 fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
252
253 /// Permutes the dimensions of a tensor.
254 ///
255 /// # Arguments
256 ///
257 /// * `tensor` - The tensor to permute the dimensions of.
258 /// * `axes` - The new order of the dimensions.
259 /// # Returns
260 ///
261 /// The tensor with the dimensions permuted.
262 fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
263
264 /// Reverse the order of elements in a tensor along the given axes.
265 ///
266 /// # Arguments
267 ///
268 /// * `tensor` - The tensor to reverse.
269 /// * `axes` - The axes to reverse.
270 ///
271 /// The tensor with the elements reversed.
272 fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
273
274 /// Select tensor elements along the given dimension corresponding for the given indices.
275 ///
276 /// # Arguments
277 ///
278 /// * `tensor` - The tensor to select from.
279 /// * `dim` - The dimension to select from.
280 /// * `indices` - The indices to select.
281 ///
282 /// # Returns
283 ///
284 /// The selected elements.
285 fn q_select(
286 tensor: QuantizedTensor<B>,
287 dim: usize,
288 indices: IntTensor<B>,
289 ) -> QuantizedTensor<B>;
290
291 /// Select tensor elements corresponding to the given slices.
292 ///
293 /// # Arguments
294 ///
295 /// * `tensor` - The tensor to select from.
296 /// * `slices` - The slices specifying ranges and steps for each dimension.
297 ///
298 /// # Returns
299 ///
300 /// The selected elements in a new tensor.
301 fn q_slice(tensor: QuantizedTensor<B>, slices: &[Slice]) -> QuantizedTensor<B>;
302
303 /// Gather elements from a tensor.
304 ///
305 /// # Arguments
306 ///
307 /// * `dim` - The dimension to gather from.
308 /// * `tensor` - The tensor to gather from.
309 /// * `indices` - The indices to gather.
310 ///
311 /// # Returns
312 ///
313 /// The gathered elements.
314 fn q_gather(
315 dim: usize,
316 tensor: QuantizedTensor<B>,
317 indices: IntTensor<B>,
318 ) -> QuantizedTensor<B> {
319 // Default implementation. Backends can gather on the quantized values when supported.
320 dequant_op_quant!(
321 float_op | tensor | B::float_gather(dim, tensor, indices),
322 tensor
323 )
324 }
325
326 /// Repeat the tensor along the given dimension.
327 ///
328 /// # Arguments
329 ///
330 /// * `tensor` - The tensor.
331 /// * `dim` - The dimension to repeat.
332 /// * `times` - The number of times to repeat the dimension.
333 ///
334 /// # Returns
335 ///
336 /// The tensor with the given dimension repeated.
337 fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
338 dequant_op_quant!(
339 float_op | tensor | B::float_repeat_dim(tensor, dim, times),
340 tensor
341 )
342 }
343
344 /// Adds two tensors together.
345 ///
346 /// # Arguments
347 ///
348 /// * `lhs` - The left hand side tensor.
349 /// * `rhs` - The right hand side tensor.
350 ///
351 /// # Returns
352 ///
353 /// The result of adding the two tensors together.
354 fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
355 dequant_op_flow!(float_op | lhs, rhs | B::float_add(lhs, rhs), lhs, rhs)
356 }
357
358 /// Adds a scalar to a tensor.
359 ///
360 /// # Arguments
361 ///
362 /// * `lhs` - The left hand side tensor.
363 /// * `rhs` - The right hand side scalar.
364 ///
365 /// # Returns
366 ///
367 /// The result of adding the scalar to the tensor.
368 fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
369 dequant_op_flow!(float_op | tensor | B::float_add_scalar(tensor, rhs), lhs)
370 }
371
372 /// Clamps a tensor under a minimum value.
373 ///
374 /// # Arguments
375 ///
376 /// * `tensor` - The tensor to clamp.
377 /// * `min` - The minimum value.
378 ///
379 /// # Returns
380 ///
381 /// The clamped tensor.
382 fn q_clamp_min(tensor: QuantizedTensor<B>, min: Scalar) -> TensorPrimitive<B> {
383 dequant_op_flow!(float_op | tensor | B::float_clamp_min(tensor, min), tensor)
384 }
385
386 /// Clamps a tensor over a maximum value.
387 ///
388 /// # Arguments
389 ///
390 /// * `tensor` - The tensor to clamp.
391 /// * `max` - The maximum value.
392 ///
393 /// # Returns
394 ///
395 /// The clamped tensor.
396 fn q_clamp_max(tensor: QuantizedTensor<B>, max: Scalar) -> TensorPrimitive<B> {
397 dequant_op_flow!(float_op | tensor | B::float_clamp_max(tensor, max), tensor)
398 }
399
400 /// Clamps a tensor between a minimum and maximum value.
401 ///
402 /// # Arguments
403 ///
404 /// * `tensor` - The tensor to clamp.
405 /// * `min` - The minimum value.
406 /// * `max` - The maximum value.
407 ///
408 /// # Returns
409 ///
410 /// The clamped tensor.
411 fn q_clamp(tensor: QuantizedTensor<B>, min: Scalar, max: Scalar) -> TensorPrimitive<B> {
412 dequant_op_flow!(float_op | tensor | B::float_clamp(tensor, min, max), tensor)
413 }
414
415 /// Subtracts two tensors.
416 ///
417 /// # Arguments
418 ///
419 /// * `lhs` - The left hand side tensor.
420 /// * `rhs` - The right hand side tensor.
421 ///
422 /// # Returns
423 ///
424 /// The result of subtracting the two tensors.
425 fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
426 dequant_op_flow!(float_op | lhs, rhs | B::float_sub(lhs, rhs), lhs, rhs)
427 }
428
429 /// Subtracts a scalar from a tensor.
430 ///
431 /// # Arguments
432 ///
433 /// * `lhs` - The left hand side tensor.
434 /// * `rhs` - The right hand side scalar.
435 ///
436 /// # Returns
437 ///
438 /// The result of subtracting the scalar from the tensor.
439 fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
440 dequant_op_flow!(float_op | tensor | B::float_sub_scalar(tensor, rhs), lhs)
441 }
442
443 /// Multiplies two tensors together element-wise.
444 fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
445 dequant_op_flow!(float_op | lhs, rhs | B::float_mul(lhs, rhs), lhs, rhs)
446 }
447
448 /// Multiplies a tensor by a scalar.
449 ///
450 /// # Arguments
451 ///
452 /// * `lhs` - The left hand side tensor.
453 /// * `rhs` - The right hand side scalar.
454 ///
455 /// # Returns
456 ///
457 /// The result of multiplying the tensor by the scalar.
458 fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
459 dequant_op_flow!(float_op | tensor | B::float_mul_scalar(tensor, rhs), lhs)
460 }
461
462 /// Divides two tensors element-wise.
463 ///
464 /// # Arguments
465 ///
466 /// * `lhs` - The left hand side tensor.
467 /// * `rhs` - The right hand side tensor.
468 ///
469 /// # Returns
470 ///
471 /// The result of dividing the two tensors.
472 fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
473 dequant_op_flow!(float_op | lhs, rhs | B::float_div(lhs, rhs), lhs, rhs)
474 }
475
476 /// Divides a tensor by a scalar.
477 ///
478 /// # Arguments
479 ///
480 /// * `lhs` - The left hand side tensor.
481 /// * `rhs` - The right hand side scalar.
482 ///
483 /// # Returns
484 ///
485 /// The result of dividing the tensor by the scalar.
486 fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
487 dequant_op_flow!(float_op | tensor | B::float_div_scalar(tensor, rhs), lhs)
488 }
489
490 /// Multiplies two tensors together using matrix multiplication.
491 ///
492 /// # Arguments
493 ///
494 /// * `lhs` - The left hand side tensor.
495 /// * `rhs` - The right hand side tensor.
496 ///
497 /// # Returns
498 ///
499 /// The result of multiplying the two tensors together using matrix multiplication.
500 fn q_matmul(lhs: TensorPrimitive<B>, rhs: TensorPrimitive<B>) -> TensorPrimitive<B> {
501 let mut propagation = QuantPropagation::Inhibit;
502 let mut scheme = QuantScheme::default();
503
504 // Pick a target dtype for any dequantization. If either operand is already
505 // a Float tensor, take its dtype so a Float-QFloat (or QFloat-Float) pair
506 // ends up matching after dequantize and `float_matmul` doesn't see a
507 // dtype mismatch. Only when both operands are QFloat do we fall back to
508 // the device default.
509 let target_dtype: Option<FloatDType> = match (&lhs, &rhs) {
510 (TensorPrimitive::Float(t), _) | (_, TensorPrimitive::Float(t)) => {
511 Some(t.dtype().into())
512 }
513 _ => None,
514 };
515
516 let lhs = match lhs {
517 TensorPrimitive::Float(lhs) => lhs,
518 TensorPrimitive::QFloat(lhs) => {
519 propagation = lhs.propagation();
520 scheme = *lhs.scheme();
521 let float_dtype = target_dtype
522 .unwrap_or_else(|| get_device_settings::<B>(&Self::q_device(&lhs)).float_dtype);
523
524 Self::dequantize(lhs, float_dtype)
525 }
526 };
527 let rhs = match rhs {
528 TensorPrimitive::Float(rhs) => rhs,
529 TensorPrimitive::QFloat(rhs) => {
530 propagation = rhs.propagation();
531 scheme = *rhs.scheme();
532 let float_dtype = target_dtype
533 .unwrap_or_else(|| get_device_settings::<B>(&Self::q_device(&rhs)).float_dtype);
534
535 Self::dequantize(rhs, float_dtype)
536 }
537 };
538
539 let out_f = B::float_matmul(lhs, rhs);
540 match propagation {
541 QuantPropagation::Propagate => {
542 TensorPrimitive::QFloat(<Self>::quantize_dynamic(out_f, &scheme))
543 }
544 QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
545 }
546 }
547
548 /// Negates a tensor element-wise.
549 fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
550 dequant_op_flow!(float_op | tensor | B::float_neg(tensor), tensor)
551 }
552
553 /// Calculates the reciprocals element-wise
554 fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
555 dequant_op_flow!(float_op | tensor | B::float_recip(tensor), tensor)
556 }
557
558 /// Sum of all elements in a tensor.
559 ///
560 /// # Arguments
561 ///
562 /// * `tensor` - The tensor to sum.
563 ///
564 /// # Returns
565 ///
566 /// A scalar tensor with the sum of all elements in `tensor`.
567 fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
568 dequant_op_flow!(float_op | tensor | B::float_sum(tensor), tensor)
569 }
570
571 /// Sum of all elements in a tensor along a dimension.
572 ///
573 /// # Arguments
574 ///
575 /// * `tensor` - The tensor to sum.
576 /// * `dim` - The dimension along which to sum.
577 ///
578 /// # Returns
579 ///
580 /// A tensor with the sum of all elements in `tensor` along `dim`.
581 fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
582 dequant_op_flow!(float_op | tensor | B::float_sum_dim(tensor, dim), tensor)
583 }
584
585 /// Product of all elements in a tensor.
586 ///
587 /// # Arguments
588 ///
589 /// * `tensor` - The tensor to product.
590 ///
591 /// # Returns
592 ///
593 /// A scalar tensor with the product of all elements in `tensor`.
594 fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
595 dequant_op_flow!(float_op | tensor | B::float_prod(tensor), tensor)
596 }
597
598 /// Product of all elements in a tensor along a dimension.
599 ///
600 /// # Arguments
601 ///
602 /// * `tensor` - The tensor to product.
603 ///
604 /// # Returns
605 ///
606 /// A tensor with the product of all elements in `tensor` along `dim`.
607 fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
608 dequant_op_flow!(float_op | tensor | B::float_prod_dim(tensor, dim), tensor)
609 }
610
611 /// Mean of all elements in a tensor.
612 ///
613 /// # Arguments
614 ///
615 /// * `tensor` - The tensor to mean.
616 ///
617 /// # Returns
618 ///
619 /// A scalar tensor with the mean of all elements in `tensor`.
620 fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
621 dequant_op_flow!(float_op | tensor | B::float_mean(tensor), tensor)
622 }
623
624 /// Mean of all elements in a tensor along a dimension.
625 ///
626 /// # Arguments
627 ///
628 /// * `tensor` - The tensor to mean.
629 /// * `dim` - The dimension along which to mean.
630 ///
631 /// # Returns
632 ///
633 /// A tensor with the mean of all elements in `tensor` along `dim`.
634 fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
635 dequant_op_flow!(float_op | tensor | B::float_mean_dim(tensor, dim), tensor)
636 }
637
638 /// Computes the cumulative sum of elements along a dimension.
639 ///
640 /// # Arguments
641 ///
642 /// * `tensor` - The tensor to compute the cumulative sum of.
643 /// * `dim` - The dimension along which to compute the cumulative sum.
644 ///
645 /// # Returns
646 ///
647 /// A tensor with the same shape where each element is the cumulative sum
648 /// of all elements up to and including that position along the dimension.
649 fn q_cumsum(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
650 dequant_op_flow!(float_op | tensor | B::float_cumsum(tensor, dim), tensor)
651 }
652
653 /// Computes the cumulative product of elements along a dimension.
654 ///
655 /// # Arguments
656 ///
657 /// * `tensor` - The tensor to compute the cumulative product of.
658 /// * `dim` - The dimension along which to compute the cumulative product.
659 ///
660 /// # Returns
661 ///
662 /// A tensor with the same shape where each element is the cumulative product
663 /// of all elements up to and including that position along the dimension.
664 fn q_cumprod(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
665 dequant_op_flow!(float_op | tensor | B::float_cumprod(tensor, dim), tensor)
666 }
667
668 /// Computes the cumulative minimum of elements along a dimension.
669 ///
670 /// # Arguments
671 ///
672 /// * `tensor` - The tensor to compute the cumulative minimum of.
673 /// * `dim` - The dimension along which to compute the cumulative minimum.
674 ///
675 /// # Returns
676 ///
677 /// A tensor with the same shape where each element is the minimum
678 /// of all elements up to and including that position along the dimension.
679 fn q_cummin(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
680 dequant_op_flow!(float_op | tensor | B::float_cummin(tensor, dim), tensor)
681 }
682
683 /// Computes the cumulative maximum of elements along a dimension.
684 ///
685 /// # Arguments
686 ///
687 /// * `tensor` - The tensor to compute the cumulative maximum of.
688 /// * `dim` - The dimension along which to compute the cumulative maximum.
689 ///
690 /// # Returns
691 ///
692 /// A tensor with the same shape where each element is the maximum
693 /// of all elements up to and including that position along the dimension.
694 fn q_cummax(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
695 dequant_op_flow!(float_op | tensor | B::float_cummax(tensor, dim), tensor)
696 }
697
698 /// Returns a new tensor with exponential values.
699 ///
700 /// # Arguments
701 ///
702 /// * `tensor` - The tensor to exponentiate.
703 ///
704 /// # Returns
705 ///
706 /// A tensor with the same shape as `tensor` with exponential values.
707 fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
708 dequant_op_flow!(float_op | tensor | B::float_exp(tensor), tensor)
709 }
710
711 /// Returns a new tensor with natural logarithm values.
712 ///
713 /// # Arguments
714 ///
715 /// * `tensor` - The tensor to take the logarithm of.
716 ///
717 /// # Returns
718 ///
719 /// A tensor with the same shape as `tensor` with natural logarithm values.
720 fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
721 dequant_op_flow!(float_op | tensor | B::float_log(tensor), tensor)
722 }
723
724 /// Returns a new tensor with logarithm values of (1 + Xi).
725 ///
726 /// # Arguments
727 ///
728 /// * `tensor` - The tensor to take the logarithm of.
729 ///
730 /// # Returns
731 ///
732 /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
733 fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
734 dequant_op_flow!(float_op | tensor | B::float_log1p(tensor), tensor)
735 }
736
737 /// Element-wise power with another tensor.
738 ///
739 /// # Arguments
740 ///
741 /// * `lhs` - The left hand side tensor.
742 /// * `rhs` - The right hand side tensor.
743 ///
744 /// # Returns
745 ///
746 /// The elements of `lhs` raised to the power of the elements of `rhs`.
747 fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
748 dequant_op_flow!(float_op | lhs, rhs | B::float_powf(lhs, rhs), lhs, rhs)
749 }
750
751 /// Element-wise power with an IntTensor.
752 ///
753 /// # Arguments
754 ///
755 /// * `lhs` - The left hand side tensor.
756 /// * `rhs` - The right hand side floatTensor.
757 ///
758 /// # Returns
759 ///
760 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
761 fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> {
762 dequant_op_flow!(float_op | tensor | B::float_powi(tensor, rhs), lhs)
763 }
764
765 /// Element-wise power with an int scalar.
766 ///
767 /// # Arguments
768 ///
769 /// * `lhs` - The left hand side tensor.
770 /// * `rhs` - The right hand side scalar.
771 ///
772 /// # Returns
773 ///
774 /// The elements of `lhs` raised to the value of `rhs`.
775 fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
776 dequant_op_flow!(float_op | tensor | B::float_powi_scalar(tensor, rhs), lhs)
777 }
778
779 /// Element-wise power with a float scalar.
780 ///
781 /// # Arguments
782 ///
783 /// * `tensor` - The tensor to exponentiate.
784 /// * `value` - The exponent.
785 ///
786 /// # Returns
787 ///
788 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
789 fn q_powf_scalar(tensor: QuantizedTensor<B>, value: Scalar) -> TensorPrimitive<B> {
790 dequant_op_flow!(
791 float_op | tensor | B::float_powf_scalar(tensor, value),
792 tensor
793 )
794 }
795
796 /// Returns a new tensor with square root values.
797 ///
798 /// # Arguments
799 ///
800 /// * `tensor` - The tensor to take the square root of.
801 ///
802 /// # Returns
803 ///
804 /// A tensor with the same shape as `tensor` with square root values.
805 fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
806 dequant_op_flow!(float_op | tensor | B::float_sqrt(tensor), tensor)
807 }
808
809 /// Returns a new tensor with absolute values.
810 ///
811 /// # Arguments
812 ///
813 /// * `tensor` - The tensor to take absolute value of.
814 ///
815 /// # Returns
816 ///
817 /// A tensor with the same shape as `tensor` with absolute values.
818 fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
819 dequant_op_quant!(float_op | tensor | B::float_abs(tensor), tensor)
820 }
821
822 /// Returns a new tensor with cosine values.
823 ///
824 /// # Arguments
825 ///
826 /// * `tensor` - The tensor to take the cosine of.
827 ///
828 /// # Returns
829 ///
830 /// A tensor with the same shape as `tensor` with cosine values.
831 fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
832 dequant_op_flow!(float_op | tensor | B::float_cos(tensor), tensor)
833 }
834
835 /// Returns a new tensor with sine values.
836 ///
837 /// # Arguments
838 ///
839 /// * `tensor` - The tensor to take the sine of.
840 ///
841 /// # Returns
842 ///
843 /// A tensor with the same shape as `tensor` with sine values.
844 fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
845 dequant_op_flow!(float_op | tensor | B::float_sin(tensor), tensor)
846 }
847
848 /// Returns a new tensor with tangent values.
849 ///
850 /// # Arguments
851 ///
852 /// * `tensor` - The tensor to take the tangent of.
853 ///
854 /// # Returns
855 ///
856 /// A tensor with the same shape as `tensor` with tangent values.
857 fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
858 dequant_op_flow!(float_op | tensor | B::float_tan(tensor), tensor)
859 }
860
861 /// Returns a new tensor with hyperbolic cosine values.
862 ///
863 /// # Arguments
864 ///
865 /// * `tensor` - The tensor to take the hyperbolic cosine of.
866 ///
867 /// # Returns
868 ///
869 /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
870 fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
871 dequant_op_flow!(float_op | tensor | B::float_cosh(tensor), tensor)
872 }
873
874 /// Returns a new tensor with hyperbolic sine values.
875 ///
876 /// # Arguments
877 ///
878 /// * `tensor` - The tensor to take the hyperbolic sine of.
879 ///
880 /// # Returns
881 ///
882 /// A tensor with the same shape as `tensor` with hyperbolic sine values.
883 fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
884 dequant_op_flow!(float_op | tensor | B::float_sinh(tensor), tensor)
885 }
886
887 /// Returns a new tensor with hyperbolic tangent values.
888 ///
889 /// # Arguments
890 ///
891 /// * `tensor` - The tensor to take the hyperbolic tangent of.
892 ///
893 /// # Returns
894 ///
895 /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
896 fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
897 dequant_op_flow!(float_op | tensor | B::float_tanh(tensor), tensor)
898 }
899
900 /// Returns a new tensor with the error function values.
901 ///
902 /// # Arguments
903 ///
904 /// * `tensor` - The tensor to take the error function of.
905 ///
906 /// # Returns
907 ///
908 /// A tensor with the same shape as `tensor` with error function values.
909 fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
910 dequant_op_flow!(float_op | tensor | B::float_erf(tensor), tensor)
911 }
912
913 /// Concatenates tensors along a dimension.
914 ///
915 /// # Arguments
916 ///
917 /// * `tensors` - The tensors to concatenate.
918 /// * `dim` - The dimension along which to concatenate.
919 ///
920 /// # Returns
921 ///
922 /// A tensor with the concatenated tensors along `dim`.
923 fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
924 // Heuristic: prioritize first tensor scheme
925 let first = tensors.first().unwrap();
926 let scheme = *first.scheme();
927 let dtype = get_device_settings::<B>(&Self::q_device(first)).float_dtype;
928
929 let tensor_f = tensors
930 .into_iter()
931 .map(|tensor| Self::dequantize(tensor, dtype))
932 .collect();
933
934 let out_f = B::float_cat(tensor_f, dim);
935
936 Self::quantize_dynamic(out_f, &scheme)
937 }
938
939 /// Gets the indices of the maximum elements of a tensor along an axis.
940 ///
941 /// # Arguments
942 ///
943 /// * `tensor` - The tensor to get the maximum elements of.
944 /// * `dim` - The dimension along which to get the maximum elements.
945 /// * `out_dtype` - The output tensor dtype.
946 ///
947 /// # Returns
948 ///
949 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
950 fn q_argmax(tensor: QuantizedTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B> {
951 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
952 let tensor_f = Self::dequantize(tensor, dtype);
953 B::float_argmax(tensor_f, dim, out_dtype)
954 }
955
956 /// Gets the indices of the minimum elements of a tensor along an axis.
957 ///
958 /// # Arguments
959 ///
960 /// * `tensor` - The tensor to get the minimum elements of.
961 /// * `dim` - The dimension along which to get the minimum elements.
962 /// * `out_dtype` - The output tensor dtype.
963 ///
964 /// # Returns
965 ///
966 /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
967 fn q_argmin(tensor: QuantizedTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B> {
968 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
969 let tensor_f = Self::dequantize(tensor, dtype);
970 B::float_argmin(tensor_f, dim, out_dtype)
971 }
972
973 /// Gets the maximum element of a tensor.
974 ///
975 /// # Arguments
976 ///
977 /// * `tensor` - The tensor to get the maximum elements of.
978 ///
979 /// # Returns
980 ///
981 /// A tensor with the maximum element of `tensor`.
982 fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
983 let shape = tensor.shape();
984 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
985
986 B::q_max_dim(tensor, 0)
987 }
988
989 /// Gets the maximum elements of a tensor along an axis.
990 ///
991 /// # Arguments
992 ///
993 /// * `tensor` - The tensor to get the maximum elements of.
994 /// * `dim` - The dimension along which to get the maximum elements.
995 ///
996 /// # Returns
997 ///
998 /// A tensor with the maximum elements of `tensor` along `dim`.
999 fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1000 let int_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
1001 let index = B::q_argmax(tensor.clone(), dim, int_dtype);
1002
1003 B::q_gather(dim, tensor, index)
1004 }
1005
1006 /// Gets the maximum elements of a tensor along an axis and their indices.
1007 ///
1008 /// # Arguments
1009 ///
1010 /// * `tensor` - The tensor to get the maximum elements of.
1011 /// * `dim` - The dimension along which to get the maximum elements.
1012 ///
1013 /// # Returns
1014 ///
1015 /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1016 fn q_max_dim_with_indices(
1017 tensor: QuantizedTensor<B>,
1018 dim: usize,
1019 out_dtype: IntDType,
1020 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1021 let index = B::q_argmax(tensor.clone(), dim, out_dtype);
1022 let values = B::q_gather(dim, tensor, index.clone());
1023
1024 (values, index)
1025 }
1026
1027 /// Gets the minimum element of a tensor.
1028 ///
1029 /// # Arguments
1030 ///
1031 /// * `tensor` - The tensor to get the minimum elements of.
1032 ///
1033 /// # Returns
1034 ///
1035 /// A tensor with the minimum element of `tensor`.
1036 fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1037 let shape = tensor.shape();
1038 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1039
1040 B::q_min_dim(tensor, 0)
1041 }
1042
1043 /// Gets the minimum elements of a tensor along an axis.
1044 ///
1045 /// # Arguments
1046 ///
1047 /// * `tensor` - The tensor to get the minimum elements of.
1048 /// * `dim` - The dimension along which to get the minimum elements.
1049 ///
1050 /// # Returns
1051 ///
1052 /// A tensor with the minimum elements of `tensor` along `dim`.
1053 fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1054 let int_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
1055 let index = B::q_argmin(tensor.clone(), dim, int_dtype);
1056
1057 B::q_gather(dim, tensor, index)
1058 }
1059
1060 /// Gets the minimum elements of a tensor along an axis and their indices.
1061 ///
1062 /// # Arguments
1063 ///
1064 /// * `tensor` - The tensor to get the minimum elements of.
1065 /// * `dim` - The dimension along which to get the minimum elements.
1066 ///
1067 /// # Returns
1068 ///
1069 /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1070 fn q_min_dim_with_indices(
1071 tensor: QuantizedTensor<B>,
1072 dim: usize,
1073 out_dtype: IntDType,
1074 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1075 let index = B::q_argmin(tensor.clone(), dim, out_dtype);
1076 let values = B::q_gather(dim, tensor, index.clone());
1077
1078 (values, index)
1079 }
1080
1081 /// Gets the maximum element of a tensor.
1082 ///
1083 /// # Arguments
1084 ///
1085 /// * `tensor` - The tensor to get the maximum elements of.
1086 ///
1087 /// # Returns
1088 ///
1089 /// A tensor with the maximum element of `tensor`.
1090 fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1091 let shape = tensor.shape();
1092 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1093
1094 B::q_max_abs_dim(tensor, 0)
1095 }
1096
1097 /// Gets the maximum elements of a tensor along an axis.
1098 ///
1099 /// # Arguments
1100 ///
1101 /// * `tensor` - The tensor to get the maximum elements of.
1102 /// * `dim` - The dimension along which to get the maximum elements.
1103 ///
1104 /// # Returns
1105 ///
1106 /// A tensor with the maximum elements of `tensor` along `dim`.
1107 fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1108 let int_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
1109 let index = B::q_argmax(B::q_abs(tensor.clone()), dim, int_dtype);
1110
1111 B::q_gather(dim, tensor, index)
1112 }
1113
1114 /// Tests if any element in the `tensor` evaluates to True.
1115 ///
1116 /// # Arguments
1117 ///
1118 /// * `tensor` - The tensor to test.
1119 ///
1120 /// # Returns
1121 ///
1122 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1123 fn q_any(tensor: QuantizedTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1124 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1125 let tensor_f = Self::dequantize(tensor, dtype);
1126 B::float_any(tensor_f, out_dtype)
1127 }
1128
1129 /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1130 ///
1131 /// # Arguments
1132 ///
1133 /// * `tensor` - The tensor to test.
1134 /// * `dim` - The axis along which to test.
1135 ///
1136 /// # Returns
1137 ///
1138 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1139 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1140 /// input evaluates to True, False otherwise.
1141 fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1142 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1143 let tensor_f = Self::dequantize(tensor, dtype);
1144 B::float_any_dim(tensor_f, dim, out_dtype)
1145 }
1146
1147 /// Tests if all elements in the `tensor` evaluate to True.
1148 ///
1149 /// # Arguments
1150 ///
1151 /// * `tensor` - The tensor to test.
1152 ///
1153 /// # Returns
1154 ///
1155 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1156 /// evaluate to True, False otherwise.
1157 fn q_all(tensor: QuantizedTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1158 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1159 let tensor_f = Self::dequantize(tensor, dtype);
1160 B::float_all(tensor_f, out_dtype)
1161 }
1162
1163 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1164 ///
1165 /// # Arguments
1166 ///
1167 /// * `tensor` - The tensor to test.
1168 /// * `dim` - The axis along which to test.
1169 ///
1170 /// # Returns
1171 ///
1172 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1173 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1174 /// evaluates to True, False otherwise.
1175 fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1176 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1177 let tensor_f = Self::dequantize(tensor, dtype);
1178 B::float_all_dim(tensor_f, dim, out_dtype)
1179 }
1180
1181 /// Sort the elements of the input `tensor` by value in along a given dimension.
1182 ///
1183 /// This sort is unstable (i.e., may reorder equal elements).
1184 ///
1185 /// # Arguments
1186 ///
1187 /// * `tensor` - The input tensor.
1188 /// * `dim` - The axis along which to sort.
1189 /// * `descending` - The sorting order.
1190 ///
1191 /// # Returns
1192 ///
1193 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1194 fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1195 // Default implementation. Backends can sort on the int values since qparams remain the same.
1196 dequant_op_quant!(
1197 float_op | tensor | B::float_sort(tensor, dim, descending),
1198 tensor
1199 )
1200 }
1201
1202 /// Sort the elements of the input `tensor` by value in along a given dimension.
1203 ///
1204 /// This sort is unstable (i.e., may reorder equal elements).
1205 ///
1206 /// # Arguments
1207 ///
1208 /// * `tensor` - The input tensor.
1209 /// * `dim` - The axis along which to sort.
1210 /// * `descending` - The sorting order.
1211 ///
1212 /// # Returns
1213 ///
1214 /// A tensor with the same shape as the input tensor and corresponding indices, where
1215 /// the elements are sorted by value and the indices map back to the original input tensor.
1216 fn q_sort_with_indices(
1217 tensor: QuantizedTensor<B>,
1218 dim: usize,
1219 descending: bool,
1220 out_dtype: IntDType,
1221 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1222 let scheme = *tensor.scheme();
1223 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1224
1225 let tensor_f = Self::dequantize(tensor, dtype);
1226 let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending, out_dtype);
1227
1228 (Self::quantize_dynamic(out_f, &scheme), indices)
1229 }
1230
1231 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1232 ///
1233 /// This sort is unstable (i.e., may reorder equal elements).
1234 ///
1235 /// # Arguments
1236 ///
1237 /// * `tensor` - The input tensor.
1238 /// * `dim` - The axis along which to sort.
1239 /// * `descending` - The sorting order.
1240 ///
1241 /// # Returns
1242 ///
1243 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1244 fn q_argsort(
1245 tensor: QuantizedTensor<B>,
1246 dim: usize,
1247 descending: bool,
1248 out_dtype: IntDType,
1249 ) -> IntTensor<B> {
1250 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1251 let tensor_f = Self::dequantize(tensor, dtype);
1252 B::float_argsort(tensor_f, dim, descending, out_dtype)
1253 }
1254}