memfault_ssf/
service_manager.rs

1//
2// Copyright (c) Memfault, Inc.
3// See License.txt for details
4use std::{thread::sleep, time::Duration};
5
6use log::warn;
7
8use crate::{
9    BoundedMailbox, BoundedServiceThread, BoundedTaskMailbox, BoundedTaskServiceThread, Mailbox,
10    MsgMailbox, Service, ServiceJoinHandle, ServiceJoinHandleError, ServiceThread,
11    ShutdownServiceMessage, StatsAggregator, TaskService,
12};
13
14#[derive(Default)]
15pub struct ServiceManager {
16    shutdown_handles: Vec<ShutdownHandle>,
17}
18
19impl ServiceManager {
20    const SHUTDOWN_TIMEOUT_MS: u64 = 1000;
21    const SHUTDOWN_RETRIES: u64 = 10;
22    const SHUTDOWN_LOOP_TIMEOUT_MS: Duration =
23        Duration::from_millis(Self::SHUTDOWN_TIMEOUT_MS / Self::SHUTDOWN_RETRIES);
24
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    pub fn spawn_service_thread<S: Service + Send + 'static>(&mut self, service: S) -> Mailbox<S> {
30        let service_thread = ServiceThread::spawn_with(service);
31
32        let service_mailbox = service_thread.mailbox.clone();
33
34        let shutdown_handle = ShutdownHandle::from(service_thread);
35        self.shutdown_handles.push(shutdown_handle);
36
37        service_mailbox
38    }
39
40    pub fn spawn_bounded_service_thread<S: Service + Send + 'static>(
41        &mut self,
42        service: S,
43        channel_size: usize,
44    ) -> BoundedMailbox<S> {
45        let service_thread = BoundedServiceThread::spawn_with(service, channel_size);
46        let service_mailbox = service_thread.mailbox.clone();
47
48        let shutdown_handle = ShutdownHandle::from(service_thread);
49        self.shutdown_handles.push(shutdown_handle);
50
51        service_mailbox
52    }
53
54    pub fn spawn_bounded_task_service_thread_with_fn<
55        S: TaskService + 'static,
56        I: FnOnce() -> S + Send + 'static,
57    >(
58        &mut self,
59        init_fn: I,
60        channel_size: usize,
61    ) -> BoundedTaskMailbox<S> {
62        let service_thread = BoundedTaskServiceThread::spawn_with_init_fn(init_fn, channel_size);
63
64        let service_mailbox = service_thread.mailbox.clone();
65
66        let shutdown_handle = ShutdownHandle::from(service_thread);
67        self.shutdown_handles.push(shutdown_handle);
68
69        service_mailbox
70    }
71
72    pub fn spawn_bounded_task_service_thread<S: TaskService + Send + 'static>(
73        &mut self,
74        service: S,
75        channel_size: usize,
76    ) -> BoundedTaskMailbox<S> {
77        let service_thread = BoundedTaskServiceThread::spawn_with(service, channel_size);
78
79        let service_mailbox = service_thread.mailbox.clone();
80
81        let shutdown_handle = ShutdownHandle::from(service_thread);
82        self.shutdown_handles.push(shutdown_handle);
83
84        service_mailbox
85    }
86
87    pub fn stop(&mut self) -> Vec<StatsAggregator> {
88        self.shutdown_handles
89            .iter()
90            .map(|handle| handle.mbox.send_and_forget(ShutdownServiceMessage {}))
91            .filter_map(|res| res.err())
92            .for_each(|e| warn!("Failed to shutdown service: {e}"));
93
94        let mut join_handles = self
95            .shutdown_handles
96            .iter_mut()
97            .map(|handle| &mut handle.join_handle)
98            .collect::<Vec<_>>();
99        let mut stats_aggregators = Vec::with_capacity(join_handles.len());
100        for _ in 0..Self::SHUTDOWN_RETRIES {
101            join_handles.retain_mut(|jh| match jh.try_join() {
102                Ok(stats) => {
103                    stats_aggregators.push(stats);
104                    false
105                }
106                Err(ServiceJoinHandleError::ServiceFailed(msg)) => {
107                    warn!("Service failed while stopping: {msg}");
108                    false
109                }
110                Err(ServiceJoinHandleError::ServiceRunning) => true,
111                Err(ServiceJoinHandleError::ServiceStopped) => false,
112            });
113
114            if join_handles.is_empty() {
115                break;
116            }
117
118            sleep(Self::SHUTDOWN_LOOP_TIMEOUT_MS);
119        }
120
121        stats_aggregators
122    }
123}
124
125pub struct ShutdownHandle {
126    mbox: MsgMailbox<ShutdownServiceMessage>,
127    join_handle: ServiceJoinHandle,
128}
129
130impl<S: Service + 'static> From<ServiceThread<S>> for ShutdownHandle {
131    fn from(service: ServiceThread<S>) -> Self {
132        Self {
133            mbox: service.mailbox.into(),
134            join_handle: service.join_handle,
135        }
136    }
137}
138
139impl<S: Service + 'static> From<BoundedServiceThread<S>> for ShutdownHandle {
140    fn from(service: BoundedServiceThread<S>) -> Self {
141        Self {
142            mbox: service.mailbox.into(),
143            join_handle: service.join_handle,
144        }
145    }
146}
147
148impl<S: TaskService + 'static> From<BoundedTaskServiceThread<S>> for ShutdownHandle {
149    fn from(service: BoundedTaskServiceThread<S>) -> Self {
150        Self {
151            mbox: service.mailbox.into(),
152            join_handle: service.join_handle,
153        }
154    }
155}
156
157#[cfg(test)]
158mod test {
159    use std::sync::{
160        atomic::{AtomicBool, Ordering},
161        Arc,
162    };
163
164    use super::*;
165
166    #[test]
167    fn test_system_stop() {
168        let mut system = ServiceManager::default();
169        let service1_is_running = Arc::new(AtomicBool::new(true));
170        let service2_is_running = Arc::new(AtomicBool::new(true));
171
172        let test_service1 = TestService::new(service1_is_running.clone());
173        let test_service2 = TestService::new(service2_is_running.clone());
174
175        system.spawn_service_thread(test_service1);
176        system.spawn_bounded_service_thread(test_service2, 128);
177
178        system.stop();
179
180        assert!(!service1_is_running.load(Ordering::SeqCst));
181        assert!(!service2_is_running.load(Ordering::SeqCst));
182    }
183
184    struct TestService {
185        is_running: Arc<AtomicBool>,
186    }
187
188    impl TestService {
189        fn new(is_running: Arc<AtomicBool>) -> Self {
190            Self { is_running }
191        }
192    }
193
194    impl Service for TestService {
195        fn name(&self) -> &str {
196            "TestService"
197        }
198    }
199
200    impl Drop for TestService {
201        fn drop(&mut self) {
202            self.is_running.store(false, Ordering::SeqCst);
203        }
204    }
205}