burn_backend/tensor/quantization/
scheme.rs

1pub use burn_std::{QPARAM_ALIGN, params_shape};
2use burn_std::{QuantLevel, QuantMode, QuantScheme, Shape};
3
4use super::{Calibration, QuantizationParametersPrimitive};
5use crate::{Backend, TensorMetadata, element::ElementConversion};
6
7/// Compute the quantization range mapping.
8pub fn compute_range<B: Backend>(
9    scheme: &QuantScheme,
10    tensor: B::FloatTensorPrimitive,
11    calibration: &Calibration,
12) -> (B::FloatTensorPrimitive, B::FloatTensorPrimitive) {
13    match calibration {
14        Calibration::MinMax => match scheme.level {
15            QuantLevel::Tensor => (B::float_min(tensor.clone()), B::float_max(tensor)),
16            QuantLevel::Block(block_size) => {
17                let block_elems = block_size.num_elements();
18                let shape = tensor.shape();
19                let numel = shape.num_elements();
20
21                assert_eq!(
22                    numel % block_elems,
23                    0,
24                    "Tensor {shape:?} must be evenly divisible by block size {block_elems}"
25                );
26
27                let num_blocks = numel / block_elems;
28
29                let params_shape = params_shape(&shape, scheme.level);
30
31                let blocks = B::float_reshape(tensor, Shape::new([num_blocks, block_elems]));
32                let blocks_min =
33                    B::float_reshape(B::float_min_dim(blocks.clone(), 1), params_shape.clone());
34                let blocks_max = B::float_reshape(B::float_max_dim(blocks, 1), params_shape);
35                (blocks_min, blocks_max)
36            }
37        },
38    }
39}
40
41/// Compute the quantization parameters.
42pub fn compute_q_params<B: Backend>(
43    scheme: &QuantScheme,
44    min: B::FloatTensorPrimitive,
45    max: B::FloatTensorPrimitive,
46) -> QuantizationParametersPrimitive<B> {
47    match scheme {
48        QuantScheme {
49            level: QuantLevel::Tensor | QuantLevel::Block(_),
50            mode: QuantMode::Symmetric,
51            ..
52        } => {
53            // Quantized range `[a, b]`
54            let (a, b) = scheme.value.range();
55
56            // Compute scale to convert an input value in range `[-alpha, alpha]`
57            let min_abs = B::float_abs(min);
58            let max_abs = B::float_abs(max);
59
60            // `min_abs.max_pair(max_abs)`
61            let mask = B::float_lower(min_abs.clone(), max_abs.clone());
62            let values_range =
63                B::float_mul_scalar(B::float_mask_where(min_abs, mask, max_abs), 2.elem());
64
65            QuantizationParametersPrimitive {
66                scales: B::float_div_scalar(values_range, (b - a).elem()),
67            }
68        }
69    }
70}