commonware_runtime/utils/
handle.rs1use crate::{supervision::Tree, utils::extract_panic_message, Error};
2use commonware_utils::{
3 channel::oneshot,
4 sync::{Mutex, Once},
5};
6use futures::{
7 future::{select, Either},
8 pin_mut,
9 stream::{AbortHandle, Abortable, Aborted},
10 FutureExt as _,
11};
12use prometheus_client::metrics::gauge::Gauge;
13use std::{
14 any::Any,
15 future::Future,
16 panic::{resume_unwind, AssertUnwindSafe},
17 pin::Pin,
18 sync::Arc,
19 task::{Context, Poll},
20};
21use tracing::error;
22
23pub struct Handle<T>
25where
26 T: Send + 'static,
27{
28 abort_handle: Option<AbortHandle>,
29 receiver: oneshot::Receiver<Result<T, Error>>,
30 metric: MetricHandle,
31}
32
33impl<T> Handle<T>
34where
35 T: Send + 'static,
36{
37 #[inline(always)]
38 pub(crate) fn init<F>(
39 f: F,
40 metric: MetricHandle,
41 panicker: Panicker,
42 tree: Arc<Tree>,
43 ) -> (impl Future<Output = ()>, Self)
44 where
45 F: Future<Output = T> + Send + 'static,
46 {
47 let (sender, receiver) = oneshot::channel();
49 let (abort_handle, abort_registration) = AbortHandle::new_pair();
50
51 let metric_handle = metric.clone();
57 let task = async move {
58 let result =
60 Abortable::new(AssertUnwindSafe(f).catch_unwind(), abort_registration).await;
61
62 match result {
64 Ok(Ok(result)) => {
65 let _ = sender.send(Ok(result));
66 }
67 Ok(Err(panic)) => {
68 panicker.notify(panic);
69 let _ = sender.send(Err(Error::Exited));
70 }
71 Err(Aborted) => {}
72 }
73
74 tree.abort();
76
77 metric_handle.finish();
79 };
80
81 (
82 task,
83 Self {
84 abort_handle: Some(abort_handle),
85 receiver,
86 metric,
87 },
88 )
89 }
90
91 pub(crate) fn closed(metric: MetricHandle) -> Self {
93 metric.finish();
95
96 let (sender, receiver) = oneshot::channel();
98 drop(sender);
99
100 Self {
101 abort_handle: None,
102 receiver,
103 metric,
104 }
105 }
106
107 pub fn abort(&self) {
109 let Some(abort_handle) = &self.abort_handle else {
111 return;
112 };
113 abort_handle.abort();
114
115 self.metric.finish();
118 }
119
120 pub(crate) fn aborter(&self) -> Option<Aborter> {
122 self.abort_handle
123 .clone()
124 .map(|inner| Aborter::new(inner, self.metric.clone()))
125 }
126}
127
128impl<T> Future for Handle<T>
129where
130 T: Send + 'static,
131{
132 type Output = Result<T, Error>;
133
134 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135 Pin::new(&mut self.receiver)
136 .poll(cx)
137 .map(|result| result.unwrap_or_else(|_| Err(Error::Closed)))
138 }
139}
140
141#[derive(Clone)]
143pub(crate) struct MetricHandle {
144 gauge: Gauge,
145 finished: Arc<Once>,
146}
147
148impl MetricHandle {
149 pub(crate) fn new(gauge: Gauge) -> Self {
152 gauge.inc();
153
154 Self {
155 gauge,
156 finished: Arc::new(Once::new()),
157 }
158 }
159
160 pub(crate) fn finish(&self) {
165 let gauge = self.gauge.clone();
166 self.finished.call_once(move || {
167 gauge.dec();
168 });
169 }
170}
171
172pub type Panic = Box<dyn Any + Send + 'static>;
174
175#[derive(Clone)]
177pub(crate) struct Panicker {
178 catch: bool,
179 sender: Arc<Mutex<Option<oneshot::Sender<Panic>>>>,
180}
181
182impl Panicker {
183 pub(crate) fn new(catch: bool) -> (Self, Panicked) {
185 let (sender, receiver) = oneshot::channel();
186 let panicker = Self {
187 catch,
188 sender: Arc::new(Mutex::new(Some(sender))),
189 };
190 let panicked = Panicked { receiver };
191 (panicker, panicked)
192 }
193
194 #[commonware_macros::stability(ALPHA)]
196 pub(crate) const fn catch(&self) -> bool {
197 self.catch
198 }
199
200 pub(crate) fn notify(&self, panic: Box<dyn Any + Send + 'static>) {
202 let err = extract_panic_message(&*panic);
204 error!(?err, "task panicked");
205
206 if self.catch {
208 return;
209 }
210
211 let mut sender = self.sender.lock();
213 let Some(sender) = sender.take() else {
214 return;
215 };
216
217 let _ = sender.send(panic);
219 }
220}
221
222pub(crate) struct Panicked {
224 receiver: oneshot::Receiver<Panic>,
225}
226
227impl Panicked {
228 pub(crate) async fn interrupt<Fut>(self, task: Fut) -> Fut::Output
230 where
231 Fut: Future,
232 {
233 let panicked = self.receiver;
235 pin_mut!(panicked);
236 pin_mut!(task);
237 match select(panicked, task).await {
238 Either::Left((panic, task)) => match panic {
239 Ok(panic) => {
241 resume_unwind(panic);
242 }
243 Err(_) => task.await,
246 },
247 Either::Right((output, _)) => {
248 output
250 }
251 }
252 }
253}
254
255pub(crate) struct Aborter {
257 inner: AbortHandle,
258 metric: MetricHandle,
259}
260
261impl Aborter {
262 pub(crate) const fn new(inner: AbortHandle, metric: MetricHandle) -> Self {
264 Self { inner, metric }
265 }
266
267 pub(crate) fn abort(self) {
269 self.inner.abort();
270
271 self.metric.finish();
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use crate::{deterministic, Metrics, Runner, Spawner};
280 use futures::future;
281
282 const METRIC_PREFIX: &str = "runtime_tasks_running{";
283
284 fn running_tasks_for_label(metrics: &str, label: &str) -> Option<u64> {
285 let label_fragment = format!("name=\"{label}\"");
286 metrics.lines().find_map(|line| {
287 if line.starts_with(METRIC_PREFIX) && line.contains(&label_fragment) {
288 line.rsplit_once(' ')
289 .and_then(|(_, value)| value.trim().parse::<u64>().ok())
290 } else {
291 None
292 }
293 })
294 }
295
296 #[test]
297 fn tasks_running_decreased_after_completion() {
298 const LABEL: &str = "tasks_running_after_completion";
299
300 let runner = deterministic::Runner::default();
301 runner.start(|context| async move {
302 let context = context.with_label(LABEL);
303 let handle = context.clone().spawn(|_| async move { "done" });
304
305 let metrics = context.encode();
306 assert_eq!(
307 running_tasks_for_label(&metrics, LABEL),
308 Some(1),
309 "expected tasks_running gauge to be 1 before completion: {metrics}",
310 );
311
312 let output = handle.await.expect("task failed");
313 assert_eq!(output, "done");
314
315 let metrics = context.encode();
316 assert_eq!(
317 running_tasks_for_label(&metrics, LABEL),
318 Some(0),
319 "expected tasks_running gauge to return to 0 after completion: {metrics}",
320 );
321 });
322 }
323
324 #[test]
325 fn tasks_running_unchanged_when_handle_dropped() {
326 const LABEL: &str = "tasks_running_unchanged";
327
328 let runner = deterministic::Runner::default();
329 runner.start(|context| async move {
330 let context = context.with_label(LABEL);
331 let handle = context.clone().spawn(|_| async move {
332 future::pending::<()>().await;
333 });
334
335 let metrics = context.encode();
336 assert_eq!(
337 running_tasks_for_label(&metrics, LABEL),
338 Some(1),
339 "expected tasks_running gauge to be 1 before dropping handle: {metrics}",
340 );
341
342 drop(handle);
343
344 let metrics = context.encode();
345 assert_eq!(
346 running_tasks_for_label(&metrics, LABEL),
347 Some(1),
348 "dropping handle should not finish metrics: {metrics}",
349 );
350 });
351 }
352
353 #[test]
354 fn tasks_running_decreased_immediately_on_abort_via_handle() {
355 const LABEL: &str = "tasks_running_abort_via_handle";
356
357 let runner = deterministic::Runner::default();
358 runner.start(|context| async move {
359 let context = context.with_label(LABEL);
360 let handle = context.clone().spawn(|_| async move {
361 future::pending::<()>().await;
362 });
363
364 let metrics = context.encode();
365 assert_eq!(
366 running_tasks_for_label(&metrics, LABEL),
367 Some(1),
368 "expected tasks_running gauge to be 1 before abort: {metrics}",
369 );
370
371 handle.abort();
372
373 let metrics = context.encode();
374 assert_eq!(
375 running_tasks_for_label(&metrics, LABEL),
376 Some(0),
377 "expected tasks_running gauge to return to 0 after abort: {metrics}",
378 );
379 });
380 }
381
382 #[test]
383 fn tasks_running_decreased_after_blocking_completion() {
384 const LABEL: &str = "tasks_running_after_blocking_completion";
385
386 let runner = deterministic::Runner::default();
387 runner.start(|context| async move {
388 let context = context.with_label(LABEL);
389
390 let blocking_handle = context.clone().shared(true).spawn(|_| async move {
391 42
393 });
394
395 let metrics = context.encode();
396 assert_eq!(
397 running_tasks_for_label(&metrics, LABEL),
398 Some(1),
399 "expected tasks_running gauge to be 1 while blocking task runs: {metrics}",
400 );
401
402 let result = blocking_handle.await.expect("blocking task failed");
403 assert_eq!(result, 42);
404
405 let metrics = context.encode();
406 assert_eq!(
407 running_tasks_for_label(&metrics, LABEL),
408 Some(0),
409 "expected tasks_running gauge to return to 0 after blocking task completes: {metrics}",
410 );
411 });
412 }
413
414 #[test]
415 fn tasks_running_decreased_immediately_on_abort_via_aborter() {
416 const LABEL: &str = "tasks_running_abort_via_aborter";
417
418 let runner = deterministic::Runner::default();
419 runner.start(|context| async move {
420 let context = context.with_label(LABEL);
421 let handle = context.clone().spawn(|_| async move {
422 future::pending::<()>().await;
423 });
424
425 let metrics = context.encode();
426 assert_eq!(
427 running_tasks_for_label(&metrics, LABEL),
428 Some(1),
429 "expected tasks_running gauge to be 1 before abort: {metrics}",
430 );
431
432 let aborter = handle.aborter().unwrap();
433 aborter.abort();
434
435 let metrics = context.encode();
436 assert_eq!(
437 running_tasks_for_label(&metrics, LABEL),
438 Some(0),
439 "expected tasks_running gauge to return to 0 after abort: {metrics}",
440 );
441 });
442 }
443}