1mod field;
4mod physical_type;
5pub mod reshape;
6mod schema;
7
8use std::collections::BTreeMap;
9use std::sync::Arc;
10
11pub use field::{DTYPE_CATEGORICAL, DTYPE_ENUM_VALUES, Field};
12pub use physical_type::*;
13use polars_utils::pl_str::PlSmallStr;
14pub use schema::{ArrowSchema, ArrowSchemaRef};
15#[cfg(feature = "serde")]
16use serde::{Deserialize, Serialize};
17
18pub type Metadata = BTreeMap<PlSmallStr, PlSmallStr>;
20pub(crate) type Extension = Option<(PlSmallStr, Option<PlSmallStr>)>;
22
23#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
32#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
33pub enum ArrowDataType {
34 #[default]
36 Null,
37 Boolean,
39 Int8,
41 Int16,
43 Int32,
45 Int64,
47 Int128,
49 UInt8,
51 UInt16,
53 UInt32,
55 UInt64,
57 Float16,
59 Float32,
61 Float64,
63 Timestamp(TimeUnit, Option<PlSmallStr>),
78 Date32,
81 Date64,
84 Time32(TimeUnit),
87 Time64(TimeUnit),
90 Duration(TimeUnit),
92 Interval(IntervalUnit),
95 Binary,
97 FixedSizeBinary(usize),
100 LargeBinary,
102 Utf8,
104 LargeUtf8,
106 List(Box<Field>),
108 FixedSizeList(Box<Field>, usize),
110 LargeList(Box<Field>),
112 Struct(Vec<Field>),
114 Map(Box<Field>, bool),
142 Dictionary(IntegerType, Box<ArrowDataType>, bool),
155 Decimal(usize, usize),
160 Decimal256(usize, usize),
162 Extension(Box<ExtensionType>),
164 BinaryView,
167 Utf8View,
170 Unknown,
172 #[cfg_attr(feature = "serde", serde(skip))]
175 Union(Box<UnionType>),
176}
177
178#[derive(Debug, Clone, PartialEq, Eq, Hash)]
179#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
180pub struct ExtensionType {
181 pub name: PlSmallStr,
182 pub inner: ArrowDataType,
183 pub metadata: Option<PlSmallStr>,
184}
185
186#[derive(Debug, Clone, PartialEq, Eq, Hash)]
187pub struct UnionType {
188 pub fields: Vec<Field>,
189 pub ids: Option<Vec<i32>>,
190 pub mode: UnionMode,
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
195#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
196pub enum UnionMode {
197 Dense,
199 Sparse,
201}
202
203impl UnionMode {
204 pub fn sparse(is_sparse: bool) -> Self {
207 if is_sparse { Self::Sparse } else { Self::Dense }
208 }
209
210 pub fn is_sparse(&self) -> bool {
212 matches!(self, Self::Sparse)
213 }
214
215 pub fn is_dense(&self) -> bool {
217 matches!(self, Self::Dense)
218 }
219}
220
221#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
223#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
224pub enum TimeUnit {
225 Second,
227 Millisecond,
229 Microsecond,
231 Nanosecond,
233}
234
235#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
237#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
238pub enum IntervalUnit {
239 YearMonth,
241 DayTime,
244 MonthDayNano,
246}
247
248impl ArrowDataType {
249 pub const IDX_DTYPE: Self = {
251 #[cfg(not(feature = "bigidx"))]
252 {
253 ArrowDataType::UInt32
254 }
255 #[cfg(feature = "bigidx")]
256 {
257 ArrowDataType::UInt64
258 }
259 };
260
261 pub fn to_physical_type(&self) -> PhysicalType {
263 use ArrowDataType::*;
264 match self {
265 Null => PhysicalType::Null,
266 Boolean => PhysicalType::Boolean,
267 Int8 => PhysicalType::Primitive(PrimitiveType::Int8),
268 Int16 => PhysicalType::Primitive(PrimitiveType::Int16),
269 Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => {
270 PhysicalType::Primitive(PrimitiveType::Int32)
271 },
272 Int64 | Date64 | Timestamp(_, _) | Time64(_) | Duration(_) => {
273 PhysicalType::Primitive(PrimitiveType::Int64)
274 },
275 Decimal(_, _) => PhysicalType::Primitive(PrimitiveType::Int128),
276 Decimal256(_, _) => PhysicalType::Primitive(PrimitiveType::Int256),
277 UInt8 => PhysicalType::Primitive(PrimitiveType::UInt8),
278 UInt16 => PhysicalType::Primitive(PrimitiveType::UInt16),
279 UInt32 => PhysicalType::Primitive(PrimitiveType::UInt32),
280 UInt64 => PhysicalType::Primitive(PrimitiveType::UInt64),
281 Float16 => PhysicalType::Primitive(PrimitiveType::Float16),
282 Float32 => PhysicalType::Primitive(PrimitiveType::Float32),
283 Float64 => PhysicalType::Primitive(PrimitiveType::Float64),
284 Int128 => PhysicalType::Primitive(PrimitiveType::Int128),
285 Interval(IntervalUnit::DayTime) => PhysicalType::Primitive(PrimitiveType::DaysMs),
286 Interval(IntervalUnit::MonthDayNano) => {
287 PhysicalType::Primitive(PrimitiveType::MonthDayNano)
288 },
289 Binary => PhysicalType::Binary,
290 FixedSizeBinary(_) => PhysicalType::FixedSizeBinary,
291 LargeBinary => PhysicalType::LargeBinary,
292 Utf8 => PhysicalType::Utf8,
293 LargeUtf8 => PhysicalType::LargeUtf8,
294 BinaryView => PhysicalType::BinaryView,
295 Utf8View => PhysicalType::Utf8View,
296 List(_) => PhysicalType::List,
297 FixedSizeList(_, _) => PhysicalType::FixedSizeList,
298 LargeList(_) => PhysicalType::LargeList,
299 Struct(_) => PhysicalType::Struct,
300 Union(_) => PhysicalType::Union,
301 Map(_, _) => PhysicalType::Map,
302 Dictionary(key, _, _) => PhysicalType::Dictionary(*key),
303 Extension(ext) => ext.inner.to_physical_type(),
304 Unknown => unimplemented!(),
305 }
306 }
307
308 pub fn underlying_physical_type(&self) -> ArrowDataType {
310 use ArrowDataType::*;
311 match self {
312 Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => Int32,
313 Date64
314 | Timestamp(_, _)
315 | Time64(_)
316 | Duration(_)
317 | Interval(IntervalUnit::DayTime) => Int64,
318 Interval(IntervalUnit::MonthDayNano) => unimplemented!(),
319 Binary => Binary,
320 List(field) => List(Box::new(Field {
321 dtype: field.dtype.underlying_physical_type(),
322 ..*field.clone()
323 })),
324 LargeList(field) => LargeList(Box::new(Field {
325 dtype: field.dtype.underlying_physical_type(),
326 ..*field.clone()
327 })),
328 FixedSizeList(field, width) => FixedSizeList(
329 Box::new(Field {
330 dtype: field.dtype.underlying_physical_type(),
331 ..*field.clone()
332 }),
333 *width,
334 ),
335 Struct(fields) => Struct(
336 fields
337 .iter()
338 .map(|field| Field {
339 dtype: field.dtype.underlying_physical_type(),
340 ..field.clone()
341 })
342 .collect(),
343 ),
344 Dictionary(keys, _, _) => (*keys).into(),
345 Union(_) => unimplemented!(),
346 Map(_, _) => unimplemented!(),
347 Extension(ext) => ext.inner.underlying_physical_type(),
348 _ => self.clone(),
349 }
350 }
351
352 pub fn to_logical_type(&self) -> &ArrowDataType {
356 use ArrowDataType::*;
357 match self {
358 Extension(ext) => ext.inner.to_logical_type(),
359 _ => self,
360 }
361 }
362
363 pub fn inner_dtype(&self) -> Option<&ArrowDataType> {
364 match self {
365 ArrowDataType::List(inner) => Some(inner.dtype()),
366 ArrowDataType::LargeList(inner) => Some(inner.dtype()),
367 ArrowDataType::FixedSizeList(inner, _) => Some(inner.dtype()),
368 _ => None,
369 }
370 }
371
372 pub fn is_nested(&self) -> bool {
373 use ArrowDataType as D;
374
375 matches!(
376 self,
377 D::List(_)
378 | D::LargeList(_)
379 | D::FixedSizeList(_, _)
380 | D::Struct(_)
381 | D::Union(_)
382 | D::Map(_, _)
383 | D::Dictionary(_, _, _)
384 | D::Extension(_)
385 )
386 }
387
388 pub fn is_view(&self) -> bool {
389 matches!(self, ArrowDataType::Utf8View | ArrowDataType::BinaryView)
390 }
391
392 pub fn is_numeric(&self) -> bool {
393 use ArrowDataType as D;
394 matches!(
395 self,
396 D::Int8
397 | D::Int16
398 | D::Int32
399 | D::Int64
400 | D::Int128
401 | D::UInt8
402 | D::UInt16
403 | D::UInt32
404 | D::UInt64
405 | D::Float32
406 | D::Float64
407 | D::Decimal(_, _)
408 | D::Decimal256(_, _)
409 )
410 }
411
412 pub fn to_fixed_size_list(self, size: usize, is_nullable: bool) -> ArrowDataType {
413 ArrowDataType::FixedSizeList(
414 Box::new(Field::new(
415 PlSmallStr::from_static("item"),
416 self,
417 is_nullable,
418 )),
419 size,
420 )
421 }
422
423 pub fn contains_dictionary(&self) -> bool {
425 use ArrowDataType as D;
426 match self {
427 D::Null
428 | D::Boolean
429 | D::Int8
430 | D::Int16
431 | D::Int32
432 | D::Int64
433 | D::UInt8
434 | D::UInt16
435 | D::UInt32
436 | D::UInt64
437 | D::Int128
438 | D::Float16
439 | D::Float32
440 | D::Float64
441 | D::Timestamp(_, _)
442 | D::Date32
443 | D::Date64
444 | D::Time32(_)
445 | D::Time64(_)
446 | D::Duration(_)
447 | D::Interval(_)
448 | D::Binary
449 | D::FixedSizeBinary(_)
450 | D::LargeBinary
451 | D::Utf8
452 | D::LargeUtf8
453 | D::Decimal(_, _)
454 | D::Decimal256(_, _)
455 | D::BinaryView
456 | D::Utf8View
457 | D::Unknown => false,
458 D::List(field)
459 | D::FixedSizeList(field, _)
460 | D::Map(field, _)
461 | D::LargeList(field) => field.dtype().contains_dictionary(),
462 D::Struct(fields) => fields.iter().any(|f| f.dtype().contains_dictionary()),
463 D::Union(union) => union.fields.iter().any(|f| f.dtype().contains_dictionary()),
464 D::Dictionary(_, _, _) => true,
465 D::Extension(ext) => ext.inner.contains_dictionary(),
466 }
467 }
468}
469
470impl From<IntegerType> for ArrowDataType {
471 fn from(item: IntegerType) -> Self {
472 match item {
473 IntegerType::Int8 => ArrowDataType::Int8,
474 IntegerType::Int16 => ArrowDataType::Int16,
475 IntegerType::Int32 => ArrowDataType::Int32,
476 IntegerType::Int64 => ArrowDataType::Int64,
477 IntegerType::Int128 => ArrowDataType::Int128,
478 IntegerType::UInt8 => ArrowDataType::UInt8,
479 IntegerType::UInt16 => ArrowDataType::UInt16,
480 IntegerType::UInt32 => ArrowDataType::UInt32,
481 IntegerType::UInt64 => ArrowDataType::UInt64,
482 }
483 }
484}
485
486impl From<PrimitiveType> for ArrowDataType {
487 fn from(item: PrimitiveType) -> Self {
488 match item {
489 PrimitiveType::Int8 => ArrowDataType::Int8,
490 PrimitiveType::Int16 => ArrowDataType::Int16,
491 PrimitiveType::Int32 => ArrowDataType::Int32,
492 PrimitiveType::Int64 => ArrowDataType::Int64,
493 PrimitiveType::UInt8 => ArrowDataType::UInt8,
494 PrimitiveType::UInt16 => ArrowDataType::UInt16,
495 PrimitiveType::UInt32 => ArrowDataType::UInt32,
496 PrimitiveType::UInt64 => ArrowDataType::UInt64,
497 PrimitiveType::Int128 => ArrowDataType::Int128,
498 PrimitiveType::Int256 => ArrowDataType::Decimal256(32, 32),
499 PrimitiveType::Float16 => ArrowDataType::Float16,
500 PrimitiveType::Float32 => ArrowDataType::Float32,
501 PrimitiveType::Float64 => ArrowDataType::Float64,
502 PrimitiveType::DaysMs => ArrowDataType::Interval(IntervalUnit::DayTime),
503 PrimitiveType::MonthDayNano => ArrowDataType::Interval(IntervalUnit::MonthDayNano),
504 PrimitiveType::UInt128 => unimplemented!(),
505 }
506 }
507}
508
509pub type SchemaRef = Arc<ArrowSchema>;
511
512pub fn get_extension(metadata: &Metadata) -> Extension {
514 if let Some(name) = metadata.get(&PlSmallStr::from_static("ARROW:extension:name")) {
515 let metadata = metadata
516 .get(&PlSmallStr::from_static("ARROW:extension:metadata"))
517 .cloned();
518 Some((name.clone(), metadata))
519 } else {
520 None
521 }
522}
523
524#[cfg(not(feature = "bigidx"))]
525pub type IdxArr = super::array::UInt32Array;
526#[cfg(feature = "bigidx")]
527pub type IdxArr = super::array::UInt64Array;