1use gut::prelude::*;
11mod pq {
18 use anyhow::Result;
19 use serde::Serialize;
20
21 use arrow2::{
22 array::Array,
23 chunk::Chunk,
24 datatypes::{Field, Schema},
25 io::parquet::write::{Encoding, WriteOptions},
26 };
27
28 pub fn get_parquet_columns<T: Serialize + ?Sized>(items: &T) -> Result<(Schema, Chunk<Box<dyn Array>>)> {
29 use serde_arrow::schema::{SchemaLike, SerdeArrowSchema, TracingOptions};
30
31 let fields: Vec<Field> =
32 SerdeArrowSchema::from_samples(items, TracingOptions::default().allow_null_fields(true).guess_dates(true))?
33 .try_into()?;
34 let arrays = serde_arrow::to_arrow2(&fields, items)?;
35
36 Ok((Schema::from(fields), Chunk::new(arrays)))
37 }
38
39 pub fn default_write_options() -> WriteOptions {
40 use arrow2::io::parquet::write::{CompressionOptions, Version};
41
42 WriteOptions {
43 write_statistics: false,
44 compression: CompressionOptions::Snappy,
47 version: Version::V2,
48 data_pagesize_limit: None,
49 }
50 }
51
52 pub fn get_encodings(schema: &Schema) -> Vec<Vec<Encoding>> {
53 use arrow2::io::parquet::write::transverse;
54
55 schema
56 .fields
57 .iter()
58 .map(|f| transverse(&f.data_type, |_| Encoding::Plain))
59 .collect()
60 }
61
62 }
81mod writer {
85 use super::pq::*;
86 use super::Result;
87
88 use arrow2::io::parquet::write::FileWriter;
89 use serde::Serialize;
90 use std::fs::File;
91 use std::path::Path;
92 use std::path::PathBuf;
93
94 pub struct SimpleParquetFileWriter {
96 path: PathBuf,
97 writer: Option<FileWriter<File>>,
98 }
99
100 impl SimpleParquetFileWriter {
101 pub fn new(path: &Path) -> Self {
103 Self {
104 path: path.to_owned(),
105 writer: None,
106 }
107 }
108
109 pub fn write_row_group<T: Serialize + ?Sized>(&mut self, records: &T) -> Result<&mut Self> {
111 use anyhow::ensure;
112 use arrow2::io::parquet::write::RowGroupIterator;
113
114 let (schema, columns) = get_parquet_columns(records)?;
115 let options = default_write_options();
116
117 if self.writer.is_none() {
118 let file = File::create(&self.path)?;
119 self.writer = FileWriter::try_new(file, schema.clone(), options).ok();
120 }
121 ensure!(self.writer.is_some());
122
123 if let Some(writer) = self.writer.as_mut() {
124 let encodings = get_encodings(&schema);
125 let iter = vec![Ok(columns)];
126 let row_groups = RowGroupIterator::try_new(iter.into_iter(), &schema, options, encodings)?;
127 for group in row_groups {
128 writer.write(group?)?;
129 }
130 }
131
132 Ok(self)
133 }
134
135 pub fn close(self) -> Result<()> {
139 if let Some(mut writer) = self.writer {
140 let _size = writer.end(None)?;
141 }
142 Ok(())
143 }
144 }
145}
146pub use writer::SimpleParquetFileWriter;
150#[test]
154fn test_parquet_writer() -> Result<()> {
155 use writer::SimpleParquetFileWriter;
156
157 #[derive(Debug, Serialize)]
159 struct Coord {
160 x: f64,
161 e: Option<f64>,
162 i: [f64; 3],
163 b: Vec<f64>,
164 }
165
166 let mut writer = SimpleParquetFileWriter::new("/tmp/b.pq".as_ref());
168 let rows = vec![
169 Coord {
170 x: 1.0,
171 e: None,
172 i: [0.0; 3],
173 b: vec![1.0],
174 },
175 Coord {
176 x: 2.0,
177 i: [0.0; 3],
178 e: None,
179 b: vec![0.2],
180 },
181 ];
182 writer.write_row_group(rows.as_slice())?;
183
184 let rows = vec![
185 Coord {
186 x: 1.0,
187 e: None,
188 i: [0.0; 3],
189 b: vec![1.0],
190 },
191 Coord {
192 x: 0.5,
193 e: Some(2.0),
194 i: [0.0; 3],
195 b: vec![0.2],
196 },
197 ];
198 writer.write_row_group(rows.as_slice())?;
199 writer.close()?;
200
201 Ok(())
202}
203#[test]
207fn test_parquet_gchemol() -> anyhow::Result<()> {
208 use gchemol::io::formats::ExtxyzFile;
209 use writer::SimpleParquetFileWriter;
210
211 #[derive(Debug, Serialize)]
213 struct Coords {
214 frame_index: usize,
215 atom_number: usize,
216 position: [f64; 3],
217 forces: [f64; 3],
218 }
219
220 let mut writer = SimpleParquetFileWriter::new("/tmp/cu.pq".as_ref());
221 let f = "tests/files/cu.xyz";
222 let mut rows = vec![];
223 let mols = ExtxyzFile::read_molecules_from(f)?;
224 for (i, mol) in mols.enumerate() {
225 for (j, atom) in mol.atoms() {
226 let forces: [f64; 3] = atom.properties.load("forces")?;
227 rows.push(Coords {
228 frame_index: i,
229 atom_number: j,
230 position: atom.position(),
231 forces,
232 });
233 }
234 writer.write_row_group(rows.as_slice())?;
235 }
236 writer.close()?;
237
238 Ok(())
239}
240