polars_plan/dsl/
expr_dyn_fn.rs

1use std::fmt::Formatter;
2use std::ops::Deref;
3use std::sync::Arc;
4
5#[cfg(feature = "python")]
6use polars_utils::pl_serialize::deserialize_map_bytes;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Deserializer, Serialize, Serializer};
9
10use super::*;
11
12/// A wrapper trait for any closure `Fn(Vec<Series>) -> PolarsResult<Series>`
13pub trait ColumnsUdf: Send + Sync {
14    fn as_any(&self) -> &dyn std::any::Any {
15        unimplemented!("as_any not implemented for this 'opaque' function")
16    }
17
18    fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>>;
19
20    fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
21        polars_bail!(ComputeError: "serialization not supported for this 'opaque' function")
22    }
23}
24
25#[cfg(feature = "serde")]
26impl Serialize for SpecialEq<Arc<dyn ColumnsUdf>> {
27    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
28    where
29        S: Serializer,
30    {
31        use serde::ser::Error;
32        let mut buf = vec![];
33        self.0
34            .try_serialize(&mut buf)
35            .map_err(|e| S::Error::custom(format!("{e}")))?;
36        serializer.serialize_bytes(&buf)
37    }
38}
39
40#[cfg(feature = "serde")]
41impl<T: Serialize + Clone> Serialize for LazySerde<T> {
42    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
43    where
44        S: Serializer,
45    {
46        match self {
47            Self::Deserialized(t) => t.serialize(serializer),
48            Self::Bytes(b) => b.serialize(serializer),
49        }
50    }
51}
52
53#[cfg(feature = "serde")]
54impl<'a, T: Deserialize<'a> + Clone> Deserialize<'a> for LazySerde<T> {
55    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
56    where
57        D: Deserializer<'a>,
58    {
59        let buf = bytes::Bytes::deserialize(deserializer)?;
60        Ok(Self::Bytes(buf))
61    }
62}
63
64#[cfg(feature = "serde")]
65// impl<T: Deserialize> Deserialize for crate::dsl::expr::LazySerde<T> {
66impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn ColumnsUdf>> {
67    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
68    where
69        D: Deserializer<'a>,
70    {
71        use serde::de::Error;
72        #[cfg(feature = "python")]
73        {
74            deserialize_map_bytes(deserializer, |buf| {
75                if buf.starts_with(crate::dsl::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) {
76                    let udf = crate::dsl::python_dsl::PythonUdfExpression::try_deserialize(&buf)
77                        .map_err(|e| D::Error::custom(format!("{e}")))?;
78                    Ok(SpecialEq::new(udf))
79                } else {
80                    Err(D::Error::custom(
81                        "deserialization not supported for this 'opaque' function",
82                    ))
83                }
84            })?
85        }
86        #[cfg(not(feature = "python"))]
87        {
88            _ = deserializer;
89
90            Err(D::Error::custom(
91                "deserialization not supported for this 'opaque' function",
92            ))
93        }
94    }
95}
96
97impl<F> ColumnsUdf for F
98where
99    F: Fn(&mut [Column]) -> PolarsResult<Option<Column>> + Send + Sync,
100{
101    fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>> {
102        self(s)
103    }
104}
105
106impl Debug for dyn ColumnsUdf {
107    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
108        write!(f, "ColumnUdf")
109    }
110}
111
112/// A wrapper trait for any binary closure `Fn(Column, Column) -> PolarsResult<Column>`
113pub trait ColumnBinaryUdf: Send + Sync {
114    fn call_udf(&self, a: Column, b: Column) -> PolarsResult<Column>;
115}
116
117impl<F> ColumnBinaryUdf for F
118where
119    F: Fn(Column, Column) -> PolarsResult<Column> + Send + Sync,
120{
121    fn call_udf(&self, a: Column, b: Column) -> PolarsResult<Column> {
122        self(a, b)
123    }
124}
125
126impl Debug for dyn ColumnBinaryUdf {
127    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
128        write!(f, "ColumnBinaryUdf")
129    }
130}
131
132impl Default for SpecialEq<Arc<dyn ColumnBinaryUdf>> {
133    fn default() -> Self {
134        panic!("implementation error");
135    }
136}
137
138impl Default for SpecialEq<Arc<dyn BinaryUdfOutputField>> {
139    fn default() -> Self {
140        let output_field = move |_: &Schema, _: Context, _: &Field, _: &Field| None;
141        SpecialEq::new(Arc::new(output_field))
142    }
143}
144
145pub trait RenameAliasFn: Send + Sync {
146    fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr>;
147
148    fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
149        polars_bail!(ComputeError: "serialization not supported for this renaming function")
150    }
151}
152
153impl<F> RenameAliasFn for F
154where
155    F: Fn(&PlSmallStr) -> PolarsResult<PlSmallStr> + Send + Sync,
156{
157    fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr> {
158        self(name)
159    }
160}
161
162impl Debug for dyn RenameAliasFn {
163    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
164        write!(f, "RenameAliasFn")
165    }
166}
167
168#[derive(Clone)]
169/// Wrapper type that has special equality properties
170/// depending on the inner type specialization
171pub struct SpecialEq<T>(T);
172
173#[cfg(feature = "serde")]
174impl Serialize for SpecialEq<Series> {
175    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
176    where
177        S: Serializer,
178    {
179        self.0.serialize(serializer)
180    }
181}
182
183#[cfg(feature = "serde")]
184impl<'a> Deserialize<'a> for SpecialEq<Series> {
185    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
186    where
187        D: Deserializer<'a>,
188    {
189        let t = Series::deserialize(deserializer)?;
190        Ok(SpecialEq(t))
191    }
192}
193
194#[cfg(feature = "serde")]
195impl<T: Serialize> Serialize for SpecialEq<Arc<T>> {
196    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
197    where
198        S: Serializer,
199    {
200        self.0.serialize(serializer)
201    }
202}
203
204#[cfg(feature = "serde")]
205impl<'a, T: Deserialize<'a>> Deserialize<'a> for SpecialEq<Arc<T>> {
206    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
207    where
208        D: Deserializer<'a>,
209    {
210        let t = T::deserialize(deserializer)?;
211        Ok(SpecialEq(Arc::new(t)))
212    }
213}
214
215impl<T> SpecialEq<T> {
216    pub fn new(val: T) -> Self {
217        SpecialEq(val)
218    }
219}
220
221impl<T: ?Sized> PartialEq for SpecialEq<Arc<T>> {
222    fn eq(&self, other: &Self) -> bool {
223        Arc::ptr_eq(&self.0, &other.0)
224    }
225}
226
227impl<T> Eq for SpecialEq<Arc<T>> {}
228
229impl PartialEq for SpecialEq<Series> {
230    fn eq(&self, other: &Self) -> bool {
231        self.0 == other.0
232    }
233}
234
235impl<T> Debug for SpecialEq<T> {
236    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
237        write!(f, "no_eq")
238    }
239}
240
241impl<T> Deref for SpecialEq<T> {
242    type Target = T;
243
244    fn deref(&self) -> &Self::Target {
245        &self.0
246    }
247}
248
249pub trait BinaryUdfOutputField: Send + Sync {
250    fn get_field(
251        &self,
252        input_schema: &Schema,
253        cntxt: Context,
254        field_a: &Field,
255        field_b: &Field,
256    ) -> Option<Field>;
257}
258
259impl<F> BinaryUdfOutputField for F
260where
261    F: Fn(&Schema, Context, &Field, &Field) -> Option<Field> + Send + Sync,
262{
263    fn get_field(
264        &self,
265        input_schema: &Schema,
266        cntxt: Context,
267        field_a: &Field,
268        field_b: &Field,
269    ) -> Option<Field> {
270        self(input_schema, cntxt, field_a, field_b)
271    }
272}
273
274pub trait FunctionOutputField: Send + Sync {
275    fn get_field(
276        &self,
277        input_schema: &Schema,
278        cntxt: Context,
279        fields: &[Field],
280    ) -> PolarsResult<Field>;
281
282    fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
283        polars_bail!(ComputeError: "serialization not supported for this output field")
284    }
285}
286
287pub type GetOutput = SpecialEq<Arc<dyn FunctionOutputField>>;
288
289impl Default for GetOutput {
290    fn default() -> Self {
291        SpecialEq::new(Arc::new(
292            |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()),
293        ))
294    }
295}
296
297impl GetOutput {
298    pub fn same_type() -> Self {
299        Default::default()
300    }
301
302    pub fn first() -> Self {
303        SpecialEq::new(Arc::new(
304            |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()),
305        ))
306    }
307
308    pub fn from_type(dt: DataType) -> Self {
309        SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
310            Ok(Field::new(flds[0].name().clone(), dt.clone()))
311        }))
312    }
313
314    pub fn map_field<F: 'static + Fn(&Field) -> PolarsResult<Field> + Send + Sync>(f: F) -> Self {
315        SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
316            f(&flds[0])
317        }))
318    }
319
320    pub fn map_fields<F: 'static + Fn(&[Field]) -> PolarsResult<Field> + Send + Sync>(
321        f: F,
322    ) -> Self {
323        SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
324            f(flds)
325        }))
326    }
327
328    pub fn map_dtype<F: 'static + Fn(&DataType) -> PolarsResult<DataType> + Send + Sync>(
329        f: F,
330    ) -> Self {
331        SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
332            let mut fld = flds[0].clone();
333            let new_type = f(fld.dtype())?;
334            fld.coerce(new_type);
335            Ok(fld)
336        }))
337    }
338
339    pub fn float_type() -> Self {
340        Self::map_dtype(|dt| {
341            Ok(match dt {
342                DataType::Float32 => DataType::Float32,
343                _ => DataType::Float64,
344            })
345        })
346    }
347
348    pub fn super_type() -> Self {
349        Self::map_dtypes(|dtypes| {
350            let mut st = dtypes[0].clone();
351            for dt in &dtypes[1..] {
352                st = try_get_supertype(&st, dt)?;
353            }
354            Ok(st)
355        })
356    }
357
358    pub fn map_dtypes<F>(f: F) -> Self
359    where
360        F: 'static + Fn(&[&DataType]) -> PolarsResult<DataType> + Send + Sync,
361    {
362        SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
363            let mut fld = flds[0].clone();
364            let dtypes = flds.iter().map(|fld| fld.dtype()).collect::<Vec<_>>();
365            let new_type = f(&dtypes)?;
366            fld.coerce(new_type);
367            Ok(fld)
368        }))
369    }
370}
371
372impl<F> FunctionOutputField for F
373where
374    F: Fn(&Schema, Context, &[Field]) -> PolarsResult<Field> + Send + Sync,
375{
376    fn get_field(
377        &self,
378        input_schema: &Schema,
379        cntxt: Context,
380        fields: &[Field],
381    ) -> PolarsResult<Field> {
382        self(input_schema, cntxt, fields)
383    }
384}
385
386#[cfg(feature = "serde")]
387impl Serialize for GetOutput {
388    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
389    where
390        S: Serializer,
391    {
392        use serde::ser::Error;
393        let mut buf = vec![];
394        self.0
395            .try_serialize(&mut buf)
396            .map_err(|e| S::Error::custom(format!("{e}")))?;
397        serializer.serialize_bytes(&buf)
398    }
399}
400
401#[cfg(feature = "serde")]
402impl<'a> Deserialize<'a> for GetOutput {
403    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
404    where
405        D: Deserializer<'a>,
406    {
407        use serde::de::Error;
408        #[cfg(feature = "python")]
409        {
410            deserialize_map_bytes(deserializer, |buf| {
411                if buf.starts_with(self::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) {
412                    let get_output = self::python_dsl::PythonGetOutput::try_deserialize(&buf)
413                        .map_err(|e| D::Error::custom(format!("{e}")))?;
414                    Ok(SpecialEq::new(get_output))
415                } else {
416                    Err(D::Error::custom(
417                        "deserialization not supported for this output field",
418                    ))
419                }
420            })?
421        }
422        #[cfg(not(feature = "python"))]
423        {
424            _ = deserializer;
425
426            Err(D::Error::custom(
427                "deserialization not supported for this output field",
428            ))
429        }
430    }
431}
432
433#[cfg(feature = "serde")]
434impl Serialize for SpecialEq<Arc<dyn RenameAliasFn>> {
435    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
436    where
437        S: Serializer,
438    {
439        use serde::ser::Error;
440        let mut buf = vec![];
441        self.0
442            .try_serialize(&mut buf)
443            .map_err(|e| S::Error::custom(format!("{e}")))?;
444        serializer.serialize_bytes(&buf)
445    }
446}
447
448#[cfg(feature = "serde")]
449impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn RenameAliasFn>> {
450    fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
451    where
452        D: Deserializer<'a>,
453    {
454        use serde::de::Error;
455        Err(D::Error::custom(
456            "deserialization not supported for this renaming function",
457        ))
458    }
459}