1use std::collections::HashMap;
19use std::fmt::Display;
20use std::sync::Arc;
21
22use snafu::{ensure, OptionExt};
23
24use crate::error::{NoTypesSnafu, Result, UnexpectedSnafu};
25use crate::projection::ProjectionMask;
26use crate::proto;
27
28use arrow::datatypes::{DataType as ArrowDataType, Field, Schema, TimeUnit, UnionMode};
29
30#[derive(Debug, Clone)]
42pub struct RootDataType {
43 children: Vec<NamedColumn>,
44}
45
46impl RootDataType {
47 pub fn column_index(&self) -> usize {
49 0
50 }
51
52 pub fn children(&self) -> &[NamedColumn] {
54 &self.children
55 }
56
57 pub fn create_arrow_schema(&self, user_metadata: &HashMap<String, String>) -> Schema {
59 let fields = self
60 .children
61 .iter()
62 .map(|col| {
63 let dt = col.data_type().to_arrow_data_type();
64 Field::new(col.name(), dt, true)
65 })
66 .collect::<Vec<_>>();
67 Schema::new_with_metadata(fields, user_metadata.clone())
68 }
69
70 pub fn project(&self, mask: &ProjectionMask) -> Self {
72 let children = self
74 .children
75 .iter()
76 .filter(|col| mask.is_index_projected(col.data_type().column_index()))
77 .map(|col| col.to_owned())
78 .collect::<Vec<_>>();
79 Self { children }
80 }
81
82 pub(crate) fn from_proto(types: &[proto::Type]) -> Result<Self> {
84 ensure!(!types.is_empty(), NoTypesSnafu {});
85 let children = parse_struct_children_from_proto(types, 0)?;
86 Ok(Self { children })
87 }
88}
89
90impl Display for RootDataType {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 write!(f, "ROOT")?;
93 for child in &self.children {
94 write!(f, "\n {child}")?;
95 }
96 Ok(())
97 }
98}
99
100#[derive(Debug, Clone)]
101pub struct NamedColumn {
102 name: String,
103 data_type: DataType,
104}
105
106impl NamedColumn {
107 pub fn name(&self) -> &str {
108 self.name.as_str()
109 }
110
111 pub fn data_type(&self) -> &DataType {
112 &self.data_type
113 }
114}
115
116impl Display for NamedColumn {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 write!(f, "{} {}", self.name(), self.data_type())
119 }
120}
121
122fn parse_struct_children_from_proto(
125 types: &[proto::Type],
126 column_index: usize,
127) -> Result<Vec<NamedColumn>> {
128 assert!(column_index < types.len());
130 let ty = &types[column_index];
131 assert!(ty.kind() == proto::r#type::Kind::Struct);
132 ensure!(
133 ty.subtypes.len() == ty.field_names.len(),
134 UnexpectedSnafu {
135 msg: format!(
136 "Struct type for column index {} must have matching lengths for subtypes and field names lists",
137 column_index,
138 )
139 }
140 );
141 let children = ty
142 .subtypes
143 .iter()
144 .zip(ty.field_names.iter())
145 .map(|(&index, name)| {
146 let index = index as usize;
147 let name = name.to_owned();
148 let data_type = DataType::from_proto(types, index)?;
149 Ok(NamedColumn { name, data_type })
150 })
151 .collect::<Result<Vec<_>>>()?;
152 Ok(children)
153}
154
155#[derive(Debug, Clone)]
160pub enum DataType {
161 Boolean { column_index: usize },
163 Byte { column_index: usize },
165 Short { column_index: usize },
167 Int { column_index: usize },
169 Long { column_index: usize },
171 Float { column_index: usize },
173 Double { column_index: usize },
175 String { column_index: usize },
177 Varchar {
179 column_index: usize,
180 max_length: u32,
181 },
182 Char {
184 column_index: usize,
185 max_length: u32,
186 },
187 Binary { column_index: usize },
189 Decimal {
191 column_index: usize,
192 precision: u32,
194 scale: u32,
195 },
196 Timestamp { column_index: usize },
202 TimestampWithLocalTimezone { column_index: usize },
208 Date { column_index: usize },
211 Struct {
214 column_index: usize,
215 children: Vec<NamedColumn>,
216 },
217 List {
220 column_index: usize,
221 child: Box<DataType>,
222 },
223 Map {
226 column_index: usize,
227 key: Box<DataType>,
228 value: Box<DataType>,
229 },
230 Union {
236 column_index: usize,
237 variants: Vec<DataType>,
238 },
239}
240
241impl DataType {
242 pub fn column_index(&self) -> usize {
245 match self {
246 DataType::Boolean { column_index } => *column_index,
247 DataType::Byte { column_index } => *column_index,
248 DataType::Short { column_index } => *column_index,
249 DataType::Int { column_index } => *column_index,
250 DataType::Long { column_index } => *column_index,
251 DataType::Float { column_index } => *column_index,
252 DataType::Double { column_index } => *column_index,
253 DataType::String { column_index } => *column_index,
254 DataType::Varchar { column_index, .. } => *column_index,
255 DataType::Char { column_index, .. } => *column_index,
256 DataType::Binary { column_index } => *column_index,
257 DataType::Decimal { column_index, .. } => *column_index,
258 DataType::Timestamp { column_index } => *column_index,
259 DataType::TimestampWithLocalTimezone { column_index } => *column_index,
260 DataType::Date { column_index } => *column_index,
261 DataType::Struct { column_index, .. } => *column_index,
262 DataType::List { column_index, .. } => *column_index,
263 DataType::Map { column_index, .. } => *column_index,
264 DataType::Union { column_index, .. } => *column_index,
265 }
266 }
267
268 pub fn children_indices(&self) -> Vec<usize> {
270 match self {
271 DataType::Boolean { .. }
272 | DataType::Byte { .. }
273 | DataType::Short { .. }
274 | DataType::Int { .. }
275 | DataType::Long { .. }
276 | DataType::Float { .. }
277 | DataType::Double { .. }
278 | DataType::String { .. }
279 | DataType::Varchar { .. }
280 | DataType::Char { .. }
281 | DataType::Binary { .. }
282 | DataType::Decimal { .. }
283 | DataType::Timestamp { .. }
284 | DataType::TimestampWithLocalTimezone { .. }
285 | DataType::Date { .. } => vec![],
286 DataType::Struct { children, .. } => children
287 .iter()
288 .flat_map(|col| col.data_type().children_indices())
289 .collect(),
290 DataType::List { child, .. } => child.all_indices(),
291 DataType::Map { key, value, .. } => {
292 let mut indices = key.children_indices();
293 indices.extend(value.children_indices());
294 indices
295 }
296 DataType::Union { variants, .. } => variants
297 .iter()
298 .flat_map(|dt| dt.children_indices())
299 .collect(),
300 }
301 }
302
303 pub fn all_indices(&self) -> Vec<usize> {
305 let mut indices = vec![self.column_index()];
306 indices.extend(self.children_indices());
307 indices
308 }
309
310 fn from_proto(types: &[proto::Type], column_index: usize) -> Result<Self> {
311 use proto::r#type::Kind;
312
313 let ty = types.get(column_index).context(UnexpectedSnafu {
314 msg: format!("Column index out of bounds: {column_index}"),
315 })?;
316 let dt = match ty.kind() {
317 Kind::Boolean => Self::Boolean { column_index },
318 Kind::Byte => Self::Byte { column_index },
319 Kind::Short => Self::Short { column_index },
320 Kind::Int => Self::Int { column_index },
321 Kind::Long => Self::Long { column_index },
322 Kind::Float => Self::Float { column_index },
323 Kind::Double => Self::Double { column_index },
324 Kind::String => Self::String { column_index },
325 Kind::Binary => Self::Binary { column_index },
326 Kind::Timestamp => Self::Timestamp { column_index },
327 Kind::List => {
328 ensure!(
329 ty.subtypes.len() == 1,
330 UnexpectedSnafu {
331 msg: format!(
332 "List type for column index {} must have 1 sub type, found {}",
333 column_index,
334 ty.subtypes.len()
335 )
336 }
337 );
338 let child = ty.subtypes[0] as usize;
339 let child = Box::new(Self::from_proto(types, child)?);
340 Self::List {
341 column_index,
342 child,
343 }
344 }
345 Kind::Map => {
346 ensure!(
347 ty.subtypes.len() == 2,
348 UnexpectedSnafu {
349 msg: format!(
350 "Map type for column index {} must have 2 sub types, found {}",
351 column_index,
352 ty.subtypes.len()
353 )
354 }
355 );
356 let key = ty.subtypes[0] as usize;
357 let key = Box::new(Self::from_proto(types, key)?);
358 let value = ty.subtypes[1] as usize;
359 let value = Box::new(Self::from_proto(types, value)?);
360 Self::Map {
361 column_index,
362 key,
363 value,
364 }
365 }
366 Kind::Struct => {
367 let children = parse_struct_children_from_proto(types, column_index)?;
368 Self::Struct {
369 column_index,
370 children,
371 }
372 }
373 Kind::Union => {
374 ensure!(
376 ty.subtypes.len() <= 127,
377 UnexpectedSnafu {
378 msg: format!(
379 "Union type for column index {} cannot exceed 127 variants, found {}",
380 column_index,
381 ty.subtypes.len()
382 )
383 }
384 );
385 let variants = ty
386 .subtypes
387 .iter()
388 .map(|&index| {
389 let index = index as usize;
390 Self::from_proto(types, index)
391 })
392 .collect::<Result<Vec<_>>>()?;
393 Self::Union {
394 column_index,
395 variants,
396 }
397 }
398 Kind::Decimal => Self::Decimal {
399 column_index,
400 precision: ty.precision(),
401 scale: ty.scale(),
402 },
403 Kind::Date => Self::Date { column_index },
404 Kind::Varchar => Self::Varchar {
405 column_index,
406 max_length: ty.maximum_length(),
407 },
408 Kind::Char => Self::Char {
409 column_index,
410 max_length: ty.maximum_length(),
411 },
412 Kind::TimestampInstant => Self::TimestampWithLocalTimezone { column_index },
413 };
414 Ok(dt)
415 }
416
417 pub fn to_arrow_data_type(&self) -> ArrowDataType {
418 match self {
419 DataType::Boolean { .. } => ArrowDataType::Boolean,
420 DataType::Byte { .. } => ArrowDataType::Int8,
421 DataType::Short { .. } => ArrowDataType::Int16,
422 DataType::Int { .. } => ArrowDataType::Int32,
423 DataType::Long { .. } => ArrowDataType::Int64,
424 DataType::Float { .. } => ArrowDataType::Float32,
425 DataType::Double { .. } => ArrowDataType::Float64,
426 DataType::String { .. } | DataType::Varchar { .. } | DataType::Char { .. } => {
427 ArrowDataType::Utf8
428 }
429 DataType::Binary { .. } => ArrowDataType::Binary,
430 DataType::Decimal {
431 precision, scale, ..
432 } => ArrowDataType::Decimal128(*precision as u8, *scale as i8), DataType::Timestamp { .. } => ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
434 DataType::TimestampWithLocalTimezone { .. } => {
435 ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into()))
436 }
437 DataType::Date { .. } => ArrowDataType::Date32,
438 DataType::Struct { children, .. } => {
439 let children = children
440 .iter()
441 .map(|col| {
442 let dt = col.data_type().to_arrow_data_type();
443 Field::new(col.name(), dt, true)
444 })
445 .collect();
446 ArrowDataType::Struct(children)
447 }
448 DataType::List { child, .. } => {
449 let child = child.to_arrow_data_type();
450 ArrowDataType::new_list(child, true)
451 }
452 DataType::Map { key, value, .. } => {
453 let key = key.to_arrow_data_type();
458 let key = Field::new("keys", key, false);
459 let value = value.to_arrow_data_type();
460 let value = Field::new("values", value, true);
461
462 let dt = ArrowDataType::Struct(vec![key, value].into());
463 let dt = Arc::new(Field::new("entries", dt, false));
464 ArrowDataType::Map(dt, false)
465 }
466 DataType::Union { variants, .. } => {
467 let fields = variants
468 .iter()
469 .enumerate()
470 .map(|(index, variant)| {
471 let index = index as u8 as i8;
475 let arrow_dt = variant.to_arrow_data_type();
476 let field = Arc::new(Field::new(format!("_union_{index}"), arrow_dt, true));
479 (index, field)
480 })
481 .collect();
482 ArrowDataType::Union(fields, UnionMode::Sparse)
483 }
484 }
485 }
486}
487
488impl Display for DataType {
489 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490 match self {
491 DataType::Boolean { column_index: _ } => write!(f, "BOOLEAN"),
492 DataType::Byte { column_index: _ } => write!(f, "BYTE"),
493 DataType::Short { column_index: _ } => write!(f, "SHORT"),
494 DataType::Int { column_index: _ } => write!(f, "INTEGER"),
495 DataType::Long { column_index: _ } => write!(f, "LONG"),
496 DataType::Float { column_index: _ } => write!(f, "FLOAT"),
497 DataType::Double { column_index: _ } => write!(f, "DOUBLE"),
498 DataType::String { column_index: _ } => write!(f, "STRING"),
499 DataType::Varchar {
500 column_index: _,
501 max_length,
502 } => write!(f, "VARCHAR({max_length})"),
503 DataType::Char {
504 column_index: _,
505 max_length,
506 } => write!(f, "CHAR({max_length})"),
507 DataType::Binary { column_index: _ } => write!(f, "BINARY"),
508 DataType::Decimal {
509 column_index: _,
510 precision,
511 scale,
512 } => write!(f, "DECIMAL({precision}, {scale})"),
513 DataType::Timestamp { column_index: _ } => write!(f, "TIMESTAMP"),
514 DataType::TimestampWithLocalTimezone { column_index: _ } => {
515 write!(f, "TIMESTAMP INSTANT")
516 }
517 DataType::Date { column_index: _ } => write!(f, "DATE"),
518 DataType::Struct {
519 column_index: _,
520 children,
521 } => {
522 write!(f, "STRUCT")?;
523 for child in children {
524 write!(f, "\n {child}")?;
525 }
526 Ok(())
527 }
528 DataType::List {
529 column_index: _,
530 child,
531 } => write!(f, "LIST\n {child}"),
532 DataType::Map {
533 column_index: _,
534 key,
535 value,
536 } => write!(f, "MAP\n {key}\n {value}"),
537 DataType::Union {
538 column_index: _,
539 variants,
540 } => {
541 write!(f, "UNION")?;
542 for variant in variants {
543 write!(f, "\n {variant}")?;
544 }
545 Ok(())
546 }
547 }
548 }
549}