commonware_runtime/utils/
mod.rs

1//! Utility functions for interacting with any runtime.
2
3#[cfg(test)]
4use crate::Runner;
5use crate::{Metrics, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use rayon::{ThreadPool as RThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
9use std::{
10    any::Any,
11    future::Future,
12    pin::Pin,
13    sync::Arc,
14    task::{Context, Poll},
15};
16
17pub mod buffer;
18pub mod signal;
19
20mod handle;
21pub use handle::Handle;
22
23/// Yield control back to the runtime.
24pub async fn reschedule() {
25    struct Reschedule {
26        yielded: bool,
27    }
28
29    impl Future for Reschedule {
30        type Output = ();
31
32        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
33            if self.yielded {
34                Poll::Ready(())
35            } else {
36                self.yielded = true;
37                cx.waker().wake_by_ref();
38                Poll::Pending
39            }
40        }
41    }
42
43    Reschedule { yielded: false }.await
44}
45
46fn extract_panic_message(err: &(dyn Any + Send)) -> String {
47    if let Some(s) = err.downcast_ref::<&str>() {
48        s.to_string()
49    } else if let Some(s) = err.downcast_ref::<String>() {
50        s.clone()
51    } else {
52        format!("{err:?}")
53    }
54}
55
56/// A clone-able wrapper around a [rayon]-compatible thread pool.
57pub type ThreadPool = Arc<RThreadPool>;
58
59/// Creates a clone-able [rayon]-compatible thread pool with [Spawner::spawn_blocking].
60///
61/// # Arguments
62/// - `context`: The runtime context implementing the [Spawner] trait.
63/// - `concurrency`: The number of tasks to execute concurrently in the pool.
64///
65/// # Returns
66/// A `Result` containing the configured [rayon::ThreadPool] or a [rayon::ThreadPoolBuildError] if the pool cannot be built.
67pub fn create_pool<S: Spawner + Metrics>(
68    context: S,
69    concurrency: usize,
70) -> Result<ThreadPool, ThreadPoolBuildError> {
71    let pool = ThreadPoolBuilder::new()
72        .num_threads(concurrency)
73        .spawn_handler(move |thread| {
74            // Tasks spawned in a thread pool are expected to run longer than any single
75            // task and thus should be provisioned as a dedicated thread.
76            context
77                .with_label("rayon-thread")
78                .spawn_blocking(true, move |_| thread.run());
79            Ok(())
80        })
81        .build()?;
82
83    Ok(Arc::new(pool))
84}
85
86/// Async reader–writer lock.
87///
88/// Powered by [async_lock::RwLock], `RwLock` provides both fair writer acquisition
89/// and `try_read` / `try_write` without waiting (without any runtime-specific dependencies).
90///
91/// Usage:
92/// ```rust
93/// use commonware_runtime::{Spawner, Runner, deterministic, RwLock};
94///
95/// let executor = deterministic::Runner::default();
96/// executor.start(|context| async move {
97///     // Create a new RwLock
98///     let lock = RwLock::new(2);
99///
100///     // many concurrent readers
101///     let r1 = lock.read().await;
102///     let r2 = lock.read().await;
103///     assert_eq!(*r1 + *r2, 4);
104///
105///     // exclusive writer
106///     drop((r1, r2));
107///     let mut w = lock.write().await;
108///     *w += 1;
109/// });
110/// ```
111pub struct RwLock<T>(async_lock::RwLock<T>);
112
113/// Shared guard returned by [RwLock::read].
114pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>;
115
116/// Exclusive guard returned by [RwLock::write].
117pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>;
118
119impl<T> RwLock<T> {
120    /// Create a new lock.
121    #[inline]
122    pub const fn new(value: T) -> Self {
123        Self(async_lock::RwLock::new(value))
124    }
125
126    /// Acquire a shared read guard.
127    #[inline]
128    pub async fn read(&self) -> RwLockReadGuard<'_, T> {
129        self.0.read().await
130    }
131
132    /// Acquire an exclusive write guard.
133    #[inline]
134    pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
135        self.0.write().await
136    }
137
138    /// Try to get a read guard without waiting.
139    #[inline]
140    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
141        self.0.try_read()
142    }
143
144    /// Try to get a write guard without waiting.
145    #[inline]
146    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
147        self.0.try_write()
148    }
149
150    /// Get mutable access without locking (requires `&mut self`).
151    #[inline]
152    pub fn get_mut(&mut self) -> &mut T {
153        self.0.get_mut()
154    }
155
156    /// Consume the lock, returning the inner value.
157    #[inline]
158    pub fn into_inner(self) -> T {
159        self.0.into_inner()
160    }
161}
162
163#[cfg(test)]
164async fn task(i: usize) -> usize {
165    for _ in 0..5 {
166        reschedule().await;
167    }
168    i
169}
170
171#[cfg(test)]
172pub fn run_tasks(tasks: usize, runner: crate::deterministic::Runner) -> (String, Vec<usize>) {
173    runner.start(|context| async move {
174        // Randomly schedule tasks
175        let mut handles = FuturesUnordered::new();
176        for i in 0..=tasks - 1 {
177            handles.push(context.clone().spawn(move |_| task(i)));
178        }
179
180        // Collect output order
181        let mut outputs = Vec::new();
182        while let Some(result) = handles.next().await {
183            outputs.push(result.unwrap());
184        }
185        assert_eq!(outputs.len(), tasks);
186        (context.auditor().state(), outputs)
187    })
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use crate::{deterministic, tokio, Metrics};
194    use commonware_macros::test_traced;
195    use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
196
197    #[test_traced]
198    fn test_create_pool() {
199        let executor = tokio::Runner::default();
200        executor.start(|context| async move {
201            // Create a thread pool with 4 threads
202            let pool = create_pool(context.with_label("pool"), 4).unwrap();
203
204            // Create a vector of numbers
205            let v: Vec<_> = (0..10000).collect();
206
207            // Use the thread pool to sum the numbers
208            pool.install(|| {
209                assert_eq!(v.par_iter().sum::<i32>(), 10000 * 9999 / 2);
210            });
211        });
212    }
213
214    #[test_traced]
215    fn test_rwlock() {
216        let executor = deterministic::Runner::default();
217        executor.start(|_| async move {
218            // Create a new RwLock
219            let lock = RwLock::new(100);
220
221            // many concurrent readers
222            let r1 = lock.read().await;
223            let r2 = lock.read().await;
224            assert_eq!(*r1 + *r2, 200);
225
226            // exclusive writer
227            drop((r1, r2)); // all readers must go away
228            let mut w = lock.write().await;
229            *w += 1;
230
231            // Check the value
232            assert_eq!(*w, 101);
233        });
234    }
235}