polars_plan/dsl/python_dsl/
python_udf.rs

1use std::io::Cursor;
2use std::sync::{Arc, OnceLock};
3
4use polars_core::datatypes::{DataType, Field};
5use polars_core::error::*;
6use polars_core::frame::DataFrame;
7use polars_core::frame::column::Column;
8use polars_core::schema::Schema;
9use polars_utils::pl_str::PlSmallStr;
10use pyo3::prelude::*;
11use pyo3::pybacked::PyBackedBytes;
12use pyo3::types::PyBytes;
13
14use crate::constants::MAP_LIST_NAME;
15use crate::prelude::*;
16
17// Will be overwritten on Python Polars start up.
18#[allow(clippy::type_complexity)]
19pub static mut CALL_COLUMNS_UDF_PYTHON: Option<
20    fn(s: Column, output_dtype: Option<DataType>, lambda: &PyObject) -> PolarsResult<Column>,
21> = None;
22pub static mut CALL_DF_UDF_PYTHON: Option<
23    fn(s: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame>,
24> = None;
25
26pub use polars_utils::python_function::PythonFunction;
27#[cfg(feature = "serde")]
28pub use polars_utils::python_function::{PYTHON_SERDE_MAGIC_BYTE_MARK, PYTHON3_VERSION};
29
30pub struct PythonUdfExpression {
31    python_function: PyObject,
32    output_type: Option<DataTypeExpr>,
33    materialized_output_type: OnceLock<DataType>,
34    is_elementwise: bool,
35    returns_scalar: bool,
36}
37
38impl PythonUdfExpression {
39    pub fn new(
40        lambda: PyObject,
41        output_type: Option<impl Into<DataTypeExpr>>,
42        is_elementwise: bool,
43        returns_scalar: bool,
44    ) -> Self {
45        let output_type = output_type.map(Into::into);
46        Self {
47            python_function: lambda,
48            output_type,
49            materialized_output_type: OnceLock::new(),
50            is_elementwise,
51            returns_scalar,
52        }
53    }
54
55    #[cfg(feature = "serde")]
56    pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn ColumnsUdf>> {
57        // Handle byte mark
58
59        use polars_utils::pl_serialize;
60        debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK));
61        let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
62
63        // Handle pickle metadata
64        let use_cloudpickle = buf[0];
65        if use_cloudpickle != 0 {
66            let ser_py_version = &buf[1..3];
67            let cur_py_version = *PYTHON3_VERSION;
68            polars_ensure!(
69                ser_py_version == cur_py_version,
70                InvalidOperation:
71                "current Python version {:?} does not match the Python version used to serialize the UDF {:?}",
72                (3, cur_py_version[0], cur_py_version[1]),
73                (3, ser_py_version[0], ser_py_version[1] )
74            );
75        }
76        let buf = &buf[3..];
77
78        // Load UDF metadata
79        let mut reader = Cursor::new(buf);
80        let (output_type, is_elementwise, returns_scalar): (Option<DataTypeExpr>, bool, bool) =
81            pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?;
82
83        let remainder = &buf[reader.position() as usize..];
84
85        // Load UDF
86        Python::with_gil(|py| {
87            let pickle = PyModule::import(py, "pickle")
88                .expect("unable to import 'pickle'")
89                .getattr("loads")
90                .unwrap();
91            let arg = (PyBytes::new(py, remainder),);
92            let python_function = pickle.call1(arg)?;
93            Ok(Arc::new(Self::new(
94                python_function.into(),
95                output_type,
96                is_elementwise,
97                returns_scalar,
98            )) as Arc<dyn ColumnsUdf>)
99        })
100    }
101}
102
103impl DataFrameUdf for polars_utils::python_function::PythonFunction {
104    fn call_udf(&self, df: DataFrame) -> PolarsResult<DataFrame> {
105        let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() };
106        func(df, &self.0)
107    }
108}
109
110impl ColumnsUdf for PythonUdfExpression {
111    fn resolve_dsl(&self, input_schema: &Schema) -> PolarsResult<()> {
112        if let Some(output_type) = self.output_type.as_ref() {
113            let dtype = output_type.clone().into_datatype(input_schema)?;
114            self.materialized_output_type.get_or_init(|| dtype);
115        }
116        Ok(())
117    }
118
119    fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>> {
120        let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };
121
122        let output_type = self
123            .materialized_output_type
124            .get()
125            .map_or_else(|| DataType::Unknown(Default::default()), |dt| dt.clone());
126        let mut out = func(
127            s[0].clone(),
128            self.materialized_output_type.get().cloned(),
129            &self.python_function,
130        )?;
131        if !matches!(output_type, DataType::Unknown(_)) {
132            let must_cast = out.dtype().matches_schema_type(&output_type).map_err(|_| {
133                polars_err!(
134                    SchemaMismatch: "expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",
135                    output_type, out.dtype(),
136                )
137            })?;
138            if must_cast {
139                out = out.cast(&output_type)?;
140            }
141        }
142
143        Ok(Some(out))
144    }
145
146    #[cfg(feature = "serde")]
147    fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
148        // Write byte marks
149
150        use polars_utils::pl_serialize;
151        buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
152
153        Python::with_gil(|py| {
154            // Try pickle to serialize the UDF, otherwise fall back to cloudpickle.
155            let pickle = PyModule::import(py, "pickle")
156                .expect("unable to import 'pickle'")
157                .getattr("dumps")
158                .unwrap();
159            let pickle_result = pickle.call1((self.python_function.clone_ref(py),));
160            let (dumped, use_cloudpickle) = match pickle_result {
161                Ok(dumped) => (dumped, false),
162                Err(_) => {
163                    let cloudpickle = PyModule::import(py, "cloudpickle")?
164                        .getattr("dumps")
165                        .unwrap();
166                    let dumped = cloudpickle.call1((self.python_function.clone_ref(py),))?;
167                    (dumped, true)
168                },
169            };
170
171            // Write pickle metadata
172            buf.push(use_cloudpickle as u8);
173            buf.extend_from_slice(&*PYTHON3_VERSION);
174
175            // Write UDF metadata
176            pl_serialize::serialize_into_writer::<_, _, true>(
177                &mut *buf,
178                &(
179                    self.output_type.clone(),
180                    self.is_elementwise,
181                    self.returns_scalar,
182                ),
183            )?;
184
185            // Write UDF
186            let dumped = dumped.extract::<PyBackedBytes>().unwrap();
187            buf.extend_from_slice(&dumped);
188            Ok(())
189        })
190    }
191}
192
193/// Serializable version of [`GetOutput`] for Python UDFs.
194pub struct PythonGetOutput {
195    return_dtype: Option<DataTypeExpr>,
196    materialized_output_type: OnceLock<DataType>,
197}
198
199impl PythonGetOutput {
200    pub fn new(return_dtype: Option<impl Into<DataTypeExpr>>) -> Self {
201        Self {
202            return_dtype: return_dtype.map(Into::into),
203            materialized_output_type: OnceLock::new(),
204        }
205    }
206
207    #[cfg(feature = "serde")]
208    pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn FunctionOutputField>> {
209        // Skip header.
210
211        use polars_utils::pl_serialize;
212        debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK));
213        let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
214
215        let mut reader = Cursor::new(buf);
216        let return_dtype: Option<DataTypeExpr> =
217            pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?;
218
219        Ok(Arc::new(Self::new(return_dtype)) as Arc<dyn FunctionOutputField>)
220    }
221}
222
223impl FunctionOutputField for PythonGetOutput {
224    fn resolve_dsl(&self, input_schema: &Schema) -> PolarsResult<()> {
225        if let Some(output_type) = self.return_dtype.as_ref() {
226            let dtype = output_type.clone().into_datatype(input_schema)?;
227            self.materialized_output_type.get_or_init(|| dtype);
228        }
229        Ok(())
230    }
231
232    fn get_field(
233        &self,
234        _input_schema: &Schema,
235        _cntxt: Context,
236        fields: &[Field],
237    ) -> PolarsResult<Field> {
238        // Take the name of first field, just like [`GetOutput::map_field`].
239        let name = fields[0].name();
240        let return_dtype = match self.materialized_output_type.get() {
241            Some(dtype) => dtype.clone(),
242            None => DataType::Unknown(Default::default()),
243        };
244        Ok(Field::new(name.clone(), return_dtype))
245    }
246
247    #[cfg(feature = "serde")]
248    fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
249        use polars_utils::pl_serialize;
250
251        buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
252        pl_serialize::serialize_into_writer::<_, _, true>(&mut *buf, &self.return_dtype)
253    }
254}
255
256impl Expr {
257    pub fn map_python(self, func: PythonUdfExpression, agg_list: bool) -> Expr {
258        let name = if agg_list {
259            MAP_LIST_NAME
260        } else {
261            "python_udf"
262        };
263
264        let returns_scalar = func.returns_scalar;
265        let return_dtype = func.output_type.clone();
266
267        let output_field = PythonGetOutput::new(return_dtype);
268        let output_type = LazySerde::Deserialized(SpecialEq::new(
269            Arc::new(output_field) as Arc<dyn FunctionOutputField>
270        ));
271
272        let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT;
273        if agg_list {
274            flags |= FunctionFlags::APPLY_LIST;
275        }
276        if func.is_elementwise {
277            flags.set_elementwise();
278        }
279        if returns_scalar {
280            flags |= FunctionFlags::RETURNS_SCALAR;
281        }
282
283        Expr::AnonymousFunction {
284            input: vec![self],
285            function: new_column_udf(func),
286            output_type,
287            options: FunctionOptions {
288                flags,
289                ..Default::default()
290            },
291            fmt_str: Box::new(PlSmallStr::from(name)),
292        }
293    }
294}