Skip to main content

rsomics_tree_tipdist/
lib.rs

1use std::collections::HashSet;
2use std::io::{self, Write};
3
4use rsomics_phylo_tree::{NodeId, Tree};
5
6mod pyfloat;
7
8#[derive(Debug, thiserror::Error)]
9pub enum TipDistError {
10    #[error("tree has no named tips")]
11    NoTips,
12    #[error("tree contains duplicate tip name {0:?}")]
13    DuplicateTip(String),
14}
15
16/// Symmetric patristic distance matrix with tip labels in skbio's postorder
17/// tip order. `data` is row-major `n×n`.
18pub struct TipDistMatrix {
19    pub ids: Vec<String>,
20    pub data: Vec<f64>,
21    n: usize,
22}
23
24impl TipDistMatrix {
25    #[must_use]
26    pub fn len(&self) -> usize {
27        self.n
28    }
29
30    #[must_use]
31    pub fn is_empty(&self) -> bool {
32        self.n == 0
33    }
34
35    #[must_use]
36    pub fn get(&self, i: usize, j: usize) -> f64 {
37        self.data[i * self.n + j]
38    }
39
40    /// Write the matrix in skbio `DistanceMatrix.write` (lsmat) TSV form:
41    /// a header row starting with a tab then the ids, then one labelled row per
42    /// tip. Values use CPython `repr` float formatting, byte-identical to skbio.
43    pub fn write_lsmat<W: Write>(&self, w: &mut W) -> io::Result<()> {
44        let mut line = String::new();
45        for id in &self.ids {
46            line.push('\t');
47            line.push_str(id);
48        }
49        line.push('\n');
50        w.write_all(line.as_bytes())?;
51
52        for i in 0..self.n {
53            line.clear();
54            line.push_str(&self.ids[i]);
55            let row = &self.data[i * self.n..(i + 1) * self.n];
56            for v in row {
57                line.push('\t');
58                pyfloat::push_repr(&mut line, *v);
59            }
60            line.push('\n');
61            w.write_all(line.as_bytes())?;
62        }
63        Ok(())
64    }
65}
66
67/// Patristic distance between every pair of named tips. With `use_length =
68/// false`, counts branches instead of summing lengths; missing branch lengths
69/// count as 0.
70///
71/// Mirrors skbio TreeNode.cophenet: one postorder sweep grows each tip's depth
72/// edge by edge, and each tip-pair distance is the sum of their two depths
73/// captured at their LCA. The float addition order matches skbio bit-for-bit.
74pub fn tip_tip_distances(tree: &Tree, use_length: bool) -> Result<TipDistMatrix, TipDistError> {
75    let mut ids = Vec::new();
76    let mut tip_index = vec![usize::MAX; tree.nodes.len()];
77    collect_tips(tree, tree.root, &mut ids, &mut tip_index);
78
79    let n = ids.len();
80    if n == 0 {
81        return Err(TipDistError::NoTips);
82    }
83    let mut seen = HashSet::with_capacity(n);
84    for name in &ids {
85        if !seen.insert(name.as_str()) {
86            return Err(TipDistError::DuplicateTip(name.clone()));
87        }
88    }
89
90    let mut data = vec![0.0_f64; n * n];
91    let mut depths = vec![0.0_f64; n];
92    let mut range = vec![(usize::MAX, usize::MAX); tree.nodes.len()];
93
94    for &id in &postorder(tree) {
95        let node = &tree.nodes[id];
96        if node.children.is_empty() {
97            let t = tip_index[id];
98            range[id] = (t, t + 1);
99            continue;
100        }
101
102        let mut clades: Vec<(usize, usize)> = Vec::with_capacity(node.children.len());
103        for &child in &node.children {
104            let (s, e) = range[child];
105            if s == usize::MAX {
106                continue;
107            }
108            let inc = if use_length {
109                tree.nodes[child].branch_length.unwrap_or(0.0)
110            } else {
111                1.0
112            };
113            for d in &mut depths[s..e] {
114                *d += inc;
115            }
116            clades.push((s, e));
117        }
118
119        for a in 0..clades.len() {
120            let (s1, e1) = clades[a];
121            for &(s2, e2) in &clades[a + 1..] {
122                for i in s1..e1 {
123                    for j in s2..e2 {
124                        let v = depths[i] + depths[j];
125                        data[i * n + j] = v;
126                        data[j * n + i] = v;
127                    }
128                }
129            }
130        }
131
132        if let (Some(&(first, _)), Some(&(_, last))) = (clades.first(), clades.last()) {
133            range[id] = (first, last);
134        }
135    }
136
137    Ok(TipDistMatrix { ids, data, n })
138}
139
140fn collect_tips(tree: &Tree, id: NodeId, ids: &mut Vec<String>, tip_index: &mut [usize]) {
141    let node = &tree.nodes[id];
142    if node.children.is_empty() {
143        if let Some(name) = &node.name {
144            tip_index[id] = ids.len();
145            ids.push(name.clone());
146        }
147        return;
148    }
149    for &child in &node.children {
150        collect_tips(tree, child, ids, tip_index);
151    }
152}
153
154fn postorder(tree: &Tree) -> Vec<NodeId> {
155    let mut order = Vec::with_capacity(tree.nodes.len());
156    let mut stack = vec![(tree.root, false)];
157    while let Some((id, visited)) = stack.pop() {
158        if visited {
159            order.push(id);
160        } else {
161            stack.push((id, true));
162            for &child in tree.nodes[id].children.iter().rev() {
163                stack.push((child, false));
164            }
165        }
166    }
167    order
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    fn render(tree: &Tree, use_length: bool) -> String {
175        let dm = tip_tip_distances(tree, use_length).unwrap();
176        let mut buf = Vec::new();
177        dm.write_lsmat(&mut buf).unwrap();
178        String::from_utf8(buf).unwrap()
179    }
180
181    #[test]
182    fn spec_example_matches_oracle() {
183        let tree = Tree::from_newick("((a:1,b:2):0.5,(c:3,d:4):0.6);").unwrap();
184        let got = render(&tree, true);
185        let want = "\ta\tb\tc\td\n\
186                    a\t0.0\t3.0\t5.1\t6.1\n\
187                    b\t3.0\t0.0\t6.1\t7.1\n\
188                    c\t5.1\t6.1\t0.0\t7.0\n\
189                    d\t6.1\t7.1\t7.0\t0.0\n";
190        assert_eq!(got, want);
191    }
192
193    #[test]
194    fn doc_example_with_named_internals() {
195        let tree = Tree::from_newick("((a:1,b:2)c:3,(d:4,e:5)f:6)root;").unwrap();
196        let dm = tip_tip_distances(&tree, true).unwrap();
197        assert_eq!(dm.ids, ["a", "b", "d", "e"]);
198        assert_eq!(dm.get(0, 1), 3.0);
199        assert_eq!(dm.get(0, 2), 14.0);
200        assert_eq!(dm.get(0, 3), 15.0);
201        assert_eq!(dm.get(2, 3), 9.0);
202    }
203
204    #[test]
205    fn branch_counts_when_use_length_false() {
206        let tree = Tree::from_newick("((a:1,b:2)c:3,(d:4,e:5)f:6)root;").unwrap();
207        let dm = tip_tip_distances(&tree, false).unwrap();
208        assert_eq!(dm.get(0, 1), 2.0);
209        assert_eq!(dm.get(0, 2), 4.0);
210        assert_eq!(dm.get(2, 3), 2.0);
211    }
212
213    #[test]
214    fn missing_lengths_count_as_zero() {
215        let tree = Tree::from_newick("((a,b),(c,d));").unwrap();
216        let dm = tip_tip_distances(&tree, true).unwrap();
217        for v in &dm.data {
218            assert_eq!(*v, 0.0);
219        }
220    }
221
222    #[test]
223    fn duplicate_tip_names_rejected() {
224        let tree = Tree::from_newick("((a:1,a:2):0.5,(c:3,d:4):0.6);").unwrap();
225        assert!(matches!(
226            tip_tip_distances(&tree, true),
227            Err(TipDistError::DuplicateTip(_))
228        ));
229    }
230}