Skip to main content

co_actor/
actor_local.rs

1// SPDX-License-Identifier: AGPL-3.0-only
2// Copyright (C) 2026 1io BRANDGUARDIAN GmbH
3
4use crate::{actor::ActorMessage, ActorError, ActorHandle, ActorState, LocalTaskHandle, LocalTaskSpawner};
5use co_primitives::Tags;
6use std::{any::type_name, sync::Arc};
7use tokio::sync::{mpsc, watch};
8#[cfg(feature = "js")]
9use tokio_with_wasm::alias as tokio;
10use tracing::Instrument;
11
12/// A LocalActor will not moved between threads.
13/// This is sometimes necessary when interfacing with external code.
14/// This trait allows to implement such behaviour with same public interface as a normal [`Actor`] ([`ActorHandle`]).
15/// For new code that dont have this requirement is usually better to use [`Actor`] as it allows to use multithreading.
16#[allow(async_fn_in_trait)]
17pub trait LocalActor: 'static {
18	type Message: Send + 'static;
19	type State: 'static;
20	type Initialize: 'static;
21
22	async fn initialize(
23		&self,
24		handle: &ActorHandle<Self::Message>,
25		tags: &Tags,
26		initialize: Self::Initialize,
27	) -> Result<Self::State, ActorError>;
28
29	async fn handle(
30		&self,
31		handle: &ActorHandle<Self::Message>,
32		message: Self::Message,
33		state: &mut Self::State,
34	) -> Result<(), ActorError>;
35
36	fn tags(&self, tags: Tags) -> Result<Tags, ActorError> {
37		Ok(tags)
38	}
39
40	/// Shutdown the actor.
41	/// This is not cancelable.
42	/// After this call no more message will be received.
43	/// Will not be executed if actor panics.
44	async fn shutdown(&self, _state: Self::State) -> Result<(), ActorError> {
45		Ok(())
46	}
47
48	fn spawner(tags: Tags, actor: Self) -> Result<LocalActorSpawner<Self>, ActorError>
49	where
50		Self: Sized + 'static,
51	{
52		LocalActorSpawner::new(tags, actor)
53	}
54
55	/// Spawn actor using a task spawner.
56	#[track_caller]
57	fn spawn_with(
58		spawner: impl LocalTaskSpawner,
59		tags: Tags,
60		actor: Self,
61		initialize: Self::Initialize,
62	) -> Result<LocalActorInstance<Self>, ActorError>
63	where
64		Self: Sized + 'static,
65	{
66		Ok(Self::spawner(tags, actor)?.spawn_local(spawner, initialize))
67	}
68}
69
70/// Actor Spawner with early access to the handle (which allow cyclic references).
71pub struct LocalActorSpawner<A>
72where
73	A: LocalActor,
74{
75	handle: ActorHandle<A::Message>,
76	actor: A,
77	rx: mpsc::UnboundedReceiver<ActorMessage<A::Message>>,
78	state_tx: tokio::sync::watch::Sender<ActorState>,
79}
80impl<A> LocalActorSpawner<A>
81where
82	A: LocalActor,
83{
84	pub fn new(tags: Tags, actor: A) -> Result<Self, ActorError> {
85		let (tx, rx) = mpsc::unbounded_channel();
86		let (state_tx, state_rx) = watch::channel(ActorState::Starting);
87		let tags = Arc::new(actor.tags(tags)?);
88		let handle = ActorHandle { tx: tx.clone(), state: state_rx.clone(), tags: tags.clone() };
89		Ok(Self { handle, actor, rx, state_tx })
90	}
91
92	pub fn handle(&self) -> ActorHandle<A::Message> {
93		self.handle.clone()
94	}
95
96	#[track_caller]
97	pub fn spawn_local(self, spawner: impl LocalTaskSpawner, initialize: A::Initialize) -> LocalActorInstance<A> {
98		let mut rx = self.rx;
99		let state_tx = self.state_tx;
100		let actor = self.actor;
101		let tags = self.handle.tags.clone();
102		let handle = self.handle;
103		let span = tracing::trace_span!("actor", ?tags, actor_type = type_name::<A>());
104		let join = spawner.spawn_local({
105			let tags = tags.clone();
106			let handle = handle.clone();
107			let actor_span = span.clone();
108			async move {
109				// log
110				tracing::trace!(?tags, "actor-initialize");
111
112				// initialize
113				let mut actor_state = actor.initialize(&handle, &tags, initialize).await.map_err(|err| {
114					tracing::error!(?err, ?tags, "actor-initialize-failed");
115					err
116				})?;
117				state_tx
118					.send(ActorState::Running)
119					.map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
120
121				// execute
122				let weak_handle = handle.downgrade();
123				while let Some(actor_message) = rx.recv().await {
124					// handle message
125					let (message, message_span, _parent_span) = match actor_message {
126						ActorMessage::Message(message) => (message, tracing::trace_span!("actor-handle"), None),
127						ActorMessage::MessageWithSpan(message, message_span) => {
128							(message, tracing::trace_span!(parent: &message_span, "actor-handle"), Some(message_span))
129						},
130						ActorMessage::Shutdown => {
131							// log
132							tracing::trace!("actor-shutdown");
133
134							// done
135							break;
136						},
137					};
138					message_span.follows_from(&actor_span);
139
140					// get a strong handle to call the handle method - this should never fail as we should not
141					// receive any message when this fails.
142					if let Some(handle) = weak_handle.clone().upgrade() {
143						actor
144							.handle(&handle, message, &mut actor_state)
145							.instrument(message_span)
146							.await
147							.map_err(|err| {
148								tracing::error!(?err, ?tags, "actor-handle-failed");
149								err
150							})?;
151					}
152				}
153
154				// state
155				state_tx
156					.send(ActorState::Stopping)
157					.map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
158				rx.close();
159
160				// shutdown
161				actor.shutdown(actor_state).await.map_err(|err| {
162					tracing::error!(?err, ?tags, "actor-shutdown-failed");
163					err
164				})?;
165
166				// done
167				state_tx
168					.send(ActorState::None)
169					.map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
170				Ok(())
171			}
172			.instrument(span)
173		});
174		LocalActorInstance { handle, join }
175	}
176}
177
178/// The actual actor instance.
179pub struct LocalActorInstance<A>
180where
181	A: LocalActor,
182{
183	handle: ActorHandle<A::Message>,
184	join: LocalTaskHandle<Result<(), ActorError>>,
185}
186
187impl<A: std::fmt::Debug> std::fmt::Debug for LocalActorInstance<A>
188where
189	A: LocalActor,
190{
191	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192		f.debug_struct("LocalActorInstance").field("handle", &self.handle).finish()
193	}
194}
195impl<A> LocalActorInstance<A>
196where
197	A: LocalActor,
198{
199	/// Get actor handle.
200	pub fn handle(&self) -> ActorHandle<A::Message> {
201		self.handle.clone()
202	}
203
204	/// Get actor tags.
205	pub fn tags(&self) -> Tags {
206		self.handle.tags.as_ref().clone()
207	}
208
209	/// Request shutdown.
210	pub fn shutdown(&self) {
211		self.handle().shutdown();
212	}
213
214	/// Wait until the actor completes.
215	pub async fn join(self) -> Result<(), ActorError> {
216		let tags = self.tags();
217		drop(self.handle);
218		self.join.await.map_err(|e| ActorError::InvalidState(e.into(), tags))??;
219		Ok(())
220	}
221
222	/// Get actor state.
223	pub fn state(&self) -> ActorState {
224		*self.handle.state.borrow()
225	}
226}