1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use async_trait::async_trait;
8use tokio::sync::oneshot;
9use tokio_util::sync::CancellationToken;
10
11use crate::errors::{ActorSendError, ErrorAction, ErrorCode, RuntimeError};
12use crate::interceptor::SendMode;
13use crate::mailbox::MailboxConfig;
14use crate::message::{Headers, Message};
15use crate::node::ActorId;
16use crate::stream::{BatchConfig, BoxStream, StreamReceiver, StreamSender};
17
18#[derive(Debug, Clone)]
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25pub struct ActorError {
26 pub code: ErrorCode,
28 pub message: String,
30 pub details: Option<String>,
32 pub cause: Option<Box<ActorError>>,
34}
35
36impl ActorError {
37 pub fn new(code: ErrorCode, message: impl Into<String>) -> Self {
39 Self {
40 code,
41 message: message.into(),
42 details: None,
43 cause: None,
44 }
45 }
46
47 pub fn internal(message: impl Into<String>) -> Self {
49 Self::new(ErrorCode::Internal, message)
50 }
51
52 pub fn with_details(mut self, details: impl Into<String>) -> Self {
54 self.details = Some(details.into());
55 self
56 }
57
58 pub fn with_cause(mut self, cause: ActorError) -> Self {
60 self.cause = Some(Box::new(cause));
61 self
62 }
63}
64
65impl fmt::Display for ActorError {
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 write!(f, "[{:?}] {}", self.code, self.message)?;
68 if let Some(ref details) = self.details {
69 write!(f, " ({})", details)?;
70 }
71 if let Some(ref cause) = self.cause {
72 write!(f, " caused by: {}", cause)?;
73 }
74 Ok(())
75 }
76}
77
78impl std::error::Error for ActorError {}
79
80#[derive(Debug)]
87pub struct ActorContext {
88 pub actor_id: ActorId,
90 pub actor_name: String,
92 pub send_mode: Option<SendMode>,
95 pub headers: Headers,
98 pub(crate) cancellation_token: Option<CancellationToken>,
100}
101
102impl ActorContext {
103 pub fn new(actor_id: ActorId, actor_name: String) -> Self {
108 Self {
109 actor_id,
110 actor_name,
111 send_mode: None,
112 headers: Headers::new(),
113 cancellation_token: None,
114 }
115 }
116
117 pub async fn cancelled(&self) {
130 match &self.cancellation_token {
131 Some(token) => token.cancelled().await,
132 None => futures::future::pending().await,
133 }
134 }
135
136 pub fn set_cancellation_token(&mut self, token: Option<CancellationToken>) {
139 self.cancellation_token = token;
140 }
141}
142
143#[async_trait]
146pub trait Actor: Send + 'static {
147 type Args: Send + 'static;
149
150 type Deps: Send + 'static;
152
153 fn create(args: Self::Args, deps: Self::Deps) -> Self
158 where
159 Self: Sized;
160
161 async fn on_start(&mut self, _ctx: &mut ActorContext) {}
164
165 async fn on_stop(&mut self) {}
168
169 fn on_error(&mut self, _error: &ActorError) -> ErrorAction {
171 ErrorAction::Stop
172 }
173}
174
175#[async_trait]
178pub trait Handler<M: Message>: Actor {
179 async fn handle(&mut self, msg: M, ctx: &mut ActorContext) -> M::Reply;
181}
182
183#[async_trait]
191pub trait ExpandHandler<M, OutputItem: Send + 'static>: Actor
192where
193 M: Send + 'static,
194{
195 async fn handle_expand(
198 &mut self,
199 msg: M,
200 sender: StreamSender<OutputItem>,
201 ctx: &mut ActorContext,
202 );
203}
204
205#[async_trait]
215pub trait ReduceHandler<InputItem: Send + 'static, Reply: Send + 'static>: Actor {
216 async fn handle_reduce(
218 &mut self,
219 receiver: StreamReceiver<InputItem>,
220 ctx: &mut ActorContext,
221 ) -> Reply;
222}
223
224#[async_trait]
235pub trait TransformHandler<InputItem: Send + 'static, OutputItem: Send + 'static>: Actor {
236 async fn handle_transform(
238 &mut self,
239 item: InputItem,
240 sender: &StreamSender<OutputItem>,
241 ctx: &mut ActorContext,
242 );
243
244 async fn on_transform_complete(
246 &mut self,
247 sender: &StreamSender<OutputItem>,
248 ctx: &mut ActorContext,
249 ) {
250 let _ = (sender, ctx);
251 }
252}
253
254pub struct AskReply<R> {
259 rx: oneshot::Receiver<Result<R, RuntimeError>>,
260}
261
262impl<R> AskReply<R> {
263 pub fn new(rx: oneshot::Receiver<Result<R, RuntimeError>>) -> Self {
267 Self { rx }
268 }
269}
270
271impl<R> Future for AskReply<R> {
272 type Output = Result<R, RuntimeError>;
273
274 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
275 match Pin::new(&mut self.rx).poll(cx) {
276 Poll::Ready(Ok(Ok(reply))) => Poll::Ready(Ok(reply)),
277 Poll::Ready(Ok(Err(error))) => Poll::Ready(Err(error)),
278 Poll::Ready(Err(_)) => Poll::Ready(Err(RuntimeError::ActorNotFound(
279 "reply channel closed — actor may have stopped, panicked, or the request was cancelled".into(),
280 ))),
281 Poll::Pending => Poll::Pending,
282 }
283 }
284}
285
286pub trait ActorRef<A: Actor>: Clone + Send + Sync + 'static {
291 fn id(&self) -> ActorId;
293
294 fn name(&self) -> String;
296
297 fn is_alive(&self) -> bool;
299
300 fn pending_messages(&self) -> usize {
307 0
308 }
309
310 fn stop(&self);
312
313 fn tell<M>(&self, msg: M) -> Result<(), ActorSendError>
316 where
317 A: Handler<M>,
318 M: Message<Reply = ()>;
319
320 fn ask<M>(
327 &self,
328 msg: M,
329 cancel: Option<CancellationToken>,
330 ) -> Result<AskReply<M::Reply>, ActorSendError>
331 where
332 A: Handler<M>,
333 M: Message;
334
335 fn expand<M, OutputItem>(
343 &self,
344 msg: M,
345 buffer: usize,
346 batch_config: Option<BatchConfig>,
347 cancel: Option<CancellationToken>,
348 ) -> Result<BoxStream<OutputItem>, ActorSendError>
349 where
350 A: ExpandHandler<M, OutputItem>,
351 M: Send + 'static,
352 OutputItem: Send + 'static;
353
354 fn reduce<InputItem, Reply>(
367 &self,
368 input: BoxStream<InputItem>,
369 buffer: usize,
370 batch_config: Option<BatchConfig>,
371 cancel: Option<CancellationToken>,
372 ) -> Result<AskReply<Reply>, ActorSendError>
373 where
374 A: ReduceHandler<InputItem, Reply>,
375 InputItem: Send + 'static,
376 Reply: Send + 'static;
377
378 fn transform<InputItem, OutputItem>(
399 &self,
400 input: BoxStream<InputItem>,
401 buffer: usize,
402 batch_config: Option<BatchConfig>,
403 cancel: Option<CancellationToken>,
404 ) -> Result<BoxStream<OutputItem>, ActorSendError>
405 where
406 A: TransformHandler<InputItem, OutputItem>,
407 InputItem: Send + 'static,
408 OutputItem: Send + 'static;
409}
410
411pub fn cancel_after(duration: Duration) -> CancellationToken {
413 let token = CancellationToken::new();
414 let child = token.clone();
415 tokio::spawn(async move {
416 tokio::time::sleep(duration).await;
417 child.cancel();
418 });
419 token
420}
421
422#[derive(Debug, Clone, Default)]
431#[non_exhaustive]
432pub struct SpawnConfig {
433 pub mailbox: MailboxConfig,
435 pub target_node: Option<crate::node::NodeId>,
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use crate::errors::ErrorAction;
445 use crate::message::Message;
446 use crate::node::{ActorId, NodeId};
447
448 struct Counter {
450 count: u64,
451 }
452
453 impl Actor for Counter {
454 type Args = Self;
455 type Deps = ();
456
457 fn create(args: Self, _deps: ()) -> Self {
458 args
459 }
460 }
461
462 struct Increment(u64);
464 impl Message for Increment {
465 type Reply = ();
466 }
467
468 struct GetCount;
469 impl Message for GetCount {
470 type Reply = u64;
471 }
472
473 struct Reset;
474 impl Message for Reset {
475 type Reply = u64;
476 }
477
478 #[async_trait]
480 impl Handler<Increment> for Counter {
481 async fn handle(&mut self, msg: Increment, _ctx: &mut ActorContext) {
482 self.count += msg.0;
483 }
484 }
485
486 #[async_trait]
487 impl Handler<GetCount> for Counter {
488 async fn handle(&mut self, _msg: GetCount, _ctx: &mut ActorContext) -> u64 {
489 self.count
490 }
491 }
492
493 #[async_trait]
494 impl Handler<Reset> for Counter {
495 async fn handle(&mut self, _msg: Reset, _ctx: &mut ActorContext) -> u64 {
496 let old = self.count;
497 self.count = 0;
498 old
499 }
500 }
501
502 #[test]
503 fn test_counter_actor_compiles() {
504 let counter = Counter::create(Counter { count: 0 }, ());
505 assert_eq!(counter.count, 0);
506 }
507
508 #[test]
509 fn test_actor_default_on_error_returns_stop() {
510 let mut counter = Counter { count: 0 };
511 let action = counter.on_error(&ActorError::internal("test error"));
512 assert_eq!(action, ErrorAction::Stop);
513 }
514
515 struct WorkerArgs {
517 name: String,
518 }
519 struct WorkerDeps {
520 multiplier: u64,
521 }
522 struct Worker {
523 name: String,
524 multiplier: u64,
525 }
526
527 impl Actor for Worker {
528 type Args = WorkerArgs;
529 type Deps = WorkerDeps;
530
531 fn create(args: WorkerArgs, deps: WorkerDeps) -> Self {
532 Worker {
533 name: args.name,
534 multiplier: deps.multiplier,
535 }
536 }
537 }
538
539 #[test]
540 fn test_worker_actor_with_deps() {
541 let worker = Worker::create(
542 WorkerArgs { name: "w1".into() },
543 WorkerDeps { multiplier: 10 },
544 );
545 assert_eq!(worker.name, "w1");
546 assert_eq!(worker.multiplier, 10);
547 }
548
549 #[test]
551 fn test_actor_id_display() {
552 let id = ActorId {
553 node: NodeId("node-1".into()),
554 local: 42,
555 };
556 assert_eq!(format!("{}", id), "Actor(node-1/42)");
557 }
558
559 #[test]
560 fn test_actor_id_equality() {
561 let id1 = ActorId {
562 node: NodeId("n1".into()),
563 local: 1,
564 };
565 let id2 = ActorId {
566 node: NodeId("n1".into()),
567 local: 1,
568 };
569 let id3 = ActorId {
570 node: NodeId("n1".into()),
571 local: 2,
572 };
573 assert_eq!(id1, id2);
574 assert_ne!(id1, id3);
575 }
576
577 #[test]
578 fn test_actor_id_clone() {
579 let id = ActorId {
580 node: NodeId("n1".into()),
581 local: 1,
582 };
583 let cloned = id.clone();
584 assert_eq!(id, cloned);
585 }
586
587 #[test]
589 fn test_error_action_variants() {
590 assert_eq!(ErrorAction::Resume, ErrorAction::Resume);
591 assert_eq!(ErrorAction::Restart, ErrorAction::Restart);
592 assert_eq!(ErrorAction::Stop, ErrorAction::Stop);
593 assert_eq!(ErrorAction::Escalate, ErrorAction::Escalate);
594 assert_ne!(ErrorAction::Resume, ErrorAction::Stop);
595 }
596
597 #[test]
599 fn test_spawn_config_default() {
600 let config = SpawnConfig::default();
601 assert!(config.target_node.is_none());
602 }
603
604 #[test]
605 fn test_spawn_config_with_target_node() {
606 let config = SpawnConfig {
607 target_node: Some(NodeId("node-3".into())),
608 ..Default::default()
609 };
610 assert_eq!(config.target_node.unwrap().0, "node-3");
611 }
612
613 #[test]
615 fn test_actor_context_fields() {
616 let ctx = ActorContext {
617 actor_id: ActorId {
618 node: NodeId("n1".into()),
619 local: 1,
620 },
621 actor_name: "test-actor".into(),
622 send_mode: None,
623 headers: Headers::new(),
624 cancellation_token: None,
625 };
626 assert_eq!(ctx.actor_name, "test-actor");
627 assert_eq!(ctx.actor_id.local, 1);
628 }
629
630 #[tokio::test]
632 async fn test_lifecycle_defaults_are_noop() {
633 let mut counter = Counter { count: 42 };
634 let mut ctx = ActorContext {
635 actor_id: ActorId {
636 node: NodeId("n1".into()),
637 local: 1,
638 },
639 actor_name: "counter".into(),
640 send_mode: None,
641 headers: Headers::new(),
642 cancellation_token: None,
643 };
644 counter.on_start(&mut ctx).await;
645 counter.on_stop().await;
646 assert_eq!(counter.count, 42);
647 }
648
649 #[tokio::test]
652 async fn test_handler_increment() {
653 let mut counter = Counter { count: 0 };
654 let mut ctx = ActorContext {
655 actor_id: ActorId {
656 node: NodeId("n1".into()),
657 local: 1,
658 },
659 actor_name: "counter".into(),
660 send_mode: None,
661 headers: Headers::new(),
662 cancellation_token: None,
663 };
664 counter.handle(Increment(5), &mut ctx).await;
665 assert_eq!(counter.count, 5);
666 counter.handle(Increment(3), &mut ctx).await;
667 assert_eq!(counter.count, 8);
668 }
669
670 #[tokio::test]
671 async fn test_handler_get_count() {
672 let mut counter = Counter { count: 42 };
673 let mut ctx = ActorContext {
674 actor_id: ActorId {
675 node: NodeId("n1".into()),
676 local: 1,
677 },
678 actor_name: "counter".into(),
679 send_mode: None,
680 headers: Headers::new(),
681 cancellation_token: None,
682 };
683 let count = counter.handle(GetCount, &mut ctx).await;
684 assert_eq!(count, 42);
685 }
686
687 #[tokio::test]
688 async fn test_handler_reset() {
689 let mut counter = Counter { count: 100 };
690 let mut ctx = ActorContext {
691 actor_id: ActorId {
692 node: NodeId("n1".into()),
693 local: 1,
694 },
695 actor_name: "counter".into(),
696 send_mode: None,
697 headers: Headers::new(),
698 cancellation_token: None,
699 };
700 let old = counter.handle(Reset, &mut ctx).await;
701 assert_eq!(old, 100);
702 assert_eq!(counter.count, 0);
703 }
704
705 #[tokio::test]
706 async fn test_multiple_handlers_on_same_actor() {
707 let mut counter = Counter { count: 0 };
708 let mut ctx = ActorContext {
709 actor_id: ActorId {
710 node: NodeId("n1".into()),
711 local: 1,
712 },
713 actor_name: "counter".into(),
714 send_mode: None,
715 headers: Headers::new(),
716 cancellation_token: None,
717 };
718
719 counter.handle(Increment(10), &mut ctx).await;
720 counter.handle(Increment(20), &mut ctx).await;
721
722 let count = counter.handle(GetCount, &mut ctx).await;
723 assert_eq!(count, 30);
724
725 let old = counter.handle(Reset, &mut ctx).await;
726 assert_eq!(old, 30);
727 assert_eq!(counter.count, 0);
728 }
729
730 #[test]
731 fn test_handler_requires_actor_bound() {
732 fn assert_handler<A: Handler<M>, M: Message>() {}
733 assert_handler::<Counter, Increment>();
734 assert_handler::<Counter, GetCount>();
735 assert_handler::<Counter, Reset>();
736 }
737
738 #[test]
739 fn test_actor_error_construction() {
740 let err = ActorError::new(ErrorCode::InvalidArgument, "bad input");
741 assert_eq!(err.code, ErrorCode::InvalidArgument);
742 assert_eq!(err.message, "bad input");
743 assert!(err.details.is_none());
744 assert!(err.cause.is_none());
745 }
746
747 #[test]
748 fn test_actor_error_with_details() {
749 let err = ActorError::new(ErrorCode::NotFound, "user not found").with_details("user_id=42");
750 assert_eq!(err.details.as_deref(), Some("user_id=42"));
751 }
752
753 #[test]
754 fn test_actor_error_chain() {
755 let root = ActorError::new(ErrorCode::Unavailable, "db connection failed");
756 let err = ActorError::new(ErrorCode::Internal, "query failed").with_cause(root);
757 assert!(err.cause.is_some());
758 assert_eq!(err.cause.as_ref().unwrap().code, ErrorCode::Unavailable);
759 }
760
761 #[test]
762 fn test_actor_error_display() {
763 let err = ActorError::new(ErrorCode::Internal, "something broke")
764 .with_details("stack: foo.rs:42");
765 let display = format!("{}", err);
766 assert!(display.contains("Internal"));
767 assert!(display.contains("something broke"));
768 assert!(display.contains("stack: foo.rs:42"));
769 }
770
771 #[test]
772 fn test_actor_error_display_with_chain() {
773 let root = ActorError::new(ErrorCode::Unavailable, "db down");
774 let err = ActorError::new(ErrorCode::Internal, "query failed").with_cause(root);
775 let display = format!("{}", err);
776 assert!(display.contains("caused by"));
777 assert!(display.contains("db down"));
778 }
779
780 #[test]
781 fn test_error_code_variants() {
782 let codes = vec![
783 ErrorCode::Internal,
784 ErrorCode::InvalidArgument,
785 ErrorCode::NotFound,
786 ErrorCode::Unavailable,
787 ErrorCode::Timeout,
788 ErrorCode::PermissionDenied,
789 ErrorCode::FailedPrecondition,
790 ErrorCode::ResourceExhausted,
791 ErrorCode::Unimplemented,
792 ErrorCode::Unknown,
793 ErrorCode::Cancelled,
794 ];
795 assert_eq!(codes.len(), 11);
796 for (i, a) in codes.iter().enumerate() {
798 for (j, b) in codes.iter().enumerate() {
799 if i != j {
800 assert_ne!(a, b);
801 }
802 }
803 }
804 }
805
806 #[test]
807 fn test_actor_error_internal_helper() {
808 let err = ActorError::internal("oops");
809 assert_eq!(err.code, ErrorCode::Internal);
810 assert_eq!(err.message, "oops");
811 }
812
813 #[test]
814 fn test_not_supported_error() {
815 use crate::errors::NotSupportedError;
816 let err = NotSupportedError {
817 capability: "BoundedMailbox".into(),
818 message: "ractor does not support bounded mailboxes".into(),
819 };
820 assert!(format!("{}", err).contains("BoundedMailbox"));
821 }
822
823 #[test]
824 fn test_runtime_error_not_supported() {
825 use crate::errors::NotSupportedError;
826 let err = RuntimeError::NotSupported(NotSupportedError {
827 capability: "PriorityMailbox".into(),
828 message: "not available".into(),
829 });
830 assert!(format!("{}", err).contains("PriorityMailbox"));
831 }
832}