1use crate::actor::Actor;
2use crate::message::{Envelope, MessageFor};
3use anyhow::Error;
4use async_trait::async_trait;
5use crb_core::{mpsc, watch};
6use crb_runtime::kit::{
7 Context, Controller, Failures, InteractiveRuntime, InteractiveTask, Interruptor,
8 ManagedContext, Runtime, Task,
9};
10
11pub struct DoActor<A: Actor> {
12 pub actor: A,
13 pub context: A::Context,
14 pub failures: Failures,
15}
16
17impl<A: Actor> DoActor<A> {
18 pub fn new(actor: A) -> Self
19 where
20 A::Context: Default,
21 {
22 Self {
23 actor,
24 context: A::Context::default(),
25 failures: Failures::default(),
26 }
27 }
28}
29
30impl<A: Actor> Task<A> for DoActor<A> {}
31impl<A: Actor> InteractiveTask<A> for DoActor<A> {}
32
33#[async_trait]
34impl<A: Actor> InteractiveRuntime for DoActor<A> {
35 type Context = A::Context;
36
37 fn address(&self) -> <Self::Context as Context>::Address {
38 self.context.address().clone()
39 }
40}
41
42#[async_trait]
43impl<A: Actor> Runtime for DoActor<A> {
44 fn get_interruptor(&mut self) -> Interruptor {
45 self.context.controller().interruptor.clone()
46 }
47
48 async fn routine(&mut self) {
49 let result = self.actor.initialize(&mut self.context).await;
50 self.failures.put(result);
51
52 while self.context.session().controller().is_active() {
53 let result = self.actor.event(&mut self.context).await;
54 self.failures.put(result);
55 }
56
57 let result = self.actor.finalize(&mut self.context).await;
58 self.failures.put(result);
59
60 let result = self
61 .context
62 .session()
63 .joint
64 .status_tx
65 .send(ActorStatus::Done)
66 .map_err(|_| Error::msg("Can't set actor's status to `Done`"));
67 self.failures.put(result);
68 }
69}
70
71#[derive(PartialEq, Eq)]
72pub enum ActorStatus {
73 Active,
74 Done,
75}
76
77impl ActorStatus {
78 pub fn is_done(&self) -> bool {
79 *self == Self::Done
80 }
81}
82
83pub struct AddressJoint<A> {
84 msg_rx: mpsc::UnboundedReceiver<Envelope<A>>,
85 status_tx: watch::Sender<ActorStatus>,
86}
87
88impl<A> AddressJoint<A> {
89 pub async fn next_envelope(&mut self) -> Option<Envelope<A>> {
90 self.msg_rx.recv().await
91 }
92}
93
94pub struct ActorSession<A> {
95 joint: AddressJoint<A>,
96 controller: Controller,
97 address: Address<A>,
98}
99
100impl<A> Default for ActorSession<A> {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106impl<A> ActorSession<A> {
107 pub fn new() -> Self {
108 let (msg_tx, msg_rx) = mpsc::unbounded_channel();
109 let (status_tx, status_rx) = watch::channel(ActorStatus::Active);
110 let controller = Controller::default();
111 let address = Address { msg_tx, status_rx };
112 let joint = AddressJoint { msg_rx, status_tx };
113 Self {
114 joint,
115 controller,
116 address,
117 }
118 }
119
120 pub fn joint(&mut self) -> &mut AddressJoint<A> {
121 &mut self.joint
122 }
123}
124
125impl<T> Context for ActorSession<T> {
126 type Address = Address<T>;
127
128 fn address(&self) -> &Self::Address {
129 &self.address
130 }
131}
132
133impl<T> ManagedContext for ActorSession<T> {
134 fn controller(&mut self) -> &mut Controller {
135 &mut self.controller
136 }
137
138 fn shutdown(&mut self) {
139 self.joint.msg_rx.close();
140 }
141}
142
143pub trait ActorContext<T>: Context<Address = Address<T>> + ManagedContext {
144 fn session(&mut self) -> &mut ActorSession<T>;
145}
146
147impl<T> ActorContext<T> for ActorSession<T> {
148 fn session(&mut self) -> &mut ActorSession<T> {
149 self
150 }
151}
152
153pub struct Address<A: ?Sized> {
154 msg_tx: mpsc::UnboundedSender<Envelope<A>>,
155 status_rx: watch::Receiver<ActorStatus>,
156}
157
158impl<A: Actor> Address<A> {
159 pub fn send(&self, msg: impl MessageFor<A>) -> Result<(), Error> {
160 self.msg_tx
161 .send(Box::new(msg))
162 .map_err(|_| Error::msg("Can't send the message to the actor"))
163 }
164
165 pub async fn join(&mut self) -> Result<(), Error> {
166 self.status_rx.wait_for(ActorStatus::is_done).await?;
167 Ok(())
168 }
169}
170
171impl<A> Clone for Address<A> {
172 fn clone(&self) -> Self {
173 Self {
174 msg_tx: self.msg_tx.clone(),
175 status_rx: self.status_rx.clone(),
176 }
177 }
178}