commonware_runtime/utils/
handle.rs

1use crate::{utils::extract_panic_message, Error};
2use futures::{
3    channel::oneshot,
4    stream::{AbortHandle, Abortable},
5    FutureExt as _,
6};
7use prometheus_client::metrics::gauge::Gauge;
8use std::{
9    future::Future,
10    panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
11    pin::Pin,
12    sync::{Arc, Once},
13    task::{Context, Poll},
14};
15use tracing::error;
16
17/// Handle to a spawned task.
18pub struct Handle<T>
19where
20    T: Send + 'static,
21{
22    aborter: Option<AbortHandle>,
23    receiver: oneshot::Receiver<Result<T, Error>>,
24
25    running: Gauge,
26    once: Arc<Once>,
27}
28
29impl<T> Handle<T>
30where
31    T: Send + 'static,
32{
33    pub(crate) fn init_future<F>(
34        f: F,
35        running: Gauge,
36        catch_panic: bool,
37    ) -> (impl Future<Output = ()>, Self)
38    where
39        F: Future<Output = T> + Send + 'static,
40    {
41        // Increment running counter
42        running.inc();
43
44        // Initialize channels to handle result/abort
45        let once = Arc::new(Once::new());
46        let (sender, receiver) = oneshot::channel();
47        let (aborter, abort_registration) = AbortHandle::new_pair();
48
49        // Wrap the future to handle panics
50        let wrapped = {
51            let once = once.clone();
52            let running = running.clone();
53            async move {
54                // Run future
55                let result = AssertUnwindSafe(f).catch_unwind().await;
56
57                // Decrement running counter
58                once.call_once(|| {
59                    running.dec();
60                });
61
62                // Handle result
63                let result = match result {
64                    Ok(result) => Ok(result),
65                    Err(err) => {
66                        if !catch_panic {
67                            resume_unwind(err);
68                        }
69                        let err = extract_panic_message(&*err);
70                        error!(?err, "task panicked");
71                        Err(Error::Exited)
72                    }
73                };
74                let _ = sender.send(result);
75            }
76        };
77
78        // Make the future abortable
79        let abortable = Abortable::new(wrapped, abort_registration);
80        (
81            abortable.map(|_| ()),
82            Self {
83                aborter: Some(aborter),
84                receiver,
85
86                running,
87                once,
88            },
89        )
90    }
91
92    pub(crate) fn init_blocking<F>(f: F, running: Gauge, catch_panic: bool) -> (impl FnOnce(), Self)
93    where
94        F: FnOnce() -> T + Send + 'static,
95    {
96        // Increment the running tasks gauge
97        running.inc();
98
99        // Initialize channel to handle result
100        let once = Arc::new(Once::new());
101        let (sender, receiver) = oneshot::channel();
102
103        // Wrap the closure with panic handling
104        let f = {
105            let once = once.clone();
106            let running = running.clone();
107            move || {
108                // Run blocking task
109                let result = catch_unwind(AssertUnwindSafe(f));
110
111                // Decrement running counter
112                once.call_once(|| {
113                    running.dec();
114                });
115
116                // Handle result
117                let result = match result {
118                    Ok(value) => Ok(value),
119                    Err(err) => {
120                        if !catch_panic {
121                            resume_unwind(err);
122                        }
123                        let err = extract_panic_message(&*err);
124                        error!(?err, "task panicked");
125                        Err(Error::Exited)
126                    }
127                };
128                let _ = sender.send(result);
129            }
130        };
131
132        // Return the task and handle
133        (
134            f,
135            Self {
136                aborter: None,
137                receiver,
138
139                running,
140                once,
141            },
142        )
143    }
144
145    /// Abort the task (if not blocking).
146    pub fn abort(&self) {
147        // Get aborter and abort
148        let Some(aborter) = &self.aborter else {
149            return;
150        };
151        aborter.abort();
152
153        // Decrement running counter
154        self.once.call_once(|| {
155            self.running.dec();
156        });
157    }
158}
159
160impl<T> Future for Handle<T>
161where
162    T: Send + 'static,
163{
164    type Output = Result<T, Error>;
165
166    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
167        match Pin::new(&mut self.receiver).poll(cx) {
168            Poll::Ready(Ok(Ok(value))) => {
169                self.once.call_once(|| {
170                    self.running.dec();
171                });
172                Poll::Ready(Ok(value))
173            }
174            Poll::Ready(Ok(Err(err))) => {
175                self.once.call_once(|| {
176                    self.running.dec();
177                });
178                Poll::Ready(Err(err))
179            }
180            Poll::Ready(Err(_)) => {
181                self.once.call_once(|| {
182                    self.running.dec();
183                });
184                Poll::Ready(Err(Error::Closed))
185            }
186            Poll::Pending => Poll::Pending,
187        }
188    }
189}