use std::{collections::HashSet, marker::PhantomData};
use ractor::{Actor, ActorCell, ActorProcessingErr, ActorRef, SpawnErr, SupervisionEvent};
use tokio_stream::Stream;
#[cfg(test)]
mod tests;
pub trait StreamMuxNotification: 'static + Send {
fn target_failed(&self, target: String, err: ActorProcessingErr);
fn end_of_stream(&self);
}
pub trait Target<S>: 'static + Send
where
S: Stream + ractor::State,
S::Item: Clone + ractor::Message,
{
fn get_id(&self) -> String;
fn message_received(&self, item: <S as Stream>::Item) -> Result<(), ActorProcessingErr>;
}
pub struct StreamMuxConfiguration<S, N>
where
S: Stream + ractor::State,
S::Item: Clone + ractor::Message,
N: StreamMuxNotification,
{
pub stream: S,
pub targets: Vec<Box<dyn Target<S>>>,
pub callback: N,
pub stop_processing_target_on_failure: bool,
}
pub async fn mux_stream<S, N>(
config: StreamMuxConfiguration<S, N>,
sup: Option<ActorCell>,
) -> Result<ActorCell, SpawnErr>
where
S: Stream + ractor::State,
S::Item: Clone + ractor::Message,
N: StreamMuxNotification,
{
let handler = MuxActor::<S, N> {
_s: PhantomData,
_n: PhantomData,
};
let actor = if let Some(s) = sup {
Actor::spawn_linked(None, handler, config, s).await?.0
} else {
Actor::spawn(None, handler, config).await?.0
};
Ok(actor.into())
}
struct MuxActorState<S, N>
where
S: Stream + ractor::State,
S::Item: Clone + ractor::Message,
N: StreamMuxNotification,
{
targets: Vec<Box<dyn Target<S>>>,
callback: N,
stop_processing_target_on_failure: bool,
_pump: ActorCell,
}
struct MuxActor<S, N>
where
S: Stream + ractor::State,
S::Item: Clone + ractor::Message,
N: StreamMuxNotification,
{
_s: PhantomData<S>,
_n: PhantomData<N>,
}
unsafe impl<S, N> Sync for MuxActor<S, N>
where
S: Stream + ractor::State,
S::Item: Clone + ractor::Message,
N: StreamMuxNotification,
{
}
#[cfg_attr(feature = "async-trait", async_trait::async_trait)]
impl<S, N> Actor for MuxActor<S, N>
where
S: Stream + ractor::State,
S::Item: Clone + ractor::Message,
N: StreamMuxNotification,
{
type Msg = Option<S::Item>;
type State = MuxActorState<S, N>;
type Arguments = StreamMuxConfiguration<S, N>;
async fn pre_start(
&self,
myself: ActorRef<Self::Msg>,
StreamMuxConfiguration::<S, N> {
stream,
targets,
callback,
stop_processing_target_on_failure,
}: Self::Arguments,
) -> Result<Self::State, ActorProcessingErr> {
let pump = crate::streams::spawn_stream_pump(stream, myself, |a| a, None).await?;
tracing::debug!("Stream pump started");
Ok(MuxActorState::<S, N> {
_pump: pump,
callback,
targets,
stop_processing_target_on_failure,
})
}
async fn handle(
&self,
myself: ActorRef<Self::Msg>,
message: Option<S::Item>,
state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
if let Some(item) = message {
let mut to_be_removed = HashSet::new();
for target in state.targets.iter() {
if let Err(err) = target.message_received(item.clone()) {
let id = target.get_id();
tracing::error!("Failed to send message to target {} with {err}", id);
state.callback.target_failed(id.clone(), err);
to_be_removed.insert(id);
}
}
if state.stop_processing_target_on_failure {
state
.targets
.retain(|target| !to_be_removed.contains(&target.get_id()));
if state.targets.is_empty() {
tracing::debug!("Halting stream processing as no more targets exist");
myself.stop(None);
state.callback.end_of_stream();
}
}
} else {
myself.stop(Some("End of stream".to_string()));
state.callback.end_of_stream();
tracing::debug!("Reached end of stream");
}
Ok(())
}
async fn handle_supervisor_evt(
&self,
_myself: ActorRef<Self::Msg>,
message: SupervisionEvent,
_state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
if let SupervisionEvent::ActorFailed(_who, what) = message {
return Err(ractor::ActorErr::Failed(what).into());
}
Ok(())
}
}