Skip to main content

datafusion_python/
pyarrow_util.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Conversions between PyArrow and DataFusion types
19
20use std::sync::Arc;
21
22use arrow::array::{Array, ArrayData, ArrayRef, ListArray, make_array};
23use arrow::buffer::OffsetBuffer;
24use arrow::datatypes::Field;
25use arrow::pyarrow::{FromPyArrow, ToPyArrow};
26use datafusion::common::exec_err;
27use datafusion::scalar::ScalarValue;
28use pyo3::types::{PyAnyMethods, PyList};
29use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python};
30
31use crate::common::data_type::PyScalarValue;
32use crate::errors::PyDataFusionError;
33
34/// Helper function to turn an Array into a ScalarValue. If ``as_list_array`` is true,
35/// the array will be turned into a ``ListArray``. Otherwise, we extract the first value
36/// from the array.
37fn array_to_scalar_value(array: ArrayRef, as_list_array: bool) -> PyResult<PyScalarValue> {
38    if as_list_array {
39        let field = Arc::new(Field::new_list_field(
40            array.data_type().clone(),
41            array.nulls().is_some(),
42        ));
43        let offsets = OffsetBuffer::from_lengths(vec![array.len()]);
44        let list_array = ListArray::new(field, offsets, array, None);
45        Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array))))
46    } else {
47        let scalar = ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?;
48        Ok(PyScalarValue(scalar))
49    }
50}
51
52/// Helper function to take any Python object that contains an Arrow PyCapsule
53/// interface and attempt to extract a scalar value from it. If `as_list_array`
54/// is true, the array will be turned into a ``ListArray``. Otherwise, we extract
55/// the first value from the array.
56fn pyobj_extract_scalar_via_capsule(
57    value: &Bound<'_, PyAny>,
58    as_list_array: bool,
59) -> PyResult<PyScalarValue> {
60    let array_data = ArrayData::from_pyarrow_bound(value)?;
61    let array = make_array(array_data);
62
63    array_to_scalar_value(array, as_list_array)
64}
65
66impl FromPyArrow for PyScalarValue {
67    fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult<Self> {
68        let py = value.py();
69        let pyarrow_mod = py.import("pyarrow");
70
71        // Is it a PyArrow object?
72        if let Ok(pa) = pyarrow_mod.as_ref() {
73            let scalar_type = pa.getattr("Scalar")?;
74            if value.is_instance(&scalar_type)? {
75                let typ = value.getattr("type")?;
76
77                // construct pyarrow array from the python value and pyarrow type
78                let factory = py.import("pyarrow")?.getattr("array")?;
79                let args = PyList::new(py, [value])?;
80                let array = factory.call1((args, typ))?;
81
82                return pyobj_extract_scalar_via_capsule(&array, false);
83            }
84
85            let array_type = pa.getattr("Array")?;
86            if value.is_instance(&array_type)? {
87                return pyobj_extract_scalar_via_capsule(value, true);
88            }
89        }
90
91        // Is it a NanoArrow scalar?
92        if let Ok(na) = py.import("nanoarrow") {
93            let scalar_type = py.import("nanoarrow.array")?.getattr("Scalar")?;
94            if value.is_instance(&scalar_type)? {
95                return pyobj_extract_scalar_via_capsule(value, false);
96            }
97            let array_type = na.getattr("Array")?;
98            if value.is_instance(&array_type)? {
99                return pyobj_extract_scalar_via_capsule(value, true);
100            }
101        }
102
103        // Is it a arro3 scalar?
104        if let Ok(arro3) = py.import("arro3").and_then(|arro3| arro3.getattr("core")) {
105            let scalar_type = arro3.getattr("Scalar")?;
106            if value.is_instance(&scalar_type)? {
107                return pyobj_extract_scalar_via_capsule(value, false);
108            }
109            let array_type = arro3.getattr("Array")?;
110            if value.is_instance(&array_type)? {
111                return pyobj_extract_scalar_via_capsule(value, true);
112            }
113        }
114
115        // Does it have a PyCapsule interface but isn't one of our known libraries?
116        // If so do our "best guess". Try checking type name, and if that fails
117        // return a single value if the length is 1 and return a List value otherwise
118        if value.hasattr("__arrow_c_array__")? {
119            let type_name = value.get_type().repr()?;
120            if type_name.contains("Scalar")? {
121                return pyobj_extract_scalar_via_capsule(value, false);
122            }
123            if type_name.contains("Array")? {
124                return pyobj_extract_scalar_via_capsule(value, true);
125            }
126
127            let array_data = ArrayData::from_pyarrow_bound(value)?;
128            let array = make_array(array_data);
129
130            let as_array_list = array.len() != 1;
131            return array_to_scalar_value(array, as_array_list);
132        }
133
134        // Last attempt - try to create a PyArrow scalar from a plain Python object
135        if let Ok(pa) = pyarrow_mod.as_ref() {
136            let scalar = pa.call_method1("scalar", (value,))?;
137
138            PyScalarValue::from_pyarrow_bound(&scalar)
139        } else {
140            exec_err!("Unable to import scalar value").map_err(PyDataFusionError::from)?
141        }
142    }
143}
144
145impl<'source> FromPyObject<'source> for PyScalarValue {
146    fn extract_bound(value: &Bound<'source, PyAny>) -> PyResult<Self> {
147        Self::from_pyarrow_bound(value)
148    }
149}
150
151pub fn scalar_to_pyarrow<'py>(
152    scalar: &ScalarValue,
153    py: Python<'py>,
154) -> PyResult<Bound<'py, PyAny>> {
155    let array = scalar.to_array().map_err(PyDataFusionError::from)?;
156    // convert to pyarrow array using C data interface
157    let pyarray = array.to_data().to_pyarrow(py)?;
158    let pyscalar = pyarray.call_method1("__getitem__", (0,))?;
159
160    Ok(pyscalar)
161}