use async_trait::async_trait;
use futures::FutureExt;
use std::future::Future;
use std::sync::OnceLock;
use std::{
panic::AssertUnwindSafe,
sync::{Arc, Mutex},
time::Instant,
};
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
pub use tactix_macros::Message;
pub struct ActorSystem;
#[derive(Message)]
pub struct Shutdown;
impl Actor for ActorSystem {}
impl Handler<Shutdown> for ActorSystem {
async fn handle(&mut self, _: Shutdown, ctx: &Ctx<Self>) {
ctx.stop();
}
}
static ACTOR_SYSTEM: OnceLock<Ctx<ActorSystem>> = OnceLock::new();
impl ActorSystem {
pub fn global() -> &'static Ctx<ActorSystem> {
ACTOR_SYSTEM.get_or_init(|| {
let mut once = Some(ActorSystem);
start_actor(
move || {
once.take()
.expect("ActorSystem factory called more than once")
},
CancellationToken::new(),
mpsc::unbounded_channel().0,
SupervisionStrategy::default(),
)
})
}
pub fn addr() -> Addr<ActorSystem> {
Self::global().address()
}
pub async fn shutdown() {
let addr = Self::addr();
addr.tell(Shutdown);
addr.wait_until_stopped().await;
}
}
type PointerToActorMessage<A> = Box<dyn ActorMessage<A>>;
pub trait Actor: Send + Sync + Sized + 'static {
fn start(self) -> Addr<Self> {
let mut once = Some(self);
ActorSystem::global().spawn_with_config(
move || once.take().expect("Factory can only be accessed once!"),
SupervisionStrategy::NoRestart,
)
}
fn started(&self, _ctx: &Ctx<Self>) -> impl Future<Output = ()> + Send { async {} }
fn stopped(&self, _ctx: &Ctx<Self>) -> impl Future<Output = ()> + Send { async {} }
fn restarted(&self, _restarts: u64, _ctx: &Ctx<Self>) -> impl Future<Output = ()> + Send { async {} }
fn child_escalated(&self, _ctx: &Ctx<Self>) -> impl Future<Output = Option<Interrupt>> + Send {
async { Some(Interrupt::RestartToEscalate) }
}
}
pub enum Interrupt {
Stop,
RestartToEscalate,
}
pub enum SupervisionStrategy {
NoRestart,
Restart { window: u64, max_restarts: u64 },
}
impl Default for SupervisionStrategy {
fn default() -> Self {
Self::Restart {
window: 5,
max_restarts: 3,
}
}
}
fn start_actor<A, F>(
mut factory: F,
cancel: CancellationToken,
escalate_to_parent: mpsc::UnboundedSender<()>,
restart_config: SupervisionStrategy,
) -> Ctx<A>
where
A: Actor + Sync,
F: FnMut() -> A + Send + 'static,
{
let (tx, mut rx) = mpsc::unbounded_channel::<PointerToActorMessage<A>>();
let (child_escalations, mut child_escalations_rx) = mpsc::unbounded_channel();
let stopped = CancellationToken::new();
let ctx = Ctx::<A> {
addr: Addr {
tx,
stopped: stopped.clone(),
},
children: Arc::new(Mutex::new(Vec::new())),
cancel,
stopped: stopped.clone(),
child_escalations,
};
let ctx_loop = ctx.clone();
tokio::spawn(async move {
let mut is_restart = false;
let ctx = ctx_loop;
let _stopped_guard = stopped.drop_guard();
let mut restarts = 0u64;
let mut first_restart = Instant::now();
loop {
let mut actor = factory();
actor.started(&ctx).await;
if is_restart {
actor.restarted(restarts, &ctx).await;
}
let code = loop {
tokio::select! {
biased;
_ = ctx.cancel.cancelled() => {
break Interrupt::Stop;
}
Some(_) = child_escalations_rx.recv() => {
if let Some(interrupt) = actor.child_escalated(&ctx).await {
break interrupt;
}
}
msg = rx.recv() => {
let Some(mut msg) = msg else {
break Interrupt::Stop;
};
if let Err(panic) = AssertUnwindSafe(msg.process(&mut actor, &ctx))
.catch_unwind()
.await
{
let msg = panic.downcast_ref::<&str>().copied()
.or_else(|| panic.downcast_ref::<String>().map(|s| s.as_str()))
.unwrap_or("<non-string panic>");
eprintln!("ACTOR PANIC!\n actor:{}\n reason: {}\n restarting...", std::any::type_name::<A>(), msg);
break Interrupt::RestartToEscalate;
}
}
}
};
ctx.stop_all_children().await;
match code {
Interrupt::Stop => {
actor.stopped(&ctx).await;
break;
}
Interrupt::RestartToEscalate => {
match &restart_config {
SupervisionStrategy::NoRestart => {
break;
}
SupervisionStrategy::Restart {
window,
max_restarts,
} => {
is_restart = true;
if first_restart.elapsed().as_secs() > *window {
restarts = 0;
first_restart = Instant::now();
}
restarts += 1;
if restarts >= *max_restarts {
eprintln!(
"Actor restarted {} times in {}s, escalating.",
max_restarts, window
);
let _ = escalate_to_parent.send(());
break;
}
}
}
}
}
}
});
ctx
}
pub struct Ctx<A: Actor> {
addr: Addr<A>,
children: Arc<Mutex<Vec<Arc<dyn Stoppable + Send + Sync>>>>,
cancel: CancellationToken,
stopped: CancellationToken,
child_escalations: mpsc::UnboundedSender<()>,
}
impl<A: Actor> Clone for Ctx<A> {
fn clone(&self) -> Self {
Self {
addr: self.addr.clone(),
children: self.children.clone(),
cancel: self.cancel.clone(),
stopped: self.stopped.clone(),
child_escalations: self.child_escalations.clone(),
}
}
}
impl<A: Actor> Ctx<A> {
#[must_use]
pub fn address(&self) -> Addr<A> {
self.addr.clone()
}
pub fn spawn<B, F>(&self, factory: F) -> Addr<B>
where
F: FnMut() -> B + Send + 'static,
B: Actor,
{
self.spawn_with_config(factory, SupervisionStrategy::default())
}
pub fn spawn_with_config<B, F>(&self, factory: F, config: SupervisionStrategy) -> Addr<B>
where
F: FnMut() -> B + Send + 'static,
B: Actor,
{
let child = start_actor(
factory,
self.cancel.child_token(),
self.child_escalations.clone(),
config,
);
self.children.lock().unwrap().push(Arc::new(child.clone()));
child.address()
}
}
#[async_trait]
pub trait Stoppable {
fn stop(&self);
async fn wait_until_stopped(&self);
async fn stop_all_children(&self);
}
pub trait Message: Send + 'static {
type Response: Send;
}
pub trait Handler<M>
where
Self: Actor,
M: Message,
{
fn handle(&mut self, msg: M, ctx: &Ctx<Self>) -> impl Future<Output = M::Response> + Send;
}
#[async_trait]
pub trait ActorMessage<A: Actor>: Send {
async fn process(&mut self, actor: &mut A, ctx: &Ctx<A>);
}
pub struct Envelope<M>
where
M: Message,
{
pub msg: Option<M>,
pub tx: Option<oneshot::Sender<M::Response>>,
}
impl<M> Envelope<M>
where
M: Message,
{
pub fn new(msg: Option<M>, tx: Option<oneshot::Sender<M::Response>>) -> Box<Self> {
Box::new(Self { msg, tx })
}
}
#[async_trait]
impl<A, M> ActorMessage<A> for Envelope<M>
where
A: Actor + Handler<M>,
M: Message,
{
async fn process(&mut self, act: &mut A, ctx: &Ctx<A>) {
if let Some(msg) = self.msg.take() {
let res = act.handle(msg, ctx).await;
if let Some(tx) = self.tx.take() {
let _ = tx.send(res);
}
}
}
}
pub struct Addr<A>
where
A: Actor,
{
tx: mpsc::UnboundedSender<PointerToActorMessage<A>>,
stopped: CancellationToken,
}
impl<A: Actor> Addr<A> {
pub async fn wait_until_stopped(&self) {
self.stopped.cancelled().await;
}
}
impl<A> Clone for Addr<A>
where
A: Actor,
{
fn clone(&self) -> Self {
Addr {
tx: self.tx.clone(),
stopped: self.stopped.clone(),
}
}
}
#[async_trait]
pub trait Sender<M>
where
M: Message,
{
async fn ask(&self, msg: M) -> M::Response;
fn tell(&self, msg: M);
fn recipient(self) -> Recipient<M>
where
Self: Sized + Send + Sync + 'static,
{
Recipient { tx: Box::new(self) }
}
}
#[async_trait]
impl<M, A> Sender<M> for Addr<A>
where
M: Message,
A: Actor + Handler<M>,
{
async fn ask(&self, msg: M) -> M::Response {
let (tx, rx) = oneshot::channel();
let _ = self.tx.send(Envelope::new(Some(msg), Some(tx)));
rx.await.expect("actor dropped before responding")
}
fn tell(&self, msg: M) {
let _ = self.tx.send(Envelope::new(Some(msg), None));
}
}
pub struct Recipient<M: Message> {
tx: Box<dyn Sender<M> + Send + Sync + 'static>,
}
impl<M> Recipient<M>
where
M: Message,
{
pub fn new(tx: Box<dyn Sender<M> + Send + Sync + 'static>) -> Self {
Recipient { tx }
}
}
#[async_trait]
impl<M> Sender<M> for Recipient<M>
where
M: Message,
{
async fn ask(&self, msg: M) -> M::Response {
self.tx.ask(msg).await
}
fn tell(&self, msg: M) {
self.tx.tell(msg);
}
}
#[async_trait]
impl<A> Stoppable for Ctx<A>
where
A: Actor,
{
fn stop(&self) {
self.cancel.cancel();
}
async fn wait_until_stopped(&self) {
self.stopped.cancelled().await;
}
async fn stop_all_children(&self) {
let children: Vec<_> = self.children.lock().unwrap().drain(..).collect();
for child in children {
child.stop();
child.wait_until_stopped().await;
}
}
}
#[cfg(test)]
mod simple_tests {
use crate::{Actor, Addr, Ctx, Handler, Message, Sender, Stoppable};
use std::time::Instant;
#[tokio::test(flavor = "multi_thread")]
async fn simple_counter() -> anyhow::Result<()> {
struct Counter {
count: i64,
}
impl Actor for Counter {}
#[derive(Message)]
struct Increment;
#[derive(Message)]
struct Decrement;
#[derive(Message)]
#[response(i64)]
struct GetCount;
impl Handler<Increment> for Counter {
async fn handle(&mut self, _: Increment, _: &Ctx<Self>) {
self.count += 1;
}
}
impl Handler<Decrement> for Counter {
async fn handle(&mut self, _: Decrement, _: &Ctx<Self>) {
self.count -= 1;
}
}
impl Handler<GetCount> for Counter {
async fn handle(&mut self, _: GetCount, _: &Ctx<Self>) -> i64 {
self.count
}
}
let counter = Counter { count: 0 }.start();
let mut handles = vec![];
let t1 = counter.clone();
let t2 = counter.clone();
let total = 10_000_000;
let start = Instant::now();
handles.push(tokio::task::spawn(async move {
for _ in 0..(total / 2) {
t1.tell(Increment);
}
Ok::<_, anyhow::Error>(())
}));
handles.push(tokio::task::spawn(async move {
for _ in 0..(total / 2) {
t2.tell(Decrement);
}
Ok::<_, anyhow::Error>(())
}));
for handle in handles {
handle.await??;
}
let count = counter.ask(GetCount).await;
assert_eq!(count, 0);
let finished = start.elapsed();
let msg_per_sec = total as f64 / finished.as_secs_f64();
println!("{:.1} million msg/sec", msg_per_sec / 1_000_000.0);
Ok(())
}
#[tokio::test]
async fn restart_counter() -> anyhow::Result<()> {
struct Db {
value: i64,
}
impl Actor for Db {
async fn stopped(&self, _: &Ctx<Self>) {
println!("Db stopped");
}
}
#[derive(Message)]
#[response(i64)]
struct DbGet;
#[derive(Message)]
struct DbSet(i64);
impl Handler<DbGet> for Db {
async fn handle(&mut self, _: DbGet, _: &Ctx<Self>) -> i64 {
self.value
}
}
impl Handler<DbSet> for Db {
async fn handle(&mut self, msg: DbSet, _: &Ctx<Self>) {
self.value = msg.0;
}
}
struct Counter {
db: Addr<Db>,
}
impl Actor for Counter {
async fn stopped(&self, _: &Ctx<Self>) {
println!("Counter stopped");
}
}
#[derive(Message)]
#[response(i64)]
struct Increment;
#[derive(Message)]
#[response(i64)]
struct GetCount;
#[derive(Message)]
struct Poison;
#[derive(Message)]
struct Stop;
impl Handler<Increment> for Counter {
async fn handle(&mut self, _: Increment, _: &Ctx<Self>) -> i64 {
let count = self.db.ask(DbGet).await + 1;
self.db.tell(DbSet(count));
count
}
}
impl Handler<GetCount> for Counter {
async fn handle(&mut self, _: GetCount, _: &Ctx<Self>) -> i64 {
self.db.ask(DbGet).await
}
}
impl Handler<Poison> for Counter {
async fn handle(&mut self, _: Poison, _: &Ctx<Self>) {
panic!("poisoned!");
}
}
struct Root {}
impl Actor for Root {
async fn stopped(&self, _: &Ctx<Self>) {
println!("Root stopped");
}
}
impl Handler<Stop> for Root {
async fn handle(&mut self, _: Stop, ctx: &Ctx<Self>) {
println!("in Stop handler");
ctx.stop();
}
}
#[derive(Message)]
#[response(Addr<Counter>)]
struct GetCounter;
impl Handler<GetCounter> for Root {
async fn handle(&mut self, _: GetCounter, ctx: &Ctx<Self>) -> Addr<Counter> {
let db = ctx.spawn(|| Db { value: 0 });
let counter = ctx.spawn(move || Counter { db: db.clone() });
counter
}
}
let root = Root {}.start();
let counter = root.ask(GetCounter).await;
for _ in 0..5 {
counter.tell(Increment);
}
assert_eq!(counter.ask(GetCount).await, 5);
counter.tell(Poison);
counter.ask(Increment).await;
let count = counter.ask(GetCount).await;
assert_eq!(count, 6, "state survives because it lives in the db actor");
root.tell(Stop);
println!("just called stop!");
root.wait_until_stopped().await;
println!("goodbye");
Ok(())
}
}
#[cfg(test)]
mod bank_tests {
use std::time::Duration;
use crate::{Actor, Ctx, Handler, Message, Sender};
use tokio::time::sleep;
struct Deposit(u64);
impl Message for Deposit {
type Response = ();
}
struct Withdraw(u64);
impl Message for Withdraw {
type Response = Result<(), String>;
}
struct GetBalance;
impl Message for GetBalance {
type Response = u64;
}
struct GetAccountInfo;
impl Message for GetAccountInfo {
type Response = (u64, u64, u64);
}
#[derive(Clone)]
struct BankAccount {
balance: u64,
total_deposits: u64,
total_withdrawals: u64,
}
impl BankAccount {
fn new(initial_balance: u64) -> Self {
Self {
balance: initial_balance,
total_deposits: 0,
total_withdrawals: 0,
}
}
}
impl Actor for BankAccount {}
impl Handler<Deposit> for BankAccount {
async fn handle(&mut self, msg: Deposit, _: &Ctx<Self>) {
tokio::time::sleep(Duration::from_millis(6)).await;
self.balance += msg.0;
self.total_deposits += msg.0;
println!("Deposit: {}. New balance: {}", msg.0, self.balance);
}
}
impl Handler<Withdraw> for BankAccount {
async fn handle(&mut self, msg: Withdraw, _: &Ctx<Self>) -> Result<(), String> {
if self.balance >= msg.0 {
tokio::time::sleep(Duration::from_millis(10)).await;
self.balance -= msg.0;
self.total_withdrawals += msg.0;
println!("Withdrawal: {}. New balance: {}", msg.0, self.balance);
Ok(())
} else {
Err(format!(
"Insufficient funds. Current balance: {}",
self.balance
))
}
}
}
impl Handler<GetAccountInfo> for BankAccount {
async fn handle(&mut self, _msg: GetAccountInfo, _: &Ctx<Self>) -> (u64, u64, u64) {
println!("GetAccountInfo!");
(self.balance, self.total_deposits, self.total_withdrawals)
}
}
impl Handler<GetBalance> for BankAccount {
async fn handle(&mut self, _msg: GetBalance, _: &Ctx<Self>) -> u64 {
self.balance
}
}
#[tokio::test]
async fn test_bank_account_race_condition() {
let initial_balance = 1000;
let account = BankAccount::new(initial_balance).start();
let deposit_amount = 100;
let withdraw_amount = 200;
let num_operations = 5;
let deposit_task = tokio::spawn({
let account = account.clone();
async move {
for _ in 0..num_operations {
account.tell(Deposit(deposit_amount));
sleep(Duration::from_millis(3)).await;
}
}
});
let withdraw_task = tokio::spawn({
let account = account.clone();
async move {
for _ in 0..num_operations {
account.tell(Withdraw(withdraw_amount));
sleep(Duration::from_millis(9)).await;
}
}
});
let _ = tokio::join!(deposit_task, withdraw_task);
let (final_balance, total_deposits, total_withdrawals) = account.ask(GetAccountInfo).await;
let expected_deposits = deposit_amount * num_operations;
let expected_withdrawals = withdraw_amount * num_operations;
let expected_balance = initial_balance + expected_deposits - expected_withdrawals;
assert_eq!(
total_deposits, expected_deposits,
"Total deposits don't match expected value"
);
assert_eq!(
total_withdrawals, expected_withdrawals,
"Total withdrawals don't match expected value"
);
assert_eq!(
final_balance, expected_balance,
"Final balance doesn't match expected value"
);
assert_eq!(
final_balance,
initial_balance + total_deposits - total_withdrawals,
"Balance is inconsistent with recorded deposits and withdrawals"
);
}
}