use std::collections::HashMap;
use std::marker::PhantomData;
use ractor::Actor;
use ractor::ActorId;
use ractor::ActorProcessingErr;
use ractor::ActorRef;
use ractor::RpcReplyPort;
#[cfg(test)]
mod tests;
pub trait BroadcastTarget<T>: 'static + Send
where
T: ractor::Message + Clone,
{
fn id(&self) -> ActorId;
fn send(&self, t: T) -> Result<(), ActorProcessingErr>;
}
pub struct Broadcaster<T>
where
T: ractor::Message + Clone,
{
_t: PhantomData<T>,
}
impl<T> Default for Broadcaster<T>
where
T: ractor::Message + Clone,
{
fn default() -> Self {
Self::new()
}
}
impl<T> Broadcaster<T>
where
T: ractor::Message + Clone,
{
pub fn new() -> Self {
Self { _t: PhantomData }
}
pub fn get_unit_mapped_target(who: ActorRef<T>) -> Box<dyn BroadcastTarget<T>>
where
T: ractor::Message + Clone,
{
struct IdTarget<T2>
where
T2: ractor::Message + Clone,
{
ar: ActorRef<T2>,
}
impl<T2> BroadcastTarget<T2> for IdTarget<T2>
where
T2: ractor::Message + Clone,
{
fn id(&self) -> ActorId {
self.ar.get_id()
}
fn send(&self, t: T2) -> Result<(), ActorProcessingErr> {
self.ar.cast(t)?;
Ok(())
}
}
Box::new(IdTarget::<T> { ar: who })
}
}
#[derive(Default)]
pub struct BroadcasterConfig<T>
where
T: ractor::Message + Clone,
{
pub initial_targets: Vec<Box<dyn BroadcastTarget<T>>>,
pub continue_with_dead_targets: bool,
}
unsafe impl<T> Sync for Broadcaster<T> where T: ractor::Message + Clone {}
pub enum BroadcasterMessage<T>
where
T: ractor::Message + Clone,
{
Broadcast(T),
AddTarget(Box<dyn BroadcastTarget<T>>),
RemoveTarget(ActorId),
ListTargets(RpcReplyPort<Vec<ActorId>>),
}
#[doc(hidden)]
pub struct BroadcasterState<T>
where
T: ractor::Message + Clone,
{
targets: HashMap<ActorId, Box<dyn BroadcastTarget<T>>>,
continue_with_dead_targets: bool,
}
#[cfg_attr(feature = "async-trait", async_trait::async_trait)]
impl<T> Actor for Broadcaster<T>
where
T: ractor::Message + Clone,
{
type Msg = BroadcasterMessage<T>;
type State = BroadcasterState<T>;
type Arguments = BroadcasterConfig<T>;
async fn pre_start(
&self,
_: ActorRef<Self::Msg>,
BroadcasterConfig {
continue_with_dead_targets,
initial_targets,
}: Self::Arguments,
) -> Result<Self::State, ActorProcessingErr> {
Ok(BroadcasterState {
targets: initial_targets
.into_iter()
.map(|target| (target.id(), target))
.collect(),
continue_with_dead_targets,
})
}
async fn handle(
&self,
_: ActorRef<Self::Msg>,
message: Self::Msg,
state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
match message {
BroadcasterMessage::Broadcast(t) => {
for (who, target) in state.targets.iter() {
if let Err(e) = target.send(t.clone()) {
tracing::error!("Error forwarding message to target {who}: {e}");
if !state.continue_with_dead_targets {
return Err(e);
}
} else {
tracing::debug!("Broadcast message to {who}");
}
}
}
BroadcasterMessage::AddTarget(target) => {
state.targets.insert(target.id(), target);
}
BroadcasterMessage::RemoveTarget(target) => {
state.targets.remove(&target);
}
BroadcasterMessage::ListTargets(reply) => {
let ids = state.targets.keys().cloned().collect::<Vec<_>>();
let _ = reply.send(ids);
}
}
Ok(())
}
}