Skip to main content

polars_expr/dispatch/
mod.rs

1use std::sync::Arc;
2
3use polars_compute::rolling::QuantileMethod;
4use polars_core::error::PolarsResult;
5use polars_core::frame::DataFrame;
6use polars_core::prelude::{Column, GroupPositions};
7use polars_plan::dsl::{ColumnsUdf, SpecialEq};
8use polars_plan::plans::{IRBooleanFunction, IRFunctionExpr, IRPowFunction};
9use polars_utils::IdxSize;
10
11use crate::prelude::{AggregationContext, PhysicalExpr};
12use crate::state::ExecutionState;
13
14#[macro_export]
15macro_rules! wrap {
16    ($e:expr) => {
17        SpecialEq::new(Arc::new($e))
18    };
19
20    ($e:expr, $($args:expr),*) => {{
21        let f = move |s: &mut [::polars_core::prelude::Column]| {
22            $e(s, $($args),*)
23        };
24
25        SpecialEq::new(Arc::new(f))
26    }};
27}
28
29/// `Fn(&[Column], args)`
30/// * all expression arguments are in the slice.
31/// * the first element is the root expression.
32#[macro_export]
33macro_rules! map_as_slice {
34    ($func:path) => {{
35        let f = move |s: &mut [::polars_core::prelude::Column]| {
36            $func(s)
37        };
38
39        SpecialEq::new(Arc::new(f))
40    }};
41
42    ($func:path, $($args:expr),*) => {{
43        let f = move |s: &mut [::polars_core::prelude::Column]| {
44            $func(s, $($args),*)
45        };
46
47        SpecialEq::new(Arc::new(f))
48    }};
49}
50
51/// * `FnOnce(Series)`
52/// * `FnOnce(Series, args)`
53#[macro_export]
54macro_rules! map_owned {
55    ($func:path) => {{
56        let f = move |c: &mut [::polars_core::prelude::Column]| {
57            let c = std::mem::take(&mut c[0]);
58            $func(c)
59        };
60
61        SpecialEq::new(Arc::new(f))
62    }};
63
64    ($func:path, $($args:expr),*) => {{
65        let f = move |c: &mut [::polars_core::prelude::Column]| {
66            let c = std::mem::take(&mut c[0]);
67            $func(c, $($args),*)
68        };
69
70        SpecialEq::new(Arc::new(f))
71    }};
72}
73
74/// `Fn(&Series, args)`
75#[macro_export]
76macro_rules! map {
77    ($func:path) => {{
78        let f = move |c: &mut [::polars_core::prelude::Column]| {
79            let c = &c[0];
80            $func(c)
81        };
82
83        SpecialEq::new(Arc::new(f))
84    }};
85
86    ($func:path, $($args:expr),*) => {{
87        let f = move |c: &mut [::polars_core::prelude::Column]| {
88            let c = &c[0];
89            $func(c, $($args),*)
90        };
91
92        SpecialEq::new(Arc::new(f))
93    }};
94}
95
96#[cfg(feature = "dtype-array")]
97mod array;
98mod binary;
99#[cfg(feature = "bitwise")]
100mod bitwise;
101mod boolean;
102#[cfg(feature = "business")]
103mod business;
104#[cfg(feature = "dtype-categorical")]
105mod cat;
106#[cfg(feature = "cum_agg")]
107mod cum;
108#[cfg(feature = "temporal")]
109mod datetime;
110#[cfg(feature = "dtype-extension")]
111mod extension;
112mod groups_dispatch;
113mod horizontal;
114mod list;
115mod misc;
116mod pow;
117#[cfg(feature = "random")]
118mod random;
119#[cfg(feature = "range")]
120mod range;
121#[cfg(feature = "rolling_window")]
122mod rolling;
123#[cfg(feature = "rolling_window_by")]
124mod rolling_by;
125#[cfg(feature = "round_series")]
126mod round;
127mod shift_and_fill;
128#[cfg(feature = "strings")]
129mod strings;
130#[cfg(feature = "dtype-struct")]
131pub(crate) mod struct_;
132#[cfg(feature = "temporal")]
133mod temporal;
134#[cfg(feature = "trigonometry")]
135mod trigonometry;
136
137pub use groups_dispatch::drop_items;
138
139pub fn function_expr_to_udf(func: IRFunctionExpr) -> SpecialEq<Arc<dyn ColumnsUdf>> {
140    use IRFunctionExpr as F;
141    match func {
142        // Namespaces
143        #[cfg(feature = "dtype-array")]
144        F::ArrayExpr(func) => array::function_expr_to_udf(func),
145        F::BinaryExpr(func) => binary::function_expr_to_udf(func),
146        #[cfg(feature = "dtype-categorical")]
147        F::Categorical(func) => cat::function_expr_to_udf(func),
148        #[cfg(feature = "dtype-extension")]
149        F::Extension(func) => extension::function_expr_to_udf(func),
150        F::ListExpr(func) => list::function_expr_to_udf(func),
151        #[cfg(feature = "strings")]
152        F::StringExpr(func) => strings::function_expr_to_udf(func),
153        #[cfg(feature = "dtype-struct")]
154        F::StructExpr(func) => struct_::function_expr_to_udf(func),
155        #[cfg(feature = "temporal")]
156        F::TemporalExpr(func) => temporal::temporal_func_to_udf(func),
157        #[cfg(feature = "bitwise")]
158        F::Bitwise(func) => bitwise::function_expr_to_udf(func),
159
160        // Other expressions
161        F::Boolean(func) => boolean::function_expr_to_udf(func),
162        #[cfg(feature = "business")]
163        F::Business(func) => business::function_expr_to_udf(func),
164        #[cfg(feature = "abs")]
165        F::Abs => map!(misc::abs),
166        F::Negate => map!(misc::negate),
167        F::NullCount => {
168            let f = |s: &mut [Column]| {
169                let s = &s[0];
170                Ok(Column::new(s.name().clone(), [s.null_count() as IdxSize]))
171            };
172            wrap!(f)
173        },
174        F::Pow(func) => match func {
175            IRPowFunction::Generic => wrap!(pow::pow),
176            IRPowFunction::Sqrt => map!(pow::sqrt),
177            IRPowFunction::Cbrt => map!(pow::cbrt),
178        },
179        #[cfg(feature = "row_hash")]
180        F::Hash(k0, k1, k2, k3) => {
181            map!(misc::row_hash, k0, k1, k2, k3)
182        },
183        #[cfg(feature = "arg_where")]
184        F::ArgWhere => {
185            wrap!(misc::arg_where)
186        },
187        #[cfg(feature = "index_of")]
188        F::IndexOf => {
189            map_as_slice!(misc::index_of)
190        },
191        #[cfg(feature = "search_sorted")]
192        F::SearchSorted { side, descending } => {
193            map_as_slice!(misc::search_sorted_impl, side, descending)
194        },
195        #[cfg(feature = "range")]
196        F::Range(func) => range::function_expr_to_udf(func),
197
198        #[cfg(feature = "trigonometry")]
199        F::Trigonometry(trig_function) => {
200            map!(trigonometry::apply_trigonometric_function, trig_function)
201        },
202        #[cfg(feature = "trigonometry")]
203        F::Atan2 => {
204            wrap!(trigonometry::apply_arctan2)
205        },
206
207        #[cfg(feature = "sign")]
208        F::Sign => {
209            map!(misc::sign)
210        },
211        F::FillNull => {
212            map_as_slice!(misc::fill_null)
213        },
214        #[cfg(feature = "rolling_window")]
215        F::RollingExpr { function, options } => {
216            use IRRollingFunction::*;
217            use polars_plan::plans::IRRollingFunction;
218            match function {
219                Min => map!(rolling::rolling_min, options.clone()),
220                Max => map!(rolling::rolling_max, options.clone()),
221                Mean => map!(rolling::rolling_mean, options.clone()),
222                Sum => map!(rolling::rolling_sum, options.clone()),
223                Quantile => map!(rolling::rolling_quantile, options.clone()),
224                Var => map!(rolling::rolling_var, options.clone()),
225                Std => map!(rolling::rolling_std, options.clone()),
226                Rank => map!(rolling::rolling_rank, options.clone()),
227                #[cfg(feature = "moment")]
228                Skew => map!(rolling::rolling_skew, options.clone()),
229                #[cfg(feature = "moment")]
230                Kurtosis => map!(rolling::rolling_kurtosis, options.clone()),
231                #[cfg(feature = "cov")]
232                CorrCov {
233                    corr_cov_options,
234                    is_corr,
235                } => {
236                    map_as_slice!(
237                        rolling::rolling_corr_cov,
238                        options.clone(),
239                        corr_cov_options,
240                        is_corr
241                    )
242                },
243                Map(f) => {
244                    map!(rolling::rolling_map, options.clone(), f.clone())
245                },
246            }
247        },
248        #[cfg(feature = "rolling_window_by")]
249        F::RollingExprBy {
250            function_by,
251            options,
252        } => {
253            use IRRollingFunctionBy::*;
254            use polars_plan::plans::IRRollingFunctionBy;
255            match function_by {
256                MinBy => map_as_slice!(rolling_by::rolling_min_by, options.clone()),
257                MaxBy => map_as_slice!(rolling_by::rolling_max_by, options.clone()),
258                MeanBy => map_as_slice!(rolling_by::rolling_mean_by, options.clone()),
259                SumBy => map_as_slice!(rolling_by::rolling_sum_by, options.clone()),
260                QuantileBy => {
261                    map_as_slice!(rolling_by::rolling_quantile_by, options.clone())
262                },
263                VarBy => map_as_slice!(rolling_by::rolling_var_by, options.clone()),
264                StdBy => map_as_slice!(rolling_by::rolling_std_by, options.clone()),
265                RankBy => map_as_slice!(rolling_by::rolling_rank_by, options.clone()),
266            }
267        },
268        #[cfg(feature = "hist")]
269        F::Hist {
270            bin_count,
271            include_category,
272            include_breakpoint,
273        } => {
274            map_as_slice!(misc::hist, bin_count, include_category, include_breakpoint)
275        },
276        F::Rechunk => map!(misc::rechunk),
277        F::ShiftAndFill => {
278            map_as_slice!(shift_and_fill::shift_and_fill)
279        },
280        F::DropNans => map_owned!(misc::drop_nans),
281        F::DropNulls => map!(misc::drop_nulls),
282        #[cfg(feature = "round_series")]
283        F::Clip { has_min, has_max } => {
284            map_as_slice!(misc::clip, has_min, has_max)
285        },
286        F::Quantile { method } => map_as_slice!(misc::quantile, method),
287        #[cfg(feature = "mode")]
288        F::Mode { maintain_order } => map!(misc::mode, maintain_order),
289        #[cfg(feature = "moment")]
290        F::Skew(bias) => map!(misc::skew, bias),
291        #[cfg(feature = "moment")]
292        F::Kurtosis(fisher, bias) => map!(misc::kurtosis, fisher, bias),
293        F::ArgUnique => map!(misc::arg_unique),
294        F::ArgMin => map!(misc::arg_min),
295        F::ArgMax => map!(misc::arg_max),
296        F::ArgSort {
297            descending,
298            nulls_last,
299        } => map!(misc::arg_sort, descending, nulls_last),
300        F::MinBy => map_as_slice!(misc::min_by),
301        F::MaxBy => map_as_slice!(misc::max_by),
302        F::Product => map!(misc::product),
303        F::Repeat => map_as_slice!(misc::repeat),
304        #[cfg(feature = "rank")]
305        F::Rank { options, seed } => map!(misc::rank, options, seed),
306        #[cfg(feature = "dtype-struct")]
307        F::AsStruct => {
308            map_as_slice!(misc::as_struct)
309        },
310        #[cfg(feature = "top_k")]
311        F::TopK { descending } => {
312            map_as_slice!(polars_ops::prelude::top_k, descending)
313        },
314        #[cfg(feature = "top_k")]
315        F::TopKBy { descending } => {
316            map_as_slice!(polars_ops::prelude::top_k_by, descending.clone())
317        },
318        F::Shift => map_as_slice!(shift_and_fill::shift),
319        #[cfg(feature = "cum_agg")]
320        F::CumCount { reverse } => map!(cum::cum_count, reverse),
321        #[cfg(feature = "cum_agg")]
322        F::CumSum { reverse } => map!(cum::cum_sum, reverse),
323        #[cfg(feature = "cum_agg")]
324        F::CumProd { reverse } => map!(cum::cum_prod, reverse),
325        #[cfg(feature = "cum_agg")]
326        F::CumMin { reverse } => map!(cum::cum_min, reverse),
327        #[cfg(feature = "cum_agg")]
328        F::CumMax { reverse } => map!(cum::cum_max, reverse),
329        #[cfg(feature = "dtype-struct")]
330        F::ValueCounts {
331            sort,
332            parallel,
333            name,
334            normalize,
335        } => map!(misc::value_counts, sort, parallel, name.clone(), normalize),
336        #[cfg(feature = "unique_counts")]
337        F::UniqueCounts => map!(misc::unique_counts),
338        F::Reverse => map!(misc::reverse),
339        #[cfg(feature = "approx_unique")]
340        F::ApproxNUnique => map!(misc::approx_n_unique),
341        F::Coalesce => map_as_slice!(misc::coalesce),
342        #[cfg(feature = "diff")]
343        F::Diff(null_behavior) => map_as_slice!(misc::diff, null_behavior),
344        #[cfg(feature = "pct_change")]
345        F::PctChange => map_as_slice!(misc::pct_change),
346        #[cfg(feature = "interpolate")]
347        F::Interpolate(method) => {
348            map!(misc::interpolate, method)
349        },
350        #[cfg(feature = "interpolate_by")]
351        F::InterpolateBy => {
352            map_as_slice!(misc::interpolate_by)
353        },
354        #[cfg(feature = "log")]
355        F::Entropy { base, normalize } => map!(misc::entropy, base, normalize),
356        #[cfg(feature = "log")]
357        F::Log => map_as_slice!(misc::log),
358        #[cfg(feature = "log")]
359        F::Log1p => map!(misc::log1p),
360        #[cfg(feature = "log")]
361        F::Exp => map!(misc::exp),
362        F::Unique(stable) => map!(misc::unique, stable),
363        #[cfg(feature = "round_series")]
364        F::Round { decimals, mode } => map!(round::round, decimals, mode),
365        #[cfg(feature = "round_series")]
366        F::RoundSF { digits } => map!(round::round_sig_figs, digits),
367        #[cfg(feature = "round_series")]
368        F::Truncate { decimals } => map!(round::truncate, decimals),
369        #[cfg(feature = "round_series")]
370        F::Floor => map!(round::floor),
371        #[cfg(feature = "round_series")]
372        F::Ceil => map!(round::ceil),
373        #[cfg(feature = "fused")]
374        F::Fused(op) => map_as_slice!(misc::fused, op),
375        F::ConcatExpr { rechunk } => map_as_slice!(misc::concat_expr, rechunk),
376        #[cfg(feature = "cov")]
377        F::Correlation { method } => map_as_slice!(misc::corr, method),
378        #[cfg(feature = "peaks")]
379        F::PeakMin => map!(misc::peak_min),
380        #[cfg(feature = "peaks")]
381        F::PeakMax => map!(misc::peak_max),
382        #[cfg(feature = "repeat_by")]
383        F::RepeatBy => map_as_slice!(misc::repeat_by),
384        #[cfg(feature = "dtype-array")]
385        F::Reshape(dims) => map!(misc::reshape, &dims),
386        #[cfg(feature = "cutqcut")]
387        F::Cut {
388            breaks,
389            labels,
390            left_closed,
391            include_breaks,
392        } => map!(
393            misc::cut,
394            breaks.clone(),
395            labels.clone(),
396            left_closed,
397            include_breaks
398        ),
399        #[cfg(feature = "cutqcut")]
400        F::QCut {
401            probs,
402            labels,
403            left_closed,
404            allow_duplicates,
405            include_breaks,
406        } => map!(
407            misc::qcut,
408            probs.clone(),
409            labels.clone(),
410            left_closed,
411            allow_duplicates,
412            include_breaks
413        ),
414        #[cfg(feature = "rle")]
415        F::RLE => map!(polars_ops::series::rle),
416        #[cfg(feature = "rle")]
417        F::RLEID => map!(polars_ops::series::rle_id),
418        F::ToPhysical => map!(misc::to_physical),
419        #[cfg(feature = "random")]
420        F::Random { method, seed } => {
421            use IRRandomMethod::*;
422            use polars_plan::plans::IRRandomMethod;
423            match method {
424                Shuffle => map!(random::shuffle, seed),
425                Sample {
426                    is_fraction,
427                    with_replacement,
428                    shuffle,
429                } => {
430                    if is_fraction {
431                        map_as_slice!(random::sample_frac, with_replacement, shuffle, seed)
432                    } else {
433                        map_as_slice!(random::sample_n, with_replacement, shuffle, seed)
434                    }
435                },
436            }
437        },
438        F::SetSortedFlag(sortedness) => map!(misc::set_sorted_flag, sortedness),
439        #[cfg(feature = "ffi_plugin")]
440        F::FfiPlugin {
441            flags: _,
442            lib,
443            symbol,
444            kwargs,
445        } => unsafe {
446            map_as_slice!(
447                polars_plan::plans::plugin::call_plugin,
448                lib.as_ref(),
449                symbol.as_ref(),
450                kwargs.as_ref()
451            )
452        },
453
454        F::FoldHorizontal {
455            callback,
456            returns_scalar,
457            return_dtype,
458        } => map_as_slice!(
459            horizontal::fold,
460            &callback,
461            returns_scalar,
462            return_dtype.as_ref()
463        ),
464        F::ReduceHorizontal {
465            callback,
466            returns_scalar,
467            return_dtype,
468        } => map_as_slice!(
469            horizontal::reduce,
470            &callback,
471            returns_scalar,
472            return_dtype.as_ref()
473        ),
474        #[cfg(feature = "dtype-struct")]
475        F::CumReduceHorizontal {
476            callback,
477            returns_scalar,
478            return_dtype,
479        } => map_as_slice!(
480            horizontal::cum_reduce,
481            &callback,
482            returns_scalar,
483            return_dtype.as_ref()
484        ),
485        #[cfg(feature = "dtype-struct")]
486        F::CumFoldHorizontal {
487            callback,
488            returns_scalar,
489            return_dtype,
490            include_init,
491        } => map_as_slice!(
492            horizontal::cum_fold,
493            &callback,
494            returns_scalar,
495            return_dtype.as_ref(),
496            include_init
497        ),
498
499        F::MaxHorizontal => wrap!(misc::max_horizontal),
500        F::MinHorizontal => wrap!(misc::min_horizontal),
501        F::SumHorizontal { ignore_nulls } => wrap!(misc::sum_horizontal, ignore_nulls),
502        F::MeanHorizontal { ignore_nulls } => wrap!(misc::mean_horizontal, ignore_nulls),
503        #[cfg(feature = "ewma")]
504        F::EwmMean { options } => map!(misc::ewm_mean, options),
505        #[cfg(feature = "ewma_by")]
506        F::EwmMeanBy { half_life } => map_as_slice!(misc::ewm_mean_by, half_life),
507        #[cfg(feature = "ewma")]
508        F::EwmStd { options } => map!(misc::ewm_std, options),
509        #[cfg(feature = "ewma")]
510        F::EwmVar { options } => map!(misc::ewm_var, options),
511        #[cfg(feature = "replace")]
512        F::Replace => {
513            map_as_slice!(misc::replace)
514        },
515        #[cfg(feature = "replace")]
516        F::ReplaceStrict { return_dtype } => {
517            map_as_slice!(misc::replace_strict, return_dtype.clone())
518        },
519
520        F::FillNullWithStrategy(strategy) => map!(misc::fill_null_with_strategy, strategy),
521        F::GatherEvery { n, offset } => map!(misc::gather_every, n, offset),
522        #[cfg(feature = "reinterpret")]
523        F::Reinterpret(dtype) => map!(misc::reinterpret, &dtype),
524        F::ExtendConstant => map_as_slice!(misc::extend_constant),
525
526        F::RowEncode(dts, variants) => {
527            map_as_slice!(misc::row_encode, dts.clone(), variants.clone())
528        },
529        #[cfg(feature = "dtype-struct")]
530        F::RowDecode(fs, variants) => {
531            map_as_slice!(misc::row_decode, fs.clone(), variants.clone())
532        },
533        F::DynamicPred { pred } => {
534            map_as_slice!(misc::dynamic_pred, &pred)
535        },
536    }
537}
538
539pub trait GroupsUdf: Send + Sync + 'static {
540    fn evaluate_on_groups<'a>(
541        &self,
542        inputs: &[Arc<dyn PhysicalExpr>],
543        df: &DataFrame,
544        groups: &'a GroupPositions,
545        state: &ExecutionState,
546    ) -> PolarsResult<AggregationContext<'a>>;
547}
548
549pub fn function_expr_to_groups_udf(func: &IRFunctionExpr) -> Option<SpecialEq<Arc<dyn GroupsUdf>>> {
550    macro_rules! wrap_groups {
551        ($f:expr$(, ($arg:expr, $n:ident:$ty:ty))*) => {{
552            struct Wrap($($ty),*);
553            impl GroupsUdf for Wrap {
554                fn evaluate_on_groups<'a>(
555                    &self,
556                    inputs: &[Arc<dyn PhysicalExpr>],
557                    df: &DataFrame,
558                    groups: &'a GroupPositions,
559                    state: &ExecutionState,
560                ) -> PolarsResult<AggregationContext<'a>> {
561                    let Wrap($($n),*) = self;
562                    $f(inputs, df, groups, state$(, *$n)*)
563                }
564            }
565
566            SpecialEq::new(Arc::new(Wrap($($arg),*)) as Arc<dyn GroupsUdf>)
567        }};
568    }
569    use IRFunctionExpr as F;
570    Some(match func {
571        F::NullCount => wrap_groups!(groups_dispatch::null_count),
572        F::Reverse => wrap_groups!(groups_dispatch::reverse),
573        F::Boolean(IRBooleanFunction::HasNulls) => wrap_groups!(groups_dispatch::has_nulls),
574        F::Boolean(IRBooleanFunction::Any { ignore_nulls }) => {
575            let ignore_nulls = *ignore_nulls;
576            wrap_groups!(groups_dispatch::any, (ignore_nulls, v: bool))
577        },
578        F::Boolean(IRBooleanFunction::All { ignore_nulls }) => {
579            let ignore_nulls = *ignore_nulls;
580            wrap_groups!(groups_dispatch::all, (ignore_nulls, v: bool))
581        },
582        F::Boolean(IRBooleanFunction::IsEmpty { ignore_nulls }) => {
583            let ignore_nulls = *ignore_nulls;
584            wrap_groups!(groups_dispatch::is_empty, (ignore_nulls, v: bool))
585        },
586        #[cfg(feature = "bitwise")]
587        F::Bitwise(f) => {
588            use polars_plan::plans::IRBitwiseFunction as B;
589            match f {
590                B::And => wrap_groups!(groups_dispatch::bitwise_and),
591                B::Or => wrap_groups!(groups_dispatch::bitwise_or),
592                B::Xor => wrap_groups!(groups_dispatch::bitwise_xor),
593                _ => return None,
594            }
595        },
596        F::DropNans => wrap_groups!(groups_dispatch::drop_nans),
597        F::DropNulls => wrap_groups!(groups_dispatch::drop_nulls),
598
599        F::Quantile { method } => {
600            wrap_groups!(groups_dispatch::quantile, (*method, v: QuantileMethod))
601        },
602        #[cfg(feature = "moment")]
603        F::Skew(bias) => wrap_groups!(groups_dispatch::skew, (*bias, v: bool)),
604        #[cfg(feature = "moment")]
605        F::Kurtosis(fisher, bias) => {
606            wrap_groups!(groups_dispatch::kurtosis, (*fisher, v1: bool), (*bias, v2: bool))
607        },
608
609        F::Unique(stable) => wrap_groups!(groups_dispatch::unique, (*stable, v: bool)),
610        F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Forward(limit)) => {
611            wrap_groups!(groups_dispatch::forward_fill_null, (*limit, v: Option<IdxSize>))
612        },
613        F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Backward(limit)) => {
614            wrap_groups!(groups_dispatch::backward_fill_null, (*limit, v: Option<IdxSize>))
615        },
616
617        _ => return None,
618    })
619}