use ndarray::{Array2, ArrayView1, ArrayView2};
use crate::geometry::normalize_weights;
pub fn validate_simplex_array(points: ArrayView2<'_, f64>) -> Result<(), String> {
let (n, d) = points.dim();
if n == 0 || d < 2 {
return Err(
"simplex values must have at least one row and at least two columns".to_string(),
);
}
if let Some(((row, col), value)) = points.indexed_iter().find(|(_, v)| !v.is_finite()) {
return Err(format!(
"simplex values must contain only finite values; got {value} at ({row}, {col})"
));
}
Ok(())
}
pub fn closure(points: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
validate_simplex_array(points)?;
let (n, d) = points.dim();
let mut out = Array2::<f64>::zeros((n, d));
for row in 0..n {
let mut total = 0.0_f64;
for col in 0..d {
let v = points[[row, col]];
if v < 0.0 {
return Err("simplex values must be non-negative".to_string());
}
total += v;
}
if total <= 0.0 {
return Err("simplex rows must have positive total mass".to_string());
}
for col in 0..d {
out[[row, col]] = points[[row, col]] / total;
}
}
Ok(out)
}
fn require_positive(comp: ArrayView2<'_, f64>, label: &str) -> Result<(), String> {
for value in comp.iter() {
if *value <= 0.0 {
return Err(format!("{label} require strictly positive simplex values"));
}
}
Ok(())
}
pub fn simplex_frechet_mean(
points: ArrayView2<'_, f64>,
weights: Option<ArrayView1<'_, f64>>,
) -> Result<Vec<f64>, String> {
let comp = closure(points)?;
require_positive(comp.view(), "simplex Fr\u{e9}chet mean")?;
let (n, d) = comp.dim();
let w = normalize_weights(n, weights)?;
let mut mean_log = vec![0.0_f64; d];
for row in 0..n {
for col in 0..d {
mean_log[col] += w[row] * comp[[row, col]].ln();
}
}
let mut max_v = f64::NEG_INFINITY;
for &v in mean_log.iter() {
if v > max_v {
max_v = v;
}
}
let mut total = 0.0_f64;
let mut out = vec![0.0_f64; d];
for col in 0..d {
let e = (mean_log[col] - max_v).exp();
out[col] = e;
total += e;
}
for value in out.iter_mut() {
*value /= total;
}
Ok(out)
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum SimplexCoord {
Clr,
Alr,
}
pub fn parse_simplex_coord(coordinates: &str) -> Result<SimplexCoord, String> {
match coordinates.to_ascii_lowercase().as_str() {
"simplex" | "clr" => Ok(SimplexCoord::Clr),
"alr" => Ok(SimplexCoord::Alr),
other => Err(format!(
"simplex coordinates must be 'clr' or 'alr'; got {other:?}"
)),
}
}
fn resolve_reference(reference: isize, d: usize) -> usize {
let d_i = d as isize;
let mut r = reference % d_i;
if r < 0 {
r += d_i;
}
r as usize
}
pub fn clr(values: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
let comp = closure(values)?;
require_positive(comp.view(), "CLR coordinates")?;
let (n, d) = comp.dim();
let mut out = Array2::<f64>::zeros((n, d));
for row in 0..n {
let mut sum_log = 0.0_f64;
for col in 0..d {
let lg = comp[[row, col]].ln();
out[[row, col]] = lg;
sum_log += lg;
}
let mean = sum_log / (d as f64);
for col in 0..d {
out[[row, col]] -= mean;
}
}
Ok(out)
}
pub fn alr(values: ArrayView2<'_, f64>, reference: isize) -> Result<Array2<f64>, String> {
let comp = closure(values)?;
require_positive(comp.view(), "ALR coordinates")?;
let (n, d) = comp.dim();
let ref_idx = resolve_reference(reference, d);
let mut out = Array2::<f64>::zeros((n, d - 1));
for row in 0..n {
let log_ref = comp[[row, ref_idx]].ln();
let mut k = 0usize;
for col in 0..d {
if col == ref_idx {
continue;
}
out[[row, k]] = comp[[row, col]].ln() - log_ref;
k += 1;
}
}
Ok(out)
}
pub fn inverse_alr(coords: ArrayView2<'_, f64>, reference: isize) -> Result<Array2<f64>, String> {
let (n, dm1) = coords.dim();
if !coords.iter().all(|v| v.is_finite()) {
return Err("ALR coordinates must contain only finite values".to_string());
}
let d = dm1 + 1;
let ref_idx = resolve_reference(reference, d);
let mut out = Array2::<f64>::zeros((n, d));
for row in 0..n {
let mut max_v = f64::NEG_INFINITY;
let mut k = 0usize;
for col in 0..d {
let v = if col == ref_idx {
0.0
} else {
let val = coords[[row, k]];
k += 1;
val
};
out[[row, col]] = v;
if v > max_v {
max_v = v;
}
}
let mut total = 0.0_f64;
for col in 0..d {
let e = (out[[row, col]] - max_v).exp();
out[[row, col]] = e;
total += e;
}
for col in 0..d {
out[[row, col]] /= total;
}
}
Ok(out)
}
pub fn simplex_log_map(
values: ArrayView2<'_, f64>,
base: ArrayView1<'_, f64>,
coord: SimplexCoord,
reference: isize,
) -> Result<Array2<f64>, String> {
let comp = closure(values)?;
let base2 = Array2::from_shape_fn((1, base.len()), |(_, j)| base[j]);
let base_comp = closure(base2.view())?;
if comp.ncols() != base_comp.ncols() {
return Err("simplex values and base point have different dimensions".to_string());
}
require_positive(comp.view(), "simplex log map")?;
require_positive(base_comp.view(), "simplex log map")?;
match coord {
SimplexCoord::Clr => {
let values_clr = clr(values)?;
let base_clr = clr(base2.view())?;
let (n, d) = values_clr.dim();
let mut out = Array2::<f64>::zeros((n, d));
for row in 0..n {
for col in 0..d {
out[[row, col]] = values_clr[[row, col]] - base_clr[[0, col]];
}
}
Ok(out)
}
SimplexCoord::Alr => {
let values_alr = alr(values, reference)?;
let base_alr = alr(base2.view(), reference)?;
let (n, dm1) = values_alr.dim();
let mut out = Array2::<f64>::zeros((n, dm1));
for row in 0..n {
for col in 0..dm1 {
out[[row, col]] = values_alr[[row, col]] - base_alr[[0, col]];
}
}
Ok(out)
}
}
}
pub fn simplex_exp_map(
tangent: ArrayView2<'_, f64>,
base: ArrayView1<'_, f64>,
coord: SimplexCoord,
reference: isize,
) -> Result<Array2<f64>, String> {
let base2 = Array2::from_shape_fn((1, base.len()), |(_, j)| base[j]);
let base_comp = closure(base2.view())?;
let d = base_comp.ncols();
match coord {
SimplexCoord::Clr => {
if tangent.ncols() != d {
return Err("CLR tangent dimension must equal simplex dimension".to_string());
}
let n = tangent.nrows();
let mut out = Array2::<f64>::zeros((n, d));
for row in 0..n {
let mut max_v = f64::NEG_INFINITY;
for col in 0..d {
let lg = base_comp[[0, col]].ln() + tangent[[row, col]];
out[[row, col]] = lg;
if lg > max_v {
max_v = lg;
}
}
let mut total = 0.0_f64;
for col in 0..d {
let e = (out[[row, col]] - max_v).exp();
out[[row, col]] = e;
total += e;
}
for col in 0..d {
out[[row, col]] /= total;
}
}
Ok(out)
}
SimplexCoord::Alr => {
if tangent.ncols() + 1 != d {
return Err("ALR tangent dimension must be simplex dimension minus one".to_string());
}
let base_alr = alr(base2.view(), reference)?;
let n = tangent.nrows();
let dm1 = d - 1;
let mut shifted = Array2::<f64>::zeros((n, dm1));
for row in 0..n {
for col in 0..dm1 {
shifted[[row, col]] = base_alr[[0, col]] + tangent[[row, col]];
}
}
inverse_alr(shifted.view(), reference)
}
}
}