1use polars::prelude::*;
2use polars_core::frame::row::{Row, rows_to_schema_first_non_null};
3use polars_core::series::SeriesIter;
4use pyo3::IntoPyObjectExt;
5use pyo3::exceptions::PyValueError;
6use pyo3::prelude::*;
7use pyo3::pybacked::PyBackedStr;
8use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString, PyTuple};
9
10use super::*;
11use crate::PyDataFrame;
12
13fn get_iters(df: &DataFrame) -> Vec<SeriesIter<'_>> {
15 df.get_columns()
16 .iter()
17 .map(|s| s.as_materialized_series().iter())
18 .collect()
19}
20
21fn get_iters_skip(df: &DataFrame, n: usize) -> Vec<std::iter::Skip<SeriesIter<'_>>> {
23 df.get_columns()
24 .iter()
25 .map(|s| s.as_materialized_series().iter().skip(n))
26 .collect()
27}
28
29pub fn apply_lambda_unknown<'py>(
31 df: &DataFrame,
32 py: Python<'py>,
33 lambda: Bound<'py, PyAny>,
34 inference_size: usize,
35) -> PyResult<(PyObject, bool)> {
36 let mut null_count = 0;
37 let mut iters = get_iters(df);
38
39 for _ in 0..df.height() {
40 let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
41 let arg = (PyTuple::new(py, iter)?,);
42 let out = lambda.call1(arg)?;
43
44 if out.is_none() {
45 null_count += 1;
46 continue;
47 } else if out.is_instance_of::<PyBool>() {
48 let first_value = out.extract::<bool>().ok();
49 return Ok((
50 PySeries::new(
51 apply_lambda_with_bool_out_type(df, py, lambda, null_count, first_value)?
52 .into_series(),
53 )
54 .into_py_any(py)?,
55 false,
56 ));
57 } else if out.is_instance_of::<PyFloat>() {
58 let first_value = out.extract::<f64>().ok();
59
60 return Ok((
61 PySeries::new(
62 apply_lambda_with_primitive_out_type::<Float64Type>(
63 df,
64 py,
65 lambda,
66 null_count,
67 first_value,
68 )?
69 .into_series(),
70 )
71 .into_py_any(py)?,
72 false,
73 ));
74 } else if out.is_instance_of::<PyInt>() {
75 let first_value = out.extract::<i64>().ok();
76 return Ok((
77 PySeries::new(
78 apply_lambda_with_primitive_out_type::<Int64Type>(
79 df,
80 py,
81 lambda,
82 null_count,
83 first_value,
84 )?
85 .into_series(),
86 )
87 .into_py_any(py)?,
88 false,
89 ));
90 } else if out.is_instance_of::<PyString>() {
91 let first_value = out.extract::<PyBackedStr>().ok();
92 return Ok((
93 PySeries::new(
94 apply_lambda_with_string_out_type(df, py, lambda, null_count, first_value)?
95 .into_series(),
96 )
97 .into_py_any(py)?,
98 false,
99 ));
100 } else if out.hasattr("_s")? {
101 let py_pyseries = out.getattr("_s").unwrap();
102 let series = py_pyseries
103 .extract::<PySeries>()
104 .unwrap()
105 .series
106 .into_inner();
107 let dt = series.dtype();
108 return Ok((
109 PySeries::new(
110 apply_lambda_with_list_out_type(df, py, lambda, null_count, Some(&series), dt)?
111 .into_series(),
112 )
113 .into_py_any(py)?,
114 false,
115 ));
116 } else if out.extract::<Wrap<Row<'static>>>().is_ok() {
117 let first_value = out.extract::<Wrap<Row<'static>>>().unwrap().0;
118 return Ok((
119 PyDataFrame::from(
120 apply_lambda_with_rows_output(
121 df,
122 py,
123 lambda,
124 null_count,
125 first_value,
126 inference_size,
127 )
128 .map_err(PyPolarsErr::from)?,
129 )
130 .into_py_any(py)?,
131 true,
132 ));
133 } else if out.is_instance_of::<PyList>() || out.is_instance_of::<PyTuple>() {
134 return Err(PyPolarsErr::Other(
135 "A list output type is invalid. Do you mean to create polars List Series?\
136Then return a Series object."
137 .into(),
138 )
139 .into());
140 } else {
141 return Err(PyPolarsErr::Other("Could not determine output type".into()).into());
142 }
143 }
144 Err(PyPolarsErr::Other("Could not determine output type".into()).into())
145}
146
147fn apply_iter<'py, T>(
148 df: &DataFrame,
149 py: Python<'py>,
150 lambda: Bound<'py, PyAny>,
151 init_null_count: usize,
152 skip: usize,
153) -> impl Iterator<Item = PyResult<Option<T>>>
154where
155 T: FromPyObject<'py>,
156{
157 let mut iters = get_iters_skip(df, init_null_count + skip);
158 ((init_null_count + skip)..df.height()).map(move |_| {
159 let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
160 let tpl = (PyTuple::new(py, iter).unwrap(),);
161 lambda.call1(tpl).map(|v| v.extract().ok())
162 })
163}
164
165pub fn apply_lambda_with_primitive_out_type<'py, D>(
167 df: &DataFrame,
168 py: Python<'py>,
169 lambda: Bound<'py, PyAny>,
170 init_null_count: usize,
171 first_value: Option<D::Native>,
172) -> PyResult<ChunkedArray<D>>
173where
174 D: PyPolarsNumericType,
175 D::Native: IntoPyObject<'py> + FromPyObject<'py>,
176{
177 let skip = usize::from(first_value.is_some());
178 if init_null_count == df.height() {
179 Ok(ChunkedArray::full_null(
180 PlSmallStr::from_static("map"),
181 df.height(),
182 ))
183 } else {
184 let iter = apply_iter(df, py, lambda, init_null_count, skip);
185 iterator_to_primitive(
186 iter,
187 init_null_count,
188 first_value,
189 PlSmallStr::from_static("map"),
190 df.height(),
191 )
192 }
193}
194
195pub fn apply_lambda_with_bool_out_type(
197 df: &DataFrame,
198 py: Python<'_>,
199 lambda: Bound<'_, PyAny>,
200 init_null_count: usize,
201 first_value: Option<bool>,
202) -> PyResult<ChunkedArray<BooleanType>> {
203 let skip = usize::from(first_value.is_some());
204 if init_null_count == df.height() {
205 Ok(ChunkedArray::full_null(
206 PlSmallStr::from_static("map"),
207 df.height(),
208 ))
209 } else {
210 let iter = apply_iter(df, py, lambda, init_null_count, skip);
211 iterator_to_bool(
212 iter,
213 init_null_count,
214 first_value,
215 PlSmallStr::from_static("map"),
216 df.height(),
217 )
218 }
219}
220
221pub fn apply_lambda_with_string_out_type(
223 df: &DataFrame,
224 py: Python<'_>,
225 lambda: Bound<'_, PyAny>,
226 init_null_count: usize,
227 first_value: Option<PyBackedStr>,
228) -> PyResult<StringChunked> {
229 let skip = usize::from(first_value.is_some());
230 if init_null_count == df.height() {
231 Ok(ChunkedArray::full_null(
232 PlSmallStr::from_static("map"),
233 df.height(),
234 ))
235 } else {
236 let iter = apply_iter::<PyBackedStr>(df, py, lambda, init_null_count, skip);
237 iterator_to_string(
238 iter,
239 init_null_count,
240 first_value,
241 PlSmallStr::from_static("map"),
242 df.height(),
243 )
244 }
245}
246
247pub fn apply_lambda_with_list_out_type(
249 df: &DataFrame,
250 py: Python<'_>,
251 lambda: Bound<'_, PyAny>,
252 init_null_count: usize,
253 first_value: Option<&Series>,
254 dt: &DataType,
255) -> PyResult<ListChunked> {
256 let skip = usize::from(first_value.is_some());
257 if init_null_count == df.height() {
258 Ok(ChunkedArray::full_null(
259 PlSmallStr::from_static("map"),
260 df.height(),
261 ))
262 } else {
263 let mut iters = get_iters_skip(df, init_null_count + skip);
264 let iter = ((init_null_count + skip)..df.height()).map(|_| {
265 let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
266 let tpl = (PyTuple::new(py, iter).unwrap(),);
267 let val = lambda.call1(tpl)?;
268 match val.getattr("_s") {
269 Ok(val) => val
270 .extract::<PySeries>()
271 .map(|s| Some(s.series.into_inner())),
272 Err(_) => {
273 if val.is_none() {
274 Ok(None)
275 } else {
276 Err(PyValueError::new_err(format!(
277 "should return a Series, got a {val:?}"
278 )))
279 }
280 },
281 }
282 });
283 iterator_to_list(
284 dt,
285 iter,
286 init_null_count,
287 first_value,
288 PlSmallStr::from_static("map"),
289 df.height(),
290 )
291 }
292}
293
294pub fn apply_lambda_with_rows_output(
295 df: &DataFrame,
296 py: Python<'_>,
297 lambda: Bound<'_, PyAny>,
298 init_null_count: usize,
299 first_value: Row<'static>,
300 inference_size: usize,
301) -> PolarsResult<DataFrame> {
302 let width = first_value.0.len();
303 let null_row = Row::new(vec![AnyValue::Null; width]);
304
305 let mut row_buf = Row::default();
306
307 let skip = 1;
308 let mut iters = get_iters_skip(df, init_null_count + skip);
309 let mut row_iter = ((init_null_count + skip)..df.height()).map(|_| {
310 let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
311 let tpl = (PyTuple::new(py, iter).unwrap(),);
312
313 let return_val = lambda.call1(tpl) ?;
314 if return_val.is_none() {
315 Ok(&null_row)
316 } else {
317 let tuple = return_val.downcast::<PyTuple>().map_err(|_| polars_err!(ComputeError: format!("expected tuple, got {}", return_val.get_type().qualname().unwrap())))?;
318 row_buf.0.clear();
319 for v in tuple {
320 let v = v.extract::<Wrap<AnyValue>>().unwrap().0;
321 row_buf.0.push(v);
322 }
323 let ptr = &row_buf as *const Row;
324 Ok(unsafe { &*ptr })
330 }
331 });
332
333 let mut buf = Vec::with_capacity(inference_size);
335 buf.push(first_value);
336 for v in (&mut row_iter).take(inference_size) {
337 buf.push(v?.clone());
338 }
339
340 let schema = rows_to_schema_first_non_null(&buf, Some(50))?;
341
342 if init_null_count > 0 {
343 let iter = unsafe {
345 (0..init_null_count)
346 .map(|_| Ok(&null_row))
347 .chain(buf.iter().map(Ok))
348 .chain(row_iter)
349 .trust_my_length(df.height())
350 };
351 DataFrame::try_from_rows_iter_and_schema(iter, &schema)
352 } else {
353 let iter = unsafe {
355 buf.iter()
356 .map(Ok)
357 .chain(row_iter)
358 .trust_my_length(df.height())
359 };
360 DataFrame::try_from_rows_iter_and_schema(iter, &schema)
361 }
362}