#![deny(missing_docs)]
use rayon::iter::plumbing::{Folder, Reducer, UnindexedConsumer};
use rayon::iter::ParallelIterator;
use rayon::{current_num_threads, join_context};
pub trait Spliterator: Iterator + Sized {
fn split(&mut self) -> Option<Self>;
}
pub trait ParallelSpliterator: Sized {
fn par_split(self) -> ParSpliter<Self>;
}
impl<T> ParallelSpliterator for T
where
T: Spliterator + Send,
T::Item: Send,
{
fn par_split(self) -> ParSpliter<Self> {
ParSpliter::new(self)
}
}
#[derive(Clone, Copy, Debug)]
pub struct ParSpliter<T> {
iter: T,
splits: usize,
}
impl<T: Spliterator> ParSpliter<T> {
fn new(iter: T) -> Self {
Self {
iter,
splits: current_num_threads(),
}
}
fn split(&mut self, stolen: bool) -> Option<Self> {
if stolen {
self.splits = current_num_threads();
}
if self.splits == 0 {
return None;
}
if let Some(split) = self.iter.split() {
self.splits /= 2;
Some(Self {
iter: split,
splits: self.splits,
})
} else {
None
}
}
fn bridge<C>(&mut self, stolen: bool, consumer: C) -> C::Result
where
T: Send,
C: UnindexedConsumer<T::Item>,
{
let mut folder = consumer.split_off_left().into_folder();
while !folder.full() {
if let Some(mut split) = self.split(stolen) {
let (r1, r2) = (consumer.to_reducer(), consumer.to_reducer());
let left_consumer = consumer.split_off_left();
let (left, right) = join_context(
|ctx| self.bridge(ctx.migrated(), left_consumer),
|ctx| split.bridge(ctx.migrated(), consumer),
);
return r1.reduce(folder.complete(), r2.reduce(left, right));
}
if let Some(next) = self.iter.next() {
folder = folder.consume(next);
} else {
break;
}
}
folder.complete()
}
}
impl<T> ParallelIterator for ParSpliter<T>
where
T: Spliterator + Send,
T::Item: Send,
{
type Item = T::Item;
fn drive_unindexed<C>(mut self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
self.bridge(false, consumer)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_par_split() {
struct AllNumbers {
stack: Vec<u32>,
}
impl AllNumbers {
fn new() -> Self {
Self { stack: vec![1] }
}
}
impl Iterator for AllNumbers {
type Item = u32;
fn next(&mut self) -> Option<Self::Item> {
if let Some(n) = self.stack.pop() {
if n < 1 << 15 {
self.stack.push(2 * n);
self.stack.push(2 * n + 1);
}
Some(n)
} else {
None
}
}
}
impl Spliterator for AllNumbers {
fn split(&mut self) -> Option<Self> {
let len = self.stack.len();
if len >= 2 {
let split = self.stack.split_off(len / 2);
Some(Self { stack: split })
} else {
None
}
}
}
assert_eq!(AllNumbers::new().count(), (1 << 16) - 1);
assert_eq!(AllNumbers::new().par_split().count(), (1 << 16) - 1);
}
}