1use std::fmt::Display;
2use std::sync::Arc;
3
4use arrow::datatypes::DataType;
5use arrow_schema::{Field, IntervalUnit, TimeUnit};
6use pyo3::exceptions::{PyTypeError, PyValueError};
7use pyo3::intern;
8use pyo3::prelude::*;
9use pyo3::types::{PyCapsule, PyTuple, PyType};
10
11use crate::error::PyArrowResult;
12use crate::export::Arro3DataType;
13use crate::ffi::from_python::utils::import_schema_pycapsule;
14use crate::ffi::to_python::nanoarrow::to_nanoarrow_schema;
15use crate::ffi::to_schema_pycapsule;
16use crate::PyField;
17
18struct PyTimeUnit(arrow_schema::TimeUnit);
19
20impl<'a> FromPyObject<'a> for PyTimeUnit {
21 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
22 let s: String = ob.extract()?;
23 match s.to_lowercase().as_str() {
24 "s" => Ok(Self(TimeUnit::Second)),
25 "ms" => Ok(Self(TimeUnit::Millisecond)),
26 "us" => Ok(Self(TimeUnit::Microsecond)),
27 "ns" => Ok(Self(TimeUnit::Nanosecond)),
28 _ => Err(PyValueError::new_err("Unexpected time unit")),
29 }
30 }
31}
32
33#[derive(PartialEq, Eq, Debug)]
35#[pyclass(module = "arro3.core._core", name = "DataType", subclass, frozen)]
36pub struct PyDataType(DataType);
37
38impl PyDataType {
39 pub fn new(data_type: DataType) -> Self {
41 Self(data_type)
42 }
43
44 pub fn from_arrow_pycapsule(capsule: &Bound<PyCapsule>) -> PyResult<Self> {
46 let schema_ptr = import_schema_pycapsule(capsule)?;
47 let data_type =
48 DataType::try_from(schema_ptr).map_err(|err| PyTypeError::new_err(err.to_string()))?;
49 Ok(Self::new(data_type))
50 }
51
52 pub fn into_inner(self) -> DataType {
54 self.0
55 }
56
57 pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
59 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
60 arro3_mod.getattr(intern!(py, "DataType"))?.call_method1(
61 intern!(py, "from_arrow_pycapsule"),
62 PyTuple::new(py, vec![self.__arrow_c_schema__(py)?])?,
63 )
64 }
65
66 pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
68 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
69 let capsule = to_schema_pycapsule(py, &self.0)?;
70 arro3_mod.getattr(intern!(py, "DataType"))?.call_method1(
71 intern!(py, "from_arrow_pycapsule"),
72 PyTuple::new(py, vec![capsule])?,
73 )
74 }
75
76 pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
78 to_nanoarrow_schema(py, &self.__arrow_c_schema__(py)?)
79 }
80
81 pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
85 let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
86 let pyarrow_field = pyarrow_mod
87 .getattr(intern!(py, "field"))?
88 .call1(PyTuple::new(py, vec![self.into_pyobject(py)?])?)?;
89 Ok(pyarrow_field
90 .getattr(intern!(py, "type"))?
91 .into_pyobject(py)?
92 .into_any()
93 .unbind())
94 }
95}
96
97impl From<PyDataType> for DataType {
98 fn from(value: PyDataType) -> Self {
99 value.0
100 }
101}
102
103impl From<DataType> for PyDataType {
104 fn from(value: DataType) -> Self {
105 Self(value)
106 }
107}
108
109impl From<&DataType> for PyDataType {
110 fn from(value: &DataType) -> Self {
111 Self(value.clone())
112 }
113}
114
115impl AsRef<DataType> for PyDataType {
116 fn as_ref(&self) -> &DataType {
117 &self.0
118 }
119}
120
121impl Display for PyDataType {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 write!(f, "arro3.core.DataType<")?;
124 self.0.fmt(f)?;
125 writeln!(f, ">")?;
126 Ok(())
127 }
128}
129
130#[allow(non_snake_case)]
131#[pymethods]
132impl PyDataType {
133 pub(crate) fn __arrow_c_schema__<'py>(
134 &'py self,
135 py: Python<'py>,
136 ) -> PyArrowResult<Bound<'py, PyCapsule>> {
137 to_schema_pycapsule(py, &self.0)
138 }
139
140 fn __eq__(&self, other: PyDataType) -> bool {
141 self.equals(other, false)
142 }
143
144 fn __repr__(&self) -> String {
145 self.to_string()
146 }
147
148 #[classmethod]
149 fn from_arrow(_cls: &Bound<PyType>, input: Self) -> Self {
150 input
151 }
152
153 #[classmethod]
154 #[pyo3(name = "from_arrow_pycapsule")]
155 fn from_arrow_pycapsule_py(_cls: &Bound<PyType>, capsule: &Bound<PyCapsule>) -> PyResult<Self> {
156 Self::from_arrow_pycapsule(capsule)
157 }
158
159 #[getter]
160 fn bit_width(&self) -> Option<usize> {
161 self.0.primitive_width().map(|width| width * 8)
162 }
163
164 #[pyo3(signature=(other, *, check_metadata=false))]
165 fn equals(&self, other: PyDataType, check_metadata: bool) -> bool {
166 let other = other.into_inner();
167 if check_metadata {
168 self.0 == other
169 } else {
170 self.0.equals_datatype(&other)
171 }
172 }
173
174 #[getter]
175 fn list_size(&self) -> Option<i32> {
176 match &self.0 {
177 DataType::FixedSizeList(_, list_size) => Some(*list_size),
178 _ => None,
179 }
180 }
181
182 #[getter]
183 fn num_fields(&self) -> usize {
184 match &self.0 {
185 DataType::Null
186 | DataType::Boolean
187 | DataType::Int8
188 | DataType::Int16
189 | DataType::Int32
190 | DataType::Int64
191 | DataType::UInt8
192 | DataType::UInt16
193 | DataType::UInt32
194 | DataType::UInt64
195 | DataType::Float16
196 | DataType::Float32
197 | DataType::Float64
198 | DataType::Timestamp(_, _)
199 | DataType::Date32
200 | DataType::Date64
201 | DataType::Time32(_)
202 | DataType::Time64(_)
203 | DataType::Duration(_)
204 | DataType::Interval(_)
205 | DataType::Binary
206 | DataType::FixedSizeBinary(_)
207 | DataType::LargeBinary
208 | DataType::BinaryView
209 | DataType::Utf8
210 | DataType::LargeUtf8
211 | DataType::Utf8View
212 | DataType::Decimal128(_, _)
213 | DataType::Decimal256(_, _) => 0,
214 DataType::List(_)
215 | DataType::ListView(_)
216 | DataType::FixedSizeList(_, _)
217 | DataType::LargeList(_)
218 | DataType::LargeListView(_) => 1,
219 DataType::Struct(fields) => fields.len(),
220 DataType::Union(fields, _) => fields.len(),
221 DataType::Dictionary(_, _) | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) => 2,
223 }
224 }
225
226 #[getter]
227 fn time_unit(&self) -> Option<&str> {
228 match &self.0 {
229 DataType::Time32(unit)
230 | DataType::Time64(unit)
231 | DataType::Timestamp(unit, _)
232 | DataType::Duration(unit) => match unit {
233 TimeUnit::Second => Some("s"),
234 TimeUnit::Millisecond => Some("ms"),
235 TimeUnit::Microsecond => Some("us"),
236 TimeUnit::Nanosecond => Some("ns"),
237 },
238 _ => None,
239 }
240 }
241
242 #[getter]
243 fn tz(&self) -> Option<&str> {
244 match &self.0 {
245 DataType::Timestamp(_, tz) => tz.as_deref(),
246 _ => None,
247 }
248 }
249
250 #[getter]
251 fn value_type(&self) -> Option<Arro3DataType> {
252 match &self.0 {
253 DataType::FixedSizeList(value_field, _)
254 | DataType::List(value_field)
255 | DataType::LargeList(value_field)
256 | DataType::ListView(value_field)
257 | DataType::LargeListView(value_field)
258 | DataType::RunEndEncoded(_, value_field) => {
259 Some(PyDataType::new(value_field.data_type().clone()).into())
260 }
261 DataType::Dictionary(_key_type, value_type) => {
262 Some(PyDataType::new(*value_type.clone()).into())
263 }
264 _ => None,
265 }
266 }
267
268 #[classmethod]
271 fn null(_: &Bound<PyType>) -> Self {
272 Self(DataType::Null)
273 }
274
275 #[classmethod]
276 fn bool(_: &Bound<PyType>) -> Self {
277 Self(DataType::Boolean)
278 }
279
280 #[classmethod]
281 fn int8(_: &Bound<PyType>) -> Self {
282 Self(DataType::Int8)
283 }
284
285 #[classmethod]
286 fn int16(_: &Bound<PyType>) -> Self {
287 Self(DataType::Int16)
288 }
289
290 #[classmethod]
291 fn int32(_: &Bound<PyType>) -> Self {
292 Self(DataType::Int32)
293 }
294
295 #[classmethod]
296 fn int64(_: &Bound<PyType>) -> Self {
297 Self(DataType::Int64)
298 }
299
300 #[classmethod]
301 fn uint8(_: &Bound<PyType>) -> Self {
302 Self(DataType::UInt8)
303 }
304
305 #[classmethod]
306 fn uint16(_: &Bound<PyType>) -> Self {
307 Self(DataType::UInt16)
308 }
309
310 #[classmethod]
311 fn uint32(_: &Bound<PyType>) -> Self {
312 Self(DataType::UInt32)
313 }
314
315 #[classmethod]
316 fn uint64(_: &Bound<PyType>) -> Self {
317 Self(DataType::UInt64)
318 }
319
320 #[classmethod]
321 fn float16(_: &Bound<PyType>) -> Self {
322 Self(DataType::Float16)
323 }
324
325 #[classmethod]
326 fn float32(_: &Bound<PyType>) -> Self {
327 Self(DataType::Float32)
328 }
329
330 #[classmethod]
331 fn float64(_: &Bound<PyType>) -> Self {
332 Self(DataType::Float64)
333 }
334
335 #[classmethod]
336 fn time32(_: &Bound<PyType>, unit: PyTimeUnit) -> PyArrowResult<Self> {
337 if unit.0 == TimeUnit::Microsecond || unit.0 == TimeUnit::Nanosecond {
338 return Err(PyValueError::new_err("Unexpected timeunit for time32").into());
339 }
340
341 Ok(Self(DataType::Time32(unit.0)))
342 }
343
344 #[classmethod]
345 fn time64(_: &Bound<PyType>, unit: PyTimeUnit) -> PyArrowResult<Self> {
346 if unit.0 == TimeUnit::Second || unit.0 == TimeUnit::Millisecond {
347 return Err(PyValueError::new_err("Unexpected timeunit for time64").into());
348 }
349
350 Ok(Self(DataType::Time64(unit.0)))
351 }
352
353 #[classmethod]
354 #[pyo3(signature = (unit, *, tz=None))]
355 fn timestamp(_: &Bound<PyType>, unit: PyTimeUnit, tz: Option<String>) -> Self {
356 Self(DataType::Timestamp(unit.0, tz.map(|s| s.into())))
357 }
358
359 #[classmethod]
360 fn date32(_: &Bound<PyType>) -> Self {
361 Self(DataType::Date32)
362 }
363
364 #[classmethod]
365 fn date64(_: &Bound<PyType>) -> Self {
366 Self(DataType::Date64)
367 }
368
369 #[classmethod]
370 fn duration(_: &Bound<PyType>, unit: PyTimeUnit) -> Self {
371 Self(DataType::Duration(unit.0))
372 }
373
374 #[classmethod]
375 fn month_day_nano_interval(_: &Bound<PyType>) -> Self {
376 Self(DataType::Interval(IntervalUnit::MonthDayNano))
377 }
378
379 #[classmethod]
380 #[pyo3(signature = (length=None))]
381 fn binary(_: &Bound<PyType>, length: Option<i32>) -> Self {
382 if let Some(length) = length {
383 Self(DataType::FixedSizeBinary(length))
384 } else {
385 Self(DataType::Binary)
386 }
387 }
388
389 #[classmethod]
390 fn string(_: &Bound<PyType>) -> Self {
391 Self(DataType::Utf8)
392 }
393
394 #[classmethod]
395 fn utf8(_: &Bound<PyType>) -> Self {
396 Self(DataType::Utf8)
397 }
398
399 #[classmethod]
400 fn large_binary(_: &Bound<PyType>) -> Self {
401 Self(DataType::LargeBinary)
402 }
403
404 #[classmethod]
405 fn large_string(_: &Bound<PyType>) -> Self {
406 Self(DataType::LargeUtf8)
407 }
408
409 #[classmethod]
410 fn large_utf8(_: &Bound<PyType>) -> Self {
411 Self(DataType::LargeUtf8)
412 }
413
414 #[classmethod]
415 fn binary_view(_: &Bound<PyType>) -> Self {
416 Self(DataType::BinaryView)
417 }
418
419 #[classmethod]
420 fn string_view(_: &Bound<PyType>) -> Self {
421 Self(DataType::Utf8View)
422 }
423
424 #[classmethod]
425 fn decimal128(_: &Bound<PyType>, precision: u8, scale: i8) -> Self {
426 Self(DataType::Decimal128(precision, scale))
427 }
428
429 #[classmethod]
430 fn decimal256(_: &Bound<PyType>, precision: u8, scale: i8) -> Self {
431 Self(DataType::Decimal256(precision, scale))
432 }
433
434 #[classmethod]
435 #[pyo3(signature = (value_type, list_size=None))]
436 fn list(_: &Bound<PyType>, value_type: PyField, list_size: Option<i32>) -> Self {
437 if let Some(list_size) = list_size {
438 Self(DataType::FixedSizeList(value_type.into(), list_size))
439 } else {
440 Self(DataType::List(value_type.into()))
441 }
442 }
443
444 #[classmethod]
445 fn large_list(_: &Bound<PyType>, value_type: PyField) -> Self {
446 Self(DataType::LargeList(value_type.into()))
447 }
448
449 #[classmethod]
450 fn list_view(_: &Bound<PyType>, value_type: PyField) -> Self {
451 Self(DataType::ListView(value_type.into()))
452 }
453
454 #[classmethod]
455 fn large_list_view(_: &Bound<PyType>, value_type: PyField) -> Self {
456 Self(DataType::LargeListView(value_type.into()))
457 }
458
459 #[classmethod]
460 fn map(_: &Bound<PyType>, key_type: PyField, item_type: PyField, keys_sorted: bool) -> Self {
461 let data_type = DataType::Map(
464 Arc::new(Field::new(
465 "entries",
466 DataType::Struct(vec![key_type.into_inner(), item_type.into_inner()].into()),
467 false, )),
469 keys_sorted,
470 );
471 Self(data_type)
472 }
473
474 #[classmethod]
475 fn r#struct(_: &Bound<PyType>, fields: Vec<PyField>) -> Self {
476 Self(DataType::Struct(
477 fields.into_iter().map(|field| field.into_inner()).collect(),
478 ))
479 }
480
481 #[classmethod]
482 fn dictionary(_: &Bound<PyType>, index_type: PyDataType, value_type: PyDataType) -> Self {
483 Self(DataType::Dictionary(
484 Box::new(index_type.into_inner()),
485 Box::new(value_type.into_inner()),
486 ))
487 }
488
489 #[classmethod]
490 fn run_end_encoded(_: &Bound<PyType>, run_end_type: PyField, value_type: PyField) -> Self {
491 Self(DataType::RunEndEncoded(
492 run_end_type.into_inner(),
493 value_type.into_inner(),
494 ))
495 }
496
497 #[staticmethod]
500 fn is_boolean(t: PyDataType) -> bool {
501 t.0 == DataType::Boolean
502 }
503
504 #[staticmethod]
505 fn is_integer(t: PyDataType) -> bool {
506 t.0.is_integer()
507 }
508
509 #[staticmethod]
510 fn is_signed_integer(t: PyDataType) -> bool {
511 t.0.is_signed_integer()
512 }
513
514 #[staticmethod]
515 fn is_unsigned_integer(t: PyDataType) -> bool {
516 t.0.is_unsigned_integer()
517 }
518
519 #[staticmethod]
520 fn is_int8(t: PyDataType) -> bool {
521 t.0 == DataType::Int8
522 }
523 #[staticmethod]
524 fn is_int16(t: PyDataType) -> bool {
525 t.0 == DataType::Int16
526 }
527 #[staticmethod]
528 fn is_int32(t: PyDataType) -> bool {
529 t.0 == DataType::Int32
530 }
531 #[staticmethod]
532 fn is_int64(t: PyDataType) -> bool {
533 t.0 == DataType::Int64
534 }
535 #[staticmethod]
536 fn is_uint8(t: PyDataType) -> bool {
537 t.0 == DataType::UInt8
538 }
539 #[staticmethod]
540 fn is_uint16(t: PyDataType) -> bool {
541 t.0 == DataType::UInt16
542 }
543 #[staticmethod]
544 fn is_uint32(t: PyDataType) -> bool {
545 t.0 == DataType::UInt32
546 }
547 #[staticmethod]
548 fn is_uint64(t: PyDataType) -> bool {
549 t.0 == DataType::UInt64
550 }
551 #[staticmethod]
552 fn is_floating(t: PyDataType) -> bool {
553 t.0.is_floating()
554 }
555 #[staticmethod]
556 fn is_float16(t: PyDataType) -> bool {
557 t.0 == DataType::Float16
558 }
559 #[staticmethod]
560 fn is_float32(t: PyDataType) -> bool {
561 t.0 == DataType::Float32
562 }
563 #[staticmethod]
564 fn is_float64(t: PyDataType) -> bool {
565 t.0 == DataType::Float64
566 }
567 #[staticmethod]
568 fn is_decimal(t: PyDataType) -> bool {
569 matches!(t.0, DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
570 }
571 #[staticmethod]
572 fn is_decimal128(t: PyDataType) -> bool {
573 matches!(t.0, DataType::Decimal128(_, _))
574 }
575 #[staticmethod]
576 fn is_decimal256(t: PyDataType) -> bool {
577 matches!(t.0, DataType::Decimal256(_, _))
578 }
579
580 #[staticmethod]
581 fn is_list(t: PyDataType) -> bool {
582 matches!(t.0, DataType::List(_))
583 }
584 #[staticmethod]
585 fn is_large_list(t: PyDataType) -> bool {
586 matches!(t.0, DataType::LargeList(_))
587 }
588 #[staticmethod]
589 fn is_fixed_size_list(t: PyDataType) -> bool {
590 matches!(t.0, DataType::FixedSizeList(_, _))
591 }
592 #[staticmethod]
593 fn is_list_view(t: PyDataType) -> bool {
594 matches!(t.0, DataType::ListView(_))
595 }
596 #[staticmethod]
597 fn is_large_list_view(t: PyDataType) -> bool {
598 matches!(t.0, DataType::LargeListView(_))
599 }
600 #[staticmethod]
601 fn is_struct(t: PyDataType) -> bool {
602 matches!(t.0, DataType::Struct(_))
603 }
604 #[staticmethod]
605 fn is_union(t: PyDataType) -> bool {
606 matches!(t.0, DataType::Union(_, _))
607 }
608 #[staticmethod]
609 fn is_nested(t: PyDataType) -> bool {
610 t.0.is_nested()
611 }
612 #[staticmethod]
613 fn is_run_end_encoded(t: PyDataType) -> bool {
614 t.0.is_run_ends_type()
615 }
616 #[staticmethod]
617 fn is_temporal(t: PyDataType) -> bool {
618 t.0.is_temporal()
619 }
620 #[staticmethod]
621 fn is_timestamp(t: PyDataType) -> bool {
622 matches!(t.0, DataType::Timestamp(_, _))
623 }
624 #[staticmethod]
625 fn is_date(t: PyDataType) -> bool {
626 matches!(t.0, DataType::Date32 | DataType::Date64)
627 }
628 #[staticmethod]
629 fn is_date32(t: PyDataType) -> bool {
630 t.0 == DataType::Date32
631 }
632 #[staticmethod]
633 fn is_date64(t: PyDataType) -> bool {
634 t.0 == DataType::Date64
635 }
636 #[staticmethod]
637 fn is_time(t: PyDataType) -> bool {
638 matches!(t.0, DataType::Time32(_) | DataType::Time64(_))
639 }
640 #[staticmethod]
641 fn is_time32(t: PyDataType) -> bool {
642 matches!(t.0, DataType::Time32(_))
643 }
644 #[staticmethod]
645 fn is_time64(t: PyDataType) -> bool {
646 matches!(t.0, DataType::Time64(_))
647 }
648 #[staticmethod]
649 fn is_duration(t: PyDataType) -> bool {
650 matches!(t.0, DataType::Duration(_))
651 }
652 #[staticmethod]
653 fn is_interval(t: PyDataType) -> bool {
654 matches!(t.0, DataType::Interval(_))
655 }
656 #[staticmethod]
657 fn is_null(t: PyDataType) -> bool {
658 t.0 == DataType::Null
659 }
660 #[staticmethod]
661 fn is_binary(t: PyDataType) -> bool {
662 t.0 == DataType::Binary
663 }
664 #[staticmethod]
665 fn is_unicode(t: PyDataType) -> bool {
666 t.0 == DataType::Utf8
667 }
668 #[staticmethod]
669 fn is_string(t: PyDataType) -> bool {
670 t.0 == DataType::Utf8
671 }
672 #[staticmethod]
673 fn is_large_binary(t: PyDataType) -> bool {
674 t.0 == DataType::LargeBinary
675 }
676 #[staticmethod]
677 fn is_large_unicode(t: PyDataType) -> bool {
678 t.0 == DataType::LargeUtf8
679 }
680 #[staticmethod]
681 fn is_large_string(t: PyDataType) -> bool {
682 t.0 == DataType::LargeUtf8
683 }
684 #[staticmethod]
685 fn is_binary_view(t: PyDataType) -> bool {
686 t.0 == DataType::BinaryView
687 }
688 #[staticmethod]
689 fn is_string_view(t: PyDataType) -> bool {
690 t.0 == DataType::Utf8View
691 }
692 #[staticmethod]
693 fn is_fixed_size_binary(t: PyDataType) -> bool {
694 matches!(t.0, DataType::FixedSizeBinary(_))
695 }
696 #[staticmethod]
697 fn is_map(t: PyDataType) -> bool {
698 matches!(t.0, DataType::Map(_, _))
699 }
700 #[staticmethod]
701 fn is_dictionary(t: PyDataType) -> bool {
702 matches!(t.0, DataType::Dictionary(_, _))
703 }
704 #[staticmethod]
705 fn is_primitive(t: PyDataType) -> bool {
706 t.0.is_primitive()
707 }
708 #[staticmethod]
709 fn is_numeric(t: PyDataType) -> bool {
710 t.0.is_numeric()
711 }
712 #[staticmethod]
713 fn is_dictionary_key_type(t: PyDataType) -> bool {
714 t.0.is_dictionary_key_type()
715 }
716}