use crate::semiring::basic::{LogWeight, TropicalWeight};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct QuantizationParams {
pub min_val: f64,
pub max_val: f64,
scale: f64,
inv_scale: f64,
}
impl QuantizationParams {
pub fn new(min_val: f64, max_val: f64) -> Self {
assert!(
min_val.is_finite() && max_val.is_finite(),
"quantization bounds must be finite: [{}, {}]",
min_val,
max_val
);
assert!(
min_val < max_val,
"min_val must be less than max_val: {} >= {}",
min_val,
max_val
);
let range = max_val - min_val;
assert!(
range.is_finite(),
"quantization range must be finite: {} - {}",
max_val,
min_val
);
Self {
min_val,
max_val,
scale: 255.0 / range, inv_scale: range / 255.0,
}
}
pub fn from_data(values: impl Iterator<Item = f64>) -> Option<Self> {
let mut min_val = f64::INFINITY;
let mut max_val = f64::NEG_INFINITY;
let mut count = 0;
for v in values {
if v.is_finite() {
min_val = min_val.min(v);
max_val = max_val.max(v);
count += 1;
}
}
if count == 0 || min_val >= max_val {
return None;
}
let range = max_val - min_val;
if !range.is_finite() {
return None;
}
let margin = range * 0.01;
if !margin.is_finite() {
return None;
}
let lower = min_val - margin;
let upper = max_val + margin;
if !(lower.is_finite() && upper.is_finite()) {
return None;
}
Some(Self::new(lower, upper))
}
pub fn for_log_weights() -> Self {
Self::new(-20.0, 0.0)
}
pub fn for_tropical_weights() -> Self {
Self::new(0.0, 20.0)
}
#[inline]
pub fn scale_8bit(&self) -> f64 {
self.scale
}
#[inline]
pub fn scale_4bit(&self) -> f64 {
15.0 / (self.max_val - self.min_val)
}
#[inline]
pub fn quantize_8bit(&self, value: f64) -> u8 {
if !value.is_finite() {
return if value.is_sign_positive() { 255 } else { 0 };
}
let clamped = value.clamp(self.min_val, self.max_val);
let normalized = (clamped - self.min_val) * self.scale;
normalized.round() as u8
}
#[inline]
pub fn dequantize_8bit(&self, quantized: u8) -> f64 {
self.min_val + (quantized as f64) * self.inv_scale
}
#[inline]
pub fn quantize_4bit(&self, value: f64) -> u8 {
if !value.is_finite() {
return if value.is_sign_positive() { 15 } else { 0 };
}
let clamped = value.clamp(self.min_val, self.max_val);
let scale = self.scale_4bit();
let normalized = (clamped - self.min_val) * scale;
(normalized.round() as u8).min(15)
}
#[inline]
pub fn dequantize_4bit(&self, quantized: u8) -> f64 {
let inv_scale = (self.max_val - self.min_val) / 15.0;
self.min_val + ((quantized & 0x0F) as f64) * inv_scale
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct Quantized8Weight {
value: u8,
}
impl Quantized8Weight {
#[inline]
pub const fn from_raw(value: u8) -> Self {
Self { value }
}
#[inline]
pub const fn raw(&self) -> u8 {
self.value
}
#[inline]
pub fn from_log_weight(weight: LogWeight, params: &QuantizationParams) -> Self {
Self {
value: params.quantize_8bit(weight.value()),
}
}
#[inline]
pub fn to_log_weight(&self, params: &QuantizationParams) -> LogWeight {
LogWeight::new(params.dequantize_8bit(self.value))
}
#[inline]
pub fn from_tropical_weight(weight: TropicalWeight, params: &QuantizationParams) -> Self {
Self {
value: params.quantize_8bit(weight.value()),
}
}
#[inline]
pub fn to_tropical_weight(&self, params: &QuantizationParams) -> TropicalWeight {
TropicalWeight::new(params.dequantize_8bit(self.value))
}
pub const ZERO: Self = Self { value: 255 };
pub const ONE: Self = Self { value: 0 };
}
impl fmt::Debug for Quantized8Weight {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Q8({})", self.value)
}
}
impl fmt::Display for Quantized8Weight {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.value)
}
}
impl Default for Quantized8Weight {
fn default() -> Self {
Self::ZERO
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct Quantized4Weight {
value: u8,
}
impl Quantized4Weight {
#[inline]
pub const fn from_raw(value: u8) -> Self {
Self {
value: value & 0x0F,
}
}
#[inline]
pub const fn raw(&self) -> u8 {
self.value
}
#[inline]
pub fn from_log_weight(weight: LogWeight, params: &QuantizationParams) -> Self {
Self {
value: params.quantize_4bit(weight.value()),
}
}
#[inline]
pub fn to_log_weight(&self, params: &QuantizationParams) -> LogWeight {
LogWeight::new(params.dequantize_4bit(self.value))
}
#[inline]
pub fn from_tropical_weight(weight: TropicalWeight, params: &QuantizationParams) -> Self {
Self {
value: params.quantize_4bit(weight.value()),
}
}
#[inline]
pub fn to_tropical_weight(&self, params: &QuantizationParams) -> TropicalWeight {
TropicalWeight::new(params.dequantize_4bit(self.value))
}
pub const ZERO: Self = Self { value: 15 };
pub const ONE: Self = Self { value: 0 };
}
impl fmt::Debug for Quantized4Weight {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Q4({})", self.value)
}
}
impl fmt::Display for Quantized4Weight {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.value)
}
}
impl Default for Quantized4Weight {
fn default() -> Self {
Self::ZERO
}
}
#[derive(Debug, Clone)]
pub struct PackedWeights4 {
data: Vec<u8>,
len: usize,
}
impl PackedWeights4 {
pub fn with_capacity(capacity: usize) -> Self {
Self {
data: Vec::with_capacity(capacity.div_ceil(2)),
len: 0,
}
}
pub fn from_weights(weights: &[Quantized4Weight]) -> Self {
let mut packed = Self::with_capacity(weights.len());
for w in weights {
packed.push(*w);
}
packed
}
pub fn push(&mut self, weight: Quantized4Weight) {
if self.len % 2 == 0 {
self.data.push(weight.raw());
} else {
let last = self.data.last_mut().expect("non-empty after odd push");
*last |= weight.raw() << 4;
}
self.len += 1;
}
#[inline]
pub fn get(&self, index: usize) -> Option<Quantized4Weight> {
if index >= self.len {
return None;
}
let byte_idx = index / 2;
let nibble = if index % 2 == 0 {
self.data[byte_idx] & 0x0F
} else {
(self.data[byte_idx] >> 4) & 0x0F
};
Some(Quantized4Weight::from_raw(nibble))
}
pub fn set(&mut self, index: usize, weight: Quantized4Weight) {
if index >= self.len {
return;
}
let byte_idx = index / 2;
if index % 2 == 0 {
self.data[byte_idx] = (self.data[byte_idx] & 0xF0) | weight.raw();
} else {
self.data[byte_idx] = (self.data[byte_idx] & 0x0F) | (weight.raw() << 4);
}
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn memory_size(&self) -> usize {
self.data.len()
}
pub fn iter(&self) -> impl Iterator<Item = Quantized4Weight> + '_ {
(0..self.len).map(|i| self.get(i).expect("valid index"))
}
}
#[derive(Debug, Clone, Default)]
pub struct QuantizationStats {
pub count: usize,
pub total_error: f64,
pub max_error: f64,
pub sum_squared_error: f64,
}
impl QuantizationStats {
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, original: f64, quantized: f64) {
let error = (original - quantized).abs();
self.count += 1;
self.total_error += error;
self.max_error = self.max_error.max(error);
self.sum_squared_error += error * error;
}
pub fn mean_error(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.total_error / self.count as f64
}
}
pub fn rmse(&self) -> f64 {
if self.count == 0 {
0.0
} else {
(self.sum_squared_error / self.count as f64).sqrt()
}
}
}
pub fn quantize_log_weights_8bit(
weights: &[LogWeight],
params: &QuantizationParams,
) -> Vec<Quantized8Weight> {
weights
.iter()
.map(|w| Quantized8Weight::from_log_weight(*w, params))
.collect()
}
pub fn dequantize_8bit_to_log(
weights: &[Quantized8Weight],
params: &QuantizationParams,
) -> Vec<LogWeight> {
weights.iter().map(|w| w.to_log_weight(params)).collect()
}
pub fn quantize_log_weights_4bit(
weights: &[LogWeight],
params: &QuantizationParams,
) -> Vec<Quantized4Weight> {
weights
.iter()
.map(|w| Quantized4Weight::from_log_weight(*w, params))
.collect()
}
pub fn quantize_and_pack_4bit(
weights: &[LogWeight],
params: &QuantizationParams,
) -> PackedWeights4 {
let quantized: Vec<_> = weights
.iter()
.map(|w| Quantized4Weight::from_log_weight(*w, params))
.collect();
PackedWeights4::from_weights(&quantized)
}
pub fn compute_quantization_stats_8bit(
weights: &[LogWeight],
params: &QuantizationParams,
) -> QuantizationStats {
let mut stats = QuantizationStats::new();
for w in weights {
let original = w.value();
let quantized = Quantized8Weight::from_log_weight(*w, params);
let dequantized = quantized.to_log_weight(params).value();
stats.add(original, dequantized);
}
stats
}
pub fn compute_quantization_stats_4bit(
weights: &[LogWeight],
params: &QuantizationParams,
) -> QuantizationStats {
let mut stats = QuantizationStats::new();
for w in weights {
let original = w.value();
let quantized = Quantized4Weight::from_log_weight(*w, params);
let dequantized = quantized.to_log_weight(params).value();
stats.add(original, dequantized);
}
stats
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantization_params() {
let params = QuantizationParams::new(-10.0, 10.0);
assert_eq!(params.quantize_8bit(-10.0), 0);
assert_eq!(params.quantize_8bit(10.0), 255);
let mid = params.quantize_8bit(0.0);
assert!((mid as i32 - 127).abs() <= 1); }
#[test]
#[should_panic(expected = "quantization bounds must be finite")]
fn test_quantization_params_reject_nonfinite_bounds() {
let _ = QuantizationParams::new(0.0, f64::INFINITY);
}
#[test]
#[should_panic(expected = "quantization range must be finite")]
fn test_quantization_params_reject_overflowing_range() {
let _ = QuantizationParams::new(-f64::MAX, f64::MAX);
}
#[test]
fn test_roundtrip_8bit() {
let params = QuantizationParams::for_log_weights();
let original = LogWeight::new(-5.0);
let quantized = Quantized8Weight::from_log_weight(original, ¶ms);
let recovered = quantized.to_log_weight(¶ms);
assert!((recovered.value() - original.value()).abs() < 0.1);
}
#[test]
fn test_roundtrip_8bit_error_bound() {
let params = QuantizationParams::new(-20.0, 20.0);
let half_step = (params.max_val - params.min_val) / 255.0 / 2.0;
for i in 0..=1000 {
let value = params.min_val + (params.max_val - params.min_val) * (i as f64) / 1000.0;
let quantized = params.quantize_8bit(value);
let recovered = params.dequantize_8bit(quantized);
let error = (recovered - value).abs();
assert!(
error <= half_step + 1e-12,
"8-bit quantization error {} exceeded half-step {} for value {}",
error,
half_step,
value
);
}
}
#[test]
fn test_roundtrip_4bit() {
let params = QuantizationParams::for_log_weights();
let original = LogWeight::new(-10.0);
let quantized = Quantized4Weight::from_log_weight(original, ¶ms);
let recovered = quantized.to_log_weight(¶ms);
assert!((recovered.value() - original.value()).abs() < 2.0);
}
#[test]
fn test_roundtrip_4bit_error_bound() {
let params = QuantizationParams::new(-20.0, 20.0);
let half_step = (params.max_val - params.min_val) / 15.0 / 2.0;
for i in 0..=1000 {
let value = params.min_val + (params.max_val - params.min_val) * (i as f64) / 1000.0;
let quantized = params.quantize_4bit(value);
let recovered = params.dequantize_4bit(quantized);
let error = (recovered - value).abs();
assert!(
error <= half_step + 1e-12,
"4-bit quantization error {} exceeded half-step {} for value {}",
error,
half_step,
value
);
}
}
#[test]
fn test_packed_weights() {
let params = QuantizationParams::for_log_weights();
let weights: Vec<LogWeight> = (0..10).map(|i| LogWeight::new(-i as f64 * 2.0)).collect();
let packed = quantize_and_pack_4bit(&weights, ¶ms);
assert_eq!(packed.len(), 10);
assert_eq!(packed.memory_size(), 5);
for (i, w) in weights.iter().enumerate() {
let packed_w = packed.get(i).expect("valid index");
let recovered = packed_w.to_log_weight(¶ms);
assert!((recovered.value() - w.value()).abs() < 2.0);
}
}
#[test]
fn test_quantization_stats() {
let params = QuantizationParams::for_log_weights();
let weights: Vec<LogWeight> = (0..100).map(|i| LogWeight::new(-i as f64 * 0.2)).collect();
let stats = compute_quantization_stats_8bit(&weights, ¶ms);
assert_eq!(stats.count, 100);
assert!(stats.mean_error() < 0.1);
assert!(stats.rmse() < 0.1);
}
#[test]
fn test_infinity_handling() {
let params = QuantizationParams::for_log_weights();
let inf = LogWeight::new(f64::INFINITY);
let below_range = LogWeight::new(-1.0e300);
let q_inf = Quantized8Weight::from_log_weight(inf, ¶ms);
let q_below_range = Quantized8Weight::from_log_weight(below_range, ¶ms);
assert_eq!(q_inf.raw(), 255);
assert_eq!(q_below_range.raw(), 0);
}
#[test]
fn test_tropical_quantization() {
let params = QuantizationParams::for_tropical_weights();
let original = TropicalWeight::new(5.0);
let quantized = Quantized8Weight::from_tropical_weight(original, ¶ms);
let recovered = quantized.to_tropical_weight(¶ms);
assert!((recovered.value() - original.value()).abs() < 0.1);
}
#[test]
fn test_params_from_data() {
let values = vec![1.0, 5.0, 3.0, 7.0, 2.0];
let params = QuantizationParams::from_data(values.into_iter()).expect("valid data");
assert!(params.min_val < 1.0);
assert!(params.max_val > 7.0);
}
#[test]
fn test_params_from_data_rejects_nonfinite_range() {
let values = vec![-f64::MAX, f64::MAX];
assert!(QuantizationParams::from_data(values.into_iter()).is_none());
}
#[test]
fn test_batch_quantize() {
let params = QuantizationParams::for_log_weights();
let weights: Vec<LogWeight> = vec![
LogWeight::new(-1.0),
LogWeight::new(-5.0),
LogWeight::new(-10.0),
];
let quantized = quantize_log_weights_8bit(&weights, ¶ms);
let dequantized = dequantize_8bit_to_log(&quantized, ¶ms);
assert_eq!(dequantized.len(), 3);
for (orig, deq) in weights.iter().zip(dequantized.iter()) {
assert!((orig.value() - deq.value()).abs() < 0.1);
}
}
#[test]
fn test_packed_weights_set() {
let mut packed = PackedWeights4::with_capacity(4);
packed.push(Quantized4Weight::from_raw(3));
packed.push(Quantized4Weight::from_raw(7));
packed.push(Quantized4Weight::from_raw(11));
packed.push(Quantized4Weight::from_raw(15));
assert_eq!(packed.get(0).expect("valid").raw(), 3);
assert_eq!(packed.get(1).expect("valid").raw(), 7);
packed.set(1, Quantized4Weight::from_raw(9));
assert_eq!(packed.get(1).expect("valid").raw(), 9);
assert_eq!(packed.get(0).expect("valid").raw(), 3);
assert_eq!(packed.get(2).expect("valid").raw(), 11);
}
}