polars_python/interop/arrow/
to_rust.rs

1use polars_core::POOL;
2use polars_core::prelude::*;
3use polars_core::utils::accumulate_dataframes_vertical_unchecked;
4use polars_core::utils::arrow::ffi;
5use pyo3::ffi::Py_uintptr_t;
6use pyo3::prelude::*;
7use pyo3::types::PyList;
8use rayon::prelude::*;
9
10use crate::error::PyPolarsErr;
11use crate::utils::EnterPolarsExt;
12
13pub fn field_to_rust_arrow(obj: Bound<'_, PyAny>) -> PyResult<ArrowField> {
14    let mut schema = Box::new(ffi::ArrowSchema::empty());
15    let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema;
16
17    // make the conversion through PyArrow's private API
18    obj.call_method1("_export_to_c", (schema_ptr as Py_uintptr_t,))?;
19    let field = unsafe { ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)? };
20    Ok(normalize_arrow_fields(&field))
21}
22
23fn normalize_arrow_fields(field: &ArrowField) -> ArrowField {
24    // normalize fields with extension dtypes that are otherwise standard dtypes associated
25    // with (for us) irrelevant metadata; recreate the field using the inner (standard) dtype
26    match field {
27        ArrowField {
28            dtype: ArrowDataType::Struct(fields),
29            ..
30        } => {
31            let mut normalized = false;
32            let normalized_fields: Vec<_> = fields
33                .iter()
34                .map(|f| {
35                    // note: google bigquery column data is returned as a standard arrow dtype, but the
36                    // sql type it was loaded from is associated as metadata (resulting in an extension dtype)
37                    if let ArrowDataType::Extension(ext_type) = &f.dtype {
38                        if ext_type.name.starts_with("google:sqlType:") {
39                            normalized = true;
40                            return ArrowField::new(
41                                f.name.clone(),
42                                ext_type.inner.clone(),
43                                f.is_nullable,
44                            );
45                        }
46                    }
47                    f.clone()
48                })
49                .collect();
50
51            if normalized {
52                ArrowField::new(
53                    field.name.clone(),
54                    ArrowDataType::Struct(normalized_fields),
55                    field.is_nullable,
56                )
57            } else {
58                field.clone()
59            }
60        },
61        _ => field.clone(),
62    }
63}
64
65pub fn field_to_rust(obj: Bound<'_, PyAny>) -> PyResult<Field> {
66    field_to_rust_arrow(obj).map(|f| (&f).into())
67}
68
69// PyList<Field> which you get by calling `list(schema)`
70pub fn pyarrow_schema_to_rust(obj: &Bound<'_, PyList>) -> PyResult<Schema> {
71    obj.into_iter().map(field_to_rust).collect()
72}
73
74pub fn array_to_rust(obj: &Bound<PyAny>) -> PyResult<ArrayRef> {
75    // prepare a pointer to receive the Array struct
76    let mut array = Box::new(ffi::ArrowArray::empty());
77    let mut schema = Box::new(ffi::ArrowSchema::empty());
78
79    let array_ptr = array.as_mut() as *mut ffi::ArrowArray;
80    let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema;
81
82    // make the conversion through PyArrow's private API
83    // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds
84    obj.call_method1(
85        "_export_to_c",
86        (array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t),
87    )?;
88
89    unsafe {
90        let field = ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)?;
91        let array = ffi::import_array_from_c(*array, field.dtype).map_err(PyPolarsErr::from)?;
92        Ok(array)
93    }
94}
95
96pub fn to_rust_df(
97    py: Python<'_>,
98    rb: &[Bound<PyAny>],
99    schema: Bound<PyAny>,
100) -> PyResult<DataFrame> {
101    let ArrowDataType::Struct(fields) = field_to_rust_arrow(schema)?.dtype else {
102        return Err(PyPolarsErr::Other("invalid top-level schema".into()).into());
103    };
104
105    let schema = ArrowSchema::from_iter(fields.iter().cloned());
106
107    // Verify that field names are not duplicated. Arrow permits duplicate field names, we do not.
108    // Required to uphold safety invariants for unsafe block below.
109    if schema.len() != fields.len() {
110        let mut field_map: PlHashMap<PlSmallStr, u64> = PlHashMap::with_capacity(fields.len());
111        fields.iter().for_each(|field| {
112            field_map
113                .entry(field.name.clone())
114                .and_modify(|c| {
115                    *c += 1;
116                })
117                .or_insert(1);
118        });
119        let duplicate_fields: Vec<_> = field_map
120            .into_iter()
121            .filter_map(|(k, v)| (v > 1).then_some(k))
122            .collect();
123
124        return Err(PyPolarsErr::Polars(PolarsError::Duplicate(
125            format!(
126                "column appears more than once; names must be unique: {:?}",
127                duplicate_fields
128            )
129            .into(),
130        ))
131        .into());
132    }
133
134    if rb.is_empty() {
135        let columns = schema
136            .iter_values()
137            .map(|field| {
138                let field = Field::from(field);
139                Series::new_empty(field.name, &field.dtype).into_column()
140            })
141            .collect::<Vec<_>>();
142
143        // no need to check as a record batch has the same guarantees
144        return Ok(unsafe { DataFrame::new_no_checks_height_from_first(columns) });
145    }
146
147    let dfs = rb
148        .iter()
149        .map(|rb| {
150            let mut run_parallel = false;
151
152            let columns = (0..schema.len())
153                .map(|i| {
154                    let array = rb.call_method1("column", (i,))?;
155                    let arr = array_to_rust(&array)?;
156                    run_parallel |= matches!(
157                        arr.dtype(),
158                        ArrowDataType::Utf8 | ArrowDataType::Dictionary(_, _, _)
159                    );
160                    Ok(arr)
161                })
162                .collect::<PyResult<Vec<_>>>()?;
163
164            // we parallelize this part because we can have dtypes that are not zero copy
165            // for instance string -> large-utf8
166            // dict encoded to categorical
167            let columns = if run_parallel {
168                py.enter_polars(|| {
169                    POOL.install(|| {
170                        columns
171                            .into_par_iter()
172                            .enumerate()
173                            .map(|(i, arr)| {
174                                let (_, field) = schema.get_at_index(i).unwrap();
175                                let s = unsafe {
176                                    Series::_try_from_arrow_unchecked_with_md(
177                                        field.name.clone(),
178                                        vec![arr],
179                                        field.dtype(),
180                                        field.metadata.as_deref(),
181                                    )
182                                }
183                                .map_err(PyPolarsErr::from)?
184                                .into_column();
185                                Ok(s)
186                            })
187                            .collect::<PyResult<Vec<_>>>()
188                    })
189                })
190            } else {
191                columns
192                    .into_iter()
193                    .enumerate()
194                    .map(|(i, arr)| {
195                        let (_, field) = schema.get_at_index(i).unwrap();
196                        let s = unsafe {
197                            Series::_try_from_arrow_unchecked_with_md(
198                                field.name.clone(),
199                                vec![arr],
200                                field.dtype(),
201                                field.metadata.as_deref(),
202                            )
203                        }
204                        .map_err(PyPolarsErr::from)?
205                        .into_column();
206                        Ok(s)
207                    })
208                    .collect::<PyResult<Vec<_>>>()
209            }?;
210
211            // no need to check as a record batch has the same guarantees
212            Ok(unsafe { DataFrame::new_no_checks_height_from_first(columns) })
213        })
214        .collect::<PyResult<Vec<_>>>()?;
215
216    Ok(accumulate_dataframes_vertical_unchecked(dfs))
217}