1use tokio::{sync::mpsc, task::JoinHandle};
2
3#[async_trait::async_trait]
4pub trait Actor {
5 type Message;
6
7 async fn run(self, state: State<Self>);
8}
9
10pub struct State<A>
11where
12 A: Actor + ?Sized,
13{
14 pub message_receiver: mpsc::Receiver<A::Message>,
15 pub control_receiver: mpsc::Receiver<Control>,
16}
17
18pub struct Handle<A: Actor> {
19 message_sender: mpsc::Sender<A::Message>,
20}
21
22impl<A: Actor> Clone for Handle<A> {
23 fn clone(&self) -> Self {
24 Self {
25 message_sender: self.message_sender.clone(),
26 }
27 }
28}
29
30impl<A: Actor> Handle<A> {
31 pub async fn send(&self, message: A::Message) {
32 if let Err(e) = self.message_sender.send(message).await {
33 tracing::error!("Failed to send message to handle: {:?}", e);
34 }
35 }
36}
37
38pub enum Control {
39 Shutdown,
40}
41
42pub struct Runner {
43 control_senders: Vec<mpsc::Sender<Control>>,
44 join_handles: Vec<JoinHandle<()>>,
45}
46
47impl Runner {
48 const BACKPRESSURE: usize = 10;
49
50 #[allow(clippy::new_without_default)]
51 pub fn new() -> Self {
52 Self {
53 control_senders: Vec::new(),
54 join_handles: Vec::new(),
55 }
56 }
57
58 pub fn run<A>(&mut self, actor: A) -> Handle<A>
59 where
60 A: Actor + Send + 'static,
61 A::Message: Send + 'static,
62 {
63 let (message_sender, message_receiver) = mpsc::channel(Self::BACKPRESSURE);
64 let (control_sender, control_receiver) = mpsc::channel(1);
65 let join_handle = tokio::spawn(async move {
66 actor
67 .run(State {
68 message_receiver,
69 control_receiver,
70 })
71 .await;
72 });
73 self.control_senders.push(control_sender);
74 self.join_handles.push(join_handle);
75
76 Handle::<A> { message_sender }
77 }
78
79 pub async fn shutdown(self) {
80 for control_sender in self.control_senders {
81 if let Err(e) = control_sender.send(Control::Shutdown).await {
82 tracing::error!("Failed to send shutdown control message: {:?}", e);
83 }
84 }
85 for join_handle in self.join_handles {
86 if let Err(e) = join_handle.await {
87 tracing::error!("Failed to join actor: {:?}", e);
88 }
89 }
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use tokio::sync::oneshot;
96
97 use super::*;
98
99 struct Counter {
100 count: usize,
101 }
102
103 impl Counter {
104 fn new() -> Self {
105 Self { count: 0 }
106 }
107 }
108
109 enum CounterMessage {
110 IncrementAndGet { response: oneshot::Sender<usize> },
111 }
112
113 #[async_trait::async_trait]
114 impl Actor for Counter {
115 type Message = CounterMessage;
116
117 async fn run(mut self, mut state: State<Self>) {
118 loop {
119 tokio::select! {
120 Some(ctrl) = state.control_receiver.recv() => {
121 match ctrl {
122 Control::Shutdown => break,
123 }
124 }
125 Some(message) = state.message_receiver.recv() => {
126 match message {
127 CounterMessage::IncrementAndGet { response } => {
128 self.count += 1;
129 let _ = response.send(self.count);
130 }
131 }
132 }
133 else => break,
134 }
135 }
136 }
137 }
138
139 struct CounterForwarder {
140 counter_handle: Handle<Counter>,
141 }
142
143 impl CounterForwarder {
144 fn new(counter_handle: Handle<Counter>) -> Self {
145 Self { counter_handle }
146 }
147 }
148
149 #[async_trait::async_trait]
150 impl Actor for CounterForwarder {
151 type Message = CounterMessage;
152
153 async fn run(mut self, mut state: State<Self>) {
154 loop {
155 tokio::select! {
156 Some(ctrl) = state.control_receiver.recv() => {
157 match ctrl {
158 Control::Shutdown => break,
159 }
160 }
161 Some(message) = state.message_receiver.recv() => {
162 self.counter_handle.send(message).await;
163 }
164 else => break,
165 }
166 }
167 }
168 }
169
170 #[tokio::test]
171 async fn actors() {
172 let mut runner = Runner::new();
173 let counter_handle = runner.run(Counter::new());
174 let counter_forwarder_handle = runner.run(CounterForwarder::new(counter_handle.clone()));
175
176 let (sender, receiver) = oneshot::channel();
177 counter_forwarder_handle
178 .send(CounterMessage::IncrementAndGet { response: sender })
179 .await;
180 let count = receiver.await.unwrap();
181 assert_eq!(count, 1);
182
183 runner.shutdown().await;
184 }
185}