1use futures::{
2 channel::oneshot::{channel, Receiver, Sender},
3 future::FusedFuture,
4};
5use log::{debug, warn};
6use std::fmt::{Debug, Formatter};
7
8type TerminatorConnection = (Sender<()>, Receiver<()>);
9
10pub struct Terminator {
13 component_name: &'static str,
14 parent_exit: Receiver<()>,
15 parent_connection: Option<TerminatorConnection>,
16 offspring_connections: Vec<(&'static str, (Sender<()>, TerminatorConnection))>,
17 returned_result: Option<Result<(), ()>>,
18}
19
20impl Debug for Terminator {
21 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22 f.debug_struct("Terminator")
23 .field("component name", &self.component_name)
24 .field(
25 "offspring connection count",
26 &self.offspring_connections.len(),
27 )
28 .finish()
29 }
30}
31
32impl Terminator {
33 fn new(
34 parent_exit: Receiver<()>,
35 parent_connection: Option<TerminatorConnection>,
36 component_name: &'static str,
37 ) -> Self {
38 Self {
39 component_name,
40 parent_exit,
41 parent_connection,
42 offspring_connections: Vec::new(),
43 returned_result: None,
44 }
45 }
46
47 pub fn create_root(exit: Receiver<()>, name: &'static str) -> Self {
49 Self::new(exit, None, name)
50 }
51
52 pub async fn get_exit(&mut self) -> Result<(), ()> {
56 if let Some(returned) = self.returned_result {
57 return returned;
58 }
59 self.returned_result
60 .insert((&mut self.parent_exit).await.map_err(|_| ()))
61 .to_owned()
62 }
63
64 pub fn add_offspring_connection(&mut self, name: &'static str) -> Terminator {
66 let (exit_send, exit_recv) = channel();
67 let (sender, offspring_recv) = channel();
68 let (offspring_sender, recv) = channel();
69
70 let endpoint = (sender, recv);
71 let offspring_endpoint = (offspring_sender, offspring_recv);
72
73 self.offspring_connections
74 .push((name, (exit_send, endpoint)));
75 Terminator::new(exit_recv, Some(offspring_endpoint), name)
76 }
77
78 pub async fn terminate_sync(self) {
80 if !self.parent_exit.is_terminated() {
81 debug!(
82 target: self.component_name,
83 "Terminator has not recieved exit from parent: synchronization canceled.",
84 );
85 return;
86 }
87
88 debug!(
89 target: self.component_name,
90 "Terminator preparing for shutdown.",
91 );
92
93 let mut offspring_senders = Vec::new();
94 let mut offspring_receivers = Vec::new();
95
96 for (name, (exit, connection)) in self.offspring_connections {
98 if exit.send(()).is_err() {
99 debug!(target: self.component_name, "{} already stopped.", name);
100 }
101
102 let (sender, receiver) = connection;
103 offspring_senders.push((sender, name));
104 offspring_receivers.push((receiver, name));
105 }
106
107 for (receiver, name) in offspring_receivers {
109 if receiver.await.is_err() {
110 debug!(
111 target: self.component_name,
112 "Terminator failed to receive from {}.",
113 name,
114 );
115 }
116 }
117
118 debug!(
119 target: self.component_name,
120 "Terminator gathered notifications from descendants.",
121 );
122
123 if let Some((sender, receiver)) = self.parent_connection {
126 if sender.send(()).is_err() {
127 debug!(
128 target: self.component_name,
129 "Terminator failed to notify parent component.",
130 );
131 } else {
132 debug!(
133 target: self.component_name,
134 "Terminator notified parent component.",
135 );
136 }
137
138 if receiver.await.is_err() {
139 debug!(
140 target: self.component_name,
141 "Terminator failed to receive from parent component."
142 );
143 } else {
144 debug!(
145 target: self.component_name,
146 "Terminator recieved shutdown permission from parent component."
147 );
148 }
149 }
150
151 for (sender, name) in offspring_senders {
153 if sender.send(()).is_err() {
154 debug!(
155 target: self.component_name,
156 "Terminator failed to notify {}.",
157 name,
158 );
159 }
160 }
161
162 debug!(
163 target: self.component_name,
164 "Terminator sent permits to descendants: ready to exit.",
165 );
166 }
167}
168
169pub async fn handle_task_termination<T>(task_handle: T, target: &'static str, name: &'static str)
170where
171 T: FusedFuture<Output = Result<(), ()>>,
172{
173 if !task_handle.is_terminated() {
174 if let Err(()) = task_handle.await {
175 warn!(
176 target: target,
177 "{} task stopped with an error", name
178 );
179 }
180 debug!(target: target, "{} stopped.", name);
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use futures::{channel::oneshot, pin_mut, select, FutureExt};
187
188 use crate::Terminator;
189
190 async fn leaf(mut terminator: Terminator) {
191 let _ = terminator.get_exit().await;
192 terminator.terminate_sync().await;
193 }
194
195 async fn internal_1(mut terminator: Terminator, with_crash: bool) {
196 let leaf_handle_1 = leaf(terminator.add_offspring_connection("leaf")).fuse();
197 let leaf_handle_2 = leaf(terminator.add_offspring_connection("leaf")).fuse();
198
199 let leaf_handle_1 = tokio::spawn(leaf_handle_1);
200 let leaf_handle_2 = tokio::spawn(leaf_handle_2);
201
202 if with_crash {
203 return;
204 }
205
206 _ = terminator.get_exit().await;
207 terminator.terminate_sync().await;
208
209 let _ = leaf_handle_1.await;
210 let _ = leaf_handle_2.await;
211 }
212
213 async fn internal_2(mut terminator: Terminator, with_crash: bool) {
214 let leaf_handle_1 = leaf(terminator.add_offspring_connection("leaf")).fuse();
215 let leaf_handle_2 = leaf(terminator.add_offspring_connection("leaf")).fuse();
216 let internal_handle = internal_1(
217 terminator.add_offspring_connection("internal_1"),
218 with_crash,
219 )
220 .fuse();
221
222 pin_mut!(leaf_handle_1);
223 pin_mut!(leaf_handle_2);
224 pin_mut!(internal_handle);
225
226 select! {
227 _ = leaf_handle_1 => assert!(with_crash, "leaf crashed when it wasn't supposed to"),
228 _ = leaf_handle_2 => assert!(with_crash, "leaf crashed when it wasn't supposed to"),
229 _ = internal_handle => assert!(with_crash, "internal_1 crashed when it wasn't supposed to"),
230 _ = terminator.get_exit().fuse() => assert!(!with_crash, "exited when we expected internal crash"),
231 }
232
233 let terminator_handle = terminator.terminate_sync().fuse();
234 pin_mut!(terminator_handle);
235
236 loop {
237 select! {
238 _ = leaf_handle_1 => {},
239 _ = leaf_handle_2 => {},
240 _ = internal_handle => {},
241 _ = terminator_handle => {},
242 complete => break,
243 }
244 }
245 }
246
247 async fn root_component(mut terminator: Terminator, with_crash: bool) {
248 let leaf_handle = leaf(terminator.add_offspring_connection("leaf")).fuse();
249 let internal_handle = internal_2(
250 terminator.add_offspring_connection("internal_2"),
251 with_crash,
252 )
253 .fuse();
254
255 pin_mut!(leaf_handle);
256 pin_mut!(internal_handle);
257
258 select! {
259 _ = leaf_handle => assert!(with_crash, "leaf crashed when it wasn't supposed to"),
260 _ = internal_handle => assert!(with_crash, "internal_2 crashed when it wasn't supposed to"),
261 _ = terminator.get_exit().fuse() => assert!(!with_crash, "exited when we expected internal crash"),
262 }
263
264 let terminator_handle = terminator.terminate_sync().fuse();
265 pin_mut!(terminator_handle);
266
267 loop {
268 select! {
269 _ = leaf_handle => {},
270 _ = internal_handle => {},
271 _ = terminator_handle => {},
272 complete => break,
273 }
274 }
275 }
276
277 #[tokio::test]
278 async fn simple_exit() {
279 let (exit_tx, exit_rx) = oneshot::channel();
280 let terminator = Terminator::create_root(exit_rx, "root");
281 exit_tx.send(()).expect("should send");
282 root_component(terminator, false).await;
283 }
284
285 #[tokio::test]
286 async fn component_crash() {
287 let (_exit_tx, exit_rx) = oneshot::channel();
288 let terminator = Terminator::create_root(exit_rx, "root");
289 root_component(terminator, true).await;
290 }
291}