commonware_runtime/utils/
handle.rs

1use crate::{supervision::Tree, utils::extract_panic_message, Error};
2use futures::{
3    channel::oneshot,
4    future::{select, Either},
5    pin_mut,
6    stream::{AbortHandle, Abortable},
7    FutureExt as _,
8};
9use prometheus_client::metrics::gauge::Gauge;
10use std::{
11    any::Any,
12    future::Future,
13    panic::{resume_unwind, AssertUnwindSafe},
14    pin::Pin,
15    sync::{Arc, Mutex, Once},
16    task::{Context, Poll},
17};
18use tracing::error;
19
20/// Handle to a spawned task.
21pub struct Handle<T>
22where
23    T: Send + 'static,
24{
25    abort_handle: Option<AbortHandle>,
26    receiver: oneshot::Receiver<Result<T, Error>>,
27    metric: MetricHandle,
28}
29
30impl<T> Handle<T>
31where
32    T: Send + 'static,
33{
34    pub(crate) fn init<F>(
35        f: F,
36        metric: MetricHandle,
37        panicker: Panicker,
38        tree: Arc<Tree>,
39    ) -> (impl Future<Output = ()>, Self)
40    where
41        F: Future<Output = T> + Send + 'static,
42    {
43        // Initialize channels to handle result/abort
44        let (sender, receiver) = oneshot::channel();
45        let (abort_handle, abort_registration) = AbortHandle::new_pair();
46
47        // Wrap the future to handle panics
48        let wrapped = async move {
49            // Run future
50            let result = AssertUnwindSafe(f).catch_unwind().await;
51
52            // Handle result
53            let result = match result {
54                Ok(result) => Ok(result),
55                Err(panic) => {
56                    panicker.notify(panic);
57                    Err(Error::Exited)
58                }
59            };
60            let _ = sender.send(result);
61        };
62
63        // Make the future abortable
64        let metric_handle = metric.clone();
65        let abortable = Abortable::new(wrapped, abort_registration).map(move |_| {
66            // Mark the task as aborted and abort all descendants.
67            tree.abort();
68
69            // Finish the metric.
70            metric_handle.finish();
71        });
72
73        (
74            abortable,
75            Self {
76                abort_handle: Some(abort_handle),
77                receiver,
78                metric,
79            },
80        )
81    }
82
83    /// Returns a handle that resolves to [`Error::Closed`] without spawning work.
84    pub(crate) fn closed(metric: MetricHandle) -> Self {
85        // Mark the task as finished immediately so gauges remain accurate.
86        metric.finish();
87
88        // Create a receiver that will yield `Err(Error::Closed)` when awaited.
89        let (sender, receiver) = oneshot::channel();
90        drop(sender);
91
92        Self {
93            abort_handle: None,
94            receiver,
95            metric,
96        }
97    }
98
99    /// Abort the task (if not blocking).
100    pub fn abort(&self) {
101        // Get abort handle and abort the task
102        let Some(abort_handle) = &self.abort_handle else {
103            return;
104        };
105        abort_handle.abort();
106
107        // We might never poll the future again after aborting it, so run the
108        // metric cleanup right away
109        self.metric.finish();
110    }
111
112    /// Returns a helper that aborts the task and updates metrics consistently.
113    pub(crate) fn aborter(&self) -> Option<Aborter> {
114        self.abort_handle
115            .clone()
116            .map(|inner| Aborter::new(inner, self.metric.clone()))
117    }
118}
119
120impl<T> Future for Handle<T>
121where
122    T: Send + 'static,
123{
124    type Output = Result<T, Error>;
125
126    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127        Pin::new(&mut self.receiver)
128            .poll(cx)
129            .map(|result| result.unwrap_or_else(|_| Err(Error::Closed)))
130    }
131}
132
133/// Tracks the metric state associated with a spawned task handle.
134#[derive(Clone)]
135pub(crate) struct MetricHandle {
136    gauge: Gauge,
137    finished: Arc<Once>,
138}
139
140impl MetricHandle {
141    /// Increments the supplied gauge and returns a handle responsible for
142    /// eventually decrementing it.
143    pub(crate) fn new(gauge: Gauge) -> Self {
144        gauge.inc();
145
146        Self {
147            gauge,
148            finished: Arc::new(Once::new()),
149        }
150    }
151
152    /// Marks the task handle as completed and decrements the gauge once.
153    ///
154    /// This method is idempotent, additional calls are ignored so completion
155    /// and abort paths can invoke it independently.
156    pub(crate) fn finish(&self) {
157        let gauge = self.gauge.clone();
158        self.finished.call_once(move || {
159            gauge.dec();
160        });
161    }
162}
163
164/// A panic emitted by a spawned task.
165pub type Panic = Box<dyn Any + Send + 'static>;
166
167/// Notifies the runtime when a spawned task panics, so it can propagate the failure.
168#[derive(Clone)]
169pub(crate) struct Panicker {
170    catch: bool,
171    sender: Arc<Mutex<Option<oneshot::Sender<Panic>>>>,
172}
173
174impl Panicker {
175    /// Creates a new [Panicker].
176    pub(crate) fn new(catch: bool) -> (Self, Panicked) {
177        let (sender, receiver) = oneshot::channel();
178        let panicker = Self {
179            catch,
180            sender: Arc::new(Mutex::new(Some(sender))),
181        };
182        let panicked = Panicked { receiver };
183        (panicker, panicked)
184    }
185
186    /// Returns whether the [Panicker] is configured to catch panics.
187    pub(crate) fn catch(&self) -> bool {
188        self.catch
189    }
190
191    /// Notifies the [Panicker] that a panic has occurred.
192    pub(crate) fn notify(&self, panic: Box<dyn Any + Send + 'static>) {
193        // Log the panic
194        let err = extract_panic_message(&*panic);
195        error!(?err, "task panicked");
196
197        // If we are catching panics, just return
198        if self.catch {
199            return;
200        }
201
202        // If we've already sent a panic, ignore the new one
203        let mut sender = self.sender.lock().unwrap();
204        let Some(sender) = sender.take() else {
205            return;
206        };
207
208        // Send the panic
209        let _ = sender.send(panic);
210    }
211}
212
213/// A handle that will be notified when a panic occurs.
214pub(crate) struct Panicked {
215    receiver: oneshot::Receiver<Panic>,
216}
217
218impl Panicked {
219    /// Polls a task that should be interrupted by a panic.
220    pub(crate) async fn interrupt<Fut>(self, task: Fut) -> Fut::Output
221    where
222        Fut: Future,
223    {
224        // Wait for task to complete or panic
225        let panicked = self.receiver;
226        pin_mut!(panicked);
227        pin_mut!(task);
228        match select(panicked, task).await {
229            Either::Left((panic, task)) => match panic {
230                // If there is a panic, resume the unwind
231                Ok(panic) => {
232                    resume_unwind(panic);
233                }
234                // If there can never be a panic (oneshot is closed), wait for the task to complete
235                // and return the output
236                Err(_) => task.await,
237            },
238            Either::Right((output, _)) => {
239                // Return the output
240                output
241            }
242        }
243    }
244}
245
246/// Couples an [`AbortHandle`] with its metric handle so aborted tasks clean up gauges.
247pub(crate) struct Aborter {
248    inner: AbortHandle,
249    metric: MetricHandle,
250}
251
252impl Aborter {
253    /// Creates a new [`Aborter`] for the provided abort handle and metric handle.
254    pub(crate) fn new(inner: AbortHandle, metric: MetricHandle) -> Self {
255        Self { inner, metric }
256    }
257
258    /// Aborts the task and records completion in the metric gauge.
259    pub(crate) fn abort(self) {
260        self.inner.abort();
261
262        // We might never poll the future again after aborting it, so run the
263        // metric cleanup right away
264        self.metric.finish();
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use crate::{deterministic, Metrics, Runner, Spawner};
271    use futures::future;
272
273    const METRIC_PREFIX: &str = "runtime_tasks_running{";
274
275    fn running_tasks_for_label(metrics: &str, label: &str) -> Option<u64> {
276        let label_fragment = format!("name=\"{label}\"");
277        metrics.lines().find_map(|line| {
278            if line.starts_with(METRIC_PREFIX) && line.contains(&label_fragment) {
279                line.rsplit_once(' ')
280                    .and_then(|(_, value)| value.trim().parse::<u64>().ok())
281            } else {
282                None
283            }
284        })
285    }
286
287    #[test]
288    fn tasks_running_decreased_after_completion() {
289        const LABEL: &str = "tasks_running_after_completion";
290
291        let runner = deterministic::Runner::default();
292        runner.start(|context| async move {
293            let context = context.with_label(LABEL);
294            let handle = context.clone().spawn(|_| async move { "done" });
295
296            let metrics = context.encode();
297            assert_eq!(
298                running_tasks_for_label(&metrics, LABEL),
299                Some(1),
300                "expected tasks_running gauge to be 1 before completion: {metrics}",
301            );
302
303            let output = handle.await.expect("task failed");
304            assert_eq!(output, "done");
305
306            let metrics = context.encode();
307            assert_eq!(
308                running_tasks_for_label(&metrics, LABEL),
309                Some(0),
310                "expected tasks_running gauge to return to 0 after completion: {metrics}",
311            );
312        });
313    }
314
315    #[test]
316    fn tasks_running_unchanged_when_handle_dropped() {
317        const LABEL: &str = "tasks_running_unchanged";
318
319        let runner = deterministic::Runner::default();
320        runner.start(|context| async move {
321            let context = context.with_label(LABEL);
322            let handle = context.clone().spawn(|_| async move {
323                future::pending::<()>().await;
324            });
325
326            let metrics = context.encode();
327            assert_eq!(
328                running_tasks_for_label(&metrics, LABEL),
329                Some(1),
330                "expected tasks_running gauge to be 1 before dropping handle: {metrics}",
331            );
332
333            drop(handle);
334
335            let metrics = context.encode();
336            assert_eq!(
337                running_tasks_for_label(&metrics, LABEL),
338                Some(1),
339                "dropping handle should not finish metrics: {metrics}",
340            );
341        });
342    }
343
344    #[test]
345    fn tasks_running_decreased_immediately_on_abort_via_handle() {
346        const LABEL: &str = "tasks_running_abort_via_handle";
347
348        let runner = deterministic::Runner::default();
349        runner.start(|context| async move {
350            let context = context.with_label(LABEL);
351            let handle = context.clone().spawn(|_| async move {
352                future::pending::<()>().await;
353            });
354
355            let metrics = context.encode();
356            assert_eq!(
357                running_tasks_for_label(&metrics, LABEL),
358                Some(1),
359                "expected tasks_running gauge to be 1 before abort: {metrics}",
360            );
361
362            handle.abort();
363
364            let metrics = context.encode();
365            assert_eq!(
366                running_tasks_for_label(&metrics, LABEL),
367                Some(0),
368                "expected tasks_running gauge to return to 0 after abort: {metrics}",
369            );
370        });
371    }
372
373    #[test]
374    fn tasks_running_decreased_after_blocking_completion() {
375        const LABEL: &str = "tasks_running_after_blocking_completion";
376
377        let runner = deterministic::Runner::default();
378        runner.start(|context| async move {
379            let context = context.with_label(LABEL);
380
381            let blocking_handle = context.clone().shared(true).spawn(|_| async move {
382                // Simulate some blocking work
383                42
384            });
385
386            let metrics = context.encode();
387            assert_eq!(
388                running_tasks_for_label(&metrics, LABEL),
389                Some(1),
390                "expected tasks_running gauge to be 1 while blocking task runs: {metrics}",
391            );
392
393            let result = blocking_handle.await.expect("blocking task failed");
394            assert_eq!(result, 42);
395
396            let metrics = context.encode();
397            assert_eq!(
398                running_tasks_for_label(&metrics, LABEL),
399                Some(0),
400                "expected tasks_running gauge to return to 0 after blocking task completes: {metrics}",
401            );
402        });
403    }
404
405    #[test]
406    fn tasks_running_decreased_immediately_on_abort_via_aborter() {
407        const LABEL: &str = "tasks_running_abort_via_aborter";
408
409        let runner = deterministic::Runner::default();
410        runner.start(|context| async move {
411            let context = context.with_label(LABEL);
412            let handle = context.clone().spawn(|_| async move {
413                future::pending::<()>().await;
414            });
415
416            let metrics = context.encode();
417            assert_eq!(
418                running_tasks_for_label(&metrics, LABEL),
419                Some(1),
420                "expected tasks_running gauge to be 1 before abort: {metrics}",
421            );
422
423            let aborter = handle.aborter().unwrap();
424            aborter.abort();
425
426            let metrics = context.encode();
427            assert_eq!(
428                running_tasks_for_label(&metrics, LABEL),
429                Some(0),
430                "expected tasks_running gauge to return to 0 after abort: {metrics}",
431            );
432        });
433    }
434}