use std::marker::PhantomData;
use itertools::Itertools;
use crate::pipeline::Node;
#[derive(Default)]
pub struct BalancedSelector<I, O> {
nodes: Vec<Box<dyn Node<Vec<I>, Output = Vec<O>> + Send>>,
_phantom: PhantomData<(I, O)>,
}
impl<I, O> BalancedSelector<I, O> {
pub fn add_node<N: Node<Vec<I>, Output = Vec<O>> + 'static + Send>(mut self, node: N) -> Self {
self.nodes.push(Box::new(node));
self
}
}
impl<I, O> Node<Vec<I>> for BalancedSelector<I, O> {
type Output = Vec<O>;
fn process(&mut self, mut input: Vec<I>) -> Self::Output {
let remaining_data = self
.nodes
.iter()
.map(|i| i.data_remaining(usize::MAX))
.collect_vec();
let total_remaining_data = remaining_data.iter().sum::<usize>() as f64;
remaining_data
.into_iter()
.enumerate()
.flat_map(|(index, i)| {
let proportion = i as f64 / total_remaining_data;
if input.is_empty() {
return vec![];
}
self.nodes[index].process(
input
.drain(..((input.len() as f64 * proportion) as usize).min(input.len()))
.collect(),
)
})
.collect()
}
fn data_remaining(&self, before: usize) -> usize {
self.nodes.iter().map(|n| n.data_remaining(before)).sum()
}
fn reset(&mut self) {
for node in &mut self.nodes {
node.reset();
}
}
}