commonware_runtime/
utils.rs

1//! Utility functions for interacting with any runtime.
2
3use crate::Error;
4#[cfg(test)]
5use crate::{Runner, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use futures::{
9    channel::oneshot,
10    future::Shared,
11    stream::{AbortHandle, Abortable},
12    FutureExt,
13};
14use prometheus_client::metrics::gauge::Gauge;
15use std::{
16    any::Any,
17    future::Future,
18    panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
19    pin::Pin,
20    sync::{Arc, Once},
21    task::{Context, Poll},
22};
23use tracing::error;
24
25/// Yield control back to the runtime.
26pub async fn reschedule() {
27    struct Reschedule {
28        yielded: bool,
29    }
30
31    impl Future for Reschedule {
32        type Output = ();
33
34        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
35            if self.yielded {
36                Poll::Ready(())
37            } else {
38                self.yielded = true;
39                cx.waker().wake_by_ref();
40                Poll::Pending
41            }
42        }
43    }
44
45    Reschedule { yielded: false }.await
46}
47
48fn extract_panic_message(err: &(dyn Any + Send)) -> String {
49    if let Some(s) = err.downcast_ref::<&str>() {
50        s.to_string()
51    } else if let Some(s) = err.downcast_ref::<String>() {
52        s.clone()
53    } else {
54        format!("{:?}", err)
55    }
56}
57
58/// Handle to a spawned task.
59pub struct Handle<T>
60where
61    T: Send + 'static,
62{
63    aborter: Option<AbortHandle>,
64    receiver: oneshot::Receiver<Result<T, Error>>,
65
66    running: Gauge,
67    once: Arc<Once>,
68}
69
70impl<T> Handle<T>
71where
72    T: Send + 'static,
73{
74    pub(crate) fn init<F>(
75        f: F,
76        running: Gauge,
77        catch_panic: bool,
78    ) -> (impl Future<Output = ()>, Self)
79    where
80        F: Future<Output = T> + Send + 'static,
81    {
82        // Increment running counter
83        running.inc();
84
85        // Initialize channels to handle result/abort
86        let once = Arc::new(Once::new());
87        let (sender, receiver) = oneshot::channel();
88        let (aborter, abort_registration) = AbortHandle::new_pair();
89
90        // Wrap the future to handle panics
91        let wrapped = {
92            let once = once.clone();
93            let running = running.clone();
94            async move {
95                // Run future
96                let result = AssertUnwindSafe(f).catch_unwind().await;
97
98                // Decrement running counter
99                once.call_once(|| {
100                    running.dec();
101                });
102
103                // Handle result
104                let result = match result {
105                    Ok(result) => Ok(result),
106                    Err(err) => {
107                        if !catch_panic {
108                            resume_unwind(err);
109                        }
110                        let err = extract_panic_message(&*err);
111                        error!(?err, "task panicked");
112                        Err(Error::Exited)
113                    }
114                };
115                let _ = sender.send(result);
116            }
117        };
118
119        // Make the future abortable
120        let abortable = Abortable::new(wrapped, abort_registration);
121        (
122            abortable.map(|_| ()),
123            Self {
124                aborter: Some(aborter),
125                receiver,
126
127                running,
128                once,
129            },
130        )
131    }
132
133    pub(crate) fn init_blocking<F>(f: F, running: Gauge, catch_panic: bool) -> (impl FnOnce(), Self)
134    where
135        F: FnOnce() -> T + Send + 'static,
136    {
137        // Increment the running tasks gauge
138        running.inc();
139
140        // Initialize channel to handle result
141        let once = Arc::new(Once::new());
142        let (sender, receiver) = oneshot::channel();
143
144        // Wrap the closure with panic handling
145        let f = {
146            let once = once.clone();
147            let running = running.clone();
148            move || {
149                // Run blocking task
150                let result = catch_unwind(AssertUnwindSafe(f));
151
152                // Decrement running counter
153                once.call_once(|| {
154                    running.dec();
155                });
156
157                // Handle result
158                let result = match result {
159                    Ok(value) => Ok(value),
160                    Err(err) => {
161                        if !catch_panic {
162                            resume_unwind(err);
163                        }
164                        let err = extract_panic_message(&*err);
165                        error!(?err, "blocking task panicked");
166                        Err(Error::Exited)
167                    }
168                };
169                let _ = sender.send(result);
170            }
171        };
172
173        // Return the task and handle
174        (
175            f,
176            Self {
177                aborter: None,
178                receiver,
179
180                running,
181                once,
182            },
183        )
184    }
185
186    /// Abort the task (if not blocking).
187    pub fn abort(&self) {
188        // Get aborter and abort
189        let Some(aborter) = &self.aborter else {
190            return;
191        };
192        aborter.abort();
193
194        // Decrement running counter
195        self.once.call_once(|| {
196            self.running.dec();
197        });
198    }
199}
200
201impl<T> Future for Handle<T>
202where
203    T: Send + 'static,
204{
205    type Output = Result<T, Error>;
206
207    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
208        match Pin::new(&mut self.receiver).poll(cx) {
209            Poll::Ready(Ok(Ok(value))) => {
210                self.once.call_once(|| {
211                    self.running.dec();
212                });
213                Poll::Ready(Ok(value))
214            }
215            Poll::Ready(Ok(Err(err))) => {
216                self.once.call_once(|| {
217                    self.running.dec();
218                });
219                Poll::Ready(Err(err))
220            }
221            Poll::Ready(Err(_)) => {
222                self.once.call_once(|| {
223                    self.running.dec();
224                });
225                Poll::Ready(Err(Error::Closed))
226            }
227            Poll::Pending => Poll::Pending,
228        }
229    }
230}
231
232/// A one-time broadcast that can be awaited by many tasks. It is often used for
233/// coordinating shutdown across many tasks.
234///
235/// To minimize the overhead of tracking outstanding signals (which only return once),
236/// it is recommended to wait on a reference to it (i.e. `&mut signal`) instead of
237/// cloning it multiple times in a given task (i.e. in each iteration of a loop).
238pub type Signal = Shared<oneshot::Receiver<i32>>;
239
240/// Coordinates a one-time signal across many tasks.
241///
242/// # Example
243///
244/// ## Basic Usage
245///
246/// ```rust
247/// use commonware_runtime::{Spawner, Runner, Signaler, deterministic::Executor};
248///
249/// let (executor, _, _) = Executor::default();
250/// executor.start(async move {
251///     // Setup signaler and get future
252///     let (mut signaler, signal) = Signaler::new();
253///
254///     // Signal shutdown
255///     signaler.signal(2);
256///
257///     // Wait for shutdown in task
258///     let sig = signal.await.unwrap();
259///     println!("Received signal: {}", sig);
260/// });
261/// ```
262///
263/// ## Advanced Usage
264///
265/// While `Futures::Shared` is efficient, there is still meaningful overhead
266/// to cloning it (i.e. in each iteration of a loop). To avoid
267/// a performance regression from introducing `Signaler`, it is recommended
268/// to wait on a reference to `Signal` (i.e. `&mut signal`).
269///
270/// ```rust
271/// use commonware_macros::select;
272/// use commonware_runtime::{Clock, Spawner, Runner, Signaler, deterministic::Executor, Metrics};
273/// use futures::channel::oneshot;
274/// use std::time::Duration;
275///
276/// let (executor, context, _) = Executor::default();
277/// executor.start(async move {
278///     // Setup signaler and get future
279///     let (mut signaler, mut signal) = Signaler::new();
280///
281///     // Loop on the signal until resolved
282///     let (tx, rx) = oneshot::channel();
283///     context.with_label("waiter").spawn(|context| async move {
284///         loop {
285///             // Wait for signal or sleep
286///             select! {
287///                  sig = &mut signal => {
288///                      println!("Received signal: {}", sig.unwrap());
289///                      break;
290///                  },
291///                  _ = context.sleep(Duration::from_secs(1)) => {},
292///             };
293///         }
294///         let _ = tx.send(());
295///     });
296///
297///     // Send signal
298///     signaler.signal(9);
299///
300///     // Wait for task
301///     rx.await.expect("shutdown signaled");
302/// });
303/// ```
304pub struct Signaler {
305    tx: Option<oneshot::Sender<i32>>,
306}
307
308impl Signaler {
309    /// Create a new `Signaler`.
310    ///
311    /// Returns a `Signaler` and a `Signal` that will resolve when `signal` is called.
312    pub fn new() -> (Self, Signal) {
313        let (tx, rx) = oneshot::channel();
314        (Self { tx: Some(tx) }, rx.shared())
315    }
316
317    /// Resolve the `Signal` for all waiters (if not already resolved).
318    pub fn signal(&mut self, value: i32) {
319        if let Some(stop_tx) = self.tx.take() {
320            let _ = stop_tx.send(value);
321        }
322    }
323}
324
325#[cfg(test)]
326async fn task(i: usize) -> usize {
327    for _ in 0..5 {
328        reschedule().await;
329    }
330    i
331}
332
333#[cfg(test)]
334pub fn run_tasks(tasks: usize, runner: impl Runner, context: impl Spawner) -> Vec<usize> {
335    runner.start(async move {
336        // Randomly schedule tasks
337        let mut handles = FuturesUnordered::new();
338        for i in 0..=tasks - 1 {
339            handles.push(context.clone().spawn(move |_| task(i)));
340        }
341
342        // Collect output order
343        let mut outputs = Vec::new();
344        while let Some(result) = handles.next().await {
345            outputs.push(result.unwrap());
346        }
347        assert_eq!(outputs.len(), tasks);
348        outputs
349    })
350}