burn_tensor/tensor/quantization/parameters.rs
1use crate::{backend::Backend, Int, Tensor};
2
3/// The tensor quantization parameters.
4pub type QuantizationParameters<B> = QParams<Tensor<B, 1>, Tensor<B, 1, Int>>;
5
6/// The quantization tensor data parameters.
7#[derive(Clone, Debug)]
8pub struct QParams<S, O> {
9 /// The scaling factor.
10 pub scale: S,
11 /// The zero-point offset.
12 pub offset: Option<O>,
13}
14
15/// The quantization parameters primitive.
16///
17/// # Remarks
18///
19/// This is a low-level struct used internally by the library to provide the quantization parameters
20/// to the backends. It is not designed for direct usage by users, and not recommended to import
21/// or use this struct directly.
22///
23/// Users should prefer the [QuantizationParameters] struct, which is designed for public use.
24pub struct QuantizationParametersPrimitive<B: Backend> {
25 /// The scaling factor.
26 pub scale: B::FloatTensorPrimitive,
27 /// The zero-point offset.
28 pub offset: Option<B::IntTensorPrimitive>,
29}
30
31impl<B: Backend> From<QuantizationParameters<B>> for QuantizationParametersPrimitive<B> {
32 fn from(value: QuantizationParameters<B>) -> Self {
33 QuantizationParametersPrimitive {
34 scale: value.scale.primitive.tensor(),
35 offset: value.offset.map(|x| x.primitive),
36 }
37 }
38}