1use std::fmt::Display;
2use std::sync::Arc;
3
4use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit};
5use pyo3::exceptions::{PyTypeError, PyValueError};
6use pyo3::intern;
7use pyo3::prelude::*;
8use pyo3::types::{PyCapsule, PyTuple, PyType};
9
10use crate::error::PyArrowResult;
11use crate::export::Arro3DataType;
12use crate::ffi::from_python::utils::import_schema_pycapsule;
13use crate::ffi::to_python::nanoarrow::to_nanoarrow_schema;
14use crate::ffi::to_schema_pycapsule;
15use crate::PyField;
16
17struct PyTimeUnit(arrow_schema::TimeUnit);
18
19impl<'a> FromPyObject<'a> for PyTimeUnit {
20 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
21 let s: String = ob.extract()?;
22 match s.to_lowercase().as_str() {
23 "s" => Ok(Self(TimeUnit::Second)),
24 "ms" => Ok(Self(TimeUnit::Millisecond)),
25 "us" => Ok(Self(TimeUnit::Microsecond)),
26 "ns" => Ok(Self(TimeUnit::Nanosecond)),
27 _ => Err(PyValueError::new_err("Unexpected time unit")),
28 }
29 }
30}
31
32#[derive(PartialEq, Eq, Debug)]
34#[pyclass(module = "arro3.core._core", name = "DataType", subclass, frozen)]
35pub struct PyDataType(DataType);
36
37impl PyDataType {
38 pub fn new(data_type: DataType) -> Self {
40 Self(data_type)
41 }
42
43 pub fn from_arrow_pycapsule(capsule: &Bound<PyCapsule>) -> PyResult<Self> {
45 let schema_ptr = import_schema_pycapsule(capsule)?;
46 let data_type =
47 DataType::try_from(schema_ptr).map_err(|err| PyTypeError::new_err(err.to_string()))?;
48 Ok(Self::new(data_type))
49 }
50
51 pub fn into_inner(self) -> DataType {
53 self.0
54 }
55
56 pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
58 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
59 arro3_mod.getattr(intern!(py, "DataType"))?.call_method1(
60 intern!(py, "from_arrow_pycapsule"),
61 PyTuple::new(py, vec![self.__arrow_c_schema__(py)?])?,
62 )
63 }
64
65 pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
67 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
68 let capsule = to_schema_pycapsule(py, &self.0)?;
69 arro3_mod.getattr(intern!(py, "DataType"))?.call_method1(
70 intern!(py, "from_arrow_pycapsule"),
71 PyTuple::new(py, vec![capsule])?,
72 )
73 }
74
75 pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
77 to_nanoarrow_schema(py, &self.__arrow_c_schema__(py)?)
78 }
79
80 pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
84 let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
85 let pyarrow_field = pyarrow_mod
86 .getattr(intern!(py, "field"))?
87 .call1(PyTuple::new(py, vec![self.into_pyobject(py)?])?)?;
88 Ok(pyarrow_field
89 .getattr(intern!(py, "type"))?
90 .into_pyobject(py)?
91 .into_any()
92 .unbind())
93 }
94}
95
96impl From<PyDataType> for DataType {
97 fn from(value: PyDataType) -> Self {
98 value.0
99 }
100}
101
102impl From<DataType> for PyDataType {
103 fn from(value: DataType) -> Self {
104 Self(value)
105 }
106}
107
108impl From<&DataType> for PyDataType {
109 fn from(value: &DataType) -> Self {
110 Self(value.clone())
111 }
112}
113
114impl AsRef<DataType> for PyDataType {
115 fn as_ref(&self) -> &DataType {
116 &self.0
117 }
118}
119
120impl Display for PyDataType {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 write!(f, "arro3.core.DataType<")?;
123 self.0.fmt(f)?;
124 writeln!(f, ">")?;
125 Ok(())
126 }
127}
128
129#[allow(non_snake_case)]
130#[pymethods]
131impl PyDataType {
132 pub(crate) fn __arrow_c_schema__<'py>(
133 &'py self,
134 py: Python<'py>,
135 ) -> PyArrowResult<Bound<'py, PyCapsule>> {
136 to_schema_pycapsule(py, &self.0)
137 }
138
139 fn __eq__(&self, other: PyDataType) -> bool {
140 self.equals(other, false)
141 }
142
143 fn __repr__(&self) -> String {
144 self.to_string()
145 }
146
147 #[classmethod]
148 fn from_arrow(_cls: &Bound<PyType>, input: Self) -> Self {
149 input
150 }
151
152 #[classmethod]
153 #[pyo3(name = "from_arrow_pycapsule")]
154 fn from_arrow_pycapsule_py(_cls: &Bound<PyType>, capsule: &Bound<PyCapsule>) -> PyResult<Self> {
155 Self::from_arrow_pycapsule(capsule)
156 }
157
158 #[getter]
159 fn bit_width(&self) -> Option<usize> {
160 self.0.primitive_width().map(|width| width * 8)
161 }
162
163 #[pyo3(signature=(other, *, check_metadata=false))]
164 fn equals(&self, other: PyDataType, check_metadata: bool) -> bool {
165 let other = other.into_inner();
166 if check_metadata {
167 self.0 == other
168 } else {
169 self.0.equals_datatype(&other)
170 }
171 }
172
173 #[getter]
174 fn list_size(&self) -> Option<i32> {
175 match &self.0 {
176 DataType::FixedSizeList(_, list_size) => Some(*list_size),
177 _ => None,
178 }
179 }
180
181 #[getter]
182 fn num_fields(&self) -> usize {
183 match &self.0 {
184 DataType::Null
185 | DataType::Boolean
186 | DataType::Int8
187 | DataType::Int16
188 | DataType::Int32
189 | DataType::Int64
190 | DataType::UInt8
191 | DataType::UInt16
192 | DataType::UInt32
193 | DataType::UInt64
194 | DataType::Float16
195 | DataType::Float32
196 | DataType::Float64
197 | DataType::Timestamp(_, _)
198 | DataType::Date32
199 | DataType::Date64
200 | DataType::Time32(_)
201 | DataType::Time64(_)
202 | DataType::Duration(_)
203 | DataType::Interval(_)
204 | DataType::Binary
205 | DataType::FixedSizeBinary(_)
206 | DataType::LargeBinary
207 | DataType::BinaryView
208 | DataType::Utf8
209 | DataType::LargeUtf8
210 | DataType::Utf8View
211 | DataType::Decimal128(_, _)
212 | DataType::Decimal256(_, _) => 0,
213 DataType::List(_)
214 | DataType::ListView(_)
215 | DataType::FixedSizeList(_, _)
216 | DataType::LargeList(_)
217 | DataType::LargeListView(_) => 1,
218 DataType::Struct(fields) => fields.len(),
219 DataType::Union(fields, _) => fields.len(),
220 DataType::Dictionary(_, _) | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) => 2,
222 }
223 }
224
225 #[getter]
226 fn time_unit(&self) -> Option<&str> {
227 match &self.0 {
228 DataType::Time32(unit)
229 | DataType::Time64(unit)
230 | DataType::Timestamp(unit, _)
231 | DataType::Duration(unit) => match unit {
232 TimeUnit::Second => Some("s"),
233 TimeUnit::Millisecond => Some("ms"),
234 TimeUnit::Microsecond => Some("us"),
235 TimeUnit::Nanosecond => Some("ns"),
236 },
237 _ => None,
238 }
239 }
240
241 #[getter]
242 fn tz(&self) -> Option<&str> {
243 match &self.0 {
244 DataType::Timestamp(_, tz) => tz.as_deref(),
245 _ => None,
246 }
247 }
248
249 #[getter]
250 fn value_type(&self) -> Option<Arro3DataType> {
251 match &self.0 {
252 DataType::FixedSizeList(value_field, _)
253 | DataType::List(value_field)
254 | DataType::LargeList(value_field)
255 | DataType::ListView(value_field)
256 | DataType::LargeListView(value_field)
257 | DataType::RunEndEncoded(_, value_field) => {
258 Some(PyDataType::new(value_field.data_type().clone()).into())
259 }
260 DataType::Dictionary(_key_type, value_type) => {
261 Some(PyDataType::new(*value_type.clone()).into())
262 }
263 _ => None,
264 }
265 }
266
267 #[classmethod]
270 fn null(_: &Bound<PyType>) -> Self {
271 Self(DataType::Null)
272 }
273
274 #[classmethod]
275 fn bool(_: &Bound<PyType>) -> Self {
276 Self(DataType::Boolean)
277 }
278
279 #[classmethod]
280 fn int8(_: &Bound<PyType>) -> Self {
281 Self(DataType::Int8)
282 }
283
284 #[classmethod]
285 fn int16(_: &Bound<PyType>) -> Self {
286 Self(DataType::Int16)
287 }
288
289 #[classmethod]
290 fn int32(_: &Bound<PyType>) -> Self {
291 Self(DataType::Int32)
292 }
293
294 #[classmethod]
295 fn int64(_: &Bound<PyType>) -> Self {
296 Self(DataType::Int64)
297 }
298
299 #[classmethod]
300 fn uint8(_: &Bound<PyType>) -> Self {
301 Self(DataType::UInt8)
302 }
303
304 #[classmethod]
305 fn uint16(_: &Bound<PyType>) -> Self {
306 Self(DataType::UInt16)
307 }
308
309 #[classmethod]
310 fn uint32(_: &Bound<PyType>) -> Self {
311 Self(DataType::UInt32)
312 }
313
314 #[classmethod]
315 fn uint64(_: &Bound<PyType>) -> Self {
316 Self(DataType::UInt64)
317 }
318
319 #[classmethod]
320 fn float16(_: &Bound<PyType>) -> Self {
321 Self(DataType::Float16)
322 }
323
324 #[classmethod]
325 fn float32(_: &Bound<PyType>) -> Self {
326 Self(DataType::Float32)
327 }
328
329 #[classmethod]
330 fn float64(_: &Bound<PyType>) -> Self {
331 Self(DataType::Float64)
332 }
333
334 #[classmethod]
335 fn time32(_: &Bound<PyType>, unit: PyTimeUnit) -> PyArrowResult<Self> {
336 if unit.0 == TimeUnit::Microsecond || unit.0 == TimeUnit::Nanosecond {
337 return Err(PyValueError::new_err("Unexpected timeunit for time32").into());
338 }
339
340 Ok(Self(DataType::Time32(unit.0)))
341 }
342
343 #[classmethod]
344 fn time64(_: &Bound<PyType>, unit: PyTimeUnit) -> PyArrowResult<Self> {
345 if unit.0 == TimeUnit::Second || unit.0 == TimeUnit::Millisecond {
346 return Err(PyValueError::new_err("Unexpected timeunit for time64").into());
347 }
348
349 Ok(Self(DataType::Time64(unit.0)))
350 }
351
352 #[classmethod]
353 #[pyo3(signature = (unit, *, tz=None))]
354 fn timestamp(_: &Bound<PyType>, unit: PyTimeUnit, tz: Option<String>) -> Self {
355 Self(DataType::Timestamp(unit.0, tz.map(|s| s.into())))
356 }
357
358 #[classmethod]
359 fn date32(_: &Bound<PyType>) -> Self {
360 Self(DataType::Date32)
361 }
362
363 #[classmethod]
364 fn date64(_: &Bound<PyType>) -> Self {
365 Self(DataType::Date64)
366 }
367
368 #[classmethod]
369 fn duration(_: &Bound<PyType>, unit: PyTimeUnit) -> Self {
370 Self(DataType::Duration(unit.0))
371 }
372
373 #[classmethod]
374 fn month_day_nano_interval(_: &Bound<PyType>) -> Self {
375 Self(DataType::Interval(IntervalUnit::MonthDayNano))
376 }
377
378 #[classmethod]
379 #[pyo3(signature = (length=None))]
380 fn binary(_: &Bound<PyType>, length: Option<i32>) -> Self {
381 if let Some(length) = length {
382 Self(DataType::FixedSizeBinary(length))
383 } else {
384 Self(DataType::Binary)
385 }
386 }
387
388 #[classmethod]
389 fn string(_: &Bound<PyType>) -> Self {
390 Self(DataType::Utf8)
391 }
392
393 #[classmethod]
394 fn utf8(_: &Bound<PyType>) -> Self {
395 Self(DataType::Utf8)
396 }
397
398 #[classmethod]
399 fn large_binary(_: &Bound<PyType>) -> Self {
400 Self(DataType::LargeBinary)
401 }
402
403 #[classmethod]
404 fn large_string(_: &Bound<PyType>) -> Self {
405 Self(DataType::LargeUtf8)
406 }
407
408 #[classmethod]
409 fn large_utf8(_: &Bound<PyType>) -> Self {
410 Self(DataType::LargeUtf8)
411 }
412
413 #[classmethod]
414 fn binary_view(_: &Bound<PyType>) -> Self {
415 Self(DataType::BinaryView)
416 }
417
418 #[classmethod]
419 fn string_view(_: &Bound<PyType>) -> Self {
420 Self(DataType::Utf8View)
421 }
422
423 #[classmethod]
424 fn decimal128(_: &Bound<PyType>, precision: u8, scale: i8) -> Self {
425 Self(DataType::Decimal128(precision, scale))
426 }
427
428 #[classmethod]
429 fn decimal256(_: &Bound<PyType>, precision: u8, scale: i8) -> Self {
430 Self(DataType::Decimal256(precision, scale))
431 }
432
433 #[classmethod]
434 #[pyo3(signature = (value_type, list_size=None))]
435 fn list(_: &Bound<PyType>, value_type: PyField, list_size: Option<i32>) -> Self {
436 if let Some(list_size) = list_size {
437 Self(DataType::FixedSizeList(value_type.into(), list_size))
438 } else {
439 Self(DataType::List(value_type.into()))
440 }
441 }
442
443 #[classmethod]
444 fn large_list(_: &Bound<PyType>, value_type: PyField) -> Self {
445 Self(DataType::LargeList(value_type.into()))
446 }
447
448 #[classmethod]
449 fn list_view(_: &Bound<PyType>, value_type: PyField) -> Self {
450 Self(DataType::ListView(value_type.into()))
451 }
452
453 #[classmethod]
454 fn large_list_view(_: &Bound<PyType>, value_type: PyField) -> Self {
455 Self(DataType::LargeListView(value_type.into()))
456 }
457
458 #[classmethod]
459 fn map(_: &Bound<PyType>, key_type: PyField, item_type: PyField, keys_sorted: bool) -> Self {
460 let data_type = DataType::Map(
463 Arc::new(Field::new(
464 "entries",
465 DataType::Struct(vec![key_type.into_inner(), item_type.into_inner()].into()),
466 false, )),
468 keys_sorted,
469 );
470 Self(data_type)
471 }
472
473 #[classmethod]
474 fn r#struct(_: &Bound<PyType>, fields: Vec<PyField>) -> Self {
475 Self(DataType::Struct(
476 fields.into_iter().map(|field| field.into_inner()).collect(),
477 ))
478 }
479
480 #[classmethod]
481 fn dictionary(_: &Bound<PyType>, index_type: PyDataType, value_type: PyDataType) -> Self {
482 Self(DataType::Dictionary(
483 Box::new(index_type.into_inner()),
484 Box::new(value_type.into_inner()),
485 ))
486 }
487
488 #[classmethod]
489 fn run_end_encoded(_: &Bound<PyType>, run_end_type: PyField, value_type: PyField) -> Self {
490 Self(DataType::RunEndEncoded(
491 run_end_type.into_inner(),
492 value_type.into_inner(),
493 ))
494 }
495
496 #[staticmethod]
499 fn is_boolean(t: PyDataType) -> bool {
500 t.0 == DataType::Boolean
501 }
502
503 #[staticmethod]
504 fn is_integer(t: PyDataType) -> bool {
505 t.0.is_integer()
506 }
507
508 #[staticmethod]
509 fn is_signed_integer(t: PyDataType) -> bool {
510 t.0.is_signed_integer()
511 }
512
513 #[staticmethod]
514 fn is_unsigned_integer(t: PyDataType) -> bool {
515 t.0.is_unsigned_integer()
516 }
517
518 #[staticmethod]
519 fn is_int8(t: PyDataType) -> bool {
520 t.0 == DataType::Int8
521 }
522 #[staticmethod]
523 fn is_int16(t: PyDataType) -> bool {
524 t.0 == DataType::Int16
525 }
526 #[staticmethod]
527 fn is_int32(t: PyDataType) -> bool {
528 t.0 == DataType::Int32
529 }
530 #[staticmethod]
531 fn is_int64(t: PyDataType) -> bool {
532 t.0 == DataType::Int64
533 }
534 #[staticmethod]
535 fn is_uint8(t: PyDataType) -> bool {
536 t.0 == DataType::UInt8
537 }
538 #[staticmethod]
539 fn is_uint16(t: PyDataType) -> bool {
540 t.0 == DataType::UInt16
541 }
542 #[staticmethod]
543 fn is_uint32(t: PyDataType) -> bool {
544 t.0 == DataType::UInt32
545 }
546 #[staticmethod]
547 fn is_uint64(t: PyDataType) -> bool {
548 t.0 == DataType::UInt64
549 }
550 #[staticmethod]
551 fn is_floating(t: PyDataType) -> bool {
552 t.0.is_floating()
553 }
554 #[staticmethod]
555 fn is_float16(t: PyDataType) -> bool {
556 t.0 == DataType::Float16
557 }
558 #[staticmethod]
559 fn is_float32(t: PyDataType) -> bool {
560 t.0 == DataType::Float32
561 }
562 #[staticmethod]
563 fn is_float64(t: PyDataType) -> bool {
564 t.0 == DataType::Float64
565 }
566 #[staticmethod]
567 fn is_decimal(t: PyDataType) -> bool {
568 matches!(t.0, DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
569 }
570 #[staticmethod]
571 fn is_decimal128(t: PyDataType) -> bool {
572 matches!(t.0, DataType::Decimal128(_, _))
573 }
574 #[staticmethod]
575 fn is_decimal256(t: PyDataType) -> bool {
576 matches!(t.0, DataType::Decimal256(_, _))
577 }
578
579 #[staticmethod]
580 fn is_list(t: PyDataType) -> bool {
581 matches!(t.0, DataType::List(_))
582 }
583 #[staticmethod]
584 fn is_large_list(t: PyDataType) -> bool {
585 matches!(t.0, DataType::LargeList(_))
586 }
587 #[staticmethod]
588 fn is_fixed_size_list(t: PyDataType) -> bool {
589 matches!(t.0, DataType::FixedSizeList(_, _))
590 }
591 #[staticmethod]
592 fn is_list_view(t: PyDataType) -> bool {
593 matches!(t.0, DataType::ListView(_))
594 }
595 #[staticmethod]
596 fn is_large_list_view(t: PyDataType) -> bool {
597 matches!(t.0, DataType::LargeListView(_))
598 }
599 #[staticmethod]
600 fn is_struct(t: PyDataType) -> bool {
601 matches!(t.0, DataType::Struct(_))
602 }
603 #[staticmethod]
604 fn is_union(t: PyDataType) -> bool {
605 matches!(t.0, DataType::Union(_, _))
606 }
607 #[staticmethod]
608 fn is_nested(t: PyDataType) -> bool {
609 t.0.is_nested()
610 }
611 #[staticmethod]
612 fn is_run_end_encoded(t: PyDataType) -> bool {
613 t.0.is_run_ends_type()
614 }
615 #[staticmethod]
616 fn is_temporal(t: PyDataType) -> bool {
617 t.0.is_temporal()
618 }
619 #[staticmethod]
620 fn is_timestamp(t: PyDataType) -> bool {
621 matches!(t.0, DataType::Timestamp(_, _))
622 }
623 #[staticmethod]
624 fn is_date(t: PyDataType) -> bool {
625 matches!(t.0, DataType::Date32 | DataType::Date64)
626 }
627 #[staticmethod]
628 fn is_date32(t: PyDataType) -> bool {
629 t.0 == DataType::Date32
630 }
631 #[staticmethod]
632 fn is_date64(t: PyDataType) -> bool {
633 t.0 == DataType::Date64
634 }
635 #[staticmethod]
636 fn is_time(t: PyDataType) -> bool {
637 matches!(t.0, DataType::Time32(_) | DataType::Time64(_))
638 }
639 #[staticmethod]
640 fn is_time32(t: PyDataType) -> bool {
641 matches!(t.0, DataType::Time32(_))
642 }
643 #[staticmethod]
644 fn is_time64(t: PyDataType) -> bool {
645 matches!(t.0, DataType::Time64(_))
646 }
647 #[staticmethod]
648 fn is_duration(t: PyDataType) -> bool {
649 matches!(t.0, DataType::Duration(_))
650 }
651 #[staticmethod]
652 fn is_interval(t: PyDataType) -> bool {
653 matches!(t.0, DataType::Interval(_))
654 }
655 #[staticmethod]
656 fn is_null(t: PyDataType) -> bool {
657 t.0 == DataType::Null
658 }
659 #[staticmethod]
660 fn is_binary(t: PyDataType) -> bool {
661 t.0 == DataType::Binary
662 }
663 #[staticmethod]
664 fn is_unicode(t: PyDataType) -> bool {
665 t.0 == DataType::Utf8
666 }
667 #[staticmethod]
668 fn is_string(t: PyDataType) -> bool {
669 t.0 == DataType::Utf8
670 }
671 #[staticmethod]
672 fn is_large_binary(t: PyDataType) -> bool {
673 t.0 == DataType::LargeBinary
674 }
675 #[staticmethod]
676 fn is_large_unicode(t: PyDataType) -> bool {
677 t.0 == DataType::LargeUtf8
678 }
679 #[staticmethod]
680 fn is_large_string(t: PyDataType) -> bool {
681 t.0 == DataType::LargeUtf8
682 }
683 #[staticmethod]
684 fn is_binary_view(t: PyDataType) -> bool {
685 t.0 == DataType::BinaryView
686 }
687 #[staticmethod]
688 fn is_string_view(t: PyDataType) -> bool {
689 t.0 == DataType::Utf8View
690 }
691 #[staticmethod]
692 fn is_fixed_size_binary(t: PyDataType) -> bool {
693 matches!(t.0, DataType::FixedSizeBinary(_))
694 }
695 #[staticmethod]
696 fn is_map(t: PyDataType) -> bool {
697 matches!(t.0, DataType::Map(_, _))
698 }
699 #[staticmethod]
700 fn is_dictionary(t: PyDataType) -> bool {
701 matches!(t.0, DataType::Dictionary(_, _))
702 }
703 #[staticmethod]
704 fn is_primitive(t: PyDataType) -> bool {
705 t.0.is_primitive()
706 }
707 #[staticmethod]
708 fn is_numeric(t: PyDataType) -> bool {
709 t.0.is_numeric()
710 }
711 #[staticmethod]
712 fn is_dictionary_key_type(t: PyDataType) -> bool {
713 t.0.is_dictionary_key_type()
714 }
715}