use crate::{
ActorError, Addr, BareContext, Capacity, Control, Priority, RegistryEntry, System, SystemHandle,
};
use futures_util::{StreamExt, select_biased};
use log::{debug, trace};
use std::{any::type_name, fmt, future, thread};
use tokio::runtime::LocalRuntime;
#[allow(async_fn_in_trait)]
pub trait AsyncActor {
type Message: Send + 'static;
type Error: fmt::Display;
const DEFAULT_CAPACITY_NORMAL: usize = 5;
const DEFAULT_CAPACITY_HIGH: usize = 5;
fn name() -> &'static str {
type_name::<Self>()
}
fn priority(_message: &Self::Message) -> Priority {
Priority::Normal
}
async fn started(&mut self, _context: &BareContext<Self::Message>) -> Result<(), Self::Error> {
Ok(())
}
async fn handle(
&mut self,
context: &BareContext<Self::Message>,
message: Self::Message,
) -> Result<(), Self::Error>;
async fn stopped(&mut self, _context: &BareContext<Self::Message>) -> Result<(), Self::Error> {
Ok(())
}
fn addr() -> Addr<Self::Message> {
let capacity =
Capacity { normal: Self::DEFAULT_CAPACITY_NORMAL, high: Self::DEFAULT_CAPACITY_HIGH };
Self::addr_with_capacity(capacity)
}
fn addr_with_capacity(capacity: impl Into<Capacity>) -> Addr<Self::Message> {
Addr::new(capacity, Self::name(), Self::priority)
}
}
#[must_use = "You must call .with_addr(), .with_capacity(), or .with_default_capacity() to \
configure this builder"]
pub struct AsyncSpawnBuilderWithoutAddress<'a, A: AsyncActor, F: IntoFuture<Output = A>> {
system: &'a mut System,
factory: F,
}
impl<'a, A: AsyncActor, F: IntoFuture<Output = A>> AsyncSpawnBuilderWithoutAddress<'a, A, F> {
pub fn with_addr(self, addr: Addr<A::Message>) -> AsyncSpawnBuilderWithAddress<'a, A, F> {
AsyncSpawnBuilderWithAddress { spawn_builder: self, addr }
}
pub fn with_capacity(
self,
capacity: impl Into<Capacity>,
) -> AsyncSpawnBuilderWithAddress<'a, A, F> {
let addr = A::addr_with_capacity(capacity);
AsyncSpawnBuilderWithAddress { spawn_builder: self, addr }
}
pub fn with_default_capacity(self) -> AsyncSpawnBuilderWithAddress<'a, A, F> {
let addr = A::addr();
AsyncSpawnBuilderWithAddress { spawn_builder: self, addr }
}
}
#[must_use = "You must call .spawn() to run the actor"]
pub struct AsyncSpawnBuilderWithAddress<'a, A: AsyncActor, F: IntoFuture<Output = A>> {
spawn_builder: AsyncSpawnBuilderWithoutAddress<'a, A, F>,
addr: Addr<A::Message>,
}
impl<A: AsyncActor, F: IntoFuture<Output = A> + Send + 'static>
AsyncSpawnBuilderWithAddress<'_, A, F>
{
pub fn spawn(self) -> Result<Addr<A::Message>, ActorError> {
let builder = self.spawn_builder;
builder.system.spawn_async_fn_with_addr(builder.factory, self.addr.clone())?;
Ok(self.addr)
}
}
impl System {
pub fn prepare_async<A>(
&mut self,
actor: A,
) -> AsyncSpawnBuilderWithoutAddress<'_, A, future::Ready<A>>
where
A: AsyncActor,
{
AsyncSpawnBuilderWithoutAddress { system: self, factory: future::ready(actor) }
}
pub fn prepare_async_factory<A, F>(
&mut self,
factory: F,
) -> AsyncSpawnBuilderWithoutAddress<'_, A, F>
where
A: AsyncActor,
F: IntoFuture<Output = A>,
{
AsyncSpawnBuilderWithoutAddress { system: self, factory }
}
pub fn spawn_async<A>(&mut self, actor: A) -> Result<Addr<A::Message>, ActorError>
where
A: AsyncActor + Send + 'static,
{
self.prepare_async(actor).with_default_capacity().spawn()
}
fn spawn_async_fn_with_addr<F, A>(
&mut self,
factory: F,
addr: Addr<A::Message>,
) -> Result<(), ActorError>
where
F: IntoFuture<Output = A> + Send + 'static,
A: AsyncActor,
{
let system_state_lock = self.handle.system_state.read();
if !system_state_lock.is_running() {
return Err(ActorError::SystemStopped { actor_name: A::name() });
}
let system_handle = self.handle.clone();
let context =
BareContext { system_handle: system_handle.clone(), myself: addr.recipient.clone() };
let control_addr = addr.control_tx.clone();
let thread_handle = thread::Builder::new()
.name(A::name().into())
.spawn(move || {
let runtime = match LocalRuntime::new() {
Ok(runtime) => runtime,
Err(e) => {
Self::report_error_shutdown(
&system_handle,
A::name(),
"creating async runtime",
e,
);
return;
},
};
let main_task = async {
let mut actor = factory.await;
if let Err(error) = actor.started(&context).await {
Self::report_error_shutdown(&system_handle, A::name(), "started()", error);
return;
}
debug!("[{}] started async actor: {}", system_handle.name, A::name());
Self::run_async_actor_select_loop(actor, addr, &context, &system_handle).await
};
runtime.block_on(main_task)
})
.map_err(|_| ActorError::SpawnFailed { actor_name: A::name() })?;
self.handle
.registry
.lock()
.push(RegistryEntry::BackgroundThread(control_addr, thread_handle));
Ok(())
}
async fn run_async_actor_select_loop<A>(
mut actor: A,
addr: Addr<A::Message>,
context: &BareContext<A::Message>,
system_handle: &SystemHandle,
) where
A: AsyncActor,
{
enum Received<M> {
Control(Control),
Message(M),
}
let mut control_stream = addr.control_rx.into_stream();
let mut high_prio_stream = addr.priority_rx.into_stream();
let mut normal_prio_stream = addr.message_rx.into_stream();
loop {
let received = select_biased!(
control = control_stream.next() => {
Received::Control(control.expect("We keep control_tx alive through addr."))
},
high_prio = high_prio_stream.next() => {
Received::Message(high_prio.expect("We keep priority_tx alive through addr."))
},
normal_prio = normal_prio_stream.next() => {
Received::Message(normal_prio.expect("We keep message_tx alive through addr."))
},
);
match received {
Received::Control(Control::Stop) => {
if let Err(error) = actor.stopped(context).await {
Self::report_error_shutdown(system_handle, A::name(), "stopped()", error);
}
debug!("[{}] stopped actor: {}", system_handle.name, A::name());
return;
},
Received::Message(msg) => {
trace!("[{}] message received by {}", system_handle.name, A::name());
if let Err(error) = actor.handle(context, msg).await {
Self::report_error_shutdown(system_handle, A::name(), "handle()", error);
return;
}
},
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Actor, Context, Recipient};
use anyhow::Error;
use std::{
sync::{Arc, Mutex},
time::Duration,
};
struct AsyncTestActor {
recorder: Recipient<TestMessage>,
}
impl AsyncActor for AsyncTestActor {
type Error = Error;
type Message = TestMessage;
fn priority(message: &TestMessage) -> Priority {
match message {
TestMessage::HighPrio(_) => Priority::High,
_ => Priority::Normal,
}
}
async fn started(&mut self, _: &BareContext<TestMessage>) -> Result<(), Error> {
debug!("AsyncActor started hook");
self.recorder.send(TestMessage::Event("started"))?;
Ok(())
}
async fn handle(
&mut self,
context: &BareContext<TestMessage>,
message: TestMessage,
) -> Result<(), Error> {
self.recorder.send(message.clone())?;
if message == TestMessage::DelayedTask {
let recorder = self.recorder.clone();
tokio::spawn(async move {
debug!("delayed task started");
tokio::time::sleep(Duration::from_millis(10)).await;
recorder.send(TestMessage::Event("delayed task finished"))?;
debug!("delayed task finished");
Ok::<(), Error>(())
});
}
if message == TestMessage::DelayedShutdown {
let system_handle = context.system_handle.clone();
tokio::spawn(async move {
debug!("delayed shutdown started");
tokio::time::sleep(Duration::from_millis(20)).await;
debug!("delayed shutdown shutting down now");
system_handle.shutdown()
});
}
Ok(())
}
async fn stopped(&mut self, _: &BareContext<TestMessage>) -> Result<(), Error> {
trace!("AsyncActor stopped hook");
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum TestMessage {
Event(&'static str),
HighPrio(usize),
NormalPrio(usize),
DelayedTask,
DelayedShutdown,
}
struct SyncRecorder {
received: Arc<Mutex<Vec<TestMessage>>>,
}
impl Actor for SyncRecorder {
type Context = Context<Self::Message>;
type Error = Error;
type Message = TestMessage;
fn handle(
&mut self,
_context: &mut Self::Context,
message: Self::Message,
) -> Result<(), Self::Error> {
self.received.lock().expect("lock should not be poisoned").push(message);
Ok(())
}
}
#[test]
fn async_priorities() {
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("trace"))
.try_init()
.ok();
let mut system = System::new("async priorities");
let received = Arc::new(Mutex::new(Vec::new()));
let recorder_actor = SyncRecorder { received: Arc::clone(&received) };
let recorder_addr = system.prepare(recorder_actor).with_capacity(10).spawn().unwrap();
let async_actor = AsyncTestActor { recorder: recorder_addr.recipient() };
let async_addr = system.prepare_async(async_actor).with_capacity(10).spawn().unwrap();
async_addr.send(TestMessage::DelayedTask).unwrap();
async_addr.send(TestMessage::DelayedShutdown).unwrap();
async_addr.send(TestMessage::NormalPrio(1)).unwrap();
async_addr.send(TestMessage::NormalPrio(2)).unwrap();
async_addr.send(TestMessage::HighPrio(3)).unwrap();
async_addr.send(TestMessage::HighPrio(4)).unwrap();
system.run().unwrap();
let received = Arc::into_inner(received)
.expect("arc has a single reference at this point")
.into_inner()
.expect("Mutex should not be poisoned");
assert_eq!(
received,
[
TestMessage::Event("started"),
TestMessage::HighPrio(3),
TestMessage::HighPrio(4),
TestMessage::DelayedTask,
TestMessage::DelayedShutdown,
TestMessage::NormalPrio(1),
TestMessage::NormalPrio(2),
TestMessage::Event("delayed task finished")
]
);
}
#[test]
fn async_error() {
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("trace"))
.try_init()
.ok();
struct ErroringActor;
impl AsyncActor for ErroringActor {
type Error = String;
type Message = ();
async fn handle(&mut self, _c: &BareContext<()>, _m: ()) -> Result<(), String> {
Err(String::from("Raising an error"))
}
}
let mut system = System::new("async error");
let addr = system.spawn_async(ErroringActor).unwrap();
addr.send(()).unwrap();
system.run().unwrap();
}
}