commonware_runtime/
utils.rs

1//! Utility functions for interacting with any runtime.
2
3#[cfg(test)]
4use crate::Runner;
5use crate::{Error, Metrics, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use futures::{
9    channel::oneshot,
10    future::Shared,
11    stream::{AbortHandle, Abortable},
12    FutureExt,
13};
14use prometheus_client::metrics::gauge::Gauge;
15use rayon::{ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
16use std::{
17    any::Any,
18    future::Future,
19    panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
20    pin::Pin,
21    sync::{Arc, Once},
22    task::{Context, Poll},
23};
24use tracing::error;
25
26/// Yield control back to the runtime.
27pub async fn reschedule() {
28    struct Reschedule {
29        yielded: bool,
30    }
31
32    impl Future for Reschedule {
33        type Output = ();
34
35        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
36            if self.yielded {
37                Poll::Ready(())
38            } else {
39                self.yielded = true;
40                cx.waker().wake_by_ref();
41                Poll::Pending
42            }
43        }
44    }
45
46    Reschedule { yielded: false }.await
47}
48
49fn extract_panic_message(err: &(dyn Any + Send)) -> String {
50    if let Some(s) = err.downcast_ref::<&str>() {
51        s.to_string()
52    } else if let Some(s) = err.downcast_ref::<String>() {
53        s.clone()
54    } else {
55        format!("{:?}", err)
56    }
57}
58
59/// Handle to a spawned task.
60pub struct Handle<T>
61where
62    T: Send + 'static,
63{
64    aborter: Option<AbortHandle>,
65    receiver: oneshot::Receiver<Result<T, Error>>,
66
67    running: Gauge,
68    once: Arc<Once>,
69}
70
71impl<T> Handle<T>
72where
73    T: Send + 'static,
74{
75    pub(crate) fn init<F>(
76        f: F,
77        running: Gauge,
78        catch_panic: bool,
79    ) -> (impl Future<Output = ()>, Self)
80    where
81        F: Future<Output = T> + Send + 'static,
82    {
83        // Increment running counter
84        running.inc();
85
86        // Initialize channels to handle result/abort
87        let once = Arc::new(Once::new());
88        let (sender, receiver) = oneshot::channel();
89        let (aborter, abort_registration) = AbortHandle::new_pair();
90
91        // Wrap the future to handle panics
92        let wrapped = {
93            let once = once.clone();
94            let running = running.clone();
95            async move {
96                // Run future
97                let result = AssertUnwindSafe(f).catch_unwind().await;
98
99                // Decrement running counter
100                once.call_once(|| {
101                    running.dec();
102                });
103
104                // Handle result
105                let result = match result {
106                    Ok(result) => Ok(result),
107                    Err(err) => {
108                        if !catch_panic {
109                            resume_unwind(err);
110                        }
111                        let err = extract_panic_message(&*err);
112                        error!(?err, "task panicked");
113                        Err(Error::Exited)
114                    }
115                };
116                let _ = sender.send(result);
117            }
118        };
119
120        // Make the future abortable
121        let abortable = Abortable::new(wrapped, abort_registration);
122        (
123            abortable.map(|_| ()),
124            Self {
125                aborter: Some(aborter),
126                receiver,
127
128                running,
129                once,
130            },
131        )
132    }
133
134    pub(crate) fn init_blocking<F>(f: F, running: Gauge, catch_panic: bool) -> (impl FnOnce(), Self)
135    where
136        F: FnOnce() -> T + Send + 'static,
137    {
138        // Increment the running tasks gauge
139        running.inc();
140
141        // Initialize channel to handle result
142        let once = Arc::new(Once::new());
143        let (sender, receiver) = oneshot::channel();
144
145        // Wrap the closure with panic handling
146        let f = {
147            let once = once.clone();
148            let running = running.clone();
149            move || {
150                // Run blocking task
151                let result = catch_unwind(AssertUnwindSafe(f));
152
153                // Decrement running counter
154                once.call_once(|| {
155                    running.dec();
156                });
157
158                // Handle result
159                let result = match result {
160                    Ok(value) => Ok(value),
161                    Err(err) => {
162                        if !catch_panic {
163                            resume_unwind(err);
164                        }
165                        let err = extract_panic_message(&*err);
166                        error!(?err, "blocking task panicked");
167                        Err(Error::Exited)
168                    }
169                };
170                let _ = sender.send(result);
171            }
172        };
173
174        // Return the task and handle
175        (
176            f,
177            Self {
178                aborter: None,
179                receiver,
180
181                running,
182                once,
183            },
184        )
185    }
186
187    /// Abort the task (if not blocking).
188    pub fn abort(&self) {
189        // Get aborter and abort
190        let Some(aborter) = &self.aborter else {
191            return;
192        };
193        aborter.abort();
194
195        // Decrement running counter
196        self.once.call_once(|| {
197            self.running.dec();
198        });
199    }
200}
201
202impl<T> Future for Handle<T>
203where
204    T: Send + 'static,
205{
206    type Output = Result<T, Error>;
207
208    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
209        match Pin::new(&mut self.receiver).poll(cx) {
210            Poll::Ready(Ok(Ok(value))) => {
211                self.once.call_once(|| {
212                    self.running.dec();
213                });
214                Poll::Ready(Ok(value))
215            }
216            Poll::Ready(Ok(Err(err))) => {
217                self.once.call_once(|| {
218                    self.running.dec();
219                });
220                Poll::Ready(Err(err))
221            }
222            Poll::Ready(Err(_)) => {
223                self.once.call_once(|| {
224                    self.running.dec();
225                });
226                Poll::Ready(Err(Error::Closed))
227            }
228            Poll::Pending => Poll::Pending,
229        }
230    }
231}
232
233/// A one-time broadcast that can be awaited by many tasks. It is often used for
234/// coordinating shutdown across many tasks.
235///
236/// To minimize the overhead of tracking outstanding signals (which only return once),
237/// it is recommended to wait on a reference to it (i.e. `&mut signal`) instead of
238/// cloning it multiple times in a given task (i.e. in each iteration of a loop).
239pub type Signal = Shared<oneshot::Receiver<i32>>;
240
241/// Coordinates a one-time signal across many tasks.
242///
243/// # Example
244///
245/// ## Basic Usage
246///
247/// ```rust
248/// use commonware_runtime::{Spawner, Runner, Signaler, deterministic};
249///
250/// let executor = deterministic::Runner::default();
251/// executor.start(|context| async move {
252///     // Setup signaler and get future
253///     let (mut signaler, signal) = Signaler::new();
254///
255///     // Signal shutdown
256///     signaler.signal(2);
257///
258///     // Wait for shutdown in task
259///     let sig = signal.await.unwrap();
260///     println!("Received signal: {}", sig);
261/// });
262/// ```
263///
264/// ## Advanced Usage
265///
266/// While `Futures::Shared` is efficient, there is still meaningful overhead
267/// to cloning it (i.e. in each iteration of a loop). To avoid
268/// a performance regression from introducing `Signaler`, it is recommended
269/// to wait on a reference to `Signal` (i.e. `&mut signal`).
270///
271/// ```rust
272/// use commonware_macros::select;
273/// use commonware_runtime::{Clock, Spawner, Runner, Signaler, deterministic, Metrics};
274/// use futures::channel::oneshot;
275/// use std::time::Duration;
276///
277/// let executor = deterministic::Runner::default();
278/// executor.start(|context| async move {
279///     // Setup signaler and get future
280///     let (mut signaler, mut signal) = Signaler::new();
281///
282///     // Loop on the signal until resolved
283///     let (tx, rx) = oneshot::channel();
284///     context.with_label("waiter").spawn(|context| async move {
285///         loop {
286///             // Wait for signal or sleep
287///             select! {
288///                  sig = &mut signal => {
289///                      println!("Received signal: {}", sig.unwrap());
290///                      break;
291///                  },
292///                  _ = context.sleep(Duration::from_secs(1)) => {},
293///             };
294///         }
295///         let _ = tx.send(());
296///     });
297///
298///     // Send signal
299///     signaler.signal(9);
300///
301///     // Wait for task
302///     rx.await.expect("shutdown signaled");
303/// });
304/// ```
305pub struct Signaler {
306    tx: Option<oneshot::Sender<i32>>,
307}
308
309impl Signaler {
310    /// Create a new `Signaler`.
311    ///
312    /// Returns a `Signaler` and a `Signal` that will resolve when `signal` is called.
313    pub fn new() -> (Self, Signal) {
314        let (tx, rx) = oneshot::channel();
315        (Self { tx: Some(tx) }, rx.shared())
316    }
317
318    /// Resolve the `Signal` for all waiters (if not already resolved).
319    pub fn signal(&mut self, value: i32) {
320        if let Some(stop_tx) = self.tx.take() {
321            let _ = stop_tx.send(value);
322        }
323    }
324}
325
326/// Creates a [rayon]-compatible thread pool with [Spawner::spawn_blocking].
327///
328/// # Arguments
329/// - `context`: The runtime context implementing the [Spawner] trait.
330/// - `concurrency`: The number of tasks to execute concurrently in the pool.
331///
332/// # Returns
333/// A `Result` containing the configured [rayon::ThreadPool] or a [rayon::ThreadPoolBuildError] if the pool cannot be built.
334pub fn create_pool<S: Spawner + Metrics>(
335    context: S,
336    concurrency: usize,
337) -> Result<ThreadPool, ThreadPoolBuildError> {
338    ThreadPoolBuilder::new()
339        .num_threads(concurrency)
340        .spawn_handler(move |thread| {
341            context
342                .with_label("rayon-thread")
343                .spawn_blocking(move || thread.run());
344            Ok(())
345        })
346        .build()
347}
348
349/// Async reader–writer lock.
350///
351/// Powered by [async_lock::RwLock], `RwLock` provides both fair writer acquisition
352/// and `try_read` / `try_write` without waiting (without any runtime-specific dependencies).
353///
354/// Usage:
355/// ```rust
356/// use commonware_runtime::{Spawner, Runner, Signaler, deterministic, RwLock};
357///
358/// let executor = deterministic::Runner::default();
359/// executor.start(|context| async move {
360///     // Create a new RwLock
361///     let lock = RwLock::new(2);
362///
363///     // many concurrent readers
364///     let r1 = lock.read().await;
365///     let r2 = lock.read().await;
366///     assert_eq!(*r1 + *r2, 4);
367///
368///     // exclusive writer
369///     drop((r1, r2));
370///     let mut w = lock.write().await;
371///     *w += 1;
372/// });
373/// ```
374pub struct RwLock<T>(async_lock::RwLock<T>);
375
376/// Shared guard returned by [`RwLock::read`].
377pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>;
378
379/// Exclusive guard returned by [`RwLock::write`].
380pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>;
381
382impl<T> RwLock<T> {
383    /// Create a new lock.
384    #[inline]
385    pub const fn new(value: T) -> Self {
386        Self(async_lock::RwLock::new(value))
387    }
388
389    /// Acquire a shared read guard.
390    #[inline]
391    pub async fn read(&self) -> RwLockReadGuard<'_, T> {
392        self.0.read().await
393    }
394
395    /// Acquire an exclusive write guard.
396    #[inline]
397    pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
398        self.0.write().await
399    }
400
401    /// Try to get a read guard without waiting.
402    #[inline]
403    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
404        self.0.try_read()
405    }
406
407    /// Try to get a write guard without waiting.
408    #[inline]
409    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
410        self.0.try_write()
411    }
412
413    /// Get mutable access without locking (requires `&mut self`).
414    #[inline]
415    pub fn get_mut(&mut self) -> &mut T {
416        self.0.get_mut()
417    }
418
419    /// Consume the lock, returning the inner value.
420    #[inline]
421    pub fn into_inner(self) -> T {
422        self.0.into_inner()
423    }
424}
425
426#[cfg(test)]
427async fn task(i: usize) -> usize {
428    for _ in 0..5 {
429        reschedule().await;
430    }
431    i
432}
433
434#[cfg(test)]
435pub fn run_tasks(tasks: usize, runner: crate::deterministic::Runner) -> (String, Vec<usize>) {
436    runner.start(|context| async move {
437        // Randomly schedule tasks
438        let mut handles = FuturesUnordered::new();
439        for i in 0..=tasks - 1 {
440            handles.push(context.clone().spawn(move |_| task(i)));
441        }
442
443        // Collect output order
444        let mut outputs = Vec::new();
445        while let Some(result) = handles.next().await {
446            outputs.push(result.unwrap());
447        }
448        assert_eq!(outputs.len(), tasks);
449        (context.auditor().state(), outputs)
450    })
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456    use crate::{deterministic, tokio, Metrics};
457    use commonware_macros::test_traced;
458    use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
459
460    #[test_traced]
461    fn test_create_pool() {
462        let executor = tokio::Runner::default();
463        executor.start(|context| async move {
464            // Create a thread pool with 4 threads
465            let pool = create_pool(context.with_label("pool"), 4).unwrap();
466
467            // Create a vector of numbers
468            let v: Vec<_> = (0..10000).collect();
469
470            // Use the thread pool to sum the numbers
471            pool.install(|| {
472                assert_eq!(v.par_iter().sum::<i32>(), 10000 * 9999 / 2);
473            });
474        });
475    }
476
477    #[test_traced]
478    fn test_rwlock() {
479        let executor = deterministic::Runner::default();
480        executor.start(|_| async move {
481            // Create a new RwLock
482            let lock = RwLock::new(100);
483
484            // many concurrent readers
485            let r1 = lock.read().await;
486            let r2 = lock.read().await;
487            assert_eq!(*r1 + *r2, 200);
488
489            // exclusive writer
490            drop((r1, r2)); // all readers must go away
491            let mut w = lock.write().await;
492            *w += 1;
493
494            // Check the value
495            assert_eq!(*w, 101);
496        });
497    }
498}