use rayon::iter::IntoParallelIterator;
use std::{
any::Any,
fmt::{self, Debug},
};
pub async fn par_iter<T, R, F>(t: T, closure: F) -> Result<R, Panicked>
where
T: IntoParallelIterator + Send + 'static,
R: Send + 'static,
F: FnOnce(<T as IntoParallelIterator>::Iter) -> R + Send + 'static,
{
let (tx, rx) = tokio::sync::oneshot::channel();
rayon::spawn(move || {
let pass = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
closure(t.into_par_iter())
}))
.map_err(|payload| Panicked { payload });
let _ = tx.send(pass);
});
rx.await.unwrap()
}
pub struct Panicked {
pub payload: Box<dyn Any + Send + 'static>,
}
impl Debug for Panicked {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "panicked")
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
thread,
time::{Duration, Instant},
};
#[tokio::test]
async fn test_happy_path() {
let v = vec![1, 2, 3];
use rayon::iter::ParallelIterator;
let sum: usize = par_iter(v, |iter| iter.sum()).await.unwrap();
assert_eq!(sum, 6);
}
#[tokio::test]
async fn test_actually_async() {
let v = vec![1usize, 2];
use rayon::iter::ParallelIterator;
let par_iter = par_iter(v, |iter| {
iter.map(|_| thread::sleep(Duration::from_secs(1))).count()
});
tokio::pin!(par_iter);
let async_sleep = tokio::time::sleep(Duration::from_millis(50));
tokio::pin!(async_sleep);
let now = Instant::now();
tokio::select! {
biased;
_ = &mut par_iter => {
assert!(false, "Shouldn't make it here")
}
_ = async_sleep => {
eprintln!("made it here sleep");
assert!(now.elapsed().as_millis() >= 50);
assert!(now.elapsed().as_millis() <= 75);
}
};
let count = par_iter.await;
assert_eq!(count.unwrap(), 2);
assert!(now.elapsed().as_secs() >= 1);
assert!(now.elapsed().as_secs() < 2);
}
#[tokio::test]
async fn test_panic_in_iter() {
let v = vec![1usize, 2, 3];
use rayon::iter::ParallelIterator;
let panicked = par_iter(v, |iter| iter.map(|_| panic!("gus")).count())
.await
.unwrap_err();
assert_eq!(
panicked.payload.downcast_ref::<&'static str>().unwrap(),
&"gus"
);
}
#[tokio::test]
async fn test_panic_in_closure() {
let v = vec![1usize, 2, 3];
let panicked = par_iter(v, |_| panic!("gus2")).await.unwrap_err();
assert_eq!(
panicked.payload.downcast_ref::<&'static str>().unwrap(),
&"gus2"
);
}
}