commonware_runtime/utils/
mod.rs

1//! Utility functions for interacting with any runtime.
2
3#[cfg(test)]
4use crate::Runner;
5use crate::{Metrics, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use futures::task::ArcWake;
9use rayon::{ThreadPool as RThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
10use std::{
11    any::Any,
12    future::Future,
13    pin::Pin,
14    sync::{Arc, Condvar, Mutex},
15    task::{Context, Poll},
16};
17
18pub mod buffer;
19pub mod signal;
20
21mod handle;
22pub use handle::Handle;
23pub(crate) use handle::{Aborter, MetricHandle, Panicked, Panicker};
24
25mod cell;
26pub use cell::Cell as ContextCell;
27
28pub(crate) mod supervision;
29
30/// The execution mode of a task.
31#[derive(Copy, Clone, Debug)]
32pub enum Execution {
33    /// Task runs on a dedicated thread.
34    Dedicated,
35    /// Task runs on the shared executor. `true` marks short blocking work that should
36    /// use the runtime's blocking-friendly pool.
37    Shared(bool),
38}
39
40impl Default for Execution {
41    fn default() -> Self {
42        Self::Shared(false)
43    }
44}
45
46/// Yield control back to the runtime.
47pub async fn reschedule() {
48    struct Reschedule {
49        yielded: bool,
50    }
51
52    impl Future for Reschedule {
53        type Output = ();
54
55        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
56            if self.yielded {
57                Poll::Ready(())
58            } else {
59                self.yielded = true;
60                cx.waker().wake_by_ref();
61                Poll::Pending
62            }
63        }
64    }
65
66    Reschedule { yielded: false }.await
67}
68
69fn extract_panic_message(err: &(dyn Any + Send)) -> String {
70    if let Some(s) = err.downcast_ref::<&str>() {
71        s.to_string()
72    } else if let Some(s) = err.downcast_ref::<String>() {
73        s.clone()
74    } else {
75        format!("{err:?}")
76    }
77}
78
79/// A clone-able wrapper around a [rayon]-compatible thread pool.
80pub type ThreadPool = Arc<RThreadPool>;
81
82/// Creates a clone-able [rayon]-compatible thread pool with [Spawner::spawn].
83///
84/// # Arguments
85/// - `context`: The runtime context implementing the [Spawner] trait.
86/// - `concurrency`: The number of tasks to execute concurrently in the pool.
87///
88/// # Returns
89/// A `Result` containing the configured [rayon::ThreadPool] or a [rayon::ThreadPoolBuildError] if the pool cannot be built.
90pub fn create_pool<S: Spawner + Metrics>(
91    context: S,
92    concurrency: usize,
93) -> Result<ThreadPool, ThreadPoolBuildError> {
94    let pool = ThreadPoolBuilder::new()
95        .num_threads(concurrency)
96        .spawn_handler(move |thread| {
97            // Tasks spawned in a thread pool are expected to run longer than any single
98            // task and thus should be provisioned as a dedicated thread.
99            context
100                .with_label("rayon-thread")
101                .dedicated()
102                .spawn(move |_| async move { thread.run() });
103            Ok(())
104        })
105        .build()?;
106
107    Ok(Arc::new(pool))
108}
109
110/// Async reader–writer lock.
111///
112/// Powered by [async_lock::RwLock], `RwLock` provides both fair writer acquisition
113/// and `try_read` / `try_write` without waiting (without any runtime-specific dependencies).
114///
115/// Usage:
116/// ```rust
117/// use commonware_runtime::{Spawner, Runner, deterministic, RwLock};
118///
119/// let executor = deterministic::Runner::default();
120/// executor.start(|context| async move {
121///     // Create a new RwLock
122///     let lock = RwLock::new(2);
123///
124///     // many concurrent readers
125///     let r1 = lock.read().await;
126///     let r2 = lock.read().await;
127///     assert_eq!(*r1 + *r2, 4);
128///
129///     // exclusive writer
130///     drop((r1, r2));
131///     let mut w = lock.write().await;
132///     *w += 1;
133/// });
134/// ```
135pub struct RwLock<T>(async_lock::RwLock<T>);
136
137/// Shared guard returned by [RwLock::read].
138pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>;
139
140/// Exclusive guard returned by [RwLock::write].
141pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>;
142
143impl<T> RwLock<T> {
144    /// Create a new lock.
145    #[inline]
146    pub const fn new(value: T) -> Self {
147        Self(async_lock::RwLock::new(value))
148    }
149
150    /// Acquire a shared read guard.
151    #[inline]
152    pub async fn read(&self) -> RwLockReadGuard<'_, T> {
153        self.0.read().await
154    }
155
156    /// Acquire an exclusive write guard.
157    #[inline]
158    pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
159        self.0.write().await
160    }
161
162    /// Try to get a read guard without waiting.
163    #[inline]
164    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
165        self.0.try_read()
166    }
167
168    /// Try to get a write guard without waiting.
169    #[inline]
170    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
171        self.0.try_write()
172    }
173
174    /// Get mutable access without locking (requires `&mut self`).
175    #[inline]
176    pub fn get_mut(&mut self) -> &mut T {
177        self.0.get_mut()
178    }
179
180    /// Consume the lock, returning the inner value.
181    #[inline]
182    pub fn into_inner(self) -> T {
183        self.0.into_inner()
184    }
185}
186
187/// Synchronization primitive that enables a thread to block until a waker delivers a signal.
188pub struct Blocker {
189    /// Tracks whether a wake-up signal has been delivered (even if wait has not started yet).
190    state: Mutex<bool>,
191    /// Condvar used to park and resume the thread when the signal flips to true.
192    cv: Condvar,
193}
194
195impl Blocker {
196    /// Create a new [Blocker].
197    pub fn new() -> Arc<Self> {
198        Arc::new(Self {
199            state: Mutex::new(false),
200            cv: Condvar::new(),
201        })
202    }
203
204    /// Block the current thread until a waker delivers a signal.
205    pub fn wait(&self) {
206        // Use a loop to tolerate spurious wake-ups and only proceed once a real signal arrives.
207        let mut signaled = self.state.lock().unwrap();
208        while !*signaled {
209            signaled = self.cv.wait(signaled).unwrap();
210        }
211
212        // Reset the flag so subsequent waits park again until the next wake signal.
213        *signaled = false;
214    }
215}
216
217impl ArcWake for Blocker {
218    fn wake_by_ref(arc_self: &Arc<Self>) {
219        let mut signaled = arc_self.state.lock().unwrap();
220        *signaled = true;
221
222        // Notify a single waiter so the blocked thread re-checks the flag.
223        arc_self.cv.notify_one();
224    }
225}
226
227#[cfg(test)]
228async fn task(i: usize) -> usize {
229    for _ in 0..5 {
230        reschedule().await;
231    }
232    i
233}
234
235#[cfg(test)]
236pub fn run_tasks(tasks: usize, runner: crate::deterministic::Runner) -> (String, Vec<usize>) {
237    runner.start(|context| async move {
238        // Randomly schedule tasks
239        let mut handles = FuturesUnordered::new();
240        for i in 0..=tasks - 1 {
241            handles.push(context.clone().spawn(move |_| task(i)));
242        }
243
244        // Collect output order
245        let mut outputs = Vec::new();
246        while let Some(result) = handles.next().await {
247            outputs.push(result.unwrap());
248        }
249        assert_eq!(outputs.len(), tasks);
250        (context.auditor().state(), outputs)
251    })
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use crate::{deterministic, tokio, Metrics};
258    use commonware_macros::test_traced;
259    use futures::task::waker;
260    use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
261    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
262
263    #[test_traced]
264    fn test_create_pool() {
265        let executor = tokio::Runner::default();
266        executor.start(|context| async move {
267            // Create a thread pool with 4 threads
268            let pool = create_pool(context.with_label("pool"), 4).unwrap();
269
270            // Create a vector of numbers
271            let v: Vec<_> = (0..10000).collect();
272
273            // Use the thread pool to sum the numbers
274            pool.install(|| {
275                assert_eq!(v.par_iter().sum::<i32>(), 10000 * 9999 / 2);
276            });
277        });
278    }
279
280    #[test_traced]
281    fn test_rwlock() {
282        let executor = deterministic::Runner::default();
283        executor.start(|_| async move {
284            // Create a new RwLock
285            let lock = RwLock::new(100);
286
287            // many concurrent readers
288            let r1 = lock.read().await;
289            let r2 = lock.read().await;
290            assert_eq!(*r1 + *r2, 200);
291
292            // exclusive writer
293            drop((r1, r2)); // all readers must go away
294            let mut w = lock.write().await;
295            *w += 1;
296
297            // Check the value
298            assert_eq!(*w, 101);
299        });
300    }
301
302    #[test]
303    fn test_blocker_waits_until_wake() {
304        let blocker = Blocker::new();
305        let started = Arc::new(AtomicBool::new(false));
306        let completed = Arc::new(AtomicBool::new(false));
307
308        let thread_blocker = blocker.clone();
309        let thread_started = started.clone();
310        let thread_completed = completed.clone();
311        let handle = std::thread::spawn(move || {
312            thread_started.store(true, Ordering::SeqCst);
313            thread_blocker.wait();
314            thread_completed.store(true, Ordering::SeqCst);
315        });
316
317        while !started.load(Ordering::SeqCst) {
318            std::thread::yield_now();
319        }
320
321        assert!(!completed.load(Ordering::SeqCst));
322        waker(blocker.clone()).wake();
323        handle.join().unwrap();
324        assert!(completed.load(Ordering::SeqCst));
325    }
326
327    #[test]
328    fn test_blocker_handles_pre_wake() {
329        let blocker = Blocker::new();
330        waker(blocker.clone()).wake();
331
332        let completed = Arc::new(AtomicBool::new(false));
333        let thread_blocker = blocker.clone();
334        let thread_completed = completed.clone();
335        std::thread::spawn(move || {
336            thread_blocker.wait();
337            thread_completed.store(true, Ordering::SeqCst);
338        })
339        .join()
340        .unwrap();
341
342        assert!(completed.load(Ordering::SeqCst));
343    }
344
345    #[test]
346    fn test_blocker_reusable_across_signals() {
347        let blocker = Blocker::new();
348        let completed = Arc::new(AtomicUsize::new(0));
349
350        let thread_blocker = blocker.clone();
351        let thread_completed = completed.clone();
352        let handle = std::thread::spawn(move || {
353            for _ in 0..2 {
354                thread_blocker.wait();
355                thread_completed.fetch_add(1, Ordering::SeqCst);
356            }
357        });
358
359        for expected in 1..=2 {
360            waker(blocker.clone()).wake();
361            while completed.load(Ordering::SeqCst) < expected {
362                std::thread::yield_now();
363            }
364        }
365
366        handle.join().unwrap();
367        assert_eq!(completed.load(Ordering::SeqCst), 2);
368    }
369}