use arrow::record_batch::RecordBatch;
use crate::deserialize::{ArRowDeserialize, DeserializationError};
pub struct RowIterator<R: Iterator<Item = RecordBatch>, T: ArRowDeserialize + Clone> {
reader: R,
batch: Vec<T>,
index: usize,
decoded_items: usize,
}
impl<R: Iterator<Item = RecordBatch>, T: ArRowDeserialize + Clone> RowIterator<R, T> {
pub fn new(reader: R) -> Result<RowIterator<R, T>, DeserializationError> {
let mut row_iterator = RowIterator {
reader,
batch: Vec::new(),
index: 0,
decoded_items: 0, };
row_iterator.read_batch( true)?; Ok(row_iterator)
}
fn read_batch(&mut self, check_schema: bool) -> Result<bool, DeserializationError> {
self.index = 0;
match self.reader.next() {
Some(record_batch) => {
if check_schema {
T::check_schema(&record_batch.schema())
.map_err(DeserializationError::MismatchedColumnDataType)?;
}
self.batch.resize(record_batch.num_rows(), T::default());
self.decoded_items = T::read_from_record_batch(record_batch, &mut self.batch)?;
Ok(false)
}
None => Ok(true),
}
}
}
impl<R: Iterator<Item = RecordBatch>, T: ArRowDeserialize + Clone> Iterator for RowIterator<R, T> {
type Item = T;
fn next(&mut self) -> Option<T> {
if self.index == self.decoded_items {
let ended = self.read_batch( false).expect("ArRowDeserialize::read_from_array() call from RowIterator::next() returns a deserialization error");
if ended {
return None;
}
}
let item = self.batch.get(self.index);
self.index += 1;
item.cloned()
}
}