1use std::collections::{HashMap, HashSet};
19use std::fmt::Display;
20use std::sync::Arc;
21
22use snafu::{ensure, OptionExt};
23
24use crate::error::{NoTypesSnafu, Result, UnexpectedSnafu};
25use crate::projection::ProjectionMask;
26use crate::proto;
27
28use arrow::datatypes::{DataType as ArrowDataType, Field, Schema, TimeUnit, UnionMode};
29
30#[derive(Debug, Clone)]
42pub struct RootDataType {
43 children: Vec<NamedColumn>,
44 all_children: HashSet<usize>,
45}
46
47impl RootDataType {
48 pub fn column_index(&self) -> usize {
50 0
51 }
52
53 pub fn children(&self) -> &[NamedColumn] {
55 &self.children
56 }
57
58 pub fn contains_column_index(&self, index: usize) -> bool {
61 self.all_children.contains(&index)
62 }
63
64 pub fn create_arrow_schema(&self, user_metadata: &HashMap<String, String>) -> Schema {
66 let fields = self
67 .children
68 .iter()
69 .map(|col| {
70 let dt = col.data_type().to_arrow_data_type();
71 Field::new(col.name(), dt, true)
72 })
73 .collect::<Vec<_>>();
74 Schema::new_with_metadata(fields, user_metadata.clone())
75 }
76
77 pub fn project(&self, mask: &ProjectionMask) -> Self {
79 let children = self
81 .children
82 .iter()
83 .filter(|col| mask.is_index_projected(col.data_type().column_index()))
84 .map(|col| col.to_owned())
85 .collect::<Vec<_>>();
86 let all_children = get_all_children_indices_set(&children);
87 Self {
88 children,
89 all_children,
90 }
91 }
92
93 pub(crate) fn from_proto(types: &[proto::Type]) -> Result<Self> {
95 ensure!(!types.is_empty(), NoTypesSnafu {});
96 let children = parse_struct_children_from_proto(types, 0)?;
97 let all_children = get_all_children_indices_set(&children);
98 Ok(Self {
99 children,
100 all_children,
101 })
102 }
103}
104
105impl Display for RootDataType {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 write!(f, "ROOT")?;
108 for child in &self.children {
109 write!(f, "\n {child}")?;
110 }
111 Ok(())
112 }
113}
114
115fn get_all_children_indices_set(columns: &[NamedColumn]) -> HashSet<usize> {
116 let mut set = HashSet::new();
117 set.insert(0);
118 set.extend(columns.iter().flat_map(|c| c.data_type().all_indices()));
119 set
120}
121
122#[derive(Debug, Clone)]
123pub struct NamedColumn {
124 name: String,
125 data_type: DataType,
126}
127
128impl NamedColumn {
129 pub fn name(&self) -> &str {
130 self.name.as_str()
131 }
132
133 pub fn data_type(&self) -> &DataType {
134 &self.data_type
135 }
136}
137
138impl Display for NamedColumn {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 write!(f, "{} {}", self.name(), self.data_type())
141 }
142}
143
144fn parse_struct_children_from_proto(
147 types: &[proto::Type],
148 column_index: usize,
149) -> Result<Vec<NamedColumn>> {
150 assert!(column_index < types.len());
152 let ty = &types[column_index];
153 assert!(ty.kind() == proto::r#type::Kind::Struct);
154 ensure!(
155 ty.subtypes.len() == ty.field_names.len(),
156 UnexpectedSnafu {
157 msg: format!(
158 "Struct type for column index {column_index} must have matching lengths for subtypes and field names lists"
159 )
160 }
161 );
162 let children = ty
163 .subtypes
164 .iter()
165 .zip(ty.field_names.iter())
166 .map(|(&index, name)| {
167 let index = index as usize;
168 let name = name.to_owned();
169 let data_type = DataType::from_proto(types, index)?;
170 Ok(NamedColumn { name, data_type })
171 })
172 .collect::<Result<Vec<_>>>()?;
173 Ok(children)
174}
175
176#[derive(Debug, Clone)]
181pub enum DataType {
182 Boolean { column_index: usize },
184 Byte { column_index: usize },
186 Short { column_index: usize },
188 Int { column_index: usize },
190 Long { column_index: usize },
192 Float { column_index: usize },
194 Double { column_index: usize },
196 String { column_index: usize },
198 Varchar {
200 column_index: usize,
201 max_length: u32,
202 },
203 Char {
205 column_index: usize,
206 max_length: u32,
207 },
208 Binary { column_index: usize },
210 Decimal {
212 column_index: usize,
213 precision: u32,
215 scale: u32,
216 },
217 Timestamp { column_index: usize },
223 TimestampWithLocalTimezone { column_index: usize },
229 Date { column_index: usize },
232 Struct {
235 column_index: usize,
236 children: Vec<NamedColumn>,
237 },
238 List {
241 column_index: usize,
242 child: Box<DataType>,
243 },
244 Map {
247 column_index: usize,
248 key: Box<DataType>,
249 value: Box<DataType>,
250 },
251 Union {
257 column_index: usize,
258 variants: Vec<DataType>,
259 },
260}
261
262impl DataType {
263 pub fn column_index(&self) -> usize {
266 match self {
267 DataType::Boolean { column_index } => *column_index,
268 DataType::Byte { column_index } => *column_index,
269 DataType::Short { column_index } => *column_index,
270 DataType::Int { column_index } => *column_index,
271 DataType::Long { column_index } => *column_index,
272 DataType::Float { column_index } => *column_index,
273 DataType::Double { column_index } => *column_index,
274 DataType::String { column_index } => *column_index,
275 DataType::Varchar { column_index, .. } => *column_index,
276 DataType::Char { column_index, .. } => *column_index,
277 DataType::Binary { column_index } => *column_index,
278 DataType::Decimal { column_index, .. } => *column_index,
279 DataType::Timestamp { column_index } => *column_index,
280 DataType::TimestampWithLocalTimezone { column_index } => *column_index,
281 DataType::Date { column_index } => *column_index,
282 DataType::Struct { column_index, .. } => *column_index,
283 DataType::List { column_index, .. } => *column_index,
284 DataType::Map { column_index, .. } => *column_index,
285 DataType::Union { column_index, .. } => *column_index,
286 }
287 }
288
289 pub fn children_indices(&self) -> Vec<usize> {
291 match self {
292 DataType::Boolean { .. }
293 | DataType::Byte { .. }
294 | DataType::Short { .. }
295 | DataType::Int { .. }
296 | DataType::Long { .. }
297 | DataType::Float { .. }
298 | DataType::Double { .. }
299 | DataType::String { .. }
300 | DataType::Varchar { .. }
301 | DataType::Char { .. }
302 | DataType::Binary { .. }
303 | DataType::Decimal { .. }
304 | DataType::Timestamp { .. }
305 | DataType::TimestampWithLocalTimezone { .. }
306 | DataType::Date { .. } => vec![],
307 DataType::Struct { children, .. } => children
308 .iter()
309 .flat_map(|col| col.data_type().all_indices())
310 .collect(),
311 DataType::List { child, .. } => child.all_indices(),
312 DataType::Map { key, value, .. } => {
313 let mut indices = key.all_indices();
314 indices.extend(value.all_indices());
315 indices
316 }
317 DataType::Union { variants, .. } => {
318 variants.iter().flat_map(|dt| dt.all_indices()).collect()
319 }
320 }
321 }
322
323 pub fn all_indices(&self) -> Vec<usize> {
325 let mut indices = vec![self.column_index()];
326 indices.extend(self.children_indices());
327 indices
328 }
329
330 fn from_proto(types: &[proto::Type], column_index: usize) -> Result<Self> {
331 use proto::r#type::Kind;
332
333 let ty = types.get(column_index).context(UnexpectedSnafu {
334 msg: format!("Column index out of bounds: {column_index}"),
335 })?;
336 let dt = match ty.kind() {
337 Kind::Boolean => Self::Boolean { column_index },
338 Kind::Byte => Self::Byte { column_index },
339 Kind::Short => Self::Short { column_index },
340 Kind::Int => Self::Int { column_index },
341 Kind::Long => Self::Long { column_index },
342 Kind::Float => Self::Float { column_index },
343 Kind::Double => Self::Double { column_index },
344 Kind::String => Self::String { column_index },
345 Kind::Binary => Self::Binary { column_index },
346 Kind::Timestamp => Self::Timestamp { column_index },
347 Kind::List => {
348 ensure!(
349 ty.subtypes.len() == 1,
350 UnexpectedSnafu {
351 msg: format!(
352 "List type for column index {} must have 1 sub type, found {}",
353 column_index,
354 ty.subtypes.len()
355 )
356 }
357 );
358 let child = ty.subtypes[0] as usize;
359 let child = Box::new(Self::from_proto(types, child)?);
360 Self::List {
361 column_index,
362 child,
363 }
364 }
365 Kind::Map => {
366 ensure!(
367 ty.subtypes.len() == 2,
368 UnexpectedSnafu {
369 msg: format!(
370 "Map type for column index {} must have 2 sub types, found {}",
371 column_index,
372 ty.subtypes.len()
373 )
374 }
375 );
376 let key = ty.subtypes[0] as usize;
377 let key = Box::new(Self::from_proto(types, key)?);
378 let value = ty.subtypes[1] as usize;
379 let value = Box::new(Self::from_proto(types, value)?);
380 Self::Map {
381 column_index,
382 key,
383 value,
384 }
385 }
386 Kind::Struct => {
387 let children = parse_struct_children_from_proto(types, column_index)?;
388 Self::Struct {
389 column_index,
390 children,
391 }
392 }
393 Kind::Union => {
394 ensure!(
396 ty.subtypes.len() <= 127,
397 UnexpectedSnafu {
398 msg: format!(
399 "Union type for column index {} cannot exceed 127 variants, found {}",
400 column_index,
401 ty.subtypes.len()
402 )
403 }
404 );
405 let variants = ty
406 .subtypes
407 .iter()
408 .map(|&index| {
409 let index = index as usize;
410 Self::from_proto(types, index)
411 })
412 .collect::<Result<Vec<_>>>()?;
413 Self::Union {
414 column_index,
415 variants,
416 }
417 }
418 Kind::Decimal => Self::Decimal {
419 column_index,
420 precision: ty.precision(),
421 scale: ty.scale(),
422 },
423 Kind::Date => Self::Date { column_index },
424 Kind::Varchar => Self::Varchar {
425 column_index,
426 max_length: ty.maximum_length(),
427 },
428 Kind::Char => Self::Char {
429 column_index,
430 max_length: ty.maximum_length(),
431 },
432 Kind::TimestampInstant => Self::TimestampWithLocalTimezone { column_index },
433 };
434 Ok(dt)
435 }
436
437 pub fn to_arrow_data_type(&self) -> ArrowDataType {
438 match self {
439 DataType::Boolean { .. } => ArrowDataType::Boolean,
440 DataType::Byte { .. } => ArrowDataType::Int8,
441 DataType::Short { .. } => ArrowDataType::Int16,
442 DataType::Int { .. } => ArrowDataType::Int32,
443 DataType::Long { .. } => ArrowDataType::Int64,
444 DataType::Float { .. } => ArrowDataType::Float32,
445 DataType::Double { .. } => ArrowDataType::Float64,
446 DataType::String { .. } | DataType::Varchar { .. } | DataType::Char { .. } => {
447 ArrowDataType::Utf8
448 }
449 DataType::Binary { .. } => ArrowDataType::Binary,
450 DataType::Decimal {
451 precision, scale, ..
452 } => ArrowDataType::Decimal128(*precision as u8, *scale as i8), DataType::Timestamp { .. } => ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
454 DataType::TimestampWithLocalTimezone { .. } => {
455 ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into()))
456 }
457 DataType::Date { .. } => ArrowDataType::Date32,
458 DataType::Struct { children, .. } => {
459 let children = children
460 .iter()
461 .map(|col| {
462 let dt = col.data_type().to_arrow_data_type();
463 Field::new(col.name(), dt, true)
464 })
465 .collect();
466 ArrowDataType::Struct(children)
467 }
468 DataType::List { child, .. } => {
469 let child = child.to_arrow_data_type();
470 ArrowDataType::new_list(child, true)
471 }
472 DataType::Map { key, value, .. } => {
473 let key = key.to_arrow_data_type();
478 let key = Field::new("keys", key, false);
479 let value = value.to_arrow_data_type();
480 let value = Field::new("values", value, true);
481
482 let dt = ArrowDataType::Struct(vec![key, value].into());
483 let dt = Arc::new(Field::new("entries", dt, false));
484 ArrowDataType::Map(dt, false)
485 }
486 DataType::Union { variants, .. } => {
487 let fields = variants
488 .iter()
489 .enumerate()
490 .map(|(index, variant)| {
491 let index = index as u8 as i8;
495 let arrow_dt = variant.to_arrow_data_type();
496 let field = Arc::new(Field::new(format!("_union_{index}"), arrow_dt, true));
499 (index, field)
500 })
501 .collect();
502 ArrowDataType::Union(fields, UnionMode::Sparse)
503 }
504 }
505 }
506}
507
508impl Display for DataType {
509 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
510 match self {
511 DataType::Boolean { column_index: _ } => write!(f, "BOOLEAN"),
512 DataType::Byte { column_index: _ } => write!(f, "BYTE"),
513 DataType::Short { column_index: _ } => write!(f, "SHORT"),
514 DataType::Int { column_index: _ } => write!(f, "INTEGER"),
515 DataType::Long { column_index: _ } => write!(f, "LONG"),
516 DataType::Float { column_index: _ } => write!(f, "FLOAT"),
517 DataType::Double { column_index: _ } => write!(f, "DOUBLE"),
518 DataType::String { column_index: _ } => write!(f, "STRING"),
519 DataType::Varchar {
520 column_index: _,
521 max_length,
522 } => write!(f, "VARCHAR({max_length})"),
523 DataType::Char {
524 column_index: _,
525 max_length,
526 } => write!(f, "CHAR({max_length})"),
527 DataType::Binary { column_index: _ } => write!(f, "BINARY"),
528 DataType::Decimal {
529 column_index: _,
530 precision,
531 scale,
532 } => write!(f, "DECIMAL({precision}, {scale})"),
533 DataType::Timestamp { column_index: _ } => write!(f, "TIMESTAMP"),
534 DataType::TimestampWithLocalTimezone { column_index: _ } => {
535 write!(f, "TIMESTAMP INSTANT")
536 }
537 DataType::Date { column_index: _ } => write!(f, "DATE"),
538 DataType::Struct {
539 column_index: _,
540 children,
541 } => {
542 write!(f, "STRUCT")?;
543 for child in children {
544 write!(f, "\n {child}")?;
545 }
546 Ok(())
547 }
548 DataType::List {
549 column_index: _,
550 child,
551 } => write!(f, "LIST\n {child}"),
552 DataType::Map {
553 column_index: _,
554 key,
555 value,
556 } => write!(f, "MAP\n {key}\n {value}"),
557 DataType::Union {
558 column_index: _,
559 variants,
560 } => {
561 write!(f, "UNION")?;
562 for variant in variants {
563 write!(f, "\n {variant}")?;
564 }
565 Ok(())
566 }
567 }
568 }
569}