raxtax/
lineage.rs

1use crate::tree::{Node, Tree};
2use itertools::Itertools;
3use logging_timer::time;
4
5use crate::utils;
6
7#[derive(Debug, Clone)]
8pub struct EvaluationResult<'a, 'b> {
9    pub query_label: &'b String,
10    pub lineage: &'a String,
11    pub confidence_values: Vec<f64>,
12    pub local_signal: f64,
13    pub global_signal: f64,
14}
15
16impl EvaluationResult<'_, '_> {
17    pub fn get_output_string(&self) -> String {
18        format!(
19            "{}\t{}\t{}\t{:.5}\t{:.5}",
20            self.query_label,
21            self.lineage,
22            self.confidence_values
23                .iter()
24                .map(|v| format!("{1:.0$}", utils::F64_OUTPUT_ACCURACY as usize, v))
25                .join(","),
26            self.local_signal,
27            self.global_signal
28        )
29    }
30
31    pub fn get_tsv_string(&self, sequence: &String) -> String {
32        format!(
33            "{}\t{}\t{:.5}\t{:.5}\t{}",
34            self.query_label,
35            self.lineage
36                .split(',')
37                .map(std::string::ToString::to_string)
38                .interleave(self.confidence_values.iter().map(|v| format!(
39                    "{1:.0$}",
40                    utils::F64_OUTPUT_ACCURACY as usize,
41                    v
42                )))
43                .join("\t"),
44            self.local_signal,
45            self.global_signal,
46            sequence
47        )
48    }
49}
50
51pub struct Lineage<'a, 'b> {
52    query_label: &'b String,
53    tree: &'a Tree,
54    confidence_values: Vec<f64>,
55    confidence_prefix_sum: Vec<f64>,
56    confidence_vectors: Vec<(usize, Vec<f64>, Vec<f64>)>,
57    rounding_factor: f64,
58}
59
60impl<'a, 'b> Lineage<'a, 'b> {
61    pub fn new(query_label: &'b String, tree: &'a Tree, confidence_values: Vec<f64>) -> Self {
62        let mut confidence_prefix_sum = vec![0.0];
63        confidence_prefix_sum.extend(confidence_values.iter().scan(0.0, |sum, i| {
64            *sum += i;
65            Some(*sum)
66        }));
67        let rounding_factor = f64::from(10_u32.pow(utils::F64_OUTPUT_ACCURACY));
68        let expected_num_results = rounding_factor as usize / 2;
69        Self {
70            query_label,
71            tree,
72            confidence_values,
73            confidence_prefix_sum,
74            confidence_vectors: Vec::with_capacity(expected_num_results),
75            rounding_factor,
76        }
77    }
78
79    #[time("debug")]
80    pub fn evaluate(mut self) -> Vec<EvaluationResult<'a, 'b>> {
81        self.eval_recurse(&self.tree.root, &[], &[]);
82        // NOTE: This would be the correct maximum leaf confidence and ideally we would normalize with this.
83        // However, because this is already 0.99 for 100 tips, it is not worth it, as it is
84        // basically 1 for any reasonable reference lineage.
85        // let max_leaf_confidence = ((1.0 - 1.0 / self.tree.num_tips as f64).powi(2) + ((self.tree.num_tips as f64 - 1.0) / (self.tree.num_tips as f64).powi(2))).sqrt();
86        let leaf_confidence = utils::euclidean_norm(
87            self.confidence_values
88                .iter()
89                .map(|&v| (v - 1.0 / self.tree.num_tips as f64)),
90        );
91        self.confidence_vectors
92            .into_iter()
93            .sorted_by(|a, b| b.1.iter().partial_cmp(a.1.iter()).unwrap())
94            .map(|(idx, conf_values, expected_conf_values)| {
95                let start_index = match expected_conf_values.iter().find_position(|&&x| 1.0 > x) {
96                    Some((i, _)) => i,
97                    None => expected_conf_values.len() - 1,
98                };
99                let lineage_confidence = utils::euclidean_distance_l1(
100                    &conf_values[start_index..],
101                    &expected_conf_values[start_index..],
102                );
103                EvaluationResult {
104                    query_label: self.query_label,
105                    lineage: &self.tree.lineages[idx],
106                    confidence_values: conf_values,
107                    local_signal: lineage_confidence,
108                    global_signal: leaf_confidence,
109                }
110            })
111            .collect_vec()
112    }
113
114    fn get_confidence(&self, node: &Node) -> f64 {
115        self.confidence_prefix_sum[node.confidence_range.1]
116            - self.confidence_prefix_sum[node.confidence_range.0]
117    }
118
119    fn eval_recurse(
120        &mut self,
121        node: &Node,
122        confidence_prefix: &[f64],
123        expected_confidence_prefix: &[f64],
124    ) -> bool {
125        let mut no_child_significant = true;
126        let mut pushed_result = false;
127        for c in &node.children {
128            let child_conf =
129                (self.get_confidence(c) * self.rounding_factor).round() / self.rounding_factor;
130            if child_conf == 0.0 {
131                continue;
132            }
133            no_child_significant = false;
134            let mut conf_prefix = confidence_prefix.to_vec();
135            let mut expected_conf_prefix = expected_confidence_prefix.to_vec();
136            conf_prefix.push(child_conf);
137            expected_conf_prefix.push(
138                (c.confidence_range.1 - c.confidence_range.0) as f64 / self.tree.num_tips as f64,
139            );
140            let child_pushed_result = self.eval_recurse(c, &conf_prefix, &expected_conf_prefix);
141            if !child_pushed_result && self.tree.is_taxon_leaf(c) {
142                self.confidence_vectors.push((
143                    c.confidence_range.0,
144                    conf_prefix,
145                    expected_conf_prefix,
146                ));
147                pushed_result = true;
148            }
149            pushed_result |= child_pushed_result;
150        }
151        if no_child_significant && self.tree.is_inner_taxon_node(node) {
152            let mut conf_prefix = confidence_prefix.to_vec();
153            let mut expected_conf_prefix = expected_confidence_prefix.to_vec();
154            let mut current_node = node;
155            while self.tree.is_inner_taxon_node(current_node) {
156                current_node = current_node
157                    .children
158                    .iter()
159                    .max_by(|c, d| {
160                        self.get_confidence(c)
161                            .partial_cmp(&self.get_confidence(d))
162                            .unwrap()
163                    })
164                    .unwrap();
165                conf_prefix.push(1.0 / self.rounding_factor);
166                expected_conf_prefix.push(
167                    (current_node.confidence_range.1 - current_node.confidence_range.0) as f64
168                        / self.tree.num_tips as f64,
169                );
170            }
171            self.confidence_vectors.push((
172                current_node.confidence_range.0,
173                conf_prefix,
174                expected_conf_prefix,
175            ));
176            pushed_result = true;
177        }
178        pushed_result
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use itertools::Itertools;
185
186    use crate::{
187        lineage::{EvaluationResult, Lineage},
188        tree::Tree,
189    };
190
191    #[test]
192    fn test_tree_construction() {
193        let lineages = vec![
194            String::from("Animalia,Chordata,Mammalia,Primates,Hominidae,Homo"),
195            "Animalia,Chordata,Mammalia,Primates,Hominidae,Pan".into(),
196            "Animalia,Chordata,Mammalia,Carnivora,Canidae,Canis".into(),
197            "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis".into(),
198            "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis".into(),
199        ];
200        let sequences = vec![
201            [0b00].repeat(9),
202            [0b00].repeat(9),
203            [0b00].repeat(9),
204            [0b00].repeat(9),
205            [0b00].repeat(9),
206        ];
207        let tree = Tree::new(lineages, sequences).unwrap();
208        let confidence_values = vec![0.1, 0.3, 0.4, 0.004, 0.004];
209        tree.print();
210        let query_label = String::from("q");
211        let lineage = Lineage::new(&query_label, &tree, confidence_values);
212        let result = lineage.evaluate();
213        assert_eq!(
214            result
215                .into_iter()
216                .map(
217                    |EvaluationResult {
218                         lineage,
219                         confidence_values,
220                         ..
221                     }| (lineage, confidence_values)
222                )
223                .collect_vec(),
224            vec![
225                (
226                    &String::from("Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis"),
227                    vec![0.81, 0.81, 0.81, 0.8, 0.7, 0.7,],
228                ),
229                (
230                    &"Animalia,Chordata,Mammalia,Carnivora,Canidae,Canis".into(),
231                    vec![0.81, 0.81, 0.81, 0.8, 0.1, 0.1,],
232                ),
233                (
234                    &"Animalia,Chordata,Mammalia,Primates,Hominidae,Pan".into(),
235                    vec![0.81, 0.81, 0.81, 0.01, 0.01, 0.01,],
236                ),
237            ]
238        );
239    }
240
241    #[test]
242    fn test_variable_lineage_length() {
243        let lineages = vec![
244            String::from("Animalia,Chordata,Mammalia,Primates,Hominidae,Homo,Homo_sapiens"),
245            "Animalia,Chordata,Mammalia,Primates,Hominidae,Pan".into(),
246            "Animalia,Chordata,Mammalia,Carnivora,Canidae,Canis".into(),
247            "Animalia,Chordata,Mammalia,Carnivora,Doggo".into(),
248            "Animalia,Chordata,Mammalia,Mouse".into(),
249            "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis".into(),
250            "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis".into(),
251        ];
252        let sequences = vec![
253            [0b00].repeat(9),
254            [0b00].repeat(9),
255            [0b00].repeat(9),
256            [0b00].repeat(9),
257            [0b00].repeat(9),
258            [0b00].repeat(9),
259            [0b00].repeat(9),
260        ];
261        let tree = Tree::new(lineages, sequences).unwrap();
262        let confidence_values = vec![0.05, 0.1, 0.3, 0.4, 0.1, 0.004, 0.004];
263        tree.print();
264        let query_label = String::from("q");
265        let lineage = Lineage::new(&query_label, &tree, confidence_values);
266        let result = lineage.evaluate();
267        dbg!(&result);
268        assert_eq!(
269            result
270                .into_iter()
271                .map(
272                    |EvaluationResult {
273                         lineage,
274                         confidence_values,
275                         ..
276                     }| (lineage, confidence_values)
277                )
278                .collect_vec(),
279            vec![
280                (
281                    &String::from("Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis"),
282                    vec![0.96, 0.96, 0.96, 0.85, 0.7, 0.7,],
283                ),
284                (
285                    &"Animalia,Chordata,Mammalia,Carnivora,Doggo".into(),
286                    vec![0.96, 0.96, 0.96, 0.85, 0.1,],
287                ),
288                (
289                    &"Animalia,Chordata,Mammalia,Carnivora,Canidae,Canis".into(),
290                    vec![0.96, 0.96, 0.96, 0.85, 0.05, 0.05,],
291                ),
292                (
293                    &"Animalia,Chordata,Mammalia,Mouse".into(),
294                    vec![0.96, 0.96, 0.96, 0.1],
295                ),
296                (
297                    &"Animalia,Chordata,Mammalia,Primates,Hominidae,Pan".into(),
298                    vec![0.96, 0.96, 0.96, 0.01, 0.01, 0.01,],
299                ),
300            ]
301        );
302    }
303
304    #[test]
305    fn test_likelihood_edge_case() {
306        let lineages = vec![
307            "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis".into(),
308            "Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis_ferrocius".into(),
309            "Animalia,Chordata,Mammalia,Carnivora,Canidae,Canis".into(),
310        ];
311        let sequences = vec![[0b00].repeat(9), [0b00].repeat(9), [0b00].repeat(9)];
312        let tree = Tree::new(lineages, sequences).unwrap();
313        let confidence_values = vec![0.004, 0.004, 0.004];
314        tree.print();
315        let query_label = String::from("q");
316        let lineage = Lineage::new(&query_label, &tree, confidence_values);
317        let result = lineage.evaluate();
318        assert_eq!(
319            result
320                .into_iter()
321                .map(
322                    |EvaluationResult {
323                         lineage,
324                         confidence_values,
325                         ..
326                     }| (lineage, confidence_values)
327                )
328                .collect_vec(),
329            vec![(
330                &String::from("Animalia,Chordata,Mammalia,Carnivora,Felidae,Felis_ferrocius"),
331                vec![0.01, 0.01, 0.01, 0.01, 0.01, 0.01,],
332            ),]
333        );
334    }
335}