use crate::MAX_SIZE_FOR_THREAD;
use crossbeam::channel;
use crossbeam::channel::Receiver;
use num_cpus;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
use std::thread::JoinHandle;
pub trait IntoParallelIteratorAsync<R, T, TL, F>
where
F: Send + Clone + 'static + Fn(T) -> Result<R, ()>,
T: Send + 'static,
TL: Send + IntoIterator<Item = T> + 'static,
R: Send,
{
fn into_par_iter_async(self, func: F) -> ParIterAsync<R>;
}
impl<R, T, TL, F> IntoParallelIteratorAsync<R, T, TL, F> for TL
where
F: Send + Clone + 'static + Fn(T) -> Result<R, ()>,
T: Send + 'static,
TL: Send + IntoIterator<Item = T> + 'static,
R: Send + 'static,
{
fn into_par_iter_async(self, func: F) -> ParIterAsync<R> {
ParIterAsync::new(self, func)
}
}
pub struct ParIterAsync<R> {
output_receiver: Receiver<R>,
worker_thread: Option<Vec<JoinHandle<()>>>,
iterator_stopper: Arc<AtomicBool>,
is_killed: bool,
worker_count: usize,
}
impl<R> ParIterAsync<R>
where
R: Send + 'static,
{
pub fn new<T, TL, F>(tasks: TL, task_executor: F) -> Self
where
F: Send + Clone + 'static + Fn(T) -> Result<R, ()>,
T: Send + 'static,
TL: Send + IntoIterator<Item = T> + 'static,
{
let cpus = num_cpus::get();
let iterator_stopper = Arc::new(AtomicBool::new(false));
let stopper_clone = iterator_stopper.clone();
let (dispatcher, task_receiver) = channel::bounded(MAX_SIZE_FOR_THREAD * cpus);
let work_dispatcher = thread::spawn(move || {
for t in tasks {
if dispatcher.send(t).is_err() {
break;
}
}
});
let (output_sender, output_receiver) = channel::bounded(MAX_SIZE_FOR_THREAD * cpus);
let worker_task = move || {
loop {
if iterator_stopper.load(Ordering::SeqCst) {
break;
}
match get_task(&task_receiver) {
None => break,
Some(task) => match task_executor(task) {
Ok(blk) => {
output_sender.send(blk).unwrap();
}
Err(_) => {
iterator_stopper.fetch_or(true, Ordering::SeqCst);
break;
}
},
}
}
};
let mut worker_handles = Vec::with_capacity(cpus + 1);
for _ in 0..cpus {
worker_handles.push(thread::spawn(worker_task.clone()));
}
worker_handles.push(work_dispatcher);
ParIterAsync {
output_receiver,
worker_thread: Some(worker_handles),
iterator_stopper: stopper_clone,
is_killed: false,
worker_count: cpus,
}
}
}
impl<R> ParIterAsync<R> {
pub fn kill(&mut self) {
if !self.is_killed {
self.iterator_stopper.fetch_or(true, Ordering::SeqCst);
for _ in 0..self.worker_count {
let _ = self.output_receiver.try_recv();
}
self.is_killed = true;
}
}
}
#[inline(always)]
fn get_task<T>(tasks: &channel::Receiver<T>) -> Option<T>
where
T: Send,
{
tasks.recv().ok()
}
impl<R> Iterator for ParIterAsync<R> {
type Item = R;
fn next(&mut self) -> Option<Self::Item> {
if self.is_killed {
return None;
}
match self.output_receiver.recv() {
Ok(block) => Some(block),
Err(_) => {
self.kill();
None
}
}
}
}
impl<R> ParIterAsync<R> {
fn join(&mut self) {
for handle in self.worker_thread.take().unwrap() {
handle.join().unwrap()
}
}
}
impl<R> Drop for ParIterAsync<R> {
fn drop(&mut self) {
self.kill();
self.join();
}
}
#[cfg(test)]
mod test_par_iter_async {
#[cfg(feature = "bench")]
extern crate test;
use crate::IntoParallelIteratorAsync;
use std::collections::HashSet;
#[cfg(feature = "bench")]
use test::Bencher;
#[test]
fn par_iter_test_exception() {
for _ in 0..100 {
let resource_captured = vec![3, 1, 4, 1, 5, 9, 2, 6, 5, 3];
let results: HashSet<i32> = (0..resource_captured.len())
.into_par_iter_async(move |a| {
let n = resource_captured.get(a).unwrap().to_owned();
if n == 5 {
Err(())
} else {
Ok(n)
}
})
.collect();
assert!(!results.contains(&5))
}
}
#[test]
fn par_iter_chained_exception() {
for _ in 0..100 {
let resource_captured: Vec<i32> = (0..10000).collect();
let resource_captured_1 = resource_captured.clone();
let resource_captured_2 = resource_captured.clone();
let results: HashSet<i32> = (0..resource_captured.len())
.into_par_iter_async(move |a| Ok(resource_captured.get(a).unwrap().to_owned()))
.into_par_iter_async(move |a| {
let n = resource_captured_1.get(a as usize).unwrap().to_owned();
if n == 1000 {
Err(())
} else {
Ok(n)
}
})
.into_par_iter_async(move |a| {
Ok(resource_captured_2.get(a as usize).unwrap().to_owned())
})
.collect();
assert!(!results.contains(&1000))
}
}
#[test]
fn test_break() {
for _ in 0..10000 {
for i in (0..2000).into_par_iter_async(|a| Ok(a)) {
if i == 1000 {
break;
}
}
}
}
#[cfg(feature = "bench")]
#[bench]
fn bench_into_par_iter_async(b: &mut Bencher) {
b.iter(|| {
(0..1_000_000)
.into_par_iter_async(|a| Ok(a))
.for_each(|_| {})
});
}
}