polars_plan/dsl/python_dsl/
python_udf.rs1use std::io::Cursor;
2use std::sync::{Arc, OnceLock};
3
4use polars_core::datatypes::{DataType, Field};
5use polars_core::error::*;
6use polars_core::frame::DataFrame;
7use polars_core::frame::column::Column;
8use polars_core::schema::Schema;
9use polars_utils::pl_str::PlSmallStr;
10use pyo3::prelude::*;
11use pyo3::pybacked::PyBackedBytes;
12use pyo3::types::PyBytes;
13
14use crate::constants::MAP_LIST_NAME;
15use crate::prelude::*;
16
17#[allow(clippy::type_complexity)]
19pub static mut CALL_COLUMNS_UDF_PYTHON: Option<
20 fn(s: Column, output_dtype: Option<DataType>, lambda: &PyObject) -> PolarsResult<Column>,
21> = None;
22pub static mut CALL_DF_UDF_PYTHON: Option<
23 fn(s: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame>,
24> = None;
25
26pub use polars_utils::python_function::PythonFunction;
27#[cfg(feature = "serde")]
28pub use polars_utils::python_function::{PYTHON_SERDE_MAGIC_BYTE_MARK, PYTHON3_VERSION};
29
30pub struct PythonUdfExpression {
31 python_function: PyObject,
32 output_type: Option<DataTypeExpr>,
33 materialized_output_type: OnceLock<DataType>,
34 is_elementwise: bool,
35 returns_scalar: bool,
36}
37
38impl PythonUdfExpression {
39 pub fn new(
40 lambda: PyObject,
41 output_type: Option<impl Into<DataTypeExpr>>,
42 is_elementwise: bool,
43 returns_scalar: bool,
44 ) -> Self {
45 let output_type = output_type.map(Into::into);
46 Self {
47 python_function: lambda,
48 output_type,
49 materialized_output_type: OnceLock::new(),
50 is_elementwise,
51 returns_scalar,
52 }
53 }
54
55 #[cfg(feature = "serde")]
56 pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn ColumnsUdf>> {
57 use polars_utils::pl_serialize;
60 debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK));
61 let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
62
63 let use_cloudpickle = buf[0];
65 if use_cloudpickle != 0 {
66 let ser_py_version = &buf[1..3];
67 let cur_py_version = *PYTHON3_VERSION;
68 polars_ensure!(
69 ser_py_version == cur_py_version,
70 InvalidOperation:
71 "current Python version {:?} does not match the Python version used to serialize the UDF {:?}",
72 (3, cur_py_version[0], cur_py_version[1]),
73 (3, ser_py_version[0], ser_py_version[1] )
74 );
75 }
76 let buf = &buf[3..];
77
78 let mut reader = Cursor::new(buf);
80 let (output_type, is_elementwise, returns_scalar): (Option<DataTypeExpr>, bool, bool) =
81 pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?;
82
83 let remainder = &buf[reader.position() as usize..];
84
85 Python::with_gil(|py| {
87 let pickle = PyModule::import(py, "pickle")
88 .expect("unable to import 'pickle'")
89 .getattr("loads")
90 .unwrap();
91 let arg = (PyBytes::new(py, remainder),);
92 let python_function = pickle.call1(arg)?;
93 Ok(Arc::new(Self::new(
94 python_function.into(),
95 output_type,
96 is_elementwise,
97 returns_scalar,
98 )) as Arc<dyn ColumnsUdf>)
99 })
100 }
101}
102
103impl DataFrameUdf for polars_utils::python_function::PythonFunction {
104 fn call_udf(&self, df: DataFrame) -> PolarsResult<DataFrame> {
105 let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() };
106 func(df, &self.0)
107 }
108}
109
110impl ColumnsUdf for PythonUdfExpression {
111 fn resolve_dsl(&self, input_schema: &Schema) -> PolarsResult<()> {
112 if let Some(output_type) = self.output_type.as_ref() {
113 let dtype = output_type.clone().into_datatype(input_schema)?;
114 self.materialized_output_type.get_or_init(|| dtype);
115 }
116 Ok(())
117 }
118
119 fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>> {
120 let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };
121
122 let output_type = self
123 .materialized_output_type
124 .get()
125 .map_or_else(|| DataType::Unknown(Default::default()), |dt| dt.clone());
126 let mut out = func(
127 s[0].clone(),
128 self.materialized_output_type.get().cloned(),
129 &self.python_function,
130 )?;
131 if !matches!(output_type, DataType::Unknown(_)) {
132 let must_cast = out.dtype().matches_schema_type(&output_type).map_err(|_| {
133 polars_err!(
134 SchemaMismatch: "expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",
135 output_type, out.dtype(),
136 )
137 })?;
138 if must_cast {
139 out = out.cast(&output_type)?;
140 }
141 }
142
143 Ok(Some(out))
144 }
145
146 #[cfg(feature = "serde")]
147 fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
148 use polars_utils::pl_serialize;
151 buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
152
153 Python::with_gil(|py| {
154 let pickle = PyModule::import(py, "pickle")
156 .expect("unable to import 'pickle'")
157 .getattr("dumps")
158 .unwrap();
159 let pickle_result = pickle.call1((self.python_function.clone_ref(py),));
160 let (dumped, use_cloudpickle) = match pickle_result {
161 Ok(dumped) => (dumped, false),
162 Err(_) => {
163 let cloudpickle = PyModule::import(py, "cloudpickle")?
164 .getattr("dumps")
165 .unwrap();
166 let dumped = cloudpickle.call1((self.python_function.clone_ref(py),))?;
167 (dumped, true)
168 },
169 };
170
171 buf.push(use_cloudpickle as u8);
173 buf.extend_from_slice(&*PYTHON3_VERSION);
174
175 pl_serialize::serialize_into_writer::<_, _, true>(
177 &mut *buf,
178 &(
179 self.output_type.clone(),
180 self.is_elementwise,
181 self.returns_scalar,
182 ),
183 )?;
184
185 let dumped = dumped.extract::<PyBackedBytes>().unwrap();
187 buf.extend_from_slice(&dumped);
188 Ok(())
189 })
190 }
191}
192
193pub struct PythonGetOutput {
195 return_dtype: Option<DataTypeExpr>,
196 materialized_output_type: OnceLock<DataType>,
197}
198
199impl PythonGetOutput {
200 pub fn new(return_dtype: Option<impl Into<DataTypeExpr>>) -> Self {
201 Self {
202 return_dtype: return_dtype.map(Into::into),
203 materialized_output_type: OnceLock::new(),
204 }
205 }
206
207 #[cfg(feature = "serde")]
208 pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn FunctionOutputField>> {
209 use polars_utils::pl_serialize;
212 debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK));
213 let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
214
215 let mut reader = Cursor::new(buf);
216 let return_dtype: Option<DataTypeExpr> =
217 pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?;
218
219 Ok(Arc::new(Self::new(return_dtype)) as Arc<dyn FunctionOutputField>)
220 }
221}
222
223impl FunctionOutputField for PythonGetOutput {
224 fn resolve_dsl(&self, input_schema: &Schema) -> PolarsResult<()> {
225 if let Some(output_type) = self.return_dtype.as_ref() {
226 let dtype = output_type.clone().into_datatype(input_schema)?;
227 self.materialized_output_type.get_or_init(|| dtype);
228 }
229 Ok(())
230 }
231
232 fn get_field(
233 &self,
234 _input_schema: &Schema,
235 _cntxt: Context,
236 fields: &[Field],
237 ) -> PolarsResult<Field> {
238 let name = fields[0].name();
240 let return_dtype = match self.materialized_output_type.get() {
241 Some(dtype) => dtype.clone(),
242 None => DataType::Unknown(Default::default()),
243 };
244 Ok(Field::new(name.clone(), return_dtype))
245 }
246
247 #[cfg(feature = "serde")]
248 fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
249 use polars_utils::pl_serialize;
250
251 buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
252 pl_serialize::serialize_into_writer::<_, _, true>(&mut *buf, &self.return_dtype)
253 }
254}
255
256impl Expr {
257 pub fn map_python(self, func: PythonUdfExpression, agg_list: bool) -> Expr {
258 let name = if agg_list {
259 MAP_LIST_NAME
260 } else {
261 "python_udf"
262 };
263
264 let returns_scalar = func.returns_scalar;
265 let return_dtype = func.output_type.clone();
266
267 let output_field = PythonGetOutput::new(return_dtype);
268 let output_type = LazySerde::Deserialized(SpecialEq::new(
269 Arc::new(output_field) as Arc<dyn FunctionOutputField>
270 ));
271
272 let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT;
273 if agg_list {
274 flags |= FunctionFlags::APPLY_LIST;
275 }
276 if func.is_elementwise {
277 flags.set_elementwise();
278 }
279 if returns_scalar {
280 flags |= FunctionFlags::RETURNS_SCALAR;
281 }
282
283 Expr::AnonymousFunction {
284 input: vec![self],
285 function: new_column_udf(func),
286 output_type,
287 options: FunctionOptions {
288 flags,
289 ..Default::default()
290 },
291 fmt_str: Box::new(PlSmallStr::from(name)),
292 }
293 }
294}