datafusion_common/
pyarrow.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Conversions between PyArrow and DataFusion types
19
20use arrow::array::{Array, ArrayData};
21use arrow::pyarrow::{FromPyArrow, ToPyArrow};
22use pyo3::exceptions::PyException;
23use pyo3::prelude::PyErr;
24use pyo3::types::{PyAnyMethods, PyList};
25use pyo3::{Bound, FromPyObject, IntoPyObject, PyAny, PyResult, Python};
26
27use crate::{DataFusionError, ScalarValue};
28
29impl From<DataFusionError> for PyErr {
30    fn from(err: DataFusionError) -> PyErr {
31        PyException::new_err(err.to_string())
32    }
33}
34
35impl FromPyArrow for ScalarValue {
36    fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult<Self> {
37        let py = value.py();
38        let typ = value.getattr("type")?;
39        let val = value.call_method0("as_py")?;
40
41        // construct pyarrow array from the python value and pyarrow type
42        let factory = py.import("pyarrow")?.getattr("array")?;
43        let args = PyList::new(py, [val])?;
44        let array = factory.call1((args, typ))?;
45
46        // convert the pyarrow array to rust array using C data interface
47        let array = arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?);
48        let scalar = ScalarValue::try_from_array(&array, 0)?;
49
50        Ok(scalar)
51    }
52}
53
54impl ToPyArrow for ScalarValue {
55    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
56        let array = self.to_array()?;
57        // convert to pyarrow array using C data interface
58        let pyarray = array.to_data().to_pyarrow(py)?;
59        let pyscalar = pyarray.call_method1("__getitem__", (0,))?;
60
61        Ok(pyscalar)
62    }
63}
64
65impl<'source> FromPyObject<'source> for ScalarValue {
66    fn extract_bound(value: &Bound<'source, PyAny>) -> PyResult<Self> {
67        Self::from_pyarrow_bound(value)
68    }
69}
70
71impl<'source> IntoPyObject<'source> for ScalarValue {
72    type Target = PyAny;
73
74    type Output = Bound<'source, Self::Target>;
75
76    type Error = PyErr;
77
78    fn into_pyobject(self, py: Python<'source>) -> Result<Self::Output, Self::Error> {
79        let array = self.to_array()?;
80        // convert to pyarrow array using C data interface
81        let pyarray = array.to_data().to_pyarrow(py)?;
82        pyarray.call_method1("__getitem__", (0,))
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use pyo3::ffi::c_str;
89    use pyo3::py_run;
90    use pyo3::types::PyDict;
91    use pyo3::Python;
92
93    use super::*;
94
95    fn init_python() {
96        Python::initialize();
97        Python::attach(|py| {
98            if py.run(c_str!("import pyarrow"), None, None).is_err() {
99                let locals = PyDict::new(py);
100                py.run(
101                    c_str!(
102                        "import sys; executable = sys.executable; python_path = sys.path"
103                    ),
104                    None,
105                    Some(&locals),
106                )
107                .expect("Couldn't get python info");
108                let executable = locals.get_item("executable").unwrap();
109                let executable: String = executable.extract().unwrap();
110
111                let python_path = locals.get_item("python_path").unwrap();
112                let python_path: Vec<String> = python_path.extract().unwrap();
113
114                panic!("pyarrow not found\nExecutable: {executable}\nPython path: {python_path:?}\n\
115                         HINT: try `pip install pyarrow`\n\
116                         NOTE: On Mac OS, you must compile against a Framework Python \
117                         (default in python.org installers and brew, but not pyenv)\n\
118                         NOTE: On Mac OS, PYO3 might point to incorrect Python library \
119                         path when using virtual environments. Try \
120                         `export PYTHONPATH=$(python -c \"import sys; print(sys.path[-1])\")`\n")
121            }
122        })
123    }
124
125    #[test]
126    fn test_roundtrip() {
127        init_python();
128
129        let example_scalars = [
130            ScalarValue::Boolean(Some(true)),
131            ScalarValue::Int32(Some(23)),
132            ScalarValue::Float64(Some(12.34)),
133            ScalarValue::from("Hello!"),
134            ScalarValue::Date32(Some(1234)),
135        ];
136
137        Python::attach(|py| {
138            for scalar in example_scalars.iter() {
139                let result =
140                    ScalarValue::from_pyarrow_bound(&scalar.to_pyarrow(py).unwrap())
141                        .unwrap();
142                assert_eq!(scalar, &result);
143            }
144        });
145    }
146
147    #[test]
148    fn test_py_scalar() -> PyResult<()> {
149        init_python();
150
151        Python::attach(|py| -> PyResult<()> {
152            let scalar_float = ScalarValue::Float64(Some(12.34));
153            let py_float = scalar_float
154                .into_pyobject(py)?
155                .call_method0("as_py")
156                .unwrap();
157            py_run!(py, py_float, "assert py_float == 12.34");
158
159            let scalar_string = ScalarValue::Utf8(Some("Hello!".to_string()));
160            let py_string = scalar_string
161                .into_pyobject(py)?
162                .call_method0("as_py")
163                .unwrap();
164            py_run!(py, py_string, "assert py_string == 'Hello!'");
165
166            Ok(())
167        })
168    }
169}