1use std::mem::{ManuallyDrop, MaybeUninit};
2
3use polars::prelude::*;
4use polars_ffi::version_0::SeriesExport;
5use pyo3::prelude::*;
6use pyo3::types::PyList;
7
8use crate::py_modules::{pl_series, polars, polars_rs};
9use crate::series::PySeries;
10use crate::{PyExpr, Wrap};
11
12pub(crate) trait ToSeries {
13 fn to_series(
14 &self,
15 py: Python<'_>,
16 py_polars_module: &Py<PyModule>,
17 name: &str,
18 ) -> PolarsResult<Series>;
19}
20
21impl ToSeries for PyObject {
22 fn to_series(
23 &self,
24 py: Python<'_>,
25 py_polars_module: &Py<PyModule>,
26 name: &str,
27 ) -> PolarsResult<Series> {
28 let py_pyseries = match self.getattr(py, "_s") {
29 Ok(s) => s,
30 _ => {
32 let res = py_polars_module
33 .getattr(py, "Series")
34 .unwrap()
35 .call1(py, (name, PyList::new(py, [self]).unwrap()));
36
37 match res {
38 Ok(python_s) => python_s.getattr(py, "_s").unwrap(),
39 Err(_) => {
40 polars_bail!(ComputeError:
41 "expected a something that could convert to a `Series` but got: {}",
42 self.bind(py).get_type()
43 )
44 },
45 }
46 },
47 };
48 let s = match py_pyseries.extract::<PySeries>(py) {
49 Ok(pyseries) => pyseries.series,
50 Err(_) => {
53 let mut export: MaybeUninit<SeriesExport> = MaybeUninit::uninit();
54 py_pyseries
55 .call_method1(py, "_export", (&raw mut export as usize,))
56 .unwrap();
57 unsafe {
58 let export = export.assume_init();
59 polars_ffi::version_0::import_series(export)?
60 }
61 },
62 };
63 Ok(s)
64 }
65}
66
67pub(crate) fn call_lambda_with_series(
68 py: Python<'_>,
69 s: &Series,
70 lambda: &PyObject,
71) -> PyResult<PyObject> {
72 let pypolars = polars(py).bind(py);
73
74 let pyseries = PySeries::new(s.clone());
76 let mut python_series_wrapper = pypolars
78 .getattr("wrap_s")
79 .unwrap()
80 .call1((pyseries,))
81 .unwrap();
82
83 if !python_series_wrapper
84 .getattr("_s")
85 .unwrap()
86 .is_instance(polars_rs(py).getattr(py, "PySeries").unwrap().bind(py))
87 .unwrap()
88 {
89 let mut export = ManuallyDrop::new(polars_ffi::version_0::export_series(s));
90 let plseries = pl_series(py).bind(py);
91
92 let s_location = &raw mut export;
93 python_series_wrapper = plseries
94 .getattr("_import")
95 .unwrap()
96 .call1((s_location as usize,))
97 .unwrap()
98 }
99 lambda.call1(py, (python_series_wrapper,))
100}
101
102pub(crate) fn binary_lambda(
104 lambda: &PyObject,
105 a: Series,
106 b: Series,
107) -> PolarsResult<Option<Series>> {
108 Python::with_gil(|py| {
109 let pypolars = polars(py).bind(py);
111 let pyseries_a = PySeries::new(a);
113 let pyseries_b = PySeries::new(b);
114
115 let python_series_wrapper_a = pypolars
117 .getattr("wrap_s")
118 .unwrap()
119 .call1((pyseries_a,))
120 .unwrap();
121 let python_series_wrapper_b = pypolars
122 .getattr("wrap_s")
123 .unwrap()
124 .call1((pyseries_b,))
125 .unwrap();
126
127 let result_series_wrapper =
129 match lambda.call1(py, (python_series_wrapper_a, python_series_wrapper_b)) {
130 Ok(pyobj) => pyobj,
131 Err(e) => polars_bail!(
132 ComputeError: "custom python function failed: {}", e.value(py),
133 ),
134 };
135 let pyseries = if let Ok(expr) = result_series_wrapper.getattr(py, "_pyexpr") {
136 let pyexpr = expr.extract::<PyExpr>(py).unwrap();
137 let expr = pyexpr.inner;
138 let df = DataFrame::empty();
139 let out = df
140 .lazy()
141 .select([expr])
142 .with_predicate_pushdown(false)
143 .with_projection_pushdown(false)
144 .collect()?;
145
146 let s = out.select_at_idx(0).unwrap().clone();
147 PySeries::new(s.take_materialized_series())
148 } else {
149 return Some(result_series_wrapper.to_series(py, pypolars.as_unbound(), ""))
150 .transpose();
151 };
152
153 Ok(Some(pyseries.series))
155 })
156}
157
158pub fn map_single(
159 pyexpr: &PyExpr,
160 lambda: PyObject,
161 output_type: Option<Wrap<DataType>>,
162 agg_list: bool,
163 is_elementwise: bool,
164 returns_scalar: bool,
165) -> PyExpr {
166 let output_type = output_type.map(|wrap| wrap.0);
167
168 let func =
169 python_dsl::PythonUdfExpression::new(lambda, output_type, is_elementwise, returns_scalar);
170 pyexpr.inner.clone().map_python(func, agg_list).into()
171}
172
173pub(crate) fn call_lambda_with_columns_slice(
174 py: Python<'_>,
175 s: &[Column],
176 lambda: &PyObject,
177 pypolars: &Py<PyModule>,
178) -> PyObject {
179 let pypolars = pypolars.bind(py);
180
181 let iter = s.iter().map(|s| {
183 let ps = PySeries::new(s.as_materialized_series().clone());
184
185 pypolars.getattr("wrap_s").unwrap().call1((ps,)).unwrap()
187 });
188 let wrapped_s = PyList::new(py, iter).unwrap();
189
190 match lambda.call1(py, (wrapped_s,)) {
192 Ok(pyobj) => pyobj,
193 Err(e) => panic!("python function failed: {}", e.value(py)),
194 }
195}
196
197pub fn map_mul(
198 pyexpr: &[PyExpr],
199 py: Python<'_>,
200 lambda: PyObject,
201 output_type: Option<Wrap<DataType>>,
202 map_groups: bool,
203 returns_scalar: bool,
204) -> PyExpr {
205 let pypolars = polars(py).clone_ref(py);
208
209 let function = move |s: &mut [Column]| {
210 Python::with_gil(|py| {
211 let out = call_lambda_with_columns_slice(py, s, &lambda, &pypolars);
213
214 if map_groups && out.is_none(py) {
216 return Ok(None);
217 }
218
219 Ok(Some(out.to_series(py, &pypolars, "")?.into_column()))
220 })
221 };
222
223 let exprs = pyexpr.iter().map(|pe| pe.clone().inner).collect::<Vec<_>>();
224
225 let output_map = GetOutput::map_field(move |fld| {
226 Ok(match output_type {
227 Some(ref dt) => Field::new(fld.name().clone(), dt.0.clone()),
228 None => fld.clone(),
229 })
230 });
231 if map_groups {
232 polars::lazy::dsl::apply_multiple(function, exprs, output_map, returns_scalar).into()
233 } else {
234 polars::lazy::dsl::map_multiple(function, exprs, output_map).into()
235 }
236}