polars_plan/dsl/expr/
expr_dyn_fn.rs

1use std::fmt::Formatter;
2use std::ops::Deref;
3use std::sync::Arc;
4
5use polars_core::utils::try_get_supertype;
6
7use super::*;
8
9/// A wrapper trait for any closure `Fn(Vec<Series>) -> PolarsResult<Series>`
10pub trait ColumnsUdf: Send + Sync {
11    fn as_any(&self) -> &dyn std::any::Any {
12        unimplemented!("as_any not implemented for this 'opaque' function")
13    }
14
15    fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>>;
16
17    fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
18        polars_bail!(ComputeError: "serialization not supported for this 'opaque' function")
19    }
20}
21
22impl<F> ColumnsUdf for F
23where
24    F: Fn(&mut [Column]) -> PolarsResult<Option<Column>> + Send + Sync,
25{
26    fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>> {
27        self(s)
28    }
29}
30
31impl Debug for dyn ColumnsUdf {
32    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
33        write!(f, "ColumnUdf")
34    }
35}
36
37/// A wrapper trait for any binary closure `Fn(Column, Column) -> PolarsResult<Column>`
38pub trait ColumnBinaryUdf: Send + Sync {
39    fn call_udf(&self, a: Column, b: Column) -> PolarsResult<Column>;
40}
41
42impl<F> ColumnBinaryUdf for F
43where
44    F: Fn(Column, Column) -> PolarsResult<Column> + Send + Sync,
45{
46    fn call_udf(&self, a: Column, b: Column) -> PolarsResult<Column> {
47        self(a, b)
48    }
49}
50
51impl Debug for dyn ColumnBinaryUdf {
52    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53        write!(f, "ColumnBinaryUdf")
54    }
55}
56
57impl Default for SpecialEq<Arc<dyn ColumnBinaryUdf>> {
58    fn default() -> Self {
59        panic!("implementation error");
60    }
61}
62
63impl Default for SpecialEq<Arc<dyn BinaryUdfOutputField>> {
64    fn default() -> Self {
65        let output_field = move |_: &Schema, _: Context, _: &Field, _: &Field| None;
66        SpecialEq::new(Arc::new(output_field))
67    }
68}
69
70pub trait RenameAliasFn: Send + Sync {
71    fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr>;
72
73    fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
74        polars_bail!(ComputeError: "serialization not supported for this renaming function")
75    }
76}
77
78impl<F> RenameAliasFn for F
79where
80    F: Fn(&PlSmallStr) -> PolarsResult<PlSmallStr> + Send + Sync,
81{
82    fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr> {
83        self(name)
84    }
85}
86
87impl Debug for dyn RenameAliasFn {
88    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
89        write!(f, "RenameAliasFn")
90    }
91}
92
93#[derive(Clone)]
94/// Wrapper type that has special equality properties
95/// depending on the inner type specialization
96pub struct SpecialEq<T>(T);
97
98impl<T> SpecialEq<T> {
99    pub fn new(val: T) -> Self {
100        SpecialEq(val)
101    }
102}
103
104impl<T: ?Sized> PartialEq for SpecialEq<Arc<T>> {
105    fn eq(&self, other: &Self) -> bool {
106        Arc::ptr_eq(&self.0, &other.0)
107    }
108}
109
110impl<T> Eq for SpecialEq<Arc<T>> {}
111
112impl PartialEq for SpecialEq<Series> {
113    fn eq(&self, other: &Self) -> bool {
114        self.0 == other.0
115    }
116}
117
118impl<T> Debug for SpecialEq<T> {
119    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
120        write!(f, "no_eq")
121    }
122}
123
124impl<T> Deref for SpecialEq<T> {
125    type Target = T;
126
127    fn deref(&self) -> &Self::Target {
128        &self.0
129    }
130}
131
132pub trait BinaryUdfOutputField: Send + Sync {
133    fn get_field(
134        &self,
135        input_schema: &Schema,
136        cntxt: Context,
137        field_a: &Field,
138        field_b: &Field,
139    ) -> Option<Field>;
140}
141
142impl<F> BinaryUdfOutputField for F
143where
144    F: Fn(&Schema, Context, &Field, &Field) -> Option<Field> + Send + Sync,
145{
146    fn get_field(
147        &self,
148        input_schema: &Schema,
149        cntxt: Context,
150        field_a: &Field,
151        field_b: &Field,
152    ) -> Option<Field> {
153        self(input_schema, cntxt, field_a, field_b)
154    }
155}
156
157pub trait FunctionOutputField: Send + Sync {
158    fn get_field(
159        &self,
160        input_schema: &Schema,
161        cntxt: Context,
162        fields: &[Field],
163    ) -> PolarsResult<Field>;
164
165    fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
166        polars_bail!(ComputeError: "serialization not supported for this output field")
167    }
168}
169
170pub type GetOutput = LazySerde<SpecialEq<Arc<dyn FunctionOutputField>>>;
171
172impl Default for GetOutput {
173    fn default() -> Self {
174        LazySerde::Deserialized(SpecialEq::new(Arc::new(
175            |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()),
176        )))
177    }
178}
179
180impl GetOutput {
181    pub fn same_type() -> Self {
182        Default::default()
183    }
184
185    pub fn first() -> Self {
186        LazySerde::Deserialized(SpecialEq::new(Arc::new(
187            |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()),
188        )))
189    }
190
191    pub fn from_type(dt: DataType) -> Self {
192        LazySerde::Deserialized(SpecialEq::new(Arc::new(
193            move |_: &Schema, _: Context, flds: &[Field]| {
194                Ok(Field::new(flds[0].name().clone(), dt.clone()))
195            },
196        )))
197    }
198
199    pub fn map_field<F: 'static + Fn(&Field) -> PolarsResult<Field> + Send + Sync>(f: F) -> Self {
200        LazySerde::Deserialized(SpecialEq::new(Arc::new(
201            move |_: &Schema, _: Context, flds: &[Field]| f(&flds[0]),
202        )))
203    }
204
205    pub fn map_fields<F: 'static + Fn(&[Field]) -> PolarsResult<Field> + Send + Sync>(
206        f: F,
207    ) -> Self {
208        LazySerde::Deserialized(SpecialEq::new(Arc::new(
209            move |_: &Schema, _: Context, flds: &[Field]| f(flds),
210        )))
211    }
212
213    pub fn map_dtype<F: 'static + Fn(&DataType) -> PolarsResult<DataType> + Send + Sync>(
214        f: F,
215    ) -> Self {
216        LazySerde::Deserialized(SpecialEq::new(Arc::new(
217            move |_: &Schema, _: Context, flds: &[Field]| {
218                let mut fld = flds[0].clone();
219                let new_type = f(fld.dtype())?;
220                fld.coerce(new_type);
221                Ok(fld)
222            },
223        )))
224    }
225
226    pub fn float_type() -> Self {
227        Self::map_dtype(|dt| {
228            Ok(match dt {
229                DataType::Float32 => DataType::Float32,
230                _ => DataType::Float64,
231            })
232        })
233    }
234
235    pub fn super_type() -> Self {
236        Self::map_dtypes(|dtypes| {
237            let mut st = dtypes[0].clone();
238            for dt in &dtypes[1..] {
239                st = try_get_supertype(&st, dt)?;
240            }
241            Ok(st)
242        })
243    }
244
245    pub fn map_dtypes<F>(f: F) -> Self
246    where
247        F: 'static + Fn(&[&DataType]) -> PolarsResult<DataType> + Send + Sync,
248    {
249        LazySerde::Deserialized(SpecialEq::new(Arc::new(
250            move |_: &Schema, _: Context, flds: &[Field]| {
251                let mut fld = flds[0].clone();
252                let dtypes = flds.iter().map(|fld| fld.dtype()).collect::<Vec<_>>();
253                let new_type = f(&dtypes)?;
254                fld.coerce(new_type);
255                Ok(fld)
256            },
257        )))
258    }
259}
260
261impl<F> FunctionOutputField for F
262where
263    F: Fn(&Schema, Context, &[Field]) -> PolarsResult<Field> + Send + Sync,
264{
265    fn get_field(
266        &self,
267        input_schema: &Schema,
268        cntxt: Context,
269        fields: &[Field],
270    ) -> PolarsResult<Field> {
271        self(input_schema, cntxt, fields)
272    }
273}
274
275pub type OpaqueColumnUdf = LazySerde<SpecialEq<Arc<dyn ColumnsUdf>>>;
276pub(crate) fn new_column_udf<F: ColumnsUdf + 'static>(func: F) -> OpaqueColumnUdf {
277    LazySerde::Deserialized(SpecialEq::new(Arc::new(func)))
278}
279
280impl OpaqueColumnUdf {
281    pub fn materialize(self) -> PolarsResult<SpecialEq<Arc<dyn ColumnsUdf>>> {
282        match self {
283            Self::Deserialized(t) => Ok(t),
284            Self::Named {
285                name: _,
286                payload: _,
287                value: _,
288            } => {
289                panic!("should not be hit")
290            },
291            Self::Bytes(_b) => {
292                feature_gated!("serde";"python", {
293                    serde_expr::deserialize_column_udf(_b.as_ref()).map(SpecialEq::new)
294                })
295            },
296        }
297    }
298}
299
300impl GetOutput {
301    pub fn materialize(self) -> PolarsResult<SpecialEq<Arc<dyn FunctionOutputField>>> {
302        match self {
303            Self::Deserialized(t) => Ok(t),
304            Self::Named {
305                name: _,
306                payload: _,
307                value,
308            } => value.ok_or_else(|| polars_err!(ComputeError: "GetOutput Value not set")),
309            Self::Bytes(_b) => {
310                polars_bail!(ComputeError: "should not be hit")
311            },
312        }
313    }
314}