Skip to main content

polars_expr/dispatch/
mod.rs

1use std::sync::Arc;
2
3use polars_core::error::PolarsResult;
4use polars_core::frame::DataFrame;
5use polars_core::prelude::{Column, GroupPositions};
6use polars_plan::dsl::{ColumnsUdf, SpecialEq};
7use polars_plan::plans::{IRBooleanFunction, IRFunctionExpr, IRPowFunction};
8use polars_utils::IdxSize;
9
10use crate::prelude::{AggregationContext, PhysicalExpr};
11use crate::state::ExecutionState;
12
13#[macro_export]
14macro_rules! wrap {
15    ($e:expr) => {
16        SpecialEq::new(Arc::new($e))
17    };
18
19    ($e:expr, $($args:expr),*) => {{
20        let f = move |s: &mut [::polars_core::prelude::Column]| {
21            $e(s, $($args),*)
22        };
23
24        SpecialEq::new(Arc::new(f))
25    }};
26}
27
28/// `Fn(&[Column], args)`
29/// * all expression arguments are in the slice.
30/// * the first element is the root expression.
31#[macro_export]
32macro_rules! map_as_slice {
33    ($func:path) => {{
34        let f = move |s: &mut [::polars_core::prelude::Column]| {
35            $func(s)
36        };
37
38        SpecialEq::new(Arc::new(f))
39    }};
40
41    ($func:path, $($args:expr),*) => {{
42        let f = move |s: &mut [::polars_core::prelude::Column]| {
43            $func(s, $($args),*)
44        };
45
46        SpecialEq::new(Arc::new(f))
47    }};
48}
49
50/// * `FnOnce(Series)`
51/// * `FnOnce(Series, args)`
52#[macro_export]
53macro_rules! map_owned {
54    ($func:path) => {{
55        let f = move |c: &mut [::polars_core::prelude::Column]| {
56            let c = std::mem::take(&mut c[0]);
57            $func(c)
58        };
59
60        SpecialEq::new(Arc::new(f))
61    }};
62
63    ($func:path, $($args:expr),*) => {{
64        let f = move |c: &mut [::polars_core::prelude::Column]| {
65            let c = std::mem::take(&mut c[0]);
66            $func(c, $($args),*)
67        };
68
69        SpecialEq::new(Arc::new(f))
70    }};
71}
72
73/// `Fn(&Series, args)`
74#[macro_export]
75macro_rules! map {
76    ($func:path) => {{
77        let f = move |c: &mut [::polars_core::prelude::Column]| {
78            let c = &c[0];
79            $func(c)
80        };
81
82        SpecialEq::new(Arc::new(f))
83    }};
84
85    ($func:path, $($args:expr),*) => {{
86        let f = move |c: &mut [::polars_core::prelude::Column]| {
87            let c = &c[0];
88            $func(c, $($args),*)
89        };
90
91        SpecialEq::new(Arc::new(f))
92    }};
93}
94
95#[cfg(feature = "dtype-array")]
96mod array;
97mod binary;
98#[cfg(feature = "bitwise")]
99mod bitwise;
100mod boolean;
101#[cfg(feature = "business")]
102mod business;
103#[cfg(feature = "dtype-categorical")]
104mod cat;
105#[cfg(feature = "cum_agg")]
106mod cum;
107#[cfg(feature = "temporal")]
108mod datetime;
109#[cfg(feature = "dtype-extension")]
110mod extension;
111mod groups_dispatch;
112mod horizontal;
113mod list;
114mod misc;
115mod pow;
116#[cfg(feature = "random")]
117mod random;
118#[cfg(feature = "range")]
119mod range;
120#[cfg(feature = "rolling_window")]
121mod rolling;
122#[cfg(feature = "rolling_window_by")]
123mod rolling_by;
124#[cfg(feature = "round_series")]
125mod round;
126mod shift_and_fill;
127#[cfg(feature = "strings")]
128mod strings;
129#[cfg(feature = "dtype-struct")]
130pub(crate) mod struct_;
131#[cfg(feature = "temporal")]
132mod temporal;
133#[cfg(feature = "trigonometry")]
134mod trigonometry;
135
136pub use groups_dispatch::drop_items;
137
138pub fn function_expr_to_udf(func: IRFunctionExpr) -> SpecialEq<Arc<dyn ColumnsUdf>> {
139    use IRFunctionExpr as F;
140    match func {
141        // Namespaces
142        #[cfg(feature = "dtype-array")]
143        F::ArrayExpr(func) => array::function_expr_to_udf(func),
144        F::BinaryExpr(func) => binary::function_expr_to_udf(func),
145        #[cfg(feature = "dtype-categorical")]
146        F::Categorical(func) => cat::function_expr_to_udf(func),
147        #[cfg(feature = "dtype-extension")]
148        F::Extension(func) => extension::function_expr_to_udf(func),
149        F::ListExpr(func) => list::function_expr_to_udf(func),
150        #[cfg(feature = "strings")]
151        F::StringExpr(func) => strings::function_expr_to_udf(func),
152        #[cfg(feature = "dtype-struct")]
153        F::StructExpr(func) => struct_::function_expr_to_udf(func),
154        #[cfg(feature = "temporal")]
155        F::TemporalExpr(func) => temporal::temporal_func_to_udf(func),
156        #[cfg(feature = "bitwise")]
157        F::Bitwise(func) => bitwise::function_expr_to_udf(func),
158
159        // Other expressions
160        F::Boolean(func) => boolean::function_expr_to_udf(func),
161        #[cfg(feature = "business")]
162        F::Business(func) => business::function_expr_to_udf(func),
163        #[cfg(feature = "abs")]
164        F::Abs => map!(misc::abs),
165        F::Negate => map!(misc::negate),
166        F::NullCount => {
167            let f = |s: &mut [Column]| {
168                let s = &s[0];
169                Ok(Column::new(s.name().clone(), [s.null_count() as IdxSize]))
170            };
171            wrap!(f)
172        },
173        F::Pow(func) => match func {
174            IRPowFunction::Generic => wrap!(pow::pow),
175            IRPowFunction::Sqrt => map!(pow::sqrt),
176            IRPowFunction::Cbrt => map!(pow::cbrt),
177        },
178        #[cfg(feature = "row_hash")]
179        F::Hash(k0, k1, k2, k3) => {
180            map!(misc::row_hash, k0, k1, k2, k3)
181        },
182        #[cfg(feature = "arg_where")]
183        F::ArgWhere => {
184            wrap!(misc::arg_where)
185        },
186        #[cfg(feature = "index_of")]
187        F::IndexOf => {
188            map_as_slice!(misc::index_of)
189        },
190        #[cfg(feature = "search_sorted")]
191        F::SearchSorted { side, descending } => {
192            map_as_slice!(misc::search_sorted_impl, side, descending)
193        },
194        #[cfg(feature = "range")]
195        F::Range(func) => range::function_expr_to_udf(func),
196
197        #[cfg(feature = "trigonometry")]
198        F::Trigonometry(trig_function) => {
199            map!(trigonometry::apply_trigonometric_function, trig_function)
200        },
201        #[cfg(feature = "trigonometry")]
202        F::Atan2 => {
203            wrap!(trigonometry::apply_arctan2)
204        },
205
206        #[cfg(feature = "sign")]
207        F::Sign => {
208            map!(misc::sign)
209        },
210        F::FillNull => {
211            map_as_slice!(misc::fill_null)
212        },
213        #[cfg(feature = "rolling_window")]
214        F::RollingExpr { function, options } => {
215            use IRRollingFunction::*;
216            use polars_plan::plans::IRRollingFunction;
217            match function {
218                Min => map!(rolling::rolling_min, options.clone()),
219                Max => map!(rolling::rolling_max, options.clone()),
220                Mean => map!(rolling::rolling_mean, options.clone()),
221                Sum => map!(rolling::rolling_sum, options.clone()),
222                Quantile => map!(rolling::rolling_quantile, options.clone()),
223                Var => map!(rolling::rolling_var, options.clone()),
224                Std => map!(rolling::rolling_std, options.clone()),
225                Rank => map!(rolling::rolling_rank, options.clone()),
226                #[cfg(feature = "moment")]
227                Skew => map!(rolling::rolling_skew, options.clone()),
228                #[cfg(feature = "moment")]
229                Kurtosis => map!(rolling::rolling_kurtosis, options.clone()),
230                #[cfg(feature = "cov")]
231                CorrCov {
232                    corr_cov_options,
233                    is_corr,
234                } => {
235                    map_as_slice!(
236                        rolling::rolling_corr_cov,
237                        options.clone(),
238                        corr_cov_options,
239                        is_corr
240                    )
241                },
242                Map(f) => {
243                    map!(rolling::rolling_map, options.clone(), f.clone())
244                },
245            }
246        },
247        #[cfg(feature = "rolling_window_by")]
248        F::RollingExprBy {
249            function_by,
250            options,
251        } => {
252            use IRRollingFunctionBy::*;
253            use polars_plan::plans::IRRollingFunctionBy;
254            match function_by {
255                MinBy => map_as_slice!(rolling_by::rolling_min_by, options.clone()),
256                MaxBy => map_as_slice!(rolling_by::rolling_max_by, options.clone()),
257                MeanBy => map_as_slice!(rolling_by::rolling_mean_by, options.clone()),
258                SumBy => map_as_slice!(rolling_by::rolling_sum_by, options.clone()),
259                QuantileBy => {
260                    map_as_slice!(rolling_by::rolling_quantile_by, options.clone())
261                },
262                VarBy => map_as_slice!(rolling_by::rolling_var_by, options.clone()),
263                StdBy => map_as_slice!(rolling_by::rolling_std_by, options.clone()),
264                RankBy => map_as_slice!(rolling_by::rolling_rank_by, options.clone()),
265            }
266        },
267        #[cfg(feature = "hist")]
268        F::Hist {
269            bin_count,
270            include_category,
271            include_breakpoint,
272        } => {
273            map_as_slice!(misc::hist, bin_count, include_category, include_breakpoint)
274        },
275        F::Rechunk => map!(misc::rechunk),
276        F::Append { upcast } => map_as_slice!(misc::append, upcast),
277        F::ShiftAndFill => {
278            map_as_slice!(shift_and_fill::shift_and_fill)
279        },
280        F::DropNans => map_owned!(misc::drop_nans),
281        F::DropNulls => map!(misc::drop_nulls),
282        #[cfg(feature = "round_series")]
283        F::Clip { has_min, has_max } => {
284            map_as_slice!(misc::clip, has_min, has_max)
285        },
286        #[cfg(feature = "mode")]
287        F::Mode { maintain_order } => map!(misc::mode, maintain_order),
288        #[cfg(feature = "moment")]
289        F::Skew(bias) => map!(misc::skew, bias),
290        #[cfg(feature = "moment")]
291        F::Kurtosis(fisher, bias) => map!(misc::kurtosis, fisher, bias),
292        F::ArgUnique => map!(misc::arg_unique),
293        F::ArgMin => map!(misc::arg_min),
294        F::ArgMax => map!(misc::arg_max),
295        F::ArgSort {
296            descending,
297            nulls_last,
298        } => map!(misc::arg_sort, descending, nulls_last),
299        F::MinBy => map_as_slice!(misc::min_by),
300        F::MaxBy => map_as_slice!(misc::max_by),
301        F::Product => map!(misc::product),
302        F::Repeat => map_as_slice!(misc::repeat),
303        #[cfg(feature = "rank")]
304        F::Rank { options, seed } => map!(misc::rank, options, seed),
305        #[cfg(feature = "dtype-struct")]
306        F::AsStruct => {
307            map_as_slice!(misc::as_struct)
308        },
309        #[cfg(feature = "top_k")]
310        F::TopK { descending } => {
311            map_as_slice!(polars_ops::prelude::top_k, descending)
312        },
313        #[cfg(feature = "top_k")]
314        F::TopKBy { descending } => {
315            map_as_slice!(polars_ops::prelude::top_k_by, descending.clone())
316        },
317        F::Shift => map_as_slice!(shift_and_fill::shift),
318        #[cfg(feature = "cum_agg")]
319        F::CumCount { reverse } => map!(cum::cum_count, reverse),
320        #[cfg(feature = "cum_agg")]
321        F::CumSum { reverse } => map!(cum::cum_sum, reverse),
322        #[cfg(feature = "cum_agg")]
323        F::CumProd { reverse } => map!(cum::cum_prod, reverse),
324        #[cfg(feature = "cum_agg")]
325        F::CumMin { reverse } => map!(cum::cum_min, reverse),
326        #[cfg(feature = "cum_agg")]
327        F::CumMax { reverse } => map!(cum::cum_max, reverse),
328        #[cfg(feature = "dtype-struct")]
329        F::ValueCounts {
330            sort,
331            parallel,
332            name,
333            normalize,
334        } => map!(misc::value_counts, sort, parallel, name.clone(), normalize),
335        #[cfg(feature = "unique_counts")]
336        F::UniqueCounts => map!(misc::unique_counts),
337        F::Reverse => map!(misc::reverse),
338        #[cfg(feature = "approx_unique")]
339        F::ApproxNUnique => map!(misc::approx_n_unique),
340        F::Coalesce => map_as_slice!(misc::coalesce),
341        #[cfg(feature = "diff")]
342        F::Diff(null_behavior) => map_as_slice!(misc::diff, null_behavior),
343        #[cfg(feature = "pct_change")]
344        F::PctChange => map_as_slice!(misc::pct_change),
345        #[cfg(feature = "interpolate")]
346        F::Interpolate(method) => {
347            map!(misc::interpolate, method)
348        },
349        #[cfg(feature = "interpolate_by")]
350        F::InterpolateBy => {
351            map_as_slice!(misc::interpolate_by)
352        },
353        #[cfg(feature = "log")]
354        F::Entropy { base, normalize } => map!(misc::entropy, base, normalize),
355        #[cfg(feature = "log")]
356        F::Log => map_as_slice!(misc::log),
357        #[cfg(feature = "log")]
358        F::Log1p => map!(misc::log1p),
359        #[cfg(feature = "log")]
360        F::Exp => map!(misc::exp),
361        F::Unique(stable) => map!(misc::unique, stable),
362        #[cfg(feature = "round_series")]
363        F::Round { decimals, mode } => map!(round::round, decimals, mode),
364        #[cfg(feature = "round_series")]
365        F::RoundSF { digits } => map!(round::round_sig_figs, digits),
366        #[cfg(feature = "round_series")]
367        F::Floor => map!(round::floor),
368        #[cfg(feature = "round_series")]
369        F::Ceil => map!(round::ceil),
370        #[cfg(feature = "fused")]
371        F::Fused(op) => map_as_slice!(misc::fused, op),
372        F::ConcatExpr(rechunk) => map_as_slice!(misc::concat_expr, rechunk),
373        #[cfg(feature = "cov")]
374        F::Correlation { method } => map_as_slice!(misc::corr, method),
375        #[cfg(feature = "peaks")]
376        F::PeakMin => map!(misc::peak_min),
377        #[cfg(feature = "peaks")]
378        F::PeakMax => map!(misc::peak_max),
379        #[cfg(feature = "repeat_by")]
380        F::RepeatBy => map_as_slice!(misc::repeat_by),
381        #[cfg(feature = "dtype-array")]
382        F::Reshape(dims) => map!(misc::reshape, &dims),
383        #[cfg(feature = "cutqcut")]
384        F::Cut {
385            breaks,
386            labels,
387            left_closed,
388            include_breaks,
389        } => map!(
390            misc::cut,
391            breaks.clone(),
392            labels.clone(),
393            left_closed,
394            include_breaks
395        ),
396        #[cfg(feature = "cutqcut")]
397        F::QCut {
398            probs,
399            labels,
400            left_closed,
401            allow_duplicates,
402            include_breaks,
403        } => map!(
404            misc::qcut,
405            probs.clone(),
406            labels.clone(),
407            left_closed,
408            allow_duplicates,
409            include_breaks
410        ),
411        #[cfg(feature = "rle")]
412        F::RLE => map!(polars_ops::series::rle),
413        #[cfg(feature = "rle")]
414        F::RLEID => map!(polars_ops::series::rle_id),
415        F::ToPhysical => map!(misc::to_physical),
416        #[cfg(feature = "random")]
417        F::Random { method, seed } => {
418            use IRRandomMethod::*;
419            use polars_plan::plans::IRRandomMethod;
420            match method {
421                Shuffle => map!(random::shuffle, seed),
422                Sample {
423                    is_fraction,
424                    with_replacement,
425                    shuffle,
426                } => {
427                    if is_fraction {
428                        map_as_slice!(random::sample_frac, with_replacement, shuffle, seed)
429                    } else {
430                        map_as_slice!(random::sample_n, with_replacement, shuffle, seed)
431                    }
432                },
433            }
434        },
435        F::SetSortedFlag(sorted) => map!(misc::set_sorted_flag, sorted),
436        #[cfg(feature = "ffi_plugin")]
437        F::FfiPlugin {
438            flags: _,
439            lib,
440            symbol,
441            kwargs,
442        } => unsafe {
443            map_as_slice!(
444                polars_plan::plans::plugin::call_plugin,
445                lib.as_ref(),
446                symbol.as_ref(),
447                kwargs.as_ref()
448            )
449        },
450
451        F::FoldHorizontal {
452            callback,
453            returns_scalar,
454            return_dtype,
455        } => map_as_slice!(
456            horizontal::fold,
457            &callback,
458            returns_scalar,
459            return_dtype.as_ref()
460        ),
461        F::ReduceHorizontal {
462            callback,
463            returns_scalar,
464            return_dtype,
465        } => map_as_slice!(
466            horizontal::reduce,
467            &callback,
468            returns_scalar,
469            return_dtype.as_ref()
470        ),
471        #[cfg(feature = "dtype-struct")]
472        F::CumReduceHorizontal {
473            callback,
474            returns_scalar,
475            return_dtype,
476        } => map_as_slice!(
477            horizontal::cum_reduce,
478            &callback,
479            returns_scalar,
480            return_dtype.as_ref()
481        ),
482        #[cfg(feature = "dtype-struct")]
483        F::CumFoldHorizontal {
484            callback,
485            returns_scalar,
486            return_dtype,
487            include_init,
488        } => map_as_slice!(
489            horizontal::cum_fold,
490            &callback,
491            returns_scalar,
492            return_dtype.as_ref(),
493            include_init
494        ),
495
496        F::MaxHorizontal => wrap!(misc::max_horizontal),
497        F::MinHorizontal => wrap!(misc::min_horizontal),
498        F::SumHorizontal { ignore_nulls } => wrap!(misc::sum_horizontal, ignore_nulls),
499        F::MeanHorizontal { ignore_nulls } => wrap!(misc::mean_horizontal, ignore_nulls),
500        #[cfg(feature = "ewma")]
501        F::EwmMean { options } => map!(misc::ewm_mean, options),
502        #[cfg(feature = "ewma_by")]
503        F::EwmMeanBy { half_life } => map_as_slice!(misc::ewm_mean_by, half_life),
504        #[cfg(feature = "ewma")]
505        F::EwmStd { options } => map!(misc::ewm_std, options),
506        #[cfg(feature = "ewma")]
507        F::EwmVar { options } => map!(misc::ewm_var, options),
508        #[cfg(feature = "replace")]
509        F::Replace => {
510            map_as_slice!(misc::replace)
511        },
512        #[cfg(feature = "replace")]
513        F::ReplaceStrict { return_dtype } => {
514            map_as_slice!(misc::replace_strict, return_dtype.clone())
515        },
516
517        F::FillNullWithStrategy(strategy) => map!(misc::fill_null_with_strategy, strategy),
518        F::GatherEvery { n, offset } => map!(misc::gather_every, n, offset),
519        #[cfg(feature = "reinterpret")]
520        F::Reinterpret(signed) => map!(misc::reinterpret, signed),
521        F::ExtendConstant => map_as_slice!(misc::extend_constant),
522
523        F::RowEncode(dts, variants) => {
524            map_as_slice!(misc::row_encode, dts.clone(), variants.clone())
525        },
526        #[cfg(feature = "dtype-struct")]
527        F::RowDecode(fs, variants) => {
528            map_as_slice!(misc::row_decode, fs.clone(), variants.clone())
529        },
530    }
531}
532
533pub trait GroupsUdf: Send + Sync + 'static {
534    fn evaluate_on_groups<'a>(
535        &self,
536        inputs: &[Arc<dyn PhysicalExpr>],
537        df: &DataFrame,
538        groups: &'a GroupPositions,
539        state: &ExecutionState,
540    ) -> PolarsResult<AggregationContext<'a>>;
541}
542
543pub fn function_expr_to_groups_udf(func: &IRFunctionExpr) -> Option<SpecialEq<Arc<dyn GroupsUdf>>> {
544    macro_rules! wrap_groups {
545        ($f:expr$(, ($arg:expr, $n:ident:$ty:ty))*) => {{
546            struct Wrap($($ty),*);
547            impl GroupsUdf for Wrap {
548                fn evaluate_on_groups<'a>(
549                    &self,
550                    inputs: &[Arc<dyn PhysicalExpr>],
551                    df: &DataFrame,
552                    groups: &'a GroupPositions,
553                    state: &ExecutionState,
554                ) -> PolarsResult<AggregationContext<'a>> {
555                    let Wrap($($n),*) = self;
556                    $f(inputs, df, groups, state$(, *$n)*)
557                }
558            }
559
560            SpecialEq::new(Arc::new(Wrap($($arg),*)) as Arc<dyn GroupsUdf>)
561        }};
562    }
563    use IRFunctionExpr as F;
564    Some(match func {
565        F::NullCount => wrap_groups!(groups_dispatch::null_count),
566        F::Reverse => wrap_groups!(groups_dispatch::reverse),
567        F::Boolean(IRBooleanFunction::Any { ignore_nulls }) => {
568            let ignore_nulls = *ignore_nulls;
569            wrap_groups!(groups_dispatch::any, (ignore_nulls, v: bool))
570        },
571        F::Boolean(IRBooleanFunction::All { ignore_nulls }) => {
572            let ignore_nulls = *ignore_nulls;
573            wrap_groups!(groups_dispatch::all, (ignore_nulls, v: bool))
574        },
575        #[cfg(feature = "bitwise")]
576        F::Bitwise(f) => {
577            use polars_plan::plans::IRBitwiseFunction as B;
578            match f {
579                B::And => wrap_groups!(groups_dispatch::bitwise_and),
580                B::Or => wrap_groups!(groups_dispatch::bitwise_or),
581                B::Xor => wrap_groups!(groups_dispatch::bitwise_xor),
582                _ => return None,
583            }
584        },
585        F::DropNans => wrap_groups!(groups_dispatch::drop_nans),
586        F::DropNulls => wrap_groups!(groups_dispatch::drop_nulls),
587
588        #[cfg(feature = "moment")]
589        F::Skew(bias) => wrap_groups!(groups_dispatch::skew, (*bias, v: bool)),
590        #[cfg(feature = "moment")]
591        F::Kurtosis(fisher, bias) => {
592            wrap_groups!(groups_dispatch::kurtosis, (*fisher, v1: bool), (*bias, v2: bool))
593        },
594
595        F::Unique(stable) => wrap_groups!(groups_dispatch::unique, (*stable, v: bool)),
596        F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Forward(limit)) => {
597            wrap_groups!(groups_dispatch::forward_fill_null, (*limit, v: Option<IdxSize>))
598        },
599        F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Backward(limit)) => {
600            wrap_groups!(groups_dispatch::backward_fill_null, (*limit, v: Option<IdxSize>))
601        },
602
603        _ => return None,
604    })
605}