use std::io::{BufRead, Write};
use faer::Mat;
use faer::linalg::solvers::Svd;
use rsomics_common::{Result, RsomicsError};
mod fmt;
use fmt::push_pyrepr;
pub struct Matrix {
pub row_ids: Vec<String>,
pub col_ids: Vec<String>,
pub data: Vec<f64>,
}
impl Matrix {
pub fn parse<R: BufRead>(reader: R, delim: char) -> Result<Matrix> {
let mut lines = reader.lines();
let header = loop {
match lines.next() {
Some(line) => {
let line = line.map_err(RsomicsError::Io)?;
if line.trim().is_empty() || line.starts_with('#') {
continue;
}
break line;
}
None => return Err(RsomicsError::InvalidInput("empty table".into())),
}
};
let col_ids: Vec<String> = header
.split(delim)
.skip(1)
.map(|s| s.trim().to_string())
.collect();
let p = col_ids.len();
if p == 0 {
return Err(RsomicsError::InvalidInput(
"header has no value columns (need an empty top-left cell + ≥1 column)".into(),
));
}
let mut row_ids = Vec::new();
let mut data = Vec::new();
for line in lines {
let line = line.map_err(RsomicsError::Io)?;
if line.trim().is_empty() || line.starts_with('#') {
continue;
}
let mut fields = line.split(delim);
let label = fields.next().unwrap_or("").trim().to_string();
let row_start = data.len();
for field in fields {
let v: f64 = field.trim().parse().map_err(|_| {
RsomicsError::InvalidInput(format!(
"row '{label}', column {}: '{}' is not numeric",
data.len() - row_start + 1,
field.trim()
))
})?;
data.push(v);
}
let got = data.len() - row_start;
if got != p {
return Err(RsomicsError::InvalidInput(format!(
"row '{label}' has {got} values, expected {p}"
)));
}
row_ids.push(label);
}
if row_ids.is_empty() {
return Err(RsomicsError::InvalidInput("no data rows".into()));
}
Ok(Matrix {
row_ids,
col_ids,
data,
})
}
#[must_use]
pub fn n_rows(&self) -> usize {
self.row_ids.len()
}
#[must_use]
pub fn n_cols(&self) -> usize {
self.col_ids.len()
}
fn to_mat(&self) -> Mat<f64> {
let c = self.n_cols();
Mat::from_fn(self.n_rows(), c, |i, j| self.data[i * c + j])
}
}
pub struct Ordination {
pub sample_ids: Vec<String>,
pub species_ids: Vec<String>,
pub constraint_ids: Vec<String>,
pub eigvals: Vec<f64>,
pub proportion_explained: Vec<f64>,
pub sample_scores: Vec<f64>,
pub species_scores: Vec<f64>,
pub biplot_scores: Vec<f64>,
pub biplot_axes: usize,
pub sample_constraints: Vec<f64>,
}
struct ThinSvd {
u: Mat<f64>,
s: Vec<f64>,
vt: Mat<f64>,
}
fn thin_svd(m: &Mat<f64>) -> ThinSvd {
let svd: Svd<f64> = m.thin_svd().unwrap();
let sv = svd.S().column_vector();
let k = sv.nrows();
let s = (0..k).map(|i| sv[i]).collect();
let u = svd.U().to_owned();
let v = svd.V();
let vt = Mat::from_fn(v.ncols(), v.nrows(), |i, j| v[(j, i)]);
ThinSvd { u, s, vt }
}
fn svd_rank(rows: usize, cols: usize, s: &[f64]) -> usize {
let smax = s.iter().fold(0.0_f64, |m, &v| m.max(v));
let tol = smax * rows.max(cols) as f64 * f64::EPSILON;
s.iter().filter(|&&v| v > tol).count()
}
fn center_columns(m: &mut Mat<f64>) {
let n = m.nrows();
for j in 0..m.ncols() {
let mut mean = 0.0;
for i in 0..n {
mean += m[(i, j)];
}
mean /= n as f64;
for i in 0..n {
m[(i, j)] -= mean;
}
}
}
fn scale_columns_std(m: &mut Mat<f64>) {
let n = m.nrows();
for j in 0..m.ncols() {
let mut var = 0.0;
for i in 0..n {
var += m[(i, j)] * m[(i, j)];
}
let mut std = (var / n as f64).sqrt();
if std == 0.0 {
std = 1.0;
}
for i in 0..n {
m[(i, j)] /= std;
}
}
}
fn corr(x: &Mat<f64>, y: &Mat<f64>) -> Mat<f64> {
let n = x.nrows();
let mut xs = x.clone();
center_columns(&mut xs);
scale_columns_std(&mut xs);
let mut ys = y.clone();
center_columns(&mut ys);
scale_columns_std(&mut ys);
let p = xs.ncols();
let q = ys.ncols();
Mat::from_fn(p, q, |i, j| {
let mut acc = 0.0;
for r in 0..n {
acc += xs[(r, i)] * ys[(r, j)];
}
acc / n as f64
})
}
impl Ordination {
pub fn compute(
response: &Matrix,
constraints: &Matrix,
scaling: u8,
scale_y: bool,
) -> Result<Ordination> {
let n = response.n_rows();
let m = constraints.n_cols();
if constraints.n_rows() != n {
return Err(RsomicsError::InvalidInput(format!(
"response has {n} samples but constraints have {}",
constraints.n_rows()
)));
}
if n < m {
return Err(RsomicsError::InvalidInput(format!(
"constraints cannot have fewer rows ({n}) than columns ({m})"
)));
}
let mut y = response.to_mat();
center_columns(&mut y);
if scale_y {
scale_columns_std(&mut y);
}
let mut x = constraints.to_mat();
center_columns(&mut x);
let y_hat = project_onto(&x, &y);
let svd = thin_svd(&y_hat);
let rank = svd_rank(y_hat.nrows(), y_hat.ncols(), &svd.s);
let u_axes = vt_rows_as_cols(&svd.vt, rank);
let f = matmul(&y, &u_axes); let z = matmul(&y_hat, &u_axes);
let y_res = &y - &y_hat;
let svd_res = thin_svd(&y_res);
let rank_res = svd_rank(y_res.nrows(), y_res.ncols(), &svd_res.s);
let u_res = vt_rows_as_cols(&svd_res.vt, rank_res); let f_res = matmul(&y_res, &u_res);
let mut eigenvalues: Vec<f64> = svd.s[..rank].to_vec();
eigenvalues.extend_from_slice(&svd_res.s[..rank_res]);
let n_axes = eigenvalues.len();
let p = response.n_cols();
if scaling != 1 && scaling != 2 {
return Err(RsomicsError::InvalidInput(
"only scaling 1 or 2 is available for RDA".into(),
));
}
let const_factor = eigenvalues
.iter()
.map(|&e| e * e)
.sum::<f64>()
.sqrt()
.sqrt();
let factor = |a: usize| -> f64 {
if scaling == 1 {
const_factor
} else {
eigenvalues[a] / const_factor
}
};
let mut species_scores = vec![0.0; p * n_axes];
for j in 0..p {
for a in 0..n_axes {
let v = if a < rank {
u_axes[(j, a)]
} else {
u_res[(j, a - rank)]
};
species_scores[j * n_axes + a] = v * factor(a);
}
}
let mut sample_scores = vec![0.0; n * n_axes];
let mut sample_constraints = vec![0.0; n * n_axes];
for i in 0..n {
for a in 0..n_axes {
let fa = factor(a);
let (samp, cons) = if a < rank {
(f[(i, a)], z[(i, a)])
} else {
let r = f_res[(i, a - rank)];
(r, r)
};
sample_scores[i * n_axes + a] = samp / fa;
sample_constraints[i * n_axes + a] = cons / fa;
}
}
let biplot = corr(&x, &svd.u);
let biplot_axes = biplot.ncols();
let mut biplot_scores = vec![0.0; m * biplot_axes];
for i in 0..m {
for a in 0..biplot_axes {
biplot_scores[i * biplot_axes + a] = biplot[(i, a)];
}
}
let total: f64 = eigenvalues.iter().sum();
let proportion_explained = eigenvalues.iter().map(|&e| e / total).collect();
Ok(Ordination {
sample_ids: response.row_ids.clone(),
species_ids: response.col_ids.clone(),
constraint_ids: constraints.col_ids.clone(),
eigvals: eigenvalues,
proportion_explained,
sample_scores,
species_scores,
biplot_scores,
biplot_axes,
sample_constraints,
})
}
pub fn write_tsv<W: Write>(&self, mut out: W) -> Result<()> {
let k = self.eigvals.len();
let mut line = String::new();
writeln!(out, "# eigenvalues").map_err(RsomicsError::Io)?;
write_axis_header(&mut out, k)?;
line.push_str("eigval");
for &v in &self.eigvals {
line.push('\t');
push_pyrepr(&mut line, v);
}
writeln!(out, "{line}").map_err(RsomicsError::Io)?;
line.clear();
line.push_str("proportion_explained");
for &v in &self.proportion_explained {
line.push('\t');
push_pyrepr(&mut line, v);
}
writeln!(out, "{line}").map_err(RsomicsError::Io)?;
write_block(
&mut out,
"# samples",
&self.sample_ids,
&self.sample_scores,
k,
)?;
write_block(
&mut out,
"# species",
&self.species_ids,
&self.species_scores,
k,
)?;
write_block(
&mut out,
"# biplot",
&self.constraint_ids,
&self.biplot_scores,
self.biplot_axes,
)?;
write_block(
&mut out,
"# site_constraints",
&self.sample_ids,
&self.sample_constraints,
k,
)
}
}
fn write_block<W: Write>(
out: &mut W,
title: &str,
ids: &[String],
scores: &[f64],
k: usize,
) -> Result<()> {
writeln!(out, "{title}").map_err(RsomicsError::Io)?;
write_axis_header(out, k)?;
let mut line = String::new();
for (i, id) in ids.iter().enumerate() {
line.clear();
line.push_str(id);
for a in 0..k {
line.push('\t');
push_pyrepr(&mut line, scores[i * k + a]);
}
writeln!(out, "{line}").map_err(RsomicsError::Io)?;
}
Ok(())
}
fn write_axis_header<W: Write>(out: &mut W, k: usize) -> Result<()> {
let mut header = String::new();
for a in 1..=k {
header.push('\t');
header.push_str("RDA");
header.push_str(&a.to_string());
}
writeln!(out, "{header}").map_err(RsomicsError::Io)
}
fn project_onto(x: &Mat<f64>, y: &Mat<f64>) -> Mat<f64> {
let svd = thin_svd(x);
let rank = svd_rank(x.nrows(), x.ncols(), &svd.s);
let n = x.nrows();
let p = y.ncols();
let mut c = vec![0.0; rank * p];
for a in 0..rank {
for j in 0..p {
let mut acc = 0.0;
for i in 0..n {
acc += svd.u[(i, a)] * y[(i, j)];
}
c[a * p + j] = acc;
}
}
Mat::from_fn(n, p, |i, j| {
let mut acc = 0.0;
for a in 0..rank {
acc += svd.u[(i, a)] * c[a * p + j];
}
acc
})
}
fn matmul(a: &Mat<f64>, b: &Mat<f64>) -> Mat<f64> {
a * b
}
fn vt_rows_as_cols(vt: &Mat<f64>, k: usize) -> Mat<f64> {
Mat::from_fn(vt.ncols(), k, |i, j| vt[(j, i)])
}
pub fn run<W: Write>(
response: &Matrix,
constraints: &Matrix,
out: W,
scaling: u8,
scale_y: bool,
) -> Result<()> {
let ord = Ordination::compute(response, constraints, scaling, scale_y)?;
ord.write_tsv(out)
}
#[cfg(test)]
mod tests {
use super::*;
fn response() -> &'static str {
"\tSp1\tSp2\tSp3\n\
S1\t1\t0\t2\n\
S2\t0\t3\t1\n\
S3\t2\t1\t0\n\
S4\t3\t2\t1\n\
S5\t1\t4\t2\n"
}
fn constraints() -> &'static str {
"\tE1\tE2\n\
S1\t1.0\t0.5\n\
S2\t0.0\t1.0\n\
S3\t2.0\t0.2\n\
S4\t1.5\t0.8\n\
S5\t0.5\t1.2\n"
}
#[test]
fn parses_matrix() {
let m = Matrix::parse(response().as_bytes(), '\t').unwrap();
assert_eq!(m.row_ids, ["S1", "S2", "S3", "S4", "S5"]);
assert_eq!(m.col_ids, ["Sp1", "Sp2", "Sp3"]);
assert_eq!(m.data[3 * 3], 3.0);
}
#[test]
fn mismatched_rows_error() {
let y = Matrix::parse(response().as_bytes(), '\t').unwrap();
let bad = "\tE1\nS1\t1\nS2\t2\n";
let x = Matrix::parse(bad.as_bytes(), '\t').unwrap();
assert!(Ordination::compute(&y, &x, 1, false).is_err());
}
#[test]
fn proportion_sums_to_one() {
let y = Matrix::parse(response().as_bytes(), '\t').unwrap();
let x = Matrix::parse(constraints().as_bytes(), '\t').unwrap();
let o = Ordination::compute(&y, &x, 1, false).unwrap();
let s: f64 = o.proportion_explained.iter().sum();
assert!((s - 1.0).abs() < 1e-12);
}
#[test]
fn axis_counts() {
let y = Matrix::parse(response().as_bytes(), '\t').unwrap();
let x = Matrix::parse(constraints().as_bytes(), '\t').unwrap();
let o = Ordination::compute(&y, &x, 1, false).unwrap();
assert!(!o.eigvals.is_empty());
assert_eq!(o.sample_scores.len(), o.sample_ids.len() * o.eigvals.len());
assert_eq!(
o.species_scores.len(),
o.species_ids.len() * o.eigvals.len()
);
assert_eq!(
o.biplot_scores.len(),
o.constraint_ids.len() * o.biplot_axes
);
}
}