burn_tensor/tensor/
quantization.rs1use crate::{Tensor, TensorPrimitive, backend::Backend};
2use burn_backend::tensor::quantization;
3
4pub use burn_backend::{QTensorPrimitive, quantization::*};
6
7pub type QuantizationParameters<B> = QParams<Tensor<B, 1>>;
9
10#[derive(Clone, Debug)]
12pub struct CalibrationRange<B: Backend> {
13 pub min: Tensor<B, 1>,
15 pub max: Tensor<B, 1>,
17}
18
19pub fn compute_range<B: Backend, const D: usize>(
21 scheme: &QuantScheme,
22 tensor: &Tensor<B, D>,
23 calibration: &Calibration,
24) -> CalibrationRange<B> {
25 let (min, max) = match &tensor.primitive {
26 TensorPrimitive::Float(tensor) => {
27 quantization::compute_range::<B>(scheme, tensor.clone(), calibration)
28 }
29 TensorPrimitive::QFloat(_) => unreachable!(),
30 };
31
32 CalibrationRange {
33 min: Tensor::from_primitive(TensorPrimitive::Float(min)),
34 max: Tensor::from_primitive(TensorPrimitive::Float(max)),
35 }
36}
37
38pub fn compute_q_params<B: Backend>(
40 scheme: &QuantScheme,
41 range: CalibrationRange<B>,
42) -> QuantizationParameters<B> {
43 match (range.min.primitive, range.max.primitive) {
44 (TensorPrimitive::Float(min), TensorPrimitive::Float(max)) => {
45 let qparams = quantization::compute_q_params::<B>(scheme, min, max);
46 QuantizationParameters {
47 scales: Tensor::from_primitive(TensorPrimitive::Float(qparams.scales)),
48 }
49 }
50 _ => unreachable!(),
51 }
52}