use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use async_trait::async_trait;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use crate::errors::{ActorSendError, ErrorAction, ErrorCode, RuntimeError};
use crate::interceptor::SendMode;
use crate::mailbox::MailboxConfig;
use crate::message::{Headers, Message};
use crate::node::ActorId;
use crate::stream::{BatchConfig, BoxStream, StreamReceiver, StreamSender};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ActorError {
pub code: ErrorCode,
pub message: String,
pub details: Option<String>,
pub cause: Option<Box<ActorError>>,
}
impl ActorError {
pub fn new(code: ErrorCode, message: impl Into<String>) -> Self {
Self {
code,
message: message.into(),
details: None,
cause: None,
}
}
pub fn internal(message: impl Into<String>) -> Self {
Self::new(ErrorCode::Internal, message)
}
pub fn with_details(mut self, details: impl Into<String>) -> Self {
self.details = Some(details.into());
self
}
pub fn with_cause(mut self, cause: ActorError) -> Self {
self.cause = Some(Box::new(cause));
self
}
}
impl fmt::Display for ActorError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[{:?}] {}", self.code, self.message)?;
if let Some(ref details) = self.details {
write!(f, " ({})", details)?;
}
if let Some(ref cause) = self.cause {
write!(f, " caused by: {}", cause)?;
}
Ok(())
}
}
impl std::error::Error for ActorError {}
#[derive(Debug)]
pub struct ActorContext {
pub actor_id: ActorId,
pub actor_name: String,
pub send_mode: Option<SendMode>,
pub headers: Headers,
pub(crate) cancellation_token: Option<CancellationToken>,
}
impl ActorContext {
pub fn new(actor_id: ActorId, actor_name: String) -> Self {
Self {
actor_id,
actor_name,
send_mode: None,
headers: Headers::new(),
cancellation_token: None,
}
}
pub async fn cancelled(&self) {
match &self.cancellation_token {
Some(token) => token.cancelled().await,
None => futures::future::pending().await,
}
}
pub fn set_cancellation_token(&mut self, token: Option<CancellationToken>) {
self.cancellation_token = token;
}
}
#[async_trait]
pub trait Actor: Send + 'static {
type Args: Send + 'static;
type Deps: Send + 'static;
fn create(args: Self::Args, deps: Self::Deps) -> Self
where
Self: Sized;
async fn on_start(&mut self, _ctx: &mut ActorContext) {}
async fn on_stop(&mut self) {}
fn on_error(&mut self, _error: &ActorError) -> ErrorAction {
ErrorAction::Stop
}
}
#[async_trait]
pub trait Handler<M: Message>: Actor {
async fn handle(&mut self, msg: M, ctx: &mut ActorContext) -> M::Reply;
}
#[async_trait]
pub trait ExpandHandler<M, OutputItem: Send + 'static>: Actor
where
M: Send + 'static,
{
async fn handle_expand(
&mut self,
msg: M,
sender: StreamSender<OutputItem>,
ctx: &mut ActorContext,
);
}
#[async_trait]
pub trait ReduceHandler<InputItem: Send + 'static, Reply: Send + 'static>: Actor {
async fn handle_reduce(
&mut self,
receiver: StreamReceiver<InputItem>,
ctx: &mut ActorContext,
) -> Reply;
}
#[async_trait]
pub trait TransformHandler<InputItem: Send + 'static, OutputItem: Send + 'static>: Actor {
async fn handle_transform(
&mut self,
item: InputItem,
sender: &StreamSender<OutputItem>,
ctx: &mut ActorContext,
);
async fn on_transform_complete(
&mut self,
sender: &StreamSender<OutputItem>,
ctx: &mut ActorContext,
) {
let _ = (sender, ctx);
}
}
pub struct AskReply<R> {
rx: oneshot::Receiver<Result<R, RuntimeError>>,
}
impl<R> AskReply<R> {
pub fn new(rx: oneshot::Receiver<Result<R, RuntimeError>>) -> Self {
Self { rx }
}
}
impl<R> Future for AskReply<R> {
type Output = Result<R, RuntimeError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.rx).poll(cx) {
Poll::Ready(Ok(Ok(reply))) => Poll::Ready(Ok(reply)),
Poll::Ready(Ok(Err(error))) => Poll::Ready(Err(error)),
Poll::Ready(Err(_)) => Poll::Ready(Err(RuntimeError::ActorNotFound(
"reply channel closed — actor may have stopped, panicked, or the request was cancelled".into(),
))),
Poll::Pending => Poll::Pending,
}
}
}
pub trait ActorRef<A: Actor>: Clone + Send + Sync + 'static {
fn id(&self) -> ActorId;
fn name(&self) -> String;
fn is_alive(&self) -> bool;
fn pending_messages(&self) -> usize {
0
}
fn stop(&self);
fn tell<M>(&self, msg: M) -> Result<(), ActorSendError>
where
A: Handler<M>,
M: Message<Reply = ()>;
fn ask<M>(
&self,
msg: M,
cancel: Option<CancellationToken>,
) -> Result<AskReply<M::Reply>, ActorSendError>
where
A: Handler<M>,
M: Message;
fn expand<M, OutputItem>(
&self,
msg: M,
buffer: usize,
batch_config: Option<BatchConfig>,
cancel: Option<CancellationToken>,
) -> Result<BoxStream<OutputItem>, ActorSendError>
where
A: ExpandHandler<M, OutputItem>,
M: Send + 'static,
OutputItem: Send + 'static;
fn reduce<InputItem, Reply>(
&self,
input: BoxStream<InputItem>,
buffer: usize,
batch_config: Option<BatchConfig>,
cancel: Option<CancellationToken>,
) -> Result<AskReply<Reply>, ActorSendError>
where
A: ReduceHandler<InputItem, Reply>,
InputItem: Send + 'static,
Reply: Send + 'static;
fn transform<InputItem, OutputItem>(
&self,
input: BoxStream<InputItem>,
buffer: usize,
batch_config: Option<BatchConfig>,
cancel: Option<CancellationToken>,
) -> Result<BoxStream<OutputItem>, ActorSendError>
where
A: TransformHandler<InputItem, OutputItem>,
InputItem: Send + 'static,
OutputItem: Send + 'static;
}
pub fn cancel_after(duration: Duration) -> CancellationToken {
let token = CancellationToken::new();
let child = token.clone();
tokio::spawn(async move {
tokio::time::sleep(duration).await;
child.cancel();
});
token
}
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct SpawnConfig {
pub mailbox: MailboxConfig,
pub target_node: Option<crate::node::NodeId>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::errors::ErrorAction;
use crate::message::Message;
use crate::node::{ActorId, NodeId};
struct Counter {
count: u64,
}
impl Actor for Counter {
type Args = Self;
type Deps = ();
fn create(args: Self, _deps: ()) -> Self {
args
}
}
struct Increment(u64);
impl Message for Increment {
type Reply = ();
}
struct GetCount;
impl Message for GetCount {
type Reply = u64;
}
struct Reset;
impl Message for Reset {
type Reply = u64;
}
#[async_trait]
impl Handler<Increment> for Counter {
async fn handle(&mut self, msg: Increment, _ctx: &mut ActorContext) {
self.count += msg.0;
}
}
#[async_trait]
impl Handler<GetCount> for Counter {
async fn handle(&mut self, _msg: GetCount, _ctx: &mut ActorContext) -> u64 {
self.count
}
}
#[async_trait]
impl Handler<Reset> for Counter {
async fn handle(&mut self, _msg: Reset, _ctx: &mut ActorContext) -> u64 {
let old = self.count;
self.count = 0;
old
}
}
#[test]
fn test_counter_actor_compiles() {
let counter = Counter::create(Counter { count: 0 }, ());
assert_eq!(counter.count, 0);
}
#[test]
fn test_actor_default_on_error_returns_stop() {
let mut counter = Counter { count: 0 };
let action = counter.on_error(&ActorError::internal("test error"));
assert_eq!(action, ErrorAction::Stop);
}
struct WorkerArgs {
name: String,
}
struct WorkerDeps {
multiplier: u64,
}
struct Worker {
name: String,
multiplier: u64,
}
impl Actor for Worker {
type Args = WorkerArgs;
type Deps = WorkerDeps;
fn create(args: WorkerArgs, deps: WorkerDeps) -> Self {
Worker {
name: args.name,
multiplier: deps.multiplier,
}
}
}
#[test]
fn test_worker_actor_with_deps() {
let worker = Worker::create(
WorkerArgs { name: "w1".into() },
WorkerDeps { multiplier: 10 },
);
assert_eq!(worker.name, "w1");
assert_eq!(worker.multiplier, 10);
}
#[test]
fn test_actor_id_display() {
let id = ActorId {
node: NodeId("node-1".into()),
local: 42,
};
assert_eq!(format!("{}", id), "Actor(node-1/42)");
}
#[test]
fn test_actor_id_equality() {
let id1 = ActorId {
node: NodeId("n1".into()),
local: 1,
};
let id2 = ActorId {
node: NodeId("n1".into()),
local: 1,
};
let id3 = ActorId {
node: NodeId("n1".into()),
local: 2,
};
assert_eq!(id1, id2);
assert_ne!(id1, id3);
}
#[test]
fn test_actor_id_clone() {
let id = ActorId {
node: NodeId("n1".into()),
local: 1,
};
let cloned = id.clone();
assert_eq!(id, cloned);
}
#[test]
fn test_error_action_variants() {
assert_eq!(ErrorAction::Resume, ErrorAction::Resume);
assert_eq!(ErrorAction::Restart, ErrorAction::Restart);
assert_eq!(ErrorAction::Stop, ErrorAction::Stop);
assert_eq!(ErrorAction::Escalate, ErrorAction::Escalate);
assert_ne!(ErrorAction::Resume, ErrorAction::Stop);
}
#[test]
fn test_spawn_config_default() {
let config = SpawnConfig::default();
assert!(config.target_node.is_none());
}
#[test]
fn test_spawn_config_with_target_node() {
let config = SpawnConfig {
target_node: Some(NodeId("node-3".into())),
..Default::default()
};
assert_eq!(config.target_node.unwrap().0, "node-3");
}
#[test]
fn test_actor_context_fields() {
let ctx = ActorContext {
actor_id: ActorId {
node: NodeId("n1".into()),
local: 1,
},
actor_name: "test-actor".into(),
send_mode: None,
headers: Headers::new(),
cancellation_token: None,
};
assert_eq!(ctx.actor_name, "test-actor");
assert_eq!(ctx.actor_id.local, 1);
}
#[tokio::test]
async fn test_lifecycle_defaults_are_noop() {
let mut counter = Counter { count: 42 };
let mut ctx = ActorContext {
actor_id: ActorId {
node: NodeId("n1".into()),
local: 1,
},
actor_name: "counter".into(),
send_mode: None,
headers: Headers::new(),
cancellation_token: None,
};
counter.on_start(&mut ctx).await;
counter.on_stop().await;
assert_eq!(counter.count, 42);
}
#[tokio::test]
async fn test_handler_increment() {
let mut counter = Counter { count: 0 };
let mut ctx = ActorContext {
actor_id: ActorId {
node: NodeId("n1".into()),
local: 1,
},
actor_name: "counter".into(),
send_mode: None,
headers: Headers::new(),
cancellation_token: None,
};
counter.handle(Increment(5), &mut ctx).await;
assert_eq!(counter.count, 5);
counter.handle(Increment(3), &mut ctx).await;
assert_eq!(counter.count, 8);
}
#[tokio::test]
async fn test_handler_get_count() {
let mut counter = Counter { count: 42 };
let mut ctx = ActorContext {
actor_id: ActorId {
node: NodeId("n1".into()),
local: 1,
},
actor_name: "counter".into(),
send_mode: None,
headers: Headers::new(),
cancellation_token: None,
};
let count = counter.handle(GetCount, &mut ctx).await;
assert_eq!(count, 42);
}
#[tokio::test]
async fn test_handler_reset() {
let mut counter = Counter { count: 100 };
let mut ctx = ActorContext {
actor_id: ActorId {
node: NodeId("n1".into()),
local: 1,
},
actor_name: "counter".into(),
send_mode: None,
headers: Headers::new(),
cancellation_token: None,
};
let old = counter.handle(Reset, &mut ctx).await;
assert_eq!(old, 100);
assert_eq!(counter.count, 0);
}
#[tokio::test]
async fn test_multiple_handlers_on_same_actor() {
let mut counter = Counter { count: 0 };
let mut ctx = ActorContext {
actor_id: ActorId {
node: NodeId("n1".into()),
local: 1,
},
actor_name: "counter".into(),
send_mode: None,
headers: Headers::new(),
cancellation_token: None,
};
counter.handle(Increment(10), &mut ctx).await;
counter.handle(Increment(20), &mut ctx).await;
let count = counter.handle(GetCount, &mut ctx).await;
assert_eq!(count, 30);
let old = counter.handle(Reset, &mut ctx).await;
assert_eq!(old, 30);
assert_eq!(counter.count, 0);
}
#[test]
fn test_handler_requires_actor_bound() {
fn assert_handler<A: Handler<M>, M: Message>() {}
assert_handler::<Counter, Increment>();
assert_handler::<Counter, GetCount>();
assert_handler::<Counter, Reset>();
}
#[test]
fn test_actor_error_construction() {
let err = ActorError::new(ErrorCode::InvalidArgument, "bad input");
assert_eq!(err.code, ErrorCode::InvalidArgument);
assert_eq!(err.message, "bad input");
assert!(err.details.is_none());
assert!(err.cause.is_none());
}
#[test]
fn test_actor_error_with_details() {
let err = ActorError::new(ErrorCode::NotFound, "user not found").with_details("user_id=42");
assert_eq!(err.details.as_deref(), Some("user_id=42"));
}
#[test]
fn test_actor_error_chain() {
let root = ActorError::new(ErrorCode::Unavailable, "db connection failed");
let err = ActorError::new(ErrorCode::Internal, "query failed").with_cause(root);
assert!(err.cause.is_some());
assert_eq!(err.cause.as_ref().unwrap().code, ErrorCode::Unavailable);
}
#[test]
fn test_actor_error_display() {
let err = ActorError::new(ErrorCode::Internal, "something broke")
.with_details("stack: foo.rs:42");
let display = format!("{}", err);
assert!(display.contains("Internal"));
assert!(display.contains("something broke"));
assert!(display.contains("stack: foo.rs:42"));
}
#[test]
fn test_actor_error_display_with_chain() {
let root = ActorError::new(ErrorCode::Unavailable, "db down");
let err = ActorError::new(ErrorCode::Internal, "query failed").with_cause(root);
let display = format!("{}", err);
assert!(display.contains("caused by"));
assert!(display.contains("db down"));
}
#[test]
fn test_error_code_variants() {
let codes = vec![
ErrorCode::Internal,
ErrorCode::InvalidArgument,
ErrorCode::NotFound,
ErrorCode::Unavailable,
ErrorCode::Timeout,
ErrorCode::PermissionDenied,
ErrorCode::FailedPrecondition,
ErrorCode::ResourceExhausted,
ErrorCode::Unimplemented,
ErrorCode::Unknown,
ErrorCode::Cancelled,
];
assert_eq!(codes.len(), 11);
for (i, a) in codes.iter().enumerate() {
for (j, b) in codes.iter().enumerate() {
if i != j {
assert_ne!(a, b);
}
}
}
}
#[test]
fn test_actor_error_internal_helper() {
let err = ActorError::internal("oops");
assert_eq!(err.code, ErrorCode::Internal);
assert_eq!(err.message, "oops");
}
#[test]
fn test_not_supported_error() {
use crate::errors::NotSupportedError;
let err = NotSupportedError {
capability: "BoundedMailbox".into(),
message: "ractor does not support bounded mailboxes".into(),
};
assert!(format!("{}", err).contains("BoundedMailbox"));
}
#[test]
fn test_runtime_error_not_supported() {
use crate::errors::NotSupportedError;
let err = RuntimeError::NotSupported(NotSupportedError {
capability: "PriorityMailbox".into(),
message: "not available".into(),
});
assert!(format!("{}", err).contains("PriorityMailbox"));
}
}