use std::marker::PhantomData;
use std::path::PathBuf;
use std::sync::Arc;
use arrow::datatypes::{FieldRef, Schema};
use parquet::arrow::ArrowWriter;
use parquet::file::properties::WriterProperties;
use serde_arrow::schema::{SchemaLike, TracingOptions};
use crate::Uri;
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum ParquetError {
#[error(transparent)]
Io(#[from] std::io::Error),
#[error("parquet write error: {0}")]
Write(#[source] Box<dyn std::error::Error + Send + Sync>),
#[error("parquet read error: {0}")]
Read(#[source] Box<dyn std::error::Error + Send + Sync>),
#[error("out of bounds: {start}..{end} (rows {rows})")]
OutOfBounds {
start: usize,
end: usize,
rows: usize,
},
#[error(transparent)]
Cache(#[from] crate::CacheError),
}
#[derive(Debug)]
pub struct ParquetWriter<T> {
writer: Option<ArrowWriter<std::fs::File>>,
rows: usize,
cols: usize,
_marker: PhantomData<T>,
}
impl<T: serde::Serialize + for<'de> serde::Deserialize<'de>> ParquetWriter<T> {
pub fn new(path: impl Into<Uri>) -> Result<Self, ParquetError> {
let uri: Uri = path.into();
let p = uri.as_path().ok_or_else(|| {
ParquetError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"not a local path",
))
})?;
let fields = Vec::<FieldRef>::from_type::<T>(TracingOptions::default())
.map_err(|e| ParquetError::Write(Box::new(e)))?;
let cols = fields.len();
let schema = Arc::new(Schema::new(fields));
let file = std::fs::File::create(p)?;
let props = WriterProperties::builder().build();
let writer = ArrowWriter::try_new(file, schema, Some(props))
.map_err(|e| ParquetError::Write(Box::new(e)))?;
Ok(ParquetWriter {
writer: Some(writer),
rows: 0,
cols,
_marker: PhantomData,
})
}
pub fn write_batch(&mut self, batch: Vec<T>) -> Result<(), ParquetError> {
let fields = Vec::<FieldRef>::from_type::<T>(TracingOptions::default())
.map_err(|e| ParquetError::Write(Box::new(e)))?;
let record_batch = serde_arrow::to_record_batch(&fields, &batch)
.map_err(|e| ParquetError::Write(Box::new(e)))?;
let n = record_batch.num_rows();
if let Some(ref mut w) = self.writer {
w.write(&record_batch)
.map_err(|e| ParquetError::Write(Box::new(e)))?;
w.flush().map_err(|e| ParquetError::Write(Box::new(e)))?;
}
self.rows += n;
Ok(())
}
#[must_use]
pub fn shape(&self) -> (usize, usize) {
(self.rows, self.cols)
}
}
impl<T> Drop for ParquetWriter<T> {
fn drop(&mut self) {
if let Some(writer) = self.writer.take() {
let _ = writer.close();
}
}
}
#[derive(Debug)]
pub struct ParquetReader<T> {
path: PathBuf,
rows: usize,
cols: usize,
_marker: PhantomData<T>,
}
impl<T: for<'de> serde::Deserialize<'de>> ParquetReader<T> {
pub fn from(uri: impl Into<Uri>) -> Result<Self, ParquetError> {
let uri = uri.into();
let uri = uri.force_cache()?;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
let path = uri.as_path().ok_or_else(|| {
ParquetError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
"not a local file",
))
})?;
let f = std::fs::File::open(&path)?;
let builder = ParquetRecordBatchReaderBuilder::try_new(f)
.map_err(|e| ParquetError::Read(Box::new(e)))?;
let metadata = builder.metadata();
let rows: usize = metadata
.row_groups()
.iter()
.map(|rg| rg.num_rows() as usize)
.sum();
let schema = builder.schema();
let cols = schema.fields().len();
Ok(ParquetReader {
path,
rows,
cols,
_marker: PhantomData,
})
}
#[must_use]
pub fn shape(&self) -> (usize, usize) {
(self.rows, self.cols)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.rows == 0
}
pub fn read_all(&self) -> Result<Vec<T>, ParquetError> {
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
let f = std::fs::File::open(&self.path)?;
let builder = ParquetRecordBatchReaderBuilder::try_new(f)
.map_err(|e| ParquetError::Read(Box::new(e)))?;
let reader = builder
.build()
.map_err(|e| ParquetError::Read(Box::new(e)))?;
let mut result = Vec::with_capacity(self.rows);
for batch in reader {
let batch = batch.map_err(|e| ParquetError::Read(Box::new(e)))?;
let mut items: Vec<T> = serde_arrow::from_record_batch(&batch)
.map_err(|e| ParquetError::Read(Box::new(e)))?;
result.append(&mut items);
}
Ok(result)
}
pub fn read_range(
&self,
range: std::ops::Range<usize>,
) -> Result<Vec<T>, ParquetError> {
if range.end > self.rows {
return Err(ParquetError::OutOfBounds {
start: range.start,
end: range.end,
rows: self.rows,
});
}
let all = self.read_all()?;
Ok(all
.into_iter()
.skip(range.start)
.take(range.end - range.start)
.collect())
}
}