memfault_ssf/
service_thread.rs

1//
2// Copyright (c) Memfault, Inc.
3// See License.txt for details
4use std::{
5    any::TypeId,
6    sync::mpsc::{channel, Receiver, RecvError, Sender, TryRecvError},
7    thread::spawn,
8};
9
10use log::{error, warn};
11use tokio::runtime::Builder;
12use tokio::sync::mpsc as tokio_mpsc;
13
14use crate::{
15    BoundedMailbox, BoundedTaskMailbox, Envelope, Mailbox, Service, ShutdownServiceMessage,
16    StatsAggregator, TaskService,
17};
18
19/// Run a service inside a dedicated thread using a mpsc::channel to send/receive messages
20pub struct ServiceThread<S: Service> {
21    pub join_handle: ServiceJoinHandle,
22    pub mailbox: Mailbox<S>,
23}
24
25impl<S: Service + Send + 'static> ServiceThread<S> {
26    pub fn spawn_with(service: S) -> Self {
27        let (mailbox, receiver) = Mailbox::create();
28        let (handle_tx, handle_rx) = channel();
29        let join_handle = ServiceJoinHandle::new(handle_rx);
30
31        spawn(move || run(service, receiver, handle_tx));
32
33        ServiceThread {
34            join_handle,
35            mailbox,
36        }
37    }
38
39    pub fn mbox(&self) -> Mailbox<S> {
40        self.mailbox.clone()
41    }
42}
43
44impl<S: Service + 'static> ServiceThread<S> {
45    pub fn spawn_with_init_fn<F: FnOnce() -> S + Send + 'static>(init_fn: F) -> Self {
46        let (mailbox, receiver) = Mailbox::create();
47        let (handle_tx, handle_rx) = channel();
48        let join_handle = ServiceJoinHandle::new(handle_rx);
49
50        spawn(move || {
51            let service = init_fn();
52            run(service, receiver, handle_tx)
53        });
54
55        ServiceThread {
56            join_handle,
57            mailbox,
58        }
59    }
60}
61
62pub struct BoundedServiceThread<S: Service> {
63    pub join_handle: ServiceJoinHandle,
64    pub mailbox: BoundedMailbox<S>,
65}
66
67impl<S: Service + Send + 'static> BoundedServiceThread<S> {
68    pub fn spawn_with(service: S, channel_size: usize) -> Self {
69        let (mailbox, receiver) = BoundedMailbox::create(channel_size);
70        let (handle_tx, handle_rx) = channel();
71        let join_handle = ServiceJoinHandle::new(handle_rx);
72
73        spawn(move || run(service, receiver, handle_tx));
74
75        BoundedServiceThread {
76            join_handle,
77            mailbox,
78        }
79    }
80
81    pub fn mbox(&self) -> BoundedMailbox<S> {
82        self.mailbox.clone()
83    }
84}
85
86impl<S: Service + 'static> BoundedServiceThread<S> {
87    pub fn spawn_with_init_fn<F: FnOnce() -> S + Send + 'static>(
88        init_fn: F,
89        channel_size: usize,
90    ) -> Self {
91        let (mailbox, receiver) = BoundedMailbox::create(channel_size);
92        let (handle_tx, handle_rx) = channel();
93        let join_handle = ServiceJoinHandle::new(handle_rx);
94
95        spawn(move || {
96            let service = init_fn();
97            run(service, receiver, handle_tx)
98        });
99
100        BoundedServiceThread {
101            join_handle,
102            mailbox,
103        }
104    }
105}
106
107fn run<S: Service>(
108    mut service: S,
109    receiver: Receiver<Envelope<S>>,
110    join_handle_tx: Sender<Result<StatsAggregator, &'static str>>,
111) {
112    let mut stats_aggregator = StatsAggregator::new();
113    for mut envelope in receiver {
114        let type_id = envelope.message_type_id();
115        match envelope.deliver_to(&mut service) {
116            Err(_e) => {
117                // Delivery failed - probably "attempt to deliver twice" - should never happen.
118                if let Err(e) = join_handle_tx.send(Err("Message delivery failed")) {
119                    error!("ssf delivery failed: {e}");
120                }
121                return;
122            }
123            Ok(stats) => {
124                stats_aggregator.add(&stats);
125            }
126        }
127        if type_id == Some(TypeId::of::<ShutdownServiceMessage>()) {
128            break;
129        }
130    }
131
132    // drop service, and send message indicating the the service thread is closed
133    drop(service);
134    if let Err(e) = join_handle_tx.send(Ok(stats_aggregator)) {
135        error!("ssf delivery failed: {e}");
136    }
137}
138
139pub struct BoundedTaskServiceThread<S: TaskService> {
140    pub join_handle: ServiceJoinHandle,
141    pub mailbox: BoundedTaskMailbox<S>,
142}
143
144impl<S: TaskService + Send + 'static> BoundedTaskServiceThread<S> {
145    pub fn spawn_with(service: S, channel_size: usize) -> Self {
146        let (mailbox, receiver) = BoundedTaskMailbox::create(channel_size);
147        let (handle_tx, handle_rx) = channel();
148        let join_handle = ServiceJoinHandle::new(handle_rx);
149
150        spawn(move || {
151            let runtime = match Builder::new_current_thread().enable_io().build() {
152                Ok(runtime) => runtime,
153                Err(e) => {
154                    error!("Failed to build task service runtime: {}", e);
155                    if let Err(send_err) =
156                        handle_tx.send(Err("Failed to start task service runtime"))
157                    {
158                        error!(
159                            "Failed to send task service failure notification: {}",
160                            send_err
161                        );
162                    }
163                    return;
164                }
165            };
166            runtime.block_on(async_run(service, receiver, handle_tx));
167        });
168
169        BoundedTaskServiceThread {
170            join_handle,
171            mailbox,
172        }
173    }
174
175    pub fn mbox(&self) -> BoundedTaskMailbox<S> {
176        self.mailbox.clone()
177    }
178}
179
180impl<S: TaskService + 'static> BoundedTaskServiceThread<S> {
181    pub fn spawn_with_init_fn<I>(init_fn: I, channel_size: usize) -> Self
182    where
183        I: FnOnce() -> S + Send + 'static,
184    {
185        let (mailbox, receiver) = BoundedTaskMailbox::create(channel_size);
186        let (handle_tx, handle_rx) = channel();
187        let join_handle = ServiceJoinHandle::new(handle_rx);
188
189        spawn(move || {
190            let service = init_fn();
191            let runtime = Builder::new_current_thread().enable_io().build();
192            match runtime {
193                Ok(runtime) => runtime.block_on(async_run(service, receiver, handle_tx)),
194                Err(e) => error!("Failed to spawn service: {}", e),
195            }
196        });
197
198        BoundedTaskServiceThread {
199            join_handle,
200            mailbox,
201        }
202    }
203}
204
205async fn async_run<S>(
206    mut service: S,
207    mut receiver: tokio_mpsc::Receiver<Envelope<S>>,
208    join_handle_tx: Sender<Result<StatsAggregator, &'static str>>,
209) where
210    S: TaskService,
211{
212    let mut stats_aggregator = StatsAggregator::new();
213
214    if let Err(e) = service.init().await {
215        error!("Failed to initialize task: {}", e);
216        return;
217    }
218
219    loop {
220        tokio::select! {
221            Some(mut envelope) = receiver.recv() => {
222                let type_id = envelope.message_type_id();
223                match envelope.deliver_to(&mut service) {
224                    Err(_e) => {
225                        // Delivery failed - probably "attempt to deliver twice" - should never happen.
226                        if let Err(e) = join_handle_tx.send(Err("Message delivery failed")) {
227                            error!("ssf delivery failed: {e}");
228                        }
229                        return;
230                    }
231                    Ok(stats) => {
232                        stats_aggregator.add(&stats);
233                    }
234                }
235                if type_id == Some(TypeId::of::<ShutdownServiceMessage>()) {
236                    break;
237                }
238            },
239            result = service.run_task() => {
240                if let Err(e) = result {
241                    warn!("Service task failed: {}", e);
242                }
243            }
244        };
245    }
246
247    // drop service, and send message indicating the the service thread is closed
248    drop(service);
249    if let Err(e) = join_handle_tx.send(Ok(stats_aggregator)) {
250        error!("ssf delivery failed: {e}");
251    }
252}
253
254pub struct ServiceJoinHandle {
255    rx: Receiver<Result<StatsAggregator, &'static str>>,
256}
257
258impl ServiceJoinHandle {
259    pub fn new(rx: Receiver<Result<StatsAggregator, &'static str>>) -> Self {
260        Self { rx }
261    }
262
263    pub fn join(&mut self) -> Result<StatsAggregator, ServiceJoinHandleError> {
264        self.rx
265            .recv()?
266            .map_err(ServiceJoinHandleError::ServiceFailed)
267    }
268
269    pub fn try_join(&mut self) -> Result<StatsAggregator, ServiceJoinHandleError> {
270        self.rx
271            .try_recv()?
272            .map_err(ServiceJoinHandleError::ServiceFailed)
273    }
274}
275
276#[derive(Debug, PartialEq, Eq)]
277pub enum ServiceJoinHandleError {
278    ServiceStopped,
279    ServiceRunning,
280    ServiceFailed(&'static str),
281}
282
283impl From<RecvError> for ServiceJoinHandleError {
284    fn from(_value: RecvError) -> Self {
285        // recv() can only fail if the sender is dropped
286        Self::ServiceStopped
287    }
288}
289
290impl From<TryRecvError> for ServiceJoinHandleError {
291    fn from(value: TryRecvError) -> Self {
292        match value {
293            TryRecvError::Empty => Self::ServiceRunning,
294            TryRecvError::Disconnected => Self::ServiceStopped,
295        }
296    }
297}
298
299#[cfg(test)]
300mod test {
301    use super::*;
302
303    #[test]
304    fn test_join_handle_error_conversion() {
305        assert_eq!(
306            ServiceJoinHandleError::from(RecvError),
307            ServiceJoinHandleError::ServiceStopped
308        );
309        assert_eq!(
310            ServiceJoinHandleError::from(TryRecvError::Empty),
311            ServiceJoinHandleError::ServiceRunning
312        );
313        assert_eq!(
314            ServiceJoinHandleError::from(TryRecvError::Disconnected),
315            ServiceJoinHandleError::ServiceStopped
316        );
317    }
318
319    #[test]
320    fn test_try_join() {
321        let (tx, rx) = channel();
322        let mut join_handle = ServiceJoinHandle::new(rx);
323
324        assert!(matches!(
325            join_handle.try_join(),
326            Err(ServiceJoinHandleError::ServiceRunning)
327        ));
328
329        tx.send(Ok(StatsAggregator::new())).unwrap();
330        assert!(join_handle.try_join().is_ok());
331    }
332
333    #[test]
334    fn test_join() {
335        let (tx, rx) = channel();
336        let mut join_handle = ServiceJoinHandle::new(rx);
337
338        tx.send(Ok(StatsAggregator::new())).unwrap();
339
340        assert!(join_handle.join().is_ok());
341    }
342
343    #[test]
344    fn test_join_dropped() {
345        let (tx, rx) = channel();
346        let mut join_handle = ServiceJoinHandle::new(rx);
347
348        drop(tx);
349        assert!(matches!(
350            join_handle.join(),
351            Err(ServiceJoinHandleError::ServiceStopped)
352        ));
353    }
354}