mod field;
mod physical_type;
pub mod reshape;
mod schema;
use std::collections::BTreeMap;
use std::sync::Arc;
pub use field::{
DTYPE_CATEGORICAL_LEGACY, DTYPE_CATEGORICAL_NEW, DTYPE_ENUM_VALUES_LEGACY,
DTYPE_ENUM_VALUES_NEW, Field, MAINTAIN_PL_TYPE, PARQUET_EMPTY_STRUCT, PL_KEY,
};
pub use physical_type::*;
use polars_utils::pl_str::PlSmallStr;
pub use schema::{ArrowSchema, ArrowSchemaRef};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::array::LIST_VALUES_NAME;
pub type Metadata = BTreeMap<PlSmallStr, PlSmallStr>;
pub(crate) type Extension = Option<(PlSmallStr, Option<PlSmallStr>)>;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
pub enum ArrowDataType {
#[default]
Null,
Boolean,
Int8,
Int16,
Int32,
Int64,
Int128,
UInt8,
UInt16,
UInt32,
UInt64,
UInt128,
Float16,
Float32,
Float64,
Timestamp(TimeUnit, Option<PlSmallStr>),
Date32,
Date64,
Time32(TimeUnit),
Time64(TimeUnit),
Duration(TimeUnit),
Interval(IntervalUnit),
Binary,
FixedSizeBinary(usize),
LargeBinary,
Utf8,
LargeUtf8,
List(Box<Field>),
FixedSizeList(Box<Field>, usize),
LargeList(Box<Field>),
Struct(Vec<Field>),
Map(Box<Field>, bool),
Dictionary(IntegerType, Box<ArrowDataType>, bool),
Decimal(usize, usize),
Decimal32(usize, usize),
Decimal64(usize, usize),
Decimal256(usize, usize),
Extension(Box<ExtensionType>),
BinaryView,
Utf8View,
Unknown,
#[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
Union(Box<UnionType>),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
pub struct ExtensionType {
pub name: PlSmallStr,
pub inner: ArrowDataType,
pub metadata: Option<PlSmallStr>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct UnionType {
pub fields: Vec<Field>,
pub ids: Option<Vec<i32>>,
pub mode: UnionMode,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum UnionMode {
Dense,
Sparse,
}
impl UnionMode {
pub fn sparse(is_sparse: bool) -> Self {
if is_sparse { Self::Sparse } else { Self::Dense }
}
pub fn is_sparse(&self) -> bool {
matches!(self, Self::Sparse)
}
pub fn is_dense(&self) -> bool {
matches!(self, Self::Dense)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
pub enum TimeUnit {
Second,
Millisecond,
Microsecond,
Nanosecond,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
pub enum IntervalUnit {
YearMonth,
DayTime,
MonthDayNano,
MonthDayMillis,
}
impl ArrowDataType {
pub const IDX_DTYPE: Self = {
#[cfg(not(feature = "bigidx"))]
{
ArrowDataType::UInt32
}
#[cfg(feature = "bigidx")]
{
ArrowDataType::UInt64
}
};
pub fn to_physical_type(&self) -> PhysicalType {
use ArrowDataType::*;
match self {
Null => PhysicalType::Null,
Boolean => PhysicalType::Boolean,
Int8 => PhysicalType::Primitive(PrimitiveType::Int8),
Int16 => PhysicalType::Primitive(PrimitiveType::Int16),
Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => {
PhysicalType::Primitive(PrimitiveType::Int32)
},
Int64 | Date64 | Timestamp(_, _) | Time64(_) | Duration(_) => {
PhysicalType::Primitive(PrimitiveType::Int64)
},
Int128 => PhysicalType::Primitive(PrimitiveType::Int128),
Decimal(_, _) => PhysicalType::Primitive(PrimitiveType::Int128),
Decimal32(_, _) => PhysicalType::Primitive(PrimitiveType::Int32),
Decimal64(_, _) => PhysicalType::Primitive(PrimitiveType::Int64),
Decimal256(_, _) => PhysicalType::Primitive(PrimitiveType::Int256),
UInt8 => PhysicalType::Primitive(PrimitiveType::UInt8),
UInt16 => PhysicalType::Primitive(PrimitiveType::UInt16),
UInt32 => PhysicalType::Primitive(PrimitiveType::UInt32),
UInt64 => PhysicalType::Primitive(PrimitiveType::UInt64),
UInt128 => PhysicalType::Primitive(PrimitiveType::UInt128),
Float16 => PhysicalType::Primitive(PrimitiveType::Float16),
Float32 => PhysicalType::Primitive(PrimitiveType::Float32),
Float64 => PhysicalType::Primitive(PrimitiveType::Float64),
Interval(IntervalUnit::DayTime) => PhysicalType::Primitive(PrimitiveType::DaysMs),
Interval(IntervalUnit::MonthDayNano) => {
PhysicalType::Primitive(PrimitiveType::MonthDayNano)
},
Interval(IntervalUnit::MonthDayMillis) => {
PhysicalType::Primitive(PrimitiveType::MonthDayMillis)
},
Binary => PhysicalType::Binary,
FixedSizeBinary(_) => PhysicalType::FixedSizeBinary,
LargeBinary => PhysicalType::LargeBinary,
Utf8 => PhysicalType::Utf8,
LargeUtf8 => PhysicalType::LargeUtf8,
BinaryView => PhysicalType::BinaryView,
Utf8View => PhysicalType::Utf8View,
List(_) => PhysicalType::List,
FixedSizeList(_, _) => PhysicalType::FixedSizeList,
LargeList(_) => PhysicalType::LargeList,
Struct(_) => PhysicalType::Struct,
Union(_) => PhysicalType::Union,
Map(_, _) => PhysicalType::Map,
Dictionary(key, _, _) => PhysicalType::Dictionary(*key),
Extension(ext) => ext.inner.to_physical_type(),
Unknown => unimplemented!(),
}
}
pub fn underlying_physical_type(&self) -> ArrowDataType {
use ArrowDataType::*;
match self {
Null | Boolean | Int8 | Int16 | Int32 | Int64 | Int128 | UInt8 | UInt16 | UInt32
| UInt64 | UInt128 | Float16 | Float32 | Float64 | Binary | LargeBinary | Utf8
| LargeUtf8 | BinaryView | Utf8View | FixedSizeBinary(_) | Unknown => self.clone(),
Decimal32(_, _) | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => Int32,
Decimal64(_, _)
| Date64
| Timestamp(_, _)
| Time64(_)
| Duration(_)
| Interval(IntervalUnit::DayTime) => Int64,
Interval(IntervalUnit::MonthDayNano | IntervalUnit::MonthDayMillis) => unimplemented!(),
Decimal(_, _) => Int128,
Decimal256(_, _) => unimplemented!(),
List(field) => List(Box::new(
field.with_dtype(field.dtype.underlying_physical_type()),
)),
LargeList(field) => LargeList(Box::new(
field.with_dtype(field.dtype.underlying_physical_type()),
)),
FixedSizeList(field, width) => FixedSizeList(
Box::new(field.with_dtype(field.dtype.underlying_physical_type())),
*width,
),
Struct(fields) => Struct(
fields
.iter()
.map(|field| field.with_dtype(field.dtype.underlying_physical_type()))
.collect(),
),
Dictionary(keys, _, _) => (*keys).into(),
Union(_) => unimplemented!(),
Map(_, _) => unimplemented!(),
Extension(ext) => ext.inner.underlying_physical_type(),
}
}
pub fn to_storage(&self) -> &ArrowDataType {
use ArrowDataType::*;
match self {
Extension(ext) => ext.inner.to_storage(),
_ => self,
}
}
pub fn to_storage_recursive(&self) -> ArrowDataType {
use ArrowDataType::*;
match self {
Extension(ext) => ext.inner.to_storage_recursive(),
List(field) => List(Box::new(Field {
dtype: field.dtype.to_storage_recursive(),
..*field.clone()
})),
LargeList(field) => LargeList(Box::new(Field {
dtype: field.dtype.to_storage_recursive(),
..*field.clone()
})),
FixedSizeList(field, width) => FixedSizeList(
Box::new(Field {
dtype: field.dtype.to_storage_recursive(),
..*field.clone()
}),
*width,
),
Struct(fields) => Struct(
fields
.iter()
.map(|field| Field {
dtype: field.dtype.to_storage_recursive(),
..field.clone()
})
.collect(),
),
Dictionary(keys, values, is_sorted) => {
Dictionary(*keys, Box::new(values.to_storage_recursive()), *is_sorted)
},
Union(_) => unimplemented!(),
Map(_, _) => unimplemented!(),
_ => self.clone(),
}
}
pub fn inner_dtype(&self) -> Option<&ArrowDataType> {
match self {
ArrowDataType::List(inner) => Some(inner.dtype()),
ArrowDataType::LargeList(inner) => Some(inner.dtype()),
ArrowDataType::FixedSizeList(inner, _) => Some(inner.dtype()),
_ => None,
}
}
pub fn is_nested(&self) -> bool {
use ArrowDataType as D;
matches!(
self,
D::List(_)
| D::LargeList(_)
| D::FixedSizeList(_, _)
| D::Struct(_)
| D::Union(_)
| D::Map(_, _)
| D::Dictionary(_, _, _)
| D::Extension(_)
)
}
pub fn is_view(&self) -> bool {
matches!(self, ArrowDataType::Utf8View | ArrowDataType::BinaryView)
}
pub fn is_numeric(&self) -> bool {
use ArrowDataType as D;
matches!(
self,
D::Int8
| D::Int16
| D::Int32
| D::Int64
| D::Int128
| D::UInt8
| D::UInt16
| D::UInt32
| D::UInt64
| D::UInt128
| D::Float16
| D::Float32
| D::Float64
| D::Decimal(_, _)
| D::Decimal32(_, _)
| D::Decimal64(_, _)
| D::Decimal256(_, _)
)
}
pub fn to_large_list(self, is_nullable: bool) -> ArrowDataType {
ArrowDataType::LargeList(Box::new(Field::new(LIST_VALUES_NAME, self, is_nullable)))
}
pub fn to_fixed_size_list(self, size: usize, is_nullable: bool) -> ArrowDataType {
ArrowDataType::FixedSizeList(
Box::new(Field::new(LIST_VALUES_NAME, self, is_nullable)),
size,
)
}
pub fn contains_dictionary(&self) -> bool {
use ArrowDataType as D;
match self {
D::Null
| D::Boolean
| D::Int8
| D::Int16
| D::Int32
| D::Int64
| D::Int128
| D::UInt8
| D::UInt16
| D::UInt32
| D::UInt64
| D::UInt128
| D::Float16
| D::Float32
| D::Float64
| D::Timestamp(_, _)
| D::Date32
| D::Date64
| D::Time32(_)
| D::Time64(_)
| D::Duration(_)
| D::Interval(_)
| D::Binary
| D::FixedSizeBinary(_)
| D::LargeBinary
| D::Utf8
| D::LargeUtf8
| D::Decimal(_, _)
| D::Decimal32(_, _)
| D::Decimal64(_, _)
| D::Decimal256(_, _)
| D::BinaryView
| D::Utf8View
| D::Unknown => false,
D::List(field)
| D::FixedSizeList(field, _)
| D::Map(field, _)
| D::LargeList(field) => field.dtype().contains_dictionary(),
D::Struct(fields) => fields.iter().any(|f| f.dtype().contains_dictionary()),
D::Union(union) => union.fields.iter().any(|f| f.dtype().contains_dictionary()),
D::Dictionary(_, _, _) => true,
D::Extension(ext) => ext.inner.contains_dictionary(),
}
}
}
impl From<IntegerType> for ArrowDataType {
fn from(item: IntegerType) -> Self {
match item {
IntegerType::Int8 => ArrowDataType::Int8,
IntegerType::Int16 => ArrowDataType::Int16,
IntegerType::Int32 => ArrowDataType::Int32,
IntegerType::Int64 => ArrowDataType::Int64,
IntegerType::Int128 => ArrowDataType::Int128,
IntegerType::UInt8 => ArrowDataType::UInt8,
IntegerType::UInt16 => ArrowDataType::UInt16,
IntegerType::UInt32 => ArrowDataType::UInt32,
IntegerType::UInt64 => ArrowDataType::UInt64,
IntegerType::UInt128 => ArrowDataType::UInt128,
}
}
}
impl From<PrimitiveType> for ArrowDataType {
fn from(item: PrimitiveType) -> Self {
match item {
PrimitiveType::Int8 => ArrowDataType::Int8,
PrimitiveType::Int16 => ArrowDataType::Int16,
PrimitiveType::Int32 => ArrowDataType::Int32,
PrimitiveType::Int64 => ArrowDataType::Int64,
PrimitiveType::Int128 => ArrowDataType::Int128,
PrimitiveType::UInt8 => ArrowDataType::UInt8,
PrimitiveType::UInt16 => ArrowDataType::UInt16,
PrimitiveType::UInt32 => ArrowDataType::UInt32,
PrimitiveType::UInt64 => ArrowDataType::UInt64,
PrimitiveType::UInt128 => ArrowDataType::UInt128,
PrimitiveType::Int256 => ArrowDataType::Decimal256(32, 32),
PrimitiveType::Float16 => ArrowDataType::Float16,
PrimitiveType::Float32 => ArrowDataType::Float32,
PrimitiveType::Float64 => ArrowDataType::Float64,
PrimitiveType::DaysMs => ArrowDataType::Interval(IntervalUnit::DayTime),
PrimitiveType::MonthDayNano => ArrowDataType::Interval(IntervalUnit::MonthDayNano),
PrimitiveType::MonthDayMillis => ArrowDataType::Interval(IntervalUnit::MonthDayMillis),
}
}
}
pub type SchemaRef = Arc<ArrowSchema>;
pub fn get_extension(metadata: &Metadata) -> Extension {
if let Some(name) = metadata.get(&PlSmallStr::from_static("ARROW:extension:name")) {
let metadata = metadata
.get(&PlSmallStr::from_static("ARROW:extension:metadata"))
.cloned();
Some((name.clone(), metadata))
} else {
None
}
}
#[cfg(not(feature = "bigidx"))]
pub type IdxArr = super::array::UInt32Array;
#[cfg(feature = "bigidx")]
pub type IdxArr = super::array::UInt64Array;