use crate::{
ActorPath, Error,
handler::HandleHelper,
runner::{InnerAction, InnerSender, StopHandle, StopSender},
supervision::SupervisionStrategy,
system::SystemRef,
};
use tokio::sync::{broadcast::Receiver as EventReceiver, mpsc, oneshot};
use async_trait::async_trait;
use serde::{Serialize, de::DeserializeOwned};
use tracing::Span;
use std::{collections::HashMap, fmt::Debug, time::Duration};
pub struct ActorContext<A: Actor + Handler<A>> {
stop: StopSender,
path: ActorPath,
system: SystemRef,
error: Option<Error>,
error_sender: ChildErrorSender,
inner_sender: InnerSender<A>,
child_senders: HashMap<ActorPath, StopHandle>,
span: tracing::Span,
}
impl<A> ActorContext<A>
where
A: Actor + Handler<A>,
{
pub(crate) fn new(
stop: StopSender,
path: ActorPath,
system: SystemRef,
error_sender: ChildErrorSender,
inner_sender: InnerSender<A>,
span: Span,
) -> Self {
Self {
span,
stop,
path,
system,
error: None,
error_sender,
inner_sender,
child_senders: HashMap::new(),
}
}
pub(crate) async fn restart(
&mut self,
actor: &mut A,
error: Option<&Error>,
) -> Result<(), Error>
where
A: Actor,
{
tracing::warn!(error = ?error, "Actor restarting");
let result = actor.pre_restart(self, error).await;
if let Err(ref e) = result {
tracing::error!(error = %e, "Actor restart failed");
}
result
}
pub async fn reference(&self) -> Result<ActorRef<A>, Error> {
self.system.get_actor(&self.path).await
}
pub const fn path(&self) -> &ActorPath {
&self.path
}
pub const fn system(&self) -> &SystemRef {
&self.system
}
pub async fn get_parent<P: Actor + Handler<P>>(
&self,
) -> Result<ActorRef<P>, Error> {
self.system.get_actor(&self.path.parent()).await
}
pub(crate) async fn stop_childs(&mut self) {
let child_count = self.child_senders.len();
if child_count > 0 {
tracing::debug!(child_count, "Stopping child actors");
}
let mut receivers = Vec::with_capacity(child_count);
for (path, handle) in std::mem::take(&mut self.child_senders) {
let (stop_sender, stop_receiver) = oneshot::channel();
if handle.sender().send(Some(stop_sender)).await.is_ok() {
receivers.push((path, handle.timeout(), stop_receiver));
}
}
for (path, timeout, receiver) in receivers {
if let Some(timeout) = timeout {
if tokio::time::timeout(timeout, receiver).await.is_err() {
tracing::warn!(
child = %path,
timeout_ms = timeout.as_millis(),
"Timed out waiting for child actor shutdown acknowledgement"
);
}
} else {
let _ = receiver.await;
}
}
}
pub(crate) async fn remove_actor(&self) {
self.system.remove_actor(&self.path).await;
}
pub async fn stop(&self, sender: Option<oneshot::Sender<()>>) {
let _ = self.stop.send(sender).await;
}
pub async fn publish_event(&self, event: A::Event) -> Result<(), Error> {
self.inner_sender
.send(InnerAction::Event(event))
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to publish event");
Error::SendEvent {
reason: e.to_string(),
}
})
}
pub async fn emit_error(&mut self, error: Error) -> Result<(), Error> {
tracing::warn!(error = %error, "Emitting error");
self.inner_sender
.send(InnerAction::Error(error))
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to emit error");
Error::Send {
reason: e.to_string(),
}
})
}
pub async fn emit_fail(&mut self, error: Error) -> Result<(), Error> {
tracing::error!(error = %error, "Actor failing");
self.set_error(error.clone());
self.inner_sender
.send(InnerAction::Fail(error.clone()))
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to emit fail");
Error::Send {
reason: e.to_string(),
}
})
}
pub async fn create_child<C, I>(
&mut self,
name: &str,
actor_init: I,
) -> Result<ActorRef<C>, Error>
where
C: Actor + Handler<C>,
I: crate::IntoActor<C>,
{
tracing::debug!(child_name = %name, "Creating child actor");
let actor = actor_init.into_actor();
let path = self.path.clone() / name;
let result = self
.system
.create_actor_path(
path.clone(),
actor,
Some(self.error_sender.clone()),
C::get_span(name, Some(self.span.clone())),
)
.await;
match result {
Ok((actor_ref, stop_sender)) => {
let child_path = path.clone();
self.child_senders.insert(
path,
StopHandle::new(stop_sender.clone(), C::stop_timeout()),
);
let inner_sender = self.inner_sender.clone();
tokio::spawn(async move {
stop_sender.closed().await;
let _ = inner_sender
.send(InnerAction::ChildStopped(child_path))
.await;
});
tracing::debug!(child_name = %name, "Child actor created");
Ok(actor_ref)
}
Err(e) => {
tracing::debug!(child_name = %name, error = %e, "Failed to create child actor");
Err(e)
}
}
}
pub(crate) fn remove_closed_child(&mut self, child_path: &ActorPath) {
let should_remove = self
.child_senders
.get(child_path)
.map(StopHandle::is_closed)
.unwrap_or(false);
if should_remove {
self.child_senders.remove(child_path);
}
}
pub async fn get_child<C>(&self, name: &str) -> Result<ActorRef<C>, Error>
where
C: Actor + Handler<C>,
{
let path = self.path.clone() / name;
self.system.get_actor(&path).await
}
pub(crate) fn error(&self) -> Option<Error> {
self.error.clone()
}
pub(crate) fn set_error(&mut self, error: Error) {
self.error = Some(error);
}
pub(crate) fn clean_error(&mut self) {
self.error = None;
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ActorLifecycle {
Created,
Started,
Restarted,
Failed,
Stopped,
Terminated,
}
#[derive(Debug, Clone)]
pub enum ChildAction {
Stop,
Restart,
Delegate,
}
pub type ChildErrorReceiver = mpsc::Receiver<ChildError>;
pub type ChildErrorSender = mpsc::Sender<ChildError>;
pub enum ChildError {
Error {
error: Error,
},
Fault {
error: Error,
sender: oneshot::Sender<ChildAction>,
},
}
#[async_trait]
pub trait Actor: Send + Sync + Sized + 'static + Handler<Self> {
type Message: Message;
type Event: Event;
type Response: Response;
fn get_span(id: &str, parent_span: Option<Span>) -> tracing::Span;
fn drain_timeout() -> std::time::Duration {
std::time::Duration::from_secs(5)
}
fn startup_timeout() -> Option<Duration> {
None
}
fn stop_timeout() -> Option<Duration> {
None
}
fn supervision_strategy() -> SupervisionStrategy {
SupervisionStrategy::Stop
}
async fn pre_start(
&mut self,
_context: &mut ActorContext<Self>,
) -> Result<(), Error> {
Ok(())
}
async fn pre_restart(
&mut self,
ctx: &mut ActorContext<Self>,
_error: Option<&Error>,
) -> Result<(), Error> {
self.pre_start(ctx).await
}
async fn pre_stop(
&mut self,
_ctx: &mut ActorContext<Self>,
) -> Result<(), Error> {
Ok(())
}
async fn post_stop(
&mut self,
_ctx: &mut ActorContext<Self>,
) -> Result<(), Error> {
Ok(())
}
fn from_response(_response: Self::Response) -> Result<Self::Event, Error> {
Err(Error::Functional {
description: "Not implemented".to_string(),
})
}
}
pub trait Event:
Serialize + DeserializeOwned + Debug + Clone + Send + Sync + 'static
{
}
pub trait Message: Clone + Send + Sync + 'static {
fn is_critical(&self) -> bool {
false
}
}
pub trait Response: Send + Sync + 'static {}
impl Response for () {}
impl Event for () {}
impl Message for () {}
#[async_trait]
pub trait Handler<A: Actor + Handler<A>>: Send + Sync {
async fn handle_message(
&mut self,
sender: ActorPath,
msg: A::Message,
ctx: &mut ActorContext<A>,
) -> Result<A::Response, Error>;
async fn on_event(&mut self, _event: A::Event, _ctx: &mut ActorContext<A>) {
}
async fn on_child_error(
&mut self,
error: Error,
_ctx: &mut ActorContext<A>,
) {
tracing::error!(error = %error, "Child actor error");
}
async fn on_child_fault(
&mut self,
error: Error,
_ctx: &mut ActorContext<A>,
) -> ChildAction {
tracing::error!(error = %error, "Child actor fault, stopping child");
ChildAction::Stop
}
}
pub struct ActorRef<A>
where
A: Actor + Handler<A>,
{
path: ActorPath,
sender: HandleHelper<A>,
event_receiver: EventReceiver<<A as Actor>::Event>,
stop_sender: StopSender,
}
impl<A> ActorRef<A>
where
A: Actor + Handler<A>,
{
pub const fn new(
path: ActorPath,
sender: HandleHelper<A>,
stop_sender: StopSender,
event_receiver: EventReceiver<<A as Actor>::Event>,
) -> Self {
Self {
path,
sender,
stop_sender,
event_receiver,
}
}
pub async fn tell(&self, message: A::Message) -> Result<(), Error> {
self.sender.tell(self.path(), message).await
}
pub async fn ask(&self, message: A::Message) -> Result<A::Response, Error> {
self.sender.ask(self.path(), message).await
}
pub async fn ask_timeout(
&self,
message: A::Message,
timeout: std::time::Duration,
) -> Result<A::Response, Error> {
tokio::time::timeout(timeout, self.sender.ask(self.path(), message))
.await
.map_err(|_| Error::Timeout {
ms: timeout.as_millis(),
})?
}
pub async fn ask_stop(&self) -> Result<(), Error> {
tracing::debug!("Stopping actor");
let (response_sender, response_receiver) = oneshot::channel();
if self.stop_sender.send(Some(response_sender)).await.is_err() {
Ok(())
} else {
response_receiver.await.map_err(|error| {
tracing::error!(error = %error, "Failed to confirm actor stop");
Error::Send {
reason: error.to_string(),
}
})
}
}
pub async fn tell_stop(&self) {
let _ = self.stop_sender.send(None).await;
}
pub fn path(&self) -> ActorPath {
self.path.clone()
}
pub fn is_closed(&self) -> bool {
self.sender.is_closed()
}
pub async fn closed(&self) {
self.sender.close().await;
}
pub fn subscribe(&self) -> EventReceiver<<A as Actor>::Event> {
self.event_receiver.resubscribe()
}
}
impl<A> Clone for ActorRef<A>
where
A: Actor + Handler<A>,
{
fn clone(&self) -> Self {
Self {
path: self.path.clone(),
sender: self.sender.clone(),
stop_sender: self.stop_sender.clone(),
event_receiver: self.event_receiver.resubscribe(),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use test_log::test;
use crate::sink::{Sink, Subscriber};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::info_span;
#[derive(Debug, Clone)]
struct TestActor {
counter: usize,
}
impl crate::NotPersistentActor for TestActor {}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TestMessage(usize);
impl Message for TestMessage {}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TestResponse(usize);
impl Response for TestResponse {}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TestEvent(usize);
impl Event for TestEvent {}
#[async_trait]
impl Actor for TestActor {
type Message = TestMessage;
type Event = TestEvent;
type Response = TestResponse;
fn get_span(
id: &str,
_parent_span: Option<tracing::Span>,
) -> tracing::Span {
info_span!("TestActor", id = %id)
}
}
#[async_trait]
impl Handler<TestActor> for TestActor {
async fn handle_message(
&mut self,
_sender: ActorPath,
msg: TestMessage,
ctx: &mut ActorContext<TestActor>,
) -> Result<TestResponse, Error> {
if ctx.get_parent::<TestActor>().await.is_ok() {
panic!("Is not a root actor");
}
let value = msg.0;
self.counter += value;
ctx.publish_event(TestEvent(self.counter)).await.unwrap();
Ok(TestResponse(self.counter))
}
}
pub struct TestSubscriber;
#[async_trait]
impl Subscriber<TestEvent> for TestSubscriber {
async fn notify(&self, event: TestEvent) {
assert!(event.0 > 0);
}
}
#[test(tokio::test)]
async fn test_actor() {
let (event_sender, _event_receiver) = mpsc::channel(100);
let system = SystemRef::new(
event_sender,
CancellationToken::new(),
CancellationToken::new(),
);
let actor = TestActor { counter: 0 };
let actor_ref = system.create_root_actor("test", actor).await.unwrap();
let sink = Sink::new(actor_ref.subscribe(), TestSubscriber);
system.run_sink(sink).await;
actor_ref.tell(TestMessage(10)).await.unwrap();
let mut recv = actor_ref.subscribe();
let response = actor_ref.ask(TestMessage(10)).await.unwrap();
assert_eq!(response.0, 20);
let event = recv.recv().await.unwrap();
assert_eq!(event.0, 10);
let event = recv.recv().await.unwrap();
assert_eq!(event.0, 20);
actor_ref.ask_stop().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
}
}