burn_tensor/tensor/quantization/
calibration.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
use crate::{backend::Backend, Tensor};

/// The observed input calibration range.
#[derive(Clone, Debug)]
pub struct CalibrationRange<B: Backend> {
    /// Minimum observed value.
    pub min: Tensor<B, 1>,
    /// Maximum observed value.
    pub max: Tensor<B, 1>,
}

/// Calibration method used to compute the quantization range mapping.
pub trait Calibration {
    /// Compute the input tensor range.
    fn compute_range<B: Backend, const D: usize>(
        &self,
        tensor: &Tensor<B, D>,
    ) -> CalibrationRange<B>;
}

/// Computes the per-tensor quantization range mapping based on the min and max values.
pub struct MinMaxCalibration {}

impl Calibration for MinMaxCalibration {
    fn compute_range<B: Backend, const D: usize>(
        &self,
        tensor: &Tensor<B, D>,
    ) -> CalibrationRange<B> {
        let min = tensor.clone().min();
        let max = tensor.clone().max();

        CalibrationRange { min, max }
    }
}