1mod field;
4mod physical_type;
5pub mod reshape;
6mod schema;
7
8use std::collections::BTreeMap;
9use std::sync::Arc;
10
11pub use field::{
12 DTYPE_CATEGORICAL_LEGACY, DTYPE_CATEGORICAL_NEW, DTYPE_ENUM_VALUES_LEGACY,
13 DTYPE_ENUM_VALUES_NEW, Field, MAINTAIN_PL_TYPE, PARQUET_EMPTY_STRUCT, PL_KEY,
14};
15pub use physical_type::*;
16use polars_utils::pl_str::PlSmallStr;
17pub use schema::{ArrowSchema, ArrowSchemaRef};
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21use crate::array::LIST_VALUES_NAME;
22
23pub type Metadata = BTreeMap<PlSmallStr, PlSmallStr>;
25pub(crate) type Extension = Option<(PlSmallStr, Option<PlSmallStr>)>;
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
37#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
38#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
39pub enum ArrowDataType {
40 #[default]
42 Null,
43 Boolean,
45 Int8,
47 Int16,
49 Int32,
51 Int64,
53 Int128,
55 UInt8,
57 UInt16,
59 UInt32,
61 UInt64,
63 UInt128,
65 Float16,
67 Float32,
69 Float64,
71 Timestamp(TimeUnit, Option<PlSmallStr>),
86 Date32,
89 Date64,
92 Time32(TimeUnit),
95 Time64(TimeUnit),
98 Duration(TimeUnit),
100 Interval(IntervalUnit),
103 Binary,
105 FixedSizeBinary(usize),
108 LargeBinary,
110 Utf8,
112 LargeUtf8,
114 List(Box<Field>),
116 FixedSizeList(Box<Field>, usize),
118 LargeList(Box<Field>),
120 Struct(Vec<Field>),
122 Map(Box<Field>, bool),
150 Dictionary(IntegerType, Box<ArrowDataType>, bool),
163 Decimal(usize, usize),
168 Decimal32(usize, usize),
170 Decimal64(usize, usize),
172 Decimal256(usize, usize),
174 Extension(Box<ExtensionType>),
176 BinaryView,
179 Utf8View,
182 Unknown,
184 #[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
187 Union(Box<UnionType>),
188}
189
190#[derive(Debug, Clone, PartialEq, Eq, Hash)]
191#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
192#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
193pub struct ExtensionType {
194 pub name: PlSmallStr,
195 pub inner: ArrowDataType,
196 pub metadata: Option<PlSmallStr>,
197}
198
199#[derive(Debug, Clone, PartialEq, Eq, Hash)]
200pub struct UnionType {
201 pub fields: Vec<Field>,
202 pub ids: Option<Vec<i32>>,
203 pub mode: UnionMode,
204}
205
206#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
208#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
209pub enum UnionMode {
210 Dense,
212 Sparse,
214}
215
216impl UnionMode {
217 pub fn sparse(is_sparse: bool) -> Self {
220 if is_sparse { Self::Sparse } else { Self::Dense }
221 }
222
223 pub fn is_sparse(&self) -> bool {
225 matches!(self, Self::Sparse)
226 }
227
228 pub fn is_dense(&self) -> bool {
230 matches!(self, Self::Dense)
231 }
232}
233
234#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
236#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
237#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
238pub enum TimeUnit {
239 Second,
241 Millisecond,
243 Microsecond,
245 Nanosecond,
247}
248
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
251#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
252#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
253pub enum IntervalUnit {
254 YearMonth,
256 DayTime,
259 MonthDayNano,
261 MonthDayMillis,
265}
266
267impl ArrowDataType {
268 pub const IDX_DTYPE: Self = {
270 #[cfg(not(feature = "bigidx"))]
271 {
272 ArrowDataType::UInt32
273 }
274 #[cfg(feature = "bigidx")]
275 {
276 ArrowDataType::UInt64
277 }
278 };
279
280 pub fn to_physical_type(&self) -> PhysicalType {
282 use ArrowDataType::*;
283 match self {
284 Null => PhysicalType::Null,
285 Boolean => PhysicalType::Boolean,
286 Int8 => PhysicalType::Primitive(PrimitiveType::Int8),
287 Int16 => PhysicalType::Primitive(PrimitiveType::Int16),
288 Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => {
289 PhysicalType::Primitive(PrimitiveType::Int32)
290 },
291 Int64 | Date64 | Timestamp(_, _) | Time64(_) | Duration(_) => {
292 PhysicalType::Primitive(PrimitiveType::Int64)
293 },
294 Int128 => PhysicalType::Primitive(PrimitiveType::Int128),
295 Decimal(_, _) => PhysicalType::Primitive(PrimitiveType::Int128),
296 Decimal32(_, _) => PhysicalType::Primitive(PrimitiveType::Int32),
297 Decimal64(_, _) => PhysicalType::Primitive(PrimitiveType::Int64),
298 Decimal256(_, _) => PhysicalType::Primitive(PrimitiveType::Int256),
299 UInt8 => PhysicalType::Primitive(PrimitiveType::UInt8),
300 UInt16 => PhysicalType::Primitive(PrimitiveType::UInt16),
301 UInt32 => PhysicalType::Primitive(PrimitiveType::UInt32),
302 UInt64 => PhysicalType::Primitive(PrimitiveType::UInt64),
303 UInt128 => PhysicalType::Primitive(PrimitiveType::UInt128),
304 Float16 => PhysicalType::Primitive(PrimitiveType::Float16),
305 Float32 => PhysicalType::Primitive(PrimitiveType::Float32),
306 Float64 => PhysicalType::Primitive(PrimitiveType::Float64),
307 Interval(IntervalUnit::DayTime) => PhysicalType::Primitive(PrimitiveType::DaysMs),
308 Interval(IntervalUnit::MonthDayNano) => {
309 PhysicalType::Primitive(PrimitiveType::MonthDayNano)
310 },
311 Interval(IntervalUnit::MonthDayMillis) => {
312 PhysicalType::Primitive(PrimitiveType::MonthDayMillis)
313 },
314 Binary => PhysicalType::Binary,
315 FixedSizeBinary(_) => PhysicalType::FixedSizeBinary,
316 LargeBinary => PhysicalType::LargeBinary,
317 Utf8 => PhysicalType::Utf8,
318 LargeUtf8 => PhysicalType::LargeUtf8,
319 BinaryView => PhysicalType::BinaryView,
320 Utf8View => PhysicalType::Utf8View,
321 List(_) => PhysicalType::List,
322 FixedSizeList(_, _) => PhysicalType::FixedSizeList,
323 LargeList(_) => PhysicalType::LargeList,
324 Struct(_) => PhysicalType::Struct,
325 Union(_) => PhysicalType::Union,
326 Map(_, _) => PhysicalType::Map,
327 Dictionary(key, _, _) => PhysicalType::Dictionary(*key),
328 Extension(ext) => ext.inner.to_physical_type(),
329 Unknown => unimplemented!(),
330 }
331 }
332
333 pub fn underlying_physical_type(&self) -> ArrowDataType {
335 use ArrowDataType::*;
336 match self {
337 Decimal32(_, _) | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => Int32,
338 Decimal64(_, _)
339 | Date64
340 | Timestamp(_, _)
341 | Time64(_)
342 | Duration(_)
343 | Interval(IntervalUnit::DayTime) => Int64,
344 Interval(IntervalUnit::MonthDayNano) => unimplemented!(),
345 Binary => Binary,
346 Decimal(_, _) => Int128,
347 Decimal256(_, _) => unimplemented!(),
348 List(field) => List(Box::new(Field {
349 dtype: field.dtype.underlying_physical_type(),
350 ..*field.clone()
351 })),
352 LargeList(field) => LargeList(Box::new(Field {
353 dtype: field.dtype.underlying_physical_type(),
354 ..*field.clone()
355 })),
356 FixedSizeList(field, width) => FixedSizeList(
357 Box::new(Field {
358 dtype: field.dtype.underlying_physical_type(),
359 ..*field.clone()
360 }),
361 *width,
362 ),
363 Struct(fields) => Struct(
364 fields
365 .iter()
366 .map(|field| Field {
367 dtype: field.dtype.underlying_physical_type(),
368 ..field.clone()
369 })
370 .collect(),
371 ),
372 Dictionary(keys, _, _) => (*keys).into(),
373 Union(_) => unimplemented!(),
374 Map(_, _) => unimplemented!(),
375 Extension(ext) => ext.inner.underlying_physical_type(),
376 _ => self.clone(),
377 }
378 }
379
380 pub fn to_logical_type(&self) -> &ArrowDataType {
384 use ArrowDataType::*;
385 match self {
386 Extension(ext) => ext.inner.to_logical_type(),
387 _ => self,
388 }
389 }
390
391 pub fn inner_dtype(&self) -> Option<&ArrowDataType> {
392 match self {
393 ArrowDataType::List(inner) => Some(inner.dtype()),
394 ArrowDataType::LargeList(inner) => Some(inner.dtype()),
395 ArrowDataType::FixedSizeList(inner, _) => Some(inner.dtype()),
396 _ => None,
397 }
398 }
399
400 pub fn is_nested(&self) -> bool {
401 use ArrowDataType as D;
402
403 matches!(
404 self,
405 D::List(_)
406 | D::LargeList(_)
407 | D::FixedSizeList(_, _)
408 | D::Struct(_)
409 | D::Union(_)
410 | D::Map(_, _)
411 | D::Dictionary(_, _, _)
412 | D::Extension(_)
413 )
414 }
415
416 pub fn is_view(&self) -> bool {
417 matches!(self, ArrowDataType::Utf8View | ArrowDataType::BinaryView)
418 }
419
420 pub fn is_numeric(&self) -> bool {
421 use ArrowDataType as D;
422 matches!(
423 self,
424 D::Int8
425 | D::Int16
426 | D::Int32
427 | D::Int64
428 | D::Int128
429 | D::UInt8
430 | D::UInt16
431 | D::UInt32
432 | D::UInt64
433 | D::UInt128
434 | D::Float32
435 | D::Float64
436 | D::Decimal(_, _)
437 | D::Decimal32(_, _)
438 | D::Decimal64(_, _)
439 | D::Decimal256(_, _)
440 )
441 }
442
443 pub fn to_large_list(self, is_nullable: bool) -> ArrowDataType {
444 ArrowDataType::LargeList(Box::new(Field::new(LIST_VALUES_NAME, self, is_nullable)))
445 }
446
447 pub fn to_fixed_size_list(self, size: usize, is_nullable: bool) -> ArrowDataType {
448 ArrowDataType::FixedSizeList(
449 Box::new(Field::new(LIST_VALUES_NAME, self, is_nullable)),
450 size,
451 )
452 }
453
454 pub fn contains_dictionary(&self) -> bool {
456 use ArrowDataType as D;
457 match self {
458 D::Null
459 | D::Boolean
460 | D::Int8
461 | D::Int16
462 | D::Int32
463 | D::Int64
464 | D::Int128
465 | D::UInt8
466 | D::UInt16
467 | D::UInt32
468 | D::UInt64
469 | D::UInt128
470 | D::Float16
471 | D::Float32
472 | D::Float64
473 | D::Timestamp(_, _)
474 | D::Date32
475 | D::Date64
476 | D::Time32(_)
477 | D::Time64(_)
478 | D::Duration(_)
479 | D::Interval(_)
480 | D::Binary
481 | D::FixedSizeBinary(_)
482 | D::LargeBinary
483 | D::Utf8
484 | D::LargeUtf8
485 | D::Decimal(_, _)
486 | D::Decimal32(_, _)
487 | D::Decimal64(_, _)
488 | D::Decimal256(_, _)
489 | D::BinaryView
490 | D::Utf8View
491 | D::Unknown => false,
492 D::List(field)
493 | D::FixedSizeList(field, _)
494 | D::Map(field, _)
495 | D::LargeList(field) => field.dtype().contains_dictionary(),
496 D::Struct(fields) => fields.iter().any(|f| f.dtype().contains_dictionary()),
497 D::Union(union) => union.fields.iter().any(|f| f.dtype().contains_dictionary()),
498 D::Dictionary(_, _, _) => true,
499 D::Extension(ext) => ext.inner.contains_dictionary(),
500 }
501 }
502}
503
504impl From<IntegerType> for ArrowDataType {
505 fn from(item: IntegerType) -> Self {
506 match item {
507 IntegerType::Int8 => ArrowDataType::Int8,
508 IntegerType::Int16 => ArrowDataType::Int16,
509 IntegerType::Int32 => ArrowDataType::Int32,
510 IntegerType::Int64 => ArrowDataType::Int64,
511 IntegerType::Int128 => ArrowDataType::Int128,
512 IntegerType::UInt8 => ArrowDataType::UInt8,
513 IntegerType::UInt16 => ArrowDataType::UInt16,
514 IntegerType::UInt32 => ArrowDataType::UInt32,
515 IntegerType::UInt64 => ArrowDataType::UInt64,
516 IntegerType::UInt128 => ArrowDataType::UInt128,
517 }
518 }
519}
520
521impl From<PrimitiveType> for ArrowDataType {
522 fn from(item: PrimitiveType) -> Self {
523 match item {
524 PrimitiveType::Int8 => ArrowDataType::Int8,
525 PrimitiveType::Int16 => ArrowDataType::Int16,
526 PrimitiveType::Int32 => ArrowDataType::Int32,
527 PrimitiveType::Int64 => ArrowDataType::Int64,
528 PrimitiveType::Int128 => ArrowDataType::Int128,
529 PrimitiveType::UInt8 => ArrowDataType::UInt8,
530 PrimitiveType::UInt16 => ArrowDataType::UInt16,
531 PrimitiveType::UInt32 => ArrowDataType::UInt32,
532 PrimitiveType::UInt64 => ArrowDataType::UInt64,
533 PrimitiveType::UInt128 => ArrowDataType::UInt128,
534 PrimitiveType::Int256 => ArrowDataType::Decimal256(32, 32),
535 PrimitiveType::Float16 => ArrowDataType::Float16,
536 PrimitiveType::Float32 => ArrowDataType::Float32,
537 PrimitiveType::Float64 => ArrowDataType::Float64,
538 PrimitiveType::DaysMs => ArrowDataType::Interval(IntervalUnit::DayTime),
539 PrimitiveType::MonthDayNano => ArrowDataType::Interval(IntervalUnit::MonthDayNano),
540 PrimitiveType::MonthDayMillis => ArrowDataType::Interval(IntervalUnit::MonthDayMillis),
541 }
542 }
543}
544
545pub type SchemaRef = Arc<ArrowSchema>;
547
548pub fn get_extension(metadata: &Metadata) -> Extension {
550 if let Some(name) = metadata.get(&PlSmallStr::from_static("ARROW:extension:name")) {
551 let metadata = metadata
552 .get(&PlSmallStr::from_static("ARROW:extension:metadata"))
553 .cloned();
554 Some((name.clone(), metadata))
555 } else {
556 None
557 }
558}
559
560#[cfg(not(feature = "bigidx"))]
561pub type IdxArr = super::array::UInt32Array;
562#[cfg(feature = "bigidx")]
563pub type IdxArr = super::array::UInt64Array;