commonware_runtime/utils/
handle.rs1use crate::{supervision::Tree, utils::extract_panic_message, Error};
2use futures::{
3 channel::oneshot,
4 future::{select, Either},
5 pin_mut,
6 stream::{AbortHandle, Abortable},
7 FutureExt as _,
8};
9use prometheus_client::metrics::gauge::Gauge;
10use std::{
11 any::Any,
12 future::Future,
13 panic::{resume_unwind, AssertUnwindSafe},
14 pin::Pin,
15 sync::{Arc, Mutex, Once},
16 task::{Context, Poll},
17};
18use tracing::error;
19
20pub struct Handle<T>
22where
23 T: Send + 'static,
24{
25 abort_handle: Option<AbortHandle>,
26 receiver: oneshot::Receiver<Result<T, Error>>,
27 metric: MetricHandle,
28}
29
30impl<T> Handle<T>
31where
32 T: Send + 'static,
33{
34 pub(crate) fn init<F>(
35 f: F,
36 metric: MetricHandle,
37 panicker: Panicker,
38 tree: Arc<Tree>,
39 ) -> (impl Future<Output = ()>, Self)
40 where
41 F: Future<Output = T> + Send + 'static,
42 {
43 let (sender, receiver) = oneshot::channel();
45 let (abort_handle, abort_registration) = AbortHandle::new_pair();
46
47 let wrapped = async move {
49 let result = AssertUnwindSafe(f).catch_unwind().await;
51
52 let result = match result {
54 Ok(result) => Ok(result),
55 Err(panic) => {
56 panicker.notify(panic);
57 Err(Error::Exited)
58 }
59 };
60 let _ = sender.send(result);
61 };
62
63 let metric_handle = metric.clone();
65 let abortable = Abortable::new(wrapped, abort_registration).map(move |_| {
66 tree.abort();
68
69 metric_handle.finish();
71 });
72
73 (
74 abortable,
75 Self {
76 abort_handle: Some(abort_handle),
77 receiver,
78 metric,
79 },
80 )
81 }
82
83 pub(crate) fn closed(metric: MetricHandle) -> Self {
85 metric.finish();
87
88 let (sender, receiver) = oneshot::channel();
90 drop(sender);
91
92 Self {
93 abort_handle: None,
94 receiver,
95 metric,
96 }
97 }
98
99 pub fn abort(&self) {
101 let Some(abort_handle) = &self.abort_handle else {
103 return;
104 };
105 abort_handle.abort();
106
107 self.metric.finish();
110 }
111
112 pub(crate) fn aborter(&self) -> Option<Aborter> {
114 self.abort_handle
115 .clone()
116 .map(|inner| Aborter::new(inner, self.metric.clone()))
117 }
118}
119
120impl<T> Future for Handle<T>
121where
122 T: Send + 'static,
123{
124 type Output = Result<T, Error>;
125
126 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127 Pin::new(&mut self.receiver)
128 .poll(cx)
129 .map(|result| result.unwrap_or_else(|_| Err(Error::Closed)))
130 }
131}
132
133#[derive(Clone)]
135pub(crate) struct MetricHandle {
136 gauge: Gauge,
137 finished: Arc<Once>,
138}
139
140impl MetricHandle {
141 pub(crate) fn new(gauge: Gauge) -> Self {
144 gauge.inc();
145
146 Self {
147 gauge,
148 finished: Arc::new(Once::new()),
149 }
150 }
151
152 pub(crate) fn finish(&self) {
157 let gauge = self.gauge.clone();
158 self.finished.call_once(move || {
159 gauge.dec();
160 });
161 }
162}
163
164pub type Panic = Box<dyn Any + Send + 'static>;
166
167#[derive(Clone)]
169pub(crate) struct Panicker {
170 catch: bool,
171 sender: Arc<Mutex<Option<oneshot::Sender<Panic>>>>,
172}
173
174impl Panicker {
175 pub(crate) fn new(catch: bool) -> (Self, Panicked) {
177 let (sender, receiver) = oneshot::channel();
178 let panicker = Self {
179 catch,
180 sender: Arc::new(Mutex::new(Some(sender))),
181 };
182 let panicked = Panicked { receiver };
183 (panicker, panicked)
184 }
185
186 pub(crate) fn catch(&self) -> bool {
188 self.catch
189 }
190
191 pub(crate) fn notify(&self, panic: Box<dyn Any + Send + 'static>) {
193 let err = extract_panic_message(&*panic);
195 error!(?err, "task panicked");
196
197 if self.catch {
199 return;
200 }
201
202 let mut sender = self.sender.lock().unwrap();
204 let Some(sender) = sender.take() else {
205 return;
206 };
207
208 let _ = sender.send(panic);
210 }
211}
212
213pub(crate) struct Panicked {
215 receiver: oneshot::Receiver<Panic>,
216}
217
218impl Panicked {
219 pub(crate) async fn interrupt<Fut>(self, task: Fut) -> Fut::Output
221 where
222 Fut: Future,
223 {
224 let panicked = self.receiver;
226 pin_mut!(panicked);
227 pin_mut!(task);
228 match select(panicked, task).await {
229 Either::Left((panic, task)) => match panic {
230 Ok(panic) => {
232 resume_unwind(panic);
233 }
234 Err(_) => task.await,
237 },
238 Either::Right((output, _)) => {
239 output
241 }
242 }
243 }
244}
245
246pub(crate) struct Aborter {
248 inner: AbortHandle,
249 metric: MetricHandle,
250}
251
252impl Aborter {
253 pub(crate) fn new(inner: AbortHandle, metric: MetricHandle) -> Self {
255 Self { inner, metric }
256 }
257
258 pub(crate) fn abort(self) {
260 self.inner.abort();
261
262 self.metric.finish();
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use crate::{deterministic, Metrics, Runner, Spawner};
271 use futures::future;
272
273 const METRIC_PREFIX: &str = "runtime_tasks_running{";
274
275 fn running_tasks_for_label(metrics: &str, label: &str) -> Option<u64> {
276 let label_fragment = format!("name=\"{label}\"");
277 metrics.lines().find_map(|line| {
278 if line.starts_with(METRIC_PREFIX) && line.contains(&label_fragment) {
279 line.rsplit_once(' ')
280 .and_then(|(_, value)| value.trim().parse::<u64>().ok())
281 } else {
282 None
283 }
284 })
285 }
286
287 #[test]
288 fn tasks_running_decreased_after_completion() {
289 const LABEL: &str = "tasks_running_after_completion";
290
291 let runner = deterministic::Runner::default();
292 runner.start(|context| async move {
293 let context = context.with_label(LABEL);
294 let handle = context.clone().spawn(|_| async move { "done" });
295
296 let metrics = context.encode();
297 assert_eq!(
298 running_tasks_for_label(&metrics, LABEL),
299 Some(1),
300 "expected tasks_running gauge to be 1 before completion: {metrics}",
301 );
302
303 let output = handle.await.expect("task failed");
304 assert_eq!(output, "done");
305
306 let metrics = context.encode();
307 assert_eq!(
308 running_tasks_for_label(&metrics, LABEL),
309 Some(0),
310 "expected tasks_running gauge to return to 0 after completion: {metrics}",
311 );
312 });
313 }
314
315 #[test]
316 fn tasks_running_unchanged_when_handle_dropped() {
317 const LABEL: &str = "tasks_running_unchanged";
318
319 let runner = deterministic::Runner::default();
320 runner.start(|context| async move {
321 let context = context.with_label(LABEL);
322 let handle = context.clone().spawn(|_| async move {
323 future::pending::<()>().await;
324 });
325
326 let metrics = context.encode();
327 assert_eq!(
328 running_tasks_for_label(&metrics, LABEL),
329 Some(1),
330 "expected tasks_running gauge to be 1 before dropping handle: {metrics}",
331 );
332
333 drop(handle);
334
335 let metrics = context.encode();
336 assert_eq!(
337 running_tasks_for_label(&metrics, LABEL),
338 Some(1),
339 "dropping handle should not finish metrics: {metrics}",
340 );
341 });
342 }
343
344 #[test]
345 fn tasks_running_decreased_immediately_on_abort_via_handle() {
346 const LABEL: &str = "tasks_running_abort_via_handle";
347
348 let runner = deterministic::Runner::default();
349 runner.start(|context| async move {
350 let context = context.with_label(LABEL);
351 let handle = context.clone().spawn(|_| async move {
352 future::pending::<()>().await;
353 });
354
355 let metrics = context.encode();
356 assert_eq!(
357 running_tasks_for_label(&metrics, LABEL),
358 Some(1),
359 "expected tasks_running gauge to be 1 before abort: {metrics}",
360 );
361
362 handle.abort();
363
364 let metrics = context.encode();
365 assert_eq!(
366 running_tasks_for_label(&metrics, LABEL),
367 Some(0),
368 "expected tasks_running gauge to return to 0 after abort: {metrics}",
369 );
370 });
371 }
372
373 #[test]
374 fn tasks_running_decreased_after_blocking_completion() {
375 const LABEL: &str = "tasks_running_after_blocking_completion";
376
377 let runner = deterministic::Runner::default();
378 runner.start(|context| async move {
379 let context = context.with_label(LABEL);
380
381 let blocking_handle = context.clone().shared(true).spawn(|_| async move {
382 42
384 });
385
386 let metrics = context.encode();
387 assert_eq!(
388 running_tasks_for_label(&metrics, LABEL),
389 Some(1),
390 "expected tasks_running gauge to be 1 while blocking task runs: {metrics}",
391 );
392
393 let result = blocking_handle.await.expect("blocking task failed");
394 assert_eq!(result, 42);
395
396 let metrics = context.encode();
397 assert_eq!(
398 running_tasks_for_label(&metrics, LABEL),
399 Some(0),
400 "expected tasks_running gauge to return to 0 after blocking task completes: {metrics}",
401 );
402 });
403 }
404
405 #[test]
406 fn tasks_running_decreased_immediately_on_abort_via_aborter() {
407 const LABEL: &str = "tasks_running_abort_via_aborter";
408
409 let runner = deterministic::Runner::default();
410 runner.start(|context| async move {
411 let context = context.with_label(LABEL);
412 let handle = context.clone().spawn(|_| async move {
413 future::pending::<()>().await;
414 });
415
416 let metrics = context.encode();
417 assert_eq!(
418 running_tasks_for_label(&metrics, LABEL),
419 Some(1),
420 "expected tasks_running gauge to be 1 before abort: {metrics}",
421 );
422
423 let aborter = handle.aborter().unwrap();
424 aborter.abort();
425
426 let metrics = context.encode();
427 assert_eq!(
428 running_tasks_for_label(&metrics, LABEL),
429 Some(0),
430 "expected tasks_running gauge to return to 0 after abort: {metrics}",
431 );
432 });
433 }
434}