1use polars::prelude::*;
2use polars_core::frame::row::{Row, rows_to_schema_first_non_null};
3use polars_core::series::SeriesIter;
4use pyo3::IntoPyObjectExt;
5use pyo3::exceptions::PyValueError;
6use pyo3::prelude::*;
7use pyo3::pybacked::PyBackedStr;
8use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString, PyTuple};
9
10use super::*;
11use crate::PyDataFrame;
12
13fn get_iters(df: &DataFrame) -> Vec<SeriesIter> {
15 df.get_columns()
16 .iter()
17 .map(|s| s.as_materialized_series().iter())
18 .collect()
19}
20
21fn get_iters_skip(df: &DataFrame, n: usize) -> Vec<std::iter::Skip<SeriesIter>> {
23 df.get_columns()
24 .iter()
25 .map(|s| s.as_materialized_series().iter().skip(n))
26 .collect()
27}
28
29pub fn apply_lambda_unknown<'py>(
31 df: &DataFrame,
32 py: Python<'py>,
33 lambda: Bound<'py, PyAny>,
34 inference_size: usize,
35) -> PyResult<(PyObject, bool)> {
36 let mut null_count = 0;
37 let mut iters = get_iters(df);
38
39 for _ in 0..df.height() {
40 let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
41 let arg = (PyTuple::new(py, iter)?,);
42 let out = lambda.call1(arg)?;
43
44 if out.is_none() {
45 null_count += 1;
46 continue;
47 } else if out.is_instance_of::<PyBool>() {
48 let first_value = out.extract::<bool>().ok();
49 return Ok((
50 PySeries::new(
51 apply_lambda_with_bool_out_type(df, py, lambda, null_count, first_value)?
52 .into_series(),
53 )
54 .into_py_any(py)?,
55 false,
56 ));
57 } else if out.is_instance_of::<PyFloat>() {
58 let first_value = out.extract::<f64>().ok();
59
60 return Ok((
61 PySeries::new(
62 apply_lambda_with_primitive_out_type::<Float64Type>(
63 df,
64 py,
65 lambda,
66 null_count,
67 first_value,
68 )?
69 .into_series(),
70 )
71 .into_py_any(py)?,
72 false,
73 ));
74 } else if out.is_instance_of::<PyInt>() {
75 let first_value = out.extract::<i64>().ok();
76 return Ok((
77 PySeries::new(
78 apply_lambda_with_primitive_out_type::<Int64Type>(
79 df,
80 py,
81 lambda,
82 null_count,
83 first_value,
84 )?
85 .into_series(),
86 )
87 .into_py_any(py)?,
88 false,
89 ));
90 } else if out.is_instance_of::<PyString>() {
91 let first_value = out.extract::<PyBackedStr>().ok();
92 return Ok((
93 PySeries::new(
94 apply_lambda_with_string_out_type(df, py, lambda, null_count, first_value)?
95 .into_series(),
96 )
97 .into_py_any(py)?,
98 false,
99 ));
100 } else if out.hasattr("_s")? {
101 let py_pyseries = out.getattr("_s").unwrap();
102 let series = py_pyseries.extract::<PySeries>().unwrap().series;
103 let dt = series.dtype();
104 return Ok((
105 PySeries::new(
106 apply_lambda_with_list_out_type(df, py, lambda, null_count, Some(&series), dt)?
107 .into_series(),
108 )
109 .into_py_any(py)?,
110 false,
111 ));
112 } else if out.extract::<Wrap<Row<'static>>>().is_ok() {
113 let first_value = out.extract::<Wrap<Row<'static>>>().unwrap().0;
114 return Ok((
115 PyDataFrame::from(
116 apply_lambda_with_rows_output(
117 df,
118 py,
119 lambda,
120 null_count,
121 first_value,
122 inference_size,
123 )
124 .map_err(PyPolarsErr::from)?,
125 )
126 .into_py_any(py)?,
127 true,
128 ));
129 } else if out.is_instance_of::<PyList>() || out.is_instance_of::<PyTuple>() {
130 return Err(PyPolarsErr::Other(
131 "A list output type is invalid. Do you mean to create polars List Series?\
132Then return a Series object."
133 .into(),
134 )
135 .into());
136 } else {
137 return Err(PyPolarsErr::Other("Could not determine output type".into()).into());
138 }
139 }
140 Err(PyPolarsErr::Other("Could not determine output type".into()).into())
141}
142
143fn apply_iter<'py, T>(
144 df: &DataFrame,
145 py: Python<'py>,
146 lambda: Bound<'py, PyAny>,
147 init_null_count: usize,
148 skip: usize,
149) -> impl Iterator<Item = PyResult<Option<T>>>
150where
151 T: FromPyObject<'py>,
152{
153 let mut iters = get_iters_skip(df, init_null_count + skip);
154 ((init_null_count + skip)..df.height()).map(move |_| {
155 let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
156 let tpl = (PyTuple::new(py, iter).unwrap(),);
157 lambda.call1(tpl).map(|v| v.extract().ok())
158 })
159}
160
161pub fn apply_lambda_with_primitive_out_type<'py, D>(
163 df: &DataFrame,
164 py: Python<'py>,
165 lambda: Bound<'py, PyAny>,
166 init_null_count: usize,
167 first_value: Option<D::Native>,
168) -> PyResult<ChunkedArray<D>>
169where
170 D: PyPolarsNumericType,
171 D::Native: IntoPyObject<'py> + FromPyObject<'py>,
172{
173 let skip = usize::from(first_value.is_some());
174 if init_null_count == df.height() {
175 Ok(ChunkedArray::full_null(
176 PlSmallStr::from_static("map"),
177 df.height(),
178 ))
179 } else {
180 let iter = apply_iter(df, py, lambda, init_null_count, skip);
181 iterator_to_primitive(
182 iter,
183 init_null_count,
184 first_value,
185 PlSmallStr::from_static("map"),
186 df.height(),
187 )
188 }
189}
190
191pub fn apply_lambda_with_bool_out_type(
193 df: &DataFrame,
194 py: Python<'_>,
195 lambda: Bound<'_, PyAny>,
196 init_null_count: usize,
197 first_value: Option<bool>,
198) -> PyResult<ChunkedArray<BooleanType>> {
199 let skip = usize::from(first_value.is_some());
200 if init_null_count == df.height() {
201 Ok(ChunkedArray::full_null(
202 PlSmallStr::from_static("map"),
203 df.height(),
204 ))
205 } else {
206 let iter = apply_iter(df, py, lambda, init_null_count, skip);
207 iterator_to_bool(
208 iter,
209 init_null_count,
210 first_value,
211 PlSmallStr::from_static("map"),
212 df.height(),
213 )
214 }
215}
216
217pub fn apply_lambda_with_string_out_type(
219 df: &DataFrame,
220 py: Python<'_>,
221 lambda: Bound<'_, PyAny>,
222 init_null_count: usize,
223 first_value: Option<PyBackedStr>,
224) -> PyResult<StringChunked> {
225 let skip = usize::from(first_value.is_some());
226 if init_null_count == df.height() {
227 Ok(ChunkedArray::full_null(
228 PlSmallStr::from_static("map"),
229 df.height(),
230 ))
231 } else {
232 let iter = apply_iter::<PyBackedStr>(df, py, lambda, init_null_count, skip);
233 iterator_to_string(
234 iter,
235 init_null_count,
236 first_value,
237 PlSmallStr::from_static("map"),
238 df.height(),
239 )
240 }
241}
242
243pub fn apply_lambda_with_list_out_type(
245 df: &DataFrame,
246 py: Python<'_>,
247 lambda: Bound<'_, PyAny>,
248 init_null_count: usize,
249 first_value: Option<&Series>,
250 dt: &DataType,
251) -> PyResult<ListChunked> {
252 let skip = usize::from(first_value.is_some());
253 if init_null_count == df.height() {
254 Ok(ChunkedArray::full_null(
255 PlSmallStr::from_static("map"),
256 df.height(),
257 ))
258 } else {
259 let mut iters = get_iters_skip(df, init_null_count + skip);
260 let iter = ((init_null_count + skip)..df.height()).map(|_| {
261 let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
262 let tpl = (PyTuple::new(py, iter).unwrap(),);
263 let val = lambda.call1(tpl)?;
264 match val.getattr("_s") {
265 Ok(val) => val.extract::<PySeries>().map(|s| Some(s.series)),
266 Err(_) => {
267 if val.is_none() {
268 Ok(None)
269 } else {
270 Err(PyValueError::new_err(format!(
271 "should return a Series, got a {val:?}"
272 )))
273 }
274 },
275 }
276 });
277 iterator_to_list(
278 dt,
279 iter,
280 init_null_count,
281 first_value,
282 PlSmallStr::from_static("map"),
283 df.height(),
284 )
285 }
286}
287
288pub fn apply_lambda_with_rows_output(
289 df: &DataFrame,
290 py: Python<'_>,
291 lambda: Bound<'_, PyAny>,
292 init_null_count: usize,
293 first_value: Row<'static>,
294 inference_size: usize,
295) -> PolarsResult<DataFrame> {
296 let width = first_value.0.len();
297 let null_row = Row::new(vec![AnyValue::Null; width]);
298
299 let mut row_buf = Row::default();
300
301 let skip = 1;
302 let mut iters = get_iters_skip(df, init_null_count + skip);
303 let mut row_iter = ((init_null_count + skip)..df.height()).map(|_| {
304 let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
305 let tpl = (PyTuple::new(py, iter).unwrap(),);
306
307 let return_val = lambda.call1(tpl) ?;
308 if return_val.is_none() {
309 Ok(&null_row)
310 } else {
311 let tuple = return_val.downcast::<PyTuple>().map_err(|_| polars_err!(ComputeError: format!("expected tuple, got {}", return_val.get_type().qualname().unwrap())))?;
312 row_buf.0.clear();
313 for v in tuple {
314 let v = v.extract::<Wrap<AnyValue>>().unwrap().0;
315 row_buf.0.push(v);
316 }
317 let ptr = &row_buf as *const Row;
318 Ok(unsafe { &*ptr })
324 }
325 });
326
327 let mut buf = Vec::with_capacity(inference_size);
329 buf.push(first_value);
330 for v in (&mut row_iter).take(inference_size) {
331 buf.push(v?.clone());
332 }
333
334 let schema = rows_to_schema_first_non_null(&buf, Some(50))?;
335
336 if init_null_count > 0 {
337 let iter = unsafe {
339 (0..init_null_count)
340 .map(|_| Ok(&null_row))
341 .chain(buf.iter().map(Ok))
342 .chain(row_iter)
343 .trust_my_length(df.height())
344 };
345 DataFrame::try_from_rows_iter_and_schema(iter, &schema)
346 } else {
347 let iter = unsafe {
349 buf.iter()
350 .map(Ok)
351 .chain(row_iter)
352 .trust_my_length(df.height())
353 };
354 DataFrame::try_from_rows_iter_and_schema(iter, &schema)
355 }
356}