use ndarray::{Array1, Array2};
use statrs::distribution::{Continuous, ContinuousCDF, Normal};
pub fn optimal_centroids(bit_width: usize, d: usize) -> Array1<f64> {
let n_centroids = 1usize << bit_width;
if bit_width == 1 {
let c = (2.0 / (std::f64::consts::PI * d as f64)).sqrt();
return Array1::from_vec(vec![-c, c]);
}
if bit_width == 2 {
let inv_sqrt_d = 1.0 / (d as f64).sqrt();
return Array1::from_vec(vec![
-1.51 * inv_sqrt_d,
-0.453 * inv_sqrt_d,
0.453 * inv_sqrt_d,
1.51 * inv_sqrt_d,
]);
}
lloyds_gaussian(n_centroids, 1.0 / (d as f64).sqrt(), 100)
}
fn lloyds_gaussian(n_centroids: usize, sigma: f64, n_iter: usize) -> Array1<f64> {
let normal = Normal::new(0.0, sigma).unwrap();
let mut boundaries = Vec::with_capacity(n_centroids - 1);
for i in 1..n_centroids {
let p = i as f64 / n_centroids as f64;
boundaries.push(normal.inverse_cdf(p));
}
let mut centroids = vec![0.0f64; n_centroids];
centroids[0] = gaussian_conditional_expectation(sigma, f64::NEG_INFINITY, boundaries[0]);
for i in 1..(n_centroids - 1) {
centroids[i] = gaussian_conditional_expectation(sigma, boundaries[i - 1], boundaries[i]);
}
centroids[n_centroids - 1] =
gaussian_conditional_expectation(sigma, boundaries[n_centroids - 2], f64::INFINITY);
for _ in 0..n_iter {
for i in 0..(n_centroids - 1) {
boundaries[i] = (centroids[i] + centroids[i + 1]) / 2.0;
}
centroids[0] = gaussian_conditional_expectation(sigma, f64::NEG_INFINITY, boundaries[0]);
for i in 1..(n_centroids - 1) {
centroids[i] =
gaussian_conditional_expectation(sigma, boundaries[i - 1], boundaries[i]);
}
centroids[n_centroids - 1] =
gaussian_conditional_expectation(sigma, boundaries[n_centroids - 2], f64::INFINITY);
}
centroids.sort_by(|a, b| a.partial_cmp(b).unwrap());
Array1::from_vec(centroids)
}
fn gaussian_conditional_expectation(sigma: f64, a: f64, b: f64) -> f64 {
let std_normal = Normal::new(0.0, 1.0).unwrap();
let a_std = if a.is_finite() { a / sigma } else { a };
let b_std = if b.is_finite() { b / sigma } else { b };
let prob = if !a_std.is_finite() && a_std < 0.0 {
std_normal.cdf(b_std)
} else if !b_std.is_finite() && b_std > 0.0 {
std_normal.cdf(-a_std)
} else {
std_normal.cdf(b_std) - std_normal.cdf(a_std)
};
if prob < 1e-15 {
if a.is_finite() && !b.is_finite() {
return a + sigma; } else if !a.is_finite() && b.is_finite() {
return b - sigma;
} else if a.is_finite() && b.is_finite() {
return (a + b) / 2.0;
} else {
return 0.0;
}
}
let pdf_diff = std_normal.pdf(a_std) - std_normal.pdf(b_std);
sigma * pdf_diff / prob
}
pub fn nearest_centroid_indices(values: &Array1<f64>, centroids: &Array1<f64>) -> Array1<usize> {
let boundaries = compute_boundaries(centroids);
Array1::from_shape_fn(values.len(), |i| {
boundaries.partition_point(|&b| b < values[i])
})
}
pub fn nearest_centroid_indices_batch(
values: &Array2<f64>,
centroids: &Array1<f64>,
) -> Array2<usize> {
let boundaries = compute_boundaries(centroids);
let (batch, d) = values.dim();
Array2::from_shape_fn((batch, d), |(i, j)| {
boundaries.partition_point(|&b| b < values[[i, j]])
})
}
fn compute_boundaries(centroids: &Array1<f64>) -> Vec<f64> {
let n = centroids.len();
(0..n - 1)
.map(|i| (centroids[i] + centroids[i + 1]) / 2.0)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_1bit_centroids_symmetric() {
let c = optimal_centroids(1, 128);
assert_eq!(c.len(), 2);
assert!((c[0] + c[1]).abs() < 1e-15, "1-bit centroids should be symmetric");
}
#[test]
fn test_2bit_centroids_count() {
let c = optimal_centroids(2, 128);
assert_eq!(c.len(), 4);
for i in 1..4 {
assert!(c[i] > c[i - 1], "Centroids should be sorted");
}
}
#[test]
fn test_3bit_centroids_via_lloyds() {
let c = optimal_centroids(3, 128);
assert_eq!(c.len(), 8);
for i in 1..8 {
assert!(c[i] > c[i - 1], "Centroids should be sorted");
}
for i in 0..4 {
assert!(
(c[i] + c[7 - i]).abs() < 1e-6,
"Centroids should be approximately symmetric: {} vs {}",
c[i],
c[7 - i]
);
}
}
#[test]
fn test_nearest_centroid_basic() {
let centroids = Array1::from_vec(vec![-1.0, 0.0, 1.0]);
let values = Array1::from_vec(vec![-0.8, 0.1, 0.7, -0.1]);
let indices = nearest_centroid_indices(&values, ¢roids);
assert_eq!(indices[0], 0); assert_eq!(indices[1], 1); assert_eq!(indices[2], 2); assert_eq!(indices[3], 1); }
}