1mod field;
4mod physical_type;
5pub mod reshape;
6mod schema;
7
8use std::collections::BTreeMap;
9use std::sync::Arc;
10
11pub use field::{
12 DTYPE_CATEGORICAL_LEGACY, DTYPE_CATEGORICAL_NEW, DTYPE_ENUM_VALUES_LEGACY,
13 DTYPE_ENUM_VALUES_NEW, Field, MAINTAIN_PL_TYPE, PARQUET_EMPTY_STRUCT, PL_KEY,
14};
15pub use physical_type::*;
16use polars_utils::pl_str::PlSmallStr;
17pub use schema::{ArrowSchema, ArrowSchemaRef};
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21use crate::array::LIST_VALUES_NAME;
22
23pub type Metadata = BTreeMap<PlSmallStr, PlSmallStr>;
25pub(crate) type Extension = Option<(PlSmallStr, Option<PlSmallStr>)>;
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
38#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
39#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
40pub enum ArrowDataType {
41 #[default]
43 Null,
44 Boolean,
46 Int8,
48 Int16,
50 Int32,
52 Int64,
54 Int128,
56 UInt8,
58 UInt16,
60 UInt32,
62 UInt64,
64 UInt128,
66 Float16,
68 Float32,
70 Float64,
72 Timestamp(TimeUnit, Option<PlSmallStr>),
87 Date32,
90 Date64,
93 Time32(TimeUnit),
96 Time64(TimeUnit),
99 Duration(TimeUnit),
101 Interval(IntervalUnit),
104 Binary,
106 FixedSizeBinary(usize),
109 LargeBinary,
111 Utf8,
113 LargeUtf8,
115 List(Box<Field>),
117 FixedSizeList(Box<Field>, usize),
119 LargeList(Box<Field>),
121 Struct(Vec<Field>),
123 Map(Box<Field>, bool),
151 Dictionary(IntegerType, Box<ArrowDataType>, bool),
164 Decimal(usize, usize),
169 Decimal32(usize, usize),
171 Decimal64(usize, usize),
173 Decimal256(usize, usize),
175 Extension(Box<ExtensionType>),
177 BinaryView,
180 Utf8View,
183 Unknown,
185 #[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
188 Union(Box<UnionType>),
189}
190
191#[derive(Debug, Clone, PartialEq, Eq, Hash)]
192#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
193#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
194pub struct ExtensionType {
195 pub name: PlSmallStr,
196 pub inner: ArrowDataType,
197 pub metadata: Option<PlSmallStr>,
198}
199
200#[derive(Debug, Clone, PartialEq, Eq, Hash)]
201pub struct UnionType {
202 pub fields: Vec<Field>,
203 pub ids: Option<Vec<i32>>,
204 pub mode: UnionMode,
205}
206
207#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
209#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
210pub enum UnionMode {
211 Dense,
213 Sparse,
215}
216
217impl UnionMode {
218 pub fn sparse(is_sparse: bool) -> Self {
221 if is_sparse { Self::Sparse } else { Self::Dense }
222 }
223
224 pub fn is_sparse(&self) -> bool {
226 matches!(self, Self::Sparse)
227 }
228
229 pub fn is_dense(&self) -> bool {
231 matches!(self, Self::Dense)
232 }
233}
234
235#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
237#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
238#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
239pub enum TimeUnit {
240 Second,
242 Millisecond,
244 Microsecond,
246 Nanosecond,
248}
249
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
252#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
253#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
254pub enum IntervalUnit {
255 YearMonth,
257 DayTime,
260 MonthDayNano,
262 MonthDayMillis,
266}
267
268impl ArrowDataType {
269 pub const IDX_DTYPE: Self = {
271 #[cfg(not(feature = "bigidx"))]
272 {
273 ArrowDataType::UInt32
274 }
275 #[cfg(feature = "bigidx")]
276 {
277 ArrowDataType::UInt64
278 }
279 };
280
281 pub fn to_physical_type(&self) -> PhysicalType {
283 use ArrowDataType::*;
284 match self {
285 Null => PhysicalType::Null,
286 Boolean => PhysicalType::Boolean,
287 Int8 => PhysicalType::Primitive(PrimitiveType::Int8),
288 Int16 => PhysicalType::Primitive(PrimitiveType::Int16),
289 Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => {
290 PhysicalType::Primitive(PrimitiveType::Int32)
291 },
292 Int64 | Date64 | Timestamp(_, _) | Time64(_) | Duration(_) => {
293 PhysicalType::Primitive(PrimitiveType::Int64)
294 },
295 Int128 => PhysicalType::Primitive(PrimitiveType::Int128),
296 Decimal(_, _) => PhysicalType::Primitive(PrimitiveType::Int128),
297 Decimal32(_, _) => PhysicalType::Primitive(PrimitiveType::Int32),
298 Decimal64(_, _) => PhysicalType::Primitive(PrimitiveType::Int64),
299 Decimal256(_, _) => PhysicalType::Primitive(PrimitiveType::Int256),
300 UInt8 => PhysicalType::Primitive(PrimitiveType::UInt8),
301 UInt16 => PhysicalType::Primitive(PrimitiveType::UInt16),
302 UInt32 => PhysicalType::Primitive(PrimitiveType::UInt32),
303 UInt64 => PhysicalType::Primitive(PrimitiveType::UInt64),
304 UInt128 => PhysicalType::Primitive(PrimitiveType::UInt128),
305 Float16 => PhysicalType::Primitive(PrimitiveType::Float16),
306 Float32 => PhysicalType::Primitive(PrimitiveType::Float32),
307 Float64 => PhysicalType::Primitive(PrimitiveType::Float64),
308 Interval(IntervalUnit::DayTime) => PhysicalType::Primitive(PrimitiveType::DaysMs),
309 Interval(IntervalUnit::MonthDayNano) => {
310 PhysicalType::Primitive(PrimitiveType::MonthDayNano)
311 },
312 Interval(IntervalUnit::MonthDayMillis) => {
313 PhysicalType::Primitive(PrimitiveType::MonthDayMillis)
314 },
315 Binary => PhysicalType::Binary,
316 FixedSizeBinary(_) => PhysicalType::FixedSizeBinary,
317 LargeBinary => PhysicalType::LargeBinary,
318 Utf8 => PhysicalType::Utf8,
319 LargeUtf8 => PhysicalType::LargeUtf8,
320 BinaryView => PhysicalType::BinaryView,
321 Utf8View => PhysicalType::Utf8View,
322 List(_) => PhysicalType::List,
323 FixedSizeList(_, _) => PhysicalType::FixedSizeList,
324 LargeList(_) => PhysicalType::LargeList,
325 Struct(_) => PhysicalType::Struct,
326 Union(_) => PhysicalType::Union,
327 Map(_, _) => PhysicalType::Map,
328 Dictionary(key, _, _) => PhysicalType::Dictionary(*key),
329 Extension(ext) => ext.inner.to_physical_type(),
330 Unknown => unimplemented!(),
331 }
332 }
333
334 pub fn underlying_physical_type(&self) -> ArrowDataType {
336 use ArrowDataType::*;
337 match self {
338 Null | Boolean | Int8 | Int16 | Int32 | Int64 | Int128 | UInt8 | UInt16 | UInt32
339 | UInt64 | UInt128 | Float16 | Float32 | Float64 | Binary | LargeBinary | Utf8
340 | LargeUtf8 | BinaryView | Utf8View | FixedSizeBinary(_) | Unknown => self.clone(),
341
342 Decimal32(_, _) | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => Int32,
343 Decimal64(_, _)
344 | Date64
345 | Timestamp(_, _)
346 | Time64(_)
347 | Duration(_)
348 | Interval(IntervalUnit::DayTime) => Int64,
349 Interval(IntervalUnit::MonthDayNano | IntervalUnit::MonthDayMillis) => unimplemented!(),
350 Decimal(_, _) => Int128,
351 Decimal256(_, _) => unimplemented!(),
352 List(field) => List(Box::new(
353 field.with_dtype(field.dtype.underlying_physical_type()),
354 )),
355 LargeList(field) => LargeList(Box::new(
356 field.with_dtype(field.dtype.underlying_physical_type()),
357 )),
358 FixedSizeList(field, width) => FixedSizeList(
359 Box::new(field.with_dtype(field.dtype.underlying_physical_type())),
360 *width,
361 ),
362 Struct(fields) => Struct(
363 fields
364 .iter()
365 .map(|field| field.with_dtype(field.dtype.underlying_physical_type()))
366 .collect(),
367 ),
368 Dictionary(keys, _, _) => (*keys).into(),
369 Union(_) => unimplemented!(),
370 Map(_, _) => unimplemented!(),
371 Extension(ext) => ext.inner.underlying_physical_type(),
372 }
373 }
374
375 pub fn to_storage(&self) -> &ArrowDataType {
379 use ArrowDataType::*;
380 match self {
381 Extension(ext) => ext.inner.to_storage(),
382 _ => self,
383 }
384 }
385
386 pub fn to_storage_recursive(&self) -> ArrowDataType {
389 use ArrowDataType::*;
390 match self {
391 Extension(ext) => ext.inner.to_storage_recursive(),
392 List(field) => List(Box::new(Field {
393 dtype: field.dtype.to_storage_recursive(),
394 ..*field.clone()
395 })),
396 LargeList(field) => LargeList(Box::new(Field {
397 dtype: field.dtype.to_storage_recursive(),
398 ..*field.clone()
399 })),
400 FixedSizeList(field, width) => FixedSizeList(
401 Box::new(Field {
402 dtype: field.dtype.to_storage_recursive(),
403 ..*field.clone()
404 }),
405 *width,
406 ),
407 Struct(fields) => Struct(
408 fields
409 .iter()
410 .map(|field| Field {
411 dtype: field.dtype.to_storage_recursive(),
412 ..field.clone()
413 })
414 .collect(),
415 ),
416 Dictionary(keys, values, is_sorted) => {
417 Dictionary(*keys, Box::new(values.to_storage_recursive()), *is_sorted)
418 },
419 Union(_) => unimplemented!(),
420 Map(_, _) => unimplemented!(),
421 _ => self.clone(),
422 }
423 }
424
425 pub fn inner_dtype(&self) -> Option<&ArrowDataType> {
426 match self {
427 ArrowDataType::List(inner) => Some(inner.dtype()),
428 ArrowDataType::LargeList(inner) => Some(inner.dtype()),
429 ArrowDataType::FixedSizeList(inner, _) => Some(inner.dtype()),
430 _ => None,
431 }
432 }
433
434 pub fn is_nested(&self) -> bool {
435 use ArrowDataType as D;
436
437 matches!(
438 self,
439 D::List(_)
440 | D::LargeList(_)
441 | D::FixedSizeList(_, _)
442 | D::Struct(_)
443 | D::Union(_)
444 | D::Map(_, _)
445 | D::Dictionary(_, _, _)
446 | D::Extension(_)
447 )
448 }
449
450 pub fn is_view(&self) -> bool {
451 matches!(self, ArrowDataType::Utf8View | ArrowDataType::BinaryView)
452 }
453
454 pub fn is_numeric(&self) -> bool {
455 use ArrowDataType as D;
456 matches!(
457 self,
458 D::Int8
459 | D::Int16
460 | D::Int32
461 | D::Int64
462 | D::Int128
463 | D::UInt8
464 | D::UInt16
465 | D::UInt32
466 | D::UInt64
467 | D::UInt128
468 | D::Float16
469 | D::Float32
470 | D::Float64
471 | D::Decimal(_, _)
472 | D::Decimal32(_, _)
473 | D::Decimal64(_, _)
474 | D::Decimal256(_, _)
475 )
476 }
477
478 pub fn to_large_list(self, is_nullable: bool) -> ArrowDataType {
479 ArrowDataType::LargeList(Box::new(Field::new(LIST_VALUES_NAME, self, is_nullable)))
480 }
481
482 pub fn to_fixed_size_list(self, size: usize, is_nullable: bool) -> ArrowDataType {
483 ArrowDataType::FixedSizeList(
484 Box::new(Field::new(LIST_VALUES_NAME, self, is_nullable)),
485 size,
486 )
487 }
488
489 pub fn contains_dictionary(&self) -> bool {
491 use ArrowDataType as D;
492 match self {
493 D::Null
494 | D::Boolean
495 | D::Int8
496 | D::Int16
497 | D::Int32
498 | D::Int64
499 | D::Int128
500 | D::UInt8
501 | D::UInt16
502 | D::UInt32
503 | D::UInt64
504 | D::UInt128
505 | D::Float16
506 | D::Float32
507 | D::Float64
508 | D::Timestamp(_, _)
509 | D::Date32
510 | D::Date64
511 | D::Time32(_)
512 | D::Time64(_)
513 | D::Duration(_)
514 | D::Interval(_)
515 | D::Binary
516 | D::FixedSizeBinary(_)
517 | D::LargeBinary
518 | D::Utf8
519 | D::LargeUtf8
520 | D::Decimal(_, _)
521 | D::Decimal32(_, _)
522 | D::Decimal64(_, _)
523 | D::Decimal256(_, _)
524 | D::BinaryView
525 | D::Utf8View
526 | D::Unknown => false,
527 D::List(field)
528 | D::FixedSizeList(field, _)
529 | D::Map(field, _)
530 | D::LargeList(field) => field.dtype().contains_dictionary(),
531 D::Struct(fields) => fields.iter().any(|f| f.dtype().contains_dictionary()),
532 D::Union(union) => union.fields.iter().any(|f| f.dtype().contains_dictionary()),
533 D::Dictionary(_, _, _) => true,
534 D::Extension(ext) => ext.inner.contains_dictionary(),
535 }
536 }
537}
538
539impl From<IntegerType> for ArrowDataType {
540 fn from(item: IntegerType) -> Self {
541 match item {
542 IntegerType::Int8 => ArrowDataType::Int8,
543 IntegerType::Int16 => ArrowDataType::Int16,
544 IntegerType::Int32 => ArrowDataType::Int32,
545 IntegerType::Int64 => ArrowDataType::Int64,
546 IntegerType::Int128 => ArrowDataType::Int128,
547 IntegerType::UInt8 => ArrowDataType::UInt8,
548 IntegerType::UInt16 => ArrowDataType::UInt16,
549 IntegerType::UInt32 => ArrowDataType::UInt32,
550 IntegerType::UInt64 => ArrowDataType::UInt64,
551 IntegerType::UInt128 => ArrowDataType::UInt128,
552 }
553 }
554}
555
556impl From<PrimitiveType> for ArrowDataType {
557 fn from(item: PrimitiveType) -> Self {
558 match item {
559 PrimitiveType::Int8 => ArrowDataType::Int8,
560 PrimitiveType::Int16 => ArrowDataType::Int16,
561 PrimitiveType::Int32 => ArrowDataType::Int32,
562 PrimitiveType::Int64 => ArrowDataType::Int64,
563 PrimitiveType::Int128 => ArrowDataType::Int128,
564 PrimitiveType::UInt8 => ArrowDataType::UInt8,
565 PrimitiveType::UInt16 => ArrowDataType::UInt16,
566 PrimitiveType::UInt32 => ArrowDataType::UInt32,
567 PrimitiveType::UInt64 => ArrowDataType::UInt64,
568 PrimitiveType::UInt128 => ArrowDataType::UInt128,
569 PrimitiveType::Int256 => ArrowDataType::Decimal256(32, 32),
570 PrimitiveType::Float16 => ArrowDataType::Float16,
571 PrimitiveType::Float32 => ArrowDataType::Float32,
572 PrimitiveType::Float64 => ArrowDataType::Float64,
573 PrimitiveType::DaysMs => ArrowDataType::Interval(IntervalUnit::DayTime),
574 PrimitiveType::MonthDayNano => ArrowDataType::Interval(IntervalUnit::MonthDayNano),
575 PrimitiveType::MonthDayMillis => ArrowDataType::Interval(IntervalUnit::MonthDayMillis),
576 }
577 }
578}
579
580pub type SchemaRef = Arc<ArrowSchema>;
582
583pub fn get_extension(metadata: &Metadata) -> Extension {
585 if let Some(name) = metadata.get(&PlSmallStr::from_static("ARROW:extension:name")) {
586 let metadata = metadata
587 .get(&PlSmallStr::from_static("ARROW:extension:metadata"))
588 .cloned();
589 Some((name.clone(), metadata))
590 } else {
591 None
592 }
593}
594
595#[cfg(not(feature = "bigidx"))]
596pub type IdxArr = super::array::UInt32Array;
597#[cfg(feature = "bigidx")]
598pub type IdxArr = super::array::UInt64Array;