burn_tensor/tensor/quantization/parameters.rs
1use crate::{DType, Shape, Tensor, backend::Backend};
2use alloc::vec::Vec;
3
4/// The tensor quantization parameters.
5pub type QuantizationParameters<B> = QParams<Tensor<B, 1>>;
6
7/// The quantization tensor data parameters.
8#[derive(Clone, Debug)]
9pub struct QParams<S> {
10 /// The scaling factor.
11 pub scales: S,
12}
13
14/// The quantization parameters primitive.
15///
16/// # Remarks
17///
18/// This is a low-level struct used internally by the library to provide the quantization parameters
19/// to the backends. It is not designed for direct usage by users, and not recommended to import
20/// or use this struct directly.
21///
22/// Users should prefer the [QuantizationParameters] struct, which is designed for public use.
23pub struct QuantizationParametersPrimitive<B: Backend> {
24 /// The scaling factor.
25 pub scales: B::FloatTensorPrimitive,
26}
27
28impl<B: Backend> From<QuantizationParameters<B>> for QuantizationParametersPrimitive<B> {
29 fn from(value: QuantizationParameters<B>) -> Self {
30 QuantizationParametersPrimitive {
31 scales: value.scales.primitive.tensor(),
32 }
33 }
34}
35
36/// A quantization parameter tensor descriptor.
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct QParamTensor {
39 /// Start of the tensor in the buffer
40 pub offset_start: usize,
41 /// Offset of tensor end from the end of the buffer
42 pub offset_end: usize,
43 /// Shape of the tensor
44 pub shape: Shape,
45 /// Strides of the tensor
46 pub strides: Vec<usize>,
47 /// Data type of the tensor
48 pub dtype: DType,
49}