use std::rc::Rc;
use std::fmt::Debug;
use linear_map::LinearMap;
use timely::dataflow::*;
use timely::dataflow::operators::*;
use timely::dataflow::channels::pact::Exchange;
use timely_sort::{LSBRadixSorter, Unsigned};
use collection::Lookup;
use iterators::coalesce::Coalesce;
use ::{Collection, Data};
pub trait ConsolidateExt<D: Data> {
fn consolidate(&self) -> Self;
fn consolidate_by<U: Unsigned, F: Fn(&D)->U+'static>(&self, part: F) -> Self;
}
impl<G: Scope, D: Ord+Data+Debug> ConsolidateExt<D> for Collection<G, D> {
fn consolidate(&self) -> Self {
self.consolidate_by(|x| x.hashed())
}
fn consolidate_by<U: Unsigned, F: Fn(&D)->U+'static>(&self, part: F) -> Self {
let mut inputs = LinearMap::new(); let part1 = Rc::new(part);
let part2 = part1.clone();
let exch = Exchange::new(move |&(ref x,_)| (*part1)(x).as_u64());
Collection::new(self.inner.unary_notify(exch, "Consolidate", vec![], move |input, output, notificator| {
input.for_each(|index, data| {
let default_threshold = usize::max_value();
let entry = inputs.entry_or_insert(index.time(), || (LSBRadixSorter::new(), 0, default_threshold));
let (ref mut sorter, ref mut count, ref mut thresh) = *entry;
*count += data.len();
sorter.extend(data.drain(..), &|x| (*part2)(&x.0));
if count > thresh {
*count = 0;
*thresh = 0;
let finished = sorter.finish(&|x| (*part2)(&x.0));
for mut block in finished {
let mut finger = 0;
for i in 1..block.len() {
if block[finger].0 == block[i].0 {
block[finger].1 += block[i].1;
block[i].1 = 0;
}
else {
finger = i;
}
}
block.retain(|x| x.1 != 0);
*thresh += block.len();
sorter.push_batch(block, &|x| (*part2)(&x.0));
}
if *thresh < default_threshold { *thresh = default_threshold; }
}
notificator.notify_at(index);
});
notificator.for_each(|index, _count, _notificator| {
if let Some((mut sorter, _, _)) = inputs.remove_key(&index) {
let mut session = output.session(&index);
let mut buffer = vec![];
let mut current = 0;
let source = sorter.finish(&|x| (*part2)(&x.0));
for (datum, wgt) in source.into_iter().flat_map(|x| x.into_iter()) {
let hash = (*part2)(&datum).as_u64();
if buffer.len() > 0 && hash != current {
buffer.sort_by(|x: &(D,i32),y: &(D,i32)| x.0.cmp(&y.0));
session.give_iterator(buffer.drain(..).coalesce());
}
buffer.push((datum,wgt));
current = hash;
}
if buffer.len() > 0 {
buffer.sort_by(|x: &(D,i32),y: &(D,i32)| x.0.cmp(&y.0));
session.give_iterator(buffer.drain(..).coalesce());
}
}
});
}))
}
}