polars_plan/dsl/python_dsl/
python_udf.rs1use std::io::Cursor;
2use std::sync::{Arc, OnceLock};
3
4use polars_core::datatypes::{DataType, Field};
5use polars_core::error::*;
6use polars_core::frame::DataFrame;
7use polars_core::frame::column::Column;
8use polars_core::schema::Schema;
9use polars_utils::pl_str::PlSmallStr;
10use pyo3::prelude::*;
11
12use crate::dsl::udf::try_infer_udf_output_dtype;
13use crate::prelude::*;
14
15#[allow(clippy::type_complexity)]
17pub static mut CALL_COLUMNS_UDF_PYTHON: Option<
18 fn(s: &[Column], output_dtype: Option<DataType>, lambda: &PyObject) -> PolarsResult<Column>,
19> = None;
20pub static mut CALL_DF_UDF_PYTHON: Option<
21 fn(s: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame>,
22> = None;
23
24pub use polars_utils::python_function::PythonFunction;
25#[cfg(feature = "serde")]
26pub use polars_utils::python_function::{PYTHON_SERDE_MAGIC_BYTE_MARK, PYTHON3_VERSION};
27
28pub struct PythonUdfExpression {
29 python_function: PyObject,
30 output_type: Option<DataTypeExpr>,
31 materialized_field: OnceLock<Field>,
32 is_elementwise: bool,
33 returns_scalar: bool,
34}
35
36impl PythonUdfExpression {
37 pub fn new(
38 lambda: PyObject,
39 output_type: Option<impl Into<DataTypeExpr>>,
40 is_elementwise: bool,
41 returns_scalar: bool,
42 ) -> Self {
43 let output_type = output_type.map(Into::into);
44 Self {
45 python_function: lambda,
46 output_type,
47 materialized_field: OnceLock::new(),
48 is_elementwise,
49 returns_scalar,
50 }
51 }
52
53 #[cfg(feature = "serde")]
54 pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn AnonymousColumnsUdf>> {
55 use polars_utils::pl_serialize;
56
57 if !buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK) {
58 polars_bail!(InvalidOperation: "serialization expected python magic byte mark");
59 }
60 let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
61
62 let mut reader = Cursor::new(buf);
64 let (output_type, materialized, is_elementwise, returns_scalar): (
65 Option<DataTypeExpr>,
66 Option<Field>,
67 bool,
68 bool,
69 ) = pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?;
70
71 let buf = &buf[reader.position() as usize..];
72 let python_function = pl_serialize::python_object_deserialize(buf)?;
73
74 let mut udf = Self::new(python_function, output_type, is_elementwise, returns_scalar);
75 if let Some(materialized) = materialized {
76 udf.materialized_field = OnceLock::from(materialized);
77 }
78
79 Ok(Arc::new(udf))
80 }
81}
82
83impl DataFrameUdf for polars_utils::python_function::PythonFunction {
84 fn call_udf(&self, df: DataFrame) -> PolarsResult<DataFrame> {
85 let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() };
86 func(df, &self.0)
87 }
88}
89
90impl ColumnsUdf for PythonUdfExpression {
91 fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Column> {
92 let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };
93 let field = self
94 .materialized_field
95 .get()
96 .expect("should have been materialized at this point");
97 let mut out = func(
98 s,
99 self.materialized_field.get().map(|f| f.dtype.clone()),
100 &self.python_function,
101 )?;
102
103 let must_cast = out.dtype().matches_schema_type(field.dtype()).map_err(|_| {
104 polars_err!(
105 SchemaMismatch: "expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",
106 field.dtype(), out.dtype(),
107 )
108 })?;
109 if must_cast {
110 out = out.cast(field.dtype())?;
111 }
112
113 Ok(out)
114 }
115}
116
117impl AnonymousColumnsUdf for PythonUdfExpression {
118 fn as_column_udf(self: Arc<Self>) -> Arc<dyn ColumnsUdf> {
119 self as _
120 }
121 fn deep_clone(self: Arc<Self>) -> Arc<dyn AnonymousColumnsUdf> {
122 Arc::new(Self {
123 python_function: Python::with_gil(|py| self.python_function.clone_ref(py)),
124 output_type: self.output_type.clone(),
125 materialized_field: OnceLock::new(),
126 is_elementwise: self.is_elementwise,
127 returns_scalar: self.returns_scalar,
128 }) as _
129 }
130
131 #[cfg(feature = "serde")]
132 fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
133 use polars_utils::pl_serialize;
134
135 buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
137
138 pl_serialize::serialize_into_writer::<_, _, true>(
140 &mut *buf,
141 &(
142 self.output_type.clone(),
143 self.materialized_field.get().cloned(),
144 self.is_elementwise,
145 self.returns_scalar,
146 ),
147 )?;
148
149 pl_serialize::python_object_serialize(&self.python_function, buf)?;
150 Ok(())
151 }
152
153 fn get_field(&self, input_schema: &Schema, fields: &[Field]) -> PolarsResult<Field> {
154 let field = match self.materialized_field.get() {
155 Some(f) => f.clone(),
156 None => {
157 let dtype = match self.output_type.as_ref() {
158 None => {
159 let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };
160 let f = |s: &[Column]| func(s, None, &self.python_function);
161 try_infer_udf_output_dtype(&f as _, fields)?
162 },
163 Some(output_type) => output_type
164 .clone()
165 .into_datatype_with_self(input_schema, fields[0].dtype())?,
166 };
167
168 let name = fields[0].name();
170 let f = Field::new(name.clone(), dtype);
171 self.materialized_field.get_or_init(|| f.clone());
172 f
173 },
174 };
175 Ok(field)
176 }
177}
178
179impl Expr {
180 pub fn map_python(self, func: PythonUdfExpression) -> Expr {
181 Self::map_many_python(vec![self], func)
182 }
183
184 pub fn map_many_python(exprs: Vec<Expr>, func: PythonUdfExpression) -> Expr {
185 const NAME: &str = "python_udf";
186
187 let returns_scalar = func.returns_scalar;
188
189 let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT;
190 if func.is_elementwise {
191 flags.set_elementwise();
192 }
193 if returns_scalar {
194 flags |= FunctionFlags::RETURNS_SCALAR;
195 }
196
197 Expr::AnonymousFunction {
198 input: exprs,
199 function: new_column_udf(func),
200 options: FunctionOptions {
201 flags,
202 ..Default::default()
203 },
204 fmt_str: Box::new(PlSmallStr::from(NAME)),
205 }
206 }
207}