use nalgebra::{DMatrix, DVector};
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
use serde::{Deserialize, Serialize};
use crate::codebook::Codebook;
use crate::error::{Result, TurboQuantError};
use crate::scalar_quant::ScalarQuantizer;
use crate::utils::validate_finite_vector;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolarCoords {
pub radius: f64,
pub angles: Vec<Vec<f64>>,
pub dim: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[must_use]
pub struct PolarQuantized {
pub radius_idx: u8,
pub angle_indices: Vec<Vec<u8>>,
pub bit_width: u8,
pub dim: usize,
pub num_levels: usize,
}
impl PolarQuantized {
pub fn bytes(&self) -> f64 {
let total_angles: usize = self.angle_indices.iter().map(|v| v.len()).sum();
((total_angles + 1) as f64 * self.bit_width as f64) / 8.0
}
pub fn compression_ratio(&self) -> f64 {
(self.dim as f64 * 4.0) / self.bytes()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PolarQuant {
preconditioner: DMatrix<f64>,
codebooks: Vec<Codebook>,
radius_codebook: Codebook,
pub dim: usize,
pub num_levels: usize,
pub bit_width: u8,
}
impl PolarQuant {
fn expected_angle_count(&self, level: usize) -> usize {
self.dim >> (level + 1)
}
fn validate_polar_coords(&self, p: &PolarCoords) -> Result<()> {
if p.dim != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: p.dim,
});
}
if !p.radius.is_finite() || p.radius < 0.0 {
return Err(TurboQuantError::InvalidValue {
context: "polar radius".into(),
value: p.radius,
});
}
if p.angles.len() != self.num_levels {
return Err(TurboQuantError::LengthMismatch {
context: "polar angle levels".into(),
expected: self.num_levels,
got: p.angles.len(),
});
}
for (level, angles) in p.angles.iter().enumerate() {
let expected = self.expected_angle_count(level);
if angles.len() != expected {
return Err(TurboQuantError::LengthMismatch {
context: format!("polar angles at level {level}"),
expected,
got: angles.len(),
});
}
validate_finite_vector(angles, "polar angle")?;
}
Ok(())
}
fn validate_quantized(&self, q: &PolarQuantized) -> Result<()> {
if q.dim != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: q.dim,
});
}
if q.bit_width != self.bit_width {
return Err(TurboQuantError::BitWidthMismatch {
expected: self.bit_width,
got: q.bit_width,
});
}
if q.num_levels != self.num_levels {
return Err(TurboQuantError::LengthMismatch {
context: "polar quantization level count".into(),
expected: self.num_levels,
got: q.num_levels,
});
}
if q.angle_indices.len() != self.num_levels {
return Err(TurboQuantError::LengthMismatch {
context: "polar angle index levels".into(),
expected: self.num_levels,
got: q.angle_indices.len(),
});
}
let radius_quantizer = ScalarQuantizer::from_codebook(self.radius_codebook.clone());
radius_quantizer.validate_indices(&[q.radius_idx])?;
for (level, indices) in q.angle_indices.iter().enumerate() {
let expected = self.expected_angle_count(level);
if indices.len() != expected {
return Err(TurboQuantError::LengthMismatch {
context: format!("polar angle indices at level {level}"),
expected,
got: indices.len(),
});
}
let quantizer = ScalarQuantizer::from_codebook(self.codebooks[level].clone());
quantizer.validate_indices(indices)?;
}
Ok(())
}
pub fn new(dim: usize, seed: u64, bit_width: u8) -> Result<Self> {
if dim == 0 || dim & (dim - 1) != 0 {
return Err(TurboQuantError::InvalidDimension(dim));
}
if !(1..=8).contains(&bit_width) {
return Err(TurboQuantError::InvalidBitWidth(bit_width));
}
let num_levels = (dim as f64).log2() as usize;
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).unwrap();
let data: Vec<f64> = (0..dim * dim).map(|_| normal.sample(&mut rng)).collect();
let g = DMatrix::from_vec(dim, dim, data);
let qr = g.qr();
let preconditioner = qr.q();
let mut codebooks = Vec::with_capacity(num_levels);
for _level in 0..num_levels {
let k = 1usize << bit_width;
let centroids: Vec<f64> = (0..k)
.map(|i| {
let u = (i as f64 + 0.5) / k as f64;
u * 2.0 * std::f64::consts::PI - std::f64::consts::PI
})
.collect();
let boundaries: Vec<f64> = centroids.windows(2).map(|w| (w[0] + w[1]) / 2.0).collect();
codebooks.push(Codebook {
centroids,
boundaries,
bit_width,
});
}
let k = 1usize << bit_width;
let centroids: Vec<f64> = (0..k).map(|i| (i as f64 + 0.5) / k as f64 * 2.0).collect();
let boundaries: Vec<f64> = centroids.windows(2).map(|w| (w[0] + w[1]) / 2.0).collect();
let radius_codebook = Codebook {
centroids,
boundaries,
bit_width,
};
Ok(Self {
preconditioner,
codebooks,
radius_codebook,
dim,
num_levels,
bit_width,
})
}
fn precondition(&self, x: &[f64]) -> Vec<f64> {
let xv = DVector::from_vec(x.to_vec());
let y = &self.preconditioner * xv;
y.data.into()
}
fn precondition_inverse(&self, x: &[f64]) -> Vec<f64> {
let xv = DVector::from_vec(x.to_vec());
let y = self.preconditioner.transpose() * xv;
y.data.into()
}
pub fn to_polar(&self, x: &[f64]) -> Result<PolarCoords> {
if x.len() != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: x.len(),
});
}
validate_finite_vector(x, "PolarQuant input")?;
let x_pre = self.precondition(x);
let mut current = x_pre.clone();
let mut all_angles: Vec<Vec<f64>> = Vec::new();
for _level in 0..self.num_levels {
let n = current.len();
let mut next_radii = Vec::with_capacity(n / 2);
let mut level_angles = Vec::with_capacity(n / 2);
for j in 0..n / 2 {
let a = current[2 * j];
let b = current[2 * j + 1];
let r = (a * a + b * b).sqrt();
let theta = b.atan2(a); next_radii.push(r);
level_angles.push(theta);
}
all_angles.push(level_angles);
current = next_radii;
}
let radius = current[0];
Ok(PolarCoords {
radius,
angles: all_angles,
dim: self.dim,
})
}
pub fn from_polar(&self, p: &PolarCoords) -> Result<Vec<f64>> {
self.validate_polar_coords(p)?;
let mut current = vec![p.radius];
for level in (0..self.num_levels).rev() {
let angles = &p.angles[level];
let n = angles.len();
let mut prev = Vec::with_capacity(n * 2);
for j in 0..n {
let r = current[j];
let theta = angles[j];
prev.push(r * theta.cos());
prev.push(r * theta.sin());
}
current = prev;
}
let result = self.precondition_inverse(¤t);
Ok(result)
}
pub fn quantize(&self, x: &[f64]) -> Result<PolarQuantized> {
let polar = self.to_polar(x)?;
let radius_idx = self.quantize_angle(&self.radius_codebook, polar.radius);
let angle_indices: Vec<Vec<u8>> = polar
.angles
.iter()
.enumerate()
.map(|(level, angles)| {
let cb = &self.codebooks[level.min(self.codebooks.len() - 1)];
angles.iter().map(|&a| self.quantize_angle(cb, a)).collect()
})
.collect();
Ok(PolarQuantized {
radius_idx,
angle_indices,
bit_width: self.bit_width,
dim: self.dim,
num_levels: self.num_levels,
})
}
pub fn dequantize(&self, q: &PolarQuantized) -> Result<Vec<f64>> {
self.validate_quantized(q)?;
let radius = self.radius_codebook.dequantize_scalar(q.radius_idx);
let angles: Vec<Vec<f64>> = q
.angle_indices
.iter()
.enumerate()
.map(|(level, indices)| {
let cb = &self.codebooks[level.min(self.codebooks.len() - 1)];
indices.iter().map(|&i| cb.dequantize_scalar(i)).collect()
})
.collect();
let polar = PolarCoords {
radius,
angles,
dim: self.dim,
};
self.from_polar(&polar)
}
fn quantize_angle(&self, cb: &Codebook, angle: f64) -> u8 {
cb.quantize_scalar(angle)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn random_vector(dim: usize, seed: u64) -> Vec<f64> {
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).unwrap();
(0..dim).map(|_| normal.sample(&mut rng)).collect()
}
#[test]
fn test_polar_roundtrip() {
let dim = 8;
let pq = PolarQuant::new(dim, 42, 4).unwrap();
let x = random_vector(dim, 1);
let polar = pq.to_polar(&x).unwrap();
let x_back = pq.from_polar(&polar).unwrap();
for (a, b) in x.iter().zip(x_back.iter()) {
assert!((a - b).abs() < 1e-8, "a={}, b={}", a, b);
}
}
#[test]
fn test_quantize_dequantize_shape() {
let dim = 16;
let pq = PolarQuant::new(dim, 7, 4).unwrap();
let x = random_vector(dim, 5);
let q = pq.quantize(&x).unwrap();
let recon = pq.dequantize(&q).unwrap();
assert_eq!(recon.len(), dim);
}
#[test]
fn test_polar_angles_count() {
let dim = 8;
let pq = PolarQuant::new(dim, 1, 2).unwrap();
let x = random_vector(dim, 2);
let polar = pq.to_polar(&x).unwrap();
let total_angles: usize = polar.angles.iter().map(|v| v.len()).sum();
assert_eq!(
total_angles,
dim - 1,
"expected {} angles, got {}",
dim - 1,
total_angles
);
}
#[test]
fn test_invalid_dim_not_power_of_2() {
assert!(PolarQuant::new(10, 1, 2).is_err());
}
#[test]
fn test_invalid_dim_zero() {
assert!(PolarQuant::new(0, 1, 2).is_err());
}
#[test]
fn test_invalid_bit_width() {
assert!(PolarQuant::new(8, 1, 0).is_err());
assert!(PolarQuant::new(8, 1, 9).is_err());
}
#[test]
fn test_dequantize_dimension_mismatch() {
let pq = PolarQuant::new(8, 42, 4).unwrap();
let bad_q = PolarQuantized {
radius_idx: 0,
angle_indices: vec![vec![0; 8], vec![0; 4], vec![0; 2]],
bit_width: 4,
dim: 16, num_levels: 4,
};
assert!(pq.dequantize(&bad_q).is_err());
}
#[test]
fn test_to_polar_dimension_mismatch() {
let pq = PolarQuant::new(8, 42, 4).unwrap();
let bad_x = vec![1.0; 16]; assert!(pq.to_polar(&bad_x).is_err());
}
#[test]
fn test_polar_compression_ratio() {
let dim = 16;
let pq = PolarQuant::new(dim, 42, 4).unwrap();
let x = random_vector(dim, 1);
let q = pq.quantize(&x).unwrap();
let ratio = q.compression_ratio();
assert!(ratio > 1.0, "compression ratio should be > 1: {}", ratio);
}
#[test]
fn test_from_polar_rejects_missing_levels() {
let pq = PolarQuant::new(8, 42, 4).unwrap();
let coords = PolarCoords {
radius: 1.0,
angles: vec![vec![0.0; 4], vec![0.0; 2]],
dim: 8,
};
assert!(matches!(
pq.from_polar(&coords),
Err(TurboQuantError::LengthMismatch { .. })
));
}
}