pyo3_arrow/
scalar.rs

1use std::fmt::Display;
2use std::str::FromStr;
3use std::sync::Arc;
4
5use arrow_array::cast::AsArray;
6use arrow_array::timezone::Tz;
7use arrow_array::types::*;
8use arrow_array::{Array, ArrayRef, Datum, UnionArray};
9use arrow_cast::cast;
10use arrow_schema::{ArrowError, DataType, Field, FieldRef, TimeUnit};
11use indexmap::IndexMap;
12use pyo3::prelude::*;
13use pyo3::types::{PyCapsule, PyList, PyTuple, PyType};
14use pyo3::{intern, IntoPyObjectExt};
15
16use crate::error::PyArrowResult;
17use crate::export::{Arro3DataType, Arro3Field, Arro3Scalar};
18use crate::ffi::to_array_pycapsules;
19use crate::{PyArray, PyField};
20
21/// A Python-facing Arrow scalar
22#[derive(Debug)]
23#[pyclass(module = "arro3.core._core", name = "Scalar", subclass, frozen)]
24pub struct PyScalar {
25    array: ArrayRef,
26    field: FieldRef,
27}
28
29impl PyScalar {
30    /// Create a new PyScalar without any checks
31    ///
32    /// # Safety
33    ///
34    /// - The array's DataType must match the field's DataType
35    /// - The array must have length 1
36    pub unsafe fn new_unchecked(array: ArrayRef, field: FieldRef) -> Self {
37        Self { array, field }
38    }
39
40    /// Create a new PyArray from an [ArrayRef], inferring its data type automatically.
41    pub fn try_from_array_ref(array: ArrayRef) -> PyArrowResult<Self> {
42        let field = Field::new("", array.data_type().clone(), true);
43        Self::try_new(array, Arc::new(field))
44    }
45
46    /// Create a new PyScalar
47    ///
48    /// This will error if the arrays' data type does not match the field's data type or if the
49    /// length of the array is not 1.
50    pub fn try_new(array: ArrayRef, field: FieldRef) -> PyArrowResult<Self> {
51        // Do usual array validation
52        let (array, field) = PyArray::try_new(array, field)?.into_inner();
53        if array.len() != 1 {
54            return Err(ArrowError::SchemaError(
55                "Expected array of length 1 for scalar".to_string(),
56            )
57            .into());
58        }
59
60        Ok(Self { array, field })
61    }
62
63    /// Import from raw Arrow capsules
64    pub fn try_from_arrow_pycapsule(
65        schema_capsule: &Bound<PyCapsule>,
66        array_capsule: &Bound<PyCapsule>,
67    ) -> PyArrowResult<Self> {
68        let (array, field) =
69            PyArray::from_arrow_pycapsule(schema_capsule, array_capsule)?.into_inner();
70        Self::try_new(array, field)
71    }
72
73    /// Access the underlying [ArrayRef].
74    pub fn array(&self) -> &ArrayRef {
75        &self.array
76    }
77
78    /// Access the underlying [FieldRef].
79    pub fn field(&self) -> &FieldRef {
80        &self.field
81    }
82
83    /// Consume self to access the underlying [ArrayRef] and [FieldRef].
84    pub fn into_inner(self) -> (ArrayRef, FieldRef) {
85        (self.array, self.field)
86    }
87
88    /// Export to an arro3.core.Scalar.
89    ///
90    /// This requires that you depend on arro3-core from your Python package.
91    pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
92        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
93        arro3_mod.getattr(intern!(py, "Scalar"))?.call_method1(
94            intern!(py, "from_arrow_pycapsule"),
95            self.__arrow_c_array__(py, None)?,
96        )
97    }
98
99    /// Export to an arro3.core.Scalar.
100    ///
101    /// This requires that you depend on arro3-core from your Python package.
102    pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
103        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
104        let capsules = to_array_pycapsules(py, self.field.clone(), &self.array, None)?;
105        arro3_mod
106            .getattr(intern!(py, "Scalar"))?
107            .call_method1(intern!(py, "from_arrow_pycapsule"), capsules)
108    }
109}
110
111impl Display for PyScalar {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        write!(f, "arro3.core.Scalar<")?;
114        self.array.data_type().fmt(f)?;
115        writeln!(f, ">")?;
116        Ok(())
117    }
118}
119
120impl Datum for PyScalar {
121    fn get(&self) -> (&dyn Array, bool) {
122        (self.array.as_ref(), true)
123    }
124}
125
126#[pymethods]
127impl PyScalar {
128    #[new]
129    #[pyo3(signature = (obj, /, r#type = None, *))]
130    fn init(py: Python, obj: &Bound<PyAny>, r#type: Option<PyField>) -> PyArrowResult<Self> {
131        if let Ok(data) = obj.extract::<PyScalar>() {
132            return Ok(data);
133        }
134
135        let obj = PyList::new(py, vec![obj])?;
136        let array = PyArray::init(&obj, r#type)?;
137        let (array, field) = array.into_inner();
138        Self::try_new(array, field)
139    }
140
141    #[pyo3(signature = (requested_schema=None))]
142    fn __arrow_c_array__<'py>(
143        &'py self,
144        py: Python<'py>,
145        requested_schema: Option<Bound<'py, PyCapsule>>,
146    ) -> PyArrowResult<Bound<'py, PyTuple>> {
147        to_array_pycapsules(py, self.field.clone(), &self.array, requested_schema)
148    }
149
150    fn __eq__(&self, py: Python, other: Bound<'_, PyAny>) -> PyResult<PyObject> {
151        if let Ok(other) = other.extract::<PyScalar>() {
152            let eq = self.array == other.array && self.field == other.field;
153            eq.into_py_any(py)
154            // Ok(eq.into_pyobject(py)?.into_any().unbind())
155        } else {
156            // If other is not an Arrow scalar, cast self to a Python object, and then call its
157            // `__eq__` method.
158            let self_py = self.as_py(py)?;
159            self_py.call_method1(py, intern!(py, "__eq__"), PyTuple::new(py, vec![other])?)
160        }
161    }
162
163    fn __repr__(&self) -> String {
164        self.to_string()
165    }
166
167    #[classmethod]
168    fn from_arrow(_cls: &Bound<PyType>, input: PyScalar) -> Self {
169        input
170    }
171
172    #[classmethod]
173    #[pyo3(name = "from_arrow_pycapsule")]
174    fn from_arrow_pycapsule_py(
175        _cls: &Bound<PyType>,
176        schema_capsule: &Bound<PyCapsule>,
177        array_capsule: &Bound<PyCapsule>,
178    ) -> PyArrowResult<Self> {
179        Self::try_from_arrow_pycapsule(schema_capsule, array_capsule)
180    }
181
182    pub(crate) fn as_py(&self, py: Python) -> PyArrowResult<PyObject> {
183        if self.array.is_null(0) {
184            return Ok(py.None());
185        }
186
187        let arr = self.array.as_ref();
188        let result = match self.array.data_type() {
189            DataType::Null => py.None(),
190            DataType::Boolean => arr.as_boolean().value(0).into_py_any(py)?,
191            DataType::Int8 => arr.as_primitive::<Int8Type>().value(0).into_py_any(py)?,
192            DataType::Int16 => arr.as_primitive::<Int16Type>().value(0).into_py_any(py)?,
193            DataType::Int32 => arr.as_primitive::<Int32Type>().value(0).into_py_any(py)?,
194            DataType::Int64 => arr.as_primitive::<Int64Type>().value(0).into_py_any(py)?,
195            DataType::UInt8 => arr.as_primitive::<UInt8Type>().value(0).into_py_any(py)?,
196            DataType::UInt16 => arr.as_primitive::<UInt16Type>().value(0).into_py_any(py)?,
197            DataType::UInt32 => arr.as_primitive::<UInt32Type>().value(0).into_py_any(py)?,
198            DataType::UInt64 => arr.as_primitive::<UInt64Type>().value(0).into_py_any(py)?,
199            DataType::Float16 => {
200                f32::from(arr.as_primitive::<Float16Type>().value(0)).into_py_any(py)?
201            }
202            DataType::Float32 => arr.as_primitive::<Float32Type>().value(0).into_py_any(py)?,
203            DataType::Float64 => arr.as_primitive::<Float64Type>().value(0).into_py_any(py)?,
204            DataType::Timestamp(time_unit, tz) => {
205                if let Some(tz) = tz {
206                    let tz = Tz::from_str(tz)?;
207                    match time_unit {
208                        TimeUnit::Second => arr
209                            .as_primitive::<TimestampSecondType>()
210                            .value_as_datetime_with_tz(0, tz)
211                            .map(|dt| dt.fixed_offset())
212                            .into_py_any(py)?,
213                        TimeUnit::Millisecond => arr
214                            .as_primitive::<TimestampMillisecondType>()
215                            .value_as_datetime_with_tz(0, tz)
216                            .map(|dt| dt.fixed_offset())
217                            .into_py_any(py)?,
218                        TimeUnit::Microsecond => arr
219                            .as_primitive::<TimestampMicrosecondType>()
220                            .value_as_datetime_with_tz(0, tz)
221                            .map(|dt| dt.fixed_offset())
222                            .into_py_any(py)?,
223                        TimeUnit::Nanosecond => arr
224                            .as_primitive::<TimestampNanosecondType>()
225                            .value_as_datetime_with_tz(0, tz)
226                            .map(|dt| dt.fixed_offset())
227                            .into_py_any(py)?,
228                    }
229                } else {
230                    match time_unit {
231                        TimeUnit::Second => arr
232                            .as_primitive::<TimestampSecondType>()
233                            .value_as_datetime(0)
234                            .into_py_any(py)?,
235                        TimeUnit::Millisecond => arr
236                            .as_primitive::<TimestampMillisecondType>()
237                            .value_as_datetime(0)
238                            .into_py_any(py)?,
239                        TimeUnit::Microsecond => arr
240                            .as_primitive::<TimestampMicrosecondType>()
241                            .value_as_datetime(0)
242                            .into_py_any(py)?,
243                        TimeUnit::Nanosecond => arr
244                            .as_primitive::<TimestampNanosecondType>()
245                            .value_as_datetime(0)
246                            .into_py_any(py)?,
247                    }
248                }
249            }
250            DataType::Date32 => arr
251                .as_primitive::<Date32Type>()
252                .value_as_date(0)
253                .into_py_any(py)?,
254            DataType::Date64 => arr
255                .as_primitive::<Date64Type>()
256                .value_as_date(0)
257                .into_py_any(py)?,
258            DataType::Time32(time_unit) => match time_unit {
259                TimeUnit::Second => arr
260                    .as_primitive::<Time32SecondType>()
261                    .value_as_time(0)
262                    .into_py_any(py)?,
263                TimeUnit::Millisecond => arr
264                    .as_primitive::<Time32MillisecondType>()
265                    .value_as_time(0)
266                    .into_py_any(py)?,
267                _ => unreachable!(),
268            },
269            DataType::Time64(time_unit) => match time_unit {
270                TimeUnit::Microsecond => arr
271                    .as_primitive::<Time64MicrosecondType>()
272                    .value_as_time(0)
273                    .into_py_any(py)?,
274                TimeUnit::Nanosecond => arr
275                    .as_primitive::<Time64NanosecondType>()
276                    .value_as_time(0)
277                    .into_py_any(py)?,
278
279                _ => unreachable!(),
280            },
281            DataType::Duration(time_unit) => match time_unit {
282                TimeUnit::Second => arr
283                    .as_primitive::<DurationSecondType>()
284                    .value_as_duration(0)
285                    .into_py_any(py)?,
286                TimeUnit::Millisecond => arr
287                    .as_primitive::<DurationMillisecondType>()
288                    .value_as_duration(0)
289                    .into_py_any(py)?,
290                TimeUnit::Microsecond => arr
291                    .as_primitive::<DurationMicrosecondType>()
292                    .value_as_duration(0)
293                    .into_py_any(py)?,
294                TimeUnit::Nanosecond => arr
295                    .as_primitive::<DurationNanosecondType>()
296                    .value_as_duration(0)
297                    .into_py_any(py)?,
298            },
299            DataType::Interval(_) => {
300                // https://github.com/apache/arrow-rs/blob/6c59b7637592e4b67b18762b8313f91086c0d5d8/arrow-array/src/temporal_conversions.rs#L219
301                todo!("interval is not yet fully documented [ARROW-3097]")
302            }
303            DataType::Binary => arr.as_binary::<i32>().value(0).into_py_any(py)?,
304            DataType::FixedSizeBinary(_) => arr.as_fixed_size_binary().value(0).into_py_any(py)?,
305            DataType::LargeBinary => arr.as_binary::<i64>().value(0).into_py_any(py)?,
306            DataType::BinaryView => arr.as_binary_view().value(0).into_py_any(py)?,
307            DataType::Utf8 => arr.as_string::<i32>().value(0).into_py_any(py)?,
308            DataType::LargeUtf8 => arr.as_string::<i64>().value(0).into_py_any(py)?,
309            DataType::Utf8View => arr.as_string_view().value(0).into_py_any(py)?,
310            DataType::List(inner_field) => {
311                let inner_array = arr.as_list::<i32>().value(0);
312                list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
313            }
314            DataType::LargeList(inner_field) => {
315                let inner_array = arr.as_list::<i64>().value(0);
316                list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
317            }
318            DataType::FixedSizeList(inner_field, _list_size) => {
319                let inner_array = arr.as_fixed_size_list().value(0);
320                list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
321            }
322            DataType::ListView(_inner_field) => {
323                todo!("as_list_view does not yet exist");
324                // let inner_array = arr.as_list_view::<i32>().value(0);
325                // list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
326            }
327            DataType::LargeListView(_inner_field) => {
328                todo!("as_list_view does not yet exist");
329                // let inner_array = arr.as_list_view::<i64>().value(0);
330                // list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
331            }
332            DataType::Struct(inner_fields) => {
333                let struct_array = arr.as_struct();
334                let mut dict_py_objects: IndexMap<&str, PyObject> =
335                    IndexMap::with_capacity(inner_fields.len());
336                for (inner_field, column) in inner_fields.iter().zip(struct_array.columns()) {
337                    let scalar =
338                        unsafe { PyScalar::new_unchecked(column.clone(), inner_field.clone()) };
339                    dict_py_objects.insert(inner_field.name(), scalar.as_py(py)?);
340                }
341                dict_py_objects.into_py_any(py)?
342            }
343            DataType::Union(_, _) => {
344                let array = arr.as_any().downcast_ref::<UnionArray>().unwrap();
345                let scalar = PyScalar::try_from_array_ref(array.value(0))?;
346                scalar.as_py(py)?
347            }
348            DataType::Dictionary(_, _) => {
349                let array = arr.as_any_dictionary();
350                let keys = array.keys();
351                let key = match keys.data_type() {
352                    DataType::Int8 => keys.as_primitive::<Int8Type>().value(0) as usize,
353                    DataType::Int16 => keys.as_primitive::<Int16Type>().value(0) as usize,
354                    DataType::Int32 => keys.as_primitive::<Int32Type>().value(0) as usize,
355                    DataType::Int64 => keys.as_primitive::<Int64Type>().value(0) as usize,
356                    DataType::UInt8 => keys.as_primitive::<UInt8Type>().value(0) as usize,
357                    DataType::UInt16 => keys.as_primitive::<UInt16Type>().value(0) as usize,
358                    DataType::UInt32 => keys.as_primitive::<UInt32Type>().value(0) as usize,
359                    DataType::UInt64 => keys.as_primitive::<UInt64Type>().value(0) as usize,
360                    // Above are the valid dictionary key types
361                    // https://docs.rs/arrow/latest/arrow/datatypes/trait.ArrowDictionaryKeyType.html
362                    _ => unreachable!(),
363                };
364                let value = array.values().slice(key, 1);
365                PyScalar::try_from_array_ref(value)?.as_py(py)?
366            }
367
368            // TODO: decimal support.
369            //
370            // We should implement this by constructing a tuple object to pass into the
371            // decimal.Decimal constructor.
372            //
373            // From the docs: https://docs.python.org/3/library/decimal.html#decimal.Decimal
374            //
375            // If value is a tuple, it should have three components, a sign (0 for positive or 1
376            // for negative), a tuple of digits, and an integer exponent. For example, Decimal((0,
377            // (1, 4, 1, 4), -3)) returns Decimal('1.414').
378            DataType::Decimal128(_, _) => {
379                // let array = arr.as_primitive::<Decimal128Type>();
380                todo!()
381            }
382            DataType::Decimal256(_, _) => {
383                // let array = arr.as_primitive::<Decimal256Type>();
384                todo!()
385            }
386            DataType::Map(_, _) => {
387                let array = arr.as_map();
388                let struct_arr = array.value(0);
389                let key_arr = struct_arr.column_by_name("key").unwrap();
390                let value_arr = struct_arr.column_by_name("value").unwrap();
391
392                let mut entries = Vec::with_capacity(struct_arr.len());
393                for i in 0..struct_arr.len() {
394                    let py_key = PyScalar::try_from_array_ref(key_arr.slice(i, 1))?.as_py(py)?;
395                    let py_value =
396                        PyScalar::try_from_array_ref(value_arr.slice(i, 1))?.as_py(py)?;
397                    entries.push(PyTuple::new(py, vec![py_key, py_value])?);
398                }
399
400                entries.into_py_any(py)?
401            }
402            DataType::RunEndEncoded(_, _) => {
403                todo!()
404            }
405        };
406        Ok(result)
407    }
408
409    fn cast(&self, target_type: PyField) -> PyArrowResult<Arro3Scalar> {
410        let new_field = target_type.into_inner();
411        let new_array = cast(&self.array, new_field.data_type())?;
412        Ok(PyScalar::try_new(new_array, new_field).unwrap().into())
413    }
414
415    #[getter]
416    #[pyo3(name = "field")]
417    fn py_field(&self) -> Arro3Field {
418        self.field.clone().into()
419    }
420
421    #[getter]
422    fn is_valid(&self) -> bool {
423        self.array.is_valid(0)
424    }
425
426    #[getter]
427    fn r#type(&self) -> Arro3DataType {
428        self.field.data_type().clone().into()
429    }
430}
431
432fn list_values_to_py(
433    py: Python,
434    inner_array: ArrayRef,
435    inner_field: &Arc<Field>,
436) -> PyArrowResult<Vec<PyObject>> {
437    let mut list_py_objects = Vec::with_capacity(inner_array.len());
438    for i in 0..inner_array.len() {
439        let scalar =
440            unsafe { PyScalar::new_unchecked(inner_array.slice(i, 1), inner_field.clone()) };
441        list_py_objects.push(scalar.as_py(py)?);
442    }
443    Ok(list_py_objects)
444}