polars_python/interop/arrow/
to_rust.rs

1use polars_core::prelude::*;
2use polars_core::utils::accumulate_dataframes_vertical_unchecked;
3use polars_core::utils::arrow::ffi;
4use polars_core::POOL;
5use pyo3::ffi::Py_uintptr_t;
6use pyo3::prelude::*;
7use pyo3::types::PyList;
8use rayon::prelude::*;
9
10use crate::error::PyPolarsErr;
11
12pub fn field_to_rust_arrow(obj: Bound<'_, PyAny>) -> PyResult<ArrowField> {
13    let mut schema = Box::new(ffi::ArrowSchema::empty());
14    let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema;
15
16    // make the conversion through PyArrow's private API
17    obj.call_method1("_export_to_c", (schema_ptr as Py_uintptr_t,))?;
18    let field = unsafe { ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)? };
19    Ok(field.clone())
20}
21
22pub fn field_to_rust(obj: Bound<'_, PyAny>) -> PyResult<Field> {
23    field_to_rust_arrow(obj).map(|f| (&f).into())
24}
25
26// PyList<Field> which you get by calling `list(schema)`
27pub fn pyarrow_schema_to_rust(obj: &Bound<'_, PyList>) -> PyResult<Schema> {
28    obj.into_iter().map(field_to_rust).collect()
29}
30
31pub fn array_to_rust(obj: &Bound<PyAny>) -> PyResult<ArrayRef> {
32    // prepare a pointer to receive the Array struct
33    let mut array = Box::new(ffi::ArrowArray::empty());
34    let mut schema = Box::new(ffi::ArrowSchema::empty());
35
36    let array_ptr = array.as_mut() as *mut ffi::ArrowArray;
37    let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema;
38
39    // make the conversion through PyArrow's private API
40    // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds
41    obj.call_method1(
42        "_export_to_c",
43        (array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t),
44    )?;
45
46    unsafe {
47        let field = ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)?;
48        let array = ffi::import_array_from_c(*array, field.dtype).map_err(PyPolarsErr::from)?;
49        Ok(array)
50    }
51}
52
53pub fn to_rust_df(py: Python, rb: &[Bound<PyAny>], schema: Bound<PyAny>) -> PyResult<DataFrame> {
54    let ArrowDataType::Struct(fields) = field_to_rust_arrow(schema)?.dtype else {
55        return Err(PyPolarsErr::Other("invalid top-level schema".into()).into());
56    };
57    let schema = ArrowSchema::from_iter(fields);
58
59    if rb.is_empty() {
60        let columns = schema
61            .iter_values()
62            .map(|field| {
63                let field = Field::from(field);
64                Series::new_empty(field.name, &field.dtype).into_column()
65            })
66            .collect::<Vec<_>>();
67
68        // no need to check as a record batch has the same guarantees
69        return Ok(unsafe { DataFrame::new_no_checks_height_from_first(columns) });
70    }
71
72    let dfs = rb
73        .iter()
74        .map(|rb| {
75            let mut run_parallel = false;
76
77            let columns = (0..schema.len())
78                .map(|i| {
79                    let array = rb.call_method1("column", (i,))?;
80                    let arr = array_to_rust(&array)?;
81                    run_parallel |= matches!(
82                        arr.dtype(),
83                        ArrowDataType::Utf8 | ArrowDataType::Dictionary(_, _, _)
84                    );
85                    Ok(arr)
86                })
87                .collect::<PyResult<Vec<_>>>()?;
88
89            // we parallelize this part because we can have dtypes that are not zero copy
90            // for instance string -> large-utf8
91            // dict encoded to categorical
92            let columns = if run_parallel {
93                py.allow_threads(|| {
94                    POOL.install(|| {
95                        columns
96                            .into_par_iter()
97                            .enumerate()
98                            .map(|(i, arr)| {
99                                let (_, field) = schema.get_at_index(i).unwrap();
100                                let s = unsafe {
101                                    Series::_try_from_arrow_unchecked_with_md(
102                                        field.name.clone(),
103                                        vec![arr],
104                                        field.dtype(),
105                                        field.metadata.as_deref(),
106                                    )
107                                }
108                                .map_err(PyPolarsErr::from)?
109                                .into_column();
110                                Ok(s)
111                            })
112                            .collect::<PyResult<Vec<_>>>()
113                    })
114                })
115            } else {
116                columns
117                    .into_iter()
118                    .enumerate()
119                    .map(|(i, arr)| {
120                        let (_, field) = schema.get_at_index(i).unwrap();
121                        let s = unsafe {
122                            Series::_try_from_arrow_unchecked_with_md(
123                                field.name.clone(),
124                                vec![arr],
125                                field.dtype(),
126                                field.metadata.as_deref(),
127                            )
128                        }
129                        .map_err(PyPolarsErr::from)?
130                        .into_column();
131                        Ok(s)
132                    })
133                    .collect::<PyResult<Vec<_>>>()
134            }?;
135
136            // no need to check as a record batch has the same guarantees
137            Ok(unsafe { DataFrame::new_no_checks_height_from_first(columns) })
138        })
139        .collect::<PyResult<Vec<_>>>()?;
140
141    Ok(accumulate_dataframes_vertical_unchecked(dfs))
142}