pub mod calibration;
pub mod ops;
pub mod qat;
pub mod schemes;
pub mod utils;
use crate::{Module, Parameter};
use torsh_core::{
dtype::DType,
error::{Result, TorshError},
};
use torsh_tensor::Tensor;
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
#[derive(Debug, Clone)]
pub struct QuantizationConfig {
pub dtype: DType,
pub scheme: QuantizationScheme,
pub backend_config: BackendQuantConfig,
pub calibration: CalibrationConfig,
pub per_channel: bool,
pub quantize_weights: bool,
pub quantize_activations: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub enum QuantizationScheme {
Symmetric,
Asymmetric,
Dynamic,
KLDivergence,
Percentile(f32),
}
#[derive(Debug, Clone)]
pub struct BackendQuantConfig {
pub use_hardware_acceleration: bool,
pub enable_kernel_fusion: bool,
pub optimize_memory_layout: bool,
pub target_platform: DeploymentPlatform,
}
#[derive(Debug, Clone, PartialEq)]
pub enum DeploymentPlatform {
CPU,
GPU,
Mobile,
Edge,
Server,
WASM,
}
#[derive(Debug, Clone)]
pub struct CalibrationConfig {
pub num_samples: usize,
pub method: CalibrationMethod,
pub outlier_percentile: f32,
pub use_moving_average: bool,
pub momentum: f32,
}
#[derive(Debug, Clone, PartialEq)]
pub enum CalibrationMethod {
MinMax,
Entropy,
MSE,
CosineSimilarity,
}
#[derive(Debug, Clone)]
pub struct QuantizationParams {
pub scale: f32,
pub zero_point: i32,
pub qmin: i32,
pub qmax: i32,
pub src_dtype: DType,
pub dst_dtype: DType,
}
impl QuantizationParams {
pub fn symmetric(scale: f32, src_dtype: DType, dst_dtype: DType) -> Self {
let (qmin, qmax) = match dst_dtype {
DType::I8 => (-128i32, 127i32),
DType::U8 => (0i32, 255i32),
DType::I16 => (-32768i32, 32767i32),
_ => panic!("Unsupported quantization dtype: {:?}", dst_dtype),
};
Self {
scale,
zero_point: 0,
qmin,
qmax,
src_dtype,
dst_dtype,
}
}
pub fn asymmetric(scale: f32, zero_point: i32, src_dtype: DType, dst_dtype: DType) -> Self {
let (qmin, qmax) = match dst_dtype {
DType::I8 => (-128i32, 127i32),
DType::U8 => (0i32, 255i32),
DType::I16 => (-32768i32, 32767i32),
_ => panic!("Unsupported quantization dtype: {:?}", dst_dtype),
};
Self {
scale,
zero_point,
qmin,
qmax,
src_dtype,
dst_dtype,
}
}
pub fn quantize(&self, tensor: &Tensor) -> Result<Tensor> {
ops::quantize_tensor(tensor, self)
}
pub fn dequantize(&self, tensor: &Tensor) -> Result<Tensor> {
ops::dequantize_tensor(tensor, self)
}
}
#[derive(Debug)]
pub struct QuantizedModel<M: Module> {
pub model: M,
pub config: QuantizationConfig,
pub layer_params: HashMap<String, QuantizationParams>,
pub calibration_stats: Option<CalibrationStats>,
}
impl<M: Module> QuantizedModel<M> {
pub fn new(model: M, config: QuantizationConfig) -> Self {
Self {
model,
config,
layer_params: HashMap::new(),
calibration_stats: None,
}
}
pub fn calibrate<I>(&mut self, calibration_data: I) -> Result<()>
where
I: Iterator<Item = Tensor>,
{
let mut calibrator = calibration::Calibrator::new(&self.config.calibration);
calibrator.calibrate(&mut self.model, calibration_data)?;
self.calibration_stats = Some(calibrator.stats());
self.layer_params = calibrator.quantization_params();
Ok(())
}
pub fn quantize(&mut self) -> Result<()> {
if self.layer_params.is_empty() {
return Err(TorshError::InvalidArgument(
"Model must be calibrated before quantization".to_string(),
));
}
for (layer_name, params) in &self.layer_params {
println!(
"Quantizing layer {} with scale={}, zero_point={}",
layer_name, params.scale, params.zero_point
);
}
Ok(())
}
pub fn compression_ratio(&self) -> f32 {
if self.layer_params.is_empty() {
return 1.0;
}
let original_bits = match DType::F32 {
DType::F32 => 32,
DType::F16 => 16,
_ => 32,
};
let quantized_bits = match self.config.dtype {
DType::I8 | DType::U8 => 8,
DType::I16 => 16,
_ => 32,
};
original_bits as f32 / quantized_bits as f32
}
}
impl<M: Module> Module for QuantizedModel<M> {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.model.forward(input)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.model.parameters()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.model.named_parameters()
}
fn training(&self) -> bool {
self.model.training()
}
fn train(&mut self) {
self.model.train()
}
fn eval(&mut self) {
self.model.eval()
}
fn set_training(&mut self, training: bool) {
self.model.set_training(training);
}
fn to_device(&mut self, device: torsh_core::device::DeviceType) -> Result<()> {
self.model.to_device(device)
}
}
#[derive(Debug, Clone)]
pub struct CalibrationStats {
pub num_samples: usize,
pub activation_ranges: HashMap<String, (f32, f32)>,
pub weight_ranges: HashMap<String, (f32, f32)>,
pub metrics: CalibrationMetrics,
}
#[derive(Debug, Clone)]
pub struct CalibrationMetrics {
pub mse: f32,
pub snr: f32,
pub cosine_similarity: f32,
pub kl_divergence: f32,
}
impl Default for QuantizationConfig {
fn default() -> Self {
Self {
dtype: DType::I8,
scheme: QuantizationScheme::Symmetric,
backend_config: BackendQuantConfig::default(),
calibration: CalibrationConfig::default(),
per_channel: false,
quantize_weights: true,
quantize_activations: true,
}
}
}
impl Default for BackendQuantConfig {
fn default() -> Self {
Self {
use_hardware_acceleration: true,
enable_kernel_fusion: true,
optimize_memory_layout: true,
target_platform: DeploymentPlatform::CPU,
}
}
}
impl Default for CalibrationConfig {
fn default() -> Self {
Self {
num_samples: 100,
method: CalibrationMethod::MinMax,
outlier_percentile: 99.99,
use_moving_average: true,
momentum: 0.9,
}
}
}
pub mod prelude {
pub use super::qat::utils::{calibrate_qat_model, prepare_qat_model, progressive_qat_training};
pub use super::qat::{
FakeQuantize, QATConfig, QATLinear, QATModel, QATScheduler, QuantizedInferenceModel,
};
pub use super::{
BackendQuantConfig, CalibrationConfig, CalibrationMethod, DeploymentPlatform,
QuantizationConfig, QuantizationParams, QuantizationScheme, QuantizedModel,
};
pub fn int8_symmetric() -> QuantizationConfig {
QuantizationConfig {
dtype: torsh_core::dtype::DType::I8,
scheme: QuantizationScheme::Symmetric,
..Default::default()
}
}
pub fn int8_asymmetric() -> QuantizationConfig {
QuantizationConfig {
dtype: torsh_core::dtype::DType::I8,
scheme: QuantizationScheme::Asymmetric,
..Default::default()
}
}
pub fn dynamic_quantization() -> QuantizationConfig {
QuantizationConfig {
scheme: QuantizationScheme::Dynamic,
..Default::default()
}
}
pub fn qat_int8_config() -> QATConfig {
QATConfig {
weight_bits: 8,
activation_bits: 8,
scheme: QuantizationScheme::Symmetric,
..Default::default()
}
}
pub fn qat_conservative_config() -> QATConfig {
QATConfig {
warmup_epochs: 5,
qparam_lr: 0.005,
observer_momentum: 0.05,
..Default::default()
}
}
pub fn qat_aggressive_config() -> QATConfig {
QATConfig {
warmup_epochs: 1,
qparam_lr: 0.02,
observer_momentum: 0.2,
..Default::default()
}
}
}