polars_plan/dsl/python_dsl/
python_udf.rs

1use std::io::Cursor;
2use std::sync::{Arc, OnceLock};
3
4use polars_core::datatypes::{DataType, Field};
5use polars_core::error::*;
6use polars_core::frame::DataFrame;
7use polars_core::frame::column::Column;
8use polars_core::schema::Schema;
9use polars_utils::pl_str::PlSmallStr;
10use pyo3::prelude::*;
11
12use crate::dsl::udf::try_infer_udf_output_dtype;
13use crate::prelude::*;
14
15// Will be overwritten on Python Polars start up.
16#[allow(clippy::type_complexity)]
17pub static mut CALL_COLUMNS_UDF_PYTHON: Option<
18    fn(s: &[Column], output_dtype: Option<DataType>, lambda: &PyObject) -> PolarsResult<Column>,
19> = None;
20pub static mut CALL_DF_UDF_PYTHON: Option<
21    fn(s: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame>,
22> = None;
23
24pub use polars_utils::python_function::PythonFunction;
25#[cfg(feature = "serde")]
26pub use polars_utils::python_function::{PYTHON_SERDE_MAGIC_BYTE_MARK, PYTHON3_VERSION};
27
28pub struct PythonUdfExpression {
29    python_function: PyObject,
30    output_type: Option<DataTypeExpr>,
31    materialized_field: OnceLock<Field>,
32    is_elementwise: bool,
33    returns_scalar: bool,
34}
35
36impl PythonUdfExpression {
37    pub fn new(
38        lambda: PyObject,
39        output_type: Option<impl Into<DataTypeExpr>>,
40        is_elementwise: bool,
41        returns_scalar: bool,
42    ) -> Self {
43        let output_type = output_type.map(Into::into);
44        Self {
45            python_function: lambda,
46            output_type,
47            materialized_field: OnceLock::new(),
48            is_elementwise,
49            returns_scalar,
50        }
51    }
52
53    #[cfg(feature = "serde")]
54    pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn AnonymousColumnsUdf>> {
55        use polars_utils::pl_serialize;
56
57        if !buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK) {
58            polars_bail!(InvalidOperation: "serialization expected python magic byte mark");
59        }
60        let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
61
62        // Load UDF metadata
63        let mut reader = Cursor::new(buf);
64        let (output_type, materialized, is_elementwise, returns_scalar): (
65            Option<DataTypeExpr>,
66            Option<Field>,
67            bool,
68            bool,
69        ) = pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?;
70
71        let buf = &buf[reader.position() as usize..];
72        let python_function = pl_serialize::python_object_deserialize(buf)?;
73
74        let mut udf = Self::new(python_function, output_type, is_elementwise, returns_scalar);
75        if let Some(materialized) = materialized {
76            udf.materialized_field = OnceLock::from(materialized);
77        }
78
79        Ok(Arc::new(udf))
80    }
81}
82
83impl DataFrameUdf for polars_utils::python_function::PythonFunction {
84    fn call_udf(&self, df: DataFrame) -> PolarsResult<DataFrame> {
85        let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() };
86        func(df, &self.0)
87    }
88}
89
90impl ColumnsUdf for PythonUdfExpression {
91    fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Column> {
92        let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };
93        let field = self
94            .materialized_field
95            .get()
96            .expect("should have been materialized at this point");
97        let mut out = func(
98            s,
99            self.materialized_field.get().map(|f| f.dtype.clone()),
100            &self.python_function,
101        )?;
102
103        let must_cast = out.dtype().matches_schema_type(field.dtype()).map_err(|_| {
104            polars_err!(
105                SchemaMismatch: "expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",
106                field.dtype(), out.dtype(),
107            )
108        })?;
109        if must_cast {
110            out = out.cast(field.dtype())?;
111        }
112
113        Ok(out)
114    }
115}
116
117impl AnonymousColumnsUdf for PythonUdfExpression {
118    fn as_column_udf(self: Arc<Self>) -> Arc<dyn ColumnsUdf> {
119        self as _
120    }
121    fn deep_clone(self: Arc<Self>) -> Arc<dyn AnonymousColumnsUdf> {
122        Arc::new(Self {
123            python_function: Python::with_gil(|py| self.python_function.clone_ref(py)),
124            output_type: self.output_type.clone(),
125            materialized_field: OnceLock::new(),
126            is_elementwise: self.is_elementwise,
127            returns_scalar: self.returns_scalar,
128        }) as _
129    }
130
131    #[cfg(feature = "serde")]
132    fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
133        use polars_utils::pl_serialize;
134
135        // Write byte marks
136        buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
137
138        // Write UDF metadata
139        pl_serialize::serialize_into_writer::<_, _, true>(
140            &mut *buf,
141            &(
142                self.output_type.clone(),
143                self.materialized_field.get().cloned(),
144                self.is_elementwise,
145                self.returns_scalar,
146            ),
147        )?;
148
149        pl_serialize::python_object_serialize(&self.python_function, buf)?;
150        Ok(())
151    }
152
153    fn get_field(&self, input_schema: &Schema, fields: &[Field]) -> PolarsResult<Field> {
154        let field = match self.materialized_field.get() {
155            Some(f) => f.clone(),
156            None => {
157                let dtype = match self.output_type.as_ref() {
158                    None => {
159                        let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };
160                        let f = |s: &[Column]| func(s, None, &self.python_function);
161                        try_infer_udf_output_dtype(&f as _, fields)?
162                    },
163                    Some(output_type) => output_type
164                        .clone()
165                        .into_datatype_with_self(input_schema, fields[0].dtype())?,
166                };
167
168                // Take the name of first field, just like `map_field`.
169                let name = fields[0].name();
170                let f = Field::new(name.clone(), dtype);
171                self.materialized_field.get_or_init(|| f.clone());
172                f
173            },
174        };
175        Ok(field)
176    }
177}
178
179impl Expr {
180    pub fn map_python(self, func: PythonUdfExpression) -> Expr {
181        Self::map_many_python(vec![self], func)
182    }
183
184    pub fn map_many_python(exprs: Vec<Expr>, func: PythonUdfExpression) -> Expr {
185        const NAME: &str = "python_udf";
186
187        let returns_scalar = func.returns_scalar;
188
189        let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT;
190        if func.is_elementwise {
191            flags.set_elementwise();
192        }
193        if returns_scalar {
194            flags |= FunctionFlags::RETURNS_SCALAR;
195        }
196
197        Expr::AnonymousFunction {
198            input: exprs,
199            function: new_column_udf(func),
200            options: FunctionOptions {
201                flags,
202                ..Default::default()
203            },
204            fmt_str: Box::new(PlSmallStr::from(NAME)),
205        }
206    }
207}