commonware_runtime/
utils.rs

1//! Utility functions for interacting with any runtime.
2
3use crate::Error;
4#[cfg(test)]
5use crate::{Runner, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use futures::{
9    channel::oneshot,
10    future::Shared,
11    stream::{AbortHandle, Abortable},
12    FutureExt,
13};
14use prometheus_client::metrics::gauge::Gauge;
15use std::{
16    any::Any,
17    future::Future,
18    panic::{resume_unwind, AssertUnwindSafe},
19    pin::Pin,
20    sync::{Arc, Once},
21    task::{Context, Poll},
22};
23use tracing::error;
24
25/// Yield control back to the runtime.
26pub async fn reschedule() {
27    struct Reschedule {
28        yielded: bool,
29    }
30
31    impl Future for Reschedule {
32        type Output = ();
33
34        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
35            if self.yielded {
36                Poll::Ready(())
37            } else {
38                self.yielded = true;
39                cx.waker().wake_by_ref();
40                Poll::Pending
41            }
42        }
43    }
44
45    Reschedule { yielded: false }.await
46}
47
48fn extract_panic_message(err: &(dyn Any + Send)) -> String {
49    if let Some(s) = err.downcast_ref::<&str>() {
50        s.to_string()
51    } else if let Some(s) = err.downcast_ref::<String>() {
52        s.clone()
53    } else {
54        format!("{:?}", err)
55    }
56}
57
58/// Handle to a spawned task.
59pub struct Handle<T>
60where
61    T: Send + 'static,
62{
63    aborter: AbortHandle,
64    receiver: oneshot::Receiver<Result<T, Error>>,
65
66    running: Gauge,
67    once: Arc<Once>,
68}
69
70impl<T> Handle<T>
71where
72    T: Send + 'static,
73{
74    pub(crate) fn init<F>(
75        f: F,
76        running: Gauge,
77        catch_panic: bool,
78    ) -> (impl Future<Output = ()>, Self)
79    where
80        F: Future<Output = T> + Send + 'static,
81    {
82        // Increment running counter
83        running.inc();
84
85        // Initialize channels to handle result/abort
86        let once = Arc::new(Once::new());
87        let (sender, receiver) = oneshot::channel();
88        let (aborter, abort_registration) = AbortHandle::new_pair();
89
90        // Wrap the future to handle panics
91        let wrapped = {
92            let once = once.clone();
93            let running = running.clone();
94            async move {
95                // Run future
96                let result = AssertUnwindSafe(f).catch_unwind().await;
97
98                // Decrement running counter
99                once.call_once(|| {
100                    running.dec();
101                });
102
103                // Handle result
104                let result = match result {
105                    Ok(result) => Ok(result),
106                    Err(err) => {
107                        if !catch_panic {
108                            resume_unwind(err);
109                        }
110                        let err = extract_panic_message(&*err);
111                        error!(?err, "task panicked");
112                        Err(Error::Exited)
113                    }
114                };
115                let _ = sender.send(result);
116            }
117        };
118
119        // Make the future abortable
120        let abortable = Abortable::new(wrapped, abort_registration);
121        (
122            abortable.map(|_| ()),
123            Self {
124                aborter,
125                receiver,
126
127                running,
128                once,
129            },
130        )
131    }
132
133    pub fn abort(&self) {
134        // Stop task
135        self.aborter.abort();
136
137        // Decrement running counter
138        self.once.call_once(|| {
139            self.running.dec();
140        });
141    }
142}
143
144impl<T> Future for Handle<T>
145where
146    T: Send + 'static,
147{
148    type Output = Result<T, Error>;
149
150    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
151        Pin::new(&mut self.receiver)
152            .poll(cx)
153            .map(|res| res.map_err(|_| Error::Closed).and_then(|r| r))
154    }
155}
156
157/// A one-time broadcast that can be awaited by many tasks. It is often used for
158/// coordinating shutdown across many tasks.
159///
160/// To minimize the overhead of tracking outstanding signals (which only return once),
161/// it is recommended to wait on a reference to it (i.e. `&mut signal`) instead of
162/// cloning it multiple times in a given task (i.e. in each iteration of a loop).
163pub type Signal = Shared<oneshot::Receiver<i32>>;
164
165/// Coordinates a one-time signal across many tasks.
166///
167/// # Example
168///
169/// ## Basic Usage
170///
171/// ```rust
172/// use commonware_runtime::{Spawner, Runner, Signaler, deterministic::Executor};
173///
174/// let (executor, _, _) = Executor::default();
175/// executor.start(async move {
176///     // Setup signaler and get future
177///     let (mut signaler, signal) = Signaler::new();
178///
179///     // Signal shutdown
180///     signaler.signal(2);
181///
182///     // Wait for shutdown in task
183///     let sig = signal.await.unwrap();
184///     println!("Received signal: {}", sig);
185/// });
186/// ```
187///
188/// ## Advanced Usage
189///
190/// While `Futures::Shared` is efficient, there is still meaningful overhead
191/// to cloning it (i.e. in each iteration of a loop). To avoid
192/// a performance regression from introducing `Signaler`, it is recommended
193/// to wait on a reference to `Signal` (i.e. `&mut signal`).
194///
195/// ```rust
196/// use commonware_macros::select;
197/// use commonware_runtime::{Clock, Spawner, Runner, Signaler, deterministic::Executor};
198/// use futures::channel::oneshot;
199/// use std::time::Duration;
200///
201/// let (executor, context, _) = Executor::default();
202/// executor.start(async move {
203///     // Setup signaler and get future
204///     let (mut signaler, mut signal) = Signaler::new();
205///
206///     // Loop on the signal until resolved
207///     let (tx, rx) = oneshot::channel();
208///     context.spawn("task", {
209///         let context = context.clone();
210///         async move {
211///             loop {
212///                 // Wait for signal or sleep
213///                 select! {
214///                      sig = &mut signal => {
215///                          println!("Received signal: {}", sig.unwrap());
216///                          break;
217///                      },   
218///                      _ = context.sleep(Duration::from_secs(1)) => {},
219///                 };
220///             }
221///             let _ = tx.send(());
222///         }
223///     });
224///
225///     // Send signal
226///     signaler.signal(9);
227///
228///     // Wait for task
229///     rx.await.expect("shutdown signaled");
230/// });
231/// ```
232pub struct Signaler {
233    tx: Option<oneshot::Sender<i32>>,
234}
235
236impl Signaler {
237    /// Create a new `Signaler`.
238    ///
239    /// Returns a `Signaler` and a `Signal` that will resolve when `signal` is called.
240    pub fn new() -> (Self, Signal) {
241        let (tx, rx) = oneshot::channel();
242        (Self { tx: Some(tx) }, rx.shared())
243    }
244
245    /// Resolve the `Signal` for all waiters (if not already resolved).
246    pub fn signal(&mut self, value: i32) {
247        if let Some(stop_tx) = self.tx.take() {
248            let _ = stop_tx.send(value);
249        }
250    }
251}
252
253#[cfg(test)]
254async fn task(i: usize) -> usize {
255    for _ in 0..5 {
256        reschedule().await;
257    }
258    i
259}
260
261#[cfg(test)]
262pub fn run_tasks(tasks: usize, runner: impl Runner, context: impl Spawner) -> Vec<usize> {
263    runner.start(async move {
264        // Randomly schedule tasks
265        let mut handles = FuturesUnordered::new();
266        for i in 0..tasks - 1 {
267            handles.push(context.spawn("test", task(i)));
268        }
269        handles.push(context.spawn("test", task(tasks - 1)));
270
271        // Collect output order
272        let mut outputs = Vec::new();
273        while let Some(result) = handles.next().await {
274            outputs.push(result.unwrap());
275        }
276        assert_eq!(outputs.len(), tasks);
277        outputs
278    })
279}