actman/
lib.rs

1use tokio::{sync::mpsc, task::JoinHandle};
2
3#[async_trait::async_trait]
4pub trait Actor {
5    type Message;
6
7    async fn run(self, state: State<Self>);
8}
9
10pub struct State<A>
11where
12    A: Actor + ?Sized,
13{
14    pub message_receiver: mpsc::Receiver<A::Message>,
15    pub control_receiver: mpsc::Receiver<Control>,
16}
17
18pub struct Handle<A: Actor> {
19    message_sender: mpsc::Sender<A::Message>,
20}
21
22impl<A: Actor> Clone for Handle<A> {
23    fn clone(&self) -> Self {
24        Self {
25            message_sender: self.message_sender.clone(),
26        }
27    }
28}
29
30impl<A: Actor> Handle<A> {
31    pub async fn send(&self, message: A::Message) {
32        if let Err(e) = self.message_sender.send(message).await {
33            tracing::error!("Failed to send message to handle: {:?}", e);
34        }
35    }
36}
37
38pub enum Control {
39    Shutdown,
40}
41
42pub struct Runner {
43    control_senders: Vec<mpsc::Sender<Control>>,
44    join_handles: Vec<JoinHandle<()>>,
45}
46
47impl Runner {
48    const BACKPRESSURE: usize = 10;
49
50    #[allow(clippy::new_without_default)]
51    pub fn new() -> Self {
52        Self {
53            control_senders: Vec::new(),
54            join_handles: Vec::new(),
55        }
56    }
57
58    pub fn run<A>(&mut self, actor: A) -> Handle<A>
59    where
60        A: Actor + Send + 'static,
61        A::Message: Send + 'static,
62    {
63        let (message_sender, message_receiver) = mpsc::channel(Self::BACKPRESSURE);
64        let (control_sender, control_receiver) = mpsc::channel(1);
65        let join_handle = tokio::spawn(async move {
66            actor
67                .run(State {
68                    message_receiver,
69                    control_receiver,
70                })
71                .await;
72        });
73        self.control_senders.push(control_sender);
74        self.join_handles.push(join_handle);
75
76        Handle::<A> { message_sender }
77    }
78
79    pub async fn shutdown(self) {
80        for control_sender in self.control_senders {
81            if let Err(e) = control_sender.send(Control::Shutdown).await {
82                tracing::error!("Failed to send shutdown control message: {:?}", e);
83            }
84        }
85        for join_handle in self.join_handles {
86            if let Err(e) = join_handle.await {
87                tracing::error!("Failed to join actor: {:?}", e);
88            }
89        }
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use tokio::sync::oneshot;
96
97    use super::*;
98
99    struct Counter {
100        count: usize,
101    }
102
103    impl Counter {
104        fn new() -> Self {
105            Self { count: 0 }
106        }
107    }
108
109    enum CounterMessage {
110        IncrementAndGet { response: oneshot::Sender<usize> },
111    }
112
113    #[async_trait::async_trait]
114    impl Actor for Counter {
115        type Message = CounterMessage;
116
117        async fn run(mut self, mut state: State<Self>) {
118            loop {
119                tokio::select! {
120                    Some(ctrl) = state.control_receiver.recv() => {
121                        match ctrl {
122                            Control::Shutdown => break,
123                        }
124                    }
125                    Some(message) = state.message_receiver.recv() => {
126                        match message {
127                            CounterMessage::IncrementAndGet { response } => {
128                                self.count += 1;
129                                let _ = response.send(self.count);
130                            }
131                        }
132                    }
133                    else => break,
134                }
135            }
136        }
137    }
138
139    struct CounterForwarder {
140        counter_handle: Handle<Counter>,
141    }
142
143    impl CounterForwarder {
144        fn new(counter_handle: Handle<Counter>) -> Self {
145            Self { counter_handle }
146        }
147    }
148
149    #[async_trait::async_trait]
150    impl Actor for CounterForwarder {
151        type Message = CounterMessage;
152
153        async fn run(mut self, mut state: State<Self>) {
154            loop {
155                tokio::select! {
156                    Some(ctrl) = state.control_receiver.recv() => {
157                        match ctrl {
158                            Control::Shutdown => break,
159                        }
160                    }
161                    Some(message) = state.message_receiver.recv() => {
162                        self.counter_handle.send(message).await;
163                    }
164                    else => break,
165                }
166            }
167        }
168    }
169
170    #[tokio::test]
171    async fn actors() {
172        let mut runner = Runner::new();
173        let counter_handle = runner.run(Counter::new());
174        let counter_forwarder_handle = runner.run(CounterForwarder::new(counter_handle.clone()));
175
176        let (sender, receiver) = oneshot::channel();
177        counter_forwarder_handle
178            .send(CounterMessage::IncrementAndGet { response: sender })
179            .await;
180        let count = receiver.await.unwrap();
181        assert_eq!(count, 1);
182
183        runner.shutdown().await;
184    }
185}