use std::collections::HashMap;
use serde::Deserialize;
use crate::{
dataflow::{
stream::{ExtractStream, IngestStream, OperatorStream, Stream, StreamId},
Data, LoopStream,
},
OperatorConfig, OperatorId,
};
use super::{
job_graph::JobGraph, AbstractOperator, AbstractStream, AbstractStreamT, OperatorRunner,
StreamSetupHook,
};
pub struct AbstractGraph {
operators: HashMap<OperatorId, AbstractOperator>,
streams: HashMap<StreamId, Box<dyn AbstractStreamT>>,
ingest_streams: HashMap<StreamId, Box<dyn StreamSetupHook>>,
extract_streams: HashMap<StreamId, Box<dyn StreamSetupHook>>,
loop_streams: HashMap<StreamId, Option<StreamId>>,
}
impl AbstractGraph {
pub fn new() -> Self {
Self {
operators: HashMap::new(),
streams: HashMap::new(),
ingest_streams: HashMap::new(),
extract_streams: HashMap::new(),
loop_streams: HashMap::new(),
}
}
pub(crate) fn add_operator<F, T, U, V, W>(
&mut self,
config: OperatorConfig,
runner: F,
left_read_stream: Option<&dyn Stream<T>>,
right_read_stream: Option<&dyn Stream<U>>,
left_write_stream: Option<&OperatorStream<V>>,
right_write_stream: Option<&OperatorStream<W>>,
) where
F: OperatorRunner,
for<'a> T: Data + Deserialize<'a>,
for<'a> U: Data + Deserialize<'a>,
for<'a> V: Data + Deserialize<'a>,
for<'a> W: Data + Deserialize<'a>,
{
let read_streams = match (left_read_stream, right_read_stream) {
(Some(ls), Some(rs)) => vec![ls.id(), rs.id()],
(Some(ls), None) => vec![ls.id()],
(None, Some(rs)) => vec![rs.id()],
(None, None) => vec![],
};
let write_streams = match (left_write_stream, right_write_stream) {
(Some(ls), Some(rs)) => vec![ls.id(), rs.id()],
(Some(ls), None) => vec![ls.id()],
(None, Some(rs)) => vec![rs.id()],
(None, None) => vec![],
};
if let Some(ls) = left_write_stream {
let stream_name = if write_streams.len() == 1 {
format!("{}-write-stream", config.get_name())
} else {
format!("{}-write-left-stream", config.get_name())
};
let abstract_stream = AbstractStream::<V>::new(ls.id(), stream_name);
self.streams.insert(ls.id(), Box::new(abstract_stream));
}
if let Some(rs) = right_write_stream {
let stream_name = format!("{}-right-write-stream", config.get_name());
let abstract_stream = AbstractStream::<W>::new(rs.id(), stream_name);
self.streams.insert(rs.id(), Box::new(abstract_stream));
}
let operator_id = config.id;
let abstract_operator = AbstractOperator {
id: operator_id,
runner: Box::new(runner),
config,
read_streams,
write_streams,
};
self.operators.insert(operator_id, abstract_operator);
}
pub(crate) fn add_ingest_stream<D>(&mut self, ingest_stream: &IngestStream<D>)
where
for<'a> D: Data + Deserialize<'a>,
{
let name = format!("ingest-stream-{}", self.ingest_streams.len());
let abstract_stream = AbstractStream::<D>::new(ingest_stream.id(), name);
self.streams
.insert(ingest_stream.id(), Box::new(abstract_stream));
let setup_hook = ingest_stream.get_setup_hook();
self.ingest_streams
.insert(ingest_stream.id(), Box::new(setup_hook));
}
pub(crate) fn add_extract_stream<D>(&mut self, extract_stream: &ExtractStream<D>)
where
for<'a> D: Data + Deserialize<'a>,
{
let setup_hook = extract_stream.get_setup_hook();
self.extract_streams
.insert(extract_stream.id(), Box::new(setup_hook));
}
pub(crate) fn add_loop_stream<D>(&mut self, loop_stream: &LoopStream<D>)
where
for<'a> D: Data + Deserialize<'a>,
{
let name = format!("loop-stream-{}", self.loop_streams.len());
let abstract_stream = AbstractStream::<D>::new(loop_stream.id(), name);
self.streams
.insert(loop_stream.id(), Box::new(abstract_stream));
self.loop_streams.insert(loop_stream.id(), None);
}
pub(crate) fn connect_loop<D>(
&mut self,
loop_stream: &LoopStream<D>,
stream: &OperatorStream<D>,
) where
for<'a> D: Data + Deserialize<'a>,
{
if let Some(v) = self.loop_streams.get_mut(&loop_stream.id()) {
*v = Some(stream.id());
}
}
pub(crate) fn get_stream_name(&self, stream_id: &StreamId) -> String {
self.streams.get(stream_id).unwrap().name()
}
pub(crate) fn set_stream_name(&mut self, stream_id: &StreamId, name: String) {
self.streams.get_mut(stream_id).unwrap().set_name(name);
}
pub(crate) fn resolve_stream_id(&self, stream_id: &StreamId) -> Option<StreamId> {
match self.loop_streams.get(stream_id) {
Some(connected_stream_id) => *connected_stream_id,
None => Some(*stream_id),
}
}
pub(crate) fn compile(&mut self) -> JobGraph {
for (loop_stream_id, connected_stream_id) in self.loop_streams.iter() {
if connected_stream_id.is_none() {
panic!("LoopStream {} is not connected to another loop. Call `LoopStream::connect_loop` to fix.", loop_stream_id);
}
}
let streams: Vec<_> = self
.streams
.iter()
.filter(|(k, _)| !self.loop_streams.contains_key(k))
.map(|(_k, v)| v.box_clone())
.collect();
let mut ingest_streams = HashMap::new();
let ingest_stream_ids: Vec<_> = self.ingest_streams.keys().cloned().collect();
for stream_id in ingest_stream_ids {
let setup_hook = self.ingest_streams.remove(&stream_id).unwrap();
ingest_streams.insert(stream_id, setup_hook.box_clone());
self.ingest_streams.insert(stream_id, setup_hook);
}
let mut extract_streams = HashMap::new();
let extract_stream_ids: Vec<_> = self.extract_streams.keys().cloned().collect();
for stream_id in extract_stream_ids {
let setup_hook = self.extract_streams.remove(&stream_id).unwrap();
extract_streams.insert(stream_id, setup_hook.box_clone());
self.extract_streams.insert(stream_id, setup_hook);
}
let mut operators: Vec<_> = self.operators.values().cloned().collect();
for o in operators.iter_mut() {
for i in 0..o.read_streams.len() {
if self.loop_streams.contains_key(&o.read_streams[i]) {
let resolved_id = self.resolve_stream_id(&o.read_streams[i]).unwrap();
o.read_streams[i] = resolved_id;
}
}
}
JobGraph::new(operators, streams, ingest_streams, extract_streams)
}
pub(crate) fn clone(&mut self) -> Self {
let streams: HashMap<_, _> = self
.streams
.iter()
.map(|(&k, v)| (k, v.box_clone()))
.collect();
let mut ingest_streams = HashMap::new();
let ingest_stream_ids: Vec<_> = self.ingest_streams.keys().cloned().collect();
for stream_id in ingest_stream_ids {
let setup_hook = self.ingest_streams.remove(&stream_id).unwrap();
ingest_streams.insert(stream_id, setup_hook.box_clone());
self.ingest_streams.insert(stream_id, setup_hook);
}
let mut extract_streams = HashMap::new();
let extract_stream_ids: Vec<_> = self.extract_streams.keys().cloned().collect();
for stream_id in extract_stream_ids {
let setup_hook = self.extract_streams.remove(&stream_id).unwrap();
extract_streams.insert(stream_id, setup_hook.box_clone());
self.extract_streams.insert(stream_id, setup_hook);
}
Self {
operators: self.operators.clone(),
streams,
ingest_streams,
extract_streams,
loop_streams: self.loop_streams.clone(),
}
}
}