1use ndarray::{Array2, ArrayView1, ArrayView2};
2
3use crate::normalize_weights;
4
5pub fn validate_simplex_array(points: ArrayView2<'_, f64>) -> Result<(), String> {
6 let (n, d) = points.dim();
7 if n == 0 || d < 2 {
8 return Err(
9 "simplex values must have at least one row and at least two columns".to_string(),
10 );
11 }
12 if let Some(((row, col), value)) = points.indexed_iter().find(|(_, v)| !v.is_finite()) {
13 return Err(format!(
14 "simplex values must contain only finite values; got {value} at ({row}, {col})"
15 ));
16 }
17 Ok(())
18}
19
20pub fn closure(points: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
21 validate_simplex_array(points)?;
22 let (n, d) = points.dim();
23 let mut out = Array2::<f64>::zeros((n, d));
24 for row in 0..n {
25 let mut total = 0.0_f64;
26 for col in 0..d {
27 let v = points[[row, col]];
28 if v < 0.0 {
29 return Err("simplex values must be non-negative".to_string());
30 }
31 total += v;
32 }
33 if total <= 0.0 {
34 return Err("simplex rows must have positive total mass".to_string());
35 }
36 for col in 0..d {
37 out[[row, col]] = points[[row, col]] / total;
38 }
39 }
40 Ok(out)
41}
42
43fn require_positive(comp: ArrayView2<'_, f64>, label: &str) -> Result<(), String> {
44 for value in comp.iter() {
45 if *value <= 0.0 {
46 return Err(format!("{label} require strictly positive simplex values"));
47 }
48 }
49 Ok(())
50}
51
52pub fn simplex_frechet_mean(
53 points: ArrayView2<'_, f64>,
54 weights: Option<ArrayView1<'_, f64>>,
55) -> Result<Vec<f64>, String> {
56 let comp = closure(points)?;
57 require_positive(comp.view(), "simplex Fr\u{e9}chet mean")?;
58 let (n, d) = comp.dim();
59 let w = normalize_weights(n, weights)?;
60 let mut mean_log = vec![0.0_f64; d];
61 for row in 0..n {
62 for col in 0..d {
63 mean_log[col] += w[row] * comp[[row, col]].ln();
64 }
65 }
66 let mut max_v = f64::NEG_INFINITY;
67 for &v in mean_log.iter() {
68 if v > max_v {
69 max_v = v;
70 }
71 }
72 let mut total = 0.0_f64;
73 let mut out = vec![0.0_f64; d];
74 for col in 0..d {
75 let e = (mean_log[col] - max_v).exp();
76 out[col] = e;
77 total += e;
78 }
79 for value in out.iter_mut() {
80 *value /= total;
81 }
82 Ok(out)
83}
84
85#[derive(Copy, Clone, Debug, PartialEq, Eq)]
89pub enum SimplexCoord {
90 Clr,
91 Alr,
92}
93
94pub fn parse_simplex_coord(coordinates: &str) -> Result<SimplexCoord, String> {
96 match coordinates.to_ascii_lowercase().as_str() {
97 "simplex" | "clr" => Ok(SimplexCoord::Clr),
98 "alr" => Ok(SimplexCoord::Alr),
99 other => Err(format!(
100 "simplex coordinates must be 'clr' or 'alr'; got {other:?}"
101 )),
102 }
103}
104
105fn resolve_reference(reference: isize, d: usize) -> usize {
107 let d_i = d as isize;
108 let mut r = reference % d_i;
109 if r < 0 {
110 r += d_i;
111 }
112 r as usize
113}
114
115pub fn clr(values: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
118 let comp = closure(values)?;
119 require_positive(comp.view(), "CLR coordinates")?;
120 let (n, d) = comp.dim();
121 let mut out = Array2::<f64>::zeros((n, d));
122 for row in 0..n {
123 let mut sum_log = 0.0_f64;
124 for col in 0..d {
125 let lg = comp[[row, col]].ln();
126 out[[row, col]] = lg;
127 sum_log += lg;
128 }
129 let mean = sum_log / (d as f64);
130 for col in 0..d {
131 out[[row, col]] -= mean;
132 }
133 }
134 Ok(out)
135}
136
137pub fn alr(values: ArrayView2<'_, f64>, reference: isize) -> Result<Array2<f64>, String> {
141 let comp = closure(values)?;
142 require_positive(comp.view(), "ALR coordinates")?;
143 let (n, d) = comp.dim();
144 let ref_idx = resolve_reference(reference, d);
145 let mut out = Array2::<f64>::zeros((n, d - 1));
146 for row in 0..n {
147 let log_ref = comp[[row, ref_idx]].ln();
148 let mut k = 0usize;
149 for col in 0..d {
150 if col == ref_idx {
151 continue;
152 }
153 out[[row, k]] = comp[[row, col]].ln() - log_ref;
154 k += 1;
155 }
156 }
157 Ok(out)
158}
159
160pub fn inverse_alr(coords: ArrayView2<'_, f64>, reference: isize) -> Result<Array2<f64>, String> {
163 let (n, dm1) = coords.dim();
164 if !coords.iter().all(|v| v.is_finite()) {
165 return Err("ALR coordinates must contain only finite values".to_string());
166 }
167 let d = dm1 + 1;
168 let ref_idx = resolve_reference(reference, d);
169 let mut out = Array2::<f64>::zeros((n, d));
170 for row in 0..n {
171 let mut max_v = f64::NEG_INFINITY;
172 let mut k = 0usize;
173 for col in 0..d {
174 let v = if col == ref_idx {
175 0.0
176 } else {
177 let val = coords[[row, k]];
178 k += 1;
179 val
180 };
181 out[[row, col]] = v;
182 if v > max_v {
183 max_v = v;
184 }
185 }
186 let mut total = 0.0_f64;
187 for col in 0..d {
188 let e = (out[[row, col]] - max_v).exp();
189 out[[row, col]] = e;
190 total += e;
191 }
192 for col in 0..d {
193 out[[row, col]] /= total;
194 }
195 }
196 Ok(out)
197}
198
199pub fn simplex_log_map(
203 values: ArrayView2<'_, f64>,
204 base: ArrayView1<'_, f64>,
205 coord: SimplexCoord,
206 reference: isize,
207) -> Result<Array2<f64>, String> {
208 let comp = closure(values)?;
209 let base2 = Array2::from_shape_fn((1, base.len()), |(_, j)| base[j]);
210 let base_comp = closure(base2.view())?;
211 if comp.ncols() != base_comp.ncols() {
212 return Err("simplex values and base point have different dimensions".to_string());
213 }
214 require_positive(comp.view(), "simplex log map")?;
215 require_positive(base_comp.view(), "simplex log map")?;
216 match coord {
217 SimplexCoord::Clr => {
218 let values_clr = clr(values)?;
219 let base_clr = clr(base2.view())?;
220 let (n, d) = values_clr.dim();
221 let mut out = Array2::<f64>::zeros((n, d));
222 for row in 0..n {
223 for col in 0..d {
224 out[[row, col]] = values_clr[[row, col]] - base_clr[[0, col]];
225 }
226 }
227 Ok(out)
228 }
229 SimplexCoord::Alr => {
230 let values_alr = alr(values, reference)?;
231 let base_alr = alr(base2.view(), reference)?;
232 let (n, dm1) = values_alr.dim();
233 let mut out = Array2::<f64>::zeros((n, dm1));
234 for row in 0..n {
235 for col in 0..dm1 {
236 out[[row, col]] = values_alr[[row, col]] - base_alr[[0, col]];
237 }
238 }
239 Ok(out)
240 }
241 }
242}
243
244pub fn simplex_exp_map(
247 tangent: ArrayView2<'_, f64>,
248 base: ArrayView1<'_, f64>,
249 coord: SimplexCoord,
250 reference: isize,
251) -> Result<Array2<f64>, String> {
252 let base2 = Array2::from_shape_fn((1, base.len()), |(_, j)| base[j]);
253 let base_comp = closure(base2.view())?;
254 let d = base_comp.ncols();
255 match coord {
256 SimplexCoord::Clr => {
257 if tangent.ncols() != d {
258 return Err("CLR tangent dimension must equal simplex dimension".to_string());
259 }
260 require_positive(base_comp.view(), "simplex exp map")?;
261 let n = tangent.nrows();
262 let mut out = Array2::<f64>::zeros((n, d));
263 for row in 0..n {
264 let mut max_v = f64::NEG_INFINITY;
265 for col in 0..d {
266 let lg = base_comp[[0, col]].ln() + tangent[[row, col]];
267 out[[row, col]] = lg;
268 if lg > max_v {
269 max_v = lg;
270 }
271 }
272 let mut total = 0.0_f64;
273 for col in 0..d {
274 let e = (out[[row, col]] - max_v).exp();
275 out[[row, col]] = e;
276 total += e;
277 }
278 for col in 0..d {
279 out[[row, col]] /= total;
280 }
281 }
282 Ok(out)
283 }
284 SimplexCoord::Alr => {
285 if tangent.ncols() + 1 != d {
286 return Err("ALR tangent dimension must be simplex dimension minus one".to_string());
287 }
288 let base_alr = alr(base2.view(), reference)?;
289 let n = tangent.nrows();
290 let dm1 = d - 1;
291 let mut shifted = Array2::<f64>::zeros((n, dm1));
292 for row in 0..n {
293 for col in 0..dm1 {
294 shifted[[row, col]] = base_alr[[0, col]] + tangent[[row, col]];
295 }
296 }
297 inverse_alr(shifted.view(), reference)
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use ndarray::{Array1, array};
306
307 #[test]
311 fn clr_exp_map_interior_base_lands_in_open_simplex() {
312 let base: Array1<f64> = array![0.2, 0.5, 0.3];
313 let tangent = array![[0.4_f64, -0.1, -0.3]];
314 let out = simplex_exp_map(tangent.view(), base.view(), SimplexCoord::Clr, 0)
315 .expect("interior base with finite tangent must succeed");
316 let sum: f64 = out.row(0).sum();
317 assert!((sum - 1.0).abs() < 1e-12, "components must sum to one");
318 for v in out.iter() {
319 assert!(*v > 0.0, "components must be strictly positive; got {v}");
320 }
321 }
322
323 #[test]
326 fn clr_exp_map_boundary_base_errors() {
327 let base: Array1<f64> = array![1.0, 0.0, 0.0];
328 let tangent = array![[0.1_f64, -0.05, -0.05]];
329 let err = simplex_exp_map(tangent.view(), base.view(), SimplexCoord::Clr, 0)
330 .expect_err("boundary base must be rejected, not yield NaN");
331 assert!(
332 err.contains("strictly positive"),
333 "error must explain the positivity domain; got {err}"
334 );
335 }
336
337 #[test]
340 fn clr_log_exp_round_trip_recovers_interior_point() {
341 let base: Array1<f64> = array![0.25, 0.45, 0.30];
342 let point = array![[0.1_f64, 0.6, 0.3]];
343 let tangent = simplex_log_map(point.view(), base.view(), SimplexCoord::Clr, 0)
344 .expect("log map at interior base must succeed");
345 let recovered = simplex_exp_map(tangent.view(), base.view(), SimplexCoord::Clr, 0)
346 .expect("exp map at interior base must succeed");
347 for col in 0..3 {
348 assert!(
349 (recovered[[0, col]] - point[[0, col]]).abs() < 1e-12,
350 "round-trip must recover input at column {col}: {} vs {}",
351 recovered[[0, col]],
352 point[[0, col]]
353 );
354 }
355 }
356}