gam 0.2.3

Generalized penalized likelihood engine
use ndarray::{Array2, ArrayView1, ArrayView2};

fn validate_simplex_array(points: ArrayView2<'_, f64>) -> Result<(), String> {
    let (n, d) = points.dim();
    if n == 0 || d < 2 {
        return Err(
            "simplex values must have at least one row and at least two columns".to_string(),
        );
    }
    if let Some(((row, col), value)) = points.indexed_iter().find(|(_, v)| !v.is_finite()) {
        return Err(format!(
            "simplex values must contain only finite values; got {value} at ({row}, {col})"
        ));
    }
    Ok(())
}

fn normalize_weights(n: usize, weights: Option<ArrayView1<'_, f64>>) -> Result<Vec<f64>, String> {
    match weights {
        None => Ok(vec![1.0 / n as f64; n]),
        Some(w) => {
            if w.len() != n {
                return Err("weights length must match the number of rows".to_string());
            }
            let mut total = 0.0_f64;
            for value in w.iter() {
                if !value.is_finite() || *value < 0.0 {
                    return Err(
                        "weights must be finite, non-negative, and have positive total".to_string(),
                    );
                }
                total += *value;
            }
            if total <= 0.0 {
                return Err(
                    "weights must be finite, non-negative, and have positive total".to_string(),
                );
            }
            Ok(w.iter().map(|v| *v / total).collect())
        }
    }
}

fn closure(points: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
    validate_simplex_array(points)?;
    let (n, d) = points.dim();
    let mut out = Array2::<f64>::zeros((n, d));
    for row in 0..n {
        let mut total = 0.0_f64;
        for col in 0..d {
            let v = points[[row, col]];
            if v < 0.0 {
                return Err("simplex values must be non-negative".to_string());
            }
            total += v;
        }
        if total <= 0.0 {
            return Err("simplex rows must have positive total mass".to_string());
        }
        for col in 0..d {
            out[[row, col]] = points[[row, col]] / total;
        }
    }
    Ok(out)
}

fn require_positive(comp: ArrayView2<'_, f64>) -> Result<(), String> {
    for value in comp.iter() {
        if *value <= 0.0 {
            return Err(
                "simplex Fr\u{e9}chet mean require strictly positive simplex values".to_string(),
            );
        }
    }
    Ok(())
}

pub fn simplex_frechet_mean(
    points: ArrayView2<'_, f64>,
    weights: Option<ArrayView1<'_, f64>>,
) -> Result<Vec<f64>, String> {
    let comp = closure(points)?;
    require_positive(comp.view())?;
    let (n, d) = comp.dim();
    let w = normalize_weights(n, weights)?;
    let mut mean_log = vec![0.0_f64; d];
    for row in 0..n {
        for col in 0..d {
            mean_log[col] += w[row] * comp[[row, col]].ln();
        }
    }
    let mut max_v = f64::NEG_INFINITY;
    for &v in mean_log.iter() {
        if v > max_v {
            max_v = v;
        }
    }
    let mut total = 0.0_f64;
    let mut out = vec![0.0_f64; d];
    for col in 0..d {
        let e = (mean_log[col] - max_v).exp();
        out[col] = e;
        total += e;
    }
    for value in out.iter_mut() {
        *value /= total;
    }
    Ok(out)
}