burn_tensor/tensor/ops/qtensor.rs
1use alloc::vec::Vec;
2use core::{future::Future, ops::Range};
3
4use crate::{
5 Device, Shape, TensorData, TensorMetadata,
6 backend::Backend,
7 quantization::{
8 Calibration, QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme,
9 },
10};
11
12use super::{BoolTensor, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor};
13
14/// Automatically applies dequantization -> float operation -> quantization.
15#[macro_export]
16macro_rules! dequant_op_quant {
17 // Binary tensor float op w/ lhs & rhs
18 (
19 ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
20 ) => {{
21 // Heuristic: prioritize lhs scheme
22 let scheme = $t1.scheme().clone();
23
24 let t1_f = <$ty>::dequantize($t1);
25 let t2_f = <$ty>::dequantize($t2);
26 #[allow(clippy::redundant_closure_call)]
27 let out_f = $float_op(t1_f, t2_f);
28
29 <$ty>::quantize_dynamic(out_f, &scheme)
30 }};
31 // Unary tensor float op
32 (
33 ty $ty:ty, float_op $float_op:expr, $tensor:expr
34 ) => {{
35 let scheme = $tensor.scheme().clone();
36
37 let tensor_f = <$ty>::dequantize($tensor);
38 #[allow(clippy::redundant_closure_call)]
39 let out_f = $float_op(tensor_f);
40
41 <$ty>::quantize_dynamic(out_f, &scheme)
42 }};
43}
44
45/// Quantized Tensor API for basic operations, see [tensor](crate::Tensor)
46/// for documentation on each function.
47pub trait QTensorOps<B: Backend> {
48 /// Creates a new tensor from the data structure.
49 ///
50 /// # Arguments
51 ///
52 /// * `data` - The data structure.
53 /// * `device` - The device to create the tensor on.
54 ///
55 /// # Returns
56 ///
57 /// The tensor with the given data.
58 fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
59
60 /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
61 fn quantize(
62 tensor: FloatTensor<B>,
63 scheme: &QuantizationScheme,
64 qparams: QuantizationParametersPrimitive<B>,
65 ) -> QuantizedTensor<B>;
66
67 /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
68 fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantizationScheme) -> QuantizedTensor<B> {
69 // Dynamically compute min/max tensor range and qparams before quantizing
70 let (min, max) = scheme.compute_range_primitive::<B>(tensor.clone(), &Calibration::MinMax);
71 let qparams = scheme.compute_q_params_primitive(min, max);
72 Self::quantize(tensor, scheme, qparams)
73 }
74
75 /// Convert the tensor back to a higher precision data type.
76 fn dequantize(tensor: QuantizedTensor<B>) -> FloatTensor<B>;
77
78 /// Gets the device of the tensor.
79 ///
80 /// # Arguments
81 ///
82 /// * `tensor` - The tensor.
83 ///
84 /// # Returns
85 ///
86 /// The device of the tensor.
87 fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
88
89 /// Moves the tensor to the given device.
90 ///
91 /// # Arguments
92 ///
93 /// * `tensor` - The tensor.
94 /// * `device` - The device to move the tensor to.
95 ///
96 /// # Returns
97 ///
98 /// The tensor on the given device.
99 fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
100
101 /// Reshapes a tensor.
102 ///
103 /// # Arguments
104 ///
105 /// * `tensor` - The tensor to reshape.
106 /// * `shape` - The new shape of the tensor.
107 ///
108 /// # Returns
109 ///
110 /// The tensor with the new shape.
111 fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
112
113 /// Converts the tensor to a data structure.
114 ///
115 /// # Arguments
116 ///
117 /// * `tensor` - The tensor.
118 ///
119 /// # Returns
120 ///
121 /// The data structure with the tensor's data.
122 fn q_into_data(tensor: QuantizedTensor<B>)
123 -> impl Future<Output = TensorData> + 'static + Send;
124
125 /// Detaches a tensor from the computation graph.
126 fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
127 // Should only be overridden by autodiff backends.
128 tensor
129 }
130
131 /// Sets the `require_grad` flag of a tensor.
132 fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
133 // Should only be overridden by autodiff backends.
134 tensor
135 }
136
137 /// Returns the `require_grad` flag of a tensor.
138 fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
139 // Should only be overridden by autodiff backends.
140 false
141 }
142
143 /// Repeat the tensor along the given dimension.
144 ///
145 /// # Arguments
146 ///
147 /// * `tensor` - The tensor.
148 /// * `dim` - The dimension to repeat.
149 /// * `times` - The number of times to repeat the dimension.
150 ///
151 /// # Returns
152 ///
153 /// The tensor with the given dimension repeated.
154 fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
155 dequant_op_quant!(
156 ty Self,
157 float_op |tensor| B::float_repeat_dim(tensor, dim, times),
158 tensor
159 )
160 }
161
162 /// Adds two tensors together.
163 ///
164 /// # Arguments
165 ///
166 /// * `lhs` - The left hand side tensor.
167 /// * `rhs` - The right hand side tensor.
168 ///
169 /// # Returns
170 ///
171 /// The result of adding the two tensors together.
172 fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
173 dequant_op_quant!(
174 ty Self,
175 float_op |lhs, rhs| B::float_add(lhs, rhs),
176 lhs,
177 rhs
178 )
179 }
180
181 /// Adds a scalar to a tensor.
182 ///
183 /// # Arguments
184 ///
185 /// * `lhs` - The left hand side tensor.
186 /// * `rhs` - The right hand side scalar.
187 ///
188 /// # Returns
189 ///
190 /// The result of adding the scalar to the tensor.
191 fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
192 let scheme = *lhs.scheme();
193
194 let lhs_f = Self::dequantize(lhs);
195 let out_f = B::float_add_scalar(lhs_f, rhs);
196
197 Self::quantize_dynamic(out_f, &scheme)
198 }
199
200 /// Clamps a tensor under a minimum value.
201 ///
202 /// # Arguments
203 ///
204 /// * `tensor` - The tensor to clamp.
205 /// * `min` - The minimum value.
206 ///
207 /// # Returns
208 ///
209 /// The clamped tensor.
210 fn q_clamp_min(tensor: QuantizedTensor<B>, min: FloatElem<B>) -> QuantizedTensor<B> {
211 let scheme = *tensor.scheme();
212
213 let tensor_f = Self::dequantize(tensor);
214 let out_f = B::float_clamp_min(tensor_f, min);
215
216 Self::quantize_dynamic(out_f, &scheme)
217 }
218
219 /// Clamps a tensor over a maximum value.
220 ///
221 /// # Arguments
222 ///
223 /// * `tensor` - The tensor to clamp.
224 /// * `max` - The maximum value.
225 ///
226 /// # Returns
227 ///
228 /// The clamped tensor.
229 fn q_clamp_max(tensor: QuantizedTensor<B>, max: FloatElem<B>) -> QuantizedTensor<B> {
230 let scheme = *tensor.scheme();
231
232 let tensor_f = Self::dequantize(tensor);
233 let out_f = B::float_clamp_max(tensor_f, max);
234
235 Self::quantize_dynamic(out_f, &scheme)
236 }
237
238 /// Clamps a tensor between a minimum and maximum value.
239 ///
240 /// # Arguments
241 ///
242 /// * `tensor` - The tensor to clamp.
243 /// * `min` - The minimum value.
244 /// * `max` - The maximum value.
245 ///
246 /// # Returns
247 ///
248 /// The clamped tensor.
249 fn q_clamp(
250 tensor: QuantizedTensor<B>,
251 min: FloatElem<B>,
252 max: FloatElem<B>,
253 ) -> QuantizedTensor<B> {
254 let scheme = *tensor.scheme();
255
256 let tensor_f = Self::dequantize(tensor);
257 let out_f = B::float_clamp(tensor_f, min, max);
258
259 Self::quantize_dynamic(out_f, &scheme)
260 }
261
262 /// Subtracts two tensors.
263 ///
264 /// # Arguments
265 ///
266 /// * `lhs` - The left hand side tensor.
267 /// * `rhs` - The right hand side tensor.
268 ///
269 /// # Returns
270 ///
271 /// The result of subtracting the two tensors.
272 fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
273 dequant_op_quant!(
274 ty Self,
275 float_op |lhs, rhs| B::float_sub(lhs, rhs),
276 lhs,
277 rhs
278 )
279 }
280
281 /// Subtracts a scalar from a tensor.
282 ///
283 /// # Arguments
284 ///
285 /// * `lhs` - The left hand side tensor.
286 /// * `rhs` - The right hand side scalar.
287 ///
288 /// # Returns
289 ///
290 /// The result of subtracting the scalar from the tensor.
291 fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
292 let scheme = *lhs.scheme();
293
294 let lhs_f = Self::dequantize(lhs);
295 let out_f = B::float_sub_scalar(lhs_f, rhs);
296
297 Self::quantize_dynamic(out_f, &scheme)
298 }
299
300 /// Multiplies two tensors together element-wise.
301 fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
302 dequant_op_quant!(
303 ty Self,
304 float_op |lhs, rhs| B::float_mul(lhs, rhs),
305 lhs,
306 rhs
307 )
308 }
309
310 /// Multiplies a tensor by a scalar.
311 ///
312 /// # Arguments
313 ///
314 /// * `lhs` - The left hand side tensor.
315 /// * `rhs` - The right hand side scalar.
316 ///
317 /// # Returns
318 ///
319 /// The result of multiplying the tensor by the scalar.
320 fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
321 let scheme = *lhs.scheme();
322
323 let lhs_f = Self::dequantize(lhs);
324 let out_f = B::float_mul_scalar(lhs_f, rhs);
325
326 Self::quantize_dynamic(out_f, &scheme)
327 }
328
329 /// Divides two tensors element-wise.
330 ///
331 /// # Arguments
332 ///
333 /// * `lhs` - The left hand side tensor.
334 /// * `rhs` - The right hand side tensor.
335 ///
336 /// # Returns
337 ///
338 /// The result of dividing the two tensors.
339 fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
340 dequant_op_quant!(
341 ty Self,
342 float_op |lhs, rhs| B::float_div(lhs, rhs),
343 lhs,
344 rhs
345 )
346 }
347
348 /// Divides a tensor by a scalar.
349 ///
350 /// # Arguments
351 ///
352 /// * `lhs` - The left hand side tensor.
353 /// * `rhs` - The right hand side scalar.
354 ///
355 /// # Returns
356 ///
357 /// The result of dividing the tensor by the scalar.
358 fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
359 let scheme = *lhs.scheme();
360
361 let lhs_f = Self::dequantize(lhs);
362 let out_f = B::float_div_scalar(lhs_f, rhs);
363
364 Self::quantize_dynamic(out_f, &scheme)
365 }
366
367 /// Computes the remainder of division between two tensors element-wise.
368 ///
369 /// # Arguments
370 ///
371 /// * `lhs` - The left hand side tensor.
372 /// * `rhs` - The right hand side tensor.
373 ///
374 /// # Returns
375 ///
376 /// The element-wise remainder when dividing `lhs` by `rhs`.
377 fn q_remainder(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
378 dequant_op_quant!(
379 ty Self,
380 float_op |lhs, rhs| B::float_remainder(lhs, rhs),
381 lhs,
382 rhs
383 )
384 }
385
386 /// Computes the modulus of a tensor given a scalar.
387 ///
388 /// # Arguments
389 /// * `lhs` - The left hand side tensor.
390 /// * `rhs` - The right hand side scalar.
391 ///
392 /// # Returns
393 ///
394 /// The result of applying the modulus of the scalar to the tensor.
395 fn q_remainder_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
396 let scheme = *lhs.scheme();
397
398 let lhs_f = Self::dequantize(lhs);
399 let out_f = B::float_remainder_scalar(lhs_f, rhs);
400
401 Self::quantize_dynamic(out_f, &scheme)
402 }
403
404 /// Multiplies two tensors together using matrix multiplication.
405 ///
406 /// # Arguments
407 ///
408 /// * `lhs` - The left hand side tensor.
409 /// * `rhs` - The right hand side tensor.
410 ///
411 /// # Returns
412 ///
413 /// The result of multiplying the two tensors together using matrix multiplication.
414 fn q_matmul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
415 dequant_op_quant!(
416 ty Self,
417 float_op |lhs, rhs| B::float_matmul(lhs, rhs),
418 lhs,
419 rhs
420 )
421 }
422
423 /// Negates a tensor element-wise.
424 fn q_neg(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
425 let scheme = *tensor.scheme();
426
427 let tensor_f = Self::dequantize(tensor);
428 let out_f = B::float_neg(tensor_f);
429
430 Self::quantize_dynamic(out_f, &scheme)
431 }
432
433 /// Calculates the reciprocals element-wise
434 fn q_recip(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
435 let scheme = *tensor.scheme();
436
437 let tensor_f = Self::dequantize(tensor);
438 let out_f = B::float_recip(tensor_f);
439
440 Self::quantize_dynamic(out_f, &scheme)
441 }
442
443 /// Transposes a tensor.
444 ///
445 /// # Arguments
446 ///
447 /// * `tensor` - The tensor to transpose.
448 ///
449 /// # Returns
450 ///
451 /// The transposed tensor.
452 fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
453 let ndims = tensor.shape().num_dims();
454 Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
455 }
456
457 /// Swaps two dimensions of a tensor.
458 ///
459 /// # Arguments
460 ///
461 /// * `tensor` - The tensor to swap the dimensions of.
462 /// * `dim1` - The first dimension to swap.
463 /// * `dim2` - The second dimension to swap.
464 ///
465 /// # Returns
466 ///
467 /// The tensor with the dimensions swapped.
468 fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
469
470 /// Permutes the dimensions of a tensor.
471 ///
472 /// # Arguments
473 ///
474 /// * `tensor` - The tensor to permute the dimensions of.
475 /// * `axes` - The new order of the dimensions.
476 /// # Returns
477 ///
478 /// The tensor with the dimensions permuted.
479 fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
480
481 /// Reverse the order of elements in a tensor along the given axes.
482 ///
483 /// # Arguments
484 ///
485 /// * `tensor` - The tensor to reverse.
486 /// * `axes` - The axes to reverse.
487 ///
488 /// The tensor with the elements reversed.
489 fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
490
491 /// Gather elements from a tensor.
492 ///
493 /// # Arguments
494 ///
495 /// * `dim` - The dimension to gather from.
496 /// * `tensor` - The tensor to gather from.
497 /// * `indices` - The indices to gather.
498 ///
499 /// # Returns
500 ///
501 /// The gathered elements.
502 fn q_gather(
503 dim: usize,
504 tensor: QuantizedTensor<B>,
505 indices: IntTensor<B>,
506 ) -> QuantizedTensor<B> {
507 // Default implementation. Backends can gather on the quantized values when supported.
508 dequant_op_quant!(
509 ty Self,
510 float_op |tensor| B::float_gather(dim, tensor, indices),
511 tensor
512 )
513 }
514
515 /// Scatter elements into a tensor.
516 ///
517 /// # Arguments
518 ///
519 /// * `dim` - The dimension to scatter into.
520 /// * `tensor` - The tensor to scatter into.
521 /// * `indices` - The indices to scatter into.
522 /// * `value` - The value to scatter.
523 ///
524 /// # Returns
525 ///
526 /// The tensor with the scattered elements.
527 fn q_scatter(
528 dim: usize,
529 tensor: QuantizedTensor<B>,
530 indices: IntTensor<B>,
531 value: QuantizedTensor<B>,
532 ) -> QuantizedTensor<B> {
533 dequant_op_quant!(
534 ty Self,
535 float_op |tensor, value| B::float_scatter(dim, tensor, indices, value),
536 tensor,
537 value
538 )
539 }
540
541 /// Select tensor elements along the given dimension corresponding for the given indices.
542 ///
543 /// # Arguments
544 ///
545 /// * `tensor` - The tensor to select from.
546 /// * `dim` - The dimension to select from.
547 /// * `indices` - The indices to select.
548 ///
549 /// # Returns
550 ///
551 /// The selected elements.
552 fn q_select(
553 tensor: QuantizedTensor<B>,
554 dim: usize,
555 indices: IntTensor<B>,
556 ) -> QuantizedTensor<B>;
557
558 /// Assign the selected elements along the given dimension corresponding for the given indices
559 /// to the given value.
560 ///
561 /// # Arguments
562 ///
563 /// * `tensor` - The tensor to select from.
564 /// * `dim` - The dimension to select from.
565 /// * `indices` - The indices to select.
566 /// * `value` - The value to assign.
567 ///
568 /// # Returns
569 ///
570 /// The tensor with the selected elements assigned to the given value.
571 fn q_select_assign(
572 tensor: QuantizedTensor<B>,
573 dim: usize,
574 indices: IntTensor<B>,
575 value: QuantizedTensor<B>,
576 ) -> QuantizedTensor<B> {
577 dequant_op_quant!(
578 ty Self,
579 float_op |tensor, value| B::float_select_assign(tensor, dim, indices, value),
580 tensor,
581 value
582 )
583 }
584
585 /// Select tensor elements corresponding for the given ranges.
586 ///
587 /// # Arguments
588 ///
589 /// * `tensor` - The tensor to select from.
590 /// * `ranges` - The ranges to select.
591 ///
592 /// # Returns
593 ///
594 /// The selected elements in a new tensor.
595 fn q_slice(tensor: QuantizedTensor<B>, ranges: &[Range<usize>]) -> QuantizedTensor<B>;
596
597 /// Assign the selected elements corresponding for the given ranges to the given value.
598 ///
599 /// # Arguments
600 ///
601 /// * `tensor` - The tensor to select from.
602 /// * `ranges` - The ranges to select.
603 /// * `value` - The value to assign.
604 ///
605 /// # Returns
606 ///
607 /// The tensor with the selected elements assigned to the given value.
608 fn q_slice_assign(
609 tensor: QuantizedTensor<B>,
610 ranges: &[Range<usize>],
611 value: QuantizedTensor<B>,
612 ) -> QuantizedTensor<B> {
613 dequant_op_quant!(
614 ty Self,
615 float_op |tensor, value| B::float_slice_assign(tensor, ranges, value),
616 tensor,
617 value
618 )
619 }
620
621 /// Update the given tensor with the value tensor where the mask is true.
622 ///
623 /// # Arguments
624 ///
625 /// * `tensor` - The tensor to select from.
626 /// * `mask` - The boolean mask to select with.
627 /// * `value` - The value to assign to the selected elements from the value tensor.
628 ///
629 /// # Returns
630 ///
631 /// The tensor with the selected elements assigned to the given value.
632 fn q_mask_where(
633 tensor: QuantizedTensor<B>,
634 mask: BoolTensor<B>,
635 value: QuantizedTensor<B>,
636 ) -> QuantizedTensor<B> {
637 dequant_op_quant!(
638 ty Self,
639 float_op |tensor, value| B::float_mask_where(tensor, mask, value),
640 tensor,
641 value
642 )
643 }
644
645 /// Update the given tensor with the value where the mask is true.
646 ///
647 /// # Arguments
648 ///
649 /// * `tensor` - The tensor to select from.
650 /// * `mask` - The boolean mask to select with.
651 /// * `value` - The value to assign to the selected elements.
652 ///
653 /// # Returns
654 ///
655 /// The tensor with the selected elements assigned to the given value.
656 fn q_mask_fill(
657 tensor: QuantizedTensor<B>,
658 mask: BoolTensor<B>,
659 value: FloatElem<B>,
660 ) -> QuantizedTensor<B> {
661 dequant_op_quant!(
662 ty Self,
663 float_op |tensor| B::float_mask_fill(tensor, mask, value),
664 tensor
665 )
666 }
667
668 /// Sum of all elements in a tensor.
669 ///
670 /// # Arguments
671 ///
672 /// * `tensor` - The tensor to sum.
673 ///
674 /// # Returns
675 ///
676 /// A scalar tensor with the sum of all elements in `tensor`.
677 fn q_sum(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
678 dequant_op_quant!(
679 ty Self,
680 float_op |tensor| B::float_sum(tensor),
681 tensor
682 )
683 }
684
685 /// Sum of all elements in a tensor along a dimension.
686 ///
687 /// # Arguments
688 ///
689 /// * `tensor` - The tensor to sum.
690 /// * `dim` - The dimension along which to sum.
691 ///
692 /// # Returns
693 ///
694 /// A tensor with the sum of all elements in `tensor` along `dim`.
695 fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
696 dequant_op_quant!(
697 ty Self,
698 float_op |tensor| B::float_sum_dim(tensor, dim),
699 tensor
700 )
701 }
702
703 /// Product of all elements in a tensor.
704 ///
705 /// # Arguments
706 ///
707 /// * `tensor` - The tensor to product.
708 ///
709 /// # Returns
710 ///
711 /// A scalar tensor with the product of all elements in `tensor`.
712 fn q_prod(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
713 dequant_op_quant!(
714 ty Self,
715 float_op |tensor| B::float_prod(tensor),
716 tensor
717 )
718 }
719
720 /// Product of all elements in a tensor along a dimension.
721 ///
722 /// # Arguments
723 ///
724 /// * `tensor` - The tensor to product.
725 ///
726 /// # Returns
727 ///
728 /// A tensor with the product of all elements in `tensor` along `dim`.
729 fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
730 dequant_op_quant!(
731 ty Self,
732 float_op |tensor| B::float_prod_dim(tensor, dim),
733 tensor
734 )
735 }
736
737 /// Mean of all elements in a tensor.
738 ///
739 /// # Arguments
740 ///
741 /// * `tensor` - The tensor to mean.
742 ///
743 /// # Returns
744 ///
745 /// A scalar tensor with the mean of all elements in `tensor`.
746 fn q_mean(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
747 dequant_op_quant!(
748 ty Self,
749 float_op |tensor| B::float_mean(tensor),
750 tensor
751 )
752 }
753
754 /// Mean of all elements in a tensor along a dimension.
755 ///
756 /// # Arguments
757 ///
758 /// * `tensor` - The tensor to mean.
759 /// * `dim` - The dimension along which to mean.
760 ///
761 /// # Returns
762 ///
763 /// A tensor with the mean of all elements in `tensor` along `dim`.
764 fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
765 dequant_op_quant!(
766 ty Self,
767 float_op |tensor| B::float_mean_dim(tensor, dim),
768 tensor
769 )
770 }
771
772 /// Returns a new tensor with exponential values.
773 ///
774 /// # Arguments
775 ///
776 /// * `tensor` - The tensor to exponentiate.
777 ///
778 /// # Returns
779 ///
780 /// A tensor with the same shape as `tensor` with exponential values.
781 fn q_exp(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
782 dequant_op_quant!(
783 ty Self,
784 float_op |tensor| B::float_exp(tensor),
785 tensor
786 )
787 }
788
789 /// Returns a new tensor with natural logarithm values.
790 ///
791 /// # Arguments
792 ///
793 /// * `tensor` - The tensor to take the logarithm of.
794 ///
795 /// # Returns
796 ///
797 /// A tensor with the same shape as `tensor` with natural logarithm values.
798 fn q_log(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
799 dequant_op_quant!(
800 ty Self,
801 float_op |tensor| B::float_log(tensor),
802 tensor
803 )
804 }
805
806 /// Returns a new tensor with logarithm values of (1 + Xi).
807 ///
808 /// # Arguments
809 ///
810 /// * `tensor` - The tensor to take the logarithm of.
811 ///
812 /// # Returns
813 ///
814 /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
815 fn q_log1p(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
816 dequant_op_quant!(
817 ty Self,
818 float_op |tensor| B::float_log1p(tensor),
819 tensor
820 )
821 }
822
823 /// Element-wise power with another tensor.
824 ///
825 /// # Arguments
826 ///
827 /// * `lhs` - The left hand side tensor.
828 /// * `rhs` - The right hand side tensor.
829 ///
830 /// # Returns
831 ///
832 /// The elements of `lhs` raised to the power of the elements of `rhs`.
833 fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
834 dequant_op_quant!(
835 ty Self,
836 float_op |lhs, rhs| B::float_powf(lhs, rhs),
837 lhs,
838 rhs
839 )
840 }
841
842 /// Element-wise power with an IntTensor.
843 ///
844 /// # Arguments
845 ///
846 /// * `lhs` - The left hand side tensor.
847 /// * `rhs` - The right hand side floatTensor.
848 ///
849 /// # Returns
850 ///
851 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
852 fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> QuantizedTensor<B> {
853 dequant_op_quant!(
854 ty Self,
855 float_op |tensor| B::float_powi(tensor, rhs),
856 lhs
857 )
858 }
859
860 /// Element-wise power with an int scalar.
861 ///
862 /// # Arguments
863 ///
864 /// * `lhs` - The left hand side tensor.
865 /// * `rhs` - The right hand side scalar.
866 ///
867 /// # Returns
868 ///
869 /// The elements of `lhs` raised to the value of `rhs`.
870 fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: IntElem<B>) -> QuantizedTensor<B> {
871 dequant_op_quant!(
872 ty Self,
873 float_op |tensor| B::float_powi_scalar(tensor, rhs),
874 lhs
875 )
876 }
877
878 /// Element-wise power with a float scalar.
879 ///
880 /// # Arguments
881 ///
882 /// * `tensor` - The tensor to exponentiate.
883 /// * `value` - The exponent.
884 ///
885 /// # Returns
886 ///
887 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
888 fn q_powf_scalar(tensor: QuantizedTensor<B>, value: f32) -> QuantizedTensor<B> {
889 dequant_op_quant!(
890 ty Self,
891 float_op |tensor| B::float_powf_scalar(tensor, value),
892 tensor
893 )
894 }
895
896 /// Returns a new tensor with square root values.
897 ///
898 /// # Arguments
899 ///
900 /// * `tensor` - The tensor to take the square root of.
901 ///
902 /// # Returns
903 ///
904 /// A tensor with the same shape as `tensor` with square root values.
905 fn q_sqrt(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
906 dequant_op_quant!(
907 ty Self,
908 float_op |tensor| B::float_sqrt(tensor),
909 tensor
910 )
911 }
912
913 /// Returns a new tensor with absolute values.
914 ///
915 /// # Arguments
916 ///
917 /// * `tensor` - The tensor to take absolute value of.
918 ///
919 /// # Returns
920 ///
921 /// A tensor with the same shape as `tensor` with absolute values.
922 fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
923 dequant_op_quant!(
924 ty Self,
925 float_op |tensor| B::float_abs(tensor),
926 tensor
927 )
928 }
929
930 /// Returns a new tensor with cosine values.
931 ///
932 /// # Arguments
933 ///
934 /// * `tensor` - The tensor to take the cosine of.
935 ///
936 /// # Returns
937 ///
938 /// A tensor with the same shape as `tensor` with cosine values.
939 fn q_cos(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
940 dequant_op_quant!(
941 ty Self,
942 float_op |tensor| B::float_cos(tensor),
943 tensor
944 )
945 }
946
947 /// Returns a new tensor with sine values.
948 ///
949 /// # Arguments
950 ///
951 /// * `tensor` - The tensor to take the sine of.
952 ///
953 /// # Returns
954 ///
955 /// A tensor with the same shape as `tensor` with sine values.
956 fn q_sin(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
957 dequant_op_quant!(
958 ty Self,
959 float_op |tensor| B::float_sin(tensor),
960 tensor
961 )
962 }
963
964 /// Returns a new tensor with tangent values.
965 ///
966 /// # Arguments
967 ///
968 /// * `tensor` - The tensor to take the tangent of.
969 ///
970 /// # Returns
971 ///
972 /// A tensor with the same shape as `tensor` with tangent values.
973 fn q_tan(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
974 dequant_op_quant!(
975 ty Self,
976 float_op |tensor| B::float_tan(tensor),
977 tensor
978 )
979 }
980
981 /// Returns a new tensor with hyperbolic cosine values.
982 ///
983 /// # Arguments
984 ///
985 /// * `tensor` - The tensor to take the hyperbolic cosine of.
986 ///
987 /// # Returns
988 ///
989 /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
990 fn q_cosh(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
991 dequant_op_quant!(
992 ty Self,
993 float_op |tensor| B::float_cosh(tensor),
994 tensor
995 )
996 }
997
998 /// Returns a new tensor with hyperbolic sine values.
999 ///
1000 /// # Arguments
1001 ///
1002 /// * `tensor` - The tensor to take the hyperbolic sine of.
1003 ///
1004 /// # Returns
1005 ///
1006 /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1007 fn q_sinh(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1008 dequant_op_quant!(
1009 ty Self,
1010 float_op |tensor| B::float_sinh(tensor),
1011 tensor
1012 )
1013 }
1014
1015 /// Returns a new tensor with hyperbolic tangent values.
1016 ///
1017 /// # Arguments
1018 ///
1019 /// * `tensor` - The tensor to take the hyperbolic tangent of.
1020 ///
1021 /// # Returns
1022 ///
1023 /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1024 fn q_tanh(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1025 dequant_op_quant!(
1026 ty Self,
1027 float_op |tensor| B::float_tanh(tensor),
1028 tensor
1029 )
1030 }
1031
1032 /// Returns a new tensor with the error function values.
1033 ///
1034 /// # Arguments
1035 ///
1036 /// * `tensor` - The tensor to take the error function of.
1037 ///
1038 /// # Returns
1039 ///
1040 /// A tensor with the same shape as `tensor` with error function values.
1041 fn q_erf(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1042 dequant_op_quant!(
1043 ty Self,
1044 float_op |tensor| B::float_erf(tensor),
1045 tensor
1046 )
1047 }
1048
1049 /// Concatenates tensors along a dimension.
1050 ///
1051 /// # Arguments
1052 ///
1053 /// * `tensors` - The tensors to concatenate.
1054 /// * `dim` - The dimension along which to concatenate.
1055 ///
1056 /// # Returns
1057 ///
1058 /// A tensor with the concatenated tensors along `dim`.
1059 fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
1060 // Heuristic: prioritize first tensor scheme
1061 let scheme = *tensors.first().unwrap().scheme();
1062
1063 let tensor_f = tensors
1064 .into_iter()
1065 .map(|tensor| Self::dequantize(tensor))
1066 .collect();
1067
1068 let out_f = B::float_cat(tensor_f, dim);
1069
1070 Self::quantize_dynamic(out_f, &scheme)
1071 }
1072
1073 /// Gets the indices of the maximum elements of a tensor along an axis.
1074 ///
1075 /// # Arguments
1076 ///
1077 /// * `tensor` - The tensor to get the maximum elements of.
1078 /// * `dim` - The dimension along which to get the maximum elements.
1079 ///
1080 /// # Returns
1081 ///
1082 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1083 fn q_argmax(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1084 // Default implementation. Backends can sort on the int values since qparams remain the same.
1085 let tensor_f = Self::dequantize(tensor);
1086 B::float_argmax(tensor_f, dim)
1087 }
1088
1089 /// Gets the indices of the minimum elements of a tensor along an axis.
1090 ///
1091 /// # Arguments
1092 ///
1093 /// * `tensor` - The tensor to get the minimum elements of.
1094 /// * `dim` - The dimension along which to get the minimum elements.
1095 ///
1096 /// # Returns
1097 ///
1098 /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1099 fn q_argmin(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1100 // Default implementation. Backends can sort on the int values since qparams remain the same.
1101 let tensor_f = Self::dequantize(tensor);
1102 B::float_argmin(tensor_f, dim)
1103 }
1104
1105 /// Gets the maximum element of a tensor.
1106 ///
1107 /// # Arguments
1108 ///
1109 /// * `tensor` - The tensor to get the maximum elements of.
1110 ///
1111 /// # Returns
1112 ///
1113 /// A tensor with the maximum element of `tensor`.
1114 fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1115 let shape = tensor.shape();
1116 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1117
1118 B::q_max_dim(tensor, 0)
1119 }
1120
1121 /// Gets the maximum elements of a tensor along an axis.
1122 ///
1123 /// # Arguments
1124 ///
1125 /// * `tensor` - The tensor to get the maximum elements of.
1126 /// * `dim` - The dimension along which to get the maximum elements.
1127 ///
1128 /// # Returns
1129 ///
1130 /// A tensor with the maximum elements of `tensor` along `dim`.
1131 fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1132 let index = B::q_argmax(tensor.clone(), dim);
1133
1134 B::q_gather(dim, tensor, index)
1135 }
1136
1137 /// Gets the maximum elements of a tensor along an axis and their indices.
1138 ///
1139 /// # Arguments
1140 ///
1141 /// * `tensor` - The tensor to get the maximum elements of.
1142 /// * `dim` - The dimension along which to get the maximum elements.
1143 ///
1144 /// # Returns
1145 ///
1146 /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1147 fn q_max_dim_with_indices(
1148 tensor: QuantizedTensor<B>,
1149 dim: usize,
1150 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1151 let index = B::q_argmax(tensor.clone(), dim);
1152 let values = B::q_gather(dim, tensor, index.clone());
1153
1154 (values, index)
1155 }
1156
1157 /// Gets the minimum element of a tensor.
1158 ///
1159 /// # Arguments
1160 ///
1161 /// * `tensor` - The tensor to get the minimum elements of.
1162 ///
1163 /// # Returns
1164 ///
1165 /// A tensor with the minimum element of `tensor`.
1166 fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1167 let shape = tensor.shape();
1168 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1169
1170 B::q_min_dim(tensor, 0)
1171 }
1172
1173 /// Gets the minimum elements of a tensor along an axis.
1174 ///
1175 /// # Arguments
1176 ///
1177 /// * `tensor` - The tensor to get the minimum elements of.
1178 /// * `dim` - The dimension along which to get the minimum elements.
1179 ///
1180 /// # Returns
1181 ///
1182 /// A tensor with the minimum elements of `tensor` along `dim`.
1183 fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1184 let index = B::q_argmin(tensor.clone(), dim);
1185
1186 B::q_gather(dim, tensor, index)
1187 }
1188
1189 /// Gets the minimum elements of a tensor along an axis and their indices.
1190 ///
1191 /// # Arguments
1192 ///
1193 /// * `tensor` - The tensor to get the minimum elements of.
1194 /// * `dim` - The dimension along which to get the minimum elements.
1195 ///
1196 /// # Returns
1197 ///
1198 /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1199 fn q_min_dim_with_indices(
1200 tensor: QuantizedTensor<B>,
1201 dim: usize,
1202 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1203 let index = B::q_argmin(tensor.clone(), dim);
1204 let values = B::q_gather(dim, tensor, index.clone());
1205
1206 (values, index)
1207 }
1208
1209 /// Gets the maximum element of a tensor.
1210 ///
1211 /// # Arguments
1212 ///
1213 /// * `tensor` - The tensor to get the maximum elements of.
1214 ///
1215 /// # Returns
1216 ///
1217 /// A tensor with the maximum element of `tensor`.
1218 fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1219 let shape = tensor.shape();
1220 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1221
1222 B::q_max_abs_dim(tensor, 0)
1223 }
1224
1225 /// Gets the maximum elements of a tensor along an axis.
1226 ///
1227 /// # Arguments
1228 ///
1229 /// * `tensor` - The tensor to get the maximum elements of.
1230 /// * `dim` - The dimension along which to get the maximum elements.
1231 ///
1232 /// # Returns
1233 ///
1234 /// A tensor with the maximum elements of `tensor` along `dim`.
1235 fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1236 let index = B::q_argmax(B::q_abs(tensor.clone()), dim);
1237
1238 B::q_gather(dim, tensor, index)
1239 }
1240
1241 /// Returns a new tensor with the given dimension narrowed to the given range.
1242 ///
1243 /// # Arguments
1244 ///
1245 /// * `dim` - The dimension along which the tensor will be narrowed.
1246 /// * `start` - The starting point of the given range.
1247 /// * `length` - The ending point of the given range.
1248 /// # Panics
1249 ///
1250 /// - If the dimension is greater than the number of dimensions of the tensor.
1251 /// - If the given range exceeds the number of elements on the given dimension.
1252 ///
1253 /// # Returns
1254 ///
1255 /// A new tensor with the given dimension narrowed to the given range.
1256 fn q_narrow(
1257 tensor: QuantizedTensor<B>,
1258 dim: usize,
1259 start: usize,
1260 length: usize,
1261 ) -> QuantizedTensor<B> {
1262 dequant_op_quant!(
1263 ty Self,
1264 float_op |tensor| B::float_narrow(tensor, dim, start, length),
1265 tensor
1266 )
1267 }
1268
1269 /// Split the tensor along the given dimension into chunks.
1270 ///
1271 /// # Arguments
1272 ///
1273 /// * `tensor` - The tensor.
1274 /// * `chunks` - The number of chunks to be produced.
1275 /// * `times` - The dimension along which the tensor will be split.
1276 ///
1277 /// # Returns
1278 ///
1279 /// A vector of tensors.
1280 fn q_chunk(tensor: QuantizedTensor<B>, chunks: usize, dim: usize) -> Vec<QuantizedTensor<B>> {
1281 let scheme = *tensor.scheme();
1282
1283 let tensor_f = Self::dequantize(tensor);
1284 let out_f = B::float_chunk(tensor_f, chunks, dim);
1285
1286 out_f
1287 .into_iter()
1288 .map(|tensor| Self::quantize_dynamic(tensor, &scheme))
1289 .collect()
1290 }
1291
1292 /// Split the tensor along the given dimension into chunks of `split_size`.
1293 ///
1294 /// # Arguments
1295 ///
1296 /// * `tensor` - The tensor.
1297 /// * `split_size` - The size of a single chunk.
1298 /// * `times` - The dimension along which the tensor will be split.
1299 ///
1300 /// # Returns
1301 ///
1302 /// A vector of tensors.
1303 fn q_split(
1304 tensor: QuantizedTensor<B>,
1305 split_size: usize,
1306 dim: usize,
1307 ) -> Vec<QuantizedTensor<B>> {
1308 let scheme = *tensor.scheme();
1309
1310 let tensor_f = Self::dequantize(tensor);
1311 let out_f = B::float_split(tensor_f, split_size, dim);
1312
1313 out_f
1314 .into_iter()
1315 .map(|tensor| Self::quantize_dynamic(tensor, &scheme))
1316 .collect()
1317 }
1318
1319 /// Split the tensor along the given dimension into chunks with sizes in
1320 /// `dim` according to `split_sizes`.
1321 ///
1322 /// # Arguments
1323 ///
1324 /// * `tensor` - The tensor.
1325 /// * `split_sizes` - Vector of sizes for each chunk.
1326 /// * `times` - The dimension along which the tensor will be split.
1327 ///
1328 /// # Returns
1329 ///
1330 /// A vector of tensors.
1331 fn q_split_with_sizes(
1332 tensor: QuantizedTensor<B>,
1333 split_sizes: Vec<usize>,
1334 dim: usize,
1335 ) -> Vec<QuantizedTensor<B>> {
1336 let scheme = *tensor.scheme();
1337
1338 let tensor_f = Self::dequantize(tensor);
1339 let out_f = B::float_split_with_sizes(tensor_f, split_sizes, dim);
1340
1341 out_f
1342 .into_iter()
1343 .map(|tensor| Self::quantize_dynamic(tensor, &scheme))
1344 .collect()
1345 }
1346
1347 /// Tests if any element in the `tensor` evaluates to True.
1348 ///
1349 /// # Arguments
1350 ///
1351 /// * `tensor` - The tensor to test.
1352 ///
1353 /// # Returns
1354 ///
1355 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1356 fn q_any(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1357 let tensor_f = Self::dequantize(tensor);
1358 B::float_any(tensor_f)
1359 }
1360
1361 /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1362 ///
1363 /// # Arguments
1364 ///
1365 /// * `tensor` - The tensor to test.
1366 /// * `dim` - The axis along which to test.
1367 ///
1368 /// # Returns
1369 ///
1370 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1371 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1372 /// input evaluates to True, False otherwise.
1373 fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1374 let tensor_f = Self::dequantize(tensor);
1375 B::float_any_dim(tensor_f, dim)
1376 }
1377
1378 /// Tests if all elements in the `tensor` evaluate to True.
1379 ///
1380 /// # Arguments
1381 ///
1382 /// * `tensor` - The tensor to test.
1383 ///
1384 /// # Returns
1385 ///
1386 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1387 /// evaluate to True, False otherwise.
1388 fn q_all(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1389 let tensor_f = Self::dequantize(tensor);
1390 B::float_all(tensor_f)
1391 }
1392
1393 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1394 ///
1395 /// # Arguments
1396 ///
1397 /// * `tensor` - The tensor to test.
1398 /// * `dim` - The axis along which to test.
1399 ///
1400 /// # Returns
1401 ///
1402 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1403 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1404 /// evaluates to True, False otherwise.
1405 fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1406 let tensor_f = Self::dequantize(tensor);
1407 B::float_all_dim(tensor_f, dim)
1408 }
1409
1410 /// Broadcasts the `tensor` to the given `shape`.
1411 fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
1412
1413 /// Sort the elements of the input `tensor` by value in along a given dimension.
1414 ///
1415 /// This sort is unstable (i.e., may reorder equal elements).
1416 ///
1417 /// # Arguments
1418 ///
1419 /// * `tensor` - The input tensor.
1420 /// * `dim` - The axis along which to sort.
1421 /// * `descending` - The sorting order.
1422 ///
1423 /// # Returns
1424 ///
1425 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1426 fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1427 // Default implementation. Backends can sort on the int values since qparams remain the same.
1428 dequant_op_quant!(
1429 ty Self,
1430 float_op |tensor| B::float_sort(tensor, dim, descending),
1431 tensor
1432 )
1433 }
1434
1435 /// Sort the elements of the input `tensor` by value in along a given dimension.
1436 ///
1437 /// This sort is unstable (i.e., may reorder equal elements).
1438 ///
1439 /// # Arguments
1440 ///
1441 /// * `tensor` - The input tensor.
1442 /// * `dim` - The axis along which to sort.
1443 /// * `descending` - The sorting order.
1444 ///
1445 /// # Returns
1446 ///
1447 /// A tensor with the same shape as the input tensor and corresponding indices, where
1448 /// the elements are sorted by value and the indices map back to the original input tensor.
1449 fn q_sort_with_indices(
1450 tensor: QuantizedTensor<B>,
1451 dim: usize,
1452 descending: bool,
1453 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1454 // Default implementation. Backends can sort on the int values since qparams remain the same.
1455 let scheme = *tensor.scheme();
1456
1457 let tensor_f = Self::dequantize(tensor);
1458 let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending);
1459
1460 (Self::quantize_dynamic(out_f, &scheme), indices)
1461 }
1462
1463 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1464 ///
1465 /// This sort is unstable (i.e., may reorder equal elements).
1466 ///
1467 /// # Arguments
1468 ///
1469 /// * `tensor` - The input tensor.
1470 /// * `dim` - The axis along which to sort.
1471 /// * `descending` - The sorting order.
1472 ///
1473 /// # Returns
1474 ///
1475 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1476 fn q_argsort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1477 // Default implementation. Backends can sort on the int values since qparams remain the same.
1478 let tensor_f = Self::dequantize(tensor);
1479 B::float_argsort(tensor_f, dim, descending)
1480 }
1481}