Skip to main content

rsomics_phydiv/
lib.rs

1use std::collections::HashMap;
2use std::io::{BufRead, Write};
3
4use rsomics_common::{Result, RsomicsError};
5use rsomics_phylo_tree::{NodeId, Tree};
6
7pub struct CountTable {
8    pub feature_ids: Vec<String>,
9    pub sample_names: Vec<String>,
10    /// Column-major: one count vector per sample, indexed by feature row.
11    pub columns: Vec<Vec<u64>>,
12}
13
14impl CountTable {
15    /// # Errors
16    /// Errors on a missing header, a ragged row, or a non-integer count.
17    pub fn parse<R: BufRead>(reader: R, delim: char) -> Result<CountTable> {
18        let mut lines = reader.lines();
19        let header = loop {
20            match lines.next() {
21                Some(line) => {
22                    let line = line.map_err(RsomicsError::Io)?;
23                    if line.trim().is_empty() || line.starts_with('#') {
24                        continue;
25                    }
26                    break line;
27                }
28                None => return Err(RsomicsError::InvalidInput("empty count table".into())),
29            }
30        };
31        let sample_names: Vec<String> = header
32            .split(delim)
33            .skip(1)
34            .map(|s| s.trim().to_string())
35            .collect();
36        if sample_names.is_empty() {
37            return Err(RsomicsError::InvalidInput(
38                "header has no sample columns (need feature-ID column + ≥1 sample)".into(),
39            ));
40        }
41        let n = sample_names.len();
42        let mut feature_ids = Vec::new();
43        let mut columns: Vec<Vec<u64>> = vec![Vec::new(); n];
44        for (row_idx, line) in lines.enumerate() {
45            let line = line.map_err(RsomicsError::Io)?;
46            if line.trim().is_empty() || line.starts_with('#') {
47                continue;
48            }
49            let mut fields = line.split(delim);
50            let feature = fields.next().unwrap_or("").trim().to_string();
51            let mut seen = 0usize;
52            for (col, field) in fields.enumerate() {
53                if col >= n {
54                    return Err(RsomicsError::InvalidInput(format!(
55                        "row {} (feature '{feature}') has more columns than the header",
56                        row_idx + 2
57                    )));
58                }
59                let count: u64 = field.trim().parse().map_err(|_| {
60                    RsomicsError::InvalidInput(format!(
61                        "row {} (feature '{feature}'), sample '{}': '{}' is not a non-negative integer count",
62                        row_idx + 2,
63                        sample_names[col],
64                        field.trim()
65                    ))
66                })?;
67                columns[col].push(count);
68                seen += 1;
69            }
70            if seen != n {
71                return Err(RsomicsError::InvalidInput(format!(
72                    "row {} (feature '{feature}') has {seen} count columns, header has {n}",
73                    row_idx + 2
74                )));
75            }
76            feature_ids.push(feature);
77        }
78        Ok(CountTable {
79            feature_ids,
80            sample_names,
81            columns,
82        })
83    }
84}
85
86/// Whether the root branch (and the path above the community LCA) counts.
87/// `Auto` follows scikit-bio: rooted iff the root is bifurcating.
88#[derive(Clone, Copy, Debug, PartialEq, Eq)]
89pub enum Rooted {
90    Auto,
91    Rooted,
92    Unrooted,
93}
94
95/// Abundance weighting of branch contributions — the BWPD family of McCoy & Matsen 2013.
96/// `Theta` is the partial-weighting exponent θ ∈ (0, 1); θ=0 is unweighted, θ=1 is `Full`.
97#[derive(Clone, Copy, Debug, PartialEq)]
98pub enum Weight {
99    Unweighted,
100    Full,
101    Theta(f64),
102}
103
104impl Weight {
105    /// Parse `--weight`: `0` → unweighted, `1` → full, a float in (0,1) → θ.
106    ///
107    /// # Errors
108    /// Errors when the value is not a number in [0, 1].
109    pub fn parse(s: &str) -> Result<Weight> {
110        let v: f64 = s
111            .trim()
112            .parse()
113            .map_err(|_| RsomicsError::InvalidInput(format!("--weight '{s}' is not a number")))?;
114        if !(0.0..=1.0).contains(&v) {
115            return Err(RsomicsError::InvalidInput(
116                "--weight must be within [0, 1]".into(),
117            ));
118        }
119        Ok(if v == 0.0 {
120            Weight::Unweighted
121        } else if v == 1.0 {
122            Weight::Full
123        } else {
124            Weight::Theta(v)
125        })
126    }
127}
128
129pub struct Config {
130    pub delim: char,
131    pub rooted: Rooted,
132    pub weight: Weight,
133    pub precision: usize,
134}
135
136/// Tree resolved for phydiv: branch length per node (missing → 0.0, matching
137/// scikit-bio's `nan_length_value=0.0`), a tip-name → node map, the node ids in
138/// postorder, and whether the root is bifurcating.
139struct PhyTree {
140    branch_length: Vec<f64>,
141    children: Vec<Vec<NodeId>>,
142    tip_index: HashMap<String, NodeId>,
143    postorder: Vec<NodeId>,
144    n_nodes: usize,
145    root_bifurcating: bool,
146}
147
148impl PhyTree {
149    fn build(tree: &Tree) -> Result<PhyTree> {
150        let n_nodes = tree.nodes.len();
151        let mut branch_length = vec![0.0f64; n_nodes];
152        let mut children = vec![Vec::new(); n_nodes];
153        let mut tip_index = HashMap::new();
154        for node in &tree.nodes {
155            children[node.id] = node.children.clone();
156            if let Some(bl) = node.branch_length {
157                branch_length[node.id] = bl;
158            }
159            if node.children.is_empty() {
160                let name = node
161                    .name
162                    .as_deref()
163                    .ok_or_else(|| RsomicsError::InvalidInput("a tip has no name".into()))?;
164                if tip_index.insert(name.to_string(), node.id).is_some() {
165                    return Err(RsomicsError::InvalidInput(format!(
166                        "duplicate tip name '{name}' in the tree"
167                    )));
168                }
169            }
170        }
171
172        let mut postorder = Vec::with_capacity(n_nodes);
173        let mut stack = vec![(tree.root, false)];
174        while let Some((id, visited)) = stack.pop() {
175            if visited {
176                postorder.push(id);
177            } else {
178                stack.push((id, true));
179                for &c in &children[id] {
180                    stack.push((c, false));
181                }
182            }
183        }
184
185        Ok(PhyTree {
186            branch_length,
187            children,
188            tip_index,
189            postorder,
190            n_nodes,
191            root_bifurcating: tree.nodes[tree.root].children.len() == 2,
192        })
193    }
194
195    fn is_rooted(&self, rooted: Rooted) -> bool {
196        match rooted {
197            Rooted::Rooted => true,
198            Rooted::Unrooted => false,
199            Rooted::Auto => self.root_bifurcating,
200        }
201    }
202
203    /// Fill `cbn[id]` with the total abundance of taxa descending from each node,
204    /// accumulated up the tree in postorder.
205    fn accumulate(&self, tip_counts: &[(NodeId, u64)], cbn: &mut [f64]) {
206        cbn.iter_mut().for_each(|c| *c = 0.0);
207        for &(tip, c) in tip_counts {
208            cbn[tip] = c as f64;
209        }
210        for &id in &self.postorder {
211            if !self.children[id].is_empty() {
212                cbn[id] = self.children[id].iter().map(|&c| cbn[c]).sum();
213            }
214        }
215    }
216}
217
218/// Generalized phylogenetic diversity from per-node descendant abundances.
219/// McCoy & Matsen 2013 BWPD: unrooted drops branches above the LCA and folds each
220/// branch to its abundance balance `2·min(p, 1−p)`; θ tempers the weighting.
221fn diversity(pt: &PhyTree, cbn: &[f64], rooted: bool, weight: Weight) -> f64 {
222    let total = cbn.iter().copied().fold(0.0f64, f64::max);
223    if total == 0.0 {
224        return 0.0;
225    }
226    let mut sum = 0.0;
227    for (id, &c) in cbn.iter().enumerate() {
228        let factor = match weight {
229            Weight::Unweighted => {
230                if c > 0.0 && (rooted || c < total) {
231                    1.0
232                } else {
233                    0.0
234                }
235            }
236            _ => {
237                let mut frac = c / total;
238                if !rooted {
239                    frac = 2.0 * frac.min(1.0 - frac);
240                }
241                match weight {
242                    Weight::Theta(theta) => frac.powf(theta),
243                    _ => frac,
244                }
245            }
246        };
247        sum += pt.branch_length[id] * factor;
248    }
249    sum
250}
251
252pub fn run<R: BufRead, W: Write>(reader: R, out: &mut W, tree: &Tree, cfg: &Config) -> Result<()> {
253    let table = CountTable::parse(reader, cfg.delim)?;
254    let pt = PhyTree::build(tree)?;
255    let rooted = pt.is_rooted(cfg.rooted);
256
257    let row_tip: Vec<NodeId> = table
258        .feature_ids
259        .iter()
260        .map(|taxon| {
261            pt.tip_index.get(taxon).copied().ok_or_else(|| {
262                RsomicsError::InvalidInput(format!(
263                    "taxon '{taxon}' from the count table is not a tip in the tree"
264                ))
265            })
266        })
267        .collect::<Result<_>>()?;
268
269    writeln!(out, "sample\tphydiv").map_err(RsomicsError::Io)?;
270    let mut tip_counts = Vec::new();
271    let mut cbn = vec![0.0f64; pt.n_nodes];
272    for (col, sample) in table.sample_names.iter().enumerate() {
273        tip_counts.clear();
274        for (row, &c) in table.columns[col].iter().enumerate() {
275            if c > 0 {
276                tip_counts.push((row_tip[row], c));
277            }
278        }
279        pt.accumulate(&tip_counts, &mut cbn);
280        let value = diversity(&pt, &cbn, rooted, cfg.weight);
281        writeln!(out, "{sample}\t{value:.*}", cfg.precision).map_err(RsomicsError::Io)?;
282    }
283    Ok(())
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    fn doc_tree() -> Tree {
291        Tree::from_newick("((a:1,b:2)c:0.5,(d:1,e:1)f:1)root;").unwrap()
292    }
293
294    fn phy(tree: &Tree, table: &str, rooted: Rooted, weight: Weight) -> f64 {
295        let cfg = Config {
296            delim: '\t',
297            rooted,
298            weight,
299            precision: 12,
300        };
301        let mut out = Vec::new();
302        run(std::io::Cursor::new(table), &mut out, tree, &cfg).unwrap();
303        String::from_utf8(out)
304            .unwrap()
305            .lines()
306            .nth(1)
307            .unwrap()
308            .split_once('\t')
309            .unwrap()
310            .1
311            .parse()
312            .unwrap()
313    }
314
315    const T1: &str = "feature\tu\na\t1\nb\t0\nd\t3\ne\t2\n";
316
317    #[test]
318    fn unweighted_matches_faith() {
319        assert!((phy(&doc_tree(), T1, Rooted::Rooted, Weight::Unweighted) - 4.5).abs() < 1e-12);
320    }
321
322    #[test]
323    fn rooted_full_weight() {
324        let v = phy(&doc_tree(), T1, Rooted::Rooted, Weight::Full);
325        assert!((v - 1.916_666_666_666_666_5).abs() < 1e-12);
326    }
327
328    #[test]
329    fn unrooted_full_weight() {
330        let v = phy(&doc_tree(), T1, Rooted::Unrooted, Weight::Full);
331        assert!((v - 2.5).abs() < 1e-12);
332    }
333
334    #[test]
335    fn rooted_theta_quarter() {
336        let v = phy(&doc_tree(), T1, Rooted::Rooted, Weight::Theta(0.25));
337        assert!((v - 3.514_589_549_479_082).abs() < 1e-12);
338    }
339
340    #[test]
341    fn unrooted_theta_half() {
342        let v = phy(&doc_tree(), T1, Rooted::Unrooted, Weight::Theta(0.5));
343        assert!((v - 3.259_872_253_901_790_4).abs() < 1e-12);
344    }
345
346    #[test]
347    fn auto_bifurcating_is_rooted() {
348        let auto = phy(&doc_tree(), T1, Rooted::Auto, Weight::Full);
349        let rooted = phy(&doc_tree(), T1, Rooted::Rooted, Weight::Full);
350        assert_eq!(auto, rooted);
351    }
352
353    #[test]
354    fn auto_trifurcating_is_unrooted() {
355        let tree = Tree::from_newick("(a:1,b:2,c:3)root;").unwrap();
356        let table = "feature\ts\na\t2\nb\t3\nc\t0\n";
357        let auto = phy(&tree, table, Rooted::Auto, Weight::Full);
358        let unrooted = phy(&tree, table, Rooted::Unrooted, Weight::Full);
359        assert_eq!(auto, unrooted);
360    }
361
362    #[test]
363    fn rooted_vs_unrooted_differ_on_subset() {
364        let tree = Tree::from_newick("(((a:1,b:2)g:3,c:1.5)h:0.7,(d:1,e:1)f:1)root;").unwrap();
365        let table = "feature\ts\na\t5\nb\t4\nc\t0\nd\t0\ne\t0\n";
366        let r = phy(&tree, table, Rooted::Rooted, Weight::Unweighted);
367        let u = phy(&tree, table, Rooted::Unrooted, Weight::Unweighted);
368        assert!((r - 6.7).abs() < 1e-12);
369        assert!((u - 3.0).abs() < 1e-12);
370    }
371
372    #[test]
373    fn empty_sample_is_zero() {
374        let table = "feature\tz\na\t0\nb\t0\nd\t0\ne\t0\n";
375        assert_eq!(phy(&doc_tree(), table, Rooted::Rooted, Weight::Full), 0.0);
376    }
377
378    #[test]
379    fn weight_parses() {
380        assert_eq!(Weight::parse("0").unwrap(), Weight::Unweighted);
381        assert_eq!(Weight::parse("1").unwrap(), Weight::Full);
382        assert_eq!(Weight::parse("0.25").unwrap(), Weight::Theta(0.25));
383        assert!(Weight::parse("1.5").is_err());
384        assert!(Weight::parse("x").is_err());
385    }
386
387    #[test]
388    fn unknown_taxon_rejected() {
389        let cfg = Config {
390            delim: '\t',
391            rooted: Rooted::Auto,
392            weight: Weight::Unweighted,
393            precision: 6,
394        };
395        let mut out = Vec::new();
396        let table = "feature\tx\na\t1\nzzz\t1\n";
397        assert!(run(std::io::Cursor::new(table), &mut out, &doc_tree(), &cfg).is_err());
398    }
399}