polars_python/interop/arrow/
to_rust.rs1use polars_core::POOL;
2use polars_core::prelude::*;
3use polars_core::utils::accumulate_dataframes_vertical_unchecked;
4use polars_core::utils::arrow::ffi;
5use pyo3::ffi::Py_uintptr_t;
6use pyo3::prelude::*;
7use pyo3::types::PyList;
8use rayon::prelude::*;
9
10use crate::error::PyPolarsErr;
11use crate::utils::EnterPolarsExt;
12
13pub fn field_to_rust_arrow(obj: Bound<'_, PyAny>) -> PyResult<ArrowField> {
14 let mut schema = Box::new(ffi::ArrowSchema::empty());
15 let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema;
16
17 obj.call_method1("_export_to_c", (schema_ptr as Py_uintptr_t,))?;
19 let field = unsafe { ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)? };
20 Ok(normalize_arrow_fields(&field))
21}
22
23fn normalize_arrow_fields(field: &ArrowField) -> ArrowField {
24 match field {
27 ArrowField {
28 dtype: ArrowDataType::Struct(fields),
29 ..
30 } => {
31 let mut normalized = false;
32 let normalized_fields: Vec<_> = fields
33 .iter()
34 .map(|f| {
35 if let ArrowDataType::Extension(ext_type) = &f.dtype {
38 if ext_type.name.starts_with("google:sqlType:") {
39 normalized = true;
40 return ArrowField::new(
41 f.name.clone(),
42 ext_type.inner.clone(),
43 f.is_nullable,
44 );
45 }
46 }
47 f.clone()
48 })
49 .collect();
50
51 if normalized {
52 ArrowField::new(
53 field.name.clone(),
54 ArrowDataType::Struct(normalized_fields),
55 field.is_nullable,
56 )
57 } else {
58 field.clone()
59 }
60 },
61 _ => field.clone(),
62 }
63}
64
65pub fn field_to_rust(obj: Bound<'_, PyAny>) -> PyResult<Field> {
66 field_to_rust_arrow(obj).map(|f| (&f).into())
67}
68
69pub fn pyarrow_schema_to_rust(obj: &Bound<'_, PyList>) -> PyResult<Schema> {
71 obj.into_iter().map(field_to_rust).collect()
72}
73
74pub fn array_to_rust(obj: &Bound<PyAny>) -> PyResult<ArrayRef> {
75 let mut array = Box::new(ffi::ArrowArray::empty());
77 let mut schema = Box::new(ffi::ArrowSchema::empty());
78
79 let array_ptr = array.as_mut() as *mut ffi::ArrowArray;
80 let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema;
81
82 obj.call_method1(
85 "_export_to_c",
86 (array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t),
87 )?;
88
89 unsafe {
90 let field = ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)?;
91 let array = ffi::import_array_from_c(*array, field.dtype).map_err(PyPolarsErr::from)?;
92 Ok(array)
93 }
94}
95
96pub fn to_rust_df(
97 py: Python<'_>,
98 rb: &[Bound<PyAny>],
99 schema: Bound<PyAny>,
100) -> PyResult<DataFrame> {
101 let ArrowDataType::Struct(fields) = field_to_rust_arrow(schema)?.dtype else {
102 return Err(PyPolarsErr::Other("invalid top-level schema".into()).into());
103 };
104
105 let schema = ArrowSchema::from_iter(fields.iter().cloned());
106
107 if schema.len() != fields.len() {
110 let mut field_map: PlHashMap<PlSmallStr, u64> = PlHashMap::with_capacity(fields.len());
111 fields.iter().for_each(|field| {
112 field_map
113 .entry(field.name.clone())
114 .and_modify(|c| {
115 *c += 1;
116 })
117 .or_insert(1);
118 });
119 let duplicate_fields: Vec<_> = field_map
120 .into_iter()
121 .filter_map(|(k, v)| (v > 1).then_some(k))
122 .collect();
123
124 return Err(PyPolarsErr::Polars(PolarsError::Duplicate(
125 format!(
126 "column appears more than once; names must be unique: {:?}",
127 duplicate_fields
128 )
129 .into(),
130 ))
131 .into());
132 }
133
134 if rb.is_empty() {
135 let columns = schema
136 .iter_values()
137 .map(|field| {
138 let field = Field::from(field);
139 Series::new_empty(field.name, &field.dtype).into_column()
140 })
141 .collect::<Vec<_>>();
142
143 return Ok(unsafe { DataFrame::new_no_checks_height_from_first(columns) });
145 }
146
147 let dfs = rb
148 .iter()
149 .map(|rb| {
150 let mut run_parallel = false;
151
152 let columns = (0..schema.len())
153 .map(|i| {
154 let array = rb.call_method1("column", (i,))?;
155 let arr = array_to_rust(&array)?;
156 run_parallel |= matches!(
157 arr.dtype(),
158 ArrowDataType::Utf8 | ArrowDataType::Dictionary(_, _, _)
159 );
160 Ok(arr)
161 })
162 .collect::<PyResult<Vec<_>>>()?;
163
164 let columns = if run_parallel {
168 py.enter_polars(|| {
169 POOL.install(|| {
170 columns
171 .into_par_iter()
172 .enumerate()
173 .map(|(i, arr)| {
174 let (_, field) = schema.get_at_index(i).unwrap();
175 let s = unsafe {
176 Series::_try_from_arrow_unchecked_with_md(
177 field.name.clone(),
178 vec![arr],
179 field.dtype(),
180 field.metadata.as_deref(),
181 )
182 }
183 .map_err(PyPolarsErr::from)?
184 .into_column();
185 Ok(s)
186 })
187 .collect::<PyResult<Vec<_>>>()
188 })
189 })
190 } else {
191 columns
192 .into_iter()
193 .enumerate()
194 .map(|(i, arr)| {
195 let (_, field) = schema.get_at_index(i).unwrap();
196 let s = unsafe {
197 Series::_try_from_arrow_unchecked_with_md(
198 field.name.clone(),
199 vec![arr],
200 field.dtype(),
201 field.metadata.as_deref(),
202 )
203 }
204 .map_err(PyPolarsErr::from)?
205 .into_column();
206 Ok(s)
207 })
208 .collect::<PyResult<Vec<_>>>()
209 }?;
210
211 Ok(unsafe { DataFrame::new_no_checks_height_from_first(columns) })
213 })
214 .collect::<PyResult<Vec<_>>>()?;
215
216 Ok(accumulate_dataframes_vertical_unchecked(dfs))
217}