parquetry_gen/
schema.rs

1use convert_case::{Case, Casing};
2use parquet::{
3    basic::{LogicalType, Repetition},
4    schema::types::{ColumnDescPtr, SchemaDescriptor, Type},
5};
6use std::collections::HashSet;
7use std::ops::Range;
8
9use crate::types::TypeMapping;
10
11use super::{error::Error, Config};
12
13#[derive(Clone, Debug)]
14pub struct GenSchema {
15    pub type_name: String,
16    pub gen_fields: Vec<GenField>,
17    pub config: Config,
18}
19
20#[derive(Clone, Debug)]
21pub struct GenField {
22    pub name: String,
23    pub base_type_name: String,
24    pub attributes: Option<String>,
25    pub optional: bool,
26    pub gen_type: GenType,
27}
28
29#[derive(Clone, Debug)]
30pub enum GenType {
31    Column(GenColumn),
32    Struct {
33        gen_fields: Vec<GenField>,
34        def_depth: usize,
35        rep_depth: usize,
36    },
37    List {
38        element_optional: bool,
39        element_gen_type: Box<GenType>,
40        element_struct_name: String,
41        def_depth: usize,
42        rep_depth: usize,
43    },
44}
45
46#[derive(Clone, Debug)]
47pub struct GenStruct {
48    pub type_name: String,
49    pub fields: Vec<GenField>,
50    pub derives: Vec<&'static str>,
51}
52
53#[derive(Clone, Debug)]
54pub struct GenColumn {
55    pub index: usize,
56    pub rust_path: Vec<(String, bool)>,
57    pub descriptor: ColumnDescPtr,
58    pub mapping: TypeMapping,
59}
60
61impl GenStruct {
62    fn new(
63        type_name: &str,
64        fields: Vec<GenField>,
65        base_derives: &[&'static str],
66        disallowed_derives: HashSet<&str>,
67    ) -> Self {
68        let derives = base_derives
69            .iter()
70            .cloned()
71            .filter(|value| !disallowed_derives.contains(value))
72            .collect::<Vec<_>>();
73
74        Self {
75            type_name: type_name.to_string(),
76            fields,
77            derives,
78        }
79    }
80}
81
82impl GenSchema {
83    pub fn from_schema(schema: &SchemaDescriptor, config: Config) -> Result<Self, Error> {
84        if let GenField {
85            base_type_name,
86            gen_type: GenType::Struct { gen_fields, .. },
87            ..
88        } = GenField::from_type(
89            &config,
90            schema.root_schema(),
91            schema.columns(),
92            0,
93            "",
94            vec![],
95            0,
96            0,
97        )?
98        .0
99        {
100            Ok(Self {
101                type_name: base_type_name,
102                gen_fields,
103                config,
104            })
105        } else {
106            Err(Error::InvalidRootSchema(schema.root_schema().clone()))
107        }
108    }
109
110    pub fn field_names(&self) -> Vec<&str> {
111        self.gen_fields
112            .iter()
113            .map(|gen_field| gen_field.name.as_str())
114            .collect()
115    }
116
117    pub fn structs(&self) -> Vec<GenStruct> {
118        let disallowed_derives = self
119            .gen_fields
120            .iter()
121            .flat_map(|gen_field| gen_field.gen_type.disallowed_derives())
122            .collect();
123
124        let mut structs = vec![GenStruct::new(
125            &self.type_name,
126            self.gen_fields.clone(),
127            &self.config.derives(),
128            disallowed_derives,
129        )];
130
131        for gen_field in &self.gen_fields {
132            gen_field.gen_type.structs(
133                &gen_field.base_type_name,
134                &self.config.derives(),
135                &mut structs,
136            );
137        }
138
139        structs
140    }
141
142    pub fn gen_columns(&self) -> Vec<GenColumn> {
143        let mut gen_columns = vec![];
144
145        for gen_field in &self.gen_fields {
146            gen_field.gen_type.gen_columns(&mut gen_columns);
147        }
148
149        gen_columns
150    }
151}
152
153impl GenField {
154    pub fn type_name(&self) -> String {
155        if self.optional {
156            format!("Option<{}>", self.base_type_name)
157        } else {
158            self.base_type_name.to_string()
159        }
160    }
161
162    fn field_name(source_name: &str) -> String {
163        source_name.to_string()
164    }
165
166    fn field_type_name(source_name: &str) -> String {
167        source_name.to_case(Case::Pascal)
168    }
169
170    fn from_type(
171        config: &Config,
172        tp: &Type,
173        columns: &[ColumnDescPtr],
174        current_column_index: usize,
175        name: &str,
176        rust_path: Vec<(String, bool)>,
177        def_depth: usize,
178        rep_depth: usize,
179    ) -> Result<(Self, usize), Error> {
180        match tp {
181            Type::PrimitiveType {
182                basic_info,
183                physical_type,
184                type_length,
185                ..
186            } => {
187                // We currently only support annotated lists
188                if basic_info.repetition() == Repetition::REPEATED {
189                    Err(Error::UnsupportedRepetition(basic_info.name().to_string()))
190                } else {
191                    let column = columns[current_column_index].clone();
192                    let mapping = super::types::TypeMapping::from_types(
193                        column.logical_type(),
194                        *physical_type,
195                        *type_length,
196                    )?;
197                    let optional = basic_info.repetition() == Repetition::OPTIONAL;
198
199                    Ok((
200                        Self {
201                            name: name.to_string(),
202                            base_type_name: mapping.rust_type_name().to_string(),
203                            attributes: mapping.attributes(config.serde_support, optional),
204                            optional,
205                            gen_type: GenType::Column(GenColumn {
206                                index: current_column_index,
207                                rust_path,
208                                descriptor: column,
209                                mapping,
210                            }),
211                        },
212                        current_column_index + 1,
213                    ))
214                }
215            }
216            Type::GroupType { basic_info, fields } => {
217                let name = Self::field_name(basic_info.name());
218                let optional =
219                    basic_info.has_repetition() && basic_info.repetition() == Repetition::OPTIONAL;
220                let new_def_depth = def_depth + if optional { 1 } else { 0 };
221
222                if let Some(element_type) =
223                    super::util::supported_logical_list_element_type(basic_info, fields)
224                {
225                    let (element_gen_field, new_current_column_index) = Self::from_type(
226                        config,
227                        &element_type,
228                        columns,
229                        current_column_index,
230                        &Self::field_name(element_type.get_basic_info().name()),
231                        rust_path,
232                        new_def_depth + 1,
233                        rep_depth + 1,
234                    )?;
235
236                    let element_struct_name =
237                        Self::field_type_name(&format!("{}_element", basic_info.name()));
238
239                    let element_type_name = match element_gen_field.gen_type {
240                        GenType::Column { .. } => element_gen_field.type_name(),
241                        GenType::Struct { .. } => {
242                            if element_gen_field.optional {
243                                format!("Option<{}>", element_struct_name)
244                            } else {
245                                element_struct_name.clone()
246                            }
247                        }
248                        GenType::List { .. } => element_gen_field.type_name(),
249                    };
250
251                    Ok((
252                        Self {
253                            name,
254                            base_type_name: format!("Vec<{}>", element_type_name),
255                            attributes: None,
256                            optional,
257                            gen_type: GenType::List {
258                                def_depth: new_def_depth + 1,
259                                rep_depth: rep_depth + 1,
260                                element_optional: element_gen_field.optional,
261                                element_gen_type: Box::new(element_gen_field.gen_type),
262                                element_struct_name,
263                            },
264                        },
265                        new_current_column_index,
266                    ))
267                } else if basic_info.logical_type() == Some(LogicalType::List)
268                    || (basic_info.has_repetition()
269                        && basic_info.repetition() == Repetition::REPEATED)
270                {
271                    Err(Error::UnsupportedRepetition(basic_info.name().to_string()))
272                } else {
273                    let mut gen_fields = vec![];
274                    let mut new_current_column_index = current_column_index;
275
276                    for field in fields {
277                        let name = Self::field_name(field.get_basic_info().name());
278                        let mut rust_path = rust_path.clone();
279                        rust_path.push((name.clone(), field.is_optional()));
280                        let (gen_field, column_index) = Self::from_type(
281                            config,
282                            field,
283                            columns,
284                            new_current_column_index,
285                            &name,
286                            rust_path,
287                            new_def_depth,
288                            rep_depth,
289                        )?;
290                        new_current_column_index = column_index;
291                        gen_fields.push(gen_field);
292                    }
293
294                    Ok((
295                        Self {
296                            name,
297                            base_type_name: Self::field_type_name(basic_info.name()),
298                            attributes: None,
299                            optional,
300                            gen_type: GenType::Struct {
301                                gen_fields,
302                                def_depth: new_def_depth,
303                                rep_depth,
304                            },
305                        },
306                        new_current_column_index,
307                    ))
308                }
309            }
310        }
311    }
312}
313
314impl GenType {
315    pub fn column_indices(&self) -> Range<usize> {
316        match self {
317            GenType::Column(GenColumn { index, .. }) => *index..*index + 1,
318            GenType::Struct { gen_fields, .. } => {
319                let mut start = usize::MAX;
320                let mut end = usize::MIN;
321
322                for gen_field in gen_fields {
323                    let range = gen_field.gen_type.column_indices();
324                    start = start.min(range.start);
325                    end = end.max(range.end);
326                }
327                start..end
328            }
329            GenType::List {
330                element_gen_type, ..
331            } => element_gen_type.column_indices(),
332        }
333    }
334
335    pub fn repeated_column_indices(&self) -> Vec<usize> {
336        match self {
337            GenType::Column(GenColumn {
338                index, descriptor, ..
339            }) => {
340                if descriptor.max_rep_level() > 0 {
341                    vec![*index]
342                } else {
343                    vec![]
344                }
345            }
346            GenType::Struct { gen_fields, .. } => {
347                let mut indices = vec![];
348
349                for gen_field in gen_fields {
350                    indices.extend(gen_field.gen_type.repeated_column_indices());
351                }
352
353                indices.sort();
354                indices.dedup();
355                indices
356            }
357            GenType::List {
358                element_gen_type, ..
359            } => element_gen_type.repeated_column_indices(),
360        }
361    }
362
363    fn disallowed_derives(&self) -> HashSet<&'static str> {
364        let mut values = HashSet::new();
365
366        match self {
367            GenType::Column(GenColumn { mapping, .. }) => {
368                values.extend(&mapping.disallowed_derives());
369            }
370            GenType::Struct { gen_fields, .. } => {
371                for gen_field in gen_fields {
372                    values.extend(gen_field.gen_type.disallowed_derives());
373                }
374            }
375            GenType::List {
376                element_gen_type, ..
377            } => {
378                values.insert("Copy");
379                values.extend(element_gen_type.disallowed_derives());
380            }
381        }
382
383        values
384    }
385
386    fn structs(&self, type_name: &str, base_derives: &[&'static str], acc: &mut Vec<GenStruct>) {
387        match self {
388            GenType::Column { .. } => {}
389            GenType::Struct { gen_fields, .. } => {
390                acc.push(GenStruct::new(
391                    type_name,
392                    gen_fields.clone(),
393                    base_derives,
394                    self.disallowed_derives(),
395                ));
396
397                for GenField {
398                    base_type_name,
399                    gen_type,
400                    ..
401                } in gen_fields
402                {
403                    gen_type.structs(base_type_name, base_derives, acc);
404                }
405            }
406            GenType::List {
407                element_gen_type,
408                element_struct_name,
409                ..
410            } => element_gen_type.structs(element_struct_name, base_derives, acc),
411        }
412    }
413
414    fn gen_columns(&self, acc: &mut Vec<GenColumn>) {
415        match self {
416            GenType::Column(column) => {
417                acc.push(column.clone());
418            }
419            GenType::Struct { gen_fields, .. } => {
420                for gen_field in gen_fields {
421                    gen_field.gen_type.gen_columns(acc);
422                }
423            }
424            GenType::List {
425                element_gen_type, ..
426            } => {
427                element_gen_type.gen_columns(acc);
428            }
429        }
430    }
431}
432
433impl GenColumn {
434    pub fn variant_name(&self) -> String {
435        self.rust_path.last().unwrap().0.to_case(Case::Pascal)
436    }
437
438    pub fn is_sort_column(&self) -> bool {
439        self.descriptor.max_rep_level() == 0
440    }
441}