burn_tensor/tensor/ops/qtensor.rs
1use alloc::vec::Vec;
2use core::{future::Future, ops::Range};
3
4use crate::{
5 backend::Backend,
6 quantization::{QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme},
7 Device, Shape, TensorData, TensorMetadata,
8};
9
10use super::{BoolTensor, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor};
11
12/// Automatically applies dequantization -> float operation -> quantization.
13#[macro_export]
14macro_rules! dequant_op_quant {
15 // Binary tensor float op w/ lhs & rhs
16 (
17 ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr
18 ) => {{
19 // Heuristic: prioritize lhs scheme
20 let scheme = $t1.scheme().clone();
21
22 let t1_f = <$ty>::dequantize($t1);
23 let t2_f = <$ty>::dequantize($t2);
24 #[allow(clippy::redundant_closure_call)]
25 let out_f = $float_op(t1_f, t2_f);
26
27 <$ty>::quantize_dynamic(out_f, &scheme)
28 }};
29 // Unary tensor float op
30 (
31 ty $ty:ty, float_op $float_op:expr, $tensor:expr
32 ) => {{
33 let scheme = $tensor.scheme().clone();
34
35 let tensor_f = <$ty>::dequantize($tensor);
36 #[allow(clippy::redundant_closure_call)]
37 let out_f = $float_op(tensor_f);
38
39 <$ty>::quantize_dynamic(out_f, &scheme)
40 }};
41}
42
43/// Quantized Tensor API for basic operations, see [tensor](crate::Tensor)
44/// for documentation on each function.
45pub trait QTensorOps<B: Backend> {
46 /// Creates a new tensor from the data structure.
47 ///
48 /// # Arguments
49 ///
50 /// * `data` - The data structure.
51 /// * `device` - The device to create the tensor on.
52 ///
53 /// # Returns
54 ///
55 /// The tensor with the given data.
56 fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;
57
58 /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
59 fn quantize(
60 tensor: FloatTensor<B>,
61 scheme: &QuantizationScheme,
62 qparams: QuantizationParametersPrimitive<B>,
63 ) -> QuantizedTensor<B>;
64
65 /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.
66 fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantizationScheme) -> QuantizedTensor<B> {
67 // Dynamically compute min/max tensor range and qparams before quantizing
68 let min = B::float_min(tensor.clone());
69 let max = B::float_max(tensor.clone());
70 let qparams = scheme.compute_q_params_primitive(min, max);
71 Self::quantize(tensor, scheme, qparams)
72 }
73
74 /// Convert the tensor back to a higher precision data type.
75 fn dequantize(tensor: QuantizedTensor<B>) -> FloatTensor<B>;
76
77 /// Gets the device of the tensor.
78 ///
79 /// # Arguments
80 ///
81 /// * `tensor` - The tensor.
82 ///
83 /// # Returns
84 ///
85 /// The device of the tensor.
86 fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;
87
88 /// Moves the tensor to the given device.
89 ///
90 /// # Arguments
91 ///
92 /// * `tensor` - The tensor.
93 /// * `device` - The device to move the tensor to.
94 ///
95 /// # Returns
96 ///
97 /// The tensor on the given device.
98 fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;
99
100 /// Reshapes a tensor.
101 ///
102 /// # Arguments
103 ///
104 /// * `tensor` - The tensor to reshape.
105 /// * `shape` - The new shape of the tensor.
106 ///
107 /// # Returns
108 ///
109 /// The tensor with the new shape.
110 fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
111
112 /// Converts the tensor to a data structure.
113 ///
114 /// # Arguments
115 ///
116 /// * `tensor` - The tensor.
117 ///
118 /// # Returns
119 ///
120 /// The data structure with the tensor's data.
121 fn q_into_data(tensor: QuantizedTensor<B>)
122 -> impl Future<Output = TensorData> + 'static + Send;
123
124 /// Detaches a tensor from the computation graph.
125 fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
126 // Should only be overridden by autodiff backends.
127 tensor
128 }
129
130 /// Sets the `require_grad` flag of a tensor.
131 fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {
132 // Should only be overridden by autodiff backends.
133 tensor
134 }
135
136 /// Returns the `require_grad` flag of a tensor.
137 fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {
138 // Should only be overridden by autodiff backends.
139 false
140 }
141
142 /// Repeat the tensor along the given dimension.
143 ///
144 /// # Arguments
145 ///
146 /// * `tensor` - The tensor.
147 /// * `dim` - The dimension to repeat.
148 /// * `times` - The number of times to repeat the dimension.
149 ///
150 /// # Returns
151 ///
152 /// The tensor with the given dimension repeated.
153 fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {
154 dequant_op_quant!(
155 ty Self,
156 float_op |tensor| B::float_repeat_dim(tensor, dim, times),
157 tensor
158 )
159 }
160
161 /// Adds two tensors together.
162 ///
163 /// # Arguments
164 ///
165 /// * `lhs` - The left hand side tensor.
166 /// * `rhs` - The right hand side tensor.
167 ///
168 /// # Returns
169 ///
170 /// The result of adding the two tensors together.
171 fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
172 dequant_op_quant!(
173 ty Self,
174 float_op |lhs, rhs| B::float_add(lhs, rhs),
175 lhs,
176 rhs
177 )
178 }
179
180 /// Adds a scalar to a tensor.
181 ///
182 /// # Arguments
183 ///
184 /// * `lhs` - The left hand side tensor.
185 /// * `rhs` - The right hand side scalar.
186 ///
187 /// # Returns
188 ///
189 /// The result of adding the scalar to the tensor.
190 fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
191 let scheme = *lhs.scheme();
192
193 let lhs_f = Self::dequantize(lhs);
194 let out_f = B::float_add_scalar(lhs_f, rhs);
195
196 Self::quantize_dynamic(out_f, &scheme)
197 }
198
199 /// Clamps a tensor under a minimum value.
200 ///
201 /// # Arguments
202 ///
203 /// * `tensor` - The tensor to clamp.
204 /// * `min` - The minimum value.
205 ///
206 /// # Returns
207 ///
208 /// The clamped tensor.
209 fn q_clamp_min(tensor: QuantizedTensor<B>, min: FloatElem<B>) -> QuantizedTensor<B> {
210 let scheme = *tensor.scheme();
211
212 let tensor_f = Self::dequantize(tensor);
213 let out_f = B::float_clamp_min(tensor_f, min);
214
215 Self::quantize_dynamic(out_f, &scheme)
216 }
217
218 /// Clamps a tensor over a maximum value.
219 ///
220 /// # Arguments
221 ///
222 /// * `tensor` - The tensor to clamp.
223 /// * `max` - The maximum value.
224 ///
225 /// # Returns
226 ///
227 /// The clamped tensor.
228 fn q_clamp_max(tensor: QuantizedTensor<B>, max: FloatElem<B>) -> QuantizedTensor<B> {
229 let scheme = *tensor.scheme();
230
231 let tensor_f = Self::dequantize(tensor);
232 let out_f = B::float_clamp_max(tensor_f, max);
233
234 Self::quantize_dynamic(out_f, &scheme)
235 }
236
237 /// Clamps a tensor between a minimum and maximum value.
238 ///
239 /// # Arguments
240 ///
241 /// * `tensor` - The tensor to clamp.
242 /// * `min` - The minimum value.
243 /// * `max` - The maximum value.
244 ///
245 /// # Returns
246 ///
247 /// The clamped tensor.
248 fn q_clamp(
249 tensor: QuantizedTensor<B>,
250 min: FloatElem<B>,
251 max: FloatElem<B>,
252 ) -> QuantizedTensor<B> {
253 let scheme = *tensor.scheme();
254
255 let tensor_f = Self::dequantize(tensor);
256 let out_f = B::float_clamp(tensor_f, min, max);
257
258 Self::quantize_dynamic(out_f, &scheme)
259 }
260
261 /// Subtracts two tensors.
262 ///
263 /// # Arguments
264 ///
265 /// * `lhs` - The left hand side tensor.
266 /// * `rhs` - The right hand side tensor.
267 ///
268 /// # Returns
269 ///
270 /// The result of subtracting the two tensors.
271 fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
272 dequant_op_quant!(
273 ty Self,
274 float_op |lhs, rhs| B::float_sub(lhs, rhs),
275 lhs,
276 rhs
277 )
278 }
279
280 /// Subtracts a scalar from a tensor.
281 ///
282 /// # Arguments
283 ///
284 /// * `lhs` - The left hand side tensor.
285 /// * `rhs` - The right hand side scalar.
286 ///
287 /// # Returns
288 ///
289 /// The result of subtracting the scalar from the tensor.
290 fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
291 let scheme = *lhs.scheme();
292
293 let lhs_f = Self::dequantize(lhs);
294 let out_f = B::float_sub_scalar(lhs_f, rhs);
295
296 Self::quantize_dynamic(out_f, &scheme)
297 }
298
299 /// Multiplies two tensors together element-wise.
300 fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
301 dequant_op_quant!(
302 ty Self,
303 float_op |lhs, rhs| B::float_mul(lhs, rhs),
304 lhs,
305 rhs
306 )
307 }
308
309 /// Multiplies a tensor by a scalar.
310 ///
311 /// # Arguments
312 ///
313 /// * `lhs` - The left hand side tensor.
314 /// * `rhs` - The right hand side scalar.
315 ///
316 /// # Returns
317 ///
318 /// The result of multiplying the tensor by the scalar.
319 fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
320 let scheme = *lhs.scheme();
321
322 let lhs_f = Self::dequantize(lhs);
323 let out_f = B::float_mul_scalar(lhs_f, rhs);
324
325 Self::quantize_dynamic(out_f, &scheme)
326 }
327
328 /// Divides two tensors element-wise.
329 ///
330 /// # Arguments
331 ///
332 /// * `lhs` - The left hand side tensor.
333 /// * `rhs` - The right hand side tensor.
334 ///
335 /// # Returns
336 ///
337 /// The result of dividing the two tensors.
338 fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
339 dequant_op_quant!(
340 ty Self,
341 float_op |lhs, rhs| B::float_div(lhs, rhs),
342 lhs,
343 rhs
344 )
345 }
346
347 /// Divides a tensor by a scalar.
348 ///
349 /// # Arguments
350 ///
351 /// * `lhs` - The left hand side tensor.
352 /// * `rhs` - The right hand side scalar.
353 ///
354 /// # Returns
355 ///
356 /// The result of dividing the tensor by the scalar.
357 fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
358 let scheme = *lhs.scheme();
359
360 let lhs_f = Self::dequantize(lhs);
361 let out_f = B::float_div_scalar(lhs_f, rhs);
362
363 Self::quantize_dynamic(out_f, &scheme)
364 }
365
366 /// Computes the remainder of division between two tensors element-wise.
367 ///
368 /// # Arguments
369 ///
370 /// * `lhs` - The left hand side tensor.
371 /// * `rhs` - The right hand side tensor.
372 ///
373 /// # Returns
374 ///
375 /// The element-wise remainder when dividing `lhs` by `rhs`.
376 fn q_remainder(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
377 dequant_op_quant!(
378 ty Self,
379 float_op |lhs, rhs| B::float_remainder(lhs, rhs),
380 lhs,
381 rhs
382 )
383 }
384
385 /// Computes the modulus of a tensor given a scalar.
386 ///
387 /// # Arguments
388 /// * `lhs` - The left hand side tensor.
389 /// * `rhs` - The right hand side scalar.
390 ///
391 /// # Returns
392 ///
393 /// The result of applying the modulus of the scalar to the tensor.
394 fn q_remainder_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> QuantizedTensor<B> {
395 let scheme = *lhs.scheme();
396
397 let lhs_f = Self::dequantize(lhs);
398 let out_f = B::float_remainder_scalar(lhs_f, rhs);
399
400 Self::quantize_dynamic(out_f, &scheme)
401 }
402
403 /// Multiplies two tensors together using matrix multiplication.
404 ///
405 /// # Arguments
406 ///
407 /// * `lhs` - The left hand side tensor.
408 /// * `rhs` - The right hand side tensor.
409 ///
410 /// # Returns
411 ///
412 /// The result of multiplying the two tensors together using matrix multiplication.
413 fn q_matmul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
414 dequant_op_quant!(
415 ty Self,
416 float_op |lhs, rhs| B::float_matmul(lhs, rhs),
417 lhs,
418 rhs
419 )
420 }
421
422 /// Negates a tensor element-wise.
423 fn q_neg(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
424 let scheme = *tensor.scheme();
425
426 let tensor_f = Self::dequantize(tensor);
427 let out_f = B::float_neg(tensor_f);
428
429 Self::quantize_dynamic(out_f, &scheme)
430 }
431
432 /// Calculates the reciprocals element-wise
433 fn q_recip(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
434 let scheme = *tensor.scheme();
435
436 let tensor_f = Self::dequantize(tensor);
437 let out_f = B::float_recip(tensor_f);
438
439 Self::quantize_dynamic(out_f, &scheme)
440 }
441
442 /// Transposes a tensor.
443 ///
444 /// # Arguments
445 ///
446 /// * `tensor` - The tensor to transpose.
447 ///
448 /// # Returns
449 ///
450 /// The transposed tensor.
451 fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
452 let ndims = tensor.shape().num_dims();
453 Self::q_swap_dims(tensor, ndims - 2, ndims - 1)
454 }
455
456 /// Swaps two dimensions of a tensor.
457 ///
458 /// # Arguments
459 ///
460 /// * `tensor` - The tensor to swap the dimensions of.
461 /// * `dim1` - The first dimension to swap.
462 /// * `dim2` - The second dimension to swap.
463 ///
464 /// # Returns
465 ///
466 /// The tensor with the dimensions swapped.
467 fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;
468
469 /// Permutes the dimensions of a tensor.
470 ///
471 /// # Arguments
472 ///
473 /// * `tensor` - The tensor to permute the dimensions of.
474 /// * `axes` - The new order of the dimensions.
475 /// # Returns
476 ///
477 /// The tensor with the dimensions permuted.
478 fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
479
480 /// Reverse the order of elements in a tensor along the given axes.
481 ///
482 /// # Arguments
483 ///
484 /// * `tensor` - The tensor to reverse.
485 /// * `axes` - The axes to reverse.
486 ///
487 /// The tensor with the elements reversed.
488 fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;
489
490 /// Gather elements from a tensor.
491 ///
492 /// # Arguments
493 ///
494 /// * `dim` - The dimension to gather from.
495 /// * `tensor` - The tensor to gather from.
496 /// * `indices` - The indices to gather.
497 ///
498 /// # Returns
499 ///
500 /// The gathered elements.
501 fn q_gather(
502 dim: usize,
503 tensor: QuantizedTensor<B>,
504 indices: IntTensor<B>,
505 ) -> QuantizedTensor<B> {
506 // Default implementation. Backends can gather on the quantized values when supported.
507 dequant_op_quant!(
508 ty Self,
509 float_op |tensor| B::float_gather(dim, tensor, indices),
510 tensor
511 )
512 }
513
514 /// Scatter elements into a tensor.
515 ///
516 /// # Arguments
517 ///
518 /// * `dim` - The dimension to scatter into.
519 /// * `tensor` - The tensor to scatter into.
520 /// * `indices` - The indices to scatter into.
521 /// * `value` - The value to scatter.
522 ///
523 /// # Returns
524 ///
525 /// The tensor with the scattered elements.
526 fn q_scatter(
527 dim: usize,
528 tensor: QuantizedTensor<B>,
529 indices: IntTensor<B>,
530 value: QuantizedTensor<B>,
531 ) -> QuantizedTensor<B> {
532 dequant_op_quant!(
533 ty Self,
534 float_op |tensor, value| B::float_scatter(dim, tensor, indices, value),
535 tensor,
536 value
537 )
538 }
539
540 /// Select tensor elements along the given dimension corresponding for the given indices.
541 ///
542 /// # Arguments
543 ///
544 /// * `tensor` - The tensor to select from.
545 /// * `dim` - The dimension to select from.
546 /// * `indices` - The indices to select.
547 ///
548 /// # Returns
549 ///
550 /// The selected elements.
551 fn q_select(
552 tensor: QuantizedTensor<B>,
553 dim: usize,
554 indices: IntTensor<B>,
555 ) -> QuantizedTensor<B>;
556
557 /// Assign the selected elements along the given dimension corresponding for the given indices
558 /// to the given value.
559 ///
560 /// # Arguments
561 ///
562 /// * `tensor` - The tensor to select from.
563 /// * `dim` - The dimension to select from.
564 /// * `indices` - The indices to select.
565 /// * `value` - The value to assign.
566 ///
567 /// # Returns
568 ///
569 /// The tensor with the selected elements assigned to the given value.
570 fn q_select_assign(
571 tensor: QuantizedTensor<B>,
572 dim: usize,
573 indices: IntTensor<B>,
574 value: QuantizedTensor<B>,
575 ) -> QuantizedTensor<B> {
576 dequant_op_quant!(
577 ty Self,
578 float_op |tensor, value| B::float_select_assign(tensor, dim, indices, value),
579 tensor,
580 value
581 )
582 }
583
584 /// Select tensor elements corresponding for the given ranges.
585 ///
586 /// # Arguments
587 ///
588 /// * `tensor` - The tensor to select from.
589 /// * `ranges` - The ranges to select.
590 ///
591 /// # Returns
592 ///
593 /// The selected elements in a new tensor.
594 fn q_slice(tensor: QuantizedTensor<B>, ranges: &[Range<usize>]) -> QuantizedTensor<B>;
595
596 /// Assign the selected elements corresponding for the given ranges to the given value.
597 ///
598 /// # Arguments
599 ///
600 /// * `tensor` - The tensor to select from.
601 /// * `ranges` - The ranges to select.
602 /// * `value` - The value to assign.
603 ///
604 /// # Returns
605 ///
606 /// The tensor with the selected elements assigned to the given value.
607 fn q_slice_assign(
608 tensor: QuantizedTensor<B>,
609 ranges: &[Range<usize>],
610 value: QuantizedTensor<B>,
611 ) -> QuantizedTensor<B> {
612 dequant_op_quant!(
613 ty Self,
614 float_op |tensor, value| B::float_slice_assign(tensor, ranges, value),
615 tensor,
616 value
617 )
618 }
619
620 /// Update the given tensor with the value tensor where the mask is true.
621 ///
622 /// # Arguments
623 ///
624 /// * `tensor` - The tensor to select from.
625 /// * `mask` - The boolean mask to select with.
626 /// * `value` - The value to assign to the selected elements from the value tensor.
627 ///
628 /// # Returns
629 ///
630 /// The tensor with the selected elements assigned to the given value.
631 fn q_mask_where(
632 tensor: QuantizedTensor<B>,
633 mask: BoolTensor<B>,
634 value: QuantizedTensor<B>,
635 ) -> QuantizedTensor<B> {
636 dequant_op_quant!(
637 ty Self,
638 float_op |tensor, value| B::float_mask_where(tensor, mask, value),
639 tensor,
640 value
641 )
642 }
643
644 /// Update the given tensor with the value where the mask is true.
645 ///
646 /// # Arguments
647 ///
648 /// * `tensor` - The tensor to select from.
649 /// * `mask` - The boolean mask to select with.
650 /// * `value` - The value to assign to the selected elements.
651 ///
652 /// # Returns
653 ///
654 /// The tensor with the selected elements assigned to the given value.
655 fn q_mask_fill(
656 tensor: QuantizedTensor<B>,
657 mask: BoolTensor<B>,
658 value: FloatElem<B>,
659 ) -> QuantizedTensor<B> {
660 dequant_op_quant!(
661 ty Self,
662 float_op |tensor| B::float_mask_fill(tensor, mask, value),
663 tensor
664 )
665 }
666
667 /// Sum of all elements in a tensor.
668 ///
669 /// # Arguments
670 ///
671 /// * `tensor` - The tensor to sum.
672 ///
673 /// # Returns
674 ///
675 /// A scalar tensor with the sum of all elements in `tensor`.
676 fn q_sum(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
677 dequant_op_quant!(
678 ty Self,
679 float_op |tensor| B::float_sum(tensor),
680 tensor
681 )
682 }
683
684 /// Sum of all elements in a tensor along a dimension.
685 ///
686 /// # Arguments
687 ///
688 /// * `tensor` - The tensor to sum.
689 /// * `dim` - The dimension along which to sum.
690 ///
691 /// # Returns
692 ///
693 /// A tensor with the sum of all elements in `tensor` along `dim`.
694 fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
695 dequant_op_quant!(
696 ty Self,
697 float_op |tensor| B::float_sum_dim(tensor, dim),
698 tensor
699 )
700 }
701
702 /// Product of all elements in a tensor.
703 ///
704 /// # Arguments
705 ///
706 /// * `tensor` - The tensor to product.
707 ///
708 /// # Returns
709 ///
710 /// A scalar tensor with the product of all elements in `tensor`.
711 fn q_prod(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
712 dequant_op_quant!(
713 ty Self,
714 float_op |tensor| B::float_prod(tensor),
715 tensor
716 )
717 }
718
719 /// Product of all elements in a tensor along a dimension.
720 ///
721 /// # Arguments
722 ///
723 /// * `tensor` - The tensor to product.
724 ///
725 /// # Returns
726 ///
727 /// A tensor with the product of all elements in `tensor` along `dim`.
728 fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
729 dequant_op_quant!(
730 ty Self,
731 float_op |tensor| B::float_prod_dim(tensor, dim),
732 tensor
733 )
734 }
735
736 /// Mean of all elements in a tensor.
737 ///
738 /// # Arguments
739 ///
740 /// * `tensor` - The tensor to mean.
741 ///
742 /// # Returns
743 ///
744 /// A scalar tensor with the mean of all elements in `tensor`.
745 fn q_mean(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
746 dequant_op_quant!(
747 ty Self,
748 float_op |tensor| B::float_mean(tensor),
749 tensor
750 )
751 }
752
753 /// Mean of all elements in a tensor along a dimension.
754 ///
755 /// # Arguments
756 ///
757 /// * `tensor` - The tensor to mean.
758 /// * `dim` - The dimension along which to mean.
759 ///
760 /// # Returns
761 ///
762 /// A tensor with the mean of all elements in `tensor` along `dim`.
763 fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
764 dequant_op_quant!(
765 ty Self,
766 float_op |tensor| B::float_mean_dim(tensor, dim),
767 tensor
768 )
769 }
770
771 /// Returns a new tensor with exponential values.
772 ///
773 /// # Arguments
774 ///
775 /// * `tensor` - The tensor to exponentiate.
776 ///
777 /// # Returns
778 ///
779 /// A tensor with the same shape as `tensor` with exponential values.
780 fn q_exp(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
781 dequant_op_quant!(
782 ty Self,
783 float_op |tensor| B::float_exp(tensor),
784 tensor
785 )
786 }
787
788 /// Returns a new tensor with natural logarithm values.
789 ///
790 /// # Arguments
791 ///
792 /// * `tensor` - The tensor to take the logarithm of.
793 ///
794 /// # Returns
795 ///
796 /// A tensor with the same shape as `tensor` with natural logarithm values.
797 fn q_log(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
798 dequant_op_quant!(
799 ty Self,
800 float_op |tensor| B::float_log(tensor),
801 tensor
802 )
803 }
804
805 /// Returns a new tensor with logarithm values of (1 + Xi).
806 ///
807 /// # Arguments
808 ///
809 /// * `tensor` - The tensor to take the logarithm of.
810 ///
811 /// # Returns
812 ///
813 /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
814 fn q_log1p(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
815 dequant_op_quant!(
816 ty Self,
817 float_op |tensor| B::float_log1p(tensor),
818 tensor
819 )
820 }
821
822 /// Element-wise power with another tensor.
823 ///
824 /// # Arguments
825 ///
826 /// * `lhs` - The left hand side tensor.
827 /// * `rhs` - The right hand side tensor.
828 ///
829 /// # Returns
830 ///
831 /// The elements of `lhs` raised to the power of the elements of `rhs`.
832 fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
833 dequant_op_quant!(
834 ty Self,
835 float_op |lhs, rhs| B::float_powf(lhs, rhs),
836 lhs,
837 rhs
838 )
839 }
840
841 /// Element-wise power with an IntTensor.
842 ///
843 /// # Arguments
844 ///
845 /// * `lhs` - The left hand side tensor.
846 /// * `rhs` - The right hand side floatTensor.
847 ///
848 /// # Returns
849 ///
850 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
851 fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> QuantizedTensor<B> {
852 dequant_op_quant!(
853 ty Self,
854 float_op |tensor| B::float_powi(tensor, rhs),
855 lhs
856 )
857 }
858
859 /// Element-wise power with an int scalar.
860 ///
861 /// # Arguments
862 ///
863 /// * `lhs` - The left hand side tensor.
864 /// * `rhs` - The right hand side scalar.
865 ///
866 /// # Returns
867 ///
868 /// The elements of `lhs` raised to the value of `rhs`.
869 fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: IntElem<B>) -> QuantizedTensor<B> {
870 dequant_op_quant!(
871 ty Self,
872 float_op |tensor| B::float_powi_scalar(tensor, rhs),
873 lhs
874 )
875 }
876
877 /// Element-wise power with a float scalar.
878 ///
879 /// # Arguments
880 ///
881 /// * `tensor` - The tensor to exponentiate.
882 /// * `value` - The exponent.
883 ///
884 /// # Returns
885 ///
886 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
887 fn q_powf_scalar(tensor: QuantizedTensor<B>, value: f32) -> QuantizedTensor<B> {
888 dequant_op_quant!(
889 ty Self,
890 float_op |tensor| B::float_powf_scalar(tensor, value),
891 tensor
892 )
893 }
894
895 /// Returns a new tensor with square root values.
896 ///
897 /// # Arguments
898 ///
899 /// * `tensor` - The tensor to take the square root of.
900 ///
901 /// # Returns
902 ///
903 /// A tensor with the same shape as `tensor` with square root values.
904 fn q_sqrt(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
905 dequant_op_quant!(
906 ty Self,
907 float_op |tensor| B::float_sqrt(tensor),
908 tensor
909 )
910 }
911
912 /// Returns a new tensor with absolute values.
913 ///
914 /// # Arguments
915 ///
916 /// * `tensor` - The tensor to take absolute value of.
917 ///
918 /// # Returns
919 ///
920 /// A tensor with the same shape as `tensor` with absolute values.
921 fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
922 dequant_op_quant!(
923 ty Self,
924 float_op |tensor| B::float_abs(tensor),
925 tensor
926 )
927 }
928
929 /// Returns a new tensor with cosine values.
930 ///
931 /// # Arguments
932 ///
933 /// * `tensor` - The tensor to take the cosine of.
934 ///
935 /// # Returns
936 ///
937 /// A tensor with the same shape as `tensor` with cosine values.
938 fn q_cos(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
939 dequant_op_quant!(
940 ty Self,
941 float_op |tensor| B::float_cos(tensor),
942 tensor
943 )
944 }
945
946 /// Returns a new tensor with sine values.
947 ///
948 /// # Arguments
949 ///
950 /// * `tensor` - The tensor to take the sine of.
951 ///
952 /// # Returns
953 ///
954 /// A tensor with the same shape as `tensor` with sine values.
955 fn q_sin(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
956 dequant_op_quant!(
957 ty Self,
958 float_op |tensor| B::float_sin(tensor),
959 tensor
960 )
961 }
962
963 /// Returns a new tensor with tangent values.
964 ///
965 /// # Arguments
966 ///
967 /// * `tensor` - The tensor to take the tangent of.
968 ///
969 /// # Returns
970 ///
971 /// A tensor with the same shape as `tensor` with tangent values.
972 fn q_tanh(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
973 dequant_op_quant!(
974 ty Self,
975 float_op |tensor| B::float_tanh(tensor),
976 tensor
977 )
978 }
979
980 /// Returns a new tensor with the error function values.
981 ///
982 /// # Arguments
983 ///
984 /// * `tensor` - The tensor to take the error function of.
985 ///
986 /// # Returns
987 ///
988 /// A tensor with the same shape as `tensor` with error function values.
989 fn q_erf(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
990 dequant_op_quant!(
991 ty Self,
992 float_op |tensor| B::float_erf(tensor),
993 tensor
994 )
995 }
996
997 /// Concatenates tensors along a dimension.
998 ///
999 /// # Arguments
1000 ///
1001 /// * `tensors` - The tensors to concatenate.
1002 /// * `dim` - The dimension along which to concatenate.
1003 ///
1004 /// # Returns
1005 ///
1006 /// A tensor with the concatenated tensors along `dim`.
1007 fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {
1008 // Heuristic: prioritize first tensor scheme
1009 let scheme = *tensors.first().unwrap().scheme();
1010
1011 let tensor_f = tensors
1012 .into_iter()
1013 .map(|tensor| Self::dequantize(tensor))
1014 .collect();
1015
1016 let out_f = B::float_cat(tensor_f, dim);
1017
1018 Self::quantize_dynamic(out_f, &scheme)
1019 }
1020
1021 /// Gets the indices of the maximum elements of a tensor along an axis.
1022 ///
1023 /// # Arguments
1024 ///
1025 /// * `tensor` - The tensor to get the maximum elements of.
1026 /// * `dim` - The dimension along which to get the maximum elements.
1027 ///
1028 /// # Returns
1029 ///
1030 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1031 fn q_argmax(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1032 // Default implementation. Backends can sort on the int values since qparams remain the same.
1033 let tensor_f = Self::dequantize(tensor);
1034 B::float_argmax(tensor_f, dim)
1035 }
1036
1037 /// Gets the indices of the minimum elements of a tensor along an axis.
1038 ///
1039 /// # Arguments
1040 ///
1041 /// * `tensor` - The tensor to get the minimum elements of.
1042 /// * `dim` - The dimension along which to get the minimum elements.
1043 ///
1044 /// # Returns
1045 ///
1046 /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1047 fn q_argmin(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {
1048 // Default implementation. Backends can sort on the int values since qparams remain the same.
1049 let tensor_f = Self::dequantize(tensor);
1050 B::float_argmin(tensor_f, dim)
1051 }
1052
1053 /// Gets the maximum element of a tensor.
1054 ///
1055 /// # Arguments
1056 ///
1057 /// * `tensor` - The tensor to get the maximum elements of.
1058 ///
1059 /// # Returns
1060 ///
1061 /// A tensor with the maximum element of `tensor`.
1062 fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1063 let shape = tensor.shape();
1064 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1065
1066 B::q_max_dim(tensor, 0)
1067 }
1068
1069 /// Gets the maximum elements of a tensor along an axis.
1070 ///
1071 /// # Arguments
1072 ///
1073 /// * `tensor` - The tensor to get the maximum elements of.
1074 /// * `dim` - The dimension along which to get the maximum elements.
1075 ///
1076 /// # Returns
1077 ///
1078 /// A tensor with the maximum elements of `tensor` along `dim`.
1079 fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1080 let index = B::q_argmax(tensor.clone(), dim);
1081
1082 B::q_gather(dim, tensor, index)
1083 }
1084
1085 /// Gets the maximum elements of a tensor along an axis and their indices.
1086 ///
1087 /// # Arguments
1088 ///
1089 /// * `tensor` - The tensor to get the maximum elements of.
1090 /// * `dim` - The dimension along which to get the maximum elements.
1091 ///
1092 /// # Returns
1093 ///
1094 /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1095 fn q_max_dim_with_indices(
1096 tensor: QuantizedTensor<B>,
1097 dim: usize,
1098 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1099 let index = B::q_argmax(tensor.clone(), dim);
1100 let values = B::q_gather(dim, tensor, index.clone());
1101
1102 (values, index)
1103 }
1104
1105 /// Gets the minimum element of a tensor.
1106 ///
1107 /// # Arguments
1108 ///
1109 /// * `tensor` - The tensor to get the minimum elements of.
1110 ///
1111 /// # Returns
1112 ///
1113 /// A tensor with the minimum element of `tensor`.
1114 fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {
1115 let shape = tensor.shape();
1116 let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));
1117
1118 B::q_min_dim(tensor, 0)
1119 }
1120
1121 /// Gets the minimum elements of a tensor along an axis.
1122 ///
1123 /// # Arguments
1124 ///
1125 /// * `tensor` - The tensor to get the minimum elements of.
1126 /// * `dim` - The dimension along which to get the minimum elements.
1127 ///
1128 /// # Returns
1129 ///
1130 /// A tensor with the minimum elements of `tensor` along `dim`.
1131 fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {
1132 let index = B::q_argmin(tensor.clone(), dim);
1133
1134 B::q_gather(dim, tensor, index)
1135 }
1136
1137 /// Gets the minimum elements of a tensor along an axis and their indices.
1138 ///
1139 /// # Arguments
1140 ///
1141 /// * `tensor` - The tensor to get the minimum elements of.
1142 /// * `dim` - The dimension along which to get the minimum elements.
1143 ///
1144 /// # Returns
1145 ///
1146 /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1147 fn q_min_dim_with_indices(
1148 tensor: QuantizedTensor<B>,
1149 dim: usize,
1150 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1151 let index = B::q_argmin(tensor.clone(), dim);
1152 let values = B::q_gather(dim, tensor, index.clone());
1153
1154 (values, index)
1155 }
1156
1157 /// Returns a new tensor with the given dimension narrowed to the given range.
1158 ///
1159 /// # Arguments
1160 ///
1161 /// * `dim` - The dimension along which the tensor will be narrowed.
1162 /// * `start` - The starting point of the given range.
1163 /// * `length` - The ending point of the given range.
1164 /// # Panics
1165 ///
1166 /// - If the dimension is greater than the number of dimensions of the tensor.
1167 /// - If the given range exceeds the number of elements on the given dimension.
1168 ///
1169 /// # Returns
1170 ///
1171 /// A new tensor with the given dimension narrowed to the given range.
1172 fn q_narrow(
1173 tensor: QuantizedTensor<B>,
1174 dim: usize,
1175 start: usize,
1176 length: usize,
1177 ) -> QuantizedTensor<B> {
1178 dequant_op_quant!(
1179 ty Self,
1180 float_op |tensor| B::float_narrow(tensor, dim, start, length),
1181 tensor
1182 )
1183 }
1184
1185 /// Split the tensor along the given dimension into chunks.
1186 ///
1187 /// # Arguments
1188 ///
1189 /// * `tensor` - The tensor.
1190 /// * `chunks` - The number of chunks to be produced.
1191 /// * `times` - The dimension along which the tensor will be split.
1192 ///
1193 /// # Returns
1194 ///
1195 /// A vector of tensors.
1196 fn q_chunk(tensor: QuantizedTensor<B>, chunks: usize, dim: usize) -> Vec<QuantizedTensor<B>> {
1197 let scheme = *tensor.scheme();
1198
1199 let tensor_f = Self::dequantize(tensor);
1200 let out_f = B::float_chunk(tensor_f, chunks, dim);
1201
1202 out_f
1203 .into_iter()
1204 .map(|tensor| Self::quantize_dynamic(tensor, &scheme))
1205 .collect()
1206 }
1207
1208 /// Split the tensor along the given dimension into chunks of `split_size`.
1209 ///
1210 /// # Arguments
1211 ///
1212 /// * `tensor` - The tensor.
1213 /// * `split_size` - The size of a single chunk.
1214 /// * `times` - The dimension along which the tensor will be split.
1215 ///
1216 /// # Returns
1217 ///
1218 /// A vector of tensors.
1219 fn q_split(
1220 tensor: QuantizedTensor<B>,
1221 split_size: usize,
1222 dim: usize,
1223 ) -> Vec<QuantizedTensor<B>> {
1224 let scheme = *tensor.scheme();
1225
1226 let tensor_f = Self::dequantize(tensor);
1227 let out_f = B::float_split(tensor_f, split_size, dim);
1228
1229 out_f
1230 .into_iter()
1231 .map(|tensor| Self::quantize_dynamic(tensor, &scheme))
1232 .collect()
1233 }
1234
1235 /// Split the tensor along the given dimension into chunks with sizes in
1236 /// `dim` according to `split_sizes`.
1237 ///
1238 /// # Arguments
1239 ///
1240 /// * `tensor` - The tensor.
1241 /// * `split_sizes` - Vector of sizes for each chunk.
1242 /// * `times` - The dimension along which the tensor will be split.
1243 ///
1244 /// # Returns
1245 ///
1246 /// A vector of tensors.
1247 fn q_split_with_sizes(
1248 tensor: QuantizedTensor<B>,
1249 split_sizes: Vec<usize>,
1250 dim: usize,
1251 ) -> Vec<QuantizedTensor<B>> {
1252 let scheme = *tensor.scheme();
1253
1254 let tensor_f = Self::dequantize(tensor);
1255 let out_f = B::float_split_with_sizes(tensor_f, split_sizes, dim);
1256
1257 out_f
1258 .into_iter()
1259 .map(|tensor| Self::quantize_dynamic(tensor, &scheme))
1260 .collect()
1261 }
1262
1263 /// Tests if any element in the `tensor` evaluates to True.
1264 ///
1265 /// # Arguments
1266 ///
1267 /// * `tensor` - The tensor to test.
1268 ///
1269 /// # Returns
1270 ///
1271 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1272 fn q_any(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1273 let tensor_f = Self::dequantize(tensor);
1274 B::float_any(tensor_f)
1275 }
1276
1277 /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1278 ///
1279 /// # Arguments
1280 ///
1281 /// * `tensor` - The tensor to test.
1282 /// * `dim` - The axis along which to test.
1283 ///
1284 /// # Returns
1285 ///
1286 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1287 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1288 /// input evaluates to True, False otherwise.
1289 fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1290 let tensor_f = Self::dequantize(tensor);
1291 B::float_any_dim(tensor_f, dim)
1292 }
1293
1294 /// Tests if all elements in the `tensor` evaluate to True.
1295 ///
1296 /// # Arguments
1297 ///
1298 /// * `tensor` - The tensor to test.
1299 ///
1300 /// # Returns
1301 ///
1302 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1303 /// evaluate to True, False otherwise.
1304 fn q_all(tensor: QuantizedTensor<B>) -> BoolTensor<B> {
1305 let tensor_f = Self::dequantize(tensor);
1306 B::float_all(tensor_f)
1307 }
1308
1309 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1310 ///
1311 /// # Arguments
1312 ///
1313 /// * `tensor` - The tensor to test.
1314 /// * `dim` - The axis along which to test.
1315 ///
1316 /// # Returns
1317 ///
1318 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1319 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1320 /// evaluates to True, False otherwise.
1321 fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {
1322 let tensor_f = Self::dequantize(tensor);
1323 B::float_all_dim(tensor_f, dim)
1324 }
1325
1326 /// Broadcasts the `tensor` to the given `shape`.
1327 fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
1328
1329 /// Sort the elements of the input `tensor` by value in along a given dimension.
1330 ///
1331 /// This sort is unstable (i.e., may reorder equal elements).
1332 ///
1333 /// # Arguments
1334 ///
1335 /// * `tensor` - The input tensor.
1336 /// * `dim` - The axis along which to sort.
1337 /// * `descending` - The sorting order.
1338 ///
1339 /// # Returns
1340 ///
1341 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1342 fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {
1343 // Default implementation. Backends can sort on the int values since qparams remain the same.
1344 dequant_op_quant!(
1345 ty Self,
1346 float_op |tensor| B::float_sort(tensor, dim, descending),
1347 tensor
1348 )
1349 }
1350
1351 /// Sort the elements of the input `tensor` by value in along a given dimension.
1352 ///
1353 /// This sort is unstable (i.e., may reorder equal elements).
1354 ///
1355 /// # Arguments
1356 ///
1357 /// * `tensor` - The input tensor.
1358 /// * `dim` - The axis along which to sort.
1359 /// * `descending` - The sorting order.
1360 ///
1361 /// # Returns
1362 ///
1363 /// A tensor with the same shape as the input tensor and corresponding indices, where
1364 /// the elements are sorted by value and the indices map back to the original input tensor.
1365 fn q_sort_with_indices(
1366 tensor: QuantizedTensor<B>,
1367 dim: usize,
1368 descending: bool,
1369 ) -> (QuantizedTensor<B>, IntTensor<B>) {
1370 // Default implementation. Backends can sort on the int values since qparams remain the same.
1371 let scheme = *tensor.scheme();
1372
1373 let tensor_f = Self::dequantize(tensor);
1374 let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending);
1375
1376 (Self::quantize_dynamic(out_f, &scheme), indices)
1377 }
1378
1379 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1380 ///
1381 /// This sort is unstable (i.e., may reorder equal elements).
1382 ///
1383 /// # Arguments
1384 ///
1385 /// * `tensor` - The input tensor.
1386 /// * `dim` - The axis along which to sort.
1387 /// * `descending` - The sorting order.
1388 ///
1389 /// # Returns
1390 ///
1391 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1392 fn q_argsort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1393 // Default implementation. Backends can sort on the int values since qparams remain the same.
1394 let tensor_f = Self::dequantize(tensor);
1395 B::float_argsort(tensor_f, dim, descending)
1396 }
1397}