use serde::{Deserialize, Serialize};
use crate::codebook::Codebook;
use crate::error::{Result, TurboQuantError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScalarQuantizer {
pub codebook: Codebook,
}
impl ScalarQuantizer {
pub fn from_codebook(codebook: Codebook) -> Self {
Self { codebook }
}
pub fn quantize_scalar(&self, value: f64) -> u8 {
self.codebook.quantize_scalar(value)
}
pub fn checked_quantize_scalar(&self, value: f64) -> Result<u8> {
self.codebook.checked_quantize_scalar(value)
}
pub fn dequantize_scalar(&self, index: u8) -> f64 {
self.codebook.dequantize_scalar(index)
}
pub fn checked_dequantize_scalar(&self, index: u8) -> Result<f64> {
self.codebook.checked_dequantize_scalar(index)
}
pub fn quantize_batch(&self, values: &[f64]) -> Vec<u8> {
values.iter().map(|&v| self.quantize_scalar(v)).collect()
}
pub fn dequantize_batch(&self, indices: &[u8]) -> Vec<f64> {
indices.iter().map(|&i| self.dequantize_scalar(i)).collect()
}
pub fn mse(&self, values: &[f64]) -> f64 {
if values.is_empty() {
return 0.0;
}
let indices = self.quantize_batch(values);
let recon = self.dequantize_batch(&indices);
let sq_err: f64 = values
.iter()
.zip(recon.iter())
.map(|(v, r)| (v - r) * (v - r))
.sum();
sq_err / values.len() as f64
}
pub fn validate_indices(&self, indices: &[u8]) -> Result<()> {
let max_idx = self.codebook.num_levels() as u8 - 1;
for &idx in indices {
if idx > max_idx {
return Err(TurboQuantError::InvalidQuantizationIndex {
index: idx,
max: max_idx,
bit_width: self.codebook.bit_width,
});
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codebook::generate_codebook;
fn make_quantizer(dim: usize, bits: u8) -> ScalarQuantizer {
let cb = generate_codebook(dim, bits, 50).unwrap();
ScalarQuantizer::from_codebook(cb)
}
#[test]
fn test_quantize_scalar_in_range() {
let sq = make_quantizer(64, 2);
let max_idx = (1u8 << 2) - 1;
for val in [-0.9, -0.5, 0.0, 0.5, 0.9] {
let idx = sq.quantize_scalar(val);
assert!(idx <= max_idx, "index {} out of range for 2-bit", idx);
}
}
#[test]
fn test_batch_roundtrip() {
let sq = make_quantizer(128, 4);
let values: Vec<f64> = (0..16).map(|i| (i as f64 - 7.5) * 0.03).collect();
let indices = sq.quantize_batch(&values);
let recon = sq.dequantize_batch(&indices);
for (v, r) in values.iter().zip(recon.iter()) {
assert!((v - r).abs() < 0.1, "v={}, r={}", v, r);
}
}
#[test]
fn test_validate_indices_out_of_range() {
let sq = make_quantizer(64, 2);
assert!(sq.validate_indices(&[0, 1, 2, 3]).is_ok());
assert!(sq.validate_indices(&[4]).is_err());
}
#[test]
fn test_mse_empty_input() {
let sq = make_quantizer(64, 2);
assert_eq!(sq.mse(&[]), 0.0);
}
#[test]
fn test_mse_decreases_with_bits() {
let values: Vec<f64> = (0..1000).map(|i| ((i as f64) * 0.1).sin()).collect();
let mse2 = make_quantizer(128, 2).mse(&values);
let mse4 = make_quantizer(128, 4).mse(&values);
assert!(
mse4 < mse2,
"4-bit MSE {} should be less than 2-bit MSE {}",
mse4,
mse2
);
}
#[test]
fn test_checked_scalar_methods() {
let sq = make_quantizer(64, 2);
assert!(sq.checked_quantize_scalar(0.0).is_ok());
assert!(matches!(
sq.checked_quantize_scalar(f64::NAN),
Err(TurboQuantError::InvalidValue { .. })
));
assert!(matches!(
sq.checked_dequantize_scalar(4),
Err(TurboQuantError::InvalidQuantizationIndex { .. })
));
}
}