use crate::{Tensor, TensorPrimitive, backend::Backend};
use burn_backend::tensor::quantization;
pub use burn_backend::{QTensorPrimitive, quantization::*};
pub type QuantizationParameters<B> = QParams<Tensor<B, 1>>;
#[derive(Clone, Debug)]
pub struct CalibrationRange<B: Backend> {
pub min: Tensor<B, 1>,
pub max: Tensor<B, 1>,
}
pub fn compute_range<B: Backend, const D: usize>(
scheme: &QuantScheme,
tensor: &Tensor<B, D>,
calibration: &Calibration,
) -> CalibrationRange<B> {
let (min, max) = match &tensor.primitive {
TensorPrimitive::Float(tensor) => {
quantization::compute_range::<B>(scheme, tensor.clone(), calibration)
}
TensorPrimitive::QFloat(_) => unreachable!(),
};
CalibrationRange {
min: Tensor::from_primitive(TensorPrimitive::Float(min)),
max: Tensor::from_primitive(TensorPrimitive::Float(max)),
}
}
pub fn compute_q_params<B: Backend>(
scheme: &QuantScheme,
range: CalibrationRange<B>,
) -> QuantizationParameters<B> {
match (range.min.primitive, range.max.primitive) {
(TensorPrimitive::Float(min), TensorPrimitive::Float(max)) => {
let qparams = quantization::compute_q_params::<B>(scheme, min, max);
QuantizationParameters {
scales: Tensor::from_primitive(TensorPrimitive::Float(qparams.scales)),
}
}
_ => unreachable!(),
}
}