Skip to main content

polars_python/dataframe/
map.rs

1use polars::frame::row::{Row, rows_to_schema_first_non_null};
2use polars_core::utils::CustomIterTools;
3use pyo3::IntoPyObjectExt;
4use pyo3::prelude::*;
5use pyo3::types::PyTuple;
6
7use super::*;
8use crate::error::PyPolarsErr;
9use crate::prelude::*;
10#[cfg(feature = "object")]
11use crate::series::construction::series_from_objects;
12use crate::{PySeries, raise_err};
13
14#[pymethods]
15impl PyDataFrame {
16    #[pyo3(signature = (lambda, output_type, inference_size))]
17    pub fn map_rows(
18        &self,
19        py: Python<'_>,
20        lambda: Bound<PyAny>,
21        output_type: Option<Wrap<DataType>>,
22        inference_size: usize,
23    ) -> PyResult<(Py<PyAny>, bool)> {
24        let df = self.df.read();
25        let height = df.height();
26        let col_series: Vec<_> = df
27            .columns()
28            .iter()
29            .map(|s| s.as_materialized_series().clone())
30            .collect();
31        let mut iters: Vec<_> = col_series.iter().map(|c| c.iter()).collect();
32        drop(df); // Release lock before calling lambda.
33
34        let lambda_result_iter = (0..height).map(move |_| {
35            let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
36            let tpl = (PyTuple::new(py, iter).unwrap(),);
37            lambda.call1(tpl)
38        });
39
40        // Simple case: return type set.
41        if let Some(output_type) = &output_type {
42            // If the output type is Object we should not go through AnyValue.
43            #[cfg(feature = "object")]
44            if let DataType::Object(_) = output_type.0 {
45                let objects = lambda_result_iter
46                    .map(|res| {
47                        Ok(ObjectValue {
48                            inner: res?.unbind(),
49                        })
50                    })
51                    .collect::<PyResult<Vec<_>>>()?;
52                let s = series_from_objects(py, PlSmallStr::from_static("map"), objects);
53                return Ok((PySeries::from(s).into_py_any(py)?, false));
54            }
55
56            let avs = lambda_result_iter
57                .map(|res| res?.extract::<Wrap<AnyValue>>().map(|w| w.0))
58                .collect::<PyResult<Vec<AnyValue>>>()?;
59            let s = Series::from_any_values_and_dtype(
60                PlSmallStr::from_static("map"),
61                &avs,
62                &output_type.0,
63                true,
64            )
65            .map_err(PyPolarsErr::from)?;
66            return Ok((PySeries::from(s).into_py_any(py)?, false));
67        }
68
69        // Disambiguate returning a DataFrame vs Series by checking the
70        // first non-null output value.
71        let mut peek_iter = lambda_result_iter.peekable();
72        let mut null_count = 0;
73        while let Some(ret) = peek_iter.peek() {
74            if let Ok(v) = ret
75                && v.is_none()
76            {
77                null_count += 1;
78                peek_iter.next();
79            } else {
80                break;
81            }
82        }
83
84        let first_val = match peek_iter.peek() {
85            Some(Ok(v)) => v,
86            Some(Err(e)) => return Err(e.clone_ref(py)),
87            None => {
88                let msg = "The output type of the 'map_rows' function cannot be determined.\n\
89                All returned values are None, consider setting the 'return_dtype'.";
90                raise_err!(msg, ComputeError)
91            },
92        };
93
94        if let Ok(first_row) = first_val.cast::<PyTuple>() {
95            let width = first_row.len();
96            let out_df = collect_lambda_ret_with_rows_output(
97                height,
98                width,
99                null_count,
100                inference_size,
101                peek_iter,
102            )
103            .map_err(PyPolarsErr::from)?;
104            Ok((PyDataFrame::from(out_df).into_py_any(py)?, true))
105        } else {
106            let avs = peek_iter
107                .map(|res| res?.extract::<Wrap<AnyValue>>().map(|w| w.0))
108                .collect::<PyResult<Vec<AnyValue>>>()?;
109            let s = Series::from_any_values(PlSmallStr::from_static("map"), &avs, true)
110                .map_err(PyPolarsErr::from)?;
111
112            let out = if null_count > 0 {
113                let mut tmp = Series::full_null(s.name().clone(), null_count, s.dtype());
114                tmp.append_owned(s).map_err(PyPolarsErr::from)?;
115                tmp
116            } else {
117                s
118            };
119            Ok((PySeries::from(out).into_py_any(py)?, false))
120        }
121    }
122}
123
124fn collect_lambda_ret_with_rows_output<'py>(
125    height: usize,
126    width: usize,
127    init_null_count: usize,
128    inference_size: usize,
129    ret_iter: impl Iterator<Item = PyResult<Bound<'py, PyAny>>>,
130) -> PolarsResult<DataFrame> {
131    let null_row = Row::new(vec![AnyValue::Null; width]);
132
133    let mut row_buf = Row::default();
134    let mut row_iter = ret_iter.map(|retval| {
135        let retval = retval?;
136        if retval.is_none() {
137            Ok(&null_row)
138        } else {
139            let tuple = retval.cast::<PyTuple>().map_err(|_| polars_err!(ComputeError: format!("expected tuple, got {}", retval.get_type().qualname().unwrap())))?;
140            row_buf.0.clear();
141            for v in tuple {
142                let v = v.extract::<Wrap<AnyValue>>().unwrap().0;
143                row_buf.0.push(v);
144            }
145            let ptr = &row_buf as *const Row;
146            // SAFETY:
147            // we know that row constructor of polars dataframe does not keep a reference
148            // to the row. Before we mutate the row buf again, the reference is dropped.
149            // we only cannot prove it to the compiler.
150            // we still to this because it save a Vec allocation in a hot loop.
151            Ok(unsafe { &*ptr })
152        }
153    });
154
155    // First rows for schema inference.
156    let mut buf = Vec::with_capacity(inference_size);
157    for v in (&mut row_iter).take(inference_size) {
158        buf.push(v?.clone());
159    }
160
161    let schema = rows_to_schema_first_non_null(&buf, Some(50))?;
162
163    if init_null_count > 0 {
164        // SAFETY: we know the iterators size.
165        let iter = unsafe {
166            (0..init_null_count)
167                .map(|_| Ok(&null_row))
168                .chain(buf.iter().map(Ok))
169                .chain(row_iter)
170                .trust_my_length(height)
171        };
172        DataFrame::try_from_rows_iter_and_schema(iter, &schema)
173    } else {
174        // SAFETY: we know the iterators size.
175        let iter = unsafe { buf.iter().map(Ok).chain(row_iter).trust_my_length(height) };
176        DataFrame::try_from_rows_iter_and_schema(iter, &schema)
177    }
178}