1use std::fmt::Formatter;
2use std::ops::Deref;
3use std::sync::Arc;
4
5#[cfg(feature = "python")]
6use polars_utils::pl_serialize::deserialize_map_bytes;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Deserializer, Serialize, Serializer};
9
10use super::*;
11
12pub trait ColumnsUdf: Send + Sync {
14 fn as_any(&self) -> &dyn std::any::Any {
15 unimplemented!("as_any not implemented for this 'opaque' function")
16 }
17
18 fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>>;
19
20 fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
21 polars_bail!(ComputeError: "serialization not supported for this 'opaque' function")
22 }
23}
24
25#[cfg(feature = "serde")]
26impl Serialize for SpecialEq<Arc<dyn ColumnsUdf>> {
27 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
28 where
29 S: Serializer,
30 {
31 use serde::ser::Error;
32 let mut buf = vec![];
33 self.0
34 .try_serialize(&mut buf)
35 .map_err(|e| S::Error::custom(format!("{e}")))?;
36 serializer.serialize_bytes(&buf)
37 }
38}
39
40#[cfg(feature = "serde")]
41impl<T: Serialize + Clone> Serialize for LazySerde<T> {
42 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
43 where
44 S: Serializer,
45 {
46 match self {
47 Self::Deserialized(t) => t.serialize(serializer),
48 Self::Bytes(b) => b.serialize(serializer),
49 }
50 }
51}
52
53#[cfg(feature = "serde")]
54impl<'a, T: Deserialize<'a> + Clone> Deserialize<'a> for LazySerde<T> {
55 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
56 where
57 D: Deserializer<'a>,
58 {
59 let buf = bytes::Bytes::deserialize(deserializer)?;
60 Ok(Self::Bytes(buf))
61 }
62}
63
64#[cfg(feature = "serde")]
65impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn ColumnsUdf>> {
67 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
68 where
69 D: Deserializer<'a>,
70 {
71 use serde::de::Error;
72 #[cfg(feature = "python")]
73 {
74 deserialize_map_bytes(deserializer, |buf| {
75 if buf.starts_with(crate::dsl::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) {
76 let udf = crate::dsl::python_dsl::PythonUdfExpression::try_deserialize(&buf)
77 .map_err(|e| D::Error::custom(format!("{e}")))?;
78 Ok(SpecialEq::new(udf))
79 } else {
80 Err(D::Error::custom(
81 "deserialization not supported for this 'opaque' function",
82 ))
83 }
84 })?
85 }
86 #[cfg(not(feature = "python"))]
87 {
88 _ = deserializer;
89
90 Err(D::Error::custom(
91 "deserialization not supported for this 'opaque' function",
92 ))
93 }
94 }
95}
96
97impl<F> ColumnsUdf for F
98where
99 F: Fn(&mut [Column]) -> PolarsResult<Option<Column>> + Send + Sync,
100{
101 fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>> {
102 self(s)
103 }
104}
105
106impl Debug for dyn ColumnsUdf {
107 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
108 write!(f, "ColumnUdf")
109 }
110}
111
112pub trait ColumnBinaryUdf: Send + Sync {
114 fn call_udf(&self, a: Column, b: Column) -> PolarsResult<Column>;
115}
116
117impl<F> ColumnBinaryUdf for F
118where
119 F: Fn(Column, Column) -> PolarsResult<Column> + Send + Sync,
120{
121 fn call_udf(&self, a: Column, b: Column) -> PolarsResult<Column> {
122 self(a, b)
123 }
124}
125
126impl Debug for dyn ColumnBinaryUdf {
127 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
128 write!(f, "ColumnBinaryUdf")
129 }
130}
131
132impl Default for SpecialEq<Arc<dyn ColumnBinaryUdf>> {
133 fn default() -> Self {
134 panic!("implementation error");
135 }
136}
137
138impl Default for SpecialEq<Arc<dyn BinaryUdfOutputField>> {
139 fn default() -> Self {
140 let output_field = move |_: &Schema, _: Context, _: &Field, _: &Field| None;
141 SpecialEq::new(Arc::new(output_field))
142 }
143}
144
145pub trait RenameAliasFn: Send + Sync {
146 fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr>;
147
148 fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
149 polars_bail!(ComputeError: "serialization not supported for this renaming function")
150 }
151}
152
153impl<F> RenameAliasFn for F
154where
155 F: Fn(&PlSmallStr) -> PolarsResult<PlSmallStr> + Send + Sync,
156{
157 fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr> {
158 self(name)
159 }
160}
161
162impl Debug for dyn RenameAliasFn {
163 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
164 write!(f, "RenameAliasFn")
165 }
166}
167
168#[derive(Clone)]
169pub struct SpecialEq<T>(T);
172
173#[cfg(feature = "serde")]
174impl Serialize for SpecialEq<Series> {
175 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
176 where
177 S: Serializer,
178 {
179 self.0.serialize(serializer)
180 }
181}
182
183#[cfg(feature = "serde")]
184impl<'a> Deserialize<'a> for SpecialEq<Series> {
185 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
186 where
187 D: Deserializer<'a>,
188 {
189 let t = Series::deserialize(deserializer)?;
190 Ok(SpecialEq(t))
191 }
192}
193
194#[cfg(feature = "serde")]
195impl<T: Serialize> Serialize for SpecialEq<Arc<T>> {
196 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
197 where
198 S: Serializer,
199 {
200 self.0.serialize(serializer)
201 }
202}
203
204#[cfg(feature = "serde")]
205impl<'a, T: Deserialize<'a>> Deserialize<'a> for SpecialEq<Arc<T>> {
206 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
207 where
208 D: Deserializer<'a>,
209 {
210 let t = T::deserialize(deserializer)?;
211 Ok(SpecialEq(Arc::new(t)))
212 }
213}
214
215impl<T> SpecialEq<T> {
216 pub fn new(val: T) -> Self {
217 SpecialEq(val)
218 }
219}
220
221impl<T: ?Sized> PartialEq for SpecialEq<Arc<T>> {
222 fn eq(&self, other: &Self) -> bool {
223 Arc::ptr_eq(&self.0, &other.0)
224 }
225}
226
227impl<T> Eq for SpecialEq<Arc<T>> {}
228
229impl PartialEq for SpecialEq<Series> {
230 fn eq(&self, other: &Self) -> bool {
231 self.0 == other.0
232 }
233}
234
235impl<T> Debug for SpecialEq<T> {
236 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
237 write!(f, "no_eq")
238 }
239}
240
241impl<T> Deref for SpecialEq<T> {
242 type Target = T;
243
244 fn deref(&self) -> &Self::Target {
245 &self.0
246 }
247}
248
249pub trait BinaryUdfOutputField: Send + Sync {
250 fn get_field(
251 &self,
252 input_schema: &Schema,
253 cntxt: Context,
254 field_a: &Field,
255 field_b: &Field,
256 ) -> Option<Field>;
257}
258
259impl<F> BinaryUdfOutputField for F
260where
261 F: Fn(&Schema, Context, &Field, &Field) -> Option<Field> + Send + Sync,
262{
263 fn get_field(
264 &self,
265 input_schema: &Schema,
266 cntxt: Context,
267 field_a: &Field,
268 field_b: &Field,
269 ) -> Option<Field> {
270 self(input_schema, cntxt, field_a, field_b)
271 }
272}
273
274pub trait FunctionOutputField: Send + Sync {
275 fn get_field(
276 &self,
277 input_schema: &Schema,
278 cntxt: Context,
279 fields: &[Field],
280 ) -> PolarsResult<Field>;
281
282 fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
283 polars_bail!(ComputeError: "serialization not supported for this output field")
284 }
285}
286
287pub type GetOutput = SpecialEq<Arc<dyn FunctionOutputField>>;
288
289impl Default for GetOutput {
290 fn default() -> Self {
291 SpecialEq::new(Arc::new(
292 |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()),
293 ))
294 }
295}
296
297impl GetOutput {
298 pub fn same_type() -> Self {
299 Default::default()
300 }
301
302 pub fn first() -> Self {
303 SpecialEq::new(Arc::new(
304 |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()),
305 ))
306 }
307
308 pub fn from_type(dt: DataType) -> Self {
309 SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
310 Ok(Field::new(flds[0].name().clone(), dt.clone()))
311 }))
312 }
313
314 pub fn map_field<F: 'static + Fn(&Field) -> PolarsResult<Field> + Send + Sync>(f: F) -> Self {
315 SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
316 f(&flds[0])
317 }))
318 }
319
320 pub fn map_fields<F: 'static + Fn(&[Field]) -> PolarsResult<Field> + Send + Sync>(
321 f: F,
322 ) -> Self {
323 SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
324 f(flds)
325 }))
326 }
327
328 pub fn map_dtype<F: 'static + Fn(&DataType) -> PolarsResult<DataType> + Send + Sync>(
329 f: F,
330 ) -> Self {
331 SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
332 let mut fld = flds[0].clone();
333 let new_type = f(fld.dtype())?;
334 fld.coerce(new_type);
335 Ok(fld)
336 }))
337 }
338
339 pub fn float_type() -> Self {
340 Self::map_dtype(|dt| {
341 Ok(match dt {
342 DataType::Float32 => DataType::Float32,
343 _ => DataType::Float64,
344 })
345 })
346 }
347
348 pub fn super_type() -> Self {
349 Self::map_dtypes(|dtypes| {
350 let mut st = dtypes[0].clone();
351 for dt in &dtypes[1..] {
352 st = try_get_supertype(&st, dt)?;
353 }
354 Ok(st)
355 })
356 }
357
358 pub fn map_dtypes<F>(f: F) -> Self
359 where
360 F: 'static + Fn(&[&DataType]) -> PolarsResult<DataType> + Send + Sync,
361 {
362 SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
363 let mut fld = flds[0].clone();
364 let dtypes = flds.iter().map(|fld| fld.dtype()).collect::<Vec<_>>();
365 let new_type = f(&dtypes)?;
366 fld.coerce(new_type);
367 Ok(fld)
368 }))
369 }
370}
371
372impl<F> FunctionOutputField for F
373where
374 F: Fn(&Schema, Context, &[Field]) -> PolarsResult<Field> + Send + Sync,
375{
376 fn get_field(
377 &self,
378 input_schema: &Schema,
379 cntxt: Context,
380 fields: &[Field],
381 ) -> PolarsResult<Field> {
382 self(input_schema, cntxt, fields)
383 }
384}
385
386#[cfg(feature = "serde")]
387impl Serialize for GetOutput {
388 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
389 where
390 S: Serializer,
391 {
392 use serde::ser::Error;
393 let mut buf = vec![];
394 self.0
395 .try_serialize(&mut buf)
396 .map_err(|e| S::Error::custom(format!("{e}")))?;
397 serializer.serialize_bytes(&buf)
398 }
399}
400
401#[cfg(feature = "serde")]
402impl<'a> Deserialize<'a> for GetOutput {
403 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
404 where
405 D: Deserializer<'a>,
406 {
407 use serde::de::Error;
408 #[cfg(feature = "python")]
409 {
410 deserialize_map_bytes(deserializer, |buf| {
411 if buf.starts_with(self::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) {
412 let get_output = self::python_dsl::PythonGetOutput::try_deserialize(&buf)
413 .map_err(|e| D::Error::custom(format!("{e}")))?;
414 Ok(SpecialEq::new(get_output))
415 } else {
416 Err(D::Error::custom(
417 "deserialization not supported for this output field",
418 ))
419 }
420 })?
421 }
422 #[cfg(not(feature = "python"))]
423 {
424 _ = deserializer;
425
426 Err(D::Error::custom(
427 "deserialization not supported for this output field",
428 ))
429 }
430 }
431}
432
433#[cfg(feature = "serde")]
434impl Serialize for SpecialEq<Arc<dyn RenameAliasFn>> {
435 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
436 where
437 S: Serializer,
438 {
439 use serde::ser::Error;
440 let mut buf = vec![];
441 self.0
442 .try_serialize(&mut buf)
443 .map_err(|e| S::Error::custom(format!("{e}")))?;
444 serializer.serialize_bytes(&buf)
445 }
446}
447
448#[cfg(feature = "serde")]
449impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn RenameAliasFn>> {
450 fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
451 where
452 D: Deserializer<'a>,
453 {
454 use serde::de::Error;
455 Err(D::Error::custom(
456 "deserialization not supported for this renaming function",
457 ))
458 }
459}