commonware_runtime/utils/
mod.rs1#[cfg(test)]
4use crate::Runner;
5use crate::{Metrics, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use futures::task::ArcWake;
9use rayon::{ThreadPool as RThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
10use std::{
11 any::Any,
12 future::Future,
13 pin::Pin,
14 sync::{Arc, Condvar, Mutex},
15 task::{Context, Poll},
16};
17
18pub mod buffer;
19pub mod signal;
20
21mod handle;
22pub use handle::Handle;
23pub(crate) use handle::{Aborter, MetricHandle, Panicked, Panicker};
24
25mod cell;
26pub use cell::Cell as ContextCell;
27
28pub(crate) mod supervision;
29
30#[derive(Copy, Clone, Debug)]
32pub enum Execution {
33 Dedicated,
35 Shared(bool),
38}
39
40impl Default for Execution {
41 fn default() -> Self {
42 Self::Shared(false)
43 }
44}
45
46pub async fn reschedule() {
48 struct Reschedule {
49 yielded: bool,
50 }
51
52 impl Future for Reschedule {
53 type Output = ();
54
55 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
56 if self.yielded {
57 Poll::Ready(())
58 } else {
59 self.yielded = true;
60 cx.waker().wake_by_ref();
61 Poll::Pending
62 }
63 }
64 }
65
66 Reschedule { yielded: false }.await
67}
68
69fn extract_panic_message(err: &(dyn Any + Send)) -> String {
70 err.downcast_ref::<&str>().map_or_else(
71 || {
72 err.downcast_ref::<String>()
73 .map_or_else(|| format!("{err:?}"), |s| s.clone())
74 },
75 |s| s.to_string(),
76 )
77}
78
79pub type ThreadPool = Arc<RThreadPool>;
81
82pub fn create_pool<S: Spawner + Metrics>(
91 context: S,
92 concurrency: usize,
93) -> Result<ThreadPool, ThreadPoolBuildError> {
94 let pool = ThreadPoolBuilder::new()
95 .num_threads(concurrency)
96 .spawn_handler(move |thread| {
97 context
100 .with_label("rayon_thread")
101 .dedicated()
102 .spawn(move |_| async move { thread.run() });
103 Ok(())
104 })
105 .build()?;
106
107 Ok(Arc::new(pool))
108}
109
110pub struct RwLock<T>(async_lock::RwLock<T>);
136
137pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>;
139
140pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>;
142
143impl<T> RwLock<T> {
144 #[inline]
146 pub const fn new(value: T) -> Self {
147 Self(async_lock::RwLock::new(value))
148 }
149
150 #[inline]
152 pub async fn read(&self) -> RwLockReadGuard<'_, T> {
153 self.0.read().await
154 }
155
156 #[inline]
158 pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
159 self.0.write().await
160 }
161
162 #[inline]
164 pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
165 self.0.try_read()
166 }
167
168 #[inline]
170 pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
171 self.0.try_write()
172 }
173
174 #[inline]
176 pub fn get_mut(&mut self) -> &mut T {
177 self.0.get_mut()
178 }
179
180 #[inline]
182 pub fn into_inner(self) -> T {
183 self.0.into_inner()
184 }
185}
186
187pub struct Blocker {
189 state: Mutex<bool>,
191 cv: Condvar,
193}
194
195impl Blocker {
196 pub fn new() -> Arc<Self> {
198 Arc::new(Self {
199 state: Mutex::new(false),
200 cv: Condvar::new(),
201 })
202 }
203
204 pub fn wait(&self) {
206 let mut signaled = self.state.lock().unwrap();
208 while !*signaled {
209 signaled = self.cv.wait(signaled).unwrap();
210 }
211
212 *signaled = false;
214 }
215}
216
217impl ArcWake for Blocker {
218 fn wake_by_ref(arc_self: &Arc<Self>) {
219 let mut signaled = arc_self.state.lock().unwrap();
220 *signaled = true;
221
222 arc_self.cv.notify_one();
224 }
225}
226
227pub fn validate_label(label: &str) {
234 let mut chars = label.chars();
235 assert!(
236 chars.next().is_some_and(|c| c.is_ascii_alphabetic()),
237 "label must start with [a-zA-Z]: {label}"
238 );
239 assert!(
240 chars.all(|c| c.is_ascii_alphanumeric() || c == '_'),
241 "label must only contain [a-zA-Z0-9_]: {label}"
242 );
243}
244
245#[cfg(test)]
246async fn task(i: usize) -> usize {
247 for _ in 0..5 {
248 reschedule().await;
249 }
250 i
251}
252
253#[cfg(test)]
254pub fn run_tasks(tasks: usize, runner: crate::deterministic::Runner) -> (String, Vec<usize>) {
255 runner.start(|context| async move {
256 let mut handles = FuturesUnordered::new();
258 for i in 0..=tasks - 1 {
259 handles.push(context.clone().spawn(move |_| task(i)));
260 }
261
262 let mut outputs = Vec::new();
264 while let Some(result) = handles.next().await {
265 outputs.push(result.unwrap());
266 }
267 assert_eq!(outputs.len(), tasks);
268 (context.auditor().state(), outputs)
269 })
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::{deterministic, tokio, Metrics};
276 use commonware_macros::test_traced;
277 use futures::task::waker;
278 use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
279 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
280
281 #[test_traced]
282 fn test_create_pool() {
283 let executor = tokio::Runner::default();
284 executor.start(|context| async move {
285 let pool = create_pool(context.with_label("pool"), 4).unwrap();
287
288 let v: Vec<_> = (0..10000).collect();
290
291 pool.install(|| {
293 assert_eq!(v.par_iter().sum::<i32>(), 10000 * 9999 / 2);
294 });
295 });
296 }
297
298 #[test_traced]
299 fn test_rwlock() {
300 let executor = deterministic::Runner::default();
301 executor.start(|_| async move {
302 let lock = RwLock::new(100);
304
305 let r1 = lock.read().await;
307 let r2 = lock.read().await;
308 assert_eq!(*r1 + *r2, 200);
309
310 drop((r1, r2)); let mut w = lock.write().await;
313 *w += 1;
314
315 assert_eq!(*w, 101);
317 });
318 }
319
320 #[test]
321 fn test_blocker_waits_until_wake() {
322 let blocker = Blocker::new();
323 let started = Arc::new(AtomicBool::new(false));
324 let completed = Arc::new(AtomicBool::new(false));
325
326 let thread_blocker = blocker.clone();
327 let thread_started = started.clone();
328 let thread_completed = completed.clone();
329 let handle = std::thread::spawn(move || {
330 thread_started.store(true, Ordering::SeqCst);
331 thread_blocker.wait();
332 thread_completed.store(true, Ordering::SeqCst);
333 });
334
335 while !started.load(Ordering::SeqCst) {
336 std::thread::yield_now();
337 }
338
339 assert!(!completed.load(Ordering::SeqCst));
340 waker(blocker).wake();
341 handle.join().unwrap();
342 assert!(completed.load(Ordering::SeqCst));
343 }
344
345 #[test]
346 fn test_blocker_handles_pre_wake() {
347 let blocker = Blocker::new();
348 waker(blocker.clone()).wake();
349
350 let completed = Arc::new(AtomicBool::new(false));
351 let thread_blocker = blocker;
352 let thread_completed = completed.clone();
353 std::thread::spawn(move || {
354 thread_blocker.wait();
355 thread_completed.store(true, Ordering::SeqCst);
356 })
357 .join()
358 .unwrap();
359
360 assert!(completed.load(Ordering::SeqCst));
361 }
362
363 #[test]
364 fn test_blocker_reusable_across_signals() {
365 let blocker = Blocker::new();
366 let completed = Arc::new(AtomicUsize::new(0));
367
368 let thread_blocker = blocker.clone();
369 let thread_completed = completed.clone();
370 let handle = std::thread::spawn(move || {
371 for _ in 0..2 {
372 thread_blocker.wait();
373 thread_completed.fetch_add(1, Ordering::SeqCst);
374 }
375 });
376
377 for expected in 1..=2 {
378 waker(blocker.clone()).wake();
379 while completed.load(Ordering::SeqCst) < expected {
380 std::thread::yield_now();
381 }
382 }
383
384 handle.join().unwrap();
385 assert_eq!(completed.load(Ordering::SeqCst), 2);
386 }
387}