1mod field;
4mod physical_type;
5pub mod reshape;
6mod schema;
7
8use std::collections::BTreeMap;
9use std::sync::Arc;
10
11pub use field::{
12 DTYPE_CATEGORICAL_LEGACY, DTYPE_CATEGORICAL_NEW, DTYPE_ENUM_VALUES_LEGACY,
13 DTYPE_ENUM_VALUES_NEW, Field,
14};
15pub use physical_type::*;
16use polars_utils::pl_str::PlSmallStr;
17pub use schema::{ArrowSchema, ArrowSchemaRef};
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21use crate::array::LIST_VALUES_NAME;
22
23pub type Metadata = BTreeMap<PlSmallStr, PlSmallStr>;
25pub(crate) type Extension = Option<(PlSmallStr, Option<PlSmallStr>)>;
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
37#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
38#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
39pub enum ArrowDataType {
40 #[default]
42 Null,
43 Boolean,
45 Int8,
47 Int16,
49 Int32,
51 Int64,
53 Int128,
55 UInt8,
57 UInt16,
59 UInt32,
61 UInt64,
63 Float16,
65 Float32,
67 Float64,
69 Timestamp(TimeUnit, Option<PlSmallStr>),
84 Date32,
87 Date64,
90 Time32(TimeUnit),
93 Time64(TimeUnit),
96 Duration(TimeUnit),
98 Interval(IntervalUnit),
101 Binary,
103 FixedSizeBinary(usize),
106 LargeBinary,
108 Utf8,
110 LargeUtf8,
112 List(Box<Field>),
114 FixedSizeList(Box<Field>, usize),
116 LargeList(Box<Field>),
118 Struct(Vec<Field>),
120 Map(Box<Field>, bool),
148 Dictionary(IntegerType, Box<ArrowDataType>, bool),
161 Decimal(usize, usize),
166 Decimal32(usize, usize),
168 Decimal64(usize, usize),
170 Decimal256(usize, usize),
172 Extension(Box<ExtensionType>),
174 BinaryView,
177 Utf8View,
180 Unknown,
182 #[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
185 Union(Box<UnionType>),
186}
187
188#[derive(Debug, Clone, PartialEq, Eq, Hash)]
189#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
190#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
191pub struct ExtensionType {
192 pub name: PlSmallStr,
193 pub inner: ArrowDataType,
194 pub metadata: Option<PlSmallStr>,
195}
196
197#[derive(Debug, Clone, PartialEq, Eq, Hash)]
198pub struct UnionType {
199 pub fields: Vec<Field>,
200 pub ids: Option<Vec<i32>>,
201 pub mode: UnionMode,
202}
203
204#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
206#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
207pub enum UnionMode {
208 Dense,
210 Sparse,
212}
213
214impl UnionMode {
215 pub fn sparse(is_sparse: bool) -> Self {
218 if is_sparse { Self::Sparse } else { Self::Dense }
219 }
220
221 pub fn is_sparse(&self) -> bool {
223 matches!(self, Self::Sparse)
224 }
225
226 pub fn is_dense(&self) -> bool {
228 matches!(self, Self::Dense)
229 }
230}
231
232#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
234#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
235#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
236pub enum TimeUnit {
237 Second,
239 Millisecond,
241 Microsecond,
243 Nanosecond,
245}
246
247#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
249#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
250#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
251pub enum IntervalUnit {
252 YearMonth,
254 DayTime,
257 MonthDayNano,
259}
260
261impl ArrowDataType {
262 pub const IDX_DTYPE: Self = {
264 #[cfg(not(feature = "bigidx"))]
265 {
266 ArrowDataType::UInt32
267 }
268 #[cfg(feature = "bigidx")]
269 {
270 ArrowDataType::UInt64
271 }
272 };
273
274 pub fn to_physical_type(&self) -> PhysicalType {
276 use ArrowDataType::*;
277 match self {
278 Null => PhysicalType::Null,
279 Boolean => PhysicalType::Boolean,
280 Int8 => PhysicalType::Primitive(PrimitiveType::Int8),
281 Int16 => PhysicalType::Primitive(PrimitiveType::Int16),
282 Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => {
283 PhysicalType::Primitive(PrimitiveType::Int32)
284 },
285 Int64 | Date64 | Timestamp(_, _) | Time64(_) | Duration(_) => {
286 PhysicalType::Primitive(PrimitiveType::Int64)
287 },
288 Decimal(_, _) => PhysicalType::Primitive(PrimitiveType::Int128),
289 Decimal32(_, _) => PhysicalType::Primitive(PrimitiveType::Int32),
290 Decimal64(_, _) => PhysicalType::Primitive(PrimitiveType::Int64),
291 Decimal256(_, _) => PhysicalType::Primitive(PrimitiveType::Int256),
292 UInt8 => PhysicalType::Primitive(PrimitiveType::UInt8),
293 UInt16 => PhysicalType::Primitive(PrimitiveType::UInt16),
294 UInt32 => PhysicalType::Primitive(PrimitiveType::UInt32),
295 UInt64 => PhysicalType::Primitive(PrimitiveType::UInt64),
296 Float16 => PhysicalType::Primitive(PrimitiveType::Float16),
297 Float32 => PhysicalType::Primitive(PrimitiveType::Float32),
298 Float64 => PhysicalType::Primitive(PrimitiveType::Float64),
299 Int128 => PhysicalType::Primitive(PrimitiveType::Int128),
300 Interval(IntervalUnit::DayTime) => PhysicalType::Primitive(PrimitiveType::DaysMs),
301 Interval(IntervalUnit::MonthDayNano) => {
302 PhysicalType::Primitive(PrimitiveType::MonthDayNano)
303 },
304 Binary => PhysicalType::Binary,
305 FixedSizeBinary(_) => PhysicalType::FixedSizeBinary,
306 LargeBinary => PhysicalType::LargeBinary,
307 Utf8 => PhysicalType::Utf8,
308 LargeUtf8 => PhysicalType::LargeUtf8,
309 BinaryView => PhysicalType::BinaryView,
310 Utf8View => PhysicalType::Utf8View,
311 List(_) => PhysicalType::List,
312 FixedSizeList(_, _) => PhysicalType::FixedSizeList,
313 LargeList(_) => PhysicalType::LargeList,
314 Struct(_) => PhysicalType::Struct,
315 Union(_) => PhysicalType::Union,
316 Map(_, _) => PhysicalType::Map,
317 Dictionary(key, _, _) => PhysicalType::Dictionary(*key),
318 Extension(ext) => ext.inner.to_physical_type(),
319 Unknown => unimplemented!(),
320 }
321 }
322
323 pub fn underlying_physical_type(&self) -> ArrowDataType {
325 use ArrowDataType::*;
326 match self {
327 Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => Int32,
328 Date64
329 | Timestamp(_, _)
330 | Time64(_)
331 | Duration(_)
332 | Interval(IntervalUnit::DayTime) => Int64,
333 Interval(IntervalUnit::MonthDayNano) => unimplemented!(),
334 Binary => Binary,
335 List(field) => List(Box::new(Field {
336 dtype: field.dtype.underlying_physical_type(),
337 ..*field.clone()
338 })),
339 LargeList(field) => LargeList(Box::new(Field {
340 dtype: field.dtype.underlying_physical_type(),
341 ..*field.clone()
342 })),
343 FixedSizeList(field, width) => FixedSizeList(
344 Box::new(Field {
345 dtype: field.dtype.underlying_physical_type(),
346 ..*field.clone()
347 }),
348 *width,
349 ),
350 Struct(fields) => Struct(
351 fields
352 .iter()
353 .map(|field| Field {
354 dtype: field.dtype.underlying_physical_type(),
355 ..field.clone()
356 })
357 .collect(),
358 ),
359 Dictionary(keys, _, _) => (*keys).into(),
360 Union(_) => unimplemented!(),
361 Map(_, _) => unimplemented!(),
362 Extension(ext) => ext.inner.underlying_physical_type(),
363 _ => self.clone(),
364 }
365 }
366
367 pub fn to_logical_type(&self) -> &ArrowDataType {
371 use ArrowDataType::*;
372 match self {
373 Extension(ext) => ext.inner.to_logical_type(),
374 _ => self,
375 }
376 }
377
378 pub fn inner_dtype(&self) -> Option<&ArrowDataType> {
379 match self {
380 ArrowDataType::List(inner) => Some(inner.dtype()),
381 ArrowDataType::LargeList(inner) => Some(inner.dtype()),
382 ArrowDataType::FixedSizeList(inner, _) => Some(inner.dtype()),
383 _ => None,
384 }
385 }
386
387 pub fn is_nested(&self) -> bool {
388 use ArrowDataType as D;
389
390 matches!(
391 self,
392 D::List(_)
393 | D::LargeList(_)
394 | D::FixedSizeList(_, _)
395 | D::Struct(_)
396 | D::Union(_)
397 | D::Map(_, _)
398 | D::Dictionary(_, _, _)
399 | D::Extension(_)
400 )
401 }
402
403 pub fn is_view(&self) -> bool {
404 matches!(self, ArrowDataType::Utf8View | ArrowDataType::BinaryView)
405 }
406
407 pub fn is_numeric(&self) -> bool {
408 use ArrowDataType as D;
409 matches!(
410 self,
411 D::Int8
412 | D::Int16
413 | D::Int32
414 | D::Int64
415 | D::Int128
416 | D::UInt8
417 | D::UInt16
418 | D::UInt32
419 | D::UInt64
420 | D::Float32
421 | D::Float64
422 | D::Decimal(_, _)
423 | D::Decimal32(_, _)
424 | D::Decimal64(_, _)
425 | D::Decimal256(_, _)
426 )
427 }
428
429 pub fn to_fixed_size_list(self, size: usize, is_nullable: bool) -> ArrowDataType {
430 ArrowDataType::FixedSizeList(
431 Box::new(Field::new(LIST_VALUES_NAME, self, is_nullable)),
432 size,
433 )
434 }
435
436 pub fn contains_dictionary(&self) -> bool {
438 use ArrowDataType as D;
439 match self {
440 D::Null
441 | D::Boolean
442 | D::Int8
443 | D::Int16
444 | D::Int32
445 | D::Int64
446 | D::UInt8
447 | D::UInt16
448 | D::UInt32
449 | D::UInt64
450 | D::Int128
451 | D::Float16
452 | D::Float32
453 | D::Float64
454 | D::Timestamp(_, _)
455 | D::Date32
456 | D::Date64
457 | D::Time32(_)
458 | D::Time64(_)
459 | D::Duration(_)
460 | D::Interval(_)
461 | D::Binary
462 | D::FixedSizeBinary(_)
463 | D::LargeBinary
464 | D::Utf8
465 | D::LargeUtf8
466 | D::Decimal(_, _)
467 | D::Decimal32(_, _)
468 | D::Decimal64(_, _)
469 | D::Decimal256(_, _)
470 | D::BinaryView
471 | D::Utf8View
472 | D::Unknown => false,
473 D::List(field)
474 | D::FixedSizeList(field, _)
475 | D::Map(field, _)
476 | D::LargeList(field) => field.dtype().contains_dictionary(),
477 D::Struct(fields) => fields.iter().any(|f| f.dtype().contains_dictionary()),
478 D::Union(union) => union.fields.iter().any(|f| f.dtype().contains_dictionary()),
479 D::Dictionary(_, _, _) => true,
480 D::Extension(ext) => ext.inner.contains_dictionary(),
481 }
482 }
483}
484
485impl From<IntegerType> for ArrowDataType {
486 fn from(item: IntegerType) -> Self {
487 match item {
488 IntegerType::Int8 => ArrowDataType::Int8,
489 IntegerType::Int16 => ArrowDataType::Int16,
490 IntegerType::Int32 => ArrowDataType::Int32,
491 IntegerType::Int64 => ArrowDataType::Int64,
492 IntegerType::Int128 => ArrowDataType::Int128,
493 IntegerType::UInt8 => ArrowDataType::UInt8,
494 IntegerType::UInt16 => ArrowDataType::UInt16,
495 IntegerType::UInt32 => ArrowDataType::UInt32,
496 IntegerType::UInt64 => ArrowDataType::UInt64,
497 }
498 }
499}
500
501impl From<PrimitiveType> for ArrowDataType {
502 fn from(item: PrimitiveType) -> Self {
503 match item {
504 PrimitiveType::Int8 => ArrowDataType::Int8,
505 PrimitiveType::Int16 => ArrowDataType::Int16,
506 PrimitiveType::Int32 => ArrowDataType::Int32,
507 PrimitiveType::Int64 => ArrowDataType::Int64,
508 PrimitiveType::UInt8 => ArrowDataType::UInt8,
509 PrimitiveType::UInt16 => ArrowDataType::UInt16,
510 PrimitiveType::UInt32 => ArrowDataType::UInt32,
511 PrimitiveType::UInt64 => ArrowDataType::UInt64,
512 PrimitiveType::Int128 => ArrowDataType::Int128,
513 PrimitiveType::Int256 => ArrowDataType::Decimal256(32, 32),
514 PrimitiveType::Float16 => ArrowDataType::Float16,
515 PrimitiveType::Float32 => ArrowDataType::Float32,
516 PrimitiveType::Float64 => ArrowDataType::Float64,
517 PrimitiveType::DaysMs => ArrowDataType::Interval(IntervalUnit::DayTime),
518 PrimitiveType::MonthDayNano => ArrowDataType::Interval(IntervalUnit::MonthDayNano),
519 PrimitiveType::UInt128 => unimplemented!(),
520 }
521 }
522}
523
524pub type SchemaRef = Arc<ArrowSchema>;
526
527pub fn get_extension(metadata: &Metadata) -> Extension {
529 if let Some(name) = metadata.get(&PlSmallStr::from_static("ARROW:extension:name")) {
530 let metadata = metadata
531 .get(&PlSmallStr::from_static("ARROW:extension:metadata"))
532 .cloned();
533 Some((name.clone(), metadata))
534 } else {
535 None
536 }
537}
538
539#[cfg(not(feature = "bigidx"))]
540pub type IdxArr = super::array::UInt32Array;
541#[cfg(feature = "bigidx")]
542pub type IdxArr = super::array::UInt64Array;