commonware_runtime/utils/
mod.rs1#[cfg(test)]
4use crate::Runner;
5use crate::{Metrics, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use futures::task::ArcWake;
9use rayon::{ThreadPool as RThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
10use std::{
11 any::Any,
12 future::Future,
13 pin::Pin,
14 sync::{Arc, Condvar, Mutex},
15 task::{Context, Poll},
16};
17
18pub mod buffer;
19pub mod signal;
20
21mod handle;
22pub use handle::Handle;
23pub(crate) use handle::{Aborter, MetricHandle, Panicked, Panicker};
24
25mod cell;
26pub use cell::Cell as ContextCell;
27
28pub(crate) mod supervision;
29
30#[derive(Copy, Clone, Debug)]
32pub enum Execution {
33 Dedicated,
35 Shared(bool),
38}
39
40impl Default for Execution {
41 fn default() -> Self {
42 Self::Shared(false)
43 }
44}
45
46pub async fn reschedule() {
48 struct Reschedule {
49 yielded: bool,
50 }
51
52 impl Future for Reschedule {
53 type Output = ();
54
55 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
56 if self.yielded {
57 Poll::Ready(())
58 } else {
59 self.yielded = true;
60 cx.waker().wake_by_ref();
61 Poll::Pending
62 }
63 }
64 }
65
66 Reschedule { yielded: false }.await
67}
68
69fn extract_panic_message(err: &(dyn Any + Send)) -> String {
70 if let Some(s) = err.downcast_ref::<&str>() {
71 s.to_string()
72 } else if let Some(s) = err.downcast_ref::<String>() {
73 s.clone()
74 } else {
75 format!("{err:?}")
76 }
77}
78
79pub type ThreadPool = Arc<RThreadPool>;
81
82pub fn create_pool<S: Spawner + Metrics>(
91 context: S,
92 concurrency: usize,
93) -> Result<ThreadPool, ThreadPoolBuildError> {
94 let pool = ThreadPoolBuilder::new()
95 .num_threads(concurrency)
96 .spawn_handler(move |thread| {
97 context
100 .with_label("rayon-thread")
101 .dedicated()
102 .spawn(move |_| async move { thread.run() });
103 Ok(())
104 })
105 .build()?;
106
107 Ok(Arc::new(pool))
108}
109
110pub struct RwLock<T>(async_lock::RwLock<T>);
136
137pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>;
139
140pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>;
142
143impl<T> RwLock<T> {
144 #[inline]
146 pub const fn new(value: T) -> Self {
147 Self(async_lock::RwLock::new(value))
148 }
149
150 #[inline]
152 pub async fn read(&self) -> RwLockReadGuard<'_, T> {
153 self.0.read().await
154 }
155
156 #[inline]
158 pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
159 self.0.write().await
160 }
161
162 #[inline]
164 pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
165 self.0.try_read()
166 }
167
168 #[inline]
170 pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
171 self.0.try_write()
172 }
173
174 #[inline]
176 pub fn get_mut(&mut self) -> &mut T {
177 self.0.get_mut()
178 }
179
180 #[inline]
182 pub fn into_inner(self) -> T {
183 self.0.into_inner()
184 }
185}
186
187pub struct Blocker {
189 state: Mutex<bool>,
191 cv: Condvar,
193}
194
195impl Blocker {
196 pub fn new() -> Arc<Self> {
198 Arc::new(Self {
199 state: Mutex::new(false),
200 cv: Condvar::new(),
201 })
202 }
203
204 pub fn wait(&self) {
206 let mut signaled = self.state.lock().unwrap();
208 while !*signaled {
209 signaled = self.cv.wait(signaled).unwrap();
210 }
211
212 *signaled = false;
214 }
215}
216
217impl ArcWake for Blocker {
218 fn wake_by_ref(arc_self: &Arc<Self>) {
219 let mut signaled = arc_self.state.lock().unwrap();
220 *signaled = true;
221
222 arc_self.cv.notify_one();
224 }
225}
226
227#[cfg(test)]
228async fn task(i: usize) -> usize {
229 for _ in 0..5 {
230 reschedule().await;
231 }
232 i
233}
234
235#[cfg(test)]
236pub fn run_tasks(tasks: usize, runner: crate::deterministic::Runner) -> (String, Vec<usize>) {
237 runner.start(|context| async move {
238 let mut handles = FuturesUnordered::new();
240 for i in 0..=tasks - 1 {
241 handles.push(context.clone().spawn(move |_| task(i)));
242 }
243
244 let mut outputs = Vec::new();
246 while let Some(result) = handles.next().await {
247 outputs.push(result.unwrap());
248 }
249 assert_eq!(outputs.len(), tasks);
250 (context.auditor().state(), outputs)
251 })
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use crate::{deterministic, tokio, Metrics};
258 use commonware_macros::test_traced;
259 use futures::task::waker;
260 use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
261 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
262
263 #[test_traced]
264 fn test_create_pool() {
265 let executor = tokio::Runner::default();
266 executor.start(|context| async move {
267 let pool = create_pool(context.with_label("pool"), 4).unwrap();
269
270 let v: Vec<_> = (0..10000).collect();
272
273 pool.install(|| {
275 assert_eq!(v.par_iter().sum::<i32>(), 10000 * 9999 / 2);
276 });
277 });
278 }
279
280 #[test_traced]
281 fn test_rwlock() {
282 let executor = deterministic::Runner::default();
283 executor.start(|_| async move {
284 let lock = RwLock::new(100);
286
287 let r1 = lock.read().await;
289 let r2 = lock.read().await;
290 assert_eq!(*r1 + *r2, 200);
291
292 drop((r1, r2)); let mut w = lock.write().await;
295 *w += 1;
296
297 assert_eq!(*w, 101);
299 });
300 }
301
302 #[test]
303 fn test_blocker_waits_until_wake() {
304 let blocker = Blocker::new();
305 let started = Arc::new(AtomicBool::new(false));
306 let completed = Arc::new(AtomicBool::new(false));
307
308 let thread_blocker = blocker.clone();
309 let thread_started = started.clone();
310 let thread_completed = completed.clone();
311 let handle = std::thread::spawn(move || {
312 thread_started.store(true, Ordering::SeqCst);
313 thread_blocker.wait();
314 thread_completed.store(true, Ordering::SeqCst);
315 });
316
317 while !started.load(Ordering::SeqCst) {
318 std::thread::yield_now();
319 }
320
321 assert!(!completed.load(Ordering::SeqCst));
322 waker(blocker.clone()).wake();
323 handle.join().unwrap();
324 assert!(completed.load(Ordering::SeqCst));
325 }
326
327 #[test]
328 fn test_blocker_handles_pre_wake() {
329 let blocker = Blocker::new();
330 waker(blocker.clone()).wake();
331
332 let completed = Arc::new(AtomicBool::new(false));
333 let thread_blocker = blocker.clone();
334 let thread_completed = completed.clone();
335 std::thread::spawn(move || {
336 thread_blocker.wait();
337 thread_completed.store(true, Ordering::SeqCst);
338 })
339 .join()
340 .unwrap();
341
342 assert!(completed.load(Ordering::SeqCst));
343 }
344
345 #[test]
346 fn test_blocker_reusable_across_signals() {
347 let blocker = Blocker::new();
348 let completed = Arc::new(AtomicUsize::new(0));
349
350 let thread_blocker = blocker.clone();
351 let thread_completed = completed.clone();
352 let handle = std::thread::spawn(move || {
353 for _ in 0..2 {
354 thread_blocker.wait();
355 thread_completed.fetch_add(1, Ordering::SeqCst);
356 }
357 });
358
359 for expected in 1..=2 {
360 waker(blocker.clone()).wake();
361 while completed.load(Ordering::SeqCst) < expected {
362 std::thread::yield_now();
363 }
364 }
365
366 handle.join().unwrap();
367 assert_eq!(completed.load(Ordering::SeqCst), 2);
368 }
369}