1use std::io::BufRead;
2
3use rsomics_common::{Result, RsomicsError};
4
5pub struct DistanceMatrix {
7 pub ids: Vec<String>,
8 pub data: Vec<f64>,
9 pub n: usize,
10}
11
12impl DistanceMatrix {
13 pub fn read<R: BufRead>(reader: R, source: &str) -> Result<Self> {
14 let mut lines = reader.lines();
15
16 let header = lines
17 .next()
18 .ok_or_else(|| RsomicsError::InvalidInput(format!("{source}: empty matrix")))?
19 .map_err(RsomicsError::Io)?;
20
21 let ids: Vec<String> = header
23 .split('\t')
24 .skip(1)
25 .map(str::trim)
26 .map(str::to_owned)
27 .collect();
28 let n = ids.len();
29 if n < 3 {
30 return Err(RsomicsError::InvalidInput(format!(
31 "{source}: need at least 3 ids, found {n}"
32 )));
33 }
34
35 let mut data = vec![0.0f64; n * n];
36 let mut row = 0;
37 for line in lines {
38 let line = line.map_err(RsomicsError::Io)?;
39 if line.is_empty() {
40 continue;
41 }
42 if row >= n {
43 return Err(RsomicsError::InvalidInput(format!(
44 "{source}: more rows than the {n} ids in the header"
45 )));
46 }
47 let mut fields = line.split('\t');
48 let rid = fields.next().unwrap_or("").trim();
49 if rid != ids[row] {
50 return Err(RsomicsError::InvalidInput(format!(
51 "{source}: row {row} id '{rid}' != header id '{}'",
52 ids[row]
53 )));
54 }
55 let mut col = 0;
56 for f in fields {
57 if col >= n {
58 return Err(RsomicsError::InvalidInput(format!(
59 "{source}: row {row} has more than {n} values"
60 )));
61 }
62 data[row * n + col] = f.trim().parse::<f64>().map_err(|_| {
63 RsomicsError::InvalidInput(format!(
64 "{source}: row {row} col {col}: not a number: '{f}'"
65 ))
66 })?;
67 col += 1;
68 }
69 if col != n {
70 return Err(RsomicsError::InvalidInput(format!(
71 "{source}: row {row} has {col} values, expected {n}"
72 )));
73 }
74 row += 1;
75 }
76 if row != n {
77 return Err(RsomicsError::InvalidInput(format!(
78 "{source}: {row} data rows, expected {n}"
79 )));
80 }
81
82 Ok(DistanceMatrix { ids, data, n })
83 }
84
85 pub fn condensed(&self) -> Vec<f64> {
87 let n = self.n;
88 let mut v = Vec::with_capacity(n * (n - 1) / 2);
89 for i in 0..n {
90 for j in (i + 1)..n {
91 v.push(self.data[i * n + j]);
92 }
93 }
94 v
95 }
96}