use flume::{Receiver, Sender};
use futures_core::Stream;
use std::any::Any;
use std::{
cmp,
collections::HashMap,
future::Future,
sync::{Arc, RwLock},
time::Duration,
};
use tokio::{
task::{self, JoinSet},
time,
};
pub mod mini;
mod exit;
mod node;
mod pubsub;
mod req_res;
mod streams;
mod watch;
pub use exit::*;
pub use node::*;
pub use pubsub::PubSubError;
pub use req_res::*;
pub use streams::{SourceSet, Sources};
use crate::pubsub::PubSub;
use crate::watch::{NoWatch, OnErrTerminate, WatchFn};
#[allow(unused_variables)]
pub trait Actor: Sized + Send + 'static {
type Props: Send + 'static;
type Msg: Send + 'static;
type Err: Send + Sync + 'static;
fn init(ctx: &mut Ctx<Self>) -> impl Future<Output = Result<Self, Self::Err>> + Send;
fn exit(
this: Option<Self>,
reason: ExitReason<Self>,
ctx: &mut Ctx<Self>,
) -> impl Future<Output = ()> + Send {
async {}
}
fn sources(
&self,
ctx: &Ctx<Self>,
) -> impl Future<Output = Result<impl Sources<Self>, Self::Err>> + Send {
async { Ok(SourceSet::new()) }
}
fn handle(
&mut self,
msg: Self::Msg,
ctx: &mut Ctx<Self>,
) -> impl Future<Output = Result<(), Self::Err>> + Send {
async { Ok(()) }
}
}
pub struct Handle<Msg> {
msg_tx: Sender<Msg>,
proc_msg_tx: Sender<ProcMsg>,
}
impl<Msg> Clone for Handle<Msg> {
fn clone(&self) -> Self {
Self {
msg_tx: self.msg_tx.clone(),
proc_msg_tx: self.proc_msg_tx.clone(),
}
}
}
impl<Msg> Handle<Msg> {
pub fn stop(&self) {
let (tx, _) = flume::unbounded();
let _ = self
.proc_msg_tx
.send(ProcMsg::FromHandle(ProcAction::Stop(tx)));
}
pub fn restart(&self) {
let _ = self
.proc_msg_tx
.send(ProcMsg::FromHandle(ProcAction::Restart));
}
pub fn is_alive(&self) -> bool {
!self.msg_tx.is_disconnected()
}
pub fn send<M: Into<Msg>>(&self, msg: M) -> bool {
self.msg_tx.send(msg.into()).is_ok()
}
pub fn send_in<M>(&self, msg: M, duration: Duration)
where
Msg: 'static + Send,
M: 'static + Send + Into<Msg>,
{
let msg_tx = self.msg_tx.clone();
task::spawn(async move {
time::sleep(duration).await;
let _ = msg_tx.send(msg.into());
});
}
pub async fn req<Req, Res>(&self, req: Req) -> Result<Res, ReqErr>
where
Msg: From<Request<Req, Res>>,
{
let (req, res) = req_res(req);
self.send(req);
res.recv().await
}
pub async fn reqw<F, Req, Res>(&self, to_req: F, req: Req) -> Result<Res, ReqErr>
where
F: Fn(Request<Req, Res>) -> Msg,
{
let (req, res) = req_res(req);
let msg = to_req(req);
self.send(msg);
res.recv().await
}
pub async fn req_timeout<Req, Res>(&self, req: Req, timeout: Duration) -> Result<Res, ReqErr>
where
Msg: From<Request<Req, Res>>,
{
let (req, res) = req_res(req);
self.send(req);
res.recv_timeout(timeout).await
}
pub async fn reqw_timeout<F, Req, Res>(
&self,
to_req: F,
req: Req,
timeout: Duration,
) -> Result<Res, ReqErr>
where
F: Fn(Request<Req, Res>) -> Msg,
{
let (req, res) = req_res(req);
let msg = to_req(req);
self.send(msg);
res.recv_timeout(timeout).await
}
}
pub struct Ctx<P>
where
P: Actor,
{
id: u64,
props: P::Props,
handle: Handle<P::Msg>,
msg_rx: Receiver<P::Msg>,
parent_proc_msg_tx: Option<Sender<ProcMsg>>,
proc_msg_rx: Receiver<ProcMsg>,
children_proc_msg_tx: HashMap<u64, Sender<ProcMsg>>,
supervision: Supervision,
total_children: u64,
tasks: JoinSet<Result<P::Msg, P::Err>>,
restarts: u64,
registry_key: Option<String>,
registry: Arc<RwLock<HashMap<String, Box<dyn Any + Send + Sync>>>>,
pubsub: Arc<RwLock<PubSub>>,
subscription_ids: Vec<(String, u64)>,
}
impl<P> Ctx<P>
where
P: Actor,
{
pub fn props(&self) -> &P::Props {
&self.props
}
pub fn this(&self) -> &Handle<P::Msg> {
&self.handle
}
pub fn clear_mailbox(&self) {
self.msg_rx.drain();
}
pub fn actor<'a, Child>(&'a mut self, props: Child::Props) -> SpawnBuilder<'a, P, Child>
where
Child: Actor,
{
SpawnBuilder::new(self, props)
}
pub fn restart_children(&self) {
for child in self.children_proc_msg_tx.values() {
let _ = child.send(ProcMsg::FromParent(ProcAction::Restart));
}
}
pub async fn stop_children(&mut self) {
let mut acks = Vec::with_capacity(self.total_children as usize);
for child in self.children_proc_msg_tx.values() {
let (ack_tx, ack_rx) = flume::unbounded();
let _ = child.send(ProcMsg::FromParent(ProcAction::Stop(ack_tx)));
acks.push(ack_rx);
}
for ack in acks {
let _ = ack.recv_async().await;
}
self.total_children = 0;
self.children_proc_msg_tx.clear();
}
pub fn task<F>(&mut self, f: F)
where
F: Future<Output = Result<P::Msg, P::Err>> + Send + 'static,
{
self.tasks.spawn(f);
}
pub fn get_handle_for<A: Actor>(&self) -> Result<Handle<A::Msg>, RegistryError> {
let key = std::any::type_name::<A>();
let reg = self.registry.read().map_err(|_| RegistryError::PoisonErr)?;
reg.get(key)
.and_then(|h| h.downcast_ref::<Handle<A::Msg>>())
.cloned()
.ok_or_else(|| RegistryError::NotFound(key.to_string()))
}
pub fn get_handle<Msg: Send + 'static>(
&self,
name: &str,
) -> Result<Handle<Msg>, RegistryError> {
let reg = self.registry.read().map_err(|_| RegistryError::PoisonErr)?;
reg.get(name)
.and_then(|h| h.downcast_ref::<Handle<Msg>>())
.cloned()
.ok_or_else(|| RegistryError::NotFound(name.to_string()))
}
pub fn send<A: Actor>(&self, msg: impl Into<A::Msg>) -> Result<(), RegistryError> {
let key = std::any::type_name::<A>();
let reg = self.registry.read().map_err(|_| RegistryError::PoisonErr)?;
match reg
.get(key)
.and_then(|h| h.downcast_ref::<Handle<A::Msg>>())
{
Some(handle) => {
handle.send(msg);
Ok(())
}
None => Err(RegistryError::NotFound(key.to_string())),
}
}
pub fn send_to<Msg: Send + 'static>(
&self,
name: &str,
msg: impl Into<Msg>,
) -> Result<(), RegistryError> {
let reg = self.registry.read().map_err(|_| RegistryError::PoisonErr)?;
match reg.get(name).and_then(|h| h.downcast_ref::<Handle<Msg>>()) {
Some(handle) => {
handle.send(msg);
Ok(())
}
None => Err(RegistryError::NotFound(name.to_string())),
}
}
}
#[allow(clippy::enum_variant_names)]
#[derive(Debug)]
enum ProcMsg {
ChildTerminated {
child_id: u64,
},
FromParent(ProcAction),
FromHandle(ProcAction),
}
#[derive(Debug)]
enum ProcAction {
Restart,
Stop(Sender<()>),
}
fn spawn<Child, W>(mut ctx: Ctx<Child>, delay: Option<Duration>, watch: W)
where
Child: Actor,
W: OnErrTerminate<Child::Err>,
{
tokio::spawn(async move {
if let Some(d) = delay.filter(|d| !d.is_zero()) {
time::sleep(d).await;
}
let mut restart = Restart::No;
let mut exit_reason = None;
let mut actor_created = None;
let mut stop_ack_tx = None;
match Child::init(&mut ctx).await {
Err(e) => {
exit_reason = Some(ExitReason::Err(e));
restart = Restart::from_supervision(ctx.supervision, ctx.restarts);
}
Ok(mut actor) => match actor.sources(&ctx).await {
Err(e) => {
exit_reason = Some(ExitReason::Err(e));
restart = Restart::from_supervision(ctx.supervision, ctx.restarts);
actor_created = Some(actor);
}
Ok(mut sources) => {
macro_rules! on_err {
($e:expr) => {
if let Supervision::Resume = ctx.supervision {
continue;
}
restart = Restart::from_supervision(ctx.supervision, ctx.restarts);
exit_reason = Some(ExitReason::Err($e));
actor_created = Some(actor);
break;
};
}
loop {
tokio::select! {
biased;
proc_msg = ctx.proc_msg_rx.recv_async() => {
match proc_msg {
Err(_) => break,
Ok(ProcMsg::FromHandle(ProcAction::Stop(tx)) ) => {
exit_reason = Some(ExitReason::Handle);
stop_ack_tx = Some(tx);
break
},
Ok(ProcMsg::FromParent(ProcAction::Stop(tx))) => {
exit_reason = exit_reason.or(Some(ExitReason::Parent));
stop_ack_tx = Some(tx);
break
},
Ok(ProcMsg::FromParent(ProcAction::Restart)) => {
exit_reason = exit_reason.or(Some(ExitReason::Parent));
restart = Restart::In(Duration::ZERO);
break;
}
Ok(ProcMsg::FromHandle(ProcAction::Restart)) => {
exit_reason = exit_reason.or(Some(ExitReason::Handle));
restart = Restart::In(Duration::ZERO);
break;
}
Ok(ProcMsg::ChildTerminated { child_id, }) => {
if ctx.children_proc_msg_tx.remove(&child_id).is_some() {
ctx.total_children -= 1;
}
}
}
}
recvd = ctx.msg_rx.recv_async() => {
match recvd {
Err(_) => break,
Ok(msg) => {
if let Err(e) = actor.handle(msg, &mut ctx).await {
on_err!(e);
};
}
}
}
Some(Ok(msg)) = ctx.tasks.join_next() => {
match msg {
Err(e) => {
on_err!(e);
}
Ok(msg) => {
if let Err(e) = actor.handle(msg, &mut ctx).await {
on_err!(e);
};
}
}
}
Some(msg) = std::future::poll_fn(|cx| Pin::new(&mut sources).poll_next(cx)) => {
if let Err(e) = actor.handle(msg, &mut ctx).await {
on_err!(e);
};
}
}
}
}
},
}
ctx.stop_children().await;
let exit_reason = exit_reason.unwrap_or(ExitReason::Handle);
if let ExitReason::Err(_) = &exit_reason {
ctx.restarts += 1;
}
if let (Restart::No, ExitReason::Err(ref e)) = (&restart, &exit_reason) {
watch.on_err_terminate(e);
}
Child::exit(actor_created, exit_reason, &mut ctx).await;
if !ctx.subscription_ids.is_empty() {
if let Ok(mut bus) = ctx.pubsub.write() {
for (topic, sub_id) in ctx.subscription_ids.drain(..) {
if let Some(entry) = bus.topics.get_mut(&topic) {
entry.subscribers.retain(|s| s.id != sub_id);
if entry.subscribers.is_empty() {
bus.topics.remove(&topic);
}
}
}
}
}
let _ = stop_ack_tx.map(|tx| tx.send(()));
if let Restart::In(duration) = restart {
spawn::<Child, W>(ctx, Some(duration), watch)
} else if let Some(parent_tx) = ctx.parent_proc_msg_tx {
if let Some(key) = ctx.registry_key.take() {
if let Ok(mut reg) = ctx.registry.write() {
reg.remove(&key);
}
}
let _ = parent_tx.send(ProcMsg::ChildTerminated { child_id: ctx.id });
}
});
}
#[derive(Debug, Clone, Copy)]
pub enum Supervision {
Stop,
Resume,
Restart { max: Limit, backoff: Backoff },
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum Backoff {
None,
Static(Duration),
Incremental {
min: Duration,
max: Duration,
step: Duration,
},
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum Limit {
None,
Amount(u64),
}
impl From<u64> for Limit {
fn from(value: u64) -> Self {
match value {
0 => Limit::None,
v => Limit::Amount(v),
}
}
}
impl PartialEq<u64> for Limit {
fn eq(&self, other: &u64) -> bool {
match self {
Limit::None => false,
Limit::Amount(n) => n == other,
}
}
}
#[derive(Debug, Clone)]
pub enum RegistryError {
NameTaken(String),
NotFound(String),
PoisonErr,
}
impl std::fmt::Display for RegistryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RegistryError::NameTaken(name) => write!(f, "registry name already taken: {name}"),
RegistryError::NotFound(name) => write!(f, "no actor registered under: {name}"),
RegistryError::PoisonErr => write!(f, "registry lock poisoned"),
}
}
}
impl std::error::Error for RegistryError {}
pub struct SpawnBuilder<'a, Parent, Child, W = NoWatch>
where
Parent: Actor,
Child: Actor,
{
ctx: &'a mut Ctx<Parent>,
props: Child::Props,
supervision: Supervision,
watch: W,
registry_key: Option<String>,
}
impl<'a, Parent, Child> SpawnBuilder<'a, Parent, Child, NoWatch>
where
Parent: Actor,
Child: Actor,
{
fn new(ctx: &'a mut Ctx<Parent>, props: Child::Props) -> Self {
Self {
ctx,
props,
supervision: Supervision::Restart {
max: Limit::None,
backoff: Backoff::None,
},
watch: NoWatch,
registry_key: None,
}
}
}
impl<'a, Parent, Child, W> SpawnBuilder<'a, Parent, Child, W>
where
Parent: Actor,
Child: Actor,
W: OnErrTerminate<Child::Err>,
{
pub fn supervision(mut self, supervision: Supervision) -> Self {
self.supervision = supervision;
self
}
pub fn watch<F>(self, f: F) -> SpawnBuilder<'a, Parent, Child, WatchFn<F, Parent::Msg>>
where
F: Fn(&Child::Err) -> Parent::Msg + Send + 'static,
{
let parent_msg_tx = self.ctx.handle.msg_tx.clone();
SpawnBuilder {
ctx: self.ctx,
props: self.props,
supervision: self.supervision,
watch: WatchFn { f, parent_msg_tx },
registry_key: self.registry_key,
}
}
pub fn spawn(self) -> Handle<Child::Msg> {
let (msg_tx, msg_rx) = flume::unbounded(); let (proc_msg_tx, proc_msg_rx) = flume::unbounded();
let handle = Handle {
msg_tx,
proc_msg_tx,
};
self.ctx.total_children += 1;
let id = self.ctx.total_children;
let ctx: Ctx<Child> = Ctx {
id,
props: self.props,
handle: handle.clone(),
msg_rx,
parent_proc_msg_tx: Some(self.ctx.handle.proc_msg_tx.clone()),
proc_msg_rx,
children_proc_msg_tx: HashMap::new(),
total_children: 0,
supervision: self.supervision,
restarts: 0,
tasks: JoinSet::new(),
registry_key: self.registry_key,
registry: self.ctx.registry.clone(),
pubsub: self.ctx.pubsub.clone(),
subscription_ids: Vec::new(),
};
spawn::<Child, W>(ctx, None, self.watch);
self.ctx
.children_proc_msg_tx
.insert(self.ctx.total_children, handle.proc_msg_tx.clone());
handle
}
pub fn spawn_registered(self) -> Result<Handle<Child::Msg>, RegistryError> {
let key = std::any::type_name::<Child>();
self.spawn_named(key)
}
pub fn spawn_named(
mut self,
name: impl Into<String>,
) -> Result<Handle<Child::Msg>, RegistryError> {
let name = name.into();
let registry = self.ctx.registry.clone();
let mut reg = registry.write().map_err(|_| RegistryError::PoisonErr)?;
if reg.contains_key(&name) {
return Err(RegistryError::NameTaken(name.clone()));
}
self.registry_key = Some(name.clone());
let handle = self.spawn();
reg.insert(name, Box::new(handle.clone()));
Ok(handle)
}
}
#[derive(Debug)]
enum Restart {
No,
In(Duration),
}
impl Restart {
fn from_supervision(supervision: Supervision, current_restarts: u64) -> Self {
match supervision {
Supervision::Stop => Restart::No,
Supervision::Resume => Restart::No,
Supervision::Restart { max, .. } if max == current_restarts + 1 => Restart::No,
Supervision::Restart { backoff, .. } => {
let wait = match backoff {
Backoff::None => Duration::ZERO,
Backoff::Static(duration) => duration,
Backoff::Incremental { min, max, step } => {
let wait = step.mul_f64((current_restarts + 1) as f64);
let wait = cmp::min(max, wait);
cmp::max(min, wait)
}
};
Restart::In(wait)
}
}
}
}