use crate::{actor::ActorMessage, ActorError, ActorHandle, ActorState, LocalTaskHandle, LocalTaskSpawner};
use co_primitives::Tags;
use std::{any::type_name, sync::Arc};
use tokio::sync::{mpsc, watch};
#[cfg(feature = "js")]
use tokio_with_wasm::alias as tokio;
use tracing::Instrument;
#[allow(async_fn_in_trait)]
pub trait LocalActor: 'static {
type Message: Send + 'static;
type State: 'static;
type Initialize: 'static;
async fn initialize(
&self,
handle: &ActorHandle<Self::Message>,
tags: &Tags,
initialize: Self::Initialize,
) -> Result<Self::State, ActorError>;
async fn handle(
&self,
handle: &ActorHandle<Self::Message>,
message: Self::Message,
state: &mut Self::State,
) -> Result<(), ActorError>;
fn tags(&self, tags: Tags) -> Result<Tags, ActorError> {
Ok(tags)
}
async fn shutdown(&self, _state: Self::State) -> Result<(), ActorError> {
Ok(())
}
fn spawner(tags: Tags, actor: Self) -> Result<LocalActorSpawner<Self>, ActorError>
where
Self: Sized + 'static,
{
LocalActorSpawner::new(tags, actor)
}
#[track_caller]
fn spawn_with(
spawner: impl LocalTaskSpawner,
tags: Tags,
actor: Self,
initialize: Self::Initialize,
) -> Result<LocalActorInstance<Self>, ActorError>
where
Self: Sized + 'static,
{
Ok(Self::spawner(tags, actor)?.spawn_local(spawner, initialize))
}
}
pub struct LocalActorSpawner<A>
where
A: LocalActor,
{
handle: ActorHandle<A::Message>,
actor: A,
rx: mpsc::UnboundedReceiver<ActorMessage<A::Message>>,
state_tx: tokio::sync::watch::Sender<ActorState>,
}
impl<A> LocalActorSpawner<A>
where
A: LocalActor,
{
pub fn new(tags: Tags, actor: A) -> Result<Self, ActorError> {
let (tx, rx) = mpsc::unbounded_channel();
let (state_tx, state_rx) = watch::channel(ActorState::Starting);
let tags = Arc::new(actor.tags(tags)?);
let handle = ActorHandle { tx: tx.clone(), state: state_rx.clone(), tags: tags.clone() };
Ok(Self { handle, actor, rx, state_tx })
}
pub fn handle(&self) -> ActorHandle<A::Message> {
self.handle.clone()
}
#[track_caller]
pub fn spawn_local(self, spawner: impl LocalTaskSpawner, initialize: A::Initialize) -> LocalActorInstance<A> {
let mut rx = self.rx;
let state_tx = self.state_tx;
let actor = self.actor;
let tags = self.handle.tags.clone();
let handle = self.handle;
let span = tracing::trace_span!("actor", ?tags, actor_type = type_name::<A>());
let join = spawner.spawn_local({
let tags = tags.clone();
let handle = handle.clone();
let actor_span = span.clone();
async move {
tracing::trace!(?tags, "actor-initialize");
let mut actor_state = actor.initialize(&handle, &tags, initialize).await.map_err(|err| {
tracing::error!(?err, ?tags, "actor-initialize-failed");
err
})?;
state_tx
.send(ActorState::Running)
.map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
let weak_handle = handle.downgrade();
while let Some(actor_message) = rx.recv().await {
let (message, message_span, _parent_span) = match actor_message {
ActorMessage::Message(message) => (message, tracing::trace_span!("actor-handle"), None),
ActorMessage::MessageWithSpan(message, message_span) => {
(message, tracing::trace_span!(parent: &message_span, "actor-handle"), Some(message_span))
},
ActorMessage::Shutdown => {
tracing::trace!("actor-shutdown");
break;
},
};
message_span.follows_from(&actor_span);
if let Some(handle) = weak_handle.clone().upgrade() {
actor
.handle(&handle, message, &mut actor_state)
.instrument(message_span)
.await
.map_err(|err| {
tracing::error!(?err, ?tags, "actor-handle-failed");
err
})?;
}
}
state_tx
.send(ActorState::Stopping)
.map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
rx.close();
actor.shutdown(actor_state).await.map_err(|err| {
tracing::error!(?err, ?tags, "actor-shutdown-failed");
err
})?;
state_tx
.send(ActorState::None)
.map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
Ok(())
}
.instrument(span)
});
LocalActorInstance { handle, join }
}
}
pub struct LocalActorInstance<A>
where
A: LocalActor,
{
handle: ActorHandle<A::Message>,
join: LocalTaskHandle<Result<(), ActorError>>,
}
impl<A: std::fmt::Debug> std::fmt::Debug for LocalActorInstance<A>
where
A: LocalActor,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalActorInstance").field("handle", &self.handle).finish()
}
}
impl<A> LocalActorInstance<A>
where
A: LocalActor,
{
pub fn handle(&self) -> ActorHandle<A::Message> {
self.handle.clone()
}
pub fn tags(&self) -> Tags {
self.handle.tags.as_ref().clone()
}
pub fn shutdown(&self) {
self.handle().shutdown();
}
pub async fn join(self) -> Result<(), ActorError> {
let tags = self.tags();
drop(self.handle);
self.join.await.map_err(|e| ActorError::InvalidState(e.into(), tags))??;
Ok(())
}
pub fn state(&self) -> ActorState {
*self.handle.state.borrow()
}
}