burn_tensor/tensor/quantization/
parameters.rs

1use crate::{backend::Backend, Int, Tensor};
2
3/// The tensor quantization parameters.
4pub type QuantizationParameters<B> = QParams<Tensor<B, 1>, Tensor<B, 1, Int>>;
5
6/// The quantization tensor data parameters.
7#[derive(Clone, Debug)]
8pub struct QParams<S, O> {
9    /// The scaling factor.
10    pub scale: S,
11    /// The zero-point offset.
12    pub offset: Option<O>,
13}
14
15/// The quantization parameters primitive.
16///
17/// # Remarks
18///
19/// This is a low-level struct used internally by the library to provide the quantization parameters
20/// to the backends. It is not designed for direct usage by users, and not recommended to import
21/// or use this struct directly.
22///
23/// Users should prefer the [QuantizationParameters] struct, which is designed for public use.
24pub struct QuantizationParametersPrimitive<B: Backend> {
25    /// The scaling factor.
26    pub scale: B::FloatTensorPrimitive,
27    /// The zero-point offset.
28    pub offset: Option<B::IntTensorPrimitive>,
29}
30
31impl<B: Backend> From<QuantizationParameters<B>> for QuantizationParametersPrimitive<B> {
32    fn from(value: QuantizationParameters<B>) -> Self {
33        QuantizationParametersPrimitive {
34            scale: value.scale.primitive.tensor(),
35            offset: value.offset.map(|x| x.primitive),
36        }
37    }
38}