1mod field;
4mod physical_type;
5pub mod reshape;
6mod schema;
7
8use std::collections::BTreeMap;
9use std::sync::Arc;
10
11pub use field::{DTYPE_CATEGORICAL, DTYPE_ENUM_VALUES, Field};
12pub use physical_type::*;
13use polars_utils::pl_str::PlSmallStr;
14pub use schema::{ArrowSchema, ArrowSchemaRef};
15#[cfg(feature = "serde")]
16use serde::{Deserialize, Serialize};
17
18pub type Metadata = BTreeMap<PlSmallStr, PlSmallStr>;
20pub(crate) type Extension = Option<(PlSmallStr, Option<PlSmallStr>)>;
22
23#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
32#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
33pub enum ArrowDataType {
34 #[default]
36 Null,
37 Boolean,
39 Int8,
41 Int16,
43 Int32,
45 Int64,
47 Int128,
49 UInt8,
51 UInt16,
53 UInt32,
55 UInt64,
57 Float16,
59 Float32,
61 Float64,
63 Timestamp(TimeUnit, Option<PlSmallStr>),
78 Date32,
81 Date64,
84 Time32(TimeUnit),
87 Time64(TimeUnit),
90 Duration(TimeUnit),
92 Interval(IntervalUnit),
95 Binary,
97 FixedSizeBinary(usize),
100 LargeBinary,
102 Utf8,
104 LargeUtf8,
106 List(Box<Field>),
108 FixedSizeList(Box<Field>, usize),
110 LargeList(Box<Field>),
112 Struct(Vec<Field>),
114 Map(Box<Field>, bool),
142 Dictionary(IntegerType, Box<ArrowDataType>, bool),
155 Decimal(usize, usize),
160 Decimal256(usize, usize),
162 Extension(Box<ExtensionType>),
164 BinaryView,
167 Utf8View,
170 Unknown,
172 #[cfg_attr(feature = "serde", serde(skip))]
175 Union(Box<UnionType>),
176}
177
178#[derive(Debug, Clone, PartialEq, Eq, Hash)]
179#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
180pub struct ExtensionType {
181 pub name: PlSmallStr,
182 pub inner: ArrowDataType,
183 pub metadata: Option<PlSmallStr>,
184}
185
186#[derive(Debug, Clone, PartialEq, Eq, Hash)]
187pub struct UnionType {
188 pub fields: Vec<Field>,
189 pub ids: Option<Vec<i32>>,
190 pub mode: UnionMode,
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
195#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
196pub enum UnionMode {
197 Dense,
199 Sparse,
201}
202
203impl UnionMode {
204 pub fn sparse(is_sparse: bool) -> Self {
207 if is_sparse { Self::Sparse } else { Self::Dense }
208 }
209
210 pub fn is_sparse(&self) -> bool {
212 matches!(self, Self::Sparse)
213 }
214
215 pub fn is_dense(&self) -> bool {
217 matches!(self, Self::Dense)
218 }
219}
220
221#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
223#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
224pub enum TimeUnit {
225 Second,
227 Millisecond,
229 Microsecond,
231 Nanosecond,
233}
234
235#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
237#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
238pub enum IntervalUnit {
239 YearMonth,
241 DayTime,
244 MonthDayNano,
246}
247
248impl ArrowDataType {
249 pub fn to_physical_type(&self) -> PhysicalType {
251 use ArrowDataType::*;
252 match self {
253 Null => PhysicalType::Null,
254 Boolean => PhysicalType::Boolean,
255 Int8 => PhysicalType::Primitive(PrimitiveType::Int8),
256 Int16 => PhysicalType::Primitive(PrimitiveType::Int16),
257 Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => {
258 PhysicalType::Primitive(PrimitiveType::Int32)
259 },
260 Int64 | Date64 | Timestamp(_, _) | Time64(_) | Duration(_) => {
261 PhysicalType::Primitive(PrimitiveType::Int64)
262 },
263 Decimal(_, _) => PhysicalType::Primitive(PrimitiveType::Int128),
264 Decimal256(_, _) => PhysicalType::Primitive(PrimitiveType::Int256),
265 UInt8 => PhysicalType::Primitive(PrimitiveType::UInt8),
266 UInt16 => PhysicalType::Primitive(PrimitiveType::UInt16),
267 UInt32 => PhysicalType::Primitive(PrimitiveType::UInt32),
268 UInt64 => PhysicalType::Primitive(PrimitiveType::UInt64),
269 Float16 => PhysicalType::Primitive(PrimitiveType::Float16),
270 Float32 => PhysicalType::Primitive(PrimitiveType::Float32),
271 Float64 => PhysicalType::Primitive(PrimitiveType::Float64),
272 Int128 => PhysicalType::Primitive(PrimitiveType::Int128),
273 Interval(IntervalUnit::DayTime) => PhysicalType::Primitive(PrimitiveType::DaysMs),
274 Interval(IntervalUnit::MonthDayNano) => {
275 PhysicalType::Primitive(PrimitiveType::MonthDayNano)
276 },
277 Binary => PhysicalType::Binary,
278 FixedSizeBinary(_) => PhysicalType::FixedSizeBinary,
279 LargeBinary => PhysicalType::LargeBinary,
280 Utf8 => PhysicalType::Utf8,
281 LargeUtf8 => PhysicalType::LargeUtf8,
282 BinaryView => PhysicalType::BinaryView,
283 Utf8View => PhysicalType::Utf8View,
284 List(_) => PhysicalType::List,
285 FixedSizeList(_, _) => PhysicalType::FixedSizeList,
286 LargeList(_) => PhysicalType::LargeList,
287 Struct(_) => PhysicalType::Struct,
288 Union(_) => PhysicalType::Union,
289 Map(_, _) => PhysicalType::Map,
290 Dictionary(key, _, _) => PhysicalType::Dictionary(*key),
291 Extension(ext) => ext.inner.to_physical_type(),
292 Unknown => unimplemented!(),
293 }
294 }
295
296 pub fn underlying_physical_type(&self) -> ArrowDataType {
298 use ArrowDataType::*;
299 match self {
300 Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => Int32,
301 Date64
302 | Timestamp(_, _)
303 | Time64(_)
304 | Duration(_)
305 | Interval(IntervalUnit::DayTime) => Int64,
306 Interval(IntervalUnit::MonthDayNano) => unimplemented!(),
307 Binary => Binary,
308 List(field) => List(Box::new(Field {
309 dtype: field.dtype.underlying_physical_type(),
310 ..*field.clone()
311 })),
312 LargeList(field) => LargeList(Box::new(Field {
313 dtype: field.dtype.underlying_physical_type(),
314 ..*field.clone()
315 })),
316 FixedSizeList(field, width) => FixedSizeList(
317 Box::new(Field {
318 dtype: field.dtype.underlying_physical_type(),
319 ..*field.clone()
320 }),
321 *width,
322 ),
323 Struct(fields) => Struct(
324 fields
325 .iter()
326 .map(|field| Field {
327 dtype: field.dtype.underlying_physical_type(),
328 ..field.clone()
329 })
330 .collect(),
331 ),
332 Dictionary(keys, _, _) => (*keys).into(),
333 Union(_) => unimplemented!(),
334 Map(_, _) => unimplemented!(),
335 Extension(ext) => ext.inner.underlying_physical_type(),
336 _ => self.clone(),
337 }
338 }
339
340 pub fn to_logical_type(&self) -> &ArrowDataType {
344 use ArrowDataType::*;
345 match self {
346 Extension(ext) => ext.inner.to_logical_type(),
347 _ => self,
348 }
349 }
350
351 pub fn inner_dtype(&self) -> Option<&ArrowDataType> {
352 match self {
353 ArrowDataType::List(inner) => Some(inner.dtype()),
354 ArrowDataType::LargeList(inner) => Some(inner.dtype()),
355 ArrowDataType::FixedSizeList(inner, _) => Some(inner.dtype()),
356 _ => None,
357 }
358 }
359
360 pub fn is_nested(&self) -> bool {
361 use ArrowDataType as D;
362
363 matches!(
364 self,
365 D::List(_)
366 | D::LargeList(_)
367 | D::FixedSizeList(_, _)
368 | D::Struct(_)
369 | D::Union(_)
370 | D::Map(_, _)
371 | D::Dictionary(_, _, _)
372 | D::Extension(_)
373 )
374 }
375
376 pub fn is_view(&self) -> bool {
377 matches!(self, ArrowDataType::Utf8View | ArrowDataType::BinaryView)
378 }
379
380 pub fn is_numeric(&self) -> bool {
381 use ArrowDataType as D;
382 matches!(
383 self,
384 D::Int8
385 | D::Int16
386 | D::Int32
387 | D::Int64
388 | D::Int128
389 | D::UInt8
390 | D::UInt16
391 | D::UInt32
392 | D::UInt64
393 | D::Float32
394 | D::Float64
395 | D::Decimal(_, _)
396 | D::Decimal256(_, _)
397 )
398 }
399
400 pub fn to_fixed_size_list(self, size: usize, is_nullable: bool) -> ArrowDataType {
401 ArrowDataType::FixedSizeList(
402 Box::new(Field::new(
403 PlSmallStr::from_static("item"),
404 self,
405 is_nullable,
406 )),
407 size,
408 )
409 }
410
411 pub fn contains_dictionary(&self) -> bool {
413 use ArrowDataType as D;
414 match self {
415 D::Null
416 | D::Boolean
417 | D::Int8
418 | D::Int16
419 | D::Int32
420 | D::Int64
421 | D::UInt8
422 | D::UInt16
423 | D::UInt32
424 | D::UInt64
425 | D::Int128
426 | D::Float16
427 | D::Float32
428 | D::Float64
429 | D::Timestamp(_, _)
430 | D::Date32
431 | D::Date64
432 | D::Time32(_)
433 | D::Time64(_)
434 | D::Duration(_)
435 | D::Interval(_)
436 | D::Binary
437 | D::FixedSizeBinary(_)
438 | D::LargeBinary
439 | D::Utf8
440 | D::LargeUtf8
441 | D::Decimal(_, _)
442 | D::Decimal256(_, _)
443 | D::BinaryView
444 | D::Utf8View
445 | D::Unknown => false,
446 D::List(field)
447 | D::FixedSizeList(field, _)
448 | D::Map(field, _)
449 | D::LargeList(field) => field.dtype().contains_dictionary(),
450 D::Struct(fields) => fields.iter().any(|f| f.dtype().contains_dictionary()),
451 D::Union(union) => union.fields.iter().any(|f| f.dtype().contains_dictionary()),
452 D::Dictionary(_, _, _) => true,
453 D::Extension(ext) => ext.inner.contains_dictionary(),
454 }
455 }
456}
457
458impl From<IntegerType> for ArrowDataType {
459 fn from(item: IntegerType) -> Self {
460 match item {
461 IntegerType::Int8 => ArrowDataType::Int8,
462 IntegerType::Int16 => ArrowDataType::Int16,
463 IntegerType::Int32 => ArrowDataType::Int32,
464 IntegerType::Int64 => ArrowDataType::Int64,
465 IntegerType::Int128 => ArrowDataType::Int128,
466 IntegerType::UInt8 => ArrowDataType::UInt8,
467 IntegerType::UInt16 => ArrowDataType::UInt16,
468 IntegerType::UInt32 => ArrowDataType::UInt32,
469 IntegerType::UInt64 => ArrowDataType::UInt64,
470 }
471 }
472}
473
474impl From<PrimitiveType> for ArrowDataType {
475 fn from(item: PrimitiveType) -> Self {
476 match item {
477 PrimitiveType::Int8 => ArrowDataType::Int8,
478 PrimitiveType::Int16 => ArrowDataType::Int16,
479 PrimitiveType::Int32 => ArrowDataType::Int32,
480 PrimitiveType::Int64 => ArrowDataType::Int64,
481 PrimitiveType::UInt8 => ArrowDataType::UInt8,
482 PrimitiveType::UInt16 => ArrowDataType::UInt16,
483 PrimitiveType::UInt32 => ArrowDataType::UInt32,
484 PrimitiveType::UInt64 => ArrowDataType::UInt64,
485 PrimitiveType::Int128 => ArrowDataType::Int128,
486 PrimitiveType::Int256 => ArrowDataType::Decimal256(32, 32),
487 PrimitiveType::Float16 => ArrowDataType::Float16,
488 PrimitiveType::Float32 => ArrowDataType::Float32,
489 PrimitiveType::Float64 => ArrowDataType::Float64,
490 PrimitiveType::DaysMs => ArrowDataType::Interval(IntervalUnit::DayTime),
491 PrimitiveType::MonthDayNano => ArrowDataType::Interval(IntervalUnit::MonthDayNano),
492 PrimitiveType::UInt128 => unimplemented!(),
493 }
494 }
495}
496
497pub type SchemaRef = Arc<ArrowSchema>;
499
500pub fn get_extension(metadata: &Metadata) -> Extension {
502 if let Some(name) = metadata.get(&PlSmallStr::from_static("ARROW:extension:name")) {
503 let metadata = metadata
504 .get(&PlSmallStr::from_static("ARROW:extension:metadata"))
505 .cloned();
506 Some((name.clone(), metadata))
507 } else {
508 None
509 }
510}
511
512#[cfg(not(feature = "bigidx"))]
513pub type IdxArr = super::array::UInt32Array;
514#[cfg(feature = "bigidx")]
515pub type IdxArr = super::array::UInt64Array;