burn_core/module/
quantize.rs1use burn_tensor::{
2 Tensor,
3 backend::Backend,
4 quantization::{Calibration, QuantScheme, compute_q_params, compute_range},
5};
6
7use crate::module::{ModuleMapper, Param};
8
9pub struct Quantizer {
11 pub calibration: Calibration,
13 pub scheme: QuantScheme,
15}
16
17impl<B: Backend> ModuleMapper<B> for Quantizer {
18 fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
19 let (id, tensor, mapper) = param.consume();
20 let range = compute_range(&self.scheme, &tensor, &self.calibration);
21 let qparams = compute_q_params(&self.scheme, range);
22 let tensor = tensor.quantize(&self.scheme, qparams);
23 Param::from_mapped_value(id, tensor, mapper)
24 }
25}
26
27#[cfg(all(test, not(feature = "test-tch")))]
28mod tests {
29 use crate::test_utils::SimpleLinear;
30 use crate::{
31 TestBackend,
32 module::{Module, Quantizer},
33 };
34 use burn_tensor::{
35 Device, Tolerance,
36 ops::QuantizedTensor,
37 quantization::{Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantValue},
38 };
39
40 type B = TestBackend;
41
42 #[test]
43 fn should_quantize_module() {
44 let device: Device<B> = Default::default();
45 let module = SimpleLinear::<B>::new(32, 32, &device);
46 let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
47 .with_value(QuantValue::Q8S)
48 .with_level(QuantLevel::Tensor)
49 .with_param(QuantParam::F32);
50
51 let result = module.weight.val();
52
53 let calibration = Calibration::MinMax;
54 let mut quantizer = Quantizer {
55 calibration,
56 scheme,
57 };
58 let q_module = module.quantize_weights(&mut quantizer);
59 let q_result = q_module.weight.val().dequantize();
60
61 result
62 .into_data()
63 .assert_approx_eq::<f32>(&q_result.into_data(), Tolerance::permissive());
64 }
65}