mod pool;
use std::{
any::Any,
error, fmt,
fmt::{Debug, Formatter},
mem,
sync::{Arc, Mutex},
time,
time::Duration,
};
use crossbeam_channel::{Receiver, RecvTimeoutError as CcRecvTimeoutError};
use rayon::{ThreadPool, ThreadPoolBuilder};
use tokio::{sync::Semaphore, task};
use crate::actor::{
context::CancellationToken, system::native::pool::PoolActorHandle, timers::scheduler::SchedulerHandle,
traits::Actor,
};
#[derive(Debug, Clone)]
pub struct ActorSystemConfig {
pub pool_threads: usize,
pub max_in_flight: usize,
}
struct ActorSystemInner {
pool: Arc<ThreadPool>,
permits: Arc<Semaphore>,
cancel: CancellationToken,
scheduler: SchedulerHandle,
wakers: Mutex<Vec<Arc<dyn Fn() + Send + Sync>>>,
keepalive: Mutex<Vec<Box<dyn Any + Send + Sync>>>,
done_rxs: Mutex<Vec<Receiver<()>>>,
}
#[derive(Clone)]
pub struct ActorSystem {
inner: Arc<ActorSystemInner>,
}
impl ActorSystem {
pub fn new(config: ActorSystemConfig) -> Self {
let pool = Arc::new(
ThreadPoolBuilder::new()
.num_threads(config.pool_threads)
.thread_name(|i| format!("actor-pool-{i}"))
.build()
.expect("failed to build rayon pool"),
);
let scheduler = SchedulerHandle::new(pool.clone());
Self {
inner: Arc::new(ActorSystemInner {
pool,
permits: Arc::new(Semaphore::new(config.max_in_flight)),
cancel: CancellationToken::new(),
scheduler,
wakers: Mutex::new(Vec::new()),
keepalive: Mutex::new(Vec::new()),
done_rxs: Mutex::new(Vec::new()),
}),
}
}
pub fn scope(&self) -> Self {
Self {
inner: Arc::new(ActorSystemInner {
pool: self.inner.pool.clone(),
permits: self.inner.permits.clone(),
cancel: CancellationToken::new(),
scheduler: self.inner.scheduler.shared(),
wakers: Mutex::new(Vec::new()),
keepalive: Mutex::new(Vec::new()),
done_rxs: Mutex::new(Vec::new()),
}),
}
}
pub fn cancellation_token(&self) -> CancellationToken {
self.inner.cancel.clone()
}
pub fn is_cancelled(&self) -> bool {
self.inner.cancel.is_cancelled()
}
pub fn shutdown(&self) {
self.inner.cancel.cancel();
let wakers = mem::take(&mut *self.inner.wakers.lock().unwrap());
for waker in &wakers {
waker();
}
drop(wakers);
self.inner.keepalive.lock().unwrap().clear();
}
pub(crate) fn register_waker(&self, f: Arc<dyn Fn() + Send + Sync>) {
self.inner.wakers.lock().unwrap().push(f);
}
pub(crate) fn register_keepalive(&self, cell: Box<dyn Any + Send + Sync>) {
self.inner.keepalive.lock().unwrap().push(cell);
}
pub(crate) fn register_done_rx(&self, rx: Receiver<()>) {
self.inner.done_rxs.lock().unwrap().push(rx);
}
pub fn join(&self) -> Result<(), JoinError> {
self.join_timeout(Duration::from_secs(5))
}
pub fn join_timeout(&self, timeout: Duration) -> Result<(), JoinError> {
let deadline = time::Instant::now() + timeout;
let rxs: Vec<_> = mem::take(&mut *self.inner.done_rxs.lock().unwrap());
for rx in rxs {
let remaining = deadline.saturating_duration_since(time::Instant::now());
match rx.recv_timeout(remaining) {
Ok(()) => {}
Err(CcRecvTimeoutError::Disconnected) => {
}
Err(CcRecvTimeoutError::Timeout) => {
return Err(JoinError::new("timed out waiting for actors to stop"));
}
}
}
Ok(())
}
pub fn scheduler(&self) -> &SchedulerHandle {
&self.inner.scheduler
}
pub fn spawn<A: Actor>(&self, name: &str, actor: A) -> ActorHandle<A::Message>
where
A::State: Send,
{
pool::spawn_on_pool(self, name, actor)
}
pub fn install<R, F>(&self, f: F) -> R
where
R: Send,
F: FnOnce() -> R + Send,
{
self.inner.pool.install(f)
}
pub async fn compute<R, F>(&self, f: F) -> Result<R, task::JoinError>
where
R: Send + 'static,
F: FnOnce() -> R + Send + 'static,
{
let permit = self.inner.permits.clone().acquire_owned().await.expect("semaphore closed");
let inner = self.inner.clone();
let handle = task::spawn_blocking(move || {
let _permit = permit; inner.pool.install(f)
});
handle.await
}
pub async fn execute<R, F>(&self, f: F) -> Result<R, task::JoinError>
where
R: Send + 'static,
F: FnOnce() -> R + Send + 'static,
{
let permit = self.inner.permits.clone().acquire_owned().await.expect("semaphore closed");
let handle = task::spawn_blocking(move || {
let _permit = permit;
f()
});
handle.await
}
pub(crate) fn pool(&self) -> &Arc<ThreadPool> {
&self.inner.pool
}
}
impl Debug for ActorSystem {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("ActorSystem").field("cancelled", &self.is_cancelled()).finish_non_exhaustive()
}
}
pub type ActorHandle<M> = PoolActorHandle<M>;
#[derive(Debug)]
pub struct JoinError {
message: String,
}
impl JoinError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl fmt::Display for JoinError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "actor join failed: {}", self.message)
}
}
impl error::Error for JoinError {}
#[cfg(test)]
mod tests {
use std::sync;
use super::*;
use crate::{
SharedRuntimeConfig,
actor::{context::Context, traits::Directive},
};
struct CounterActor;
#[derive(Debug)]
enum CounterMsg {
Inc,
Get(sync::mpsc::Sender<i64>),
Stop,
}
impl Actor for CounterActor {
type State = i64;
type Message = CounterMsg;
fn init(&self, _ctx: &Context<Self::Message>) -> Self::State {
0
}
fn handle(
&self,
state: &mut Self::State,
msg: Self::Message,
_ctx: &Context<Self::Message>,
) -> Directive {
match msg {
CounterMsg::Inc => *state += 1,
CounterMsg::Get(tx) => {
let _ = tx.send(*state);
}
CounterMsg::Stop => return Directive::Stop,
}
Directive::Continue
}
}
#[test]
fn test_spawn_and_send() {
let system = ActorSystem::new(SharedRuntimeConfig::default().actor_system_config());
let handle = system.spawn("counter", CounterActor);
let actor_ref = handle.actor_ref().clone();
actor_ref.send(CounterMsg::Inc).unwrap();
actor_ref.send(CounterMsg::Inc).unwrap();
actor_ref.send(CounterMsg::Inc).unwrap();
let (tx, rx) = sync::mpsc::channel();
actor_ref.send(CounterMsg::Get(tx)).unwrap();
let value = rx.recv().unwrap();
assert_eq!(value, 3);
actor_ref.send(CounterMsg::Stop).unwrap();
handle.join().unwrap();
}
#[test]
fn test_install() {
let system = ActorSystem::new(SharedRuntimeConfig::default().actor_system_config());
let result = system.install(|| 42);
assert_eq!(result, 42);
}
#[tokio::test]
async fn test_compute() {
let system = ActorSystem::new(SharedRuntimeConfig::default().actor_system_config());
let result = system.compute(|| 42).await.unwrap();
assert_eq!(result, 42);
}
#[test]
fn test_shutdown_join() {
let system = ActorSystem::new(SharedRuntimeConfig::default().actor_system_config());
for i in 0..5 {
system.spawn(&format!("counter-{i}"), CounterActor);
}
system.shutdown();
system.join().unwrap();
}
}