use ndarray::{Array1, Array2, Axis};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use crate::codebook::{nearest_centroid_indices_batch, optimal_centroids};
use crate::rotation::random_rotation_dense;
pub struct PolarQuant {
pub d: usize,
pub bit_width: usize,
pub n_centroids: usize,
pub norm_correction: bool,
rotation: Array2<f64>,
centroids: Array1<f64>,
}
impl PolarQuant {
pub fn new(d: usize, bit_width: usize, seed: u64, norm_correction: bool) -> Self {
let n_centroids = 1usize << bit_width;
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let rotation = random_rotation_dense(d, &mut rng);
let centroids = optimal_centroids(bit_width, d);
Self {
d,
bit_width,
n_centroids,
norm_correction,
rotation,
centroids,
}
}
pub fn quantize(&self, x: &Array1<f64>) -> (Array1<usize>, f64) {
let (indices, norms) = self.quantize_batch(&x.clone().insert_axis(Axis(0)));
(indices.row(0).to_owned(), norms[0])
}
pub fn quantize_batch(&self, x: &Array2<f64>) -> (Array2<usize>, Array1<f64>) {
let batch = x.nrows();
let mut norms = Array1::zeros(batch);
for i in 0..batch {
norms[i] = x.row(i).dot(&x.row(i)).sqrt();
}
let mut x_normalized = x.clone();
for i in 0..batch {
let safe_norm = if norms[i] > 0.0 { norms[i] } else { 1.0 };
for j in 0..self.d {
x_normalized[[i, j]] /= safe_norm;
}
}
let y = x_normalized.dot(&self.rotation.t());
let indices = nearest_centroid_indices_batch(&y, &self.centroids);
(indices, norms)
}
pub fn dequantize(&self, indices: &Array1<usize>, norm: f64) -> Array1<f64> {
let norms = Array1::from_vec(vec![norm]);
let indices_2d = indices.clone().insert_axis(Axis(0));
let result = self.dequantize_batch(&indices_2d, &norms);
result.row(0).to_owned()
}
pub fn dequantize_batch(&self, indices: &Array2<usize>, norms: &Array1<f64>) -> Array2<f64> {
let (batch, d) = indices.dim();
let mut y_hat = Array2::zeros((batch, d));
for i in 0..batch {
for j in 0..d {
y_hat[[i, j]] = self.centroids[indices[[i, j]]];
}
}
if self.norm_correction {
for i in 0..batch {
let y_norm = y_hat.row(i).dot(&y_hat.row(i)).sqrt();
let safe_norm = if y_norm > 1e-10 { y_norm } else { 1.0 };
for j in 0..d {
y_hat[[i, j]] /= safe_norm;
}
}
}
let x_hat_unit = y_hat.dot(&self.rotation);
let mut x_hat = x_hat_unit;
for i in 0..batch {
for j in 0..d {
x_hat[[i, j]] *= norms[i];
}
}
x_hat
}
pub fn quantize_and_residual(
&self,
x: &Array1<f64>,
) -> (Array1<usize>, f64, Array1<f64>) {
let (indices, norm) = self.quantize(x);
let x_hat = self.dequantize(&indices, norm);
let residual = x - &x_hat;
(indices, norm, residual)
}
pub fn quantize_and_residual_batch(
&self,
x: &Array2<f64>,
) -> (Array2<usize>, Array1<f64>, Array2<f64>) {
let (indices, norms) = self.quantize_batch(x);
let x_hat = self.dequantize_batch(&indices, &norms);
let residual = x - &x_hat;
(indices, norms, residual)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn synthetic_batch(batch: usize, d: usize) -> Array2<f64> {
Array2::from_shape_fn((batch, d), |(i, j)| {
let phase = (i * 13 + j * 7) as f64;
let base = (phase / d as f64).sin() + 0.3 * (phase / 11.0).cos();
let scale = 0.5 + (i % 5) as f64 * 0.35;
base * scale
})
}
#[test]
fn test_quantize_dequantize_roundtrip() {
let d = 128;
let pq = PolarQuant::new(d, 2, 42, true);
let x = synthetic_batch(64, d);
let (indices, norms) = pq.quantize_batch(&x);
let x_hat = pq.dequantize_batch(&indices, &norms);
assert_eq!(indices.dim(), (64, d));
assert_eq!(x_hat.dim(), (64, d));
let error: f64 = (&x - &x_hat).mapv(|v| v * v).sum();
let original: f64 = x.mapv(|v| v * v).sum();
let relative_error = error / original;
assert!(
relative_error < 0.75,
"Relative MSE too large: {relative_error}"
);
}
#[test]
fn test_zero_vector() {
let pq = PolarQuant::new(16, 2, 42, true);
let x = Array1::zeros(16);
let (indices, norm) = pq.quantize(&x);
let x_hat = pq.dequantize(&indices, norm);
let error: f64 = x_hat.mapv(|v| v * v).sum();
assert!(error < 1e-20, "Zero vector reconstruction error: {error}");
}
#[test]
fn test_batch_matches_single() {
let pq = PolarQuant::new(16, 2, 42, true);
let x1 = Array1::from_shape_fn(16, |i| (i as f64 + 1.0) / 16.0);
let x2 = Array1::from_shape_fn(16, |i| -(i as f64 + 1.0) / 16.0);
let (idx1, n1) = pq.quantize(&x1);
let (idx2, n2) = pq.quantize(&x2);
let mut batch = Array2::zeros((2, 16));
batch.row_mut(0).assign(&x1);
batch.row_mut(1).assign(&x2);
let (batch_idx, batch_norms) = pq.quantize_batch(&batch);
assert_eq!(idx1, batch_idx.row(0).to_owned());
assert_eq!(idx2, batch_idx.row(1).to_owned());
assert!((n1 - batch_norms[0]).abs() < 1e-15);
assert!((n2 - batch_norms[1]).abs() < 1e-15);
}
}