use differential_dataflow::{
collection::{AsCollection, Collection},
input::{Input as _, InputSession},
operators::{
arrange::{
agent::TraceAgent,
arrangement::{ArrangeByKey, Arranged},
},
count::CountTotal,
join::JoinCore,
reduce::ReduceCore,
threshold::ThresholdTotal,
},
trace::implementations::{ord::OrdValBatch, spine_fueled_neu::Spine},
ExchangeData, Hashable,
};
use std::{
collections::BTreeMap, marker::PhantomData, rc::Rc, sync::mpsc::Receiver, time::Duration,
};
use timely::{
communication::allocator::Thread,
dataflow::{
operators::{capture::Event, probe::Handle, Capture, Map},
scopes::Child,
},
worker::Worker,
};
pub type Scope<'a> = Child<'a, Worker<Thread>, usize>;
pub struct Input<T: ExchangeData>(InputSession<usize, T, isize>, Option<Duration>);
impl<T: ExchangeData> Input<T> {
pub fn advance_to(&mut self, time: usize) {
self.0.advance_to(time)
}
pub fn flush(&mut self) {
self.0.flush()
}
pub fn insert(&mut self, value: T) {
self.0.insert(value)
}
pub fn remove(&mut self, value: T) {
self.0.remove(value)
}
pub fn look_back(&self) -> Option<Duration> {
self.1
}
}
pub struct Output<T: ExchangeData>(Receiver<Event<usize, (T, usize, isize)>>);
impl<T: ExchangeData> Output<T> {
pub fn msgs<'a>(&'a mut self) -> impl Iterator<Item = (T, isize)> + 'a {
self.0.try_iter().flat_map(|elem| {
if let Event::Messages(_, msgs) = elem {
msgs.into_iter().map(|(msg, _, mult)| (msg, mult)).collect()
} else {
vec![]
}
})
}
}
pub struct Probe(Handle<usize>);
impl Probe {
pub fn less_than(&self, time: usize) -> bool {
self.0.less_than(&time)
}
}
pub trait NeedsState {
fn needs_state() -> bool;
}
pub struct Stateless {}
impl NeedsState for Stateless {
fn needs_state() -> bool {
false
}
}
pub struct Stateful {}
impl NeedsState for Stateful {
fn needs_state() -> bool {
true
}
}
pub struct Flow<'a, T: ExchangeData, St: NeedsState>(
Collection<Child<'a, Worker<Thread>, usize>, T, isize>,
PhantomData<St>,
);
impl<'a, T: ExchangeData> Flow<'a, T, Stateless> {
pub fn new(scope: &mut Child<'a, Worker<Thread>, usize>) -> (Input<T>, Flow<'a, T, Stateless>) {
let (input, collection) = scope.new_collection();
(Input(input, None), Flow(collection, PhantomData))
}
pub fn new_limited(
scope: &mut Child<'a, Worker<Thread>, usize>,
look_back: Duration,
) -> (Input<T>, Flow<'a, T, Stateless>) {
let (input, collection) = scope.new_collection();
(Input(input, Some(look_back)), Flow(collection, PhantomData))
}
}
impl<'a, T: ExchangeData, St: NeedsState> Flow<'a, T, St> {
pub fn filter(&self, f: impl Fn(&T) -> bool + 'static) -> Self {
Self(self.0.filter(f), PhantomData)
}
pub fn filter_mut(&self, f: impl FnMut(&T) -> bool + 'static) -> Flow<'a, T, Stateful> {
Flow(self.0.filter(f), PhantomData)
}
pub fn map<U: ExchangeData>(&self, f: impl Fn(T) -> U + 'static) -> Flow<'a, U, St> {
Flow(self.0.map(f), PhantomData)
}
pub fn map_mut<U: ExchangeData>(
&self,
f: impl FnMut(T) -> U + 'static,
) -> Flow<'a, U, Stateful> {
Flow(self.0.map(f), PhantomData)
}
pub fn map_in_place(&self, f: impl Fn(&mut T) + 'static) -> Self {
Self(self.0.map_in_place(f), PhantomData)
}
pub fn map_in_place_mut(&self, f: impl FnMut(&mut T) + 'static) -> Flow<'a, T, Stateful> {
Flow(self.0.map_in_place(f), PhantomData)
}
pub fn flat_map<U, I>(&self, f: impl Fn(T) -> I + 'static) -> Flow<'a, U, St>
where
U: ExchangeData,
I: IntoIterator<Item = U>,
{
Flow(self.0.flat_map(f), PhantomData)
}
pub fn flat_map_mut<U, I>(&self, f: impl FnMut(T) -> I + 'static) -> Flow<'a, U, Stateful>
where
U: ExchangeData,
I: IntoIterator<Item = U>,
{
Flow(self.0.flat_map(f), PhantomData)
}
pub fn monotonic_max_by<K: ExchangeData>(
&self,
f: impl Fn(&T) -> K + 'static,
) -> Flow<'a, T, Stateful> {
let mut highest = BTreeMap::new();
Flow(
self.0
.inner
.flat_map(move |(mut data, time, delta)| {
let key = f(&data);
if let Some(max) = highest.get_mut(&key) {
if &data > max {
std::mem::swap(&mut data, max);
vec![(data, time, -1), (max.clone(), time, 1)]
} else {
assert!(
&data != max || delta >= 0,
"cannot remove max element {:?} from monotonic_max_by",
data,
);
vec![]
}
} else {
highest.insert(key, data.clone());
vec![(data, time, 1)]
}
})
.as_collection(),
PhantomData,
)
}
pub fn monotonic_representative_by<K: ExchangeData>(
&self,
f: impl Fn(&T) -> K + 'static,
) -> Flow<'a, T, Stateful> {
let mut repr = BTreeMap::<K, (T, isize)>::new();
Flow(
self.0
.inner
.flat_map(move |(data, time, delta)| {
let key = f(&data);
if let Some(repr) = repr.get_mut(&key) {
let (prev, mult) = repr;
if prev == &data {
*mult += delta;
assert!(
*mult != 0,
"cannot remove representative {:?} from collection",
data
);
vec![]
} else {
vec![]
}
} else {
repr.insert(key, (data.clone(), 1));
vec![(data, time, 1)]
}
})
.as_collection(),
PhantomData,
)
}
pub fn negate(&self) -> Self {
Self(self.0.negate(), PhantomData)
}
pub fn group_by<K: ExchangeData + Hashable>(
&self,
mut f: impl FnMut(&T) -> K + 'static,
) -> Grouped<'a, K, T> {
Grouped(
self.0.map(move |t| (f(&t), t)).arrange_by_key(),
PhantomData,
)
}
pub fn inspect(&self, f: impl Fn(&(T, usize, isize)) + 'static) -> Self {
Self(self.0.inspect(f), PhantomData)
}
pub fn inspect_mut(
&self,
f: impl FnMut(&(T, usize, isize)) + 'static,
) -> Flow<'a, T, Stateful> {
Flow(self.0.inspect(f), PhantomData)
}
pub fn probe(&self) -> Probe {
Probe(self.0.probe())
}
pub fn output(&self) -> Output<T> {
Output(self.0.inner.capture())
}
}
impl<'a, T: ExchangeData> Flow<'a, T, Stateless> {
pub fn concat<St: NeedsState>(&self, other: &Flow<'a, T, St>) -> Flow<'a, T, St> {
Flow(self.0.concat(&other.0), PhantomData)
}
pub fn concat_many<St: NeedsState>(
&self,
others: impl IntoIterator<Item = Flow<'a, T, St>>,
) -> Flow<'a, T, St> {
Flow(
self.0.concatenate(others.into_iter().map(|x| x.0)),
PhantomData,
)
}
}
impl<'a, T: ExchangeData> Flow<'a, T, Stateful> {
pub fn concat<St: NeedsState>(&self, other: &Flow<'a, T, St>) -> Flow<'a, T, Stateful> {
Flow(self.0.concat(&other.0), PhantomData)
}
pub fn concat_many<St: NeedsState>(
&self,
others: impl IntoIterator<Item = Flow<'a, T, St>>,
) -> Flow<'a, T, Stateful> {
Flow(
self.0.concatenate(others.into_iter().map(|x| x.0)),
PhantomData,
)
}
}
impl<'a, T: ExchangeData + Hashable, St: NeedsState> Flow<'a, T, St> {
pub fn distinct(&self) -> Flow<'a, T, Stateful> {
Flow(self.0.distinct_total(), PhantomData)
}
pub fn threshold(
&self,
mut f: impl FnMut(&T, isize) -> isize + 'static,
) -> Flow<'a, T, Stateful> {
Flow(self.0.threshold_total(move |k, r| f(k, *r)), PhantomData)
}
pub fn count(&self) -> Flow<'a, (T, isize), Stateful> {
Flow(self.0.count_total(), PhantomData)
}
}
impl<'a, K: ExchangeData + Hashable, V: ExchangeData, St: NeedsState> Flow<'a, (K, V), St> {
pub fn group(&self) -> Grouped<'a, K, V> {
Grouped(self.0.arrange_by_key(), PhantomData)
}
}
#[allow(clippy::type_complexity)]
pub struct Grouped<'a, K, V>(
Arranged<
Child<'a, Worker<Thread>, usize>,
TraceAgent<Spine<K, V, usize, isize, Rc<OrdValBatch<K, V, usize, isize>>>>,
>,
PhantomData<(K, V)>,
)
where
K: ExchangeData + Hashable,
V: ExchangeData;
impl<'a, K, V> Grouped<'a, K, V>
where
K: ExchangeData + Hashable,
V: ExchangeData,
{
pub fn join<V2, L, D, I>(&self, other: &Grouped<'a, K, V2>, f: L) -> Flow<'a, D, Stateful>
where
V2: ExchangeData,
D: ExchangeData,
I: IntoIterator<Item = D>,
L: FnMut(&K, &V, &V2) -> I + 'static,
{
Flow(self.0.join_core(&other.0, f), PhantomData)
}
pub fn join_single<V2, L, D>(
&self,
other: &Grouped<'a, K, V2>,
mut f: L,
) -> Flow<'a, D, Stateful>
where
V2: ExchangeData,
D: ExchangeData,
L: FnMut(&K, &V, &V2) -> D + 'static,
{
Flow(
self.0
.join_core(&other.0, move |k, v, v2| std::iter::once(f(k, v, v2))),
PhantomData,
)
}
pub fn reduce<V2, L>(&self, f: L) -> Grouped<'a, K, V2>
where
V2: ExchangeData,
L: FnMut(&K, &[(&V, isize)], &mut Vec<(V2, isize)>) + 'static,
{
Grouped(self.0.reduce_abelian("Reduce", f), PhantomData)
}
pub fn threshold(&self, mut f: impl FnMut(&K, &V, isize) -> isize + 'static) -> Self {
self.reduce(move |k, i, o| o.extend(i.iter().map(|(v, m)| ((**v).clone(), f(k, v, *m)))))
}
pub fn distinct(&self) -> Self {
self.threshold(|_, _, _| 1)
}
pub fn count(&self) -> Grouped<'a, K, isize> {
self.reduce(|_, i, o| o.push((i.iter().map(|x| x.1).sum(), 1)))
}
pub fn min(&self) -> Self {
self.reduce(|_, i, o| o.push((i[0].0.clone(), 1)))
}
pub fn max(&self) -> Self {
self.reduce(|_, i, o| o.push((i[i.len() - 1].0.clone(), 1)))
}
pub fn max_by<T, F>(&self, f: F) -> Self
where
F: Fn(&V) -> T + 'static + Clone,
T: Ord,
{
self.reduce(move |_, i, o| {
o.push((
i.iter().map(|x| x.0.clone()).max_by_key(f.clone()).unwrap(),
1,
))
})
}
pub fn ungroup(&self) -> Flow<'a, V, Stateful> {
self.ungroup_with(|_, v| v.clone())
}
pub fn ungroup_with<T: ExchangeData>(
&self,
f: impl FnMut(&K, &V) -> T + 'static,
) -> Flow<'a, T, Stateful> {
Flow(self.0.as_collection(f), PhantomData)
}
pub fn ungroup_both(&self) -> Flow<'a, (K, V), Stateful> {
self.ungroup_with(|k, v| (k.clone(), v.clone()))
}
pub fn regroup<K2, V2, L>(&self, f: L) -> Grouped<'a, K2, V2>
where
K2: ExchangeData + Hashable,
V2: ExchangeData,
L: FnMut(&K, &V) -> (K2, V2) + 'static,
{
self.ungroup_with(f).group()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::machine::{Inputs, Machine};
use anyhow::Result;
impl Inputs for Input<i32> {
type Elem = i32;
fn advance_clock(&mut self, time: usize) {
self.advance_to(time);
self.flush();
}
fn feed(&mut self, input: &Self::Elem) -> Result<()> {
self.insert(*input);
Ok(())
}
}
#[test]
fn monotonic_max_by() {
let mut machine = Machine::new(|scope| {
let (handle, coll) = Flow::<i32, _>::new(scope);
let out = coll.monotonic_max_by(|x| *x % 5);
(handle, out)
});
machine.assert(&[1], &[(1, 1)]);
machine.assert(&[1], &[]);
machine.assert(&[11, 2], &[(1, -1), (2, 1), (11, 1)]);
machine.assert(&[6, 7], &[(2, -1), (7, 1)]);
}
#[test]
fn monotonic_representative_by() {
let mut machine = Machine::new(|scope| {
let (handle, coll) = Flow::<i32, _>::new(scope);
let out = coll.monotonic_representative_by(|x| *x % 5);
(handle, out)
});
machine.assert(&[1], &[(1, 1)]);
machine.assert(&[1], &[]);
machine.assert(&[11, 2], &[(2, 1)]);
machine.assert(&[6, 7], &[]);
}
}