Skip to main content

polars_python/interop/arrow/
to_rust.rs

1use polars_core::POOL;
2use polars_core::prelude::*;
3use polars_core::utils::accumulate_dataframes_vertical_unchecked;
4use polars_core::utils::arrow::ffi;
5use pyo3::ffi::Py_uintptr_t;
6use pyo3::prelude::*;
7use pyo3::types::PyList;
8use rayon::prelude::*;
9
10use crate::error::PyPolarsErr;
11use crate::utils::EnterPolarsExt;
12
13pub fn field_to_rust_arrow(obj: Bound<'_, PyAny>) -> PyResult<ArrowField> {
14    let mut schema = Box::new(ffi::ArrowSchema::empty());
15    let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema;
16
17    // make the conversion through PyArrow's private API
18    obj.call_method1("_export_to_c", (schema_ptr as Py_uintptr_t,))?;
19    let field = unsafe { ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)? };
20    Ok(field)
21}
22
23pub fn field_to_rust(obj: Bound<'_, PyAny>) -> PyResult<Field> {
24    field_to_rust_arrow(obj).map(|f| (&f).into())
25}
26
27// PyList<Field> which you get by calling `list(schema)`
28pub fn pyarrow_schema_to_rust(obj: &Bound<'_, PyList>) -> PyResult<Schema> {
29    obj.into_iter().map(field_to_rust).collect()
30}
31
32pub fn array_to_rust(obj: &Bound<PyAny>) -> PyResult<ArrayRef> {
33    // prepare a pointer to receive the Array struct
34    let mut array = Box::new(ffi::ArrowArray::empty());
35    let mut schema = Box::new(ffi::ArrowSchema::empty());
36
37    let array_ptr = array.as_mut() as *mut ffi::ArrowArray;
38    let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema;
39
40    // make the conversion through PyArrow's private API
41    // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds
42    obj.call_method1(
43        "_export_to_c",
44        (array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t),
45    )?;
46
47    unsafe {
48        let field = ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)?;
49        let array = ffi::import_array_from_c(*array, field.dtype).map_err(PyPolarsErr::from)?;
50        Ok(array)
51    }
52}
53
54pub fn to_rust_df(
55    py: Python<'_>,
56    rb: &[Bound<PyAny>],
57    schema: Bound<PyAny>,
58) -> PyResult<DataFrame> {
59    let ArrowDataType::Struct(fields) = field_to_rust_arrow(schema)?.dtype else {
60        return Err(PyPolarsErr::Other("invalid top-level schema".into()).into());
61    };
62
63    let schema = ArrowSchema::from_iter(fields.iter().cloned());
64
65    // Verify that field names are not duplicated. Arrow permits duplicate field names, we do not.
66    // Required to uphold safety invariants for unsafe block below.
67    if schema.len() != fields.len() {
68        let mut field_map: PlHashMap<PlSmallStr, u64> = PlHashMap::with_capacity(fields.len());
69        fields.iter().for_each(|field| {
70            field_map
71                .entry(field.name.clone())
72                .and_modify(|c| {
73                    *c += 1;
74                })
75                .or_insert(1);
76        });
77        let duplicate_fields: Vec<_> = field_map
78            .into_iter()
79            .filter_map(|(k, v)| (v > 1).then_some(k))
80            .collect();
81
82        return Err(PyPolarsErr::Polars(PolarsError::Duplicate(
83            format!("column appears more than once; names must be unique: {duplicate_fields:?}")
84                .into(),
85        ))
86        .into());
87    }
88
89    if rb.is_empty() {
90        let columns = schema
91            .iter_values()
92            .map(|field| {
93                let field = Field::from(field);
94                Series::new_empty(field.name, &field.dtype).into_column()
95            })
96            .collect::<Vec<_>>();
97
98        // no need to check as a record batch has the same guarantees
99        return Ok(unsafe { DataFrame::new_unchecked_infer_height(columns) });
100    }
101
102    let dfs = rb
103        .iter()
104        .map(|rb| {
105            let mut run_parallel = false;
106
107            let columns = (0..schema.len())
108                .map(|i| {
109                    let array = rb.call_method1("column", (i,))?;
110                    let mut arr = array_to_rust(&array)?;
111
112                    // Only the schema contains extension type info, restore.
113                    // TODO: nested?
114                    let dtype = schema.get_at_index(i).unwrap().1.dtype();
115                    if let ArrowDataType::Extension(ext) = dtype {
116                        if *arr.dtype() == ext.inner {
117                            *arr.dtype_mut() = dtype.clone();
118                        }
119                    }
120
121                    run_parallel |= matches!(
122                        arr.dtype(),
123                        ArrowDataType::Utf8 | ArrowDataType::Dictionary(_, _, _)
124                    );
125                    Ok(arr)
126                })
127                .collect::<PyResult<Vec<_>>>()?;
128
129            // we parallelize this part because we can have dtypes that are not zero copy
130            // for instance string -> large-utf8
131            // dict encoded to categorical
132            let columns = if run_parallel {
133                py.enter_polars(|| {
134                    POOL.install(|| {
135                        columns
136                            .into_par_iter()
137                            .enumerate()
138                            .map(|(i, arr)| {
139                                let (_, field) = schema.get_at_index(i).unwrap();
140                                let s = unsafe {
141                                    Series::_try_from_arrow_unchecked_with_md(
142                                        field.name.clone(),
143                                        vec![arr],
144                                        field.dtype(),
145                                        field.metadata.as_deref(),
146                                    )
147                                }
148                                .map_err(PyPolarsErr::from)?
149                                .into_column();
150                                Ok(s)
151                            })
152                            .collect::<PyResult<Vec<_>>>()
153                    })
154                })
155            } else {
156                columns
157                    .into_iter()
158                    .enumerate()
159                    .map(|(i, arr)| {
160                        let (_, field) = schema.get_at_index(i).unwrap();
161                        let s = unsafe {
162                            Series::_try_from_arrow_unchecked_with_md(
163                                field.name.clone(),
164                                vec![arr],
165                                field.dtype(),
166                                field.metadata.as_deref(),
167                            )
168                        }
169                        .map_err(PyPolarsErr::from)?
170                        .into_column();
171                        Ok(s)
172                    })
173                    .collect::<PyResult<Vec<_>>>()
174            }?;
175
176            // no need to check as a record batch has the same guarantees
177            Ok(unsafe { DataFrame::new_unchecked_infer_height(columns) })
178        })
179        .collect::<PyResult<Vec<_>>>()?;
180
181    Ok(accumulate_dataframes_vertical_unchecked(dfs))
182}