Skip to main content

rsomics_rda/
lib.rs

1use std::io::{BufRead, Write};
2
3use faer::Mat;
4use faer::linalg::solvers::Svd;
5use rsomics_common::{Result, RsomicsError};
6
7mod fmt;
8use fmt::push_pyrepr;
9
10/// A numeric matrix read from a labelled TSV: an empty top-left cell, then
11/// column IDs as the header, then one row per sample (row ID + tab-separated
12/// values). Used for both the response (samples × species) and the constraint
13/// (samples × variables) tables.
14pub struct Matrix {
15    pub row_ids: Vec<String>,
16    pub col_ids: Vec<String>,
17    /// Row-major `n_rows × n_cols`.
18    pub data: Vec<f64>,
19}
20
21impl Matrix {
22    /// # Errors
23    /// Errors on a missing header, a ragged body, or a non-numeric cell.
24    pub fn parse<R: BufRead>(reader: R, delim: char) -> Result<Matrix> {
25        let mut lines = reader.lines();
26        let header = loop {
27            match lines.next() {
28                Some(line) => {
29                    let line = line.map_err(RsomicsError::Io)?;
30                    if line.trim().is_empty() || line.starts_with('#') {
31                        continue;
32                    }
33                    break line;
34                }
35                None => return Err(RsomicsError::InvalidInput("empty table".into())),
36            }
37        };
38        let col_ids: Vec<String> = header
39            .split(delim)
40            .skip(1)
41            .map(|s| s.trim().to_string())
42            .collect();
43        let p = col_ids.len();
44        if p == 0 {
45            return Err(RsomicsError::InvalidInput(
46                "header has no value columns (need an empty top-left cell + ≥1 column)".into(),
47            ));
48        }
49
50        let mut row_ids = Vec::new();
51        let mut data = Vec::new();
52        for line in lines {
53            let line = line.map_err(RsomicsError::Io)?;
54            if line.trim().is_empty() || line.starts_with('#') {
55                continue;
56            }
57            let mut fields = line.split(delim);
58            let label = fields.next().unwrap_or("").trim().to_string();
59            let row_start = data.len();
60            for field in fields {
61                let v: f64 = field.trim().parse().map_err(|_| {
62                    RsomicsError::InvalidInput(format!(
63                        "row '{label}', column {}: '{}' is not numeric",
64                        data.len() - row_start + 1,
65                        field.trim()
66                    ))
67                })?;
68                data.push(v);
69            }
70            let got = data.len() - row_start;
71            if got != p {
72                return Err(RsomicsError::InvalidInput(format!(
73                    "row '{label}' has {got} values, expected {p}"
74                )));
75            }
76            row_ids.push(label);
77        }
78        if row_ids.is_empty() {
79            return Err(RsomicsError::InvalidInput("no data rows".into()));
80        }
81        Ok(Matrix {
82            row_ids,
83            col_ids,
84            data,
85        })
86    }
87
88    #[must_use]
89    pub fn n_rows(&self) -> usize {
90        self.row_ids.len()
91    }
92
93    #[must_use]
94    pub fn n_cols(&self) -> usize {
95        self.col_ids.len()
96    }
97
98    fn to_mat(&self) -> Mat<f64> {
99        let c = self.n_cols();
100        Mat::from_fn(self.n_rows(), c, |i, j| self.data[i * c + j])
101    }
102}
103
104/// Result of a Redundancy Analysis. Eigenvalues, proportion explained, and the
105/// site/species scores plus biplot and site-constraint scores follow
106/// `skbio.stats.ordination.rda`. The first `n_constrained` axes are canonical
107/// (constrained by the explanatory variables); the rest are the unconstrained
108/// PCA of the residuals.
109pub struct Ordination {
110    pub sample_ids: Vec<String>,
111    pub species_ids: Vec<String>,
112    pub constraint_ids: Vec<String>,
113    pub eigvals: Vec<f64>,
114    pub proportion_explained: Vec<f64>,
115    /// Row-major `n_samples × n_axes`.
116    pub sample_scores: Vec<f64>,
117    /// Row-major `n_species × n_axes`.
118    pub species_scores: Vec<f64>,
119    /// Biplot scores follow the left singular vectors of the fitted values, so
120    /// they span only the constrained axes: row-major `n_constraints × biplot_axes`.
121    pub biplot_scores: Vec<f64>,
122    pub biplot_axes: usize,
123    /// Row-major `n_samples × n_axes`.
124    pub sample_constraints: Vec<f64>,
125}
126
127struct ThinSvd {
128    u: Mat<f64>,
129    s: Vec<f64>,
130    vt: Mat<f64>,
131}
132
133fn thin_svd(m: &Mat<f64>) -> ThinSvd {
134    let svd: Svd<f64> = m.thin_svd().unwrap();
135    let sv = svd.S().column_vector();
136    let k = sv.nrows();
137    let s = (0..k).map(|i| sv[i]).collect();
138    let u = svd.U().to_owned();
139    let v = svd.V();
140    let vt = Mat::from_fn(v.ncols(), v.nrows(), |i, j| v[(j, i)]);
141    ThinSvd { u, s, vt }
142}
143
144/// Rank from singular values, matching numpy's `matrix_rank` default tolerance.
145fn svd_rank(rows: usize, cols: usize, s: &[f64]) -> usize {
146    let smax = s.iter().fold(0.0_f64, |m, &v| m.max(v));
147    let tol = smax * rows.max(cols) as f64 * f64::EPSILON;
148    s.iter().filter(|&&v| v > tol).count()
149}
150
151/// Column-centre `m` in place (with_mean), matching skbio `scale(with_std=False)`.
152fn center_columns(m: &mut Mat<f64>) {
153    let n = m.nrows();
154    for j in 0..m.ncols() {
155        let mut mean = 0.0;
156        for i in 0..n {
157            mean += m[(i, j)];
158        }
159        mean /= n as f64;
160        for i in 0..n {
161            m[(i, j)] -= mean;
162        }
163    }
164}
165
166/// Column-scale `m` to unit population std (ddof=0), zero std left as 1.
167fn scale_columns_std(m: &mut Mat<f64>) {
168    let n = m.nrows();
169    for j in 0..m.ncols() {
170        let mut var = 0.0;
171        for i in 0..n {
172            var += m[(i, j)] * m[(i, j)];
173        }
174        let mut std = (var / n as f64).sqrt();
175        if std == 0.0 {
176            std = 1.0;
177        }
178        for i in 0..n {
179            m[(i, j)] /= std;
180        }
181    }
182}
183
184/// Correlation between columns of `x` and `y` (both centred + scaled to unit
185/// population std, then `x' y / n`). skbio `corr`.
186fn corr(x: &Mat<f64>, y: &Mat<f64>) -> Mat<f64> {
187    let n = x.nrows();
188    let mut xs = x.clone();
189    center_columns(&mut xs);
190    scale_columns_std(&mut xs);
191    let mut ys = y.clone();
192    center_columns(&mut ys);
193    scale_columns_std(&mut ys);
194    let p = xs.ncols();
195    let q = ys.ncols();
196    Mat::from_fn(p, q, |i, j| {
197        let mut acc = 0.0;
198        for r in 0..n {
199            acc += xs[(r, i)] * ys[(r, j)];
200        }
201        acc / n as f64
202    })
203}
204
205impl Ordination {
206    /// RDA per Legendre & Legendre 1998 §11.1: regress centred `y` on centred
207    /// `x` (SVD least squares), SVD the fitted values for the canonical axes,
208    /// SVD the residuals for the unconstrained axes, then apply scaling 1 or 2.
209    ///
210    /// # Errors
211    /// Errors when the two tables disagree on sample count or when `x` has more
212    /// columns than rows (an under-determined regression), matching skbio.
213    pub fn compute(
214        response: &Matrix,
215        constraints: &Matrix,
216        scaling: u8,
217        scale_y: bool,
218    ) -> Result<Ordination> {
219        let n = response.n_rows();
220        let m = constraints.n_cols();
221        if constraints.n_rows() != n {
222            return Err(RsomicsError::InvalidInput(format!(
223                "response has {n} samples but constraints have {}",
224                constraints.n_rows()
225            )));
226        }
227        if n < m {
228            return Err(RsomicsError::InvalidInput(format!(
229                "constraints cannot have fewer rows ({n}) than columns ({m})"
230            )));
231        }
232
233        let mut y = response.to_mat();
234        center_columns(&mut y);
235        if scale_y {
236            scale_columns_std(&mut y);
237        }
238        let mut x = constraints.to_mat();
239        center_columns(&mut x);
240
241        // Y_hat = X B with B the minimum-norm least-squares solution (SVD), so
242        // Y_hat is the projection of Y onto the column space of X.
243        let y_hat = project_onto(&x, &y);
244
245        let svd = thin_svd(&y_hat);
246        let rank = svd_rank(y_hat.nrows(), y_hat.ncols(), &svd.s);
247        let u_axes = vt_rows_as_cols(&svd.vt, rank); // p × rank
248
249        let f = matmul(&y, &u_axes); // n × rank, sample scores
250        let z = matmul(&y_hat, &u_axes); // n × rank, fitted sample scores
251
252        let y_res = &y - &y_hat;
253        let svd_res = thin_svd(&y_res);
254        let rank_res = svd_rank(y_res.nrows(), y_res.ncols(), &svd_res.s);
255        let u_res = vt_rows_as_cols(&svd_res.vt, rank_res); // p × rank_res
256        let f_res = matmul(&y_res, &u_res); // n × rank_res
257
258        let mut eigenvalues: Vec<f64> = svd.s[..rank].to_vec();
259        eigenvalues.extend_from_slice(&svd_res.s[..rank_res]);
260        let n_axes = eigenvalues.len();
261        let p = response.n_cols();
262
263        if scaling != 1 && scaling != 2 {
264            return Err(RsomicsError::InvalidInput(
265                "only scaling 1 or 2 is available for RDA".into(),
266            ));
267        }
268        let const_factor = eigenvalues
269            .iter()
270            .map(|&e| e * e)
271            .sum::<f64>()
272            .sqrt()
273            .sqrt();
274        // scaling 1: a single factor; scaling 2: a per-axis factor.
275        let factor = |a: usize| -> f64 {
276            if scaling == 1 {
277                const_factor
278            } else {
279                eigenvalues[a] / const_factor
280            }
281        };
282
283        // species scores = [U | U_res] * factor
284        let mut species_scores = vec![0.0; p * n_axes];
285        for j in 0..p {
286            for a in 0..n_axes {
287                let v = if a < rank {
288                    u_axes[(j, a)]
289                } else {
290                    u_res[(j, a - rank)]
291                };
292                species_scores[j * n_axes + a] = v * factor(a);
293            }
294        }
295        // sample scores = [F | F_res] / factor
296        let mut sample_scores = vec![0.0; n * n_axes];
297        // site constraints = [Z | F_res] / factor
298        let mut sample_constraints = vec![0.0; n * n_axes];
299        for i in 0..n {
300            for a in 0..n_axes {
301                let fa = factor(a);
302                let (samp, cons) = if a < rank {
303                    (f[(i, a)], z[(i, a)])
304                } else {
305                    let r = f_res[(i, a - rank)];
306                    (r, r)
307                };
308                sample_scores[i * n_axes + a] = samp / fa;
309                sample_constraints[i * n_axes + a] = cons / fa;
310            }
311        }
312
313        // biplot scores = corr(X, left singular vectors of Y_hat); spans the
314        // thin-SVD width of Y_hat, not the full set of canonical+residual axes.
315        let biplot = corr(&x, &svd.u);
316        let biplot_axes = biplot.ncols();
317        let mut biplot_scores = vec![0.0; m * biplot_axes];
318        for i in 0..m {
319            for a in 0..biplot_axes {
320                biplot_scores[i * biplot_axes + a] = biplot[(i, a)];
321            }
322        }
323
324        let total: f64 = eigenvalues.iter().sum();
325        let proportion_explained = eigenvalues.iter().map(|&e| e / total).collect();
326
327        Ok(Ordination {
328            sample_ids: response.row_ids.clone(),
329            species_ids: response.col_ids.clone(),
330            constraint_ids: constraints.col_ids.clone(),
331            eigvals: eigenvalues,
332            proportion_explained,
333            sample_scores,
334            species_scores,
335            biplot_scores,
336            biplot_axes,
337            sample_constraints,
338        })
339    }
340
341    /// Write the ordination as a flat TSV with `# eigenvalues`, `# samples`,
342    /// `# species`, `# biplot`, and `# site_constraints` blocks, axes `RDA1..`.
343    ///
344    /// # Errors
345    /// Propagates write errors.
346    pub fn write_tsv<W: Write>(&self, mut out: W) -> Result<()> {
347        let k = self.eigvals.len();
348        let mut line = String::new();
349
350        writeln!(out, "# eigenvalues").map_err(RsomicsError::Io)?;
351        write_axis_header(&mut out, k)?;
352        line.push_str("eigval");
353        for &v in &self.eigvals {
354            line.push('\t');
355            push_pyrepr(&mut line, v);
356        }
357        writeln!(out, "{line}").map_err(RsomicsError::Io)?;
358        line.clear();
359        line.push_str("proportion_explained");
360        for &v in &self.proportion_explained {
361            line.push('\t');
362            push_pyrepr(&mut line, v);
363        }
364        writeln!(out, "{line}").map_err(RsomicsError::Io)?;
365
366        write_block(
367            &mut out,
368            "# samples",
369            &self.sample_ids,
370            &self.sample_scores,
371            k,
372        )?;
373        write_block(
374            &mut out,
375            "# species",
376            &self.species_ids,
377            &self.species_scores,
378            k,
379        )?;
380        write_block(
381            &mut out,
382            "# biplot",
383            &self.constraint_ids,
384            &self.biplot_scores,
385            self.biplot_axes,
386        )?;
387        write_block(
388            &mut out,
389            "# site_constraints",
390            &self.sample_ids,
391            &self.sample_constraints,
392            k,
393        )
394    }
395}
396
397fn write_block<W: Write>(
398    out: &mut W,
399    title: &str,
400    ids: &[String],
401    scores: &[f64],
402    k: usize,
403) -> Result<()> {
404    writeln!(out, "{title}").map_err(RsomicsError::Io)?;
405    write_axis_header(out, k)?;
406    let mut line = String::new();
407    for (i, id) in ids.iter().enumerate() {
408        line.clear();
409        line.push_str(id);
410        for a in 0..k {
411            line.push('\t');
412            push_pyrepr(&mut line, scores[i * k + a]);
413        }
414        writeln!(out, "{line}").map_err(RsomicsError::Io)?;
415    }
416    Ok(())
417}
418
419fn write_axis_header<W: Write>(out: &mut W, k: usize) -> Result<()> {
420    let mut header = String::new();
421    for a in 1..=k {
422        header.push('\t');
423        header.push_str("RDA");
424        header.push_str(&a.to_string());
425    }
426    writeln!(out, "{header}").map_err(RsomicsError::Io)
427}
428
429/// Project the columns of `y` onto the column space of `x` via the SVD
430/// (minimum-norm least squares), giving `X (X⁺ Y) = U_x U_x' Y`.
431fn project_onto(x: &Mat<f64>, y: &Mat<f64>) -> Mat<f64> {
432    let svd = thin_svd(x);
433    let rank = svd_rank(x.nrows(), x.ncols(), &svd.s);
434    let n = x.nrows();
435    let p = y.ncols();
436    // c = U_x' Y over the kept left singular vectors.
437    let mut c = vec![0.0; rank * p];
438    for a in 0..rank {
439        for j in 0..p {
440            let mut acc = 0.0;
441            for i in 0..n {
442                acc += svd.u[(i, a)] * y[(i, j)];
443            }
444            c[a * p + j] = acc;
445        }
446    }
447    Mat::from_fn(n, p, |i, j| {
448        let mut acc = 0.0;
449        for a in 0..rank {
450            acc += svd.u[(i, a)] * c[a * p + j];
451        }
452        acc
453    })
454}
455
456fn matmul(a: &Mat<f64>, b: &Mat<f64>) -> Mat<f64> {
457    a * b
458}
459
460/// First `k` rows of `vt` returned as columns (i.e. `vt[:k].T`).
461fn vt_rows_as_cols(vt: &Mat<f64>, k: usize) -> Mat<f64> {
462    Mat::from_fn(vt.ncols(), k, |i, j| vt[(j, i)])
463}
464
465/// # Errors
466/// Propagates parse, compute, and write errors.
467pub fn run<W: Write>(
468    response: &Matrix,
469    constraints: &Matrix,
470    out: W,
471    scaling: u8,
472    scale_y: bool,
473) -> Result<()> {
474    let ord = Ordination::compute(response, constraints, scaling, scale_y)?;
475    ord.write_tsv(out)
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    fn response() -> &'static str {
483        "\tSp1\tSp2\tSp3\n\
484         S1\t1\t0\t2\n\
485         S2\t0\t3\t1\n\
486         S3\t2\t1\t0\n\
487         S4\t3\t2\t1\n\
488         S5\t1\t4\t2\n"
489    }
490
491    fn constraints() -> &'static str {
492        "\tE1\tE2\n\
493         S1\t1.0\t0.5\n\
494         S2\t0.0\t1.0\n\
495         S3\t2.0\t0.2\n\
496         S4\t1.5\t0.8\n\
497         S5\t0.5\t1.2\n"
498    }
499
500    #[test]
501    fn parses_matrix() {
502        let m = Matrix::parse(response().as_bytes(), '\t').unwrap();
503        assert_eq!(m.row_ids, ["S1", "S2", "S3", "S4", "S5"]);
504        assert_eq!(m.col_ids, ["Sp1", "Sp2", "Sp3"]);
505        assert_eq!(m.data[3 * 3], 3.0);
506    }
507
508    #[test]
509    fn mismatched_rows_error() {
510        let y = Matrix::parse(response().as_bytes(), '\t').unwrap();
511        let bad = "\tE1\nS1\t1\nS2\t2\n";
512        let x = Matrix::parse(bad.as_bytes(), '\t').unwrap();
513        assert!(Ordination::compute(&y, &x, 1, false).is_err());
514    }
515
516    #[test]
517    fn proportion_sums_to_one() {
518        let y = Matrix::parse(response().as_bytes(), '\t').unwrap();
519        let x = Matrix::parse(constraints().as_bytes(), '\t').unwrap();
520        let o = Ordination::compute(&y, &x, 1, false).unwrap();
521        let s: f64 = o.proportion_explained.iter().sum();
522        assert!((s - 1.0).abs() < 1e-12);
523    }
524
525    /// Constrained axes ≤ min(species, constraints, n-1); residual axes fill the rest.
526    #[test]
527    fn axis_counts() {
528        let y = Matrix::parse(response().as_bytes(), '\t').unwrap();
529        let x = Matrix::parse(constraints().as_bytes(), '\t').unwrap();
530        let o = Ordination::compute(&y, &x, 1, false).unwrap();
531        assert!(!o.eigvals.is_empty());
532        assert_eq!(o.sample_scores.len(), o.sample_ids.len() * o.eigvals.len());
533        assert_eq!(
534            o.species_scores.len(),
535            o.species_ids.len() * o.eigvals.len()
536        );
537        assert_eq!(
538            o.biplot_scores.len(),
539            o.constraint_ids.len() * o.biplot_axes
540        );
541    }
542}