Skip to main content

commonware_runtime/utils/
handle.rs

1use crate::{supervision::Tree, utils::extract_panic_message, Error};
2use commonware_utils::channel::oneshot;
3use futures::{
4    future::{select, Either},
5    pin_mut,
6    stream::{AbortHandle, Abortable},
7    FutureExt as _,
8};
9use prometheus_client::metrics::gauge::Gauge;
10use std::{
11    any::Any,
12    future::Future,
13    panic::{resume_unwind, AssertUnwindSafe},
14    pin::Pin,
15    sync::{Arc, Mutex, Once},
16    task::{Context, Poll},
17};
18use tracing::error;
19
20/// Handle to a spawned task.
21pub struct Handle<T>
22where
23    T: Send + 'static,
24{
25    abort_handle: Option<AbortHandle>,
26    receiver: oneshot::Receiver<Result<T, Error>>,
27    metric: MetricHandle,
28}
29
30impl<T> Handle<T>
31where
32    T: Send + 'static,
33{
34    pub(crate) fn init<F>(
35        f: F,
36        metric: MetricHandle,
37        panicker: Panicker,
38        tree: Arc<Tree>,
39    ) -> (impl Future<Output = ()>, Self)
40    where
41        F: Future<Output = T> + Send + 'static,
42    {
43        // Initialize channels to handle result/abort
44        let (sender, receiver) = oneshot::channel();
45        let (abort_handle, abort_registration) = AbortHandle::new_pair();
46
47        // Wrap the future to handle panics
48        let wrapped = async move {
49            // Run future
50            let result = AssertUnwindSafe(f).catch_unwind().await;
51
52            // Handle result
53            let result = match result {
54                Ok(result) => Ok(result),
55                Err(panic) => {
56                    panicker.notify(panic);
57                    Err(Error::Exited)
58                }
59            };
60            let _ = sender.send(result);
61        };
62
63        // Make the future abortable
64        let metric_handle = metric.clone();
65        let abortable = Abortable::new(wrapped, abort_registration).map(move |_| {
66            // Mark the task as aborted and abort all descendants.
67            tree.abort();
68
69            // Finish the metric.
70            metric_handle.finish();
71        });
72
73        (
74            abortable,
75            Self {
76                abort_handle: Some(abort_handle),
77                receiver,
78                metric,
79            },
80        )
81    }
82
83    /// Returns a handle that resolves to [`Error::Closed`] without spawning work.
84    pub(crate) fn closed(metric: MetricHandle) -> Self {
85        // Mark the task as finished immediately so gauges remain accurate.
86        metric.finish();
87
88        // Create a receiver that will yield `Err(Error::Closed)` when awaited.
89        let (sender, receiver) = oneshot::channel();
90        drop(sender);
91
92        Self {
93            abort_handle: None,
94            receiver,
95            metric,
96        }
97    }
98
99    /// Abort the task (if not blocking).
100    pub fn abort(&self) {
101        // Get abort handle and abort the task
102        let Some(abort_handle) = &self.abort_handle else {
103            return;
104        };
105        abort_handle.abort();
106
107        // We might never poll the future again after aborting it, so run the
108        // metric cleanup right away
109        self.metric.finish();
110    }
111
112    /// Returns a helper that aborts the task and updates metrics consistently.
113    pub(crate) fn aborter(&self) -> Option<Aborter> {
114        self.abort_handle
115            .clone()
116            .map(|inner| Aborter::new(inner, self.metric.clone()))
117    }
118}
119
120impl<T> Future for Handle<T>
121where
122    T: Send + 'static,
123{
124    type Output = Result<T, Error>;
125
126    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127        Pin::new(&mut self.receiver)
128            .poll(cx)
129            .map(|result| result.unwrap_or_else(|_| Err(Error::Closed)))
130    }
131}
132
133/// Tracks the metric state associated with a spawned task handle.
134#[derive(Clone)]
135pub(crate) struct MetricHandle {
136    gauge: Gauge,
137    finished: Arc<Once>,
138}
139
140impl MetricHandle {
141    /// Increments the supplied gauge and returns a handle responsible for
142    /// eventually decrementing it.
143    pub(crate) fn new(gauge: Gauge) -> Self {
144        gauge.inc();
145
146        Self {
147            gauge,
148            finished: Arc::new(Once::new()),
149        }
150    }
151
152    /// Marks the task handle as completed and decrements the gauge once.
153    ///
154    /// This method is idempotent, additional calls are ignored so completion
155    /// and abort paths can invoke it independently.
156    pub(crate) fn finish(&self) {
157        let gauge = self.gauge.clone();
158        self.finished.call_once(move || {
159            gauge.dec();
160        });
161    }
162}
163
164/// A panic emitted by a spawned task.
165pub type Panic = Box<dyn Any + Send + 'static>;
166
167/// Notifies the runtime when a spawned task panics, so it can propagate the failure.
168#[derive(Clone)]
169pub(crate) struct Panicker {
170    catch: bool,
171    sender: Arc<Mutex<Option<oneshot::Sender<Panic>>>>,
172}
173
174impl Panicker {
175    /// Creates a new [Panicker].
176    pub(crate) fn new(catch: bool) -> (Self, Panicked) {
177        let (sender, receiver) = oneshot::channel();
178        let panicker = Self {
179            catch,
180            sender: Arc::new(Mutex::new(Some(sender))),
181        };
182        let panicked = Panicked { receiver };
183        (panicker, panicked)
184    }
185
186    /// Returns whether the [Panicker] is configured to catch panics.
187    #[commonware_macros::stability(ALPHA)]
188    pub(crate) const fn catch(&self) -> bool {
189        self.catch
190    }
191
192    /// Notifies the [Panicker] that a panic has occurred.
193    pub(crate) fn notify(&self, panic: Box<dyn Any + Send + 'static>) {
194        // Log the panic
195        let err = extract_panic_message(&*panic);
196        error!(?err, "task panicked");
197
198        // If we are catching panics, just return
199        if self.catch {
200            return;
201        }
202
203        // If we've already sent a panic, ignore the new one
204        let mut sender = self.sender.lock().unwrap();
205        let Some(sender) = sender.take() else {
206            return;
207        };
208
209        // Send the panic
210        let _ = sender.send(panic);
211    }
212}
213
214/// A handle that will be notified when a panic occurs.
215pub(crate) struct Panicked {
216    receiver: oneshot::Receiver<Panic>,
217}
218
219impl Panicked {
220    /// Polls a task that should be interrupted by a panic.
221    pub(crate) async fn interrupt<Fut>(self, task: Fut) -> Fut::Output
222    where
223        Fut: Future,
224    {
225        // Wait for task to complete or panic
226        let panicked = self.receiver;
227        pin_mut!(panicked);
228        pin_mut!(task);
229        match select(panicked, task).await {
230            Either::Left((panic, task)) => match panic {
231                // If there is a panic, resume the unwind
232                Ok(panic) => {
233                    resume_unwind(panic);
234                }
235                // If there can never be a panic (oneshot is closed), wait for the task to complete
236                // and return the output
237                Err(_) => task.await,
238            },
239            Either::Right((output, _)) => {
240                // Return the output
241                output
242            }
243        }
244    }
245}
246
247/// Couples an [`AbortHandle`] with its metric handle so aborted tasks clean up gauges.
248pub(crate) struct Aborter {
249    inner: AbortHandle,
250    metric: MetricHandle,
251}
252
253impl Aborter {
254    /// Creates a new [`Aborter`] for the provided abort handle and metric handle.
255    pub(crate) const fn new(inner: AbortHandle, metric: MetricHandle) -> Self {
256        Self { inner, metric }
257    }
258
259    /// Aborts the task and records completion in the metric gauge.
260    pub(crate) fn abort(self) {
261        self.inner.abort();
262
263        // We might never poll the future again after aborting it, so run the
264        // metric cleanup right away
265        self.metric.finish();
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use crate::{deterministic, Metrics, Runner, Spawner};
272    use futures::future;
273
274    const METRIC_PREFIX: &str = "runtime_tasks_running{";
275
276    fn running_tasks_for_label(metrics: &str, label: &str) -> Option<u64> {
277        let label_fragment = format!("name=\"{label}\"");
278        metrics.lines().find_map(|line| {
279            if line.starts_with(METRIC_PREFIX) && line.contains(&label_fragment) {
280                line.rsplit_once(' ')
281                    .and_then(|(_, value)| value.trim().parse::<u64>().ok())
282            } else {
283                None
284            }
285        })
286    }
287
288    #[test]
289    fn tasks_running_decreased_after_completion() {
290        const LABEL: &str = "tasks_running_after_completion";
291
292        let runner = deterministic::Runner::default();
293        runner.start(|context| async move {
294            let context = context.with_label(LABEL);
295            let handle = context.clone().spawn(|_| async move { "done" });
296
297            let metrics = context.encode();
298            assert_eq!(
299                running_tasks_for_label(&metrics, LABEL),
300                Some(1),
301                "expected tasks_running gauge to be 1 before completion: {metrics}",
302            );
303
304            let output = handle.await.expect("task failed");
305            assert_eq!(output, "done");
306
307            let metrics = context.encode();
308            assert_eq!(
309                running_tasks_for_label(&metrics, LABEL),
310                Some(0),
311                "expected tasks_running gauge to return to 0 after completion: {metrics}",
312            );
313        });
314    }
315
316    #[test]
317    fn tasks_running_unchanged_when_handle_dropped() {
318        const LABEL: &str = "tasks_running_unchanged";
319
320        let runner = deterministic::Runner::default();
321        runner.start(|context| async move {
322            let context = context.with_label(LABEL);
323            let handle = context.clone().spawn(|_| async move {
324                future::pending::<()>().await;
325            });
326
327            let metrics = context.encode();
328            assert_eq!(
329                running_tasks_for_label(&metrics, LABEL),
330                Some(1),
331                "expected tasks_running gauge to be 1 before dropping handle: {metrics}",
332            );
333
334            drop(handle);
335
336            let metrics = context.encode();
337            assert_eq!(
338                running_tasks_for_label(&metrics, LABEL),
339                Some(1),
340                "dropping handle should not finish metrics: {metrics}",
341            );
342        });
343    }
344
345    #[test]
346    fn tasks_running_decreased_immediately_on_abort_via_handle() {
347        const LABEL: &str = "tasks_running_abort_via_handle";
348
349        let runner = deterministic::Runner::default();
350        runner.start(|context| async move {
351            let context = context.with_label(LABEL);
352            let handle = context.clone().spawn(|_| async move {
353                future::pending::<()>().await;
354            });
355
356            let metrics = context.encode();
357            assert_eq!(
358                running_tasks_for_label(&metrics, LABEL),
359                Some(1),
360                "expected tasks_running gauge to be 1 before abort: {metrics}",
361            );
362
363            handle.abort();
364
365            let metrics = context.encode();
366            assert_eq!(
367                running_tasks_for_label(&metrics, LABEL),
368                Some(0),
369                "expected tasks_running gauge to return to 0 after abort: {metrics}",
370            );
371        });
372    }
373
374    #[test]
375    fn tasks_running_decreased_after_blocking_completion() {
376        const LABEL: &str = "tasks_running_after_blocking_completion";
377
378        let runner = deterministic::Runner::default();
379        runner.start(|context| async move {
380            let context = context.with_label(LABEL);
381
382            let blocking_handle = context.clone().shared(true).spawn(|_| async move {
383                // Simulate some blocking work
384                42
385            });
386
387            let metrics = context.encode();
388            assert_eq!(
389                running_tasks_for_label(&metrics, LABEL),
390                Some(1),
391                "expected tasks_running gauge to be 1 while blocking task runs: {metrics}",
392            );
393
394            let result = blocking_handle.await.expect("blocking task failed");
395            assert_eq!(result, 42);
396
397            let metrics = context.encode();
398            assert_eq!(
399                running_tasks_for_label(&metrics, LABEL),
400                Some(0),
401                "expected tasks_running gauge to return to 0 after blocking task completes: {metrics}",
402            );
403        });
404    }
405
406    #[test]
407    fn tasks_running_decreased_immediately_on_abort_via_aborter() {
408        const LABEL: &str = "tasks_running_abort_via_aborter";
409
410        let runner = deterministic::Runner::default();
411        runner.start(|context| async move {
412            let context = context.with_label(LABEL);
413            let handle = context.clone().spawn(|_| async move {
414                future::pending::<()>().await;
415            });
416
417            let metrics = context.encode();
418            assert_eq!(
419                running_tasks_for_label(&metrics, LABEL),
420                Some(1),
421                "expected tasks_running gauge to be 1 before abort: {metrics}",
422            );
423
424            let aborter = handle.aborter().unwrap();
425            aborter.abort();
426
427            let metrics = context.encode();
428            assert_eq!(
429                running_tasks_for_label(&metrics, LABEL),
430                Some(0),
431                "expected tasks_running gauge to return to 0 after abort: {metrics}",
432            );
433        });
434    }
435}