burn_tensor/tensor/quantization/
calibration.rs

1use crate::{backend::Backend, Tensor};
2
3/// The observed input calibration range.
4#[derive(Clone, Debug)]
5pub struct CalibrationRange<B: Backend> {
6    /// Minimum observed value.
7    pub min: Tensor<B, 1>,
8    /// Maximum observed value.
9    pub max: Tensor<B, 1>,
10}
11
12/// Calibration method used to compute the quantization range mapping.
13pub trait Calibration {
14    /// Compute the input tensor range.
15    fn compute_range<B: Backend, const D: usize>(
16        &self,
17        tensor: &Tensor<B, D>,
18    ) -> CalibrationRange<B>;
19}
20
21/// Computes the per-tensor quantization range mapping based on the min and max values.
22pub struct MinMaxCalibration {}
23
24impl Calibration for MinMaxCalibration {
25    fn compute_range<B: Backend, const D: usize>(
26        &self,
27        tensor: &Tensor<B, D>,
28    ) -> CalibrationRange<B> {
29        let min = tensor.clone().min();
30        let max = tensor.clone().max();
31
32        CalibrationRange { min, max }
33    }
34}