Skip to main content

raxtax/
lineage.rs

1use crate::tree::{Node, Tree};
2use itertools::Itertools;
3use logging_timer::time;
4
5use crate::utils;
6
7#[derive(Debug, Clone)]
8pub struct EvaluationResult<'a, 'b> {
9    pub query_label: &'b String,
10    pub lineage: &'a String,
11    pub confidence_values: Vec<f64>,
12    pub local_signal: f64,
13    pub global_signal: f64,
14}
15
16impl EvaluationResult<'_, '_> {
17    pub fn get_output_string(&self) -> String {
18        format!(
19            "{}\t{}\t{}\t{:.5}\t{:.5}",
20            self.query_label,
21            self.lineage,
22            self.confidence_values
23                .iter()
24                .map(|v| format!("{1:.0$}", utils::F64_OUTPUT_ACCURACY as usize, v))
25                .join(","),
26            self.local_signal,
27            self.global_signal
28        )
29    }
30
31    pub fn get_tsv_string(&self, sequence: &String) -> String {
32        format!(
33            "{}\t{}\t{:.5}\t{:.5}\t{}",
34            self.query_label,
35            self.lineage
36                .split(',')
37                .map(std::string::ToString::to_string)
38                .interleave(self.confidence_values.iter().map(|v| format!(
39                    "{1:.0$}",
40                    utils::F64_OUTPUT_ACCURACY as usize,
41                    v
42                )))
43                .join("\t"),
44            self.local_signal,
45            self.global_signal,
46            sequence
47        )
48    }
49}
50
51pub struct Lineage<'a, 'b> {
52    query_label: &'b String,
53    tree: &'a Tree,
54    confidence_values: Vec<f64>,
55    confidence_prefix_sum: Vec<f64>,
56    confidence_vectors: Vec<(usize, Vec<f64>, Vec<f64>)>,
57    rounding_factor: f64,
58    binning: bool,
59}
60
61impl<'a, 'b> Lineage<'a, 'b> {
62    pub fn new(
63        query_label: &'b String,
64        tree: &'a Tree,
65        confidence_values: Vec<f64>,
66        binning: bool,
67    ) -> Self {
68        let mut confidence_prefix_sum = vec![0.0];
69        confidence_prefix_sum.extend(confidence_values.iter().scan(0.0, |sum, i| {
70            *sum += i;
71            Some(*sum)
72        }));
73        let rounding_factor = f64::from(10_u32.pow(utils::F64_OUTPUT_ACCURACY));
74        let expected_num_results = rounding_factor as usize / 2;
75        Self {
76            query_label,
77            tree,
78            confidence_values,
79            confidence_prefix_sum,
80            confidence_vectors: Vec::with_capacity(expected_num_results),
81            rounding_factor,
82            binning,
83        }
84    }
85
86    #[time("debug")]
87    pub fn evaluate(mut self) -> (Vec<EvaluationResult<'a, 'b>>, Option<(String, f64)>) {
88        self.eval_recurse(&self.tree.root, &[], &[]);
89        // NOTE: This would be the correct maximum leaf confidence and ideally we would normalize with this.
90        // However, because this is already 0.99 for 100 tips, it is not worth it, as it is
91        // basically 1 for any reasonable reference lineage.
92        // let max_leaf_confidence = ((1.0 - 1.0 / self.tree.num_tips as f64).powi(2) + ((self.tree.num_tips as f64 - 1.0) / (self.tree.num_tips as f64).powi(2))).sqrt();
93        let leaf_confidence = utils::euclidean_norm(
94            self.confidence_values
95                .iter()
96                .map(|&v| (v - 1.0 / self.tree.num_tips as f64)),
97        );
98        let mut best_bin_idx = None;
99        let results = self
100            .confidence_vectors
101            .into_iter()
102            .sorted_by(|a, b| b.1.iter().partial_cmp(a.1.iter()).unwrap())
103            .map(|(idx, conf_values, expected_conf_values)| {
104                let start_index = match expected_conf_values.iter().find_position(|&&x| 1.0 > x) {
105                    Some((i, _)) => i,
106                    None => expected_conf_values.len() - 1,
107                };
108                let lineage_confidence = utils::euclidean_distance_l1(
109                    &conf_values[start_index..],
110                    &expected_conf_values[start_index..],
111                );
112                if self.binning && best_bin_idx.is_none() {
113                    best_bin_idx = self.tree.lineage_idx_to_bin_idx[idx];
114                }
115                EvaluationResult {
116                    query_label: self.query_label,
117                    lineage: &self.tree.lineages[idx],
118                    confidence_values: conf_values,
119                    local_signal: lineage_confidence,
120                    global_signal: leaf_confidence,
121                }
122            })
123            .collect_vec();
124        let bin_result = match best_bin_idx {
125            Some(bin_idx) => Some((
126                self.tree.bins[bin_idx].clone(),
127                self.tree.bin_idx_to_lineage_idxs[bin_idx]
128                    .iter()
129                    .map(|&idx| self.confidence_values[idx])
130                    .sum(),
131            )),
132            None => None,
133        };
134        if !self.binning {
135            assert!(bin_result.is_none())
136        }
137        (results, bin_result)
138    }
139
140    fn get_confidence(&self, node: &Node) -> f64 {
141        self.confidence_prefix_sum[node.confidence_range.1]
142            - self.confidence_prefix_sum[node.confidence_range.0]
143    }
144
145    fn eval_recurse(
146        &mut self,
147        node: &Node,
148        confidence_prefix: &[f64],
149        expected_confidence_prefix: &[f64],
150    ) -> bool {
151        let mut no_child_significant = true;
152        let mut pushed_result = false;
153        for c in &node.children {
154            let child_conf =
155                (self.get_confidence(c) * self.rounding_factor).round() / self.rounding_factor;
156            if child_conf == 0.0 {
157                continue;
158            }
159            no_child_significant = false;
160            let mut conf_prefix = confidence_prefix.to_vec();
161            let mut expected_conf_prefix = expected_confidence_prefix.to_vec();
162            conf_prefix.push(child_conf);
163            expected_conf_prefix.push(
164                (c.confidence_range.1 - c.confidence_range.0) as f64 / self.tree.num_tips as f64,
165            );
166            let child_pushed_result = self.eval_recurse(c, &conf_prefix, &expected_conf_prefix);
167            if !child_pushed_result && self.tree.is_taxon_leaf(c) {
168                let max_idx = if self.binning {
169                    self.confidence_values[c.confidence_range.0..c.confidence_range.1]
170                        .iter()
171                        .position_max_by(|&&a, &b| a.partial_cmp(b).unwrap())
172                        .unwrap()
173                        + c.confidence_range.0
174                } else {
175                    c.confidence_range.0
176                };
177                self.confidence_vectors
178                    .push((max_idx, conf_prefix, expected_conf_prefix));
179                pushed_result = true;
180            }
181            pushed_result |= child_pushed_result;
182        }
183        if no_child_significant && self.tree.is_inner_taxon_node(node) {
184            let mut conf_prefix = confidence_prefix.to_vec();
185            let mut expected_conf_prefix = expected_confidence_prefix.to_vec();
186            let mut current_node = node;
187            while self.tree.is_inner_taxon_node(current_node) {
188                current_node = current_node
189                    .children
190                    .iter()
191                    .max_by(|c, d| {
192                        self.get_confidence(c)
193                            .partial_cmp(&self.get_confidence(d))
194                            .unwrap()
195                    })
196                    .unwrap();
197                conf_prefix.push(1.0 / self.rounding_factor);
198                expected_conf_prefix.push(
199                    (current_node.confidence_range.1 - current_node.confidence_range.0) as f64
200                        / self.tree.num_tips as f64,
201                );
202            }
203            self.confidence_vectors.push((
204                current_node.confidence_range.0,
205                conf_prefix,
206                expected_conf_prefix,
207            ));
208            pushed_result = true;
209        }
210        pushed_result
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use itertools::Itertools;
217    use statrs::assert_almost_eq;
218
219    use crate::{
220        lineage::{EvaluationResult, Lineage},
221        tree::Tree,
222    };
223
224    #[test]
225    fn test_tree_construction() {
226        let lineages = vec![
227            String::from("Animalia,Chordata,Mammalia,Primates,Hominidae,Homo"),
228            "Animalia,Chordata,Mammalia,Primates,Hominidae,Pan".into(),
229            "Animalia,Chordata,Mammalia,Carnivora,Canidae,Canis".into(),
230            "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis".into(),
231            "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis".into(),
232        ];
233        let bins = vec![None; 5];
234        let sequences = vec![
235            [0b00].repeat(9),
236            [0b00].repeat(9),
237            [0b00].repeat(9),
238            [0b00].repeat(9),
239            [0b00].repeat(9),
240        ];
241        let tree = Tree::new(lineages.into_iter().zip(bins).collect_vec(), sequences).unwrap();
242        let confidence_values = vec![0.1, 0.3, 0.4, 0.004, 0.004];
243        tree.print();
244        let query_label = String::from("q");
245        let lineage = Lineage::new(&query_label, &tree, confidence_values, false);
246        let result = lineage.evaluate();
247        assert_eq!(
248            result
249                .0
250                .into_iter()
251                .map(
252                    |EvaluationResult {
253                         lineage,
254                         confidence_values,
255                         ..
256                     }| (lineage, confidence_values)
257                )
258                .collect_vec(),
259            vec![
260                (
261                    &String::from("Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis"),
262                    vec![0.81, 0.81, 0.81, 0.8, 0.7, 0.7,],
263                ),
264                (
265                    &"Animalia,Chordata,Mammalia,Carnivora,Canidae,Canis".into(),
266                    vec![0.81, 0.81, 0.81, 0.8, 0.1, 0.1,],
267                ),
268                (
269                    &"Animalia,Chordata,Mammalia,Primates,Hominidae,Pan".into(),
270                    vec![0.81, 0.81, 0.81, 0.01, 0.01, 0.01,],
271                ),
272            ]
273        );
274    }
275
276    #[test]
277    fn test_variable_lineage_length() {
278        let lineages = vec![
279            String::from("Animalia,Chordata,Mammalia,Primates,Hominidae,Homo,Homo_sapiens"),
280            "Animalia,Chordata,Mammalia,Primates,Hominidae,Pan".into(),
281            "Animalia,Chordata,Mammalia,Carnivora,Canidae,Canis".into(),
282            "Animalia,Chordata,Mammalia,Carnivora,Doggo".into(),
283            "Animalia,Chordata,Mammalia,Mouse".into(),
284            "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis".into(),
285            "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis".into(),
286        ];
287        let bins = vec![None; 7];
288        let sequences = vec![
289            [0b00].repeat(9),
290            [0b00].repeat(9),
291            [0b00].repeat(9),
292            [0b00].repeat(9),
293            [0b00].repeat(9),
294            [0b00].repeat(9),
295            [0b00].repeat(9),
296        ];
297        let tree = Tree::new(lineages.into_iter().zip(bins).collect_vec(), sequences).unwrap();
298        let confidence_values = vec![0.05, 0.1, 0.3, 0.4, 0.1, 0.004, 0.004];
299        tree.print();
300        let query_label = String::from("q");
301        let lineage = Lineage::new(&query_label, &tree, confidence_values, false);
302        let result = lineage.evaluate();
303        dbg!(&result);
304        assert_eq!(
305            result
306                .0
307                .into_iter()
308                .map(
309                    |EvaluationResult {
310                         lineage,
311                         confidence_values,
312                         ..
313                     }| (lineage, confidence_values)
314                )
315                .collect_vec(),
316            vec![
317                (
318                    &String::from("Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis"),
319                    vec![0.96, 0.96, 0.96, 0.85, 0.7, 0.7,],
320                ),
321                (
322                    &"Animalia,Chordata,Mammalia,Carnivora,Doggo".into(),
323                    vec![0.96, 0.96, 0.96, 0.85, 0.1,],
324                ),
325                (
326                    &"Animalia,Chordata,Mammalia,Carnivora,Canidae,Canis".into(),
327                    vec![0.96, 0.96, 0.96, 0.85, 0.05, 0.05,],
328                ),
329                (
330                    &"Animalia,Chordata,Mammalia,Mouse".into(),
331                    vec![0.96, 0.96, 0.96, 0.1],
332                ),
333                (
334                    &"Animalia,Chordata,Mammalia,Primates,Hominidae,Pan".into(),
335                    vec![0.96, 0.96, 0.96, 0.01, 0.01, 0.01,],
336                ),
337            ]
338        );
339    }
340
341    #[test]
342    fn test_likelihood_edge_case() {
343        let lineages = vec![
344            (
345                "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis".into(),
346                None,
347            ),
348            (
349                "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis_ferrocius".into(),
350                None,
351            ),
352            (
353                "Animalia,Chordata,Mammalia,Carnivora,Canidae,Canis".into(),
354                None,
355            ),
356        ];
357        let sequences = vec![[0b00].repeat(9), [0b00].repeat(9), [0b00].repeat(9)];
358        let tree = Tree::new(lineages, sequences).unwrap();
359        let confidence_values = vec![0.004, 0.004, 0.004];
360        tree.print();
361        let query_label = String::from("q");
362        let lineage = Lineage::new(&query_label, &tree, confidence_values, false);
363        let result = lineage.evaluate();
364        assert_eq!(
365            result
366                .0
367                .into_iter()
368                .map(
369                    |EvaluationResult {
370                         lineage,
371                         confidence_values,
372                         ..
373                     }| (lineage, confidence_values)
374                )
375                .collect_vec(),
376            vec![(
377                &String::from("Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis_ferrocius"),
378                vec![0.01, 0.01, 0.01, 0.01, 0.01, 0.01,],
379            ),]
380        );
381    }
382
383    #[test]
384    fn test_taxonomic_binning() {
385        let lineages = vec![
386            String::from("Animalia,Chordata,Mammalia,Primates,Hominidae,Homo,Homo_sapiens"),
387            "Animalia,Chordata,Mammalia,Primates,Hominidae,Pan".into(),
388            "Animalia,Chordata,Mammalia,Carnivora,Canidae,Canis".into(),
389            "Animalia,Chordata,Mammalia,Carnivora,Doggo".into(),
390            "Animalia,Chordata,Mammalia,Mouse".into(),
391            "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis".into(),
392            "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis".into(),
393        ];
394        let bins = vec![
395            Some(String::from("BIN1")),
396            Some("BIN1".into()),
397            Some("BIN2".into()),
398            Some("BIN2".into()),
399            Some("BIN3".into()),
400            None,
401            None,
402        ];
403        let sequences = vec![
404            [0b00].repeat(9),
405            [0b00].repeat(9),
406            [0b00].repeat(9),
407            [0b00].repeat(9),
408            [0b00].repeat(9),
409            [0b00].repeat(9),
410            [0b00].repeat(9),
411        ];
412        let tree = Tree::new(lineages.into_iter().zip(bins).collect_vec(), sequences).unwrap();
413        let confidence_values = vec![0.05, 0.1, 0.3, 0.4, 0.1, 0.004, 0.004];
414        tree.print();
415        let query_label = String::from("q");
416        let lineage = Lineage::new(&query_label, &tree, confidence_values, true);
417        let result = lineage.evaluate().1.unwrap();
418        assert_eq!(result.0, String::from("BIN2"));
419        assert_almost_eq!(result.1, 0.15, 1e-7);
420    }
421
422    #[test]
423    fn test_taxonomic_binning_weird_bins() {
424        let lineages = vec![
425            String::from("Animalia,Chordata,Mammalia,Primates,Hominidae,Homo,Homo_sapiens"),
426            "Animalia,Chordata,Mammalia,Primates,Hominidae,Pan".into(),
427            "Animalia,Chordata,Mammalia,Carnivora,Canidae,Canis".into(),
428            "Animalia,Chordata,Mammalia,Carnivora,Doggo".into(),
429            "Animalia,Chordata,Mammalia,Mouse".into(),
430            "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis".into(),
431            "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis".into(),
432        ];
433        let bins = vec![
434            Some(String::from("BIN1")),
435            Some("BIN2".into()),
436            None,
437            Some("BIN3".into()),
438            Some("BIN3".into()),
439            Some("BIN4".into()),
440            Some("BIN3".into()),
441        ];
442        let sequences = vec![
443            [0b00].repeat(9),
444            [0b00].repeat(9),
445            [0b00].repeat(9),
446            [0b00].repeat(9),
447            [0b00].repeat(9),
448            [0b00].repeat(9),
449            [0b00].repeat(9),
450        ];
451        let tree = Tree::new(lineages.into_iter().zip(bins).collect_vec(), sequences).unwrap();
452        let confidence_values = vec![0.05, 0.1, 0.3, 0.4, 0.1, 0.004, 0.004];
453        tree.print();
454        let query_label = String::from("q");
455        let lineage = Lineage::new(&query_label, &tree, confidence_values, true);
456        let result = lineage.evaluate().1.unwrap();
457        assert_eq!(result.0, String::from("BIN4"));
458        assert_almost_eq!(result.1, 0.4, 1e-7);
459    }
460}