use std::collections::HashMap;
use std::ops::Add;
use self::ordered_float::NotNan;
use ordered_float;
use crate::agg;
use crate::taxon::{TaxonId, TaxonList};
use crate::tree::Tree;
pub struct MixCalculator {
root: TaxonId,
parents: Vec<Option<TaxonId>>,
factor: f32,
}
impl MixCalculator {
pub fn new(root: TaxonId, taxonomy: &TaxonList, factor: f32) -> Self {
MixCalculator {
factor,
root,
parents: taxonomy.ancestry(),
}
}
}
impl agg::MultiThreadSafeAggregator for MixCalculator {}
impl agg::Aggregator for MixCalculator {
fn aggregate(&self, taxons: &HashMap<TaxonId, f32>) -> agg::Result<TaxonId> {
if taxons.is_empty() {
bail!(agg::ErrorKind::EmptyInput);
}
let subtree = Tree::new(self.root, &self.parents, taxons)?
.collapse(&Add::add)
.aggregate(&Add::add);
let mut base = &subtree;
while let Some(max) = base
.children
.iter()
.max_by_key(|c| NotNan::new(c.value).unwrap())
{
if max.value / base.value < self.factor {
break;
}
base = max;
}
Ok(base.root)
}
}
#[cfg(test)]
#[rustfmt::skip]
mod tests {
use super::MixCalculator;
use crate::agg::Aggregator;
use crate::fixtures;
#[test]
fn test_full_rtl() {
let aggregator = MixCalculator::new(fixtures::ROOT, &fixtures::by_id(), 0.0);
assert_matches!(aggregator.counting_aggregate(&vec![12884, 185751]), Ok(185751));
assert_matches!(aggregator.counting_aggregate(&vec![12884, 185751, 185752, 185752]), Ok(185752));
assert!(vec![185751, 185752].contains(&aggregator.counting_aggregate(&vec![1, 1, 10239, 10239, 12884, 185751, 185752]).unwrap()));
}
#[test]
fn test_full_lca() {
let aggregator = MixCalculator::new(fixtures::ROOT, &fixtures::by_id(), 1.0);
assert_matches!(aggregator.counting_aggregate(&vec![12884, 185751]), Ok(185751));
assert_matches!(aggregator.counting_aggregate(&vec![12884, 185751, 185752, 185752]), Ok(12884));
assert_matches!(aggregator.counting_aggregate(&vec![1, 1, 10239, 10239, 10239, 12884, 185751, 185752]), Ok(1));
}
#[test]
fn test_two_thirds() {
let aggregator = MixCalculator::new(fixtures::ROOT, &fixtures::by_id(), 0.66);
assert_matches!(aggregator.counting_aggregate(&vec![12884, 185751]), Ok(185751));
assert_matches!(aggregator.counting_aggregate(&vec![12884, 185751]), Ok(185751));
assert_matches!(aggregator.counting_aggregate(&vec![1, 12884, 12884, 185751]), Ok(185751));
assert_matches!(aggregator.counting_aggregate(&vec![1, 12884, 10239, 185751, 185751, 185752]), Ok(12884));
}
}