use crate::errors::{Result, TrustformersError};
use crate::tensor::Tensor;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActivationQuantConfig {
pub scheme: ActivationQuantScheme,
pub symmetric: bool,
pub calibration_samples: usize,
pub percentile: f32,
pub ema_decay: f32,
pub quantize_during_training: bool,
pub layer_configs: HashMap<String, LayerQuantConfig>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ActivationQuantScheme {
Int8,
Int16,
Dynamic,
Adaptive,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerQuantConfig {
pub enabled: bool,
pub scheme: Option<ActivationQuantScheme>,
pub bits: Option<u8>,
pub calibrate: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActivationStats {
pub min_val: f32,
pub max_val: f32,
pub sum: f64,
pub sum_squares: f64,
pub count: usize,
pub histogram: Vec<(f32, usize)>,
pub ema_min: f32,
pub ema_max: f32,
}
#[derive(Debug, Clone)]
pub struct QuantizedActivation {
pub data: Vec<u8>,
pub scale: f32,
pub zero_point: i32,
pub shape: Vec<usize>,
pub scheme: ActivationQuantScheme,
pub bits: u8,
}
pub struct ActivationQuantizer {
config: ActivationQuantConfig,
layer_stats: HashMap<String, ActivationStats>,
calibrating: bool,
calibration_count: usize,
}
impl Default for ActivationQuantConfig {
fn default() -> Self {
Self {
scheme: ActivationQuantScheme::Int8,
symmetric: false,
calibration_samples: 100,
percentile: 0.99,
ema_decay: 0.01,
quantize_during_training: false,
layer_configs: HashMap::new(),
}
}
}
impl Default for LayerQuantConfig {
fn default() -> Self {
Self {
enabled: true,
scheme: None,
bits: None,
calibrate: true,
}
}
}
impl Default for ActivationStats {
fn default() -> Self {
Self::new()
}
}
impl ActivationStats {
pub fn new() -> Self {
Self {
min_val: f32::INFINITY,
max_val: f32::NEG_INFINITY,
sum: 0.0,
sum_squares: 0.0,
count: 0,
histogram: Vec::new(),
ema_min: f32::INFINITY,
ema_max: f32::NEG_INFINITY,
}
}
pub fn update(&mut self, tensor: &Tensor, ema_decay: f32) -> Result<()> {
match tensor {
Tensor::F32(arr) => {
let data: Vec<f32> = arr.iter().cloned().collect();
for &val in &data {
if !val.is_finite() {
continue; }
self.min_val = self.min_val.min(val);
self.max_val = self.max_val.max(val);
self.sum += val as f64;
self.sum_squares += (val * val) as f64;
self.count += 1;
if self.ema_min.is_infinite() {
self.ema_min = val;
self.ema_max = val;
} else {
if val < self.ema_min {
self.ema_min = self.ema_min * (1.0 - ema_decay) + val * ema_decay;
}
if val > self.ema_max {
self.ema_max = self.ema_max * (1.0 - ema_decay) + val * ema_decay;
}
}
}
let num_bins = 1000;
let range = self.max_val - self.min_val;
if range > 0.0 {
self.histogram.resize(num_bins, (0.0, 0));
for &val in &data {
if val.is_finite() {
let bin_idx =
((val - self.min_val) / range * (num_bins - 1) as f32) as usize;
let bin_idx = bin_idx.min(num_bins - 1);
self.histogram[bin_idx].0 = val;
self.histogram[bin_idx].1 += 1;
}
}
}
},
_ => {
return Err(TrustformersError::quantization_error(
"Unsupported tensor type for activation quantization".into(),
))
},
}
Ok(())
}
pub fn mean(&self) -> f32 {
if self.count == 0 {
0.0
} else {
(self.sum / self.count as f64) as f32
}
}
pub fn variance(&self) -> f32 {
if self.count <= 1 {
0.0
} else {
let mean = self.mean() as f64;
let variance = (self.sum_squares / self.count as f64) - (mean * mean);
variance.max(0.0) as f32
}
}
pub fn percentile(&self, p: f32) -> f32 {
if self.histogram.is_empty() || self.count == 0 {
return self.max_val;
}
let target_count = (self.count as f32 * p) as usize;
let mut cumulative_count = 0;
for &(val, count) in &self.histogram {
cumulative_count += count;
if cumulative_count >= target_count {
return val;
}
}
self.max_val
}
pub fn get_quantization_params(
&self,
symmetric: bool,
bits: u8,
percentile: f32,
) -> Result<(f32, i32)> {
if self.count == 0 {
return Err(TrustformersError::quantization_error(
"No statistics available for quantization".into(),
));
}
let q_min = if symmetric { -(1 << (bits - 1)) } else { 0 };
let q_max = if symmetric { (1 << (bits - 1)) - 1 } else { (1 << bits) - 1 };
let min_val = if percentile < 1.0 {
-self.percentile(1.0 - percentile)
} else {
self.min_val
};
let max_val = if percentile < 1.0 { self.percentile(percentile) } else { self.max_val };
let (scale, zero_point) = if symmetric {
let abs_max = max_val.abs().max(min_val.abs());
if abs_max == 0.0 {
return Ok((1.0, 0));
}
let scale = abs_max / (q_max - q_min) as f32;
(scale, 0)
} else {
if max_val == min_val {
return Ok((1.0, q_min));
}
let scale = (max_val - min_val) / (q_max - q_min) as f32;
let zero_point = q_min - (min_val / scale).round() as i32;
let zero_point = zero_point.clamp(q_min, q_max);
(scale, zero_point)
};
Ok((scale, zero_point))
}
}
impl QuantizedActivation {
pub fn new(
data: Vec<u8>,
scale: f32,
zero_point: i32,
shape: Vec<usize>,
scheme: ActivationQuantScheme,
bits: u8,
) -> Self {
Self {
data,
scale,
zero_point,
shape,
scheme,
bits,
}
}
pub fn dequantize(&self) -> Result<Tensor> {
let total_elements: usize = self.shape.iter().product();
let mut result = Vec::with_capacity(total_elements);
match self.scheme {
ActivationQuantScheme::Int8 | ActivationQuantScheme::Dynamic => {
for &quantized_val in &self.data {
let int_val = quantized_val as i32 - self.zero_point;
let float_val = int_val as f32 * self.scale;
result.push(float_val);
}
},
ActivationQuantScheme::Int16 => {
for chunk in self.data.chunks(2) {
if chunk.len() == 2 {
let int16_val =
u16::from_le_bytes([chunk[0], chunk[1]]) as i32 - self.zero_point;
let float_val = int16_val as f32 * self.scale;
result.push(float_val);
}
}
},
ActivationQuantScheme::Adaptive => {
for &quantized_val in &self.data {
let int_val = quantized_val as i32 - self.zero_point;
let float_val = int_val as f32 * self.scale;
result.push(float_val);
}
},
}
Tensor::from_vec(result, &self.shape)
}
}
impl ActivationQuantizer {
pub fn new(config: ActivationQuantConfig) -> Self {
Self {
config,
layer_stats: HashMap::new(),
calibrating: true,
calibration_count: 0,
}
}
pub fn start_calibration(&mut self) {
self.calibrating = true;
self.calibration_count = 0;
self.layer_stats.clear();
}
pub fn end_calibration(&mut self) {
self.calibrating = false;
}
pub fn is_calibration_complete(&self) -> bool {
!self.calibrating || self.calibration_count >= self.config.calibration_samples
}
pub fn quantize_activation(
&mut self,
tensor: &Tensor,
layer_name: &str,
training: bool,
) -> Result<Tensor> {
let layer_config = self.config.layer_configs.get(layer_name).cloned().unwrap_or_default();
if !layer_config.enabled {
return Ok(tensor.clone());
}
if training && !self.config.quantize_during_training {
if self.calibrating && layer_config.calibrate {
self.update_statistics(tensor, layer_name)?;
}
return Ok(tensor.clone());
}
if self.calibrating && layer_config.calibrate {
self.update_statistics(tensor, layer_name)?;
if self.calibration_count < self.config.calibration_samples {
return Ok(tensor.clone());
}
}
self.apply_quantization(tensor, layer_name, &layer_config)
}
fn update_statistics(&mut self, tensor: &Tensor, layer_name: &str) -> Result<()> {
let stats = self.layer_stats.entry(layer_name.to_string()).or_default();
stats.update(tensor, self.config.ema_decay)?;
self.calibration_count += 1;
Ok(())
}
fn apply_quantization(
&self,
tensor: &Tensor,
layer_name: &str,
layer_config: &LayerQuantConfig,
) -> Result<Tensor> {
let stats = self.layer_stats.get(layer_name).ok_or_else(|| {
TrustformersError::quantization_error(format!(
"No calibration statistics found for layer {}",
layer_name
))
})?;
let scheme = layer_config.scheme.unwrap_or(self.config.scheme);
let bits = layer_config.bits.unwrap_or(match scheme {
ActivationQuantScheme::Int8
| ActivationQuantScheme::Dynamic
| ActivationQuantScheme::Adaptive => 8,
ActivationQuantScheme::Int16 => 16,
});
let (scale, zero_point) =
stats.get_quantization_params(self.config.symmetric, bits, self.config.percentile)?;
match scheme {
ActivationQuantScheme::Int8 | ActivationQuantScheme::Dynamic => {
self.quantize_int8(tensor, scale, zero_point)
},
ActivationQuantScheme::Int16 => self.quantize_int16(tensor, scale, zero_point),
ActivationQuantScheme::Adaptive => {
self.quantize_adaptive(tensor, stats, scale, zero_point)
},
}
}
fn quantize_int8(&self, tensor: &Tensor, scale: f32, zero_point: i32) -> Result<Tensor> {
match tensor {
Tensor::F32(arr) => {
let quantized_data: Vec<f32> = arr
.iter()
.map(|&val| {
let q_val = ((val / scale).round() as i32 + zero_point).clamp(0, 255) as u8;
(q_val as i32 - zero_point) as f32 * scale
})
.collect();
Tensor::from_vec(quantized_data, arr.shape())
},
_ => Err(TrustformersError::quantization_error(
"Unsupported tensor type for activation quantization".into(),
)),
}
}
fn quantize_int16(&self, tensor: &Tensor, scale: f32, zero_point: i32) -> Result<Tensor> {
match tensor {
Tensor::F32(arr) => {
let quantized_data: Vec<f32> = arr
.iter()
.map(|&val| {
let q_val =
((val / scale).round() as i32 + zero_point).clamp(0, 65535) as u16;
(q_val as i32 - zero_point) as f32 * scale
})
.collect();
Tensor::from_vec(quantized_data, arr.shape())
},
_ => Err(TrustformersError::quantization_error(
"Unsupported tensor type for activation quantization".into(),
)),
}
}
fn quantize_adaptive(
&self,
tensor: &Tensor,
stats: &ActivationStats,
scale: f32,
zero_point: i32,
) -> Result<Tensor> {
match tensor {
Tensor::F32(arr) => {
let variance = stats.variance();
let mean = stats.mean();
let quantized_data: Vec<f32> = arr
.iter()
.map(|&val| {
let effective_scale = if variance < 0.1 {
scale * 0.5 } else {
scale
};
let clipped_val = if (val - mean).abs() > 3.0 * variance.sqrt() {
if val > mean {
mean + 3.0 * variance.sqrt()
} else {
mean - 3.0 * variance.sqrt()
}
} else {
val
};
let q_val = ((clipped_val / effective_scale).round() as i32 + zero_point)
.clamp(0, 255) as u8;
(q_val as i32 - zero_point) as f32 * effective_scale
})
.collect();
Tensor::from_vec(quantized_data, arr.shape())
},
_ => Err(TrustformersError::quantization_error(
"Unsupported tensor type for adaptive quantization".into(),
)),
}
}
pub fn get_layer_stats(&self, layer_name: &str) -> Option<&ActivationStats> {
self.layer_stats.get(layer_name)
}
pub fn get_all_stats(&self) -> &HashMap<String, ActivationStats> {
&self.layer_stats
}
pub fn save_calibration(&self, path: &str) -> Result<()> {
let json_data = serde_json::to_string_pretty(&self.layer_stats).map_err(|e| {
TrustformersError::quantization_error(format!("Failed to serialize statistics: {}", e))
})?;
std::fs::write(path, json_data).map_err(|e| {
TrustformersError::quantization_error(format!("Failed to write file: {}", e))
})?;
Ok(())
}
pub fn load_calibration(&mut self, path: &str) -> Result<()> {
let json_data = std::fs::read_to_string(path).map_err(|e| {
TrustformersError::quantization_error(format!("Failed to read file: {}", e))
})?;
self.layer_stats = serde_json::from_str(&json_data).map_err(|e| {
TrustformersError::quantization_error(format!(
"Failed to deserialize statistics: {}",
e
))
})?;
self.calibrating = false;
Ok(())
}
pub fn configure_layer(&mut self, layer_name: &str, config: LayerQuantConfig) {
self.config.layer_configs.insert(layer_name.to_string(), config);
}
pub fn disable_layer(&mut self, layer_name: &str) {
let config = LayerQuantConfig {
enabled: false,
..Default::default()
};
self.config.layer_configs.insert(layer_name.to_string(), config);
}
pub fn get_memory_savings(&self) -> f32 {
match self.config.scheme {
ActivationQuantScheme::Int8
| ActivationQuantScheme::Dynamic
| ActivationQuantScheme::Adaptive => 0.75, ActivationQuantScheme::Int16 => 0.5, }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_activation_stats_update() {
let mut stats = ActivationStats::new();
let tensor =
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).expect("Tensor from_vec failed");
stats.update(&tensor, 0.01).expect("tensor operation failed");
assert_eq!(stats.count, 5);
assert_eq!(stats.min_val, 1.0);
assert_eq!(stats.max_val, 5.0);
assert_eq!(stats.mean(), 3.0);
}
#[test]
fn test_activation_stats_quantization_params() {
let mut stats = ActivationStats::new();
let tensor = Tensor::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5])
.expect("Tensor from_vec failed");
stats.update(&tensor, 0.01).expect("tensor operation failed");
let (scale, zero_point) =
stats.get_quantization_params(true, 8, 1.0).expect("operation failed in test");
assert!(scale > 0.0);
assert_eq!(zero_point, 0); }
#[test]
fn test_activation_quantizer_calibration() {
let config = ActivationQuantConfig {
calibration_samples: 2,
..Default::default()
};
let mut quantizer = ActivationQuantizer::new(config);
let tensor1 = Tensor::randn(&[10, 20]).expect("Failed to create random tensor");
let tensor2 = Tensor::randn(&[10, 20]).expect("Failed to create random tensor");
assert!(quantizer.calibrating);
quantizer
.quantize_activation(&tensor1, "layer1", false)
.expect("tensor operation failed");
quantizer
.quantize_activation(&tensor2, "layer1", false)
.expect("tensor operation failed");
assert!(quantizer.get_layer_stats("layer1").is_some());
}
#[test]
fn test_activation_quantizer_int8() {
let config = ActivationQuantConfig {
calibration_samples: 1,
scheme: ActivationQuantScheme::Int8,
..Default::default()
};
let mut quantizer = ActivationQuantizer::new(config);
let tensor =
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).expect("Tensor from_vec failed");
quantizer
.quantize_activation(&tensor, "test_layer", false)
.expect("tensor operation failed");
quantizer.end_calibration();
let result = quantizer
.quantize_activation(&tensor, "test_layer", false)
.expect("tensor operation failed");
assert_eq!(result.shape(), tensor.shape());
}
#[test]
fn test_activation_quantizer_layer_config() {
let config = ActivationQuantConfig::default();
let mut quantizer = ActivationQuantizer::new(config);
let layer_config = LayerQuantConfig {
enabled: true,
scheme: Some(ActivationQuantScheme::Int16),
bits: Some(16),
calibrate: true,
};
quantizer.configure_layer("special_layer", layer_config);
quantizer.disable_layer("disabled_layer");
let tensor = Tensor::randn(&[8, 8]).expect("Failed to create random tensor");
let result = quantizer
.quantize_activation(&tensor, "disabled_layer", false)
.expect("tensor operation failed");
assert_eq!(result.shape(), tensor.shape());
}
#[test]
fn test_activation_quantizer_adaptive() {
let config = ActivationQuantConfig {
scheme: ActivationQuantScheme::Adaptive,
calibration_samples: 1,
..Default::default()
};
let mut quantizer = ActivationQuantizer::new(config);
let tensor = Tensor::from_vec(vec![0.1, 0.2, 0.15, 0.18, 10.0], &[5])
.expect("Tensor from_vec failed");
quantizer
.quantize_activation(&tensor, "adaptive_layer", false)
.expect("tensor operation failed");
quantizer.end_calibration();
let result = quantizer
.quantize_activation(&tensor, "adaptive_layer", false)
.expect("tensor operation failed");
assert_eq!(result.shape(), tensor.shape());
}
#[test]
fn test_quantized_activation_dequantization() {
let _original_data = [1.0, 2.0, 3.0, 4.0];
let shape = vec![4];
let quantized_data = vec![64, 128, 192, 255]; let scale = 4.0 / 255.0; let zero_point = 0;
let quant_activation = QuantizedActivation::new(
quantized_data,
scale,
zero_point,
shape.clone(),
ActivationQuantScheme::Int8,
8,
);
let dequantized = quant_activation.dequantize().expect("Dequantization failed");
assert_eq!(dequantized.shape(), shape);
}
#[test]
fn test_memory_savings_calculation() {
let config = ActivationQuantConfig {
scheme: ActivationQuantScheme::Int8,
..Default::default()
};
let quantizer = ActivationQuantizer::new(config);
let savings = quantizer.get_memory_savings();
assert_eq!(savings, 0.75); }
#[test]
fn test_percentile_calculation() {
let mut stats = ActivationStats::new();
let tensor = Tensor::from_vec((1..=100).map(|x| x as f32).collect(), &[100])
.expect("tensor operation failed");
stats.update(&tensor, 0.01).expect("tensor operation failed");
let p95 = stats.percentile(0.95);
assert!((90.0..=100.0).contains(&p95)); }
#[test]
fn test_serialization() {
let config = ActivationQuantConfig::default();
let serialized = serde_json::to_string(&config).expect("JSON serialization failed");
let deserialized: ActivationQuantConfig =
serde_json::from_str(&serialized).expect("JSON deserialization failed");
assert_eq!(config.scheme, deserialized.scheme);
assert_eq!(config.symmetric, deserialized.symmetric);
}
}