commonware_runtime/utils/
handle.rs1use crate::{utils::extract_panic_message, Error};
2use futures::{
3 channel::oneshot,
4 stream::{AbortHandle, Abortable},
5 FutureExt as _,
6};
7use prometheus_client::metrics::gauge::Gauge;
8use std::{
9 future::Future,
10 panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
11 pin::Pin,
12 sync::{Arc, Once},
13 task::{Context, Poll},
14};
15use tracing::error;
16
17pub struct Handle<T>
19where
20 T: Send + 'static,
21{
22 aborter: Option<AbortHandle>,
23 receiver: oneshot::Receiver<Result<T, Error>>,
24
25 running: Gauge,
26 once: Arc<Once>,
27}
28
29impl<T> Handle<T>
30where
31 T: Send + 'static,
32{
33 pub(crate) fn init_future<F>(
34 f: F,
35 running: Gauge,
36 catch_panic: bool,
37 ) -> (impl Future<Output = ()>, Self)
38 where
39 F: Future<Output = T> + Send + 'static,
40 {
41 running.inc();
43
44 let once = Arc::new(Once::new());
46 let (sender, receiver) = oneshot::channel();
47 let (aborter, abort_registration) = AbortHandle::new_pair();
48
49 let wrapped = {
51 let once = once.clone();
52 let running = running.clone();
53 async move {
54 let result = AssertUnwindSafe(f).catch_unwind().await;
56
57 once.call_once(|| {
59 running.dec();
60 });
61
62 let result = match result {
64 Ok(result) => Ok(result),
65 Err(err) => {
66 if !catch_panic {
67 resume_unwind(err);
68 }
69 let err = extract_panic_message(&*err);
70 error!(?err, "task panicked");
71 Err(Error::Exited)
72 }
73 };
74 let _ = sender.send(result);
75 }
76 };
77
78 let abortable = Abortable::new(wrapped, abort_registration);
80 (
81 abortable.map(|_| ()),
82 Self {
83 aborter: Some(aborter),
84 receiver,
85
86 running,
87 once,
88 },
89 )
90 }
91
92 pub(crate) fn init_blocking<F>(f: F, running: Gauge, catch_panic: bool) -> (impl FnOnce(), Self)
93 where
94 F: FnOnce() -> T + Send + 'static,
95 {
96 running.inc();
98
99 let once = Arc::new(Once::new());
101 let (sender, receiver) = oneshot::channel();
102
103 let f = {
105 let once = once.clone();
106 let running = running.clone();
107 move || {
108 let result = catch_unwind(AssertUnwindSafe(f));
110
111 once.call_once(|| {
113 running.dec();
114 });
115
116 let result = match result {
118 Ok(value) => Ok(value),
119 Err(err) => {
120 if !catch_panic {
121 resume_unwind(err);
122 }
123 let err = extract_panic_message(&*err);
124 error!(?err, "task panicked");
125 Err(Error::Exited)
126 }
127 };
128 let _ = sender.send(result);
129 }
130 };
131
132 (
134 f,
135 Self {
136 aborter: None,
137 receiver,
138
139 running,
140 once,
141 },
142 )
143 }
144
145 pub fn abort(&self) {
147 let Some(aborter) = &self.aborter else {
149 return;
150 };
151 aborter.abort();
152
153 self.once.call_once(|| {
155 self.running.dec();
156 });
157 }
158}
159
160impl<T> Future for Handle<T>
161where
162 T: Send + 'static,
163{
164 type Output = Result<T, Error>;
165
166 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
167 match Pin::new(&mut self.receiver).poll(cx) {
168 Poll::Ready(Ok(Ok(value))) => {
169 self.once.call_once(|| {
170 self.running.dec();
171 });
172 Poll::Ready(Ok(value))
173 }
174 Poll::Ready(Ok(Err(err))) => {
175 self.once.call_once(|| {
176 self.running.dec();
177 });
178 Poll::Ready(Err(err))
179 }
180 Poll::Ready(Err(_)) => {
181 self.once.call_once(|| {
182 self.running.dec();
183 });
184 Poll::Ready(Err(Error::Closed))
185 }
186 Poll::Pending => Poll::Pending,
187 }
188 }
189}