#![cfg_attr(feature = "bench", feature(test))]
mod iter_async;
use crossbeam::channel::{bounded, Receiver};
use crossbeam::sync::{Parker, Unparker};
use crossbeam::utils::Backoff;
pub use iter_async::*;
use num_cpus;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicIsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::thread::JoinHandle;
use std::time::Duration;
const MAX_SIZE_FOR_THREAD: usize = 128;
const BUFFER_SIZE: usize = 64;
pub trait IntoParallelIteratorSync<R, T, TL, F>
where
F: Send + Clone + 'static + Fn(T) -> Result<R, ()>,
T: Send + 'static,
TL: Send + IntoIterator<Item = T> + 'static,
<TL as IntoIterator>::IntoIter: Send + 'static,
R: Send,
{
fn into_par_iter_sync(self, func: F) -> ParIterSync<R>;
}
impl<R, T, TL, F> IntoParallelIteratorSync<R, T, TL, F> for TL
where
F: Send + Clone + 'static + Fn(T) -> Result<R, ()>,
T: Send + 'static,
TL: Send + IntoIterator<Item = T> + 'static,
<TL as IntoIterator>::IntoIter: Send + 'static,
R: Send + 'static,
{
fn into_par_iter_sync(self, func: F) -> ParIterSync<R> {
ParIterSync::new(self, func)
}
}
struct TaskRegistry {
inner: Arc<Vec<AtomicIsize>>,
parkers: Vec<Parker>,
}
impl Deref for TaskRegistry {
type Target = Vec<AtomicIsize>;
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
struct TaskRegistryWrite {
inner: Arc<Vec<AtomicIsize>>,
unparkers: Vec<Unparker>
}
impl Deref for TaskRegistryWrite {
type Target = Vec<AtomicIsize>;
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
impl Drop for TaskRegistryWrite {
fn drop(&mut self) {
for unparker in &self.unparkers {
unparker.unpark();
}
}
}
impl TaskRegistry {
fn new(size: usize) -> TaskRegistry {
TaskRegistry {
inner: Arc::new((0..size).map(|_| AtomicIsize::new(-1)).collect()),
parkers: (0..size).map(|_| Parker::new()).collect()
}
}
#[inline(always)]
pub(crate) fn lookup(&self, task_id: usize) -> Option<isize> {
let registry_len = self.len();
let pos = TaskRegistry::id_to_key(task_id, registry_len);
let backoff = Backoff::new();
loop {
if !self.is_disconnected() {
let thread_num = self[pos].swap(-1, Ordering::SeqCst);
if thread_num >= 0 {
return Some(thread_num);
} else {
if backoff.is_completed() {
self.parkers[pos].park_timeout(Duration::from_millis(500));
} else {
backoff.snooze();
}
}
} else {
let thread_num = self[pos].swap(-1, Ordering::SeqCst);
return if thread_num >= 0 {
Some(thread_num)
} else {
None
};
}
}
}
#[inline(always)]
fn id_to_key(task_id: usize, registry_len: usize) -> usize {
task_id % registry_len
}
fn to_write(&self) -> TaskRegistryWrite {
TaskRegistryWrite {
inner: self.inner.clone(),
unparkers: self.parkers.iter().map(|p| p.unparker().clone()).collect(),
}
}
#[inline(always)]
fn is_disconnected(&self) -> bool {
Arc::strong_count(&self.inner) == 1
}
}
impl TaskRegistryWrite {
#[inline(always)]
pub(crate) fn register(&self, task_id: usize, thread_id: isize) {
let registry_len = self.len();
let key = TaskRegistry::id_to_key(task_id, registry_len);
debug_assert_eq!(self[key].load(Ordering::SeqCst), -1);
self[key].store(thread_id, Ordering::SeqCst);
self.unparkers[key].unpark();
}
}
pub struct ParIterSync<R> {
output_receivers: Vec<Receiver<R>>,
task_registry: TaskRegistry,
worker_thread: Option<Vec<JoinHandle<()>>>,
iterator_stopper: Arc<AtomicBool>,
is_killed: bool,
current: usize,
}
impl<R> ParIterSync<R>
where
R: Send + 'static,
{
pub fn new<T, TL, F>(tasks: TL, task_executor: F) -> Self
where
F: Send + Clone + 'static + Fn(T) -> Result<R, ()>,
T: Send + 'static,
TL: Send + IntoIterator<Item = T> + 'static,
<TL as IntoIterator>::IntoIter: Send + 'static,
{
let cpus = num_cpus::get();
let iterator_stopper = Arc::new(AtomicBool::new(false));
let task_registry: TaskRegistry = TaskRegistry::new((1 + MAX_SIZE_FOR_THREAD) * cpus);
let (dispatcher, task_receiver) = bounded(BUFFER_SIZE);
let sender_thread = thread::spawn(move || {
for (task_id, t) in tasks.into_iter().enumerate() {
if dispatcher.send((t, task_id)).is_err() {
break;
}
}
});
let mut handles = Vec::with_capacity(cpus + 1);
let mut output_receivers = Vec::with_capacity(cpus);
for thread_number in 0..cpus as isize {
let (output_sender, output_receiver) = bounded(MAX_SIZE_FOR_THREAD);
let task_receiver = task_receiver.clone();
let task_registry = task_registry.to_write();
let iterator_stopper = iterator_stopper.clone();
let task_executor = task_executor.clone();
let handle = thread::spawn(move || {
loop {
if iterator_stopper.load(Ordering::SeqCst) {
break;
}
match get_task(&task_receiver, &task_registry, thread_number) {
None => break,
Some(task) => match task_executor(task) {
Ok(blk) => {
output_sender.send(blk).unwrap();
}
Err(_) => {
iterator_stopper.fetch_or(true, Ordering::SeqCst);
break;
}
},
}
}
});
output_receivers.push(output_receiver);
handles.push(handle);
}
handles.push(sender_thread);
ParIterSync {
output_receivers,
task_registry,
worker_thread: Some(handles),
iterator_stopper,
is_killed: false,
current: 0,
}
}
}
impl<R> ParIterSync<R> {
pub fn kill(&mut self) {
if !self.is_killed {
self.iterator_stopper.fetch_or(true, Ordering::SeqCst);
for receiver in &self.output_receivers {
let _ = receiver.try_recv();
}
self.is_killed = true;
}
}
}
#[inline(always)]
fn get_task<T>(
tasks: &Receiver<(T, usize)>,
registry: &TaskRegistryWrite,
thread_number: isize,
) -> Option<T>
where
T: Send,
{
match tasks.recv() {
Ok((task, task_id)) => {
registry.register(task_id, thread_number);
Some(task)
}
Err(_) => None,
}
}
impl<R> Iterator for ParIterSync<R> {
type Item = R;
fn next(&mut self) -> Option<Self::Item> {
if self.is_killed {
return None;
}
match self.task_registry.lookup(self.current) {
None => None,
Some(thread_num) => {
match self.output_receivers[thread_num as usize].recv() {
Ok(block) => {
self.current += 1;
Some(block)
}
Err(_) => {
self.kill();
None
}
}
}
}
}
}
impl<R> ParIterSync<R> {
fn join(&mut self) {
for handle in self.worker_thread.take().unwrap() {
handle.join().unwrap()
}
}
}
impl<R> Drop for ParIterSync<R> {
fn drop(&mut self) {
self.kill();
self.join();
}
}
#[cfg(test)]
mod test_par_iter {
#[cfg(feature = "bench")]
extern crate test;
use crate::IntoParallelIteratorSync;
#[cfg(feature = "bench")]
use test::Bencher;
fn error_at_1000(test_vec: &Vec<i32>, a: i32) -> Result<i32, ()> {
let n = test_vec.get(a as usize).unwrap().to_owned();
if n == 1000 {
Err(())
} else {
Ok(n)
}
}
#[test]
fn par_iter_test_exception() {
for _ in 0..100 {
let resource_captured = vec![3, 1, 4, 1, 5, 9, 2, 6, 5, 3];
let results_expected = vec![3, 1, 4, 1];
let results: Vec<i32> = (0..resource_captured.len())
.into_par_iter_sync(move |a| {
let n = resource_captured.get(a).unwrap().to_owned();
if n == 5 {
Err(())
} else {
Ok(n)
}
})
.collect();
assert_eq!(results, results_expected)
}
}
#[test]
fn par_iter_chained_exception() {
for _ in 0..100 {
let resource_captured: Vec<i32> = (0..10000).collect();
let resource_captured_1 = resource_captured.clone();
let resource_captured_2 = resource_captured.clone();
let results_expected: Vec<i32> = (0..1000).collect();
let results: Vec<i32> = (0..resource_captured.len())
.into_par_iter_sync(move |a| Ok(resource_captured.get(a).unwrap().to_owned()))
.into_par_iter_sync(move |a| error_at_1000(&resource_captured_1, a))
.into_par_iter_sync(move |a| {
Ok(resource_captured_2.get(a as usize).unwrap().to_owned())
})
.collect();
assert_eq!(results, results_expected)
}
}
#[test]
fn par_iter_chained_exception_1() {
for _ in 0..100 {
let resource_captured: Vec<i32> = (0..10000).collect();
let resource_captured_1 = resource_captured.clone();
let resource_captured_2 = resource_captured.clone();
let results_expected: Vec<i32> = (0..1000).collect();
let results: Vec<i32> = (0..resource_captured.len())
.into_par_iter_sync(move |a| Ok(resource_captured.get(a).unwrap().to_owned()))
.into_par_iter_sync(move |a| {
Ok(resource_captured_2.get(a as usize).unwrap().to_owned())
})
.into_par_iter_sync(move |a| error_at_1000(&resource_captured_1, a))
.collect();
assert_eq!(results, results_expected)
}
}
#[test]
fn par_iter_chained_exception_2() {
for _ in 0..100 {
let resource_captured: Vec<i32> = (0..10000).collect();
let resource_captured_1 = resource_captured.clone();
let resource_captured_2 = resource_captured.clone();
let results_expected: Vec<i32> = (0..1000).collect();
let results: Vec<i32> = (0..resource_captured.len())
.into_par_iter_sync(move |a| error_at_1000(&resource_captured_1, a as i32))
.into_par_iter_sync(move |a| {
Ok(resource_captured.get(a as usize).unwrap().to_owned())
})
.into_par_iter_sync(move |a| {
Ok(resource_captured_2.get(a as usize).unwrap().to_owned())
})
.collect();
assert_eq!(results, results_expected)
}
}
#[test]
fn test_break() {
for _ in 0..100 {
let mut count = 0;
for i in (0..20000).into_par_iter_sync(|a| Ok(a)) {
if i == 10000 {
break;
}
count += 1;
}
assert_eq!(count, 10000)
}
}
#[test]
fn test_large_iter() {
for _ in 0..10 {
let mut count = 0;
for i in (0..1_000_000).into_par_iter_sync(|i| Ok(i)) {
assert_eq!(i, count);
count += 1;
}
assert_eq!(count, 1_000_000)
}
}
#[cfg(feature = "bench")]
#[bench]
fn bench_into_par_iter_sync(b: &mut Bencher) {
b.iter(|| {
(0..1_000_000)
.into_par_iter_sync(|a| Ok(a))
.for_each(|_| {})
});
}
}