burn_backend/tensor/quantization/
scheme.rs1pub use burn_std::{QPARAM_ALIGN, params_shape};
2use burn_std::{QuantLevel, QuantMode, QuantScheme, Shape};
3
4use super::{Calibration, QuantizationParametersPrimitive};
5use crate::{Backend, TensorMetadata, element::ElementConversion};
6
7pub fn compute_range<B: Backend>(
9 scheme: &QuantScheme,
10 tensor: B::FloatTensorPrimitive,
11 calibration: &Calibration,
12) -> (B::FloatTensorPrimitive, B::FloatTensorPrimitive) {
13 match calibration {
14 Calibration::MinMax => match scheme.level {
15 QuantLevel::Tensor => (B::float_min(tensor.clone()), B::float_max(tensor)),
16 QuantLevel::Block(block_size) => {
17 let block_elems = block_size.num_elements();
18 let shape = tensor.shape();
19 let numel = shape.num_elements();
20
21 assert_eq!(
22 numel % block_elems,
23 0,
24 "Tensor {shape:?} must be evenly divisible by block size {block_elems}"
25 );
26
27 let num_blocks = numel / block_elems;
28
29 let params_shape = params_shape(&shape, scheme.level);
30
31 let blocks = B::float_reshape(tensor, Shape::new([num_blocks, block_elems]));
32 let blocks_min =
33 B::float_reshape(B::float_min_dim(blocks.clone(), 1), params_shape.clone());
34 let blocks_max = B::float_reshape(B::float_max_dim(blocks, 1), params_shape);
35 (blocks_min, blocks_max)
36 }
37 },
38 }
39}
40
41pub fn compute_q_params<B: Backend>(
43 scheme: &QuantScheme,
44 min: B::FloatTensorPrimitive,
45 max: B::FloatTensorPrimitive,
46) -> QuantizationParametersPrimitive<B> {
47 match scheme {
48 QuantScheme {
49 level: QuantLevel::Tensor | QuantLevel::Block(_),
50 mode: QuantMode::Symmetric,
51 ..
52 } => {
53 let (a, b) = scheme.value.range();
55
56 let min_abs = B::float_abs(min);
58 let max_abs = B::float_abs(max);
59
60 let mask = B::float_lower(min_abs.clone(), max_abs.clone());
62 let values_range =
63 B::float_mul_scalar(B::float_mask_where(min_abs, mask, max_abs), 2.elem());
64
65 QuantizationParametersPrimitive {
66 scales: B::float_div_scalar(values_range, (b - a).elem()),
67 }
68 }
69 }
70}