1use std::collections::HashMap;
2use std::io::{BufRead, Write};
3
4use rsomics_common::{Result, RsomicsError};
5use rsomics_phylo_tree::{NodeId, Tree};
6
7pub struct CountTable {
8 pub feature_ids: Vec<String>,
9 pub sample_names: Vec<String>,
10 pub columns: Vec<Vec<u64>>,
12}
13
14impl CountTable {
15 pub fn parse<R: BufRead>(reader: R, delim: char) -> Result<CountTable> {
18 let mut lines = reader.lines();
19 let header = loop {
20 match lines.next() {
21 Some(line) => {
22 let line = line.map_err(RsomicsError::Io)?;
23 if line.trim().is_empty() || line.starts_with('#') {
24 continue;
25 }
26 break line;
27 }
28 None => return Err(RsomicsError::InvalidInput("empty count table".into())),
29 }
30 };
31 let sample_names: Vec<String> = header
32 .split(delim)
33 .skip(1)
34 .map(|s| s.trim().to_string())
35 .collect();
36 if sample_names.is_empty() {
37 return Err(RsomicsError::InvalidInput(
38 "header has no sample columns (need feature-ID column + ≥1 sample)".into(),
39 ));
40 }
41 let n = sample_names.len();
42 let mut feature_ids = Vec::new();
43 let mut columns: Vec<Vec<u64>> = vec![Vec::new(); n];
44 for (row_idx, line) in lines.enumerate() {
45 let line = line.map_err(RsomicsError::Io)?;
46 if line.trim().is_empty() || line.starts_with('#') {
47 continue;
48 }
49 let mut fields = line.split(delim);
50 let feature = fields.next().unwrap_or("").trim().to_string();
51 let mut seen = 0usize;
52 for (col, field) in fields.enumerate() {
53 if col >= n {
54 return Err(RsomicsError::InvalidInput(format!(
55 "row {} (feature '{feature}') has more columns than the header",
56 row_idx + 2
57 )));
58 }
59 let count: u64 = field.trim().parse().map_err(|_| {
60 RsomicsError::InvalidInput(format!(
61 "row {} (feature '{feature}'), sample '{}': '{}' is not a non-negative integer count",
62 row_idx + 2,
63 sample_names[col],
64 field.trim()
65 ))
66 })?;
67 columns[col].push(count);
68 seen += 1;
69 }
70 if seen != n {
71 return Err(RsomicsError::InvalidInput(format!(
72 "row {} (feature '{feature}') has {seen} count columns, header has {n}",
73 row_idx + 2
74 )));
75 }
76 feature_ids.push(feature);
77 }
78 Ok(CountTable {
79 feature_ids,
80 sample_names,
81 columns,
82 })
83 }
84}
85
86#[derive(Clone, Copy, Debug, PartialEq, Eq)]
89pub enum Rooted {
90 Auto,
91 Rooted,
92 Unrooted,
93}
94
95#[derive(Clone, Copy, Debug, PartialEq)]
98pub enum Weight {
99 Unweighted,
100 Full,
101 Theta(f64),
102}
103
104impl Weight {
105 pub fn parse(s: &str) -> Result<Weight> {
110 let v: f64 = s
111 .trim()
112 .parse()
113 .map_err(|_| RsomicsError::InvalidInput(format!("--weight '{s}' is not a number")))?;
114 if !(0.0..=1.0).contains(&v) {
115 return Err(RsomicsError::InvalidInput(
116 "--weight must be within [0, 1]".into(),
117 ));
118 }
119 Ok(if v == 0.0 {
120 Weight::Unweighted
121 } else if v == 1.0 {
122 Weight::Full
123 } else {
124 Weight::Theta(v)
125 })
126 }
127}
128
129pub struct Config {
130 pub delim: char,
131 pub rooted: Rooted,
132 pub weight: Weight,
133 pub precision: usize,
134}
135
136struct PhyTree {
140 branch_length: Vec<f64>,
141 children: Vec<Vec<NodeId>>,
142 tip_index: HashMap<String, NodeId>,
143 postorder: Vec<NodeId>,
144 n_nodes: usize,
145 root_bifurcating: bool,
146}
147
148impl PhyTree {
149 fn build(tree: &Tree) -> Result<PhyTree> {
150 let n_nodes = tree.nodes.len();
151 let mut branch_length = vec![0.0f64; n_nodes];
152 let mut children = vec![Vec::new(); n_nodes];
153 let mut tip_index = HashMap::new();
154 for node in &tree.nodes {
155 children[node.id] = node.children.clone();
156 if let Some(bl) = node.branch_length {
157 branch_length[node.id] = bl;
158 }
159 if node.children.is_empty() {
160 let name = node
161 .name
162 .as_deref()
163 .ok_or_else(|| RsomicsError::InvalidInput("a tip has no name".into()))?;
164 if tip_index.insert(name.to_string(), node.id).is_some() {
165 return Err(RsomicsError::InvalidInput(format!(
166 "duplicate tip name '{name}' in the tree"
167 )));
168 }
169 }
170 }
171
172 let mut postorder = Vec::with_capacity(n_nodes);
173 let mut stack = vec![(tree.root, false)];
174 while let Some((id, visited)) = stack.pop() {
175 if visited {
176 postorder.push(id);
177 } else {
178 stack.push((id, true));
179 for &c in &children[id] {
180 stack.push((c, false));
181 }
182 }
183 }
184
185 Ok(PhyTree {
186 branch_length,
187 children,
188 tip_index,
189 postorder,
190 n_nodes,
191 root_bifurcating: tree.nodes[tree.root].children.len() == 2,
192 })
193 }
194
195 fn is_rooted(&self, rooted: Rooted) -> bool {
196 match rooted {
197 Rooted::Rooted => true,
198 Rooted::Unrooted => false,
199 Rooted::Auto => self.root_bifurcating,
200 }
201 }
202
203 fn accumulate(&self, tip_counts: &[(NodeId, u64)], cbn: &mut [f64]) {
206 cbn.iter_mut().for_each(|c| *c = 0.0);
207 for &(tip, c) in tip_counts {
208 cbn[tip] = c as f64;
209 }
210 for &id in &self.postorder {
211 if !self.children[id].is_empty() {
212 cbn[id] = self.children[id].iter().map(|&c| cbn[c]).sum();
213 }
214 }
215 }
216}
217
218fn diversity(pt: &PhyTree, cbn: &[f64], rooted: bool, weight: Weight) -> f64 {
222 let total = cbn.iter().copied().fold(0.0f64, f64::max);
223 if total == 0.0 {
224 return 0.0;
225 }
226 let mut sum = 0.0;
227 for (id, &c) in cbn.iter().enumerate() {
228 let factor = match weight {
229 Weight::Unweighted => {
230 if c > 0.0 && (rooted || c < total) {
231 1.0
232 } else {
233 0.0
234 }
235 }
236 _ => {
237 let mut frac = c / total;
238 if !rooted {
239 frac = 2.0 * frac.min(1.0 - frac);
240 }
241 match weight {
242 Weight::Theta(theta) => frac.powf(theta),
243 _ => frac,
244 }
245 }
246 };
247 sum += pt.branch_length[id] * factor;
248 }
249 sum
250}
251
252pub fn run<R: BufRead, W: Write>(reader: R, out: &mut W, tree: &Tree, cfg: &Config) -> Result<()> {
253 let table = CountTable::parse(reader, cfg.delim)?;
254 let pt = PhyTree::build(tree)?;
255 let rooted = pt.is_rooted(cfg.rooted);
256
257 let row_tip: Vec<NodeId> = table
258 .feature_ids
259 .iter()
260 .map(|taxon| {
261 pt.tip_index.get(taxon).copied().ok_or_else(|| {
262 RsomicsError::InvalidInput(format!(
263 "taxon '{taxon}' from the count table is not a tip in the tree"
264 ))
265 })
266 })
267 .collect::<Result<_>>()?;
268
269 writeln!(out, "sample\tphydiv").map_err(RsomicsError::Io)?;
270 let mut tip_counts = Vec::new();
271 let mut cbn = vec![0.0f64; pt.n_nodes];
272 for (col, sample) in table.sample_names.iter().enumerate() {
273 tip_counts.clear();
274 for (row, &c) in table.columns[col].iter().enumerate() {
275 if c > 0 {
276 tip_counts.push((row_tip[row], c));
277 }
278 }
279 pt.accumulate(&tip_counts, &mut cbn);
280 let value = diversity(&pt, &cbn, rooted, cfg.weight);
281 writeln!(out, "{sample}\t{value:.*}", cfg.precision).map_err(RsomicsError::Io)?;
282 }
283 Ok(())
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 fn doc_tree() -> Tree {
291 Tree::from_newick("((a:1,b:2)c:0.5,(d:1,e:1)f:1)root;").unwrap()
292 }
293
294 fn phy(tree: &Tree, table: &str, rooted: Rooted, weight: Weight) -> f64 {
295 let cfg = Config {
296 delim: '\t',
297 rooted,
298 weight,
299 precision: 12,
300 };
301 let mut out = Vec::new();
302 run(std::io::Cursor::new(table), &mut out, tree, &cfg).unwrap();
303 String::from_utf8(out)
304 .unwrap()
305 .lines()
306 .nth(1)
307 .unwrap()
308 .split_once('\t')
309 .unwrap()
310 .1
311 .parse()
312 .unwrap()
313 }
314
315 const T1: &str = "feature\tu\na\t1\nb\t0\nd\t3\ne\t2\n";
316
317 #[test]
318 fn unweighted_matches_faith() {
319 assert!((phy(&doc_tree(), T1, Rooted::Rooted, Weight::Unweighted) - 4.5).abs() < 1e-12);
320 }
321
322 #[test]
323 fn rooted_full_weight() {
324 let v = phy(&doc_tree(), T1, Rooted::Rooted, Weight::Full);
325 assert!((v - 1.916_666_666_666_666_5).abs() < 1e-12);
326 }
327
328 #[test]
329 fn unrooted_full_weight() {
330 let v = phy(&doc_tree(), T1, Rooted::Unrooted, Weight::Full);
331 assert!((v - 2.5).abs() < 1e-12);
332 }
333
334 #[test]
335 fn rooted_theta_quarter() {
336 let v = phy(&doc_tree(), T1, Rooted::Rooted, Weight::Theta(0.25));
337 assert!((v - 3.514_589_549_479_082).abs() < 1e-12);
338 }
339
340 #[test]
341 fn unrooted_theta_half() {
342 let v = phy(&doc_tree(), T1, Rooted::Unrooted, Weight::Theta(0.5));
343 assert!((v - 3.259_872_253_901_790_4).abs() < 1e-12);
344 }
345
346 #[test]
347 fn auto_bifurcating_is_rooted() {
348 let auto = phy(&doc_tree(), T1, Rooted::Auto, Weight::Full);
349 let rooted = phy(&doc_tree(), T1, Rooted::Rooted, Weight::Full);
350 assert_eq!(auto, rooted);
351 }
352
353 #[test]
354 fn auto_trifurcating_is_unrooted() {
355 let tree = Tree::from_newick("(a:1,b:2,c:3)root;").unwrap();
356 let table = "feature\ts\na\t2\nb\t3\nc\t0\n";
357 let auto = phy(&tree, table, Rooted::Auto, Weight::Full);
358 let unrooted = phy(&tree, table, Rooted::Unrooted, Weight::Full);
359 assert_eq!(auto, unrooted);
360 }
361
362 #[test]
363 fn rooted_vs_unrooted_differ_on_subset() {
364 let tree = Tree::from_newick("(((a:1,b:2)g:3,c:1.5)h:0.7,(d:1,e:1)f:1)root;").unwrap();
365 let table = "feature\ts\na\t5\nb\t4\nc\t0\nd\t0\ne\t0\n";
366 let r = phy(&tree, table, Rooted::Rooted, Weight::Unweighted);
367 let u = phy(&tree, table, Rooted::Unrooted, Weight::Unweighted);
368 assert!((r - 6.7).abs() < 1e-12);
369 assert!((u - 3.0).abs() < 1e-12);
370 }
371
372 #[test]
373 fn empty_sample_is_zero() {
374 let table = "feature\tz\na\t0\nb\t0\nd\t0\ne\t0\n";
375 assert_eq!(phy(&doc_tree(), table, Rooted::Rooted, Weight::Full), 0.0);
376 }
377
378 #[test]
379 fn weight_parses() {
380 assert_eq!(Weight::parse("0").unwrap(), Weight::Unweighted);
381 assert_eq!(Weight::parse("1").unwrap(), Weight::Full);
382 assert_eq!(Weight::parse("0.25").unwrap(), Weight::Theta(0.25));
383 assert!(Weight::parse("1.5").is_err());
384 assert!(Weight::parse("x").is_err());
385 }
386
387 #[test]
388 fn unknown_taxon_rejected() {
389 let cfg = Config {
390 delim: '\t',
391 rooted: Rooted::Auto,
392 weight: Weight::Unweighted,
393 precision: 6,
394 };
395 let mut out = Vec::new();
396 let table = "feature\tx\na\t1\nzzz\t1\n";
397 assert!(run(std::io::Cursor::new(table), &mut out, &doc_tree(), &cfg).is_err());
398 }
399}