commonware_runtime/utils/
handle.rs1use crate::{
2 telemetry::metrics::raw::Gauge,
3 utils::{extract_panic_message, supervision::Tree},
4 Error,
5};
6use commonware_utils::{
7 channel::oneshot,
8 sync::{Mutex, Once},
9};
10use futures::{
11 future::{select, Either},
12 pin_mut,
13 stream::{AbortHandle, Abortable, Aborted},
14 FutureExt as _,
15};
16use std::{
17 any::Any,
18 future::Future,
19 panic::{resume_unwind, AssertUnwindSafe},
20 pin::Pin,
21 sync::Arc,
22 task::{Context, Poll},
23};
24use tracing::error;
25
26pub struct Handle<T>
28where
29 T: Send + 'static,
30{
31 abort_handle: Option<AbortHandle>,
32 receiver: oneshot::Receiver<Result<T, Error>>,
33 metric: MetricHandle,
34}
35
36impl<T> Handle<T>
37where
38 T: Send + 'static,
39{
40 #[inline(always)]
41 pub(crate) fn init<F>(
42 f: F,
43 metric: MetricHandle,
44 panicker: Panicker,
45 tree: Arc<Tree>,
46 ) -> (impl Future<Output = ()>, Self)
47 where
48 F: Future<Output = T> + Send + 'static,
49 {
50 let (sender, receiver) = oneshot::channel();
52 let (abort_handle, abort_registration) = AbortHandle::new_pair();
53
54 let metric_handle = metric.clone();
60 let task = async move {
61 let result =
63 Abortable::new(AssertUnwindSafe(f).catch_unwind(), abort_registration).await;
64
65 match result {
67 Ok(Ok(result)) => {
68 let _ = sender.send(Ok(result));
69 }
70 Ok(Err(panic)) => {
71 panicker.notify(panic);
72 let _ = sender.send(Err(Error::Exited));
73 }
74 Err(Aborted) => {}
75 }
76
77 tree.abort();
79
80 metric_handle.finish();
82 };
83
84 (
85 task,
86 Self {
87 abort_handle: Some(abort_handle),
88 receiver,
89 metric,
90 },
91 )
92 }
93
94 pub(crate) fn closed(metric: MetricHandle) -> Self {
96 metric.finish();
98
99 let (sender, receiver) = oneshot::channel();
101 drop(sender);
102
103 Self {
104 abort_handle: None,
105 receiver,
106 metric,
107 }
108 }
109
110 pub fn abort(&self) {
112 let Some(abort_handle) = &self.abort_handle else {
114 return;
115 };
116 abort_handle.abort();
117
118 self.metric.finish();
121 }
122
123 pub(crate) fn aborter(&self) -> Option<Aborter> {
125 self.abort_handle
126 .clone()
127 .map(|inner| Aborter::new(inner, self.metric.clone()))
128 }
129}
130
131impl<T> Future for Handle<T>
132where
133 T: Send + 'static,
134{
135 type Output = Result<T, Error>;
136
137 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
138 Pin::new(&mut self.receiver)
139 .poll(cx)
140 .map(|result| result.unwrap_or_else(|_| Err(Error::Closed)))
141 }
142}
143
144#[derive(Clone)]
146pub(crate) struct MetricHandle {
147 gauge: Gauge,
148 finished: Arc<Once>,
149}
150
151impl MetricHandle {
152 pub(crate) fn new(gauge: Gauge) -> Self {
155 gauge.inc();
156
157 Self {
158 gauge,
159 finished: Arc::new(Once::new()),
160 }
161 }
162
163 pub(crate) fn finish(&self) {
168 let gauge = self.gauge.clone();
169 self.finished.call_once(move || {
170 gauge.dec();
171 });
172 }
173}
174
175pub type Panic = Box<dyn Any + Send + 'static>;
177
178#[derive(Clone)]
180pub(crate) struct Panicker {
181 catch: bool,
182 sender: Arc<Mutex<Option<oneshot::Sender<Panic>>>>,
183}
184
185impl Panicker {
186 pub(crate) fn new(catch: bool) -> (Self, Panicked) {
188 let (sender, receiver) = oneshot::channel();
189 let panicker = Self {
190 catch,
191 sender: Arc::new(Mutex::new(Some(sender))),
192 };
193 let panicked = Panicked { receiver };
194 (panicker, panicked)
195 }
196
197 #[commonware_macros::stability(ALPHA)]
199 pub(crate) const fn catch(&self) -> bool {
200 self.catch
201 }
202
203 pub(crate) fn notify(&self, panic: Box<dyn Any + Send + 'static>) {
205 let err = extract_panic_message(&*panic);
207 error!(?err, "task panicked");
208
209 if self.catch {
211 return;
212 }
213
214 let mut sender = self.sender.lock();
216 let Some(sender) = sender.take() else {
217 return;
218 };
219
220 let _ = sender.send(panic);
222 }
223}
224
225pub(crate) struct Panicked {
227 receiver: oneshot::Receiver<Panic>,
228}
229
230impl Panicked {
231 pub(crate) async fn interrupt<Fut>(self, task: Fut) -> Fut::Output
233 where
234 Fut: Future,
235 {
236 let panicked = self.receiver;
238 pin_mut!(panicked);
239 pin_mut!(task);
240 match select(panicked, task).await {
241 Either::Left((panic, task)) => match panic {
242 Ok(panic) => {
244 resume_unwind(panic);
245 }
246 Err(_) => task.await,
249 },
250 Either::Right((output, _)) => {
251 output
253 }
254 }
255 }
256}
257
258pub(crate) struct Aborter {
260 inner: AbortHandle,
261 metric: MetricHandle,
262}
263
264impl Aborter {
265 pub(crate) const fn new(inner: AbortHandle, metric: MetricHandle) -> Self {
267 Self { inner, metric }
268 }
269
270 pub(crate) fn abort(self) {
272 self.inner.abort();
273
274 self.metric.finish();
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use crate::{deterministic, Metrics as _, Runner, Spawner, Supervisor as _};
283 use futures::future;
284
285 const METRIC_PREFIX: &str = "runtime_tasks_running{";
286
287 fn running_tasks_for_label(metrics: &str, label: &str) -> Option<u64> {
288 let label_fragment = format!("name=\"{label}\"");
289 metrics.lines().find_map(|line| {
290 if line.starts_with(METRIC_PREFIX) && line.contains(&label_fragment) {
291 line.rsplit_once(' ')
292 .and_then(|(_, value)| value.trim().parse::<u64>().ok())
293 } else {
294 None
295 }
296 })
297 }
298
299 #[test]
300 fn tasks_running_decreased_after_completion() {
301 const LABEL: &str = "tasks_running_after_completion";
302
303 let runner = deterministic::Runner::default();
304 runner.start(|context| async move {
305 let handle = context.child(LABEL).spawn(|_| async move { "done" });
306
307 let metrics = context.encode();
308 assert_eq!(
309 running_tasks_for_label(&metrics, LABEL),
310 Some(1),
311 "expected tasks_running gauge to be 1 before completion: {metrics}",
312 );
313
314 let output = handle.await.expect("task failed");
315 assert_eq!(output, "done");
316
317 let metrics = context.encode();
318 assert_eq!(
319 running_tasks_for_label(&metrics, LABEL),
320 Some(0),
321 "expected tasks_running gauge to return to 0 after completion: {metrics}",
322 );
323 });
324 }
325
326 #[test]
327 fn tasks_running_unchanged_when_handle_dropped() {
328 const LABEL: &str = "tasks_running_unchanged";
329
330 let runner = deterministic::Runner::default();
331 runner.start(|context| async move {
332 let handle = context.child(LABEL).spawn(|_| async move {
333 future::pending::<()>().await;
334 });
335
336 let metrics = context.encode();
337 assert_eq!(
338 running_tasks_for_label(&metrics, LABEL),
339 Some(1),
340 "expected tasks_running gauge to be 1 before dropping handle: {metrics}",
341 );
342
343 drop(handle);
344
345 let metrics = context.encode();
346 assert_eq!(
347 running_tasks_for_label(&metrics, LABEL),
348 Some(1),
349 "dropping handle should not finish metrics: {metrics}",
350 );
351 });
352 }
353
354 #[test]
355 fn tasks_running_decreased_immediately_on_abort_via_handle() {
356 const LABEL: &str = "tasks_running_abort_via_handle";
357
358 let runner = deterministic::Runner::default();
359 runner.start(|context| async move {
360 let handle = context.child(LABEL).spawn(|_| async move {
361 future::pending::<()>().await;
362 });
363
364 let metrics = context.encode();
365 assert_eq!(
366 running_tasks_for_label(&metrics, LABEL),
367 Some(1),
368 "expected tasks_running gauge to be 1 before abort: {metrics}",
369 );
370
371 handle.abort();
372
373 let metrics = context.encode();
374 assert_eq!(
375 running_tasks_for_label(&metrics, LABEL),
376 Some(0),
377 "expected tasks_running gauge to return to 0 after abort: {metrics}",
378 );
379 });
380 }
381
382 #[test]
383 fn tasks_running_decreased_after_blocking_completion() {
384 const LABEL: &str = "tasks_running_after_blocking_completion";
385
386 let runner = deterministic::Runner::default();
387 runner.start(|context| async move {
388 let blocking_handle = context.child(LABEL).shared(true).spawn(|_| async move {
389 42
391 });
392
393 let metrics = context.encode();
394 assert_eq!(
395 running_tasks_for_label(&metrics, LABEL),
396 Some(1),
397 "expected tasks_running gauge to be 1 while blocking task runs: {metrics}",
398 );
399
400 let result = blocking_handle.await.expect("blocking task failed");
401 assert_eq!(result, 42);
402
403 let metrics = context.encode();
404 assert_eq!(
405 running_tasks_for_label(&metrics, LABEL),
406 Some(0),
407 "expected tasks_running gauge to return to 0 after blocking task completes: {metrics}",
408 );
409 });
410 }
411
412 #[test]
413 fn tasks_running_decreased_immediately_on_abort_via_aborter() {
414 const LABEL: &str = "tasks_running_abort_via_aborter";
415
416 let runner = deterministic::Runner::default();
417 runner.start(|context| async move {
418 let handle = context.child(LABEL).spawn(|_| async move {
419 future::pending::<()>().await;
420 });
421
422 let metrics = context.encode();
423 assert_eq!(
424 running_tasks_for_label(&metrics, LABEL),
425 Some(1),
426 "expected tasks_running gauge to be 1 before abort: {metrics}",
427 );
428
429 let aborter = handle.aborter().unwrap();
430 aborter.abort();
431
432 let metrics = context.encode();
433 assert_eq!(
434 running_tasks_for_label(&metrics, LABEL),
435 Some(0),
436 "expected tasks_running gauge to return to 0 after abort: {metrics}",
437 );
438 });
439 }
440}