use super::{centroid_count, Codebook, SUPPORT_MAX, SUPPORT_MIN};
use crate::math::{converge, ln_gamma, simpsons_integrate, HALF};
const MAX_ITERATIONS: usize = 200;
const CONVERGENCE_EPS: f64 = 1e-12;
const INTEGRATION_STEPS: usize = 1024;
const EPSILON_ZERO: f64 = 1e-30;
const MIN_DIMENSION_FOR_PDF: usize = 3;
const KERNEL_EXPONENT_OFFSET: f64 = 3.0;
pub fn beta_pdf(x: f64, d: usize) -> f64 {
if d < MIN_DIMENSION_FOR_PDF {
return 0.0;
}
let df = d as f64;
let exponent = (df - KERNEL_EXPONENT_OFFSET) * HALF;
let one_minus_x2 = 1.0 - x * x;
if one_minus_x2 <= 0.0 {
return 0.0;
}
let kernel = one_minus_x2.powf(exponent);
let half_df = df * HALF;
let half_df_minus_one = (df - 1.0) * HALF;
let half_ln_pi = HALF * core::f64::consts::PI.ln();
let log_norm = ln_gamma(half_df) - half_ln_pi - ln_gamma(half_df_minus_one);
log_norm.exp() * kernel
}
fn initialize_centroids(k: usize) -> Vec<f64> {
let range = SUPPORT_MAX - SUPPORT_MIN; (0..k)
.map(|i| SUPPORT_MIN + (range * (i as f64 + HALF)) / k as f64)
.collect()
}
fn midpoint_boundaries(centroids: &[f64]) -> Vec<f64> {
centroids.windows(2).map(|w| (w[0] + w[1]) * HALF).collect()
}
fn bin_lower_bound(i: usize, boundaries: &[f64]) -> f64 {
if i == 0 {
SUPPORT_MIN
} else {
boundaries[i - 1]
}
}
fn bin_upper_bound(i: usize, k: usize, boundaries: &[f64]) -> f64 {
if i == k - 1 {
SUPPORT_MAX
} else {
boundaries[i]
}
}
fn has_converged(prev_distortion: f64, distortion: f64) -> bool {
(prev_distortion - distortion).abs() < CONVERGENCE_EPS * prev_distortion.abs().max(EPSILON_ZERO)
}
fn select_conditional_or_midpoint(numerator: f64, denominator: f64, a: f64, b: f64) -> f64 {
if denominator.abs() < EPSILON_ZERO {
(a + b) * HALF
} else {
numerator / denominator
}
}
fn integrate<F: Fn(f64) -> f64>(f: F, a: f64, b: f64) -> f64 {
simpsons_integrate(f, a, b, INTEGRATION_STEPS)
}
fn integrate_pdf(a: f64, b: f64, d: usize) -> f64 {
integrate(|x| beta_pdf(x, d), a, b)
}
fn integrate_x_pdf(a: f64, b: f64, d: usize) -> f64 {
integrate(|x| x * beta_pdf(x, d), a, b)
}
fn conditional_expectation(a: f64, b: f64, d: usize) -> f64 {
let denom = integrate_pdf(a, b, d);
let numer = integrate_x_pdf(a, b, d);
select_conditional_or_midpoint(numer, denom, a, b)
}
fn bin_distortion(lo: f64, hi: f64, c: f64, d: usize) -> f64 {
integrate(|x| (x - c).powi(2) * beta_pdf(x, d), lo, hi)
}
fn compute_distortion(centroids: &[f64], boundaries: &[f64], d: usize) -> f64 {
let k = centroids.len();
centroids
.iter()
.enumerate()
.map(|(i, ¢roid)| {
let lo = bin_lower_bound(i, boundaries);
let hi = bin_upper_bound(i, k, boundaries);
bin_distortion(lo, hi, centroid, d)
})
.sum()
}
fn update_centroids(centroids_len: usize, boundaries: &[f64], d: usize) -> Vec<f64> {
(0..centroids_len)
.map(|i| {
let lo = bin_lower_bound(i, boundaries);
let hi = bin_upper_bound(i, centroids_len, boundaries);
conditional_expectation(lo, hi, d)
})
.collect()
}
fn lloyd_max_step(centroids: &[f64], prev_distortion: f64, d: usize) -> (Vec<f64>, f64, bool) {
let boundaries = midpoint_boundaries(centroids);
let new_centroids = update_centroids(centroids.len(), &boundaries, d);
let distortion = compute_distortion(&new_centroids, &boundaries, d);
let converged = has_converged(prev_distortion, distortion);
(new_centroids, distortion, converged)
}
fn lloyd_max_iterate(mut centroids: Vec<f64>, d: usize) -> Codebook {
let mut prev_distortion = f64::MAX;
converge(MAX_ITERATIONS, || {
let (new_centroids, distortion, converged) = lloyd_max_step(¢roids, prev_distortion, d);
centroids = new_centroids;
prev_distortion = distortion;
converged
});
let boundaries = midpoint_boundaries(¢roids);
Codebook {
centroids,
boundaries,
}
}
pub fn generate_codebook(bits: u8, dim: usize) -> Codebook {
let k = centroid_count(bits);
let centroids = initialize_centroids(k);
lloyd_max_iterate(centroids, dim)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
const TEST_DIM: usize = 128;
const TEST_CENTROIDS_8: usize = 8;
const TEST_CENTROIDS_16: usize = 16;
const TEST_BITS_3: u8 = 3;
const TEST_DIM_64: usize = 64;
const TEST_NUMERATOR: f64 = 3.0;
const TEST_DENOMINATOR: f64 = 2.0;
const TEST_NEAR_ZERO_DENOM: f64 = 1e-31;
#[test]
fn initialize_centroids_correct_count() {
assert_eq!(
initialize_centroids(TEST_CENTROIDS_8).len(),
TEST_CENTROIDS_8
);
assert_eq!(
initialize_centroids(TEST_CENTROIDS_16).len(),
TEST_CENTROIDS_16
);
}
#[test]
fn initialize_centroids_sorted() {
let c = initialize_centroids(TEST_CENTROIDS_8);
for w in c.windows(2) {
assert!(w[0] < w[1]);
}
}
#[test]
fn initialize_centroids_symmetric() {
let c = initialize_centroids(TEST_CENTROIDS_8);
let half = TEST_CENTROIDS_8 / 2;
for i in 0..half {
assert_relative_eq!(c[i], -c[TEST_CENTROIDS_8 - 1 - i], epsilon = 1e-14);
}
}
#[test]
fn initialize_centroids_within_support() {
let c = initialize_centroids(TEST_CENTROIDS_16);
for &v in &c {
assert!(v > SUPPORT_MIN && v < SUPPORT_MAX);
}
}
#[test]
fn midpoint_boundaries_correct_values() {
let centroids = vec![-0.5, 0.0, 0.5];
let b = midpoint_boundaries(¢roids);
assert_eq!(b.len(), 2);
assert_relative_eq!(b[0], -0.25, epsilon = 1e-14);
assert_relative_eq!(b[1], 0.25, epsilon = 1e-14);
}
#[test]
fn bin_lower_bound_first() {
let boundaries = vec![0.0];
assert_relative_eq!(
bin_lower_bound(0, &boundaries),
SUPPORT_MIN,
epsilon = 1e-15
);
}
#[test]
fn bin_lower_bound_second() {
let boundaries = vec![0.0];
assert_relative_eq!(bin_lower_bound(1, &boundaries), 0.0, epsilon = 1e-15);
}
#[test]
fn bin_upper_bound_last() {
let boundaries = vec![0.0];
assert_relative_eq!(
bin_upper_bound(1, 2, &boundaries),
SUPPORT_MAX,
epsilon = 1e-15
);
}
#[test]
fn bin_upper_bound_first() {
let boundaries = vec![0.0];
assert_relative_eq!(bin_upper_bound(0, 2, &boundaries), 0.0, epsilon = 1e-15);
}
#[test]
fn has_converged_identical_values() {
assert!(has_converged(1.0, 1.0));
}
#[test]
fn has_converged_large_change() {
assert!(!has_converged(1.0, 0.5));
}
#[test]
fn select_conditional_or_midpoint_normal_case() {
let result = select_conditional_or_midpoint(TEST_NUMERATOR, TEST_DENOMINATOR, 0.0, 1.0);
assert_relative_eq!(result, TEST_NUMERATOR / TEST_DENOMINATOR, epsilon = 1e-15);
}
#[test]
fn select_conditional_or_midpoint_near_zero_denom() {
let result = select_conditional_or_midpoint(TEST_NUMERATOR, TEST_NEAR_ZERO_DENOM, 0.0, 1.0);
assert_relative_eq!(result, 0.5, epsilon = 1e-15);
}
#[test]
fn conditional_expectation_symmetric_interval() {
let result = conditional_expectation(SUPPORT_MIN, SUPPORT_MAX, TEST_DIM);
assert_relative_eq!(result, 0.0, epsilon = 1e-8);
}
#[test]
fn compute_distortion_nonnegative() {
let centroids = vec![-0.5, 0.0, 0.5];
let boundaries = vec![-0.25, 0.25];
let d = compute_distortion(¢roids, &boundaries, TEST_DIM);
assert!(d >= 0.0);
}
#[test]
fn update_centroids_correct_count() {
let boundaries = midpoint_boundaries(&initialize_centroids(TEST_CENTROIDS_8));
let updated = update_centroids(TEST_CENTROIDS_8, &boundaries, TEST_DIM);
assert_eq!(updated.len(), TEST_CENTROIDS_8);
}
#[test]
fn update_centroids_within_support() {
let boundaries = midpoint_boundaries(&initialize_centroids(TEST_CENTROIDS_8));
let updated = update_centroids(TEST_CENTROIDS_8, &boundaries, TEST_DIM);
for &c in &updated {
assert!((SUPPORT_MIN..=SUPPORT_MAX).contains(&c));
}
}
#[test]
fn generate_codebook_valid_structure() {
let cb = generate_codebook(TEST_BITS_3, TEST_DIM_64);
assert_eq!(cb.centroids.len(), TEST_CENTROIDS_8);
assert_eq!(cb.boundaries.len(), TEST_CENTROIDS_8 - 1);
for w in cb.centroids.windows(2) {
assert!(w[0] < w[1]);
}
}
#[test]
fn lloyd_max_step_reduces_distortion() {
let centroids = initialize_centroids(TEST_CENTROIDS_8);
let boundaries = midpoint_boundaries(¢roids);
let initial_dist = compute_distortion(¢roids, &boundaries, TEST_DIM);
let (new_centroids, new_dist, _) = lloyd_max_step(¢roids, f64::MAX, TEST_DIM);
assert!(new_dist <= initial_dist + 1e-15);
assert_eq!(new_centroids.len(), TEST_CENTROIDS_8);
}
}