mod gen;
use crate::error::{require, Result, TurboQuantError};
use crate::math::{ln_gamma, HALF};
use crate::packed::is_valid_bits;
pub use gen::generate_codebook;
pub(crate) const SUPPORT_MIN: f64 = -1.0;
pub(crate) const SUPPORT_MAX: f64 = 1.0;
const MIN_DIMENSION_FOR_PDF: usize = 3;
const KERNEL_EXPONENT_OFFSET: f64 = 3.0;
#[derive(Debug, Clone)]
pub struct Codebook {
pub centroids: Vec<f64>,
pub boundaries: Vec<f64>,
}
struct StaticCodebook {
centroids: &'static [f64],
boundaries: &'static [f64],
}
impl StaticCodebook {
fn to_codebook(&self) -> Codebook {
Codebook {
centroids: self.centroids.to_vec(),
boundaries: self.boundaries.to_vec(),
}
}
}
static CODEBOOK_2BIT_D64: StaticCodebook = StaticCodebook {
centroids: &CENTROIDS_2B_D64,
boundaries: &BOUNDARIES_2B_D64,
};
static CODEBOOK_2BIT_D128: StaticCodebook = StaticCodebook {
centroids: &CENTROIDS_2B_D128,
boundaries: &BOUNDARIES_2B_D128,
};
static CODEBOOK_2BIT_D256: StaticCodebook = StaticCodebook {
centroids: &CENTROIDS_2B_D256,
boundaries: &BOUNDARIES_2B_D256,
};
static CODEBOOK_3BIT_D64: StaticCodebook = StaticCodebook {
centroids: &CENTROIDS_3B_D64,
boundaries: &BOUNDARIES_3B_D64,
};
static CODEBOOK_3BIT_D128: StaticCodebook = StaticCodebook {
centroids: &CENTROIDS_3B_D128,
boundaries: &BOUNDARIES_3B_D128,
};
static CODEBOOK_3BIT_D256: StaticCodebook = StaticCodebook {
centroids: &CENTROIDS_3B_D256,
boundaries: &BOUNDARIES_3B_D256,
};
static CODEBOOK_4BIT_D64: StaticCodebook = StaticCodebook {
centroids: &CENTROIDS_4B_D64,
boundaries: &BOUNDARIES_4B_D64,
};
static CODEBOOK_4BIT_D128: StaticCodebook = StaticCodebook {
centroids: &CENTROIDS_4B_D128,
boundaries: &BOUNDARIES_4B_D128,
};
static CODEBOOK_4BIT_D256: StaticCodebook = StaticCodebook {
centroids: &CENTROIDS_4B_D256,
boundaries: &BOUNDARIES_4B_D256,
};
fn validate_bits(bits: u8) -> Result<()> {
require(is_valid_bits(bits), TurboQuantError::UnsupportedBits(bits))
}
pub(crate) fn centroid_count(bits: u8) -> usize {
1usize << bits
}
pub fn beta_pdf(x: f64, d: usize) -> f64 {
if d < MIN_DIMENSION_FOR_PDF {
return 0.0;
}
let df = d as f64;
let exponent = (df - KERNEL_EXPONENT_OFFSET) * HALF;
let one_minus_x2 = 1.0 - x * x;
if one_minus_x2 <= 0.0 {
return 0.0;
}
let kernel = one_minus_x2.powf(exponent);
let half_df = df * HALF;
let half_df_minus_one = (df - 1.0) * HALF;
let half_ln_pi = HALF * core::f64::consts::PI.ln();
let log_norm = ln_gamma(half_df) - half_ln_pi - ln_gamma(half_df_minus_one);
log_norm.exp() * kernel
}
fn boundary_binary_search(value: f64, boundaries: &[f64]) -> u8 {
let mut lo: usize = 0;
let mut hi: usize = boundaries.len();
while lo < hi {
let mid = lo + (hi - lo) / 2;
if value > boundaries[mid] {
lo = mid + 1;
} else {
hi = mid;
}
}
lo as u8
}
pub fn nearest_centroid(value: f64, codebook: &Codebook) -> u8 {
boundary_binary_search(value, &codebook.boundaries)
}
fn lookup_static_codebook_ref(bits: u8, dim: usize) -> Option<&'static StaticCodebook> {
match (bits, dim) {
(2, 64) => Some(&CODEBOOK_2BIT_D64),
(2, 128) => Some(&CODEBOOK_2BIT_D128),
(2, 256) => Some(&CODEBOOK_2BIT_D256),
(3, 64) => Some(&CODEBOOK_3BIT_D64),
(3, 128) => Some(&CODEBOOK_3BIT_D128),
(3, 256) => Some(&CODEBOOK_3BIT_D256),
(4, 64) => Some(&CODEBOOK_4BIT_D64),
(4, 128) => Some(&CODEBOOK_4BIT_D128),
(4, 256) => Some(&CODEBOOK_4BIT_D256),
_ => None,
}
}
fn lookup_static_codebook(bits: u8, dim: usize) -> Option<Codebook> {
let sc = lookup_static_codebook_ref(bits, dim)?;
Some(sc.to_codebook())
}
pub fn get_codebook(bits: u8, dim: usize) -> Result<Codebook> {
validate_bits(bits)?;
let maybe = lookup_static_codebook(bits, dim);
Ok(maybe.unwrap_or_else(|| gen::generate_codebook(bits, dim)))
}
const CENTROIDS_2B_D64: [f64; 4] = [
-0.18749689292196112,
-0.05651489047635318,
0.05651489047635313,
0.18749689292196103,
];
const BOUNDARIES_2B_D64: [f64; 3] = [-0.12200589169915715, 0.0, 0.12200589169915708];
const CENTROIDS_2B_D128: [f64; 4] = [
-0.13304154846077318,
-0.03999160906335877,
0.03999160906335891,
0.13304154846077348,
];
const BOUNDARIES_2B_D128: [f64; 3] = [-0.08651657876206598, 0.0, 0.086_516_578_762_066_2];
const CENTROIDS_2B_D256: [f64; 4] = [
-0.09423779913129633,
-0.02828860721372146,
0.02828860721372166,
0.09423779913129664,
];
const BOUNDARIES_2B_D256: [f64; 3] = [-0.061_263_203_172_508_9, 0.0, 0.06126320317250915];
const CENTROIDS_3B_D64: [f64; 8] = [
-0.26391407457137683,
-0.16616801009118487,
-0.093_832_375_844_160_5,
-0.03046922045737837,
0.03046922045737837,
0.093_832_375_844_160_5,
0.16616801009118487,
0.26391407457137683,
];
const BOUNDARIES_3B_D64: [f64; 7] = [
-0.21504104233128085,
-0.13000019296767268,
-0.06215079815076943,
0.0,
0.06215079815076943,
0.13000019296767268,
0.21504104233128085,
];
const CENTROIDS_3B_D128: [f64; 8] = [
-0.18839728518004373,
-0.11813986946554235,
-0.06658568378325364,
-0.02160433847349997,
0.02160433847349997,
0.06658568378325364,
0.11813986946554235,
0.18839728518004373,
];
const BOUNDARIES_3B_D128: [f64; 7] = [
-0.15326857732279303,
-0.09236277662439799,
-0.044_095_011_128_376_8,
0.0,
0.044_095_011_128_376_8,
0.09236277662439799,
0.15326857732279303,
];
const CENTROIDS_3B_D256: [f64; 8] = [
-0.13385436276083063,
-0.083_765_531_459_768_9,
-0.04716676527922715,
-0.01529750782483941,
0.01529750782483941,
0.04716676527922715,
0.083_765_531_459_768_9,
0.13385436276083063,
];
const BOUNDARIES_3B_D256: [f64; 7] = [
-0.10880994711029976,
-0.06546614836949802,
-0.03123213655203328,
0.0,
0.03123213655203328,
0.06546614836949802,
0.10880994711029976,
];
const CENTROIDS_4B_D64: [f64; 16] = [
-0.33092994168409773,
-0.25307088610074774,
-0.19901983361887085,
-0.15508179062990365,
-0.11662310388676207,
-0.08141753279040376,
-0.04815672368589858,
-0.015_941_930_352_081_4,
0.015_941_930_352_081_4,
0.04815672368589858,
0.08141753279040376,
0.11662310388676207,
0.15508179062990365,
0.19901983361887085,
0.25307088610074774,
0.33092994168409773,
];
const BOUNDARIES_4B_D64: [f64; 15] = [
-0.292_000_413_892_422_7,
-0.22604535985980928,
-0.17705081212438725,
-0.13585244725833287,
-0.09902031833858291,
-0.06478712823815116,
-0.03204932701898999,
0.0,
0.03204932701898999,
0.06478712823815116,
0.09902031833858291,
0.13585244725833287,
0.17705081212438725,
0.22604535985980928,
0.292_000_413_892_422_7,
];
const CENTROIDS_4B_D128: [f64; 16] = [
-0.23777655506958537,
-0.18096588552769086,
-0.14193912272806147,
-0.11041538921898804,
-0.08293881469006784,
-0.05785765497830671,
-0.03420549908335103,
-0.01132093590150223,
0.01132093590150223,
0.03420549908335103,
0.05785765497830671,
0.08293881469006784,
0.11041538921898804,
0.14193912272806147,
0.18096588552769086,
0.23777655506958537,
];
const BOUNDARIES_4B_D128: [f64; 15] = [
-0.209_371_220_298_638_1,
-0.16145250412787615,
-0.12617725597352475,
-0.09667710195452794,
-0.07039823483418728,
-0.04603157703082887,
-0.02276321749242663,
0.0,
0.02276321749242663,
0.04603157703082887,
0.07039823483418728,
0.09667710195452794,
0.12617725597352475,
0.16145250412787615,
0.209_371_220_298_638_1,
];
const CENTROIDS_4B_D256: [f64; 16] = [
-0.16949853314441155,
-0.12868871755030106,
-0.10080108457584613,
-0.07834675699488723,
-0.05881658417438018,
-0.04101444098641885,
-0.02424206232116148,
-0.00802245010411462,
0.00802245010411462,
0.02424206232116148,
0.04101444098641885,
0.05881658417438018,
0.07834675699488723,
0.10080108457584613,
0.12868871755030106,
0.16949853314441155,
];
const BOUNDARIES_4B_D256: [f64; 15] = [
-0.14909362534735632,
-0.114_744_901_063_073_6,
-0.08957392078536669,
-0.06858167058463371,
-0.04991551258039952,
-0.03262825165379016,
-0.01613225621263805,
0.0,
0.01613225621263805,
0.03262825165379016,
0.04991551258039952,
0.06858167058463371,
0.08957392078536669,
0.114_744_901_063_073_6,
0.14909362534735632,
];
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
const TEST_DIM: usize = 128;
const INTEGRATION_STEPS: usize = 2048;
const TEST_CENTROIDS_8: usize = 8;
const TEST_CENTROIDS_16: usize = 16;
const TEST_BITS_3: u8 = 3;
const TEST_BITS_4: u8 = 4;
const TEST_DIMS: [usize; 3] = [64, 128, 256];
const TEST_X_VALUES: [f64; 5] = [0.0, 0.1, 0.3, 0.5, 0.9];
const TEST_CENTROIDS_4: usize = 4;
const TEST_BITS_2: u8 = 2;
const KNOWN_CODEBOOK_CONFIGS: [(u8, usize); 9] = [
(2, 64),
(2, 128),
(2, 256),
(3, 64),
(3, 128),
(3, 256),
(4, 64),
(4, 128),
(4, 256),
];
use crate::math::{ln_gamma, simpsons_integrate};
fn beta_pdf_log_normalization(df: f64) -> f64 {
let half_df = df * HALF;
let half_df_minus_one = (df - 1.0) * HALF;
let half_ln_pi = HALF * core::f64::consts::PI.ln();
ln_gamma(half_df) - half_ln_pi - ln_gamma(half_df_minus_one)
}
fn interval_midpoint(a: f64, b: f64) -> f64 {
(a + b) * HALF
}
const EPSILON_ZERO: f64 = 1e-30;
fn is_near_zero(value: f64) -> bool {
value.abs() < EPSILON_ZERO
}
#[test]
fn beta_pdf_integrates_to_approximately_one() {
let d = TEST_DIM;
let integral = simpsons_integrate(
|x| beta_pdf(x, d),
SUPPORT_MIN,
SUPPORT_MAX,
INTEGRATION_STEPS,
);
assert_relative_eq!(integral, 1.0, epsilon = 1e-4);
}
#[test]
fn beta_pdf_is_symmetric() {
for d in TEST_DIMS {
for &x in &TEST_X_VALUES {
assert_relative_eq!(beta_pdf(x, d), beta_pdf(-x, d), epsilon = 1e-12);
}
}
}
#[test]
fn beta_pdf_zero_at_boundary() {
for d in TEST_DIMS {
assert_relative_eq!(beta_pdf(SUPPORT_MAX, d), 0.0, epsilon = 1e-15);
assert_relative_eq!(beta_pdf(SUPPORT_MIN, d), 0.0, epsilon = 1e-15);
}
}
#[test]
fn beta_pdf_zero_outside_support() {
assert_relative_eq!(beta_pdf(1.5, TEST_DIM), 0.0, epsilon = 1e-15);
assert_relative_eq!(beta_pdf(-2.0, TEST_DIM), 0.0, epsilon = 1e-15);
}
#[test]
fn beta_pdf_zero_for_low_dimension() {
assert_relative_eq!(beta_pdf(0.0, 2), 0.0, epsilon = 1e-15);
assert_relative_eq!(beta_pdf(0.0, 1), 0.0, epsilon = 1e-15);
}
#[test]
fn boundary_binary_search_first_bin() {
let boundaries = vec![-0.5, 0.0, 0.5];
assert_eq!(boundary_binary_search(-0.9, &boundaries), 0);
}
#[test]
fn boundary_binary_search_last_bin() {
let boundaries = vec![-0.5, 0.0, 0.5];
assert_eq!(boundary_binary_search(0.9, &boundaries), 3);
}
#[test]
fn boundary_binary_search_middle() {
let boundaries = vec![-0.5, 0.0, 0.5];
assert_eq!(boundary_binary_search(-0.1, &boundaries), 1);
assert_eq!(boundary_binary_search(0.1, &boundaries), 2);
}
#[test]
fn interval_midpoint_basic() {
assert_relative_eq!(interval_midpoint(0.0, 1.0), 0.5, epsilon = 1e-15);
assert_relative_eq!(interval_midpoint(-1.0, 1.0), 0.0, epsilon = 1e-15);
}
#[test]
fn is_near_zero_true_for_tiny() {
assert!(is_near_zero(1e-31));
assert!(is_near_zero(-1e-31));
assert!(is_near_zero(0.0));
}
#[test]
fn is_near_zero_false_for_normal() {
assert!(!is_near_zero(1e-10));
assert!(!is_near_zero(-0.001));
}
#[test]
fn lookup_known_configs_return_some() {
for &(bits, dim) in &KNOWN_CODEBOOK_CONFIGS {
assert!(
lookup_static_codebook_ref(bits, dim).is_some(),
"expected Some for ({bits}, {dim})"
);
assert!(lookup_static_codebook(bits, dim).is_some());
}
}
#[test]
fn lookup_unknown_config_returns_none() {
assert!(lookup_static_codebook_ref(3, 512).is_none());
assert!(lookup_static_codebook(4, 32).is_none());
}
#[test]
fn centroid_count_3bit() {
assert_eq!(centroid_count(TEST_BITS_3), TEST_CENTROIDS_8);
}
#[test]
fn centroid_count_4bit() {
assert_eq!(centroid_count(TEST_BITS_4), TEST_CENTROIDS_16);
}
#[test]
fn to_codebook_copies_correctly() {
let sc = &CODEBOOK_3BIT_D64;
let cb = sc.to_codebook();
assert_eq!(cb.centroids.len(), sc.centroids.len());
assert_eq!(cb.boundaries.len(), sc.boundaries.len());
for (a, b) in cb.centroids.iter().zip(sc.centroids.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-15);
}
for (a, b) in cb.boundaries.iter().zip(sc.boundaries.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-15);
}
}
#[test]
fn validate_bits_accepts_2_3_and_4() {
assert!(validate_bits(TEST_BITS_2).is_ok());
assert!(validate_bits(TEST_BITS_3).is_ok());
assert!(validate_bits(TEST_BITS_4).is_ok());
}
#[test]
fn validate_bits_rejects_others() {
assert!(validate_bits(0).is_err());
assert!(validate_bits(1).is_err());
assert!(validate_bits(5).is_err());
}
#[test]
fn centroid_count_2bit() {
assert_eq!(centroid_count(TEST_BITS_2), TEST_CENTROIDS_4);
}
#[test]
fn beta_pdf_log_normalization_positive_for_high_d() {
let ln_norm = beta_pdf_log_normalization(TEST_DIM as f64);
assert!(ln_norm > 0.0);
}
}