commonware_runtime/utils/
mod.rs

1//! Utility functions for interacting with any runtime.
2
3#[cfg(test)]
4use crate::Runner;
5use crate::{Metrics, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use futures::{channel::oneshot, future::Shared, FutureExt};
9use rayon::{ThreadPool as RThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
10use std::{
11    any::Any,
12    future::Future,
13    pin::Pin,
14    sync::Arc,
15    task::{Context, Poll},
16};
17
18pub mod buffer;
19
20mod handle;
21pub use handle::Handle;
22
23/// Yield control back to the runtime.
24pub async fn reschedule() {
25    struct Reschedule {
26        yielded: bool,
27    }
28
29    impl Future for Reschedule {
30        type Output = ();
31
32        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
33            if self.yielded {
34                Poll::Ready(())
35            } else {
36                self.yielded = true;
37                cx.waker().wake_by_ref();
38                Poll::Pending
39            }
40        }
41    }
42
43    Reschedule { yielded: false }.await
44}
45
46fn extract_panic_message(err: &(dyn Any + Send)) -> String {
47    if let Some(s) = err.downcast_ref::<&str>() {
48        s.to_string()
49    } else if let Some(s) = err.downcast_ref::<String>() {
50        s.clone()
51    } else {
52        format!("{err:?}")
53    }
54}
55
56/// A one-time broadcast that can be awaited by many tasks. It is often used for
57/// coordinating shutdown across many tasks.
58///
59/// To minimize the overhead of tracking outstanding signals (which only return once),
60/// it is recommended to wait on a reference to it (i.e. `&mut signal`) instead of
61/// cloning it multiple times in a given task (i.e. in each iteration of a loop).
62pub type Signal = Shared<oneshot::Receiver<i32>>;
63
64/// Coordinates a one-time signal across many tasks.
65///
66/// # Example
67///
68/// ## Basic Usage
69///
70/// ```rust
71/// use commonware_runtime::{Spawner, Runner, Signaler, deterministic};
72///
73/// let executor = deterministic::Runner::default();
74/// executor.start(|context| async move {
75///     // Setup signaler and get future
76///     let (mut signaler, signal) = Signaler::new();
77///
78///     // Signal shutdown
79///     signaler.signal(2);
80///
81///     // Wait for shutdown in task
82///     let sig = signal.await.unwrap();
83///     println!("Received signal: {}", sig);
84/// });
85/// ```
86///
87/// ## Advanced Usage
88///
89/// While `Futures::Shared` is efficient, there is still meaningful overhead
90/// to cloning it (i.e. in each iteration of a loop). To avoid
91/// a performance regression from introducing `Signaler`, it is recommended
92/// to wait on a reference to `Signal` (i.e. `&mut signal`).
93///
94/// ```rust
95/// use commonware_macros::select;
96/// use commonware_runtime::{Clock, Spawner, Runner, Signaler, deterministic, Metrics};
97/// use futures::channel::oneshot;
98/// use std::time::Duration;
99///
100/// let executor = deterministic::Runner::default();
101/// executor.start(|context| async move {
102///     // Setup signaler and get future
103///     let (mut signaler, mut signal) = Signaler::new();
104///
105///     // Loop on the signal until resolved
106///     let (tx, rx) = oneshot::channel();
107///     context.with_label("waiter").spawn(|context| async move {
108///         loop {
109///             // Wait for signal or sleep
110///             select! {
111///                  sig = &mut signal => {
112///                      println!("Received signal: {}", sig.unwrap());
113///                      break;
114///                  },
115///                  _ = context.sleep(Duration::from_secs(1)) => {},
116///             };
117///         }
118///         let _ = tx.send(());
119///     });
120///
121///     // Send signal
122///     signaler.signal(9);
123///
124///     // Wait for task
125///     rx.await.expect("shutdown signaled");
126/// });
127/// ```
128pub struct Signaler {
129    tx: Option<oneshot::Sender<i32>>,
130}
131
132impl Signaler {
133    /// Create a new `Signaler`.
134    ///
135    /// Returns a `Signaler` and a `Signal` that will resolve when `signal` is called.
136    pub fn new() -> (Self, Signal) {
137        let (tx, rx) = oneshot::channel();
138        (Self { tx: Some(tx) }, rx.shared())
139    }
140
141    /// Resolve the `Signal` for all waiters (if not already resolved).
142    pub fn signal(&mut self, value: i32) {
143        if let Some(stop_tx) = self.tx.take() {
144            let _ = stop_tx.send(value);
145        }
146    }
147}
148
149/// A clone-able wrapper around a [rayon]-compatible thread pool.
150pub type ThreadPool = Arc<RThreadPool>;
151
152/// Creates a clone-able [rayon]-compatible thread pool with [Spawner::spawn_blocking].
153///
154/// # Arguments
155/// - `context`: The runtime context implementing the [Spawner] trait.
156/// - `concurrency`: The number of tasks to execute concurrently in the pool.
157///
158/// # Returns
159/// A `Result` containing the configured [rayon::ThreadPool] or a [rayon::ThreadPoolBuildError] if the pool cannot be built.
160pub fn create_pool<S: Spawner + Metrics>(
161    context: S,
162    concurrency: usize,
163) -> Result<ThreadPool, ThreadPoolBuildError> {
164    let pool = ThreadPoolBuilder::new()
165        .num_threads(concurrency)
166        .spawn_handler(move |thread| {
167            // Tasks spawned in a thread pool are expected to run longer than any single
168            // task and thus should be provisioned as a dedicated thread.
169            context
170                .with_label("rayon-thread")
171                .spawn_blocking(true, move |_| thread.run());
172            Ok(())
173        })
174        .build()?;
175
176    Ok(Arc::new(pool))
177}
178
179/// Async reader–writer lock.
180///
181/// Powered by [async_lock::RwLock], `RwLock` provides both fair writer acquisition
182/// and `try_read` / `try_write` without waiting (without any runtime-specific dependencies).
183///
184/// Usage:
185/// ```rust
186/// use commonware_runtime::{Spawner, Runner, Signaler, deterministic, RwLock};
187///
188/// let executor = deterministic::Runner::default();
189/// executor.start(|context| async move {
190///     // Create a new RwLock
191///     let lock = RwLock::new(2);
192///
193///     // many concurrent readers
194///     let r1 = lock.read().await;
195///     let r2 = lock.read().await;
196///     assert_eq!(*r1 + *r2, 4);
197///
198///     // exclusive writer
199///     drop((r1, r2));
200///     let mut w = lock.write().await;
201///     *w += 1;
202/// });
203/// ```
204pub struct RwLock<T>(async_lock::RwLock<T>);
205
206/// Shared guard returned by [RwLock::read].
207pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>;
208
209/// Exclusive guard returned by [RwLock::write].
210pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>;
211
212impl<T> RwLock<T> {
213    /// Create a new lock.
214    #[inline]
215    pub const fn new(value: T) -> Self {
216        Self(async_lock::RwLock::new(value))
217    }
218
219    /// Acquire a shared read guard.
220    #[inline]
221    pub async fn read(&self) -> RwLockReadGuard<'_, T> {
222        self.0.read().await
223    }
224
225    /// Acquire an exclusive write guard.
226    #[inline]
227    pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
228        self.0.write().await
229    }
230
231    /// Try to get a read guard without waiting.
232    #[inline]
233    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
234        self.0.try_read()
235    }
236
237    /// Try to get a write guard without waiting.
238    #[inline]
239    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
240        self.0.try_write()
241    }
242
243    /// Get mutable access without locking (requires `&mut self`).
244    #[inline]
245    pub fn get_mut(&mut self) -> &mut T {
246        self.0.get_mut()
247    }
248
249    /// Consume the lock, returning the inner value.
250    #[inline]
251    pub fn into_inner(self) -> T {
252        self.0.into_inner()
253    }
254}
255
256#[cfg(test)]
257async fn task(i: usize) -> usize {
258    for _ in 0..5 {
259        reschedule().await;
260    }
261    i
262}
263
264#[cfg(test)]
265pub fn run_tasks(tasks: usize, runner: crate::deterministic::Runner) -> (String, Vec<usize>) {
266    runner.start(|context| async move {
267        // Randomly schedule tasks
268        let mut handles = FuturesUnordered::new();
269        for i in 0..=tasks - 1 {
270            handles.push(context.clone().spawn(move |_| task(i)));
271        }
272
273        // Collect output order
274        let mut outputs = Vec::new();
275        while let Some(result) = handles.next().await {
276            outputs.push(result.unwrap());
277        }
278        assert_eq!(outputs.len(), tasks);
279        (context.auditor().state(), outputs)
280    })
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use crate::{deterministic, tokio, Metrics};
287    use commonware_macros::test_traced;
288    use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
289
290    #[test_traced]
291    fn test_create_pool() {
292        let executor = tokio::Runner::default();
293        executor.start(|context| async move {
294            // Create a thread pool with 4 threads
295            let pool = create_pool(context.with_label("pool"), 4).unwrap();
296
297            // Create a vector of numbers
298            let v: Vec<_> = (0..10000).collect();
299
300            // Use the thread pool to sum the numbers
301            pool.install(|| {
302                assert_eq!(v.par_iter().sum::<i32>(), 10000 * 9999 / 2);
303            });
304        });
305    }
306
307    #[test_traced]
308    fn test_rwlock() {
309        let executor = deterministic::Runner::default();
310        executor.start(|_| async move {
311            // Create a new RwLock
312            let lock = RwLock::new(100);
313
314            // many concurrent readers
315            let r1 = lock.read().await;
316            let r2 = lock.read().await;
317            assert_eq!(*r1 + *r2, 200);
318
319            // exclusive writer
320            drop((r1, r2)); // all readers must go away
321            let mut w = lock.write().await;
322            *w += 1;
323
324            // Check the value
325            assert_eq!(*w, 101);
326        });
327    }
328}