1use std::sync::Arc;
16
17use databend_client::SchemaField as APISchemaField;
18
19use crate::error::{Error, Result};
20
21#[cfg(feature = "flight-sql")]
22use arrow_schema::{DataType as ArrowDataType, Field as ArrowField, SchemaRef as ArrowSchemaRef};
23
24#[cfg(feature = "flight-sql")]
26pub(crate) const EXTENSION_KEY: &str = "Extension";
27#[cfg(feature = "flight-sql")]
28pub(crate) const ARROW_EXT_TYPE_EMPTY_ARRAY: &str = "EmptyArray";
29#[cfg(feature = "flight-sql")]
30pub(crate) const ARROW_EXT_TYPE_EMPTY_MAP: &str = "EmptyMap";
31#[cfg(feature = "flight-sql")]
32pub(crate) const ARROW_EXT_TYPE_VARIANT: &str = "Variant";
33#[cfg(feature = "flight-sql")]
34pub(crate) const ARROW_EXT_TYPE_BITMAP: &str = "Bitmap";
35#[cfg(feature = "flight-sql")]
36pub(crate) const ARROW_EXT_TYPE_GEOMETRY: &str = "Geometry";
37#[cfg(feature = "flight-sql")]
38pub(crate) const ARROW_EXT_TYPE_GEOGRAPHY: &str = "Geography";
39#[cfg(feature = "flight-sql")]
40pub(crate) const ARROW_EXT_TYPE_INTERVAL: &str = "Interval";
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum NumberDataType {
44 UInt8,
45 UInt16,
46 UInt32,
47 UInt64,
48 Int8,
49 Int16,
50 Int32,
51 Int64,
52 Float32,
53 Float64,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub struct DecimalSize {
58 pub precision: u8,
59 pub scale: u8,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum DecimalDataType {
64 Decimal128(DecimalSize),
65 Decimal256(DecimalSize),
66}
67
68impl DecimalDataType {
69 pub fn decimal_size(&self) -> &DecimalSize {
70 match self {
71 DecimalDataType::Decimal128(size) => size,
72 DecimalDataType::Decimal256(size) => size,
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
78pub enum DataType {
79 Null,
80 EmptyArray,
81 EmptyMap,
82 Boolean,
83 Binary,
84 String,
85 Number(NumberDataType),
86 Decimal(DecimalDataType),
87 Timestamp,
88 Date,
89 Nullable(Box<DataType>),
90 Array(Box<DataType>),
91 Map(Box<DataType>),
92 Tuple(Vec<DataType>),
93 Variant,
94 Bitmap,
95 Geometry,
96 Geography,
97 Interval,
98 }
100
101impl DataType {
102 pub fn is_numeric(&self) -> bool {
103 match self {
104 DataType::Number(_) | DataType::Decimal(_) => true,
105 DataType::Nullable(inner) => inner.is_numeric(),
106 _ => false,
107 }
108 }
109}
110
111impl std::fmt::Display for DataType {
112 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
113 match self {
114 DataType::Null => write!(f, "Null"),
115 DataType::EmptyArray => write!(f, "EmptyArray"),
116 DataType::EmptyMap => write!(f, "EmptyMap"),
117 DataType::Boolean => write!(f, "Boolean"),
118 DataType::Binary => write!(f, "Binary"),
119 DataType::String => write!(f, "String"),
120 DataType::Number(n) => match n {
121 NumberDataType::UInt8 => write!(f, "UInt8"),
122 NumberDataType::UInt16 => write!(f, "UInt16"),
123 NumberDataType::UInt32 => write!(f, "UInt32"),
124 NumberDataType::UInt64 => write!(f, "UInt64"),
125 NumberDataType::Int8 => write!(f, "Int8"),
126 NumberDataType::Int16 => write!(f, "Int16"),
127 NumberDataType::Int32 => write!(f, "Int32"),
128 NumberDataType::Int64 => write!(f, "Int64"),
129 NumberDataType::Float32 => write!(f, "Float32"),
130 NumberDataType::Float64 => write!(f, "Float64"),
131 },
132 DataType::Decimal(d) => {
133 let size = d.decimal_size();
134 write!(f, "Decimal({}, {})", size.precision, size.scale)
135 }
136 DataType::Timestamp => write!(f, "Timestamp"),
137 DataType::Date => write!(f, "Date"),
138 DataType::Nullable(inner) => write!(f, "Nullable({})", inner),
139 DataType::Array(inner) => write!(f, "Array({})", inner),
140 DataType::Map(inner) => match inner.as_ref() {
141 DataType::Tuple(tys) => {
142 write!(f, "Map({}, {})", tys[0], tys[1])
143 }
144 _ => unreachable!(),
145 },
146 DataType::Tuple(inner) => {
147 let inner = inner
148 .iter()
149 .map(|x| x.to_string())
150 .collect::<Vec<_>>()
151 .join(", ");
152 write!(f, "Tuple({})", inner)
153 }
154 DataType::Variant => write!(f, "Variant"),
155 DataType::Bitmap => write!(f, "Bitmap"),
156 DataType::Geometry => write!(f, "Geometry"),
157 DataType::Geography => write!(f, "Geography"),
158 DataType::Interval => write!(f, "Interval"),
159 }
160 }
161}
162
163#[derive(Debug, Clone)]
164pub struct Field {
165 pub name: String,
166 pub data_type: DataType,
167}
168
169#[derive(Debug, Clone, Default)]
170pub struct Schema(Vec<Field>);
171
172pub type SchemaRef = Arc<Schema>;
173
174impl Schema {
175 pub fn fields(&self) -> &[Field] {
176 &self.0
177 }
178
179 pub fn from_vec(fields: Vec<Field>) -> Self {
180 Self(fields)
181 }
182}
183
184impl TryFrom<&TypeDesc<'_>> for DataType {
185 type Error = Error;
186
187 fn try_from(desc: &TypeDesc) -> Result<Self> {
188 if desc.nullable {
189 let mut desc = desc.clone();
190 desc.nullable = false;
191 let inner = DataType::try_from(&desc)?;
192 return Ok(DataType::Nullable(Box::new(inner)));
193 }
194 let dt = match desc.name {
195 "NULL" | "Null" => DataType::Null,
196 "Boolean" => DataType::Boolean,
197 "Binary" => DataType::Binary,
198 "String" => DataType::String,
199 "Int8" => DataType::Number(NumberDataType::Int8),
200 "Int16" => DataType::Number(NumberDataType::Int16),
201 "Int32" => DataType::Number(NumberDataType::Int32),
202 "Int64" => DataType::Number(NumberDataType::Int64),
203 "UInt8" => DataType::Number(NumberDataType::UInt8),
204 "UInt16" => DataType::Number(NumberDataType::UInt16),
205 "UInt32" => DataType::Number(NumberDataType::UInt32),
206 "UInt64" => DataType::Number(NumberDataType::UInt64),
207 "Float32" => DataType::Number(NumberDataType::Float32),
208 "Float64" => DataType::Number(NumberDataType::Float64),
209 "Decimal" => {
210 let precision = desc.args[0].name.parse::<u8>()?;
211 let scale = desc.args[1].name.parse::<u8>()?;
212
213 if precision <= 38 {
214 DataType::Decimal(DecimalDataType::Decimal128(DecimalSize {
215 precision,
216 scale,
217 }))
218 } else {
219 DataType::Decimal(DecimalDataType::Decimal256(DecimalSize {
220 precision,
221 scale,
222 }))
223 }
224 }
225 "Timestamp" => DataType::Timestamp,
226 "Date" => DataType::Date,
227 "Nullable" => {
228 if desc.args.len() != 1 {
229 return Err(Error::Parsing(
230 "Nullable type must have one argument".to_string(),
231 ));
232 }
233 let mut desc = desc.clone();
234 desc.nullable = false;
236 let inner = Self::try_from(&desc.args[0])?;
237 DataType::Nullable(Box::new(inner))
238 }
239 "Array" => {
240 if desc.args.len() != 1 {
241 return Err(Error::Parsing(
242 "Array type must have one argument".to_string(),
243 ));
244 }
245 if desc.args[0].name == "Nothing" {
246 DataType::EmptyArray
247 } else {
248 let inner = Self::try_from(&desc.args[0])?;
249 DataType::Array(Box::new(inner))
250 }
251 }
252 "Map" => {
253 if desc.args.len() == 1 && desc.args[0].name == "Nothing" {
254 DataType::EmptyMap
255 } else {
256 if desc.args.len() != 2 {
257 return Err(Error::Parsing(
258 "Map type must have two arguments".to_string(),
259 ));
260 }
261 let key_ty = Self::try_from(&desc.args[0])?;
262 let val_ty = Self::try_from(&desc.args[1])?;
263 DataType::Map(Box::new(DataType::Tuple(vec![key_ty, val_ty])))
264 }
265 }
266 "Tuple" => {
267 let mut inner = vec![];
268 for arg in &desc.args {
269 inner.push(Self::try_from(arg)?);
270 }
271 DataType::Tuple(inner)
272 }
273 "Variant" => DataType::Variant,
274 "Bitmap" => DataType::Bitmap,
275 "Geometry" => DataType::Geometry,
276 "Geography" => DataType::Geography,
277 "Interval" => DataType::Interval,
278 _ => return Err(Error::Parsing(format!("Unknown type: {:?}", desc))),
279 };
280 Ok(dt)
281 }
282}
283
284impl TryFrom<APISchemaField> for Field {
285 type Error = Error;
286
287 fn try_from(f: APISchemaField) -> Result<Self> {
288 let type_desc = parse_type_desc(&f.data_type)?;
289 let dt = DataType::try_from(&type_desc)?;
290 let field = Self {
291 name: f.name,
292 data_type: dt,
293 };
294 Ok(field)
295 }
296}
297
298impl TryFrom<Vec<APISchemaField>> for Schema {
299 type Error = Error;
300
301 fn try_from(fields: Vec<APISchemaField>) -> Result<Self> {
302 let fields = fields
303 .into_iter()
304 .map(Field::try_from)
305 .collect::<Result<Vec<_>>>()?;
306 Ok(Self(fields))
307 }
308}
309
310#[cfg(feature = "flight-sql")]
311impl TryFrom<&Arc<ArrowField>> for Field {
312 type Error = Error;
313
314 fn try_from(f: &Arc<ArrowField>) -> Result<Self> {
315 let mut dt = if let Some(extend_type) = f.metadata().get(EXTENSION_KEY) {
316 match extend_type.as_str() {
317 ARROW_EXT_TYPE_EMPTY_ARRAY => DataType::EmptyArray,
318 ARROW_EXT_TYPE_EMPTY_MAP => DataType::EmptyMap,
319 ARROW_EXT_TYPE_VARIANT => DataType::Variant,
320 ARROW_EXT_TYPE_BITMAP => DataType::Bitmap,
321 ARROW_EXT_TYPE_GEOMETRY => DataType::Geometry,
322 ARROW_EXT_TYPE_GEOGRAPHY => DataType::Geography,
323 _ => {
324 return Err(Error::Parsing(format!(
325 "Unsupported extension datatype for arrow field: {:?}",
326 f
327 )))
328 }
329 }
330 } else {
331 match f.data_type() {
332 ArrowDataType::Null => DataType::Null,
333 ArrowDataType::Boolean => DataType::Boolean,
334 ArrowDataType::Int8 => DataType::Number(NumberDataType::Int8),
335 ArrowDataType::Int16 => DataType::Number(NumberDataType::Int16),
336 ArrowDataType::Int32 => DataType::Number(NumberDataType::Int32),
337 ArrowDataType::Int64 => DataType::Number(NumberDataType::Int64),
338 ArrowDataType::UInt8 => DataType::Number(NumberDataType::UInt8),
339 ArrowDataType::UInt16 => DataType::Number(NumberDataType::UInt16),
340 ArrowDataType::UInt32 => DataType::Number(NumberDataType::UInt32),
341 ArrowDataType::UInt64 => DataType::Number(NumberDataType::UInt64),
342 ArrowDataType::Float32 => DataType::Number(NumberDataType::Float32),
343 ArrowDataType::Float64 => DataType::Number(NumberDataType::Float64),
344 ArrowDataType::Binary
345 | ArrowDataType::LargeBinary
346 | ArrowDataType::FixedSizeBinary(_) => DataType::Binary,
347 ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Utf8View => {
348 DataType::String
349 }
350 ArrowDataType::Timestamp(_, _) => DataType::Timestamp,
351 ArrowDataType::Date32 => DataType::Date,
352 ArrowDataType::Decimal128(p, s) => {
353 DataType::Decimal(DecimalDataType::Decimal128(DecimalSize {
354 precision: *p,
355 scale: *s as u8,
356 }))
357 }
358 ArrowDataType::Decimal256(p, s) => {
359 DataType::Decimal(DecimalDataType::Decimal256(DecimalSize {
360 precision: *p,
361 scale: *s as u8,
362 }))
363 }
364 ArrowDataType::List(f) | ArrowDataType::LargeList(f) => {
365 let inner_field = Field::try_from(f)?;
366 let inner_ty = inner_field.data_type;
367 DataType::Array(Box::new(inner_ty))
368 }
369 ArrowDataType::Map(f, _) => {
370 let inner_field = Field::try_from(f)?;
371 let inner_ty = inner_field.data_type;
372 DataType::Map(Box::new(inner_ty))
373 }
374 ArrowDataType::Struct(fs) => {
375 let mut inner_tys = Vec::with_capacity(fs.len());
376 for f in fs {
377 let inner_field = Field::try_from(f)?;
378 let inner_ty = inner_field.data_type;
379 inner_tys.push(inner_ty);
380 }
381 DataType::Tuple(inner_tys)
382 }
383 _ => {
384 return Err(Error::Parsing(format!(
385 "Unsupported datatype for arrow field: {:?}",
386 f
387 )))
388 }
389 }
390 };
391 if f.is_nullable() && !matches!(dt, DataType::Null) {
392 dt = DataType::Nullable(Box::new(dt));
393 }
394 Ok(Field {
395 name: f.name().to_string(),
396 data_type: dt,
397 })
398 }
399}
400
401#[cfg(feature = "flight-sql")]
402impl TryFrom<ArrowSchemaRef> for Schema {
403 type Error = Error;
404
405 fn try_from(schema_ref: ArrowSchemaRef) -> Result<Self> {
406 let fields = schema_ref
407 .fields()
408 .iter()
409 .map(Field::try_from)
410 .collect::<Result<Vec<_>>>()?;
411 Ok(Self(fields))
412 }
413}
414
415#[derive(Debug, Clone, PartialEq, Eq)]
416struct TypeDesc<'t> {
417 name: &'t str,
418 nullable: bool,
419 args: Vec<TypeDesc<'t>>,
420}
421
422fn parse_type_desc(s: &str) -> Result<TypeDesc> {
423 let mut name = "";
424 let mut args = vec![];
425 let mut depth = 0;
426 let mut start = 0;
427 let mut nullable = false;
428 for (i, c) in s.chars().enumerate() {
429 match c {
430 '(' => {
431 if depth == 0 {
432 name = &s[start..i];
433 start = i + 1;
434 }
435 depth += 1;
436 }
437 ')' => {
438 depth -= 1;
439 if depth == 0 {
440 let s = &s[start..i];
441 if !s.is_empty() {
442 args.push(parse_type_desc(s)?);
443 }
444 start = i + 1;
445 }
446 }
447 ',' => {
448 if depth == 1 {
449 let s = &s[start..i];
450 args.push(parse_type_desc(s)?);
451 start = i + 1;
452 }
453 }
454 ' ' => {
455 if depth == 0 {
456 let s = &s[start..i];
457 if !s.is_empty() {
458 name = s;
459 }
460 start = i + 1;
461 }
462 }
463 _ => {}
464 }
465 }
466 if depth != 0 {
467 return Err(Error::Parsing(format!("Invalid type desc: {}", s)));
468 }
469 if start < s.len() {
470 let s = &s[start..];
471 if !s.is_empty() {
472 if name.is_empty() {
473 name = s;
474 } else if s == "NULL" {
475 nullable = true;
476 } else {
477 return Err(Error::Parsing(format!(
478 "Invalid type arg for {}: {}",
479 name, s
480 )));
481 }
482 }
483 }
484 Ok(TypeDesc {
485 name,
486 nullable,
487 args,
488 })
489}
490
491#[cfg(test)]
492mod test {
493 use std::vec;
494
495 use super::*;
496
497 #[test]
498 fn test_parse_type_desc() {
499 struct TestCase<'t> {
500 desc: &'t str,
501 input: &'t str,
502 output: TypeDesc<'t>,
503 }
504 let test_cases = vec![
505 TestCase {
506 desc: "plain type",
507 input: "String",
508 output: TypeDesc {
509 name: "String",
510 nullable: false,
511 args: vec![],
512 },
513 },
514 TestCase {
515 desc: "decimal type",
516 input: "Decimal(42, 42)",
517 output: TypeDesc {
518 name: "Decimal",
519 nullable: false,
520 args: vec![
521 TypeDesc {
522 name: "42",
523 nullable: false,
524 args: vec![],
525 },
526 TypeDesc {
527 name: "42",
528 nullable: false,
529 args: vec![],
530 },
531 ],
532 },
533 },
534 TestCase {
535 desc: "nullable type",
536 input: "Nullable(Nothing)",
537 output: TypeDesc {
538 name: "Nullable",
539 nullable: false,
540 args: vec![TypeDesc {
541 name: "Nothing",
542 nullable: false,
543 args: vec![],
544 }],
545 },
546 },
547 TestCase {
548 desc: "empty arg",
549 input: "DateTime()",
550 output: TypeDesc {
551 name: "DateTime",
552 nullable: false,
553 args: vec![],
554 },
555 },
556 TestCase {
557 desc: "numeric arg",
558 input: "FixedString(42)",
559 output: TypeDesc {
560 name: "FixedString",
561 nullable: false,
562 args: vec![TypeDesc {
563 name: "42",
564 nullable: false,
565 args: vec![],
566 }],
567 },
568 },
569 TestCase {
570 desc: "multiple args",
571 input: "Array(Tuple(Tuple(String, String), Tuple(String, UInt64)))",
572 output: TypeDesc {
573 name: "Array",
574 nullable: false,
575 args: vec![TypeDesc {
576 name: "Tuple",
577 nullable: false,
578 args: vec![
579 TypeDesc {
580 name: "Tuple",
581 nullable: false,
582 args: vec![
583 TypeDesc {
584 name: "String",
585 nullable: false,
586 args: vec![],
587 },
588 TypeDesc {
589 name: "String",
590 nullable: false,
591 args: vec![],
592 },
593 ],
594 },
595 TypeDesc {
596 name: "Tuple",
597 nullable: false,
598 args: vec![
599 TypeDesc {
600 name: "String",
601 nullable: false,
602 args: vec![],
603 },
604 TypeDesc {
605 name: "UInt64",
606 nullable: false,
607 args: vec![],
608 },
609 ],
610 },
611 ],
612 }],
613 },
614 },
615 TestCase {
616 desc: "map args",
617 input: "Map(String, Array(Int64))",
618 output: TypeDesc {
619 name: "Map",
620 nullable: false,
621 args: vec![
622 TypeDesc {
623 name: "String",
624 nullable: false,
625 args: vec![],
626 },
627 TypeDesc {
628 name: "Array",
629 nullable: false,
630 args: vec![TypeDesc {
631 name: "Int64",
632 nullable: false,
633 args: vec![],
634 }],
635 },
636 ],
637 },
638 },
639 TestCase {
640 desc: "map nullable value args",
641 input: "Nullable(Map(String, String NULL))",
642 output: TypeDesc {
643 name: "Nullable",
644 nullable: false,
645 args: vec![TypeDesc {
646 name: "Map",
647 nullable: false,
648 args: vec![
649 TypeDesc {
650 name: "String",
651 nullable: false,
652 args: vec![],
653 },
654 TypeDesc {
655 name: "String",
656 nullable: true,
657 args: vec![],
658 },
659 ],
660 }],
661 },
662 },
663 ];
664 for case in test_cases {
665 let output = parse_type_desc(case.input).unwrap();
666 assert_eq!(output, case.output, "{}", case.desc);
667 }
668 }
669
670 #[test]
671 fn test_parse_complex_type_with_null() {
672 struct TestCase<'t> {
673 desc: &'t str,
674 input: &'t str,
675 output: TypeDesc<'t>,
676 }
677 let test_cases = vec![
678 TestCase {
679 desc: "complex nullable type",
680 input: "Nullable(Tuple(String NULL, Array(Tuple(Array(Int32 NULL) NULL, Array(String NULL) NULL) NULL) NULL))",
681 output: TypeDesc {
682 name: "Nullable",
683 nullable: false,
684 args: vec![
685 TypeDesc {
686 name: "Tuple",
687 nullable: false,
688 args: vec![
689 TypeDesc {
690 name: "String",
691 nullable: true,
692 args: vec![],
693 },
694 TypeDesc {
695 name: "Array",
696 nullable: true,
697 args: vec![
698 TypeDesc{
699 name: "Tuple",
700 nullable: true,
701 args: vec![
702 TypeDesc {
703 name: "Array",
704 nullable: true,
705 args: vec![
706 TypeDesc {
707 name: "Int32",
708 nullable: true,
709 args: vec![],
710 },
711 ],
712 },
713 TypeDesc {
714 name: "Array",
715 nullable: true,
716 args: vec![
717 TypeDesc {
718 name: "String",
719 nullable: true,
720 args: vec![],
721 },
722 ],
723 },
724 ]
725 }
726 ],
727 },
728 ],
729 },
730 ],
731 },
732 },
733 ];
734 for case in test_cases {
735 let output = parse_type_desc(case.input).unwrap();
736 assert_eq!(output, case.output, "{}", case.desc);
737 }
738 }
739}