commonware_runtime/utils/
mod.rs

1//! Utility functions for interacting with any runtime.
2
3#[cfg(test)]
4use crate::Runner;
5use crate::{Metrics, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use futures::task::ArcWake;
9use rayon::{ThreadPool as RThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
10use std::{
11    any::Any,
12    future::Future,
13    pin::Pin,
14    sync::{Arc, Condvar, Mutex},
15    task::{Context, Poll},
16};
17
18pub mod buffer;
19pub mod signal;
20
21mod handle;
22pub use handle::Handle;
23pub(crate) use handle::{Aborter, MetricHandle, Panicked, Panicker};
24
25mod cell;
26pub use cell::Cell as ContextCell;
27
28pub(crate) mod supervision;
29
30/// The execution mode of a task.
31#[derive(Copy, Clone, Debug)]
32pub enum Execution {
33    /// Task runs on a dedicated thread.
34    Dedicated,
35    /// Task runs on the shared executor. `true` marks short blocking work that should
36    /// use the runtime's blocking-friendly pool.
37    Shared(bool),
38}
39
40impl Default for Execution {
41    fn default() -> Self {
42        Self::Shared(false)
43    }
44}
45
46/// Yield control back to the runtime.
47pub async fn reschedule() {
48    struct Reschedule {
49        yielded: bool,
50    }
51
52    impl Future for Reschedule {
53        type Output = ();
54
55        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
56            if self.yielded {
57                Poll::Ready(())
58            } else {
59                self.yielded = true;
60                cx.waker().wake_by_ref();
61                Poll::Pending
62            }
63        }
64    }
65
66    Reschedule { yielded: false }.await
67}
68
69fn extract_panic_message(err: &(dyn Any + Send)) -> String {
70    err.downcast_ref::<&str>().map_or_else(
71        || {
72            err.downcast_ref::<String>()
73                .map_or_else(|| format!("{err:?}"), |s| s.clone())
74        },
75        |s| s.to_string(),
76    )
77}
78
79/// A clone-able wrapper around a [rayon]-compatible thread pool.
80pub type ThreadPool = Arc<RThreadPool>;
81
82/// Creates a clone-able [rayon]-compatible thread pool with [Spawner::spawn].
83///
84/// # Arguments
85/// - `context`: The runtime context implementing the [Spawner] trait.
86/// - `concurrency`: The number of tasks to execute concurrently in the pool.
87///
88/// # Returns
89/// A `Result` containing the configured [rayon::ThreadPool] or a [rayon::ThreadPoolBuildError] if the pool cannot be built.
90pub fn create_pool<S: Spawner + Metrics>(
91    context: S,
92    concurrency: usize,
93) -> Result<ThreadPool, ThreadPoolBuildError> {
94    let pool = ThreadPoolBuilder::new()
95        .num_threads(concurrency)
96        .spawn_handler(move |thread| {
97            // Tasks spawned in a thread pool are expected to run longer than any single
98            // task and thus should be provisioned as a dedicated thread.
99            context
100                .with_label("rayon_thread")
101                .dedicated()
102                .spawn(move |_| async move { thread.run() });
103            Ok(())
104        })
105        .build()?;
106
107    Ok(Arc::new(pool))
108}
109
110/// Async reader–writer lock.
111///
112/// Powered by [async_lock::RwLock], `RwLock` provides both fair writer acquisition
113/// and `try_read` / `try_write` without waiting (without any runtime-specific dependencies).
114///
115/// Usage:
116/// ```rust
117/// use commonware_runtime::{Spawner, Runner, deterministic, RwLock};
118///
119/// let executor = deterministic::Runner::default();
120/// executor.start(|context| async move {
121///     // Create a new RwLock
122///     let lock = RwLock::new(2);
123///
124///     // many concurrent readers
125///     let r1 = lock.read().await;
126///     let r2 = lock.read().await;
127///     assert_eq!(*r1 + *r2, 4);
128///
129///     // exclusive writer
130///     drop((r1, r2));
131///     let mut w = lock.write().await;
132///     *w += 1;
133/// });
134/// ```
135pub struct RwLock<T>(async_lock::RwLock<T>);
136
137/// Shared guard returned by [RwLock::read].
138pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>;
139
140/// Exclusive guard returned by [RwLock::write].
141pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>;
142
143impl<T> RwLock<T> {
144    /// Create a new lock.
145    #[inline]
146    pub const fn new(value: T) -> Self {
147        Self(async_lock::RwLock::new(value))
148    }
149
150    /// Acquire a shared read guard.
151    #[inline]
152    pub async fn read(&self) -> RwLockReadGuard<'_, T> {
153        self.0.read().await
154    }
155
156    /// Acquire an exclusive write guard.
157    #[inline]
158    pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
159        self.0.write().await
160    }
161
162    /// Try to get a read guard without waiting.
163    #[inline]
164    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
165        self.0.try_read()
166    }
167
168    /// Try to get a write guard without waiting.
169    #[inline]
170    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
171        self.0.try_write()
172    }
173
174    /// Get mutable access without locking (requires `&mut self`).
175    #[inline]
176    pub fn get_mut(&mut self) -> &mut T {
177        self.0.get_mut()
178    }
179
180    /// Consume the lock, returning the inner value.
181    #[inline]
182    pub fn into_inner(self) -> T {
183        self.0.into_inner()
184    }
185}
186
187/// Synchronization primitive that enables a thread to block until a waker delivers a signal.
188pub struct Blocker {
189    /// Tracks whether a wake-up signal has been delivered (even if wait has not started yet).
190    state: Mutex<bool>,
191    /// Condvar used to park and resume the thread when the signal flips to true.
192    cv: Condvar,
193}
194
195impl Blocker {
196    /// Create a new [Blocker].
197    pub fn new() -> Arc<Self> {
198        Arc::new(Self {
199            state: Mutex::new(false),
200            cv: Condvar::new(),
201        })
202    }
203
204    /// Block the current thread until a waker delivers a signal.
205    pub fn wait(&self) {
206        // Use a loop to tolerate spurious wake-ups and only proceed once a real signal arrives.
207        let mut signaled = self.state.lock().unwrap();
208        while !*signaled {
209            signaled = self.cv.wait(signaled).unwrap();
210        }
211
212        // Reset the flag so subsequent waits park again until the next wake signal.
213        *signaled = false;
214    }
215}
216
217impl ArcWake for Blocker {
218    fn wake_by_ref(arc_self: &Arc<Self>) {
219        let mut signaled = arc_self.state.lock().unwrap();
220        *signaled = true;
221
222        // Notify a single waiter so the blocked thread re-checks the flag.
223        arc_self.cv.notify_one();
224    }
225}
226
227/// Validates that a label matches Prometheus metric name format: `[a-zA-Z][a-zA-Z0-9_]*`.
228///
229/// # Panics
230///
231/// Panics if the label is empty, starts with a non-alphabetic character,
232/// or contains characters other than `[a-zA-Z0-9_]`.
233pub fn validate_label(label: &str) {
234    let mut chars = label.chars();
235    assert!(
236        chars.next().is_some_and(|c| c.is_ascii_alphabetic()),
237        "label must start with [a-zA-Z]: {label}"
238    );
239    assert!(
240        chars.all(|c| c.is_ascii_alphanumeric() || c == '_'),
241        "label must only contain [a-zA-Z0-9_]: {label}"
242    );
243}
244
245#[cfg(test)]
246async fn task(i: usize) -> usize {
247    for _ in 0..5 {
248        reschedule().await;
249    }
250    i
251}
252
253#[cfg(test)]
254pub fn run_tasks(tasks: usize, runner: crate::deterministic::Runner) -> (String, Vec<usize>) {
255    runner.start(|context| async move {
256        // Randomly schedule tasks
257        let mut handles = FuturesUnordered::new();
258        for i in 0..=tasks - 1 {
259            handles.push(context.clone().spawn(move |_| task(i)));
260        }
261
262        // Collect output order
263        let mut outputs = Vec::new();
264        while let Some(result) = handles.next().await {
265            outputs.push(result.unwrap());
266        }
267        assert_eq!(outputs.len(), tasks);
268        (context.auditor().state(), outputs)
269    })
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::{deterministic, tokio, Metrics};
276    use commonware_macros::test_traced;
277    use futures::task::waker;
278    use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
279    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
280
281    #[test_traced]
282    fn test_create_pool() {
283        let executor = tokio::Runner::default();
284        executor.start(|context| async move {
285            // Create a thread pool with 4 threads
286            let pool = create_pool(context.with_label("pool"), 4).unwrap();
287
288            // Create a vector of numbers
289            let v: Vec<_> = (0..10000).collect();
290
291            // Use the thread pool to sum the numbers
292            pool.install(|| {
293                assert_eq!(v.par_iter().sum::<i32>(), 10000 * 9999 / 2);
294            });
295        });
296    }
297
298    #[test_traced]
299    fn test_rwlock() {
300        let executor = deterministic::Runner::default();
301        executor.start(|_| async move {
302            // Create a new RwLock
303            let lock = RwLock::new(100);
304
305            // many concurrent readers
306            let r1 = lock.read().await;
307            let r2 = lock.read().await;
308            assert_eq!(*r1 + *r2, 200);
309
310            // exclusive writer
311            drop((r1, r2)); // all readers must go away
312            let mut w = lock.write().await;
313            *w += 1;
314
315            // Check the value
316            assert_eq!(*w, 101);
317        });
318    }
319
320    #[test]
321    fn test_blocker_waits_until_wake() {
322        let blocker = Blocker::new();
323        let started = Arc::new(AtomicBool::new(false));
324        let completed = Arc::new(AtomicBool::new(false));
325
326        let thread_blocker = blocker.clone();
327        let thread_started = started.clone();
328        let thread_completed = completed.clone();
329        let handle = std::thread::spawn(move || {
330            thread_started.store(true, Ordering::SeqCst);
331            thread_blocker.wait();
332            thread_completed.store(true, Ordering::SeqCst);
333        });
334
335        while !started.load(Ordering::SeqCst) {
336            std::thread::yield_now();
337        }
338
339        assert!(!completed.load(Ordering::SeqCst));
340        waker(blocker).wake();
341        handle.join().unwrap();
342        assert!(completed.load(Ordering::SeqCst));
343    }
344
345    #[test]
346    fn test_blocker_handles_pre_wake() {
347        let blocker = Blocker::new();
348        waker(blocker.clone()).wake();
349
350        let completed = Arc::new(AtomicBool::new(false));
351        let thread_blocker = blocker;
352        let thread_completed = completed.clone();
353        std::thread::spawn(move || {
354            thread_blocker.wait();
355            thread_completed.store(true, Ordering::SeqCst);
356        })
357        .join()
358        .unwrap();
359
360        assert!(completed.load(Ordering::SeqCst));
361    }
362
363    #[test]
364    fn test_blocker_reusable_across_signals() {
365        let blocker = Blocker::new();
366        let completed = Arc::new(AtomicUsize::new(0));
367
368        let thread_blocker = blocker.clone();
369        let thread_completed = completed.clone();
370        let handle = std::thread::spawn(move || {
371            for _ in 0..2 {
372                thread_blocker.wait();
373                thread_completed.fetch_add(1, Ordering::SeqCst);
374            }
375        });
376
377        for expected in 1..=2 {
378            waker(blocker.clone()).wake();
379            while completed.load(Ordering::SeqCst) < expected {
380                std::thread::yield_now();
381            }
382        }
383
384        handle.join().unwrap();
385        assert_eq!(completed.load(Ordering::SeqCst), 2);
386    }
387}