Skip to main content

co_actor/
actor.rs

1// SPDX-License-Identifier: AGPL-3.0-only
2// Copyright (C) 2026 1io BRANDGUARDIAN GmbH
3
4use crate::{
5	Response, ResponseBackPressureStream, ResponseBackPressureStreamReceiver, ResponseReceiver, ResponseStream,
6	ResponseStreamReceiver, TaskHandle, TaskOptions, TaskSpawner,
7};
8use anyhow::anyhow;
9use async_trait::async_trait;
10use co_primitives::Tags;
11use futures::{Stream, StreamExt};
12use std::{any::type_name, future::ready, ops::Deref, sync::Arc};
13use tokio::sync::{mpsc, watch};
14#[cfg(feature = "js")]
15use tokio_with_wasm::alias as tokio;
16use tracing::{Instrument, Span};
17
18#[derive(Debug, thiserror::Error)]
19pub enum ActorError {
20	#[error("Invalid actor state for that operation ({1}).")]
21	InvalidState(#[source] anyhow::Error, Tags),
22
23	#[error("Operation canceled.")]
24	Canceled,
25
26	#[error("Actor error")]
27	Actor(#[from] anyhow::Error),
28}
29
30/// Simple actor model implementation.
31/// Accepts messages which will be applied to the actor state.
32/// Actor state is different to the actual actor instance in order to allow initialization of it within the actor
33/// context.
34#[async_trait]
35pub trait Actor: Send + Sync + 'static {
36	type Message: Send + 'static;
37	type State: Send + 'static;
38	type Initialize: Send + 'static;
39
40	async fn initialize(
41		&self,
42		handle: &ActorHandle<Self::Message>,
43		tags: &Tags,
44		initialize: Self::Initialize,
45	) -> Result<Self::State, ActorError>;
46
47	async fn handle(
48		&self,
49		handle: &ActorHandle<Self::Message>,
50		message: Self::Message,
51		state: &mut Self::State,
52	) -> Result<(), ActorError>;
53
54	fn tags(&self, tags: Tags) -> Result<Tags, ActorError> {
55		Ok(tags)
56	}
57
58	/// Shutdown the actor.
59	/// This is not cancelable.
60	/// After this call no more message will be received.
61	/// Will not be executed if actor panics.
62	async fn shutdown(&self, _state: Self::State) -> Result<(), ActorError> {
63		Ok(())
64	}
65
66	fn spawner(tags: Tags, actor: Self) -> Result<ActorSpawner<Self>, ActorError>
67	where
68		Self: Send + Sync + Sized + 'static,
69	{
70		ActorSpawner::new(tags, actor)
71	}
72
73	/// Spawn actor.
74	#[track_caller]
75	fn spawn(tags: Tags, actor: Self, initialize: Self::Initialize) -> Result<ActorInstance<Self>, ActorError>
76	where
77		Self: Send + Sync + Sized + 'static,
78	{
79		Self::spawn_with(Default::default(), tags, actor, initialize)
80	}
81
82	/// Spawn actor using a task spawner.
83	#[track_caller]
84	fn spawn_with(
85		spawner: TaskSpawner,
86		tags: Tags,
87		actor: Self,
88		initialize: Self::Initialize,
89	) -> Result<ActorInstance<Self>, ActorError>
90	where
91		Self: Send + Sync + Sized + 'static,
92	{
93		Ok(Self::spawner(tags, actor)?.spawn(spawner, initialize))
94	}
95}
96
97/// Actor Spawner with early access to the handle (which allow cyclic references).
98pub struct ActorSpawner<A>
99where
100	A: Actor,
101{
102	handle: ActorHandle<A::Message>,
103	actor: A,
104	rx: tokio::sync::mpsc::UnboundedReceiver<ActorMessage<A::Message>>,
105	state_tx: tokio::sync::watch::Sender<ActorState>,
106	options: TaskOptions,
107}
108impl<A> ActorSpawner<A>
109where
110	A: Actor,
111{
112	pub fn new(tags: Tags, actor: A) -> Result<Self, ActorError> {
113		let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
114		let (state_tx, state_rx) = watch::channel(ActorState::Starting);
115		let tags = Arc::new(actor.tags(tags)?);
116		let handle = ActorHandle { tx: tx.clone(), state: state_rx.clone(), tags: tags.clone() };
117		Ok(Self { handle, actor, rx, state_tx, options: TaskOptions::new(type_name::<A>()) })
118	}
119
120	pub fn handle(&self) -> ActorHandle<A::Message> {
121		self.handle.clone()
122	}
123
124	#[track_caller]
125	pub fn spawn(self, spawner: TaskSpawner, initialize: A::Initialize) -> ActorInstance<A> {
126		let mut rx = self.rx;
127		let state_tx = self.state_tx;
128		let actor = self.actor;
129		let tags = self.handle.tags.clone();
130		let handle = self.handle;
131		let span = tracing::trace_span!("actor", ?tags, actor_type = type_name::<A>());
132		let join = spawner.spawn_options(self.options, {
133			let tags = tags.clone();
134			let handle = handle.clone();
135			let actor_span = span.clone();
136			async move {
137				// log
138				tracing::trace!("actor-initialize");
139
140				// initialize
141				let mut actor_state = actor.initialize(&handle, &tags, initialize).await.map_err(|err| {
142					tracing::error!(?err, "actor-initialize-failed");
143					err
144				})?;
145				state_tx
146					.send(ActorState::Running)
147					.map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
148
149				// execute
150				let weak_handle = handle.downgrade();
151				while let Some(actor_message) = rx.recv().await {
152					// handle message
153					let (message, message_span, _parent_span) = match actor_message {
154						ActorMessage::Message(message) => (message, tracing::trace_span!("actor-handle"), None),
155						ActorMessage::MessageWithSpan(message, message_span) => {
156							(message, tracing::trace_span!(parent: &message_span, "actor-handle"), Some(message_span))
157						},
158						ActorMessage::Shutdown => {
159							// log
160							tracing::trace!("actor-shutdown");
161
162							// done
163							break;
164						},
165					};
166					message_span.follows_from(&actor_span);
167
168					// get a strong handle to call the handle method - this should never fail as we should not
169					// receive any message when this fails.
170					if let Some(handle) = weak_handle.clone().upgrade() {
171						actor
172							.handle(&handle, message, &mut actor_state)
173							.instrument(message_span)
174							.await
175							.map_err(|err| {
176								tracing::error!(?err, "actor-handle-failed");
177								err
178							})?;
179					}
180				}
181
182				// state
183				state_tx
184					.send(ActorState::Stopping)
185					.map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
186				rx.close();
187
188				// shutdown
189				actor.shutdown(actor_state).await.map_err(|err| {
190					tracing::error!(?err, ?tags, "actor-shutdown-failed");
191					err
192				})?;
193
194				// done
195				state_tx
196					.send(ActorState::None)
197					.map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
198				Ok(())
199			}
200			.instrument(span)
201		});
202		ActorInstance { join, handle }
203	}
204}
205
206#[derive(Debug, Clone, Eq, PartialEq, Copy)]
207#[repr(u8)]
208pub enum ActorState {
209	/// Starting.
210	Starting,
211
212	/// Running.
213	Running,
214
215	/// Shutdown has been requested.
216	Stopping,
217
218	/// Not running (yet or anymore).
219	None,
220}
221
222#[derive(Debug)]
223pub enum ActorMessage<M> {
224	/// Actor shutdown requested.
225	Shutdown,
226
227	/// Actor received message.
228	#[allow(unused)]
229	Message(M),
230
231	/// Actor received message.
232	MessageWithSpan(M, tracing::Span),
233}
234
235/// The actual actor instance.
236#[derive(Debug)]
237pub struct ActorInstance<A>
238where
239	A: Actor,
240{
241	handle: ActorHandle<A::Message>,
242	join: TaskHandle<Result<(), ActorError>>,
243}
244impl<A> ActorInstance<A>
245where
246	A: Actor,
247{
248	/// Get actor handle.
249	pub fn handle(&self) -> ActorHandle<A::Message> {
250		self.handle.clone()
251	}
252
253	/// Get actor tags.
254	pub fn tags(&self) -> Tags {
255		self.handle.tags.as_ref().clone()
256	}
257
258	/// Request shutdown.
259	pub fn shutdown(&self) {
260		self.handle().shutdown();
261	}
262
263	/// Wait until the actor completes.
264	pub async fn join(self) -> Result<(), ActorError> {
265		let tags = self.tags();
266		drop(self.handle);
267		self.join.await.map_err(|e| ActorError::InvalidState(e.into(), tags))??;
268		Ok(())
269	}
270
271	/// Wait for startup to be complete and then run in background.
272	/// This will resolve when initialization is done by returning any initialization errors.
273	pub async fn initialized(self) -> Result<ActorHandle<A::Message>, ActorError> {
274		let handle = self.handle();
275		match handle.initialized().await {
276			Ok(_) => Ok(handle),
277			Err(err @ ActorError::InvalidState(_, _)) if self.handle().is_closed() => {
278				// use the orignal initialize error and forward
279				//  this will not block as the actor has been closed already
280				self.join().await?;
281				Err(err)
282			},
283			Err(err) => Err(err),
284		}
285	}
286
287	/// Get actor state.
288	pub fn state(&self) -> ActorState {
289		*self.handle.state.borrow()
290	}
291}
292
293/// Handle into an actor which can be used to send messages.
294pub struct ActorHandle<M> {
295	pub(crate) tx: mpsc::UnboundedSender<ActorMessage<M>>,
296	pub(crate) state: watch::Receiver<ActorState>,
297	pub(crate) tags: Arc<Tags>,
298}
299impl<M> std::fmt::Debug for ActorHandle<M> {
300	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301		f.debug_struct("ActorHandle")
302			.field("message_type", &type_name::<M>())
303			.field("tx_closed", &self.tx.is_closed())
304			.field("state", &self.state.borrow().deref())
305			.field("tags", &self.tags)
306			.finish()
307	}
308}
309impl<M> Clone for ActorHandle<M> {
310	fn clone(&self) -> Self {
311		Self { tx: self.tx.clone(), state: self.state.clone(), tags: self.tags.clone() }
312	}
313}
314impl<M> ActorHandle<M> {
315	/// Create a closed (disconnected) handle useful for tests.
316	pub fn new_closed() -> Self {
317		let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
318		let (_state_tx, state_rx) = watch::channel(ActorState::Stopping);
319		Self { tx, state: state_rx, tags: Arc::new(Tags::default()) }
320	}
321}
322
323impl<M> ActorHandle<M>
324where
325	M: Send + 'static,
326{
327	/// Convert to weak actor handle.
328	pub fn downgrade(self) -> WeakActorHandle<M> {
329		WeakActorHandle { state: self.state, tags: self.tags, tx: self.tx.downgrade() }
330	}
331
332	/// Get actor tags.
333	pub fn tags(&self) -> &Tags {
334		self.tags.as_ref()
335	}
336
337	/// Check if actor is running.
338	///
339	/// Running means not initializing, stopping or stopped.
340	pub fn is_running(&self) -> bool {
341		match *self.state.borrow() {
342			ActorState::Running => !self.tx.is_closed(),
343			_ => false,
344		}
345	}
346
347	/// Check if actor is closed.
348	///
349	/// Closed means not initializing or running.
350	pub fn is_closed(&self) -> bool {
351		match *self.state.borrow() {
352			ActorState::Stopping => true,
353			_ => self.tx.is_closed(),
354		}
355	}
356
357	/// Wait for startup to be complete.
358	pub async fn initialized(&self) -> Result<(), ActorError> {
359		let mut state = self.state.clone();
360		loop {
361			let actor_state = *state.borrow_and_update();
362			match actor_state {
363				ActorState::Starting => {
364					state
365						.changed()
366						.await
367						.map_err(|e| ActorError::InvalidState(e.into(), self.tags().clone()))?;
368				},
369				_ => {
370					break;
371				},
372			}
373		}
374		Ok(())
375	}
376
377	/// Wait for actor shutdown.
378	pub async fn closed(&self) -> Result<(), ActorError> {
379		actor_closed(self.state.clone(), self.tags.clone()).await
380	}
381
382	/// Request shutdown.
383	pub fn shutdown(&self) {
384		self.tx.send(ActorMessage::Shutdown).ok();
385	}
386
387	/// Dispatch message.
388	/// Will only fail when the actor already has been stopped.
389	pub fn dispatch(&self, message: impl Into<M>) -> Result<(), ActorError> {
390		self.tx
391			.send(ActorMessage::MessageWithSpan(message.into(), Span::current()))
392			.map_err(|_| ActorError::InvalidState(anyhow!("Actor not running."), self.tags().clone()))?;
393		Ok(())
394	}
395
396	/// Request with response.
397	#[tracing::instrument(level = tracing::Level::TRACE, err(Debug), skip_all, fields(message_type = type_name::<M>()))]
398	pub async fn request<T>(&self, message: impl FnOnce(Response<T>) -> M) -> Result<T, ActorError> {
399		let (responder, response) = ResponseReceiver::new();
400		self.tx
401			.send(ActorMessage::MessageWithSpan(message(responder), Span::current()))
402			.map_err(|_| ActorError::InvalidState(anyhow!("Actor not running."), self.tags().clone()))?;
403		response.await
404	}
405
406	/// Request with response result.
407	/// If an error is returned in the result it will be wrapped in ´ActorError::Actor`.
408	#[tracing::instrument(level = tracing::Level::TRACE, err(Debug), skip_all, fields(message_type = type_name::<M>()))]
409	pub async fn try_request<T, E>(&self, message: impl FnOnce(Response<Result<T, E>>) -> M) -> Result<T, ActorError>
410	where
411		E: Into<anyhow::Error>,
412	{
413		let (responder, response) = ResponseReceiver::new();
414		self.tx
415			.send(ActorMessage::MessageWithSpan(message(responder), Span::current()))
416			.map_err(|_| ActorError::InvalidState(anyhow!("Actor not running."), self.tags().clone()))?;
417		response
418			.await?
419			.map_err(|err| ActorError::Actor(err.into().context(anyhow!("Actor try request: {}", type_name::<M>()))))
420	}
421
422	/// Request with streaming response.
423	///
424	/// # Errors
425	/// The stream only fails if the stream request could not be sent to the actor because it's not running.
426	/// In this case [`ActorError::InvalidState`] is returned and the stream ends after it.
427	pub fn stream<T>(&self, message: impl FnOnce(ResponseStream<T>) -> M) -> impl Stream<Item = Result<T, ActorError>> {
428		let (responder, response) = ResponseStreamReceiver::new();
429		let send_result = self
430			.tx
431			.send(ActorMessage::MessageWithSpan(message(responder), Span::current()))
432			.map_err(|_| ActorError::InvalidState(anyhow!("Actor not running."), self.tags().clone()));
433		let handle = self.clone();
434		let span = Span::current();
435		async_stream::stream! {
436			// force keep actor alive while stream is running
437			let _handle = handle;
438
439			// fail if send not worked
440			match send_result {
441				Ok(_) => {},
442				Err(err) => {
443					let _span_guard = span.enter();
444					yield Err(err);
445					return;
446				}
447			}
448
449			// forward items
450			for await item in response {
451				let _span_guard = span.enter();
452				yield Ok(item);
453			}
454		}
455	}
456
457	/// Request with streaming response.
458	/// Gracefully ends the stream when the actor is not running.
459	pub fn stream_graceful<T>(&self, message: impl FnOnce(ResponseStream<T>) -> M) -> impl Stream<Item = T> {
460		self.stream(message).filter_map(|item| ready(item.ok()))
461	}
462
463	/// Request with streaming response with back-pressure.
464	pub fn stream_backpressure<T: std::fmt::Debug>(
465		&self,
466		buffer: usize,
467		message: impl FnOnce(ResponseBackPressureStream<T>) -> M,
468	) -> impl Stream<Item = Result<T, ActorError>> {
469		let (responder, response) = ResponseBackPressureStreamReceiver::new(buffer);
470		let send_result = self
471			.tx
472			.send(ActorMessage::MessageWithSpan(message(responder), Span::current()))
473			.map_err(|_| ActorError::InvalidState(anyhow!("Actor not running."), self.tags().clone()));
474		let handle = self.clone();
475		let span = Span::current();
476		async_stream::stream! {
477			// force keep actor alive while stream is running
478			let _handle = handle;
479
480			// fail if send not worked
481			match send_result {
482				Ok(_) => {},
483				Err(err) => {
484					let _span_guard = span.enter();
485					yield Err(err);
486					return;
487				}
488			}
489
490			// forward items
491			for await item in response {
492				match item {
493					Err(ActorError::Canceled) => {
494						break;
495					},
496					item => {
497						let _span_guard = span.enter();
498						yield item;
499					},
500				}
501			}
502		}
503	}
504}
505
506#[derive(Debug)]
507pub struct WeakActorHandle<M> {
508	tx: mpsc::WeakUnboundedSender<ActorMessage<M>>,
509	state: watch::Receiver<ActorState>,
510	tags: Arc<Tags>,
511}
512impl<M> Clone for WeakActorHandle<M> {
513	fn clone(&self) -> Self {
514		Self { tx: self.tx.clone(), state: self.state.clone(), tags: self.tags.clone() }
515	}
516}
517impl<M> WeakActorHandle<M> {
518	pub fn upgrade(self) -> Option<ActorHandle<M>> {
519		Some(ActorHandle { state: self.state, tags: self.tags, tx: self.tx.upgrade()? })
520	}
521
522	/// Wait for actor shutdown.
523	pub async fn closed(&self) -> Result<(), ActorError> {
524		actor_closed(self.state.clone(), self.tags.clone()).await
525	}
526}
527
528/// Wait for actor shutdown.
529async fn actor_closed(mut state: watch::Receiver<ActorState>, tags: Arc<Tags>) -> Result<(), ActorError> {
530	loop {
531		let actor_state = *state.borrow_and_update();
532		match actor_state {
533			ActorState::Starting | ActorState::Running => {
534				state
535					.changed()
536					.await
537					.map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
538			},
539			_ => {
540				break;
541			},
542		}
543	}
544	Ok(())
545}
546// pub trait ActorExt: Actor {
547// 	fn with_epic<E, C>(self, epic: E, context: C) -> EpicActor<Self, C>
548// 	where
549// 		E: Epic<Self::Message, Self::State, C>,
550// 	{
551// 		EpicActor { actor: self, context }
552// 	}
553// }
554
555#[cfg(test)]
556mod tests {
557	use crate::{Actor, ActorError, ActorHandle, Response, ResponseStream, ResponseStreams};
558	use async_trait::async_trait;
559	use co_primitives::Tags;
560	use futures::{StreamExt, TryStreamExt};
561	use std::time::Duration;
562	use tokio::time::timeout;
563
564	#[tokio::test]
565	async fn smoke() {
566		struct Test {}
567		enum TestMessage {
568			Inc(i32),
569			Get(Response<i32>),
570			IncGet(i32, Response<i32>),
571		}
572
573		#[async_trait]
574		impl Actor for Test {
575			type Message = TestMessage;
576			type State = i32;
577			type Initialize = i32;
578
579			async fn initialize(
580				&self,
581				_handle: &ActorHandle<Self::Message>,
582				_tags: &Tags,
583				initialize: Self::Initialize,
584			) -> Result<Self::State, ActorError> {
585				Ok(initialize)
586			}
587
588			async fn handle(
589				&self,
590				_handle: &ActorHandle<Self::Message>,
591				message: Self::Message,
592				state: &mut Self::State,
593			) -> Result<(), ActorError> {
594				match message {
595					TestMessage::Inc(value) => {
596						*state += value;
597					},
598					TestMessage::Get(response) => {
599						response.respond(*state);
600					},
601					TestMessage::IncGet(value, response) => {
602						*state += value;
603						response.respond(*state);
604					},
605				}
606				Ok(())
607			}
608		}
609
610		let actor = Actor::spawn(Default::default(), Test {}, 0).unwrap();
611		let handle = actor.handle();
612		handle.dispatch(TestMessage::Inc(10)).unwrap();
613		handle.dispatch(TestMessage::Inc(-5)).unwrap();
614		let state = handle.request(TestMessage::Get).await.unwrap();
615		assert_eq!(state, 5);
616		let state = handle.request(|r| TestMessage::IncGet(37, r)).await.unwrap();
617		assert_eq!(state, 42);
618	}
619
620	#[tokio::test]
621	async fn test_stream() {
622		struct Test {}
623		enum TestMessage {
624			Inc(i32),
625			Watch(ResponseStream<i32>),
626		}
627		struct TestState {
628			watchers: ResponseStreams<i32>,
629			value: i32,
630		}
631
632		#[async_trait]
633		impl Actor for Test {
634			type Message = TestMessage;
635			type State = TestState;
636			type Initialize = i32;
637
638			async fn initialize(
639				&self,
640				_handle: &ActorHandle<Self::Message>,
641				_tags: &Tags,
642				initialize: Self::Initialize,
643			) -> Result<Self::State, ActorError> {
644				Ok(TestState { watchers: Default::default(), value: initialize })
645			}
646
647			async fn handle(
648				&self,
649				_handle: &ActorHandle<Self::Message>,
650				message: Self::Message,
651				state: &mut Self::State,
652			) -> Result<(), ActorError> {
653				match message {
654					TestMessage::Inc(value) => {
655						state.value += value;
656						state.watchers.send(state.value);
657					},
658					TestMessage::Watch(mut response) => {
659						if response.send(state.value).is_ok() {
660							state.watchers.push(response);
661						}
662					},
663				}
664				Ok(())
665			}
666		}
667
668		let actor = Actor::spawn(Default::default(), Test {}, 0).unwrap();
669		let handle = actor.handle();
670		handle.dispatch(TestMessage::Inc(10)).unwrap();
671		handle.dispatch(TestMessage::Inc(-1)).unwrap();
672		let state = handle.stream(TestMessage::Watch);
673		handle.dispatch(TestMessage::Inc(-4)).unwrap();
674		handle.dispatch(TestMessage::Inc(37)).unwrap();
675		let result: Vec<i32> = state.take(3).try_collect().await.unwrap();
676		assert_eq!(result, vec![9, 5, 42]);
677	}
678
679	#[tokio::test]
680	async fn test_drop_when_no_handles() {
681		struct Test {}
682		enum TestMessage {
683			Inc(i32),
684			Get(Response<i32>),
685		}
686		#[async_trait]
687		impl Actor for Test {
688			type Message = TestMessage;
689			type State = i32;
690			type Initialize = i32;
691
692			async fn initialize(
693				&self,
694				_handle: &ActorHandle<Self::Message>,
695				_tags: &Tags,
696				initialize: Self::Initialize,
697			) -> Result<Self::State, ActorError> {
698				Ok(initialize)
699			}
700
701			async fn handle(
702				&self,
703				_handle: &ActorHandle<Self::Message>,
704				message: Self::Message,
705				state: &mut Self::State,
706			) -> Result<(), ActorError> {
707				match message {
708					TestMessage::Inc(value) => {
709						*state += value;
710					},
711					TestMessage::Get(response) => {
712						response.send(*state).ok();
713					},
714				}
715				Ok(())
716			}
717		}
718
719		// spawn
720		let actor = Actor::spawn(Default::default(), Test {}, 1).unwrap();
721
722		// do some work
723		let handle = actor.handle();
724		handle.dispatch(TestMessage::Inc(10)).unwrap();
725		assert_eq!(handle.request(TestMessage::Get).await.unwrap(), 11);
726
727		// drop handle and wait for shutdown
728		drop(handle);
729		timeout(Duration::from_millis(100), actor.join()).await.unwrap().unwrap();
730	}
731}