Skip to main content

rivet/format/
mod.rs

1pub mod csv;
2pub mod parquet;
3
4use arrow::datatypes::SchemaRef;
5use arrow::record_batch::RecordBatch;
6
7use crate::config::{CompressionType, FormatType};
8use crate::error::Result;
9
10/// Streaming writer: receives one RecordBatch at a time.
11pub trait FormatWriter {
12    fn write_batch(&mut self, batch: &RecordBatch) -> Result<()>;
13    fn finish(self: Box<Self>) -> Result<()>;
14    /// Approximate bytes written so far (for file-size splitting).
15    fn bytes_written(&self) -> u64;
16}
17
18pub trait Format {
19    fn create_writer(
20        &self,
21        schema: &SchemaRef,
22        writer: Box<dyn std::io::Write + Send>,
23    ) -> Result<Box<dyn FormatWriter + Send>>;
24
25    fn file_extension(&self) -> &str;
26}
27
28pub fn create_format(
29    format_type: FormatType,
30    compression: CompressionType,
31    compression_level: Option<u32>,
32    row_group_rows: Option<usize>,
33) -> Box<dyn Format> {
34    match format_type {
35        FormatType::Csv => Box::new(csv::CsvFormat),
36        FormatType::Parquet => Box::new(parquet::ParquetFormat::new(
37            compression,
38            compression_level,
39            row_group_rows,
40        )),
41    }
42}
43
44#[cfg(test)]
45mod tests {
46    use super::*;
47    use arrow::array::Int64Array;
48    use arrow::datatypes::{DataType, Field, Schema};
49    use std::sync::Arc;
50
51    fn schema() -> Arc<Schema> {
52        Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]))
53    }
54
55    fn batch(schema: &Arc<Schema>) -> arrow::record_batch::RecordBatch {
56        arrow::record_batch::RecordBatch::try_new(
57            schema.clone(),
58            vec![Arc::new(Int64Array::from(vec![1i64, 2]))],
59        )
60        .unwrap()
61    }
62
63    #[test]
64    fn create_format_csv_extension_and_roundtrip() {
65        let schema = schema();
66        let fmt = create_format(FormatType::Csv, CompressionType::None, None, None);
67        assert_eq!(fmt.file_extension(), "csv");
68        let mut w = fmt
69            .create_writer(&schema, Box::new(Vec::<u8>::new()))
70            .unwrap();
71        w.write_batch(&batch(&schema)).unwrap();
72        w.finish().unwrap();
73    }
74
75    #[test]
76    fn create_format_parquet_extension_and_roundtrip() {
77        let schema = schema();
78        let fmt = create_format(FormatType::Parquet, CompressionType::Zstd, None, None);
79        assert_eq!(fmt.file_extension(), "parquet");
80        let mut w = fmt
81            .create_writer(&schema, Box::new(Vec::<u8>::new()))
82            .unwrap();
83        w.write_batch(&batch(&schema)).unwrap();
84        w.finish().unwrap();
85    }
86
87    #[test]
88    fn create_format_parquet_uncompressed_finish_ok() {
89        let schema = schema();
90        let fmt = create_format(FormatType::Parquet, CompressionType::None, None, None);
91        let w = fmt
92            .create_writer(&schema, Box::new(Vec::<u8>::new()))
93            .unwrap();
94        w.finish().unwrap();
95    }
96}