#![allow(dead_code)]
use crate::{Error, Result};
pub fn validate_simplex(p: &[f32], tol: f32) -> Result<()> {
if p.is_empty() {
return Err(Error::Domain("simplex vector must be non-empty"));
}
if !tol.is_finite() || tol < 0.0 {
return Err(Error::Domain("tol must be finite and >= 0"));
}
if p.iter().any(|&x| !x.is_finite()) {
return Err(Error::Domain("simplex vector contains non-finite values"));
}
if p.iter().any(|&x| x < -tol) {
return Err(Error::Domain("simplex vector has negative entries"));
}
let s: f32 = p.iter().sum();
if (s - 1.0).abs() > tol {
return Err(Error::Domain(
"simplex vector does not sum to 1 (within tol)",
));
}
Ok(())
}
pub fn normalize_simplex(p: &[f32]) -> Result<Vec<f32>> {
if p.is_empty() {
return Err(Error::Domain("simplex vector must be non-empty"));
}
if p.iter().any(|&x| !x.is_finite()) {
return Err(Error::Domain("vector contains non-finite values"));
}
if p.iter().any(|&x| x < 0.0) {
return Err(Error::Domain("vector must be nonnegative to normalize"));
}
let s: f32 = p.iter().sum();
if s <= 0.0 {
return Err(Error::Domain("vector must have positive total mass"));
}
Ok(p.iter().map(|&x| x / s).collect())
}
pub fn simplex_lerp(p0: &[f32], p1: &[f32], t: f32) -> Result<Vec<f32>> {
if p0.len() != p1.len() {
return Err(Error::Shape("p0 and p1 must have same length"));
}
if !(0.0..=1.0).contains(&t) || !t.is_finite() {
return Err(Error::Domain("t must be in [0,1] and finite"));
}
let mut out = Vec::with_capacity(p0.len());
for i in 0..p0.len() {
out.push((1.0 - t) * p0[i] + t * p1[i]);
}
Ok(out)
}
pub fn sample_dirichlet(alpha: &[f32], rng: &mut impl rand::Rng) -> Result<Vec<f32>> {
if alpha.is_empty() {
return Err(Error::Domain("alpha must be non-empty"));
}
if alpha.iter().any(|&a| !a.is_finite() || a <= 0.0) {
return Err(Error::Domain("Dirichlet alpha must be positive and finite"));
}
use rand_distr::{Distribution, Gamma};
let mut xs = vec![0.0f32; alpha.len()];
let mut s: f64 = 0.0;
for (i, &a) in alpha.iter().enumerate() {
let g = Gamma::new(a as f64, 1.0).map_err(|_| Error::Domain("invalid Gamma params"))?;
let x: f64 = g.sample(rng);
xs[i] = x as f32;
s += x;
}
if s <= 0.0 {
return Err(Error::Domain("Dirichlet sampling produced zero total mass"));
}
for v in &mut xs {
*v = (*v as f64 / s) as f32;
}
Ok(xs)
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
#[test]
fn dirichlet_samples_are_on_simplex() {
let alpha = vec![0.7f32, 1.3, 2.1, 0.5];
let mut rng = ChaCha8Rng::seed_from_u64(123);
let p = sample_dirichlet(&alpha, &mut rng).unwrap();
validate_simplex(&p, 1e-5).unwrap();
assert!(p.iter().all(|&x| x >= 0.0));
}
#[test]
fn validate_simplex_accepts_valid_point() {
let p = vec![0.2f32, 0.3, 0.5];
validate_simplex(&p, 1e-6).unwrap();
}
#[test]
fn validate_simplex_rejects_wrong_sum() {
let p = vec![0.2f32, 0.3, 0.4]; assert!(validate_simplex(&p, 1e-6).is_err());
}
#[test]
fn validate_simplex_rejects_negative() {
let p = vec![-0.1f32, 0.6, 0.5];
assert!(validate_simplex(&p, 1e-6).is_err());
}
#[test]
fn validate_simplex_rejects_empty() {
let p: Vec<f32> = vec![];
assert!(validate_simplex(&p, 1e-6).is_err());
}
#[test]
fn normalize_simplex_produces_valid_simplex() {
let p = vec![1.0f32, 2.0, 3.0, 4.0];
let n = normalize_simplex(&p).unwrap();
validate_simplex(&n, 1e-6).unwrap();
assert!((n[0] / n[1] - 0.5).abs() < 1e-6);
assert!((n[2] / n[3] - 0.75).abs() < 1e-6);
}
#[test]
fn normalize_simplex_rejects_negative() {
let p = vec![1.0f32, -1.0, 2.0];
assert!(normalize_simplex(&p).is_err());
}
#[test]
fn normalize_simplex_rejects_all_zeros() {
let p = vec![0.0f32, 0.0, 0.0];
assert!(normalize_simplex(&p).is_err());
}
#[test]
fn simplex_lerp_at_boundaries() {
let p0 = vec![1.0f32, 0.0, 0.0];
let p1 = vec![0.0f32, 0.0, 1.0];
let at0 = simplex_lerp(&p0, &p1, 0.0).unwrap();
for (a, b) in at0.iter().zip(p0.iter()) {
assert!((a - b).abs() < 1e-7);
}
let at1 = simplex_lerp(&p0, &p1, 1.0).unwrap();
for (a, b) in at1.iter().zip(p1.iter()) {
assert!((a - b).abs() < 1e-7);
}
let mid = simplex_lerp(&p0, &p1, 0.5).unwrap();
assert!((mid[0] - 0.5).abs() < 1e-7);
assert!((mid[1] - 0.0).abs() < 1e-7);
assert!((mid[2] - 0.5).abs() < 1e-7);
}
#[test]
fn simplex_lerp_stays_on_simplex() {
let p0 = vec![0.3f32, 0.5, 0.2];
let p1 = vec![0.1f32, 0.1, 0.8];
for i in 0..=10 {
let t = i as f32 / 10.0;
let pt = simplex_lerp(&p0, &p1, t).unwrap();
validate_simplex(&pt, 1e-5).unwrap();
}
}
#[test]
fn simplex_lerp_rejects_mismatched_lengths() {
let p0 = vec![0.5f32, 0.5];
let p1 = vec![0.3f32, 0.3, 0.4];
assert!(simplex_lerp(&p0, &p1, 0.5).is_err());
}
#[test]
fn simplex_lerp_rejects_out_of_range_t() {
let p0 = vec![0.5f32, 0.5];
let p1 = vec![0.3f32, 0.7];
assert!(simplex_lerp(&p0, &p1, -0.1).is_err());
assert!(simplex_lerp(&p0, &p1, 1.1).is_err());
}
}