burn_backend/backend/ops/qtensor.rs
1use alloc::vec::Vec;
2use burn_std::{
3 BoolDType, FloatDType, IntDType, Shape, Slice,
4 quantization::{QuantPropagation, QuantScheme},
5};
6
7use crate::{
8 Backend, ExecutionError, QTensorPrimitive, TensorData, TensorMetadata, TensorPrimitive,
9 get_device_settings,
10};
11use crate::{
12 Scalar,
13 tensor::{
14 BoolTensor, Device, FloatTensor, IntTensor, QuantizedTensor,
15 quantization::{
16 Calibration, QuantizationParametersPrimitive, compute_q_params, compute_range,
17 },
18 },
19};
20
21/// Automatically applies `dequantization -> float operation -> quantization`.
22///
23/// Used for tensor ops that should always return a quantized output.
24#[macro_export]
25macro_rules! dequant_op_quant {
26 // Binary tensor float op w/ lhs & rhs
27 (
28 float_op $float_op:expr, $t1:expr, $t2:expr
29 ) => {{
30 // Heuristic: prioritize lhs scheme
31 let scheme = $t1.scheme().clone();
32
33 let t1_f = Self::dequantize($t1);
34 let t2_f = Self::dequantize($t2);
35 #[allow(clippy::redundant_closure_call)]
36 let out_f = $float_op(t1_f, t2_f);
37
38 Self::quantize_dynamic(out_f, &scheme)
39 }};
40 // Unary tensor float op
41 (
42 float_op $float_op:expr, $tensor:expr
43 ) => {{
44 let scheme = $tensor.scheme().clone();
45 let dtype = get_device_settings::<B>(&Self::q_device(&$tensor)).float_dtype;
46
47 let tensor_f = Self::dequantize($tensor, dtype);
48 #[allow(clippy::redundant_closure_call)]
49 let out_f = $float_op(tensor_f);
50
51 Self::quantize_dynamic(out_f, &scheme)
52 }};
53}
54
55/// Automatically applies `dequantization -> float operation [-> quantization]`.
56///
57/// The output quantization step is optional.
58/// It is only performed when the input quantization scheme is propagated.
59#[macro_export]
60macro_rules! dequant_op_flow {
61 // Binary tensor float op w/ lhs & rhs
62 (
63 float_op $float_op:expr, $t1:expr, $t2:expr
64 ) => {{
65 // Heuristic: prioritize lhs scheme
66 let scheme = $t1.scheme().clone();
67 let propagation = $t1.propagation();
68 let dtype = get_device_settings::<B>(&Self::q_device(&$t1)).float_dtype;
69
70 let t1_f = Self::dequantize($t1, dtype);
71 let t2_f = Self::dequantize($t2, dtype);
72 #[allow(clippy::redundant_closure_call)]
73 let out_f = $float_op(t1_f, t2_f);
74
75 match propagation {
76 QuantPropagation::Propagate => {
77 TensorPrimitive::QFloat(Self::quantize_dynamic(out_f, &scheme))
78 }
79 QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
80 }
81 }};
82 // Unary tensor float op
83 (
84 float_op $float_op:expr, $tensor:expr
85 ) => {{
86 let scheme = $tensor.scheme().clone();
87 let propagation = $tensor.propagation();
88 let dtype = get_device_settings::<B>(&Self::q_device(&$tensor)).float_dtype;
89
90 let tensor_f = Self::dequantize($tensor, dtype);
91 #[allow(clippy::redundant_closure_call)]
92 let out_f = $float_op(tensor_f);
93
94 match propagation {
95 QuantPropagation::Propagate => {
96 TensorPrimitive::QFloat(Self::quantize_dynamic(out_f, &scheme))
97 }
98 QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
99 }
100 }};
101}
102
103/// Operations on quantized tensors.
104///
105/// # Return Type Semantics
106///
107/// The return type of each operation indicates how quantization is handled:
108///
109/// ## [`QuantizedTensor<B>`]
110/// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized
111/// representation. Implementations should avoid dequantizing when possible to maintain performance.
112/// For example, shape or layout changes such as expand or transpose preserve quantization.
113///
114/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is
115/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering
116/// the quantization parameters to match the new layout.*
117///
118///
119/// ## [`TensorPrimitive<B>`]
120/// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation
121/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`])
122/// returned in floating-point form ([`TensorPrimitive::Float`]).
123///
124/// This distinction allows for fine-grained control over mixed-precision flows while still operating
125/// through a unified API.
126pub trait QTensorOps<B: Backend> {
127 /// Creates a new tensor from the data structure.
128 ///
129 /// # Arguments
130 ///
131 /// * `data` - The data structure.
132 /// * `device` - The device to create the tensor on.
133 ///
134 /// # Returns
135 ///
136 /// The tensor with the given data.
137 fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
138
139 /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
140 fn quantize(
141 tensor: FloatTensor<B>,
142 scheme: &QuantScheme,
143 qparams: QuantizationParametersPrimitive<B>,
144 ) -> QuantizedTensor<B>;
145
146 /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
147 fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> {
148 // Dynamically compute min/max tensor range and qparams before quantizing
149 let (min, max) = compute_range::<B>(scheme, tensor.clone(), &Calibration::MinMax);
150 let qparams = compute_q_params(scheme, min, max);
151 Self::quantize(tensor, scheme, qparams)
152 }
153
154 /// Convert the tensor back to a higher precision data type.
155 fn dequantize(tensor: QuantizedTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
156
157 /// Gets the device of the tensor.
158 ///
159 /// # Arguments
160 ///
161 /// * `tensor` - The tensor.
162 ///
163 /// # Returns
164 ///
165 /// The device of the tensor.
166 fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
167
168 /// Moves the tensor to the given device.
169 ///
170 /// # Arguments
171 ///
172 /// * `tensor` - The tensor.
173 /// * `device` - The device to move the tensor to.
174 ///
175 /// # Returns
176 ///
177 /// The tensor on the given device.
178 fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
179
180 /// Reshapes a tensor.
181 ///
182 /// # Arguments
183 ///
184 /// * `tensor` - The tensor to reshape.
185 /// * `shape` - The new shape of the tensor.
186 ///
187 /// # Returns
188 ///
189 /// The tensor with the new shape.
190 fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
191
192 /// Converts the tensor to a data structure.
193 ///
194 /// # Arguments
195 ///
196 /// * `tensor` - The tensor.
197 ///
198 /// # Returns
199 ///
200 /// The data structure with the tensor's data.
201 fn q_into_data(
202 tensor: QuantizedTensor<B>,
203 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
204
205 /// Detaches a tensor from the computation graph.
206 fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
207 // Should only be overridden by autodiff backends.
208 tensor
209 }
210
211 /// Sets the `require_grad` flag of a tensor.
212 fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
213 // Should only be overridden by autodiff backends.
214 tensor
215 }
216
217 /// Returns the `require_grad` flag of a tensor.
218 fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
219 // Should only be overridden by autodiff backends.
220 false
221 }
222
223 /// Broadcasts the `tensor` to the given `shape`.
224 fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
225
226 /// Transposes a tensor.
227 ///
228 /// # Arguments
229 ///
230 /// * `tensor` - The tensor to transpose.
231 ///
232 /// # Returns
233 ///
234 /// The transposed tensor.
235 fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
236 let ndims = tensor.shape().num_dims();
237 Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
238 }
239
240 /// Swaps two dimensions of a tensor.
241 ///
242 /// # Arguments
243 ///
244 /// * `tensor` - The tensor to swap the dimensions of.
245 /// * `dim1` - The first dimension to swap.
246 /// * `dim2` - The second dimension to swap.
247 ///
248 /// # Returns
249 ///
250 /// The tensor with the dimensions swapped.
251 fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
252
253 /// Permutes the dimensions of a tensor.
254 ///
255 /// # Arguments
256 ///
257 /// * `tensor` - The tensor to permute the dimensions of.
258 /// * `axes` - The new order of the dimensions.
259 /// # Returns
260 ///
261 /// The tensor with the dimensions permuted.
262 fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
263
264 /// Reverse the order of elements in a tensor along the given axes.
265 ///
266 /// # Arguments
267 ///
268 /// * `tensor` - The tensor to reverse.
269 /// * `axes` - The axes to reverse.
270 ///
271 /// The tensor with the elements reversed.
272 fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
273
274 /// Select tensor elements along the given dimension corresponding for the given indices.
275 ///
276 /// # Arguments
277 ///
278 /// * `tensor` - The tensor to select from.
279 /// * `dim` - The dimension to select from.
280 /// * `indices` - The indices to select.
281 ///
282 /// # Returns
283 ///
284 /// The selected elements.
285 fn q_select(
286 tensor: QuantizedTensor<B>,
287 dim: usize,
288 indices: IntTensor<B>,
289 ) -> QuantizedTensor<B>;
290
291 /// Select tensor elements corresponding to the given slices.
292 ///
293 /// # Arguments
294 ///
295 /// * `tensor` - The tensor to select from.
296 /// * `slices` - The slices specifying ranges and steps for each dimension.
297 ///
298 /// # Returns
299 ///
300 /// The selected elements in a new tensor.
301 fn q_slice(tensor: QuantizedTensor<B>, slices: &[Slice]) -> QuantizedTensor<B>;
302
303 /// Gather elements from a tensor.
304 ///
305 /// # Arguments
306 ///
307 /// * `dim` - The dimension to gather from.
308 /// * `tensor` - The tensor to gather from.
309 /// * `indices` - The indices to gather.
310 ///
311 /// # Returns
312 ///
313 /// The gathered elements.
314 fn q_gather(
315 dim: usize,
316 tensor: QuantizedTensor<B>,
317 indices: IntTensor<B>,
318 ) -> QuantizedTensor<B> {
319 // Default implementation. Backends can gather on the quantized values when supported.
320 dequant_op_quant!(
321 float_op | tensor | B::float_gather(dim, tensor, indices),
322 tensor
323 )
324 }
325
326 /// Repeat the tensor along the given dimension.
327 ///
328 /// # Arguments
329 ///
330 /// * `tensor` - The tensor.
331 /// * `dim` - The dimension to repeat.
332 /// * `times` - The number of times to repeat the dimension.
333 ///
334 /// # Returns
335 ///
336 /// The tensor with the given dimension repeated.
337 fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
338 dequant_op_quant!(
339 float_op | tensor | B::float_repeat_dim(tensor, dim, times),
340 tensor
341 )
342 }
343
344 /// Adds two tensors together.
345 ///
346 /// # Arguments
347 ///
348 /// * `lhs` - The left hand side tensor.
349 /// * `rhs` - The right hand side tensor.
350 ///
351 /// # Returns
352 ///
353 /// The result of adding the two tensors together.
354 fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
355 dequant_op_flow!(float_op | lhs, rhs | B::float_add(lhs, rhs), lhs, rhs)
356 }
357
358 /// Adds a scalar to a tensor.
359 ///
360 /// # Arguments
361 ///
362 /// * `lhs` - The left hand side tensor.
363 /// * `rhs` - The right hand side scalar.
364 ///
365 /// # Returns
366 ///
367 /// The result of adding the scalar to the tensor.
368 fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
369 dequant_op_flow!(float_op | tensor | B::float_add_scalar(tensor, rhs), lhs)
370 }
371
372 /// Clamps a tensor under a minimum value.
373 ///
374 /// # Arguments
375 ///
376 /// * `tensor` - The tensor to clamp.
377 /// * `min` - The minimum value.
378 ///
379 /// # Returns
380 ///
381 /// The clamped tensor.
382 fn q_clamp_min(tensor: QuantizedTensor<B>, min: Scalar) -> TensorPrimitive<B> {
383 dequant_op_flow!(float_op | tensor | B::float_clamp_min(tensor, min), tensor)
384 }
385
386 /// Clamps a tensor over a maximum value.
387 ///
388 /// # Arguments
389 ///
390 /// * `tensor` - The tensor to clamp.
391 /// * `max` - The maximum value.
392 ///
393 /// # Returns
394 ///
395 /// The clamped tensor.
396 fn q_clamp_max(tensor: QuantizedTensor<B>, max: Scalar) -> TensorPrimitive<B> {
397 dequant_op_flow!(float_op | tensor | B::float_clamp_max(tensor, max), tensor)
398 }
399
400 /// Clamps a tensor between a minimum and maximum value.
401 ///
402 /// # Arguments
403 ///
404 /// * `tensor` - The tensor to clamp.
405 /// * `min` - The minimum value.
406 /// * `max` - The maximum value.
407 ///
408 /// # Returns
409 ///
410 /// The clamped tensor.
411 fn q_clamp(tensor: QuantizedTensor<B>, min: Scalar, max: Scalar) -> TensorPrimitive<B> {
412 dequant_op_flow!(float_op | tensor | B::float_clamp(tensor, min, max), tensor)
413 }
414
415 /// Subtracts two tensors.
416 ///
417 /// # Arguments
418 ///
419 /// * `lhs` - The left hand side tensor.
420 /// * `rhs` - The right hand side tensor.
421 ///
422 /// # Returns
423 ///
424 /// The result of subtracting the two tensors.
425 fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
426 dequant_op_flow!(float_op | lhs, rhs | B::float_sub(lhs, rhs), lhs, rhs)
427 }
428
429 /// Subtracts a scalar from a tensor.
430 ///
431 /// # Arguments
432 ///
433 /// * `lhs` - The left hand side tensor.
434 /// * `rhs` - The right hand side scalar.
435 ///
436 /// # Returns
437 ///
438 /// The result of subtracting the scalar from the tensor.
439 fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
440 dequant_op_flow!(float_op | tensor | B::float_sub_scalar(tensor, rhs), lhs)
441 }
442
443 /// Multiplies two tensors together element-wise.
444 fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
445 dequant_op_flow!(float_op | lhs, rhs | B::float_mul(lhs, rhs), lhs, rhs)
446 }
447
448 /// Multiplies a tensor by a scalar.
449 ///
450 /// # Arguments
451 ///
452 /// * `lhs` - The left hand side tensor.
453 /// * `rhs` - The right hand side scalar.
454 ///
455 /// # Returns
456 ///
457 /// The result of multiplying the tensor by the scalar.
458 fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
459 dequant_op_flow!(float_op | tensor | B::float_mul_scalar(tensor, rhs), lhs)
460 }
461
462 /// Divides two tensors element-wise.
463 ///
464 /// # Arguments
465 ///
466 /// * `lhs` - The left hand side tensor.
467 /// * `rhs` - The right hand side tensor.
468 ///
469 /// # Returns
470 ///
471 /// The result of dividing the two tensors.
472 fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
473 dequant_op_flow!(float_op | lhs, rhs | B::float_div(lhs, rhs), lhs, rhs)
474 }
475
476 /// Divides a tensor by a scalar.
477 ///
478 /// # Arguments
479 ///
480 /// * `lhs` - The left hand side tensor.
481 /// * `rhs` - The right hand side scalar.
482 ///
483 /// # Returns
484 ///
485 /// The result of dividing the tensor by the scalar.
486 fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
487 dequant_op_flow!(float_op | tensor | B::float_div_scalar(tensor, rhs), lhs)
488 }
489
490 /// Multiplies two tensors together using matrix multiplication.
491 ///
492 /// # Arguments
493 ///
494 /// * `lhs` - The left hand side tensor.
495 /// * `rhs` - The right hand side tensor.
496 ///
497 /// # Returns
498 ///
499 /// The result of multiplying the two tensors together using matrix multiplication.
500 fn q_matmul(lhs: TensorPrimitive<B>, rhs: TensorPrimitive<B>) -> TensorPrimitive<B> {
501 let mut propagation = QuantPropagation::Inhibit;
502 let mut scheme = QuantScheme::default();
503 let mut dtype = None;
504
505 let lhs = match lhs {
506 TensorPrimitive::Float(lhs) => lhs,
507 TensorPrimitive::QFloat(lhs) => {
508 propagation = lhs.propagation();
509 scheme = *lhs.scheme();
510 let float_dtype = get_device_settings::<B>(&Self::q_device(&lhs)).float_dtype;
511 dtype = Some(float_dtype);
512
513 Self::dequantize(lhs, float_dtype)
514 }
515 };
516 let rhs = match rhs {
517 TensorPrimitive::Float(rhs) => rhs,
518 TensorPrimitive::QFloat(rhs) => {
519 propagation = rhs.propagation();
520 scheme = *rhs.scheme();
521 let float_dtype = dtype
522 .unwrap_or_else(|| get_device_settings::<B>(&Self::q_device(&rhs)).float_dtype);
523
524 Self::dequantize(rhs, float_dtype)
525 }
526 };
527
528 let out_f = B::float_matmul(lhs, rhs);
529 match propagation {
530 QuantPropagation::Propagate => {
531 TensorPrimitive::QFloat(<Self>::quantize_dynamic(out_f, &scheme))
532 }
533 QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),
534 }
535 }
536
537 /// Negates a tensor element-wise.
538 fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
539 dequant_op_flow!(float_op | tensor | B::float_neg(tensor), tensor)
540 }
541
542 /// Calculates the reciprocals element-wise
543 fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
544 dequant_op_flow!(float_op | tensor | B::float_recip(tensor), tensor)
545 }
546
547 /// Sum of all elements in a tensor.
548 ///
549 /// # Arguments
550 ///
551 /// * `tensor` - The tensor to sum.
552 ///
553 /// # Returns
554 ///
555 /// A scalar tensor with the sum of all elements in `tensor`.
556 fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
557 dequant_op_flow!(float_op | tensor | B::float_sum(tensor), tensor)
558 }
559
560 /// Sum of all elements in a tensor along a dimension.
561 ///
562 /// # Arguments
563 ///
564 /// * `tensor` - The tensor to sum.
565 /// * `dim` - The dimension along which to sum.
566 ///
567 /// # Returns
568 ///
569 /// A tensor with the sum of all elements in `tensor` along `dim`.
570 fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
571 dequant_op_flow!(float_op | tensor | B::float_sum_dim(tensor, dim), tensor)
572 }
573
574 /// Product of all elements in a tensor.
575 ///
576 /// # Arguments
577 ///
578 /// * `tensor` - The tensor to product.
579 ///
580 /// # Returns
581 ///
582 /// A scalar tensor with the product of all elements in `tensor`.
583 fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
584 dequant_op_flow!(float_op | tensor | B::float_prod(tensor), tensor)
585 }
586
587 /// Product of all elements in a tensor along a dimension.
588 ///
589 /// # Arguments
590 ///
591 /// * `tensor` - The tensor to product.
592 ///
593 /// # Returns
594 ///
595 /// A tensor with the product of all elements in `tensor` along `dim`.
596 fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
597 dequant_op_flow!(float_op | tensor | B::float_prod_dim(tensor, dim), tensor)
598 }
599
600 /// Mean of all elements in a tensor.
601 ///
602 /// # Arguments
603 ///
604 /// * `tensor` - The tensor to mean.
605 ///
606 /// # Returns
607 ///
608 /// A scalar tensor with the mean of all elements in `tensor`.
609 fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
610 dequant_op_flow!(float_op | tensor | B::float_mean(tensor), tensor)
611 }
612
613 /// Mean of all elements in a tensor along a dimension.
614 ///
615 /// # Arguments
616 ///
617 /// * `tensor` - The tensor to mean.
618 /// * `dim` - The dimension along which to mean.
619 ///
620 /// # Returns
621 ///
622 /// A tensor with the mean of all elements in `tensor` along `dim`.
623 fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
624 dequant_op_flow!(float_op | tensor | B::float_mean_dim(tensor, dim), tensor)
625 }
626
627 /// Computes the cumulative sum of elements along a dimension.
628 ///
629 /// # Arguments
630 ///
631 /// * `tensor` - The tensor to compute the cumulative sum of.
632 /// * `dim` - The dimension along which to compute the cumulative sum.
633 ///
634 /// # Returns
635 ///
636 /// A tensor with the same shape where each element is the cumulative sum
637 /// of all elements up to and including that position along the dimension.
638 fn q_cumsum(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
639 dequant_op_flow!(float_op | tensor | B::float_cumsum(tensor, dim), tensor)
640 }
641
642 /// Computes the cumulative product of elements along a dimension.
643 ///
644 /// # Arguments
645 ///
646 /// * `tensor` - The tensor to compute the cumulative product of.
647 /// * `dim` - The dimension along which to compute the cumulative product.
648 ///
649 /// # Returns
650 ///
651 /// A tensor with the same shape where each element is the cumulative product
652 /// of all elements up to and including that position along the dimension.
653 fn q_cumprod(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
654 dequant_op_flow!(float_op | tensor | B::float_cumprod(tensor, dim), tensor)
655 }
656
657 /// Computes the cumulative minimum of elements along a dimension.
658 ///
659 /// # Arguments
660 ///
661 /// * `tensor` - The tensor to compute the cumulative minimum of.
662 /// * `dim` - The dimension along which to compute the cumulative minimum.
663 ///
664 /// # Returns
665 ///
666 /// A tensor with the same shape where each element is the minimum
667 /// of all elements up to and including that position along the dimension.
668 fn q_cummin(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
669 dequant_op_flow!(float_op | tensor | B::float_cummin(tensor, dim), tensor)
670 }
671
672 /// Computes the cumulative maximum of elements along a dimension.
673 ///
674 /// # Arguments
675 ///
676 /// * `tensor` - The tensor to compute the cumulative maximum of.
677 /// * `dim` - The dimension along which to compute the cumulative maximum.
678 ///
679 /// # Returns
680 ///
681 /// A tensor with the same shape where each element is the maximum
682 /// of all elements up to and including that position along the dimension.
683 fn q_cummax(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {
684 dequant_op_flow!(float_op | tensor | B::float_cummax(tensor, dim), tensor)
685 }
686
687 /// Returns a new tensor with exponential values.
688 ///
689 /// # Arguments
690 ///
691 /// * `tensor` - The tensor to exponentiate.
692 ///
693 /// # Returns
694 ///
695 /// A tensor with the same shape as `tensor` with exponential values.
696 fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
697 dequant_op_flow!(float_op | tensor | B::float_exp(tensor), tensor)
698 }
699
700 /// Returns a new tensor with natural logarithm values.
701 ///
702 /// # Arguments
703 ///
704 /// * `tensor` - The tensor to take the logarithm of.
705 ///
706 /// # Returns
707 ///
708 /// A tensor with the same shape as `tensor` with natural logarithm values.
709 fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
710 dequant_op_flow!(float_op | tensor | B::float_log(tensor), tensor)
711 }
712
713 /// Returns a new tensor with logarithm values of (1 + Xi).
714 ///
715 /// # Arguments
716 ///
717 /// * `tensor` - The tensor to take the logarithm of.
718 ///
719 /// # Returns
720 ///
721 /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
722 fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
723 dequant_op_flow!(float_op | tensor | B::float_log1p(tensor), tensor)
724 }
725
726 /// Element-wise power with another tensor.
727 ///
728 /// # Arguments
729 ///
730 /// * `lhs` - The left hand side tensor.
731 /// * `rhs` - The right hand side tensor.
732 ///
733 /// # Returns
734 ///
735 /// The elements of `lhs` raised to the power of the elements of `rhs`.
736 fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
737 dequant_op_flow!(float_op | lhs, rhs | B::float_powf(lhs, rhs), lhs, rhs)
738 }
739
740 /// Element-wise power with an IntTensor.
741 ///
742 /// # Arguments
743 ///
744 /// * `lhs` - The left hand side tensor.
745 /// * `rhs` - The right hand side floatTensor.
746 ///
747 /// # Returns
748 ///
749 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
750 fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> {
751 dequant_op_flow!(float_op | tensor | B::float_powi(tensor, rhs), lhs)
752 }
753
754 /// Element-wise power with an int scalar.
755 ///
756 /// # Arguments
757 ///
758 /// * `lhs` - The left hand side tensor.
759 /// * `rhs` - The right hand side scalar.
760 ///
761 /// # Returns
762 ///
763 /// The elements of `lhs` raised to the value of `rhs`.
764 fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {
765 dequant_op_flow!(float_op | tensor | B::float_powi_scalar(tensor, rhs), lhs)
766 }
767
768 /// Element-wise power with a float scalar.
769 ///
770 /// # Arguments
771 ///
772 /// * `tensor` - The tensor to exponentiate.
773 /// * `value` - The exponent.
774 ///
775 /// # Returns
776 ///
777 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
778 fn q_powf_scalar(tensor: QuantizedTensor<B>, value: Scalar) -> TensorPrimitive<B> {
779 dequant_op_flow!(
780 float_op | tensor | B::float_powf_scalar(tensor, value),
781 tensor
782 )
783 }
784
785 /// Returns a new tensor with square root values.
786 ///
787 /// # Arguments
788 ///
789 /// * `tensor` - The tensor to take the square root of.
790 ///
791 /// # Returns
792 ///
793 /// A tensor with the same shape as `tensor` with square root values.
794 fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
795 dequant_op_flow!(float_op | tensor | B::float_sqrt(tensor), tensor)
796 }
797
798 /// Returns a new tensor with absolute values.
799 ///
800 /// # Arguments
801 ///
802 /// * `tensor` - The tensor to take absolute value of.
803 ///
804 /// # Returns
805 ///
806 /// A tensor with the same shape as `tensor` with absolute values.
807 fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
808 dequant_op_quant!(float_op | tensor | B::float_abs(tensor), tensor)
809 }
810
811 /// Returns a new tensor with cosine values.
812 ///
813 /// # Arguments
814 ///
815 /// * `tensor` - The tensor to take the cosine of.
816 ///
817 /// # Returns
818 ///
819 /// A tensor with the same shape as `tensor` with cosine values.
820 fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
821 dequant_op_flow!(float_op | tensor | B::float_cos(tensor), tensor)
822 }
823
824 /// Returns a new tensor with sine values.
825 ///
826 /// # Arguments
827 ///
828 /// * `tensor` - The tensor to take the sine of.
829 ///
830 /// # Returns
831 ///
832 /// A tensor with the same shape as `tensor` with sine values.
833 fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
834 dequant_op_flow!(float_op | tensor | B::float_sin(tensor), tensor)
835 }
836
837 /// Returns a new tensor with tangent values.
838 ///
839 /// # Arguments
840 ///
841 /// * `tensor` - The tensor to take the tangent of.
842 ///
843 /// # Returns
844 ///
845 /// A tensor with the same shape as `tensor` with tangent values.
846 fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
847 dequant_op_flow!(float_op | tensor | B::float_tan(tensor), tensor)
848 }
849
850 /// Returns a new tensor with hyperbolic cosine values.
851 ///
852 /// # Arguments
853 ///
854 /// * `tensor` - The tensor to take the hyperbolic cosine of.
855 ///
856 /// # Returns
857 ///
858 /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
859 fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
860 dequant_op_flow!(float_op | tensor | B::float_cosh(tensor), tensor)
861 }
862
863 /// Returns a new tensor with hyperbolic sine values.
864 ///
865 /// # Arguments
866 ///
867 /// * `tensor` - The tensor to take the hyperbolic sine of.
868 ///
869 /// # Returns
870 ///
871 /// A tensor with the same shape as `tensor` with hyperbolic sine values.
872 fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
873 dequant_op_flow!(float_op | tensor | B::float_sinh(tensor), tensor)
874 }
875
876 /// Returns a new tensor with hyperbolic tangent values.
877 ///
878 /// # Arguments
879 ///
880 /// * `tensor` - The tensor to take the hyperbolic tangent of.
881 ///
882 /// # Returns
883 ///
884 /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
885 fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
886 dequant_op_flow!(float_op | tensor | B::float_tanh(tensor), tensor)
887 }
888
889 /// Returns a new tensor with the error function values.
890 ///
891 /// # Arguments
892 ///
893 /// * `tensor` - The tensor to take the error function of.
894 ///
895 /// # Returns
896 ///
897 /// A tensor with the same shape as `tensor` with error function values.
898 fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {
899 dequant_op_flow!(float_op | tensor | B::float_erf(tensor), tensor)
900 }
901
902 /// Concatenates tensors along a dimension.
903 ///
904 /// # Arguments
905 ///
906 /// * `tensors` - The tensors to concatenate.
907 /// * `dim` - The dimension along which to concatenate.
908 ///
909 /// # Returns
910 ///
911 /// A tensor with the concatenated tensors along `dim`.
912 fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
913 // Heuristic: prioritize first tensor scheme
914 let first = tensors.first().unwrap();
915 let scheme = *first.scheme();
916 let dtype = get_device_settings::<B>(&Self::q_device(first)).float_dtype;
917
918 let tensor_f = tensors
919 .into_iter()
920 .map(|tensor| Self::dequantize(tensor, dtype))
921 .collect();
922
923 let out_f = B::float_cat(tensor_f, dim);
924
925 Self::quantize_dynamic(out_f, &scheme)
926 }
927
928 /// Gets the indices of the maximum elements of a tensor along an axis.
929 ///
930 /// # Arguments
931 ///
932 /// * `tensor` - The tensor to get the maximum elements of.
933 /// * `dim` - The dimension along which to get the maximum elements.
934 /// * `out_dtype` - The output tensor dtype.
935 ///
936 /// # Returns
937 ///
938 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
939 fn q_argmax(tensor: QuantizedTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B> {
940 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
941 let tensor_f = Self::dequantize(tensor, dtype);
942 B::float_argmax(tensor_f, dim, out_dtype)
943 }
944
945 /// Gets the indices of the minimum elements of a tensor along an axis.
946 ///
947 /// # Arguments
948 ///
949 /// * `tensor` - The tensor to get the minimum elements of.
950 /// * `dim` - The dimension along which to get the minimum elements.
951 /// * `out_dtype` - The output tensor dtype.
952 ///
953 /// # Returns
954 ///
955 /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
956 fn q_argmin(tensor: QuantizedTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B> {
957 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
958 let tensor_f = Self::dequantize(tensor, dtype);
959 B::float_argmin(tensor_f, dim, out_dtype)
960 }
961
962 /// Gets the maximum element of a tensor.
963 ///
964 /// # Arguments
965 ///
966 /// * `tensor` - The tensor to get the maximum elements of.
967 ///
968 /// # Returns
969 ///
970 /// A tensor with the maximum element of `tensor`.
971 fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
972 let shape = tensor.shape();
973 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
974
975 B::q_max_dim(tensor, 0)
976 }
977
978 /// Gets the maximum elements of a tensor along an axis.
979 ///
980 /// # Arguments
981 ///
982 /// * `tensor` - The tensor to get the maximum elements of.
983 /// * `dim` - The dimension along which to get the maximum elements.
984 ///
985 /// # Returns
986 ///
987 /// A tensor with the maximum elements of `tensor` along `dim`.
988 fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
989 let int_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
990 let index = B::q_argmax(tensor.clone(), dim, int_dtype);
991
992 B::q_gather(dim, tensor, index)
993 }
994
995 /// Gets the maximum elements of a tensor along an axis and their indices.
996 ///
997 /// # Arguments
998 ///
999 /// * `tensor` - The tensor to get the maximum elements of.
1000 /// * `dim` - The dimension along which to get the maximum elements.
1001 ///
1002 /// # Returns
1003 ///
1004 /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1005 fn q_max_dim_with_indices(
1006 tensor: QuantizedTensor<B>,
1007 dim: usize,
1008 out_dtype: IntDType,
1009 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1010 let index = B::q_argmax(tensor.clone(), dim, out_dtype);
1011 let values = B::q_gather(dim, tensor, index.clone());
1012
1013 (values, index)
1014 }
1015
1016 /// Gets the minimum element of a tensor.
1017 ///
1018 /// # Arguments
1019 ///
1020 /// * `tensor` - The tensor to get the minimum elements of.
1021 ///
1022 /// # Returns
1023 ///
1024 /// A tensor with the minimum element of `tensor`.
1025 fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1026 let shape = tensor.shape();
1027 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1028
1029 B::q_min_dim(tensor, 0)
1030 }
1031
1032 /// Gets the minimum elements of a tensor along an axis.
1033 ///
1034 /// # Arguments
1035 ///
1036 /// * `tensor` - The tensor to get the minimum elements of.
1037 /// * `dim` - The dimension along which to get the minimum elements.
1038 ///
1039 /// # Returns
1040 ///
1041 /// A tensor with the minimum elements of `tensor` along `dim`.
1042 fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1043 let int_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
1044 let index = B::q_argmin(tensor.clone(), dim, int_dtype);
1045
1046 B::q_gather(dim, tensor, index)
1047 }
1048
1049 /// Gets the minimum elements of a tensor along an axis and their indices.
1050 ///
1051 /// # Arguments
1052 ///
1053 /// * `tensor` - The tensor to get the minimum elements of.
1054 /// * `dim` - The dimension along which to get the minimum elements.
1055 ///
1056 /// # Returns
1057 ///
1058 /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1059 fn q_min_dim_with_indices(
1060 tensor: QuantizedTensor<B>,
1061 dim: usize,
1062 out_dtype: IntDType,
1063 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1064 let index = B::q_argmin(tensor.clone(), dim, out_dtype);
1065 let values = B::q_gather(dim, tensor, index.clone());
1066
1067 (values, index)
1068 }
1069
1070 /// Gets the maximum element of a tensor.
1071 ///
1072 /// # Arguments
1073 ///
1074 /// * `tensor` - The tensor to get the maximum elements of.
1075 ///
1076 /// # Returns
1077 ///
1078 /// A tensor with the maximum element of `tensor`.
1079 fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1080 let shape = tensor.shape();
1081 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1082
1083 B::q_max_abs_dim(tensor, 0)
1084 }
1085
1086 /// Gets the maximum elements of a tensor along an axis.
1087 ///
1088 /// # Arguments
1089 ///
1090 /// * `tensor` - The tensor to get the maximum elements of.
1091 /// * `dim` - The dimension along which to get the maximum elements.
1092 ///
1093 /// # Returns
1094 ///
1095 /// A tensor with the maximum elements of `tensor` along `dim`.
1096 fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1097 let int_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
1098 let index = B::q_argmax(B::q_abs(tensor.clone()), dim, int_dtype);
1099
1100 B::q_gather(dim, tensor, index)
1101 }
1102
1103 /// Tests if any element in the `tensor` evaluates to True.
1104 ///
1105 /// # Arguments
1106 ///
1107 /// * `tensor` - The tensor to test.
1108 ///
1109 /// # Returns
1110 ///
1111 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1112 fn q_any(tensor: QuantizedTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1113 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1114 let tensor_f = Self::dequantize(tensor, dtype);
1115 B::float_any(tensor_f, out_dtype)
1116 }
1117
1118 /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1119 ///
1120 /// # Arguments
1121 ///
1122 /// * `tensor` - The tensor to test.
1123 /// * `dim` - The axis along which to test.
1124 ///
1125 /// # Returns
1126 ///
1127 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1128 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1129 /// input evaluates to True, False otherwise.
1130 fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1131 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1132 let tensor_f = Self::dequantize(tensor, dtype);
1133 B::float_any_dim(tensor_f, dim, out_dtype)
1134 }
1135
1136 /// Tests if all elements in the `tensor` evaluate to True.
1137 ///
1138 /// # Arguments
1139 ///
1140 /// * `tensor` - The tensor to test.
1141 ///
1142 /// # Returns
1143 ///
1144 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1145 /// evaluate to True, False otherwise.
1146 fn q_all(tensor: QuantizedTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1147 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1148 let tensor_f = Self::dequantize(tensor, dtype);
1149 B::float_all(tensor_f, out_dtype)
1150 }
1151
1152 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1153 ///
1154 /// # Arguments
1155 ///
1156 /// * `tensor` - The tensor to test.
1157 /// * `dim` - The axis along which to test.
1158 ///
1159 /// # Returns
1160 ///
1161 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1162 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1163 /// evaluates to True, False otherwise.
1164 fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1165 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1166 let tensor_f = Self::dequantize(tensor, dtype);
1167 B::float_all_dim(tensor_f, dim, out_dtype)
1168 }
1169
1170 /// Sort the elements of the input `tensor` by value in along a given dimension.
1171 ///
1172 /// This sort is unstable (i.e., may reorder equal elements).
1173 ///
1174 /// # Arguments
1175 ///
1176 /// * `tensor` - The input tensor.
1177 /// * `dim` - The axis along which to sort.
1178 /// * `descending` - The sorting order.
1179 ///
1180 /// # Returns
1181 ///
1182 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1183 fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1184 // Default implementation. Backends can sort on the int values since qparams remain the same.
1185 dequant_op_quant!(
1186 float_op | tensor | B::float_sort(tensor, dim, descending),
1187 tensor
1188 )
1189 }
1190
1191 /// Sort the elements of the input `tensor` by value in along a given dimension.
1192 ///
1193 /// This sort is unstable (i.e., may reorder equal elements).
1194 ///
1195 /// # Arguments
1196 ///
1197 /// * `tensor` - The input tensor.
1198 /// * `dim` - The axis along which to sort.
1199 /// * `descending` - The sorting order.
1200 ///
1201 /// # Returns
1202 ///
1203 /// A tensor with the same shape as the input tensor and corresponding indices, where
1204 /// the elements are sorted by value and the indices map back to the original input tensor.
1205 fn q_sort_with_indices(
1206 tensor: QuantizedTensor<B>,
1207 dim: usize,
1208 descending: bool,
1209 out_dtype: IntDType,
1210 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1211 let scheme = *tensor.scheme();
1212 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1213
1214 let tensor_f = Self::dequantize(tensor, dtype);
1215 let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending, out_dtype);
1216
1217 (Self::quantize_dynamic(out_f, &scheme), indices)
1218 }
1219
1220 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1221 ///
1222 /// This sort is unstable (i.e., may reorder equal elements).
1223 ///
1224 /// # Arguments
1225 ///
1226 /// * `tensor` - The input tensor.
1227 /// * `dim` - The axis along which to sort.
1228 /// * `descending` - The sorting order.
1229 ///
1230 /// # Returns
1231 ///
1232 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1233 fn q_argsort(
1234 tensor: QuantizedTensor<B>,
1235 dim: usize,
1236 descending: bool,
1237 out_dtype: IntDType,
1238 ) -> IntTensor<B> {
1239 let dtype = get_device_settings::<B>(&Self::q_device(&tensor)).float_dtype;
1240 let tensor_f = Self::dequantize(tensor, dtype);
1241 B::float_argsort(tensor_f, dim, descending, out_dtype)
1242 }
1243}