hyperi_rustlib/concurrency/
actor.rs1use std::future::Future;
31use std::time::Duration;
32
33use tokio::sync::mpsc;
34use tokio::task::JoinHandle;
35use tokio::time::{MissedTickBehavior, interval};
36use tokio_util::sync::CancellationToken;
37
38use super::error::ActorError;
39
40#[derive(Debug, Clone)]
42pub struct ActorConfig {
43 pub queue_capacity: usize,
46
47 pub idle_interval: Duration,
51}
52
53impl Default for ActorConfig {
54 fn default() -> Self {
55 Self {
56 queue_capacity: 1024,
57 idle_interval: Duration::from_mins(1),
58 }
59 }
60}
61
62pub trait Actor: Send + 'static {
67 type Command: Send + 'static;
69
70 fn handle(&mut self, cmd: Self::Command) -> impl Future<Output = ()> + Send;
72
73 fn on_idle(&mut self) -> impl Future<Output = ()> + Send {
77 std::future::ready(())
78 }
79
80 fn on_shutdown(&mut self) -> impl Future<Output = ()> + Send {
83 std::future::ready(())
84 }
85}
86
87#[derive(Debug, Clone)]
91pub struct ActorHandle<Cmd: Send + 'static> {
92 tx: mpsc::Sender<Cmd>,
93}
94
95pub struct ActorJoinHandle {
97 join: JoinHandle<()>,
98}
99
100impl<Cmd: Send + 'static> ActorHandle<Cmd> {
101 pub fn spawn<A: Actor<Command = Cmd>>(
104 actor: A,
105 config: ActorConfig,
106 shutdown: CancellationToken,
107 ) -> (Self, ActorJoinHandle) {
108 let (tx, rx) = mpsc::channel(config.queue_capacity);
109 let join = tokio::spawn(actor_loop(actor, rx, config, shutdown));
110 (Self { tx }, ActorJoinHandle { join })
111 }
112
113 pub async fn send(&self, cmd: Cmd) -> Result<(), ActorError> {
115 self.tx.send(cmd).await.map_err(|_| ActorError::Closed)
116 }
117
118 pub fn try_send(&self, cmd: Cmd) -> Result<(), ActorError> {
123 self.tx.try_send(cmd).map_err(|e| match e {
124 mpsc::error::TrySendError::Full(_) => ActorError::Full,
125 mpsc::error::TrySendError::Closed(_) => ActorError::Closed,
126 })
127 }
128}
129
130impl ActorJoinHandle {
131 pub async fn join(self) -> Result<(), tokio::task::JoinError> {
133 self.join.await
134 }
135}
136
137async fn actor_loop<A: Actor>(
138 mut actor: A,
139 mut rx: mpsc::Receiver<A::Command>,
140 config: ActorConfig,
141 shutdown: CancellationToken,
142) {
143 let mut idle = interval(config.idle_interval);
144 idle.set_missed_tick_behavior(MissedTickBehavior::Delay);
145 idle.tick().await;
147
148 loop {
149 tokio::select! {
150 biased;
151 () = shutdown.cancelled() => {
152 actor.on_shutdown().await;
153 return;
154 }
155 cmd = rx.recv() => if let Some(c) = cmd {
156 actor.handle(c).await;
157 } else {
158 actor.on_shutdown().await;
160 return;
161 },
162 _ = idle.tick() => {
163 actor.on_idle().await;
164 }
165 }
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use std::sync::Arc;
173 use std::sync::atomic::{AtomicU32, Ordering};
174
175 use tokio::sync::oneshot;
176
177 enum Cmd {
178 Increment,
179 Read(oneshot::Sender<u32>),
180 }
181
182 struct Counter {
183 value: u32,
184 }
185
186 impl Actor for Counter {
187 type Command = Cmd;
188
189 async fn handle(&mut self, cmd: Cmd) {
190 match cmd {
191 Cmd::Increment => self.value += 1,
192 Cmd::Read(reply) => {
193 let _ = reply.send(self.value);
194 }
195 }
196 }
197 }
198
199 #[tokio::test]
200 async fn actor_handles_commands_in_order() {
201 let shutdown = CancellationToken::new();
202 let (handle, _join) = ActorHandle::spawn(
203 Counter { value: 0 },
204 ActorConfig::default(),
205 shutdown.clone(),
206 );
207 for _ in 0..10 {
208 handle.send(Cmd::Increment).await.expect("send ok");
209 }
210 let (tx, rx) = oneshot::channel();
211 handle.send(Cmd::Read(tx)).await.expect("send ok");
212 assert_eq!(rx.await.expect("reply"), 10);
213 shutdown.cancel();
214 }
215
216 #[tokio::test]
217 async fn try_send_returns_full_when_saturated() {
218 struct SlowCounter {
219 value: u32,
220 release: Arc<tokio::sync::Notify>,
221 }
222 impl Actor for SlowCounter {
223 type Command = u32;
224 async fn handle(&mut self, _cmd: u32) {
225 self.release.notified().await;
226 self.value += 1;
227 }
228 }
229 let release = Arc::new(tokio::sync::Notify::new());
230 let shutdown = CancellationToken::new();
231 let cfg = ActorConfig {
232 queue_capacity: 4,
233 idle_interval: Duration::from_mins(1),
234 };
235 let (handle, _join) = ActorHandle::spawn(
236 SlowCounter {
237 value: 0,
238 release: release.clone(),
239 },
240 cfg,
241 shutdown.clone(),
242 );
243 let mut full_count = 0;
246 for i in 0..20 {
247 match handle.try_send(i) {
248 Ok(()) => {}
249 Err(ActorError::Full) => full_count += 1,
250 Err(e) => panic!("unexpected: {e}"),
251 }
252 }
253 assert!(full_count >= 10, "got {full_count} Full errors");
254 shutdown.cancel();
255 release.notify_waiters();
256 }
257
258 #[tokio::test]
259 async fn on_shutdown_called_once() {
260 struct ShutdownObserver {
261 called: Arc<AtomicU32>,
262 }
263 impl Actor for ShutdownObserver {
264 type Command = ();
265 async fn handle(&mut self, _cmd: ()) {}
266 async fn on_shutdown(&mut self) {
267 self.called.fetch_add(1, Ordering::SeqCst);
268 }
269 }
270 let called = Arc::new(AtomicU32::new(0));
271 let shutdown = CancellationToken::new();
272 let (_handle, join) = ActorHandle::spawn(
273 ShutdownObserver {
274 called: called.clone(),
275 },
276 ActorConfig::default(),
277 shutdown.clone(),
278 );
279 shutdown.cancel();
280 join.join().await.expect("clean exit");
281 assert_eq!(called.load(Ordering::SeqCst), 1);
282 }
283
284 #[tokio::test]
285 async fn dropping_all_handles_exits_gracefully() {
286 struct ShutdownObserver {
287 called: Arc<AtomicU32>,
288 }
289 impl Actor for ShutdownObserver {
290 type Command = ();
291 async fn handle(&mut self, _cmd: ()) {}
292 async fn on_shutdown(&mut self) {
293 self.called.fetch_add(1, Ordering::SeqCst);
294 }
295 }
296 let called = Arc::new(AtomicU32::new(0));
297 let shutdown = CancellationToken::new();
298 let (handle, join) = ActorHandle::spawn(
299 ShutdownObserver {
300 called: called.clone(),
301 },
302 ActorConfig::default(),
303 shutdown.clone(),
304 );
305 drop(handle);
307 join.join().await.expect("clean exit");
308 assert_eq!(called.load(Ordering::SeqCst), 1);
309 }
310
311 #[tokio::test]
312 async fn idle_tick_fires_when_no_commands() {
313 struct IdleCounter {
314 ticks: Arc<AtomicU32>,
315 }
316 impl Actor for IdleCounter {
317 type Command = ();
318 async fn handle(&mut self, _cmd: ()) {}
319 async fn on_idle(&mut self) {
320 self.ticks.fetch_add(1, Ordering::SeqCst);
321 }
322 }
323 let ticks = Arc::new(AtomicU32::new(0));
324 let shutdown = CancellationToken::new();
325 let cfg = ActorConfig {
326 queue_capacity: 16,
327 idle_interval: Duration::from_millis(20),
328 };
329 let (_handle, _join) = ActorHandle::spawn(
330 IdleCounter {
331 ticks: ticks.clone(),
332 },
333 cfg,
334 shutdown.clone(),
335 );
336 tokio::time::sleep(Duration::from_millis(110)).await;
337 shutdown.cancel();
338 let n = ticks.load(Ordering::SeqCst);
339 assert!((4..=7).contains(&n), "got {n} idle ticks, expected 4-7");
340 }
341}