use std::{fs, path::Path};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerCalibration {
pub name: String,
pub w_min: f32,
pub w_max: f32,
pub act_min: f32,
pub act_max: f32,
}
impl LayerCalibration {
pub fn weight_scale(&self) -> f32 {
self.w_min.abs().max(self.w_max.abs()) / 127.0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizedWeights {
pub name: String,
pub weights_i8: Vec<i32>,
pub shape: Vec<usize>,
pub scale: f32,
pub bias: Option<Vec<f32>>,
}
impl QuantizedWeights {
pub fn from_f32(
name: String,
w: &[f32],
shape: Vec<usize>,
bias: Option<Vec<f32>>,
scale: Option<f32>,
) -> Self {
let max_abs = w.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let s = scale.unwrap_or(max_abs / 127.0).max(1e-8);
let weights_i8: Vec<i32> = w
.iter()
.map(|&x| (x / s).round().clamp(-127.0, 127.0) as i32)
.collect();
Self { name, weights_i8, shape, scale: s, bias }
}
pub fn dequantize(&self) -> Vec<f32> {
self.weights_i8
.iter()
.map(|&q| q as f32 * self.scale)
.collect()
}
pub fn size_bytes(&self) -> usize {
self.weights_i8.len() }
pub fn original_size_bytes(&self) -> usize {
self.weights_i8.len() * 4
}
pub fn compression_ratio(&self) -> f32 {
self.original_size_bytes() as f32 / self.size_bytes() as f32
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizedModel {
pub config_json: String,
pub layers: Vec<QuantizedWeights>,
pub scheme: String,
pub total_quantized_bytes: usize,
pub total_fp32_bytes: usize,
}
impl QuantizedModel {
pub fn compression_ratio(&self) -> f32 {
self.total_fp32_bytes as f32 / self.total_quantized_bytes.max(1) as f32
}
pub fn save(&self, path: &Path) -> crate::error::Result<()> {
let json = serde_json::to_string_pretty(self)?;
fs::write(path, json)?;
Ok(())
}
pub fn load(path: &Path) -> crate::error::Result<Self> {
let json = fs::read_to_string(path)?;
let model = serde_json::from_str(&json)?;
Ok(model)
}
}
pub struct Calibrator {
calibrations: Vec<LayerCalibration>,
}
impl Calibrator {
pub fn new() -> Self {
Self { calibrations: Vec::new() }
}
pub fn record_layer(&mut self, name: String, weights: &[f32]) {
let w_min = weights.iter().cloned().fold(f32::INFINITY, f32::min);
let w_max = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
self.calibrations.push(LayerCalibration {
name,
w_min,
w_max,
act_min: -1.0, act_max: 1.0,
});
}
pub fn finish(self) -> Vec<LayerCalibration> {
self.calibrations
}
}
impl Default for Calibrator {
fn default() -> Self {
Self::new()
}
}
pub fn quantize_model_weights(
config_json: String,
layers: impl IntoIterator<Item = (String, Vec<f32>, Vec<usize>, Option<Vec<f32>>)>,
) -> QuantizedModel {
let mut quantized_layers = Vec::new();
let mut total_fp32_bytes = 0usize;
let mut total_quantized_bytes = 0usize;
for (name, weights, shape, bias) in layers {
let qw = QuantizedWeights::from_f32(name, &weights, shape, bias, None);
total_fp32_bytes += qw.original_size_bytes();
total_quantized_bytes += qw.size_bytes();
quantized_layers.push(qw);
}
QuantizedModel {
config_json,
layers: quantized_layers,
scheme: "symmetric_per_tensor_int8".to_string(),
total_quantized_bytes,
total_fp32_bytes,
}
}
pub fn fp32_to_fp16_bits(weights: &[f32]) -> Vec<u16> {
weights
.iter()
.map(|&x| half::f16::from_f32(x).to_bits())
.collect()
}
pub fn fp16_bits_to_fp32(bits: &[u16]) -> Vec<f32> {
bits.iter()
.map(|&b| half::f16::from_bits(b).to_f32())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantize_roundtrip() {
let weights: Vec<f32> = (0..100).map(|i| i as f32 * 0.01 - 0.5).collect();
let qw = QuantizedWeights::from_f32(
"test".to_string(),
&weights,
vec![10, 10],
None,
None,
);
let dq = qw.dequantize();
assert_eq!(dq.len(), weights.len());
let max_err = weights.iter().zip(dq.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(max_err <= qw.scale / 2.0 + 1e-6,
"Max quant error {max_err} > scale/2 = {}", qw.scale / 2.0);
}
#[test]
fn test_compression_ratio() {
let weights: Vec<f32> = vec![1.0f32; 1024];
let qw = QuantizedWeights::from_f32("test".into(), &weights, vec![32, 32], None, None);
assert!((qw.compression_ratio() - 4.0).abs() < 1e-5,
"INT8 should compress ~4x vs FP32");
}
#[test]
fn test_save_load_roundtrip() {
let qm = QuantizedModel {
config_json: "{}".into(),
layers: vec![QuantizedWeights {
name: "l1".into(),
weights_i8: vec![1, -2, 3],
shape: vec![1, 3],
scale: 0.01,
bias: None,
}],
scheme: "symmetric_per_tensor_int8".into(),
total_quantized_bytes: 3,
total_fp32_bytes: 12,
};
let tmp = tempfile::NamedTempFile::new().unwrap();
qm.save(tmp.path()).unwrap();
let loaded = QuantizedModel::load(tmp.path()).unwrap();
assert_eq!(loaded.layers[0].name, "l1");
assert_eq!(loaded.layers[0].weights_i8, vec![1, -2, 3]);
}
}