burn_tensor/tensor/quantization/
parameters.rs

1use crate::{DType, Shape, Tensor, backend::Backend};
2use alloc::vec::Vec;
3
4/// The tensor quantization parameters.
5pub type QuantizationParameters<B> = QParams<Tensor<B, 1>>;
6
7/// The quantization tensor data parameters.
8#[derive(Clone, Debug)]
9pub struct QParams<S> {
10    /// The scaling factor.
11    pub scales: S,
12}
13
14/// The quantization parameters primitive.
15///
16/// # Remarks
17///
18/// This is a low-level struct used internally by the library to provide the quantization parameters
19/// to the backends. It is not designed for direct usage by users, and not recommended to import
20/// or use this struct directly.
21///
22/// Users should prefer the [QuantizationParameters] struct, which is designed for public use.
23pub struct QuantizationParametersPrimitive<B: Backend> {
24    /// The scaling factor.
25    pub scales: B::FloatTensorPrimitive,
26}
27
28impl<B: Backend> From<QuantizationParameters<B>> for QuantizationParametersPrimitive<B> {
29    fn from(value: QuantizationParameters<B>) -> Self {
30        QuantizationParametersPrimitive {
31            scales: value.scales.primitive.tensor(),
32        }
33    }
34}
35
36/// A quantization parameter tensor descriptor.
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct QParamTensor {
39    /// Start of the tensor in the buffer
40    pub offset_start: usize,
41    /// Offset of tensor end from the end of the buffer
42    pub offset_end: usize,
43    /// Shape of the tensor
44    pub shape: Shape,
45    /// Strides of the tensor
46    pub strides: Vec<usize>,
47    /// Data type of the tensor
48    pub dtype: DType,
49}