1use crate::{actor::ActorMessage, ActorError, ActorHandle, ActorState, LocalTaskHandle, LocalTaskSpawner};
5use co_primitives::Tags;
6use std::{any::type_name, sync::Arc};
7use tokio::sync::{mpsc, watch};
8#[cfg(feature = "js")]
9use tokio_with_wasm::alias as tokio;
10use tracing::Instrument;
11
12#[allow(async_fn_in_trait)]
17pub trait LocalActor: 'static {
18 type Message: Send + 'static;
19 type State: 'static;
20 type Initialize: 'static;
21
22 async fn initialize(
23 &self,
24 handle: &ActorHandle<Self::Message>,
25 tags: &Tags,
26 initialize: Self::Initialize,
27 ) -> Result<Self::State, ActorError>;
28
29 async fn handle(
30 &self,
31 handle: &ActorHandle<Self::Message>,
32 message: Self::Message,
33 state: &mut Self::State,
34 ) -> Result<(), ActorError>;
35
36 fn tags(&self, tags: Tags) -> Result<Tags, ActorError> {
37 Ok(tags)
38 }
39
40 async fn shutdown(&self, _state: Self::State) -> Result<(), ActorError> {
45 Ok(())
46 }
47
48 fn spawner(tags: Tags, actor: Self) -> Result<LocalActorSpawner<Self>, ActorError>
49 where
50 Self: Sized + 'static,
51 {
52 LocalActorSpawner::new(tags, actor)
53 }
54
55 #[track_caller]
57 fn spawn_with(
58 spawner: impl LocalTaskSpawner,
59 tags: Tags,
60 actor: Self,
61 initialize: Self::Initialize,
62 ) -> Result<LocalActorInstance<Self>, ActorError>
63 where
64 Self: Sized + 'static,
65 {
66 Ok(Self::spawner(tags, actor)?.spawn_local(spawner, initialize))
67 }
68}
69
70pub struct LocalActorSpawner<A>
72where
73 A: LocalActor,
74{
75 handle: ActorHandle<A::Message>,
76 actor: A,
77 rx: mpsc::UnboundedReceiver<ActorMessage<A::Message>>,
78 state_tx: tokio::sync::watch::Sender<ActorState>,
79}
80impl<A> LocalActorSpawner<A>
81where
82 A: LocalActor,
83{
84 pub fn new(tags: Tags, actor: A) -> Result<Self, ActorError> {
85 let (tx, rx) = mpsc::unbounded_channel();
86 let (state_tx, state_rx) = watch::channel(ActorState::Starting);
87 let tags = Arc::new(actor.tags(tags)?);
88 let handle = ActorHandle { tx: tx.clone(), state: state_rx.clone(), tags: tags.clone() };
89 Ok(Self { handle, actor, rx, state_tx })
90 }
91
92 pub fn handle(&self) -> ActorHandle<A::Message> {
93 self.handle.clone()
94 }
95
96 #[track_caller]
97 pub fn spawn_local(self, spawner: impl LocalTaskSpawner, initialize: A::Initialize) -> LocalActorInstance<A> {
98 let mut rx = self.rx;
99 let state_tx = self.state_tx;
100 let actor = self.actor;
101 let tags = self.handle.tags.clone();
102 let handle = self.handle;
103 let span = tracing::trace_span!("actor", ?tags, actor_type = type_name::<A>());
104 let join = spawner.spawn_local({
105 let tags = tags.clone();
106 let handle = handle.clone();
107 let actor_span = span.clone();
108 async move {
109 tracing::trace!(?tags, "actor-initialize");
111
112 let mut actor_state = actor.initialize(&handle, &tags, initialize).await.map_err(|err| {
114 tracing::error!(?err, ?tags, "actor-initialize-failed");
115 err
116 })?;
117 state_tx
118 .send(ActorState::Running)
119 .map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
120
121 let weak_handle = handle.downgrade();
123 while let Some(actor_message) = rx.recv().await {
124 let (message, message_span, _parent_span) = match actor_message {
126 ActorMessage::Message(message) => (message, tracing::trace_span!("actor-handle"), None),
127 ActorMessage::MessageWithSpan(message, message_span) => {
128 (message, tracing::trace_span!(parent: &message_span, "actor-handle"), Some(message_span))
129 },
130 ActorMessage::Shutdown => {
131 tracing::trace!("actor-shutdown");
133
134 break;
136 },
137 };
138 message_span.follows_from(&actor_span);
139
140 if let Some(handle) = weak_handle.clone().upgrade() {
143 actor
144 .handle(&handle, message, &mut actor_state)
145 .instrument(message_span)
146 .await
147 .map_err(|err| {
148 tracing::error!(?err, ?tags, "actor-handle-failed");
149 err
150 })?;
151 }
152 }
153
154 state_tx
156 .send(ActorState::Stopping)
157 .map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
158 rx.close();
159
160 actor.shutdown(actor_state).await.map_err(|err| {
162 tracing::error!(?err, ?tags, "actor-shutdown-failed");
163 err
164 })?;
165
166 state_tx
168 .send(ActorState::None)
169 .map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
170 Ok(())
171 }
172 .instrument(span)
173 });
174 LocalActorInstance { handle, join }
175 }
176}
177
178pub struct LocalActorInstance<A>
180where
181 A: LocalActor,
182{
183 handle: ActorHandle<A::Message>,
184 join: LocalTaskHandle<Result<(), ActorError>>,
185}
186
187impl<A: std::fmt::Debug> std::fmt::Debug for LocalActorInstance<A>
188where
189 A: LocalActor,
190{
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 f.debug_struct("LocalActorInstance").field("handle", &self.handle).finish()
193 }
194}
195impl<A> LocalActorInstance<A>
196where
197 A: LocalActor,
198{
199 pub fn handle(&self) -> ActorHandle<A::Message> {
201 self.handle.clone()
202 }
203
204 pub fn tags(&self) -> Tags {
206 self.handle.tags.as_ref().clone()
207 }
208
209 pub fn shutdown(&self) {
211 self.handle().shutdown();
212 }
213
214 pub async fn join(self) -> Result<(), ActorError> {
216 let tags = self.tags();
217 drop(self.handle);
218 self.join.await.map_err(|e| ActorError::InvalidState(e.into(), tags))??;
219 Ok(())
220 }
221
222 pub fn state(&self) -> ActorState {
224 *self.handle.state.borrow()
225 }
226}