polars_python/interop/arrow/
to_rust.rs1use polars_core::POOL;
2use polars_core::prelude::*;
3use polars_core::utils::accumulate_dataframes_vertical_unchecked;
4use polars_core::utils::arrow::ffi;
5use pyo3::ffi::Py_uintptr_t;
6use pyo3::prelude::*;
7use pyo3::types::PyList;
8use rayon::prelude::*;
9
10use crate::error::PyPolarsErr;
11use crate::utils::EnterPolarsExt;
12
13pub fn field_to_rust_arrow(obj: Bound<'_, PyAny>) -> PyResult<ArrowField> {
14 let mut schema = Box::new(ffi::ArrowSchema::empty());
15 let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema;
16
17 obj.call_method1("_export_to_c", (schema_ptr as Py_uintptr_t,))?;
19 let field = unsafe { ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)? };
20 Ok(field)
21}
22
23pub fn field_to_rust(obj: Bound<'_, PyAny>) -> PyResult<Field> {
24 field_to_rust_arrow(obj).map(|f| (&f).into())
25}
26
27pub fn pyarrow_schema_to_rust(obj: &Bound<'_, PyList>) -> PyResult<Schema> {
29 obj.into_iter().map(field_to_rust).collect()
30}
31
32pub fn array_to_rust(obj: &Bound<PyAny>) -> PyResult<ArrayRef> {
33 let mut array = Box::new(ffi::ArrowArray::empty());
35 let mut schema = Box::new(ffi::ArrowSchema::empty());
36
37 let array_ptr = array.as_mut() as *mut ffi::ArrowArray;
38 let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema;
39
40 obj.call_method1(
43 "_export_to_c",
44 (array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t),
45 )?;
46
47 unsafe {
48 let field = ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)?;
49 let array = ffi::import_array_from_c(*array, field.dtype).map_err(PyPolarsErr::from)?;
50 Ok(array)
51 }
52}
53
54pub fn to_rust_df(
55 py: Python<'_>,
56 rb: &[Bound<PyAny>],
57 schema: Bound<PyAny>,
58) -> PyResult<DataFrame> {
59 let ArrowDataType::Struct(fields) = field_to_rust_arrow(schema)?.dtype else {
60 return Err(PyPolarsErr::Other("invalid top-level schema".into()).into());
61 };
62
63 let schema = ArrowSchema::from_iter(fields.iter().cloned());
64
65 if schema.len() != fields.len() {
68 let mut field_map: PlHashMap<PlSmallStr, u64> = PlHashMap::with_capacity(fields.len());
69 fields.iter().for_each(|field| {
70 field_map
71 .entry(field.name.clone())
72 .and_modify(|c| {
73 *c += 1;
74 })
75 .or_insert(1);
76 });
77 let duplicate_fields: Vec<_> = field_map
78 .into_iter()
79 .filter_map(|(k, v)| (v > 1).then_some(k))
80 .collect();
81
82 return Err(PyPolarsErr::Polars(PolarsError::Duplicate(
83 format!("column appears more than once; names must be unique: {duplicate_fields:?}")
84 .into(),
85 ))
86 .into());
87 }
88
89 if rb.is_empty() {
90 let columns = schema
91 .iter_values()
92 .map(|field| {
93 let field = Field::from(field);
94 Series::new_empty(field.name, &field.dtype).into_column()
95 })
96 .collect::<Vec<_>>();
97
98 return Ok(unsafe { DataFrame::new_unchecked_infer_height(columns) });
100 }
101
102 let dfs = rb
103 .iter()
104 .map(|rb| {
105 let mut run_parallel = false;
106
107 let columns = (0..schema.len())
108 .map(|i| {
109 let array = rb.call_method1("column", (i,))?;
110 let mut arr = array_to_rust(&array)?;
111
112 let dtype = schema.get_at_index(i).unwrap().1.dtype();
115 if let ArrowDataType::Extension(ext) = dtype {
116 if *arr.dtype() == ext.inner {
117 *arr.dtype_mut() = dtype.clone();
118 }
119 }
120
121 run_parallel |= matches!(
122 arr.dtype(),
123 ArrowDataType::Utf8 | ArrowDataType::Dictionary(_, _, _)
124 );
125 Ok(arr)
126 })
127 .collect::<PyResult<Vec<_>>>()?;
128
129 let columns = if run_parallel {
133 py.enter_polars(|| {
134 POOL.install(|| {
135 columns
136 .into_par_iter()
137 .enumerate()
138 .map(|(i, arr)| {
139 let (_, field) = schema.get_at_index(i).unwrap();
140 let s = unsafe {
141 Series::_try_from_arrow_unchecked_with_md(
142 field.name.clone(),
143 vec![arr],
144 field.dtype(),
145 field.metadata.as_deref(),
146 )
147 }
148 .map_err(PyPolarsErr::from)?
149 .into_column();
150 Ok(s)
151 })
152 .collect::<PyResult<Vec<_>>>()
153 })
154 })
155 } else {
156 columns
157 .into_iter()
158 .enumerate()
159 .map(|(i, arr)| {
160 let (_, field) = schema.get_at_index(i).unwrap();
161 let s = unsafe {
162 Series::_try_from_arrow_unchecked_with_md(
163 field.name.clone(),
164 vec![arr],
165 field.dtype(),
166 field.metadata.as_deref(),
167 )
168 }
169 .map_err(PyPolarsErr::from)?
170 .into_column();
171 Ok(s)
172 })
173 .collect::<PyResult<Vec<_>>>()
174 }?;
175
176 Ok(unsafe { DataFrame::new_unchecked_infer_height(columns) })
178 })
179 .collect::<PyResult<Vec<_>>>()?;
180
181 Ok(accumulate_dataframes_vertical_unchecked(dfs))
182}