gosh_dataset/
lib.rs

1// [[file:../parquet.note::530f359c][530f359c]]
2// #![deny(warnings)]
3// 530f359c ends here
4
5// [[file:../parquet.note::561ea56f][561ea56f]]
6// mod nested;
7// 561ea56f ends here
8
9// [[file:../parquet.note::c8eabb97][c8eabb97]]
10use gut::prelude::*;
11// c8eabb97 ends here
12
13// [[file:../parquet.note::b0b749b7][b0b749b7]]
14// see
15// https://jorgecarleitao.github.io/arrow2/io/parquet_write.html
16// https://github.com/chmp/serde_arrow
17mod pq {
18    use anyhow::Result;
19    use serde::Serialize;
20
21    use arrow2::{
22        array::Array,
23        chunk::Chunk,
24        datatypes::{Field, Schema},
25        io::parquet::write::{Encoding, WriteOptions},
26    };
27
28    pub fn get_parquet_columns<T: Serialize + ?Sized>(items: &T) -> Result<(Schema, Chunk<Box<dyn Array>>)> {
29        use serde_arrow::schema::{SchemaLike, SerdeArrowSchema, TracingOptions};
30
31        let fields: Vec<Field> =
32            SerdeArrowSchema::from_samples(items, TracingOptions::default().allow_null_fields(true).guess_dates(true))?
33                .try_into()?;
34        let arrays = serde_arrow::to_arrow2(&fields, items)?;
35
36        Ok((Schema::from(fields), Chunk::new(arrays)))
37    }
38
39    pub fn default_write_options() -> WriteOptions {
40        use arrow2::io::parquet::write::{CompressionOptions, Version};
41
42        WriteOptions {
43            write_statistics: false,
44            // SNAPPY has fast compression speeds: https://github.com/google/snappy
45            // https://github.com/apache/parquet-format/blob/master/Compression.md
46            compression: CompressionOptions::Snappy,
47            version: Version::V2,
48            data_pagesize_limit: None,
49        }
50    }
51
52    pub fn get_encodings(schema: &Schema) -> Vec<Vec<Encoding>> {
53        use arrow2::io::parquet::write::transverse;
54
55        schema
56            .fields
57            .iter()
58            .map(|f| transverse(&f.data_type, |_| Encoding::Plain))
59            .collect()
60    }
61
62    // pub fn write_parquet_chunk(path: &str, schema: Schema, columns: Chunk<Box<dyn Array>>) -> Result<()> {
63    //     use arrow2::io::parquet::write::{transverse, Encoding, FileWriter, RowGroupIterator};
64
65    //     let options = default_write_options();
66    //     // Create a new empty file
67    //     let file = File::create(path)?;
68    //     let mut writer = FileWriter::try_new(file, schema.clone(), options)?;
69
70    //     let encodings = get_encodings(&schema);
71    //     let iter = vec![Ok(columns)];
72    //     let row_groups = RowGroupIterator::try_new(iter.into_iter(), &schema, options, encodings)?;
73
74    //     for group in row_groups {
75    //         writer.write(group?)?;
76    //     }
77    //     let _size = writer.end(None)?;
78    //     Ok(())
79    // }
80}
81// b0b749b7 ends here
82
83// [[file:../parquet.note::091d2689][091d2689]]
84mod writer {
85    use super::pq::*;
86    use super::Result;
87
88    use arrow2::io::parquet::write::FileWriter;
89    use serde::Serialize;
90    use std::fs::File;
91    use std::path::Path;
92    use std::path::PathBuf;
93
94    /// A simple struct for writing a vec of `struct` in Parquet format
95    pub struct SimpleParquetFileWriter {
96        path: PathBuf,
97        writer: Option<FileWriter<File>>,
98    }
99
100    impl SimpleParquetFileWriter {
101        /// Construct a Parquet file writer in `path`.
102        pub fn new(path: &Path) -> Self {
103            Self {
104                path: path.to_owned(),
105                writer: None,
106            }
107        }
108
109        /// Write `records` in a row group .
110        pub fn write_row_group<T: Serialize + ?Sized>(&mut self, records: &T) -> Result<&mut Self> {
111            use anyhow::ensure;
112            use arrow2::io::parquet::write::RowGroupIterator;
113
114            let (schema, columns) = get_parquet_columns(records)?;
115            let options = default_write_options();
116
117            if self.writer.is_none() {
118                let file = File::create(&self.path)?;
119                self.writer = FileWriter::try_new(file, schema.clone(), options).ok();
120            }
121            ensure!(self.writer.is_some());
122
123            if let Some(writer) = self.writer.as_mut() {
124                let encodings = get_encodings(&schema);
125                let iter = vec![Ok(columns)];
126                let row_groups = RowGroupIterator::try_new(iter.into_iter(), &schema, options, encodings)?;
127                for group in row_groups {
128                    writer.write(group?)?;
129                }
130            }
131
132            Ok(self)
133        }
134
135        /// Writes the footer of the parquet file. Must be called when
136        /// finish writing. Otherwise, the resulting parquet file will be
137        /// invalid for missing `PAR1` in the end.
138        pub fn close(self) -> Result<()> {
139            if let Some(mut writer) = self.writer {
140                let _size = writer.end(None)?;
141            }
142            Ok(())
143        }
144    }
145}
146// 091d2689 ends here
147
148// [[file:../parquet.note::531d4795][531d4795]]
149pub use writer::SimpleParquetFileWriter;
150// 531d4795 ends here
151
152// [[file:../parquet.note::74c362b5][74c362b5]]
153#[test]
154fn test_parquet_writer() -> Result<()> {
155    use writer::SimpleParquetFileWriter;
156
157    // define columns using rust struct with arbitrary, simple data types
158    #[derive(Debug, Serialize)]
159    struct Coord {
160        x: f64,
161        e: Option<f64>,
162        i: [f64; 3],
163        b: Vec<f64>,
164    }
165
166    // write computed rows data
167    let mut writer = SimpleParquetFileWriter::new("/tmp/b.pq".as_ref());
168    let rows = vec![
169        Coord {
170            x: 1.0,
171            e: None,
172            i: [0.0; 3],
173            b: vec![1.0],
174        },
175        Coord {
176            x: 2.0,
177            i: [0.0; 3],
178            e: None,
179            b: vec![0.2],
180        },
181    ];
182    writer.write_row_group(rows.as_slice())?;
183
184    let rows = vec![
185        Coord {
186            x: 1.0,
187            e: None,
188            i: [0.0; 3],
189            b: vec![1.0],
190        },
191        Coord {
192            x: 0.5,
193            e: Some(2.0),
194            i: [0.0; 3],
195            b: vec![0.2],
196        },
197    ];
198    writer.write_row_group(rows.as_slice())?;
199    writer.close()?;
200
201    Ok(())
202}
203// 74c362b5 ends here
204
205// [[file:../parquet.note::e55075ee][e55075ee]]
206#[test]
207fn test_parquet_gchemol() -> anyhow::Result<()> {
208    use gchemol::io::formats::ExtxyzFile;
209    use writer::SimpleParquetFileWriter;
210
211    // define columns for parquet
212    #[derive(Debug, Serialize)]
213    struct Coords {
214        frame_index: usize,
215        atom_number: usize,
216        position: [f64; 3],
217        forces: [f64; 3],
218    }
219
220    let mut writer = SimpleParquetFileWriter::new("/tmp/cu.pq".as_ref());
221    let f = "tests/files/cu.xyz";
222    let mut rows = vec![];
223    let mols = ExtxyzFile::read_molecules_from(f)?;
224    for (i, mol) in mols.enumerate() {
225        for (j, atom) in mol.atoms() {
226            let forces: [f64; 3] = atom.properties.load("forces")?;
227            rows.push(Coords {
228                frame_index: i,
229                atom_number: j,
230                position: atom.position(),
231                forces,
232            });
233        }
234        writer.write_row_group(rows.as_slice())?;
235    }
236    writer.close()?;
237
238    Ok(())
239}
240// e55075ee ends here