1use std::sync::Arc;
2
3use polars_core::error::PolarsResult;
4use polars_core::frame::DataFrame;
5use polars_core::prelude::{Column, GroupPositions};
6use polars_plan::dsl::{ColumnsUdf, SpecialEq};
7use polars_plan::plans::{IRBooleanFunction, IRFunctionExpr, IRPowFunction};
8use polars_utils::IdxSize;
9
10use crate::prelude::{AggregationContext, PhysicalExpr};
11use crate::state::ExecutionState;
12
13#[macro_export]
14macro_rules! wrap {
15 ($e:expr) => {
16 SpecialEq::new(Arc::new($e))
17 };
18
19 ($e:expr, $($args:expr),*) => {{
20 let f = move |s: &mut [::polars_core::prelude::Column]| {
21 $e(s, $($args),*)
22 };
23
24 SpecialEq::new(Arc::new(f))
25 }};
26}
27
28#[macro_export]
32macro_rules! map_as_slice {
33 ($func:path) => {{
34 let f = move |s: &mut [::polars_core::prelude::Column]| {
35 $func(s)
36 };
37
38 SpecialEq::new(Arc::new(f))
39 }};
40
41 ($func:path, $($args:expr),*) => {{
42 let f = move |s: &mut [::polars_core::prelude::Column]| {
43 $func(s, $($args),*)
44 };
45
46 SpecialEq::new(Arc::new(f))
47 }};
48}
49
50#[macro_export]
53macro_rules! map_owned {
54 ($func:path) => {{
55 let f = move |c: &mut [::polars_core::prelude::Column]| {
56 let c = std::mem::take(&mut c[0]);
57 $func(c)
58 };
59
60 SpecialEq::new(Arc::new(f))
61 }};
62
63 ($func:path, $($args:expr),*) => {{
64 let f = move |c: &mut [::polars_core::prelude::Column]| {
65 let c = std::mem::take(&mut c[0]);
66 $func(c, $($args),*)
67 };
68
69 SpecialEq::new(Arc::new(f))
70 }};
71}
72
73#[macro_export]
75macro_rules! map {
76 ($func:path) => {{
77 let f = move |c: &mut [::polars_core::prelude::Column]| {
78 let c = &c[0];
79 $func(c)
80 };
81
82 SpecialEq::new(Arc::new(f))
83 }};
84
85 ($func:path, $($args:expr),*) => {{
86 let f = move |c: &mut [::polars_core::prelude::Column]| {
87 let c = &c[0];
88 $func(c, $($args),*)
89 };
90
91 SpecialEq::new(Arc::new(f))
92 }};
93}
94
95#[cfg(feature = "dtype-array")]
96mod array;
97mod binary;
98#[cfg(feature = "bitwise")]
99mod bitwise;
100mod boolean;
101#[cfg(feature = "business")]
102mod business;
103#[cfg(feature = "dtype-categorical")]
104mod cat;
105#[cfg(feature = "cum_agg")]
106mod cum;
107#[cfg(feature = "temporal")]
108mod datetime;
109mod groups_dispatch;
110mod horizontal;
111mod list;
112mod misc;
113mod pow;
114#[cfg(feature = "random")]
115mod random;
116#[cfg(feature = "range")]
117mod range;
118#[cfg(feature = "rolling_window")]
119mod rolling;
120#[cfg(feature = "rolling_window_by")]
121mod rolling_by;
122#[cfg(feature = "round_series")]
123mod round;
124mod shift_and_fill;
125#[cfg(feature = "strings")]
126mod strings;
127#[cfg(feature = "dtype-struct")]
128mod struct_;
129#[cfg(feature = "temporal")]
130mod temporal;
131#[cfg(feature = "trigonometry")]
132mod trigonometry;
133
134pub use groups_dispatch::drop_items;
135
136pub fn function_expr_to_udf(func: IRFunctionExpr) -> SpecialEq<Arc<dyn ColumnsUdf>> {
137 use IRFunctionExpr as F;
138 match func {
139 #[cfg(feature = "dtype-array")]
141 F::ArrayExpr(func) => array::function_expr_to_udf(func),
142 F::BinaryExpr(func) => binary::function_expr_to_udf(func),
143 #[cfg(feature = "dtype-categorical")]
144 F::Categorical(func) => cat::function_expr_to_udf(func),
145 F::ListExpr(func) => list::function_expr_to_udf(func),
146 #[cfg(feature = "strings")]
147 F::StringExpr(func) => strings::function_expr_to_udf(func),
148 #[cfg(feature = "dtype-struct")]
149 F::StructExpr(func) => struct_::function_expr_to_udf(func),
150 #[cfg(feature = "temporal")]
151 F::TemporalExpr(func) => temporal::temporal_func_to_udf(func),
152 #[cfg(feature = "bitwise")]
153 F::Bitwise(func) => bitwise::function_expr_to_udf(func),
154
155 F::Boolean(func) => boolean::function_expr_to_udf(func),
157 #[cfg(feature = "business")]
158 F::Business(func) => business::function_expr_to_udf(func),
159 #[cfg(feature = "abs")]
160 F::Abs => map!(misc::abs),
161 F::Negate => map!(misc::negate),
162 F::NullCount => {
163 let f = |s: &mut [Column]| {
164 let s = &s[0];
165 Ok(Column::new(s.name().clone(), [s.null_count() as IdxSize]))
166 };
167 wrap!(f)
168 },
169 F::Pow(func) => match func {
170 IRPowFunction::Generic => wrap!(pow::pow),
171 IRPowFunction::Sqrt => map!(pow::sqrt),
172 IRPowFunction::Cbrt => map!(pow::cbrt),
173 },
174 #[cfg(feature = "row_hash")]
175 F::Hash(k0, k1, k2, k3) => {
176 map!(misc::row_hash, k0, k1, k2, k3)
177 },
178 #[cfg(feature = "arg_where")]
179 F::ArgWhere => {
180 wrap!(misc::arg_where)
181 },
182 #[cfg(feature = "index_of")]
183 F::IndexOf => {
184 map_as_slice!(misc::index_of)
185 },
186 #[cfg(feature = "search_sorted")]
187 F::SearchSorted { side, descending } => {
188 map_as_slice!(misc::search_sorted_impl, side, descending)
189 },
190 #[cfg(feature = "range")]
191 F::Range(func) => range::function_expr_to_udf(func),
192
193 #[cfg(feature = "trigonometry")]
194 F::Trigonometry(trig_function) => {
195 map!(trigonometry::apply_trigonometric_function, trig_function)
196 },
197 #[cfg(feature = "trigonometry")]
198 F::Atan2 => {
199 wrap!(trigonometry::apply_arctan2)
200 },
201
202 #[cfg(feature = "sign")]
203 F::Sign => {
204 map!(misc::sign)
205 },
206 F::FillNull => {
207 map_as_slice!(misc::fill_null)
208 },
209 #[cfg(feature = "rolling_window")]
210 F::RollingExpr { function, options } => {
211 use IRRollingFunction::*;
212 use polars_plan::plans::IRRollingFunction;
213 match function {
214 Min => map!(rolling::rolling_min, options.clone()),
215 Max => map!(rolling::rolling_max, options.clone()),
216 Mean => map!(rolling::rolling_mean, options.clone()),
217 Sum => map!(rolling::rolling_sum, options.clone()),
218 Quantile => map!(rolling::rolling_quantile, options.clone()),
219 Var => map!(rolling::rolling_var, options.clone()),
220 Std => map!(rolling::rolling_std, options.clone()),
221 Rank => map!(rolling::rolling_rank, options.clone()),
222 #[cfg(feature = "moment")]
223 Skew => map!(rolling::rolling_skew, options.clone()),
224 #[cfg(feature = "moment")]
225 Kurtosis => map!(rolling::rolling_kurtosis, options.clone()),
226 #[cfg(feature = "cov")]
227 CorrCov {
228 corr_cov_options,
229 is_corr,
230 } => {
231 map_as_slice!(
232 rolling::rolling_corr_cov,
233 options.clone(),
234 corr_cov_options,
235 is_corr
236 )
237 },
238 Map(f) => {
239 map!(rolling::rolling_map, options.clone(), f.clone())
240 },
241 }
242 },
243 #[cfg(feature = "rolling_window_by")]
244 F::RollingExprBy {
245 function_by,
246 options,
247 } => {
248 use IRRollingFunctionBy::*;
249 use polars_plan::plans::IRRollingFunctionBy;
250 match function_by {
251 MinBy => map_as_slice!(rolling_by::rolling_min_by, options.clone()),
252 MaxBy => map_as_slice!(rolling_by::rolling_max_by, options.clone()),
253 MeanBy => map_as_slice!(rolling_by::rolling_mean_by, options.clone()),
254 SumBy => map_as_slice!(rolling_by::rolling_sum_by, options.clone()),
255 QuantileBy => {
256 map_as_slice!(rolling_by::rolling_quantile_by, options.clone())
257 },
258 VarBy => map_as_slice!(rolling_by::rolling_var_by, options.clone()),
259 StdBy => map_as_slice!(rolling_by::rolling_std_by, options.clone()),
260 RankBy => map_as_slice!(rolling_by::rolling_rank_by, options.clone()),
261 }
262 },
263 #[cfg(feature = "hist")]
264 F::Hist {
265 bin_count,
266 include_category,
267 include_breakpoint,
268 } => {
269 map_as_slice!(misc::hist, bin_count, include_category, include_breakpoint)
270 },
271 F::Rechunk => map!(misc::rechunk),
272 F::Append { upcast } => map_as_slice!(misc::append, upcast),
273 F::ShiftAndFill => {
274 map_as_slice!(shift_and_fill::shift_and_fill)
275 },
276 F::DropNans => map_owned!(misc::drop_nans),
277 F::DropNulls => map!(misc::drop_nulls),
278 #[cfg(feature = "round_series")]
279 F::Clip { has_min, has_max } => {
280 map_as_slice!(misc::clip, has_min, has_max)
281 },
282 #[cfg(feature = "mode")]
283 F::Mode => map!(misc::mode),
284 #[cfg(feature = "moment")]
285 F::Skew(bias) => map!(misc::skew, bias),
286 #[cfg(feature = "moment")]
287 F::Kurtosis(fisher, bias) => map!(misc::kurtosis, fisher, bias),
288 F::ArgUnique => map!(misc::arg_unique),
289 F::ArgMin => map!(misc::arg_min),
290 F::ArgMax => map!(misc::arg_max),
291 F::ArgSort {
292 descending,
293 nulls_last,
294 } => map!(misc::arg_sort, descending, nulls_last),
295 F::Product => map!(misc::product),
296 F::Repeat => map_as_slice!(misc::repeat),
297 #[cfg(feature = "rank")]
298 F::Rank { options, seed } => map!(misc::rank, options, seed),
299 #[cfg(feature = "dtype-struct")]
300 F::AsStruct => {
301 map_as_slice!(misc::as_struct)
302 },
303 #[cfg(feature = "top_k")]
304 F::TopK { descending } => {
305 map_as_slice!(polars_ops::prelude::top_k, descending)
306 },
307 #[cfg(feature = "top_k")]
308 F::TopKBy { descending } => {
309 map_as_slice!(polars_ops::prelude::top_k_by, descending.clone())
310 },
311 F::Shift => map_as_slice!(shift_and_fill::shift),
312 #[cfg(feature = "cum_agg")]
313 F::CumCount { reverse } => map!(cum::cum_count, reverse),
314 #[cfg(feature = "cum_agg")]
315 F::CumSum { reverse } => map!(cum::cum_sum, reverse),
316 #[cfg(feature = "cum_agg")]
317 F::CumProd { reverse } => map!(cum::cum_prod, reverse),
318 #[cfg(feature = "cum_agg")]
319 F::CumMin { reverse } => map!(cum::cum_min, reverse),
320 #[cfg(feature = "cum_agg")]
321 F::CumMax { reverse } => map!(cum::cum_max, reverse),
322 #[cfg(feature = "dtype-struct")]
323 F::ValueCounts {
324 sort,
325 parallel,
326 name,
327 normalize,
328 } => map!(misc::value_counts, sort, parallel, name.clone(), normalize),
329 #[cfg(feature = "unique_counts")]
330 F::UniqueCounts => map!(misc::unique_counts),
331 F::Reverse => map!(misc::reverse),
332 #[cfg(feature = "approx_unique")]
333 F::ApproxNUnique => map!(misc::approx_n_unique),
334 F::Coalesce => map_as_slice!(misc::coalesce),
335 #[cfg(feature = "diff")]
336 F::Diff(null_behavior) => map_as_slice!(misc::diff, null_behavior),
337 #[cfg(feature = "pct_change")]
338 F::PctChange => map_as_slice!(misc::pct_change),
339 #[cfg(feature = "interpolate")]
340 F::Interpolate(method) => {
341 map!(misc::interpolate, method)
342 },
343 #[cfg(feature = "interpolate_by")]
344 F::InterpolateBy => {
345 map_as_slice!(misc::interpolate_by)
346 },
347 #[cfg(feature = "log")]
348 F::Entropy { base, normalize } => map!(misc::entropy, base, normalize),
349 #[cfg(feature = "log")]
350 F::Log => map_as_slice!(misc::log),
351 #[cfg(feature = "log")]
352 F::Log1p => map!(misc::log1p),
353 #[cfg(feature = "log")]
354 F::Exp => map!(misc::exp),
355 F::Unique(stable) => map!(misc::unique, stable),
356 #[cfg(feature = "round_series")]
357 F::Round { decimals, mode } => map!(round::round, decimals, mode),
358 #[cfg(feature = "round_series")]
359 F::RoundSF { digits } => map!(round::round_sig_figs, digits),
360 #[cfg(feature = "round_series")]
361 F::Floor => map!(round::floor),
362 #[cfg(feature = "round_series")]
363 F::Ceil => map!(round::ceil),
364 #[cfg(feature = "fused")]
365 F::Fused(op) => map_as_slice!(misc::fused, op),
366 F::ConcatExpr(rechunk) => map_as_slice!(misc::concat_expr, rechunk),
367 #[cfg(feature = "cov")]
368 F::Correlation { method } => map_as_slice!(misc::corr, method),
369 #[cfg(feature = "peaks")]
370 F::PeakMin => map!(misc::peak_min),
371 #[cfg(feature = "peaks")]
372 F::PeakMax => map!(misc::peak_max),
373 #[cfg(feature = "repeat_by")]
374 F::RepeatBy => map_as_slice!(misc::repeat_by),
375 #[cfg(feature = "dtype-array")]
376 F::Reshape(dims) => map!(misc::reshape, &dims),
377 #[cfg(feature = "cutqcut")]
378 F::Cut {
379 breaks,
380 labels,
381 left_closed,
382 include_breaks,
383 } => map!(
384 misc::cut,
385 breaks.clone(),
386 labels.clone(),
387 left_closed,
388 include_breaks
389 ),
390 #[cfg(feature = "cutqcut")]
391 F::QCut {
392 probs,
393 labels,
394 left_closed,
395 allow_duplicates,
396 include_breaks,
397 } => map!(
398 misc::qcut,
399 probs.clone(),
400 labels.clone(),
401 left_closed,
402 allow_duplicates,
403 include_breaks
404 ),
405 #[cfg(feature = "rle")]
406 F::RLE => map!(polars_ops::series::rle),
407 #[cfg(feature = "rle")]
408 F::RLEID => map!(polars_ops::series::rle_id),
409 F::ToPhysical => map!(misc::to_physical),
410 #[cfg(feature = "random")]
411 F::Random { method, seed } => {
412 use IRRandomMethod::*;
413 use polars_plan::plans::IRRandomMethod;
414 match method {
415 Shuffle => map!(random::shuffle, seed),
416 Sample {
417 is_fraction,
418 with_replacement,
419 shuffle,
420 } => {
421 if is_fraction {
422 map_as_slice!(random::sample_frac, with_replacement, shuffle, seed)
423 } else {
424 map_as_slice!(random::sample_n, with_replacement, shuffle, seed)
425 }
426 },
427 }
428 },
429 F::SetSortedFlag(sorted) => map!(misc::set_sorted_flag, sorted),
430 #[cfg(feature = "ffi_plugin")]
431 F::FfiPlugin {
432 flags: _,
433 lib,
434 symbol,
435 kwargs,
436 } => unsafe {
437 map_as_slice!(
438 polars_plan::plans::plugin::call_plugin,
439 lib.as_ref(),
440 symbol.as_ref(),
441 kwargs.as_ref()
442 )
443 },
444
445 F::FoldHorizontal {
446 callback,
447 returns_scalar,
448 return_dtype,
449 } => map_as_slice!(
450 horizontal::fold,
451 &callback,
452 returns_scalar,
453 return_dtype.as_ref()
454 ),
455 F::ReduceHorizontal {
456 callback,
457 returns_scalar,
458 return_dtype,
459 } => map_as_slice!(
460 horizontal::reduce,
461 &callback,
462 returns_scalar,
463 return_dtype.as_ref()
464 ),
465 #[cfg(feature = "dtype-struct")]
466 F::CumReduceHorizontal {
467 callback,
468 returns_scalar,
469 return_dtype,
470 } => map_as_slice!(
471 horizontal::cum_reduce,
472 &callback,
473 returns_scalar,
474 return_dtype.as_ref()
475 ),
476 #[cfg(feature = "dtype-struct")]
477 F::CumFoldHorizontal {
478 callback,
479 returns_scalar,
480 return_dtype,
481 include_init,
482 } => map_as_slice!(
483 horizontal::cum_fold,
484 &callback,
485 returns_scalar,
486 return_dtype.as_ref(),
487 include_init
488 ),
489
490 F::MaxHorizontal => wrap!(misc::max_horizontal),
491 F::MinHorizontal => wrap!(misc::min_horizontal),
492 F::SumHorizontal { ignore_nulls } => wrap!(misc::sum_horizontal, ignore_nulls),
493 F::MeanHorizontal { ignore_nulls } => wrap!(misc::mean_horizontal, ignore_nulls),
494 #[cfg(feature = "ewma")]
495 F::EwmMean { options } => map!(misc::ewm_mean, options),
496 #[cfg(feature = "ewma_by")]
497 F::EwmMeanBy { half_life } => map_as_slice!(misc::ewm_mean_by, half_life),
498 #[cfg(feature = "ewma")]
499 F::EwmStd { options } => map!(misc::ewm_std, options),
500 #[cfg(feature = "ewma")]
501 F::EwmVar { options } => map!(misc::ewm_var, options),
502 #[cfg(feature = "replace")]
503 F::Replace => {
504 map_as_slice!(misc::replace)
505 },
506 #[cfg(feature = "replace")]
507 F::ReplaceStrict { return_dtype } => {
508 map_as_slice!(misc::replace_strict, return_dtype.clone())
509 },
510
511 F::FillNullWithStrategy(strategy) => map!(misc::fill_null_with_strategy, strategy),
512 F::GatherEvery { n, offset } => map!(misc::gather_every, n, offset),
513 #[cfg(feature = "reinterpret")]
514 F::Reinterpret(signed) => map!(misc::reinterpret, signed),
515 F::ExtendConstant => map_as_slice!(misc::extend_constant),
516
517 F::RowEncode(dts, variants) => {
518 map_as_slice!(misc::row_encode, dts.clone(), variants.clone())
519 },
520 #[cfg(feature = "dtype-struct")]
521 F::RowDecode(fs, variants) => {
522 map_as_slice!(misc::row_decode, fs.clone(), variants.clone())
523 },
524 }
525}
526
527pub trait GroupsUdf: Send + Sync + 'static {
528 fn evaluate_on_groups<'a>(
529 &self,
530 inputs: &[Arc<dyn PhysicalExpr>],
531 df: &DataFrame,
532 groups: &'a GroupPositions,
533 state: &ExecutionState,
534 ) -> PolarsResult<AggregationContext<'a>>;
535}
536
537pub fn function_expr_to_groups_udf(func: &IRFunctionExpr) -> Option<SpecialEq<Arc<dyn GroupsUdf>>> {
538 macro_rules! wrap_groups {
539 ($f:expr$(, ($arg:expr, $n:ident:$ty:ty))*) => {{
540 struct Wrap($($ty),*);
541 impl GroupsUdf for Wrap {
542 fn evaluate_on_groups<'a>(
543 &self,
544 inputs: &[Arc<dyn PhysicalExpr>],
545 df: &DataFrame,
546 groups: &'a GroupPositions,
547 state: &ExecutionState,
548 ) -> PolarsResult<AggregationContext<'a>> {
549 let Wrap($($n),*) = self;
550 $f(inputs, df, groups, state$(, *$n)*)
551 }
552 }
553
554 SpecialEq::new(Arc::new(Wrap($($arg),*)) as Arc<dyn GroupsUdf>)
555 }};
556 }
557 use IRFunctionExpr as F;
558 Some(match func {
559 F::NullCount => wrap_groups!(groups_dispatch::null_count),
560 F::Reverse => wrap_groups!(groups_dispatch::reverse),
561 F::Boolean(IRBooleanFunction::Any { ignore_nulls }) => {
562 let ignore_nulls = *ignore_nulls;
563 wrap_groups!(groups_dispatch::any, (ignore_nulls, v: bool))
564 },
565 F::Boolean(IRBooleanFunction::All { ignore_nulls }) => {
566 let ignore_nulls = *ignore_nulls;
567 wrap_groups!(groups_dispatch::all, (ignore_nulls, v: bool))
568 },
569 #[cfg(feature = "bitwise")]
570 F::Bitwise(f) => {
571 use polars_plan::plans::IRBitwiseFunction as B;
572 match f {
573 B::And => wrap_groups!(groups_dispatch::bitwise_and),
574 B::Or => wrap_groups!(groups_dispatch::bitwise_or),
575 B::Xor => wrap_groups!(groups_dispatch::bitwise_xor),
576 _ => return None,
577 }
578 },
579 F::DropNans => wrap_groups!(groups_dispatch::drop_nans),
580 F::DropNulls => wrap_groups!(groups_dispatch::drop_nulls),
581
582 #[cfg(feature = "moment")]
583 F::Skew(bias) => wrap_groups!(groups_dispatch::skew, (*bias, v: bool)),
584 #[cfg(feature = "moment")]
585 F::Kurtosis(fisher, bias) => {
586 wrap_groups!(groups_dispatch::kurtosis, (*fisher, v1: bool), (*bias, v2: bool))
587 },
588
589 F::Unique(stable) => wrap_groups!(groups_dispatch::unique, (*stable, v: bool)),
590 F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Forward(limit)) => {
591 wrap_groups!(groups_dispatch::forward_fill_null, (*limit, v: Option<IdxSize>))
592 },
593 F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Backward(limit)) => {
594 wrap_groups!(groups_dispatch::backward_fill_null, (*limit, v: Option<IdxSize>))
595 },
596
597 _ => return None,
598 })
599}