rsomics_tree_tipdist/
lib.rs1use std::collections::HashSet;
2use std::io::{self, Write};
3
4use rsomics_phylo_tree::{NodeId, Tree};
5
6mod pyfloat;
7
8#[derive(Debug, thiserror::Error)]
9pub enum TipDistError {
10 #[error("tree has no named tips")]
11 NoTips,
12 #[error("tree contains duplicate tip name {0:?}")]
13 DuplicateTip(String),
14}
15
16pub struct TipDistMatrix {
19 pub ids: Vec<String>,
20 pub data: Vec<f64>,
21 n: usize,
22}
23
24impl TipDistMatrix {
25 #[must_use]
26 pub fn len(&self) -> usize {
27 self.n
28 }
29
30 #[must_use]
31 pub fn is_empty(&self) -> bool {
32 self.n == 0
33 }
34
35 #[must_use]
36 pub fn get(&self, i: usize, j: usize) -> f64 {
37 self.data[i * self.n + j]
38 }
39
40 pub fn write_lsmat<W: Write>(&self, w: &mut W) -> io::Result<()> {
44 let mut line = String::new();
45 for id in &self.ids {
46 line.push('\t');
47 line.push_str(id);
48 }
49 line.push('\n');
50 w.write_all(line.as_bytes())?;
51
52 for i in 0..self.n {
53 line.clear();
54 line.push_str(&self.ids[i]);
55 let row = &self.data[i * self.n..(i + 1) * self.n];
56 for v in row {
57 line.push('\t');
58 pyfloat::push_repr(&mut line, *v);
59 }
60 line.push('\n');
61 w.write_all(line.as_bytes())?;
62 }
63 Ok(())
64 }
65}
66
67pub fn tip_tip_distances(tree: &Tree, use_length: bool) -> Result<TipDistMatrix, TipDistError> {
75 let mut ids = Vec::new();
76 let mut tip_index = vec![usize::MAX; tree.nodes.len()];
77 collect_tips(tree, tree.root, &mut ids, &mut tip_index);
78
79 let n = ids.len();
80 if n == 0 {
81 return Err(TipDistError::NoTips);
82 }
83 let mut seen = HashSet::with_capacity(n);
84 for name in &ids {
85 if !seen.insert(name.as_str()) {
86 return Err(TipDistError::DuplicateTip(name.clone()));
87 }
88 }
89
90 let mut data = vec![0.0_f64; n * n];
91 let mut depths = vec![0.0_f64; n];
92 let mut range = vec![(usize::MAX, usize::MAX); tree.nodes.len()];
93
94 for &id in &postorder(tree) {
95 let node = &tree.nodes[id];
96 if node.children.is_empty() {
97 let t = tip_index[id];
98 range[id] = (t, t + 1);
99 continue;
100 }
101
102 let mut clades: Vec<(usize, usize)> = Vec::with_capacity(node.children.len());
103 for &child in &node.children {
104 let (s, e) = range[child];
105 if s == usize::MAX {
106 continue;
107 }
108 let inc = if use_length {
109 tree.nodes[child].branch_length.unwrap_or(0.0)
110 } else {
111 1.0
112 };
113 for d in &mut depths[s..e] {
114 *d += inc;
115 }
116 clades.push((s, e));
117 }
118
119 for a in 0..clades.len() {
120 let (s1, e1) = clades[a];
121 for &(s2, e2) in &clades[a + 1..] {
122 for i in s1..e1 {
123 for j in s2..e2 {
124 let v = depths[i] + depths[j];
125 data[i * n + j] = v;
126 data[j * n + i] = v;
127 }
128 }
129 }
130 }
131
132 if let (Some(&(first, _)), Some(&(_, last))) = (clades.first(), clades.last()) {
133 range[id] = (first, last);
134 }
135 }
136
137 Ok(TipDistMatrix { ids, data, n })
138}
139
140fn collect_tips(tree: &Tree, id: NodeId, ids: &mut Vec<String>, tip_index: &mut [usize]) {
141 let node = &tree.nodes[id];
142 if node.children.is_empty() {
143 if let Some(name) = &node.name {
144 tip_index[id] = ids.len();
145 ids.push(name.clone());
146 }
147 return;
148 }
149 for &child in &node.children {
150 collect_tips(tree, child, ids, tip_index);
151 }
152}
153
154fn postorder(tree: &Tree) -> Vec<NodeId> {
155 let mut order = Vec::with_capacity(tree.nodes.len());
156 let mut stack = vec![(tree.root, false)];
157 while let Some((id, visited)) = stack.pop() {
158 if visited {
159 order.push(id);
160 } else {
161 stack.push((id, true));
162 for &child in tree.nodes[id].children.iter().rev() {
163 stack.push((child, false));
164 }
165 }
166 }
167 order
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 fn render(tree: &Tree, use_length: bool) -> String {
175 let dm = tip_tip_distances(tree, use_length).unwrap();
176 let mut buf = Vec::new();
177 dm.write_lsmat(&mut buf).unwrap();
178 String::from_utf8(buf).unwrap()
179 }
180
181 #[test]
182 fn spec_example_matches_oracle() {
183 let tree = Tree::from_newick("((a:1,b:2):0.5,(c:3,d:4):0.6);").unwrap();
184 let got = render(&tree, true);
185 let want = "\ta\tb\tc\td\n\
186 a\t0.0\t3.0\t5.1\t6.1\n\
187 b\t3.0\t0.0\t6.1\t7.1\n\
188 c\t5.1\t6.1\t0.0\t7.0\n\
189 d\t6.1\t7.1\t7.0\t0.0\n";
190 assert_eq!(got, want);
191 }
192
193 #[test]
194 fn doc_example_with_named_internals() {
195 let tree = Tree::from_newick("((a:1,b:2)c:3,(d:4,e:5)f:6)root;").unwrap();
196 let dm = tip_tip_distances(&tree, true).unwrap();
197 assert_eq!(dm.ids, ["a", "b", "d", "e"]);
198 assert_eq!(dm.get(0, 1), 3.0);
199 assert_eq!(dm.get(0, 2), 14.0);
200 assert_eq!(dm.get(0, 3), 15.0);
201 assert_eq!(dm.get(2, 3), 9.0);
202 }
203
204 #[test]
205 fn branch_counts_when_use_length_false() {
206 let tree = Tree::from_newick("((a:1,b:2)c:3,(d:4,e:5)f:6)root;").unwrap();
207 let dm = tip_tip_distances(&tree, false).unwrap();
208 assert_eq!(dm.get(0, 1), 2.0);
209 assert_eq!(dm.get(0, 2), 4.0);
210 assert_eq!(dm.get(2, 3), 2.0);
211 }
212
213 #[test]
214 fn missing_lengths_count_as_zero() {
215 let tree = Tree::from_newick("((a,b),(c,d));").unwrap();
216 let dm = tip_tip_distances(&tree, true).unwrap();
217 for v in &dm.data {
218 assert_eq!(*v, 0.0);
219 }
220 }
221
222 #[test]
223 fn duplicate_tip_names_rejected() {
224 let tree = Tree::from_newick("((a:1,a:2):0.5,(c:3,d:4):0.6);").unwrap();
225 assert!(matches!(
226 tip_tip_distances(&tree, true),
227 Err(TipDistError::DuplicateTip(_))
228 ));
229 }
230}