commonware_runtime/utils/
handle.rs1use crate::{supervision::Tree, utils::extract_panic_message, Error};
2use commonware_utils::channel::oneshot;
3use futures::{
4 future::{select, Either},
5 pin_mut,
6 stream::{AbortHandle, Abortable},
7 FutureExt as _,
8};
9use prometheus_client::metrics::gauge::Gauge;
10use std::{
11 any::Any,
12 future::Future,
13 panic::{resume_unwind, AssertUnwindSafe},
14 pin::Pin,
15 sync::{Arc, Mutex, Once},
16 task::{Context, Poll},
17};
18use tracing::error;
19
20pub struct Handle<T>
22where
23 T: Send + 'static,
24{
25 abort_handle: Option<AbortHandle>,
26 receiver: oneshot::Receiver<Result<T, Error>>,
27 metric: MetricHandle,
28}
29
30impl<T> Handle<T>
31where
32 T: Send + 'static,
33{
34 pub(crate) fn init<F>(
35 f: F,
36 metric: MetricHandle,
37 panicker: Panicker,
38 tree: Arc<Tree>,
39 ) -> (impl Future<Output = ()>, Self)
40 where
41 F: Future<Output = T> + Send + 'static,
42 {
43 let (sender, receiver) = oneshot::channel();
45 let (abort_handle, abort_registration) = AbortHandle::new_pair();
46
47 let wrapped = async move {
49 let result = AssertUnwindSafe(f).catch_unwind().await;
51
52 let result = match result {
54 Ok(result) => Ok(result),
55 Err(panic) => {
56 panicker.notify(panic);
57 Err(Error::Exited)
58 }
59 };
60 let _ = sender.send(result);
61 };
62
63 let metric_handle = metric.clone();
65 let abortable = Abortable::new(wrapped, abort_registration).map(move |_| {
66 tree.abort();
68
69 metric_handle.finish();
71 });
72
73 (
74 abortable,
75 Self {
76 abort_handle: Some(abort_handle),
77 receiver,
78 metric,
79 },
80 )
81 }
82
83 pub(crate) fn closed(metric: MetricHandle) -> Self {
85 metric.finish();
87
88 let (sender, receiver) = oneshot::channel();
90 drop(sender);
91
92 Self {
93 abort_handle: None,
94 receiver,
95 metric,
96 }
97 }
98
99 pub fn abort(&self) {
101 let Some(abort_handle) = &self.abort_handle else {
103 return;
104 };
105 abort_handle.abort();
106
107 self.metric.finish();
110 }
111
112 pub(crate) fn aborter(&self) -> Option<Aborter> {
114 self.abort_handle
115 .clone()
116 .map(|inner| Aborter::new(inner, self.metric.clone()))
117 }
118}
119
120impl<T> Future for Handle<T>
121where
122 T: Send + 'static,
123{
124 type Output = Result<T, Error>;
125
126 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127 Pin::new(&mut self.receiver)
128 .poll(cx)
129 .map(|result| result.unwrap_or_else(|_| Err(Error::Closed)))
130 }
131}
132
133#[derive(Clone)]
135pub(crate) struct MetricHandle {
136 gauge: Gauge,
137 finished: Arc<Once>,
138}
139
140impl MetricHandle {
141 pub(crate) fn new(gauge: Gauge) -> Self {
144 gauge.inc();
145
146 Self {
147 gauge,
148 finished: Arc::new(Once::new()),
149 }
150 }
151
152 pub(crate) fn finish(&self) {
157 let gauge = self.gauge.clone();
158 self.finished.call_once(move || {
159 gauge.dec();
160 });
161 }
162}
163
164pub type Panic = Box<dyn Any + Send + 'static>;
166
167#[derive(Clone)]
169pub(crate) struct Panicker {
170 catch: bool,
171 sender: Arc<Mutex<Option<oneshot::Sender<Panic>>>>,
172}
173
174impl Panicker {
175 pub(crate) fn new(catch: bool) -> (Self, Panicked) {
177 let (sender, receiver) = oneshot::channel();
178 let panicker = Self {
179 catch,
180 sender: Arc::new(Mutex::new(Some(sender))),
181 };
182 let panicked = Panicked { receiver };
183 (panicker, panicked)
184 }
185
186 #[commonware_macros::stability(ALPHA)]
188 pub(crate) const fn catch(&self) -> bool {
189 self.catch
190 }
191
192 pub(crate) fn notify(&self, panic: Box<dyn Any + Send + 'static>) {
194 let err = extract_panic_message(&*panic);
196 error!(?err, "task panicked");
197
198 if self.catch {
200 return;
201 }
202
203 let mut sender = self.sender.lock().unwrap();
205 let Some(sender) = sender.take() else {
206 return;
207 };
208
209 let _ = sender.send(panic);
211 }
212}
213
214pub(crate) struct Panicked {
216 receiver: oneshot::Receiver<Panic>,
217}
218
219impl Panicked {
220 pub(crate) async fn interrupt<Fut>(self, task: Fut) -> Fut::Output
222 where
223 Fut: Future,
224 {
225 let panicked = self.receiver;
227 pin_mut!(panicked);
228 pin_mut!(task);
229 match select(panicked, task).await {
230 Either::Left((panic, task)) => match panic {
231 Ok(panic) => {
233 resume_unwind(panic);
234 }
235 Err(_) => task.await,
238 },
239 Either::Right((output, _)) => {
240 output
242 }
243 }
244 }
245}
246
247pub(crate) struct Aborter {
249 inner: AbortHandle,
250 metric: MetricHandle,
251}
252
253impl Aborter {
254 pub(crate) const fn new(inner: AbortHandle, metric: MetricHandle) -> Self {
256 Self { inner, metric }
257 }
258
259 pub(crate) fn abort(self) {
261 self.inner.abort();
262
263 self.metric.finish();
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use crate::{deterministic, Metrics, Runner, Spawner};
272 use futures::future;
273
274 const METRIC_PREFIX: &str = "runtime_tasks_running{";
275
276 fn running_tasks_for_label(metrics: &str, label: &str) -> Option<u64> {
277 let label_fragment = format!("name=\"{label}\"");
278 metrics.lines().find_map(|line| {
279 if line.starts_with(METRIC_PREFIX) && line.contains(&label_fragment) {
280 line.rsplit_once(' ')
281 .and_then(|(_, value)| value.trim().parse::<u64>().ok())
282 } else {
283 None
284 }
285 })
286 }
287
288 #[test]
289 fn tasks_running_decreased_after_completion() {
290 const LABEL: &str = "tasks_running_after_completion";
291
292 let runner = deterministic::Runner::default();
293 runner.start(|context| async move {
294 let context = context.with_label(LABEL);
295 let handle = context.clone().spawn(|_| async move { "done" });
296
297 let metrics = context.encode();
298 assert_eq!(
299 running_tasks_for_label(&metrics, LABEL),
300 Some(1),
301 "expected tasks_running gauge to be 1 before completion: {metrics}",
302 );
303
304 let output = handle.await.expect("task failed");
305 assert_eq!(output, "done");
306
307 let metrics = context.encode();
308 assert_eq!(
309 running_tasks_for_label(&metrics, LABEL),
310 Some(0),
311 "expected tasks_running gauge to return to 0 after completion: {metrics}",
312 );
313 });
314 }
315
316 #[test]
317 fn tasks_running_unchanged_when_handle_dropped() {
318 const LABEL: &str = "tasks_running_unchanged";
319
320 let runner = deterministic::Runner::default();
321 runner.start(|context| async move {
322 let context = context.with_label(LABEL);
323 let handle = context.clone().spawn(|_| async move {
324 future::pending::<()>().await;
325 });
326
327 let metrics = context.encode();
328 assert_eq!(
329 running_tasks_for_label(&metrics, LABEL),
330 Some(1),
331 "expected tasks_running gauge to be 1 before dropping handle: {metrics}",
332 );
333
334 drop(handle);
335
336 let metrics = context.encode();
337 assert_eq!(
338 running_tasks_for_label(&metrics, LABEL),
339 Some(1),
340 "dropping handle should not finish metrics: {metrics}",
341 );
342 });
343 }
344
345 #[test]
346 fn tasks_running_decreased_immediately_on_abort_via_handle() {
347 const LABEL: &str = "tasks_running_abort_via_handle";
348
349 let runner = deterministic::Runner::default();
350 runner.start(|context| async move {
351 let context = context.with_label(LABEL);
352 let handle = context.clone().spawn(|_| async move {
353 future::pending::<()>().await;
354 });
355
356 let metrics = context.encode();
357 assert_eq!(
358 running_tasks_for_label(&metrics, LABEL),
359 Some(1),
360 "expected tasks_running gauge to be 1 before abort: {metrics}",
361 );
362
363 handle.abort();
364
365 let metrics = context.encode();
366 assert_eq!(
367 running_tasks_for_label(&metrics, LABEL),
368 Some(0),
369 "expected tasks_running gauge to return to 0 after abort: {metrics}",
370 );
371 });
372 }
373
374 #[test]
375 fn tasks_running_decreased_after_blocking_completion() {
376 const LABEL: &str = "tasks_running_after_blocking_completion";
377
378 let runner = deterministic::Runner::default();
379 runner.start(|context| async move {
380 let context = context.with_label(LABEL);
381
382 let blocking_handle = context.clone().shared(true).spawn(|_| async move {
383 42
385 });
386
387 let metrics = context.encode();
388 assert_eq!(
389 running_tasks_for_label(&metrics, LABEL),
390 Some(1),
391 "expected tasks_running gauge to be 1 while blocking task runs: {metrics}",
392 );
393
394 let result = blocking_handle.await.expect("blocking task failed");
395 assert_eq!(result, 42);
396
397 let metrics = context.encode();
398 assert_eq!(
399 running_tasks_for_label(&metrics, LABEL),
400 Some(0),
401 "expected tasks_running gauge to return to 0 after blocking task completes: {metrics}",
402 );
403 });
404 }
405
406 #[test]
407 fn tasks_running_decreased_immediately_on_abort_via_aborter() {
408 const LABEL: &str = "tasks_running_abort_via_aborter";
409
410 let runner = deterministic::Runner::default();
411 runner.start(|context| async move {
412 let context = context.with_label(LABEL);
413 let handle = context.clone().spawn(|_| async move {
414 future::pending::<()>().await;
415 });
416
417 let metrics = context.encode();
418 assert_eq!(
419 running_tasks_for_label(&metrics, LABEL),
420 Some(1),
421 "expected tasks_running gauge to be 1 before abort: {metrics}",
422 );
423
424 let aborter = handle.aborter().unwrap();
425 aborter.abort();
426
427 let metrics = context.encode();
428 assert_eq!(
429 running_tasks_for_label(&metrics, LABEL),
430 Some(0),
431 "expected tasks_running gauge to return to 0 after abort: {metrics}",
432 );
433 });
434 }
435}