#[cfg(test)]
use crate::Runner;
use crate::{Metrics, Spawner};
#[cfg(test)]
use futures::stream::{FuturesUnordered, StreamExt};
use rayon::{ThreadPool as RThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
use std::{
any::Any,
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
pub mod buffer;
pub mod signal;
mod handle;
pub use handle::Handle;
pub async fn reschedule() {
struct Reschedule {
yielded: bool,
}
impl Future for Reschedule {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.yielded {
Poll::Ready(())
} else {
self.yielded = true;
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
Reschedule { yielded: false }.await
}
fn extract_panic_message(err: &(dyn Any + Send)) -> String {
if let Some(s) = err.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = err.downcast_ref::<String>() {
s.clone()
} else {
format!("{err:?}")
}
}
pub type ThreadPool = Arc<RThreadPool>;
pub fn create_pool<S: Spawner + Metrics>(
context: S,
concurrency: usize,
) -> Result<ThreadPool, ThreadPoolBuildError> {
let pool = ThreadPoolBuilder::new()
.num_threads(concurrency)
.spawn_handler(move |thread| {
context
.with_label("rayon-thread")
.spawn_blocking(true, move |_| thread.run());
Ok(())
})
.build()?;
Ok(Arc::new(pool))
}
pub struct RwLock<T>(async_lock::RwLock<T>);
pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>;
pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>;
impl<T> RwLock<T> {
#[inline]
pub const fn new(value: T) -> Self {
Self(async_lock::RwLock::new(value))
}
#[inline]
pub async fn read(&self) -> RwLockReadGuard<'_, T> {
self.0.read().await
}
#[inline]
pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
self.0.write().await
}
#[inline]
pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
self.0.try_read()
}
#[inline]
pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
self.0.try_write()
}
#[inline]
pub fn get_mut(&mut self) -> &mut T {
self.0.get_mut()
}
#[inline]
pub fn into_inner(self) -> T {
self.0.into_inner()
}
}
#[cfg(test)]
async fn task(i: usize) -> usize {
for _ in 0..5 {
reschedule().await;
}
i
}
#[cfg(test)]
pub fn run_tasks(tasks: usize, runner: crate::deterministic::Runner) -> (String, Vec<usize>) {
runner.start(|context| async move {
let mut handles = FuturesUnordered::new();
for i in 0..=tasks - 1 {
handles.push(context.clone().spawn(move |_| task(i)));
}
let mut outputs = Vec::new();
while let Some(result) = handles.next().await {
outputs.push(result.unwrap());
}
assert_eq!(outputs.len(), tasks);
(context.auditor().state(), outputs)
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{deterministic, tokio, Metrics};
use commonware_macros::test_traced;
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
#[test_traced]
fn test_create_pool() {
let executor = tokio::Runner::default();
executor.start(|context| async move {
let pool = create_pool(context.with_label("pool"), 4).unwrap();
let v: Vec<_> = (0..10000).collect();
pool.install(|| {
assert_eq!(v.par_iter().sum::<i32>(), 10000 * 9999 / 2);
});
});
}
#[test_traced]
fn test_rwlock() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let lock = RwLock::new(100);
let r1 = lock.read().await;
let r2 = lock.read().await;
assert_eq!(*r1 + *r2, 200);
drop((r1, r2)); let mut w = lock.write().await;
*w += 1;
assert_eq!(*w, 101);
});
}
}