pyo3_arrow/
datatypes.rs

1use std::fmt::Display;
2use std::sync::Arc;
3
4use arrow::datatypes::DataType;
5use arrow_schema::{Field, IntervalUnit, TimeUnit};
6use pyo3::exceptions::{PyTypeError, PyValueError};
7use pyo3::intern;
8use pyo3::prelude::*;
9use pyo3::types::{PyCapsule, PyTuple, PyType};
10
11use crate::error::PyArrowResult;
12use crate::export::Arro3DataType;
13use crate::ffi::from_python::utils::import_schema_pycapsule;
14use crate::ffi::to_python::nanoarrow::to_nanoarrow_schema;
15use crate::ffi::to_schema_pycapsule;
16use crate::PyField;
17
18struct PyTimeUnit(arrow_schema::TimeUnit);
19
20impl<'a> FromPyObject<'a> for PyTimeUnit {
21    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
22        let s: String = ob.extract()?;
23        match s.to_lowercase().as_str() {
24            "s" => Ok(Self(TimeUnit::Second)),
25            "ms" => Ok(Self(TimeUnit::Millisecond)),
26            "us" => Ok(Self(TimeUnit::Microsecond)),
27            "ns" => Ok(Self(TimeUnit::Nanosecond)),
28            _ => Err(PyValueError::new_err("Unexpected time unit")),
29        }
30    }
31}
32
33/// A Python-facing wrapper around [DataType].
34#[derive(PartialEq, Eq, Debug)]
35#[pyclass(module = "arro3.core._core", name = "DataType", subclass, frozen)]
36pub struct PyDataType(DataType);
37
38impl PyDataType {
39    /// Construct a new PyDataType around a [DataType].
40    pub fn new(data_type: DataType) -> Self {
41        Self(data_type)
42    }
43
44    /// Create from a raw Arrow C Schema capsule
45    pub fn from_arrow_pycapsule(capsule: &Bound<PyCapsule>) -> PyResult<Self> {
46        let schema_ptr = import_schema_pycapsule(capsule)?;
47        let data_type =
48            DataType::try_from(schema_ptr).map_err(|err| PyTypeError::new_err(err.to_string()))?;
49        Ok(Self::new(data_type))
50    }
51
52    /// Consume this and return its inner part.
53    pub fn into_inner(self) -> DataType {
54        self.0
55    }
56
57    /// Export this to a Python `arro3.core.DataType`.
58    pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
59        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
60        arro3_mod.getattr(intern!(py, "DataType"))?.call_method1(
61            intern!(py, "from_arrow_pycapsule"),
62            PyTuple::new(py, vec![self.__arrow_c_schema__(py)?])?,
63        )
64    }
65
66    /// Export this to a Python `arro3.core.DataType`.
67    pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
68        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
69        let capsule = to_schema_pycapsule(py, &self.0)?;
70        arro3_mod.getattr(intern!(py, "DataType"))?.call_method1(
71            intern!(py, "from_arrow_pycapsule"),
72            PyTuple::new(py, vec![capsule])?,
73        )
74    }
75
76    /// Export this to a Python `nanoarrow.Schema`.
77    pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
78        to_nanoarrow_schema(py, &self.__arrow_c_schema__(py)?)
79    }
80
81    /// Export to a pyarrow.DataType
82    ///
83    /// Requires pyarrow >=14
84    pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
85        let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
86        let pyarrow_field = pyarrow_mod
87            .getattr(intern!(py, "field"))?
88            .call1(PyTuple::new(py, vec![self.into_pyobject(py)?])?)?;
89        Ok(pyarrow_field
90            .getattr(intern!(py, "type"))?
91            .into_pyobject(py)?
92            .into_any()
93            .unbind())
94    }
95}
96
97impl From<PyDataType> for DataType {
98    fn from(value: PyDataType) -> Self {
99        value.0
100    }
101}
102
103impl From<DataType> for PyDataType {
104    fn from(value: DataType) -> Self {
105        Self(value)
106    }
107}
108
109impl From<&DataType> for PyDataType {
110    fn from(value: &DataType) -> Self {
111        Self(value.clone())
112    }
113}
114
115impl AsRef<DataType> for PyDataType {
116    fn as_ref(&self) -> &DataType {
117        &self.0
118    }
119}
120
121impl Display for PyDataType {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        write!(f, "arro3.core.DataType<")?;
124        self.0.fmt(f)?;
125        writeln!(f, ">")?;
126        Ok(())
127    }
128}
129
130#[allow(non_snake_case)]
131#[pymethods]
132impl PyDataType {
133    pub(crate) fn __arrow_c_schema__<'py>(
134        &'py self,
135        py: Python<'py>,
136    ) -> PyArrowResult<Bound<'py, PyCapsule>> {
137        to_schema_pycapsule(py, &self.0)
138    }
139
140    fn __eq__(&self, other: PyDataType) -> bool {
141        self.equals(other, false)
142    }
143
144    fn __repr__(&self) -> String {
145        self.to_string()
146    }
147
148    #[classmethod]
149    fn from_arrow(_cls: &Bound<PyType>, input: Self) -> Self {
150        input
151    }
152
153    #[classmethod]
154    #[pyo3(name = "from_arrow_pycapsule")]
155    fn from_arrow_pycapsule_py(_cls: &Bound<PyType>, capsule: &Bound<PyCapsule>) -> PyResult<Self> {
156        Self::from_arrow_pycapsule(capsule)
157    }
158
159    #[getter]
160    fn bit_width(&self) -> Option<usize> {
161        self.0.primitive_width().map(|width| width * 8)
162    }
163
164    #[pyo3(signature=(other, *, check_metadata=false))]
165    fn equals(&self, other: PyDataType, check_metadata: bool) -> bool {
166        let other = other.into_inner();
167        if check_metadata {
168            self.0 == other
169        } else {
170            self.0.equals_datatype(&other)
171        }
172    }
173
174    #[getter]
175    fn list_size(&self) -> Option<i32> {
176        match &self.0 {
177            DataType::FixedSizeList(_, list_size) => Some(*list_size),
178            _ => None,
179        }
180    }
181
182    #[getter]
183    fn num_fields(&self) -> usize {
184        match &self.0 {
185            DataType::Null
186            | DataType::Boolean
187            | DataType::Int8
188            | DataType::Int16
189            | DataType::Int32
190            | DataType::Int64
191            | DataType::UInt8
192            | DataType::UInt16
193            | DataType::UInt32
194            | DataType::UInt64
195            | DataType::Float16
196            | DataType::Float32
197            | DataType::Float64
198            | DataType::Timestamp(_, _)
199            | DataType::Date32
200            | DataType::Date64
201            | DataType::Time32(_)
202            | DataType::Time64(_)
203            | DataType::Duration(_)
204            | DataType::Interval(_)
205            | DataType::Binary
206            | DataType::FixedSizeBinary(_)
207            | DataType::LargeBinary
208            | DataType::BinaryView
209            | DataType::Utf8
210            | DataType::LargeUtf8
211            | DataType::Utf8View
212            | DataType::Decimal128(_, _)
213            | DataType::Decimal256(_, _) => 0,
214            DataType::List(_)
215            | DataType::ListView(_)
216            | DataType::FixedSizeList(_, _)
217            | DataType::LargeList(_)
218            | DataType::LargeListView(_) => 1,
219            DataType::Struct(fields) => fields.len(),
220            DataType::Union(fields, _) => fields.len(),
221            // Is this accurate?
222            DataType::Dictionary(_, _) | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) => 2,
223        }
224    }
225
226    #[getter]
227    fn time_unit(&self) -> Option<&str> {
228        match &self.0 {
229            DataType::Time32(unit)
230            | DataType::Time64(unit)
231            | DataType::Timestamp(unit, _)
232            | DataType::Duration(unit) => match unit {
233                TimeUnit::Second => Some("s"),
234                TimeUnit::Millisecond => Some("ms"),
235                TimeUnit::Microsecond => Some("us"),
236                TimeUnit::Nanosecond => Some("ns"),
237            },
238            _ => None,
239        }
240    }
241
242    #[getter]
243    fn tz(&self) -> Option<&str> {
244        match &self.0 {
245            DataType::Timestamp(_, tz) => tz.as_deref(),
246            _ => None,
247        }
248    }
249
250    #[getter]
251    fn value_type(&self) -> Option<Arro3DataType> {
252        match &self.0 {
253            DataType::FixedSizeList(value_field, _)
254            | DataType::List(value_field)
255            | DataType::LargeList(value_field)
256            | DataType::ListView(value_field)
257            | DataType::LargeListView(value_field)
258            | DataType::RunEndEncoded(_, value_field) => {
259                Some(PyDataType::new(value_field.data_type().clone()).into())
260            }
261            DataType::Dictionary(_key_type, value_type) => {
262                Some(PyDataType::new(*value_type.clone()).into())
263            }
264            _ => None,
265        }
266    }
267
268    ///////////////////// Constructors
269
270    #[classmethod]
271    fn null(_: &Bound<PyType>) -> Self {
272        Self(DataType::Null)
273    }
274
275    #[classmethod]
276    fn bool(_: &Bound<PyType>) -> Self {
277        Self(DataType::Boolean)
278    }
279
280    #[classmethod]
281    fn int8(_: &Bound<PyType>) -> Self {
282        Self(DataType::Int8)
283    }
284
285    #[classmethod]
286    fn int16(_: &Bound<PyType>) -> Self {
287        Self(DataType::Int16)
288    }
289
290    #[classmethod]
291    fn int32(_: &Bound<PyType>) -> Self {
292        Self(DataType::Int32)
293    }
294
295    #[classmethod]
296    fn int64(_: &Bound<PyType>) -> Self {
297        Self(DataType::Int64)
298    }
299
300    #[classmethod]
301    fn uint8(_: &Bound<PyType>) -> Self {
302        Self(DataType::UInt8)
303    }
304
305    #[classmethod]
306    fn uint16(_: &Bound<PyType>) -> Self {
307        Self(DataType::UInt16)
308    }
309
310    #[classmethod]
311    fn uint32(_: &Bound<PyType>) -> Self {
312        Self(DataType::UInt32)
313    }
314
315    #[classmethod]
316    fn uint64(_: &Bound<PyType>) -> Self {
317        Self(DataType::UInt64)
318    }
319
320    #[classmethod]
321    fn float16(_: &Bound<PyType>) -> Self {
322        Self(DataType::Float16)
323    }
324
325    #[classmethod]
326    fn float32(_: &Bound<PyType>) -> Self {
327        Self(DataType::Float32)
328    }
329
330    #[classmethod]
331    fn float64(_: &Bound<PyType>) -> Self {
332        Self(DataType::Float64)
333    }
334
335    #[classmethod]
336    fn time32(_: &Bound<PyType>, unit: PyTimeUnit) -> PyArrowResult<Self> {
337        if unit.0 == TimeUnit::Microsecond || unit.0 == TimeUnit::Nanosecond {
338            return Err(PyValueError::new_err("Unexpected timeunit for time32").into());
339        }
340
341        Ok(Self(DataType::Time32(unit.0)))
342    }
343
344    #[classmethod]
345    fn time64(_: &Bound<PyType>, unit: PyTimeUnit) -> PyArrowResult<Self> {
346        if unit.0 == TimeUnit::Second || unit.0 == TimeUnit::Millisecond {
347            return Err(PyValueError::new_err("Unexpected timeunit for time64").into());
348        }
349
350        Ok(Self(DataType::Time64(unit.0)))
351    }
352
353    #[classmethod]
354    #[pyo3(signature = (unit, *, tz=None))]
355    fn timestamp(_: &Bound<PyType>, unit: PyTimeUnit, tz: Option<String>) -> Self {
356        Self(DataType::Timestamp(unit.0, tz.map(|s| s.into())))
357    }
358
359    #[classmethod]
360    fn date32(_: &Bound<PyType>) -> Self {
361        Self(DataType::Date32)
362    }
363
364    #[classmethod]
365    fn date64(_: &Bound<PyType>) -> Self {
366        Self(DataType::Date64)
367    }
368
369    #[classmethod]
370    fn duration(_: &Bound<PyType>, unit: PyTimeUnit) -> Self {
371        Self(DataType::Duration(unit.0))
372    }
373
374    #[classmethod]
375    fn month_day_nano_interval(_: &Bound<PyType>) -> Self {
376        Self(DataType::Interval(IntervalUnit::MonthDayNano))
377    }
378
379    #[classmethod]
380    #[pyo3(signature = (length=None))]
381    fn binary(_: &Bound<PyType>, length: Option<i32>) -> Self {
382        if let Some(length) = length {
383            Self(DataType::FixedSizeBinary(length))
384        } else {
385            Self(DataType::Binary)
386        }
387    }
388
389    #[classmethod]
390    fn string(_: &Bound<PyType>) -> Self {
391        Self(DataType::Utf8)
392    }
393
394    #[classmethod]
395    fn utf8(_: &Bound<PyType>) -> Self {
396        Self(DataType::Utf8)
397    }
398
399    #[classmethod]
400    fn large_binary(_: &Bound<PyType>) -> Self {
401        Self(DataType::LargeBinary)
402    }
403
404    #[classmethod]
405    fn large_string(_: &Bound<PyType>) -> Self {
406        Self(DataType::LargeUtf8)
407    }
408
409    #[classmethod]
410    fn large_utf8(_: &Bound<PyType>) -> Self {
411        Self(DataType::LargeUtf8)
412    }
413
414    #[classmethod]
415    fn binary_view(_: &Bound<PyType>) -> Self {
416        Self(DataType::BinaryView)
417    }
418
419    #[classmethod]
420    fn string_view(_: &Bound<PyType>) -> Self {
421        Self(DataType::Utf8View)
422    }
423
424    #[classmethod]
425    fn decimal128(_: &Bound<PyType>, precision: u8, scale: i8) -> Self {
426        Self(DataType::Decimal128(precision, scale))
427    }
428
429    #[classmethod]
430    fn decimal256(_: &Bound<PyType>, precision: u8, scale: i8) -> Self {
431        Self(DataType::Decimal256(precision, scale))
432    }
433
434    #[classmethod]
435    #[pyo3(signature = (value_type, list_size=None))]
436    fn list(_: &Bound<PyType>, value_type: PyField, list_size: Option<i32>) -> Self {
437        if let Some(list_size) = list_size {
438            Self(DataType::FixedSizeList(value_type.into(), list_size))
439        } else {
440            Self(DataType::List(value_type.into()))
441        }
442    }
443
444    #[classmethod]
445    fn large_list(_: &Bound<PyType>, value_type: PyField) -> Self {
446        Self(DataType::LargeList(value_type.into()))
447    }
448
449    #[classmethod]
450    fn list_view(_: &Bound<PyType>, value_type: PyField) -> Self {
451        Self(DataType::ListView(value_type.into()))
452    }
453
454    #[classmethod]
455    fn large_list_view(_: &Bound<PyType>, value_type: PyField) -> Self {
456        Self(DataType::LargeListView(value_type.into()))
457    }
458
459    #[classmethod]
460    fn map(_: &Bound<PyType>, key_type: PyField, item_type: PyField, keys_sorted: bool) -> Self {
461        // Note: copied from source of `Field::new_map`
462        // https://github.com/apache/arrow-rs/blob/bf9ce475df82d362631099d491d3454d64d50217/arrow-schema/src/field.rs#L251-L258
463        let data_type = DataType::Map(
464            Arc::new(Field::new(
465                "entries",
466                DataType::Struct(vec![key_type.into_inner(), item_type.into_inner()].into()),
467                false, // The inner map field is always non-nullable (arrow-rs #1697),
468            )),
469            keys_sorted,
470        );
471        Self(data_type)
472    }
473
474    #[classmethod]
475    fn r#struct(_: &Bound<PyType>, fields: Vec<PyField>) -> Self {
476        Self(DataType::Struct(
477            fields.into_iter().map(|field| field.into_inner()).collect(),
478        ))
479    }
480
481    #[classmethod]
482    fn dictionary(_: &Bound<PyType>, index_type: PyDataType, value_type: PyDataType) -> Self {
483        Self(DataType::Dictionary(
484            Box::new(index_type.into_inner()),
485            Box::new(value_type.into_inner()),
486        ))
487    }
488
489    #[classmethod]
490    fn run_end_encoded(_: &Bound<PyType>, run_end_type: PyField, value_type: PyField) -> Self {
491        Self(DataType::RunEndEncoded(
492            run_end_type.into_inner(),
493            value_type.into_inner(),
494        ))
495    }
496
497    ///////////////////// Type checking
498
499    #[staticmethod]
500    fn is_boolean(t: PyDataType) -> bool {
501        t.0 == DataType::Boolean
502    }
503
504    #[staticmethod]
505    fn is_integer(t: PyDataType) -> bool {
506        t.0.is_integer()
507    }
508
509    #[staticmethod]
510    fn is_signed_integer(t: PyDataType) -> bool {
511        t.0.is_signed_integer()
512    }
513
514    #[staticmethod]
515    fn is_unsigned_integer(t: PyDataType) -> bool {
516        t.0.is_unsigned_integer()
517    }
518
519    #[staticmethod]
520    fn is_int8(t: PyDataType) -> bool {
521        t.0 == DataType::Int8
522    }
523    #[staticmethod]
524    fn is_int16(t: PyDataType) -> bool {
525        t.0 == DataType::Int16
526    }
527    #[staticmethod]
528    fn is_int32(t: PyDataType) -> bool {
529        t.0 == DataType::Int32
530    }
531    #[staticmethod]
532    fn is_int64(t: PyDataType) -> bool {
533        t.0 == DataType::Int64
534    }
535    #[staticmethod]
536    fn is_uint8(t: PyDataType) -> bool {
537        t.0 == DataType::UInt8
538    }
539    #[staticmethod]
540    fn is_uint16(t: PyDataType) -> bool {
541        t.0 == DataType::UInt16
542    }
543    #[staticmethod]
544    fn is_uint32(t: PyDataType) -> bool {
545        t.0 == DataType::UInt32
546    }
547    #[staticmethod]
548    fn is_uint64(t: PyDataType) -> bool {
549        t.0 == DataType::UInt64
550    }
551    #[staticmethod]
552    fn is_floating(t: PyDataType) -> bool {
553        t.0.is_floating()
554    }
555    #[staticmethod]
556    fn is_float16(t: PyDataType) -> bool {
557        t.0 == DataType::Float16
558    }
559    #[staticmethod]
560    fn is_float32(t: PyDataType) -> bool {
561        t.0 == DataType::Float32
562    }
563    #[staticmethod]
564    fn is_float64(t: PyDataType) -> bool {
565        t.0 == DataType::Float64
566    }
567    #[staticmethod]
568    fn is_decimal(t: PyDataType) -> bool {
569        matches!(t.0, DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
570    }
571    #[staticmethod]
572    fn is_decimal128(t: PyDataType) -> bool {
573        matches!(t.0, DataType::Decimal128(_, _))
574    }
575    #[staticmethod]
576    fn is_decimal256(t: PyDataType) -> bool {
577        matches!(t.0, DataType::Decimal256(_, _))
578    }
579
580    #[staticmethod]
581    fn is_list(t: PyDataType) -> bool {
582        matches!(t.0, DataType::List(_))
583    }
584    #[staticmethod]
585    fn is_large_list(t: PyDataType) -> bool {
586        matches!(t.0, DataType::LargeList(_))
587    }
588    #[staticmethod]
589    fn is_fixed_size_list(t: PyDataType) -> bool {
590        matches!(t.0, DataType::FixedSizeList(_, _))
591    }
592    #[staticmethod]
593    fn is_list_view(t: PyDataType) -> bool {
594        matches!(t.0, DataType::ListView(_))
595    }
596    #[staticmethod]
597    fn is_large_list_view(t: PyDataType) -> bool {
598        matches!(t.0, DataType::LargeListView(_))
599    }
600    #[staticmethod]
601    fn is_struct(t: PyDataType) -> bool {
602        matches!(t.0, DataType::Struct(_))
603    }
604    #[staticmethod]
605    fn is_union(t: PyDataType) -> bool {
606        matches!(t.0, DataType::Union(_, _))
607    }
608    #[staticmethod]
609    fn is_nested(t: PyDataType) -> bool {
610        t.0.is_nested()
611    }
612    #[staticmethod]
613    fn is_run_end_encoded(t: PyDataType) -> bool {
614        t.0.is_run_ends_type()
615    }
616    #[staticmethod]
617    fn is_temporal(t: PyDataType) -> bool {
618        t.0.is_temporal()
619    }
620    #[staticmethod]
621    fn is_timestamp(t: PyDataType) -> bool {
622        matches!(t.0, DataType::Timestamp(_, _))
623    }
624    #[staticmethod]
625    fn is_date(t: PyDataType) -> bool {
626        matches!(t.0, DataType::Date32 | DataType::Date64)
627    }
628    #[staticmethod]
629    fn is_date32(t: PyDataType) -> bool {
630        t.0 == DataType::Date32
631    }
632    #[staticmethod]
633    fn is_date64(t: PyDataType) -> bool {
634        t.0 == DataType::Date64
635    }
636    #[staticmethod]
637    fn is_time(t: PyDataType) -> bool {
638        matches!(t.0, DataType::Time32(_) | DataType::Time64(_))
639    }
640    #[staticmethod]
641    fn is_time32(t: PyDataType) -> bool {
642        matches!(t.0, DataType::Time32(_))
643    }
644    #[staticmethod]
645    fn is_time64(t: PyDataType) -> bool {
646        matches!(t.0, DataType::Time64(_))
647    }
648    #[staticmethod]
649    fn is_duration(t: PyDataType) -> bool {
650        matches!(t.0, DataType::Duration(_))
651    }
652    #[staticmethod]
653    fn is_interval(t: PyDataType) -> bool {
654        matches!(t.0, DataType::Interval(_))
655    }
656    #[staticmethod]
657    fn is_null(t: PyDataType) -> bool {
658        t.0 == DataType::Null
659    }
660    #[staticmethod]
661    fn is_binary(t: PyDataType) -> bool {
662        t.0 == DataType::Binary
663    }
664    #[staticmethod]
665    fn is_unicode(t: PyDataType) -> bool {
666        t.0 == DataType::Utf8
667    }
668    #[staticmethod]
669    fn is_string(t: PyDataType) -> bool {
670        t.0 == DataType::Utf8
671    }
672    #[staticmethod]
673    fn is_large_binary(t: PyDataType) -> bool {
674        t.0 == DataType::LargeBinary
675    }
676    #[staticmethod]
677    fn is_large_unicode(t: PyDataType) -> bool {
678        t.0 == DataType::LargeUtf8
679    }
680    #[staticmethod]
681    fn is_large_string(t: PyDataType) -> bool {
682        t.0 == DataType::LargeUtf8
683    }
684    #[staticmethod]
685    fn is_binary_view(t: PyDataType) -> bool {
686        t.0 == DataType::BinaryView
687    }
688    #[staticmethod]
689    fn is_string_view(t: PyDataType) -> bool {
690        t.0 == DataType::Utf8View
691    }
692    #[staticmethod]
693    fn is_fixed_size_binary(t: PyDataType) -> bool {
694        matches!(t.0, DataType::FixedSizeBinary(_))
695    }
696    #[staticmethod]
697    fn is_map(t: PyDataType) -> bool {
698        matches!(t.0, DataType::Map(_, _))
699    }
700    #[staticmethod]
701    fn is_dictionary(t: PyDataType) -> bool {
702        matches!(t.0, DataType::Dictionary(_, _))
703    }
704    #[staticmethod]
705    fn is_primitive(t: PyDataType) -> bool {
706        t.0.is_primitive()
707    }
708    #[staticmethod]
709    fn is_numeric(t: PyDataType) -> bool {
710        t.0.is_numeric()
711    }
712    #[staticmethod]
713    fn is_dictionary_key_type(t: PyDataType) -> bool {
714        t.0.is_dictionary_key_type()
715    }
716}