use std::{
any::Any,
boxed,
collections::HashMap,
future::Future,
io,
pin::Pin,
sync::{
Arc, Mutex, RwLock,
atomic::{AtomicU64, Ordering},
},
};
#[allow(unused_imports)]
use futures_util::{
FutureExt,
future::{self, Either},
pin_mut,
};
pub type Sender<T> = flume::Sender<T>;
pub type Receiver<T> = flume::Receiver<T>;
pub use futures_executor::block_on;
pub trait ActorError: Sized + Send + 'static {
fn from_actor_message(msg: String) -> Self;
}
impl ActorError for io::Error {
fn from_actor_message(msg: String) -> Self {
io::Error::other(msg)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum ActorState {
#[default]
Running,
Stopped,
}
#[cfg(feature = "anyhow")]
impl ActorError for anyhow::Error {
fn from_actor_message(msg: String) -> Self {
anyhow::anyhow!(msg)
}
}
impl ActorError for String {
fn from_actor_message(msg: String) -> Self {
msg
}
}
impl ActorError for Box<dyn std::error::Error + Send + Sync> {
fn from_actor_message(msg: String) -> Self {
Box::new(io::Error::other(msg))
}
}
pub type PreBoxActorFut<'a, T> = dyn Future<Output = T> + Send + 'a;
pub type ActorFut<'a, T> = Pin<boxed::Box<PreBoxActorFut<'a, T>>>;
pub type Action<A> = Box<dyn for<'a> FnOnce(&'a mut A) -> ActorFut<'a, ()> + Send + 'static>;
type BaseCallResult<R, E> = Result<
(
Receiver<Result<R, E>>,
Receiver<()>,
u64,
&'static std::panic::Location<'static>,
),
E,
>;
type PendingCancelMap = Arc<Mutex<HashMap<u64, Sender<()>>>>;
fn fail_pending_calls(pending: &PendingCancelMap) {
if let Ok(mut pending) = pending.lock() {
for (_, cancel_tx) in pending.drain() {
let _ = cancel_tx.send(());
}
}
}
#[doc(hidden)]
pub fn into_actor_fut_res<'a, Fut, T, E>(fut: Fut) -> ActorFut<'a, Result<T, E>>
where
Fut: Future<Output = Result<T, E>> + Send + 'a,
T: Send + 'a,
{
Box::pin(fut)
}
#[doc(hidden)]
pub fn into_actor_fut_ok<'a, Fut, T, E>(fut: Fut) -> ActorFut<'a, Result<T, E>>
where
Fut: Future<Output = T> + Send + 'a,
T: Send + 'a,
E: ActorError,
{
Box::pin(async move { Ok(fut.await) })
}
#[macro_export]
macro_rules! act {
($actor:ident => $expr:expr) => {{ move |$actor| $crate::into_actor_fut_res(($expr)) }};
($actor:ident => $body:block) => {{ move |$actor| $crate::into_actor_fut_res($body) }};
}
#[macro_export]
macro_rules! act_ok {
($actor:ident => $expr:expr) => {{ move |$actor| $crate::into_actor_fut_ok(($expr)) }};
($actor:ident => $body:block) => {{ move |$actor| $crate::into_actor_fut_ok($body) }};
}
fn panic_payload_message(panic_payload: Box<dyn Any + Send>) -> String {
if let Some(s) = panic_payload.downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = panic_payload.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic".to_string()
}
}
fn actor_loop_panic<E: ActorError>(panic_payload: Box<dyn Any + Send>) -> E {
E::from_actor_message(format!(
"panic in actor loop: {}",
panic_payload_message(panic_payload)
))
}
#[derive(Debug)]
pub struct Handle<A, E>
where
A: Send + 'static,
E: ActorError,
{
tx: Arc<Mutex<Option<Sender<Action<A>>>>>,
state: Arc<RwLock<ActorState>>,
pending: PendingCancelMap,
next_call_id: Arc<AtomicU64>,
stopped_rx: Receiver<()>,
_phantom: std::marker::PhantomData<E>,
}
impl<A, E> Clone for Handle<A, E>
where
A: Send + 'static,
E: ActorError,
{
fn clone(&self) -> Self {
Self {
tx: Arc::clone(&self.tx),
state: Arc::clone(&self.state),
pending: Arc::clone(&self.pending),
next_call_id: Arc::clone(&self.next_call_id),
stopped_rx: self.stopped_rx.clone(),
_phantom: std::marker::PhantomData,
}
}
}
impl<A, E> PartialEq for Handle<A, E>
where
A: Send + 'static,
E: ActorError,
{
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.state, &other.state)
}
}
impl<A, E> Eq for Handle<A, E>
where
A: Send + 'static,
E: ActorError,
{
}
impl<A, E> Handle<A, E>
where
A: Send + 'static,
E: ActorError,
{
pub fn state(&self) -> ActorState {
self.state.read().expect("poisned lock").clone()
}
#[cfg(all(feature = "tokio", not(feature = "async-std")))]
pub fn spawn(actor: A) -> (Self, tokio::task::JoinHandle<Result<(), E>>)
{
let (tx, rx) = flume::unbounded::<Action<A>>();
let state = Arc::new(RwLock::new(ActorState::default()));
let pending = Arc::new(Mutex::new(HashMap::new()));
let next_call_id = Arc::new(AtomicU64::new(0));
let (stopped_tx, stopped_rx) = flume::bounded::<()>(1);
let join_handle = {
let state = Arc::clone(&state);
let pending = Arc::clone(&pending);
tokio::task::spawn(async move {
let _stopped_signal = stopped_tx;
let mut actor = actor;
let res = std::panic::AssertUnwindSafe(async {
while let Ok(action) = rx.recv_async().await {
action(&mut actor).await;
}
Ok::<(), E>(())
})
.catch_unwind()
.await;
if let Ok(mut st) = state.write() {
*st = ActorState::Stopped;
}
fail_pending_calls(&pending);
match res {
Ok(result) => result,
Err(panic_payload) => Err(actor_loop_panic(panic_payload)),
}
})
};
(
Self {
tx: Arc::new(Mutex::new(Some(tx))),
state,
pending,
next_call_id,
stopped_rx,
_phantom: std::marker::PhantomData,
},
join_handle,
)
}
#[cfg(all(feature = "tokio", not(feature = "async-std")))]
pub fn spawn_with<F, Fut>(actor: A, run: F) -> (Self, tokio::task::JoinHandle<Result<(), E>>)
where
F: FnOnce(A, Receiver<Action<A>>) -> Fut + Send + 'static,
Fut: Future<Output = Result<(), E>> + Send,
{
let (tx, rx) = flume::unbounded();
let state = Arc::new(RwLock::new(ActorState::default()));
let pending = Arc::new(Mutex::new(HashMap::new()));
let next_call_id = Arc::new(AtomicU64::new(0));
let (stopped_tx, stopped_rx) = flume::bounded::<()>(1);
let join_handle = {
let state = Arc::clone(&state);
let pending = Arc::clone(&pending);
tokio::task::spawn(async move {
let _stopped_signal = stopped_tx;
let res = std::panic::AssertUnwindSafe(run(actor, rx))
.catch_unwind()
.await;
if let Ok(mut st) = state.write() {
*st = ActorState::Stopped;
}
fail_pending_calls(&pending);
match res {
Ok(result) => result,
Err(panic_payload) => Err(actor_loop_panic(panic_payload)),
}
})
};
(
Self {
tx: Arc::new(Mutex::new(Some(tx))),
state,
pending,
next_call_id,
stopped_rx,
_phantom: std::marker::PhantomData,
},
join_handle,
)
}
#[cfg(all(feature = "async-std", not(feature = "tokio")))]
pub fn spawn(actor: A) -> (Self, async_std::task::JoinHandle<Result<(), E>>)
{
let (tx, rx) = flume::unbounded::<Action<A>>();
let state = Arc::new(RwLock::new(ActorState::default()));
let pending = Arc::new(Mutex::new(HashMap::new()));
let next_call_id = Arc::new(AtomicU64::new(0));
let (stopped_tx, stopped_rx) = flume::bounded::<()>(1);
let join_handle = {
let state = Arc::clone(&state);
let pending = Arc::clone(&pending);
async_std::task::spawn(async move {
let _stopped_signal = stopped_tx;
let mut actor = actor;
let res = std::panic::AssertUnwindSafe(async {
while let Ok(action) = rx.recv_async().await {
action(&mut actor).await;
}
Ok::<(), E>(())
})
.catch_unwind()
.await;
if let Ok(mut st) = state.write() {
*st = ActorState::Stopped;
}
fail_pending_calls(&pending);
match res {
Ok(result) => result,
Err(panic_payload) => Err(actor_loop_panic(panic_payload)),
}
})
};
(
Self {
tx: Arc::new(Mutex::new(Some(tx))),
state,
pending,
next_call_id,
stopped_rx,
_phantom: std::marker::PhantomData,
},
join_handle,
)
}
#[cfg(all(feature = "async-std", not(feature = "tokio")))]
pub fn spawn_with<F, Fut>(actor: A, run: F) -> (Self, async_std::task::JoinHandle<Result<(), E>>)
where
F: FnOnce(A, Receiver<Action<A>>) -> Fut + Send + 'static,
Fut: Future<Output = Result<(), E>> + Send,
{
let (tx, rx) = flume::unbounded();
let state = Arc::new(RwLock::new(ActorState::default()));
let pending = Arc::new(Mutex::new(HashMap::new()));
let next_call_id = Arc::new(AtomicU64::new(0));
let (stopped_tx, stopped_rx) = flume::bounded::<()>(1);
let join_handle = {
let state = Arc::clone(&state);
let pending = Arc::clone(&pending);
async_std::task::spawn(async move {
let _stopped_signal = stopped_tx;
let res = std::panic::AssertUnwindSafe(run(actor, rx))
.catch_unwind()
.await;
if let Ok(mut st) = state.write() {
*st = ActorState::Stopped;
}
fail_pending_calls(&pending);
match res {
Ok(result) => result,
Err(panic_payload) => Err(actor_loop_panic(panic_payload)),
}
})
};
(
Self {
tx: Arc::new(Mutex::new(Some(tx))),
state,
pending,
next_call_id,
stopped_rx,
_phantom: std::marker::PhantomData,
},
join_handle,
)
}
pub fn spawn_blocking(actor: A) -> (Self, std::thread::JoinHandle<Result<(), E>>)
{
let (tx, rx) = flume::unbounded::<Action<A>>();
let state = Arc::new(RwLock::new(ActorState::default()));
let pending = Arc::new(Mutex::new(HashMap::new()));
let next_call_id = Arc::new(AtomicU64::new(0));
let (stopped_tx, stopped_rx) = flume::bounded::<()>(1);
let join_handle = {
let state = Arc::clone(&state);
let pending = Arc::clone(&pending);
std::thread::spawn(move || {
let _stopped_signal = stopped_tx;
let mut actor = actor;
let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
while let Ok(action) = rx.recv() {
block_on(action(&mut actor));
}
Ok::<(), E>(())
}));
if let Ok(mut st) = state.write() {
*st = ActorState::Stopped;
}
fail_pending_calls(&pending);
match res {
Ok(result) => result,
Err(panic_payload) => Err(actor_loop_panic(panic_payload)),
}
})
};
(
Self {
tx: Arc::new(Mutex::new(Some(tx))),
state,
pending,
next_call_id,
stopped_rx,
_phantom: std::marker::PhantomData,
},
join_handle,
)
}
pub fn spawn_blocking_with<F>(actor: A, run: F) -> (Self, std::thread::JoinHandle<Result<(), E>>)
where
F: FnOnce(A, Receiver<Action<A>>) -> Result<(), E> + Send + 'static,
{
let (tx, rx) = flume::unbounded();
let state = Arc::new(RwLock::new(ActorState::default()));
let pending = Arc::new(Mutex::new(HashMap::new()));
let next_call_id = Arc::new(AtomicU64::new(0));
let (stopped_tx, stopped_rx) = flume::bounded::<()>(1);
let join_handle = {
let state = Arc::clone(&state);
let pending = Arc::clone(&pending);
std::thread::spawn(move || {
let _stopped_signal = stopped_tx;
let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
run(actor, rx)
}));
if let Ok(mut st) = state.write() {
*st = ActorState::Stopped;
}
fail_pending_calls(&pending);
match res {
Ok(result) => result,
Err(panic_payload) => Err(actor_loop_panic(panic_payload)),
}
})
};
(
Self {
tx: Arc::new(Mutex::new(Some(tx))),
state,
pending,
next_call_id,
stopped_rx,
_phantom: std::marker::PhantomData,
},
join_handle,
)
}
fn base_call<R, F>(&self, f: F) -> BaseCallResult<R, E>
where
F: for<'a> FnOnce(&'a mut A) -> ActorFut<'a, Result<R, E>> + Send + 'static,
R: Send + 'static,
{
if self.state() != ActorState::Running {
return Err(E::from_actor_message(
"actor stopped (call attempted while actor state is not running)".to_string(),
));
}
let (rtx, rrx) = flume::unbounded();
let (cancel_tx, cancel_rx) = flume::bounded::<()>(1);
let loc = std::panic::Location::caller();
let call_id = self.next_call_id.fetch_add(1, Ordering::Relaxed);
self.pending
.lock()
.expect("poisoned lock")
.insert(call_id, cancel_tx);
let action: Action<A> = Box::new(move |actor: &mut A| {
Box::pin(async move {
let panic_result = std::panic::AssertUnwindSafe(async move { f(actor).await })
.catch_unwind()
.await;
let res = match panic_result {
Ok(action_result) => action_result,
Err(panic_payload) => {
let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = panic_payload.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic".to_string()
};
Err(E::from_actor_message(format!(
"panic in actor call at {}:{}: {}",
loc.file(),
loc.line(),
msg
)))
}
};
let _ = rtx.send(res);
})
});
let sent = {
let tx_guard = self.tx.lock().expect("poisoned lock");
tx_guard
.as_ref()
.map_or(false, |tx| tx.send(action).is_ok())
};
if !sent {
if let Ok(mut pending) = self.pending.lock() {
pending.remove(&call_id);
}
return Err(E::from_actor_message(format!(
"actor stopped (call send at {}:{})",
loc.file(),
loc.line()
)));
}
Ok((rrx, cancel_rx, call_id, loc))
}
pub fn call_blocking<R, F>(&self, f: F) -> Result<R, E>
where
F: for<'a> FnOnce(&'a mut A) -> ActorFut<'a, Result<R, E>> + Send + 'static,
R: Send + 'static,
{
enum BlockingWaitResult<T, E> {
Result(Result<Result<T, E>, flume::RecvError>),
Canceled(Result<(), flume::RecvError>),
}
let (rrx, cancel_rx, call_id, loc) = self.base_call(f)?;
let out = match flume::Selector::new()
.recv(&rrx, BlockingWaitResult::Result)
.recv(&cancel_rx, BlockingWaitResult::Canceled)
.wait()
{
BlockingWaitResult::Result(msg) => msg.map_err(|_| {
E::from_actor_message(format!(
"actor stopped (call recv at {}:{})",
loc.file(),
loc.line()
))
})?,
BlockingWaitResult::Canceled(Ok(())) => Err(E::from_actor_message(format!(
"actor stopped (call canceled at {}:{})",
loc.file(),
loc.line()
))),
BlockingWaitResult::Canceled(Err(_)) => Err(E::from_actor_message(format!(
"actor stopped (call recv at {}:{})",
loc.file(),
loc.line()
))),
};
if let Ok(mut pending) = self.pending.lock() {
pending.remove(&call_id);
}
out
}
#[cfg(any(feature = "tokio", feature = "async-std"))]
pub async fn call<R, F>(&self, f: F) -> Result<R, E>
where
F: for<'a> FnOnce(&'a mut A) -> ActorFut<'a, Result<R, E>> + Send + 'static,
R: Send + 'static,
{
let (rrx, cancel_rx, call_id, loc) = self.base_call(f)?;
let recv_fut = rrx.recv_async();
let cancel_fut = cancel_rx.recv_async();
pin_mut!(recv_fut, cancel_fut);
let out = match future::select(recv_fut, cancel_fut).await {
Either::Left((msg, _)) => msg.map_err(|_| {
E::from_actor_message(format!(
"actor stopped (call recv at {}:{})",
loc.file(),
loc.line()
))
})?,
Either::Right((Ok(_), _)) => Err(E::from_actor_message(format!(
"actor stopped (call canceled at {}:{})",
loc.file(),
loc.line()
))),
Either::Right((Err(_), _)) => Err(E::from_actor_message(format!(
"actor stopped (call recv at {}:{})",
loc.file(),
loc.line()
))),
};
if let Ok(mut pending) = self.pending.lock() {
pending.remove(&call_id);
}
out
}
pub fn shutdown(&self) {
if let Ok(mut tx) = self.tx.lock() {
tx.take();
}
}
pub fn wait_stopped_blocking(&self) {
if self.state() == ActorState::Stopped {
return;
}
let _ = self.stopped_rx.recv();
}
#[cfg(any(feature = "tokio", feature = "async-std"))]
pub async fn wait_stopped(&self) {
if self.state() == ActorState::Stopped {
return;
}
let _ = self.stopped_rx.recv_async().await;
}
}