pyo3_arrow/scalar/
mod.rs

1mod temporal;
2
3use std::fmt::Display;
4use std::str::FromStr;
5use std::sync::Arc;
6
7use arrow_array::cast::AsArray;
8use arrow_array::types::*;
9use arrow_array::{Array, ArrayRef, Datum, UnionArray};
10use arrow_cast::cast;
11use arrow_cast::pretty::pretty_format_columns_with_options;
12use arrow_schema::{ArrowError, DataType, Field, FieldRef, TimeUnit};
13use indexmap::IndexMap;
14use pyo3::prelude::*;
15use pyo3::types::{PyCapsule, PyList, PyTuple, PyType};
16use pyo3::{intern, IntoPyObjectExt};
17
18use crate::error::PyArrowResult;
19use crate::export::{Arro3DataType, Arro3Field, Arro3Scalar};
20use crate::ffi::to_array_pycapsules;
21use crate::scalar::temporal::{as_datetime_with_timezone, PyArrowTz};
22use crate::utils::default_repr_options;
23use crate::{PyArray, PyField};
24
25/// A Python-facing Arrow scalar
26#[derive(Debug)]
27#[pyclass(module = "arro3.core._core", name = "Scalar", subclass, frozen)]
28pub struct PyScalar {
29    array: ArrayRef,
30    field: FieldRef,
31}
32
33impl PyScalar {
34    /// Create a new PyScalar without any checks
35    ///
36    /// # Safety
37    ///
38    /// - The array's DataType must match the field's DataType
39    /// - The array must have length 1
40    pub unsafe fn new_unchecked(array: ArrayRef, field: FieldRef) -> Self {
41        Self { array, field }
42    }
43
44    /// Create a new PyArray from an [ArrayRef], inferring its data type automatically.
45    pub fn try_from_array_ref(array: ArrayRef) -> PyArrowResult<Self> {
46        let field = Field::new("", array.data_type().clone(), true);
47        Self::try_new(array, Arc::new(field))
48    }
49
50    /// Create a new PyScalar
51    ///
52    /// This will error if the arrays' data type does not match the field's data type or if the
53    /// length of the array is not 1.
54    pub fn try_new(array: ArrayRef, field: FieldRef) -> PyArrowResult<Self> {
55        // Do usual array validation
56        let (array, field) = PyArray::try_new(array, field)?.into_inner();
57        if array.len() != 1 {
58            return Err(ArrowError::SchemaError(
59                "Expected array of length 1 for scalar".to_string(),
60            )
61            .into());
62        }
63
64        Ok(Self { array, field })
65    }
66
67    /// Import from raw Arrow capsules
68    pub fn try_from_arrow_pycapsule(
69        schema_capsule: &Bound<PyCapsule>,
70        array_capsule: &Bound<PyCapsule>,
71    ) -> PyArrowResult<Self> {
72        let (array, field) =
73            PyArray::from_arrow_pycapsule(schema_capsule, array_capsule)?.into_inner();
74        Self::try_new(array, field)
75    }
76
77    /// Access the underlying [ArrayRef].
78    pub fn array(&self) -> &ArrayRef {
79        &self.array
80    }
81
82    /// Access the underlying [FieldRef].
83    pub fn field(&self) -> &FieldRef {
84        &self.field
85    }
86
87    /// Consume self to access the underlying [ArrayRef] and [FieldRef].
88    pub fn into_inner(self) -> (ArrayRef, FieldRef) {
89        (self.array, self.field)
90    }
91
92    /// Export to an arro3.core.Scalar.
93    ///
94    /// This requires that you depend on arro3-core from your Python package.
95    pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
96        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
97        arro3_mod.getattr(intern!(py, "Scalar"))?.call_method1(
98            intern!(py, "from_arrow_pycapsule"),
99            self.__arrow_c_array__(py, None)?,
100        )
101    }
102
103    /// Export to an arro3.core.Scalar.
104    ///
105    /// This requires that you depend on arro3-core from your Python package.
106    pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
107        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
108        let capsules = to_array_pycapsules(py, self.field.clone(), &self.array, None)?;
109        arro3_mod
110            .getattr(intern!(py, "Scalar"))?
111            .call_method1(intern!(py, "from_arrow_pycapsule"), capsules)
112    }
113}
114
115impl Display for PyScalar {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        write!(f, "arro3.core.Scalar<")?;
118        self.array.data_type().fmt(f)?;
119        writeln!(f, ">")?;
120
121        pretty_format_columns_with_options(
122            self.field.name(),
123            std::slice::from_ref(&self.array),
124            &default_repr_options(),
125        )
126        .map_err(|_| std::fmt::Error)?
127        .fmt(f)?;
128
129        Ok(())
130    }
131}
132
133impl Datum for PyScalar {
134    fn get(&self) -> (&dyn Array, bool) {
135        (self.array.as_ref(), true)
136    }
137}
138
139#[pymethods]
140impl PyScalar {
141    #[new]
142    #[pyo3(signature = (obj, /, r#type = None, *))]
143    fn init(py: Python, obj: &Bound<PyAny>, r#type: Option<PyField>) -> PyArrowResult<Self> {
144        if obj.hasattr(intern!(py, "__arrow_c_array__"))?
145            || obj.hasattr(intern!(py, "__arrow_c_stream__"))?
146        {
147            return Ok(obj.extract::<PyScalar>()?);
148        }
149
150        let obj = PyList::new(py, vec![obj])?;
151        let array = PyArray::init(py, &obj, r#type)?;
152        let (array, field) = array.into_inner();
153        Self::try_new(array, field)
154    }
155
156    #[pyo3(signature = (requested_schema=None))]
157    fn __arrow_c_array__<'py>(
158        &'py self,
159        py: Python<'py>,
160        requested_schema: Option<Bound<'py, PyCapsule>>,
161    ) -> PyArrowResult<Bound<'py, PyTuple>> {
162        to_array_pycapsules(py, self.field.clone(), &self.array, requested_schema)
163    }
164
165    fn __eq__(&self, py: Python, other: Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
166        if let Ok(other) = other.extract::<PyScalar>() {
167            let eq = self.array == other.array && self.field == other.field;
168            eq.into_py_any(py)
169        } else {
170            // If other is not an Arrow scalar, cast self to a Python object, and then call its
171            // `__eq__` method.
172            let self_py = self.as_py(py)?;
173            self_py.call_method1(py, intern!(py, "__eq__"), PyTuple::new(py, vec![other])?)
174        }
175    }
176
177    fn __repr__(&self) -> String {
178        self.to_string()
179    }
180
181    #[classmethod]
182    fn from_arrow(_cls: &Bound<PyType>, input: PyScalar) -> Self {
183        input
184    }
185
186    #[classmethod]
187    #[pyo3(name = "from_arrow_pycapsule")]
188    fn from_arrow_pycapsule_py(
189        _cls: &Bound<PyType>,
190        schema_capsule: &Bound<PyCapsule>,
191        array_capsule: &Bound<PyCapsule>,
192    ) -> PyArrowResult<Self> {
193        Self::try_from_arrow_pycapsule(schema_capsule, array_capsule)
194    }
195
196    pub(crate) fn as_py(&self, py: Python) -> PyArrowResult<Py<PyAny>> {
197        if self.array.is_null(0) {
198            return Ok(py.None());
199        }
200
201        let arr = self.array.as_ref();
202        let result = match self.array.data_type() {
203            DataType::Null => py.None(),
204            DataType::Boolean => arr.as_boolean().value(0).into_py_any(py)?,
205            DataType::Int8 => arr.as_primitive::<Int8Type>().value(0).into_py_any(py)?,
206            DataType::Int16 => arr.as_primitive::<Int16Type>().value(0).into_py_any(py)?,
207            DataType::Int32 => arr.as_primitive::<Int32Type>().value(0).into_py_any(py)?,
208            DataType::Int64 => arr.as_primitive::<Int64Type>().value(0).into_py_any(py)?,
209            DataType::UInt8 => arr.as_primitive::<UInt8Type>().value(0).into_py_any(py)?,
210            DataType::UInt16 => arr.as_primitive::<UInt16Type>().value(0).into_py_any(py)?,
211            DataType::UInt32 => arr.as_primitive::<UInt32Type>().value(0).into_py_any(py)?,
212            DataType::UInt64 => arr.as_primitive::<UInt64Type>().value(0).into_py_any(py)?,
213            DataType::Float16 => {
214                f32::from(arr.as_primitive::<Float16Type>().value(0)).into_py_any(py)?
215            }
216            DataType::Float32 => arr.as_primitive::<Float32Type>().value(0).into_py_any(py)?,
217            DataType::Float64 => arr.as_primitive::<Float64Type>().value(0).into_py_any(py)?,
218            DataType::Timestamp(time_unit, tz) => {
219                if let Some(tz) = tz {
220                    let tz = PyArrowTz::from_str(tz)?;
221                    match time_unit {
222                        TimeUnit::Second => {
223                            let value = arr.as_primitive::<TimestampSecondType>().value(0);
224                            as_datetime_with_timezone::<TimestampSecondType>(value, tz)
225                                .into_py_any(py)?
226                        }
227                        TimeUnit::Millisecond => {
228                            let value = arr.as_primitive::<TimestampMillisecondType>().value(0);
229                            as_datetime_with_timezone::<TimestampMillisecondType>(value, tz)
230                                .into_py_any(py)?
231                        }
232                        TimeUnit::Microsecond => {
233                            let value = arr.as_primitive::<TimestampMicrosecondType>().value(0);
234                            as_datetime_with_timezone::<TimestampMicrosecondType>(value, tz)
235                                .into_py_any(py)?
236                        }
237                        TimeUnit::Nanosecond => {
238                            let value = arr.as_primitive::<TimestampNanosecondType>().value(0);
239                            as_datetime_with_timezone::<TimestampNanosecondType>(value, tz)
240                                .into_py_any(py)?
241                        }
242                    }
243                } else {
244                    match time_unit {
245                        TimeUnit::Second => arr
246                            .as_primitive::<TimestampSecondType>()
247                            .value_as_datetime(0)
248                            .into_py_any(py)?,
249                        TimeUnit::Millisecond => arr
250                            .as_primitive::<TimestampMillisecondType>()
251                            .value_as_datetime(0)
252                            .into_py_any(py)?,
253                        TimeUnit::Microsecond => arr
254                            .as_primitive::<TimestampMicrosecondType>()
255                            .value_as_datetime(0)
256                            .into_py_any(py)?,
257                        TimeUnit::Nanosecond => arr
258                            .as_primitive::<TimestampNanosecondType>()
259                            .value_as_datetime(0)
260                            .into_py_any(py)?,
261                    }
262                }
263            }
264            DataType::Date32 => arr
265                .as_primitive::<Date32Type>()
266                .value_as_date(0)
267                .into_py_any(py)?,
268            DataType::Date64 => arr
269                .as_primitive::<Date64Type>()
270                .value_as_date(0)
271                .into_py_any(py)?,
272            DataType::Time32(time_unit) => match time_unit {
273                TimeUnit::Second => arr
274                    .as_primitive::<Time32SecondType>()
275                    .value_as_time(0)
276                    .into_py_any(py)?,
277                TimeUnit::Millisecond => arr
278                    .as_primitive::<Time32MillisecondType>()
279                    .value_as_time(0)
280                    .into_py_any(py)?,
281                _ => unreachable!(),
282            },
283            DataType::Time64(time_unit) => match time_unit {
284                TimeUnit::Microsecond => arr
285                    .as_primitive::<Time64MicrosecondType>()
286                    .value_as_time(0)
287                    .into_py_any(py)?,
288                TimeUnit::Nanosecond => arr
289                    .as_primitive::<Time64NanosecondType>()
290                    .value_as_time(0)
291                    .into_py_any(py)?,
292
293                _ => unreachable!(),
294            },
295            DataType::Duration(time_unit) => match time_unit {
296                TimeUnit::Second => arr
297                    .as_primitive::<DurationSecondType>()
298                    .value_as_duration(0)
299                    .into_py_any(py)?,
300                TimeUnit::Millisecond => arr
301                    .as_primitive::<DurationMillisecondType>()
302                    .value_as_duration(0)
303                    .into_py_any(py)?,
304                TimeUnit::Microsecond => arr
305                    .as_primitive::<DurationMicrosecondType>()
306                    .value_as_duration(0)
307                    .into_py_any(py)?,
308                TimeUnit::Nanosecond => arr
309                    .as_primitive::<DurationNanosecondType>()
310                    .value_as_duration(0)
311                    .into_py_any(py)?,
312            },
313            DataType::Interval(_) => {
314                // https://github.com/apache/arrow-rs/blob/6c59b7637592e4b67b18762b8313f91086c0d5d8/arrow-array/src/temporal_conversions.rs#L219
315                todo!("interval is not yet fully documented [ARROW-3097]")
316            }
317            DataType::Binary => arr.as_binary::<i32>().value(0).into_py_any(py)?,
318            DataType::FixedSizeBinary(_) => arr.as_fixed_size_binary().value(0).into_py_any(py)?,
319            DataType::LargeBinary => arr.as_binary::<i64>().value(0).into_py_any(py)?,
320            DataType::BinaryView => arr.as_binary_view().value(0).into_py_any(py)?,
321            DataType::Utf8 => arr.as_string::<i32>().value(0).into_py_any(py)?,
322            DataType::LargeUtf8 => arr.as_string::<i64>().value(0).into_py_any(py)?,
323            DataType::Utf8View => arr.as_string_view().value(0).into_py_any(py)?,
324            DataType::List(inner_field) => {
325                let inner_array = arr.as_list::<i32>().value(0);
326                list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
327            }
328            DataType::LargeList(inner_field) => {
329                let inner_array = arr.as_list::<i64>().value(0);
330                list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
331            }
332            DataType::FixedSizeList(inner_field, _list_size) => {
333                let inner_array = arr.as_fixed_size_list().value(0);
334                list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
335            }
336            DataType::ListView(_inner_field) => {
337                todo!("as_list_view does not yet exist");
338                // let inner_array = arr.as_list_view::<i32>().value(0);
339                // list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
340            }
341            DataType::LargeListView(_inner_field) => {
342                todo!("as_list_view does not yet exist");
343                // let inner_array = arr.as_list_view::<i64>().value(0);
344                // list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
345            }
346            DataType::Struct(inner_fields) => {
347                let struct_array = arr.as_struct();
348                let mut dict_py_objects = IndexMap::with_capacity(inner_fields.len());
349                for (inner_field, column) in inner_fields.iter().zip(struct_array.columns()) {
350                    let scalar =
351                        unsafe { PyScalar::new_unchecked(column.clone(), inner_field.clone()) };
352                    dict_py_objects.insert(inner_field.name(), scalar.as_py(py)?);
353                }
354                dict_py_objects.into_py_any(py)?
355            }
356            DataType::Union(_, _) => {
357                let array = arr.as_any().downcast_ref::<UnionArray>().unwrap();
358                let scalar = PyScalar::try_from_array_ref(array.value(0))?;
359                scalar.as_py(py)?
360            }
361            DataType::Dictionary(_, _) => {
362                let array = arr.as_any_dictionary();
363                let keys = array.keys();
364                let key = match keys.data_type() {
365                    DataType::Int8 => keys.as_primitive::<Int8Type>().value(0) as usize,
366                    DataType::Int16 => keys.as_primitive::<Int16Type>().value(0) as usize,
367                    DataType::Int32 => keys.as_primitive::<Int32Type>().value(0) as usize,
368                    DataType::Int64 => keys.as_primitive::<Int64Type>().value(0) as usize,
369                    DataType::UInt8 => keys.as_primitive::<UInt8Type>().value(0) as usize,
370                    DataType::UInt16 => keys.as_primitive::<UInt16Type>().value(0) as usize,
371                    DataType::UInt32 => keys.as_primitive::<UInt32Type>().value(0) as usize,
372                    DataType::UInt64 => keys.as_primitive::<UInt64Type>().value(0) as usize,
373                    // Above are the valid dictionary key types
374                    // https://docs.rs/arrow/latest/arrow/datatypes/trait.ArrowDictionaryKeyType.html
375                    _ => unreachable!(),
376                };
377                let value = array.values().slice(key, 1);
378                PyScalar::try_from_array_ref(value)?.as_py(py)?
379            }
380            DataType::Decimal32(precision, scale) => {
381                let decimal_mod = py.import(intern!(py, "decimal"))?;
382                let decimal_class = decimal_mod.getattr(intern!(py, "Decimal"))?;
383
384                let array = arr.as_primitive::<Decimal32Type>();
385                let s = Decimal32Type::format_decimal(array.value(0), *precision, *scale);
386                decimal_class.call1((s,))?.unbind()
387            }
388            DataType::Decimal64(precision, scale) => {
389                let decimal_mod = py.import(intern!(py, "decimal"))?;
390                let decimal_class = decimal_mod.getattr(intern!(py, "Decimal"))?;
391
392                let array = arr.as_primitive::<Decimal64Type>();
393                let s = Decimal64Type::format_decimal(array.value(0), *precision, *scale);
394                decimal_class.call1((s,))?.unbind()
395            }
396            DataType::Decimal128(precision, scale) => {
397                let decimal_mod = py.import(intern!(py, "decimal"))?;
398                let decimal_class = decimal_mod.getattr(intern!(py, "Decimal"))?;
399
400                let array = arr.as_primitive::<Decimal128Type>();
401                let s = Decimal128Type::format_decimal(array.value(0), *precision, *scale);
402                decimal_class.call1((s,))?.unbind()
403            }
404            DataType::Decimal256(precision, scale) => {
405                let decimal_mod = py.import(intern!(py, "decimal"))?;
406                let decimal_class = decimal_mod.getattr(intern!(py, "Decimal"))?;
407
408                let array = arr.as_primitive::<Decimal256Type>();
409                let s = Decimal256Type::format_decimal(array.value(0), *precision, *scale);
410                decimal_class.call1((s,))?.unbind()
411            }
412            DataType::Map(_, _) => {
413                let array = arr.as_map();
414                let struct_arr = array.value(0);
415                let key_arr = struct_arr.column_by_name("key").unwrap();
416                let value_arr = struct_arr.column_by_name("value").unwrap();
417
418                let mut entries = Vec::with_capacity(struct_arr.len());
419                for i in 0..struct_arr.len() {
420                    let py_key = PyScalar::try_from_array_ref(key_arr.slice(i, 1))?.as_py(py)?;
421                    let py_value =
422                        PyScalar::try_from_array_ref(value_arr.slice(i, 1))?.as_py(py)?;
423                    entries.push(PyTuple::new(py, vec![py_key, py_value])?);
424                }
425
426                entries.into_py_any(py)?
427            }
428            DataType::RunEndEncoded(_, _) => {
429                todo!()
430            }
431        };
432        Ok(result)
433    }
434
435    fn cast(&self, target_type: PyField) -> PyArrowResult<Arro3Scalar> {
436        let new_field = target_type.into_inner();
437        let new_array = cast(&self.array, new_field.data_type())?;
438        Ok(PyScalar::try_new(new_array, new_field).unwrap().into())
439    }
440
441    #[getter]
442    #[pyo3(name = "field")]
443    fn py_field(&self) -> Arro3Field {
444        self.field.clone().into()
445    }
446
447    #[getter]
448    fn is_valid(&self) -> bool {
449        self.array.is_valid(0)
450    }
451
452    #[getter]
453    fn r#type(&self) -> Arro3DataType {
454        self.field.data_type().clone().into()
455    }
456}
457
458fn list_values_to_py(
459    py: Python,
460    inner_array: ArrayRef,
461    inner_field: &Arc<Field>,
462) -> PyArrowResult<Vec<Py<PyAny>>> {
463    let mut list_py_objects = Vec::with_capacity(inner_array.len());
464    for i in 0..inner_array.len() {
465        let scalar =
466            unsafe { PyScalar::new_unchecked(inner_array.slice(i, 1), inner_field.clone()) };
467        list_py_objects.push(scalar.as_py(py)?);
468    }
469    Ok(list_py_objects)
470}