Skip to main content

rsomics_mantel/
lib.rs

1use std::io::{BufRead, Write};
2
3use rayon::prelude::*;
4use rsomics_common::{Result, RsomicsError};
5
6pub mod dm;
7mod rng;
8
9pub use dm::DistanceMatrix;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq)]
12pub enum Method {
13    Pearson,
14    Spearman,
15}
16
17impl Method {
18    pub fn parse(s: &str) -> Result<Self> {
19        match s {
20            "pearson" => Ok(Method::Pearson),
21            "spearman" => Ok(Method::Spearman),
22            other => Err(RsomicsError::InvalidInput(format!(
23                "invalid method '{other}' (pearson|spearman)"
24            ))),
25        }
26    }
27    pub fn name(self) -> &'static str {
28        match self {
29            Method::Pearson => "pearson",
30            Method::Spearman => "spearman",
31        }
32    }
33}
34
35#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36pub enum Alternative {
37    TwoSided,
38    Greater,
39    Less,
40}
41
42impl Alternative {
43    pub fn parse(s: &str) -> Result<Self> {
44        match s {
45            "two-sided" => Ok(Alternative::TwoSided),
46            "greater" => Ok(Alternative::Greater),
47            "less" => Ok(Alternative::Less),
48            other => Err(RsomicsError::InvalidInput(format!(
49                "invalid alternative '{other}' (two-sided|greater|less)"
50            ))),
51        }
52    }
53    pub fn name(self) -> &'static str {
54        match self {
55            Alternative::TwoSided => "two-sided",
56            Alternative::Greater => "greater",
57            Alternative::Less => "less",
58        }
59    }
60}
61
62pub struct MantelResult {
63    pub r: f64,
64    pub p_value: f64,
65    pub n: usize,
66    pub method: Method,
67    pub permutations: usize,
68    pub alternative: Alternative,
69}
70
71/// Run the Mantel test. `y_data` is already reordered onto `x`'s id order.
72///
73/// The correlation coefficient is deterministic and matches scikit-bio's
74/// `mantel()` to floating-point tolerance. The p-value is a permutation
75/// estimate computed with a seeded RNG; it is Monte-Carlo and does not
76/// reproduce numpy's PCG64 permutation stream bit-for-bit.
77pub fn mantel(
78    x_data: &[f64],
79    y_data: &[f64],
80    n: usize,
81    method: Method,
82    permutations: usize,
83    alternative: Alternative,
84    seed: u64,
85) -> MantelResult {
86    let (x_flat, y_flat) = match method {
87        Method::Pearson => (
88            dm::DistanceMatrix::condensed(x_data, n),
89            dm::DistanceMatrix::condensed(y_data, n),
90        ),
91        Method::Spearman => (
92            rankdata(&dm::DistanceMatrix::condensed(x_data, n)),
93            rankdata(&dm::DistanceMatrix::condensed(y_data, n)),
94        ),
95    };
96
97    // The permutation acts on the full matrix; for Spearman that is the
98    // rank-transformed matrix. Rebuild a full square from the ranked condensed
99    // form so permutation semantics stay identical to skbio.
100    let x_full = match method {
101        Method::Pearson => x_data.to_vec(),
102        Method::Spearman => square_from_condensed(&x_flat, n),
103    };
104
105    // x's condensed mean and norm are permutation-invariant — the permutation
106    // only reorders the same upper-triangle entries — so each permuted statistic
107    // is a single allocation-free pass: dot(x_perm - xmean, ym_normalized)/normx.
108    let xmean = mean(&x_flat);
109    let normx = norm_centered(&x_flat, xmean);
110    let ym = normalize(&y_flat);
111    let r = match (&ym, normx) {
112        (Some(ymn), Some(nx)) => dot_centered(&x_flat, xmean, nx, ymn).clamp(-1.0, 1.0),
113        _ => f64::NAN,
114    };
115
116    let p_value = if permutations == 0 || r.is_nan() {
117        f64::NAN
118    } else {
119        let ymn = ym.unwrap();
120        let nx = normx.unwrap();
121        let count_extreme: usize = (0..permutations)
122            .into_par_iter()
123            .map(|k| {
124                let perm = rng::permutation(n, seed, k as u64);
125                let stat = permuted_stat(&x_full, n, &perm, xmean, nx, &ymn).clamp(-1.0, 1.0);
126                match alternative {
127                    Alternative::TwoSided => usize::from(stat.abs() >= r.abs()),
128                    Alternative::Greater => usize::from(stat >= r),
129                    Alternative::Less => usize::from(stat <= r),
130                }
131            })
132            .sum();
133        (count_extreme + 1) as f64 / (permutations + 1) as f64
134    };
135
136    MantelResult {
137        r,
138        p_value,
139        n,
140        method,
141        permutations,
142        alternative,
143    }
144}
145
146/// Center then scale to unit norm; `None` if the input has no variation.
147fn normalize(v: &[f64]) -> Option<Vec<f64>> {
148    let m = mean(v);
149    let mut out: Vec<f64> = v.iter().map(|&x| x - m).collect();
150    let norm = out.iter().map(|&x| x * x).sum::<f64>().sqrt();
151    if norm == 0.0 {
152        return None;
153    }
154    for x in &mut out {
155        *x /= norm;
156    }
157    Some(out)
158}
159
160fn mean(v: &[f64]) -> f64 {
161    v.iter().sum::<f64>() / v.len() as f64
162}
163
164fn norm_centered(v: &[f64], m: f64) -> Option<f64> {
165    let s = v.iter().map(|&x| (x - m) * (x - m)).sum::<f64>().sqrt();
166    (s != 0.0).then_some(s)
167}
168
169fn dot_centered(x: &[f64], xmean: f64, normx: f64, ym_norm: &[f64]) -> f64 {
170    x.iter()
171        .zip(ym_norm)
172        .map(|(&xv, &yv)| (xv - xmean) * yv)
173        .sum::<f64>()
174        / normx
175}
176
177/// Pearson statistic of `x_full` permuted by `perm`, against the already-
178/// normalized `ym_norm`, in one allocation-free pass over the upper triangle.
179fn permuted_stat(
180    x_full: &[f64],
181    n: usize,
182    perm: &[usize],
183    xmean: f64,
184    normx: f64,
185    ym_norm: &[f64],
186) -> f64 {
187    let mut acc = 0.0;
188    let mut k = 0;
189    for i in 0..n {
190        let base = perm[i] * n;
191        for j in (i + 1)..n {
192            acc += (x_full[base + perm[j]] - xmean) * ym_norm[k];
193            k += 1;
194        }
195    }
196    acc / normx
197}
198
199/// Average-rank of each element, scipy `rankdata` default (ties averaged).
200fn rankdata(v: &[f64]) -> Vec<f64> {
201    let mut order: Vec<usize> = (0..v.len()).collect();
202    order.sort_by(|&a, &b| v[a].partial_cmp(&v[b]).unwrap());
203    let mut ranks = vec![0.0f64; v.len()];
204    let mut i = 0;
205    while i < order.len() {
206        let mut j = i + 1;
207        while j < order.len() && v[order[j]] == v[order[i]] {
208            j += 1;
209        }
210        // ranks are 1-based; the average of the tied positions
211        let avg = ((i + 1 + j) as f64) / 2.0;
212        for &idx in &order[i..j] {
213            ranks[idx] = avg;
214        }
215        i = j;
216    }
217    ranks
218}
219
220fn square_from_condensed(cond: &[f64], n: usize) -> Vec<f64> {
221    let mut out = vec![0.0f64; n * n];
222    let mut k = 0;
223    for i in 0..n {
224        for j in (i + 1)..n {
225            out[i * n + j] = cond[k];
226            out[j * n + i] = cond[k];
227            k += 1;
228        }
229    }
230    out
231}
232
233pub fn write_result<W: Write>(out: &mut W, res: &MantelResult) -> Result<()> {
234    writeln!(
235        out,
236        "method\tstatistic\tp_value\tn\tpermutations\talternative"
237    )
238    .map_err(RsomicsError::Io)?;
239    writeln!(
240        out,
241        "{}\t{:.12}\t{}\t{}\t{}\t{}",
242        res.method.name(),
243        res.r,
244        fmt_p(res.p_value),
245        res.n,
246        res.permutations,
247        res.alternative.name(),
248    )
249    .map_err(RsomicsError::Io)?;
250    Ok(())
251}
252
253fn fmt_p(p: f64) -> String {
254    if p.is_nan() {
255        "nan".to_string()
256    } else {
257        format!("{p:.12}")
258    }
259}
260
261pub fn read_matrix<R: BufRead>(reader: R, source: &str) -> Result<DistanceMatrix> {
262    DistanceMatrix::read(reader, source)
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    fn square(rows: &[&[f64]]) -> (Vec<f64>, usize) {
270        let n = rows.len();
271        let mut d = vec![0.0; n * n];
272        for (i, r) in rows.iter().enumerate() {
273            for (j, &v) in r.iter().enumerate() {
274                d[i * n + j] = v;
275            }
276        }
277        (d, n)
278    }
279
280    #[test]
281    fn skbio_doc_example_pearson() {
282        let (x, n) = square(&[&[0.0, 1.0, 2.0], &[1.0, 0.0, 3.0], &[2.0, 3.0, 0.0]]);
283        let (y, _) = square(&[&[0.0, 2.0, 7.0], &[2.0, 0.0, 6.0], &[7.0, 6.0, 0.0]]);
284        let res = mantel(&x, &y, n, Method::Pearson, 0, Alternative::TwoSided, 1);
285        assert!((res.r - 0.7559289460184544).abs() < 1e-12, "r={}", res.r);
286        assert!(res.p_value.is_nan());
287    }
288
289    #[test]
290    fn rankdata_ties_averaged() {
291        assert_eq!(rankdata(&[1.0, 2.0, 2.0, 3.0]), vec![1.0, 2.5, 2.5, 4.0]);
292    }
293
294    #[test]
295    fn condensed_upper_triangle() {
296        let (x, n) = square(&[&[0.0, 1.0, 2.0], &[1.0, 0.0, 3.0], &[2.0, 3.0, 0.0]]);
297        assert_eq!(DistanceMatrix::condensed(&x, n), vec![1.0, 2.0, 3.0]);
298    }
299
300    #[test]
301    fn identity_permutation_reproduces_observed_stat() {
302        let (x, n) = square(&[&[0.0, 1.0, 2.0], &[1.0, 0.0, 3.0], &[2.0, 3.0, 0.0]]);
303        let (y, _) = square(&[&[0.0, 2.0, 7.0], &[2.0, 0.0, 6.0], &[7.0, 6.0, 0.0]]);
304        let xf = DistanceMatrix::condensed(&x, n);
305        let yf = DistanceMatrix::condensed(&y, n);
306        let xmean = mean(&xf);
307        let normx = norm_centered(&xf, xmean).unwrap();
308        let ymn = normalize(&yf).unwrap();
309        let id: Vec<usize> = (0..n).collect();
310        let permuted = permuted_stat(&x, n, &id, xmean, normx, &ymn);
311        let observed = dot_centered(&xf, xmean, normx, &ymn);
312        assert!((permuted - observed).abs() < 1e-12);
313    }
314}