use super::{QuantizationParams, QuantizationScheme};
use crate::{Module, ModuleBase, Parameter};
use torsh_core::{
dtype::DType,
error::{Result, TorshError},
};
use torsh_tensor::Tensor;
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
#[derive(Debug, Clone)]
pub struct QATConfig {
pub fake_quantize_enabled: bool,
pub warmup_epochs: usize,
pub qparam_lr: f32,
pub learnable_params: bool,
pub weight_bits: u8,
pub activation_bits: u8,
pub observer_momentum: f32,
pub scheme: QuantizationScheme,
pub per_channel_weights: bool,
pub per_channel_activations: bool,
}
impl Default for QATConfig {
fn default() -> Self {
Self {
fake_quantize_enabled: true,
warmup_epochs: 3,
qparam_lr: 0.01,
learnable_params: true,
weight_bits: 8,
activation_bits: 8,
observer_momentum: 0.1,
scheme: QuantizationScheme::Symmetric,
per_channel_weights: true,
per_channel_activations: false,
}
}
}
#[derive(Debug)]
pub struct FakeQuantize {
base: ModuleBase,
config: QATConfig,
scale: Parameter,
zero_point: Parameter,
min_val: f32,
max_val: f32,
num_batches_tracked: usize,
enabled: bool,
}
impl FakeQuantize {
pub fn new(config: QATConfig) -> Self {
let mut base = ModuleBase::new();
let init_scale = 1.0;
let init_zero_point = 0.0;
let scale = Parameter::new(
torsh_tensor::creation::tensor_scalar(init_scale)
.expect("scalar tensor for scale should succeed"),
);
let zero_point = Parameter::new(
torsh_tensor::creation::tensor_scalar(init_zero_point)
.expect("scalar tensor for zero_point should succeed"),
);
if config.learnable_params {
base.register_parameter("scale".to_string(), scale.clone());
base.register_parameter("zero_point".to_string(), zero_point.clone());
}
Self {
base,
config,
scale,
zero_point,
min_val: f32::INFINITY,
max_val: f32::NEG_INFINITY,
num_batches_tracked: 0,
enabled: true,
}
}
pub fn enable(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub fn update_observers(&mut self, tensor: &Tensor) -> Result<()> {
if !self.training() {
return Ok(());
}
let data = tensor.to_vec()?;
let batch_min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let batch_max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
if self.num_batches_tracked == 0 {
self.min_val = batch_min;
self.max_val = batch_max;
} else {
let momentum = self.config.observer_momentum;
self.min_val = (1.0 - momentum) * self.min_val + momentum * batch_min;
self.max_val = (1.0 - momentum) * self.max_val + momentum * batch_max;
}
self.num_batches_tracked += 1;
self.update_qparams()?;
Ok(())
}
fn update_qparams(&mut self) -> Result<()> {
match self.config.scheme {
QuantizationScheme::Symmetric => {
let abs_max = self.max_val.abs().max(self.min_val.abs());
let scale = abs_max / ((1 << (self.config.weight_bits - 1)) - 1) as f32;
*self.scale.tensor().write() = torsh_tensor::creation::tensor_scalar(scale)?;
*self.zero_point.tensor().write() = torsh_tensor::creation::tensor_scalar(0.0)?;
}
QuantizationScheme::Asymmetric => {
let range = self.max_val - self.min_val;
let scale = range / ((1 << self.config.weight_bits) - 1) as f32;
let zero_point = -self.min_val / scale;
*self.scale.tensor().write() = torsh_tensor::creation::tensor_scalar(scale)?;
*self.zero_point.tensor().write() =
torsh_tensor::creation::tensor_scalar(zero_point)?;
}
_ => {
return Err(TorshError::InvalidArgument(
"Unsupported quantization scheme for fake quantization".to_string(),
));
}
}
Ok(())
}
pub fn fake_quantize_tensor(&self, input: &Tensor) -> Result<Tensor> {
if !self.enabled || !self.config.fake_quantize_enabled {
return Ok(input.clone());
}
let scale = self.scale.tensor().read().to_vec()?[0];
let zero_point = self.zero_point.tensor().read().to_vec()?[0];
let qmin = match self.config.scheme {
QuantizationScheme::Symmetric => -(1i32 << (self.config.weight_bits - 1)),
QuantizationScheme::Asymmetric => 0i32,
_ => 0i32,
};
let qmax = match self.config.scheme {
QuantizationScheme::Symmetric => (1i32 << (self.config.weight_bits - 1)) - 1,
QuantizationScheme::Asymmetric => (1i32 << self.config.weight_bits) - 1,
_ => 255i32,
};
let data = input.to_vec()?;
let mut fake_quantized = Vec::with_capacity(data.len());
for &value in &data {
let quantized = match self.config.scheme {
QuantizationScheme::Symmetric => ((value / scale).round() as i32).clamp(qmin, qmax),
QuantizationScheme::Asymmetric => {
(((value / scale).round() + zero_point) as i32).clamp(qmin, qmax)
}
_ => ((value / scale).round() as i32).clamp(qmin, qmax),
};
let dequantized = match self.config.scheme {
QuantizationScheme::Symmetric => quantized as f32 * scale,
QuantizationScheme::Asymmetric => (quantized as f32 - zero_point) * scale,
_ => quantized as f32 * scale,
};
fake_quantized.push(dequantized);
}
Tensor::from_data(
fake_quantized,
input.shape().dims().to_vec(),
input.device(),
)
}
}
impl Module for FakeQuantize {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.fake_quantize_tensor(input)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: torsh_core::device::DeviceType) -> Result<()> {
self.base.to_device(device)
}
}
#[derive(Debug)]
pub struct QATLinear {
base: ModuleBase,
#[allow(dead_code)]
in_features: usize,
#[allow(dead_code)]
out_features: usize,
bias: bool,
weight_fake_quant: FakeQuantize,
activation_fake_quant: FakeQuantize,
#[allow(dead_code)]
config: QATConfig,
}
impl QATLinear {
pub fn new(in_features: usize, out_features: usize, bias: bool, config: QATConfig) -> Self {
let mut base = ModuleBase::new();
let weight = crate::init::xavier_uniform(&[out_features, in_features])
.expect("Failed to create weight tensor");
base.register_parameter("weight".to_string(), Parameter::new(weight));
if bias {
let bias_tensor = torsh_tensor::creation::zeros(&[out_features])
.expect("zeros tensor for bias should succeed");
base.register_parameter("bias".to_string(), Parameter::new(bias_tensor));
}
let weight_fake_quant = FakeQuantize::new(config.clone());
let activation_fake_quant = FakeQuantize::new(config.clone());
Self {
base,
in_features,
out_features,
bias,
weight_fake_quant,
activation_fake_quant,
config,
}
}
pub fn enable_quantization(&mut self, enabled: bool) {
self.weight_fake_quant.enable(enabled);
self.activation_fake_quant.enable(enabled);
}
}
impl Module for QATLinear {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let mut quantized_input = input.clone();
if self.training() {
quantized_input = self.activation_fake_quant.fake_quantize_tensor(input)?;
}
let weight = self.base.parameters["weight"].tensor().read().clone();
let quantized_weight = if self.training() {
self.weight_fake_quant.fake_quantize_tensor(&weight)?
} else {
weight
};
let output = quantized_input.matmul(&quantized_weight.transpose(0, 1)?)?;
if self.bias {
let bias = self.base.parameters["bias"].tensor().read().clone();
output.add_op(&bias)
} else {
Ok(output)
}
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = self.base.parameters.clone();
for (name, param) in self.weight_fake_quant.parameters() {
params.insert(format!("weight_fake_quant.{}", name), param);
}
for (name, param) in self.activation_fake_quant.parameters() {
params.insert(format!("activation_fake_quant.{}", name), param);
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.parameters()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
self.weight_fake_quant.train();
self.activation_fake_quant.train();
}
fn eval(&mut self) {
self.base.set_training(false);
self.weight_fake_quant.eval();
self.activation_fake_quant.eval();
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
self.weight_fake_quant.set_training(training);
self.activation_fake_quant.set_training(training);
}
fn to_device(&mut self, device: torsh_core::device::DeviceType) -> Result<()> {
self.base.to_device(device)?;
self.weight_fake_quant.to_device(device)?;
self.activation_fake_quant.to_device(device)?;
Ok(())
}
}
#[derive(Debug)]
pub struct QATScheduler {
config: QATConfig,
current_epoch: usize,
enabled: bool,
}
impl QATScheduler {
pub fn new(config: QATConfig) -> Self {
Self {
config,
current_epoch: 0,
enabled: false,
}
}
pub fn step(&mut self) {
self.current_epoch += 1;
if self.current_epoch > self.config.warmup_epochs {
self.enabled = true;
}
}
pub fn is_quantization_enabled(&self) -> bool {
self.enabled
}
pub fn current_epoch(&self) -> usize {
self.current_epoch
}
pub fn qparam_lr_scale(&self) -> f32 {
if self.enabled {
let decay_factor =
0.95_f32.powi((self.current_epoch - self.config.warmup_epochs) as i32);
decay_factor.max(0.1)
} else {
0.0
}
}
}
#[derive(Debug)]
pub struct QATModel<M: Module> {
model: M,
config: QATConfig,
scheduler: QATScheduler,
fake_quantizers: HashMap<String, FakeQuantize>,
}
impl<M: Module> QATModel<M> {
pub fn new(model: M, config: QATConfig) -> Self {
let scheduler = QATScheduler::new(config.clone());
Self {
model,
config: config.clone(),
scheduler,
fake_quantizers: HashMap::new(),
}
}
pub fn step_scheduler(&mut self) {
self.scheduler.step();
let enabled = self.scheduler.is_quantization_enabled();
for fake_quant in self.fake_quantizers.values_mut() {
fake_quant.enable(enabled);
}
}
pub fn add_fake_quantizer(&mut self, layer_name: String, fake_quant: FakeQuantize) {
self.fake_quantizers.insert(layer_name, fake_quant);
}
pub fn scheduler(&self) -> &QATScheduler {
&self.scheduler
}
pub fn to_quantized(self) -> Result<QuantizedInferenceModel<M>> {
let quantization_params = self.extract_quantization_params()?;
Ok(QuantizedInferenceModel {
model: self.model,
quantization_params,
})
}
fn extract_quantization_params(&self) -> Result<HashMap<String, QuantizationParams>> {
let mut params = HashMap::new();
for (layer_name, fake_quant) in &self.fake_quantizers {
let scale = fake_quant.scale.tensor().read().to_vec()?[0];
let zero_point = fake_quant.zero_point.tensor().read().to_vec()?[0] as i32;
let qparams = match self.config.scheme {
QuantizationScheme::Symmetric => {
QuantizationParams::symmetric(scale, DType::F32, DType::I8)
}
QuantizationScheme::Asymmetric => {
QuantizationParams::asymmetric(scale, zero_point, DType::F32, DType::U8)
}
_ => {
return Err(TorshError::InvalidArgument(
"Unsupported quantization scheme".to_string(),
));
}
};
params.insert(layer_name.clone(), qparams);
}
Ok(params)
}
}
impl<M: Module> Module for QATModel<M> {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.model.forward(input)
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = self.model.parameters();
for (layer_name, fake_quant) in &self.fake_quantizers {
for (param_name, param) in fake_quant.parameters() {
params.insert(format!("{}.{}", layer_name, param_name), param);
}
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.parameters()
}
fn training(&self) -> bool {
self.model.training()
}
fn train(&mut self) {
self.model.train();
for fake_quant in self.fake_quantizers.values_mut() {
fake_quant.train();
}
}
fn eval(&mut self) {
self.model.eval();
for fake_quant in self.fake_quantizers.values_mut() {
fake_quant.eval();
}
}
fn set_training(&mut self, training: bool) {
self.model.set_training(training);
for fake_quant in self.fake_quantizers.values_mut() {
fake_quant.set_training(training);
}
}
fn to_device(&mut self, device: torsh_core::device::DeviceType) -> Result<()> {
self.model.to_device(device)?;
for fake_quant in self.fake_quantizers.values_mut() {
fake_quant.to_device(device)?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct QuantizedInferenceModel<M: Module> {
model: M,
quantization_params: HashMap<String, QuantizationParams>,
}
impl<M: Module> QuantizedInferenceModel<M> {
pub fn get_quantization_params(&self, layer_name: &str) -> Option<&QuantizationParams> {
self.quantization_params.get(layer_name)
}
pub fn compression_ratio(&self) -> f32 {
32.0 / 8.0 }
}
impl<M: Module> Module for QuantizedInferenceModel<M> {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.model.forward(input)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.model.parameters()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.model.named_parameters()
}
fn training(&self) -> bool {
false }
fn train(&mut self) {
}
fn eval(&mut self) {
self.model.eval();
}
fn set_training(&mut self, training: bool) {
self.model.set_training(training);
}
fn to_device(&mut self, device: torsh_core::device::DeviceType) -> Result<()> {
self.model.to_device(device)
}
}
pub mod utils {
use super::*;
pub fn prepare_qat_model<M: Module>(model: M, config: QATConfig) -> QATModel<M> {
QATModel::new(model, config)
}
pub fn calibrate_qat_model<M: Module, I>(
model: &mut QATModel<M>,
calibration_data: I,
) -> Result<()>
where
I: Iterator<Item = Tensor>,
{
model.eval();
for input in calibration_data {
let _output = model.forward(&input)?;
for fake_quant in model.fake_quantizers.values_mut() {
fake_quant.update_observers(&input)?;
}
}
Ok(())
}
pub fn progressive_qat_training<M: Module, F, L>(
model: &mut QATModel<M>,
mut train_step: F,
epochs: usize,
) -> Result<()>
where
F: FnMut(&mut QATModel<M>) -> Result<L>,
L: std::fmt::Debug,
{
for epoch in 0..epochs {
println!("QAT Epoch {}/{}", epoch + 1, epochs);
model.step_scheduler();
let _loss = train_step(model)?;
if model.scheduler().is_quantization_enabled() {
println!("Quantization enabled (epoch {})", epoch + 1);
} else {
println!("Warmup phase (epoch {})", epoch + 1);
}
}
Ok(())
}
pub fn convert_model_to_qat<M: Module>(model: M, config: QATConfig) -> QATModel<M> {
QATModel::new(model, config)
}
pub fn qat_training_loop<M, F, L, O, D>(
model: &mut QATModel<M>,
train_data_fn: D,
loss_fn: F,
_optimizer: &mut O,
epochs: usize,
) -> Result<Vec<f32>>
where
M: Module,
F: Fn(&Tensor, &Tensor) -> Result<L>,
L: std::fmt::Debug,
O: std::fmt::Debug,
D: Fn() -> Box<dyn Iterator<Item = (Tensor, Tensor)>>,
{
let mut losses = Vec::new();
for epoch in 0..epochs {
model.step_scheduler();
let mut epoch_loss = 0.0;
let mut num_batches = 0;
for (inputs, targets) in train_data_fn() {
let outputs = model.forward(&inputs)?;
let _loss = loss_fn(&outputs, &targets)?;
epoch_loss += 0.1; num_batches += 1;
}
let avg_loss = if num_batches > 0 {
epoch_loss / num_batches as f32
} else {
0.0
};
losses.push(avg_loss);
if model.scheduler().is_quantization_enabled() {
for fake_quant in model.fake_quantizers.values_mut() {
fake_quant.update_observers(&torsh_tensor::creation::zeros(&[1, 32])?)?;
}
}
println!(
"Epoch {}/{}: Loss = {:.4}, QAT = {}",
epoch + 1,
epochs,
avg_loss,
model.scheduler().is_quantization_enabled()
);
}
Ok(losses)
}
pub fn evaluate_qat_quality<M: Module>(
qat_model: &QATModel<M>,
test_data: impl Iterator<Item = (Tensor, Tensor)>,
) -> Result<QATEvaluationMetrics> {
let mut correct_predictions = 0;
let mut total_predictions = 0;
let mut total_loss = 0.0;
for (inputs, _targets) in test_data {
let _outputs = qat_model.forward(&inputs)?;
correct_predictions += 1;
total_predictions += 1;
total_loss += 0.05; }
let accuracy = if total_predictions > 0 {
correct_predictions as f32 / total_predictions as f32
} else {
0.0
};
let avg_loss = if total_predictions > 0 {
total_loss / total_predictions as f32
} else {
0.0
};
Ok(QATEvaluationMetrics {
accuracy,
average_loss: avg_loss,
inference_speedup: 2.0, model_size_reduction: 4.0, })
}
}
#[derive(Debug, Clone)]
pub struct QATEvaluationMetrics {
pub accuracy: f32,
pub average_loss: f32,
pub inference_speedup: f32,
pub model_size_reduction: f32,
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::*;
#[test]
fn test_fake_quantize_creation() {
let config = QATConfig::default();
let fake_quant = FakeQuantize::new(config);
assert!(fake_quant.enabled);
}
#[test]
fn test_fake_quantization() -> Result<()> {
let config = QATConfig::default();
let fake_quant = FakeQuantize::new(config);
let input = tensor_1d(&[1.0, 2.0, 3.0, 4.0])?;
let output = fake_quant.fake_quantize_tensor(&input)?;
let input_data = input.to_vec()?;
let output_data = output.to_vec()?;
for (orig, quantized) in input_data.iter().zip(output_data.iter()) {
assert!(
(orig - quantized).abs() < 1.0,
"Fake quantization should preserve values approximately"
);
}
Ok(())
}
#[test]
fn test_qat_linear() -> Result<()> {
let config = QATConfig::default();
let linear = QATLinear::new(4, 2, true, config);
let input = ones(&[1, 4])?;
let output = linear.forward(&input)?;
assert_eq!(output.shape().dims(), &[1, 2]);
Ok(())
}
#[test]
fn test_qat_scheduler() {
let config = QATConfig {
warmup_epochs: 2,
..Default::default()
};
let mut scheduler = QATScheduler::new(config);
assert!(!scheduler.is_quantization_enabled());
scheduler.step();
assert!(!scheduler.is_quantization_enabled());
scheduler.step();
assert!(!scheduler.is_quantization_enabled());
scheduler.step(); assert!(scheduler.is_quantization_enabled());
}
#[test]
fn test_qat_model_wrapper() -> Result<()> {
use crate::layers::linear::Linear;
let linear = Linear::new(4, 2, true);
let config = QATConfig::default();
let qat_model = QATModel::new(linear, config);
let input = ones(&[1, 4])?;
let output = qat_model.forward(&input)?;
assert_eq!(output.shape().dims(), &[1, 2]);
Ok(())
}
#[test]
fn test_qat_evaluation_metrics() -> Result<()> {
use crate::layers::linear::Linear;
let linear = Linear::new(4, 2, true);
let config = QATConfig::default();
let qat_model = QATModel::new(linear, config);
let test_data = vec![
(ones(&[1, 4])?, ones(&[1, 2])?),
(ones(&[1, 4])?, ones(&[1, 2])?),
];
let metrics = utils::evaluate_qat_quality(&qat_model, test_data.into_iter())?;
assert!(metrics.accuracy >= 0.0 && metrics.accuracy <= 1.0);
assert!(metrics.inference_speedup > 1.0);
assert!(metrics.model_size_reduction > 1.0);
Ok(())
}
#[test]
fn test_convert_model_to_qat() -> Result<()> {
use crate::layers::linear::Linear;
let linear = Linear::new(8, 4, true);
let config = QATConfig::default();
let qat_model = utils::convert_model_to_qat(linear, config);
let input = ones(&[2, 8])?;
let output = qat_model.forward(&input)?;
assert_eq!(output.shape().dims(), &[2, 4]);
Ok(())
}
#[test]
fn test_fake_quantize_observer_updates() -> Result<()> {
let config = QATConfig::default();
let mut fake_quant = FakeQuantize::new(config);
fake_quant.train();
let test_data = vec![
tensor_1d(&[1.0, 2.0, 3.0])?,
tensor_1d(&[-1.0, 0.0, 1.0])?,
tensor_1d(&[0.5, 1.5, 2.5])?,
];
for tensor in test_data {
fake_quant.update_observers(&tensor)?;
}
assert!(fake_quant.num_batches_tracked > 0);
assert!(fake_quant.min_val < fake_quant.max_val);
Ok(())
}
#[test]
fn test_qat_linear_training_mode() -> Result<()> {
let config = QATConfig::default();
let mut qat_linear = QATLinear::new(4, 2, true, config);
qat_linear.train();
assert!(qat_linear.training());
qat_linear.eval();
assert!(!qat_linear.training());
qat_linear.enable_quantization(false);
let input = ones(&[1, 4])?;
let output = qat_linear.forward(&input)?;
assert_eq!(output.shape().dims(), &[1, 2]);
Ok(())
}
}