Skip to main content

gam_geometry/manifolds/
simplex.rs

1use ndarray::{Array2, ArrayView1, ArrayView2};
2
3use crate::normalize_weights;
4
5pub fn validate_simplex_array(points: ArrayView2<'_, f64>) -> Result<(), String> {
6    let (n, d) = points.dim();
7    if n == 0 || d < 2 {
8        return Err(
9            "simplex values must have at least one row and at least two columns".to_string(),
10        );
11    }
12    if let Some(((row, col), value)) = points.indexed_iter().find(|(_, v)| !v.is_finite()) {
13        return Err(format!(
14            "simplex values must contain only finite values; got {value} at ({row}, {col})"
15        ));
16    }
17    Ok(())
18}
19
20pub fn closure(points: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
21    validate_simplex_array(points)?;
22    let (n, d) = points.dim();
23    let mut out = Array2::<f64>::zeros((n, d));
24    for row in 0..n {
25        let mut total = 0.0_f64;
26        for col in 0..d {
27            let v = points[[row, col]];
28            if v < 0.0 {
29                return Err("simplex values must be non-negative".to_string());
30            }
31            total += v;
32        }
33        if total <= 0.0 {
34            return Err("simplex rows must have positive total mass".to_string());
35        }
36        for col in 0..d {
37            out[[row, col]] = points[[row, col]] / total;
38        }
39    }
40    Ok(out)
41}
42
43fn require_positive(comp: ArrayView2<'_, f64>, label: &str) -> Result<(), String> {
44    for value in comp.iter() {
45        if *value <= 0.0 {
46            return Err(format!("{label} require strictly positive simplex values"));
47        }
48    }
49    Ok(())
50}
51
52pub fn simplex_frechet_mean(
53    points: ArrayView2<'_, f64>,
54    weights: Option<ArrayView1<'_, f64>>,
55) -> Result<Vec<f64>, String> {
56    let comp = closure(points)?;
57    require_positive(comp.view(), "simplex Fr\u{e9}chet mean")?;
58    let (n, d) = comp.dim();
59    let w = normalize_weights(n, weights)?;
60    let mut mean_log = vec![0.0_f64; d];
61    for row in 0..n {
62        for col in 0..d {
63            mean_log[col] += w[row] * comp[[row, col]].ln();
64        }
65    }
66    let mut max_v = f64::NEG_INFINITY;
67    for &v in mean_log.iter() {
68        if v > max_v {
69            max_v = v;
70        }
71    }
72    let mut total = 0.0_f64;
73    let mut out = vec![0.0_f64; d];
74    for col in 0..d {
75        let e = (mean_log[col] - max_v).exp();
76        out[col] = e;
77        total += e;
78    }
79    for value in out.iter_mut() {
80        *value /= total;
81    }
82    Ok(out)
83}
84
85/// Coordinate system for simplex (Aitchison) log/exp maps: centered log-ratio
86/// (`Clr`, `d`-dim, sum-zero) or additive log-ratio (`Alr`, `(d-1)`-dim relative
87/// to a reference part).
88#[derive(Copy, Clone, Debug, PartialEq, Eq)]
89pub enum SimplexCoord {
90    Clr,
91    Alr,
92}
93
94/// Parse a simplex coordinate label. `"simplex"`/`"clr"` → CLR, `"alr"` → ALR.
95pub fn parse_simplex_coord(coordinates: &str) -> Result<SimplexCoord, String> {
96    match coordinates.to_ascii_lowercase().as_str() {
97        "simplex" | "clr" => Ok(SimplexCoord::Clr),
98        "alr" => Ok(SimplexCoord::Alr),
99        other => Err(format!(
100            "simplex coordinates must be 'clr' or 'alr'; got {other:?}"
101        )),
102    }
103}
104
105/// Wrap a (possibly negative) reference index into `0..d`.
106fn resolve_reference(reference: isize, d: usize) -> usize {
107    let d_i = d as isize;
108    let mut r = reference % d_i;
109    if r < 0 {
110        r += d_i;
111    }
112    r as usize
113}
114
115/// Centered log-ratio coordinates: `clr(x)_j = ln x_j - mean_k ln x_k` after
116/// closure. Requires strictly positive compositions.
117pub fn clr(values: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
118    let comp = closure(values)?;
119    require_positive(comp.view(), "CLR coordinates")?;
120    let (n, d) = comp.dim();
121    let mut out = Array2::<f64>::zeros((n, d));
122    for row in 0..n {
123        let mut sum_log = 0.0_f64;
124        for col in 0..d {
125            let lg = comp[[row, col]].ln();
126            out[[row, col]] = lg;
127            sum_log += lg;
128        }
129        let mean = sum_log / (d as f64);
130        for col in 0..d {
131            out[[row, col]] -= mean;
132        }
133    }
134    Ok(out)
135}
136
137/// Additive log-ratio coordinates relative to `reference`: `alr(x)_j = ln x_j -
138/// ln x_ref` for `j != ref`, yielding `(d-1)` columns. Requires strictly
139/// positive compositions.
140pub fn alr(values: ArrayView2<'_, f64>, reference: isize) -> Result<Array2<f64>, String> {
141    let comp = closure(values)?;
142    require_positive(comp.view(), "ALR coordinates")?;
143    let (n, d) = comp.dim();
144    let ref_idx = resolve_reference(reference, d);
145    let mut out = Array2::<f64>::zeros((n, d - 1));
146    for row in 0..n {
147        let log_ref = comp[[row, ref_idx]].ln();
148        let mut k = 0usize;
149        for col in 0..d {
150            if col == ref_idx {
151                continue;
152            }
153            out[[row, k]] = comp[[row, col]].ln() - log_ref;
154            k += 1;
155        }
156    }
157    Ok(out)
158}
159
160/// Inverse additive log-ratio: map `(d-1)` ALR coordinates back to the simplex
161/// via a numerically stable softmax with the reference logit pinned to zero.
162pub fn inverse_alr(coords: ArrayView2<'_, f64>, reference: isize) -> Result<Array2<f64>, String> {
163    let (n, dm1) = coords.dim();
164    if !coords.iter().all(|v| v.is_finite()) {
165        return Err("ALR coordinates must contain only finite values".to_string());
166    }
167    let d = dm1 + 1;
168    let ref_idx = resolve_reference(reference, d);
169    let mut out = Array2::<f64>::zeros((n, d));
170    for row in 0..n {
171        let mut max_v = f64::NEG_INFINITY;
172        let mut k = 0usize;
173        for col in 0..d {
174            let v = if col == ref_idx {
175                0.0
176            } else {
177                let val = coords[[row, k]];
178                k += 1;
179                val
180            };
181            out[[row, col]] = v;
182            if v > max_v {
183                max_v = v;
184            }
185        }
186        let mut total = 0.0_f64;
187        for col in 0..d {
188            let e = (out[[row, col]] - max_v).exp();
189            out[[row, col]] = e;
190            total += e;
191        }
192        for col in 0..d {
193            out[[row, col]] /= total;
194        }
195    }
196    Ok(out)
197}
198
199/// Riemannian log map at an intrinsic simplex base point, expressed in the
200/// chosen coordinate system: the difference of the values' and base's CLR/ALR
201/// coordinates.
202pub fn simplex_log_map(
203    values: ArrayView2<'_, f64>,
204    base: ArrayView1<'_, f64>,
205    coord: SimplexCoord,
206    reference: isize,
207) -> Result<Array2<f64>, String> {
208    let comp = closure(values)?;
209    let base2 = Array2::from_shape_fn((1, base.len()), |(_, j)| base[j]);
210    let base_comp = closure(base2.view())?;
211    if comp.ncols() != base_comp.ncols() {
212        return Err("simplex values and base point have different dimensions".to_string());
213    }
214    require_positive(comp.view(), "simplex log map")?;
215    require_positive(base_comp.view(), "simplex log map")?;
216    match coord {
217        SimplexCoord::Clr => {
218            let values_clr = clr(values)?;
219            let base_clr = clr(base2.view())?;
220            let (n, d) = values_clr.dim();
221            let mut out = Array2::<f64>::zeros((n, d));
222            for row in 0..n {
223                for col in 0..d {
224                    out[[row, col]] = values_clr[[row, col]] - base_clr[[0, col]];
225                }
226            }
227            Ok(out)
228        }
229        SimplexCoord::Alr => {
230            let values_alr = alr(values, reference)?;
231            let base_alr = alr(base2.view(), reference)?;
232            let (n, dm1) = values_alr.dim();
233            let mut out = Array2::<f64>::zeros((n, dm1));
234            for row in 0..n {
235                for col in 0..dm1 {
236                    out[[row, col]] = values_alr[[row, col]] - base_alr[[0, col]];
237                }
238            }
239            Ok(out)
240        }
241    }
242}
243
244/// Riemannian exp map from tangent coordinates back to the simplex at `base`,
245/// inverting [`simplex_log_map`] for the matching coordinate system.
246pub fn simplex_exp_map(
247    tangent: ArrayView2<'_, f64>,
248    base: ArrayView1<'_, f64>,
249    coord: SimplexCoord,
250    reference: isize,
251) -> Result<Array2<f64>, String> {
252    let base2 = Array2::from_shape_fn((1, base.len()), |(_, j)| base[j]);
253    let base_comp = closure(base2.view())?;
254    let d = base_comp.ncols();
255    match coord {
256        SimplexCoord::Clr => {
257            if tangent.ncols() != d {
258                return Err("CLR tangent dimension must equal simplex dimension".to_string());
259            }
260            require_positive(base_comp.view(), "simplex exp map")?;
261            let n = tangent.nrows();
262            let mut out = Array2::<f64>::zeros((n, d));
263            for row in 0..n {
264                let mut max_v = f64::NEG_INFINITY;
265                for col in 0..d {
266                    let lg = base_comp[[0, col]].ln() + tangent[[row, col]];
267                    out[[row, col]] = lg;
268                    if lg > max_v {
269                        max_v = lg;
270                    }
271                }
272                let mut total = 0.0_f64;
273                for col in 0..d {
274                    let e = (out[[row, col]] - max_v).exp();
275                    out[[row, col]] = e;
276                    total += e;
277                }
278                for col in 0..d {
279                    out[[row, col]] /= total;
280                }
281            }
282            Ok(out)
283        }
284        SimplexCoord::Alr => {
285            if tangent.ncols() + 1 != d {
286                return Err("ALR tangent dimension must be simplex dimension minus one".to_string());
287            }
288            let base_alr = alr(base2.view(), reference)?;
289            let n = tangent.nrows();
290            let dm1 = d - 1;
291            let mut shifted = Array2::<f64>::zeros((n, dm1));
292            for row in 0..n {
293                for col in 0..dm1 {
294                    shifted[[row, col]] = base_alr[[0, col]] + tangent[[row, col]];
295                }
296            }
297            inverse_alr(shifted.view(), reference)
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use ndarray::{Array1, array};
306
307    // ── parse_simplex_coord ───────────────────────────────────────────────────
308
309    #[test]
310    fn parse_simplex_coord_simplex_and_clr_map_to_clr() {
311        assert_eq!(parse_simplex_coord("simplex").unwrap(), SimplexCoord::Clr);
312        assert_eq!(parse_simplex_coord("clr").unwrap(), SimplexCoord::Clr);
313    }
314
315    #[test]
316    fn parse_simplex_coord_alr_maps_to_alr() {
317        assert_eq!(parse_simplex_coord("alr").unwrap(), SimplexCoord::Alr);
318    }
319
320    #[test]
321    fn parse_simplex_coord_case_insensitive() {
322        assert_eq!(parse_simplex_coord("CLR").unwrap(), SimplexCoord::Clr);
323        assert_eq!(parse_simplex_coord("ALR").unwrap(), SimplexCoord::Alr);
324        assert_eq!(parse_simplex_coord("Simplex").unwrap(), SimplexCoord::Clr);
325    }
326
327    #[test]
328    fn parse_simplex_coord_unknown_is_error() {
329        assert!(parse_simplex_coord("pca").is_err());
330        assert!(parse_simplex_coord("").is_err());
331    }
332
333    // ── validate_simplex_array ────────────────────────────────────────────────
334
335    #[test]
336    fn validate_simplex_array_valid_input_passes() {
337        let m = array![[0.5_f64, 0.5]];
338        assert!(validate_simplex_array(m.view()).is_ok());
339    }
340
341    #[test]
342    fn validate_simplex_array_no_rows_is_error() {
343        use ndarray::Array2;
344        let m: Array2<f64> = Array2::zeros((0, 3));
345        assert!(validate_simplex_array(m.view()).is_err());
346    }
347
348    #[test]
349    fn validate_simplex_array_single_column_is_error() {
350        let m = array![[0.5_f64]];
351        assert!(validate_simplex_array(m.view()).is_err());
352    }
353
354    #[test]
355    fn validate_simplex_array_non_finite_is_error() {
356        let m = array![[0.5_f64, f64::NAN]];
357        let err = validate_simplex_array(m.view()).unwrap_err();
358        assert!(err.contains("finite"), "error should mention finite, got: {err}");
359    }
360
361    // ── closure ───────────────────────────────────────────────────────────────
362
363    #[test]
364    fn closure_normalizes_rows_to_sum_one() {
365        let m = array![[1.0_f64, 2.0, 3.0], [4.0, 4.0, 4.0]];
366        let c = closure(m.view()).unwrap();
367        assert!((c.row(0).sum() - 1.0).abs() < 1e-14, "row 0 sum: {}", c.row(0).sum());
368        assert!((c.row(1).sum() - 1.0).abs() < 1e-14, "row 1 sum: {}", c.row(1).sum());
369    }
370
371    #[test]
372    fn closure_equal_weights_gives_uniform_composition() {
373        let m = array![[2.0_f64, 2.0]];
374        let c = closure(m.view()).unwrap();
375        assert!((c[[0, 0]] - 0.5).abs() < 1e-14);
376        assert!((c[[0, 1]] - 0.5).abs() < 1e-14);
377    }
378
379    #[test]
380    fn closure_negative_value_is_error() {
381        let m = array![[1.0_f64, -0.5]];
382        assert!(closure(m.view()).is_err());
383    }
384
385    #[test]
386    fn closure_zero_total_mass_is_error() {
387        let m = array![[0.0_f64, 0.0]];
388        let err = closure(m.view()).unwrap_err();
389        assert!(err.contains("total mass") || err.contains("positive"), "got: {err}");
390    }
391
392    // ── resolve_reference ─────────────────────────────────────────────────────
393
394    #[test]
395    fn resolve_reference_positive_index() {
396        assert_eq!(resolve_reference(1, 3), 1);
397        assert_eq!(resolve_reference(2, 3), 2);
398    }
399
400    #[test]
401    fn resolve_reference_negative_index_wraps() {
402        // -1 → last element (d-1)
403        assert_eq!(resolve_reference(-1, 3), 2);
404        // -2 → second-to-last
405        assert_eq!(resolve_reference(-2, 3), 1);
406        // -3 → first (same as 0)
407        assert_eq!(resolve_reference(-3, 3), 0);
408    }
409
410    // ── clr known values ──────────────────────────────────────────────────────
411
412    #[test]
413    fn clr_of_uniform_composition_is_zero() {
414        // clr([1/3, 1/3, 1/3]) = [0, 0, 0]
415        let m = array![[1.0_f64, 1.0, 1.0]];
416        let c = clr(m.view()).unwrap();
417        for v in c.iter() {
418            assert!(v.abs() < 1e-14, "clr of uniform should be 0, got {v}");
419        }
420    }
421
422    #[test]
423    fn clr_sum_is_zero_per_row() {
424        let m = array![[1.0_f64, 2.0, 3.0], [4.0, 1.0, 1.0]];
425        let c = clr(m.view()).unwrap();
426        for row in c.rows() {
427            assert!(row.sum().abs() < 1e-12, "clr row must sum to zero, got {}", row.sum());
428        }
429    }
430
431    // ── alr / inverse_alr round-trip ─────────────────────────────────────────
432
433    #[test]
434    fn alr_inverse_alr_round_trip() {
435        let m = array![[0.2_f64, 0.5, 0.3]];
436        let coords = alr(m.view(), -1).unwrap(); // reference = last
437        let recovered = inverse_alr(coords.view(), -1).unwrap();
438        for col in 0..3 {
439            assert!(
440                (recovered[[0, col]] - m[[0, col]]).abs() < 1e-12,
441                "col {col}: {} vs {}",
442                recovered[[0, col]],
443                m[[0, col]]
444            );
445        }
446    }
447
448    /// CLR exp map at a strictly-interior base with a finite tangent succeeds
449    /// and lands in the open simplex (all components strictly positive, summing
450    /// to one).
451    #[test]
452    fn clr_exp_map_interior_base_lands_in_open_simplex() {
453        let base: Array1<f64> = array![0.2, 0.5, 0.3];
454        let tangent = array![[0.4_f64, -0.1, -0.3]];
455        let out = simplex_exp_map(tangent.view(), base.view(), SimplexCoord::Clr, 0)
456            .expect("interior base with finite tangent must succeed");
457        let sum: f64 = out.row(0).sum();
458        assert!((sum - 1.0).abs() < 1e-12, "components must sum to one");
459        for v in out.iter() {
460            assert!(*v > 0.0, "components must be strictly positive; got {v}");
461        }
462    }
463
464    /// CLR exp map at a boundary base (a zero component, on the closed simplex
465    /// but off the Aitchison manifold) must error rather than produce NaN.
466    #[test]
467    fn clr_exp_map_boundary_base_errors() {
468        let base: Array1<f64> = array![1.0, 0.0, 0.0];
469        let tangent = array![[0.1_f64, -0.05, -0.05]];
470        let err = simplex_exp_map(tangent.view(), base.view(), SimplexCoord::Clr, 0)
471            .expect_err("boundary base must be rejected, not yield NaN");
472        assert!(
473            err.contains("strictly positive"),
474            "error must explain the positivity domain; got {err}"
475        );
476    }
477
478    /// CLR log map followed by exp map at the same interior base recovers the
479    /// original interior point.
480    #[test]
481    fn clr_log_exp_round_trip_recovers_interior_point() {
482        let base: Array1<f64> = array![0.25, 0.45, 0.30];
483        let point = array![[0.1_f64, 0.6, 0.3]];
484        let tangent = simplex_log_map(point.view(), base.view(), SimplexCoord::Clr, 0)
485            .expect("log map at interior base must succeed");
486        let recovered = simplex_exp_map(tangent.view(), base.view(), SimplexCoord::Clr, 0)
487            .expect("exp map at interior base must succeed");
488        for col in 0..3 {
489            assert!(
490                (recovered[[0, col]] - point[[0, col]]).abs() < 1e-12,
491                "round-trip must recover input at column {col}: {} vs {}",
492                recovered[[0, col]],
493                point[[0, col]]
494            );
495        }
496    }
497}