1use std::io::{BufRead, Write};
2
3use faer::Mat;
4use faer::linalg::solvers::Svd;
5use rsomics_common::{Result, RsomicsError};
6
7mod fmt;
8use fmt::push_pyrepr;
9
10pub struct Matrix {
15 pub row_ids: Vec<String>,
16 pub col_ids: Vec<String>,
17 pub data: Vec<f64>,
19}
20
21impl Matrix {
22 pub fn parse<R: BufRead>(reader: R, delim: char) -> Result<Matrix> {
25 let mut lines = reader.lines();
26 let header = loop {
27 match lines.next() {
28 Some(line) => {
29 let line = line.map_err(RsomicsError::Io)?;
30 if line.trim().is_empty() || line.starts_with('#') {
31 continue;
32 }
33 break line;
34 }
35 None => return Err(RsomicsError::InvalidInput("empty table".into())),
36 }
37 };
38 let col_ids: Vec<String> = header
39 .split(delim)
40 .skip(1)
41 .map(|s| s.trim().to_string())
42 .collect();
43 let p = col_ids.len();
44 if p == 0 {
45 return Err(RsomicsError::InvalidInput(
46 "header has no value columns (need an empty top-left cell + ≥1 column)".into(),
47 ));
48 }
49
50 let mut row_ids = Vec::new();
51 let mut data = Vec::new();
52 for line in lines {
53 let line = line.map_err(RsomicsError::Io)?;
54 if line.trim().is_empty() || line.starts_with('#') {
55 continue;
56 }
57 let mut fields = line.split(delim);
58 let label = fields.next().unwrap_or("").trim().to_string();
59 let row_start = data.len();
60 for field in fields {
61 let v: f64 = field.trim().parse().map_err(|_| {
62 RsomicsError::InvalidInput(format!(
63 "row '{label}', column {}: '{}' is not numeric",
64 data.len() - row_start + 1,
65 field.trim()
66 ))
67 })?;
68 data.push(v);
69 }
70 let got = data.len() - row_start;
71 if got != p {
72 return Err(RsomicsError::InvalidInput(format!(
73 "row '{label}' has {got} values, expected {p}"
74 )));
75 }
76 row_ids.push(label);
77 }
78 if row_ids.is_empty() {
79 return Err(RsomicsError::InvalidInput("no data rows".into()));
80 }
81 Ok(Matrix {
82 row_ids,
83 col_ids,
84 data,
85 })
86 }
87
88 #[must_use]
89 pub fn n_rows(&self) -> usize {
90 self.row_ids.len()
91 }
92
93 #[must_use]
94 pub fn n_cols(&self) -> usize {
95 self.col_ids.len()
96 }
97
98 fn to_mat(&self) -> Mat<f64> {
99 let c = self.n_cols();
100 Mat::from_fn(self.n_rows(), c, |i, j| self.data[i * c + j])
101 }
102}
103
104pub struct Ordination {
110 pub sample_ids: Vec<String>,
111 pub species_ids: Vec<String>,
112 pub constraint_ids: Vec<String>,
113 pub eigvals: Vec<f64>,
114 pub proportion_explained: Vec<f64>,
115 pub sample_scores: Vec<f64>,
117 pub species_scores: Vec<f64>,
119 pub biplot_scores: Vec<f64>,
122 pub biplot_axes: usize,
123 pub sample_constraints: Vec<f64>,
125}
126
127struct ThinSvd {
128 u: Mat<f64>,
129 s: Vec<f64>,
130 vt: Mat<f64>,
131}
132
133fn thin_svd(m: &Mat<f64>) -> ThinSvd {
134 let svd: Svd<f64> = m.thin_svd().unwrap();
135 let sv = svd.S().column_vector();
136 let k = sv.nrows();
137 let s = (0..k).map(|i| sv[i]).collect();
138 let u = svd.U().to_owned();
139 let v = svd.V();
140 let vt = Mat::from_fn(v.ncols(), v.nrows(), |i, j| v[(j, i)]);
141 ThinSvd { u, s, vt }
142}
143
144fn svd_rank(rows: usize, cols: usize, s: &[f64]) -> usize {
146 let smax = s.iter().fold(0.0_f64, |m, &v| m.max(v));
147 let tol = smax * rows.max(cols) as f64 * f64::EPSILON;
148 s.iter().filter(|&&v| v > tol).count()
149}
150
151fn center_columns(m: &mut Mat<f64>) {
153 let n = m.nrows();
154 for j in 0..m.ncols() {
155 let mut mean = 0.0;
156 for i in 0..n {
157 mean += m[(i, j)];
158 }
159 mean /= n as f64;
160 for i in 0..n {
161 m[(i, j)] -= mean;
162 }
163 }
164}
165
166fn scale_columns_std(m: &mut Mat<f64>) {
168 let n = m.nrows();
169 for j in 0..m.ncols() {
170 let mut var = 0.0;
171 for i in 0..n {
172 var += m[(i, j)] * m[(i, j)];
173 }
174 let mut std = (var / n as f64).sqrt();
175 if std == 0.0 {
176 std = 1.0;
177 }
178 for i in 0..n {
179 m[(i, j)] /= std;
180 }
181 }
182}
183
184fn corr(x: &Mat<f64>, y: &Mat<f64>) -> Mat<f64> {
187 let n = x.nrows();
188 let mut xs = x.clone();
189 center_columns(&mut xs);
190 scale_columns_std(&mut xs);
191 let mut ys = y.clone();
192 center_columns(&mut ys);
193 scale_columns_std(&mut ys);
194 let p = xs.ncols();
195 let q = ys.ncols();
196 Mat::from_fn(p, q, |i, j| {
197 let mut acc = 0.0;
198 for r in 0..n {
199 acc += xs[(r, i)] * ys[(r, j)];
200 }
201 acc / n as f64
202 })
203}
204
205impl Ordination {
206 pub fn compute(
214 response: &Matrix,
215 constraints: &Matrix,
216 scaling: u8,
217 scale_y: bool,
218 ) -> Result<Ordination> {
219 let n = response.n_rows();
220 let m = constraints.n_cols();
221 if constraints.n_rows() != n {
222 return Err(RsomicsError::InvalidInput(format!(
223 "response has {n} samples but constraints have {}",
224 constraints.n_rows()
225 )));
226 }
227 if n < m {
228 return Err(RsomicsError::InvalidInput(format!(
229 "constraints cannot have fewer rows ({n}) than columns ({m})"
230 )));
231 }
232
233 let mut y = response.to_mat();
234 center_columns(&mut y);
235 if scale_y {
236 scale_columns_std(&mut y);
237 }
238 let mut x = constraints.to_mat();
239 center_columns(&mut x);
240
241 let y_hat = project_onto(&x, &y);
244
245 let svd = thin_svd(&y_hat);
246 let rank = svd_rank(y_hat.nrows(), y_hat.ncols(), &svd.s);
247 let u_axes = vt_rows_as_cols(&svd.vt, rank); let f = matmul(&y, &u_axes); let z = matmul(&y_hat, &u_axes); let y_res = &y - &y_hat;
253 let svd_res = thin_svd(&y_res);
254 let rank_res = svd_rank(y_res.nrows(), y_res.ncols(), &svd_res.s);
255 let u_res = vt_rows_as_cols(&svd_res.vt, rank_res); let f_res = matmul(&y_res, &u_res); let mut eigenvalues: Vec<f64> = svd.s[..rank].to_vec();
259 eigenvalues.extend_from_slice(&svd_res.s[..rank_res]);
260 let n_axes = eigenvalues.len();
261 let p = response.n_cols();
262
263 if scaling != 1 && scaling != 2 {
264 return Err(RsomicsError::InvalidInput(
265 "only scaling 1 or 2 is available for RDA".into(),
266 ));
267 }
268 let const_factor = eigenvalues
269 .iter()
270 .map(|&e| e * e)
271 .sum::<f64>()
272 .sqrt()
273 .sqrt();
274 let factor = |a: usize| -> f64 {
276 if scaling == 1 {
277 const_factor
278 } else {
279 eigenvalues[a] / const_factor
280 }
281 };
282
283 let mut species_scores = vec![0.0; p * n_axes];
285 for j in 0..p {
286 for a in 0..n_axes {
287 let v = if a < rank {
288 u_axes[(j, a)]
289 } else {
290 u_res[(j, a - rank)]
291 };
292 species_scores[j * n_axes + a] = v * factor(a);
293 }
294 }
295 let mut sample_scores = vec![0.0; n * n_axes];
297 let mut sample_constraints = vec![0.0; n * n_axes];
299 for i in 0..n {
300 for a in 0..n_axes {
301 let fa = factor(a);
302 let (samp, cons) = if a < rank {
303 (f[(i, a)], z[(i, a)])
304 } else {
305 let r = f_res[(i, a - rank)];
306 (r, r)
307 };
308 sample_scores[i * n_axes + a] = samp / fa;
309 sample_constraints[i * n_axes + a] = cons / fa;
310 }
311 }
312
313 let biplot = corr(&x, &svd.u);
316 let biplot_axes = biplot.ncols();
317 let mut biplot_scores = vec![0.0; m * biplot_axes];
318 for i in 0..m {
319 for a in 0..biplot_axes {
320 biplot_scores[i * biplot_axes + a] = biplot[(i, a)];
321 }
322 }
323
324 let total: f64 = eigenvalues.iter().sum();
325 let proportion_explained = eigenvalues.iter().map(|&e| e / total).collect();
326
327 Ok(Ordination {
328 sample_ids: response.row_ids.clone(),
329 species_ids: response.col_ids.clone(),
330 constraint_ids: constraints.col_ids.clone(),
331 eigvals: eigenvalues,
332 proportion_explained,
333 sample_scores,
334 species_scores,
335 biplot_scores,
336 biplot_axes,
337 sample_constraints,
338 })
339 }
340
341 pub fn write_tsv<W: Write>(&self, mut out: W) -> Result<()> {
347 let k = self.eigvals.len();
348 let mut line = String::new();
349
350 writeln!(out, "# eigenvalues").map_err(RsomicsError::Io)?;
351 write_axis_header(&mut out, k)?;
352 line.push_str("eigval");
353 for &v in &self.eigvals {
354 line.push('\t');
355 push_pyrepr(&mut line, v);
356 }
357 writeln!(out, "{line}").map_err(RsomicsError::Io)?;
358 line.clear();
359 line.push_str("proportion_explained");
360 for &v in &self.proportion_explained {
361 line.push('\t');
362 push_pyrepr(&mut line, v);
363 }
364 writeln!(out, "{line}").map_err(RsomicsError::Io)?;
365
366 write_block(
367 &mut out,
368 "# samples",
369 &self.sample_ids,
370 &self.sample_scores,
371 k,
372 )?;
373 write_block(
374 &mut out,
375 "# species",
376 &self.species_ids,
377 &self.species_scores,
378 k,
379 )?;
380 write_block(
381 &mut out,
382 "# biplot",
383 &self.constraint_ids,
384 &self.biplot_scores,
385 self.biplot_axes,
386 )?;
387 write_block(
388 &mut out,
389 "# site_constraints",
390 &self.sample_ids,
391 &self.sample_constraints,
392 k,
393 )
394 }
395}
396
397fn write_block<W: Write>(
398 out: &mut W,
399 title: &str,
400 ids: &[String],
401 scores: &[f64],
402 k: usize,
403) -> Result<()> {
404 writeln!(out, "{title}").map_err(RsomicsError::Io)?;
405 write_axis_header(out, k)?;
406 let mut line = String::new();
407 for (i, id) in ids.iter().enumerate() {
408 line.clear();
409 line.push_str(id);
410 for a in 0..k {
411 line.push('\t');
412 push_pyrepr(&mut line, scores[i * k + a]);
413 }
414 writeln!(out, "{line}").map_err(RsomicsError::Io)?;
415 }
416 Ok(())
417}
418
419fn write_axis_header<W: Write>(out: &mut W, k: usize) -> Result<()> {
420 let mut header = String::new();
421 for a in 1..=k {
422 header.push('\t');
423 header.push_str("RDA");
424 header.push_str(&a.to_string());
425 }
426 writeln!(out, "{header}").map_err(RsomicsError::Io)
427}
428
429fn project_onto(x: &Mat<f64>, y: &Mat<f64>) -> Mat<f64> {
432 let svd = thin_svd(x);
433 let rank = svd_rank(x.nrows(), x.ncols(), &svd.s);
434 let n = x.nrows();
435 let p = y.ncols();
436 let mut c = vec![0.0; rank * p];
438 for a in 0..rank {
439 for j in 0..p {
440 let mut acc = 0.0;
441 for i in 0..n {
442 acc += svd.u[(i, a)] * y[(i, j)];
443 }
444 c[a * p + j] = acc;
445 }
446 }
447 Mat::from_fn(n, p, |i, j| {
448 let mut acc = 0.0;
449 for a in 0..rank {
450 acc += svd.u[(i, a)] * c[a * p + j];
451 }
452 acc
453 })
454}
455
456fn matmul(a: &Mat<f64>, b: &Mat<f64>) -> Mat<f64> {
457 a * b
458}
459
460fn vt_rows_as_cols(vt: &Mat<f64>, k: usize) -> Mat<f64> {
462 Mat::from_fn(vt.ncols(), k, |i, j| vt[(j, i)])
463}
464
465pub fn run<W: Write>(
468 response: &Matrix,
469 constraints: &Matrix,
470 out: W,
471 scaling: u8,
472 scale_y: bool,
473) -> Result<()> {
474 let ord = Ordination::compute(response, constraints, scaling, scale_y)?;
475 ord.write_tsv(out)
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481
482 fn response() -> &'static str {
483 "\tSp1\tSp2\tSp3\n\
484 S1\t1\t0\t2\n\
485 S2\t0\t3\t1\n\
486 S3\t2\t1\t0\n\
487 S4\t3\t2\t1\n\
488 S5\t1\t4\t2\n"
489 }
490
491 fn constraints() -> &'static str {
492 "\tE1\tE2\n\
493 S1\t1.0\t0.5\n\
494 S2\t0.0\t1.0\n\
495 S3\t2.0\t0.2\n\
496 S4\t1.5\t0.8\n\
497 S5\t0.5\t1.2\n"
498 }
499
500 #[test]
501 fn parses_matrix() {
502 let m = Matrix::parse(response().as_bytes(), '\t').unwrap();
503 assert_eq!(m.row_ids, ["S1", "S2", "S3", "S4", "S5"]);
504 assert_eq!(m.col_ids, ["Sp1", "Sp2", "Sp3"]);
505 assert_eq!(m.data[3 * 3], 3.0);
506 }
507
508 #[test]
509 fn mismatched_rows_error() {
510 let y = Matrix::parse(response().as_bytes(), '\t').unwrap();
511 let bad = "\tE1\nS1\t1\nS2\t2\n";
512 let x = Matrix::parse(bad.as_bytes(), '\t').unwrap();
513 assert!(Ordination::compute(&y, &x, 1, false).is_err());
514 }
515
516 #[test]
517 fn proportion_sums_to_one() {
518 let y = Matrix::parse(response().as_bytes(), '\t').unwrap();
519 let x = Matrix::parse(constraints().as_bytes(), '\t').unwrap();
520 let o = Ordination::compute(&y, &x, 1, false).unwrap();
521 let s: f64 = o.proportion_explained.iter().sum();
522 assert!((s - 1.0).abs() < 1e-12);
523 }
524
525 #[test]
527 fn axis_counts() {
528 let y = Matrix::parse(response().as_bytes(), '\t').unwrap();
529 let x = Matrix::parse(constraints().as_bytes(), '\t').unwrap();
530 let o = Ordination::compute(&y, &x, 1, false).unwrap();
531 assert!(!o.eigvals.is_empty());
532 assert_eq!(o.sample_scores.len(), o.sample_ids.len() * o.eigvals.len());
533 assert_eq!(
534 o.species_scores.len(),
535 o.species_ids.len() * o.eigvals.len()
536 );
537 assert_eq!(
538 o.biplot_scores.len(),
539 o.constraint_ids.len() * o.biplot_axes
540 );
541 }
542}