polars_python/map/
mod.rs

1pub mod dataframe;
2pub mod lazy;
3pub mod series;
4
5use std::collections::BTreeMap;
6
7use arrow::bitmap::BitmapBuilder;
8use polars::chunked_array::builder::get_list_builder;
9use polars::prelude::*;
10use polars_core::POOL;
11use polars_core::utils::CustomIterTools;
12use polars_utils::pl_str::PlSmallStr;
13use pyo3::prelude::*;
14use pyo3::pybacked::PyBackedStr;
15use pyo3::types::PyDict;
16use rayon::prelude::*;
17
18use crate::error::PyPolarsErr;
19use crate::prelude::ObjectValue;
20use crate::utils::EnterPolarsExt;
21use crate::{PySeries, Wrap};
22
23pub trait PyPolarsNumericType: PolarsNumericType {}
24
25impl PyPolarsNumericType for UInt8Type {}
26impl PyPolarsNumericType for UInt16Type {}
27impl PyPolarsNumericType for UInt32Type {}
28impl PyPolarsNumericType for UInt64Type {}
29impl PyPolarsNumericType for Int8Type {}
30impl PyPolarsNumericType for Int16Type {}
31impl PyPolarsNumericType for Int32Type {}
32impl PyPolarsNumericType for Int64Type {}
33impl PyPolarsNumericType for Int128Type {}
34impl PyPolarsNumericType for Float32Type {}
35impl PyPolarsNumericType for Float64Type {}
36
37fn iterator_to_struct<'py>(
38    py: Python<'py>,
39    it: impl Iterator<Item = PyResult<Option<Bound<'py, PyAny>>>>,
40    init_null_count: usize,
41    first_value: AnyValue<'static>,
42    name: PlSmallStr,
43    capacity: usize,
44) -> PyResult<PySeries> {
45    let (vals, flds) = match &first_value {
46        av @ AnyValue::Struct(_, _, flds) => (av._iter_struct_av().collect::<Vec<_>>(), &**flds),
47        AnyValue::StructOwned(payload) => (payload.0.clone(), &*payload.1),
48        _ => {
49            return Err(crate::exceptions::ComputeError::new_err(format!(
50                "expected struct got {first_value:?}",
51            )));
52        },
53    };
54
55    // Every item in the struct is kept as its own buffer of AnyValues.
56    // So a struct with 2 items: {a, b} will have:
57    // [
58    //      [ a values ]
59    //      [ b values ]
60    // ]
61    let mut struct_fields: BTreeMap<PlSmallStr, Vec<AnyValue>> = BTreeMap::new();
62
63    // As a BTreeMap sorts its keys, we also need to track the original
64    // order of the field names.
65    let mut field_names_ordered: Vec<PlSmallStr> = Vec::with_capacity(flds.len());
66
67    // Use the first value and the known null count to initialize the buffers
68    // if we find a new key later on, we make a new entry in the BTree.
69    for (value, fld) in vals.into_iter().zip(flds) {
70        let mut buf = Vec::with_capacity(capacity);
71        buf.extend((0..init_null_count).map(|_| AnyValue::Null));
72        buf.push(value);
73        field_names_ordered.push(fld.name().clone());
74        struct_fields.insert(fld.name().clone(), buf);
75    }
76
77    let mut validity = BitmapBuilder::with_capacity(capacity);
78    validity.extend_constant(init_null_count, false);
79    validity.push(true);
80
81    for dict in it {
82        match dict? {
83            None => {
84                validity.push(false);
85                for field_items in struct_fields.values_mut() {
86                    field_items.push(AnyValue::Null);
87                }
88            },
89            Some(dict) => {
90                validity.push(true);
91                let dict = dict.downcast::<PyDict>()?;
92                let current_len = struct_fields
93                    .values()
94                    .next()
95                    .map(|buf| buf.len())
96                    .unwrap_or(0);
97
98                // We ignore the keys of the rest of the dicts,
99                // the first item determines the output name.
100                for (key, val) in dict.iter() {
101                    let key = key.str().unwrap().extract::<PyBackedStr>().unwrap();
102                    let item = val.extract::<Wrap<AnyValue>>()?;
103                    if let Some(buf) = struct_fields.get_mut(&*key) {
104                        buf.push(item.0);
105                    } else {
106                        let mut buf = Vec::with_capacity(capacity);
107                        buf.extend((0..init_null_count + current_len).map(|_| AnyValue::Null));
108                        buf.push(item.0);
109                        let key: PlSmallStr = (&*key).into();
110                        field_names_ordered.push(key.clone());
111                        struct_fields.insert(key, buf);
112                    };
113                }
114
115                // Add nulls to keys that were not in the dict.
116                if dict.len() < struct_fields.len() {
117                    let current_len = current_len + 1;
118                    for buf in struct_fields.values_mut() {
119                        if buf.len() < current_len {
120                            buf.push(AnyValue::Null)
121                        }
122                    }
123                }
124            },
125        }
126    }
127
128    let fields = py.enter_polars_ok(|| {
129        POOL.install(|| {
130            field_names_ordered
131                .par_iter()
132                .map(|name| Series::new(name.clone(), struct_fields.get(name).unwrap()))
133                .collect::<Vec<_>>()
134        })
135    })?;
136
137    Ok(
138        StructChunked::from_series(name, fields[0].len(), fields.iter())
139            .unwrap()
140            .with_outer_validity(validity.into_opt_validity())
141            .into_series()
142            .into(),
143    )
144}
145
146fn iterator_to_primitive<T>(
147    it: impl Iterator<Item = PyResult<Option<T::Native>>>,
148    init_null_count: usize,
149    first_value: Option<T::Native>,
150    name: PlSmallStr,
151    capacity: usize,
152) -> PyResult<ChunkedArray<T>>
153where
154    T: PyPolarsNumericType,
155{
156    let mut error = None;
157    // SAFETY: we know the iterators len.
158    let ca: ChunkedArray<T> = unsafe {
159        if init_null_count > 0 {
160            (0..init_null_count)
161                .map(|_| Ok(None))
162                .chain(std::iter::once(Ok(first_value)))
163                .chain(it)
164                .trust_my_length(capacity)
165                .map(|v| catch_err(&mut error, v))
166                .collect_trusted()
167        } else if first_value.is_some() {
168            std::iter::once(Ok(first_value))
169                .chain(it)
170                .trust_my_length(capacity)
171                .map(|v| catch_err(&mut error, v))
172                .collect_trusted()
173        } else {
174            it.map(|v| catch_err(&mut error, v)).collect()
175        }
176    };
177    debug_assert_eq!(ca.len(), capacity);
178
179    if let Some(err) = error {
180        let _ = err?;
181    }
182    Ok(ca.with_name(name))
183}
184
185fn iterator_to_bool(
186    it: impl Iterator<Item = PyResult<Option<bool>>>,
187    init_null_count: usize,
188    first_value: Option<bool>,
189    name: PlSmallStr,
190    capacity: usize,
191) -> PyResult<ChunkedArray<BooleanType>> {
192    let mut error = None;
193    // SAFETY: we know the iterators len.
194    let ca: BooleanChunked = unsafe {
195        if init_null_count > 0 {
196            (0..init_null_count)
197                .map(|_| Ok(None))
198                .chain(std::iter::once(Ok(first_value)))
199                .chain(it)
200                .trust_my_length(capacity)
201                .map(|v| catch_err(&mut error, v))
202                .collect_trusted()
203        } else if first_value.is_some() {
204            std::iter::once(Ok(first_value))
205                .chain(it)
206                .trust_my_length(capacity)
207                .map(|v| catch_err(&mut error, v))
208                .collect_trusted()
209        } else {
210            it.map(|v| catch_err(&mut error, v)).collect()
211        }
212    };
213    if let Some(err) = error {
214        let _ = err?;
215    }
216    debug_assert_eq!(ca.len(), capacity);
217    Ok(ca.with_name(name))
218}
219
220#[cfg(feature = "object")]
221fn iterator_to_object(
222    it: impl Iterator<Item = PyResult<Option<ObjectValue>>>,
223    init_null_count: usize,
224    first_value: Option<ObjectValue>,
225    name: PlSmallStr,
226    capacity: usize,
227) -> PyResult<ObjectChunked<ObjectValue>> {
228    let mut error = None;
229    // SAFETY: we know the iterators len.
230    let ca: ObjectChunked<ObjectValue> = unsafe {
231        if init_null_count > 0 {
232            (0..init_null_count)
233                .map(|_| Ok(None))
234                .chain(std::iter::once(Ok(first_value)))
235                .chain(it)
236                .map(|v| catch_err(&mut error, v))
237                .trust_my_length(capacity)
238                .collect_trusted()
239        } else if first_value.is_some() {
240            std::iter::once(Ok(first_value))
241                .chain(it)
242                .map(|v| catch_err(&mut error, v))
243                .trust_my_length(capacity)
244                .collect_trusted()
245        } else {
246            it.map(|v| catch_err(&mut error, v)).collect()
247        }
248    };
249    if let Some(err) = error {
250        let _ = err?;
251    }
252    debug_assert_eq!(ca.len(), capacity);
253    Ok(ca.with_name(name))
254}
255
256fn catch_err<K>(error: &mut Option<PyResult<Option<K>>>, result: PyResult<Option<K>>) -> Option<K> {
257    match result {
258        Ok(item) => item,
259        err => {
260            if error.is_none() {
261                *error = Some(err);
262            }
263            None
264        },
265    }
266}
267
268fn iterator_to_string<S: AsRef<str>>(
269    it: impl Iterator<Item = PyResult<Option<S>>>,
270    init_null_count: usize,
271    first_value: Option<S>,
272    name: PlSmallStr,
273    capacity: usize,
274) -> PyResult<StringChunked> {
275    let mut error = None;
276    // SAFETY: we know the iterators len.
277    let ca: StringChunked = unsafe {
278        if init_null_count > 0 {
279            (0..init_null_count)
280                .map(|_| Ok(None))
281                .chain(std::iter::once(Ok(first_value)))
282                .trust_my_length(capacity)
283                .map(|v| catch_err(&mut error, v))
284                .collect_trusted()
285        } else if first_value.is_some() {
286            std::iter::once(Ok(first_value))
287                .chain(it)
288                .trust_my_length(capacity)
289                .map(|v| catch_err(&mut error, v))
290                .collect_trusted()
291        } else {
292            it.map(|v| catch_err(&mut error, v)).collect()
293        }
294    };
295    debug_assert_eq!(ca.len(), capacity);
296    if let Some(err) = error {
297        let _ = err?;
298    }
299    Ok(ca.with_name(name))
300}
301
302fn iterator_to_list(
303    dt: &DataType,
304    it: impl Iterator<Item = PyResult<Option<Series>>>,
305    init_null_count: usize,
306    first_value: Option<&Series>,
307    name: PlSmallStr,
308    capacity: usize,
309) -> PyResult<ListChunked> {
310    let mut builder = get_list_builder(dt, capacity * 5, capacity, name);
311    for _ in 0..init_null_count {
312        builder.append_null()
313    }
314    if first_value.is_some() {
315        builder
316            .append_opt_series(first_value)
317            .map_err(PyPolarsErr::from)?;
318    }
319    for opt_val in it {
320        match opt_val? {
321            None => builder.append_null(),
322            Some(s) => {
323                if s.is_empty() && s.dtype() != dt {
324                    builder
325                        .append_series(&Series::full_null(PlSmallStr::EMPTY, 0, dt))
326                        .unwrap()
327                } else {
328                    builder.append_series(&s).map_err(PyPolarsErr::from)?
329                }
330            },
331        }
332    }
333    Ok(builder.finish())
334}