1use std::sync::Arc;
16
17use crate::SchemaField as APISchemaField;
18
19use crate::error::{Error, Result};
20
21use arrow_schema::{DataType as ArrowDataType, Field as ArrowField, SchemaRef as ArrowSchemaRef};
22
23pub const EXTENSION_KEY: &str = "Extension";
25pub const ARROW_EXT_TYPE_EMPTY_ARRAY: &str = "EmptyArray";
26pub const ARROW_EXT_TYPE_EMPTY_MAP: &str = "EmptyMap";
27pub const ARROW_EXT_TYPE_VARIANT: &str = "Variant";
28pub const ARROW_EXT_TYPE_BITMAP: &str = "Bitmap";
29pub const ARROW_EXT_TYPE_GEOMETRY: &str = "Geometry";
30pub const ARROW_EXT_TYPE_GEOGRAPHY: &str = "Geography";
31pub const ARROW_EXT_TYPE_INTERVAL: &str = "Interval";
32pub const ARROW_EXT_TYPE_VECTOR: &str = "Vector";
33pub const ARROW_EXT_TYPE_TIMESTAMP_TIMEZONE: &str = "TimestampTz";
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum NumberDataType {
37 UInt8,
38 UInt16,
39 UInt32,
40 UInt64,
41 Int8,
42 Int16,
43 Int32,
44 Int64,
45 Float32,
46 Float64,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub struct DecimalSize {
51 pub precision: u8,
52 pub scale: u8,
53}
54
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum DecimalDataType {
57 Decimal128(DecimalSize),
58 Decimal256(DecimalSize),
59}
60
61impl DecimalDataType {
62 pub fn decimal_size(&self) -> &DecimalSize {
63 match self {
64 DecimalDataType::Decimal128(size) => size,
65 DecimalDataType::Decimal256(size) => size,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
71pub enum DataType {
72 Null,
73 EmptyArray,
74 EmptyMap,
75 Boolean,
76 Binary,
77 String,
78 Number(NumberDataType),
79 Decimal(DecimalDataType),
80 Timestamp,
81 TimestampTz,
82 Date,
83 Nullable(Box<DataType>),
84 Array(Box<DataType>),
85 Map(Box<DataType>),
86 Tuple(Vec<DataType>),
87 Variant,
88 Bitmap,
89 Geometry,
90 Geography,
91 Interval,
92 Vector(u64),
93 }
95
96impl DataType {
97 pub fn is_numeric(&self) -> bool {
98 match self {
99 DataType::Number(_) | DataType::Decimal(_) => true,
100 DataType::Nullable(inner) => inner.is_numeric(),
101 _ => false,
102 }
103 }
104}
105
106impl std::fmt::Display for DataType {
107 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
108 match self {
109 DataType::Null => write!(f, "Null"),
110 DataType::EmptyArray => write!(f, "EmptyArray"),
111 DataType::EmptyMap => write!(f, "EmptyMap"),
112 DataType::Boolean => write!(f, "Boolean"),
113 DataType::Binary => write!(f, "Binary"),
114 DataType::String => write!(f, "String"),
115 DataType::Number(n) => match n {
116 NumberDataType::UInt8 => write!(f, "UInt8"),
117 NumberDataType::UInt16 => write!(f, "UInt16"),
118 NumberDataType::UInt32 => write!(f, "UInt32"),
119 NumberDataType::UInt64 => write!(f, "UInt64"),
120 NumberDataType::Int8 => write!(f, "Int8"),
121 NumberDataType::Int16 => write!(f, "Int16"),
122 NumberDataType::Int32 => write!(f, "Int32"),
123 NumberDataType::Int64 => write!(f, "Int64"),
124 NumberDataType::Float32 => write!(f, "Float32"),
125 NumberDataType::Float64 => write!(f, "Float64"),
126 },
127 DataType::Decimal(d) => {
128 let size = d.decimal_size();
129 write!(f, "Decimal({}, {})", size.precision, size.scale)
130 }
131 DataType::Timestamp => write!(f, "Timestamp"),
132 DataType::TimestampTz => write!(f, "Timestamp_Tz"),
133 DataType::Date => write!(f, "Date"),
134 DataType::Nullable(inner) => write!(f, "Nullable({inner})"),
135 DataType::Array(inner) => write!(f, "Array({inner})"),
136 DataType::Map(inner) => match inner.as_ref() {
137 DataType::Tuple(tys) => {
138 write!(f, "Map({}, {})", tys[0], tys[1])
139 }
140 _ => unreachable!(),
141 },
142 DataType::Tuple(inner) => {
143 let inner = inner
144 .iter()
145 .map(|x| x.to_string())
146 .collect::<Vec<_>>()
147 .join(", ");
148 write!(f, "Tuple({inner})")
149 }
150 DataType::Variant => write!(f, "Variant"),
151 DataType::Bitmap => write!(f, "Bitmap"),
152 DataType::Geometry => write!(f, "Geometry"),
153 DataType::Geography => write!(f, "Geography"),
154 DataType::Interval => write!(f, "Interval"),
155 DataType::Vector(d) => write!(f, "Vector({d})"),
156 }
157 }
158}
159
160#[derive(Debug, Clone)]
161pub struct Field {
162 pub name: String,
163 pub data_type: DataType,
164}
165
166#[derive(Debug, Clone, Default)]
167pub struct Schema(Vec<Field>);
168
169pub type SchemaRef = Arc<Schema>;
170
171impl Schema {
172 pub fn fields(&self) -> &[Field] {
173 &self.0
174 }
175
176 pub fn from_vec(fields: Vec<Field>) -> Self {
177 Self(fields)
178 }
179}
180
181impl TryFrom<&TypeDesc<'_>> for DataType {
182 type Error = Error;
183
184 fn try_from(desc: &TypeDesc) -> Result<Self> {
185 if desc.nullable {
186 let mut desc = desc.clone();
187 desc.nullable = false;
188 let inner = DataType::try_from(&desc)?;
189 return Ok(DataType::Nullable(Box::new(inner)));
190 }
191 let dt = match desc.name {
192 "NULL" | "Null" => DataType::Null,
193 "Boolean" => DataType::Boolean,
194 "Binary" => DataType::Binary,
195 "String" => DataType::String,
196 "Int8" => DataType::Number(NumberDataType::Int8),
197 "Int16" => DataType::Number(NumberDataType::Int16),
198 "Int32" => DataType::Number(NumberDataType::Int32),
199 "Int64" => DataType::Number(NumberDataType::Int64),
200 "UInt8" => DataType::Number(NumberDataType::UInt8),
201 "UInt16" => DataType::Number(NumberDataType::UInt16),
202 "UInt32" => DataType::Number(NumberDataType::UInt32),
203 "UInt64" => DataType::Number(NumberDataType::UInt64),
204 "Float32" => DataType::Number(NumberDataType::Float32),
205 "Float64" => DataType::Number(NumberDataType::Float64),
206 "Decimal" => {
207 let precision = desc.args[0].name.parse::<u8>()?;
208 let scale = desc.args[1].name.parse::<u8>()?;
209
210 if precision <= 38 {
211 DataType::Decimal(DecimalDataType::Decimal128(DecimalSize {
212 precision,
213 scale,
214 }))
215 } else {
216 DataType::Decimal(DecimalDataType::Decimal256(DecimalSize {
217 precision,
218 scale,
219 }))
220 }
221 }
222 "Timestamp" => DataType::Timestamp,
223 "Date" => DataType::Date,
224 "Nullable" => {
225 if desc.args.len() != 1 {
226 return Err(Error::Decode(
227 "Nullable type must have one argument".to_string(),
228 ));
229 }
230 let mut desc = desc.clone();
231 desc.nullable = false;
233 let inner = Self::try_from(&desc.args[0])?;
234 DataType::Nullable(Box::new(inner))
235 }
236 "Array" => {
237 if desc.args.len() != 1 {
238 return Err(Error::Decode(
239 "Array type must have one argument".to_string(),
240 ));
241 }
242 if desc.args[0].name == "Nothing" {
243 DataType::EmptyArray
244 } else {
245 let inner = Self::try_from(&desc.args[0])?;
246 DataType::Array(Box::new(inner))
247 }
248 }
249 "Map" => {
250 if desc.args.len() == 1 && desc.args[0].name == "Nothing" {
251 DataType::EmptyMap
252 } else {
253 if desc.args.len() != 2 {
254 return Err(Error::Decode(
255 "Map type must have two arguments".to_string(),
256 ));
257 }
258 let key_ty = Self::try_from(&desc.args[0])?;
259 let val_ty = Self::try_from(&desc.args[1])?;
260 DataType::Map(Box::new(DataType::Tuple(vec![key_ty, val_ty])))
261 }
262 }
263 "Tuple" => {
264 let mut inner = vec![];
265 for arg in &desc.args {
266 inner.push(Self::try_from(arg)?);
267 }
268 DataType::Tuple(inner)
269 }
270 "Variant" => DataType::Variant,
271 "Bitmap" => DataType::Bitmap,
272 "Geometry" => DataType::Geometry,
273 "Geography" => DataType::Geography,
274 "Interval" => DataType::Interval,
275 "Vector" => {
276 let dimension = desc.args[0].name.parse::<u64>()?;
277 DataType::Vector(dimension)
278 }
279 "Timestamp_Tz" => DataType::TimestampTz,
280 _ => return Err(Error::Decode(format!("Unknown type: {desc:?}"))),
281 };
282 Ok(dt)
283 }
284}
285
286impl TryFrom<APISchemaField> for Field {
287 type Error = Error;
288
289 fn try_from(f: APISchemaField) -> Result<Self> {
290 let type_desc = parse_type_desc(&f.data_type)?;
291 let dt = DataType::try_from(&type_desc)?;
292 let field = Self {
293 name: f.name,
294 data_type: dt,
295 };
296 Ok(field)
297 }
298}
299
300impl TryFrom<Vec<APISchemaField>> for Schema {
301 type Error = Error;
302
303 fn try_from(fields: Vec<APISchemaField>) -> Result<Self> {
304 let fields = fields
305 .into_iter()
306 .map(Field::try_from)
307 .collect::<Result<Vec<_>>>()?;
308 Ok(Self(fields))
309 }
310}
311
312impl TryFrom<&Arc<ArrowField>> for Field {
313 type Error = Error;
314
315 fn try_from(f: &Arc<ArrowField>) -> Result<Self> {
316 let mut dt = if let Some(extend_type) = f.metadata().get(EXTENSION_KEY) {
317 match extend_type.as_str() {
318 ARROW_EXT_TYPE_EMPTY_ARRAY => DataType::EmptyArray,
319 ARROW_EXT_TYPE_EMPTY_MAP => DataType::EmptyMap,
320 ARROW_EXT_TYPE_VARIANT => DataType::Variant,
321 ARROW_EXT_TYPE_BITMAP => DataType::Bitmap,
322 ARROW_EXT_TYPE_GEOMETRY => DataType::Geometry,
323 ARROW_EXT_TYPE_GEOGRAPHY => DataType::Geography,
324 ARROW_EXT_TYPE_INTERVAL => DataType::Interval,
325 ARROW_EXT_TYPE_TIMESTAMP_TIMEZONE => DataType::TimestampTz,
326 ARROW_EXT_TYPE_VECTOR => match f.data_type() {
327 ArrowDataType::FixedSizeList(field, dimension) => {
328 let dimension = match field.data_type() {
329 ArrowDataType::Float32 => *dimension as u64,
330 _ => {
331 return Err(Error::Decode(format!(
332 "Unsupported FixedSizeList Arrow type: {:?}",
333 field.data_type()
334 )));
335 }
336 };
337 DataType::Vector(dimension)
338 }
339 arrow_type => {
340 return Err(Error::Decode(format!(
341 "Unsupported Arrow type: {arrow_type:?}",
342 )));
343 }
344 },
345 _ => {
346 return Err(Error::Decode(format!(
347 "Unsupported extension datatype for arrow field: {f:?}"
348 )))
349 }
350 }
351 } else {
352 match f.data_type() {
353 ArrowDataType::Null => DataType::Null,
354 ArrowDataType::Boolean => DataType::Boolean,
355 ArrowDataType::Int8 => DataType::Number(NumberDataType::Int8),
356 ArrowDataType::Int16 => DataType::Number(NumberDataType::Int16),
357 ArrowDataType::Int32 => DataType::Number(NumberDataType::Int32),
358 ArrowDataType::Int64 => DataType::Number(NumberDataType::Int64),
359 ArrowDataType::UInt8 => DataType::Number(NumberDataType::UInt8),
360 ArrowDataType::UInt16 => DataType::Number(NumberDataType::UInt16),
361 ArrowDataType::UInt32 => DataType::Number(NumberDataType::UInt32),
362 ArrowDataType::UInt64 => DataType::Number(NumberDataType::UInt64),
363 ArrowDataType::Float32 => DataType::Number(NumberDataType::Float32),
364 ArrowDataType::Float64 => DataType::Number(NumberDataType::Float64),
365 ArrowDataType::Binary
366 | ArrowDataType::LargeBinary
367 | ArrowDataType::FixedSizeBinary(_) => DataType::Binary,
368 ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Utf8View => {
369 DataType::String
370 }
371 ArrowDataType::Timestamp(_, _) => DataType::Timestamp,
372 ArrowDataType::Date32 => DataType::Date,
373 ArrowDataType::Decimal128(p, s) => {
374 DataType::Decimal(DecimalDataType::Decimal128(DecimalSize {
375 precision: *p,
376 scale: *s as u8,
377 }))
378 }
379 ArrowDataType::Decimal256(p, s) => {
380 DataType::Decimal(DecimalDataType::Decimal256(DecimalSize {
381 precision: *p,
382 scale: *s as u8,
383 }))
384 }
385 ArrowDataType::List(f) | ArrowDataType::LargeList(f) => {
386 let inner_field = Field::try_from(f)?;
387 let inner_ty = inner_field.data_type;
388 DataType::Array(Box::new(inner_ty))
389 }
390 ArrowDataType::Map(f, _) => {
391 let inner_field = Field::try_from(f)?;
392 let inner_ty = inner_field.data_type;
393 DataType::Map(Box::new(inner_ty))
394 }
395 ArrowDataType::Struct(fs) => {
396 let mut inner_tys = Vec::with_capacity(fs.len());
397 for f in fs {
398 let inner_field = Field::try_from(f)?;
399 let inner_ty = inner_field.data_type;
400 inner_tys.push(inner_ty);
401 }
402 DataType::Tuple(inner_tys)
403 }
404 _ => {
405 return Err(Error::Decode(format!(
406 "Unsupported datatype for arrow field: {f:?}"
407 )))
408 }
409 }
410 };
411 if f.is_nullable() && !matches!(dt, DataType::Null) {
412 dt = DataType::Nullable(Box::new(dt));
413 }
414 Ok(Field {
415 name: f.name().to_string(),
416 data_type: dt,
417 })
418 }
419}
420
421impl TryFrom<ArrowSchemaRef> for Schema {
422 type Error = Error;
423
424 fn try_from(schema_ref: ArrowSchemaRef) -> Result<Self> {
425 let fields = schema_ref
426 .fields()
427 .iter()
428 .map(Field::try_from)
429 .collect::<Result<Vec<_>>>()?;
430 Ok(Self(fields))
431 }
432}
433
434#[derive(Debug, Clone, PartialEq, Eq)]
435struct TypeDesc<'t> {
436 name: &'t str,
437 nullable: bool,
438 args: Vec<TypeDesc<'t>>,
439}
440
441fn parse_type_desc(s: &str) -> Result<TypeDesc<'_>> {
442 let mut name = "";
443 let mut args = vec![];
444 let mut depth = 0;
445 let mut start = 0;
446 let mut nullable = false;
447 for (i, c) in s.char_indices() {
448 match c {
449 '(' => {
450 if depth == 0 {
451 name = &s[start..i];
452 start = i + 1;
453 }
454 depth += 1;
455 }
456 ')' => {
457 depth -= 1;
458 if depth == 0 {
459 let s = &s[start..i];
460 if !s.is_empty() {
461 args.push(parse_type_desc(s)?);
462 }
463 start = i + 1;
464 }
465 }
466 ',' => {
467 if depth == 1 {
468 let s = &s[start..i];
469 args.push(parse_type_desc(s)?);
470 start = i + 1;
471 }
472 }
473 ' ' => {
474 if depth == 0 {
475 let s = &s[start..i];
476 if !s.is_empty() {
477 name = s;
478 }
479 start = i + 1;
480 }
481 }
482 _ => {}
483 }
484 }
485 if depth != 0 {
486 return Err(Error::Decode(format!("Invalid type desc: {s}")));
487 }
488 if start < s.len() {
489 let s = &s[start..];
490 if !s.is_empty() {
491 if name.is_empty() {
492 name = s;
493 } else if s == "NULL" {
494 nullable = true;
495 } else {
496 return Err(Error::Decode(format!("Invalid type arg for {name}: {s}")));
497 }
498 }
499 }
500 Ok(TypeDesc {
501 name,
502 nullable,
503 args,
504 })
505}
506
507#[cfg(test)]
508mod test {
509 use std::vec;
510
511 use super::*;
512
513 #[test]
514 fn test_parse_type_desc() {
515 struct TestCase<'t> {
516 desc: &'t str,
517 input: &'t str,
518 output: TypeDesc<'t>,
519 }
520 let test_cases = vec![
521 TestCase {
522 desc: "plain type",
523 input: "String",
524 output: TypeDesc {
525 name: "String",
526 nullable: false,
527 args: vec![],
528 },
529 },
530 TestCase {
531 desc: "decimal type",
532 input: "Decimal(42, 42)",
533 output: TypeDesc {
534 name: "Decimal",
535 nullable: false,
536 args: vec![
537 TypeDesc {
538 name: "42",
539 nullable: false,
540 args: vec![],
541 },
542 TypeDesc {
543 name: "42",
544 nullable: false,
545 args: vec![],
546 },
547 ],
548 },
549 },
550 TestCase {
551 desc: "nullable type",
552 input: "Nullable(Nothing)",
553 output: TypeDesc {
554 name: "Nullable",
555 nullable: false,
556 args: vec![TypeDesc {
557 name: "Nothing",
558 nullable: false,
559 args: vec![],
560 }],
561 },
562 },
563 TestCase {
564 desc: "empty arg",
565 input: "DateTime()",
566 output: TypeDesc {
567 name: "DateTime",
568 nullable: false,
569 args: vec![],
570 },
571 },
572 TestCase {
573 desc: "numeric arg",
574 input: "FixedString(42)",
575 output: TypeDesc {
576 name: "FixedString",
577 nullable: false,
578 args: vec![TypeDesc {
579 name: "42",
580 nullable: false,
581 args: vec![],
582 }],
583 },
584 },
585 TestCase {
586 desc: "multiple args",
587 input: "Array(Tuple(Tuple(String, String), Tuple(String, UInt64)))",
588 output: TypeDesc {
589 name: "Array",
590 nullable: false,
591 args: vec![TypeDesc {
592 name: "Tuple",
593 nullable: false,
594 args: vec![
595 TypeDesc {
596 name: "Tuple",
597 nullable: false,
598 args: vec![
599 TypeDesc {
600 name: "String",
601 nullable: false,
602 args: vec![],
603 },
604 TypeDesc {
605 name: "String",
606 nullable: false,
607 args: vec![],
608 },
609 ],
610 },
611 TypeDesc {
612 name: "Tuple",
613 nullable: false,
614 args: vec![
615 TypeDesc {
616 name: "String",
617 nullable: false,
618 args: vec![],
619 },
620 TypeDesc {
621 name: "UInt64",
622 nullable: false,
623 args: vec![],
624 },
625 ],
626 },
627 ],
628 }],
629 },
630 },
631 TestCase {
632 desc: "map args",
633 input: "Map(String, Array(Int64))",
634 output: TypeDesc {
635 name: "Map",
636 nullable: false,
637 args: vec![
638 TypeDesc {
639 name: "String",
640 nullable: false,
641 args: vec![],
642 },
643 TypeDesc {
644 name: "Array",
645 nullable: false,
646 args: vec![TypeDesc {
647 name: "Int64",
648 nullable: false,
649 args: vec![],
650 }],
651 },
652 ],
653 },
654 },
655 TestCase {
656 desc: "map nullable value args",
657 input: "Nullable(Map(String, String NULL))",
658 output: TypeDesc {
659 name: "Nullable",
660 nullable: false,
661 args: vec![TypeDesc {
662 name: "Map",
663 nullable: false,
664 args: vec![
665 TypeDesc {
666 name: "String",
667 nullable: false,
668 args: vec![],
669 },
670 TypeDesc {
671 name: "String",
672 nullable: true,
673 args: vec![],
674 },
675 ],
676 }],
677 },
678 },
679 ];
680 for case in test_cases {
681 let output = parse_type_desc(case.input).unwrap();
682 assert_eq!(output, case.output, "{}", case.desc);
683 }
684 }
685
686 #[test]
687 fn test_parse_complex_type_with_null() {
688 struct TestCase<'t> {
689 desc: &'t str,
690 input: &'t str,
691 output: TypeDesc<'t>,
692 }
693 let test_cases = vec![
694 TestCase {
695 desc: "complex nullable type",
696 input: "Nullable(Tuple(String NULL, Array(Tuple(Array(Int32 NULL) NULL, Array(String NULL) NULL) NULL) NULL))",
697 output: TypeDesc {
698 name: "Nullable",
699 nullable: false,
700 args: vec![
701 TypeDesc {
702 name: "Tuple",
703 nullable: false,
704 args: vec![
705 TypeDesc {
706 name: "String",
707 nullable: true,
708 args: vec![],
709 },
710 TypeDesc {
711 name: "Array",
712 nullable: true,
713 args: vec![
714 TypeDesc{
715 name: "Tuple",
716 nullable: true,
717 args: vec![
718 TypeDesc {
719 name: "Array",
720 nullable: true,
721 args: vec![
722 TypeDesc {
723 name: "Int32",
724 nullable: true,
725 args: vec![],
726 },
727 ],
728 },
729 TypeDesc {
730 name: "Array",
731 nullable: true,
732 args: vec![
733 TypeDesc {
734 name: "String",
735 nullable: true,
736 args: vec![],
737 },
738 ],
739 },
740 ]
741 }
742 ],
743 },
744 ],
745 },
746 ],
747 },
748 },
749 ];
750 for case in test_cases {
751 let output = parse_type_desc(case.input).unwrap();
752 assert_eq!(output, case.output, "{}", case.desc);
753 }
754 }
755}