commonware_runtime/utils.rs
1//! Utility functions for interacting with any runtime.
2
3use crate::Error;
4#[cfg(test)]
5use crate::{Runner, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use futures::{
9 channel::oneshot,
10 future::Shared,
11 stream::{AbortHandle, Abortable},
12 FutureExt,
13};
14use prometheus_client::metrics::gauge::Gauge;
15use std::{
16 any::Any,
17 future::Future,
18 panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
19 pin::Pin,
20 sync::{Arc, Once},
21 task::{Context, Poll},
22};
23use tracing::error;
24
25/// Yield control back to the runtime.
26pub async fn reschedule() {
27 struct Reschedule {
28 yielded: bool,
29 }
30
31 impl Future for Reschedule {
32 type Output = ();
33
34 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
35 if self.yielded {
36 Poll::Ready(())
37 } else {
38 self.yielded = true;
39 cx.waker().wake_by_ref();
40 Poll::Pending
41 }
42 }
43 }
44
45 Reschedule { yielded: false }.await
46}
47
48fn extract_panic_message(err: &(dyn Any + Send)) -> String {
49 if let Some(s) = err.downcast_ref::<&str>() {
50 s.to_string()
51 } else if let Some(s) = err.downcast_ref::<String>() {
52 s.clone()
53 } else {
54 format!("{:?}", err)
55 }
56}
57
58/// Handle to a spawned task.
59pub struct Handle<T>
60where
61 T: Send + 'static,
62{
63 aborter: Option<AbortHandle>,
64 receiver: oneshot::Receiver<Result<T, Error>>,
65
66 running: Gauge,
67 once: Arc<Once>,
68}
69
70impl<T> Handle<T>
71where
72 T: Send + 'static,
73{
74 pub(crate) fn init<F>(
75 f: F,
76 running: Gauge,
77 catch_panic: bool,
78 ) -> (impl Future<Output = ()>, Self)
79 where
80 F: Future<Output = T> + Send + 'static,
81 {
82 // Increment running counter
83 running.inc();
84
85 // Initialize channels to handle result/abort
86 let once = Arc::new(Once::new());
87 let (sender, receiver) = oneshot::channel();
88 let (aborter, abort_registration) = AbortHandle::new_pair();
89
90 // Wrap the future to handle panics
91 let wrapped = {
92 let once = once.clone();
93 let running = running.clone();
94 async move {
95 // Run future
96 let result = AssertUnwindSafe(f).catch_unwind().await;
97
98 // Decrement running counter
99 once.call_once(|| {
100 running.dec();
101 });
102
103 // Handle result
104 let result = match result {
105 Ok(result) => Ok(result),
106 Err(err) => {
107 if !catch_panic {
108 resume_unwind(err);
109 }
110 let err = extract_panic_message(&*err);
111 error!(?err, "task panicked");
112 Err(Error::Exited)
113 }
114 };
115 let _ = sender.send(result);
116 }
117 };
118
119 // Make the future abortable
120 let abortable = Abortable::new(wrapped, abort_registration);
121 (
122 abortable.map(|_| ()),
123 Self {
124 aborter: Some(aborter),
125 receiver,
126
127 running,
128 once,
129 },
130 )
131 }
132
133 pub(crate) fn init_blocking<F>(f: F, running: Gauge, catch_panic: bool) -> (impl FnOnce(), Self)
134 where
135 F: FnOnce() -> T + Send + 'static,
136 {
137 // Increment the running tasks gauge
138 running.inc();
139
140 // Initialize channel to handle result
141 let once = Arc::new(Once::new());
142 let (sender, receiver) = oneshot::channel();
143
144 // Wrap the closure with panic handling
145 let f = {
146 let once = once.clone();
147 let running = running.clone();
148 move || {
149 // Run blocking task
150 let result = catch_unwind(AssertUnwindSafe(f));
151
152 // Decrement running counter
153 once.call_once(|| {
154 running.dec();
155 });
156
157 // Handle result
158 let result = match result {
159 Ok(value) => Ok(value),
160 Err(err) => {
161 if !catch_panic {
162 resume_unwind(err);
163 }
164 let err = extract_panic_message(&*err);
165 error!(?err, "blocking task panicked");
166 Err(Error::Exited)
167 }
168 };
169 let _ = sender.send(result);
170 }
171 };
172
173 // Return the task and handle
174 (
175 f,
176 Self {
177 aborter: None,
178 receiver,
179
180 running,
181 once,
182 },
183 )
184 }
185
186 /// Abort the task (if not blocking).
187 pub fn abort(&self) {
188 // Get aborter and abort
189 let Some(aborter) = &self.aborter else {
190 return;
191 };
192 aborter.abort();
193
194 // Decrement running counter
195 self.once.call_once(|| {
196 self.running.dec();
197 });
198 }
199}
200
201impl<T> Future for Handle<T>
202where
203 T: Send + 'static,
204{
205 type Output = Result<T, Error>;
206
207 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
208 match Pin::new(&mut self.receiver).poll(cx) {
209 Poll::Ready(Ok(Ok(value))) => {
210 self.once.call_once(|| {
211 self.running.dec();
212 });
213 Poll::Ready(Ok(value))
214 }
215 Poll::Ready(Ok(Err(err))) => {
216 self.once.call_once(|| {
217 self.running.dec();
218 });
219 Poll::Ready(Err(err))
220 }
221 Poll::Ready(Err(_)) => {
222 self.once.call_once(|| {
223 self.running.dec();
224 });
225 Poll::Ready(Err(Error::Closed))
226 }
227 Poll::Pending => Poll::Pending,
228 }
229 }
230}
231
232/// A one-time broadcast that can be awaited by many tasks. It is often used for
233/// coordinating shutdown across many tasks.
234///
235/// To minimize the overhead of tracking outstanding signals (which only return once),
236/// it is recommended to wait on a reference to it (i.e. `&mut signal`) instead of
237/// cloning it multiple times in a given task (i.e. in each iteration of a loop).
238pub type Signal = Shared<oneshot::Receiver<i32>>;
239
240/// Coordinates a one-time signal across many tasks.
241///
242/// # Example
243///
244/// ## Basic Usage
245///
246/// ```rust
247/// use commonware_runtime::{Spawner, Runner, Signaler, deterministic::Executor};
248///
249/// let (executor, _, _) = Executor::default();
250/// executor.start(async move {
251/// // Setup signaler and get future
252/// let (mut signaler, signal) = Signaler::new();
253///
254/// // Signal shutdown
255/// signaler.signal(2);
256///
257/// // Wait for shutdown in task
258/// let sig = signal.await.unwrap();
259/// println!("Received signal: {}", sig);
260/// });
261/// ```
262///
263/// ## Advanced Usage
264///
265/// While `Futures::Shared` is efficient, there is still meaningful overhead
266/// to cloning it (i.e. in each iteration of a loop). To avoid
267/// a performance regression from introducing `Signaler`, it is recommended
268/// to wait on a reference to `Signal` (i.e. `&mut signal`).
269///
270/// ```rust
271/// use commonware_macros::select;
272/// use commonware_runtime::{Clock, Spawner, Runner, Signaler, deterministic::Executor, Metrics};
273/// use futures::channel::oneshot;
274/// use std::time::Duration;
275///
276/// let (executor, context, _) = Executor::default();
277/// executor.start(async move {
278/// // Setup signaler and get future
279/// let (mut signaler, mut signal) = Signaler::new();
280///
281/// // Loop on the signal until resolved
282/// let (tx, rx) = oneshot::channel();
283/// context.with_label("waiter").spawn(|context| async move {
284/// loop {
285/// // Wait for signal or sleep
286/// select! {
287/// sig = &mut signal => {
288/// println!("Received signal: {}", sig.unwrap());
289/// break;
290/// },
291/// _ = context.sleep(Duration::from_secs(1)) => {},
292/// };
293/// }
294/// let _ = tx.send(());
295/// });
296///
297/// // Send signal
298/// signaler.signal(9);
299///
300/// // Wait for task
301/// rx.await.expect("shutdown signaled");
302/// });
303/// ```
304pub struct Signaler {
305 tx: Option<oneshot::Sender<i32>>,
306}
307
308impl Signaler {
309 /// Create a new `Signaler`.
310 ///
311 /// Returns a `Signaler` and a `Signal` that will resolve when `signal` is called.
312 pub fn new() -> (Self, Signal) {
313 let (tx, rx) = oneshot::channel();
314 (Self { tx: Some(tx) }, rx.shared())
315 }
316
317 /// Resolve the `Signal` for all waiters (if not already resolved).
318 pub fn signal(&mut self, value: i32) {
319 if let Some(stop_tx) = self.tx.take() {
320 let _ = stop_tx.send(value);
321 }
322 }
323}
324
325#[cfg(test)]
326async fn task(i: usize) -> usize {
327 for _ in 0..5 {
328 reschedule().await;
329 }
330 i
331}
332
333#[cfg(test)]
334pub fn run_tasks(tasks: usize, runner: impl Runner, context: impl Spawner) -> Vec<usize> {
335 runner.start(async move {
336 // Randomly schedule tasks
337 let mut handles = FuturesUnordered::new();
338 for i in 0..=tasks - 1 {
339 handles.push(context.clone().spawn(move |_| task(i)));
340 }
341
342 // Collect output order
343 let mut outputs = Vec::new();
344 while let Some(result) = handles.next().await {
345 outputs.push(result.unwrap());
346 }
347 assert_eq!(outputs.len(), tasks);
348 outputs
349 })
350}