commonware_runtime/utils/
mod.rs

1//! Utility functions for interacting with any runtime.
2
3#[cfg(test)]
4use crate::Runner;
5use crate::{Metrics, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use futures::{channel::oneshot, future::Shared, FutureExt};
9use rayon::{ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
10use std::{
11    any::Any,
12    future::Future,
13    pin::Pin,
14    task::{Context, Poll},
15};
16
17pub mod buffer;
18
19mod handle;
20pub use handle::Handle;
21
22/// Yield control back to the runtime.
23pub async fn reschedule() {
24    struct Reschedule {
25        yielded: bool,
26    }
27
28    impl Future for Reschedule {
29        type Output = ();
30
31        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
32            if self.yielded {
33                Poll::Ready(())
34            } else {
35                self.yielded = true;
36                cx.waker().wake_by_ref();
37                Poll::Pending
38            }
39        }
40    }
41
42    Reschedule { yielded: false }.await
43}
44
45fn extract_panic_message(err: &(dyn Any + Send)) -> String {
46    if let Some(s) = err.downcast_ref::<&str>() {
47        s.to_string()
48    } else if let Some(s) = err.downcast_ref::<String>() {
49        s.clone()
50    } else {
51        format!("{:?}", err)
52    }
53}
54
55/// A one-time broadcast that can be awaited by many tasks. It is often used for
56/// coordinating shutdown across many tasks.
57///
58/// To minimize the overhead of tracking outstanding signals (which only return once),
59/// it is recommended to wait on a reference to it (i.e. `&mut signal`) instead of
60/// cloning it multiple times in a given task (i.e. in each iteration of a loop).
61pub type Signal = Shared<oneshot::Receiver<i32>>;
62
63/// Coordinates a one-time signal across many tasks.
64///
65/// # Example
66///
67/// ## Basic Usage
68///
69/// ```rust
70/// use commonware_runtime::{Spawner, Runner, Signaler, deterministic};
71///
72/// let executor = deterministic::Runner::default();
73/// executor.start(|context| async move {
74///     // Setup signaler and get future
75///     let (mut signaler, signal) = Signaler::new();
76///
77///     // Signal shutdown
78///     signaler.signal(2);
79///
80///     // Wait for shutdown in task
81///     let sig = signal.await.unwrap();
82///     println!("Received signal: {}", sig);
83/// });
84/// ```
85///
86/// ## Advanced Usage
87///
88/// While `Futures::Shared` is efficient, there is still meaningful overhead
89/// to cloning it (i.e. in each iteration of a loop). To avoid
90/// a performance regression from introducing `Signaler`, it is recommended
91/// to wait on a reference to `Signal` (i.e. `&mut signal`).
92///
93/// ```rust
94/// use commonware_macros::select;
95/// use commonware_runtime::{Clock, Spawner, Runner, Signaler, deterministic, Metrics};
96/// use futures::channel::oneshot;
97/// use std::time::Duration;
98///
99/// let executor = deterministic::Runner::default();
100/// executor.start(|context| async move {
101///     // Setup signaler and get future
102///     let (mut signaler, mut signal) = Signaler::new();
103///
104///     // Loop on the signal until resolved
105///     let (tx, rx) = oneshot::channel();
106///     context.with_label("waiter").spawn(|context| async move {
107///         loop {
108///             // Wait for signal or sleep
109///             select! {
110///                  sig = &mut signal => {
111///                      println!("Received signal: {}", sig.unwrap());
112///                      break;
113///                  },
114///                  _ = context.sleep(Duration::from_secs(1)) => {},
115///             };
116///         }
117///         let _ = tx.send(());
118///     });
119///
120///     // Send signal
121///     signaler.signal(9);
122///
123///     // Wait for task
124///     rx.await.expect("shutdown signaled");
125/// });
126/// ```
127pub struct Signaler {
128    tx: Option<oneshot::Sender<i32>>,
129}
130
131impl Signaler {
132    /// Create a new `Signaler`.
133    ///
134    /// Returns a `Signaler` and a `Signal` that will resolve when `signal` is called.
135    pub fn new() -> (Self, Signal) {
136        let (tx, rx) = oneshot::channel();
137        (Self { tx: Some(tx) }, rx.shared())
138    }
139
140    /// Resolve the `Signal` for all waiters (if not already resolved).
141    pub fn signal(&mut self, value: i32) {
142        if let Some(stop_tx) = self.tx.take() {
143            let _ = stop_tx.send(value);
144        }
145    }
146}
147
148/// Creates a [rayon]-compatible thread pool with [Spawner::spawn_blocking].
149///
150/// # Arguments
151/// - `context`: The runtime context implementing the [Spawner] trait.
152/// - `concurrency`: The number of tasks to execute concurrently in the pool.
153///
154/// # Returns
155/// A `Result` containing the configured [rayon::ThreadPool] or a [rayon::ThreadPoolBuildError] if the pool cannot be built.
156pub fn create_pool<S: Spawner + Metrics>(
157    context: S,
158    concurrency: usize,
159) -> Result<ThreadPool, ThreadPoolBuildError> {
160    ThreadPoolBuilder::new()
161        .num_threads(concurrency)
162        .spawn_handler(move |thread| {
163            // Tasks spawned in a thread pool are expected to run longer than any single
164            // task and thus should be provisioned as a dedicated thread.
165            context
166                .with_label("rayon-thread")
167                .spawn_blocking(true, move |_| thread.run());
168            Ok(())
169        })
170        .build()
171}
172
173/// Async reader–writer lock.
174///
175/// Powered by [async_lock::RwLock], `RwLock` provides both fair writer acquisition
176/// and `try_read` / `try_write` without waiting (without any runtime-specific dependencies).
177///
178/// Usage:
179/// ```rust
180/// use commonware_runtime::{Spawner, Runner, Signaler, deterministic, RwLock};
181///
182/// let executor = deterministic::Runner::default();
183/// executor.start(|context| async move {
184///     // Create a new RwLock
185///     let lock = RwLock::new(2);
186///
187///     // many concurrent readers
188///     let r1 = lock.read().await;
189///     let r2 = lock.read().await;
190///     assert_eq!(*r1 + *r2, 4);
191///
192///     // exclusive writer
193///     drop((r1, r2));
194///     let mut w = lock.write().await;
195///     *w += 1;
196/// });
197/// ```
198pub struct RwLock<T>(async_lock::RwLock<T>);
199
200/// Shared guard returned by [RwLock::read].
201pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>;
202
203/// Exclusive guard returned by [RwLock::write].
204pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>;
205
206impl<T> RwLock<T> {
207    /// Create a new lock.
208    #[inline]
209    pub const fn new(value: T) -> Self {
210        Self(async_lock::RwLock::new(value))
211    }
212
213    /// Acquire a shared read guard.
214    #[inline]
215    pub async fn read(&self) -> RwLockReadGuard<'_, T> {
216        self.0.read().await
217    }
218
219    /// Acquire an exclusive write guard.
220    #[inline]
221    pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
222        self.0.write().await
223    }
224
225    /// Try to get a read guard without waiting.
226    #[inline]
227    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
228        self.0.try_read()
229    }
230
231    /// Try to get a write guard without waiting.
232    #[inline]
233    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
234        self.0.try_write()
235    }
236
237    /// Get mutable access without locking (requires `&mut self`).
238    #[inline]
239    pub fn get_mut(&mut self) -> &mut T {
240        self.0.get_mut()
241    }
242
243    /// Consume the lock, returning the inner value.
244    #[inline]
245    pub fn into_inner(self) -> T {
246        self.0.into_inner()
247    }
248}
249
250#[cfg(test)]
251async fn task(i: usize) -> usize {
252    for _ in 0..5 {
253        reschedule().await;
254    }
255    i
256}
257
258#[cfg(test)]
259pub fn run_tasks(tasks: usize, runner: crate::deterministic::Runner) -> (String, Vec<usize>) {
260    runner.start(|context| async move {
261        // Randomly schedule tasks
262        let mut handles = FuturesUnordered::new();
263        for i in 0..=tasks - 1 {
264            handles.push(context.clone().spawn(move |_| task(i)));
265        }
266
267        // Collect output order
268        let mut outputs = Vec::new();
269        while let Some(result) = handles.next().await {
270            outputs.push(result.unwrap());
271        }
272        assert_eq!(outputs.len(), tasks);
273        (context.auditor().state(), outputs)
274    })
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use crate::{deterministic, tokio, Metrics};
281    use commonware_macros::test_traced;
282    use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
283
284    #[test_traced]
285    fn test_create_pool() {
286        let executor = tokio::Runner::default();
287        executor.start(|context| async move {
288            // Create a thread pool with 4 threads
289            let pool = create_pool(context.with_label("pool"), 4).unwrap();
290
291            // Create a vector of numbers
292            let v: Vec<_> = (0..10000).collect();
293
294            // Use the thread pool to sum the numbers
295            pool.install(|| {
296                assert_eq!(v.par_iter().sum::<i32>(), 10000 * 9999 / 2);
297            });
298        });
299    }
300
301    #[test_traced]
302    fn test_rwlock() {
303        let executor = deterministic::Runner::default();
304        executor.start(|_| async move {
305            // Create a new RwLock
306            let lock = RwLock::new(100);
307
308            // many concurrent readers
309            let r1 = lock.read().await;
310            let r2 = lock.read().await;
311            assert_eq!(*r1 + *r2, 200);
312
313            // exclusive writer
314            drop((r1, r2)); // all readers must go away
315            let mut w = lock.write().await;
316            *w += 1;
317
318            // Check the value
319            assert_eq!(*w, 101);
320        });
321    }
322}