polars_expr/dispatch/
mod.rs

1use std::sync::Arc;
2
3use polars_core::error::PolarsResult;
4use polars_core::frame::DataFrame;
5use polars_core::prelude::{Column, GroupPositions};
6use polars_plan::dsl::{ColumnsUdf, SpecialEq};
7use polars_plan::plans::{IRBooleanFunction, IRFunctionExpr, IRPowFunction};
8use polars_utils::IdxSize;
9
10use crate::prelude::{AggregationContext, PhysicalExpr};
11use crate::state::ExecutionState;
12
13#[macro_export]
14macro_rules! wrap {
15    ($e:expr) => {
16        SpecialEq::new(Arc::new($e))
17    };
18
19    ($e:expr, $($args:expr),*) => {{
20        let f = move |s: &mut [::polars_core::prelude::Column]| {
21            $e(s, $($args),*)
22        };
23
24        SpecialEq::new(Arc::new(f))
25    }};
26}
27
28/// `Fn(&[Column], args)`
29/// * all expression arguments are in the slice.
30/// * the first element is the root expression.
31#[macro_export]
32macro_rules! map_as_slice {
33    ($func:path) => {{
34        let f = move |s: &mut [::polars_core::prelude::Column]| {
35            $func(s)
36        };
37
38        SpecialEq::new(Arc::new(f))
39    }};
40
41    ($func:path, $($args:expr),*) => {{
42        let f = move |s: &mut [::polars_core::prelude::Column]| {
43            $func(s, $($args),*)
44        };
45
46        SpecialEq::new(Arc::new(f))
47    }};
48}
49
50/// * `FnOnce(Series)`
51/// * `FnOnce(Series, args)`
52#[macro_export]
53macro_rules! map_owned {
54    ($func:path) => {{
55        let f = move |c: &mut [::polars_core::prelude::Column]| {
56            let c = std::mem::take(&mut c[0]);
57            $func(c)
58        };
59
60        SpecialEq::new(Arc::new(f))
61    }};
62
63    ($func:path, $($args:expr),*) => {{
64        let f = move |c: &mut [::polars_core::prelude::Column]| {
65            let c = std::mem::take(&mut c[0]);
66            $func(c, $($args),*)
67        };
68
69        SpecialEq::new(Arc::new(f))
70    }};
71}
72
73/// `Fn(&Series, args)`
74#[macro_export]
75macro_rules! map {
76    ($func:path) => {{
77        let f = move |c: &mut [::polars_core::prelude::Column]| {
78            let c = &c[0];
79            $func(c)
80        };
81
82        SpecialEq::new(Arc::new(f))
83    }};
84
85    ($func:path, $($args:expr),*) => {{
86        let f = move |c: &mut [::polars_core::prelude::Column]| {
87            let c = &c[0];
88            $func(c, $($args),*)
89        };
90
91        SpecialEq::new(Arc::new(f))
92    }};
93}
94
95#[cfg(feature = "dtype-array")]
96mod array;
97mod binary;
98#[cfg(feature = "bitwise")]
99mod bitwise;
100mod boolean;
101#[cfg(feature = "business")]
102mod business;
103#[cfg(feature = "dtype-categorical")]
104mod cat;
105#[cfg(feature = "cum_agg")]
106mod cum;
107#[cfg(feature = "temporal")]
108mod datetime;
109mod groups_dispatch;
110mod horizontal;
111mod list;
112mod misc;
113mod pow;
114#[cfg(feature = "random")]
115mod random;
116#[cfg(feature = "range")]
117mod range;
118#[cfg(feature = "rolling_window")]
119mod rolling;
120#[cfg(feature = "rolling_window_by")]
121mod rolling_by;
122#[cfg(feature = "round_series")]
123mod round;
124mod shift_and_fill;
125#[cfg(feature = "strings")]
126mod strings;
127#[cfg(feature = "dtype-struct")]
128mod struct_;
129#[cfg(feature = "temporal")]
130mod temporal;
131#[cfg(feature = "trigonometry")]
132mod trigonometry;
133
134pub use groups_dispatch::drop_items;
135
136pub fn function_expr_to_udf(func: IRFunctionExpr) -> SpecialEq<Arc<dyn ColumnsUdf>> {
137    use IRFunctionExpr as F;
138    match func {
139        // Namespaces
140        #[cfg(feature = "dtype-array")]
141        F::ArrayExpr(func) => array::function_expr_to_udf(func),
142        F::BinaryExpr(func) => binary::function_expr_to_udf(func),
143        #[cfg(feature = "dtype-categorical")]
144        F::Categorical(func) => cat::function_expr_to_udf(func),
145        F::ListExpr(func) => list::function_expr_to_udf(func),
146        #[cfg(feature = "strings")]
147        F::StringExpr(func) => strings::function_expr_to_udf(func),
148        #[cfg(feature = "dtype-struct")]
149        F::StructExpr(func) => struct_::function_expr_to_udf(func),
150        #[cfg(feature = "temporal")]
151        F::TemporalExpr(func) => temporal::temporal_func_to_udf(func),
152        #[cfg(feature = "bitwise")]
153        F::Bitwise(func) => bitwise::function_expr_to_udf(func),
154
155        // Other expressions
156        F::Boolean(func) => boolean::function_expr_to_udf(func),
157        #[cfg(feature = "business")]
158        F::Business(func) => business::function_expr_to_udf(func),
159        #[cfg(feature = "abs")]
160        F::Abs => map!(misc::abs),
161        F::Negate => map!(misc::negate),
162        F::NullCount => {
163            let f = |s: &mut [Column]| {
164                let s = &s[0];
165                Ok(Column::new(s.name().clone(), [s.null_count() as IdxSize]))
166            };
167            wrap!(f)
168        },
169        F::Pow(func) => match func {
170            IRPowFunction::Generic => wrap!(pow::pow),
171            IRPowFunction::Sqrt => map!(pow::sqrt),
172            IRPowFunction::Cbrt => map!(pow::cbrt),
173        },
174        #[cfg(feature = "row_hash")]
175        F::Hash(k0, k1, k2, k3) => {
176            map!(misc::row_hash, k0, k1, k2, k3)
177        },
178        #[cfg(feature = "arg_where")]
179        F::ArgWhere => {
180            wrap!(misc::arg_where)
181        },
182        #[cfg(feature = "index_of")]
183        F::IndexOf => {
184            map_as_slice!(misc::index_of)
185        },
186        #[cfg(feature = "search_sorted")]
187        F::SearchSorted { side, descending } => {
188            map_as_slice!(misc::search_sorted_impl, side, descending)
189        },
190        #[cfg(feature = "range")]
191        F::Range(func) => range::function_expr_to_udf(func),
192
193        #[cfg(feature = "trigonometry")]
194        F::Trigonometry(trig_function) => {
195            map!(trigonometry::apply_trigonometric_function, trig_function)
196        },
197        #[cfg(feature = "trigonometry")]
198        F::Atan2 => {
199            wrap!(trigonometry::apply_arctan2)
200        },
201
202        #[cfg(feature = "sign")]
203        F::Sign => {
204            map!(misc::sign)
205        },
206        F::FillNull => {
207            map_as_slice!(misc::fill_null)
208        },
209        #[cfg(feature = "rolling_window")]
210        F::RollingExpr { function, options } => {
211            use IRRollingFunction::*;
212            use polars_plan::plans::IRRollingFunction;
213            match function {
214                Min => map!(rolling::rolling_min, options.clone()),
215                Max => map!(rolling::rolling_max, options.clone()),
216                Mean => map!(rolling::rolling_mean, options.clone()),
217                Sum => map!(rolling::rolling_sum, options.clone()),
218                Quantile => map!(rolling::rolling_quantile, options.clone()),
219                Var => map!(rolling::rolling_var, options.clone()),
220                Std => map!(rolling::rolling_std, options.clone()),
221                Rank => map!(rolling::rolling_rank, options.clone()),
222                #[cfg(feature = "moment")]
223                Skew => map!(rolling::rolling_skew, options.clone()),
224                #[cfg(feature = "moment")]
225                Kurtosis => map!(rolling::rolling_kurtosis, options.clone()),
226                #[cfg(feature = "cov")]
227                CorrCov {
228                    corr_cov_options,
229                    is_corr,
230                } => {
231                    map_as_slice!(
232                        rolling::rolling_corr_cov,
233                        options.clone(),
234                        corr_cov_options,
235                        is_corr
236                    )
237                },
238                Map(f) => {
239                    map!(rolling::rolling_map, options.clone(), f.clone())
240                },
241            }
242        },
243        #[cfg(feature = "rolling_window_by")]
244        F::RollingExprBy {
245            function_by,
246            options,
247        } => {
248            use IRRollingFunctionBy::*;
249            use polars_plan::plans::IRRollingFunctionBy;
250            match function_by {
251                MinBy => map_as_slice!(rolling_by::rolling_min_by, options.clone()),
252                MaxBy => map_as_slice!(rolling_by::rolling_max_by, options.clone()),
253                MeanBy => map_as_slice!(rolling_by::rolling_mean_by, options.clone()),
254                SumBy => map_as_slice!(rolling_by::rolling_sum_by, options.clone()),
255                QuantileBy => {
256                    map_as_slice!(rolling_by::rolling_quantile_by, options.clone())
257                },
258                VarBy => map_as_slice!(rolling_by::rolling_var_by, options.clone()),
259                StdBy => map_as_slice!(rolling_by::rolling_std_by, options.clone()),
260                RankBy => map_as_slice!(rolling_by::rolling_rank_by, options.clone()),
261            }
262        },
263        #[cfg(feature = "hist")]
264        F::Hist {
265            bin_count,
266            include_category,
267            include_breakpoint,
268        } => {
269            map_as_slice!(misc::hist, bin_count, include_category, include_breakpoint)
270        },
271        F::Rechunk => map!(misc::rechunk),
272        F::Append { upcast } => map_as_slice!(misc::append, upcast),
273        F::ShiftAndFill => {
274            map_as_slice!(shift_and_fill::shift_and_fill)
275        },
276        F::DropNans => map_owned!(misc::drop_nans),
277        F::DropNulls => map!(misc::drop_nulls),
278        #[cfg(feature = "round_series")]
279        F::Clip { has_min, has_max } => {
280            map_as_slice!(misc::clip, has_min, has_max)
281        },
282        #[cfg(feature = "mode")]
283        F::Mode => map!(misc::mode),
284        #[cfg(feature = "moment")]
285        F::Skew(bias) => map!(misc::skew, bias),
286        #[cfg(feature = "moment")]
287        F::Kurtosis(fisher, bias) => map!(misc::kurtosis, fisher, bias),
288        F::ArgUnique => map!(misc::arg_unique),
289        F::ArgMin => map!(misc::arg_min),
290        F::ArgMax => map!(misc::arg_max),
291        F::ArgSort {
292            descending,
293            nulls_last,
294        } => map!(misc::arg_sort, descending, nulls_last),
295        F::Product => map!(misc::product),
296        F::Repeat => map_as_slice!(misc::repeat),
297        #[cfg(feature = "rank")]
298        F::Rank { options, seed } => map!(misc::rank, options, seed),
299        #[cfg(feature = "dtype-struct")]
300        F::AsStruct => {
301            map_as_slice!(misc::as_struct)
302        },
303        #[cfg(feature = "top_k")]
304        F::TopK { descending } => {
305            map_as_slice!(polars_ops::prelude::top_k, descending)
306        },
307        #[cfg(feature = "top_k")]
308        F::TopKBy { descending } => {
309            map_as_slice!(polars_ops::prelude::top_k_by, descending.clone())
310        },
311        F::Shift => map_as_slice!(shift_and_fill::shift),
312        #[cfg(feature = "cum_agg")]
313        F::CumCount { reverse } => map!(cum::cum_count, reverse),
314        #[cfg(feature = "cum_agg")]
315        F::CumSum { reverse } => map!(cum::cum_sum, reverse),
316        #[cfg(feature = "cum_agg")]
317        F::CumProd { reverse } => map!(cum::cum_prod, reverse),
318        #[cfg(feature = "cum_agg")]
319        F::CumMin { reverse } => map!(cum::cum_min, reverse),
320        #[cfg(feature = "cum_agg")]
321        F::CumMax { reverse } => map!(cum::cum_max, reverse),
322        #[cfg(feature = "dtype-struct")]
323        F::ValueCounts {
324            sort,
325            parallel,
326            name,
327            normalize,
328        } => map!(misc::value_counts, sort, parallel, name.clone(), normalize),
329        #[cfg(feature = "unique_counts")]
330        F::UniqueCounts => map!(misc::unique_counts),
331        F::Reverse => map!(misc::reverse),
332        #[cfg(feature = "approx_unique")]
333        F::ApproxNUnique => map!(misc::approx_n_unique),
334        F::Coalesce => map_as_slice!(misc::coalesce),
335        #[cfg(feature = "diff")]
336        F::Diff(null_behavior) => map_as_slice!(misc::diff, null_behavior),
337        #[cfg(feature = "pct_change")]
338        F::PctChange => map_as_slice!(misc::pct_change),
339        #[cfg(feature = "interpolate")]
340        F::Interpolate(method) => {
341            map!(misc::interpolate, method)
342        },
343        #[cfg(feature = "interpolate_by")]
344        F::InterpolateBy => {
345            map_as_slice!(misc::interpolate_by)
346        },
347        #[cfg(feature = "log")]
348        F::Entropy { base, normalize } => map!(misc::entropy, base, normalize),
349        #[cfg(feature = "log")]
350        F::Log => map_as_slice!(misc::log),
351        #[cfg(feature = "log")]
352        F::Log1p => map!(misc::log1p),
353        #[cfg(feature = "log")]
354        F::Exp => map!(misc::exp),
355        F::Unique(stable) => map!(misc::unique, stable),
356        #[cfg(feature = "round_series")]
357        F::Round { decimals, mode } => map!(round::round, decimals, mode),
358        #[cfg(feature = "round_series")]
359        F::RoundSF { digits } => map!(round::round_sig_figs, digits),
360        #[cfg(feature = "round_series")]
361        F::Floor => map!(round::floor),
362        #[cfg(feature = "round_series")]
363        F::Ceil => map!(round::ceil),
364        #[cfg(feature = "fused")]
365        F::Fused(op) => map_as_slice!(misc::fused, op),
366        F::ConcatExpr(rechunk) => map_as_slice!(misc::concat_expr, rechunk),
367        #[cfg(feature = "cov")]
368        F::Correlation { method } => map_as_slice!(misc::corr, method),
369        #[cfg(feature = "peaks")]
370        F::PeakMin => map!(misc::peak_min),
371        #[cfg(feature = "peaks")]
372        F::PeakMax => map!(misc::peak_max),
373        #[cfg(feature = "repeat_by")]
374        F::RepeatBy => map_as_slice!(misc::repeat_by),
375        #[cfg(feature = "dtype-array")]
376        F::Reshape(dims) => map!(misc::reshape, &dims),
377        #[cfg(feature = "cutqcut")]
378        F::Cut {
379            breaks,
380            labels,
381            left_closed,
382            include_breaks,
383        } => map!(
384            misc::cut,
385            breaks.clone(),
386            labels.clone(),
387            left_closed,
388            include_breaks
389        ),
390        #[cfg(feature = "cutqcut")]
391        F::QCut {
392            probs,
393            labels,
394            left_closed,
395            allow_duplicates,
396            include_breaks,
397        } => map!(
398            misc::qcut,
399            probs.clone(),
400            labels.clone(),
401            left_closed,
402            allow_duplicates,
403            include_breaks
404        ),
405        #[cfg(feature = "rle")]
406        F::RLE => map!(polars_ops::series::rle),
407        #[cfg(feature = "rle")]
408        F::RLEID => map!(polars_ops::series::rle_id),
409        F::ToPhysical => map!(misc::to_physical),
410        #[cfg(feature = "random")]
411        F::Random { method, seed } => {
412            use IRRandomMethod::*;
413            use polars_plan::plans::IRRandomMethod;
414            match method {
415                Shuffle => map!(random::shuffle, seed),
416                Sample {
417                    is_fraction,
418                    with_replacement,
419                    shuffle,
420                } => {
421                    if is_fraction {
422                        map_as_slice!(random::sample_frac, with_replacement, shuffle, seed)
423                    } else {
424                        map_as_slice!(random::sample_n, with_replacement, shuffle, seed)
425                    }
426                },
427            }
428        },
429        F::SetSortedFlag(sorted) => map!(misc::set_sorted_flag, sorted),
430        #[cfg(feature = "ffi_plugin")]
431        F::FfiPlugin {
432            flags: _,
433            lib,
434            symbol,
435            kwargs,
436        } => unsafe {
437            map_as_slice!(
438                polars_plan::plans::plugin::call_plugin,
439                lib.as_ref(),
440                symbol.as_ref(),
441                kwargs.as_ref()
442            )
443        },
444
445        F::FoldHorizontal {
446            callback,
447            returns_scalar,
448            return_dtype,
449        } => map_as_slice!(
450            horizontal::fold,
451            &callback,
452            returns_scalar,
453            return_dtype.as_ref()
454        ),
455        F::ReduceHorizontal {
456            callback,
457            returns_scalar,
458            return_dtype,
459        } => map_as_slice!(
460            horizontal::reduce,
461            &callback,
462            returns_scalar,
463            return_dtype.as_ref()
464        ),
465        #[cfg(feature = "dtype-struct")]
466        F::CumReduceHorizontal {
467            callback,
468            returns_scalar,
469            return_dtype,
470        } => map_as_slice!(
471            horizontal::cum_reduce,
472            &callback,
473            returns_scalar,
474            return_dtype.as_ref()
475        ),
476        #[cfg(feature = "dtype-struct")]
477        F::CumFoldHorizontal {
478            callback,
479            returns_scalar,
480            return_dtype,
481            include_init,
482        } => map_as_slice!(
483            horizontal::cum_fold,
484            &callback,
485            returns_scalar,
486            return_dtype.as_ref(),
487            include_init
488        ),
489
490        F::MaxHorizontal => wrap!(misc::max_horizontal),
491        F::MinHorizontal => wrap!(misc::min_horizontal),
492        F::SumHorizontal { ignore_nulls } => wrap!(misc::sum_horizontal, ignore_nulls),
493        F::MeanHorizontal { ignore_nulls } => wrap!(misc::mean_horizontal, ignore_nulls),
494        #[cfg(feature = "ewma")]
495        F::EwmMean { options } => map!(misc::ewm_mean, options),
496        #[cfg(feature = "ewma_by")]
497        F::EwmMeanBy { half_life } => map_as_slice!(misc::ewm_mean_by, half_life),
498        #[cfg(feature = "ewma")]
499        F::EwmStd { options } => map!(misc::ewm_std, options),
500        #[cfg(feature = "ewma")]
501        F::EwmVar { options } => map!(misc::ewm_var, options),
502        #[cfg(feature = "replace")]
503        F::Replace => {
504            map_as_slice!(misc::replace)
505        },
506        #[cfg(feature = "replace")]
507        F::ReplaceStrict { return_dtype } => {
508            map_as_slice!(misc::replace_strict, return_dtype.clone())
509        },
510
511        F::FillNullWithStrategy(strategy) => map!(misc::fill_null_with_strategy, strategy),
512        F::GatherEvery { n, offset } => map!(misc::gather_every, n, offset),
513        #[cfg(feature = "reinterpret")]
514        F::Reinterpret(signed) => map!(misc::reinterpret, signed),
515        F::ExtendConstant => map_as_slice!(misc::extend_constant),
516
517        F::RowEncode(dts, variants) => {
518            map_as_slice!(misc::row_encode, dts.clone(), variants.clone())
519        },
520        #[cfg(feature = "dtype-struct")]
521        F::RowDecode(fs, variants) => {
522            map_as_slice!(misc::row_decode, fs.clone(), variants.clone())
523        },
524    }
525}
526
527pub trait GroupsUdf: Send + Sync + 'static {
528    fn evaluate_on_groups<'a>(
529        &self,
530        inputs: &[Arc<dyn PhysicalExpr>],
531        df: &DataFrame,
532        groups: &'a GroupPositions,
533        state: &ExecutionState,
534    ) -> PolarsResult<AggregationContext<'a>>;
535}
536
537pub fn function_expr_to_groups_udf(func: &IRFunctionExpr) -> Option<SpecialEq<Arc<dyn GroupsUdf>>> {
538    macro_rules! wrap_groups {
539        ($f:expr$(, ($arg:expr, $n:ident:$ty:ty))*) => {{
540            struct Wrap($($ty),*);
541            impl GroupsUdf for Wrap {
542                fn evaluate_on_groups<'a>(
543                    &self,
544                    inputs: &[Arc<dyn PhysicalExpr>],
545                    df: &DataFrame,
546                    groups: &'a GroupPositions,
547                    state: &ExecutionState,
548                ) -> PolarsResult<AggregationContext<'a>> {
549                    let Wrap($($n),*) = self;
550                    $f(inputs, df, groups, state$(, *$n)*)
551                }
552            }
553
554            SpecialEq::new(Arc::new(Wrap($($arg),*)) as Arc<dyn GroupsUdf>)
555        }};
556    }
557    use IRFunctionExpr as F;
558    Some(match func {
559        F::NullCount => wrap_groups!(groups_dispatch::null_count),
560        F::Reverse => wrap_groups!(groups_dispatch::reverse),
561        F::Boolean(IRBooleanFunction::Any { ignore_nulls }) => {
562            let ignore_nulls = *ignore_nulls;
563            wrap_groups!(groups_dispatch::any, (ignore_nulls, v: bool))
564        },
565        F::Boolean(IRBooleanFunction::All { ignore_nulls }) => {
566            let ignore_nulls = *ignore_nulls;
567            wrap_groups!(groups_dispatch::all, (ignore_nulls, v: bool))
568        },
569        #[cfg(feature = "bitwise")]
570        F::Bitwise(f) => {
571            use polars_plan::plans::IRBitwiseFunction as B;
572            match f {
573                B::And => wrap_groups!(groups_dispatch::bitwise_and),
574                B::Or => wrap_groups!(groups_dispatch::bitwise_or),
575                B::Xor => wrap_groups!(groups_dispatch::bitwise_xor),
576                _ => return None,
577            }
578        },
579        F::DropNans => wrap_groups!(groups_dispatch::drop_nans),
580        F::DropNulls => wrap_groups!(groups_dispatch::drop_nulls),
581
582        #[cfg(feature = "moment")]
583        F::Skew(bias) => wrap_groups!(groups_dispatch::skew, (*bias, v: bool)),
584        #[cfg(feature = "moment")]
585        F::Kurtosis(fisher, bias) => {
586            wrap_groups!(groups_dispatch::kurtosis, (*fisher, v1: bool), (*bias, v2: bool))
587        },
588
589        F::Unique(stable) => wrap_groups!(groups_dispatch::unique, (*stable, v: bool)),
590        F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Forward(limit)) => {
591            wrap_groups!(groups_dispatch::forward_fill_null, (*limit, v: Option<IdxSize>))
592        },
593        F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Backward(limit)) => {
594            wrap_groups!(groups_dispatch::backward_fill_null, (*limit, v: Option<IdxSize>))
595        },
596
597        _ => return None,
598    })
599}