1use codegen::Scope;
2use parquet::schema::{parser::parse_message_type, types::SchemaDescriptor};
3use std::path::PathBuf;
4use std::{path::Path, sync::Arc};
5
6mod code;
7mod column_code;
8pub mod error;
9pub mod schema;
10mod test_code;
11mod types;
12mod util;
13
14use error::Error;
15use schema::{GenSchema, GenStruct};
16
17#[derive(Clone, Debug, Eq, PartialEq)]
18pub struct Config {
19 pub base_derives: Vec<&'static str>,
20 pub format: bool,
21 pub serde_support: bool,
22 pub tests: bool,
23}
24
25impl Config {
26 pub fn derives(&self) -> Vec<&'static str> {
27 let mut derives = self.base_derives.clone();
28
29 if self.serde_support {
30 derives.push("serde::Deserialize");
31 derives.push("serde::Serialize");
32 }
33
34 derives
35 }
36}
37
38impl Default for Config {
39 fn default() -> Self {
40 let base_derives = vec!["Clone", "Copy", "Debug", "Eq", "PartialEq"];
41
42 Self {
43 base_derives,
44 format: true,
45 serde_support: true,
46 tests: true,
47 }
48 }
49}
50
51#[derive(Debug)]
52pub struct ParsedFileSchema {
53 pub name: String,
54 pub schema: GenSchema,
55 pub descriptor: SchemaDescriptor,
56 scope: Scope,
57 absolute_path: PathBuf,
58 config: Config,
59}
60
61impl ParsedFileSchema {
62 pub fn code(&self) -> Result<String, Error> {
63 let raw_code = self.scope.to_string();
64
65 if self.config.format {
66 let file = syn::parse_file(&format!(
67 "#![cfg_attr(rustfmt, rustfmt_skip)]\n{}",
68 raw_code
69 ))?;
70 Ok(prettyplease::unparse(&file))
71 } else {
72 Ok(raw_code)
73 }
74 }
75
76 pub fn open<P: AsRef<Path>>(input: P, config: Config) -> Result<ParsedFileSchema, Error> {
77 let input = input.as_ref();
78 let schema_source = std::fs::read_to_string(input)?;
79 let (schema, descriptor) = parse_schema(&schema_source, config.clone())?;
80 let scope = schema_to_scope(&schema_source, &schema, &descriptor)?;
81
82 let name = input
83 .file_name()
84 .and_then(|file_name| file_name.to_str())
85 .and_then(|file_name| file_name.split('.').next())
86 .ok_or_else(|| Error::InvalidPath(input.to_path_buf()))?
87 .to_string();
88
89 Ok(ParsedFileSchema {
90 name,
91 schema,
92 descriptor,
93 scope,
94 absolute_path: input.canonicalize()?,
95 config,
96 })
97 }
98
99 pub fn open_dir<P: AsRef<Path>>(
100 input: P,
101 config: Config,
102 suffix: Option<&str>,
103 ) -> Result<Vec<ParsedFileSchema>, Error> {
104 let mut schemas = std::fs::read_dir(input)?
105 .map(|result| result.map_err(Error::from).map(|entry| entry.path()))
106 .filter_map(|result| {
107 result.map_or_else(
108 |error| Some(Err(error)),
109 |path| {
110 if path.is_file() {
111 match path.file_name().and_then(|file_name| file_name.to_str()) {
112 Some(file_name) => {
113 if suffix
114 .filter(|suffix| !file_name.ends_with(suffix))
115 .is_none()
116 {
117 Some(Self::open(path, config.clone()))
118 } else {
119 None
120 }
121 }
122 None => Some(Err(Error::InvalidPath(path))),
123 }
124 } else {
125 None
126 }
127 },
128 )
129 })
130 .collect::<Result<Vec<_>, _>>()?;
131 schemas.sort_by_key(|schema| (schema.name.clone(), schema.absolute_path.clone()));
132
133 Ok(schemas)
134 }
135
136 pub fn absolute_path_str(&self) -> Result<&str, Error> {
138 self.absolute_path
139 .as_os_str()
140 .to_str()
141 .ok_or_else(|| Error::InvalidPath(self.absolute_path.clone()))
142 }
143}
144
145pub fn parse_schema(
146 schema_source: &str,
147 config: Config,
148) -> Result<(GenSchema, SchemaDescriptor), Error> {
149 let schema_type = Arc::new(parse_message_type(schema_source)?);
150 let descriptor = SchemaDescriptor::new(schema_type);
151 let schema = GenSchema::from_schema(&descriptor, config)?;
152
153 Ok((schema, descriptor))
154}
155
156const STATIC_SCHEMA_DEF: &str = "
157 pub static SCHEMA: std::sync::LazyLock<parquet::schema::types::SchemaDescPtr> =
158 std::sync::LazyLock::new(|| std::sync::Arc::new(
159 parquet::schema::types::SchemaDescriptor::new(
160 std::sync::Arc::new(
161 parquet::schema::parser::parse_message_type(SCHEMA_SOURCE).unwrap()
162 )
163 )
164 ));
165";
166
167fn schema_to_scope(
168 schema_source: &str,
169 schema: &GenSchema,
170 descriptor: &SchemaDescriptor,
171) -> Result<Scope, Error> {
172 let mut scope = Scope::new();
173
174 scope.raw(format!(
175 "const SCHEMA_SOURCE: &str = \"{}\";",
176 schema_source
177 ));
178 scope.raw(STATIC_SCHEMA_DEF);
179
180 for GenStruct {
181 type_name,
182 fields,
183 derives,
184 } in schema.structs()
185 {
186 let gen_struct = scope.new_struct(&type_name).vis("pub");
187 for value in &derives {
188 gen_struct.derive(value);
189 }
190
191 for gen_field in fields {
192 let field = gen_struct
193 .new_field(&gen_field.name, gen_field.type_name())
194 .vis("pub");
195
196 if let Some(attributes) = gen_field.attributes {
197 field.annotation(attributes);
198 }
199 }
200 }
201
202 column_code::add_column_info_modules(&mut scope, &schema.gen_columns());
203
204 let schema_impl = scope
205 .new_impl(&schema.type_name)
206 .impl_trait("parquetry::Schema")
207 .associate_type("SortColumn", "columns::SortColumn")
208 .associate_type(
209 "Writer<W: std::io::Write + Send>",
210 format!("{}Writer<W>", schema.type_name),
211 );
212
213 schema_impl
214 .new_fn("sort_key_value")
215 .arg_ref_self()
216 .arg("sort_key", "parquetry::sort::SortKey<Self::SortColumn>")
217 .ret("Vec<u8>")
218 .push_block(code::gen_sort_key_value_block());
219
220 schema_impl
221 .new_fn("source")
222 .ret("&'static str")
223 .line("SCHEMA_SOURCE");
224
225 schema_impl
226 .new_fn("schema")
227 .ret("parquet::schema::types::SchemaDescPtr")
228 .line("SCHEMA.clone()");
229
230 schema_impl
231 .new_fn("writer")
232 .generic("W: std::io::Write + Send")
233 .arg("writer", "W")
234 .arg("properties", "parquet::file::properties::WriterProperties")
235 .ret("Result<Self::Writer<W>, parquetry::error::Error>")
236 .push_block(code::gen_writer_block()?);
237
238 let writer_struct = scope
239 .new_struct(&format!("{}Writer", schema.type_name))
240 .vis("pub")
241 .generic("W: std::io::Write");
242
243 writer_struct.new_field("writer", "parquet::file::writer::SerializedFileWriter<W>");
244 writer_struct.new_field("workspace", code::WORKSPACE_STRUCT_NAME);
245
246 let writer_impl = scope
247 .new_impl(&format!("{}Writer<W>", schema.type_name))
248 .impl_trait(format!(
249 "parquetry::write::SchemaWrite<{}, W>",
250 schema.type_name
251 ))
252 .generic("W: std::io::Write + Send");
253
254 writer_impl
255 .new_fn("write_row_group")
256 .generic("'a")
257 .generic(format!(
258 "E: From<parquetry::error::Error>, I: Iterator<Item = Result<&'a {}, E>>",
259 schema.type_name
260 ))
261 .arg_mut_self()
262 .arg("values", "&mut I")
263 .ret("Result<parquet::file::metadata::RowGroupMetaDataPtr, E>")
264 .bound(&schema.type_name, "'a")
265 .push_block(code::gen_writer_write_row_group_block(schema)?);
266
267 writer_impl
268 .new_fn("write_item")
269 .arg_mut_self()
270 .arg("value", format!("&{}", schema.type_name))
271 .ret("Result<(), parquetry::error::Error>")
272 .line(format!(
273 "{}::add_item_to_workspace(&mut self.workspace, value)",
274 schema.type_name
275 ));
276
277 writer_impl
278 .new_fn("finish_row_group")
279 .arg_mut_self()
280 .ret("Result<parquet::file::metadata::RowGroupMetaDataPtr, parquetry::error::Error>")
281 .line(format!(
282 "{}::write_with_workspace(&mut self.writer, &mut self.workspace)",
283 schema.type_name
284 ));
285
286 writer_impl
287 .new_fn("finish")
288 .arg_self()
289 .ret("Result<parquet::format::FileMetaData, parquetry::error::Error>")
290 .line("Ok(self.writer.close()?)");
291
292 let row_conversion_impl = scope
293 .new_impl(&schema.type_name)
294 .impl_trait("TryFrom<parquet::record::Row>")
295 .associate_type("Error", "parquetry::error::Error");
296
297 row_conversion_impl
298 .new_fn("try_from")
299 .arg("row", "parquet::record::Row")
300 .ret("Result<Self, parquetry::error::Error>")
301 .push_block(code::gen_row_conversion_block(schema)?);
302
303 let base_impl = scope.new_impl(&schema.type_name);
304
305 base_impl
306 .new_fn("write_sort_key_bytes")
307 .arg_ref_self()
308 .arg(
309 "column",
310 "parquetry::sort::Sort<<Self as parquetry::Schema>::SortColumn>",
311 )
312 .arg("bytes", "&mut Vec<u8>")
313 .push_block(code::gen_write_sort_key_bytes_block(schema)?);
314
315 base_impl
316 .new_fn("write_with_workspace")
317 .generic("W: std::io::Write + Send")
318 .arg(
319 "file_writer",
320 "&mut parquet::file::writer::SerializedFileWriter<W>",
321 )
322 .arg("workspace", format!("&mut {}", code::WORKSPACE_STRUCT_NAME))
323 .ret("Result<parquet::file::metadata::RowGroupMetaDataPtr, parquetry::error::Error>")
324 .push_block(code::gen_write_with_workspace_block(descriptor.columns())?);
325
326 base_impl
327 .new_fn("fill_workspace")
328 .generic("'a")
329 .generic("E: From<parquetry::error::Error>, I: Iterator<Item = Result<&'a Self, E>>")
330 .arg("workspace", format!("&mut {}", code::WORKSPACE_STRUCT_NAME))
331 .arg("values", "I")
332 .ret("Result<usize, E>")
333 .push_block(code::gen_fill_workspace_block()?);
334
335 base_impl
336 .new_fn("add_item_to_workspace")
337 .arg("workspace", format!("&mut {}", code::WORKSPACE_STRUCT_NAME))
338 .arg("value", "&Self")
339 .ret("Result<(), parquetry::error::Error>")
340 .push_block(code::gen_add_item_to_workspace_block(schema)?);
341
342 for gen_struct in schema.structs() {
343 let base_impl = scope.new_impl(&gen_struct.type_name);
344
345 code::gen_constructor(&gen_struct, base_impl.new_fn("new"))?;
346 }
347
348 code::add_workspace_struct(&mut scope, descriptor.columns())?;
349
350 if schema.config.tests {
351 let test_module = scope.new_module("test").attr("cfg(test)");
352
353 test_code::gen_test_code(test_module, schema)?;
354 }
355
356 Ok(scope)
357}