1use std::collections::HashMap;
19use std::fmt::Display;
20use std::sync::Arc;
21
22use snafu::{ensure, OptionExt};
23
24use crate::error::{NoTypesSnafu, Result, UnexpectedSnafu};
25use crate::projection::ProjectionMask;
26use crate::proto;
27
28use arrow::datatypes::{DataType as ArrowDataType, Field, Schema, TimeUnit, UnionMode};
29
30#[derive(Debug, Clone)]
42pub struct RootDataType {
43 children: Vec<NamedColumn>,
44}
45
46impl RootDataType {
47 pub fn column_index(&self) -> usize {
49 0
50 }
51
52 pub fn children(&self) -> &[NamedColumn] {
54 &self.children
55 }
56
57 pub fn create_arrow_schema(&self, user_metadata: &HashMap<String, String>) -> Schema {
59 let fields = self
60 .children
61 .iter()
62 .map(|col| {
63 let dt = col.data_type().to_arrow_data_type();
64 Field::new(col.name(), dt, true)
65 })
66 .collect::<Vec<_>>();
67 Schema::new_with_metadata(fields, user_metadata.clone())
68 }
69
70 pub fn project(&self, mask: &ProjectionMask) -> Self {
72 let children = self
74 .children
75 .iter()
76 .filter(|col| mask.is_index_projected(col.data_type().column_index()))
77 .map(|col| col.to_owned())
78 .collect::<Vec<_>>();
79 Self { children }
80 }
81
82 pub(crate) fn from_proto(types: &[proto::Type]) -> Result<Self> {
84 ensure!(!types.is_empty(), NoTypesSnafu {});
85 let children = parse_struct_children_from_proto(types, 0)?;
86 Ok(Self { children })
87 }
88}
89
90impl Display for RootDataType {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 write!(f, "ROOT")?;
93 for child in &self.children {
94 write!(f, "\n {child}")?;
95 }
96 Ok(())
97 }
98}
99
100#[derive(Debug, Clone)]
101pub struct NamedColumn {
102 name: String,
103 data_type: DataType,
104}
105
106impl NamedColumn {
107 pub fn name(&self) -> &str {
108 self.name.as_str()
109 }
110
111 pub fn data_type(&self) -> &DataType {
112 &self.data_type
113 }
114}
115
116impl Display for NamedColumn {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 write!(f, "{} {}", self.name(), self.data_type())
119 }
120}
121
122fn parse_struct_children_from_proto(
125 types: &[proto::Type],
126 column_index: usize,
127) -> Result<Vec<NamedColumn>> {
128 assert!(column_index < types.len());
130 let ty = &types[column_index];
131 assert!(ty.kind() == proto::r#type::Kind::Struct);
132 ensure!(
133 ty.subtypes.len() == ty.field_names.len(),
134 UnexpectedSnafu {
135 msg: format!(
136 "Struct type for column index {column_index} must have matching lengths for subtypes and field names lists"
137 )
138 }
139 );
140 let children = ty
141 .subtypes
142 .iter()
143 .zip(ty.field_names.iter())
144 .map(|(&index, name)| {
145 let index = index as usize;
146 let name = name.to_owned();
147 let data_type = DataType::from_proto(types, index)?;
148 Ok(NamedColumn { name, data_type })
149 })
150 .collect::<Result<Vec<_>>>()?;
151 Ok(children)
152}
153
154#[derive(Debug, Clone)]
159pub enum DataType {
160 Boolean { column_index: usize },
162 Byte { column_index: usize },
164 Short { column_index: usize },
166 Int { column_index: usize },
168 Long { column_index: usize },
170 Float { column_index: usize },
172 Double { column_index: usize },
174 String { column_index: usize },
176 Varchar {
178 column_index: usize,
179 max_length: u32,
180 },
181 Char {
183 column_index: usize,
184 max_length: u32,
185 },
186 Binary { column_index: usize },
188 Decimal {
190 column_index: usize,
191 precision: u32,
193 scale: u32,
194 },
195 Timestamp { column_index: usize },
201 TimestampWithLocalTimezone { column_index: usize },
207 Date { column_index: usize },
210 Struct {
213 column_index: usize,
214 children: Vec<NamedColumn>,
215 },
216 List {
219 column_index: usize,
220 child: Box<DataType>,
221 },
222 Map {
225 column_index: usize,
226 key: Box<DataType>,
227 value: Box<DataType>,
228 },
229 Union {
235 column_index: usize,
236 variants: Vec<DataType>,
237 },
238}
239
240impl DataType {
241 pub fn column_index(&self) -> usize {
244 match self {
245 DataType::Boolean { column_index } => *column_index,
246 DataType::Byte { column_index } => *column_index,
247 DataType::Short { column_index } => *column_index,
248 DataType::Int { column_index } => *column_index,
249 DataType::Long { column_index } => *column_index,
250 DataType::Float { column_index } => *column_index,
251 DataType::Double { column_index } => *column_index,
252 DataType::String { column_index } => *column_index,
253 DataType::Varchar { column_index, .. } => *column_index,
254 DataType::Char { column_index, .. } => *column_index,
255 DataType::Binary { column_index } => *column_index,
256 DataType::Decimal { column_index, .. } => *column_index,
257 DataType::Timestamp { column_index } => *column_index,
258 DataType::TimestampWithLocalTimezone { column_index } => *column_index,
259 DataType::Date { column_index } => *column_index,
260 DataType::Struct { column_index, .. } => *column_index,
261 DataType::List { column_index, .. } => *column_index,
262 DataType::Map { column_index, .. } => *column_index,
263 DataType::Union { column_index, .. } => *column_index,
264 }
265 }
266
267 pub fn children_indices(&self) -> Vec<usize> {
269 match self {
270 DataType::Boolean { .. }
271 | DataType::Byte { .. }
272 | DataType::Short { .. }
273 | DataType::Int { .. }
274 | DataType::Long { .. }
275 | DataType::Float { .. }
276 | DataType::Double { .. }
277 | DataType::String { .. }
278 | DataType::Varchar { .. }
279 | DataType::Char { .. }
280 | DataType::Binary { .. }
281 | DataType::Decimal { .. }
282 | DataType::Timestamp { .. }
283 | DataType::TimestampWithLocalTimezone { .. }
284 | DataType::Date { .. } => vec![],
285 DataType::Struct { children, .. } => children
286 .iter()
287 .flat_map(|col| col.data_type().children_indices())
288 .collect(),
289 DataType::List { child, .. } => child.all_indices(),
290 DataType::Map { key, value, .. } => {
291 let mut indices = key.children_indices();
292 indices.extend(value.children_indices());
293 indices
294 }
295 DataType::Union { variants, .. } => variants
296 .iter()
297 .flat_map(|dt| dt.children_indices())
298 .collect(),
299 }
300 }
301
302 pub fn all_indices(&self) -> Vec<usize> {
304 let mut indices = vec![self.column_index()];
305 indices.extend(self.children_indices());
306 indices
307 }
308
309 fn from_proto(types: &[proto::Type], column_index: usize) -> Result<Self> {
310 use proto::r#type::Kind;
311
312 let ty = types.get(column_index).context(UnexpectedSnafu {
313 msg: format!("Column index out of bounds: {column_index}"),
314 })?;
315 let dt = match ty.kind() {
316 Kind::Boolean => Self::Boolean { column_index },
317 Kind::Byte => Self::Byte { column_index },
318 Kind::Short => Self::Short { column_index },
319 Kind::Int => Self::Int { column_index },
320 Kind::Long => Self::Long { column_index },
321 Kind::Float => Self::Float { column_index },
322 Kind::Double => Self::Double { column_index },
323 Kind::String => Self::String { column_index },
324 Kind::Binary => Self::Binary { column_index },
325 Kind::Timestamp => Self::Timestamp { column_index },
326 Kind::List => {
327 ensure!(
328 ty.subtypes.len() == 1,
329 UnexpectedSnafu {
330 msg: format!(
331 "List type for column index {} must have 1 sub type, found {}",
332 column_index,
333 ty.subtypes.len()
334 )
335 }
336 );
337 let child = ty.subtypes[0] as usize;
338 let child = Box::new(Self::from_proto(types, child)?);
339 Self::List {
340 column_index,
341 child,
342 }
343 }
344 Kind::Map => {
345 ensure!(
346 ty.subtypes.len() == 2,
347 UnexpectedSnafu {
348 msg: format!(
349 "Map type for column index {} must have 2 sub types, found {}",
350 column_index,
351 ty.subtypes.len()
352 )
353 }
354 );
355 let key = ty.subtypes[0] as usize;
356 let key = Box::new(Self::from_proto(types, key)?);
357 let value = ty.subtypes[1] as usize;
358 let value = Box::new(Self::from_proto(types, value)?);
359 Self::Map {
360 column_index,
361 key,
362 value,
363 }
364 }
365 Kind::Struct => {
366 let children = parse_struct_children_from_proto(types, column_index)?;
367 Self::Struct {
368 column_index,
369 children,
370 }
371 }
372 Kind::Union => {
373 ensure!(
375 ty.subtypes.len() <= 127,
376 UnexpectedSnafu {
377 msg: format!(
378 "Union type for column index {} cannot exceed 127 variants, found {}",
379 column_index,
380 ty.subtypes.len()
381 )
382 }
383 );
384 let variants = ty
385 .subtypes
386 .iter()
387 .map(|&index| {
388 let index = index as usize;
389 Self::from_proto(types, index)
390 })
391 .collect::<Result<Vec<_>>>()?;
392 Self::Union {
393 column_index,
394 variants,
395 }
396 }
397 Kind::Decimal => Self::Decimal {
398 column_index,
399 precision: ty.precision(),
400 scale: ty.scale(),
401 },
402 Kind::Date => Self::Date { column_index },
403 Kind::Varchar => Self::Varchar {
404 column_index,
405 max_length: ty.maximum_length(),
406 },
407 Kind::Char => Self::Char {
408 column_index,
409 max_length: ty.maximum_length(),
410 },
411 Kind::TimestampInstant => Self::TimestampWithLocalTimezone { column_index },
412 };
413 Ok(dt)
414 }
415
416 pub fn to_arrow_data_type(&self) -> ArrowDataType {
417 match self {
418 DataType::Boolean { .. } => ArrowDataType::Boolean,
419 DataType::Byte { .. } => ArrowDataType::Int8,
420 DataType::Short { .. } => ArrowDataType::Int16,
421 DataType::Int { .. } => ArrowDataType::Int32,
422 DataType::Long { .. } => ArrowDataType::Int64,
423 DataType::Float { .. } => ArrowDataType::Float32,
424 DataType::Double { .. } => ArrowDataType::Float64,
425 DataType::String { .. } | DataType::Varchar { .. } | DataType::Char { .. } => {
426 ArrowDataType::Utf8
427 }
428 DataType::Binary { .. } => ArrowDataType::Binary,
429 DataType::Decimal {
430 precision, scale, ..
431 } => ArrowDataType::Decimal128(*precision as u8, *scale as i8), DataType::Timestamp { .. } => ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
433 DataType::TimestampWithLocalTimezone { .. } => {
434 ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into()))
435 }
436 DataType::Date { .. } => ArrowDataType::Date32,
437 DataType::Struct { children, .. } => {
438 let children = children
439 .iter()
440 .map(|col| {
441 let dt = col.data_type().to_arrow_data_type();
442 Field::new(col.name(), dt, true)
443 })
444 .collect();
445 ArrowDataType::Struct(children)
446 }
447 DataType::List { child, .. } => {
448 let child = child.to_arrow_data_type();
449 ArrowDataType::new_list(child, true)
450 }
451 DataType::Map { key, value, .. } => {
452 let key = key.to_arrow_data_type();
457 let key = Field::new("keys", key, false);
458 let value = value.to_arrow_data_type();
459 let value = Field::new("values", value, true);
460
461 let dt = ArrowDataType::Struct(vec![key, value].into());
462 let dt = Arc::new(Field::new("entries", dt, false));
463 ArrowDataType::Map(dt, false)
464 }
465 DataType::Union { variants, .. } => {
466 let fields = variants
467 .iter()
468 .enumerate()
469 .map(|(index, variant)| {
470 let index = index as u8 as i8;
474 let arrow_dt = variant.to_arrow_data_type();
475 let field = Arc::new(Field::new(format!("_union_{index}"), arrow_dt, true));
478 (index, field)
479 })
480 .collect();
481 ArrowDataType::Union(fields, UnionMode::Sparse)
482 }
483 }
484 }
485}
486
487impl Display for DataType {
488 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
489 match self {
490 DataType::Boolean { column_index: _ } => write!(f, "BOOLEAN"),
491 DataType::Byte { column_index: _ } => write!(f, "BYTE"),
492 DataType::Short { column_index: _ } => write!(f, "SHORT"),
493 DataType::Int { column_index: _ } => write!(f, "INTEGER"),
494 DataType::Long { column_index: _ } => write!(f, "LONG"),
495 DataType::Float { column_index: _ } => write!(f, "FLOAT"),
496 DataType::Double { column_index: _ } => write!(f, "DOUBLE"),
497 DataType::String { column_index: _ } => write!(f, "STRING"),
498 DataType::Varchar {
499 column_index: _,
500 max_length,
501 } => write!(f, "VARCHAR({max_length})"),
502 DataType::Char {
503 column_index: _,
504 max_length,
505 } => write!(f, "CHAR({max_length})"),
506 DataType::Binary { column_index: _ } => write!(f, "BINARY"),
507 DataType::Decimal {
508 column_index: _,
509 precision,
510 scale,
511 } => write!(f, "DECIMAL({precision}, {scale})"),
512 DataType::Timestamp { column_index: _ } => write!(f, "TIMESTAMP"),
513 DataType::TimestampWithLocalTimezone { column_index: _ } => {
514 write!(f, "TIMESTAMP INSTANT")
515 }
516 DataType::Date { column_index: _ } => write!(f, "DATE"),
517 DataType::Struct {
518 column_index: _,
519 children,
520 } => {
521 write!(f, "STRUCT")?;
522 for child in children {
523 write!(f, "\n {child}")?;
524 }
525 Ok(())
526 }
527 DataType::List {
528 column_index: _,
529 child,
530 } => write!(f, "LIST\n {child}"),
531 DataType::Map {
532 column_index: _,
533 key,
534 value,
535 } => write!(f, "MAP\n {key}\n {value}"),
536 DataType::Union {
537 column_index: _,
538 variants,
539 } => {
540 write!(f, "UNION")?;
541 for variant in variants {
542 write!(f, "\n {variant}")?;
543 }
544 Ok(())
545 }
546 }
547 }
548}