1use std::fmt::Formatter;
2use std::ops::Deref;
3use std::sync::Arc;
4
5use polars_core::utils::try_get_supertype;
6
7use super::*;
8
9pub trait ColumnsUdf: Send + Sync {
11 fn as_any(&self) -> &dyn std::any::Any {
12 unimplemented!("as_any not implemented for this 'opaque' function")
13 }
14
15 fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>>;
16
17 fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
18 polars_bail!(ComputeError: "serialization not supported for this 'opaque' function")
19 }
20}
21
22impl<F> ColumnsUdf for F
23where
24 F: Fn(&mut [Column]) -> PolarsResult<Option<Column>> + Send + Sync,
25{
26 fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Option<Column>> {
27 self(s)
28 }
29}
30
31impl Debug for dyn ColumnsUdf {
32 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
33 write!(f, "ColumnUdf")
34 }
35}
36
37pub trait ColumnBinaryUdf: Send + Sync {
39 fn call_udf(&self, a: Column, b: Column) -> PolarsResult<Column>;
40}
41
42impl<F> ColumnBinaryUdf for F
43where
44 F: Fn(Column, Column) -> PolarsResult<Column> + Send + Sync,
45{
46 fn call_udf(&self, a: Column, b: Column) -> PolarsResult<Column> {
47 self(a, b)
48 }
49}
50
51impl Debug for dyn ColumnBinaryUdf {
52 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53 write!(f, "ColumnBinaryUdf")
54 }
55}
56
57impl Default for SpecialEq<Arc<dyn ColumnBinaryUdf>> {
58 fn default() -> Self {
59 panic!("implementation error");
60 }
61}
62
63impl Default for SpecialEq<Arc<dyn BinaryUdfOutputField>> {
64 fn default() -> Self {
65 let output_field = move |_: &Schema, _: Context, _: &Field, _: &Field| None;
66 SpecialEq::new(Arc::new(output_field))
67 }
68}
69
70pub trait RenameAliasFn: Send + Sync {
71 fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr>;
72
73 fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
74 polars_bail!(ComputeError: "serialization not supported for this renaming function")
75 }
76}
77
78impl<F> RenameAliasFn for F
79where
80 F: Fn(&PlSmallStr) -> PolarsResult<PlSmallStr> + Send + Sync,
81{
82 fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr> {
83 self(name)
84 }
85}
86
87impl Debug for dyn RenameAliasFn {
88 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
89 write!(f, "RenameAliasFn")
90 }
91}
92
93#[derive(Clone)]
94pub struct SpecialEq<T>(T);
97
98impl<T> SpecialEq<T> {
99 pub fn new(val: T) -> Self {
100 SpecialEq(val)
101 }
102}
103
104impl<T: ?Sized> PartialEq for SpecialEq<Arc<T>> {
105 fn eq(&self, other: &Self) -> bool {
106 Arc::ptr_eq(&self.0, &other.0)
107 }
108}
109
110impl<T> Eq for SpecialEq<Arc<T>> {}
111
112impl PartialEq for SpecialEq<Series> {
113 fn eq(&self, other: &Self) -> bool {
114 self.0 == other.0
115 }
116}
117
118impl<T> Debug for SpecialEq<T> {
119 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
120 write!(f, "no_eq")
121 }
122}
123
124impl<T> Deref for SpecialEq<T> {
125 type Target = T;
126
127 fn deref(&self) -> &Self::Target {
128 &self.0
129 }
130}
131
132pub trait BinaryUdfOutputField: Send + Sync {
133 fn get_field(
134 &self,
135 input_schema: &Schema,
136 cntxt: Context,
137 field_a: &Field,
138 field_b: &Field,
139 ) -> Option<Field>;
140}
141
142impl<F> BinaryUdfOutputField for F
143where
144 F: Fn(&Schema, Context, &Field, &Field) -> Option<Field> + Send + Sync,
145{
146 fn get_field(
147 &self,
148 input_schema: &Schema,
149 cntxt: Context,
150 field_a: &Field,
151 field_b: &Field,
152 ) -> Option<Field> {
153 self(input_schema, cntxt, field_a, field_b)
154 }
155}
156
157pub trait FunctionOutputField: Send + Sync {
158 fn get_field(
159 &self,
160 input_schema: &Schema,
161 cntxt: Context,
162 fields: &[Field],
163 ) -> PolarsResult<Field>;
164
165 fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
166 polars_bail!(ComputeError: "serialization not supported for this output field")
167 }
168}
169
170pub type GetOutput = LazySerde<SpecialEq<Arc<dyn FunctionOutputField>>>;
171
172impl Default for GetOutput {
173 fn default() -> Self {
174 LazySerde::Deserialized(SpecialEq::new(Arc::new(
175 |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()),
176 )))
177 }
178}
179
180impl GetOutput {
181 pub fn same_type() -> Self {
182 Default::default()
183 }
184
185 pub fn first() -> Self {
186 LazySerde::Deserialized(SpecialEq::new(Arc::new(
187 |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()),
188 )))
189 }
190
191 pub fn from_type(dt: DataType) -> Self {
192 LazySerde::Deserialized(SpecialEq::new(Arc::new(
193 move |_: &Schema, _: Context, flds: &[Field]| {
194 Ok(Field::new(flds[0].name().clone(), dt.clone()))
195 },
196 )))
197 }
198
199 pub fn map_field<F: 'static + Fn(&Field) -> PolarsResult<Field> + Send + Sync>(f: F) -> Self {
200 LazySerde::Deserialized(SpecialEq::new(Arc::new(
201 move |_: &Schema, _: Context, flds: &[Field]| f(&flds[0]),
202 )))
203 }
204
205 pub fn map_fields<F: 'static + Fn(&[Field]) -> PolarsResult<Field> + Send + Sync>(
206 f: F,
207 ) -> Self {
208 LazySerde::Deserialized(SpecialEq::new(Arc::new(
209 move |_: &Schema, _: Context, flds: &[Field]| f(flds),
210 )))
211 }
212
213 pub fn map_dtype<F: 'static + Fn(&DataType) -> PolarsResult<DataType> + Send + Sync>(
214 f: F,
215 ) -> Self {
216 LazySerde::Deserialized(SpecialEq::new(Arc::new(
217 move |_: &Schema, _: Context, flds: &[Field]| {
218 let mut fld = flds[0].clone();
219 let new_type = f(fld.dtype())?;
220 fld.coerce(new_type);
221 Ok(fld)
222 },
223 )))
224 }
225
226 pub fn float_type() -> Self {
227 Self::map_dtype(|dt| {
228 Ok(match dt {
229 DataType::Float32 => DataType::Float32,
230 _ => DataType::Float64,
231 })
232 })
233 }
234
235 pub fn super_type() -> Self {
236 Self::map_dtypes(|dtypes| {
237 let mut st = dtypes[0].clone();
238 for dt in &dtypes[1..] {
239 st = try_get_supertype(&st, dt)?;
240 }
241 Ok(st)
242 })
243 }
244
245 pub fn map_dtypes<F>(f: F) -> Self
246 where
247 F: 'static + Fn(&[&DataType]) -> PolarsResult<DataType> + Send + Sync,
248 {
249 LazySerde::Deserialized(SpecialEq::new(Arc::new(
250 move |_: &Schema, _: Context, flds: &[Field]| {
251 let mut fld = flds[0].clone();
252 let dtypes = flds.iter().map(|fld| fld.dtype()).collect::<Vec<_>>();
253 let new_type = f(&dtypes)?;
254 fld.coerce(new_type);
255 Ok(fld)
256 },
257 )))
258 }
259}
260
261impl<F> FunctionOutputField for F
262where
263 F: Fn(&Schema, Context, &[Field]) -> PolarsResult<Field> + Send + Sync,
264{
265 fn get_field(
266 &self,
267 input_schema: &Schema,
268 cntxt: Context,
269 fields: &[Field],
270 ) -> PolarsResult<Field> {
271 self(input_schema, cntxt, fields)
272 }
273}
274
275pub type OpaqueColumnUdf = LazySerde<SpecialEq<Arc<dyn ColumnsUdf>>>;
276pub(crate) fn new_column_udf<F: ColumnsUdf + 'static>(func: F) -> OpaqueColumnUdf {
277 LazySerde::Deserialized(SpecialEq::new(Arc::new(func)))
278}
279
280impl OpaqueColumnUdf {
281 pub fn materialize(self) -> PolarsResult<SpecialEq<Arc<dyn ColumnsUdf>>> {
282 match self {
283 Self::Deserialized(t) => Ok(t),
284 Self::Named {
285 name: _,
286 payload: _,
287 value: _,
288 } => {
289 panic!("should not be hit")
290 },
291 Self::Bytes(_b) => {
292 feature_gated!("serde";"python", {
293 serde_expr::deserialize_column_udf(_b.as_ref()).map(SpecialEq::new)
294 })
295 },
296 }
297 }
298}
299
300impl GetOutput {
301 pub fn materialize(self) -> PolarsResult<SpecialEq<Arc<dyn FunctionOutputField>>> {
302 match self {
303 Self::Deserialized(t) => Ok(t),
304 Self::Named {
305 name: _,
306 payload: _,
307 value,
308 } => value.ok_or_else(|| polars_err!(ComputeError: "GetOutput Value not set")),
309 Self::Bytes(_b) => {
310 polars_bail!(ComputeError: "should not be hit")
311 },
312 }
313 }
314}