polars_python/map/
dataframe.rs

1use polars::prelude::*;
2use polars_core::frame::row::{Row, rows_to_schema_first_non_null};
3use polars_core::series::SeriesIter;
4use pyo3::IntoPyObjectExt;
5use pyo3::exceptions::PyValueError;
6use pyo3::prelude::*;
7use pyo3::pybacked::PyBackedStr;
8use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString, PyTuple};
9
10use super::*;
11use crate::PyDataFrame;
12
13/// Create iterators for all the Series in the DataFrame.
14fn get_iters(df: &DataFrame) -> Vec<SeriesIter> {
15    df.get_columns()
16        .iter()
17        .map(|s| s.as_materialized_series().iter())
18        .collect()
19}
20
21/// Create iterators for all the Series in the DataFrame, skipping the first `n` rows.
22fn get_iters_skip(df: &DataFrame, n: usize) -> Vec<std::iter::Skip<SeriesIter>> {
23    df.get_columns()
24        .iter()
25        .map(|s| s.as_materialized_series().iter().skip(n))
26        .collect()
27}
28
29// the return type is Union[PySeries, PyDataFrame] and a boolean indicating if it is a dataframe or not
30pub fn apply_lambda_unknown<'py>(
31    df: &DataFrame,
32    py: Python<'py>,
33    lambda: Bound<'py, PyAny>,
34    inference_size: usize,
35) -> PyResult<(PyObject, bool)> {
36    let mut null_count = 0;
37    let mut iters = get_iters(df);
38
39    for _ in 0..df.height() {
40        let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
41        let arg = (PyTuple::new(py, iter)?,);
42        let out = lambda.call1(arg)?;
43
44        if out.is_none() {
45            null_count += 1;
46            continue;
47        } else if out.is_instance_of::<PyBool>() {
48            let first_value = out.extract::<bool>().ok();
49            return Ok((
50                PySeries::new(
51                    apply_lambda_with_bool_out_type(df, py, lambda, null_count, first_value)?
52                        .into_series(),
53                )
54                .into_py_any(py)?,
55                false,
56            ));
57        } else if out.is_instance_of::<PyFloat>() {
58            let first_value = out.extract::<f64>().ok();
59
60            return Ok((
61                PySeries::new(
62                    apply_lambda_with_primitive_out_type::<Float64Type>(
63                        df,
64                        py,
65                        lambda,
66                        null_count,
67                        first_value,
68                    )?
69                    .into_series(),
70                )
71                .into_py_any(py)?,
72                false,
73            ));
74        } else if out.is_instance_of::<PyInt>() {
75            let first_value = out.extract::<i64>().ok();
76            return Ok((
77                PySeries::new(
78                    apply_lambda_with_primitive_out_type::<Int64Type>(
79                        df,
80                        py,
81                        lambda,
82                        null_count,
83                        first_value,
84                    )?
85                    .into_series(),
86                )
87                .into_py_any(py)?,
88                false,
89            ));
90        } else if out.is_instance_of::<PyString>() {
91            let first_value = out.extract::<PyBackedStr>().ok();
92            return Ok((
93                PySeries::new(
94                    apply_lambda_with_string_out_type(df, py, lambda, null_count, first_value)?
95                        .into_series(),
96                )
97                .into_py_any(py)?,
98                false,
99            ));
100        } else if out.hasattr("_s")? {
101            let py_pyseries = out.getattr("_s").unwrap();
102            let series = py_pyseries.extract::<PySeries>().unwrap().series;
103            let dt = series.dtype();
104            return Ok((
105                PySeries::new(
106                    apply_lambda_with_list_out_type(df, py, lambda, null_count, Some(&series), dt)?
107                        .into_series(),
108                )
109                .into_py_any(py)?,
110                false,
111            ));
112        } else if out.extract::<Wrap<Row<'static>>>().is_ok() {
113            let first_value = out.extract::<Wrap<Row<'static>>>().unwrap().0;
114            return Ok((
115                PyDataFrame::from(
116                    apply_lambda_with_rows_output(
117                        df,
118                        py,
119                        lambda,
120                        null_count,
121                        first_value,
122                        inference_size,
123                    )
124                    .map_err(PyPolarsErr::from)?,
125                )
126                .into_py_any(py)?,
127                true,
128            ));
129        } else if out.is_instance_of::<PyList>() || out.is_instance_of::<PyTuple>() {
130            return Err(PyPolarsErr::Other(
131                "A list output type is invalid. Do you mean to create polars List Series?\
132Then return a Series object."
133                    .into(),
134            )
135            .into());
136        } else {
137            return Err(PyPolarsErr::Other("Could not determine output type".into()).into());
138        }
139    }
140    Err(PyPolarsErr::Other("Could not determine output type".into()).into())
141}
142
143fn apply_iter<'py, T>(
144    df: &DataFrame,
145    py: Python<'py>,
146    lambda: Bound<'py, PyAny>,
147    init_null_count: usize,
148    skip: usize,
149) -> impl Iterator<Item = PyResult<Option<T>>>
150where
151    T: FromPyObject<'py>,
152{
153    let mut iters = get_iters_skip(df, init_null_count + skip);
154    ((init_null_count + skip)..df.height()).map(move |_| {
155        let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
156        let tpl = (PyTuple::new(py, iter).unwrap(),);
157        lambda.call1(tpl).map(|v| v.extract().ok())
158    })
159}
160
161/// Apply a lambda with a primitive output type
162pub fn apply_lambda_with_primitive_out_type<'py, D>(
163    df: &DataFrame,
164    py: Python<'py>,
165    lambda: Bound<'py, PyAny>,
166    init_null_count: usize,
167    first_value: Option<D::Native>,
168) -> PyResult<ChunkedArray<D>>
169where
170    D: PyPolarsNumericType,
171    D::Native: IntoPyObject<'py> + FromPyObject<'py>,
172{
173    let skip = usize::from(first_value.is_some());
174    if init_null_count == df.height() {
175        Ok(ChunkedArray::full_null(
176            PlSmallStr::from_static("map"),
177            df.height(),
178        ))
179    } else {
180        let iter = apply_iter(df, py, lambda, init_null_count, skip);
181        iterator_to_primitive(
182            iter,
183            init_null_count,
184            first_value,
185            PlSmallStr::from_static("map"),
186            df.height(),
187        )
188    }
189}
190
191/// Apply a lambda with a boolean output type
192pub fn apply_lambda_with_bool_out_type(
193    df: &DataFrame,
194    py: Python<'_>,
195    lambda: Bound<'_, PyAny>,
196    init_null_count: usize,
197    first_value: Option<bool>,
198) -> PyResult<ChunkedArray<BooleanType>> {
199    let skip = usize::from(first_value.is_some());
200    if init_null_count == df.height() {
201        Ok(ChunkedArray::full_null(
202            PlSmallStr::from_static("map"),
203            df.height(),
204        ))
205    } else {
206        let iter = apply_iter(df, py, lambda, init_null_count, skip);
207        iterator_to_bool(
208            iter,
209            init_null_count,
210            first_value,
211            PlSmallStr::from_static("map"),
212            df.height(),
213        )
214    }
215}
216
217/// Apply a lambda with string output type
218pub fn apply_lambda_with_string_out_type(
219    df: &DataFrame,
220    py: Python<'_>,
221    lambda: Bound<'_, PyAny>,
222    init_null_count: usize,
223    first_value: Option<PyBackedStr>,
224) -> PyResult<StringChunked> {
225    let skip = usize::from(first_value.is_some());
226    if init_null_count == df.height() {
227        Ok(ChunkedArray::full_null(
228            PlSmallStr::from_static("map"),
229            df.height(),
230        ))
231    } else {
232        let iter = apply_iter::<PyBackedStr>(df, py, lambda, init_null_count, skip);
233        iterator_to_string(
234            iter,
235            init_null_count,
236            first_value,
237            PlSmallStr::from_static("map"),
238            df.height(),
239        )
240    }
241}
242
243/// Apply a lambda with list output type
244pub fn apply_lambda_with_list_out_type(
245    df: &DataFrame,
246    py: Python<'_>,
247    lambda: Bound<'_, PyAny>,
248    init_null_count: usize,
249    first_value: Option<&Series>,
250    dt: &DataType,
251) -> PyResult<ListChunked> {
252    let skip = usize::from(first_value.is_some());
253    if init_null_count == df.height() {
254        Ok(ChunkedArray::full_null(
255            PlSmallStr::from_static("map"),
256            df.height(),
257        ))
258    } else {
259        let mut iters = get_iters_skip(df, init_null_count + skip);
260        let iter = ((init_null_count + skip)..df.height()).map(|_| {
261            let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
262            let tpl = (PyTuple::new(py, iter).unwrap(),);
263            let val = lambda.call1(tpl)?;
264            match val.getattr("_s") {
265                Ok(val) => val.extract::<PySeries>().map(|s| Some(s.series)),
266                Err(_) => {
267                    if val.is_none() {
268                        Ok(None)
269                    } else {
270                        Err(PyValueError::new_err(format!(
271                            "should return a Series, got a {val:?}"
272                        )))
273                    }
274                },
275            }
276        });
277        iterator_to_list(
278            dt,
279            iter,
280            init_null_count,
281            first_value,
282            PlSmallStr::from_static("map"),
283            df.height(),
284        )
285    }
286}
287
288pub fn apply_lambda_with_rows_output(
289    df: &DataFrame,
290    py: Python<'_>,
291    lambda: Bound<'_, PyAny>,
292    init_null_count: usize,
293    first_value: Row<'static>,
294    inference_size: usize,
295) -> PolarsResult<DataFrame> {
296    let width = first_value.0.len();
297    let null_row = Row::new(vec![AnyValue::Null; width]);
298
299    let mut row_buf = Row::default();
300
301    let skip = 1;
302    let mut iters = get_iters_skip(df, init_null_count + skip);
303    let mut row_iter = ((init_null_count + skip)..df.height()).map(|_| {
304        let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
305        let tpl = (PyTuple::new(py, iter).unwrap(),);
306
307        let return_val = lambda.call1(tpl) ?;
308        if return_val.is_none() {
309            Ok(&null_row)
310        } else {
311            let tuple = return_val.downcast::<PyTuple>().map_err(|_| polars_err!(ComputeError: format!("expected tuple, got {}", return_val.get_type().qualname().unwrap())))?;
312            row_buf.0.clear();
313            for v in tuple {
314                let v = v.extract::<Wrap<AnyValue>>().unwrap().0;
315                row_buf.0.push(v);
316            }
317            let ptr = &row_buf as *const Row;
318            // SAFETY:
319            // we know that row constructor of polars dataframe does not keep a reference
320            // to the row. Before we mutate the row buf again, the reference is dropped.
321            // we only cannot prove it to the compiler.
322            // we still to this because it save a Vec allocation in a hot loop.
323            Ok(unsafe { &*ptr })
324        }
325    });
326
327    // first rows for schema inference
328    let mut buf = Vec::with_capacity(inference_size);
329    buf.push(first_value);
330    for v in (&mut row_iter).take(inference_size) {
331        buf.push(v?.clone());
332    }
333
334    let schema = rows_to_schema_first_non_null(&buf, Some(50))?;
335
336    if init_null_count > 0 {
337        // SAFETY: we know the iterators size
338        let iter = unsafe {
339            (0..init_null_count)
340                .map(|_| Ok(&null_row))
341                .chain(buf.iter().map(Ok))
342                .chain(row_iter)
343                .trust_my_length(df.height())
344        };
345        DataFrame::try_from_rows_iter_and_schema(iter, &schema)
346    } else {
347        // SAFETY: we know the iterators size
348        let iter = unsafe {
349            buf.iter()
350                .map(Ok)
351                .chain(row_iter)
352                .trust_my_length(df.height())
353        };
354        DataFrame::try_from_rows_iter_and_schema(iter, &schema)
355    }
356}