commonware_runtime/utils/
handle.rs

1use crate::{utils::extract_panic_message, Error};
2use futures::{
3    channel::oneshot,
4    stream::{AbortHandle, Abortable},
5    FutureExt as _,
6};
7use prometheus_client::metrics::gauge::Gauge;
8use std::{
9    future::Future,
10    panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
11    pin::Pin,
12    sync::{Arc, Mutex, Once},
13    task::{Context, Poll},
14};
15use tracing::error;
16
17/// Handle to a spawned task.
18pub struct Handle<T>
19where
20    T: Send + 'static,
21{
22    aborter: Option<AbortHandle>,
23    receiver: oneshot::Receiver<Result<T, Error>>,
24
25    running: Gauge,
26    once: Arc<Once>,
27}
28
29impl<T> Handle<T>
30where
31    T: Send + 'static,
32{
33    pub(crate) fn init_future<F>(
34        f: F,
35        running: Gauge,
36        catch_panic: bool,
37        children: Arc<Mutex<Vec<AbortHandle>>>,
38    ) -> (impl Future<Output = ()>, Self)
39    where
40        F: Future<Output = T> + Send + 'static,
41    {
42        // Increment running counter
43        running.inc();
44
45        // Initialize channels to handle result/abort
46        let once = Arc::new(Once::new());
47        let (sender, receiver) = oneshot::channel();
48        let (aborter, abort_registration) = AbortHandle::new_pair();
49
50        // Wrap the future to handle panics
51        let wrapped = {
52            let once = once.clone();
53            let running = running.clone();
54            async move {
55                // Run future
56                let result = AssertUnwindSafe(f).catch_unwind().await;
57
58                // Decrement running counter
59                once.call_once(|| {
60                    running.dec();
61                });
62
63                // Handle result
64                let result = match result {
65                    Ok(result) => Ok(result),
66                    Err(err) => {
67                        if !catch_panic {
68                            resume_unwind(err);
69                        }
70                        let err = extract_panic_message(&*err);
71                        error!(?err, "task panicked");
72                        Err(Error::Exited)
73                    }
74                };
75                let _ = sender.send(result);
76            }
77        };
78
79        // Make the future abortable
80        let abortable = Abortable::new(wrapped, abort_registration);
81        (
82            abortable.map(move |_| {
83                // Abort all children
84                for handle in children.lock().unwrap().drain(..) {
85                    handle.abort();
86                }
87            }),
88            Self {
89                aborter: Some(aborter),
90                receiver,
91
92                running,
93                once,
94            },
95        )
96    }
97
98    pub(crate) fn init_blocking<F>(f: F, running: Gauge, catch_panic: bool) -> (impl FnOnce(), Self)
99    where
100        F: FnOnce() -> T + Send + 'static,
101    {
102        // Increment the running tasks gauge
103        running.inc();
104
105        // Initialize channel to handle result
106        let once = Arc::new(Once::new());
107        let (sender, receiver) = oneshot::channel();
108
109        // Wrap the closure with panic handling
110        let f = {
111            let once = once.clone();
112            let running = running.clone();
113            move || {
114                // Run blocking task
115                let result = catch_unwind(AssertUnwindSafe(f));
116
117                // Decrement running counter
118                once.call_once(|| {
119                    running.dec();
120                });
121
122                // Handle result
123                let result = match result {
124                    Ok(value) => Ok(value),
125                    Err(err) => {
126                        if !catch_panic {
127                            resume_unwind(err);
128                        }
129                        let err = extract_panic_message(&*err);
130                        error!(?err, "task panicked");
131                        Err(Error::Exited)
132                    }
133                };
134                let _ = sender.send(result);
135            }
136        };
137
138        // Return the task and handle
139        (
140            f,
141            Self {
142                aborter: None,
143                receiver,
144
145                running,
146                once,
147            },
148        )
149    }
150
151    /// Abort the task (if not blocking).
152    pub fn abort(&self) {
153        // Get aborter and abort
154        let Some(aborter) = &self.aborter else {
155            return;
156        };
157        aborter.abort();
158
159        // Decrement running counter
160        self.once.call_once(|| {
161            self.running.dec();
162        });
163    }
164
165    pub(crate) fn abort_handle(&self) -> Option<AbortHandle> {
166        self.aborter.clone()
167    }
168}
169
170impl<T> Future for Handle<T>
171where
172    T: Send + 'static,
173{
174    type Output = Result<T, Error>;
175
176    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
177        match Pin::new(&mut self.receiver).poll(cx) {
178            Poll::Ready(Ok(Ok(value))) => {
179                self.once.call_once(|| {
180                    self.running.dec();
181                });
182                Poll::Ready(Ok(value))
183            }
184            Poll::Ready(Ok(Err(err))) => {
185                self.once.call_once(|| {
186                    self.running.dec();
187                });
188                Poll::Ready(Err(err))
189            }
190            Poll::Ready(Err(_)) => {
191                self.once.call_once(|| {
192                    self.running.dec();
193                });
194                Poll::Ready(Err(Error::Closed))
195            }
196            Poll::Pending => Poll::Pending,
197        }
198    }
199}