Skip to main content

commonware_runtime/utils/
handle.rs

1use crate::{supervision::Tree, utils::extract_panic_message, Error};
2use commonware_utils::{
3    channel::oneshot,
4    sync::{Mutex, Once},
5};
6use futures::{
7    future::{select, Either},
8    pin_mut,
9    stream::{AbortHandle, Abortable, Aborted},
10    FutureExt as _,
11};
12use prometheus_client::metrics::gauge::Gauge;
13use std::{
14    any::Any,
15    future::Future,
16    panic::{resume_unwind, AssertUnwindSafe},
17    pin::Pin,
18    sync::Arc,
19    task::{Context, Poll},
20};
21use tracing::error;
22
23/// Handle to a spawned task.
24pub struct Handle<T>
25where
26    T: Send + 'static,
27{
28    abort_handle: Option<AbortHandle>,
29    receiver: oneshot::Receiver<Result<T, Error>>,
30    metric: MetricHandle,
31}
32
33impl<T> Handle<T>
34where
35    T: Send + 'static,
36{
37    #[inline(always)]
38    pub(crate) fn init<F>(
39        f: F,
40        metric: MetricHandle,
41        panicker: Panicker,
42        tree: Arc<Tree>,
43    ) -> (impl Future<Output = ()>, Self)
44    where
45        F: Future<Output = T> + Send + 'static,
46    {
47        // Initialize channels to handle result/abort
48        let (sender, receiver) = oneshot::channel();
49        let (abort_handle, abort_registration) = AbortHandle::new_pair();
50
51        // Wrap the future with panic catching, abort support, and cleanup.
52        //
53        // Everything is done in a single async block (and the function is marked
54        // #[inline(always)]) so that stack usage is `size_of(F) + constant` rather than
55        // `N * size_of(F)` (which is what a combinator chain produces in debug builds).
56        let metric_handle = metric.clone();
57        let task = async move {
58            // Run future with panic catching and abort support
59            let result =
60                Abortable::new(AssertUnwindSafe(f).catch_unwind(), abort_registration).await;
61
62            // Handle result
63            match result {
64                Ok(Ok(result)) => {
65                    let _ = sender.send(Ok(result));
66                }
67                Ok(Err(panic)) => {
68                    panicker.notify(panic);
69                    let _ = sender.send(Err(Error::Exited));
70                }
71                Err(Aborted) => {}
72            }
73
74            // Mark the task as aborted and abort all descendants.
75            tree.abort();
76
77            // Finish the metric.
78            metric_handle.finish();
79        };
80
81        (
82            task,
83            Self {
84                abort_handle: Some(abort_handle),
85                receiver,
86                metric,
87            },
88        )
89    }
90
91    /// Returns a handle that resolves to [`Error::Closed`] without spawning work.
92    pub(crate) fn closed(metric: MetricHandle) -> Self {
93        // Mark the task as finished immediately so gauges remain accurate.
94        metric.finish();
95
96        // Create a receiver that will yield `Err(Error::Closed)` when awaited.
97        let (sender, receiver) = oneshot::channel();
98        drop(sender);
99
100        Self {
101            abort_handle: None,
102            receiver,
103            metric,
104        }
105    }
106
107    /// Abort the task (if not blocking).
108    pub fn abort(&self) {
109        // Get abort handle and abort the task
110        let Some(abort_handle) = &self.abort_handle else {
111            return;
112        };
113        abort_handle.abort();
114
115        // We might never poll the future again after aborting it, so run the
116        // metric cleanup right away
117        self.metric.finish();
118    }
119
120    /// Returns a helper that aborts the task and updates metrics consistently.
121    pub(crate) fn aborter(&self) -> Option<Aborter> {
122        self.abort_handle
123            .clone()
124            .map(|inner| Aborter::new(inner, self.metric.clone()))
125    }
126}
127
128impl<T> Future for Handle<T>
129where
130    T: Send + 'static,
131{
132    type Output = Result<T, Error>;
133
134    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135        Pin::new(&mut self.receiver)
136            .poll(cx)
137            .map(|result| result.unwrap_or_else(|_| Err(Error::Closed)))
138    }
139}
140
141/// Tracks the metric state associated with a spawned task handle.
142#[derive(Clone)]
143pub(crate) struct MetricHandle {
144    gauge: Gauge,
145    finished: Arc<Once>,
146}
147
148impl MetricHandle {
149    /// Increments the supplied gauge and returns a handle responsible for
150    /// eventually decrementing it.
151    pub(crate) fn new(gauge: Gauge) -> Self {
152        gauge.inc();
153
154        Self {
155            gauge,
156            finished: Arc::new(Once::new()),
157        }
158    }
159
160    /// Marks the task handle as completed and decrements the gauge once.
161    ///
162    /// This method is idempotent, additional calls are ignored so completion
163    /// and abort paths can invoke it independently.
164    pub(crate) fn finish(&self) {
165        let gauge = self.gauge.clone();
166        self.finished.call_once(move || {
167            gauge.dec();
168        });
169    }
170}
171
172/// A panic emitted by a spawned task.
173pub type Panic = Box<dyn Any + Send + 'static>;
174
175/// Notifies the runtime when a spawned task panics, so it can propagate the failure.
176#[derive(Clone)]
177pub(crate) struct Panicker {
178    catch: bool,
179    sender: Arc<Mutex<Option<oneshot::Sender<Panic>>>>,
180}
181
182impl Panicker {
183    /// Creates a new [Panicker].
184    pub(crate) fn new(catch: bool) -> (Self, Panicked) {
185        let (sender, receiver) = oneshot::channel();
186        let panicker = Self {
187            catch,
188            sender: Arc::new(Mutex::new(Some(sender))),
189        };
190        let panicked = Panicked { receiver };
191        (panicker, panicked)
192    }
193
194    /// Returns whether the [Panicker] is configured to catch panics.
195    #[commonware_macros::stability(ALPHA)]
196    pub(crate) const fn catch(&self) -> bool {
197        self.catch
198    }
199
200    /// Notifies the [Panicker] that a panic has occurred.
201    pub(crate) fn notify(&self, panic: Box<dyn Any + Send + 'static>) {
202        // Log the panic
203        let err = extract_panic_message(&*panic);
204        error!(?err, "task panicked");
205
206        // If we are catching panics, just return
207        if self.catch {
208            return;
209        }
210
211        // If we've already sent a panic, ignore the new one
212        let mut sender = self.sender.lock();
213        let Some(sender) = sender.take() else {
214            return;
215        };
216
217        // Send the panic
218        let _ = sender.send(panic);
219    }
220}
221
222/// A handle that will be notified when a panic occurs.
223pub(crate) struct Panicked {
224    receiver: oneshot::Receiver<Panic>,
225}
226
227impl Panicked {
228    /// Polls a task that should be interrupted by a panic.
229    pub(crate) async fn interrupt<Fut>(self, task: Fut) -> Fut::Output
230    where
231        Fut: Future,
232    {
233        // Wait for task to complete or panic
234        let panicked = self.receiver;
235        pin_mut!(panicked);
236        pin_mut!(task);
237        match select(panicked, task).await {
238            Either::Left((panic, task)) => match panic {
239                // If there is a panic, resume the unwind
240                Ok(panic) => {
241                    resume_unwind(panic);
242                }
243                // If there can never be a panic (oneshot is closed), wait for the task to complete
244                // and return the output
245                Err(_) => task.await,
246            },
247            Either::Right((output, _)) => {
248                // Return the output
249                output
250            }
251        }
252    }
253}
254
255/// Couples an [`AbortHandle`] with its metric handle so aborted tasks clean up gauges.
256pub(crate) struct Aborter {
257    inner: AbortHandle,
258    metric: MetricHandle,
259}
260
261impl Aborter {
262    /// Creates a new [`Aborter`] for the provided abort handle and metric handle.
263    pub(crate) const fn new(inner: AbortHandle, metric: MetricHandle) -> Self {
264        Self { inner, metric }
265    }
266
267    /// Aborts the task and records completion in the metric gauge.
268    pub(crate) fn abort(self) {
269        self.inner.abort();
270
271        // We might never poll the future again after aborting it, so run the
272        // metric cleanup right away
273        self.metric.finish();
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use crate::{deterministic, Metrics, Runner, Spawner};
280    use futures::future;
281
282    const METRIC_PREFIX: &str = "runtime_tasks_running{";
283
284    fn running_tasks_for_label(metrics: &str, label: &str) -> Option<u64> {
285        let label_fragment = format!("name=\"{label}\"");
286        metrics.lines().find_map(|line| {
287            if line.starts_with(METRIC_PREFIX) && line.contains(&label_fragment) {
288                line.rsplit_once(' ')
289                    .and_then(|(_, value)| value.trim().parse::<u64>().ok())
290            } else {
291                None
292            }
293        })
294    }
295
296    #[test]
297    fn tasks_running_decreased_after_completion() {
298        const LABEL: &str = "tasks_running_after_completion";
299
300        let runner = deterministic::Runner::default();
301        runner.start(|context| async move {
302            let context = context.with_label(LABEL);
303            let handle = context.clone().spawn(|_| async move { "done" });
304
305            let metrics = context.encode();
306            assert_eq!(
307                running_tasks_for_label(&metrics, LABEL),
308                Some(1),
309                "expected tasks_running gauge to be 1 before completion: {metrics}",
310            );
311
312            let output = handle.await.expect("task failed");
313            assert_eq!(output, "done");
314
315            let metrics = context.encode();
316            assert_eq!(
317                running_tasks_for_label(&metrics, LABEL),
318                Some(0),
319                "expected tasks_running gauge to return to 0 after completion: {metrics}",
320            );
321        });
322    }
323
324    #[test]
325    fn tasks_running_unchanged_when_handle_dropped() {
326        const LABEL: &str = "tasks_running_unchanged";
327
328        let runner = deterministic::Runner::default();
329        runner.start(|context| async move {
330            let context = context.with_label(LABEL);
331            let handle = context.clone().spawn(|_| async move {
332                future::pending::<()>().await;
333            });
334
335            let metrics = context.encode();
336            assert_eq!(
337                running_tasks_for_label(&metrics, LABEL),
338                Some(1),
339                "expected tasks_running gauge to be 1 before dropping handle: {metrics}",
340            );
341
342            drop(handle);
343
344            let metrics = context.encode();
345            assert_eq!(
346                running_tasks_for_label(&metrics, LABEL),
347                Some(1),
348                "dropping handle should not finish metrics: {metrics}",
349            );
350        });
351    }
352
353    #[test]
354    fn tasks_running_decreased_immediately_on_abort_via_handle() {
355        const LABEL: &str = "tasks_running_abort_via_handle";
356
357        let runner = deterministic::Runner::default();
358        runner.start(|context| async move {
359            let context = context.with_label(LABEL);
360            let handle = context.clone().spawn(|_| async move {
361                future::pending::<()>().await;
362            });
363
364            let metrics = context.encode();
365            assert_eq!(
366                running_tasks_for_label(&metrics, LABEL),
367                Some(1),
368                "expected tasks_running gauge to be 1 before abort: {metrics}",
369            );
370
371            handle.abort();
372
373            let metrics = context.encode();
374            assert_eq!(
375                running_tasks_for_label(&metrics, LABEL),
376                Some(0),
377                "expected tasks_running gauge to return to 0 after abort: {metrics}",
378            );
379        });
380    }
381
382    #[test]
383    fn tasks_running_decreased_after_blocking_completion() {
384        const LABEL: &str = "tasks_running_after_blocking_completion";
385
386        let runner = deterministic::Runner::default();
387        runner.start(|context| async move {
388            let context = context.with_label(LABEL);
389
390            let blocking_handle = context.clone().shared(true).spawn(|_| async move {
391                // Simulate some blocking work
392                42
393            });
394
395            let metrics = context.encode();
396            assert_eq!(
397                running_tasks_for_label(&metrics, LABEL),
398                Some(1),
399                "expected tasks_running gauge to be 1 while blocking task runs: {metrics}",
400            );
401
402            let result = blocking_handle.await.expect("blocking task failed");
403            assert_eq!(result, 42);
404
405            let metrics = context.encode();
406            assert_eq!(
407                running_tasks_for_label(&metrics, LABEL),
408                Some(0),
409                "expected tasks_running gauge to return to 0 after blocking task completes: {metrics}",
410            );
411        });
412    }
413
414    #[test]
415    fn tasks_running_decreased_immediately_on_abort_via_aborter() {
416        const LABEL: &str = "tasks_running_abort_via_aborter";
417
418        let runner = deterministic::Runner::default();
419        runner.start(|context| async move {
420            let context = context.with_label(LABEL);
421            let handle = context.clone().spawn(|_| async move {
422                future::pending::<()>().await;
423            });
424
425            let metrics = context.encode();
426            assert_eq!(
427                running_tasks_for_label(&metrics, LABEL),
428                Some(1),
429                "expected tasks_running gauge to be 1 before abort: {metrics}",
430            );
431
432            let aborter = handle.aborter().unwrap();
433            aborter.abort();
434
435            let metrics = context.encode();
436            assert_eq!(
437                running_tasks_for_label(&metrics, LABEL),
438                Some(0),
439                "expected tasks_running gauge to return to 0 after abort: {metrics}",
440            );
441        });
442    }
443}