use deserialize::{CheckableKind, OrcDeserialize, OrcStruct};
use errors::OpenOrcError;
use reader::{Reader, RowReaderOptions};
use std::convert::TryInto;
use std::marker::PhantomData;
use std::num::NonZeroU64;
use std::sync::Arc;
use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback};
use rayon::prelude::*;
use row_iterator::RowIterator;
pub struct ParallelRowIterator<T: OrcDeserialize + Clone> {
reader: Arc<Reader>,
row_reader_options: RowReaderOptions,
batch_size: NonZeroU64,
start: usize,
end: usize,
marker: PhantomData<T>,
}
impl<T: OrcDeserialize + OrcStruct + CheckableKind + Clone> ParallelRowIterator<T> {
pub fn new(
reader: Arc<Reader>,
batch_size: NonZeroU64,
) -> Result<ParallelRowIterator<T>, OpenOrcError> {
let options = RowReaderOptions::default().include_names(T::columns());
Self::new_with_options(reader, batch_size, options)
}
}
impl<T: OrcDeserialize + Clone> ParallelRowIterator<T> {
pub fn new_with_options(
reader: Arc<Reader>,
batch_size: NonZeroU64,
options: RowReaderOptions,
) -> Result<ParallelRowIterator<T>, OpenOrcError> {
let row_reader = reader
.row_reader(&options)
.map_err(OpenOrcError::OrcError)?;
match T::check_kind(&row_reader.selected_kind()) {
Ok(_) => (),
Err(msg) => return Err(OpenOrcError::KindError(msg)),
}
let row_count = reader
.row_count()
.try_into()
.expect("row count overflows usize");
Ok(ParallelRowIterator {
reader: reader,
row_reader_options: options,
batch_size,
start: 0,
end: row_count,
marker: PhantomData,
})
}
}
impl<T: OrcDeserialize + Clone + Send + Sync> ParallelIterator for ParallelRowIterator<T> {
type Item = T;
fn drive_unindexed<C: rayon::iter::plumbing::UnindexedConsumer<Self::Item>>(
self,
consumer: C,
) -> C::Result {
bridge(self, consumer)
}
fn opt_len(&self) -> Option<usize> {
Some(self.end - self.start)
}
}
impl<T: OrcDeserialize + Clone + Send + Sync> IndexedParallelIterator for ParallelRowIterator<T> {
fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
callback.callback(RowProducer {
iter: &self,
start: self.start,
end: self.end,
})
}
fn drive<C: Consumer<Self::Item>>(self, consumer: C) -> C::Result {
bridge(self, consumer)
}
fn len(&self) -> usize {
self.end - self.start
}
}
struct RowProducer<'a, T: OrcDeserialize + Clone + Send + Sync> {
iter: &'a ParallelRowIterator<T>,
start: usize,
end: usize,
}
impl<'a, T: OrcDeserialize + Clone + Send + Sync> Producer for RowProducer<'a, T> {
type Item = T;
type IntoIter = std::iter::Take<RowIterator<T>>;
fn into_iter(self) -> Self::IntoIter {
assert!(self.start <= self.end);
let start = self
.start
.try_into()
.expect("RowProducer::start overflows u64");
RowIterator::new_with_options(
&self.iter.reader,
self.iter.batch_size,
&self.iter.row_reader_options,
)
.expect("Could not create RowIterator") .seek(start)
.take(self.end - self.start) }
fn split_at(self, index: usize) -> (Self, Self) {
(
RowProducer {
iter: self.iter,
start: self.start,
end: self.start + index,
},
RowProducer {
iter: self.iter,
start: self.start + index,
end: self.end,
},
)
}
}