use arrow::array::RecordBatchReader;
use arrow::datatypes::SchemaRef;
use arrow::error::ArrowError;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use indicatif::ProgressBar;
use crate::Error;
use crate::FileType;
use crate::cli::DisplayOutputFormat;
use crate::pipeline::DisplaySlice;
use crate::pipeline::Producer;
use crate::pipeline::ProgressRecordBatchReader;
use crate::pipeline::RecordBatchReaderSource;
use crate::pipeline::SelectSpec;
use crate::pipeline::Step;
use crate::pipeline::VecRecordBatchReaderSource;
use crate::pipeline::block_on_pipeline_future;
use crate::pipeline::count_rows;
use crate::pipeline::display::apply_select_and_display;
use crate::pipeline::orc::OrcRecordBatchReader;
use crate::pipeline::read::ReadArgs;
use crate::pipeline::sample_from_reader;
use crate::pipeline::schema::get_schema_fields;
use crate::pipeline::schema::print_schema_fields;
use crate::pipeline::schema::schema_fields_from_arrow;
use crate::pipeline::tail_batches;
use crate::pipeline::write::write_record_batches_from_reader;
struct RecordBatchReaderHolder {
reader: Option<Box<dyn RecordBatchReader + 'static>>,
}
#[async_trait(?Send)]
impl Producer<dyn RecordBatchReader + 'static> for RecordBatchReaderHolder {
async fn get(&mut self) -> crate::Result<Box<dyn RecordBatchReader + 'static>> {
std::mem::take(&mut self.reader)
.ok_or_else(|| Error::GenericError("Reader already taken".to_string()))
}
}
pub struct RecordBatchSelect {
pub select: SelectSpec,
}
#[async_trait(?Send)]
impl Step for RecordBatchSelect {
type Input = RecordBatchReaderSource;
type Output = RecordBatchReaderSource;
async fn execute(self, mut input: Self::Input) -> crate::Result<Self::Output> {
let reader = input.get().await?;
let schema = reader.schema();
let column_names = self.select.resolve_names(&schema)?;
let indices: Vec<usize> = column_names
.iter()
.map(|col| {
schema
.index_of(col)
.map_err(|e| Error::GenericError(format!("Column '{col}' not found: {e}")))
})
.collect::<crate::Result<Vec<_>>>()?;
let projected_schema = reader.schema().project(&indices)?;
let projected_reader = SelectColumnRecordBatchReader {
reader,
schema: std::sync::Arc::new(projected_schema),
indices,
};
Ok(Box::new(RecordBatchReaderHolder {
reader: Some(Box::new(projected_reader)),
}))
}
}
pub struct SelectColumnRecordBatchReader {
reader: Box<dyn RecordBatchReader>,
schema: arrow::datatypes::SchemaRef,
indices: Vec<usize>,
}
impl RecordBatchReader for SelectColumnRecordBatchReader {
fn schema(&self) -> arrow::datatypes::SchemaRef {
self.schema.clone()
}
}
impl Iterator for SelectColumnRecordBatchReader {
type Item = arrow::error::Result<arrow::record_batch::RecordBatch>;
fn next(&mut self) -> Option<Self::Item> {
self.reader
.next()
.map(|batch| batch.and_then(|b| b.project(&self.indices)))
}
}
pub fn parse_select_step(select: Option<SelectSpec>) -> Option<RecordBatchSelect> {
select.as_ref()?;
select.map(|columns| RecordBatchSelect { select: columns })
}
pub(crate) struct SkipRowsRecordBatchReader {
reader: Box<dyn RecordBatchReader + 'static>,
offset_remaining: usize,
}
impl RecordBatchReader for SkipRowsRecordBatchReader {
fn schema(&self) -> SchemaRef {
self.reader.schema()
}
}
impl Iterator for SkipRowsRecordBatchReader {
type Item = std::result::Result<RecordBatch, ArrowError>;
fn next(&mut self) -> Option<Self::Item> {
loop {
let batch = self.reader.next()?;
let batch = match batch {
Ok(b) => b,
Err(e) => return Some(Err(e)),
};
let batch_rows = batch.num_rows();
if batch_rows == 0 {
continue;
}
if self.offset_remaining == 0 {
return Some(Ok(batch));
}
if batch_rows <= self.offset_remaining {
self.offset_remaining -= batch_rows;
continue;
}
let start = self.offset_remaining;
self.offset_remaining = 0;
return Some(Ok(batch.slice(start, batch_rows - start)));
}
}
}
pub(crate) struct TakeRowsRecordBatchReader {
reader: Box<dyn RecordBatchReader + 'static>,
remaining: usize,
}
impl RecordBatchReader for TakeRowsRecordBatchReader {
fn schema(&self) -> SchemaRef {
self.reader.schema()
}
}
impl Iterator for TakeRowsRecordBatchReader {
type Item = std::result::Result<RecordBatch, ArrowError>;
fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}
loop {
match self.reader.next() {
None => return None,
Some(Err(e)) => return Some(Err(e)),
Some(Ok(batch)) => {
let rows = batch.num_rows();
if rows == 0 {
continue;
}
if rows <= self.remaining {
self.remaining -= rows;
return Some(Ok(batch));
}
let slice = batch.slice(0, self.remaining);
self.remaining = 0;
return Some(Ok(slice));
}
}
}
}
}
pub(crate) fn apply_offset_limit(
reader: Box<dyn RecordBatchReader + 'static>,
offset: usize,
limit: Option<usize>,
) -> Box<dyn RecordBatchReader + 'static> {
let mut r = reader;
if offset > 0 {
r = Box::new(SkipRowsRecordBatchReader {
reader: r,
offset_remaining: offset,
});
}
if let Some(n) = limit {
r = Box::new(TakeRowsRecordBatchReader {
reader: r,
remaining: n,
});
}
r
}
pub struct RecordBatchHead {
pub n: usize,
}
#[async_trait(?Send)]
impl Step for RecordBatchHead {
type Input = RecordBatchReaderSource;
type Output = RecordBatchReaderSource;
async fn execute(self, mut input: Self::Input) -> crate::Result<Self::Output> {
let reader = input.get().await?;
let wrapped = Box::new(TakeRowsRecordBatchReader {
reader,
remaining: self.n,
});
Ok(Box::new(RecordBatchReaderHolder {
reader: Some(wrapped),
}))
}
}
pub struct RecordBatchTail {
pub n: usize,
}
#[async_trait(?Send)]
impl Step for RecordBatchTail {
type Input = RecordBatchReaderSource;
type Output = RecordBatchReaderSource;
async fn execute(self, mut input: Self::Input) -> crate::Result<Self::Output> {
let reader = input.get().await?;
let batches: Vec<RecordBatch> = reader
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(Error::ArrowError)?;
let batches = tail_batches(batches, self.n);
Ok(Box::new(VecRecordBatchReaderSource::new(batches)))
}
}
pub struct RecordBatchSample {
pub input_path: String,
pub n: usize,
}
#[async_trait(?Send)]
impl Step for RecordBatchSample {
type Input = RecordBatchReaderSource;
type Output = RecordBatchReaderSource;
async fn execute(self, mut input: Self::Input) -> crate::Result<Self::Output> {
let total_rows = crate::get_total_rows_result(&self.input_path, FileType::Orc)?;
let reader = input.get().await?;
let sampled = sample_from_reader(reader, total_rows, self.n);
Ok(Box::new(VecRecordBatchReaderSource::new(sampled)))
}
}
pub enum RecordBatchSink {
Write {
output_path: String,
output_file_type: FileType,
json_pretty: bool,
progress: Option<ProgressBar>,
},
Display {
output_format: DisplayOutputFormat,
csv_stdout_headers: bool,
},
Schema {
output_format: DisplayOutputFormat,
sparse: bool,
},
Count,
}
async fn orc_source_after_select(
input_path: String,
select: Option<SelectSpec>,
) -> crate::Result<RecordBatchReaderSource> {
let read_args = ReadArgs::new(input_path, FileType::Orc);
let mut source: RecordBatchReaderSource = Box::new(OrcRecordBatchReader { args: read_args });
if let Some(select_step) = parse_select_step(select) {
source = select_step.execute(source).await?;
}
Ok(source)
}
async fn orc_source_after_select_and_slice(
input_path: String,
select: Option<SelectSpec>,
slice: Option<DisplaySlice>,
) -> crate::Result<RecordBatchReaderSource> {
let mut source = orc_source_after_select(input_path.clone(), select).await?;
if let Some(slice) = slice {
source = match slice {
DisplaySlice::Head(n) => RecordBatchHead { n }.execute(source).await?,
DisplaySlice::Tail(n) => RecordBatchTail { n }.execute(source).await?,
DisplaySlice::Sample(n) => RecordBatchSample { input_path, n }.execute(source).await?,
};
}
Ok(source)
}
pub struct RecordBatchPipeline {
pub(crate) input_path: String,
pub(crate) input_file_type: FileType,
pub(crate) select: Option<SelectSpec>,
pub(crate) slice: Option<DisplaySlice>,
pub(crate) sparse: bool,
pub(crate) sink: RecordBatchSink,
}
impl RecordBatchPipeline {
pub fn execute(&mut self) -> crate::Result<()> {
if self.input_file_type != FileType::Orc {
return Err(Error::GenericError(format!(
"RecordBatchPipeline only supports ORC input, got {}",
self.input_file_type
)));
}
let input_path = self.input_path.clone();
let select = self.select.clone();
let slice = self.slice;
let sparse = self.sparse;
let sink = match &self.sink {
RecordBatchSink::Write {
output_path,
output_file_type,
json_pretty,
progress,
} => RecordBatchSink::Write {
output_path: output_path.clone(),
output_file_type: *output_file_type,
json_pretty: *json_pretty,
progress: progress.clone(),
},
RecordBatchSink::Display {
output_format,
csv_stdout_headers,
} => RecordBatchSink::Display {
output_format: *output_format,
csv_stdout_headers: *csv_stdout_headers,
},
RecordBatchSink::Schema {
output_format,
sparse: schema_sparse,
} => RecordBatchSink::Schema {
output_format: *output_format,
sparse: *schema_sparse,
},
RecordBatchSink::Count => RecordBatchSink::Count,
};
let fut = async move {
match sink {
RecordBatchSink::Schema {
output_format,
sparse: schema_sparse,
} => {
if select.is_none() {
let fields = get_schema_fields(&input_path, FileType::Orc, None)
.map_err(|e| Error::GenericError(e.to_string()))?;
print_schema_fields(&fields, output_format, schema_sparse)
.map_err(|e| Error::GenericError(e.to_string()))?;
} else {
let mut source =
orc_source_after_select(input_path.clone(), select).await?;
let reader = source.get().await?;
let fields = schema_fields_from_arrow(reader.schema().as_ref());
print_schema_fields(&fields, output_format, schema_sparse)
.map_err(|e| Error::GenericError(e.to_string()))?;
}
Ok::<(), Error>(())
}
RecordBatchSink::Count => {
if select.is_none() {
let total = count_rows(&input_path, FileType::Orc, None).await?;
println!("{total}");
} else {
let mut source =
orc_source_after_select(input_path.clone(), select).await?;
let reader = source.get().await?;
let mut total = 0usize;
for batch in reader {
total += batch.map_err(Error::ArrowError)?.num_rows();
}
println!("{total}");
}
Ok::<(), Error>(())
}
RecordBatchSink::Write {
output_path,
output_file_type,
json_pretty,
progress,
} => {
let mut source =
orc_source_after_select_and_slice(input_path.clone(), select, slice)
.await?;
let reader = source.get().await?;
if let Some(pb) = progress {
let mut wrapped = ProgressRecordBatchReader {
inner: reader,
progress: pb,
};
write_record_batches_from_reader(
&mut wrapped,
output_path.as_str(),
output_file_type,
sparse,
json_pretty,
)?;
} else {
let mut reader = reader;
write_record_batches_from_reader(
&mut *reader,
output_path.as_str(),
output_file_type,
sparse,
json_pretty,
)?;
}
Ok::<(), Error>(())
}
RecordBatchSink::Display {
output_format,
csv_stdout_headers,
} => {
let source =
orc_source_after_select_and_slice(input_path.clone(), select, slice)
.await?;
apply_select_and_display(
source,
None,
output_format,
sparse,
csv_stdout_headers,
)
.await?;
Ok::<(), Error>(())
}
}
};
block_on_pipeline_future(fut)
}
}
pub trait BatchWriteSink {
fn write_batch(&mut self, batch: &RecordBatch) -> crate::Result<()>;
fn finish(self) -> crate::Result<()>;
}
pub fn write_record_batches_with_sink<S, BuildSink>(
path: &str,
reader: &mut dyn RecordBatchReader,
build_sink: BuildSink,
) -> crate::Result<()>
where
S: BatchWriteSink,
BuildSink: FnOnce(&str, SchemaRef) -> crate::Result<S>,
{
let schema = reader.schema();
let mut sink = build_sink(path, schema)?;
for batch in reader {
let batch = batch.map_err(Error::ArrowError)?;
sink.write_batch(&batch)?;
}
sink.finish()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::FileType;
use crate::pipeline::ColumnSpec;
use crate::pipeline::RecordBatchReaderSource;
use crate::pipeline::SelectItem;
use crate::pipeline::SelectSpec;
use crate::pipeline::parquet::RecordBatchParquetReader;
use crate::pipeline::read::ReadArgs;
#[test]
fn test_parse_select_step_none() {
assert!(parse_select_step(None).is_none());
}
#[test]
fn test_parse_select_step_some() {
let select = SelectSpec::from_cli_args(&Some(vec!["one".to_string(), "two".to_string()]));
let step = parse_select_step(select).expect("should return some");
assert_eq!(step.select.len(), 2);
assert_eq!(
step.select[0],
SelectItem::Column(ColumnSpec::Exact("one".into()))
);
assert_eq!(
step.select[1],
SelectItem::Column(ColumnSpec::Exact("two".into()))
);
}
#[test]
fn test_parse_select_step_comma_separated() {
let select = SelectSpec::from_cli_args(&Some(vec!["one, two".to_string()]));
let step = parse_select_step(select).expect("should return some");
assert_eq!(step.select.len(), 2);
assert_eq!(
step.select[0],
SelectItem::Column(ColumnSpec::Exact("one".into()))
);
assert_eq!(
step.select[1],
SelectItem::Column(ColumnSpec::Exact("two".into()))
);
}
#[test]
fn test_parse_select_step_empty_returns_none() {
let select = SelectSpec::from_cli_args(&Some(vec![" , ".to_string()]));
assert!(parse_select_step(select).is_none());
}
#[tokio::test(flavor = "multi_thread")]
async fn test_select_columns() {
let args = ReadArgs::new("fixtures/table.parquet", FileType::Parquet);
let parquet_step = RecordBatchParquetReader { args };
let source: RecordBatchReaderSource = Box::new(parquet_step);
let select_step = RecordBatchSelect {
select: SelectSpec {
columns: vec![
SelectItem::Column(ColumnSpec::Exact("two".to_string())),
SelectItem::Column(ColumnSpec::Exact("four".to_string())),
],
group_by: None,
},
};
let mut projected_source = select_step
.execute(source)
.await
.expect("Failed to execute select columns");
let mut projected_reader = projected_source
.get()
.await
.expect("Failed to get record batch reader");
let projected_schema = projected_reader.schema();
assert_eq!(projected_schema.fields().len(), 2);
assert_eq!(projected_schema.field(0).name(), "two");
assert_eq!(projected_schema.field(1).name(), "four");
let batch_result = projected_reader.next().unwrap();
let projected_batch = batch_result.unwrap();
let batch_rows = projected_batch.num_rows();
assert_eq!(projected_batch.num_columns(), 2);
assert_eq!(projected_batch.column(0).len(), batch_rows);
assert_eq!(projected_batch.column(1).len(), batch_rows);
}
}