use serde::{Deserialize, Serialize};
use crate::error::{Result, TurboQuantError};
use crate::utils::{beta_pdf, sample_beta_marginal};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Codebook {
pub centroids: Vec<f64>,
pub boundaries: Vec<f64>,
pub bit_width: u8,
}
impl Codebook {
pub fn num_levels(&self) -> usize {
self.centroids.len()
}
pub fn validate_index(&self, index: u8) -> Result<()> {
let max = self.num_levels().saturating_sub(1) as u8;
if index as usize >= self.num_levels() {
return Err(TurboQuantError::InvalidQuantizationIndex {
index,
max,
bit_width: self.bit_width,
});
}
Ok(())
}
pub fn quantize_scalar(&self, value: f64) -> u8 {
if value.is_nan() {
return 0;
}
match self
.boundaries
.binary_search_by(|b| b.partial_cmp(&value).unwrap_or(std::cmp::Ordering::Equal))
{
Ok(i) => i as u8, Err(i) => i.min(self.centroids.len() - 1) as u8,
}
}
pub fn checked_quantize_scalar(&self, value: f64) -> Result<u8> {
if !value.is_finite() {
return Err(TurboQuantError::InvalidValue {
context: "codebook scalar value".into(),
value,
});
}
Ok(self.quantize_scalar(value))
}
pub fn dequantize_scalar(&self, index: u8) -> f64 {
self.centroids[index as usize]
}
pub fn checked_dequantize_scalar(&self, index: u8) -> Result<f64> {
self.validate_index(index)?;
Ok(self.dequantize_scalar(index))
}
}
pub fn generate_codebook(dim: usize, bit_width: u8, iterations: usize) -> Result<Codebook> {
if !(1..=8).contains(&bit_width) {
return Err(TurboQuantError::InvalidBitWidth(bit_width));
}
if dim == 0 {
return Err(TurboQuantError::InvalidDimension(dim));
}
let k = 1usize << bit_width;
let mut centroids: Vec<f64> = (0..k)
.map(|i| {
let u = (i as f64 + 0.5) / k as f64;
sample_beta_marginal(dim, u)
})
.collect();
centroids.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
for _ in 0..iterations {
let boundaries: Vec<f64> = centroids.windows(2).map(|w| (w[0] + w[1]) / 2.0).collect();
let new_centroids = compute_centroids(¢roids, &boundaries, dim);
let converged = centroids
.iter()
.zip(new_centroids.iter())
.all(|(a, b)| (a - b).abs() < 1e-12);
centroids = new_centroids;
if converged {
break;
}
}
let boundaries: Vec<f64> = centroids.windows(2).map(|w| (w[0] + w[1]) / 2.0).collect();
Ok(Codebook {
centroids,
boundaries,
bit_width,
})
}
fn compute_centroids(old_centroids: &[f64], boundaries: &[f64], dim: usize) -> Vec<f64> {
let k = old_centroids.len();
let n_points = 200usize;
let mut new_centroids = Vec::with_capacity(k);
for i in 0..k {
let lo = if i == 0 { -1.0_f64 } else { boundaries[i - 1] };
let hi = if i == k - 1 { 1.0_f64 } else { boundaries[i] };
let lo = lo.max(-0.9999);
let hi = hi.min(0.9999);
if hi <= lo {
new_centroids.push(old_centroids[i]);
continue;
}
let (num, den) = simpson_integrate(lo, hi, n_points, |x| {
let pdf = beta_pdf(x, dim);
(x * pdf, pdf)
});
if den.abs() < 1e-15 {
new_centroids.push(old_centroids[i]);
} else {
new_centroids.push(num / den);
}
}
new_centroids
}
fn simpson_integrate<F>(a: f64, b: f64, n: usize, f: F) -> (f64, f64)
where
F: Fn(f64) -> (f64, f64),
{
let n = if n.is_multiple_of(2) { n } else { n + 1 };
let h = (b - a) / n as f64;
let (f0, g0) = f(a);
let (fn_, gn) = f(b);
let mut sum_f = f0 + fn_;
let mut sum_g = g0 + gn;
for i in 1..n {
let x = a + i as f64 * h;
let (fi, gi) = f(x);
let w = if i % 2 == 0 { 2.0 } else { 4.0 };
sum_f += w * fi;
sum_g += w * gi;
}
(sum_f * h / 3.0, sum_g * h / 3.0)
}
pub struct CodebookCache {
books: std::collections::HashMap<(usize, u8), Codebook>,
}
impl CodebookCache {
pub fn new() -> Self {
Self {
books: std::collections::HashMap::new(),
}
}
pub fn get_or_generate(&mut self, dim: usize, bit_width: u8) -> Result<&Codebook> {
let key = (dim, bit_width);
if let std::collections::hash_map::Entry::Vacant(e) = self.books.entry(key) {
let cb = generate_codebook(dim, bit_width, 100)?;
e.insert(cb);
}
Ok(&self.books[&key])
}
}
impl Default for CodebookCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_codebook_1bit() {
let cb = generate_codebook(128, 1, 50).unwrap();
assert_eq!(cb.centroids.len(), 2);
assert_eq!(cb.boundaries.len(), 1);
assert!((cb.centroids[0] + cb.centroids[1]).abs() < 1e-6);
assert!(cb.boundaries[0].abs() < 1e-6);
}
#[test]
fn test_codebook_4bit() {
let cb = generate_codebook(128, 4, 100).unwrap();
assert_eq!(cb.centroids.len(), 16);
assert_eq!(cb.boundaries.len(), 15);
for w in cb.centroids.windows(2) {
assert!(w[0] < w[1]);
}
}
#[test]
fn test_quantize_dequantize_roundtrip() {
let cb = generate_codebook(64, 3, 50).unwrap();
for &val in &[-0.5, -0.1, 0.0, 0.1, 0.5] {
let idx = cb.quantize_scalar(val);
let recon = cb.dequantize_scalar(idx);
assert!(
(val - recon).abs() < 0.5,
"val={}, recon={}, idx={}",
val,
recon,
idx
);
}
}
#[test]
fn test_invalid_bit_width() {
assert!(generate_codebook(64, 0, 50).is_err());
assert!(generate_codebook(64, 9, 50).is_err());
}
#[test]
fn test_invalid_dimension() {
assert!(generate_codebook(0, 4, 50).is_err());
}
#[test]
fn test_quantize_nan_returns_zero() {
let cb = generate_codebook(64, 2, 50).unwrap();
let idx = cb.quantize_scalar(f64::NAN);
assert_eq!(idx, 0);
}
#[test]
fn test_codebook_centroids_in_range() {
let cb = generate_codebook(32, 2, 50).unwrap();
for c in &cb.centroids {
assert!(*c > -1.0 && *c < 1.0, "centroid {} out of range", c);
}
}
#[test]
fn test_codebook_cache_get_or_generate() {
let mut cache = CodebookCache::new();
let cb1 = cache.get_or_generate(64, 2).unwrap();
let centroids1 = cb1.centroids.clone();
let cb2 = cache.get_or_generate(64, 2).unwrap();
assert_eq!(centroids1, cb2.centroids);
}
#[test]
fn test_codebook_cache_different_configs() {
let mut cache = CodebookCache::new();
let len1 = cache.get_or_generate(64, 2).unwrap().centroids.len();
let len2 = cache.get_or_generate(64, 4).unwrap().centroids.len();
assert_ne!(len1, len2);
}
#[test]
fn test_codebook_cache_default() {
let mut cache = CodebookCache::default();
assert!(cache.get_or_generate(32, 2).is_ok());
}
#[test]
fn test_codebook_cache_invalid_params() {
let mut cache = CodebookCache::new();
assert!(cache.get_or_generate(0, 2).is_err());
assert!(cache.get_or_generate(64, 0).is_err());
}
#[test]
fn test_codebook_large_dim_gaussian_path() {
let cb = generate_codebook(256, 4, 100).unwrap();
assert_eq!(cb.centroids.len(), 16);
for w in cb.centroids.windows(2) {
assert!(w[0] < w[1], "centroids not sorted: {} >= {}", w[0], w[1]);
}
for c in &cb.centroids {
assert!(c.abs() < 0.5, "centroid {} too far from 0 for dim=256", c);
}
}
#[test]
fn test_codebook_boundary_midpoints() {
let cb = generate_codebook(64, 2, 50).unwrap();
for (i, &b) in cb.boundaries.iter().enumerate() {
assert!(
b > cb.centroids[i] && b < cb.centroids[i + 1],
"boundary {} not between centroids {} and {}",
b,
cb.centroids[i],
cb.centroids[i + 1]
);
}
}
#[test]
fn test_codebook_8bit() {
let cb = generate_codebook(128, 8, 50).unwrap();
assert_eq!(cb.centroids.len(), 256);
assert_eq!(cb.boundaries.len(), 255);
}
#[test]
fn test_quantize_extreme_values() {
let cb = generate_codebook(64, 4, 50).unwrap();
let idx_pos = cb.quantize_scalar(100.0);
assert_eq!(idx_pos as usize, cb.centroids.len() - 1);
let idx_neg = cb.quantize_scalar(-100.0);
assert_eq!(idx_neg, 0);
}
#[test]
fn test_checked_quantize_rejects_non_finite() {
let cb = generate_codebook(64, 4, 50).unwrap();
assert!(matches!(
cb.checked_quantize_scalar(f64::NAN),
Err(TurboQuantError::InvalidValue { .. })
));
assert!(matches!(
cb.checked_quantize_scalar(f64::INFINITY),
Err(TurboQuantError::InvalidValue { .. })
));
}
#[test]
fn test_checked_dequantize_rejects_invalid_index() {
let cb = generate_codebook(64, 2, 50).unwrap();
assert!(matches!(
cb.checked_dequantize_scalar(4),
Err(TurboQuantError::InvalidQuantizationIndex { .. })
));
}
}