use std::collections::HashMap;
use self::ordered_float::NotNan;
use ordered_float;
use crate::agg;
use crate::taxon;
use crate::taxon::{TaxonId, TaxonList};
pub struct RTLCalculator {
pub root: TaxonId,
pub ancestors: Vec<Option<TaxonId>>,
}
impl RTLCalculator {
pub fn new(root: TaxonId, taxons: &TaxonList) -> Self {
let mut ancestors = taxons.ancestry();
ancestors[root] = None;
RTLCalculator { root, ancestors }
}
}
impl agg::MultiThreadSafeAggregator for RTLCalculator {}
impl agg::Aggregator for RTLCalculator {
fn aggregate(&self, taxons: &HashMap<TaxonId, f32>) -> agg::Result<TaxonId> {
let mut rtl_counts = taxons.clone();
for (taxon, count) in rtl_counts.iter_mut() {
let mut next = *taxon;
while let Some(ancestor) = self.ancestors[next] {
*count += *taxons.get(&ancestor).unwrap_or(&0.0);
next = ancestor;
}
if next != self.root {
bail!(agg::ErrorKind::Taxon(taxon::ErrorKind::UnknownTaxon(next)));
}
}
rtl_counts
.iter()
.max_by_key(|&(_, &count)| NotNan::new(count).unwrap())
.map(|tup| *tup.0)
.ok_or_else(|| agg::ErrorKind::EmptyInput.into())
}
}
#[cfg(test)]
#[rustfmt::skip]
mod tests {
use super::RTLCalculator;
use crate::agg::Aggregator;
use crate::fixtures;
#[test]
fn test_all_on_same_path() {
let aggregator = RTLCalculator::new(fixtures::ROOT, &fixtures::by_id());
assert_matches!(aggregator.counting_aggregate(&vec![1]), Ok(1));
assert_matches!(aggregator.counting_aggregate(&vec![1, 12884]), Ok(12884));
assert_matches!(aggregator.counting_aggregate(&vec![1, 12884, 185751]), Ok(185751));
}
#[test]
fn favouring_root() {
let aggregator = RTLCalculator::new(fixtures::ROOT, &fixtures::by_id());
assert_matches!(aggregator.counting_aggregate(&vec![1, 1, 1, 185751, 1, 1]), Ok(185751));
}
#[test]
fn leaning_close() {
let aggregator = RTLCalculator::new(fixtures::ROOT, &fixtures::by_id());
assert_matches!(aggregator.counting_aggregate(&vec![1, 1, 185752, 185751, 185751, 1]), Ok(185751));
}
#[test]
fn non_deterministic() {
let aggregator = RTLCalculator::new(fixtures::ROOT, &fixtures::by_id());
assert!(vec![185751, 185752].contains(&aggregator.counting_aggregate(&vec![1, 1, 185752, 185751, 1]).unwrap()));
}
}