burn_tensor/tensor/quantization/
scheme.rs1pub use cubecl_quant::scheme::{
3 BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue,
4};
5
6use serde::{Deserialize, Serialize};
7
8use crate::{Shape, Tensor, TensorMetadata, TensorPrimitive, backend::Backend};
9
10use super::{
11 Calibration, CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive,
12};
13
14#[derive(
15 Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
16)]
17pub enum QuantAcc {
19 #[default]
21 F32,
22 F16,
24 BF16,
26}
27
28#[derive(
31 Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
32)]
33pub enum QuantPropagation {
34 Propagate,
36 #[default]
38 Inhibit,
39}
40
41pub fn compute_range<B: Backend, const D: usize>(
43 scheme: &QuantScheme,
44 tensor: &Tensor<B, D>,
45 calibration: &Calibration,
46) -> CalibrationRange<B> {
47 let (min, max) = match &tensor.primitive {
48 TensorPrimitive::Float(tensor) => {
49 compute_range_primitive::<B>(scheme, tensor.clone(), calibration)
50 }
51 TensorPrimitive::QFloat(_) => unreachable!(),
52 };
53
54 CalibrationRange {
55 min: Tensor::from_primitive(TensorPrimitive::Float(min)),
56 max: Tensor::from_primitive(TensorPrimitive::Float(max)),
57 }
58}
59
60pub fn params_shape(data_shape: &Shape, level: QuantLevel) -> Shape {
62 match level {
63 QuantLevel::Tensor => Shape::new([1]),
64 QuantLevel::Block(block_size) => {
65 let mut params_shape = data_shape.clone();
66 let block_size = block_size.to_dim_vec(data_shape.num_dims());
67
68 for (shape, block_size) in params_shape.dims.iter_mut().zip(block_size) {
69 *shape = (*shape).div_ceil(block_size as usize);
70 }
71
72 params_shape
73 }
74 }
75}
76
77pub(crate) fn compute_range_primitive<B: Backend>(
78 scheme: &QuantScheme,
79 tensor: B::FloatTensorPrimitive,
80 calibration: &Calibration,
81) -> (B::FloatTensorPrimitive, B::FloatTensorPrimitive) {
82 match calibration {
83 Calibration::MinMax => match scheme.level {
84 QuantLevel::Tensor => (B::float_min(tensor.clone()), B::float_max(tensor)),
85 QuantLevel::Block(block_size) => {
86 let block_elems = block_size.num_elements();
87 let shape = tensor.shape();
88 let numel = shape.num_elements();
89
90 assert_eq!(
91 numel % block_elems,
92 0,
93 "Tensor {shape:?} must be evenly divisible by block size {block_elems}"
94 );
95
96 let num_blocks = numel / block_elems;
97
98 let params_shape = params_shape(&shape, scheme.level);
99
100 let blocks = B::float_reshape(tensor, Shape::new([num_blocks, block_elems]));
101 let blocks_min =
102 B::float_reshape(B::float_min_dim(blocks.clone(), 1), params_shape.clone());
103 let blocks_max = B::float_reshape(B::float_max_dim(blocks, 1), params_shape);
104 (blocks_min, blocks_max)
105 }
106 },
107 }
108}
109
110pub fn compute_q_params<B: Backend>(
112 scheme: &QuantScheme,
113 range: CalibrationRange<B>,
114) -> QuantizationParameters<B> {
115 match scheme {
116 QuantScheme {
117 level: QuantLevel::Tensor | QuantLevel::Block(_),
118 mode: QuantMode::Symmetric,
119 ..
120 } => {
121 let (a, b) = scheme.value.range();
123
124 let values_range = range.min.abs().max_pair(range.max.abs()).mul_scalar(2);
126
127 QuantizationParameters {
128 scales: values_range.div_scalar(b - a),
129 }
130 }
131 }
132}
133
134pub(crate) fn compute_q_params_primitive<B: Backend>(
136 scheme: &QuantScheme,
137 min: B::FloatTensorPrimitive,
138 max: B::FloatTensorPrimitive,
139) -> QuantizationParametersPrimitive<B> {
140 let range = CalibrationRange {
141 min: Tensor::from_primitive(TensorPrimitive::Float(min)),
142 max: Tensor::from_primitive(TensorPrimitive::Float(max)),
143 };
144 compute_q_params(scheme, range).into()
145}