1use super::{
19 LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields,
20 TypeSignature,
21};
22use crate::error::{_internal_err, Result};
23use arrow::compute::can_cast_types;
24use arrow::datatypes::{
25 DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION, DataType,
26 Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields,
27};
28use std::{fmt::Display, sync::Arc};
29
30#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
33pub enum NativeType {
34 Null,
36 Boolean,
38 Int8,
40 Int16,
42 Int32,
44 Int64,
46 UInt8,
48 UInt16,
50 UInt32,
52 UInt64,
54 Float16,
56 Float32,
58 Float64,
60 Timestamp(TimeUnit, Option<Arc<str>>),
136 Date,
139 Time(TimeUnit),
141 Duration(TimeUnit),
143 Interval(IntervalUnit),
147 Binary,
149 FixedSizeBinary(i32),
152 String,
154 List(LogicalFieldRef),
156 FixedSizeList(LogicalFieldRef, i32),
158 Struct(LogicalFields),
160 Union(LogicalUnionFields),
162 Decimal(u8, i8),
176 Map(LogicalFieldRef),
185}
186
187fn format_logical_field(
190 f: &mut std::fmt::Formatter<'_>,
191 field: &LogicalField,
192) -> std::fmt::Result {
193 let non_null = if field.nullable { "" } else { "non-null " };
194 write!(f, "{:?}: {non_null}{}", field.name, field.logical_type)
195}
196
197impl Display for NativeType {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 match self {
201 Self::Null => write!(f, "Null"),
202 Self::Boolean => write!(f, "Boolean"),
203 Self::Int8 => write!(f, "Int8"),
204 Self::Int16 => write!(f, "Int16"),
205 Self::Int32 => write!(f, "Int32"),
206 Self::Int64 => write!(f, "Int64"),
207 Self::UInt8 => write!(f, "UInt8"),
208 Self::UInt16 => write!(f, "UInt16"),
209 Self::UInt32 => write!(f, "UInt32"),
210 Self::UInt64 => write!(f, "UInt64"),
211 Self::Float16 => write!(f, "Float16"),
212 Self::Float32 => write!(f, "Float32"),
213 Self::Float64 => write!(f, "Float64"),
214 Self::Timestamp(unit, Some(tz)) => write!(f, "Timestamp({unit}, {tz:?})"),
215 Self::Timestamp(unit, None) => write!(f, "Timestamp({unit})"),
216 Self::Date => write!(f, "Date"),
217 Self::Time(unit) => write!(f, "Time({unit})"),
218 Self::Duration(unit) => write!(f, "Duration({unit})"),
219 Self::Interval(unit) => write!(f, "Interval({unit:?})"),
220 Self::Binary => write!(f, "Binary"),
221 Self::FixedSizeBinary(size) => write!(f, "FixedSizeBinary({size})"),
222 Self::String => write!(f, "String"),
223 Self::List(field) => {
224 let non_null = if field.nullable { "" } else { "non-null " };
225 write!(f, "List({non_null}{})", field.logical_type)
226 }
227 Self::FixedSizeList(field, size) => {
228 let non_null = if field.nullable { "" } else { "non-null " };
229 write!(
230 f,
231 "FixedSizeList({size} x {non_null}{})",
232 field.logical_type
233 )
234 }
235 Self::Struct(fields) => {
236 write!(f, "Struct(")?;
237 for (i, field) in fields.iter().enumerate() {
238 if i > 0 {
239 write!(f, ", ")?;
240 }
241 format_logical_field(f, field)?;
242 }
243 write!(f, ")")
244 }
245 Self::Union(fields) => {
246 write!(f, "Union(")?;
247 for (i, (type_id, field)) in fields.iter().enumerate() {
248 if i > 0 {
249 write!(f, ", ")?;
250 }
251 write!(f, "{type_id}: (")?;
252 format_logical_field(f, field)?;
253 write!(f, ")")?;
254 }
255 write!(f, ")")
256 }
257 Self::Decimal(precision, scale) => write!(f, "Decimal({precision}, {scale})"),
258 Self::Map(field) => {
259 let non_null = if field.nullable { "" } else { "non-null " };
260 write!(f, "Map({non_null}{})", field.logical_type)
261 }
262 }
263 }
264}
265
266impl LogicalType for NativeType {
267 fn native(&self) -> &NativeType {
268 self
269 }
270
271 fn signature(&self) -> TypeSignature<'_> {
272 TypeSignature::Native(self)
273 }
274
275 fn default_cast_for(&self, origin: &DataType) -> Result<DataType> {
281 use DataType::*;
282
283 fn default_field_cast(to: &LogicalField, from: &Field) -> Result<FieldRef> {
284 Ok(Arc::new(Field::new(
285 to.name.clone(),
286 to.logical_type.default_cast_for(from.data_type())?,
287 to.nullable,
288 )))
289 }
290
291 Ok(match (self, origin) {
292 (Self::Null, _) => Null,
293 (Self::Boolean, _) => Boolean,
294 (Self::Int8, _) => Int8,
295 (Self::Int16, _) => Int16,
296 (Self::Int32, _) => Int32,
297 (Self::Int64, _) => Int64,
298 (Self::UInt8, _) => UInt8,
299 (Self::UInt16, _) => UInt16,
300 (Self::UInt32, _) => UInt32,
301 (Self::UInt64, _) => UInt64,
302 (Self::Float16, _) => Float16,
303 (Self::Float32, _) => Float32,
304 (Self::Float64, _) => Float64,
305 (Self::Decimal(p, s), _) if *p <= DECIMAL32_MAX_PRECISION => {
306 Decimal32(*p, *s)
307 }
308 (Self::Decimal(p, s), _) if *p <= DECIMAL64_MAX_PRECISION => {
309 Decimal64(*p, *s)
310 }
311 (Self::Decimal(p, s), _) if *p <= DECIMAL128_MAX_PRECISION => {
312 Decimal128(*p, *s)
313 }
314 (Self::Decimal(p, s), _) => Decimal256(*p, *s),
315 (Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()),
316 (Self::Date, Date32 | Date64) => origin.to_owned(),
318 (Self::Date, _) => Date32,
319 (Self::Time(tu), _) => match tu {
320 TimeUnit::Second | TimeUnit::Millisecond => Time32(*tu),
321 TimeUnit::Microsecond | TimeUnit::Nanosecond => Time64(*tu),
322 },
323 (Self::Duration(tu), _) => Duration(*tu),
324 (Self::Interval(iu), _) => Interval(*iu),
325 (Self::Binary, LargeUtf8) => LargeBinary,
326 (Self::Binary, Utf8View) => BinaryView,
327 (Self::Binary, Binary | LargeBinary | BinaryView) => origin.to_owned(),
329 (Self::Binary, data_type) if can_cast_types(data_type, &BinaryView) => {
330 BinaryView
331 }
332 (Self::Binary, data_type) if can_cast_types(data_type, &LargeBinary) => {
333 LargeBinary
334 }
335 (Self::Binary, data_type) if can_cast_types(data_type, &Binary) => Binary,
336 (Self::FixedSizeBinary(size), _) => FixedSizeBinary(*size),
337 (Self::String, LargeBinary) => LargeUtf8,
338 (Self::String, BinaryView) => Utf8View,
339 (Self::String, Utf8 | LargeUtf8 | Utf8View) => origin.to_owned(),
341 (Self::String, data_type) if can_cast_types(data_type, &Utf8View) => Utf8View,
342 (Self::String, data_type) if can_cast_types(data_type, &LargeUtf8) => {
343 LargeUtf8
344 }
345 (Self::String, data_type) if can_cast_types(data_type, &Utf8) => Utf8,
346 (Self::List(to_field), List(from_field) | FixedSizeList(from_field, _)) => {
347 List(default_field_cast(to_field, from_field)?)
348 }
349 (Self::List(to_field), LargeList(from_field)) => {
350 LargeList(default_field_cast(to_field, from_field)?)
351 }
352 (Self::List(to_field), ListView(from_field)) => {
353 ListView(default_field_cast(to_field, from_field)?)
354 }
355 (Self::List(to_field), LargeListView(from_field)) => {
356 LargeListView(default_field_cast(to_field, from_field)?)
357 }
358 (Self::List(field), _) => List(Arc::new(Field::new(
360 field.name.clone(),
361 field.logical_type.default_cast_for(origin)?,
362 field.nullable,
363 ))),
364 (
365 Self::FixedSizeList(to_field, to_size),
366 FixedSizeList(from_field, from_size),
367 ) if from_size == to_size => {
368 FixedSizeList(default_field_cast(to_field, from_field)?, *to_size)
369 }
370 (
371 Self::FixedSizeList(to_field, size),
372 List(from_field)
373 | LargeList(from_field)
374 | ListView(from_field)
375 | LargeListView(from_field),
376 ) => FixedSizeList(default_field_cast(to_field, from_field)?, *size),
377 (Self::FixedSizeList(field, size), _) => FixedSizeList(
379 Arc::new(Field::new(
380 field.name.clone(),
381 field.logical_type.default_cast_for(origin)?,
382 field.nullable,
383 )),
384 *size,
385 ),
386 (Self::Struct(to_fields), Struct(from_fields))
388 if from_fields.len() == to_fields.len() =>
389 {
390 Struct(
391 from_fields
392 .iter()
393 .zip(to_fields.iter())
394 .map(|(from, to)| default_field_cast(to, from))
395 .collect::<Result<Fields>>()?,
396 )
397 }
398 (Self::Struct(to_fields), Null) => Struct(
399 to_fields
400 .iter()
401 .map(|field| {
402 Ok(Arc::new(Field::new(
403 field.name.clone(),
404 field.logical_type.default_cast_for(&Null)?,
405 field.nullable,
406 )))
407 })
408 .collect::<Result<Fields>>()?,
409 ),
410 (Self::Map(to_field), Map(from_field, sorted)) => {
411 Map(default_field_cast(to_field, from_field)?, *sorted)
412 }
413 (Self::Map(field), Null) => Map(
414 Arc::new(Field::new(
415 field.name.clone(),
416 field.logical_type.default_cast_for(&Null)?,
417 field.nullable,
418 )),
419 false,
420 ),
421 (Self::Union(to_fields), Union(from_fields, mode))
422 if from_fields.len() == to_fields.len() =>
423 {
424 Union(
425 from_fields
426 .iter()
427 .zip(to_fields.iter())
428 .map(|((_, from), (i, to))| {
429 Ok((*i, default_field_cast(to, from)?))
430 })
431 .collect::<Result<UnionFields>>()?,
432 *mode,
433 )
434 }
435 _ => {
436 return _internal_err!(
437 "Unavailable default cast for native type {} from physical type {}",
438 self,
439 origin
440 );
441 }
442 })
443 }
444}
445
446impl From<&DataType> for NativeType {
451 fn from(value: &DataType) -> Self {
452 value.clone().into()
453 }
454}
455
456impl From<DataType> for NativeType {
457 fn from(value: DataType) -> Self {
458 use NativeType::*;
459 match value {
460 DataType::Null => Null,
461 DataType::Boolean => Boolean,
462 DataType::Int8 => Int8,
463 DataType::Int16 => Int16,
464 DataType::Int32 => Int32,
465 DataType::Int64 => Int64,
466 DataType::UInt8 => UInt8,
467 DataType::UInt16 => UInt16,
468 DataType::UInt32 => UInt32,
469 DataType::UInt64 => UInt64,
470 DataType::Float16 => Float16,
471 DataType::Float32 => Float32,
472 DataType::Float64 => Float64,
473 DataType::Timestamp(tu, tz) => Timestamp(tu, tz),
474 DataType::Date32 | DataType::Date64 => Date,
475 DataType::Time32(tu) | DataType::Time64(tu) => Time(tu),
476 DataType::Duration(tu) => Duration(tu),
477 DataType::Interval(iu) => Interval(iu),
478 DataType::Binary | DataType::LargeBinary | DataType::BinaryView => Binary,
479 DataType::FixedSizeBinary(size) => FixedSizeBinary(size),
480 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => String,
481 DataType::List(field)
482 | DataType::ListView(field)
483 | DataType::LargeList(field)
484 | DataType::LargeListView(field) => List(Arc::new(field.as_ref().into())),
485 DataType::FixedSizeList(field, size) => {
486 FixedSizeList(Arc::new(field.as_ref().into()), size)
487 }
488 DataType::Struct(fields) => Struct(LogicalFields::from(&fields)),
489 DataType::Union(union_fields, _) => {
490 Union(LogicalUnionFields::from(&union_fields))
491 }
492 DataType::Decimal32(p, s)
493 | DataType::Decimal64(p, s)
494 | DataType::Decimal128(p, s)
495 | DataType::Decimal256(p, s) => Decimal(p, s),
496 DataType::Map(field, _) => Map(Arc::new(field.as_ref().into())),
497 DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(),
498 DataType::RunEndEncoded(_, field) => field.data_type().clone().into(),
499 }
500 }
501}
502
503impl NativeType {
504 #[inline]
505 pub fn is_numeric(&self) -> bool {
506 self.is_integer() || self.is_float() || self.is_decimal()
507 }
508
509 #[inline]
510 pub fn is_integer(&self) -> bool {
511 use NativeType::*;
512 matches!(
513 self,
514 UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64
515 )
516 }
517
518 #[inline]
519 pub fn is_timestamp(&self) -> bool {
520 matches!(self, NativeType::Timestamp(_, _))
521 }
522
523 #[inline]
524 pub fn is_date(&self) -> bool {
525 *self == NativeType::Date
526 }
527
528 #[inline]
529 pub fn is_time(&self) -> bool {
530 matches!(self, NativeType::Time(_))
531 }
532
533 #[inline]
534 pub fn is_interval(&self) -> bool {
535 matches!(self, NativeType::Interval(_))
536 }
537
538 #[inline]
539 pub fn is_duration(&self) -> bool {
540 matches!(self, NativeType::Duration(_))
541 }
542
543 #[inline]
544 pub fn is_binary(&self) -> bool {
545 matches!(self, NativeType::Binary | NativeType::FixedSizeBinary(_))
546 }
547
548 #[inline]
549 pub fn is_null(&self) -> bool {
550 *self == NativeType::Null
551 }
552
553 #[inline]
554 pub fn is_decimal(&self) -> bool {
555 matches!(self, Self::Decimal(_, _))
556 }
557
558 #[inline]
559 pub fn is_float(&self) -> bool {
560 matches!(self, Self::Float16 | Self::Float32 | Self::Float64)
561 }
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567 use crate::types::LogicalField;
568 use insta::assert_snapshot;
569
570 #[test]
571 fn test_native_type_display() {
572 assert_snapshot!(NativeType::Null, @"Null");
573 assert_snapshot!(NativeType::Boolean, @"Boolean");
574 assert_snapshot!(NativeType::Int8, @"Int8");
575 assert_snapshot!(NativeType::Int16, @"Int16");
576 assert_snapshot!(NativeType::Int32, @"Int32");
577 assert_snapshot!(NativeType::Int64, @"Int64");
578 assert_snapshot!(NativeType::UInt8, @"UInt8");
579 assert_snapshot!(NativeType::UInt16, @"UInt16");
580 assert_snapshot!(NativeType::UInt32, @"UInt32");
581 assert_snapshot!(NativeType::UInt64, @"UInt64");
582 assert_snapshot!(NativeType::Float16, @"Float16");
583 assert_snapshot!(NativeType::Float32, @"Float32");
584 assert_snapshot!(NativeType::Float64, @"Float64");
585 assert_snapshot!(NativeType::Date, @"Date");
586 assert_snapshot!(NativeType::Binary, @"Binary");
587 assert_snapshot!(NativeType::String, @"String");
588 assert_snapshot!(NativeType::FixedSizeBinary(16), @"FixedSizeBinary(16)");
589 assert_snapshot!(NativeType::Decimal(10, 2), @"Decimal(10, 2)");
590 }
591
592 #[test]
593 fn test_native_type_display_timestamp() {
594 assert_snapshot!(
595 NativeType::Timestamp(TimeUnit::Second, None),
596 @"Timestamp(s)"
597 );
598 assert_snapshot!(
599 NativeType::Timestamp(TimeUnit::Millisecond, None),
600 @"Timestamp(ms)"
601 );
602 assert_snapshot!(
603 NativeType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("UTC"))),
604 @r#"Timestamp(ns, "UTC")"#
605 );
606 }
607
608 #[test]
609 fn test_native_type_display_time_duration_interval() {
610 assert_snapshot!(NativeType::Time(TimeUnit::Microsecond), @"Time(µs)");
611 assert_snapshot!(NativeType::Duration(TimeUnit::Nanosecond), @"Duration(ns)");
612 assert_snapshot!(NativeType::Interval(IntervalUnit::YearMonth), @"Interval(YearMonth)");
613 assert_snapshot!(NativeType::Interval(IntervalUnit::MonthDayNano), @"Interval(MonthDayNano)");
614 }
615
616 #[test]
617 fn test_native_type_display_nested() {
618 let list = NativeType::List(Arc::new(LogicalField::from(&Field::new(
619 "item",
620 DataType::Int32,
621 true,
622 ))));
623 assert_snapshot!(list, @"List(Int32)");
624
625 let fixed_list = NativeType::FixedSizeList(
626 Arc::new(LogicalField::from(&Field::new(
627 "item",
628 DataType::Float64,
629 false,
630 ))),
631 3,
632 );
633 assert_snapshot!(fixed_list, @"FixedSizeList(3 x non-null Float64)");
634
635 let struct_type = NativeType::Struct(LogicalFields::from(&Fields::from(vec![
636 Field::new("name", DataType::Utf8, false),
637 Field::new("age", DataType::Int32, true),
638 ])));
639 assert_snapshot!(struct_type, @r#"Struct("name": non-null String, "age": Int32)"#);
640
641 let map = NativeType::Map(Arc::new(LogicalField::from(&Field::new(
642 "entries",
643 DataType::Utf8,
644 false,
645 ))));
646 assert_snapshot!(map, @"Map(non-null String)");
647 }
648}