mod gen;
mod tables;
use crate::error::{require, Result, TurboQuantError};
use crate::packed::is_valid_bits;
pub use gen::{beta_pdf, generate_codebook};
pub(crate) const SUPPORT_MIN: f64 = -1.0;
pub(crate) const SUPPORT_MAX: f64 = 1.0;
#[derive(Debug, Clone)]
pub struct Codebook {
pub centroids: Vec<f64>,
pub boundaries: Vec<f64>,
}
fn validate_bits(bits: u8) -> Result<()> {
require(is_valid_bits(bits), TurboQuantError::UnsupportedBits(bits))
}
pub(crate) fn centroid_count(bits: u8) -> usize {
1usize << bits
}
fn boundary_binary_search(value: f64, boundaries: &[f64]) -> u8 {
let mut lo: usize = 0;
let mut hi: usize = boundaries.len();
while lo < hi {
let mid = lo + (hi - lo) / 2;
if value > boundaries[mid] {
lo = mid + 1;
} else {
hi = mid;
}
}
lo as u8
}
pub fn nearest_centroid(value: f64, codebook: &Codebook) -> u8 {
boundary_binary_search(value, &codebook.boundaries)
}
fn lookup_static_codebook(bits: u8, dim: usize) -> Option<Codebook> {
let sc = tables::lookup_static_codebook_ref(bits, dim)?;
Some(sc.to_codebook())
}
pub fn get_codebook(bits: u8, dim: usize) -> Result<Codebook> {
validate_bits(bits)?;
let maybe = lookup_static_codebook(bits, dim);
Ok(maybe.unwrap_or_else(|| gen::generate_codebook(bits, dim)))
}
#[cfg(test)]
mod tests {
use super::tables::*;
use super::*;
use approx::assert_relative_eq;
const TEST_DIM: usize = 128;
const INTEGRATION_STEPS: usize = 2048;
const TEST_CENTROIDS_8: usize = 8;
const TEST_CENTROIDS_16: usize = 16;
const TEST_BITS_3: u8 = 3;
const TEST_BITS_4: u8 = 4;
const TEST_DIMS: [usize; 3] = [64, 128, 256];
const TEST_X_VALUES: [f64; 5] = [0.0, 0.1, 0.3, 0.5, 0.9];
const TEST_CENTROIDS_4: usize = 4;
const TEST_BITS_2: u8 = 2;
const KNOWN_CODEBOOK_CONFIGS: [(u8, usize); 12] = [
(2, 32),
(2, 64),
(2, 128),
(2, 256),
(3, 32),
(3, 64),
(3, 128),
(3, 256),
(4, 32),
(4, 64),
(4, 128),
(4, 256),
];
use crate::math::{ln_gamma, simpsons_integrate, HALF};
fn beta_pdf_log_normalization(df: f64) -> f64 {
let half_df = df * HALF;
let half_df_minus_one = (df - 1.0) * HALF;
let half_ln_pi = HALF * core::f64::consts::PI.ln();
ln_gamma(half_df) - half_ln_pi - ln_gamma(half_df_minus_one)
}
fn interval_midpoint(a: f64, b: f64) -> f64 {
(a + b) * HALF
}
const EPSILON_ZERO: f64 = 1e-30;
fn is_near_zero(value: f64) -> bool {
value.abs() < EPSILON_ZERO
}
#[test]
fn beta_pdf_integrates_to_approximately_one() {
let d = TEST_DIM;
let integral = simpsons_integrate(
|x| beta_pdf(x, d),
SUPPORT_MIN,
SUPPORT_MAX,
INTEGRATION_STEPS,
);
assert_relative_eq!(integral, 1.0, epsilon = 1e-4);
}
#[test]
fn beta_pdf_is_symmetric() {
for d in TEST_DIMS {
for &x in &TEST_X_VALUES {
assert_relative_eq!(beta_pdf(x, d), beta_pdf(-x, d), epsilon = 1e-12);
}
}
}
#[test]
fn beta_pdf_zero_at_boundary() {
for d in TEST_DIMS {
assert_relative_eq!(beta_pdf(SUPPORT_MAX, d), 0.0, epsilon = 1e-15);
assert_relative_eq!(beta_pdf(SUPPORT_MIN, d), 0.0, epsilon = 1e-15);
}
}
#[test]
fn beta_pdf_zero_outside_support() {
assert_relative_eq!(beta_pdf(1.5, TEST_DIM), 0.0, epsilon = 1e-15);
assert_relative_eq!(beta_pdf(-2.0, TEST_DIM), 0.0, epsilon = 1e-15);
}
#[test]
fn beta_pdf_zero_for_low_dimension() {
assert_relative_eq!(beta_pdf(0.0, 2), 0.0, epsilon = 1e-15);
assert_relative_eq!(beta_pdf(0.0, 1), 0.0, epsilon = 1e-15);
}
#[test]
fn boundary_binary_search_first_bin() {
let boundaries = vec![-0.5, 0.0, 0.5];
assert_eq!(boundary_binary_search(-0.9, &boundaries), 0);
}
#[test]
fn boundary_binary_search_last_bin() {
let boundaries = vec![-0.5, 0.0, 0.5];
assert_eq!(boundary_binary_search(0.9, &boundaries), 3);
}
#[test]
fn boundary_binary_search_middle() {
let boundaries = vec![-0.5, 0.0, 0.5];
assert_eq!(boundary_binary_search(-0.1, &boundaries), 1);
assert_eq!(boundary_binary_search(0.1, &boundaries), 2);
}
#[test]
fn interval_midpoint_basic() {
assert_relative_eq!(interval_midpoint(0.0, 1.0), 0.5, epsilon = 1e-15);
assert_relative_eq!(interval_midpoint(-1.0, 1.0), 0.0, epsilon = 1e-15);
}
#[test]
fn is_near_zero_true_for_tiny() {
assert!(is_near_zero(1e-31));
assert!(is_near_zero(-1e-31));
assert!(is_near_zero(0.0));
}
#[test]
fn is_near_zero_false_for_normal() {
assert!(!is_near_zero(1e-10));
assert!(!is_near_zero(-0.001));
}
#[test]
fn lookup_known_configs_return_some() {
for &(bits, dim) in &KNOWN_CODEBOOK_CONFIGS {
assert!(
lookup_static_codebook_ref(bits, dim).is_some(),
"expected Some for ({bits}, {dim})"
);
assert!(lookup_static_codebook(bits, dim).is_some());
}
}
#[test]
fn lookup_unknown_config_returns_none() {
assert!(lookup_static_codebook_ref(3, 512).is_none());
assert!(lookup_static_codebook(4, 16).is_none());
}
#[test]
fn centroid_count_3bit() {
assert_eq!(centroid_count(TEST_BITS_3), TEST_CENTROIDS_8);
}
#[test]
fn centroid_count_4bit() {
assert_eq!(centroid_count(TEST_BITS_4), TEST_CENTROIDS_16);
}
#[test]
fn to_codebook_copies_correctly() {
let sc = &CODEBOOK_3BIT_D64;
let cb = sc.to_codebook();
assert_eq!(cb.centroids.len(), sc.centroids.len());
assert_eq!(cb.boundaries.len(), sc.boundaries.len());
for (a, b) in cb.centroids.iter().zip(sc.centroids.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-15);
}
for (a, b) in cb.boundaries.iter().zip(sc.boundaries.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-15);
}
}
#[test]
fn validate_bits_accepts_2_3_and_4() {
assert!(validate_bits(TEST_BITS_2).is_ok());
assert!(validate_bits(TEST_BITS_3).is_ok());
assert!(validate_bits(TEST_BITS_4).is_ok());
}
#[test]
fn validate_bits_rejects_others() {
assert!(validate_bits(0).is_err());
assert!(validate_bits(1).is_err());
assert!(validate_bits(5).is_err());
}
#[test]
fn centroid_count_2bit() {
assert_eq!(centroid_count(TEST_BITS_2), TEST_CENTROIDS_4);
}
#[test]
fn beta_pdf_log_normalization_positive_for_high_d() {
let ln_norm = beta_pdf_log_normalization(TEST_DIM as f64);
assert!(ln_norm > 0.0);
}
}