polars_plan/dsl/python_dsl/
python_udf.rs1use std::io::Cursor;
2use std::sync::Arc;
3
4use polars_core::datatypes::{DataType, Field};
5use polars_core::error::*;
6use polars_core::frame::DataFrame;
7use polars_core::frame::column::Column;
8use polars_core::schema::Schema;
9use pyo3::prelude::*;
10use pyo3::pybacked::PyBackedBytes;
11use pyo3::types::PyBytes;
12
13use crate::constants::MAP_LIST_NAME;
14use crate::prelude::*;
15
16pub static mut CALL_COLUMNS_UDF_PYTHON: Option<
18 fn(s: Column, lambda: &PyObject) -> PolarsResult<Column>,
19> = None;
20pub static mut CALL_DF_UDF_PYTHON: Option<
21 fn(s: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame>,
22> = None;
23
24pub use polars_utils::python_function::{
25 PYTHON_SERDE_MAGIC_BYTE_MARK, PYTHON3_VERSION, PythonFunction,
26};
27
28pub struct PythonUdfExpression {
29 python_function: PyObject,
30 output_type: Option<DataType>,
31 is_elementwise: bool,
32 returns_scalar: bool,
33}
34
35impl PythonUdfExpression {
36 pub fn new(
37 lambda: PyObject,
38 output_type: Option<DataType>,
39 is_elementwise: bool,
40 returns_scalar: bool,
41 ) -> Self {
42 Self {
43 python_function: lambda,
44 output_type,
45 is_elementwise,
46 returns_scalar,
47 }
48 }
49
50 #[cfg(feature = "serde")]
51 pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn ColumnsUdf>> {
52 use polars_utils::pl_serialize;
55 debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK));
56 let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
57
58 let use_cloudpickle = buf[0];
60 if use_cloudpickle != 0 {
61 let ser_py_version = &buf[1..3];
62 let cur_py_version = *PYTHON3_VERSION;
63 polars_ensure!(
64 ser_py_version == cur_py_version,
65 InvalidOperation:
66 "current Python version {:?} does not match the Python version used to serialize the UDF {:?}",
67 (3, cur_py_version[0], cur_py_version[1]),
68 (3, ser_py_version[0], ser_py_version[1] )
69 );
70 }
71 let buf = &buf[3..];
72
73 let mut reader = Cursor::new(buf);
75 let (output_type, is_elementwise, returns_scalar): (Option<DataType>, bool, bool) =
76 pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?;
77
78 let remainder = &buf[reader.position() as usize..];
79
80 Python::with_gil(|py| {
82 let pickle = PyModule::import(py, "pickle")
83 .expect("unable to import 'pickle'")
84 .getattr("loads")
85 .unwrap();
86 let arg = (PyBytes::new(py, remainder),);
87 let python_function = pickle.call1(arg)?;
88 Ok(Arc::new(Self::new(
89 python_function.into(),
90 output_type,
91 is_elementwise,
92 returns_scalar,
93 )) as Arc<dyn ColumnsUdf>)
94 })
95 }
96}
97
98impl DataFrameUdf for polars_utils::python_function::PythonFunction {
99 fn call_udf(&self, df: DataFrame) -> PolarsResult<DataFrame> {
100 let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() };
101 func(df, &self.0)
102 }
103}
104
105impl ColumnsUdf for PythonUdfExpression {
106 fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>> {
107 let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };
108
109 let output_type = self
110 .output_type
111 .clone()
112 .unwrap_or_else(|| DataType::Unknown(Default::default()));
113 let mut out = func(s[0].clone(), &self.python_function)?;
114 if !matches!(output_type, DataType::Unknown(_)) {
115 let must_cast = out.dtype().matches_schema_type(&output_type).map_err(|_| {
116 polars_err!(
117 SchemaMismatch: "expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",
118 output_type, out.dtype(),
119 )
120 })?;
121 if must_cast {
122 out = out.cast(&output_type)?;
123 }
124 }
125
126 Ok(Some(out))
127 }
128
129 #[cfg(feature = "serde")]
130 fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
131 use polars_utils::pl_serialize;
134 buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
135
136 Python::with_gil(|py| {
137 let pickle = PyModule::import(py, "pickle")
139 .expect("unable to import 'pickle'")
140 .getattr("dumps")
141 .unwrap();
142 let pickle_result = pickle.call1((self.python_function.clone_ref(py),));
143 let (dumped, use_cloudpickle) = match pickle_result {
144 Ok(dumped) => (dumped, false),
145 Err(_) => {
146 let cloudpickle = PyModule::import(py, "cloudpickle")?
147 .getattr("dumps")
148 .unwrap();
149 let dumped = cloudpickle.call1((self.python_function.clone_ref(py),))?;
150 (dumped, true)
151 },
152 };
153
154 buf.push(use_cloudpickle as u8);
156 buf.extend_from_slice(&*PYTHON3_VERSION);
157
158 pl_serialize::serialize_into_writer::<_, _, true>(
160 &mut *buf,
161 &(
162 self.output_type.clone(),
163 self.is_elementwise,
164 self.returns_scalar,
165 ),
166 )?;
167
168 let dumped = dumped.extract::<PyBackedBytes>().unwrap();
170 buf.extend_from_slice(&dumped);
171 Ok(())
172 })
173 }
174}
175
176pub struct PythonGetOutput {
178 return_dtype: Option<DataType>,
179}
180
181impl PythonGetOutput {
182 pub fn new(return_dtype: Option<DataType>) -> Self {
183 Self { return_dtype }
184 }
185
186 #[cfg(feature = "serde")]
187 pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn FunctionOutputField>> {
188 use polars_utils::pl_serialize;
191 debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK));
192 let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
193
194 let mut reader = Cursor::new(buf);
195 let return_dtype: Option<DataType> =
196 pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?;
197
198 Ok(Arc::new(Self::new(return_dtype)) as Arc<dyn FunctionOutputField>)
199 }
200}
201
202impl FunctionOutputField for PythonGetOutput {
203 fn get_field(
204 &self,
205 _input_schema: &Schema,
206 _cntxt: Context,
207 fields: &[Field],
208 ) -> PolarsResult<Field> {
209 let name = fields[0].name();
211 let return_dtype = match self.return_dtype {
212 Some(ref dtype) => dtype.clone(),
213 None => DataType::Unknown(Default::default()),
214 };
215 Ok(Field::new(name.clone(), return_dtype))
216 }
217
218 #[cfg(feature = "serde")]
219 fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
220 use polars_utils::pl_serialize;
221
222 buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
223 pl_serialize::serialize_into_writer::<_, _, true>(&mut *buf, &self.return_dtype)
224 }
225}
226
227impl Expr {
228 pub fn map_python(self, func: PythonUdfExpression, agg_list: bool) -> Expr {
229 let name = if agg_list {
230 MAP_LIST_NAME
231 } else {
232 "python_udf"
233 };
234
235 let returns_scalar = func.returns_scalar;
236 let return_dtype = func.output_type.clone();
237
238 let output_field = PythonGetOutput::new(return_dtype);
239 let output_type = SpecialEq::new(Arc::new(output_field) as Arc<dyn FunctionOutputField>);
240
241 let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT;
242 if agg_list {
243 flags |= FunctionFlags::APPLY_LIST;
244 }
245 if func.is_elementwise {
246 flags.set_elementwise();
247 }
248 if returns_scalar {
249 flags |= FunctionFlags::RETURNS_SCALAR;
250 }
251
252 Expr::AnonymousFunction {
253 input: vec![self],
254 function: new_column_udf(func),
255 output_type,
256 options: FunctionOptions {
257 fmt_str: name,
258 flags,
259 ..Default::default()
260 },
261 }
262 }
263}