use std::path::PathBuf;
use std::sync::Arc;
use arrow::array::RecordBatch;
use arrow::datatypes::Schema;
use parquet::arrow::ArrowWriter;
use parquet::basic::Compression;
use parquet::file::properties::WriterProperties;
use datasynth_core::error::{SynthError, SynthResult};
use datasynth_core::traits::{StreamEvent, StreamingSink};
pub struct ParquetStreamingSink<T: ToParquetBatch + Send> {
writer: Option<ArrowWriter<std::fs::File>>,
items_written: u64,
buffer: Vec<T>,
row_group_size: usize,
path: PathBuf,
schema: Option<Arc<Schema>>,
writer_created: bool,
}
impl<T: ToParquetBatch + Send> ParquetStreamingSink<T> {
pub fn new(path: PathBuf, row_group_size: usize) -> SynthResult<Self> {
Ok(Self {
writer: None,
items_written: 0,
buffer: Vec::with_capacity(row_group_size),
row_group_size,
path,
schema: None,
writer_created: false,
})
}
pub fn with_defaults(path: PathBuf) -> SynthResult<Self> {
Self::new(path, 10000)
}
pub fn path(&self) -> &PathBuf {
&self.path
}
fn ensure_writer(&mut self, schema: Arc<Schema>) -> SynthResult<()> {
if self.writer_created {
return Ok(());
}
let file = std::fs::File::create(&self.path)?;
let props = WriterProperties::builder()
.set_compression(Compression::SNAPPY)
.set_max_row_group_row_count(Some(self.row_group_size))
.build();
let writer = ArrowWriter::try_new(file, Arc::clone(&schema), Some(props))
.map_err(|e| SynthError::generation(format!("Failed to create Parquet writer: {e}")))?;
self.writer = Some(writer);
self.schema = Some(schema);
self.writer_created = true;
Ok(())
}
fn flush_buffer(&mut self) -> SynthResult<()> {
if self.buffer.is_empty() {
return Ok(());
}
let dummy_schema = Arc::new(T::schema());
let batch = T::to_batch(&self.buffer, Arc::clone(&dummy_schema))?;
self.ensure_writer(batch.schema())?;
if let Some(writer) = &mut self.writer {
writer.write(&batch).map_err(|e| {
SynthError::generation(format!("Failed to write Parquet batch: {e}"))
})?;
}
self.buffer.clear();
Ok(())
}
}
impl<T: ToParquetBatch + Send> StreamingSink<T> for ParquetStreamingSink<T> {
fn process(&mut self, event: StreamEvent<T>) -> SynthResult<()> {
match event {
StreamEvent::Data(item) => {
self.buffer.push(item);
self.items_written += 1;
if self.buffer.len() >= self.row_group_size {
self.flush_buffer()?;
}
}
StreamEvent::Complete(_summary) => {
self.flush_buffer()?;
if let Some(writer) = self.writer.take() {
writer.close().map_err(|e| {
SynthError::generation(format!("Failed to close Parquet writer: {e}"))
})?;
}
}
StreamEvent::BatchComplete { .. } => {
self.flush_buffer()?;
}
StreamEvent::Progress(_) | StreamEvent::Error(_) => {}
}
Ok(())
}
fn flush(&mut self) -> SynthResult<()> {
self.flush_buffer()?;
if let Some(writer) = &mut self.writer {
writer.flush().map_err(|e| {
SynthError::generation(format!("Failed to flush Parquet writer: {e}"))
})?;
}
Ok(())
}
fn close(mut self) -> SynthResult<()> {
self.flush_buffer()?;
if let Some(writer) = self.writer.take() {
writer.close().map_err(|e| {
SynthError::generation(format!("Failed to close Parquet writer: {e}"))
})?;
}
Ok(())
}
fn items_processed(&self) -> u64 {
self.items_written
}
}
pub trait ToParquetBatch {
fn schema() -> Schema;
fn to_batch(items: &[Self], schema: Arc<Schema>) -> SynthResult<RecordBatch>
where
Self: Sized;
}
#[cfg(test)]
#[derive(Debug, Clone)]
pub struct GenericParquetRecord {
pub field_names: Vec<String>,
pub values: Vec<String>,
}
#[cfg(test)]
impl GenericParquetRecord {
pub fn new(field_names: Vec<String>, values: Vec<String>) -> Self {
Self {
field_names,
values,
}
}
}
#[cfg(test)]
impl ToParquetBatch for GenericParquetRecord {
fn schema() -> Schema {
use arrow::datatypes::{DataType, Field};
Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("type", DataType::Utf8, true),
Field::new("data", DataType::Utf8, true),
])
}
fn to_batch(items: &[Self], schema: Arc<Schema>) -> SynthResult<RecordBatch> {
use arrow::array::{ArrayRef, StringArray};
use arrow::datatypes::{DataType, Field};
if items.is_empty() {
return RecordBatch::try_new_with_options(
schema,
vec![],
&arrow::array::RecordBatchOptions::new().with_row_count(Some(0)),
)
.map_err(|e| SynthError::generation(format!("Failed to create empty batch: {}", e)));
}
let field_names = &items[0].field_names;
let num_fields = field_names.len();
let mut arrays: Vec<ArrayRef> = Vec::with_capacity(num_fields);
for field_idx in 0..num_fields {
let values: Vec<&str> = items
.iter()
.map(|item| item.values.get(field_idx).map(|s| s.as_str()).unwrap_or(""))
.collect();
arrays.push(Arc::new(StringArray::from(values)));
}
let fields: Vec<Field> = field_names
.iter()
.map(|name| Field::new(name, DataType::Utf8, true))
.collect();
let dynamic_schema = Arc::new(Schema::new(fields));
RecordBatch::try_new(dynamic_schema, arrays)
.map_err(|e| SynthError::generation(format!("Failed to create record batch: {}", e)))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use datasynth_core::traits::StreamSummary;
use tempfile::tempdir;
#[test]
fn test_parquet_streaming_sink_basic() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.parquet");
let mut sink =
ParquetStreamingSink::<GenericParquetRecord>::new(path.clone(), 100).unwrap();
let record = GenericParquetRecord::new(
vec!["id".to_string(), "name".to_string()],
vec!["1".to_string(), "test".to_string()],
);
sink.process(StreamEvent::Data(record)).unwrap();
sink.process(StreamEvent::Complete(StreamSummary::new(1, 100)))
.unwrap();
assert!(path.exists());
assert!(std::fs::metadata(&path).unwrap().len() > 0);
}
#[test]
fn test_parquet_streaming_sink_row_group_flush() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.parquet");
let mut sink = ParquetStreamingSink::<GenericParquetRecord>::new(path.clone(), 5).unwrap();
for i in 0..12 {
let record = GenericParquetRecord::new(
vec!["id".to_string(), "value".to_string()],
vec![i.to_string(), format!("value_{}", i)],
);
sink.process(StreamEvent::Data(record)).unwrap();
}
sink.process(StreamEvent::Complete(StreamSummary::new(12, 100)))
.unwrap();
assert_eq!(sink.items_processed(), 12);
}
#[test]
fn test_parquet_items_processed() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.parquet");
let mut sink = ParquetStreamingSink::<GenericParquetRecord>::new(path, 100).unwrap();
for i in 0..25 {
let record = GenericParquetRecord::new(vec!["id".to_string()], vec![i.to_string()]);
sink.process(StreamEvent::Data(record)).unwrap();
}
assert_eq!(sink.items_processed(), 25);
}
}