commonware_runtime/utils/
handle.rs1use crate::{utils::extract_panic_message, Error};
2use futures::{
3 channel::oneshot,
4 stream::{AbortHandle, Abortable},
5 FutureExt as _,
6};
7use prometheus_client::metrics::gauge::Gauge;
8use std::{
9 future::Future,
10 panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
11 pin::Pin,
12 sync::{Arc, Mutex, Once},
13 task::{Context, Poll},
14};
15use tracing::error;
16
17pub struct Handle<T>
19where
20 T: Send + 'static,
21{
22 aborter: Option<AbortHandle>,
23 receiver: oneshot::Receiver<Result<T, Error>>,
24
25 running: Gauge,
26 once: Arc<Once>,
27}
28
29impl<T> Handle<T>
30where
31 T: Send + 'static,
32{
33 pub(crate) fn init_future<F>(
34 f: F,
35 running: Gauge,
36 catch_panic: bool,
37 children: Arc<Mutex<Vec<AbortHandle>>>,
38 ) -> (impl Future<Output = ()>, Self)
39 where
40 F: Future<Output = T> + Send + 'static,
41 {
42 running.inc();
44
45 let once = Arc::new(Once::new());
47 let (sender, receiver) = oneshot::channel();
48 let (aborter, abort_registration) = AbortHandle::new_pair();
49
50 let wrapped = {
52 let once = once.clone();
53 let running = running.clone();
54 async move {
55 let result = AssertUnwindSafe(f).catch_unwind().await;
57
58 once.call_once(|| {
60 running.dec();
61 });
62
63 let result = match result {
65 Ok(result) => Ok(result),
66 Err(err) => {
67 if !catch_panic {
68 resume_unwind(err);
69 }
70 let err = extract_panic_message(&*err);
71 error!(?err, "task panicked");
72 Err(Error::Exited)
73 }
74 };
75 let _ = sender.send(result);
76 }
77 };
78
79 let abortable = Abortable::new(wrapped, abort_registration);
81 (
82 abortable.map(move |_| {
83 for handle in children.lock().unwrap().drain(..) {
85 handle.abort();
86 }
87 }),
88 Self {
89 aborter: Some(aborter),
90 receiver,
91
92 running,
93 once,
94 },
95 )
96 }
97
98 pub(crate) fn init_blocking<F>(f: F, running: Gauge, catch_panic: bool) -> (impl FnOnce(), Self)
99 where
100 F: FnOnce() -> T + Send + 'static,
101 {
102 running.inc();
104
105 let once = Arc::new(Once::new());
107 let (sender, receiver) = oneshot::channel();
108
109 let f = {
111 let once = once.clone();
112 let running = running.clone();
113 move || {
114 let result = catch_unwind(AssertUnwindSafe(f));
116
117 once.call_once(|| {
119 running.dec();
120 });
121
122 let result = match result {
124 Ok(value) => Ok(value),
125 Err(err) => {
126 if !catch_panic {
127 resume_unwind(err);
128 }
129 let err = extract_panic_message(&*err);
130 error!(?err, "task panicked");
131 Err(Error::Exited)
132 }
133 };
134 let _ = sender.send(result);
135 }
136 };
137
138 (
140 f,
141 Self {
142 aborter: None,
143 receiver,
144
145 running,
146 once,
147 },
148 )
149 }
150
151 pub fn abort(&self) {
153 let Some(aborter) = &self.aborter else {
155 return;
156 };
157 aborter.abort();
158
159 self.once.call_once(|| {
161 self.running.dec();
162 });
163 }
164
165 pub(crate) fn abort_handle(&self) -> Option<AbortHandle> {
166 self.aborter.clone()
167 }
168}
169
170impl<T> Future for Handle<T>
171where
172 T: Send + 'static,
173{
174 type Output = Result<T, Error>;
175
176 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
177 match Pin::new(&mut self.receiver).poll(cx) {
178 Poll::Ready(Ok(Ok(value))) => {
179 self.once.call_once(|| {
180 self.running.dec();
181 });
182 Poll::Ready(Ok(value))
183 }
184 Poll::Ready(Ok(Err(err))) => {
185 self.once.call_once(|| {
186 self.running.dec();
187 });
188 Poll::Ready(Err(err))
189 }
190 Poll::Ready(Err(_)) => {
191 self.once.call_once(|| {
192 self.running.dec();
193 });
194 Poll::Ready(Err(Error::Closed))
195 }
196 Poll::Pending => Poll::Pending,
197 }
198 }
199}