pyo3_arrow/
datatypes.rs

1use std::fmt::Display;
2use std::sync::Arc;
3
4use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit};
5use pyo3::exceptions::{PyTypeError, PyValueError};
6use pyo3::intern;
7use pyo3::prelude::*;
8use pyo3::types::{PyCapsule, PyTuple, PyType};
9
10use crate::error::PyArrowResult;
11use crate::export::Arro3DataType;
12use crate::ffi::from_python::utils::import_schema_pycapsule;
13use crate::ffi::to_python::nanoarrow::to_nanoarrow_schema;
14use crate::ffi::to_schema_pycapsule;
15use crate::PyField;
16
17struct PyTimeUnit(arrow_schema::TimeUnit);
18
19impl<'a> FromPyObject<'a> for PyTimeUnit {
20    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
21        let s: String = ob.extract()?;
22        match s.to_lowercase().as_str() {
23            "s" => Ok(Self(TimeUnit::Second)),
24            "ms" => Ok(Self(TimeUnit::Millisecond)),
25            "us" => Ok(Self(TimeUnit::Microsecond)),
26            "ns" => Ok(Self(TimeUnit::Nanosecond)),
27            _ => Err(PyValueError::new_err("Unexpected time unit")),
28        }
29    }
30}
31
32/// A Python-facing wrapper around [DataType].
33#[derive(PartialEq, Eq, Debug)]
34#[pyclass(module = "arro3.core._core", name = "DataType", subclass, frozen)]
35pub struct PyDataType(DataType);
36
37impl PyDataType {
38    /// Construct a new PyDataType around a [DataType].
39    pub fn new(data_type: DataType) -> Self {
40        Self(data_type)
41    }
42
43    /// Create from a raw Arrow C Schema capsule
44    pub fn from_arrow_pycapsule(capsule: &Bound<PyCapsule>) -> PyResult<Self> {
45        let schema_ptr = import_schema_pycapsule(capsule)?;
46        let data_type =
47            DataType::try_from(schema_ptr).map_err(|err| PyTypeError::new_err(err.to_string()))?;
48        Ok(Self::new(data_type))
49    }
50
51    /// Consume this and return its inner part.
52    pub fn into_inner(self) -> DataType {
53        self.0
54    }
55
56    /// Export this to a Python `arro3.core.DataType`.
57    pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
58        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
59        arro3_mod.getattr(intern!(py, "DataType"))?.call_method1(
60            intern!(py, "from_arrow_pycapsule"),
61            PyTuple::new(py, vec![self.__arrow_c_schema__(py)?])?,
62        )
63    }
64
65    /// Export this to a Python `arro3.core.DataType`.
66    pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
67        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
68        let capsule = to_schema_pycapsule(py, &self.0)?;
69        arro3_mod.getattr(intern!(py, "DataType"))?.call_method1(
70            intern!(py, "from_arrow_pycapsule"),
71            PyTuple::new(py, vec![capsule])?,
72        )
73    }
74
75    /// Export this to a Python `nanoarrow.Schema`.
76    pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
77        to_nanoarrow_schema(py, &self.__arrow_c_schema__(py)?)
78    }
79
80    /// Export to a pyarrow.DataType
81    ///
82    /// Requires pyarrow >=14
83    pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
84        let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
85        let pyarrow_field = pyarrow_mod
86            .getattr(intern!(py, "field"))?
87            .call1(PyTuple::new(py, vec![self.into_pyobject(py)?])?)?;
88        Ok(pyarrow_field
89            .getattr(intern!(py, "type"))?
90            .into_pyobject(py)?
91            .into_any()
92            .unbind())
93    }
94}
95
96impl From<PyDataType> for DataType {
97    fn from(value: PyDataType) -> Self {
98        value.0
99    }
100}
101
102impl From<DataType> for PyDataType {
103    fn from(value: DataType) -> Self {
104        Self(value)
105    }
106}
107
108impl From<&DataType> for PyDataType {
109    fn from(value: &DataType) -> Self {
110        Self(value.clone())
111    }
112}
113
114impl AsRef<DataType> for PyDataType {
115    fn as_ref(&self) -> &DataType {
116        &self.0
117    }
118}
119
120impl Display for PyDataType {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        write!(f, "arro3.core.DataType<")?;
123        self.0.fmt(f)?;
124        writeln!(f, ">")?;
125        Ok(())
126    }
127}
128
129#[allow(non_snake_case)]
130#[pymethods]
131impl PyDataType {
132    pub(crate) fn __arrow_c_schema__<'py>(
133        &'py self,
134        py: Python<'py>,
135    ) -> PyArrowResult<Bound<'py, PyCapsule>> {
136        to_schema_pycapsule(py, &self.0)
137    }
138
139    fn __eq__(&self, other: PyDataType) -> bool {
140        self.equals(other, false)
141    }
142
143    fn __repr__(&self) -> String {
144        self.to_string()
145    }
146
147    #[classmethod]
148    fn from_arrow(_cls: &Bound<PyType>, input: Self) -> Self {
149        input
150    }
151
152    #[classmethod]
153    #[pyo3(name = "from_arrow_pycapsule")]
154    fn from_arrow_pycapsule_py(_cls: &Bound<PyType>, capsule: &Bound<PyCapsule>) -> PyResult<Self> {
155        Self::from_arrow_pycapsule(capsule)
156    }
157
158    #[getter]
159    fn bit_width(&self) -> Option<usize> {
160        self.0.primitive_width().map(|width| width * 8)
161    }
162
163    #[pyo3(signature=(other, *, check_metadata=false))]
164    fn equals(&self, other: PyDataType, check_metadata: bool) -> bool {
165        let other = other.into_inner();
166        if check_metadata {
167            self.0 == other
168        } else {
169            self.0.equals_datatype(&other)
170        }
171    }
172
173    #[getter]
174    fn list_size(&self) -> Option<i32> {
175        match &self.0 {
176            DataType::FixedSizeList(_, list_size) => Some(*list_size),
177            _ => None,
178        }
179    }
180
181    #[getter]
182    fn num_fields(&self) -> usize {
183        match &self.0 {
184            DataType::Null
185            | DataType::Boolean
186            | DataType::Int8
187            | DataType::Int16
188            | DataType::Int32
189            | DataType::Int64
190            | DataType::UInt8
191            | DataType::UInt16
192            | DataType::UInt32
193            | DataType::UInt64
194            | DataType::Float16
195            | DataType::Float32
196            | DataType::Float64
197            | DataType::Timestamp(_, _)
198            | DataType::Date32
199            | DataType::Date64
200            | DataType::Time32(_)
201            | DataType::Time64(_)
202            | DataType::Duration(_)
203            | DataType::Interval(_)
204            | DataType::Binary
205            | DataType::FixedSizeBinary(_)
206            | DataType::LargeBinary
207            | DataType::BinaryView
208            | DataType::Utf8
209            | DataType::LargeUtf8
210            | DataType::Utf8View
211            | DataType::Decimal128(_, _)
212            | DataType::Decimal256(_, _) => 0,
213            DataType::List(_)
214            | DataType::ListView(_)
215            | DataType::FixedSizeList(_, _)
216            | DataType::LargeList(_)
217            | DataType::LargeListView(_) => 1,
218            DataType::Struct(fields) => fields.len(),
219            DataType::Union(fields, _) => fields.len(),
220            // Is this accurate?
221            DataType::Dictionary(_, _) | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) => 2,
222        }
223    }
224
225    #[getter]
226    fn time_unit(&self) -> Option<&str> {
227        match &self.0 {
228            DataType::Time32(unit)
229            | DataType::Time64(unit)
230            | DataType::Timestamp(unit, _)
231            | DataType::Duration(unit) => match unit {
232                TimeUnit::Second => Some("s"),
233                TimeUnit::Millisecond => Some("ms"),
234                TimeUnit::Microsecond => Some("us"),
235                TimeUnit::Nanosecond => Some("ns"),
236            },
237            _ => None,
238        }
239    }
240
241    #[getter]
242    fn tz(&self) -> Option<&str> {
243        match &self.0 {
244            DataType::Timestamp(_, tz) => tz.as_deref(),
245            _ => None,
246        }
247    }
248
249    #[getter]
250    fn value_type(&self) -> Option<Arro3DataType> {
251        match &self.0 {
252            DataType::FixedSizeList(value_field, _)
253            | DataType::List(value_field)
254            | DataType::LargeList(value_field)
255            | DataType::ListView(value_field)
256            | DataType::LargeListView(value_field)
257            | DataType::RunEndEncoded(_, value_field) => {
258                Some(PyDataType::new(value_field.data_type().clone()).into())
259            }
260            DataType::Dictionary(_key_type, value_type) => {
261                Some(PyDataType::new(*value_type.clone()).into())
262            }
263            _ => None,
264        }
265    }
266
267    ///////////////////// Constructors
268
269    #[classmethod]
270    fn null(_: &Bound<PyType>) -> Self {
271        Self(DataType::Null)
272    }
273
274    #[classmethod]
275    fn bool(_: &Bound<PyType>) -> Self {
276        Self(DataType::Boolean)
277    }
278
279    #[classmethod]
280    fn int8(_: &Bound<PyType>) -> Self {
281        Self(DataType::Int8)
282    }
283
284    #[classmethod]
285    fn int16(_: &Bound<PyType>) -> Self {
286        Self(DataType::Int16)
287    }
288
289    #[classmethod]
290    fn int32(_: &Bound<PyType>) -> Self {
291        Self(DataType::Int32)
292    }
293
294    #[classmethod]
295    fn int64(_: &Bound<PyType>) -> Self {
296        Self(DataType::Int64)
297    }
298
299    #[classmethod]
300    fn uint8(_: &Bound<PyType>) -> Self {
301        Self(DataType::UInt8)
302    }
303
304    #[classmethod]
305    fn uint16(_: &Bound<PyType>) -> Self {
306        Self(DataType::UInt16)
307    }
308
309    #[classmethod]
310    fn uint32(_: &Bound<PyType>) -> Self {
311        Self(DataType::UInt32)
312    }
313
314    #[classmethod]
315    fn uint64(_: &Bound<PyType>) -> Self {
316        Self(DataType::UInt64)
317    }
318
319    #[classmethod]
320    fn float16(_: &Bound<PyType>) -> Self {
321        Self(DataType::Float16)
322    }
323
324    #[classmethod]
325    fn float32(_: &Bound<PyType>) -> Self {
326        Self(DataType::Float32)
327    }
328
329    #[classmethod]
330    fn float64(_: &Bound<PyType>) -> Self {
331        Self(DataType::Float64)
332    }
333
334    #[classmethod]
335    fn time32(_: &Bound<PyType>, unit: PyTimeUnit) -> PyArrowResult<Self> {
336        if unit.0 == TimeUnit::Microsecond || unit.0 == TimeUnit::Nanosecond {
337            return Err(PyValueError::new_err("Unexpected timeunit for time32").into());
338        }
339
340        Ok(Self(DataType::Time32(unit.0)))
341    }
342
343    #[classmethod]
344    fn time64(_: &Bound<PyType>, unit: PyTimeUnit) -> PyArrowResult<Self> {
345        if unit.0 == TimeUnit::Second || unit.0 == TimeUnit::Millisecond {
346            return Err(PyValueError::new_err("Unexpected timeunit for time64").into());
347        }
348
349        Ok(Self(DataType::Time64(unit.0)))
350    }
351
352    #[classmethod]
353    #[pyo3(signature = (unit, *, tz=None))]
354    fn timestamp(_: &Bound<PyType>, unit: PyTimeUnit, tz: Option<String>) -> Self {
355        Self(DataType::Timestamp(unit.0, tz.map(|s| s.into())))
356    }
357
358    #[classmethod]
359    fn date32(_: &Bound<PyType>) -> Self {
360        Self(DataType::Date32)
361    }
362
363    #[classmethod]
364    fn date64(_: &Bound<PyType>) -> Self {
365        Self(DataType::Date64)
366    }
367
368    #[classmethod]
369    fn duration(_: &Bound<PyType>, unit: PyTimeUnit) -> Self {
370        Self(DataType::Duration(unit.0))
371    }
372
373    #[classmethod]
374    fn month_day_nano_interval(_: &Bound<PyType>) -> Self {
375        Self(DataType::Interval(IntervalUnit::MonthDayNano))
376    }
377
378    #[classmethod]
379    #[pyo3(signature = (length=None))]
380    fn binary(_: &Bound<PyType>, length: Option<i32>) -> Self {
381        if let Some(length) = length {
382            Self(DataType::FixedSizeBinary(length))
383        } else {
384            Self(DataType::Binary)
385        }
386    }
387
388    #[classmethod]
389    fn string(_: &Bound<PyType>) -> Self {
390        Self(DataType::Utf8)
391    }
392
393    #[classmethod]
394    fn utf8(_: &Bound<PyType>) -> Self {
395        Self(DataType::Utf8)
396    }
397
398    #[classmethod]
399    fn large_binary(_: &Bound<PyType>) -> Self {
400        Self(DataType::LargeBinary)
401    }
402
403    #[classmethod]
404    fn large_string(_: &Bound<PyType>) -> Self {
405        Self(DataType::LargeUtf8)
406    }
407
408    #[classmethod]
409    fn large_utf8(_: &Bound<PyType>) -> Self {
410        Self(DataType::LargeUtf8)
411    }
412
413    #[classmethod]
414    fn binary_view(_: &Bound<PyType>) -> Self {
415        Self(DataType::BinaryView)
416    }
417
418    #[classmethod]
419    fn string_view(_: &Bound<PyType>) -> Self {
420        Self(DataType::Utf8View)
421    }
422
423    #[classmethod]
424    fn decimal128(_: &Bound<PyType>, precision: u8, scale: i8) -> Self {
425        Self(DataType::Decimal128(precision, scale))
426    }
427
428    #[classmethod]
429    fn decimal256(_: &Bound<PyType>, precision: u8, scale: i8) -> Self {
430        Self(DataType::Decimal256(precision, scale))
431    }
432
433    #[classmethod]
434    #[pyo3(signature = (value_type, list_size=None))]
435    fn list(_: &Bound<PyType>, value_type: PyField, list_size: Option<i32>) -> Self {
436        if let Some(list_size) = list_size {
437            Self(DataType::FixedSizeList(value_type.into(), list_size))
438        } else {
439            Self(DataType::List(value_type.into()))
440        }
441    }
442
443    #[classmethod]
444    fn large_list(_: &Bound<PyType>, value_type: PyField) -> Self {
445        Self(DataType::LargeList(value_type.into()))
446    }
447
448    #[classmethod]
449    fn list_view(_: &Bound<PyType>, value_type: PyField) -> Self {
450        Self(DataType::ListView(value_type.into()))
451    }
452
453    #[classmethod]
454    fn large_list_view(_: &Bound<PyType>, value_type: PyField) -> Self {
455        Self(DataType::LargeListView(value_type.into()))
456    }
457
458    #[classmethod]
459    fn map(_: &Bound<PyType>, key_type: PyField, item_type: PyField, keys_sorted: bool) -> Self {
460        // Note: copied from source of `Field::new_map`
461        // https://github.com/apache/arrow-rs/blob/bf9ce475df82d362631099d491d3454d64d50217/arrow-schema/src/field.rs#L251-L258
462        let data_type = DataType::Map(
463            Arc::new(Field::new(
464                "entries",
465                DataType::Struct(vec![key_type.into_inner(), item_type.into_inner()].into()),
466                false, // The inner map field is always non-nullable (arrow-rs #1697),
467            )),
468            keys_sorted,
469        );
470        Self(data_type)
471    }
472
473    #[classmethod]
474    fn r#struct(_: &Bound<PyType>, fields: Vec<PyField>) -> Self {
475        Self(DataType::Struct(
476            fields.into_iter().map(|field| field.into_inner()).collect(),
477        ))
478    }
479
480    #[classmethod]
481    fn dictionary(_: &Bound<PyType>, index_type: PyDataType, value_type: PyDataType) -> Self {
482        Self(DataType::Dictionary(
483            Box::new(index_type.into_inner()),
484            Box::new(value_type.into_inner()),
485        ))
486    }
487
488    #[classmethod]
489    fn run_end_encoded(_: &Bound<PyType>, run_end_type: PyField, value_type: PyField) -> Self {
490        Self(DataType::RunEndEncoded(
491            run_end_type.into_inner(),
492            value_type.into_inner(),
493        ))
494    }
495
496    ///////////////////// Type checking
497
498    #[staticmethod]
499    fn is_boolean(t: PyDataType) -> bool {
500        t.0 == DataType::Boolean
501    }
502
503    #[staticmethod]
504    fn is_integer(t: PyDataType) -> bool {
505        t.0.is_integer()
506    }
507
508    #[staticmethod]
509    fn is_signed_integer(t: PyDataType) -> bool {
510        t.0.is_signed_integer()
511    }
512
513    #[staticmethod]
514    fn is_unsigned_integer(t: PyDataType) -> bool {
515        t.0.is_unsigned_integer()
516    }
517
518    #[staticmethod]
519    fn is_int8(t: PyDataType) -> bool {
520        t.0 == DataType::Int8
521    }
522    #[staticmethod]
523    fn is_int16(t: PyDataType) -> bool {
524        t.0 == DataType::Int16
525    }
526    #[staticmethod]
527    fn is_int32(t: PyDataType) -> bool {
528        t.0 == DataType::Int32
529    }
530    #[staticmethod]
531    fn is_int64(t: PyDataType) -> bool {
532        t.0 == DataType::Int64
533    }
534    #[staticmethod]
535    fn is_uint8(t: PyDataType) -> bool {
536        t.0 == DataType::UInt8
537    }
538    #[staticmethod]
539    fn is_uint16(t: PyDataType) -> bool {
540        t.0 == DataType::UInt16
541    }
542    #[staticmethod]
543    fn is_uint32(t: PyDataType) -> bool {
544        t.0 == DataType::UInt32
545    }
546    #[staticmethod]
547    fn is_uint64(t: PyDataType) -> bool {
548        t.0 == DataType::UInt64
549    }
550    #[staticmethod]
551    fn is_floating(t: PyDataType) -> bool {
552        t.0.is_floating()
553    }
554    #[staticmethod]
555    fn is_float16(t: PyDataType) -> bool {
556        t.0 == DataType::Float16
557    }
558    #[staticmethod]
559    fn is_float32(t: PyDataType) -> bool {
560        t.0 == DataType::Float32
561    }
562    #[staticmethod]
563    fn is_float64(t: PyDataType) -> bool {
564        t.0 == DataType::Float64
565    }
566    #[staticmethod]
567    fn is_decimal(t: PyDataType) -> bool {
568        matches!(t.0, DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
569    }
570    #[staticmethod]
571    fn is_decimal128(t: PyDataType) -> bool {
572        matches!(t.0, DataType::Decimal128(_, _))
573    }
574    #[staticmethod]
575    fn is_decimal256(t: PyDataType) -> bool {
576        matches!(t.0, DataType::Decimal256(_, _))
577    }
578
579    #[staticmethod]
580    fn is_list(t: PyDataType) -> bool {
581        matches!(t.0, DataType::List(_))
582    }
583    #[staticmethod]
584    fn is_large_list(t: PyDataType) -> bool {
585        matches!(t.0, DataType::LargeList(_))
586    }
587    #[staticmethod]
588    fn is_fixed_size_list(t: PyDataType) -> bool {
589        matches!(t.0, DataType::FixedSizeList(_, _))
590    }
591    #[staticmethod]
592    fn is_list_view(t: PyDataType) -> bool {
593        matches!(t.0, DataType::ListView(_))
594    }
595    #[staticmethod]
596    fn is_large_list_view(t: PyDataType) -> bool {
597        matches!(t.0, DataType::LargeListView(_))
598    }
599    #[staticmethod]
600    fn is_struct(t: PyDataType) -> bool {
601        matches!(t.0, DataType::Struct(_))
602    }
603    #[staticmethod]
604    fn is_union(t: PyDataType) -> bool {
605        matches!(t.0, DataType::Union(_, _))
606    }
607    #[staticmethod]
608    fn is_nested(t: PyDataType) -> bool {
609        t.0.is_nested()
610    }
611    #[staticmethod]
612    fn is_run_end_encoded(t: PyDataType) -> bool {
613        t.0.is_run_ends_type()
614    }
615    #[staticmethod]
616    fn is_temporal(t: PyDataType) -> bool {
617        t.0.is_temporal()
618    }
619    #[staticmethod]
620    fn is_timestamp(t: PyDataType) -> bool {
621        matches!(t.0, DataType::Timestamp(_, _))
622    }
623    #[staticmethod]
624    fn is_date(t: PyDataType) -> bool {
625        matches!(t.0, DataType::Date32 | DataType::Date64)
626    }
627    #[staticmethod]
628    fn is_date32(t: PyDataType) -> bool {
629        t.0 == DataType::Date32
630    }
631    #[staticmethod]
632    fn is_date64(t: PyDataType) -> bool {
633        t.0 == DataType::Date64
634    }
635    #[staticmethod]
636    fn is_time(t: PyDataType) -> bool {
637        matches!(t.0, DataType::Time32(_) | DataType::Time64(_))
638    }
639    #[staticmethod]
640    fn is_time32(t: PyDataType) -> bool {
641        matches!(t.0, DataType::Time32(_))
642    }
643    #[staticmethod]
644    fn is_time64(t: PyDataType) -> bool {
645        matches!(t.0, DataType::Time64(_))
646    }
647    #[staticmethod]
648    fn is_duration(t: PyDataType) -> bool {
649        matches!(t.0, DataType::Duration(_))
650    }
651    #[staticmethod]
652    fn is_interval(t: PyDataType) -> bool {
653        matches!(t.0, DataType::Interval(_))
654    }
655    #[staticmethod]
656    fn is_null(t: PyDataType) -> bool {
657        t.0 == DataType::Null
658    }
659    #[staticmethod]
660    fn is_binary(t: PyDataType) -> bool {
661        t.0 == DataType::Binary
662    }
663    #[staticmethod]
664    fn is_unicode(t: PyDataType) -> bool {
665        t.0 == DataType::Utf8
666    }
667    #[staticmethod]
668    fn is_string(t: PyDataType) -> bool {
669        t.0 == DataType::Utf8
670    }
671    #[staticmethod]
672    fn is_large_binary(t: PyDataType) -> bool {
673        t.0 == DataType::LargeBinary
674    }
675    #[staticmethod]
676    fn is_large_unicode(t: PyDataType) -> bool {
677        t.0 == DataType::LargeUtf8
678    }
679    #[staticmethod]
680    fn is_large_string(t: PyDataType) -> bool {
681        t.0 == DataType::LargeUtf8
682    }
683    #[staticmethod]
684    fn is_binary_view(t: PyDataType) -> bool {
685        t.0 == DataType::BinaryView
686    }
687    #[staticmethod]
688    fn is_string_view(t: PyDataType) -> bool {
689        t.0 == DataType::Utf8View
690    }
691    #[staticmethod]
692    fn is_fixed_size_binary(t: PyDataType) -> bool {
693        matches!(t.0, DataType::FixedSizeBinary(_))
694    }
695    #[staticmethod]
696    fn is_map(t: PyDataType) -> bool {
697        matches!(t.0, DataType::Map(_, _))
698    }
699    #[staticmethod]
700    fn is_dictionary(t: PyDataType) -> bool {
701        matches!(t.0, DataType::Dictionary(_, _))
702    }
703    #[staticmethod]
704    fn is_primitive(t: PyDataType) -> bool {
705        t.0.is_primitive()
706    }
707    #[staticmethod]
708    fn is_numeric(t: PyDataType) -> bool {
709        t.0.is_numeric()
710    }
711    #[staticmethod]
712    fn is_dictionary_key_type(t: PyDataType) -> bool {
713        t.0.is_dictionary_key_type()
714    }
715}