commonware_runtime/utils/
mod.rs1#[cfg(test)]
4use crate::Runner;
5use crate::{Metrics, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use rayon::{ThreadPool as RThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
9use std::{
10 any::Any,
11 future::Future,
12 pin::Pin,
13 sync::Arc,
14 task::{Context, Poll},
15};
16
17pub mod buffer;
18pub mod signal;
19
20mod handle;
21pub use handle::Handle;
22
23pub async fn reschedule() {
25 struct Reschedule {
26 yielded: bool,
27 }
28
29 impl Future for Reschedule {
30 type Output = ();
31
32 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
33 if self.yielded {
34 Poll::Ready(())
35 } else {
36 self.yielded = true;
37 cx.waker().wake_by_ref();
38 Poll::Pending
39 }
40 }
41 }
42
43 Reschedule { yielded: false }.await
44}
45
46fn extract_panic_message(err: &(dyn Any + Send)) -> String {
47 if let Some(s) = err.downcast_ref::<&str>() {
48 s.to_string()
49 } else if let Some(s) = err.downcast_ref::<String>() {
50 s.clone()
51 } else {
52 format!("{err:?}")
53 }
54}
55
56pub type ThreadPool = Arc<RThreadPool>;
58
59pub fn create_pool<S: Spawner + Metrics>(
68 context: S,
69 concurrency: usize,
70) -> Result<ThreadPool, ThreadPoolBuildError> {
71 let pool = ThreadPoolBuilder::new()
72 .num_threads(concurrency)
73 .spawn_handler(move |thread| {
74 context
77 .with_label("rayon-thread")
78 .spawn_blocking(true, move |_| thread.run());
79 Ok(())
80 })
81 .build()?;
82
83 Ok(Arc::new(pool))
84}
85
86pub struct RwLock<T>(async_lock::RwLock<T>);
112
113pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>;
115
116pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>;
118
119impl<T> RwLock<T> {
120 #[inline]
122 pub const fn new(value: T) -> Self {
123 Self(async_lock::RwLock::new(value))
124 }
125
126 #[inline]
128 pub async fn read(&self) -> RwLockReadGuard<'_, T> {
129 self.0.read().await
130 }
131
132 #[inline]
134 pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
135 self.0.write().await
136 }
137
138 #[inline]
140 pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
141 self.0.try_read()
142 }
143
144 #[inline]
146 pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
147 self.0.try_write()
148 }
149
150 #[inline]
152 pub fn get_mut(&mut self) -> &mut T {
153 self.0.get_mut()
154 }
155
156 #[inline]
158 pub fn into_inner(self) -> T {
159 self.0.into_inner()
160 }
161}
162
163#[cfg(test)]
164async fn task(i: usize) -> usize {
165 for _ in 0..5 {
166 reschedule().await;
167 }
168 i
169}
170
171#[cfg(test)]
172pub fn run_tasks(tasks: usize, runner: crate::deterministic::Runner) -> (String, Vec<usize>) {
173 runner.start(|context| async move {
174 let mut handles = FuturesUnordered::new();
176 for i in 0..=tasks - 1 {
177 handles.push(context.clone().spawn(move |_| task(i)));
178 }
179
180 let mut outputs = Vec::new();
182 while let Some(result) = handles.next().await {
183 outputs.push(result.unwrap());
184 }
185 assert_eq!(outputs.len(), tasks);
186 (context.auditor().state(), outputs)
187 })
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use crate::{deterministic, tokio, Metrics};
194 use commonware_macros::test_traced;
195 use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
196
197 #[test_traced]
198 fn test_create_pool() {
199 let executor = tokio::Runner::default();
200 executor.start(|context| async move {
201 let pool = create_pool(context.with_label("pool"), 4).unwrap();
203
204 let v: Vec<_> = (0..10000).collect();
206
207 pool.install(|| {
209 assert_eq!(v.par_iter().sum::<i32>(), 10000 * 9999 / 2);
210 });
211 });
212 }
213
214 #[test_traced]
215 fn test_rwlock() {
216 let executor = deterministic::Runner::default();
217 executor.start(|_| async move {
218 let lock = RwLock::new(100);
220
221 let r1 = lock.read().await;
223 let r2 = lock.read().await;
224 assert_eq!(*r1 + *r2, 200);
225
226 drop((r1, r2)); let mut w = lock.write().await;
229 *w += 1;
230
231 assert_eq!(*w, 101);
233 });
234 }
235}