burn_core/module/
quantize.rs

1use burn_tensor::{
2    Tensor,
3    backend::Backend,
4    quantization::{Calibration, QuantScheme, compute_q_params, compute_range},
5};
6
7use crate::module::{ModuleMapper, Param};
8
9/// Describes how to quantize a module.
10pub struct Quantizer {
11    /// The calibration method used in quantization.
12    pub calibration: Calibration,
13    /// The quantization scheme.
14    pub scheme: QuantScheme,
15}
16
17impl<B: Backend> ModuleMapper<B> for Quantizer {
18    fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
19        let (id, tensor, mapper) = param.consume();
20        let range = compute_range(&self.scheme, &tensor, &self.calibration);
21        let qparams = compute_q_params(&self.scheme, range);
22        let tensor = tensor.quantize(&self.scheme, qparams);
23        Param::from_mapped_value(id, tensor, mapper)
24    }
25}
26
27#[cfg(all(test, not(feature = "test-tch")))]
28mod tests {
29    use crate::test_utils::SimpleLinear;
30    use crate::{
31        TestBackend,
32        module::{Module, Quantizer},
33    };
34    use burn_tensor::{
35        Device, Tolerance,
36        ops::QuantizedTensor,
37        quantization::{Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantValue},
38    };
39
40    type B = TestBackend;
41
42    #[test]
43    fn should_quantize_module() {
44        let device: Device<B> = Default::default();
45        let module = SimpleLinear::<B>::new(32, 32, &device);
46        let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
47            .with_value(QuantValue::Q8S)
48            .with_level(QuantLevel::Tensor)
49            .with_param(QuantParam::F32);
50
51        let result = module.weight.val();
52
53        let calibration = Calibration::MinMax;
54        let mut quantizer = Quantizer {
55            calibration,
56            scheme,
57        };
58        let q_module = module.quantize_weights(&mut quantizer);
59        let q_result = q_module.weight.val().dequantize();
60
61        result
62            .into_data()
63            .assert_approx_eq::<f32>(&q_result.into_data(), Tolerance::permissive());
64    }
65}