use crate::rv::dist::Categorical;
use serde::{Deserialize, Serialize};
use std::ops::Index;
use thiserror::Error;
#[derive(Clone, Debug, PartialEq, PartialOrd, Serialize, Deserialize)]
pub struct SimplexPoint(Vec<f64>);
#[derive(Clone, Debug, PartialEq, Error)]
pub enum SimplexPointError {
#[error("simplex coordinate {ix} is invalid with value: {coord}")]
InvalidCoordinate { ix: usize, coord: f64 },
#[error("The simplex coordinates do not sum to one ({sum})")]
DoesNotSumToOne { sum: f64 },
}
impl SimplexPoint {
pub fn new(point: Vec<f64>) -> Result<Self, SimplexPointError> {
let sum: f64 =
point
.iter()
.enumerate()
.try_fold(0.0, |sum, (ix, &coord)| {
if coord.is_finite() && coord >= 0.0 {
Ok(sum + coord)
} else {
Err(SimplexPointError::InvalidCoordinate { ix, coord })
}
})?;
if (sum - 1.0).abs() > 1e-10 {
Err(SimplexPointError::DoesNotSumToOne { sum })
} else {
Ok(SimplexPoint(point))
}
}
pub fn new_unchecked(point: Vec<f64>) -> Self {
SimplexPoint(point)
}
pub fn point(&self) -> &Vec<f64> {
&self.0
}
pub fn ndims(&self) -> usize {
self.0.len()
}
pub fn to_categorical(&self) -> Categorical {
let ln_weights = self.point().iter().map(|&w| w.ln()).collect();
Categorical::from_ln_weights(ln_weights).unwrap()
}
pub fn draw<R: rand::Rng>(&self, rng: &mut R) -> usize {
let u: f64 = rng.gen();
let mut sum_p = 0.0;
for (ix, &p) in self.0.iter().enumerate() {
sum_p += p;
if u < sum_p {
return ix;
}
}
unreachable!("The simplex coords {:?} do not sum to 1", self.0);
}
}
impl Index<u8> for SimplexPoint {
type Output = f64;
fn index(&self, index: u8) -> &f64 {
&self.point()[index as usize]
}
}
impl Index<usize> for SimplexPoint {
type Output = f64;
fn index(&self, index: usize) -> &f64 {
&self.point()[index]
}
}
#[allow(clippy::needless_range_loop)]
pub fn uvec_to_simplex(mut uvec: Vec<f64>) -> SimplexPoint {
let n = uvec.len();
uvec[n - 1] = 1.0;
uvec.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mut um = uvec[0];
for i in 1..n {
let diff = uvec[i] - um;
um = uvec[i];
uvec[i] = diff;
}
SimplexPoint::new_unchecked(uvec)
}