1use convert_case::{Case, Casing};
2use parquet::{
3 basic::{LogicalType, Repetition},
4 schema::types::{ColumnDescPtr, SchemaDescriptor, Type},
5};
6use std::collections::HashSet;
7use std::ops::Range;
8
9use crate::types::TypeMapping;
10
11use super::{error::Error, Config};
12
13#[derive(Clone, Debug)]
14pub struct GenSchema {
15 pub type_name: String,
16 pub gen_fields: Vec<GenField>,
17 pub config: Config,
18}
19
20#[derive(Clone, Debug)]
21pub struct GenField {
22 pub name: String,
23 pub base_type_name: String,
24 pub attributes: Option<String>,
25 pub optional: bool,
26 pub gen_type: GenType,
27}
28
29#[derive(Clone, Debug)]
30pub enum GenType {
31 Column(GenColumn),
32 Struct {
33 gen_fields: Vec<GenField>,
34 def_depth: usize,
35 rep_depth: usize,
36 },
37 List {
38 element_optional: bool,
39 element_gen_type: Box<GenType>,
40 element_struct_name: String,
41 def_depth: usize,
42 rep_depth: usize,
43 },
44}
45
46#[derive(Clone, Debug)]
47pub struct GenStruct {
48 pub type_name: String,
49 pub fields: Vec<GenField>,
50 pub derives: Vec<&'static str>,
51}
52
53#[derive(Clone, Debug)]
54pub struct GenColumn {
55 pub index: usize,
56 pub rust_path: Vec<(String, bool)>,
57 pub descriptor: ColumnDescPtr,
58 pub mapping: TypeMapping,
59}
60
61impl GenStruct {
62 fn new(
63 type_name: &str,
64 fields: Vec<GenField>,
65 base_derives: &[&'static str],
66 disallowed_derives: HashSet<&str>,
67 ) -> Self {
68 let derives = base_derives
69 .iter()
70 .cloned()
71 .filter(|value| !disallowed_derives.contains(value))
72 .collect::<Vec<_>>();
73
74 Self {
75 type_name: type_name.to_string(),
76 fields,
77 derives,
78 }
79 }
80}
81
82impl GenSchema {
83 pub fn from_schema(schema: &SchemaDescriptor, config: Config) -> Result<Self, Error> {
84 if let GenField {
85 base_type_name,
86 gen_type: GenType::Struct { gen_fields, .. },
87 ..
88 } = GenField::from_type(
89 &config,
90 schema.root_schema(),
91 schema.columns(),
92 0,
93 "",
94 vec![],
95 0,
96 0,
97 )?
98 .0
99 {
100 Ok(Self {
101 type_name: base_type_name,
102 gen_fields,
103 config,
104 })
105 } else {
106 Err(Error::InvalidRootSchema(schema.root_schema().clone()))
107 }
108 }
109
110 pub fn field_names(&self) -> Vec<&str> {
111 self.gen_fields
112 .iter()
113 .map(|gen_field| gen_field.name.as_str())
114 .collect()
115 }
116
117 pub fn structs(&self) -> Vec<GenStruct> {
118 let disallowed_derives = self
119 .gen_fields
120 .iter()
121 .flat_map(|gen_field| gen_field.gen_type.disallowed_derives())
122 .collect();
123
124 let mut structs = vec![GenStruct::new(
125 &self.type_name,
126 self.gen_fields.clone(),
127 &self.config.derives(),
128 disallowed_derives,
129 )];
130
131 for gen_field in &self.gen_fields {
132 gen_field.gen_type.structs(
133 &gen_field.base_type_name,
134 &self.config.derives(),
135 &mut structs,
136 );
137 }
138
139 structs
140 }
141
142 pub fn gen_columns(&self) -> Vec<GenColumn> {
143 let mut gen_columns = vec![];
144
145 for gen_field in &self.gen_fields {
146 gen_field.gen_type.gen_columns(&mut gen_columns);
147 }
148
149 gen_columns
150 }
151}
152
153impl GenField {
154 pub fn type_name(&self) -> String {
155 if self.optional {
156 format!("Option<{}>", self.base_type_name)
157 } else {
158 self.base_type_name.to_string()
159 }
160 }
161
162 fn field_name(source_name: &str) -> String {
163 source_name.to_string()
164 }
165
166 fn field_type_name(source_name: &str) -> String {
167 source_name.to_case(Case::Pascal)
168 }
169
170 fn from_type(
171 config: &Config,
172 tp: &Type,
173 columns: &[ColumnDescPtr],
174 current_column_index: usize,
175 name: &str,
176 rust_path: Vec<(String, bool)>,
177 def_depth: usize,
178 rep_depth: usize,
179 ) -> Result<(Self, usize), Error> {
180 match tp {
181 Type::PrimitiveType {
182 basic_info,
183 physical_type,
184 type_length,
185 ..
186 } => {
187 if basic_info.repetition() == Repetition::REPEATED {
189 Err(Error::UnsupportedRepetition(basic_info.name().to_string()))
190 } else {
191 let column = columns[current_column_index].clone();
192 let mapping = super::types::TypeMapping::from_types(
193 column.logical_type(),
194 *physical_type,
195 *type_length,
196 )?;
197 let optional = basic_info.repetition() == Repetition::OPTIONAL;
198
199 Ok((
200 Self {
201 name: name.to_string(),
202 base_type_name: mapping.rust_type_name().to_string(),
203 attributes: mapping.attributes(config.serde_support, optional),
204 optional,
205 gen_type: GenType::Column(GenColumn {
206 index: current_column_index,
207 rust_path,
208 descriptor: column,
209 mapping,
210 }),
211 },
212 current_column_index + 1,
213 ))
214 }
215 }
216 Type::GroupType { basic_info, fields } => {
217 let name = Self::field_name(basic_info.name());
218 let optional =
219 basic_info.has_repetition() && basic_info.repetition() == Repetition::OPTIONAL;
220 let new_def_depth = def_depth + if optional { 1 } else { 0 };
221
222 if let Some(element_type) =
223 super::util::supported_logical_list_element_type(basic_info, fields)
224 {
225 let (element_gen_field, new_current_column_index) = Self::from_type(
226 config,
227 &element_type,
228 columns,
229 current_column_index,
230 &Self::field_name(element_type.get_basic_info().name()),
231 rust_path,
232 new_def_depth + 1,
233 rep_depth + 1,
234 )?;
235
236 let element_struct_name =
237 Self::field_type_name(&format!("{}_element", basic_info.name()));
238
239 let element_type_name = match element_gen_field.gen_type {
240 GenType::Column { .. } => element_gen_field.type_name(),
241 GenType::Struct { .. } => {
242 if element_gen_field.optional {
243 format!("Option<{}>", element_struct_name)
244 } else {
245 element_struct_name.clone()
246 }
247 }
248 GenType::List { .. } => element_gen_field.type_name(),
249 };
250
251 Ok((
252 Self {
253 name,
254 base_type_name: format!("Vec<{}>", element_type_name),
255 attributes: None,
256 optional,
257 gen_type: GenType::List {
258 def_depth: new_def_depth + 1,
259 rep_depth: rep_depth + 1,
260 element_optional: element_gen_field.optional,
261 element_gen_type: Box::new(element_gen_field.gen_type),
262 element_struct_name,
263 },
264 },
265 new_current_column_index,
266 ))
267 } else if basic_info.logical_type() == Some(LogicalType::List)
268 || (basic_info.has_repetition()
269 && basic_info.repetition() == Repetition::REPEATED)
270 {
271 Err(Error::UnsupportedRepetition(basic_info.name().to_string()))
272 } else {
273 let mut gen_fields = vec![];
274 let mut new_current_column_index = current_column_index;
275
276 for field in fields {
277 let name = Self::field_name(field.get_basic_info().name());
278 let mut rust_path = rust_path.clone();
279 rust_path.push((name.clone(), field.is_optional()));
280 let (gen_field, column_index) = Self::from_type(
281 config,
282 field,
283 columns,
284 new_current_column_index,
285 &name,
286 rust_path,
287 new_def_depth,
288 rep_depth,
289 )?;
290 new_current_column_index = column_index;
291 gen_fields.push(gen_field);
292 }
293
294 Ok((
295 Self {
296 name,
297 base_type_name: Self::field_type_name(basic_info.name()),
298 attributes: None,
299 optional,
300 gen_type: GenType::Struct {
301 gen_fields,
302 def_depth: new_def_depth,
303 rep_depth,
304 },
305 },
306 new_current_column_index,
307 ))
308 }
309 }
310 }
311 }
312}
313
314impl GenType {
315 pub fn column_indices(&self) -> Range<usize> {
316 match self {
317 GenType::Column(GenColumn { index, .. }) => *index..*index + 1,
318 GenType::Struct { gen_fields, .. } => {
319 let mut start = usize::MAX;
320 let mut end = usize::MIN;
321
322 for gen_field in gen_fields {
323 let range = gen_field.gen_type.column_indices();
324 start = start.min(range.start);
325 end = end.max(range.end);
326 }
327 start..end
328 }
329 GenType::List {
330 element_gen_type, ..
331 } => element_gen_type.column_indices(),
332 }
333 }
334
335 pub fn repeated_column_indices(&self) -> Vec<usize> {
336 match self {
337 GenType::Column(GenColumn {
338 index, descriptor, ..
339 }) => {
340 if descriptor.max_rep_level() > 0 {
341 vec![*index]
342 } else {
343 vec![]
344 }
345 }
346 GenType::Struct { gen_fields, .. } => {
347 let mut indices = vec![];
348
349 for gen_field in gen_fields {
350 indices.extend(gen_field.gen_type.repeated_column_indices());
351 }
352
353 indices.sort();
354 indices.dedup();
355 indices
356 }
357 GenType::List {
358 element_gen_type, ..
359 } => element_gen_type.repeated_column_indices(),
360 }
361 }
362
363 fn disallowed_derives(&self) -> HashSet<&'static str> {
364 let mut values = HashSet::new();
365
366 match self {
367 GenType::Column(GenColumn { mapping, .. }) => {
368 values.extend(&mapping.disallowed_derives());
369 }
370 GenType::Struct { gen_fields, .. } => {
371 for gen_field in gen_fields {
372 values.extend(gen_field.gen_type.disallowed_derives());
373 }
374 }
375 GenType::List {
376 element_gen_type, ..
377 } => {
378 values.insert("Copy");
379 values.extend(element_gen_type.disallowed_derives());
380 }
381 }
382
383 values
384 }
385
386 fn structs(&self, type_name: &str, base_derives: &[&'static str], acc: &mut Vec<GenStruct>) {
387 match self {
388 GenType::Column { .. } => {}
389 GenType::Struct { gen_fields, .. } => {
390 acc.push(GenStruct::new(
391 type_name,
392 gen_fields.clone(),
393 base_derives,
394 self.disallowed_derives(),
395 ));
396
397 for GenField {
398 base_type_name,
399 gen_type,
400 ..
401 } in gen_fields
402 {
403 gen_type.structs(base_type_name, base_derives, acc);
404 }
405 }
406 GenType::List {
407 element_gen_type,
408 element_struct_name,
409 ..
410 } => element_gen_type.structs(element_struct_name, base_derives, acc),
411 }
412 }
413
414 fn gen_columns(&self, acc: &mut Vec<GenColumn>) {
415 match self {
416 GenType::Column(column) => {
417 acc.push(column.clone());
418 }
419 GenType::Struct { gen_fields, .. } => {
420 for gen_field in gen_fields {
421 gen_field.gen_type.gen_columns(acc);
422 }
423 }
424 GenType::List {
425 element_gen_type, ..
426 } => {
427 element_gen_type.gen_columns(acc);
428 }
429 }
430 }
431}
432
433impl GenColumn {
434 pub fn variant_name(&self) -> String {
435 self.rust_path.last().unwrap().0.to_case(Case::Pascal)
436 }
437
438 pub fn is_sort_column(&self) -> bool {
439 self.descriptor.max_rep_level() == 0
440 }
441}