polars_plan/dsl/python_dsl/
python_udf.rs

1use std::io::Cursor;
2use std::sync::Arc;
3
4use polars_core::datatypes::{DataType, Field};
5use polars_core::error::*;
6use polars_core::frame::DataFrame;
7use polars_core::frame::column::Column;
8use polars_core::schema::Schema;
9use pyo3::prelude::*;
10use pyo3::pybacked::PyBackedBytes;
11use pyo3::types::PyBytes;
12
13use crate::constants::MAP_LIST_NAME;
14use crate::prelude::*;
15
16// Will be overwritten on Python Polars start up.
17pub static mut CALL_COLUMNS_UDF_PYTHON: Option<
18    fn(s: Column, lambda: &PyObject) -> PolarsResult<Column>,
19> = None;
20pub static mut CALL_DF_UDF_PYTHON: Option<
21    fn(s: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame>,
22> = None;
23
24pub use polars_utils::python_function::{
25    PYTHON_SERDE_MAGIC_BYTE_MARK, PYTHON3_VERSION, PythonFunction,
26};
27
28pub struct PythonUdfExpression {
29    python_function: PyObject,
30    output_type: Option<DataType>,
31    is_elementwise: bool,
32    returns_scalar: bool,
33}
34
35impl PythonUdfExpression {
36    pub fn new(
37        lambda: PyObject,
38        output_type: Option<DataType>,
39        is_elementwise: bool,
40        returns_scalar: bool,
41    ) -> Self {
42        Self {
43            python_function: lambda,
44            output_type,
45            is_elementwise,
46            returns_scalar,
47        }
48    }
49
50    #[cfg(feature = "serde")]
51    pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn ColumnsUdf>> {
52        // Handle byte mark
53
54        use polars_utils::pl_serialize;
55        debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK));
56        let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
57
58        // Handle pickle metadata
59        let use_cloudpickle = buf[0];
60        if use_cloudpickle != 0 {
61            let ser_py_version = &buf[1..3];
62            let cur_py_version = *PYTHON3_VERSION;
63            polars_ensure!(
64                ser_py_version == cur_py_version,
65                InvalidOperation:
66                "current Python version {:?} does not match the Python version used to serialize the UDF {:?}",
67                (3, cur_py_version[0], cur_py_version[1]),
68                (3, ser_py_version[0], ser_py_version[1] )
69            );
70        }
71        let buf = &buf[3..];
72
73        // Load UDF metadata
74        let mut reader = Cursor::new(buf);
75        let (output_type, is_elementwise, returns_scalar): (Option<DataType>, bool, bool) =
76            pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?;
77
78        let remainder = &buf[reader.position() as usize..];
79
80        // Load UDF
81        Python::with_gil(|py| {
82            let pickle = PyModule::import(py, "pickle")
83                .expect("unable to import 'pickle'")
84                .getattr("loads")
85                .unwrap();
86            let arg = (PyBytes::new(py, remainder),);
87            let python_function = pickle.call1(arg)?;
88            Ok(Arc::new(Self::new(
89                python_function.into(),
90                output_type,
91                is_elementwise,
92                returns_scalar,
93            )) as Arc<dyn ColumnsUdf>)
94        })
95    }
96}
97
98impl DataFrameUdf for polars_utils::python_function::PythonFunction {
99    fn call_udf(&self, df: DataFrame) -> PolarsResult<DataFrame> {
100        let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() };
101        func(df, &self.0)
102    }
103}
104
105impl ColumnsUdf for PythonUdfExpression {
106    fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>> {
107        let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };
108
109        let output_type = self
110            .output_type
111            .clone()
112            .unwrap_or_else(|| DataType::Unknown(Default::default()));
113        let mut out = func(s[0].clone(), &self.python_function)?;
114        if !matches!(output_type, DataType::Unknown(_)) {
115            let must_cast = out.dtype().matches_schema_type(&output_type).map_err(|_| {
116                polars_err!(
117                    SchemaMismatch: "expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",
118                    output_type, out.dtype(),
119                )
120            })?;
121            if must_cast {
122                out = out.cast(&output_type)?;
123            }
124        }
125
126        Ok(Some(out))
127    }
128
129    #[cfg(feature = "serde")]
130    fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
131        // Write byte marks
132
133        use polars_utils::pl_serialize;
134        buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
135
136        Python::with_gil(|py| {
137            // Try pickle to serialize the UDF, otherwise fall back to cloudpickle.
138            let pickle = PyModule::import(py, "pickle")
139                .expect("unable to import 'pickle'")
140                .getattr("dumps")
141                .unwrap();
142            let pickle_result = pickle.call1((self.python_function.clone_ref(py),));
143            let (dumped, use_cloudpickle) = match pickle_result {
144                Ok(dumped) => (dumped, false),
145                Err(_) => {
146                    let cloudpickle = PyModule::import(py, "cloudpickle")?
147                        .getattr("dumps")
148                        .unwrap();
149                    let dumped = cloudpickle.call1((self.python_function.clone_ref(py),))?;
150                    (dumped, true)
151                },
152            };
153
154            // Write pickle metadata
155            buf.push(use_cloudpickle as u8);
156            buf.extend_from_slice(&*PYTHON3_VERSION);
157
158            // Write UDF metadata
159            pl_serialize::serialize_into_writer::<_, _, true>(
160                &mut *buf,
161                &(
162                    self.output_type.clone(),
163                    self.is_elementwise,
164                    self.returns_scalar,
165                ),
166            )?;
167
168            // Write UDF
169            let dumped = dumped.extract::<PyBackedBytes>().unwrap();
170            buf.extend_from_slice(&dumped);
171            Ok(())
172        })
173    }
174}
175
176/// Serializable version of [`GetOutput`] for Python UDFs.
177pub struct PythonGetOutput {
178    return_dtype: Option<DataType>,
179}
180
181impl PythonGetOutput {
182    pub fn new(return_dtype: Option<DataType>) -> Self {
183        Self { return_dtype }
184    }
185
186    #[cfg(feature = "serde")]
187    pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn FunctionOutputField>> {
188        // Skip header.
189
190        use polars_utils::pl_serialize;
191        debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK));
192        let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
193
194        let mut reader = Cursor::new(buf);
195        let return_dtype: Option<DataType> =
196            pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?;
197
198        Ok(Arc::new(Self::new(return_dtype)) as Arc<dyn FunctionOutputField>)
199    }
200}
201
202impl FunctionOutputField for PythonGetOutput {
203    fn get_field(
204        &self,
205        _input_schema: &Schema,
206        _cntxt: Context,
207        fields: &[Field],
208    ) -> PolarsResult<Field> {
209        // Take the name of first field, just like [`GetOutput::map_field`].
210        let name = fields[0].name();
211        let return_dtype = match self.return_dtype {
212            Some(ref dtype) => dtype.clone(),
213            None => DataType::Unknown(Default::default()),
214        };
215        Ok(Field::new(name.clone(), return_dtype))
216    }
217
218    #[cfg(feature = "serde")]
219    fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
220        use polars_utils::pl_serialize;
221
222        buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
223        pl_serialize::serialize_into_writer::<_, _, true>(&mut *buf, &self.return_dtype)
224    }
225}
226
227impl Expr {
228    pub fn map_python(self, func: PythonUdfExpression, agg_list: bool) -> Expr {
229        let name = if agg_list {
230            MAP_LIST_NAME
231        } else {
232            "python_udf"
233        };
234
235        let returns_scalar = func.returns_scalar;
236        let return_dtype = func.output_type.clone();
237
238        let output_field = PythonGetOutput::new(return_dtype);
239        let output_type = SpecialEq::new(Arc::new(output_field) as Arc<dyn FunctionOutputField>);
240
241        let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT;
242        if agg_list {
243            flags |= FunctionFlags::APPLY_LIST;
244        }
245        if func.is_elementwise {
246            flags.set_elementwise();
247        }
248        if returns_scalar {
249            flags |= FunctionFlags::RETURNS_SCALAR;
250        }
251
252        Expr::AnonymousFunction {
253            input: vec![self],
254            function: new_column_udf(func),
255            output_type,
256            options: FunctionOptions {
257                fmt_str: name,
258                flags,
259                ..Default::default()
260            },
261        }
262    }
263}