commonware_runtime/utils.rs
1//! Utility functions for interacting with any runtime.
2
3use crate::Error;
4#[cfg(test)]
5use crate::{Runner, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use futures::{
9 channel::oneshot,
10 future::Shared,
11 stream::{AbortHandle, Abortable},
12 FutureExt,
13};
14use prometheus_client::metrics::gauge::Gauge;
15use std::{
16 any::Any,
17 future::Future,
18 panic::{resume_unwind, AssertUnwindSafe},
19 pin::Pin,
20 sync::{Arc, Once},
21 task::{Context, Poll},
22};
23use tracing::error;
24
25/// Yield control back to the runtime.
26pub async fn reschedule() {
27 struct Reschedule {
28 yielded: bool,
29 }
30
31 impl Future for Reschedule {
32 type Output = ();
33
34 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
35 if self.yielded {
36 Poll::Ready(())
37 } else {
38 self.yielded = true;
39 cx.waker().wake_by_ref();
40 Poll::Pending
41 }
42 }
43 }
44
45 Reschedule { yielded: false }.await
46}
47
48fn extract_panic_message(err: &(dyn Any + Send)) -> String {
49 if let Some(s) = err.downcast_ref::<&str>() {
50 s.to_string()
51 } else if let Some(s) = err.downcast_ref::<String>() {
52 s.clone()
53 } else {
54 format!("{:?}", err)
55 }
56}
57
58/// Handle to a spawned task.
59pub struct Handle<T>
60where
61 T: Send + 'static,
62{
63 aborter: AbortHandle,
64 receiver: oneshot::Receiver<Result<T, Error>>,
65
66 running: Gauge,
67 once: Arc<Once>,
68}
69
70impl<T> Handle<T>
71where
72 T: Send + 'static,
73{
74 pub(crate) fn init<F>(
75 f: F,
76 running: Gauge,
77 catch_panic: bool,
78 ) -> (impl Future<Output = ()>, Self)
79 where
80 F: Future<Output = T> + Send + 'static,
81 {
82 // Increment running counter
83 running.inc();
84
85 // Initialize channels to handle result/abort
86 let once = Arc::new(Once::new());
87 let (sender, receiver) = oneshot::channel();
88 let (aborter, abort_registration) = AbortHandle::new_pair();
89
90 // Wrap the future to handle panics
91 let wrapped = {
92 let once = once.clone();
93 let running = running.clone();
94 async move {
95 // Run future
96 let result = AssertUnwindSafe(f).catch_unwind().await;
97
98 // Decrement running counter
99 once.call_once(|| {
100 running.dec();
101 });
102
103 // Handle result
104 let result = match result {
105 Ok(result) => Ok(result),
106 Err(err) => {
107 if !catch_panic {
108 resume_unwind(err);
109 }
110 let err = extract_panic_message(&*err);
111 error!(?err, "task panicked");
112 Err(Error::Exited)
113 }
114 };
115 let _ = sender.send(result);
116 }
117 };
118
119 // Make the future abortable
120 let abortable = Abortable::new(wrapped, abort_registration);
121 (
122 abortable.map(|_| ()),
123 Self {
124 aborter,
125 receiver,
126
127 running,
128 once,
129 },
130 )
131 }
132
133 pub fn abort(&self) {
134 // Stop task
135 self.aborter.abort();
136
137 // Decrement running counter
138 self.once.call_once(|| {
139 self.running.dec();
140 });
141 }
142}
143
144impl<T> Future for Handle<T>
145where
146 T: Send + 'static,
147{
148 type Output = Result<T, Error>;
149
150 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
151 Pin::new(&mut self.receiver)
152 .poll(cx)
153 .map(|res| res.map_err(|_| Error::Closed).and_then(|r| r))
154 }
155}
156
157/// A one-time broadcast that can be awaited by many tasks. It is often used for
158/// coordinating shutdown across many tasks.
159///
160/// To minimize the overhead of tracking outstanding signals (which only return once),
161/// it is recommended to wait on a reference to it (i.e. `&mut signal`) instead of
162/// cloning it multiple times in a given task (i.e. in each iteration of a loop).
163pub type Signal = Shared<oneshot::Receiver<i32>>;
164
165/// Coordinates a one-time signal across many tasks.
166///
167/// # Example
168///
169/// ## Basic Usage
170///
171/// ```rust
172/// use commonware_runtime::{Spawner, Runner, Signaler, deterministic::Executor};
173///
174/// let (executor, _, _) = Executor::default();
175/// executor.start(async move {
176/// // Setup signaler and get future
177/// let (mut signaler, signal) = Signaler::new();
178///
179/// // Signal shutdown
180/// signaler.signal(2);
181///
182/// // Wait for shutdown in task
183/// let sig = signal.await.unwrap();
184/// println!("Received signal: {}", sig);
185/// });
186/// ```
187///
188/// ## Advanced Usage
189///
190/// While `Futures::Shared` is efficient, there is still meaningful overhead
191/// to cloning it (i.e. in each iteration of a loop). To avoid
192/// a performance regression from introducing `Signaler`, it is recommended
193/// to wait on a reference to `Signal` (i.e. `&mut signal`).
194///
195/// ```rust
196/// use commonware_macros::select;
197/// use commonware_runtime::{Clock, Spawner, Runner, Signaler, deterministic::Executor};
198/// use futures::channel::oneshot;
199/// use std::time::Duration;
200///
201/// let (executor, context, _) = Executor::default();
202/// executor.start(async move {
203/// // Setup signaler and get future
204/// let (mut signaler, mut signal) = Signaler::new();
205///
206/// // Loop on the signal until resolved
207/// let (tx, rx) = oneshot::channel();
208/// context.spawn("task", {
209/// let context = context.clone();
210/// async move {
211/// loop {
212/// // Wait for signal or sleep
213/// select! {
214/// sig = &mut signal => {
215/// println!("Received signal: {}", sig.unwrap());
216/// break;
217/// },
218/// _ = context.sleep(Duration::from_secs(1)) => {},
219/// };
220/// }
221/// let _ = tx.send(());
222/// }
223/// });
224///
225/// // Send signal
226/// signaler.signal(9);
227///
228/// // Wait for task
229/// rx.await.expect("shutdown signaled");
230/// });
231/// ```
232pub struct Signaler {
233 tx: Option<oneshot::Sender<i32>>,
234}
235
236impl Signaler {
237 /// Create a new `Signaler`.
238 ///
239 /// Returns a `Signaler` and a `Signal` that will resolve when `signal` is called.
240 pub fn new() -> (Self, Signal) {
241 let (tx, rx) = oneshot::channel();
242 (Self { tx: Some(tx) }, rx.shared())
243 }
244
245 /// Resolve the `Signal` for all waiters (if not already resolved).
246 pub fn signal(&mut self, value: i32) {
247 if let Some(stop_tx) = self.tx.take() {
248 let _ = stop_tx.send(value);
249 }
250 }
251}
252
253#[cfg(test)]
254async fn task(i: usize) -> usize {
255 for _ in 0..5 {
256 reschedule().await;
257 }
258 i
259}
260
261#[cfg(test)]
262pub fn run_tasks(tasks: usize, runner: impl Runner, context: impl Spawner) -> Vec<usize> {
263 runner.start(async move {
264 // Randomly schedule tasks
265 let mut handles = FuturesUnordered::new();
266 for i in 0..tasks - 1 {
267 handles.push(context.spawn("test", task(i)));
268 }
269 handles.push(context.spawn("test", task(tasks - 1)));
270
271 // Collect output order
272 let mut outputs = Vec::new();
273 while let Some(result) = handles.next().await {
274 outputs.push(result.unwrap());
275 }
276 assert_eq!(outputs.len(), tasks);
277 outputs
278 })
279}