burn_tensor/tensor/
quantization.rs

1use crate::{Tensor, TensorPrimitive, backend::Backend};
2use burn_backend::tensor::quantization;
3
4// We re-export those types.
5pub use burn_backend::{QTensorPrimitive, quantization::*};
6
7/// The tensor quantization parameters.
8pub type QuantizationParameters<B> = QParams<Tensor<B, 1>>;
9
10/// The observed input calibration range.
11#[derive(Clone, Debug)]
12pub struct CalibrationRange<B: Backend> {
13    /// Minimum observed value(s).
14    pub min: Tensor<B, 1>,
15    /// Maximum observed value(s).
16    pub max: Tensor<B, 1>,
17}
18
19/// Compute the quantization range mapping.
20pub fn compute_range<B: Backend, const D: usize>(
21    scheme: &QuantScheme,
22    tensor: &Tensor<B, D>,
23    calibration: &Calibration,
24) -> CalibrationRange<B> {
25    let (min, max) = match &tensor.primitive {
26        TensorPrimitive::Float(tensor) => {
27            quantization::compute_range::<B>(scheme, tensor.clone(), calibration)
28        }
29        TensorPrimitive::QFloat(_) => unreachable!(),
30    };
31
32    CalibrationRange {
33        min: Tensor::from_primitive(TensorPrimitive::Float(min)),
34        max: Tensor::from_primitive(TensorPrimitive::Float(max)),
35    }
36}
37
38/// Compute the quantization parameters.
39pub fn compute_q_params<B: Backend>(
40    scheme: &QuantScheme,
41    range: CalibrationRange<B>,
42) -> QuantizationParameters<B> {
43    match (range.min.primitive, range.max.primitive) {
44        (TensorPrimitive::Float(min), TensorPrimitive::Float(max)) => {
45            let qparams = quantization::compute_q_params::<B>(scheme, min, max);
46            QuantizationParameters {
47                scales: Tensor::from_primitive(TensorPrimitive::Float(qparams.scales)),
48            }
49        }
50        _ => unreachable!(),
51    }
52}