Skip to main content

commonware_runtime/utils/
handle.rs

1use crate::{
2    telemetry::metrics::raw::Gauge,
3    utils::{extract_panic_message, supervision::Tree},
4    Error,
5};
6use commonware_utils::{
7    channel::oneshot,
8    sync::{Mutex, Once},
9};
10use futures::{
11    future::{select, Either},
12    pin_mut,
13    stream::{AbortHandle, Abortable, Aborted},
14    FutureExt as _,
15};
16use std::{
17    any::Any,
18    future::Future,
19    panic::{resume_unwind, AssertUnwindSafe},
20    pin::Pin,
21    sync::Arc,
22    task::{Context, Poll},
23};
24use tracing::error;
25
26/// Handle to a spawned task.
27pub struct Handle<T>
28where
29    T: Send + 'static,
30{
31    abort_handle: Option<AbortHandle>,
32    receiver: oneshot::Receiver<Result<T, Error>>,
33    metric: MetricHandle,
34}
35
36impl<T> Handle<T>
37where
38    T: Send + 'static,
39{
40    #[inline(always)]
41    pub(crate) fn init<F>(
42        f: F,
43        metric: MetricHandle,
44        panicker: Panicker,
45        tree: Arc<Tree>,
46    ) -> (impl Future<Output = ()>, Self)
47    where
48        F: Future<Output = T> + Send + 'static,
49    {
50        // Initialize channels to handle result/abort
51        let (sender, receiver) = oneshot::channel();
52        let (abort_handle, abort_registration) = AbortHandle::new_pair();
53
54        // Wrap the future with panic catching, abort support, and cleanup.
55        //
56        // Everything is done in a single async block (and the function is marked
57        // #[inline(always)]) so that stack usage is `size_of(F) + constant` rather than
58        // `N * size_of(F)` (which is what a combinator chain produces in debug builds).
59        let metric_handle = metric.clone();
60        let task = async move {
61            // Run future with panic catching and abort support
62            let result =
63                Abortable::new(AssertUnwindSafe(f).catch_unwind(), abort_registration).await;
64
65            // Handle result
66            match result {
67                Ok(Ok(result)) => {
68                    let _ = sender.send(Ok(result));
69                }
70                Ok(Err(panic)) => {
71                    panicker.notify(panic);
72                    let _ = sender.send(Err(Error::Exited));
73                }
74                Err(Aborted) => {}
75            }
76
77            // Mark the task as aborted and abort all descendants.
78            tree.abort();
79
80            // Finish the metric.
81            metric_handle.finish();
82        };
83
84        (
85            task,
86            Self {
87                abort_handle: Some(abort_handle),
88                receiver,
89                metric,
90            },
91        )
92    }
93
94    /// Returns a handle that resolves to [`Error::Closed`] without spawning work.
95    pub(crate) fn closed(metric: MetricHandle) -> Self {
96        // Mark the task as finished immediately so gauges remain accurate.
97        metric.finish();
98
99        // Create a receiver that will yield `Err(Error::Closed)` when awaited.
100        let (sender, receiver) = oneshot::channel();
101        drop(sender);
102
103        Self {
104            abort_handle: None,
105            receiver,
106            metric,
107        }
108    }
109
110    /// Abort the task (if not blocking).
111    pub fn abort(&self) {
112        // Get abort handle and abort the task
113        let Some(abort_handle) = &self.abort_handle else {
114            return;
115        };
116        abort_handle.abort();
117
118        // We might never poll the future again after aborting it, so run the
119        // metric cleanup right away
120        self.metric.finish();
121    }
122
123    /// Returns a helper that aborts the task and updates metrics consistently.
124    pub(crate) fn aborter(&self) -> Option<Aborter> {
125        self.abort_handle
126            .clone()
127            .map(|inner| Aborter::new(inner, self.metric.clone()))
128    }
129}
130
131impl<T> Future for Handle<T>
132where
133    T: Send + 'static,
134{
135    type Output = Result<T, Error>;
136
137    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
138        Pin::new(&mut self.receiver)
139            .poll(cx)
140            .map(|result| result.unwrap_or_else(|_| Err(Error::Closed)))
141    }
142}
143
144/// Tracks the metric state associated with a spawned task handle.
145#[derive(Clone)]
146pub(crate) struct MetricHandle {
147    gauge: Gauge,
148    finished: Arc<Once>,
149}
150
151impl MetricHandle {
152    /// Increments the supplied gauge and returns a handle responsible for
153    /// eventually decrementing it.
154    pub(crate) fn new(gauge: Gauge) -> Self {
155        gauge.inc();
156
157        Self {
158            gauge,
159            finished: Arc::new(Once::new()),
160        }
161    }
162
163    /// Marks the task handle as completed and decrements the gauge once.
164    ///
165    /// This method is idempotent, additional calls are ignored so completion
166    /// and abort paths can invoke it independently.
167    pub(crate) fn finish(&self) {
168        let gauge = self.gauge.clone();
169        self.finished.call_once(move || {
170            gauge.dec();
171        });
172    }
173}
174
175/// A panic emitted by a spawned task.
176pub type Panic = Box<dyn Any + Send + 'static>;
177
178/// Notifies the runtime when a spawned task panics, so it can propagate the failure.
179#[derive(Clone)]
180pub(crate) struct Panicker {
181    catch: bool,
182    sender: Arc<Mutex<Option<oneshot::Sender<Panic>>>>,
183}
184
185impl Panicker {
186    /// Creates a new [Panicker].
187    pub(crate) fn new(catch: bool) -> (Self, Panicked) {
188        let (sender, receiver) = oneshot::channel();
189        let panicker = Self {
190            catch,
191            sender: Arc::new(Mutex::new(Some(sender))),
192        };
193        let panicked = Panicked { receiver };
194        (panicker, panicked)
195    }
196
197    /// Returns whether the [Panicker] is configured to catch panics.
198    #[commonware_macros::stability(ALPHA)]
199    pub(crate) const fn catch(&self) -> bool {
200        self.catch
201    }
202
203    /// Notifies the [Panicker] that a panic has occurred.
204    pub(crate) fn notify(&self, panic: Box<dyn Any + Send + 'static>) {
205        // Log the panic
206        let err = extract_panic_message(&*panic);
207        error!(?err, "task panicked");
208
209        // If we are catching panics, just return
210        if self.catch {
211            return;
212        }
213
214        // If we've already sent a panic, ignore the new one
215        let mut sender = self.sender.lock();
216        let Some(sender) = sender.take() else {
217            return;
218        };
219
220        // Send the panic
221        let _ = sender.send(panic);
222    }
223}
224
225/// A handle that will be notified when a panic occurs.
226pub(crate) struct Panicked {
227    receiver: oneshot::Receiver<Panic>,
228}
229
230impl Panicked {
231    /// Polls a task that should be interrupted by a panic.
232    pub(crate) async fn interrupt<Fut>(self, task: Fut) -> Fut::Output
233    where
234        Fut: Future,
235    {
236        // Wait for task to complete or panic
237        let panicked = self.receiver;
238        pin_mut!(panicked);
239        pin_mut!(task);
240        match select(panicked, task).await {
241            Either::Left((panic, task)) => match panic {
242                // If there is a panic, resume the unwind
243                Ok(panic) => {
244                    resume_unwind(panic);
245                }
246                // If there can never be a panic (oneshot is closed), wait for the task to complete
247                // and return the output
248                Err(_) => task.await,
249            },
250            Either::Right((output, _)) => {
251                // Return the output
252                output
253            }
254        }
255    }
256}
257
258/// Couples an [`AbortHandle`] with its metric handle so aborted tasks clean up gauges.
259pub(crate) struct Aborter {
260    inner: AbortHandle,
261    metric: MetricHandle,
262}
263
264impl Aborter {
265    /// Creates a new [`Aborter`] for the provided abort handle and metric handle.
266    pub(crate) const fn new(inner: AbortHandle, metric: MetricHandle) -> Self {
267        Self { inner, metric }
268    }
269
270    /// Aborts the task and records completion in the metric gauge.
271    pub(crate) fn abort(self) {
272        self.inner.abort();
273
274        // We might never poll the future again after aborting it, so run the
275        // metric cleanup right away
276        self.metric.finish();
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use crate::{deterministic, Metrics as _, Runner, Spawner, Supervisor as _};
283    use futures::future;
284
285    const METRIC_PREFIX: &str = "runtime_tasks_running{";
286
287    fn running_tasks_for_label(metrics: &str, label: &str) -> Option<u64> {
288        let label_fragment = format!("name=\"{label}\"");
289        metrics.lines().find_map(|line| {
290            if line.starts_with(METRIC_PREFIX) && line.contains(&label_fragment) {
291                line.rsplit_once(' ')
292                    .and_then(|(_, value)| value.trim().parse::<u64>().ok())
293            } else {
294                None
295            }
296        })
297    }
298
299    #[test]
300    fn tasks_running_decreased_after_completion() {
301        const LABEL: &str = "tasks_running_after_completion";
302
303        let runner = deterministic::Runner::default();
304        runner.start(|context| async move {
305            let handle = context.child(LABEL).spawn(|_| async move { "done" });
306
307            let metrics = context.encode();
308            assert_eq!(
309                running_tasks_for_label(&metrics, LABEL),
310                Some(1),
311                "expected tasks_running gauge to be 1 before completion: {metrics}",
312            );
313
314            let output = handle.await.expect("task failed");
315            assert_eq!(output, "done");
316
317            let metrics = context.encode();
318            assert_eq!(
319                running_tasks_for_label(&metrics, LABEL),
320                Some(0),
321                "expected tasks_running gauge to return to 0 after completion: {metrics}",
322            );
323        });
324    }
325
326    #[test]
327    fn tasks_running_unchanged_when_handle_dropped() {
328        const LABEL: &str = "tasks_running_unchanged";
329
330        let runner = deterministic::Runner::default();
331        runner.start(|context| async move {
332            let handle = context.child(LABEL).spawn(|_| async move {
333                future::pending::<()>().await;
334            });
335
336            let metrics = context.encode();
337            assert_eq!(
338                running_tasks_for_label(&metrics, LABEL),
339                Some(1),
340                "expected tasks_running gauge to be 1 before dropping handle: {metrics}",
341            );
342
343            drop(handle);
344
345            let metrics = context.encode();
346            assert_eq!(
347                running_tasks_for_label(&metrics, LABEL),
348                Some(1),
349                "dropping handle should not finish metrics: {metrics}",
350            );
351        });
352    }
353
354    #[test]
355    fn tasks_running_decreased_immediately_on_abort_via_handle() {
356        const LABEL: &str = "tasks_running_abort_via_handle";
357
358        let runner = deterministic::Runner::default();
359        runner.start(|context| async move {
360            let handle = context.child(LABEL).spawn(|_| async move {
361                future::pending::<()>().await;
362            });
363
364            let metrics = context.encode();
365            assert_eq!(
366                running_tasks_for_label(&metrics, LABEL),
367                Some(1),
368                "expected tasks_running gauge to be 1 before abort: {metrics}",
369            );
370
371            handle.abort();
372
373            let metrics = context.encode();
374            assert_eq!(
375                running_tasks_for_label(&metrics, LABEL),
376                Some(0),
377                "expected tasks_running gauge to return to 0 after abort: {metrics}",
378            );
379        });
380    }
381
382    #[test]
383    fn tasks_running_decreased_after_blocking_completion() {
384        const LABEL: &str = "tasks_running_after_blocking_completion";
385
386        let runner = deterministic::Runner::default();
387        runner.start(|context| async move {
388            let blocking_handle = context.child(LABEL).shared(true).spawn(|_| async move {
389                // Simulate some blocking work
390                42
391            });
392
393            let metrics = context.encode();
394            assert_eq!(
395                running_tasks_for_label(&metrics, LABEL),
396                Some(1),
397                "expected tasks_running gauge to be 1 while blocking task runs: {metrics}",
398            );
399
400            let result = blocking_handle.await.expect("blocking task failed");
401            assert_eq!(result, 42);
402
403            let metrics = context.encode();
404            assert_eq!(
405                running_tasks_for_label(&metrics, LABEL),
406                Some(0),
407                "expected tasks_running gauge to return to 0 after blocking task completes: {metrics}",
408            );
409        });
410    }
411
412    #[test]
413    fn tasks_running_decreased_immediately_on_abort_via_aborter() {
414        const LABEL: &str = "tasks_running_abort_via_aborter";
415
416        let runner = deterministic::Runner::default();
417        runner.start(|context| async move {
418            let handle = context.child(LABEL).spawn(|_| async move {
419                future::pending::<()>().await;
420            });
421
422            let metrics = context.encode();
423            assert_eq!(
424                running_tasks_for_label(&metrics, LABEL),
425                Some(1),
426                "expected tasks_running gauge to be 1 before abort: {metrics}",
427            );
428
429            let aborter = handle.aborter().unwrap();
430            aborter.abort();
431
432            let metrics = context.encode();
433            assert_eq!(
434                running_tasks_for_label(&metrics, LABEL),
435                Some(0),
436                "expected tasks_running gauge to return to 0 after abort: {metrics}",
437            );
438        });
439    }
440}