use std::{
borrow::BorrowMut,
sync::{Arc, Mutex},
thread,
time::Duration,
};
use serde::Deserialize;
use crate::{
dataflow::{
graph::{default_graph, StreamSetupHook},
Data, Message,
},
scheduler::channel_manager::ChannelManager,
};
use super::{
errors::{ReadError, TryReadError},
OperatorStream, ReadStream, Stream, StreamId,
};
pub struct ExtractStream<D>
where
for<'a> D: Data + Deserialize<'a>,
{
id: StreamId,
read_stream_option: Option<ReadStream<D>>,
channel_manager_option: Arc<Mutex<Option<Arc<Mutex<ChannelManager>>>>>,
}
impl<D> ExtractStream<D>
where
for<'a> D: Data + Deserialize<'a>,
{
pub fn new(stream: &OperatorStream<D>) -> Self {
tracing::debug!(
"Initializing an ExtractStream with the ReadStream {} (ID: {})",
stream.name(),
stream.id(),
);
let id = stream.id();
let extract_stream = Self {
id,
read_stream_option: None,
channel_manager_option: Arc::new(Mutex::new(None)),
};
default_graph::add_extract_stream(&extract_stream);
extract_stream
}
pub fn is_closed(&self) -> bool {
self.read_stream_option
.as_ref()
.map(ReadStream::is_closed)
.unwrap_or(true)
}
pub fn try_read(&mut self) -> Result<Message<D>, TryReadError> {
if let Some(read_stream) = self.read_stream_option.borrow_mut() {
read_stream.try_read()
} else {
if let Some(channel_manager) = &*self.channel_manager_option.lock().unwrap() {
match channel_manager.lock().unwrap().take_recv_endpoint(self.id) {
Ok(recv_endpoint) => {
let mut read_stream = ReadStream::new(
self.id,
&default_graph::get_stream_name(&self.id),
Some(recv_endpoint),
);
let result = read_stream.try_read();
self.read_stream_option.replace(read_stream);
return result;
}
Err(msg) => tracing::error!(
"ExtractStream {} (ID: {}): error getting endpoint from \
channel manager \"{}\"",
self.name(),
self.id(),
msg
),
}
}
Err(TryReadError::Disconnected)
}
}
pub fn read(&mut self) -> Result<Message<D>, ReadError> {
loop {
let result = self.try_read();
if self.read_stream_option.is_some() {
match result {
Ok(msg) => return Ok(msg),
Err(TryReadError::Disconnected) => return Err(ReadError::Disconnected),
Err(TryReadError::Empty) => (),
Err(TryReadError::SerializationError) => {
return Err(ReadError::SerializationError)
}
Err(TryReadError::Closed) => return Err(ReadError::Closed),
};
} else {
thread::sleep(Duration::from_millis(100));
}
}
}
pub fn id(&self) -> StreamId {
self.id
}
pub fn name(&self) -> String {
default_graph::get_stream_name(&self.id)
}
pub(crate) fn get_setup_hook(&self) -> impl StreamSetupHook {
let channel_manager_option_copy = Arc::clone(&self.channel_manager_option);
move |channel_manager: Arc<Mutex<ChannelManager>>| {
channel_manager_option_copy
.lock()
.unwrap()
.replace(channel_manager);
}
}
}
unsafe impl<D> Send for ExtractStream<D> where for<'a> D: Data + Deserialize<'a> {}