datasynth_output/streaming/
parquet_sink.rs1use std::path::PathBuf;
6use std::sync::Arc;
7
8use arrow::array::RecordBatch;
9use arrow::datatypes::Schema;
10use parquet::arrow::ArrowWriter;
11use parquet::basic::Compression;
12use parquet::file::properties::WriterProperties;
13
14use datasynth_core::error::{SynthError, SynthResult};
15use datasynth_core::traits::{StreamEvent, StreamingSink};
16
17pub struct ParquetStreamingSink<T: ToParquetBatch + Send> {
31 writer: Option<ArrowWriter<std::fs::File>>,
33 items_written: u64,
34 buffer: Vec<T>,
35 row_group_size: usize,
36 path: PathBuf,
37 schema: Option<Arc<Schema>>,
39 writer_created: bool,
40}
41
42impl<T: ToParquetBatch + Send> ParquetStreamingSink<T> {
43 pub fn new(path: PathBuf, row_group_size: usize) -> SynthResult<Self> {
54 Ok(Self {
55 writer: None,
56 items_written: 0,
57 buffer: Vec::with_capacity(row_group_size),
58 row_group_size,
59 path,
60 schema: None,
61 writer_created: false,
62 })
63 }
64
65 pub fn with_defaults(path: PathBuf) -> SynthResult<Self> {
67 Self::new(path, 10000)
68 }
69
70 pub fn path(&self) -> &PathBuf {
72 &self.path
73 }
74
75 fn ensure_writer(&mut self, schema: Arc<Schema>) -> SynthResult<()> {
77 if self.writer_created {
78 return Ok(());
79 }
80
81 let file = std::fs::File::create(&self.path)?;
82
83 let props = WriterProperties::builder()
84 .set_compression(Compression::SNAPPY)
85 .set_max_row_group_row_count(Some(self.row_group_size))
86 .build();
87
88 let writer = ArrowWriter::try_new(file, Arc::clone(&schema), Some(props))
89 .map_err(|e| SynthError::generation(format!("Failed to create Parquet writer: {e}")))?;
90
91 self.writer = Some(writer);
92 self.schema = Some(schema);
93 self.writer_created = true;
94 Ok(())
95 }
96
97 fn flush_buffer(&mut self) -> SynthResult<()> {
99 if self.buffer.is_empty() {
100 return Ok(());
101 }
102
103 let dummy_schema = Arc::new(T::schema());
105 let batch = T::to_batch(&self.buffer, Arc::clone(&dummy_schema))?;
106
107 self.ensure_writer(batch.schema())?;
109
110 if let Some(writer) = &mut self.writer {
111 writer.write(&batch).map_err(|e| {
112 SynthError::generation(format!("Failed to write Parquet batch: {e}"))
113 })?;
114 }
115
116 self.buffer.clear();
117 Ok(())
118 }
119}
120
121impl<T: ToParquetBatch + Send> StreamingSink<T> for ParquetStreamingSink<T> {
122 fn process(&mut self, event: StreamEvent<T>) -> SynthResult<()> {
123 match event {
124 StreamEvent::Data(item) => {
125 self.buffer.push(item);
126 self.items_written += 1;
127
128 if self.buffer.len() >= self.row_group_size {
130 self.flush_buffer()?;
131 }
132 }
133 StreamEvent::Complete(_summary) => {
134 self.flush_buffer()?;
136 if let Some(writer) = self.writer.take() {
137 writer.close().map_err(|e| {
138 SynthError::generation(format!("Failed to close Parquet writer: {e}"))
139 })?;
140 }
141 }
142 StreamEvent::BatchComplete { .. } => {
143 self.flush_buffer()?;
145 }
146 StreamEvent::Progress(_) | StreamEvent::Error(_) => {}
147 }
148 Ok(())
149 }
150
151 fn flush(&mut self) -> SynthResult<()> {
152 self.flush_buffer()?;
153 if let Some(writer) = &mut self.writer {
154 writer.flush().map_err(|e| {
155 SynthError::generation(format!("Failed to flush Parquet writer: {e}"))
156 })?;
157 }
158 Ok(())
159 }
160
161 fn close(mut self) -> SynthResult<()> {
162 self.flush_buffer()?;
163 if let Some(writer) = self.writer.take() {
164 writer.close().map_err(|e| {
165 SynthError::generation(format!("Failed to close Parquet writer: {e}"))
166 })?;
167 }
168 Ok(())
169 }
170
171 fn items_processed(&self) -> u64 {
172 self.items_written
173 }
174}
175
176pub trait ToParquetBatch {
180 fn schema() -> Schema;
182
183 fn to_batch(items: &[Self], schema: Arc<Schema>) -> SynthResult<RecordBatch>
185 where
186 Self: Sized;
187}
188
189#[cfg(test)]
194#[derive(Debug, Clone)]
195pub struct GenericParquetRecord {
196 pub field_names: Vec<String>,
198 pub values: Vec<String>,
200}
201
202#[cfg(test)]
203impl GenericParquetRecord {
204 pub fn new(field_names: Vec<String>, values: Vec<String>) -> Self {
206 Self {
207 field_names,
208 values,
209 }
210 }
211}
212
213#[cfg(test)]
214impl ToParquetBatch for GenericParquetRecord {
215 fn schema() -> Schema {
216 use arrow::datatypes::{DataType, Field};
217 Schema::new(vec![
219 Field::new("id", DataType::Utf8, false),
220 Field::new("type", DataType::Utf8, true),
221 Field::new("data", DataType::Utf8, true),
222 ])
223 }
224
225 fn to_batch(items: &[Self], schema: Arc<Schema>) -> SynthResult<RecordBatch> {
226 use arrow::array::{ArrayRef, StringArray};
227 use arrow::datatypes::{DataType, Field};
228
229 if items.is_empty() {
230 return RecordBatch::try_new_with_options(
231 schema,
232 vec![],
233 &arrow::array::RecordBatchOptions::new().with_row_count(Some(0)),
234 )
235 .map_err(|e| SynthError::generation(format!("Failed to create empty batch: {}", e)));
236 }
237
238 let field_names = &items[0].field_names;
240 let num_fields = field_names.len();
241
242 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(num_fields);
244
245 for field_idx in 0..num_fields {
246 let values: Vec<&str> = items
247 .iter()
248 .map(|item| item.values.get(field_idx).map(|s| s.as_str()).unwrap_or(""))
249 .collect();
250 arrays.push(Arc::new(StringArray::from(values)));
251 }
252
253 let fields: Vec<Field> = field_names
255 .iter()
256 .map(|name| Field::new(name, DataType::Utf8, true))
257 .collect();
258 let dynamic_schema = Arc::new(Schema::new(fields));
259
260 RecordBatch::try_new(dynamic_schema, arrays)
261 .map_err(|e| SynthError::generation(format!("Failed to create record batch: {}", e)))
262 }
263}
264
265#[cfg(test)]
266#[allow(clippy::unwrap_used)]
267mod tests {
268 use super::*;
269 use datasynth_core::traits::StreamSummary;
270 use tempfile::tempdir;
271
272 #[test]
273 fn test_parquet_streaming_sink_basic() {
274 let dir = tempdir().unwrap();
275 let path = dir.path().join("test.parquet");
276
277 let mut sink =
278 ParquetStreamingSink::<GenericParquetRecord>::new(path.clone(), 100).unwrap();
279
280 let record = GenericParquetRecord::new(
281 vec!["id".to_string(), "name".to_string()],
282 vec!["1".to_string(), "test".to_string()],
283 );
284
285 sink.process(StreamEvent::Data(record)).unwrap();
286 sink.process(StreamEvent::Complete(StreamSummary::new(1, 100)))
287 .unwrap();
288
289 assert!(path.exists());
291 assert!(std::fs::metadata(&path).unwrap().len() > 0);
292 }
293
294 #[test]
295 fn test_parquet_streaming_sink_row_group_flush() {
296 let dir = tempdir().unwrap();
297 let path = dir.path().join("test.parquet");
298
299 let mut sink = ParquetStreamingSink::<GenericParquetRecord>::new(path.clone(), 5).unwrap();
301
302 for i in 0..12 {
303 let record = GenericParquetRecord::new(
304 vec!["id".to_string(), "value".to_string()],
305 vec![i.to_string(), format!("value_{}", i)],
306 );
307 sink.process(StreamEvent::Data(record)).unwrap();
308 }
309
310 sink.process(StreamEvent::Complete(StreamSummary::new(12, 100)))
311 .unwrap();
312
313 assert_eq!(sink.items_processed(), 12);
314 }
315
316 #[test]
317 fn test_parquet_items_processed() {
318 let dir = tempdir().unwrap();
319 let path = dir.path().join("test.parquet");
320
321 let mut sink = ParquetStreamingSink::<GenericParquetRecord>::new(path, 100).unwrap();
322
323 for i in 0..25 {
324 let record = GenericParquetRecord::new(vec!["id".to_string()], vec![i.to_string()]);
325 sink.process(StreamEvent::Data(record)).unwrap();
326 }
327
328 assert_eq!(sink.items_processed(), 25);
329 }
330}