1use std::collections::{HashMap, HashSet};
19use std::fmt::Display;
20use std::sync::Arc;
21
22use snafu::{ensure, OptionExt};
23
24use crate::error::{NoTypesSnafu, Result, UnexpectedSnafu};
25use crate::projection::ProjectionMask;
26use crate::proto;
27
28use arrow::datatypes::{DataType as ArrowDataType, Field, Schema, TimeUnit, UnionMode};
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
32pub enum TimestampPrecision {
33 Microsecond,
35 #[default]
37 Nanosecond,
38}
39
40#[derive(Debug, Clone)]
42pub struct ArrowSchemaOptions {
43 timestamp_precision: TimestampPrecision,
44}
45
46impl Default for ArrowSchemaOptions {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl ArrowSchemaOptions {
53 pub fn new() -> Self {
56 Self {
57 timestamp_precision: TimestampPrecision::default(),
58 }
59 }
60
61 pub fn with_timestamp_precision(mut self, precision: TimestampPrecision) -> Self {
69 self.timestamp_precision = precision;
70 self
71 }
72
73 fn timestamp_precision(&self) -> TimestampPrecision {
75 self.timestamp_precision
76 }
77}
78
79#[derive(Debug, Clone)]
91pub struct RootDataType {
92 children: Vec<NamedColumn>,
93 all_children: HashSet<usize>,
94}
95
96impl RootDataType {
97 pub fn column_index(&self) -> usize {
99 0
100 }
101
102 pub fn children(&self) -> &[NamedColumn] {
104 &self.children
105 }
106
107 pub fn contains_column_index(&self, index: usize) -> bool {
110 self.all_children.contains(&index)
111 }
112
113 pub fn create_arrow_schema(&self, user_metadata: &HashMap<String, String>) -> Schema {
115 self.create_arrow_schema_with_options(user_metadata, ArrowSchemaOptions::new())
116 }
117
118 pub fn create_arrow_schema_with_options(
120 &self,
121 user_metadata: &HashMap<String, String>,
122 options: ArrowSchemaOptions,
123 ) -> Schema {
124 let fields = self
125 .children
126 .iter()
127 .map(|col| {
128 let dt = col
129 .data_type()
130 .to_arrow_data_type_with_options(options.clone());
131 Field::new(col.name(), dt, true)
132 })
133 .collect::<Vec<_>>();
134 Schema::new_with_metadata(fields, user_metadata.clone())
135 }
136
137 pub fn project(&self, mask: &ProjectionMask) -> Self {
139 let children = self
141 .children
142 .iter()
143 .filter(|col| mask.is_index_projected(col.data_type().column_index()))
144 .map(|col| col.to_owned())
145 .collect::<Vec<_>>();
146 let all_children = get_all_children_indices_set(&children);
147 Self {
148 children,
149 all_children,
150 }
151 }
152
153 pub(crate) fn from_proto(types: &[proto::Type]) -> Result<Self> {
155 ensure!(!types.is_empty(), NoTypesSnafu {});
156 let children = parse_struct_children_from_proto(types, 0)?;
157 let all_children = get_all_children_indices_set(&children);
158 Ok(Self {
159 children,
160 all_children,
161 })
162 }
163}
164
165impl Display for RootDataType {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 write!(f, "ROOT")?;
168 for child in &self.children {
169 write!(f, "\n {child}")?;
170 }
171 Ok(())
172 }
173}
174
175fn get_all_children_indices_set(columns: &[NamedColumn]) -> HashSet<usize> {
176 let mut set = HashSet::new();
177 set.insert(0);
178 set.extend(columns.iter().flat_map(|c| c.data_type().all_indices()));
179 set
180}
181
182#[derive(Debug, Clone)]
183pub struct NamedColumn {
184 name: String,
185 data_type: DataType,
186}
187
188impl NamedColumn {
189 pub fn name(&self) -> &str {
190 self.name.as_str()
191 }
192
193 pub fn data_type(&self) -> &DataType {
194 &self.data_type
195 }
196}
197
198impl Display for NamedColumn {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 write!(f, "{} {}", self.name(), self.data_type())
201 }
202}
203
204fn parse_struct_children_from_proto(
207 types: &[proto::Type],
208 column_index: usize,
209) -> Result<Vec<NamedColumn>> {
210 assert!(column_index < types.len());
212 let ty = &types[column_index];
213 assert!(ty.kind() == proto::r#type::Kind::Struct);
214 ensure!(
215 ty.subtypes.len() == ty.field_names.len(),
216 UnexpectedSnafu {
217 msg: format!(
218 "Struct type for column index {column_index} must have matching lengths for subtypes and field names lists"
219 )
220 }
221 );
222 let children = ty
223 .subtypes
224 .iter()
225 .zip(ty.field_names.iter())
226 .map(|(&index, name)| {
227 let index = index as usize;
228 let name = name.to_owned();
229 let data_type = DataType::from_proto(types, index)?;
230 Ok(NamedColumn { name, data_type })
231 })
232 .collect::<Result<Vec<_>>>()?;
233 Ok(children)
234}
235
236#[derive(Debug, Clone)]
241pub enum DataType {
242 Boolean { column_index: usize },
244 Byte { column_index: usize },
246 Short { column_index: usize },
248 Int { column_index: usize },
250 Long { column_index: usize },
252 Float { column_index: usize },
254 Double { column_index: usize },
256 String { column_index: usize },
258 Varchar {
260 column_index: usize,
261 max_length: u32,
262 },
263 Char {
265 column_index: usize,
266 max_length: u32,
267 },
268 Binary { column_index: usize },
270 Decimal {
272 column_index: usize,
273 precision: u32,
275 scale: u32,
276 },
277 Timestamp { column_index: usize },
283 TimestampWithLocalTimezone { column_index: usize },
289 Date { column_index: usize },
292 Struct {
295 column_index: usize,
296 children: Vec<NamedColumn>,
297 },
298 List {
301 column_index: usize,
302 child: Box<DataType>,
303 },
304 Map {
307 column_index: usize,
308 key: Box<DataType>,
309 value: Box<DataType>,
310 },
311 Union {
317 column_index: usize,
318 variants: Vec<DataType>,
319 },
320}
321
322impl DataType {
323 pub fn column_index(&self) -> usize {
326 match self {
327 DataType::Boolean { column_index } => *column_index,
328 DataType::Byte { column_index } => *column_index,
329 DataType::Short { column_index } => *column_index,
330 DataType::Int { column_index } => *column_index,
331 DataType::Long { column_index } => *column_index,
332 DataType::Float { column_index } => *column_index,
333 DataType::Double { column_index } => *column_index,
334 DataType::String { column_index } => *column_index,
335 DataType::Varchar { column_index, .. } => *column_index,
336 DataType::Char { column_index, .. } => *column_index,
337 DataType::Binary { column_index } => *column_index,
338 DataType::Decimal { column_index, .. } => *column_index,
339 DataType::Timestamp { column_index } => *column_index,
340 DataType::TimestampWithLocalTimezone { column_index } => *column_index,
341 DataType::Date { column_index } => *column_index,
342 DataType::Struct { column_index, .. } => *column_index,
343 DataType::List { column_index, .. } => *column_index,
344 DataType::Map { column_index, .. } => *column_index,
345 DataType::Union { column_index, .. } => *column_index,
346 }
347 }
348
349 pub fn children_indices(&self) -> Vec<usize> {
351 match self {
352 DataType::Boolean { .. }
353 | DataType::Byte { .. }
354 | DataType::Short { .. }
355 | DataType::Int { .. }
356 | DataType::Long { .. }
357 | DataType::Float { .. }
358 | DataType::Double { .. }
359 | DataType::String { .. }
360 | DataType::Varchar { .. }
361 | DataType::Char { .. }
362 | DataType::Binary { .. }
363 | DataType::Decimal { .. }
364 | DataType::Timestamp { .. }
365 | DataType::TimestampWithLocalTimezone { .. }
366 | DataType::Date { .. } => vec![],
367 DataType::Struct { children, .. } => children
368 .iter()
369 .flat_map(|col| col.data_type().all_indices())
370 .collect(),
371 DataType::List { child, .. } => child.all_indices(),
372 DataType::Map { key, value, .. } => {
373 let mut indices = key.all_indices();
374 indices.extend(value.all_indices());
375 indices
376 }
377 DataType::Union { variants, .. } => {
378 variants.iter().flat_map(|dt| dt.all_indices()).collect()
379 }
380 }
381 }
382
383 pub fn all_indices(&self) -> Vec<usize> {
385 let mut indices = vec![self.column_index()];
386 indices.extend(self.children_indices());
387 indices
388 }
389
390 fn from_proto(types: &[proto::Type], column_index: usize) -> Result<Self> {
391 use proto::r#type::Kind;
392
393 let ty = types.get(column_index).context(UnexpectedSnafu {
394 msg: format!("Column index out of bounds: {column_index}"),
395 })?;
396 let dt = match ty.kind() {
397 Kind::Boolean => Self::Boolean { column_index },
398 Kind::Byte => Self::Byte { column_index },
399 Kind::Short => Self::Short { column_index },
400 Kind::Int => Self::Int { column_index },
401 Kind::Long => Self::Long { column_index },
402 Kind::Float => Self::Float { column_index },
403 Kind::Double => Self::Double { column_index },
404 Kind::String => Self::String { column_index },
405 Kind::Binary => Self::Binary { column_index },
406 Kind::Timestamp => Self::Timestamp { column_index },
407 Kind::List => {
408 ensure!(
409 ty.subtypes.len() == 1,
410 UnexpectedSnafu {
411 msg: format!(
412 "List type for column index {} must have 1 sub type, found {}",
413 column_index,
414 ty.subtypes.len()
415 )
416 }
417 );
418 let child = ty.subtypes[0] as usize;
419 let child = Box::new(Self::from_proto(types, child)?);
420 Self::List {
421 column_index,
422 child,
423 }
424 }
425 Kind::Map => {
426 ensure!(
427 ty.subtypes.len() == 2,
428 UnexpectedSnafu {
429 msg: format!(
430 "Map type for column index {} must have 2 sub types, found {}",
431 column_index,
432 ty.subtypes.len()
433 )
434 }
435 );
436 let key = ty.subtypes[0] as usize;
437 let key = Box::new(Self::from_proto(types, key)?);
438 let value = ty.subtypes[1] as usize;
439 let value = Box::new(Self::from_proto(types, value)?);
440 Self::Map {
441 column_index,
442 key,
443 value,
444 }
445 }
446 Kind::Struct => {
447 let children = parse_struct_children_from_proto(types, column_index)?;
448 Self::Struct {
449 column_index,
450 children,
451 }
452 }
453 Kind::Union => {
454 ensure!(
456 ty.subtypes.len() <= 127,
457 UnexpectedSnafu {
458 msg: format!(
459 "Union type for column index {} cannot exceed 127 variants, found {}",
460 column_index,
461 ty.subtypes.len()
462 )
463 }
464 );
465 let variants = ty
466 .subtypes
467 .iter()
468 .map(|&index| {
469 let index = index as usize;
470 Self::from_proto(types, index)
471 })
472 .collect::<Result<Vec<_>>>()?;
473 Self::Union {
474 column_index,
475 variants,
476 }
477 }
478 Kind::Decimal => Self::Decimal {
479 column_index,
480 precision: ty.precision(),
481 scale: ty.scale(),
482 },
483 Kind::Date => Self::Date { column_index },
484 Kind::Varchar => Self::Varchar {
485 column_index,
486 max_length: ty.maximum_length(),
487 },
488 Kind::Char => Self::Char {
489 column_index,
490 max_length: ty.maximum_length(),
491 },
492 Kind::TimestampInstant => Self::TimestampWithLocalTimezone { column_index },
493 };
494 Ok(dt)
495 }
496
497 pub fn to_arrow_data_type(&self) -> ArrowDataType {
499 self.to_arrow_data_type_with_options(ArrowSchemaOptions::new())
500 }
501
502 pub fn to_arrow_data_type_with_options(&self, options: ArrowSchemaOptions) -> ArrowDataType {
504 let timestamp_precision = options.timestamp_precision();
505 let time_unit = match timestamp_precision {
506 TimestampPrecision::Microsecond => TimeUnit::Microsecond,
507 TimestampPrecision::Nanosecond => TimeUnit::Nanosecond,
508 };
509
510 match self {
511 DataType::Boolean { .. } => ArrowDataType::Boolean,
512 DataType::Byte { .. } => ArrowDataType::Int8,
513 DataType::Short { .. } => ArrowDataType::Int16,
514 DataType::Int { .. } => ArrowDataType::Int32,
515 DataType::Long { .. } => ArrowDataType::Int64,
516 DataType::Float { .. } => ArrowDataType::Float32,
517 DataType::Double { .. } => ArrowDataType::Float64,
518 DataType::String { .. } | DataType::Varchar { .. } | DataType::Char { .. } => {
519 ArrowDataType::Utf8
520 }
521 DataType::Binary { .. } => ArrowDataType::Binary,
522 DataType::Decimal {
523 precision, scale, ..
524 } => ArrowDataType::Decimal128(*precision as u8, *scale as i8), DataType::Timestamp { .. } => ArrowDataType::Timestamp(time_unit, None),
526 DataType::TimestampWithLocalTimezone { .. } => {
527 ArrowDataType::Timestamp(time_unit, Some("UTC".into()))
528 }
529 DataType::Date { .. } => ArrowDataType::Date32,
530 DataType::Struct { children, .. } => {
531 let children = children
532 .iter()
533 .map(|col| {
534 let dt = col
535 .data_type()
536 .to_arrow_data_type_with_options(options.clone());
537 Field::new(col.name(), dt, true)
538 })
539 .collect();
540 ArrowDataType::Struct(children)
541 }
542 DataType::List { child, .. } => {
543 let child = child.to_arrow_data_type_with_options(options);
544 ArrowDataType::new_list(child, true)
545 }
546 DataType::Map { key, value, .. } => {
547 let key = key.to_arrow_data_type_with_options(options.clone());
552 let key = Field::new("keys", key, false);
553 let value = value.to_arrow_data_type_with_options(options);
554 let value = Field::new("values", value, true);
555
556 let dt = ArrowDataType::Struct(vec![key, value].into());
557 let dt = Arc::new(Field::new("entries", dt, false));
558 ArrowDataType::Map(dt, false)
559 }
560 DataType::Union { variants, .. } => {
561 let fields = variants
562 .iter()
563 .enumerate()
564 .map(|(index, variant)| {
565 let index = index as u8 as i8;
569 let arrow_dt = variant.to_arrow_data_type_with_options(options.clone());
570 let field = Arc::new(Field::new(format!("_union_{index}"), arrow_dt, true));
573 (index, field)
574 })
575 .collect();
576 ArrowDataType::Union(fields, UnionMode::Sparse)
577 }
578 }
579 }
580}
581
582impl Display for DataType {
583 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
584 match self {
585 DataType::Boolean { column_index: _ } => write!(f, "BOOLEAN"),
586 DataType::Byte { column_index: _ } => write!(f, "BYTE"),
587 DataType::Short { column_index: _ } => write!(f, "SHORT"),
588 DataType::Int { column_index: _ } => write!(f, "INTEGER"),
589 DataType::Long { column_index: _ } => write!(f, "LONG"),
590 DataType::Float { column_index: _ } => write!(f, "FLOAT"),
591 DataType::Double { column_index: _ } => write!(f, "DOUBLE"),
592 DataType::String { column_index: _ } => write!(f, "STRING"),
593 DataType::Varchar {
594 column_index: _,
595 max_length,
596 } => write!(f, "VARCHAR({max_length})"),
597 DataType::Char {
598 column_index: _,
599 max_length,
600 } => write!(f, "CHAR({max_length})"),
601 DataType::Binary { column_index: _ } => write!(f, "BINARY"),
602 DataType::Decimal {
603 column_index: _,
604 precision,
605 scale,
606 } => write!(f, "DECIMAL({precision}, {scale})"),
607 DataType::Timestamp { column_index: _ } => write!(f, "TIMESTAMP"),
608 DataType::TimestampWithLocalTimezone { column_index: _ } => {
609 write!(f, "TIMESTAMP INSTANT")
610 }
611 DataType::Date { column_index: _ } => write!(f, "DATE"),
612 DataType::Struct {
613 column_index: _,
614 children,
615 } => {
616 write!(f, "STRUCT")?;
617 for child in children {
618 write!(f, "\n {child}")?;
619 }
620 Ok(())
621 }
622 DataType::List {
623 column_index: _,
624 child,
625 } => write!(f, "LIST\n {child}"),
626 DataType::Map {
627 column_index: _,
628 key,
629 value,
630 } => write!(f, "MAP\n {key}\n {value}"),
631 DataType::Union {
632 column_index: _,
633 variants,
634 } => {
635 write!(f, "UNION")?;
636 for variant in variants {
637 write!(f, "\n {variant}")?;
638 }
639 Ok(())
640 }
641 }
642 }
643}