commonware_runtime/
utils.rs

1//! Utility functions for interacting with any runtime.
2
3use crate::Error;
4#[cfg(test)]
5use crate::{Runner, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use futures::{
9    channel::oneshot,
10    future::Shared,
11    stream::{AbortHandle, Abortable},
12    FutureExt,
13};
14use prometheus_client::metrics::gauge::Gauge;
15use std::{
16    any::Any,
17    future::Future,
18    panic::{resume_unwind, AssertUnwindSafe},
19    pin::Pin,
20    sync::{Arc, Once},
21    task::{Context, Poll},
22};
23use tracing::error;
24
25/// Yield control back to the runtime.
26pub async fn reschedule() {
27    struct Reschedule {
28        yielded: bool,
29    }
30
31    impl Future for Reschedule {
32        type Output = ();
33
34        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
35            if self.yielded {
36                Poll::Ready(())
37            } else {
38                self.yielded = true;
39                cx.waker().wake_by_ref();
40                Poll::Pending
41            }
42        }
43    }
44
45    Reschedule { yielded: false }.await
46}
47
48fn extract_panic_message(err: &(dyn Any + Send)) -> String {
49    if let Some(s) = err.downcast_ref::<&str>() {
50        s.to_string()
51    } else if let Some(s) = err.downcast_ref::<String>() {
52        s.clone()
53    } else {
54        format!("{:?}", err)
55    }
56}
57
58/// Handle to a spawned task.
59pub struct Handle<T>
60where
61    T: Send + 'static,
62{
63    aborter: AbortHandle,
64    receiver: oneshot::Receiver<Result<T, Error>>,
65
66    running: Gauge,
67    once: Arc<Once>,
68}
69
70impl<T> Handle<T>
71where
72    T: Send + 'static,
73{
74    pub(crate) fn init<F>(
75        f: F,
76        running: Gauge,
77        catch_panic: bool,
78    ) -> (impl Future<Output = ()>, Self)
79    where
80        F: Future<Output = T> + Send + 'static,
81    {
82        // Increment running counter
83        running.inc();
84
85        // Initialize channels to handle result/abort
86        let once = Arc::new(Once::new());
87        let (sender, receiver) = oneshot::channel();
88        let (aborter, abort_registration) = AbortHandle::new_pair();
89
90        // Wrap the future to handle panics
91        let wrapped = {
92            let once = once.clone();
93            let running = running.clone();
94            async move {
95                // Run future
96                let result = AssertUnwindSafe(f).catch_unwind().await;
97
98                // Decrement running counter
99                once.call_once(|| {
100                    running.dec();
101                });
102
103                // Handle result
104                let result = match result {
105                    Ok(result) => Ok(result),
106                    Err(err) => {
107                        if !catch_panic {
108                            resume_unwind(err);
109                        }
110                        let err = extract_panic_message(&*err);
111                        error!(?err, "task panicked");
112                        Err(Error::Exited)
113                    }
114                };
115                let _ = sender.send(result);
116            }
117        };
118
119        // Make the future abortable
120        let abortable = Abortable::new(wrapped, abort_registration);
121        (
122            abortable.map(|_| ()),
123            Self {
124                aborter,
125                receiver,
126
127                running,
128                once,
129            },
130        )
131    }
132
133    pub fn abort(&self) {
134        // Stop task
135        self.aborter.abort();
136
137        // Decrement running counter
138        self.once.call_once(|| {
139            self.running.dec();
140        });
141    }
142}
143
144impl<T> Future for Handle<T>
145where
146    T: Send + 'static,
147{
148    type Output = Result<T, Error>;
149
150    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
151        Pin::new(&mut self.receiver)
152            .poll(cx)
153            .map(|res| res.map_err(|_| Error::Closed).and_then(|r| r))
154    }
155}
156
157/// A one-time broadcast that can be awaited by many tasks. It is often used for
158/// coordinating shutdown across many tasks.
159///
160/// To minimize the overhead of tracking outstanding signals (which only return once),
161/// it is recommended to wait on a reference to it (i.e. `&mut signal`) instead of
162/// cloning it multiple times in a given task (i.e. in each iteration of a loop).
163pub type Signal = Shared<oneshot::Receiver<i32>>;
164
165/// Coordinates a one-time signal across many tasks.
166///
167/// # Example
168///
169/// ## Basic Usage
170///
171/// ```rust
172/// use commonware_runtime::{Spawner, Runner, Signaler, deterministic::Executor};
173///
174/// let (executor, _, _) = Executor::default();
175/// executor.start(async move {
176///     // Setup signaler and get future
177///     let (mut signaler, signal) = Signaler::new();
178///
179///     // Signal shutdown
180///     signaler.signal(2);
181///
182///     // Wait for shutdown in task
183///     let sig = signal.await.unwrap();
184///     println!("Received signal: {}", sig);
185/// });
186/// ```
187///
188/// ## Advanced Usage
189///
190/// While `Futures::Shared` is efficient, there is still meaningful overhead
191/// to cloning it (i.e. in each iteration of a loop). To avoid
192/// a performance regression from introducing `Signaler`, it is recommended
193/// to wait on a reference to `Signal` (i.e. `&mut signal`).
194///
195/// ```rust
196/// use commonware_macros::select;
197/// use commonware_runtime::{Clock, Spawner, Runner, Signaler, deterministic::Executor, Metrics};
198/// use futures::channel::oneshot;
199/// use std::time::Duration;
200///
201/// let (executor, context, _) = Executor::default();
202/// executor.start(async move {
203///     // Setup signaler and get future
204///     let (mut signaler, mut signal) = Signaler::new();
205///
206///     // Loop on the signal until resolved
207///     let (tx, rx) = oneshot::channel();
208///     context.with_label("waiter").spawn(|context| async move {
209///         loop {
210///             // Wait for signal or sleep
211///             select! {
212///                  sig = &mut signal => {
213///                      println!("Received signal: {}", sig.unwrap());
214///                      break;
215///                  },
216///                  _ = context.sleep(Duration::from_secs(1)) => {},
217///             };
218///         }
219///         let _ = tx.send(());
220///     });
221///
222///     // Send signal
223///     signaler.signal(9);
224///
225///     // Wait for task
226///     rx.await.expect("shutdown signaled");
227/// });
228/// ```
229pub struct Signaler {
230    tx: Option<oneshot::Sender<i32>>,
231}
232
233impl Signaler {
234    /// Create a new `Signaler`.
235    ///
236    /// Returns a `Signaler` and a `Signal` that will resolve when `signal` is called.
237    pub fn new() -> (Self, Signal) {
238        let (tx, rx) = oneshot::channel();
239        (Self { tx: Some(tx) }, rx.shared())
240    }
241
242    /// Resolve the `Signal` for all waiters (if not already resolved).
243    pub fn signal(&mut self, value: i32) {
244        if let Some(stop_tx) = self.tx.take() {
245            let _ = stop_tx.send(value);
246        }
247    }
248}
249
250#[cfg(test)]
251async fn task(i: usize) -> usize {
252    for _ in 0..5 {
253        reschedule().await;
254    }
255    i
256}
257
258#[cfg(test)]
259pub fn run_tasks(tasks: usize, runner: impl Runner, context: impl Spawner) -> Vec<usize> {
260    runner.start(async move {
261        // Randomly schedule tasks
262        let mut handles = FuturesUnordered::new();
263        for i in 0..=tasks - 1 {
264            handles.push(context.clone().spawn(move |_| task(i)));
265        }
266
267        // Collect output order
268        let mut outputs = Vec::new();
269        while let Some(result) = handles.next().await {
270            outputs.push(result.unwrap());
271        }
272        assert_eq!(outputs.len(), tasks);
273        outputs
274    })
275}