commonware_runtime/utils/
handle.rs1use crate::{supervision::Tree, utils::extract_panic_message, Error};
2use commonware_utils::{
3 channel::oneshot,
4 sync::{Mutex, Once},
5};
6use futures::{
7 future::{select, Either},
8 pin_mut,
9 stream::{AbortHandle, Abortable},
10 FutureExt as _,
11};
12use prometheus_client::metrics::gauge::Gauge;
13use std::{
14 any::Any,
15 future::Future,
16 panic::{resume_unwind, AssertUnwindSafe},
17 pin::Pin,
18 sync::Arc,
19 task::{Context, Poll},
20};
21use tracing::error;
22
23pub struct Handle<T>
25where
26 T: Send + 'static,
27{
28 abort_handle: Option<AbortHandle>,
29 receiver: oneshot::Receiver<Result<T, Error>>,
30 metric: MetricHandle,
31}
32
33impl<T> Handle<T>
34where
35 T: Send + 'static,
36{
37 pub(crate) fn init<F>(
38 f: F,
39 metric: MetricHandle,
40 panicker: Panicker,
41 tree: Arc<Tree>,
42 ) -> (impl Future<Output = ()>, Self)
43 where
44 F: Future<Output = T> + Send + 'static,
45 {
46 let (sender, receiver) = oneshot::channel();
48 let (abort_handle, abort_registration) = AbortHandle::new_pair();
49
50 let wrapped = async move {
52 let result = AssertUnwindSafe(f).catch_unwind().await;
54
55 let result = match result {
57 Ok(result) => Ok(result),
58 Err(panic) => {
59 panicker.notify(panic);
60 Err(Error::Exited)
61 }
62 };
63 let _ = sender.send(result);
64 };
65
66 let metric_handle = metric.clone();
68 let abortable = Abortable::new(wrapped, abort_registration).map(move |_| {
69 tree.abort();
71
72 metric_handle.finish();
74 });
75
76 (
77 abortable,
78 Self {
79 abort_handle: Some(abort_handle),
80 receiver,
81 metric,
82 },
83 )
84 }
85
86 pub(crate) fn closed(metric: MetricHandle) -> Self {
88 metric.finish();
90
91 let (sender, receiver) = oneshot::channel();
93 drop(sender);
94
95 Self {
96 abort_handle: None,
97 receiver,
98 metric,
99 }
100 }
101
102 pub fn abort(&self) {
104 let Some(abort_handle) = &self.abort_handle else {
106 return;
107 };
108 abort_handle.abort();
109
110 self.metric.finish();
113 }
114
115 pub(crate) fn aborter(&self) -> Option<Aborter> {
117 self.abort_handle
118 .clone()
119 .map(|inner| Aborter::new(inner, self.metric.clone()))
120 }
121}
122
123impl<T> Future for Handle<T>
124where
125 T: Send + 'static,
126{
127 type Output = Result<T, Error>;
128
129 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
130 Pin::new(&mut self.receiver)
131 .poll(cx)
132 .map(|result| result.unwrap_or_else(|_| Err(Error::Closed)))
133 }
134}
135
136#[derive(Clone)]
138pub(crate) struct MetricHandle {
139 gauge: Gauge,
140 finished: Arc<Once>,
141}
142
143impl MetricHandle {
144 pub(crate) fn new(gauge: Gauge) -> Self {
147 gauge.inc();
148
149 Self {
150 gauge,
151 finished: Arc::new(Once::new()),
152 }
153 }
154
155 pub(crate) fn finish(&self) {
160 let gauge = self.gauge.clone();
161 self.finished.call_once(move || {
162 gauge.dec();
163 });
164 }
165}
166
167pub type Panic = Box<dyn Any + Send + 'static>;
169
170#[derive(Clone)]
172pub(crate) struct Panicker {
173 catch: bool,
174 sender: Arc<Mutex<Option<oneshot::Sender<Panic>>>>,
175}
176
177impl Panicker {
178 pub(crate) fn new(catch: bool) -> (Self, Panicked) {
180 let (sender, receiver) = oneshot::channel();
181 let panicker = Self {
182 catch,
183 sender: Arc::new(Mutex::new(Some(sender))),
184 };
185 let panicked = Panicked { receiver };
186 (panicker, panicked)
187 }
188
189 #[commonware_macros::stability(ALPHA)]
191 pub(crate) const fn catch(&self) -> bool {
192 self.catch
193 }
194
195 pub(crate) fn notify(&self, panic: Box<dyn Any + Send + 'static>) {
197 let err = extract_panic_message(&*panic);
199 error!(?err, "task panicked");
200
201 if self.catch {
203 return;
204 }
205
206 let mut sender = self.sender.lock();
208 let Some(sender) = sender.take() else {
209 return;
210 };
211
212 let _ = sender.send(panic);
214 }
215}
216
217pub(crate) struct Panicked {
219 receiver: oneshot::Receiver<Panic>,
220}
221
222impl Panicked {
223 pub(crate) async fn interrupt<Fut>(self, task: Fut) -> Fut::Output
225 where
226 Fut: Future,
227 {
228 let panicked = self.receiver;
230 pin_mut!(panicked);
231 pin_mut!(task);
232 match select(panicked, task).await {
233 Either::Left((panic, task)) => match panic {
234 Ok(panic) => {
236 resume_unwind(panic);
237 }
238 Err(_) => task.await,
241 },
242 Either::Right((output, _)) => {
243 output
245 }
246 }
247 }
248}
249
250pub(crate) struct Aborter {
252 inner: AbortHandle,
253 metric: MetricHandle,
254}
255
256impl Aborter {
257 pub(crate) const fn new(inner: AbortHandle, metric: MetricHandle) -> Self {
259 Self { inner, metric }
260 }
261
262 pub(crate) fn abort(self) {
264 self.inner.abort();
265
266 self.metric.finish();
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use crate::{deterministic, Metrics, Runner, Spawner};
275 use futures::future;
276
277 const METRIC_PREFIX: &str = "runtime_tasks_running{";
278
279 fn running_tasks_for_label(metrics: &str, label: &str) -> Option<u64> {
280 let label_fragment = format!("name=\"{label}\"");
281 metrics.lines().find_map(|line| {
282 if line.starts_with(METRIC_PREFIX) && line.contains(&label_fragment) {
283 line.rsplit_once(' ')
284 .and_then(|(_, value)| value.trim().parse::<u64>().ok())
285 } else {
286 None
287 }
288 })
289 }
290
291 #[test]
292 fn tasks_running_decreased_after_completion() {
293 const LABEL: &str = "tasks_running_after_completion";
294
295 let runner = deterministic::Runner::default();
296 runner.start(|context| async move {
297 let context = context.with_label(LABEL);
298 let handle = context.clone().spawn(|_| async move { "done" });
299
300 let metrics = context.encode();
301 assert_eq!(
302 running_tasks_for_label(&metrics, LABEL),
303 Some(1),
304 "expected tasks_running gauge to be 1 before completion: {metrics}",
305 );
306
307 let output = handle.await.expect("task failed");
308 assert_eq!(output, "done");
309
310 let metrics = context.encode();
311 assert_eq!(
312 running_tasks_for_label(&metrics, LABEL),
313 Some(0),
314 "expected tasks_running gauge to return to 0 after completion: {metrics}",
315 );
316 });
317 }
318
319 #[test]
320 fn tasks_running_unchanged_when_handle_dropped() {
321 const LABEL: &str = "tasks_running_unchanged";
322
323 let runner = deterministic::Runner::default();
324 runner.start(|context| async move {
325 let context = context.with_label(LABEL);
326 let handle = context.clone().spawn(|_| async move {
327 future::pending::<()>().await;
328 });
329
330 let metrics = context.encode();
331 assert_eq!(
332 running_tasks_for_label(&metrics, LABEL),
333 Some(1),
334 "expected tasks_running gauge to be 1 before dropping handle: {metrics}",
335 );
336
337 drop(handle);
338
339 let metrics = context.encode();
340 assert_eq!(
341 running_tasks_for_label(&metrics, LABEL),
342 Some(1),
343 "dropping handle should not finish metrics: {metrics}",
344 );
345 });
346 }
347
348 #[test]
349 fn tasks_running_decreased_immediately_on_abort_via_handle() {
350 const LABEL: &str = "tasks_running_abort_via_handle";
351
352 let runner = deterministic::Runner::default();
353 runner.start(|context| async move {
354 let context = context.with_label(LABEL);
355 let handle = context.clone().spawn(|_| async move {
356 future::pending::<()>().await;
357 });
358
359 let metrics = context.encode();
360 assert_eq!(
361 running_tasks_for_label(&metrics, LABEL),
362 Some(1),
363 "expected tasks_running gauge to be 1 before abort: {metrics}",
364 );
365
366 handle.abort();
367
368 let metrics = context.encode();
369 assert_eq!(
370 running_tasks_for_label(&metrics, LABEL),
371 Some(0),
372 "expected tasks_running gauge to return to 0 after abort: {metrics}",
373 );
374 });
375 }
376
377 #[test]
378 fn tasks_running_decreased_after_blocking_completion() {
379 const LABEL: &str = "tasks_running_after_blocking_completion";
380
381 let runner = deterministic::Runner::default();
382 runner.start(|context| async move {
383 let context = context.with_label(LABEL);
384
385 let blocking_handle = context.clone().shared(true).spawn(|_| async move {
386 42
388 });
389
390 let metrics = context.encode();
391 assert_eq!(
392 running_tasks_for_label(&metrics, LABEL),
393 Some(1),
394 "expected tasks_running gauge to be 1 while blocking task runs: {metrics}",
395 );
396
397 let result = blocking_handle.await.expect("blocking task failed");
398 assert_eq!(result, 42);
399
400 let metrics = context.encode();
401 assert_eq!(
402 running_tasks_for_label(&metrics, LABEL),
403 Some(0),
404 "expected tasks_running gauge to return to 0 after blocking task completes: {metrics}",
405 );
406 });
407 }
408
409 #[test]
410 fn tasks_running_decreased_immediately_on_abort_via_aborter() {
411 const LABEL: &str = "tasks_running_abort_via_aborter";
412
413 let runner = deterministic::Runner::default();
414 runner.start(|context| async move {
415 let context = context.with_label(LABEL);
416 let handle = context.clone().spawn(|_| async move {
417 future::pending::<()>().await;
418 });
419
420 let metrics = context.encode();
421 assert_eq!(
422 running_tasks_for_label(&metrics, LABEL),
423 Some(1),
424 "expected tasks_running gauge to be 1 before abort: {metrics}",
425 );
426
427 let aborter = handle.aborter().unwrap();
428 aborter.abort();
429
430 let metrics = context.encode();
431 assert_eq!(
432 running_tasks_for_label(&metrics, LABEL),
433 Some(0),
434 "expected tasks_running gauge to return to 0 after abort: {metrics}",
435 );
436 });
437 }
438}