1use std::io::{BufRead, Write};
2
3use rayon::prelude::*;
4use rsomics_common::{Result, RsomicsError};
5
6pub mod dm;
7mod rng;
8
9pub use dm::DistanceMatrix;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq)]
12pub enum Method {
13 Pearson,
14 Spearman,
15}
16
17impl Method {
18 pub fn parse(s: &str) -> Result<Self> {
19 match s {
20 "pearson" => Ok(Method::Pearson),
21 "spearman" => Ok(Method::Spearman),
22 other => Err(RsomicsError::InvalidInput(format!(
23 "invalid method '{other}' (pearson|spearman)"
24 ))),
25 }
26 }
27 pub fn name(self) -> &'static str {
28 match self {
29 Method::Pearson => "pearson",
30 Method::Spearman => "spearman",
31 }
32 }
33}
34
35#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36pub enum Alternative {
37 TwoSided,
38 Greater,
39 Less,
40}
41
42impl Alternative {
43 pub fn parse(s: &str) -> Result<Self> {
44 match s {
45 "two-sided" => Ok(Alternative::TwoSided),
46 "greater" => Ok(Alternative::Greater),
47 "less" => Ok(Alternative::Less),
48 other => Err(RsomicsError::InvalidInput(format!(
49 "invalid alternative '{other}' (two-sided|greater|less)"
50 ))),
51 }
52 }
53 pub fn name(self) -> &'static str {
54 match self {
55 Alternative::TwoSided => "two-sided",
56 Alternative::Greater => "greater",
57 Alternative::Less => "less",
58 }
59 }
60}
61
62pub struct MantelResult {
63 pub r: f64,
64 pub p_value: f64,
65 pub n: usize,
66 pub method: Method,
67 pub permutations: usize,
68 pub alternative: Alternative,
69}
70
71pub fn mantel(
78 x_data: &[f64],
79 y_data: &[f64],
80 n: usize,
81 method: Method,
82 permutations: usize,
83 alternative: Alternative,
84 seed: u64,
85) -> MantelResult {
86 let (x_flat, y_flat) = match method {
87 Method::Pearson => (
88 dm::DistanceMatrix::condensed(x_data, n),
89 dm::DistanceMatrix::condensed(y_data, n),
90 ),
91 Method::Spearman => (
92 rankdata(&dm::DistanceMatrix::condensed(x_data, n)),
93 rankdata(&dm::DistanceMatrix::condensed(y_data, n)),
94 ),
95 };
96
97 let x_full = match method {
101 Method::Pearson => x_data.to_vec(),
102 Method::Spearman => square_from_condensed(&x_flat, n),
103 };
104
105 let xmean = mean(&x_flat);
109 let normx = norm_centered(&x_flat, xmean);
110 let ym = normalize(&y_flat);
111 let r = match (&ym, normx) {
112 (Some(ymn), Some(nx)) => dot_centered(&x_flat, xmean, nx, ymn).clamp(-1.0, 1.0),
113 _ => f64::NAN,
114 };
115
116 let p_value = if permutations == 0 || r.is_nan() {
117 f64::NAN
118 } else {
119 let ymn = ym.unwrap();
120 let nx = normx.unwrap();
121 let count_extreme: usize = (0..permutations)
122 .into_par_iter()
123 .map(|k| {
124 let perm = rng::permutation(n, seed, k as u64);
125 let stat = permuted_stat(&x_full, n, &perm, xmean, nx, &ymn).clamp(-1.0, 1.0);
126 match alternative {
127 Alternative::TwoSided => usize::from(stat.abs() >= r.abs()),
128 Alternative::Greater => usize::from(stat >= r),
129 Alternative::Less => usize::from(stat <= r),
130 }
131 })
132 .sum();
133 (count_extreme + 1) as f64 / (permutations + 1) as f64
134 };
135
136 MantelResult {
137 r,
138 p_value,
139 n,
140 method,
141 permutations,
142 alternative,
143 }
144}
145
146fn normalize(v: &[f64]) -> Option<Vec<f64>> {
148 let m = mean(v);
149 let mut out: Vec<f64> = v.iter().map(|&x| x - m).collect();
150 let norm = out.iter().map(|&x| x * x).sum::<f64>().sqrt();
151 if norm == 0.0 {
152 return None;
153 }
154 for x in &mut out {
155 *x /= norm;
156 }
157 Some(out)
158}
159
160fn mean(v: &[f64]) -> f64 {
161 v.iter().sum::<f64>() / v.len() as f64
162}
163
164fn norm_centered(v: &[f64], m: f64) -> Option<f64> {
165 let s = v.iter().map(|&x| (x - m) * (x - m)).sum::<f64>().sqrt();
166 (s != 0.0).then_some(s)
167}
168
169fn dot_centered(x: &[f64], xmean: f64, normx: f64, ym_norm: &[f64]) -> f64 {
170 x.iter()
171 .zip(ym_norm)
172 .map(|(&xv, &yv)| (xv - xmean) * yv)
173 .sum::<f64>()
174 / normx
175}
176
177fn permuted_stat(
180 x_full: &[f64],
181 n: usize,
182 perm: &[usize],
183 xmean: f64,
184 normx: f64,
185 ym_norm: &[f64],
186) -> f64 {
187 let mut acc = 0.0;
188 let mut k = 0;
189 for i in 0..n {
190 let base = perm[i] * n;
191 for j in (i + 1)..n {
192 acc += (x_full[base + perm[j]] - xmean) * ym_norm[k];
193 k += 1;
194 }
195 }
196 acc / normx
197}
198
199fn rankdata(v: &[f64]) -> Vec<f64> {
201 let mut order: Vec<usize> = (0..v.len()).collect();
202 order.sort_by(|&a, &b| v[a].partial_cmp(&v[b]).unwrap());
203 let mut ranks = vec![0.0f64; v.len()];
204 let mut i = 0;
205 while i < order.len() {
206 let mut j = i + 1;
207 while j < order.len() && v[order[j]] == v[order[i]] {
208 j += 1;
209 }
210 let avg = ((i + 1 + j) as f64) / 2.0;
212 for &idx in &order[i..j] {
213 ranks[idx] = avg;
214 }
215 i = j;
216 }
217 ranks
218}
219
220fn square_from_condensed(cond: &[f64], n: usize) -> Vec<f64> {
221 let mut out = vec![0.0f64; n * n];
222 let mut k = 0;
223 for i in 0..n {
224 for j in (i + 1)..n {
225 out[i * n + j] = cond[k];
226 out[j * n + i] = cond[k];
227 k += 1;
228 }
229 }
230 out
231}
232
233pub fn write_result<W: Write>(out: &mut W, res: &MantelResult) -> Result<()> {
234 writeln!(
235 out,
236 "method\tstatistic\tp_value\tn\tpermutations\talternative"
237 )
238 .map_err(RsomicsError::Io)?;
239 writeln!(
240 out,
241 "{}\t{:.12}\t{}\t{}\t{}\t{}",
242 res.method.name(),
243 res.r,
244 fmt_p(res.p_value),
245 res.n,
246 res.permutations,
247 res.alternative.name(),
248 )
249 .map_err(RsomicsError::Io)?;
250 Ok(())
251}
252
253fn fmt_p(p: f64) -> String {
254 if p.is_nan() {
255 "nan".to_string()
256 } else {
257 format!("{p:.12}")
258 }
259}
260
261pub fn read_matrix<R: BufRead>(reader: R, source: &str) -> Result<DistanceMatrix> {
262 DistanceMatrix::read(reader, source)
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 fn square(rows: &[&[f64]]) -> (Vec<f64>, usize) {
270 let n = rows.len();
271 let mut d = vec![0.0; n * n];
272 for (i, r) in rows.iter().enumerate() {
273 for (j, &v) in r.iter().enumerate() {
274 d[i * n + j] = v;
275 }
276 }
277 (d, n)
278 }
279
280 #[test]
281 fn skbio_doc_example_pearson() {
282 let (x, n) = square(&[&[0.0, 1.0, 2.0], &[1.0, 0.0, 3.0], &[2.0, 3.0, 0.0]]);
283 let (y, _) = square(&[&[0.0, 2.0, 7.0], &[2.0, 0.0, 6.0], &[7.0, 6.0, 0.0]]);
284 let res = mantel(&x, &y, n, Method::Pearson, 0, Alternative::TwoSided, 1);
285 assert!((res.r - 0.7559289460184544).abs() < 1e-12, "r={}", res.r);
286 assert!(res.p_value.is_nan());
287 }
288
289 #[test]
290 fn rankdata_ties_averaged() {
291 assert_eq!(rankdata(&[1.0, 2.0, 2.0, 3.0]), vec![1.0, 2.5, 2.5, 4.0]);
292 }
293
294 #[test]
295 fn condensed_upper_triangle() {
296 let (x, n) = square(&[&[0.0, 1.0, 2.0], &[1.0, 0.0, 3.0], &[2.0, 3.0, 0.0]]);
297 assert_eq!(DistanceMatrix::condensed(&x, n), vec![1.0, 2.0, 3.0]);
298 }
299
300 #[test]
301 fn identity_permutation_reproduces_observed_stat() {
302 let (x, n) = square(&[&[0.0, 1.0, 2.0], &[1.0, 0.0, 3.0], &[2.0, 3.0, 0.0]]);
303 let (y, _) = square(&[&[0.0, 2.0, 7.0], &[2.0, 0.0, 6.0], &[7.0, 6.0, 0.0]]);
304 let xf = DistanceMatrix::condensed(&x, n);
305 let yf = DistanceMatrix::condensed(&y, n);
306 let xmean = mean(&xf);
307 let normx = norm_centered(&xf, xmean).unwrap();
308 let ymn = normalize(&yf).unwrap();
309 let id: Vec<usize> = (0..n).collect();
310 let permuted = permuted_stat(&x, n, &id, xmean, normx, &ymn);
311 let observed = dot_centered(&xf, xmean, normx, &ymn);
312 assert!((permuted - observed).abs() < 1e-12);
313 }
314}