use std::fs::File;
use std::path::Path;
use async_trait::async_trait;
use datafusion::prelude::AvroReadOptions;
use datafusion::prelude::CsvReadOptions;
use datafusion::prelude::DataFrame;
use datafusion::prelude::NdJsonReadOptions;
use datafusion::prelude::ParquetReadOptions;
use datafusion::prelude::SessionContext;
use orc_rust::ArrowReaderBuilder;
use crate::Error;
use crate::FileType;
use crate::Result;
use crate::errors::PipelineExecutionError;
use crate::errors::PipelinePlanningError;
use crate::pipeline::Producer;
use crate::pipeline::Step;
use crate::pipeline::dataframe::DataFrameSource;
#[derive(Clone)]
pub struct ReadArgs {
pub path: String,
pub file_type: FileType,
pub csv_has_header: Option<bool>,
pub limit: Option<usize>,
pub offset: Option<usize>,
}
impl ReadArgs {
pub fn new(path: impl Into<String>, file_type: FileType) -> Self {
Self {
path: path.into(),
file_type,
csv_has_header: None,
limit: None,
offset: None,
}
}
}
pub(crate) fn expect_file_type(args: &ReadArgs, expected: FileType) -> Result<()> {
if args.file_type != expected {
return Err(Error::GenericError(format!(
"read args file type mismatch: expected {expected}, got {}",
args.file_type
)));
}
Ok(())
}
#[allow(clippy::large_enum_variant)]
pub enum ReadResult {
DataFrame(DataFrameSource),
OrcReaderBuilder(ArrowReaderBuilder<File>),
}
impl std::fmt::Debug for ReadResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ReadResult::DataFrame(source) => write!(f, "DataFrame({:?})", &source),
ReadResult::OrcReaderBuilder(builder) => write!(f, "OrcReaderBuilder({:p})", &builder),
}
}
}
pub async fn read(args: &ReadArgs) -> Result<ReadResult> {
if args.file_type.supports_datafusion_file_read() {
read_to_dataframe(&args.path, args.file_type, args.csv_has_header).await
} else if args.file_type == FileType::Orc {
read_to_record_batches(args)
} else {
Err(Error::PipelinePlanningError(
PipelinePlanningError::UnsupportedInputFileType(args.file_type.to_string()),
))
}
}
pub async fn read_to_dataframe(
input_path: &str,
file_type: FileType,
csv_has_header: Option<bool>,
) -> Result<ReadResult> {
let ctx = SessionContext::new();
let df = match file_type {
FileType::Parquet => {
ctx.read_parquet(input_path, ParquetReadOptions::default())
.await?
}
FileType::Avro => {
ctx.read_avro(input_path, AvroReadOptions::default())
.await?
}
FileType::Json => {
ctx.read_json(input_path, NdJsonReadOptions::default())
.await?
}
FileType::Csv => {
let csv_options = CsvReadOptions::new().has_header(csv_has_header.unwrap_or(true));
ctx.read_csv(input_path, csv_options).await?
}
_ => {
return Err(Error::PipelineExecutionError(
PipelineExecutionError::UnsupportedInputFileType(file_type),
));
}
};
let table_provider = df.into_temporary_view();
let basename = Path::new(input_path)
.file_stem()
.unwrap()
.to_string_lossy()
.to_string();
ctx.register_table(&basename, table_provider)?;
let registered_df = ctx.table(&basename).await?;
Ok(ReadResult::DataFrame(DataFrameSource::new(registered_df)))
}
pub struct DataframeFormatReader {
pub args: ReadArgs,
}
#[async_trait(?Send)]
impl Step for DataframeFormatReader {
type Input = ();
type Output = DataFrameSource;
async fn execute(self, _input: Self::Input) -> Result<Self::Output> {
let result = read_to_dataframe(
&self.args.path,
self.args.file_type,
self.args.csv_has_header,
)
.await?;
let ReadResult::DataFrame(source) = result else {
unreachable!()
};
Ok(source)
}
}
#[async_trait(?Send)]
impl Producer<DataFrame> for DataframeFormatReader {
async fn get(&mut self) -> Result<Box<DataFrame>> {
let result = read_to_dataframe(
&self.args.path,
self.args.file_type,
self.args.csv_has_header,
)
.await?;
let ReadResult::DataFrame(mut source) = result else {
unreachable!()
};
source.get().await
}
}
pub fn read_to_record_batches(args: &ReadArgs) -> Result<ReadResult> {
match args.file_type {
FileType::Orc => {
let file = std::fs::File::open(&args.path).map_err(Error::IoError)?;
let builder = ArrowReaderBuilder::try_new(file).map_err(Error::OrcError)?;
Ok(ReadResult::OrcReaderBuilder(builder))
}
_ => Err(Error::PipelineExecutionError(
PipelineExecutionError::UnsupportedInputFileType(args.file_type),
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_read_datafusion_formats_yield_dataframe() {
let cases = [
("fixtures/table.parquet", FileType::Parquet),
("fixtures/userdata5.avro", FileType::Avro),
("fixtures/table.csv", FileType::Csv),
("fixtures/table.json", FileType::Json),
];
for (path, file_type) in cases {
let args = ReadArgs::new(path, file_type);
let result = read(&args)
.await
.unwrap_or_else(|e| panic!("read {path} ({file_type}): {e}"));
let ReadResult::DataFrame(mut source) = result else {
panic!("expected DataFrame for {path}, got {result:?}");
};
source
.get()
.await
.unwrap_or_else(|e| panic!("get DataFrame {path}: {e}"));
}
}
#[tokio::test]
async fn test_read_orc() {
let args = ReadArgs::new("fixtures/userdata.orc", FileType::Orc);
let result = read(&args).await.expect("read ORC");
assert!(matches!(result, ReadResult::OrcReaderBuilder(_)));
}
}