1use ndarray::{Array2, ArrayView1, ArrayView2};
2
3use crate::normalize_weights;
4
5pub fn validate_simplex_array(points: ArrayView2<'_, f64>) -> Result<(), String> {
6 let (n, d) = points.dim();
7 if n == 0 || d < 2 {
8 return Err(
9 "simplex values must have at least one row and at least two columns".to_string(),
10 );
11 }
12 if let Some(((row, col), value)) = points.indexed_iter().find(|(_, v)| !v.is_finite()) {
13 return Err(format!(
14 "simplex values must contain only finite values; got {value} at ({row}, {col})"
15 ));
16 }
17 Ok(())
18}
19
20pub fn closure(points: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
21 validate_simplex_array(points)?;
22 let (n, d) = points.dim();
23 let mut out = Array2::<f64>::zeros((n, d));
24 for row in 0..n {
25 let mut total = 0.0_f64;
26 for col in 0..d {
27 let v = points[[row, col]];
28 if v < 0.0 {
29 return Err("simplex values must be non-negative".to_string());
30 }
31 total += v;
32 }
33 if total <= 0.0 {
34 return Err("simplex rows must have positive total mass".to_string());
35 }
36 for col in 0..d {
37 out[[row, col]] = points[[row, col]] / total;
38 }
39 }
40 Ok(out)
41}
42
43fn require_positive(comp: ArrayView2<'_, f64>, label: &str) -> Result<(), String> {
44 for value in comp.iter() {
45 if *value <= 0.0 {
46 return Err(format!("{label} require strictly positive simplex values"));
47 }
48 }
49 Ok(())
50}
51
52pub fn simplex_frechet_mean(
53 points: ArrayView2<'_, f64>,
54 weights: Option<ArrayView1<'_, f64>>,
55) -> Result<Vec<f64>, String> {
56 let comp = closure(points)?;
57 require_positive(comp.view(), "simplex Fr\u{e9}chet mean")?;
58 let (n, d) = comp.dim();
59 let w = normalize_weights(n, weights)?;
60 let mut mean_log = vec![0.0_f64; d];
61 for row in 0..n {
62 for col in 0..d {
63 mean_log[col] += w[row] * comp[[row, col]].ln();
64 }
65 }
66 let mut max_v = f64::NEG_INFINITY;
67 for &v in mean_log.iter() {
68 if v > max_v {
69 max_v = v;
70 }
71 }
72 let mut total = 0.0_f64;
73 let mut out = vec![0.0_f64; d];
74 for col in 0..d {
75 let e = (mean_log[col] - max_v).exp();
76 out[col] = e;
77 total += e;
78 }
79 for value in out.iter_mut() {
80 *value /= total;
81 }
82 Ok(out)
83}
84
85#[derive(Copy, Clone, Debug, PartialEq, Eq)]
89pub enum SimplexCoord {
90 Clr,
91 Alr,
92}
93
94pub fn parse_simplex_coord(coordinates: &str) -> Result<SimplexCoord, String> {
96 match coordinates.to_ascii_lowercase().as_str() {
97 "simplex" | "clr" => Ok(SimplexCoord::Clr),
98 "alr" => Ok(SimplexCoord::Alr),
99 other => Err(format!(
100 "simplex coordinates must be 'clr' or 'alr'; got {other:?}"
101 )),
102 }
103}
104
105fn resolve_reference(reference: isize, d: usize) -> usize {
107 let d_i = d as isize;
108 let mut r = reference % d_i;
109 if r < 0 {
110 r += d_i;
111 }
112 r as usize
113}
114
115pub fn clr(values: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
118 let comp = closure(values)?;
119 require_positive(comp.view(), "CLR coordinates")?;
120 let (n, d) = comp.dim();
121 let mut out = Array2::<f64>::zeros((n, d));
122 for row in 0..n {
123 let mut sum_log = 0.0_f64;
124 for col in 0..d {
125 let lg = comp[[row, col]].ln();
126 out[[row, col]] = lg;
127 sum_log += lg;
128 }
129 let mean = sum_log / (d as f64);
130 for col in 0..d {
131 out[[row, col]] -= mean;
132 }
133 }
134 Ok(out)
135}
136
137pub fn alr(values: ArrayView2<'_, f64>, reference: isize) -> Result<Array2<f64>, String> {
141 let comp = closure(values)?;
142 require_positive(comp.view(), "ALR coordinates")?;
143 let (n, d) = comp.dim();
144 let ref_idx = resolve_reference(reference, d);
145 let mut out = Array2::<f64>::zeros((n, d - 1));
146 for row in 0..n {
147 let log_ref = comp[[row, ref_idx]].ln();
148 let mut k = 0usize;
149 for col in 0..d {
150 if col == ref_idx {
151 continue;
152 }
153 out[[row, k]] = comp[[row, col]].ln() - log_ref;
154 k += 1;
155 }
156 }
157 Ok(out)
158}
159
160pub fn inverse_alr(coords: ArrayView2<'_, f64>, reference: isize) -> Result<Array2<f64>, String> {
163 let (n, dm1) = coords.dim();
164 if !coords.iter().all(|v| v.is_finite()) {
165 return Err("ALR coordinates must contain only finite values".to_string());
166 }
167 let d = dm1 + 1;
168 let ref_idx = resolve_reference(reference, d);
169 let mut out = Array2::<f64>::zeros((n, d));
170 for row in 0..n {
171 let mut max_v = f64::NEG_INFINITY;
172 let mut k = 0usize;
173 for col in 0..d {
174 let v = if col == ref_idx {
175 0.0
176 } else {
177 let val = coords[[row, k]];
178 k += 1;
179 val
180 };
181 out[[row, col]] = v;
182 if v > max_v {
183 max_v = v;
184 }
185 }
186 let mut total = 0.0_f64;
187 for col in 0..d {
188 let e = (out[[row, col]] - max_v).exp();
189 out[[row, col]] = e;
190 total += e;
191 }
192 for col in 0..d {
193 out[[row, col]] /= total;
194 }
195 }
196 Ok(out)
197}
198
199pub fn simplex_log_map(
203 values: ArrayView2<'_, f64>,
204 base: ArrayView1<'_, f64>,
205 coord: SimplexCoord,
206 reference: isize,
207) -> Result<Array2<f64>, String> {
208 let comp = closure(values)?;
209 let base2 = Array2::from_shape_fn((1, base.len()), |(_, j)| base[j]);
210 let base_comp = closure(base2.view())?;
211 if comp.ncols() != base_comp.ncols() {
212 return Err("simplex values and base point have different dimensions".to_string());
213 }
214 require_positive(comp.view(), "simplex log map")?;
215 require_positive(base_comp.view(), "simplex log map")?;
216 match coord {
217 SimplexCoord::Clr => {
218 let values_clr = clr(values)?;
219 let base_clr = clr(base2.view())?;
220 let (n, d) = values_clr.dim();
221 let mut out = Array2::<f64>::zeros((n, d));
222 for row in 0..n {
223 for col in 0..d {
224 out[[row, col]] = values_clr[[row, col]] - base_clr[[0, col]];
225 }
226 }
227 Ok(out)
228 }
229 SimplexCoord::Alr => {
230 let values_alr = alr(values, reference)?;
231 let base_alr = alr(base2.view(), reference)?;
232 let (n, dm1) = values_alr.dim();
233 let mut out = Array2::<f64>::zeros((n, dm1));
234 for row in 0..n {
235 for col in 0..dm1 {
236 out[[row, col]] = values_alr[[row, col]] - base_alr[[0, col]];
237 }
238 }
239 Ok(out)
240 }
241 }
242}
243
244pub fn simplex_exp_map(
247 tangent: ArrayView2<'_, f64>,
248 base: ArrayView1<'_, f64>,
249 coord: SimplexCoord,
250 reference: isize,
251) -> Result<Array2<f64>, String> {
252 let base2 = Array2::from_shape_fn((1, base.len()), |(_, j)| base[j]);
253 let base_comp = closure(base2.view())?;
254 let d = base_comp.ncols();
255 if let Some(((row, col), value)) = tangent.indexed_iter().find(|(_, v)| !v.is_finite()) {
256 return Err(format!(
257 "simplex exp map tangent must contain only finite values; got {value} at ({row}, {col})"
258 ));
259 }
260 match coord {
261 SimplexCoord::Clr => {
262 if tangent.ncols() != d {
263 return Err("CLR tangent dimension must equal simplex dimension".to_string());
264 }
265 require_positive(base_comp.view(), "simplex exp map")?;
266 let n = tangent.nrows();
267 let mut out = Array2::<f64>::zeros((n, d));
268 for row in 0..n {
269 let mut max_v = f64::NEG_INFINITY;
270 for col in 0..d {
271 let lg = base_comp[[0, col]].ln() + tangent[[row, col]];
272 out[[row, col]] = lg;
273 if lg > max_v {
274 max_v = lg;
275 }
276 }
277 let mut total = 0.0_f64;
278 for col in 0..d {
279 let e = (out[[row, col]] - max_v).exp();
280 out[[row, col]] = e;
281 total += e;
282 }
283 for col in 0..d {
284 out[[row, col]] /= total;
285 }
286 }
287 Ok(out)
288 }
289 SimplexCoord::Alr => {
290 if tangent.ncols() + 1 != d {
291 return Err("ALR tangent dimension must be simplex dimension minus one".to_string());
292 }
293 let base_alr = alr(base2.view(), reference)?;
294 let n = tangent.nrows();
295 let dm1 = d - 1;
296 let mut shifted = Array2::<f64>::zeros((n, dm1));
297 for row in 0..n {
298 for col in 0..dm1 {
299 shifted[[row, col]] = base_alr[[0, col]] + tangent[[row, col]];
300 }
301 }
302 inverse_alr(shifted.view(), reference)
303 }
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use ndarray::{Array1, array};
311
312 #[test]
315 fn parse_simplex_coord_simplex_and_clr_map_to_clr() {
316 assert_eq!(parse_simplex_coord("simplex").unwrap(), SimplexCoord::Clr);
317 assert_eq!(parse_simplex_coord("clr").unwrap(), SimplexCoord::Clr);
318 }
319
320 #[test]
321 fn parse_simplex_coord_alr_maps_to_alr() {
322 assert_eq!(parse_simplex_coord("alr").unwrap(), SimplexCoord::Alr);
323 }
324
325 #[test]
326 fn parse_simplex_coord_case_insensitive() {
327 assert_eq!(parse_simplex_coord("CLR").unwrap(), SimplexCoord::Clr);
328 assert_eq!(parse_simplex_coord("ALR").unwrap(), SimplexCoord::Alr);
329 assert_eq!(parse_simplex_coord("Simplex").unwrap(), SimplexCoord::Clr);
330 }
331
332 #[test]
333 fn parse_simplex_coord_unknown_is_error() {
334 assert!(parse_simplex_coord("pca").is_err());
335 assert!(parse_simplex_coord("").is_err());
336 }
337
338 #[test]
341 fn validate_simplex_array_valid_input_passes() {
342 let m = array![[0.5_f64, 0.5]];
343 assert!(validate_simplex_array(m.view()).is_ok());
344 }
345
346 #[test]
347 fn validate_simplex_array_no_rows_is_error() {
348 use ndarray::Array2;
349 let m: Array2<f64> = Array2::zeros((0, 3));
350 assert!(validate_simplex_array(m.view()).is_err());
351 }
352
353 #[test]
354 fn validate_simplex_array_single_column_is_error() {
355 let m = array![[0.5_f64]];
356 assert!(validate_simplex_array(m.view()).is_err());
357 }
358
359 #[test]
360 fn validate_simplex_array_non_finite_is_error() {
361 let m = array![[0.5_f64, f64::NAN]];
362 let err = validate_simplex_array(m.view()).unwrap_err();
363 assert!(err.contains("finite"), "error should mention finite, got: {err}");
364 }
365
366 #[test]
369 fn closure_normalizes_rows_to_sum_one() {
370 let m = array![[1.0_f64, 2.0, 3.0], [4.0, 4.0, 4.0]];
371 let c = closure(m.view()).unwrap();
372 assert!((c.row(0).sum() - 1.0).abs() < 1e-14, "row 0 sum: {}", c.row(0).sum());
373 assert!((c.row(1).sum() - 1.0).abs() < 1e-14, "row 1 sum: {}", c.row(1).sum());
374 }
375
376 #[test]
377 fn closure_equal_weights_gives_uniform_composition() {
378 let m = array![[2.0_f64, 2.0]];
379 let c = closure(m.view()).unwrap();
380 assert!((c[[0, 0]] - 0.5).abs() < 1e-14);
381 assert!((c[[0, 1]] - 0.5).abs() < 1e-14);
382 }
383
384 #[test]
385 fn closure_negative_value_is_error() {
386 let m = array![[1.0_f64, -0.5]];
387 assert!(closure(m.view()).is_err());
388 }
389
390 #[test]
391 fn closure_zero_total_mass_is_error() {
392 let m = array![[0.0_f64, 0.0]];
393 let err = closure(m.view()).unwrap_err();
394 assert!(err.contains("total mass") || err.contains("positive"), "got: {err}");
395 }
396
397 #[test]
400 fn resolve_reference_positive_index() {
401 assert_eq!(resolve_reference(1, 3), 1);
402 assert_eq!(resolve_reference(2, 3), 2);
403 }
404
405 #[test]
406 fn resolve_reference_negative_index_wraps() {
407 assert_eq!(resolve_reference(-1, 3), 2);
409 assert_eq!(resolve_reference(-2, 3), 1);
411 assert_eq!(resolve_reference(-3, 3), 0);
413 }
414
415 #[test]
418 fn clr_of_uniform_composition_is_zero() {
419 let m = array![[1.0_f64, 1.0, 1.0]];
421 let c = clr(m.view()).unwrap();
422 for v in c.iter() {
423 assert!(v.abs() < 1e-14, "clr of uniform should be 0, got {v}");
424 }
425 }
426
427 #[test]
428 fn clr_sum_is_zero_per_row() {
429 let m = array![[1.0_f64, 2.0, 3.0], [4.0, 1.0, 1.0]];
430 let c = clr(m.view()).unwrap();
431 for row in c.rows() {
432 assert!(row.sum().abs() < 1e-12, "clr row must sum to zero, got {}", row.sum());
433 }
434 }
435
436 #[test]
439 fn alr_inverse_alr_round_trip() {
440 let m = array![[0.2_f64, 0.5, 0.3]];
441 let coords = alr(m.view(), -1).unwrap(); let recovered = inverse_alr(coords.view(), -1).unwrap();
443 for col in 0..3 {
444 assert!(
445 (recovered[[0, col]] - m[[0, col]]).abs() < 1e-12,
446 "col {col}: {} vs {}",
447 recovered[[0, col]],
448 m[[0, col]]
449 );
450 }
451 }
452
453 #[test]
457 fn clr_exp_map_interior_base_lands_in_open_simplex() {
458 let base: Array1<f64> = array![0.2, 0.5, 0.3];
459 let tangent = array![[0.4_f64, -0.1, -0.3]];
460 let out = simplex_exp_map(tangent.view(), base.view(), SimplexCoord::Clr, 0)
461 .expect("interior base with finite tangent must succeed");
462 let sum: f64 = out.row(0).sum();
463 assert!((sum - 1.0).abs() < 1e-12, "components must sum to one");
464 for v in out.iter() {
465 assert!(*v > 0.0, "components must be strictly positive; got {v}");
466 }
467 }
468
469 #[test]
472 fn clr_exp_map_boundary_base_errors() {
473 let base: Array1<f64> = array![1.0, 0.0, 0.0];
474 let tangent = array![[0.1_f64, -0.05, -0.05]];
475 let err = simplex_exp_map(tangent.view(), base.view(), SimplexCoord::Clr, 0)
476 .expect_err("boundary base must be rejected, not yield NaN");
477 assert!(
478 err.contains("strictly positive"),
479 "error must explain the positivity domain; got {err}"
480 );
481 }
482
483 #[test]
487 fn clr_exp_map_nonfinite_tangent_errors() {
488 let base: Array1<f64> = array![0.25, 0.45, 0.30];
489 let tangent = array![[f64::INFINITY, 0.0_f64, 0.0]];
490 let err = simplex_exp_map(tangent.view(), base.view(), SimplexCoord::Clr, 0)
491 .expect_err("non-finite CLR tangent must be rejected, not yield NaN");
492 assert!(
493 err.contains("finite"),
494 "error must explain the finiteness requirement; got {err}"
495 );
496
497 let alr_tangent = array![[f64::INFINITY, 0.0_f64]];
499 simplex_exp_map(alr_tangent.view(), base.view(), SimplexCoord::Alr, 0)
500 .expect_err("non-finite ALR tangent must be rejected");
501 }
502
503 #[test]
506 fn clr_log_exp_round_trip_recovers_interior_point() {
507 let base: Array1<f64> = array![0.25, 0.45, 0.30];
508 let point = array![[0.1_f64, 0.6, 0.3]];
509 let tangent = simplex_log_map(point.view(), base.view(), SimplexCoord::Clr, 0)
510 .expect("log map at interior base must succeed");
511 let recovered = simplex_exp_map(tangent.view(), base.view(), SimplexCoord::Clr, 0)
512 .expect("exp map at interior base must succeed");
513 for col in 0..3 {
514 assert!(
515 (recovered[[0, col]] - point[[0, col]]).abs() < 1e-12,
516 "round-trip must recover input at column {col}: {} vs {}",
517 recovered[[0, col]],
518 point[[0, col]]
519 );
520 }
521 }
522}