commonware_runtime/utils.rs
1//! Utility functions for interacting with any runtime.
2
3#[cfg(test)]
4use crate::Runner;
5use crate::{Error, Metrics, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use futures::{
9 channel::oneshot,
10 future::Shared,
11 stream::{AbortHandle, Abortable},
12 FutureExt,
13};
14use prometheus_client::metrics::gauge::Gauge;
15use rayon::{ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
16use std::{
17 any::Any,
18 future::Future,
19 panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
20 pin::Pin,
21 sync::{Arc, Once},
22 task::{Context, Poll},
23};
24use tracing::error;
25
26/// Yield control back to the runtime.
27pub async fn reschedule() {
28 struct Reschedule {
29 yielded: bool,
30 }
31
32 impl Future for Reschedule {
33 type Output = ();
34
35 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
36 if self.yielded {
37 Poll::Ready(())
38 } else {
39 self.yielded = true;
40 cx.waker().wake_by_ref();
41 Poll::Pending
42 }
43 }
44 }
45
46 Reschedule { yielded: false }.await
47}
48
49fn extract_panic_message(err: &(dyn Any + Send)) -> String {
50 if let Some(s) = err.downcast_ref::<&str>() {
51 s.to_string()
52 } else if let Some(s) = err.downcast_ref::<String>() {
53 s.clone()
54 } else {
55 format!("{:?}", err)
56 }
57}
58
59/// Handle to a spawned task.
60pub struct Handle<T>
61where
62 T: Send + 'static,
63{
64 aborter: Option<AbortHandle>,
65 receiver: oneshot::Receiver<Result<T, Error>>,
66
67 running: Gauge,
68 once: Arc<Once>,
69}
70
71impl<T> Handle<T>
72where
73 T: Send + 'static,
74{
75 pub(crate) fn init<F>(
76 f: F,
77 running: Gauge,
78 catch_panic: bool,
79 ) -> (impl Future<Output = ()>, Self)
80 where
81 F: Future<Output = T> + Send + 'static,
82 {
83 // Increment running counter
84 running.inc();
85
86 // Initialize channels to handle result/abort
87 let once = Arc::new(Once::new());
88 let (sender, receiver) = oneshot::channel();
89 let (aborter, abort_registration) = AbortHandle::new_pair();
90
91 // Wrap the future to handle panics
92 let wrapped = {
93 let once = once.clone();
94 let running = running.clone();
95 async move {
96 // Run future
97 let result = AssertUnwindSafe(f).catch_unwind().await;
98
99 // Decrement running counter
100 once.call_once(|| {
101 running.dec();
102 });
103
104 // Handle result
105 let result = match result {
106 Ok(result) => Ok(result),
107 Err(err) => {
108 if !catch_panic {
109 resume_unwind(err);
110 }
111 let err = extract_panic_message(&*err);
112 error!(?err, "task panicked");
113 Err(Error::Exited)
114 }
115 };
116 let _ = sender.send(result);
117 }
118 };
119
120 // Make the future abortable
121 let abortable = Abortable::new(wrapped, abort_registration);
122 (
123 abortable.map(|_| ()),
124 Self {
125 aborter: Some(aborter),
126 receiver,
127
128 running,
129 once,
130 },
131 )
132 }
133
134 pub(crate) fn init_blocking<F>(f: F, running: Gauge, catch_panic: bool) -> (impl FnOnce(), Self)
135 where
136 F: FnOnce() -> T + Send + 'static,
137 {
138 // Increment the running tasks gauge
139 running.inc();
140
141 // Initialize channel to handle result
142 let once = Arc::new(Once::new());
143 let (sender, receiver) = oneshot::channel();
144
145 // Wrap the closure with panic handling
146 let f = {
147 let once = once.clone();
148 let running = running.clone();
149 move || {
150 // Run blocking task
151 let result = catch_unwind(AssertUnwindSafe(f));
152
153 // Decrement running counter
154 once.call_once(|| {
155 running.dec();
156 });
157
158 // Handle result
159 let result = match result {
160 Ok(value) => Ok(value),
161 Err(err) => {
162 if !catch_panic {
163 resume_unwind(err);
164 }
165 let err = extract_panic_message(&*err);
166 error!(?err, "blocking task panicked");
167 Err(Error::Exited)
168 }
169 };
170 let _ = sender.send(result);
171 }
172 };
173
174 // Return the task and handle
175 (
176 f,
177 Self {
178 aborter: None,
179 receiver,
180
181 running,
182 once,
183 },
184 )
185 }
186
187 /// Abort the task (if not blocking).
188 pub fn abort(&self) {
189 // Get aborter and abort
190 let Some(aborter) = &self.aborter else {
191 return;
192 };
193 aborter.abort();
194
195 // Decrement running counter
196 self.once.call_once(|| {
197 self.running.dec();
198 });
199 }
200}
201
202impl<T> Future for Handle<T>
203where
204 T: Send + 'static,
205{
206 type Output = Result<T, Error>;
207
208 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
209 match Pin::new(&mut self.receiver).poll(cx) {
210 Poll::Ready(Ok(Ok(value))) => {
211 self.once.call_once(|| {
212 self.running.dec();
213 });
214 Poll::Ready(Ok(value))
215 }
216 Poll::Ready(Ok(Err(err))) => {
217 self.once.call_once(|| {
218 self.running.dec();
219 });
220 Poll::Ready(Err(err))
221 }
222 Poll::Ready(Err(_)) => {
223 self.once.call_once(|| {
224 self.running.dec();
225 });
226 Poll::Ready(Err(Error::Closed))
227 }
228 Poll::Pending => Poll::Pending,
229 }
230 }
231}
232
233/// A one-time broadcast that can be awaited by many tasks. It is often used for
234/// coordinating shutdown across many tasks.
235///
236/// To minimize the overhead of tracking outstanding signals (which only return once),
237/// it is recommended to wait on a reference to it (i.e. `&mut signal`) instead of
238/// cloning it multiple times in a given task (i.e. in each iteration of a loop).
239pub type Signal = Shared<oneshot::Receiver<i32>>;
240
241/// Coordinates a one-time signal across many tasks.
242///
243/// # Example
244///
245/// ## Basic Usage
246///
247/// ```rust
248/// use commonware_runtime::{Spawner, Runner, Signaler, deterministic::Executor};
249///
250/// let (executor, _, _) = Executor::default();
251/// executor.start(async move {
252/// // Setup signaler and get future
253/// let (mut signaler, signal) = Signaler::new();
254///
255/// // Signal shutdown
256/// signaler.signal(2);
257///
258/// // Wait for shutdown in task
259/// let sig = signal.await.unwrap();
260/// println!("Received signal: {}", sig);
261/// });
262/// ```
263///
264/// ## Advanced Usage
265///
266/// While `Futures::Shared` is efficient, there is still meaningful overhead
267/// to cloning it (i.e. in each iteration of a loop). To avoid
268/// a performance regression from introducing `Signaler`, it is recommended
269/// to wait on a reference to `Signal` (i.e. `&mut signal`).
270///
271/// ```rust
272/// use commonware_macros::select;
273/// use commonware_runtime::{Clock, Spawner, Runner, Signaler, deterministic::Executor, Metrics};
274/// use futures::channel::oneshot;
275/// use std::time::Duration;
276///
277/// let (executor, context, _) = Executor::default();
278/// executor.start(async move {
279/// // Setup signaler and get future
280/// let (mut signaler, mut signal) = Signaler::new();
281///
282/// // Loop on the signal until resolved
283/// let (tx, rx) = oneshot::channel();
284/// context.with_label("waiter").spawn(|context| async move {
285/// loop {
286/// // Wait for signal or sleep
287/// select! {
288/// sig = &mut signal => {
289/// println!("Received signal: {}", sig.unwrap());
290/// break;
291/// },
292/// _ = context.sleep(Duration::from_secs(1)) => {},
293/// };
294/// }
295/// let _ = tx.send(());
296/// });
297///
298/// // Send signal
299/// signaler.signal(9);
300///
301/// // Wait for task
302/// rx.await.expect("shutdown signaled");
303/// });
304/// ```
305pub struct Signaler {
306 tx: Option<oneshot::Sender<i32>>,
307}
308
309impl Signaler {
310 /// Create a new `Signaler`.
311 ///
312 /// Returns a `Signaler` and a `Signal` that will resolve when `signal` is called.
313 pub fn new() -> (Self, Signal) {
314 let (tx, rx) = oneshot::channel();
315 (Self { tx: Some(tx) }, rx.shared())
316 }
317
318 /// Resolve the `Signal` for all waiters (if not already resolved).
319 pub fn signal(&mut self, value: i32) {
320 if let Some(stop_tx) = self.tx.take() {
321 let _ = stop_tx.send(value);
322 }
323 }
324}
325
326/// Creates a [rayon]-compatible thread pool with [Spawner::spawn_blocking].
327///
328/// # Arguments
329/// - `context`: The runtime context implementing the [Spawner] trait.
330/// - `concurrency`: The number of tasks to execute concurrently in the pool.
331///
332/// # Returns
333/// A `Result` containing the configured [rayon::ThreadPool] or a [rayon::ThreadPoolBuildError] if the pool cannot be built.
334pub fn create_pool<S: Spawner + Metrics>(
335 context: S,
336 concurrency: usize,
337) -> Result<ThreadPool, ThreadPoolBuildError> {
338 ThreadPoolBuilder::new()
339 .num_threads(concurrency)
340 .spawn_handler(move |thread| {
341 context
342 .with_label("rayon-thread")
343 .spawn_blocking(move || thread.run());
344 Ok(())
345 })
346 .build()
347}
348
349#[cfg(test)]
350async fn task(i: usize) -> usize {
351 for _ in 0..5 {
352 reschedule().await;
353 }
354 i
355}
356
357#[cfg(test)]
358pub fn run_tasks(tasks: usize, runner: impl Runner, context: impl Spawner) -> Vec<usize> {
359 runner.start(async move {
360 // Randomly schedule tasks
361 let mut handles = FuturesUnordered::new();
362 for i in 0..=tasks - 1 {
363 handles.push(context.clone().spawn(move |_| task(i)));
364 }
365
366 // Collect output order
367 let mut outputs = Vec::new();
368 while let Some(result) = handles.next().await {
369 outputs.push(result.unwrap());
370 }
371 assert_eq!(outputs.len(), tasks);
372 outputs
373 })
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use crate::{tokio::Executor, Metrics};
380 use commonware_macros::test_traced;
381 use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
382
383 #[test_traced]
384 fn test_create_pool() {
385 let (executor, context) = Executor::default();
386 executor.start(async move {
387 // Create a thread pool with 4 threads
388 let pool = create_pool(context.with_label("pool"), 4).unwrap();
389
390 // Create a vector of numbers
391 let v: Vec<_> = (0..10000).collect();
392
393 // Use the thread pool to sum the numbers
394 pool.install(|| {
395 assert_eq!(v.par_iter().sum::<i32>(), 10000 * 9999 / 2);
396 });
397 });
398 }
399}