Skip to main content

commonware_runtime/utils/
handle.rs

1use crate::{supervision::Tree, utils::extract_panic_message, Error};
2use commonware_utils::{
3    channel::oneshot,
4    sync::{Mutex, Once},
5};
6use futures::{
7    future::{select, Either},
8    pin_mut,
9    stream::{AbortHandle, Abortable},
10    FutureExt as _,
11};
12use prometheus_client::metrics::gauge::Gauge;
13use std::{
14    any::Any,
15    future::Future,
16    panic::{resume_unwind, AssertUnwindSafe},
17    pin::Pin,
18    sync::Arc,
19    task::{Context, Poll},
20};
21use tracing::error;
22
23/// Handle to a spawned task.
24pub struct Handle<T>
25where
26    T: Send + 'static,
27{
28    abort_handle: Option<AbortHandle>,
29    receiver: oneshot::Receiver<Result<T, Error>>,
30    metric: MetricHandle,
31}
32
33impl<T> Handle<T>
34where
35    T: Send + 'static,
36{
37    pub(crate) fn init<F>(
38        f: F,
39        metric: MetricHandle,
40        panicker: Panicker,
41        tree: Arc<Tree>,
42    ) -> (impl Future<Output = ()>, Self)
43    where
44        F: Future<Output = T> + Send + 'static,
45    {
46        // Initialize channels to handle result/abort
47        let (sender, receiver) = oneshot::channel();
48        let (abort_handle, abort_registration) = AbortHandle::new_pair();
49
50        // Wrap the future to handle panics
51        let wrapped = async move {
52            // Run future
53            let result = AssertUnwindSafe(f).catch_unwind().await;
54
55            // Handle result
56            let result = match result {
57                Ok(result) => Ok(result),
58                Err(panic) => {
59                    panicker.notify(panic);
60                    Err(Error::Exited)
61                }
62            };
63            let _ = sender.send(result);
64        };
65
66        // Make the future abortable
67        let metric_handle = metric.clone();
68        let abortable = Abortable::new(wrapped, abort_registration).map(move |_| {
69            // Mark the task as aborted and abort all descendants.
70            tree.abort();
71
72            // Finish the metric.
73            metric_handle.finish();
74        });
75
76        (
77            abortable,
78            Self {
79                abort_handle: Some(abort_handle),
80                receiver,
81                metric,
82            },
83        )
84    }
85
86    /// Returns a handle that resolves to [`Error::Closed`] without spawning work.
87    pub(crate) fn closed(metric: MetricHandle) -> Self {
88        // Mark the task as finished immediately so gauges remain accurate.
89        metric.finish();
90
91        // Create a receiver that will yield `Err(Error::Closed)` when awaited.
92        let (sender, receiver) = oneshot::channel();
93        drop(sender);
94
95        Self {
96            abort_handle: None,
97            receiver,
98            metric,
99        }
100    }
101
102    /// Abort the task (if not blocking).
103    pub fn abort(&self) {
104        // Get abort handle and abort the task
105        let Some(abort_handle) = &self.abort_handle else {
106            return;
107        };
108        abort_handle.abort();
109
110        // We might never poll the future again after aborting it, so run the
111        // metric cleanup right away
112        self.metric.finish();
113    }
114
115    /// Returns a helper that aborts the task and updates metrics consistently.
116    pub(crate) fn aborter(&self) -> Option<Aborter> {
117        self.abort_handle
118            .clone()
119            .map(|inner| Aborter::new(inner, self.metric.clone()))
120    }
121}
122
123impl<T> Future for Handle<T>
124where
125    T: Send + 'static,
126{
127    type Output = Result<T, Error>;
128
129    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
130        Pin::new(&mut self.receiver)
131            .poll(cx)
132            .map(|result| result.unwrap_or_else(|_| Err(Error::Closed)))
133    }
134}
135
136/// Tracks the metric state associated with a spawned task handle.
137#[derive(Clone)]
138pub(crate) struct MetricHandle {
139    gauge: Gauge,
140    finished: Arc<Once>,
141}
142
143impl MetricHandle {
144    /// Increments the supplied gauge and returns a handle responsible for
145    /// eventually decrementing it.
146    pub(crate) fn new(gauge: Gauge) -> Self {
147        gauge.inc();
148
149        Self {
150            gauge,
151            finished: Arc::new(Once::new()),
152        }
153    }
154
155    /// Marks the task handle as completed and decrements the gauge once.
156    ///
157    /// This method is idempotent, additional calls are ignored so completion
158    /// and abort paths can invoke it independently.
159    pub(crate) fn finish(&self) {
160        let gauge = self.gauge.clone();
161        self.finished.call_once(move || {
162            gauge.dec();
163        });
164    }
165}
166
167/// A panic emitted by a spawned task.
168pub type Panic = Box<dyn Any + Send + 'static>;
169
170/// Notifies the runtime when a spawned task panics, so it can propagate the failure.
171#[derive(Clone)]
172pub(crate) struct Panicker {
173    catch: bool,
174    sender: Arc<Mutex<Option<oneshot::Sender<Panic>>>>,
175}
176
177impl Panicker {
178    /// Creates a new [Panicker].
179    pub(crate) fn new(catch: bool) -> (Self, Panicked) {
180        let (sender, receiver) = oneshot::channel();
181        let panicker = Self {
182            catch,
183            sender: Arc::new(Mutex::new(Some(sender))),
184        };
185        let panicked = Panicked { receiver };
186        (panicker, panicked)
187    }
188
189    /// Returns whether the [Panicker] is configured to catch panics.
190    #[commonware_macros::stability(ALPHA)]
191    pub(crate) const fn catch(&self) -> bool {
192        self.catch
193    }
194
195    /// Notifies the [Panicker] that a panic has occurred.
196    pub(crate) fn notify(&self, panic: Box<dyn Any + Send + 'static>) {
197        // Log the panic
198        let err = extract_panic_message(&*panic);
199        error!(?err, "task panicked");
200
201        // If we are catching panics, just return
202        if self.catch {
203            return;
204        }
205
206        // If we've already sent a panic, ignore the new one
207        let mut sender = self.sender.lock();
208        let Some(sender) = sender.take() else {
209            return;
210        };
211
212        // Send the panic
213        let _ = sender.send(panic);
214    }
215}
216
217/// A handle that will be notified when a panic occurs.
218pub(crate) struct Panicked {
219    receiver: oneshot::Receiver<Panic>,
220}
221
222impl Panicked {
223    /// Polls a task that should be interrupted by a panic.
224    pub(crate) async fn interrupt<Fut>(self, task: Fut) -> Fut::Output
225    where
226        Fut: Future,
227    {
228        // Wait for task to complete or panic
229        let panicked = self.receiver;
230        pin_mut!(panicked);
231        pin_mut!(task);
232        match select(panicked, task).await {
233            Either::Left((panic, task)) => match panic {
234                // If there is a panic, resume the unwind
235                Ok(panic) => {
236                    resume_unwind(panic);
237                }
238                // If there can never be a panic (oneshot is closed), wait for the task to complete
239                // and return the output
240                Err(_) => task.await,
241            },
242            Either::Right((output, _)) => {
243                // Return the output
244                output
245            }
246        }
247    }
248}
249
250/// Couples an [`AbortHandle`] with its metric handle so aborted tasks clean up gauges.
251pub(crate) struct Aborter {
252    inner: AbortHandle,
253    metric: MetricHandle,
254}
255
256impl Aborter {
257    /// Creates a new [`Aborter`] for the provided abort handle and metric handle.
258    pub(crate) const fn new(inner: AbortHandle, metric: MetricHandle) -> Self {
259        Self { inner, metric }
260    }
261
262    /// Aborts the task and records completion in the metric gauge.
263    pub(crate) fn abort(self) {
264        self.inner.abort();
265
266        // We might never poll the future again after aborting it, so run the
267        // metric cleanup right away
268        self.metric.finish();
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use crate::{deterministic, Metrics, Runner, Spawner};
275    use futures::future;
276
277    const METRIC_PREFIX: &str = "runtime_tasks_running{";
278
279    fn running_tasks_for_label(metrics: &str, label: &str) -> Option<u64> {
280        let label_fragment = format!("name=\"{label}\"");
281        metrics.lines().find_map(|line| {
282            if line.starts_with(METRIC_PREFIX) && line.contains(&label_fragment) {
283                line.rsplit_once(' ')
284                    .and_then(|(_, value)| value.trim().parse::<u64>().ok())
285            } else {
286                None
287            }
288        })
289    }
290
291    #[test]
292    fn tasks_running_decreased_after_completion() {
293        const LABEL: &str = "tasks_running_after_completion";
294
295        let runner = deterministic::Runner::default();
296        runner.start(|context| async move {
297            let context = context.with_label(LABEL);
298            let handle = context.clone().spawn(|_| async move { "done" });
299
300            let metrics = context.encode();
301            assert_eq!(
302                running_tasks_for_label(&metrics, LABEL),
303                Some(1),
304                "expected tasks_running gauge to be 1 before completion: {metrics}",
305            );
306
307            let output = handle.await.expect("task failed");
308            assert_eq!(output, "done");
309
310            let metrics = context.encode();
311            assert_eq!(
312                running_tasks_for_label(&metrics, LABEL),
313                Some(0),
314                "expected tasks_running gauge to return to 0 after completion: {metrics}",
315            );
316        });
317    }
318
319    #[test]
320    fn tasks_running_unchanged_when_handle_dropped() {
321        const LABEL: &str = "tasks_running_unchanged";
322
323        let runner = deterministic::Runner::default();
324        runner.start(|context| async move {
325            let context = context.with_label(LABEL);
326            let handle = context.clone().spawn(|_| async move {
327                future::pending::<()>().await;
328            });
329
330            let metrics = context.encode();
331            assert_eq!(
332                running_tasks_for_label(&metrics, LABEL),
333                Some(1),
334                "expected tasks_running gauge to be 1 before dropping handle: {metrics}",
335            );
336
337            drop(handle);
338
339            let metrics = context.encode();
340            assert_eq!(
341                running_tasks_for_label(&metrics, LABEL),
342                Some(1),
343                "dropping handle should not finish metrics: {metrics}",
344            );
345        });
346    }
347
348    #[test]
349    fn tasks_running_decreased_immediately_on_abort_via_handle() {
350        const LABEL: &str = "tasks_running_abort_via_handle";
351
352        let runner = deterministic::Runner::default();
353        runner.start(|context| async move {
354            let context = context.with_label(LABEL);
355            let handle = context.clone().spawn(|_| async move {
356                future::pending::<()>().await;
357            });
358
359            let metrics = context.encode();
360            assert_eq!(
361                running_tasks_for_label(&metrics, LABEL),
362                Some(1),
363                "expected tasks_running gauge to be 1 before abort: {metrics}",
364            );
365
366            handle.abort();
367
368            let metrics = context.encode();
369            assert_eq!(
370                running_tasks_for_label(&metrics, LABEL),
371                Some(0),
372                "expected tasks_running gauge to return to 0 after abort: {metrics}",
373            );
374        });
375    }
376
377    #[test]
378    fn tasks_running_decreased_after_blocking_completion() {
379        const LABEL: &str = "tasks_running_after_blocking_completion";
380
381        let runner = deterministic::Runner::default();
382        runner.start(|context| async move {
383            let context = context.with_label(LABEL);
384
385            let blocking_handle = context.clone().shared(true).spawn(|_| async move {
386                // Simulate some blocking work
387                42
388            });
389
390            let metrics = context.encode();
391            assert_eq!(
392                running_tasks_for_label(&metrics, LABEL),
393                Some(1),
394                "expected tasks_running gauge to be 1 while blocking task runs: {metrics}",
395            );
396
397            let result = blocking_handle.await.expect("blocking task failed");
398            assert_eq!(result, 42);
399
400            let metrics = context.encode();
401            assert_eq!(
402                running_tasks_for_label(&metrics, LABEL),
403                Some(0),
404                "expected tasks_running gauge to return to 0 after blocking task completes: {metrics}",
405            );
406        });
407    }
408
409    #[test]
410    fn tasks_running_decreased_immediately_on_abort_via_aborter() {
411        const LABEL: &str = "tasks_running_abort_via_aborter";
412
413        let runner = deterministic::Runner::default();
414        runner.start(|context| async move {
415            let context = context.with_label(LABEL);
416            let handle = context.clone().spawn(|_| async move {
417                future::pending::<()>().await;
418            });
419
420            let metrics = context.encode();
421            assert_eq!(
422                running_tasks_for_label(&metrics, LABEL),
423                Some(1),
424                "expected tasks_running gauge to be 1 before abort: {metrics}",
425            );
426
427            let aborter = handle.aborter().unwrap();
428            aborter.abort();
429
430            let metrics = context.encode();
431            assert_eq!(
432                running_tasks_for_label(&metrics, LABEL),
433                Some(0),
434                "expected tasks_running gauge to return to 0 after abort: {metrics}",
435            );
436        });
437    }
438}