use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread::{self, JoinHandle};
use crate::{
ChannelReceiverStream, Element, MutableNode, ReadyNotifier, RunMode, StreamPeekRef, UpStreams,
channel::{ChannelSender, channel_pair},
};
use tinyvec::TinyVec;
enum State<T: Element + Send> {
Func(Box<dyn Fn(ChannelSender<T>, Arc<AtomicBool>) -> anyhow::Result<()> + Send + 'static>),
JoinHandle(JoinHandle<anyhow::Result<()>>),
Done,
Empty,
}
impl<T: Element + Send> State<T> {
pub fn start(&mut self, channel_sender: ChannelSender<T>, stop: Arc<AtomicBool>) {
if let State::Func(f) = std::mem::replace(self, State::Empty) {
let handle = thread::spawn(move || f(channel_sender, stop));
*self = State::JoinHandle(handle);
}
}
pub fn check_running(&mut self, finished: bool) -> anyhow::Result<()> {
match self {
State::JoinHandle(handle) if handle.is_finished() => {
let State::JoinHandle(handle) = std::mem::replace(self, State::Empty) else {
unreachable!("state matched JoinHandle on the previous line");
};
match handle.join() {
Err(e) => Err(anyhow::anyhow!("Receiver thread panicked: {e:?}")),
Ok(Err(e)) => Err(e),
Ok(Ok(())) if finished => {
*self = State::Done;
Ok(())
}
Ok(Ok(())) => Err(anyhow::anyhow!("Receiver thread exited unexpectedly")),
}
}
State::JoinHandle(_) | State::Done => Ok(()),
State::Func(_) | State::Empty => Err(anyhow::anyhow!("Receiver thread not running")),
}
}
pub fn stop(&mut self) -> anyhow::Result<()> {
match std::mem::replace(self, State::Empty) {
State::JoinHandle(handle) => handle
.join()
.map_err(|e| anyhow::anyhow!("Thread panicked: {e:?}"))?,
_ => Ok(()),
}
}
}
pub(crate) struct ReceiverStream<T: Element + Send> {
inner: ChannelReceiverStream<T>,
sender: Option<ChannelSender<T>>,
state: State<T>,
stop: Arc<AtomicBool>,
assert_realtime: bool,
pending_err: Option<anyhow::Error>,
notifier: Option<ReadyNotifier>,
}
impl<T: Element + Send> MutableNode for ReceiverStream<T> {
fn upstreams(&self) -> UpStreams {
self.inner.upstreams()
}
fn cycle(&mut self, state: &mut crate::GraphState) -> anyhow::Result<bool> {
if let Some(e) = self.pending_err.take() {
return Err(e);
}
let cycle_result = self.inner.cycle(state)?;
if let Err(thread_err) = self.state.check_running(self.inner.finished()) {
self.pending_err = Some(thread_err);
if let Some(notifier) = &self.notifier {
let _ = notifier.notify();
}
}
Ok(cycle_result)
}
fn setup(&mut self, state: &mut crate::GraphState) -> anyhow::Result<()> {
let mut sender = self
.sender
.take()
.ok_or_else(|| anyhow::anyhow!("missing sender"))?;
if state.run_mode() == RunMode::RealTime {
let notifier = state.ready_notifier();
sender.set_notifier(notifier.clone());
self.notifier = Some(notifier);
}
self.state.start(sender, self.stop.clone());
self.inner.setup(state)
}
fn start(&mut self, state: &mut crate::GraphState) -> anyhow::Result<()> {
if self.assert_realtime && state.run_mode() != RunMode::RealTime {
anyhow::bail!("ReceiverStream only supports real-time mode");
}
self.inner.start(state)
}
fn stop(&mut self, state: &mut crate::GraphState) -> anyhow::Result<()> {
self.stop.store(true, Ordering::Relaxed);
self.state.stop()?;
self.inner.stop(state)
}
fn teardown(&mut self, state: &mut crate::GraphState) -> anyhow::Result<()> {
self.inner.teardown(state)
}
}
impl<T: Element + Send> StreamPeekRef<TinyVec<[T; 1]>> for ReceiverStream<T> {
fn peek_ref(&self) -> &TinyVec<[T; 1]> {
self.inner.peek_ref()
}
}
impl<T: Element + Send> ReceiverStream<T> {
pub(crate) fn new(
f: impl Fn(ChannelSender<T>, Arc<AtomicBool>) -> anyhow::Result<()> + Send + 'static,
assert_realtime: bool,
) -> Self {
let (sender, receiver) = channel_pair(None);
let inner = ChannelReceiverStream::new(receiver, None, None);
let sender = Some(sender);
let stop = Arc::new(AtomicBool::new(false));
let state = State::Func(Box::new(f));
Self {
inner,
sender,
state,
stop,
assert_realtime,
pending_err: None,
notifier: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::channel::Message;
use crate::nodes::NodeOperators;
use crate::{IntoStream, RunFor, RunMode};
use std::time::Duration;
fn finished_state(f: impl FnOnce() -> anyhow::Result<()> + Send + 'static) -> State<u64> {
let state = State::JoinHandle(thread::spawn(f));
while !matches!(&state, State::JoinHandle(h) if h.is_finished()) {
std::thread::yield_now();
}
state
}
#[test]
fn clean_exit_after_end_of_stream_is_ok() {
let mut state = finished_state(|| Ok(()));
assert!(state.check_running(true).is_ok());
assert!(matches!(state, State::Done));
assert!(state.check_running(true).is_ok());
}
#[test]
fn silent_exit_without_end_of_stream_is_error() {
let mut state = finished_state(|| Ok(()));
assert!(state.check_running(false).is_err());
}
#[test]
fn thread_error_propagates() {
let mut state = finished_state(|| Err(anyhow::anyhow!("boom")));
let err = state.check_running(true).unwrap_err();
assert!(err.to_string().contains("boom"));
}
#[test]
fn clean_end_of_stream_does_not_error_the_graph() {
let stream = ReceiverStream::new(
|sender, _stop| {
sender.send_message(Message::RealtimeValue(1u64))?;
sender.send_message(Message::EndOfStream)?;
Ok(())
},
true,
)
.into_stream();
stream
.count()
.run(
RunMode::RealTime,
RunFor::Duration(Duration::from_millis(100)),
)
.expect("clean end-of-stream should not error");
}
}