burn_core/module/
quantize.rs

1use burn_tensor::{
2    backend::Backend,
3    quantization::{Calibration, QuantizationScheme},
4    Tensor,
5};
6
7use crate::module::{ModuleMapper, ParamId};
8
9/// Describes how to quantize a module.
10pub struct Quantizer<C: Calibration> {
11    /// The calibration method used in quantization.
12    pub calibration: C,
13    /// The quantization scheme.
14    pub scheme: QuantizationScheme,
15}
16
17impl<B: Backend, C: Calibration> ModuleMapper<B> for Quantizer<C> {
18    fn map_float<const D: usize>(&mut self, _id: ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
19        let range = self.calibration.compute_range(&tensor);
20        let qparams = self.scheme.compute_q_params(range);
21        tensor.quantize(&self.scheme, qparams)
22    }
23}