1use crate::{
5 Response, ResponseBackPressureStream, ResponseBackPressureStreamReceiver, ResponseReceiver, ResponseStream,
6 ResponseStreamReceiver, TaskHandle, TaskOptions, TaskSpawner,
7};
8use anyhow::anyhow;
9use async_trait::async_trait;
10use co_primitives::Tags;
11use futures::{Stream, StreamExt};
12use std::{any::type_name, future::ready, ops::Deref, sync::Arc};
13use tokio::sync::{mpsc, watch};
14#[cfg(feature = "js")]
15use tokio_with_wasm::alias as tokio;
16use tracing::{Instrument, Span};
17
18#[derive(Debug, thiserror::Error)]
19pub enum ActorError {
20 #[error("Invalid actor state for that operation ({1}).")]
21 InvalidState(#[source] anyhow::Error, Tags),
22
23 #[error("Operation canceled.")]
24 Canceled,
25
26 #[error("Actor error")]
27 Actor(#[from] anyhow::Error),
28}
29
30#[async_trait]
35pub trait Actor: Send + Sync + 'static {
36 type Message: Send + 'static;
37 type State: Send + 'static;
38 type Initialize: Send + 'static;
39
40 async fn initialize(
41 &self,
42 handle: &ActorHandle<Self::Message>,
43 tags: &Tags,
44 initialize: Self::Initialize,
45 ) -> Result<Self::State, ActorError>;
46
47 async fn handle(
48 &self,
49 handle: &ActorHandle<Self::Message>,
50 message: Self::Message,
51 state: &mut Self::State,
52 ) -> Result<(), ActorError>;
53
54 fn tags(&self, tags: Tags) -> Result<Tags, ActorError> {
55 Ok(tags)
56 }
57
58 async fn shutdown(&self, _state: Self::State) -> Result<(), ActorError> {
63 Ok(())
64 }
65
66 fn spawner(tags: Tags, actor: Self) -> Result<ActorSpawner<Self>, ActorError>
67 where
68 Self: Send + Sync + Sized + 'static,
69 {
70 ActorSpawner::new(tags, actor)
71 }
72
73 #[track_caller]
75 fn spawn(tags: Tags, actor: Self, initialize: Self::Initialize) -> Result<ActorInstance<Self>, ActorError>
76 where
77 Self: Send + Sync + Sized + 'static,
78 {
79 Self::spawn_with(Default::default(), tags, actor, initialize)
80 }
81
82 #[track_caller]
84 fn spawn_with(
85 spawner: TaskSpawner,
86 tags: Tags,
87 actor: Self,
88 initialize: Self::Initialize,
89 ) -> Result<ActorInstance<Self>, ActorError>
90 where
91 Self: Send + Sync + Sized + 'static,
92 {
93 Ok(Self::spawner(tags, actor)?.spawn(spawner, initialize))
94 }
95}
96
97pub struct ActorSpawner<A>
99where
100 A: Actor,
101{
102 handle: ActorHandle<A::Message>,
103 actor: A,
104 rx: tokio::sync::mpsc::UnboundedReceiver<ActorMessage<A::Message>>,
105 state_tx: tokio::sync::watch::Sender<ActorState>,
106 options: TaskOptions,
107}
108impl<A> ActorSpawner<A>
109where
110 A: Actor,
111{
112 pub fn new(tags: Tags, actor: A) -> Result<Self, ActorError> {
113 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
114 let (state_tx, state_rx) = watch::channel(ActorState::Starting);
115 let tags = Arc::new(actor.tags(tags)?);
116 let handle = ActorHandle { tx: tx.clone(), state: state_rx.clone(), tags: tags.clone() };
117 Ok(Self { handle, actor, rx, state_tx, options: TaskOptions::new(type_name::<A>()) })
118 }
119
120 pub fn handle(&self) -> ActorHandle<A::Message> {
121 self.handle.clone()
122 }
123
124 #[track_caller]
125 pub fn spawn(self, spawner: TaskSpawner, initialize: A::Initialize) -> ActorInstance<A> {
126 let mut rx = self.rx;
127 let state_tx = self.state_tx;
128 let actor = self.actor;
129 let tags = self.handle.tags.clone();
130 let handle = self.handle;
131 let span = tracing::trace_span!("actor", ?tags, actor_type = type_name::<A>());
132 let join = spawner.spawn_options(self.options, {
133 let tags = tags.clone();
134 let handle = handle.clone();
135 let actor_span = span.clone();
136 async move {
137 tracing::trace!("actor-initialize");
139
140 let mut actor_state = actor.initialize(&handle, &tags, initialize).await.map_err(|err| {
142 tracing::error!(?err, "actor-initialize-failed");
143 err
144 })?;
145 state_tx
146 .send(ActorState::Running)
147 .map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
148
149 let weak_handle = handle.downgrade();
151 while let Some(actor_message) = rx.recv().await {
152 let (message, message_span, _parent_span) = match actor_message {
154 ActorMessage::Message(message) => (message, tracing::trace_span!("actor-handle"), None),
155 ActorMessage::MessageWithSpan(message, message_span) => {
156 (message, tracing::trace_span!(parent: &message_span, "actor-handle"), Some(message_span))
157 },
158 ActorMessage::Shutdown => {
159 tracing::trace!("actor-shutdown");
161
162 break;
164 },
165 };
166 message_span.follows_from(&actor_span);
167
168 if let Some(handle) = weak_handle.clone().upgrade() {
171 actor
172 .handle(&handle, message, &mut actor_state)
173 .instrument(message_span)
174 .await
175 .map_err(|err| {
176 tracing::error!(?err, "actor-handle-failed");
177 err
178 })?;
179 }
180 }
181
182 state_tx
184 .send(ActorState::Stopping)
185 .map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
186 rx.close();
187
188 actor.shutdown(actor_state).await.map_err(|err| {
190 tracing::error!(?err, ?tags, "actor-shutdown-failed");
191 err
192 })?;
193
194 state_tx
196 .send(ActorState::None)
197 .map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
198 Ok(())
199 }
200 .instrument(span)
201 });
202 ActorInstance { join, handle }
203 }
204}
205
206#[derive(Debug, Clone, Eq, PartialEq, Copy)]
207#[repr(u8)]
208pub enum ActorState {
209 Starting,
211
212 Running,
214
215 Stopping,
217
218 None,
220}
221
222#[derive(Debug)]
223pub enum ActorMessage<M> {
224 Shutdown,
226
227 #[allow(unused)]
229 Message(M),
230
231 MessageWithSpan(M, tracing::Span),
233}
234
235#[derive(Debug)]
237pub struct ActorInstance<A>
238where
239 A: Actor,
240{
241 handle: ActorHandle<A::Message>,
242 join: TaskHandle<Result<(), ActorError>>,
243}
244impl<A> ActorInstance<A>
245where
246 A: Actor,
247{
248 pub fn handle(&self) -> ActorHandle<A::Message> {
250 self.handle.clone()
251 }
252
253 pub fn tags(&self) -> Tags {
255 self.handle.tags.as_ref().clone()
256 }
257
258 pub fn shutdown(&self) {
260 self.handle().shutdown();
261 }
262
263 pub async fn join(self) -> Result<(), ActorError> {
265 let tags = self.tags();
266 drop(self.handle);
267 self.join.await.map_err(|e| ActorError::InvalidState(e.into(), tags))??;
268 Ok(())
269 }
270
271 pub async fn initialized(self) -> Result<ActorHandle<A::Message>, ActorError> {
274 let handle = self.handle();
275 match handle.initialized().await {
276 Ok(_) => Ok(handle),
277 Err(err @ ActorError::InvalidState(_, _)) if self.handle().is_closed() => {
278 self.join().await?;
281 Err(err)
282 },
283 Err(err) => Err(err),
284 }
285 }
286
287 pub fn state(&self) -> ActorState {
289 *self.handle.state.borrow()
290 }
291}
292
293pub struct ActorHandle<M> {
295 pub(crate) tx: mpsc::UnboundedSender<ActorMessage<M>>,
296 pub(crate) state: watch::Receiver<ActorState>,
297 pub(crate) tags: Arc<Tags>,
298}
299impl<M> std::fmt::Debug for ActorHandle<M> {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 f.debug_struct("ActorHandle")
302 .field("message_type", &type_name::<M>())
303 .field("tx_closed", &self.tx.is_closed())
304 .field("state", &self.state.borrow().deref())
305 .field("tags", &self.tags)
306 .finish()
307 }
308}
309impl<M> Clone for ActorHandle<M> {
310 fn clone(&self) -> Self {
311 Self { tx: self.tx.clone(), state: self.state.clone(), tags: self.tags.clone() }
312 }
313}
314impl<M> ActorHandle<M> {
315 pub fn new_closed() -> Self {
317 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
318 let (_state_tx, state_rx) = watch::channel(ActorState::Stopping);
319 Self { tx, state: state_rx, tags: Arc::new(Tags::default()) }
320 }
321}
322
323impl<M> ActorHandle<M>
324where
325 M: Send + 'static,
326{
327 pub fn downgrade(self) -> WeakActorHandle<M> {
329 WeakActorHandle { state: self.state, tags: self.tags, tx: self.tx.downgrade() }
330 }
331
332 pub fn tags(&self) -> &Tags {
334 self.tags.as_ref()
335 }
336
337 pub fn is_running(&self) -> bool {
341 match *self.state.borrow() {
342 ActorState::Running => !self.tx.is_closed(),
343 _ => false,
344 }
345 }
346
347 pub fn is_closed(&self) -> bool {
351 match *self.state.borrow() {
352 ActorState::Stopping => true,
353 _ => self.tx.is_closed(),
354 }
355 }
356
357 pub async fn initialized(&self) -> Result<(), ActorError> {
359 let mut state = self.state.clone();
360 loop {
361 let actor_state = *state.borrow_and_update();
362 match actor_state {
363 ActorState::Starting => {
364 state
365 .changed()
366 .await
367 .map_err(|e| ActorError::InvalidState(e.into(), self.tags().clone()))?;
368 },
369 _ => {
370 break;
371 },
372 }
373 }
374 Ok(())
375 }
376
377 pub async fn closed(&self) -> Result<(), ActorError> {
379 actor_closed(self.state.clone(), self.tags.clone()).await
380 }
381
382 pub fn shutdown(&self) {
384 self.tx.send(ActorMessage::Shutdown).ok();
385 }
386
387 pub fn dispatch(&self, message: impl Into<M>) -> Result<(), ActorError> {
390 self.tx
391 .send(ActorMessage::MessageWithSpan(message.into(), Span::current()))
392 .map_err(|_| ActorError::InvalidState(anyhow!("Actor not running."), self.tags().clone()))?;
393 Ok(())
394 }
395
396 #[tracing::instrument(level = tracing::Level::TRACE, err(Debug), skip_all, fields(message_type = type_name::<M>()))]
398 pub async fn request<T>(&self, message: impl FnOnce(Response<T>) -> M) -> Result<T, ActorError> {
399 let (responder, response) = ResponseReceiver::new();
400 self.tx
401 .send(ActorMessage::MessageWithSpan(message(responder), Span::current()))
402 .map_err(|_| ActorError::InvalidState(anyhow!("Actor not running."), self.tags().clone()))?;
403 response.await
404 }
405
406 #[tracing::instrument(level = tracing::Level::TRACE, err(Debug), skip_all, fields(message_type = type_name::<M>()))]
409 pub async fn try_request<T, E>(&self, message: impl FnOnce(Response<Result<T, E>>) -> M) -> Result<T, ActorError>
410 where
411 E: Into<anyhow::Error>,
412 {
413 let (responder, response) = ResponseReceiver::new();
414 self.tx
415 .send(ActorMessage::MessageWithSpan(message(responder), Span::current()))
416 .map_err(|_| ActorError::InvalidState(anyhow!("Actor not running."), self.tags().clone()))?;
417 response
418 .await?
419 .map_err(|err| ActorError::Actor(err.into().context(anyhow!("Actor try request: {}", type_name::<M>()))))
420 }
421
422 pub fn stream<T>(&self, message: impl FnOnce(ResponseStream<T>) -> M) -> impl Stream<Item = Result<T, ActorError>> {
428 let (responder, response) = ResponseStreamReceiver::new();
429 let send_result = self
430 .tx
431 .send(ActorMessage::MessageWithSpan(message(responder), Span::current()))
432 .map_err(|_| ActorError::InvalidState(anyhow!("Actor not running."), self.tags().clone()));
433 let handle = self.clone();
434 let span = Span::current();
435 async_stream::stream! {
436 let _handle = handle;
438
439 match send_result {
441 Ok(_) => {},
442 Err(err) => {
443 let _span_guard = span.enter();
444 yield Err(err);
445 return;
446 }
447 }
448
449 for await item in response {
451 let _span_guard = span.enter();
452 yield Ok(item);
453 }
454 }
455 }
456
457 pub fn stream_graceful<T>(&self, message: impl FnOnce(ResponseStream<T>) -> M) -> impl Stream<Item = T> {
460 self.stream(message).filter_map(|item| ready(item.ok()))
461 }
462
463 pub fn stream_backpressure<T: std::fmt::Debug>(
465 &self,
466 buffer: usize,
467 message: impl FnOnce(ResponseBackPressureStream<T>) -> M,
468 ) -> impl Stream<Item = Result<T, ActorError>> {
469 let (responder, response) = ResponseBackPressureStreamReceiver::new(buffer);
470 let send_result = self
471 .tx
472 .send(ActorMessage::MessageWithSpan(message(responder), Span::current()))
473 .map_err(|_| ActorError::InvalidState(anyhow!("Actor not running."), self.tags().clone()));
474 let handle = self.clone();
475 let span = Span::current();
476 async_stream::stream! {
477 let _handle = handle;
479
480 match send_result {
482 Ok(_) => {},
483 Err(err) => {
484 let _span_guard = span.enter();
485 yield Err(err);
486 return;
487 }
488 }
489
490 for await item in response {
492 match item {
493 Err(ActorError::Canceled) => {
494 break;
495 },
496 item => {
497 let _span_guard = span.enter();
498 yield item;
499 },
500 }
501 }
502 }
503 }
504}
505
506#[derive(Debug)]
507pub struct WeakActorHandle<M> {
508 tx: mpsc::WeakUnboundedSender<ActorMessage<M>>,
509 state: watch::Receiver<ActorState>,
510 tags: Arc<Tags>,
511}
512impl<M> Clone for WeakActorHandle<M> {
513 fn clone(&self) -> Self {
514 Self { tx: self.tx.clone(), state: self.state.clone(), tags: self.tags.clone() }
515 }
516}
517impl<M> WeakActorHandle<M> {
518 pub fn upgrade(self) -> Option<ActorHandle<M>> {
519 Some(ActorHandle { state: self.state, tags: self.tags, tx: self.tx.upgrade()? })
520 }
521
522 pub async fn closed(&self) -> Result<(), ActorError> {
524 actor_closed(self.state.clone(), self.tags.clone()).await
525 }
526}
527
528async fn actor_closed(mut state: watch::Receiver<ActorState>, tags: Arc<Tags>) -> Result<(), ActorError> {
530 loop {
531 let actor_state = *state.borrow_and_update();
532 match actor_state {
533 ActorState::Starting | ActorState::Running => {
534 state
535 .changed()
536 .await
537 .map_err(|e| ActorError::InvalidState(e.into(), tags.as_ref().clone()))?;
538 },
539 _ => {
540 break;
541 },
542 }
543 }
544 Ok(())
545}
546#[cfg(test)]
556mod tests {
557 use crate::{Actor, ActorError, ActorHandle, Response, ResponseStream, ResponseStreams};
558 use async_trait::async_trait;
559 use co_primitives::Tags;
560 use futures::{StreamExt, TryStreamExt};
561 use std::time::Duration;
562 use tokio::time::timeout;
563
564 #[tokio::test]
565 async fn smoke() {
566 struct Test {}
567 enum TestMessage {
568 Inc(i32),
569 Get(Response<i32>),
570 IncGet(i32, Response<i32>),
571 }
572
573 #[async_trait]
574 impl Actor for Test {
575 type Message = TestMessage;
576 type State = i32;
577 type Initialize = i32;
578
579 async fn initialize(
580 &self,
581 _handle: &ActorHandle<Self::Message>,
582 _tags: &Tags,
583 initialize: Self::Initialize,
584 ) -> Result<Self::State, ActorError> {
585 Ok(initialize)
586 }
587
588 async fn handle(
589 &self,
590 _handle: &ActorHandle<Self::Message>,
591 message: Self::Message,
592 state: &mut Self::State,
593 ) -> Result<(), ActorError> {
594 match message {
595 TestMessage::Inc(value) => {
596 *state += value;
597 },
598 TestMessage::Get(response) => {
599 response.respond(*state);
600 },
601 TestMessage::IncGet(value, response) => {
602 *state += value;
603 response.respond(*state);
604 },
605 }
606 Ok(())
607 }
608 }
609
610 let actor = Actor::spawn(Default::default(), Test {}, 0).unwrap();
611 let handle = actor.handle();
612 handle.dispatch(TestMessage::Inc(10)).unwrap();
613 handle.dispatch(TestMessage::Inc(-5)).unwrap();
614 let state = handle.request(TestMessage::Get).await.unwrap();
615 assert_eq!(state, 5);
616 let state = handle.request(|r| TestMessage::IncGet(37, r)).await.unwrap();
617 assert_eq!(state, 42);
618 }
619
620 #[tokio::test]
621 async fn test_stream() {
622 struct Test {}
623 enum TestMessage {
624 Inc(i32),
625 Watch(ResponseStream<i32>),
626 }
627 struct TestState {
628 watchers: ResponseStreams<i32>,
629 value: i32,
630 }
631
632 #[async_trait]
633 impl Actor for Test {
634 type Message = TestMessage;
635 type State = TestState;
636 type Initialize = i32;
637
638 async fn initialize(
639 &self,
640 _handle: &ActorHandle<Self::Message>,
641 _tags: &Tags,
642 initialize: Self::Initialize,
643 ) -> Result<Self::State, ActorError> {
644 Ok(TestState { watchers: Default::default(), value: initialize })
645 }
646
647 async fn handle(
648 &self,
649 _handle: &ActorHandle<Self::Message>,
650 message: Self::Message,
651 state: &mut Self::State,
652 ) -> Result<(), ActorError> {
653 match message {
654 TestMessage::Inc(value) => {
655 state.value += value;
656 state.watchers.send(state.value);
657 },
658 TestMessage::Watch(mut response) => {
659 if response.send(state.value).is_ok() {
660 state.watchers.push(response);
661 }
662 },
663 }
664 Ok(())
665 }
666 }
667
668 let actor = Actor::spawn(Default::default(), Test {}, 0).unwrap();
669 let handle = actor.handle();
670 handle.dispatch(TestMessage::Inc(10)).unwrap();
671 handle.dispatch(TestMessage::Inc(-1)).unwrap();
672 let state = handle.stream(TestMessage::Watch);
673 handle.dispatch(TestMessage::Inc(-4)).unwrap();
674 handle.dispatch(TestMessage::Inc(37)).unwrap();
675 let result: Vec<i32> = state.take(3).try_collect().await.unwrap();
676 assert_eq!(result, vec![9, 5, 42]);
677 }
678
679 #[tokio::test]
680 async fn test_drop_when_no_handles() {
681 struct Test {}
682 enum TestMessage {
683 Inc(i32),
684 Get(Response<i32>),
685 }
686 #[async_trait]
687 impl Actor for Test {
688 type Message = TestMessage;
689 type State = i32;
690 type Initialize = i32;
691
692 async fn initialize(
693 &self,
694 _handle: &ActorHandle<Self::Message>,
695 _tags: &Tags,
696 initialize: Self::Initialize,
697 ) -> Result<Self::State, ActorError> {
698 Ok(initialize)
699 }
700
701 async fn handle(
702 &self,
703 _handle: &ActorHandle<Self::Message>,
704 message: Self::Message,
705 state: &mut Self::State,
706 ) -> Result<(), ActorError> {
707 match message {
708 TestMessage::Inc(value) => {
709 *state += value;
710 },
711 TestMessage::Get(response) => {
712 response.send(*state).ok();
713 },
714 }
715 Ok(())
716 }
717 }
718
719 let actor = Actor::spawn(Default::default(), Test {}, 1).unwrap();
721
722 let handle = actor.handle();
724 handle.dispatch(TestMessage::Inc(10)).unwrap();
725 assert_eq!(handle.request(TestMessage::Get).await.unwrap(), 11);
726
727 drop(handle);
729 timeout(Duration::from_millis(100), actor.join()).await.unwrap().unwrap();
730 }
731}