polars_python/map/
dataframe.rs

1use polars::prelude::*;
2use polars_core::frame::row::{Row, rows_to_schema_first_non_null};
3use polars_core::series::SeriesIter;
4use pyo3::IntoPyObjectExt;
5use pyo3::exceptions::PyValueError;
6use pyo3::prelude::*;
7use pyo3::pybacked::PyBackedStr;
8use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString, PyTuple};
9
10use super::*;
11use crate::PyDataFrame;
12
13/// Create iterators for all the Series in the DataFrame.
14fn get_iters(df: &DataFrame) -> Vec<SeriesIter<'_>> {
15    df.get_columns()
16        .iter()
17        .map(|s| s.as_materialized_series().iter())
18        .collect()
19}
20
21/// Create iterators for all the Series in the DataFrame, skipping the first `n` rows.
22fn get_iters_skip(df: &DataFrame, n: usize) -> Vec<std::iter::Skip<SeriesIter<'_>>> {
23    df.get_columns()
24        .iter()
25        .map(|s| s.as_materialized_series().iter().skip(n))
26        .collect()
27}
28
29// the return type is Union[PySeries, PyDataFrame] and a boolean indicating if it is a dataframe or not
30pub fn apply_lambda_unknown<'py>(
31    df: &DataFrame,
32    py: Python<'py>,
33    lambda: Bound<'py, PyAny>,
34    inference_size: usize,
35) -> PyResult<(PyObject, bool)> {
36    let mut null_count = 0;
37    let mut iters = get_iters(df);
38
39    for _ in 0..df.height() {
40        let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
41        let arg = (PyTuple::new(py, iter)?,);
42        let out = lambda.call1(arg)?;
43
44        if out.is_none() {
45            null_count += 1;
46            continue;
47        } else if out.is_instance_of::<PyBool>() {
48            let first_value = out.extract::<bool>().ok();
49            return Ok((
50                PySeries::new(
51                    apply_lambda_with_bool_out_type(df, py, lambda, null_count, first_value)?
52                        .into_series(),
53                )
54                .into_py_any(py)?,
55                false,
56            ));
57        } else if out.is_instance_of::<PyFloat>() {
58            let first_value = out.extract::<f64>().ok();
59
60            return Ok((
61                PySeries::new(
62                    apply_lambda_with_primitive_out_type::<Float64Type>(
63                        df,
64                        py,
65                        lambda,
66                        null_count,
67                        first_value,
68                    )?
69                    .into_series(),
70                )
71                .into_py_any(py)?,
72                false,
73            ));
74        } else if out.is_instance_of::<PyInt>() {
75            let first_value = out.extract::<i64>().ok();
76            return Ok((
77                PySeries::new(
78                    apply_lambda_with_primitive_out_type::<Int64Type>(
79                        df,
80                        py,
81                        lambda,
82                        null_count,
83                        first_value,
84                    )?
85                    .into_series(),
86                )
87                .into_py_any(py)?,
88                false,
89            ));
90        } else if out.is_instance_of::<PyString>() {
91            let first_value = out.extract::<PyBackedStr>().ok();
92            return Ok((
93                PySeries::new(
94                    apply_lambda_with_string_out_type(df, py, lambda, null_count, first_value)?
95                        .into_series(),
96                )
97                .into_py_any(py)?,
98                false,
99            ));
100        } else if out.hasattr("_s")? {
101            let py_pyseries = out.getattr("_s").unwrap();
102            let series = py_pyseries
103                .extract::<PySeries>()
104                .unwrap()
105                .series
106                .into_inner();
107            let dt = series.dtype();
108            return Ok((
109                PySeries::new(
110                    apply_lambda_with_list_out_type(df, py, lambda, null_count, Some(&series), dt)?
111                        .into_series(),
112                )
113                .into_py_any(py)?,
114                false,
115            ));
116        } else if out.extract::<Wrap<Row<'static>>>().is_ok() {
117            let first_value = out.extract::<Wrap<Row<'static>>>().unwrap().0;
118            return Ok((
119                PyDataFrame::from(
120                    apply_lambda_with_rows_output(
121                        df,
122                        py,
123                        lambda,
124                        null_count,
125                        first_value,
126                        inference_size,
127                    )
128                    .map_err(PyPolarsErr::from)?,
129                )
130                .into_py_any(py)?,
131                true,
132            ));
133        } else if out.is_instance_of::<PyList>() || out.is_instance_of::<PyTuple>() {
134            return Err(PyPolarsErr::Other(
135                "A list output type is invalid. Do you mean to create polars List Series?\
136Then return a Series object."
137                    .into(),
138            )
139            .into());
140        } else {
141            return Err(PyPolarsErr::Other("Could not determine output type".into()).into());
142        }
143    }
144    Err(PyPolarsErr::Other("Could not determine output type".into()).into())
145}
146
147fn apply_iter<'py, T>(
148    df: &DataFrame,
149    py: Python<'py>,
150    lambda: Bound<'py, PyAny>,
151    init_null_count: usize,
152    skip: usize,
153) -> impl Iterator<Item = PyResult<Option<T>>>
154where
155    T: FromPyObject<'py>,
156{
157    let mut iters = get_iters_skip(df, init_null_count + skip);
158    ((init_null_count + skip)..df.height()).map(move |_| {
159        let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
160        let tpl = (PyTuple::new(py, iter).unwrap(),);
161        lambda.call1(tpl).map(|v| v.extract().ok())
162    })
163}
164
165/// Apply a lambda with a primitive output type
166pub fn apply_lambda_with_primitive_out_type<'py, D>(
167    df: &DataFrame,
168    py: Python<'py>,
169    lambda: Bound<'py, PyAny>,
170    init_null_count: usize,
171    first_value: Option<D::Native>,
172) -> PyResult<ChunkedArray<D>>
173where
174    D: PyPolarsNumericType,
175    D::Native: IntoPyObject<'py> + FromPyObject<'py>,
176{
177    let skip = usize::from(first_value.is_some());
178    if init_null_count == df.height() {
179        Ok(ChunkedArray::full_null(
180            PlSmallStr::from_static("map"),
181            df.height(),
182        ))
183    } else {
184        let iter = apply_iter(df, py, lambda, init_null_count, skip);
185        iterator_to_primitive(
186            iter,
187            init_null_count,
188            first_value,
189            PlSmallStr::from_static("map"),
190            df.height(),
191        )
192    }
193}
194
195/// Apply a lambda with a boolean output type
196pub fn apply_lambda_with_bool_out_type(
197    df: &DataFrame,
198    py: Python<'_>,
199    lambda: Bound<'_, PyAny>,
200    init_null_count: usize,
201    first_value: Option<bool>,
202) -> PyResult<ChunkedArray<BooleanType>> {
203    let skip = usize::from(first_value.is_some());
204    if init_null_count == df.height() {
205        Ok(ChunkedArray::full_null(
206            PlSmallStr::from_static("map"),
207            df.height(),
208        ))
209    } else {
210        let iter = apply_iter(df, py, lambda, init_null_count, skip);
211        iterator_to_bool(
212            iter,
213            init_null_count,
214            first_value,
215            PlSmallStr::from_static("map"),
216            df.height(),
217        )
218    }
219}
220
221/// Apply a lambda with string output type
222pub fn apply_lambda_with_string_out_type(
223    df: &DataFrame,
224    py: Python<'_>,
225    lambda: Bound<'_, PyAny>,
226    init_null_count: usize,
227    first_value: Option<PyBackedStr>,
228) -> PyResult<StringChunked> {
229    let skip = usize::from(first_value.is_some());
230    if init_null_count == df.height() {
231        Ok(ChunkedArray::full_null(
232            PlSmallStr::from_static("map"),
233            df.height(),
234        ))
235    } else {
236        let iter = apply_iter::<PyBackedStr>(df, py, lambda, init_null_count, skip);
237        iterator_to_string(
238            iter,
239            init_null_count,
240            first_value,
241            PlSmallStr::from_static("map"),
242            df.height(),
243        )
244    }
245}
246
247/// Apply a lambda with list output type
248pub fn apply_lambda_with_list_out_type(
249    df: &DataFrame,
250    py: Python<'_>,
251    lambda: Bound<'_, PyAny>,
252    init_null_count: usize,
253    first_value: Option<&Series>,
254    dt: &DataType,
255) -> PyResult<ListChunked> {
256    let skip = usize::from(first_value.is_some());
257    if init_null_count == df.height() {
258        Ok(ChunkedArray::full_null(
259            PlSmallStr::from_static("map"),
260            df.height(),
261        ))
262    } else {
263        let mut iters = get_iters_skip(df, init_null_count + skip);
264        let iter = ((init_null_count + skip)..df.height()).map(|_| {
265            let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
266            let tpl = (PyTuple::new(py, iter).unwrap(),);
267            let val = lambda.call1(tpl)?;
268            match val.getattr("_s") {
269                Ok(val) => val
270                    .extract::<PySeries>()
271                    .map(|s| Some(s.series.into_inner())),
272                Err(_) => {
273                    if val.is_none() {
274                        Ok(None)
275                    } else {
276                        Err(PyValueError::new_err(format!(
277                            "should return a Series, got a {val:?}"
278                        )))
279                    }
280                },
281            }
282        });
283        iterator_to_list(
284            dt,
285            iter,
286            init_null_count,
287            first_value,
288            PlSmallStr::from_static("map"),
289            df.height(),
290        )
291    }
292}
293
294pub fn apply_lambda_with_rows_output(
295    df: &DataFrame,
296    py: Python<'_>,
297    lambda: Bound<'_, PyAny>,
298    init_null_count: usize,
299    first_value: Row<'static>,
300    inference_size: usize,
301) -> PolarsResult<DataFrame> {
302    let width = first_value.0.len();
303    let null_row = Row::new(vec![AnyValue::Null; width]);
304
305    let mut row_buf = Row::default();
306
307    let skip = 1;
308    let mut iters = get_iters_skip(df, init_null_count + skip);
309    let mut row_iter = ((init_null_count + skip)..df.height()).map(|_| {
310        let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
311        let tpl = (PyTuple::new(py, iter).unwrap(),);
312
313        let return_val = lambda.call1(tpl) ?;
314        if return_val.is_none() {
315            Ok(&null_row)
316        } else {
317            let tuple = return_val.downcast::<PyTuple>().map_err(|_| polars_err!(ComputeError: format!("expected tuple, got {}", return_val.get_type().qualname().unwrap())))?;
318            row_buf.0.clear();
319            for v in tuple {
320                let v = v.extract::<Wrap<AnyValue>>().unwrap().0;
321                row_buf.0.push(v);
322            }
323            let ptr = &row_buf as *const Row;
324            // SAFETY:
325            // we know that row constructor of polars dataframe does not keep a reference
326            // to the row. Before we mutate the row buf again, the reference is dropped.
327            // we only cannot prove it to the compiler.
328            // we still to this because it save a Vec allocation in a hot loop.
329            Ok(unsafe { &*ptr })
330        }
331    });
332
333    // first rows for schema inference
334    let mut buf = Vec::with_capacity(inference_size);
335    buf.push(first_value);
336    for v in (&mut row_iter).take(inference_size) {
337        buf.push(v?.clone());
338    }
339
340    let schema = rows_to_schema_first_non_null(&buf, Some(50))?;
341
342    if init_null_count > 0 {
343        // SAFETY: we know the iterators size
344        let iter = unsafe {
345            (0..init_null_count)
346                .map(|_| Ok(&null_row))
347                .chain(buf.iter().map(Ok))
348                .chain(row_iter)
349                .trust_my_length(df.height())
350        };
351        DataFrame::try_from_rows_iter_and_schema(iter, &schema)
352    } else {
353        // SAFETY: we know the iterators size
354        let iter = unsafe {
355            buf.iter()
356                .map(Ok)
357                .chain(row_iter)
358                .trust_my_length(df.height())
359        };
360        DataFrame::try_from_rows_iter_and_schema(iter, &schema)
361    }
362}