use std::any::Any;
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
use std::sync::mpsc::{sync_channel, Receiver};
use std::thread::{self, JoinHandle};
use std::vec::IntoIter;
type PanicUnwindErr = Box<dyn Any + Send>;
pub struct ChunkedReadAheadIterator<T: Send + 'static> {
receiver: Receiver<Result<Vec<T>, PanicUnwindErr>>,
join_handle: Option<JoinHandle<()>>,
current_chunk: IntoIter<T>,
}
impl<T> ChunkedReadAheadIterator<T>
where
T: Send + 'static,
{
pub fn new<I>(mut inner: I, chunk_size: usize, num_chunk_buffer_size: usize) -> Self
where
I: Iterator<Item = T> + Send + 'static,
{
assert_ne!(chunk_size, 0, "Chunk size cannot be zero!");
assert_ne!(num_chunk_buffer_size, 0, "Number of buffered chunks cannot be zero!");
let (sender, receiver) = sync_channel(num_chunk_buffer_size);
let join_handle = thread::Builder::new()
.name("chunked_read_ahead_thread".to_owned())
.spawn(move || {
'chunk_loop: loop {
let mut chunk = Vec::with_capacity(chunk_size);
for _ in 0..chunk_size {
match catch_unwind(AssertUnwindSafe(|| inner.by_ref().next())) {
Ok(Some(val)) => chunk.push(val),
Ok(None) => break,
Err(e) => {
let _x = sender.send(Ok(chunk));
let _x = sender.send(Err(e));
break 'chunk_loop;
}
}
}
if chunk.is_empty() || sender.send(Ok(chunk)).is_err() {
break;
}
}
})
.expect("failed to spawn chunked read ahead thread");
Self { receiver, join_handle: Some(join_handle), current_chunk: Vec::new().into_iter() }
}
}
impl<T> Iterator for ChunkedReadAheadIterator<T>
where
T: Send + 'static,
{
type Item = T;
fn next(&mut self) -> Option<T> {
let next_option = self.current_chunk.next();
if next_option.is_some() {
next_option
} else {
if let Ok(chunk_or_panic) = self.receiver.recv() {
match chunk_or_panic {
Ok(next_chunk) => {
self.current_chunk = next_chunk.into_iter();
self.current_chunk.next()
}
Err(e) => {
resume_unwind(e);
}
}
} else {
if let Some(join_handle) = self.join_handle.take() {
if let Err(e) = join_handle.join() {
resume_unwind(e)
}
}
None
}
}
}
}
#[allow(clippy::module_name_repetitions)]
pub trait IntoChunkedReadAheadIterator<T>
where
T: Send + 'static,
{
fn read_ahead(
self,
chunk_size: usize,
num_chunk_buffer_size: usize,
) -> ChunkedReadAheadIterator<T>
where
Self: Send + 'static;
}
impl<I, T> IntoChunkedReadAheadIterator<T> for I
where
T: Send + 'static,
I: Iterator<Item = T>,
{
fn read_ahead(
self,
chunk_size: usize,
num_chunk_buffer_size: usize,
) -> ChunkedReadAheadIterator<T>
where
Self: Send + 'static,
{
ChunkedReadAheadIterator::new(self, chunk_size, num_chunk_buffer_size)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use std::mem::drop;
use std::panic;
use std::thread::sleep;
use std::time::Duration;
#[rstest]
#[case(1)] #[case(2)]
#[case(4)]
#[case(8)]
#[case(16)] fn test_wrapping_empty_iter(#[case] chunk_size: usize) {
let test_vec: Vec<usize> = Vec::new();
let mut chunked_iter = test_vec.into_iter().read_ahead(chunk_size, 1);
assert_eq!(chunked_iter.next(), None);
}
#[rstest]
#[case(1, 1)] #[case(2, 1)]
#[case(4, 1)]
#[case(8, 1)]
#[case(16, 1)]
#[case(1, 2)]
#[case(2, 2)]
#[case(4, 2)]
#[case(8, 2)]
#[case(16, 2)]
#[case(1, 16)]
#[case(2, 16)]
#[case(4, 16)]
#[case(8, 16)]
#[case(16, 16)]
#[case(1, 100)]
#[case(2, 100)]
#[case(4, 100)]
#[case(8, 100)]
#[case(16, 100)]
fn test_handle_large_iterator_and_low_chunk_size(
#[case] chunk_size: usize,
#[case] buffer_size: usize,
) {
let test_vec: Vec<usize> = (0..1_000_000).into_iter().collect();
let test_vec2 = test_vec.clone();
let mut regular_iter = test_vec.into_iter();
let mut chunked_iter = test_vec2.into_iter().read_ahead(chunk_size, buffer_size);
loop {
let i = regular_iter.next();
let j = chunked_iter.next();
assert_eq!(i, j);
if i.is_none() {
assert!(j.is_none());
break;
}
}
}
#[test]
fn test_low_bound_on_channel_for_blocking() {
let chunked_iter = (0..100_000).into_iter().read_ahead(8, 1);
for i in chunked_iter {
let _ = i % 2;
}
}
#[rstest]
#[case(1)] #[case(2)]
#[case(4)]
#[case(8)]
#[case(16)] fn test_dropping_before_doesnt_explode(#[case] chunk_size: usize) {
let test_vec = vec![0usize, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let chunked_iter = test_vec.into_iter().read_ahead(chunk_size, 1);
sleep(Duration::from_millis(10));
drop(chunked_iter);
}
#[rstest]
#[case(1)] #[case(2)]
#[case(4)]
#[case(8)]
#[case(16)] fn test_dropping_half_used_iterator_doesnt_explode(#[case] chunk_size: usize) {
let test_vec = vec![0usize, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let mut chunked_iter = test_vec.into_iter().read_ahead(chunk_size, 1);
for _ in 0..4 {
chunked_iter.next();
}
drop(chunked_iter);
}
#[rstest]
#[case(1)] #[case(2)]
#[case(4)]
#[case(8)]
#[case(16)] fn test_dropping_fully_used_iterator_doesnt_explode(#[case] chunk_size: usize) {
let test_vec = vec![0usize, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let mut chunked_iter = test_vec.clone().into_iter().read_ahead(chunk_size, 1);
for _ in 0..test_vec.len() {
chunked_iter.next();
}
drop(chunked_iter);
}
#[rstest]
#[case(1)] #[case(2)]
#[case(4)]
#[case(8)]
#[case(16)] fn test_read_ahead_results_in_same_results_as_regular_iter(#[case] chunk_size: usize) {
let test_vec = vec![0usize, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let mut regular_iter = test_vec.clone().into_iter();
let mut chunked_iter = test_vec.into_iter().read_ahead(chunk_size, 1);
loop {
let i = regular_iter.next();
let j = chunked_iter.next();
assert_eq!(i, j);
if i.is_none() {
assert!(j.is_none());
break;
}
}
}
#[rstest]
#[case(1)] #[case(2)]
#[case(4)]
#[case(8)]
#[case(16)] fn test_read_past_end(#[case] chunk_size: usize) {
let mut test_iter =
vec![0usize, 1, 2, 3, 4, 5, 6, 7, 8, 9].into_iter().read_ahead(chunk_size, 1);
for i in 0..20 {
let v = test_iter.next();
if i < 10 {
assert_eq!(v, Some(i));
} else {
assert_eq!(v, None);
}
}
}
const FAIL_POINT: usize = 6;
struct FailingIter {
counter: usize,
}
impl FailingIter {
fn new() -> Self {
Self { counter: 0 }
}
}
impl Iterator for FailingIter {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
assert!(self.counter < FAIL_POINT, "expected error message");
let current = self.counter;
self.counter += 1;
Some(current)
}
}
#[test]
#[should_panic(expected = "expected error message")]
fn test_panic_occurring_mid_chunk_returns_results_until_panic() {
let mut test_iter = FailingIter::new().into_iter().read_ahead(8, 1);
for _ in 0..FAIL_POINT {
panic::catch_unwind(AssertUnwindSafe(|| {
test_iter.next();
}))
.expect("different error message");
}
test_iter.next();
}
struct ExitFailingIter {
counter: usize,
}
impl ExitFailingIter {
fn new() -> Self {
Self { counter: 0 }
}
}
impl Iterator for ExitFailingIter {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
if self.counter < FAIL_POINT {
let current = self.counter;
self.counter += 1;
Some(current)
} else {
None
}
}
}
impl Drop for ExitFailingIter {
fn drop(&mut self) {
panic!("expected error message")
}
}
#[test]
#[should_panic(expected = "expected error message")]
fn test_panic_occurring_after_iteration_raises() {
{
let mut test_iter = ExitFailingIter::new().into_iter().read_ahead(8, 1);
for _ in 0..FAIL_POINT {
panic::catch_unwind(AssertUnwindSafe(|| {
test_iter.next();
}))
.expect("different error message");
}
assert_eq!(test_iter.next(), None);
}
}
}