1use std::{
5 any::TypeId,
6 sync::mpsc::{channel, Receiver, RecvError, Sender, TryRecvError},
7 thread::spawn,
8};
9
10use log::{error, warn};
11use tokio::runtime::Builder;
12use tokio::sync::mpsc as tokio_mpsc;
13
14use crate::{
15 BoundedMailbox, BoundedTaskMailbox, Envelope, Mailbox, Service, ShutdownServiceMessage,
16 StatsAggregator, TaskService,
17};
18
19pub struct ServiceThread<S: Service> {
21 pub join_handle: ServiceJoinHandle,
22 pub mailbox: Mailbox<S>,
23}
24
25impl<S: Service + Send + 'static> ServiceThread<S> {
26 pub fn spawn_with(service: S) -> Self {
27 let (mailbox, receiver) = Mailbox::create();
28 let (handle_tx, handle_rx) = channel();
29 let join_handle = ServiceJoinHandle::new(handle_rx);
30
31 spawn(move || run(service, receiver, handle_tx));
32
33 ServiceThread {
34 join_handle,
35 mailbox,
36 }
37 }
38
39 pub fn mbox(&self) -> Mailbox<S> {
40 self.mailbox.clone()
41 }
42}
43
44impl<S: Service + 'static> ServiceThread<S> {
45 pub fn spawn_with_init_fn<F: FnOnce() -> S + Send + 'static>(init_fn: F) -> Self {
46 let (mailbox, receiver) = Mailbox::create();
47 let (handle_tx, handle_rx) = channel();
48 let join_handle = ServiceJoinHandle::new(handle_rx);
49
50 spawn(move || {
51 let service = init_fn();
52 run(service, receiver, handle_tx)
53 });
54
55 ServiceThread {
56 join_handle,
57 mailbox,
58 }
59 }
60}
61
62pub struct BoundedServiceThread<S: Service> {
63 pub join_handle: ServiceJoinHandle,
64 pub mailbox: BoundedMailbox<S>,
65}
66
67impl<S: Service + Send + 'static> BoundedServiceThread<S> {
68 pub fn spawn_with(service: S, channel_size: usize) -> Self {
69 let (mailbox, receiver) = BoundedMailbox::create(channel_size);
70 let (handle_tx, handle_rx) = channel();
71 let join_handle = ServiceJoinHandle::new(handle_rx);
72
73 spawn(move || run(service, receiver, handle_tx));
74
75 BoundedServiceThread {
76 join_handle,
77 mailbox,
78 }
79 }
80
81 pub fn mbox(&self) -> BoundedMailbox<S> {
82 self.mailbox.clone()
83 }
84}
85
86impl<S: Service + 'static> BoundedServiceThread<S> {
87 pub fn spawn_with_init_fn<F: FnOnce() -> S + Send + 'static>(
88 init_fn: F,
89 channel_size: usize,
90 ) -> Self {
91 let (mailbox, receiver) = BoundedMailbox::create(channel_size);
92 let (handle_tx, handle_rx) = channel();
93 let join_handle = ServiceJoinHandle::new(handle_rx);
94
95 spawn(move || {
96 let service = init_fn();
97 run(service, receiver, handle_tx)
98 });
99
100 BoundedServiceThread {
101 join_handle,
102 mailbox,
103 }
104 }
105}
106
107fn run<S: Service>(
108 mut service: S,
109 receiver: Receiver<Envelope<S>>,
110 join_handle_tx: Sender<Result<StatsAggregator, &'static str>>,
111) {
112 let mut stats_aggregator = StatsAggregator::new();
113 for mut envelope in receiver {
114 let type_id = envelope.message_type_id();
115 match envelope.deliver_to(&mut service) {
116 Err(_e) => {
117 if let Err(e) = join_handle_tx.send(Err("Message delivery failed")) {
119 error!("ssf delivery failed: {e}");
120 }
121 return;
122 }
123 Ok(stats) => {
124 stats_aggregator.add(&stats);
125 }
126 }
127 if type_id == Some(TypeId::of::<ShutdownServiceMessage>()) {
128 break;
129 }
130 }
131
132 drop(service);
134 if let Err(e) = join_handle_tx.send(Ok(stats_aggregator)) {
135 error!("ssf delivery failed: {e}");
136 }
137}
138
139pub struct BoundedTaskServiceThread<S: TaskService> {
140 pub join_handle: ServiceJoinHandle,
141 pub mailbox: BoundedTaskMailbox<S>,
142}
143
144impl<S: TaskService + Send + 'static> BoundedTaskServiceThread<S> {
145 pub fn spawn_with(service: S, channel_size: usize) -> Self {
146 let (mailbox, receiver) = BoundedTaskMailbox::create(channel_size);
147 let (handle_tx, handle_rx) = channel();
148 let join_handle = ServiceJoinHandle::new(handle_rx);
149
150 spawn(move || {
151 let runtime = match Builder::new_current_thread().enable_io().build() {
152 Ok(runtime) => runtime,
153 Err(e) => {
154 error!("Failed to build task service runtime: {}", e);
155 if let Err(send_err) =
156 handle_tx.send(Err("Failed to start task service runtime"))
157 {
158 error!(
159 "Failed to send task service failure notification: {}",
160 send_err
161 );
162 }
163 return;
164 }
165 };
166 runtime.block_on(async_run(service, receiver, handle_tx));
167 });
168
169 BoundedTaskServiceThread {
170 join_handle,
171 mailbox,
172 }
173 }
174
175 pub fn mbox(&self) -> BoundedTaskMailbox<S> {
176 self.mailbox.clone()
177 }
178}
179
180impl<S: TaskService + 'static> BoundedTaskServiceThread<S> {
181 pub fn spawn_with_init_fn<I>(init_fn: I, channel_size: usize) -> Self
182 where
183 I: FnOnce() -> S + Send + 'static,
184 {
185 let (mailbox, receiver) = BoundedTaskMailbox::create(channel_size);
186 let (handle_tx, handle_rx) = channel();
187 let join_handle = ServiceJoinHandle::new(handle_rx);
188
189 spawn(move || {
190 let service = init_fn();
191 let runtime = Builder::new_current_thread().enable_io().build();
192 match runtime {
193 Ok(runtime) => runtime.block_on(async_run(service, receiver, handle_tx)),
194 Err(e) => error!("Failed to spawn service: {}", e),
195 }
196 });
197
198 BoundedTaskServiceThread {
199 join_handle,
200 mailbox,
201 }
202 }
203}
204
205async fn async_run<S>(
206 mut service: S,
207 mut receiver: tokio_mpsc::Receiver<Envelope<S>>,
208 join_handle_tx: Sender<Result<StatsAggregator, &'static str>>,
209) where
210 S: TaskService,
211{
212 let mut stats_aggregator = StatsAggregator::new();
213
214 if let Err(e) = service.init().await {
215 error!("Failed to initialize task: {}", e);
216 return;
217 }
218
219 loop {
220 tokio::select! {
221 Some(mut envelope) = receiver.recv() => {
222 let type_id = envelope.message_type_id();
223 match envelope.deliver_to(&mut service) {
224 Err(_e) => {
225 if let Err(e) = join_handle_tx.send(Err("Message delivery failed")) {
227 error!("ssf delivery failed: {e}");
228 }
229 return;
230 }
231 Ok(stats) => {
232 stats_aggregator.add(&stats);
233 }
234 }
235 if type_id == Some(TypeId::of::<ShutdownServiceMessage>()) {
236 break;
237 }
238 },
239 result = service.run_task() => {
240 if let Err(e) = result {
241 warn!("Service task failed: {}", e);
242 }
243 }
244 };
245 }
246
247 drop(service);
249 if let Err(e) = join_handle_tx.send(Ok(stats_aggregator)) {
250 error!("ssf delivery failed: {e}");
251 }
252}
253
254pub struct ServiceJoinHandle {
255 rx: Receiver<Result<StatsAggregator, &'static str>>,
256}
257
258impl ServiceJoinHandle {
259 pub fn new(rx: Receiver<Result<StatsAggregator, &'static str>>) -> Self {
260 Self { rx }
261 }
262
263 pub fn join(&mut self) -> Result<StatsAggregator, ServiceJoinHandleError> {
264 self.rx
265 .recv()?
266 .map_err(ServiceJoinHandleError::ServiceFailed)
267 }
268
269 pub fn try_join(&mut self) -> Result<StatsAggregator, ServiceJoinHandleError> {
270 self.rx
271 .try_recv()?
272 .map_err(ServiceJoinHandleError::ServiceFailed)
273 }
274}
275
276#[derive(Debug, PartialEq, Eq)]
277pub enum ServiceJoinHandleError {
278 ServiceStopped,
279 ServiceRunning,
280 ServiceFailed(&'static str),
281}
282
283impl From<RecvError> for ServiceJoinHandleError {
284 fn from(_value: RecvError) -> Self {
285 Self::ServiceStopped
287 }
288}
289
290impl From<TryRecvError> for ServiceJoinHandleError {
291 fn from(value: TryRecvError) -> Self {
292 match value {
293 TryRecvError::Empty => Self::ServiceRunning,
294 TryRecvError::Disconnected => Self::ServiceStopped,
295 }
296 }
297}
298
299#[cfg(test)]
300mod test {
301 use super::*;
302
303 #[test]
304 fn test_join_handle_error_conversion() {
305 assert_eq!(
306 ServiceJoinHandleError::from(RecvError),
307 ServiceJoinHandleError::ServiceStopped
308 );
309 assert_eq!(
310 ServiceJoinHandleError::from(TryRecvError::Empty),
311 ServiceJoinHandleError::ServiceRunning
312 );
313 assert_eq!(
314 ServiceJoinHandleError::from(TryRecvError::Disconnected),
315 ServiceJoinHandleError::ServiceStopped
316 );
317 }
318
319 #[test]
320 fn test_try_join() {
321 let (tx, rx) = channel();
322 let mut join_handle = ServiceJoinHandle::new(rx);
323
324 assert!(matches!(
325 join_handle.try_join(),
326 Err(ServiceJoinHandleError::ServiceRunning)
327 ));
328
329 tx.send(Ok(StatsAggregator::new())).unwrap();
330 assert!(join_handle.try_join().is_ok());
331 }
332
333 #[test]
334 fn test_join() {
335 let (tx, rx) = channel();
336 let mut join_handle = ServiceJoinHandle::new(rx);
337
338 tx.send(Ok(StatsAggregator::new())).unwrap();
339
340 assert!(join_handle.join().is_ok());
341 }
342
343 #[test]
344 fn test_join_dropped() {
345 let (tx, rx) = channel();
346 let mut join_handle = ServiceJoinHandle::new(rx);
347
348 drop(tx);
349 assert!(matches!(
350 join_handle.join(),
351 Err(ServiceJoinHandleError::ServiceStopped)
352 ));
353 }
354}