use super::config::{CalibrationMethodConfig, TernaryConfig};
use super::types::TernaryTensor;
use crate::error::{Result, UnslothError};
use candle_core::{DType, Tensor};
#[derive(Debug, Clone, Copy)]
pub enum CalibrationMethod {
AbsMax {
factor: f32,
},
Percentile {
percentile: f32,
},
MeanStd {
k: f32,
},
Manual {
threshold: f32,
},
}
impl Default for CalibrationMethod {
fn default() -> Self {
Self::AbsMax { factor: 0.7 }
}
}
impl From<CalibrationMethodConfig> for CalibrationMethod {
fn from(config: CalibrationMethodConfig) -> Self {
match config {
CalibrationMethodConfig::AbsMax => Self::AbsMax { factor: 0.7 },
CalibrationMethodConfig::Percentile(p) => Self::Percentile { percentile: p },
CalibrationMethodConfig::MeanStd(k) => Self::MeanStd { k },
CalibrationMethodConfig::Manual(t) => Self::Manual { threshold: t },
}
}
}
#[derive(Debug, Clone)]
pub struct QuantizationStats {
pub sparsity: f32,
pub positive_ratio: f32,
pub negative_ratio: f32,
pub thresholds: Vec<f32>,
pub scales: Vec<f32>,
pub mean_error: f32,
pub max_error: f32,
}
pub fn quantize_tensor(
tensor: &Tensor,
config: &TernaryConfig,
) -> Result<(TernaryTensor, QuantizationStats)> {
let shape = tensor.shape();
if shape.dims().len() != 2 {
return Err(UnslothError::ShapeMismatch {
expected: vec![2],
actual: shape.dims().to_vec(),
});
}
if tensor.dtype() != DType::F32 {
return Err(UnslothError::InvalidConfig(format!(
"quantize_tensor requires f32, got {:?}",
tensor.dtype()
)));
}
let (out_features, in_features) = (shape.dims()[0], shape.dims()[1]);
let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
let calibration = CalibrationMethod::from(config.calibration_method);
let k_words = in_features.div_ceil(32);
let mut plus_plane = vec![0u32; out_features * k_words];
let mut minus_plane = vec![0u32; out_features * k_words];
let mut scales = vec![0.0f32; out_features];
let mut thresholds = vec![0.0f32; out_features];
let mut total_positive = 0usize;
let mut total_negative = 0usize;
let mut total_zero = 0usize;
let mut total_error = 0.0f64;
let mut max_error = 0.0f32;
for row in 0..out_features {
let row_start = row * in_features;
let row_data = &data[row_start..row_start + in_features];
let threshold = compute_threshold(row_data, calibration);
thresholds[row] = threshold;
let (row_plus, row_minus, scale, pos, neg, zero) =
quantize_row(row_data, threshold, k_words);
let plane_offset = row * k_words;
plus_plane[plane_offset..plane_offset + k_words].copy_from_slice(&row_plus);
minus_plane[plane_offset..plane_offset + k_words].copy_from_slice(&row_minus);
scales[row] = scale;
total_positive += pos;
total_negative += neg;
total_zero += zero;
for (i, &val) in row_data.iter().enumerate() {
let word_idx = i / 32;
let bit_idx = i % 32;
let mask = 1u32 << bit_idx;
let is_plus = (row_plus[word_idx] & mask) != 0;
let is_minus = (row_minus[word_idx] & mask) != 0;
let reconstructed = if is_plus {
scale
} else if is_minus {
-scale
} else {
0.0
};
let error = (val - reconstructed).abs();
total_error += f64::from(error);
max_error = max_error.max(error);
}
}
let total_elements = out_features * in_features;
#[allow(clippy::cast_precision_loss)] let stats = QuantizationStats {
sparsity: total_zero as f32 / total_elements as f32,
positive_ratio: total_positive as f32 / total_elements as f32,
negative_ratio: total_negative as f32 / total_elements as f32,
thresholds,
scales: scales.clone(),
#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] mean_error: (total_error / total_elements as f64) as f32,
max_error,
};
let ternary = TernaryTensor::new(plus_plane, minus_plane, scales, (out_features, in_features));
Ok((ternary, stats))
}
fn compute_threshold(data: &[f32], method: CalibrationMethod) -> f32 {
match method {
CalibrationMethod::AbsMax { factor } => {
let max_abs = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
factor * max_abs
}
CalibrationMethod::Percentile { percentile } => {
let mut abs_values: Vec<f32> = data.iter().map(|x| x.abs()).collect();
abs_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_precision_loss
)]
let idx = ((percentile / 100.0) * (abs_values.len() - 1) as f32) as usize;
abs_values[idx.min(abs_values.len() - 1)]
}
CalibrationMethod::MeanStd { k } => {
#[allow(clippy::cast_precision_loss)]
let n = data.len() as f64;
let abs_values: Vec<f64> = data.iter().map(|x| f64::from(x.abs())).collect();
let mean = abs_values.iter().sum::<f64>() / n;
let variance = abs_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
let std = variance.sqrt();
#[allow(clippy::cast_possible_truncation)]
let threshold_value = (mean + f64::from(k) * std) as f32;
threshold_value
}
CalibrationMethod::Manual { threshold } => threshold,
}
}
fn quantize_row(
data: &[f32],
threshold: f32,
k_words: usize,
) -> (Vec<u32>, Vec<u32>, f32, usize, usize, usize) {
let mut plus = vec![0u32; k_words];
let mut minus = vec![0u32; k_words];
let mut positive_sum = 0.0f64;
let mut positive_count = 0usize;
let mut negative_sum = 0.0f64;
let mut negative_count = 0usize;
let mut zero_count = 0usize;
for (i, &val) in data.iter().enumerate() {
let word_idx = i / 32;
let bit_idx = i % 32;
let mask = 1u32 << bit_idx;
if val > threshold {
plus[word_idx] |= mask;
positive_sum += f64::from(val.abs());
positive_count += 1;
} else if val < -threshold {
minus[word_idx] |= mask;
negative_sum += f64::from(val.abs());
negative_count += 1;
} else {
zero_count += 1;
}
}
let nonzero_count = positive_count + negative_count;
let scale = if nonzero_count > 0 {
#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
let scale = ((positive_sum + negative_sum) / nonzero_count as f64) as f32;
scale
} else {
1.0 };
(
plus,
minus,
scale,
positive_count,
negative_count,
zero_count,
)
}
pub fn dequantize_tensor(ternary: &TernaryTensor) -> Result<Tensor> {
let (out_features, in_features) = ternary.dims();
let mut data = vec![0.0f32; out_features * in_features];
for row in 0..out_features {
let scale = ternary.scales[row];
let planes = ternary.get_row_planes(row);
for col in 0..in_features {
let val = planes.get(col);
data[row * in_features + col] = f32::from(val) * scale;
}
}
let tensor = Tensor::from_vec(data, (out_features, in_features), &candle_core::Device::Cpu)?;
Ok(tensor)
}
pub fn quantize_linear_weights(weights: &Tensor, config: &TernaryConfig) -> Result<TernaryTensor> {
let (ternary, _stats) = quantize_tensor(weights, config)?;
Ok(ternary)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_quantize_simple() -> Result<()> {
let data: Vec<f32> = vec![
0.5, -0.5, 0.1, -0.1, 0.8, -0.8, 0.0, 0.3, 1.0, -1.0, 0.2, -0.2, 0.0, 0.0, 0.9, -0.9, ];
let tensor = Tensor::from_vec(data, (2, 8), &Device::Cpu)?;
let config = TernaryConfig {
calibration_method: CalibrationMethodConfig::Manual(0.3),
..Default::default()
};
let (ternary, stats) = quantize_tensor(&tensor, &config)?;
assert_eq!(ternary.dims(), (2, 8));
assert!(stats.sparsity > 0.0); assert!(stats.positive_ratio > 0.0);
assert!(stats.negative_ratio > 0.0);
Ok(())
}
#[test]
fn test_quantize_dequantize_roundtrip() -> Result<()> {
let data: Vec<f32> = (0..256)
.map(|i| {
#[allow(clippy::cast_precision_loss)]
{
(i as f32 - 128.0) / 128.0
}
})
.collect();
let tensor = Tensor::from_vec(data.clone(), (4, 64), &Device::Cpu)?;
let config = TernaryConfig::default();
let (ternary, _stats) = quantize_tensor(&tensor, &config)?;
let reconstructed = dequantize_tensor(&ternary)?;
let recon_data: Vec<f32> = reconstructed.flatten_all()?.to_vec1()?;
let mse: f32 = data
.iter()
.zip(recon_data.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
/ {
#[allow(clippy::cast_precision_loss)]
{
data.len() as f32
}
};
assert!(mse < 0.5, "MSE too high: {mse}");
Ok(())
}
#[test]
fn test_calibration_methods() {
let data: Vec<f32> = vec![0.1, 0.5, 1.0, -0.3, -0.8, 2.0, -1.5, 0.0];
let t1 = compute_threshold(&data, CalibrationMethod::AbsMax { factor: 0.7 });
assert!((t1 - 1.4).abs() < 0.01);
let t2 = compute_threshold(&data, CalibrationMethod::Manual { threshold: 0.5 });
assert!((t2 - 0.5).abs() < 0.001);
}
#[test]
fn test_sparsity_detection() -> Result<()> {
let mut data = vec![0.0f32; 1000];
for i in 0..100 {
data[i * 10] = if i % 2 == 0 { 1.0 } else { -1.0 };
}
let tensor = Tensor::from_vec(data, (10, 100), &Device::Cpu)?;
let config = TernaryConfig {
calibration_method: CalibrationMethodConfig::Manual(0.1),
..Default::default()
};
let (ternary, stats) = quantize_tensor(&tensor, &config)?;
assert!(stats.sparsity > 0.85, "Sparsity: {}", stats.sparsity);
assert!(ternary.sparsity() > 0.85);
Ok(())
}
#[test]
fn test_compression_ratio() -> Result<()> {
let data = vec![0.0f32; 4096 * 4096];
let tensor = Tensor::from_vec(data, (4096, 4096), &Device::Cpu)?;
let config = TernaryConfig::default();
let (ternary, _) = quantize_tensor(&tensor, &config)?;
let ratio = ternary.compression_ratio();
assert!(ratio > 10.0, "Compression ratio too low: {ratio}");
Ok(())
}
}