burn_tensor/tensor/quantization/
scheme.rs

1// We re-export those types.
2pub use cubecl_quant::scheme::{
3    BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue,
4};
5
6use serde::{Deserialize, Serialize};
7
8use crate::{Shape, Tensor, TensorMetadata, TensorPrimitive, backend::Backend};
9
10use super::{
11    Calibration, CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive,
12};
13
14#[derive(
15    Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
16)]
17/// The precision of accumulating elements.
18pub enum QuantAcc {
19    /// Full precision.
20    #[default]
21    F32,
22    /// Half precision.
23    F16,
24    /// bfloat16 precision.
25    BF16,
26}
27
28/// Specify if the output of an operation is quantized using the scheme of the input
29/// or returned unquantized.
30#[derive(
31    Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
32)]
33pub enum QuantPropagation {
34    /// The output is quantized using the scheme of the input.
35    Propagate,
36    /// The output is not quantized.
37    #[default]
38    Inhibit,
39}
40
41/// Compute the quantization range mapping.
42pub fn compute_range<B: Backend, const D: usize>(
43    scheme: &QuantScheme,
44    tensor: &Tensor<B, D>,
45    calibration: &Calibration,
46) -> CalibrationRange<B> {
47    let (min, max) = match &tensor.primitive {
48        TensorPrimitive::Float(tensor) => {
49            compute_range_primitive::<B>(scheme, tensor.clone(), calibration)
50        }
51        TensorPrimitive::QFloat(_) => unreachable!(),
52    };
53
54    CalibrationRange {
55        min: Tensor::from_primitive(TensorPrimitive::Float(min)),
56        max: Tensor::from_primitive(TensorPrimitive::Float(max)),
57    }
58}
59
60/// Calculate the shape of the quantization parameters for a given tensor and level
61pub fn params_shape(data_shape: &Shape, level: QuantLevel) -> Shape {
62    match level {
63        QuantLevel::Tensor => Shape::new([1]),
64        QuantLevel::Block(block_size) => {
65            let mut params_shape = data_shape.clone();
66            let block_size = block_size.to_dim_vec(data_shape.num_dims());
67
68            for (shape, block_size) in params_shape.dims.iter_mut().zip(block_size) {
69                *shape = (*shape).div_ceil(block_size as usize);
70            }
71
72            params_shape
73        }
74    }
75}
76
77pub(crate) fn compute_range_primitive<B: Backend>(
78    scheme: &QuantScheme,
79    tensor: B::FloatTensorPrimitive,
80    calibration: &Calibration,
81) -> (B::FloatTensorPrimitive, B::FloatTensorPrimitive) {
82    match calibration {
83        Calibration::MinMax => match scheme.level {
84            QuantLevel::Tensor => (B::float_min(tensor.clone()), B::float_max(tensor)),
85            QuantLevel::Block(block_size) => {
86                let block_elems = block_size.num_elements();
87                let shape = tensor.shape();
88                let numel = shape.num_elements();
89
90                assert_eq!(
91                    numel % block_elems,
92                    0,
93                    "Tensor {shape:?} must be evenly divisible by block size {block_elems}"
94                );
95
96                let num_blocks = numel / block_elems;
97
98                let params_shape = params_shape(&shape, scheme.level);
99
100                let blocks = B::float_reshape(tensor, Shape::new([num_blocks, block_elems]));
101                let blocks_min =
102                    B::float_reshape(B::float_min_dim(blocks.clone(), 1), params_shape.clone());
103                let blocks_max = B::float_reshape(B::float_max_dim(blocks, 1), params_shape);
104                (blocks_min, blocks_max)
105            }
106        },
107    }
108}
109
110/// Compute the quantization parameters.
111pub fn compute_q_params<B: Backend>(
112    scheme: &QuantScheme,
113    range: CalibrationRange<B>,
114) -> QuantizationParameters<B> {
115    match scheme {
116        QuantScheme {
117            level: QuantLevel::Tensor | QuantLevel::Block(_),
118            mode: QuantMode::Symmetric,
119            ..
120        } => {
121            // Quantized range `[a, b]`
122            let (a, b) = scheme.value.range();
123
124            // Compute scale to convert an input value in range `[-alpha, alpha]`
125            let values_range = range.min.abs().max_pair(range.max.abs()).mul_scalar(2);
126
127            QuantizationParameters {
128                scales: values_range.div_scalar(b - a),
129            }
130        }
131    }
132}
133
134/// Compute the quantization parameters.
135pub(crate) fn compute_q_params_primitive<B: Backend>(
136    scheme: &QuantScheme,
137    min: B::FloatTensorPrimitive,
138    max: B::FloatTensorPrimitive,
139) -> QuantizationParametersPrimitive<B> {
140    let range = CalibrationRange {
141        min: Tensor::from_primitive(TensorPrimitive::Float(min)),
142        max: Tensor::from_primitive(TensorPrimitive::Float(max)),
143    };
144    compute_q_params(scheme, range).into()
145}