use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct NodeId(usize);
pub trait DataflowNode<T>: Send + Sync {
fn push(&self, value: T);
fn pull(&self) -> Option<T>;
}
struct SourceInner<T> {
buffer: std::collections::VecDeque<T>,
}
pub struct Source<T> {
inner: Arc<Mutex<SourceInner<T>>>,
}
impl<T: Clone + Send + Sync + 'static> Source<T> {
#[allow(clippy::should_implement_trait)]
pub fn from_iter(iter: impl Iterator<Item = T>) -> Self {
let buffer: std::collections::VecDeque<T> = iter.collect();
Source {
inner: Arc::new(Mutex::new(SourceInner { buffer })),
}
}
pub fn empty() -> Self {
Source {
inner: Arc::new(Mutex::new(SourceInner {
buffer: std::collections::VecDeque::new(),
})),
}
}
pub fn push_value(&self, value: T) {
if let Ok(mut g) = self.inner.lock() {
g.buffer.push_back(value);
}
}
pub fn len(&self) -> usize {
self.inner.lock().map(|g| g.buffer.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<T: Clone + Send + Sync + 'static> DataflowNode<T> for Source<T> {
fn push(&self, value: T) {
if let Ok(mut g) = self.inner.lock() {
g.buffer.push_back(value);
}
}
fn pull(&self) -> Option<T> {
self.inner.lock().ok()?.buffer.pop_front()
}
}
struct SinkInner<T> {
collected: Vec<T>,
callback: Option<Box<dyn Fn(T) + Send + Sync + 'static>>,
}
pub struct Sink<T> {
inner: Arc<Mutex<SinkInner<T>>>,
}
impl<T: Clone + Send + Sync + 'static> Sink<T> {
pub fn new() -> Self {
Sink {
inner: Arc::new(Mutex::new(SinkInner {
collected: Vec::new(),
callback: None,
})),
}
}
pub fn for_each(f: impl Fn(T) + Send + Sync + 'static) -> Self {
Sink {
inner: Arc::new(Mutex::new(SinkInner {
collected: Vec::new(),
callback: Some(Box::new(f)),
})),
}
}
pub fn drain(&self) -> Vec<T> {
self.inner
.lock()
.map(|mut g| std::mem::take(&mut g.collected))
.unwrap_or_default()
}
pub fn collect(&self) -> Vec<T>
where
T: Clone,
{
self.inner
.lock()
.map(|g| g.collected.clone())
.unwrap_or_default()
}
}
impl<T: Clone + Send + Sync + 'static> DataflowNode<T> for Sink<T> {
fn push(&self, value: T) {
if let Ok(mut g) = self.inner.lock() {
if let Some(cb) = &g.callback {
cb(value);
} else {
g.collected.push(value);
}
}
}
fn pull(&self) -> Option<T> {
None
}
}
pub struct Map<T, U> {
func: Arc<dyn Fn(T) -> U + Send + Sync + 'static>,
output: Arc<Mutex<std::collections::VecDeque<U>>>,
}
impl<T: Send + Sync + 'static, U: Send + Sync + 'static> Map<T, U> {
pub fn new(f: impl Fn(T) -> U + Send + Sync + 'static) -> Self {
Map {
func: Arc::new(f),
output: Arc::new(Mutex::new(std::collections::VecDeque::new())),
}
}
}
impl<T: Send + Sync + 'static, U: Send + Sync + 'static> DataflowNode<T> for Map<T, U> {
fn push(&self, value: T) {
let out = (self.func)(value);
if let Ok(mut q) = self.output.lock() {
q.push_back(out);
}
}
fn pull(&self) -> Option<T> {
None
}
}
impl<T: Send + Sync + 'static, U: Send + Sync + 'static> Map<T, U> {
pub fn pull_out(&self) -> Option<U> {
self.output.lock().ok()?.pop_front()
}
}
pub struct Filter<T> {
pred: Arc<dyn Fn(&T) -> bool + Send + Sync + 'static>,
output: Arc<Mutex<std::collections::VecDeque<T>>>,
}
impl<T: Send + Sync + 'static> Filter<T> {
pub fn new(pred: impl Fn(&T) -> bool + Send + Sync + 'static) -> Self {
Filter {
pred: Arc::new(pred),
output: Arc::new(Mutex::new(std::collections::VecDeque::new())),
}
}
}
impl<T: Send + Sync + 'static> DataflowNode<T> for Filter<T> {
fn push(&self, value: T) {
if (self.pred)(&value) {
if let Ok(mut q) = self.output.lock() {
q.push_back(value);
}
}
}
fn pull(&self) -> Option<T> {
self.output.lock().ok()?.pop_front()
}
}
pub struct Zip<T, U> {
left: Arc<Mutex<std::collections::VecDeque<T>>>,
right: Arc<Mutex<std::collections::VecDeque<U>>>,
output: Arc<Mutex<std::collections::VecDeque<(T, U)>>>,
}
impl<T: Send + Sync + 'static, U: Send + Sync + 'static> Zip<T, U> {
pub fn new() -> Self {
Zip {
left: Arc::new(Mutex::new(std::collections::VecDeque::new())),
right: Arc::new(Mutex::new(std::collections::VecDeque::new())),
output: Arc::new(Mutex::new(std::collections::VecDeque::new())),
}
}
pub fn push_left(&self, value: T) {
if let Ok(mut l) = self.left.lock() {
l.push_back(value);
}
self.try_pair();
}
pub fn push_right(&self, value: U) {
if let Ok(mut r) = self.right.lock() {
r.push_back(value);
}
self.try_pair();
}
fn try_pair(&self) {
loop {
let pair = {
let mut l = match self.left.lock() {
Ok(g) => g,
Err(_) => break,
};
let mut r = match self.right.lock() {
Ok(g) => g,
Err(_) => break,
};
match (l.pop_front(), r.pop_front()) {
(Some(lv), Some(rv)) => (lv, rv),
(Some(lv), None) => {
l.push_front(lv);
break;
}
(None, Some(rv)) => {
r.push_front(rv);
break;
}
(None, None) => break,
}
};
if let Ok(mut out) = self.output.lock() {
out.push_back(pair);
}
}
}
pub fn pull_pair(&self) -> Option<(T, U)> {
self.output.lock().ok()?.pop_front()
}
}
pub struct Buffer<T> {
batch_size: usize,
input: Arc<Mutex<std::collections::VecDeque<T>>>,
output: Arc<Mutex<std::collections::VecDeque<Vec<T>>>>,
}
impl<T: Send + Sync + 'static> Buffer<T> {
pub fn new(batch_size: usize) -> Self {
Buffer {
batch_size: batch_size.max(1),
input: Arc::new(Mutex::new(std::collections::VecDeque::new())),
output: Arc::new(Mutex::new(std::collections::VecDeque::new())),
}
}
fn flush_if_ready(&self) {
loop {
let batch: Option<Vec<T>> = {
let mut inp = match self.input.lock() {
Ok(g) => g,
Err(_) => break,
};
if inp.len() >= self.batch_size {
Some(inp.drain(..self.batch_size).collect())
} else {
None
}
};
match batch {
Some(b) => {
if let Ok(mut out) = self.output.lock() {
out.push_back(b);
}
}
None => break,
}
}
}
pub fn pull_batch(&self) -> Option<Vec<T>> {
self.output.lock().ok()?.pop_front()
}
pub fn batch_count(&self) -> usize {
self.output.lock().map(|g| g.len()).unwrap_or(0)
}
}
impl<T: Send + Sync + 'static> DataflowNode<T> for Buffer<T> {
fn push(&self, value: T) {
if let Ok(mut inp) = self.input.lock() {
inp.push_back(value);
}
self.flush_if_ready();
}
fn pull(&self) -> Option<T> {
None
}
}
enum AnyNode {
SourceI32(Arc<Source<i32>>),
SinkI32(Arc<Sink<i32>>),
MapI32(Arc<Map<i32, i32>>),
FilterI32(Arc<Filter<i32>>),
BufferI32(Arc<Buffer<i32>>),
SourceF64(Arc<Source<f64>>),
SinkF64(Arc<Sink<f64>>),
}
#[allow(missing_debug_implementations)]
pub struct DataflowGraph {
nodes: Vec<AnyNode>,
edges: Vec<(NodeId, NodeId)>,
}
impl DataflowGraph {
pub fn new() -> Self {
DataflowGraph {
nodes: Vec::new(),
edges: Vec::new(),
}
}
pub fn add_source(&mut self, src: Source<i32>) -> NodeId {
let id = NodeId(self.nodes.len());
self.nodes.push(AnyNode::SourceI32(Arc::new(src)));
id
}
pub fn add_map(&mut self, map: Map<i32, i32>) -> NodeId {
let id = NodeId(self.nodes.len());
self.nodes.push(AnyNode::MapI32(Arc::new(map)));
id
}
pub fn add_filter(&mut self, filter: Filter<i32>) -> NodeId {
let id = NodeId(self.nodes.len());
self.nodes.push(AnyNode::FilterI32(Arc::new(filter)));
id
}
pub fn add_sink(&mut self, sink: Sink<i32>) -> NodeId {
let id = NodeId(self.nodes.len());
self.nodes.push(AnyNode::SinkI32(Arc::new(sink)));
id
}
pub fn add_buffer(&mut self, buf: Buffer<i32>) -> NodeId {
let id = NodeId(self.nodes.len());
self.nodes.push(AnyNode::BufferI32(Arc::new(buf)));
id
}
pub fn add_source_f64(&mut self, src: Source<f64>) -> NodeId {
let id = NodeId(self.nodes.len());
self.nodes.push(AnyNode::SourceF64(Arc::new(src)));
id
}
pub fn add_sink_f64(&mut self, sink: Sink<f64>) -> NodeId {
let id = NodeId(self.nodes.len());
self.nodes.push(AnyNode::SinkF64(Arc::new(sink)));
id
}
pub fn connect(&mut self, src: NodeId, dst: NodeId) {
self.edges.push((src, dst));
}
pub fn run(&self) {
let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); self.nodes.len()];
for &(NodeId(src), NodeId(dst)) in &self.edges {
if src < adjacency.len() {
adjacency[src].push(dst);
}
}
let mut changed = true;
while changed {
changed = false;
for (src_idx, node) in self.nodes.iter().enumerate() {
match node {
AnyNode::SourceI32(src) => {
while let Some(v) = src.pull() {
changed = true;
self.propagate_i32(v, &adjacency[src_idx]);
}
}
AnyNode::MapI32(map) => {
while let Some(v) = map.pull_out() {
changed = true;
self.propagate_i32(v, &adjacency[src_idx]);
}
}
AnyNode::FilterI32(flt) => {
while let Some(v) = flt.pull() {
changed = true;
self.propagate_i32(v, &adjacency[src_idx]);
}
}
AnyNode::SourceF64(src) => {
while let Some(v) = src.pull() {
changed = true;
self.propagate_f64(v, &adjacency[src_idx]);
}
}
_ => {}
}
}
}
}
fn propagate_i32(&self, value: i32, dst_indices: &[usize]) {
for &dst in dst_indices {
match self.nodes.get(dst) {
Some(AnyNode::MapI32(map)) => map.push(value),
Some(AnyNode::FilterI32(flt)) => flt.push(value),
Some(AnyNode::SinkI32(sink)) => sink.push(value),
Some(AnyNode::BufferI32(buf)) => buf.push(value),
_ => {}
}
}
}
fn propagate_f64(&self, value: f64, dst_indices: &[usize]) {
for &dst in dst_indices {
if let Some(AnyNode::SinkF64(sink)) = self.nodes.get(dst) {
sink.push(value)
}
}
}
pub fn collect_sink(&self, id: NodeId) -> Vec<i32> {
match self.nodes.get(id.0) {
Some(AnyNode::SinkI32(sink)) => sink.drain(),
_ => Vec::new(),
}
}
pub fn collect_sink_f64(&self, id: NodeId) -> Vec<f64> {
match self.nodes.get(id.0) {
Some(AnyNode::SinkF64(sink)) => sink.drain(),
_ => Vec::new(),
}
}
pub fn collect_buffer(&self, id: NodeId) -> Vec<Vec<i32>> {
match self.nodes.get(id.0) {
Some(AnyNode::BufferI32(buf)) => {
let mut batches = Vec::new();
while let Some(b) = buf.pull_batch() {
batches.push(b);
}
batches
}
_ => Vec::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dataflow_map() {
let mut graph = DataflowGraph::new();
let src = Source::from_iter(0..5i32);
let map = Map::new(|x: i32| x * 3);
let sink: Sink<i32> = Sink::new();
let src_id = graph.add_source(src);
let map_id = graph.add_map(map);
let snk_id = graph.add_sink(sink);
graph.connect(src_id, map_id);
graph.connect(map_id, snk_id);
graph.run();
let res = graph.collect_sink(snk_id);
assert_eq!(res, vec![0, 3, 6, 9, 12]);
}
#[test]
fn test_dataflow_filter() {
let mut graph = DataflowGraph::new();
let src = Source::from_iter(0..10i32);
let flt = Filter::new(|x: &i32| x % 2 == 0);
let sink: Sink<i32> = Sink::new();
let src_id = graph.add_source(src);
let flt_id = graph.add_filter(flt);
let snk_id = graph.add_sink(sink);
graph.connect(src_id, flt_id);
graph.connect(flt_id, snk_id);
graph.run();
let res = graph.collect_sink(snk_id);
assert_eq!(res, vec![0, 2, 4, 6, 8]);
}
#[test]
fn test_dataflow_source_sink() {
let mut graph = DataflowGraph::new();
let src = Source::from_iter(1..=5i32);
let sink: Sink<i32> = Sink::new();
let src_id = graph.add_source(src);
let snk_id = graph.add_sink(sink);
graph.connect(src_id, snk_id);
graph.run();
let res = graph.collect_sink(snk_id);
assert_eq!(res, vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_dataflow_buffer() {
let mut graph = DataflowGraph::new();
let src = Source::from_iter(0..9i32);
let buf = Buffer::new(3);
let src_id = graph.add_source(src);
let snk_buf_id = graph.add_buffer(buf);
graph.connect(src_id, snk_buf_id);
graph.run();
let batches = graph.collect_buffer(snk_buf_id);
assert_eq!(batches.len(), 3);
assert_eq!(batches[0], vec![0, 1, 2]);
assert_eq!(batches[1], vec![3, 4, 5]);
assert_eq!(batches[2], vec![6, 7, 8]);
}
#[test]
fn test_dataflow_zip() {
let zip: Zip<i32, i32> = Zip::new();
zip.push_left(1);
zip.push_left(2);
zip.push_right(10);
zip.push_right(20);
zip.push_left(3);
zip.push_right(30);
let mut pairs = Vec::new();
while let Some(p) = zip.pull_pair() {
pairs.push(p);
}
assert_eq!(pairs, vec![(1, 10), (2, 20), (3, 30)]);
}
#[test]
fn test_source_manual_push() {
let src: Source<i32> = Source::empty();
src.push_value(5);
src.push_value(6);
assert_eq!(src.pull(), Some(5));
assert_eq!(src.pull(), Some(6));
assert_eq!(src.pull(), None);
}
#[test]
fn test_sink_collect() {
let sink: Sink<i32> = Sink::new();
sink.push(1);
sink.push(2);
sink.push(3);
assert_eq!(sink.collect(), vec![1, 2, 3]);
}
#[test]
fn test_dataflow_map_filter_pipeline() {
let mut graph = DataflowGraph::new();
let src = Source::from_iter(0..10i32);
let map = Map::new(|x: i32| x * 2);
let filter = Filter::new(|x: &i32| *x > 8);
let sink: Sink<i32> = Sink::new();
let src_id = graph.add_source(src);
let map_id = graph.add_map(map);
let flt_id = graph.add_filter(filter);
let snk_id = graph.add_sink(sink);
graph.connect(src_id, map_id);
graph.connect(map_id, flt_id);
graph.connect(flt_id, snk_id);
graph.run();
let res = graph.collect_sink(snk_id);
assert_eq!(res, vec![10, 12, 14, 16, 18]);
}
}