parquetry_gen/
lib.rs

1use codegen::Scope;
2use parquet::schema::{parser::parse_message_type, types::SchemaDescriptor};
3use std::path::PathBuf;
4use std::{path::Path, sync::Arc};
5
6mod code;
7mod column_code;
8pub mod error;
9pub mod schema;
10mod test_code;
11mod types;
12mod util;
13
14use error::Error;
15use schema::{GenSchema, GenStruct};
16
17#[derive(Clone, Debug, Eq, PartialEq)]
18pub struct Config {
19    pub base_derives: Vec<&'static str>,
20    pub format: bool,
21    pub serde_support: bool,
22    pub tests: bool,
23}
24
25impl Config {
26    pub fn derives(&self) -> Vec<&'static str> {
27        let mut derives = self.base_derives.clone();
28
29        if self.serde_support {
30            derives.push("serde::Deserialize");
31            derives.push("serde::Serialize");
32        }
33
34        derives
35    }
36}
37
38impl Default for Config {
39    fn default() -> Self {
40        let base_derives = vec!["Clone", "Copy", "Debug", "Eq", "PartialEq"];
41
42        Self {
43            base_derives,
44            format: true,
45            serde_support: true,
46            tests: true,
47        }
48    }
49}
50
51#[derive(Debug)]
52pub struct ParsedFileSchema {
53    pub name: String,
54    pub schema: GenSchema,
55    pub descriptor: SchemaDescriptor,
56    scope: Scope,
57    absolute_path: PathBuf,
58    config: Config,
59}
60
61impl ParsedFileSchema {
62    pub fn code(&self) -> Result<String, Error> {
63        let raw_code = self.scope.to_string();
64
65        if self.config.format {
66            let file = syn::parse_file(&format!(
67                "#![cfg_attr(rustfmt, rustfmt_skip)]\n{}",
68                raw_code
69            ))?;
70            Ok(prettyplease::unparse(&file))
71        } else {
72            Ok(raw_code)
73        }
74    }
75
76    pub fn open<P: AsRef<Path>>(input: P, config: Config) -> Result<ParsedFileSchema, Error> {
77        let input = input.as_ref();
78        let schema_source = std::fs::read_to_string(input)?;
79        let (schema, descriptor) = parse_schema(&schema_source, config.clone())?;
80        let scope = schema_to_scope(&schema_source, &schema, &descriptor)?;
81
82        let name = input
83            .file_name()
84            .and_then(|file_name| file_name.to_str())
85            .and_then(|file_name| file_name.split('.').next())
86            .ok_or_else(|| Error::InvalidPath(input.to_path_buf()))?
87            .to_string();
88
89        Ok(ParsedFileSchema {
90            name,
91            schema,
92            descriptor,
93            scope,
94            absolute_path: input.canonicalize()?,
95            config,
96        })
97    }
98
99    pub fn open_dir<P: AsRef<Path>>(
100        input: P,
101        config: Config,
102        suffix: Option<&str>,
103    ) -> Result<Vec<ParsedFileSchema>, Error> {
104        let mut schemas = std::fs::read_dir(input)?
105            .map(|result| result.map_err(Error::from).map(|entry| entry.path()))
106            .filter_map(|result| {
107                result.map_or_else(
108                    |error| Some(Err(error)),
109                    |path| {
110                        if path.is_file() {
111                            match path.file_name().and_then(|file_name| file_name.to_str()) {
112                                Some(file_name) => {
113                                    if suffix
114                                        .filter(|suffix| !file_name.ends_with(suffix))
115                                        .is_none()
116                                    {
117                                        Some(Self::open(path, config.clone()))
118                                    } else {
119                                        None
120                                    }
121                                }
122                                None => Some(Err(Error::InvalidPath(path))),
123                            }
124                        } else {
125                            None
126                        }
127                    },
128                )
129            })
130            .collect::<Result<Vec<_>, _>>()?;
131        schemas.sort_by_key(|schema| (schema.name.clone(), schema.absolute_path.clone()));
132
133        Ok(schemas)
134    }
135
136    /// For use with `cargo:rerun-if-changed`
137    pub fn absolute_path_str(&self) -> Result<&str, Error> {
138        self.absolute_path
139            .as_os_str()
140            .to_str()
141            .ok_or_else(|| Error::InvalidPath(self.absolute_path.clone()))
142    }
143}
144
145pub fn parse_schema(
146    schema_source: &str,
147    config: Config,
148) -> Result<(GenSchema, SchemaDescriptor), Error> {
149    let schema_type = Arc::new(parse_message_type(schema_source)?);
150    let descriptor = SchemaDescriptor::new(schema_type);
151    let schema = GenSchema::from_schema(&descriptor, config)?;
152
153    Ok((schema, descriptor))
154}
155
156const STATIC_SCHEMA_DEF: &str = "
157    pub static SCHEMA: std::sync::LazyLock<parquet::schema::types::SchemaDescPtr> =
158        std::sync::LazyLock::new(|| std::sync::Arc::new(
159            parquet::schema::types::SchemaDescriptor::new(
160                std::sync::Arc::new(
161                    parquet::schema::parser::parse_message_type(SCHEMA_SOURCE).unwrap()
162                )
163            )
164        ));
165";
166
167fn schema_to_scope(
168    schema_source: &str,
169    schema: &GenSchema,
170    descriptor: &SchemaDescriptor,
171) -> Result<Scope, Error> {
172    let mut scope = Scope::new();
173
174    scope.raw(format!(
175        "const SCHEMA_SOURCE: &str = \"{}\";",
176        schema_source
177    ));
178    scope.raw(STATIC_SCHEMA_DEF);
179
180    for GenStruct {
181        type_name,
182        fields,
183        derives,
184    } in schema.structs()
185    {
186        let gen_struct = scope.new_struct(&type_name).vis("pub");
187        for value in &derives {
188            gen_struct.derive(value);
189        }
190
191        for gen_field in fields {
192            let field = gen_struct
193                .new_field(&gen_field.name, gen_field.type_name())
194                .vis("pub");
195
196            if let Some(attributes) = gen_field.attributes {
197                field.annotation(attributes);
198            }
199        }
200    }
201
202    column_code::add_column_info_modules(&mut scope, &schema.gen_columns());
203
204    let schema_impl = scope
205        .new_impl(&schema.type_name)
206        .impl_trait("parquetry::Schema")
207        .associate_type("SortColumn", "columns::SortColumn")
208        .associate_type(
209            "Writer<W: std::io::Write + Send>",
210            format!("{}Writer<W>", schema.type_name),
211        );
212
213    schema_impl
214        .new_fn("sort_key_value")
215        .arg_ref_self()
216        .arg("sort_key", "parquetry::sort::SortKey<Self::SortColumn>")
217        .ret("Vec<u8>")
218        .push_block(code::gen_sort_key_value_block());
219
220    schema_impl
221        .new_fn("source")
222        .ret("&'static str")
223        .line("SCHEMA_SOURCE");
224
225    schema_impl
226        .new_fn("schema")
227        .ret("parquet::schema::types::SchemaDescPtr")
228        .line("SCHEMA.clone()");
229
230    schema_impl
231        .new_fn("writer")
232        .generic("W: std::io::Write + Send")
233        .arg("writer", "W")
234        .arg("properties", "parquet::file::properties::WriterProperties")
235        .ret("Result<Self::Writer<W>, parquetry::error::Error>")
236        .push_block(code::gen_writer_block()?);
237
238    let writer_struct = scope
239        .new_struct(&format!("{}Writer", schema.type_name))
240        .vis("pub")
241        .generic("W: std::io::Write");
242
243    writer_struct.new_field("writer", "parquet::file::writer::SerializedFileWriter<W>");
244    writer_struct.new_field("workspace", code::WORKSPACE_STRUCT_NAME);
245
246    let writer_impl = scope
247        .new_impl(&format!("{}Writer<W>", schema.type_name))
248        .impl_trait(format!(
249            "parquetry::write::SchemaWrite<{}, W>",
250            schema.type_name
251        ))
252        .generic("W: std::io::Write + Send");
253
254    writer_impl
255        .new_fn("write_row_group")
256        .generic("'a")
257        .generic(format!(
258            "E: From<parquetry::error::Error>, I: Iterator<Item = Result<&'a {}, E>>",
259            schema.type_name
260        ))
261        .arg_mut_self()
262        .arg("values", "&mut I")
263        .ret("Result<parquet::file::metadata::RowGroupMetaDataPtr, E>")
264        .bound(&schema.type_name, "'a")
265        .push_block(code::gen_writer_write_row_group_block(schema)?);
266
267    writer_impl
268        .new_fn("write_item")
269        .arg_mut_self()
270        .arg("value", format!("&{}", schema.type_name))
271        .ret("Result<(), parquetry::error::Error>")
272        .line(format!(
273            "{}::add_item_to_workspace(&mut self.workspace, value)",
274            schema.type_name
275        ));
276
277    writer_impl
278        .new_fn("finish_row_group")
279        .arg_mut_self()
280        .ret("Result<parquet::file::metadata::RowGroupMetaDataPtr, parquetry::error::Error>")
281        .line(format!(
282            "{}::write_with_workspace(&mut self.writer, &mut self.workspace)",
283            schema.type_name
284        ));
285
286    writer_impl
287        .new_fn("finish")
288        .arg_self()
289        .ret("Result<parquet::format::FileMetaData, parquetry::error::Error>")
290        .line("Ok(self.writer.close()?)");
291
292    let row_conversion_impl = scope
293        .new_impl(&schema.type_name)
294        .impl_trait("TryFrom<parquet::record::Row>")
295        .associate_type("Error", "parquetry::error::Error");
296
297    row_conversion_impl
298        .new_fn("try_from")
299        .arg("row", "parquet::record::Row")
300        .ret("Result<Self, parquetry::error::Error>")
301        .push_block(code::gen_row_conversion_block(schema)?);
302
303    let base_impl = scope.new_impl(&schema.type_name);
304
305    base_impl
306        .new_fn("write_sort_key_bytes")
307        .arg_ref_self()
308        .arg(
309            "column",
310            "parquetry::sort::Sort<<Self as parquetry::Schema>::SortColumn>",
311        )
312        .arg("bytes", "&mut Vec<u8>")
313        .push_block(code::gen_write_sort_key_bytes_block(schema)?);
314
315    base_impl
316        .new_fn("write_with_workspace")
317        .generic("W: std::io::Write + Send")
318        .arg(
319            "file_writer",
320            "&mut parquet::file::writer::SerializedFileWriter<W>",
321        )
322        .arg("workspace", format!("&mut {}", code::WORKSPACE_STRUCT_NAME))
323        .ret("Result<parquet::file::metadata::RowGroupMetaDataPtr, parquetry::error::Error>")
324        .push_block(code::gen_write_with_workspace_block(descriptor.columns())?);
325
326    base_impl
327        .new_fn("fill_workspace")
328        .generic("'a")
329        .generic("E: From<parquetry::error::Error>, I: Iterator<Item = Result<&'a Self, E>>")
330        .arg("workspace", format!("&mut {}", code::WORKSPACE_STRUCT_NAME))
331        .arg("values", "I")
332        .ret("Result<usize, E>")
333        .push_block(code::gen_fill_workspace_block()?);
334
335    base_impl
336        .new_fn("add_item_to_workspace")
337        .arg("workspace", format!("&mut {}", code::WORKSPACE_STRUCT_NAME))
338        .arg("value", "&Self")
339        .ret("Result<(), parquetry::error::Error>")
340        .push_block(code::gen_add_item_to_workspace_block(schema)?);
341
342    for gen_struct in schema.structs() {
343        let base_impl = scope.new_impl(&gen_struct.type_name);
344
345        code::gen_constructor(&gen_struct, base_impl.new_fn("new"))?;
346    }
347
348    code::add_workspace_struct(&mut scope, descriptor.columns())?;
349
350    if schema.config.tests {
351        let test_module = scope.new_module("test").attr("cfg(test)");
352
353        test_code::gen_test_code(test_module, schema)?;
354    }
355
356    Ok(scope)
357}