1use std::sync::Arc;
16
17use crate::SchemaField as APISchemaField;
18
19use crate::error::{Error, Result};
20
21use arrow_schema::{DataType as ArrowDataType, Field as ArrowField, SchemaRef as ArrowSchemaRef};
22
23pub const EXTENSION_KEY: &str = "Extension";
25pub const ARROW_EXT_TYPE_EMPTY_ARRAY: &str = "EmptyArray";
26pub const ARROW_EXT_TYPE_EMPTY_MAP: &str = "EmptyMap";
27pub const ARROW_EXT_TYPE_VARIANT: &str = "Variant";
28pub const ARROW_EXT_TYPE_BITMAP: &str = "Bitmap";
29pub const ARROW_EXT_TYPE_GEOMETRY: &str = "Geometry";
30pub const ARROW_EXT_TYPE_GEOGRAPHY: &str = "Geography";
31pub const ARROW_EXT_TYPE_INTERVAL: &str = "Interval";
32pub const ARROW_EXT_TYPE_VECTOR: &str = "Vector";
33pub const ARROW_EXT_TYPE_TIMESTAMP_TIMEZONE: &str = "TimestampTz";
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum NumberDataType {
37 UInt8,
38 UInt16,
39 UInt32,
40 UInt64,
41 Int8,
42 Int16,
43 Int32,
44 Int64,
45 Float32,
46 Float64,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub struct DecimalSize {
51 pub precision: u8,
52 pub scale: u8,
53}
54
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum DecimalDataType {
57 Decimal64(DecimalSize),
58 Decimal128(DecimalSize),
59 Decimal256(DecimalSize),
60}
61
62impl DecimalDataType {
63 pub fn decimal_size(&self) -> &DecimalSize {
64 match self {
65 DecimalDataType::Decimal64(size) => size,
66 DecimalDataType::Decimal128(size) => size,
67 DecimalDataType::Decimal256(size) => size,
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
73pub enum DataType {
74 Null,
75 EmptyArray,
76 EmptyMap,
77 Boolean,
78 Binary,
79 String,
80 Number(NumberDataType),
81 Decimal(DecimalDataType),
82 Timestamp,
83 TimestampTz,
84 Date,
85 Nullable(Box<DataType>),
86 Array(Box<DataType>),
87 Map(Box<DataType>),
88 Tuple(Vec<DataType>),
89 Variant,
90 Bitmap,
91 Geometry,
92 Geography,
93 Interval,
94 Vector(u64),
95 }
97
98impl DataType {
99 pub fn is_numeric(&self) -> bool {
100 match self {
101 DataType::Number(_) | DataType::Decimal(_) => true,
102 DataType::Nullable(inner) => inner.is_numeric(),
103 _ => false,
104 }
105 }
106}
107
108impl std::fmt::Display for DataType {
109 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
110 match self {
111 DataType::Null => write!(f, "Null"),
112 DataType::EmptyArray => write!(f, "EmptyArray"),
113 DataType::EmptyMap => write!(f, "EmptyMap"),
114 DataType::Boolean => write!(f, "Boolean"),
115 DataType::Binary => write!(f, "Binary"),
116 DataType::String => write!(f, "String"),
117 DataType::Number(n) => match n {
118 NumberDataType::UInt8 => write!(f, "UInt8"),
119 NumberDataType::UInt16 => write!(f, "UInt16"),
120 NumberDataType::UInt32 => write!(f, "UInt32"),
121 NumberDataType::UInt64 => write!(f, "UInt64"),
122 NumberDataType::Int8 => write!(f, "Int8"),
123 NumberDataType::Int16 => write!(f, "Int16"),
124 NumberDataType::Int32 => write!(f, "Int32"),
125 NumberDataType::Int64 => write!(f, "Int64"),
126 NumberDataType::Float32 => write!(f, "Float32"),
127 NumberDataType::Float64 => write!(f, "Float64"),
128 },
129 DataType::Decimal(d) => {
130 let size = d.decimal_size();
131 write!(f, "Decimal({}, {})", size.precision, size.scale)
132 }
133 DataType::Timestamp => write!(f, "Timestamp"),
134 DataType::TimestampTz => write!(f, "Timestamp_Tz"),
135 DataType::Date => write!(f, "Date"),
136 DataType::Nullable(inner) => write!(f, "Nullable({inner})"),
137 DataType::Array(inner) => write!(f, "Array({inner})"),
138 DataType::Map(inner) => match inner.as_ref() {
139 DataType::Tuple(tys) => {
140 write!(f, "Map({}, {})", tys[0], tys[1])
141 }
142 _ => unreachable!(),
143 },
144 DataType::Tuple(inner) => {
145 let inner = inner
146 .iter()
147 .map(|x| x.to_string())
148 .collect::<Vec<_>>()
149 .join(", ");
150 write!(f, "Tuple({inner})")
151 }
152 DataType::Variant => write!(f, "Variant"),
153 DataType::Bitmap => write!(f, "Bitmap"),
154 DataType::Geometry => write!(f, "Geometry"),
155 DataType::Geography => write!(f, "Geography"),
156 DataType::Interval => write!(f, "Interval"),
157 DataType::Vector(d) => write!(f, "Vector({d})"),
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
163pub struct Field {
164 pub name: String,
165 pub data_type: DataType,
166}
167
168#[derive(Debug, Clone, Default)]
169pub struct Schema(Vec<Field>);
170
171pub type SchemaRef = Arc<Schema>;
172
173impl Schema {
174 pub fn fields(&self) -> &[Field] {
175 &self.0
176 }
177
178 pub fn from_vec(fields: Vec<Field>) -> Self {
179 Self(fields)
180 }
181}
182
183impl TryFrom<&TypeDesc<'_>> for DataType {
184 type Error = Error;
185
186 fn try_from(desc: &TypeDesc) -> Result<Self> {
187 if desc.nullable {
188 let mut desc = desc.clone();
189 desc.nullable = false;
190 let inner = DataType::try_from(&desc)?;
191 return Ok(DataType::Nullable(Box::new(inner)));
192 }
193 let dt = match desc.name {
194 "NULL" | "Null" => DataType::Null,
195 "Boolean" => DataType::Boolean,
196 "Binary" => DataType::Binary,
197 "String" => DataType::String,
198 "Int8" => DataType::Number(NumberDataType::Int8),
199 "Int16" => DataType::Number(NumberDataType::Int16),
200 "Int32" => DataType::Number(NumberDataType::Int32),
201 "Int64" => DataType::Number(NumberDataType::Int64),
202 "UInt8" => DataType::Number(NumberDataType::UInt8),
203 "UInt16" => DataType::Number(NumberDataType::UInt16),
204 "UInt32" => DataType::Number(NumberDataType::UInt32),
205 "UInt64" => DataType::Number(NumberDataType::UInt64),
206 "Float32" => DataType::Number(NumberDataType::Float32),
207 "Float64" => DataType::Number(NumberDataType::Float64),
208 "Decimal" => {
209 let precision = desc.args[0].name.parse::<u8>()?;
210 let scale = desc.args[1].name.parse::<u8>()?;
211
212 if precision <= 38 {
213 DataType::Decimal(DecimalDataType::Decimal128(DecimalSize {
214 precision,
215 scale,
216 }))
217 } else {
218 DataType::Decimal(DecimalDataType::Decimal256(DecimalSize {
219 precision,
220 scale,
221 }))
222 }
223 }
224 "Timestamp" => DataType::Timestamp,
225 "Date" => DataType::Date,
226 "Nullable" => {
227 if desc.args.len() != 1 {
228 return Err(Error::Decode(
229 "Nullable type must have one argument".to_string(),
230 ));
231 }
232 let mut desc = desc.clone();
233 desc.nullable = false;
235 let inner = Self::try_from(&desc.args[0])?;
236 DataType::Nullable(Box::new(inner))
237 }
238 "Array" => {
239 if desc.args.len() != 1 {
240 return Err(Error::Decode(
241 "Array type must have one argument".to_string(),
242 ));
243 }
244 if desc.args[0].name == "Nothing" {
245 DataType::EmptyArray
246 } else {
247 let inner = Self::try_from(&desc.args[0])?;
248 DataType::Array(Box::new(inner))
249 }
250 }
251 "Map" => {
252 if desc.args.len() == 1 && desc.args[0].name == "Nothing" {
253 DataType::EmptyMap
254 } else {
255 if desc.args.len() != 2 {
256 return Err(Error::Decode(
257 "Map type must have two arguments".to_string(),
258 ));
259 }
260 let key_ty = Self::try_from(&desc.args[0])?;
261 let val_ty = Self::try_from(&desc.args[1])?;
262 DataType::Map(Box::new(DataType::Tuple(vec![key_ty, val_ty])))
263 }
264 }
265 "Tuple" => {
266 let mut inner = vec![];
267 for arg in &desc.args {
268 inner.push(Self::try_from(arg)?);
269 }
270 DataType::Tuple(inner)
271 }
272 "Variant" => DataType::Variant,
273 "Bitmap" => DataType::Bitmap,
274 "Geometry" => DataType::Geometry,
275 "Geography" => DataType::Geography,
276 "Interval" => DataType::Interval,
277 "Vector" => {
278 let dimension = desc.args[0].name.parse::<u64>()?;
279 DataType::Vector(dimension)
280 }
281 "Timestamp_Tz" => DataType::TimestampTz,
282 _ => return Err(Error::Decode(format!("Unknown type: {desc:?}"))),
283 };
284 Ok(dt)
285 }
286}
287
288impl TryFrom<APISchemaField> for Field {
289 type Error = Error;
290
291 fn try_from(f: APISchemaField) -> Result<Self> {
292 let type_desc = parse_type_desc(&f.data_type)?;
293 let dt = DataType::try_from(&type_desc)?;
294 let field = Self {
295 name: f.name,
296 data_type: dt,
297 };
298 Ok(field)
299 }
300}
301
302impl TryFrom<Vec<APISchemaField>> for Schema {
303 type Error = Error;
304
305 fn try_from(fields: Vec<APISchemaField>) -> Result<Self> {
306 let fields = fields
307 .into_iter()
308 .map(Field::try_from)
309 .collect::<Result<Vec<_>>>()?;
310 Ok(Self(fields))
311 }
312}
313
314impl TryFrom<&Arc<ArrowField>> for Field {
315 type Error = Error;
316
317 fn try_from(f: &Arc<ArrowField>) -> Result<Self> {
318 let mut dt = if let Some(extend_type) = f.metadata().get(EXTENSION_KEY) {
319 match extend_type.as_str() {
320 ARROW_EXT_TYPE_EMPTY_ARRAY => DataType::EmptyArray,
321 ARROW_EXT_TYPE_EMPTY_MAP => DataType::EmptyMap,
322 ARROW_EXT_TYPE_VARIANT => DataType::Variant,
323 ARROW_EXT_TYPE_BITMAP => DataType::Bitmap,
324 ARROW_EXT_TYPE_GEOMETRY => DataType::Geometry,
325 ARROW_EXT_TYPE_GEOGRAPHY => DataType::Geography,
326 ARROW_EXT_TYPE_INTERVAL => DataType::Interval,
327 ARROW_EXT_TYPE_TIMESTAMP_TIMEZONE => DataType::TimestampTz,
328 ARROW_EXT_TYPE_VECTOR => match f.data_type() {
329 ArrowDataType::FixedSizeList(field, dimension) => {
330 let dimension = match field.data_type() {
331 ArrowDataType::Float32 => *dimension as u64,
332 _ => {
333 return Err(Error::Decode(format!(
334 "Unsupported FixedSizeList Arrow type: {:?}",
335 field.data_type()
336 )));
337 }
338 };
339 DataType::Vector(dimension)
340 }
341 arrow_type => {
342 return Err(Error::Decode(format!(
343 "Unsupported Arrow type: {arrow_type:?}",
344 )));
345 }
346 },
347 _ => {
348 return Err(Error::Decode(format!(
349 "Unsupported extension datatype for arrow field: {f:?}"
350 )))
351 }
352 }
353 } else {
354 match f.data_type() {
355 ArrowDataType::Null => DataType::Null,
356 ArrowDataType::Boolean => DataType::Boolean,
357 ArrowDataType::Int8 => DataType::Number(NumberDataType::Int8),
358 ArrowDataType::Int16 => DataType::Number(NumberDataType::Int16),
359 ArrowDataType::Int32 => DataType::Number(NumberDataType::Int32),
360 ArrowDataType::Int64 => DataType::Number(NumberDataType::Int64),
361 ArrowDataType::UInt8 => DataType::Number(NumberDataType::UInt8),
362 ArrowDataType::UInt16 => DataType::Number(NumberDataType::UInt16),
363 ArrowDataType::UInt32 => DataType::Number(NumberDataType::UInt32),
364 ArrowDataType::UInt64 => DataType::Number(NumberDataType::UInt64),
365 ArrowDataType::Float32 => DataType::Number(NumberDataType::Float32),
366 ArrowDataType::Float64 => DataType::Number(NumberDataType::Float64),
367 ArrowDataType::Binary
368 | ArrowDataType::LargeBinary
369 | ArrowDataType::FixedSizeBinary(_) => DataType::Binary,
370 ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Utf8View => {
371 DataType::String
372 }
373 ArrowDataType::Timestamp(_, _) => DataType::Timestamp,
374 ArrowDataType::Date32 => DataType::Date,
375 ArrowDataType::Decimal64(p, s) => {
376 DataType::Decimal(DecimalDataType::Decimal64(DecimalSize {
377 precision: *p,
378 scale: *s as u8,
379 }))
380 }
381 ArrowDataType::Decimal128(p, s) => {
382 DataType::Decimal(DecimalDataType::Decimal128(DecimalSize {
383 precision: *p,
384 scale: *s as u8,
385 }))
386 }
387 ArrowDataType::Decimal256(p, s) => {
388 DataType::Decimal(DecimalDataType::Decimal256(DecimalSize {
389 precision: *p,
390 scale: *s as u8,
391 }))
392 }
393 ArrowDataType::List(f) | ArrowDataType::LargeList(f) => {
394 let inner_field = Field::try_from(f)?;
395 let inner_ty = inner_field.data_type;
396 DataType::Array(Box::new(inner_ty))
397 }
398 ArrowDataType::Map(f, _) => {
399 let inner_field = Field::try_from(f)?;
400 let inner_ty = inner_field.data_type;
401 DataType::Map(Box::new(inner_ty))
402 }
403 ArrowDataType::Struct(fs) => {
404 let mut inner_tys = Vec::with_capacity(fs.len());
405 for f in fs {
406 let inner_field = Field::try_from(f)?;
407 let inner_ty = inner_field.data_type;
408 inner_tys.push(inner_ty);
409 }
410 DataType::Tuple(inner_tys)
411 }
412 _ => {
413 return Err(Error::Decode(format!(
414 "Unsupported datatype for arrow field: {f:?}"
415 )))
416 }
417 }
418 };
419 if f.is_nullable() && !matches!(dt, DataType::Null) {
420 dt = DataType::Nullable(Box::new(dt));
421 }
422 Ok(Field {
423 name: f.name().to_string(),
424 data_type: dt,
425 })
426 }
427}
428
429impl TryFrom<ArrowSchemaRef> for Schema {
430 type Error = Error;
431
432 fn try_from(schema_ref: ArrowSchemaRef) -> Result<Self> {
433 let fields = schema_ref
434 .fields()
435 .iter()
436 .map(Field::try_from)
437 .collect::<Result<Vec<_>>>()?;
438 Ok(Self(fields))
439 }
440}
441
442#[derive(Debug, Clone, PartialEq, Eq)]
443struct TypeDesc<'t> {
444 name: &'t str,
445 nullable: bool,
446 args: Vec<TypeDesc<'t>>,
447}
448
449fn parse_type_desc(s: &str) -> Result<TypeDesc<'_>> {
450 let mut name = "";
451 let mut args = vec![];
452 let mut depth = 0;
453 let mut start = 0;
454 let mut nullable = false;
455 for (i, c) in s.char_indices() {
456 match c {
457 '(' => {
458 if depth == 0 {
459 name = &s[start..i];
460 start = i + 1;
461 }
462 depth += 1;
463 }
464 ')' => {
465 depth -= 1;
466 if depth == 0 {
467 let s = &s[start..i];
468 if !s.is_empty() {
469 args.push(parse_type_desc(s)?);
470 }
471 start = i + 1;
472 }
473 }
474 ',' => {
475 if depth == 1 {
476 let s = &s[start..i];
477 args.push(parse_type_desc(s)?);
478 start = i + 1;
479 }
480 }
481 ' ' => {
482 if depth == 0 {
483 let s = &s[start..i];
484 if !s.is_empty() {
485 name = s;
486 }
487 start = i + 1;
488 }
489 }
490 _ => {}
491 }
492 }
493 if depth != 0 {
494 return Err(Error::Decode(format!("Invalid type desc: {s}")));
495 }
496 if start < s.len() {
497 let s = &s[start..];
498 if !s.is_empty() {
499 if name.is_empty() {
500 name = s;
501 } else if s == "NULL" {
502 nullable = true;
503 } else {
504 return Err(Error::Decode(format!("Invalid type arg for {name}: {s}")));
505 }
506 }
507 }
508 Ok(TypeDesc {
509 name,
510 nullable,
511 args,
512 })
513}
514
515#[cfg(test)]
516mod test {
517 use std::vec;
518
519 use super::*;
520
521 #[test]
522 fn test_parse_type_desc() {
523 struct TestCase<'t> {
524 desc: &'t str,
525 input: &'t str,
526 output: TypeDesc<'t>,
527 }
528 let test_cases = vec![
529 TestCase {
530 desc: "plain type",
531 input: "String",
532 output: TypeDesc {
533 name: "String",
534 nullable: false,
535 args: vec![],
536 },
537 },
538 TestCase {
539 desc: "decimal type",
540 input: "Decimal(42, 42)",
541 output: TypeDesc {
542 name: "Decimal",
543 nullable: false,
544 args: vec![
545 TypeDesc {
546 name: "42",
547 nullable: false,
548 args: vec![],
549 },
550 TypeDesc {
551 name: "42",
552 nullable: false,
553 args: vec![],
554 },
555 ],
556 },
557 },
558 TestCase {
559 desc: "nullable type",
560 input: "Nullable(Nothing)",
561 output: TypeDesc {
562 name: "Nullable",
563 nullable: false,
564 args: vec![TypeDesc {
565 name: "Nothing",
566 nullable: false,
567 args: vec![],
568 }],
569 },
570 },
571 TestCase {
572 desc: "empty arg",
573 input: "DateTime()",
574 output: TypeDesc {
575 name: "DateTime",
576 nullable: false,
577 args: vec![],
578 },
579 },
580 TestCase {
581 desc: "numeric arg",
582 input: "FixedString(42)",
583 output: TypeDesc {
584 name: "FixedString",
585 nullable: false,
586 args: vec![TypeDesc {
587 name: "42",
588 nullable: false,
589 args: vec![],
590 }],
591 },
592 },
593 TestCase {
594 desc: "multiple args",
595 input: "Array(Tuple(Tuple(String, String), Tuple(String, UInt64)))",
596 output: TypeDesc {
597 name: "Array",
598 nullable: false,
599 args: vec![TypeDesc {
600 name: "Tuple",
601 nullable: false,
602 args: vec![
603 TypeDesc {
604 name: "Tuple",
605 nullable: false,
606 args: vec![
607 TypeDesc {
608 name: "String",
609 nullable: false,
610 args: vec![],
611 },
612 TypeDesc {
613 name: "String",
614 nullable: false,
615 args: vec![],
616 },
617 ],
618 },
619 TypeDesc {
620 name: "Tuple",
621 nullable: false,
622 args: vec![
623 TypeDesc {
624 name: "String",
625 nullable: false,
626 args: vec![],
627 },
628 TypeDesc {
629 name: "UInt64",
630 nullable: false,
631 args: vec![],
632 },
633 ],
634 },
635 ],
636 }],
637 },
638 },
639 TestCase {
640 desc: "map args",
641 input: "Map(String, Array(Int64))",
642 output: TypeDesc {
643 name: "Map",
644 nullable: false,
645 args: vec![
646 TypeDesc {
647 name: "String",
648 nullable: false,
649 args: vec![],
650 },
651 TypeDesc {
652 name: "Array",
653 nullable: false,
654 args: vec![TypeDesc {
655 name: "Int64",
656 nullable: false,
657 args: vec![],
658 }],
659 },
660 ],
661 },
662 },
663 TestCase {
664 desc: "map nullable value args",
665 input: "Nullable(Map(String, String NULL))",
666 output: TypeDesc {
667 name: "Nullable",
668 nullable: false,
669 args: vec![TypeDesc {
670 name: "Map",
671 nullable: false,
672 args: vec![
673 TypeDesc {
674 name: "String",
675 nullable: false,
676 args: vec![],
677 },
678 TypeDesc {
679 name: "String",
680 nullable: true,
681 args: vec![],
682 },
683 ],
684 }],
685 },
686 },
687 ];
688 for case in test_cases {
689 let output = parse_type_desc(case.input).unwrap();
690 assert_eq!(output, case.output, "{}", case.desc);
691 }
692 }
693
694 #[test]
695 fn test_parse_complex_type_with_null() {
696 struct TestCase<'t> {
697 desc: &'t str,
698 input: &'t str,
699 output: TypeDesc<'t>,
700 }
701 let test_cases = vec![
702 TestCase {
703 desc: "complex nullable type",
704 input: "Nullable(Tuple(String NULL, Array(Tuple(Array(Int32 NULL) NULL, Array(String NULL) NULL) NULL) NULL))",
705 output: TypeDesc {
706 name: "Nullable",
707 nullable: false,
708 args: vec![
709 TypeDesc {
710 name: "Tuple",
711 nullable: false,
712 args: vec![
713 TypeDesc {
714 name: "String",
715 nullable: true,
716 args: vec![],
717 },
718 TypeDesc {
719 name: "Array",
720 nullable: true,
721 args: vec![
722 TypeDesc{
723 name: "Tuple",
724 nullable: true,
725 args: vec![
726 TypeDesc {
727 name: "Array",
728 nullable: true,
729 args: vec![
730 TypeDesc {
731 name: "Int32",
732 nullable: true,
733 args: vec![],
734 },
735 ],
736 },
737 TypeDesc {
738 name: "Array",
739 nullable: true,
740 args: vec![
741 TypeDesc {
742 name: "String",
743 nullable: true,
744 args: vec![],
745 },
746 ],
747 },
748 ]
749 }
750 ],
751 },
752 ],
753 },
754 ],
755 },
756 },
757 ];
758 for case in test_cases {
759 let output = parse_type_desc(case.input).unwrap();
760 assert_eq!(output, case.output, "{}", case.desc);
761 }
762 }
763}