use crate::backend::{dot, squared_l2_norm, ExecutionBackend};
use crate::error::{Result, TurboQuantError};
const ZERO_NORM_EPSILON: f64 = 1e-12;
pub(crate) const UNIT_NORM_TOLERANCE: f64 = 1e-6;
pub fn norm(x: &[f64]) -> f64 {
squared_l2_norm(ExecutionBackend::default(), x).sqrt()
}
pub fn normalize(x: &[f64]) -> Result<Vec<f64>> {
let n = norm(x);
if n.is_nan() || n.is_infinite() || n < ZERO_NORM_EPSILON {
return Err(TurboQuantError::ZeroVector(n));
}
Ok(x.iter().map(|v| v / n).collect())
}
pub(crate) fn validate_finite_vector(x: &[f64], context: &str) -> Result<()> {
if let Some((index, &value)) = x.iter().enumerate().find(|(_, value)| !value.is_finite()) {
return Err(TurboQuantError::InvalidValue {
context: format!("{context}[{index}]"),
value,
});
}
Ok(())
}
pub(crate) fn validate_unit_vector(x: &[f64], context: &str) -> Result<()> {
validate_finite_vector(x, context)?;
let n = norm(x);
if (n - 1.0).abs() > UNIT_NORM_TOLERANCE {
return Err(TurboQuantError::NotUnitVector(n));
}
Ok(())
}
pub fn inner_product(x: &[f64], y: &[f64]) -> f64 {
assert_eq!(
x.len(),
y.len(),
"inner_product: length mismatch ({} vs {})",
x.len(),
y.len()
);
dot(ExecutionBackend::default(), x, y)
}
pub fn beta_pdf(x: f64, dim: usize) -> f64 {
if dim < 2 {
return 0.0;
}
if dim > 50 {
let sigma2 = 1.0 / dim as f64;
let sigma = sigma2.sqrt();
return (-x * x / (2.0 * sigma2)).exp() / (sigma * (2.0 * std::f64::consts::PI).sqrt());
}
if x.abs() >= 1.0 {
return 0.0;
}
let exponent = (dim as f64 - 3.0) / 2.0;
let unnorm = (1.0 - x * x).powf(exponent);
let log_c = lgamma(dim as f64 / 2.0)
- 0.5 * std::f64::consts::PI.ln()
- lgamma((dim as f64 - 1.0) / 2.0);
log_c.exp() * unnorm
}
fn lgamma(x: f64) -> f64 {
if x <= 0.0 {
return f64::INFINITY;
}
#[allow(clippy::excessive_precision)]
let c = [
0.99999999999980993,
676.5203681218851,
-1259.1392167224028,
771.32342877765313,
-176.61502916214059,
12.507343278686905,
-0.13857109526572012,
9.9843695780195716e-6,
1.5056327351493116e-7,
];
let g = 7.0_f64;
if x < 0.5 {
return std::f64::consts::PI.ln() - (std::f64::consts::PI * x).sin().ln() - lgamma(1.0 - x);
}
let x = x - 1.0;
let mut a = c[0];
let t = x + g + 0.5;
for (i, &ci) in c[1..].iter().enumerate() {
a += ci / (x + i as f64 + 1.0);
}
0.5 * (2.0 * std::f64::consts::PI).ln() + a.ln() + (x + 0.5) * t.ln() - t
}
pub fn sample_beta_marginal(dim: usize, u: f64) -> f64 {
if dim > 50 {
let sigma = (1.0 / dim as f64).sqrt();
return sigma * normal_icdf(u);
}
numerical_icdf(u, dim)
}
pub fn normal_icdf(p: f64) -> f64 {
#[allow(clippy::excessive_precision)]
let a = [
-3.969683028665376e+01,
2.209460984245205e+02,
-2.759285104469687e+02,
1.383577518672690e+02,
-3.066479806614716e+01,
2.506628277459239e+00,
];
let b = [
-5.447609879822406e+01,
1.615858368580409e+02,
-1.556989798598866e+02,
6.680131188771972e+01,
-1.328068155288572e+01,
];
let c = [
-7.784894002430293e-03,
-3.223964580411365e-01,
-2.400758277161838e+00,
-2.549732539343734e+00,
4.374664141464968e+00,
2.938163982698783e+00,
];
let d = [
7.784695709041462e-03,
3.224671290700398e-01,
2.445134137142996e+00,
3.754408661907416e+00,
];
let p_low = 0.02425;
let p_high = 1.0 - p_low;
if p < p_low {
let q = (-2.0 * p.ln()).sqrt();
(((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5])
/ ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0)
} else if p <= p_high {
let q = p - 0.5;
let r = q * q;
(((((a[0] * r + a[1]) * r + a[2]) * r + a[3]) * r + a[4]) * r + a[5]) * q
/ (((((b[0] * r + b[1]) * r + b[2]) * r + b[3]) * r + b[4]) * r + 1.0)
} else {
let q = (-2.0 * (1.0 - p).ln()).sqrt();
-(((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5])
/ ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0)
}
}
fn numerical_icdf(u: f64, dim: usize) -> f64 {
let mut lo = -1.0_f64;
let mut hi = 1.0_f64;
for _ in 0..64 {
let mid = (lo + hi) / 2.0;
let cdf_mid = numerical_cdf(mid, dim);
if cdf_mid < u {
lo = mid;
} else {
hi = mid;
}
}
(lo + hi) / 2.0
}
fn numerical_cdf(x: f64, dim: usize) -> f64 {
let n = 200usize;
let a = -0.9999_f64;
let b = x.min(0.9999);
if b <= a {
return 0.0;
}
let h = (b - a) / n as f64;
let mut sum = beta_pdf(a, dim) + beta_pdf(b, dim);
for i in 1..n {
let xi = a + i as f64 * h;
let w = if i % 2 == 0 { 2.0 } else { 4.0 };
sum += w * beta_pdf(xi, dim);
}
sum * h / 3.0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_norm() {
let x = vec![3.0, 4.0];
assert!((norm(&x) - 5.0).abs() < 1e-10);
}
#[test]
fn test_normalize() {
let x = vec![3.0, 4.0];
let n = normalize(&x).unwrap();
assert!((norm(&n) - 1.0).abs() < 1e-10);
assert!((n[0] - 0.6).abs() < 1e-10);
assert!((n[1] - 0.8).abs() < 1e-10);
}
#[test]
fn test_inner_product() {
let x = vec![1.0, 2.0, 3.0];
let y = vec![4.0, 5.0, 6.0];
assert!((inner_product(&x, &y) - 32.0).abs() < 1e-10);
}
#[test]
fn test_beta_pdf_integrates_to_one() {
let dim = 10usize;
let n = 1000usize;
let h = 2.0 / n as f64;
let mut sum = 0.0;
for i in 0..n {
let x = -1.0 + (i as f64 + 0.5) * h;
sum += beta_pdf(x, dim) * h;
}
assert!((sum - 1.0).abs() < 0.05, "integral = {}", sum);
}
#[test]
fn test_normalize_zero_vector() {
let x = vec![0.0, 0.0, 0.0];
let result = normalize(&x);
assert!(result.is_err());
assert!(
matches!(result, Err(TurboQuantError::ZeroVector(_))),
"Expected ZeroVector error"
);
}
#[test]
fn test_norm_empty_vector() {
let x: Vec<f64> = vec![];
assert!((norm(&x) - 0.0).abs() < 1e-15);
}
#[test]
fn test_inner_product_empty() {
let x: Vec<f64> = vec![];
let y: Vec<f64> = vec![];
assert!((inner_product(&x, &y) - 0.0).abs() < 1e-15);
}
#[test]
fn test_beta_pdf_dim_1() {
assert_eq!(beta_pdf(0.5, 1), 0.0);
assert_eq!(beta_pdf(0.5, 0), 0.0);
}
#[test]
fn test_beta_pdf_at_boundary() {
assert_eq!(beta_pdf(1.0, 10), 0.0);
assert_eq!(beta_pdf(-1.0, 10), 0.0);
}
#[test]
fn test_normalize_nan_input() {
let x = vec![f64::NAN, 1.0, 2.0];
let result = normalize(&x);
assert!(result.is_err(), "normalize should reject NaN input");
assert!(
matches!(result, Err(TurboQuantError::ZeroVector(_))),
"Expected ZeroVector error for NaN input"
);
}
#[test]
fn test_normalize_near_zero_vector() {
let x = vec![1e-13, 1e-14, 1e-15];
assert!(normalize(&x).is_err());
}
#[test]
#[should_panic(expected = "length mismatch")]
fn test_inner_product_length_mismatch() {
let x = vec![1.0, 2.0, 3.0];
let y = vec![1.0, 2.0];
inner_product(&x, &y);
}
#[test]
fn test_beta_pdf_large_dim_gaussian() {
let dim = 100usize;
let x = 0.0;
let pdf = beta_pdf(x, dim);
let expected = (dim as f64 / (2.0 * std::f64::consts::PI)).sqrt();
assert!(
(pdf - expected).abs() / expected < 0.1,
"pdf={}, expected={}",
pdf,
expected
);
}
#[test]
fn test_normal_icdf_basic_quantiles() {
assert!((normal_icdf(0.5)).abs() < 1e-6, "median should be 0");
let z = normal_icdf(0.8413);
assert!((z - 1.0).abs() < 0.01, "z={}, expected ~1.0", z);
let z = normal_icdf(0.1587);
assert!((z + 1.0).abs() < 0.01, "z={}, expected ~-1.0", z);
}
#[test]
fn test_normal_icdf_tails() {
let z_low = normal_icdf(0.001);
assert!(z_low < -2.5, "z_low={}, expected < -2.5", z_low);
let z_high = normal_icdf(0.999);
assert!(z_high > 2.5, "z_high={}, expected > 2.5", z_high);
assert!(
(z_low + z_high).abs() < 0.01,
"tails not symmetric: {} + {} = {}",
z_low,
z_high,
z_low + z_high
);
}
#[test]
fn test_sample_beta_marginal_large_dim() {
let dim = 128;
let mid = sample_beta_marginal(dim, 0.5);
assert!(mid.abs() < 0.01, "median should be near 0, got {}", mid);
let lo = sample_beta_marginal(dim, 0.01);
let hi = sample_beta_marginal(dim, 0.99);
assert!(lo < 0.0, "low quantile should be negative: {}", lo);
assert!(hi > 0.0, "high quantile should be positive: {}", hi);
assert!(
(lo + hi).abs() < 0.01,
"not symmetric: {} + {} = {}",
lo,
hi,
lo + hi
);
}
#[test]
fn test_sample_beta_marginal_small_dim() {
let dim = 10;
let mid = sample_beta_marginal(dim, 0.5);
assert!(mid.abs() < 0.1, "median should be near 0, got {}", mid);
let lo = sample_beta_marginal(dim, 0.05);
let hi = sample_beta_marginal(dim, 0.95);
assert!(lo < hi, "quantiles should be ordered: {} < {}", lo, hi);
}
#[test]
fn test_normalize_infinity_input() {
let x = vec![f64::INFINITY, 1.0, 2.0];
let result = normalize(&x);
assert!(result.is_err(), "normalize should reject infinite input");
assert!(
matches!(result, Err(TurboQuantError::ZeroVector(_))),
"Expected ZeroVector error for infinite input"
);
}
#[test]
fn test_normalize_neg_infinity_input() {
let x = vec![1.0, f64::NEG_INFINITY, 2.0];
let result = normalize(&x);
assert!(result.is_err(), "normalize should reject -inf input");
}
#[test]
fn test_beta_pdf_dim_2() {
let pdf = beta_pdf(0.0, 2);
assert!(pdf > 0.0, "dim=2 pdf at 0 should be positive: {}", pdf);
}
#[test]
fn test_beta_pdf_symmetry() {
for dim in [5, 10, 20, 100] {
for &x in &[0.1, 0.3, 0.5, 0.8] {
let pos = beta_pdf(x, dim);
let neg = beta_pdf(-x, dim);
assert!(
(pos - neg).abs() < 1e-10,
"dim={}, x={}: f({})={} != f(-{})={}",
dim,
x,
x,
pos,
x,
neg
);
}
}
}
}