1use std::sync::Arc;
2
3use polars_core::error::PolarsResult;
4use polars_core::frame::DataFrame;
5use polars_core::prelude::{Column, GroupPositions};
6use polars_plan::dsl::{ColumnsUdf, SpecialEq};
7use polars_plan::plans::{IRBooleanFunction, IRFunctionExpr, IRPowFunction};
8use polars_utils::IdxSize;
9
10use crate::prelude::{AggregationContext, PhysicalExpr};
11use crate::state::ExecutionState;
12
13#[macro_export]
14macro_rules! wrap {
15 ($e:expr) => {
16 SpecialEq::new(Arc::new($e))
17 };
18
19 ($e:expr, $($args:expr),*) => {{
20 let f = move |s: &mut [::polars_core::prelude::Column]| {
21 $e(s, $($args),*)
22 };
23
24 SpecialEq::new(Arc::new(f))
25 }};
26}
27
28#[macro_export]
32macro_rules! map_as_slice {
33 ($func:path) => {{
34 let f = move |s: &mut [::polars_core::prelude::Column]| {
35 $func(s)
36 };
37
38 SpecialEq::new(Arc::new(f))
39 }};
40
41 ($func:path, $($args:expr),*) => {{
42 let f = move |s: &mut [::polars_core::prelude::Column]| {
43 $func(s, $($args),*)
44 };
45
46 SpecialEq::new(Arc::new(f))
47 }};
48}
49
50#[macro_export]
53macro_rules! map_owned {
54 ($func:path) => {{
55 let f = move |c: &mut [::polars_core::prelude::Column]| {
56 let c = std::mem::take(&mut c[0]);
57 $func(c)
58 };
59
60 SpecialEq::new(Arc::new(f))
61 }};
62
63 ($func:path, $($args:expr),*) => {{
64 let f = move |c: &mut [::polars_core::prelude::Column]| {
65 let c = std::mem::take(&mut c[0]);
66 $func(c, $($args),*)
67 };
68
69 SpecialEq::new(Arc::new(f))
70 }};
71}
72
73#[macro_export]
75macro_rules! map {
76 ($func:path) => {{
77 let f = move |c: &mut [::polars_core::prelude::Column]| {
78 let c = &c[0];
79 $func(c)
80 };
81
82 SpecialEq::new(Arc::new(f))
83 }};
84
85 ($func:path, $($args:expr),*) => {{
86 let f = move |c: &mut [::polars_core::prelude::Column]| {
87 let c = &c[0];
88 $func(c, $($args),*)
89 };
90
91 SpecialEq::new(Arc::new(f))
92 }};
93}
94
95#[cfg(feature = "dtype-array")]
96mod array;
97mod binary;
98#[cfg(feature = "bitwise")]
99mod bitwise;
100mod boolean;
101#[cfg(feature = "business")]
102mod business;
103#[cfg(feature = "dtype-categorical")]
104mod cat;
105#[cfg(feature = "cum_agg")]
106mod cum;
107#[cfg(feature = "temporal")]
108mod datetime;
109#[cfg(feature = "dtype-extension")]
110mod extension;
111mod groups_dispatch;
112mod horizontal;
113mod list;
114mod misc;
115mod pow;
116#[cfg(feature = "random")]
117mod random;
118#[cfg(feature = "range")]
119mod range;
120#[cfg(feature = "rolling_window")]
121mod rolling;
122#[cfg(feature = "rolling_window_by")]
123mod rolling_by;
124#[cfg(feature = "round_series")]
125mod round;
126mod shift_and_fill;
127#[cfg(feature = "strings")]
128mod strings;
129#[cfg(feature = "dtype-struct")]
130pub(crate) mod struct_;
131#[cfg(feature = "temporal")]
132mod temporal;
133#[cfg(feature = "trigonometry")]
134mod trigonometry;
135
136pub use groups_dispatch::drop_items;
137
138pub fn function_expr_to_udf(func: IRFunctionExpr) -> SpecialEq<Arc<dyn ColumnsUdf>> {
139 use IRFunctionExpr as F;
140 match func {
141 #[cfg(feature = "dtype-array")]
143 F::ArrayExpr(func) => array::function_expr_to_udf(func),
144 F::BinaryExpr(func) => binary::function_expr_to_udf(func),
145 #[cfg(feature = "dtype-categorical")]
146 F::Categorical(func) => cat::function_expr_to_udf(func),
147 #[cfg(feature = "dtype-extension")]
148 F::Extension(func) => extension::function_expr_to_udf(func),
149 F::ListExpr(func) => list::function_expr_to_udf(func),
150 #[cfg(feature = "strings")]
151 F::StringExpr(func) => strings::function_expr_to_udf(func),
152 #[cfg(feature = "dtype-struct")]
153 F::StructExpr(func) => struct_::function_expr_to_udf(func),
154 #[cfg(feature = "temporal")]
155 F::TemporalExpr(func) => temporal::temporal_func_to_udf(func),
156 #[cfg(feature = "bitwise")]
157 F::Bitwise(func) => bitwise::function_expr_to_udf(func),
158
159 F::Boolean(func) => boolean::function_expr_to_udf(func),
161 #[cfg(feature = "business")]
162 F::Business(func) => business::function_expr_to_udf(func),
163 #[cfg(feature = "abs")]
164 F::Abs => map!(misc::abs),
165 F::Negate => map!(misc::negate),
166 F::NullCount => {
167 let f = |s: &mut [Column]| {
168 let s = &s[0];
169 Ok(Column::new(s.name().clone(), [s.null_count() as IdxSize]))
170 };
171 wrap!(f)
172 },
173 F::Pow(func) => match func {
174 IRPowFunction::Generic => wrap!(pow::pow),
175 IRPowFunction::Sqrt => map!(pow::sqrt),
176 IRPowFunction::Cbrt => map!(pow::cbrt),
177 },
178 #[cfg(feature = "row_hash")]
179 F::Hash(k0, k1, k2, k3) => {
180 map!(misc::row_hash, k0, k1, k2, k3)
181 },
182 #[cfg(feature = "arg_where")]
183 F::ArgWhere => {
184 wrap!(misc::arg_where)
185 },
186 #[cfg(feature = "index_of")]
187 F::IndexOf => {
188 map_as_slice!(misc::index_of)
189 },
190 #[cfg(feature = "search_sorted")]
191 F::SearchSorted { side, descending } => {
192 map_as_slice!(misc::search_sorted_impl, side, descending)
193 },
194 #[cfg(feature = "range")]
195 F::Range(func) => range::function_expr_to_udf(func),
196
197 #[cfg(feature = "trigonometry")]
198 F::Trigonometry(trig_function) => {
199 map!(trigonometry::apply_trigonometric_function, trig_function)
200 },
201 #[cfg(feature = "trigonometry")]
202 F::Atan2 => {
203 wrap!(trigonometry::apply_arctan2)
204 },
205
206 #[cfg(feature = "sign")]
207 F::Sign => {
208 map!(misc::sign)
209 },
210 F::FillNull => {
211 map_as_slice!(misc::fill_null)
212 },
213 #[cfg(feature = "rolling_window")]
214 F::RollingExpr { function, options } => {
215 use IRRollingFunction::*;
216 use polars_plan::plans::IRRollingFunction;
217 match function {
218 Min => map!(rolling::rolling_min, options.clone()),
219 Max => map!(rolling::rolling_max, options.clone()),
220 Mean => map!(rolling::rolling_mean, options.clone()),
221 Sum => map!(rolling::rolling_sum, options.clone()),
222 Quantile => map!(rolling::rolling_quantile, options.clone()),
223 Var => map!(rolling::rolling_var, options.clone()),
224 Std => map!(rolling::rolling_std, options.clone()),
225 Rank => map!(rolling::rolling_rank, options.clone()),
226 #[cfg(feature = "moment")]
227 Skew => map!(rolling::rolling_skew, options.clone()),
228 #[cfg(feature = "moment")]
229 Kurtosis => map!(rolling::rolling_kurtosis, options.clone()),
230 #[cfg(feature = "cov")]
231 CorrCov {
232 corr_cov_options,
233 is_corr,
234 } => {
235 map_as_slice!(
236 rolling::rolling_corr_cov,
237 options.clone(),
238 corr_cov_options,
239 is_corr
240 )
241 },
242 Map(f) => {
243 map!(rolling::rolling_map, options.clone(), f.clone())
244 },
245 }
246 },
247 #[cfg(feature = "rolling_window_by")]
248 F::RollingExprBy {
249 function_by,
250 options,
251 } => {
252 use IRRollingFunctionBy::*;
253 use polars_plan::plans::IRRollingFunctionBy;
254 match function_by {
255 MinBy => map_as_slice!(rolling_by::rolling_min_by, options.clone()),
256 MaxBy => map_as_slice!(rolling_by::rolling_max_by, options.clone()),
257 MeanBy => map_as_slice!(rolling_by::rolling_mean_by, options.clone()),
258 SumBy => map_as_slice!(rolling_by::rolling_sum_by, options.clone()),
259 QuantileBy => {
260 map_as_slice!(rolling_by::rolling_quantile_by, options.clone())
261 },
262 VarBy => map_as_slice!(rolling_by::rolling_var_by, options.clone()),
263 StdBy => map_as_slice!(rolling_by::rolling_std_by, options.clone()),
264 RankBy => map_as_slice!(rolling_by::rolling_rank_by, options.clone()),
265 }
266 },
267 #[cfg(feature = "hist")]
268 F::Hist {
269 bin_count,
270 include_category,
271 include_breakpoint,
272 } => {
273 map_as_slice!(misc::hist, bin_count, include_category, include_breakpoint)
274 },
275 F::Rechunk => map!(misc::rechunk),
276 F::Append { upcast } => map_as_slice!(misc::append, upcast),
277 F::ShiftAndFill => {
278 map_as_slice!(shift_and_fill::shift_and_fill)
279 },
280 F::DropNans => map_owned!(misc::drop_nans),
281 F::DropNulls => map!(misc::drop_nulls),
282 #[cfg(feature = "round_series")]
283 F::Clip { has_min, has_max } => {
284 map_as_slice!(misc::clip, has_min, has_max)
285 },
286 #[cfg(feature = "mode")]
287 F::Mode { maintain_order } => map!(misc::mode, maintain_order),
288 #[cfg(feature = "moment")]
289 F::Skew(bias) => map!(misc::skew, bias),
290 #[cfg(feature = "moment")]
291 F::Kurtosis(fisher, bias) => map!(misc::kurtosis, fisher, bias),
292 F::ArgUnique => map!(misc::arg_unique),
293 F::ArgMin => map!(misc::arg_min),
294 F::ArgMax => map!(misc::arg_max),
295 F::ArgSort {
296 descending,
297 nulls_last,
298 } => map!(misc::arg_sort, descending, nulls_last),
299 F::MinBy => map_as_slice!(misc::min_by),
300 F::MaxBy => map_as_slice!(misc::max_by),
301 F::Product => map!(misc::product),
302 F::Repeat => map_as_slice!(misc::repeat),
303 #[cfg(feature = "rank")]
304 F::Rank { options, seed } => map!(misc::rank, options, seed),
305 #[cfg(feature = "dtype-struct")]
306 F::AsStruct => {
307 map_as_slice!(misc::as_struct)
308 },
309 #[cfg(feature = "top_k")]
310 F::TopK { descending } => {
311 map_as_slice!(polars_ops::prelude::top_k, descending)
312 },
313 #[cfg(feature = "top_k")]
314 F::TopKBy { descending } => {
315 map_as_slice!(polars_ops::prelude::top_k_by, descending.clone())
316 },
317 F::Shift => map_as_slice!(shift_and_fill::shift),
318 #[cfg(feature = "cum_agg")]
319 F::CumCount { reverse } => map!(cum::cum_count, reverse),
320 #[cfg(feature = "cum_agg")]
321 F::CumSum { reverse } => map!(cum::cum_sum, reverse),
322 #[cfg(feature = "cum_agg")]
323 F::CumProd { reverse } => map!(cum::cum_prod, reverse),
324 #[cfg(feature = "cum_agg")]
325 F::CumMin { reverse } => map!(cum::cum_min, reverse),
326 #[cfg(feature = "cum_agg")]
327 F::CumMax { reverse } => map!(cum::cum_max, reverse),
328 #[cfg(feature = "dtype-struct")]
329 F::ValueCounts {
330 sort,
331 parallel,
332 name,
333 normalize,
334 } => map!(misc::value_counts, sort, parallel, name.clone(), normalize),
335 #[cfg(feature = "unique_counts")]
336 F::UniqueCounts => map!(misc::unique_counts),
337 F::Reverse => map!(misc::reverse),
338 #[cfg(feature = "approx_unique")]
339 F::ApproxNUnique => map!(misc::approx_n_unique),
340 F::Coalesce => map_as_slice!(misc::coalesce),
341 #[cfg(feature = "diff")]
342 F::Diff(null_behavior) => map_as_slice!(misc::diff, null_behavior),
343 #[cfg(feature = "pct_change")]
344 F::PctChange => map_as_slice!(misc::pct_change),
345 #[cfg(feature = "interpolate")]
346 F::Interpolate(method) => {
347 map!(misc::interpolate, method)
348 },
349 #[cfg(feature = "interpolate_by")]
350 F::InterpolateBy => {
351 map_as_slice!(misc::interpolate_by)
352 },
353 #[cfg(feature = "log")]
354 F::Entropy { base, normalize } => map!(misc::entropy, base, normalize),
355 #[cfg(feature = "log")]
356 F::Log => map_as_slice!(misc::log),
357 #[cfg(feature = "log")]
358 F::Log1p => map!(misc::log1p),
359 #[cfg(feature = "log")]
360 F::Exp => map!(misc::exp),
361 F::Unique(stable) => map!(misc::unique, stable),
362 #[cfg(feature = "round_series")]
363 F::Round { decimals, mode } => map!(round::round, decimals, mode),
364 #[cfg(feature = "round_series")]
365 F::RoundSF { digits } => map!(round::round_sig_figs, digits),
366 #[cfg(feature = "round_series")]
367 F::Floor => map!(round::floor),
368 #[cfg(feature = "round_series")]
369 F::Ceil => map!(round::ceil),
370 #[cfg(feature = "fused")]
371 F::Fused(op) => map_as_slice!(misc::fused, op),
372 F::ConcatExpr(rechunk) => map_as_slice!(misc::concat_expr, rechunk),
373 #[cfg(feature = "cov")]
374 F::Correlation { method } => map_as_slice!(misc::corr, method),
375 #[cfg(feature = "peaks")]
376 F::PeakMin => map!(misc::peak_min),
377 #[cfg(feature = "peaks")]
378 F::PeakMax => map!(misc::peak_max),
379 #[cfg(feature = "repeat_by")]
380 F::RepeatBy => map_as_slice!(misc::repeat_by),
381 #[cfg(feature = "dtype-array")]
382 F::Reshape(dims) => map!(misc::reshape, &dims),
383 #[cfg(feature = "cutqcut")]
384 F::Cut {
385 breaks,
386 labels,
387 left_closed,
388 include_breaks,
389 } => map!(
390 misc::cut,
391 breaks.clone(),
392 labels.clone(),
393 left_closed,
394 include_breaks
395 ),
396 #[cfg(feature = "cutqcut")]
397 F::QCut {
398 probs,
399 labels,
400 left_closed,
401 allow_duplicates,
402 include_breaks,
403 } => map!(
404 misc::qcut,
405 probs.clone(),
406 labels.clone(),
407 left_closed,
408 allow_duplicates,
409 include_breaks
410 ),
411 #[cfg(feature = "rle")]
412 F::RLE => map!(polars_ops::series::rle),
413 #[cfg(feature = "rle")]
414 F::RLEID => map!(polars_ops::series::rle_id),
415 F::ToPhysical => map!(misc::to_physical),
416 #[cfg(feature = "random")]
417 F::Random { method, seed } => {
418 use IRRandomMethod::*;
419 use polars_plan::plans::IRRandomMethod;
420 match method {
421 Shuffle => map!(random::shuffle, seed),
422 Sample {
423 is_fraction,
424 with_replacement,
425 shuffle,
426 } => {
427 if is_fraction {
428 map_as_slice!(random::sample_frac, with_replacement, shuffle, seed)
429 } else {
430 map_as_slice!(random::sample_n, with_replacement, shuffle, seed)
431 }
432 },
433 }
434 },
435 F::SetSortedFlag(sorted) => map!(misc::set_sorted_flag, sorted),
436 #[cfg(feature = "ffi_plugin")]
437 F::FfiPlugin {
438 flags: _,
439 lib,
440 symbol,
441 kwargs,
442 } => unsafe {
443 map_as_slice!(
444 polars_plan::plans::plugin::call_plugin,
445 lib.as_ref(),
446 symbol.as_ref(),
447 kwargs.as_ref()
448 )
449 },
450
451 F::FoldHorizontal {
452 callback,
453 returns_scalar,
454 return_dtype,
455 } => map_as_slice!(
456 horizontal::fold,
457 &callback,
458 returns_scalar,
459 return_dtype.as_ref()
460 ),
461 F::ReduceHorizontal {
462 callback,
463 returns_scalar,
464 return_dtype,
465 } => map_as_slice!(
466 horizontal::reduce,
467 &callback,
468 returns_scalar,
469 return_dtype.as_ref()
470 ),
471 #[cfg(feature = "dtype-struct")]
472 F::CumReduceHorizontal {
473 callback,
474 returns_scalar,
475 return_dtype,
476 } => map_as_slice!(
477 horizontal::cum_reduce,
478 &callback,
479 returns_scalar,
480 return_dtype.as_ref()
481 ),
482 #[cfg(feature = "dtype-struct")]
483 F::CumFoldHorizontal {
484 callback,
485 returns_scalar,
486 return_dtype,
487 include_init,
488 } => map_as_slice!(
489 horizontal::cum_fold,
490 &callback,
491 returns_scalar,
492 return_dtype.as_ref(),
493 include_init
494 ),
495
496 F::MaxHorizontal => wrap!(misc::max_horizontal),
497 F::MinHorizontal => wrap!(misc::min_horizontal),
498 F::SumHorizontal { ignore_nulls } => wrap!(misc::sum_horizontal, ignore_nulls),
499 F::MeanHorizontal { ignore_nulls } => wrap!(misc::mean_horizontal, ignore_nulls),
500 #[cfg(feature = "ewma")]
501 F::EwmMean { options } => map!(misc::ewm_mean, options),
502 #[cfg(feature = "ewma_by")]
503 F::EwmMeanBy { half_life } => map_as_slice!(misc::ewm_mean_by, half_life),
504 #[cfg(feature = "ewma")]
505 F::EwmStd { options } => map!(misc::ewm_std, options),
506 #[cfg(feature = "ewma")]
507 F::EwmVar { options } => map!(misc::ewm_var, options),
508 #[cfg(feature = "replace")]
509 F::Replace => {
510 map_as_slice!(misc::replace)
511 },
512 #[cfg(feature = "replace")]
513 F::ReplaceStrict { return_dtype } => {
514 map_as_slice!(misc::replace_strict, return_dtype.clone())
515 },
516
517 F::FillNullWithStrategy(strategy) => map!(misc::fill_null_with_strategy, strategy),
518 F::GatherEvery { n, offset } => map!(misc::gather_every, n, offset),
519 #[cfg(feature = "reinterpret")]
520 F::Reinterpret(signed) => map!(misc::reinterpret, signed),
521 F::ExtendConstant => map_as_slice!(misc::extend_constant),
522
523 F::RowEncode(dts, variants) => {
524 map_as_slice!(misc::row_encode, dts.clone(), variants.clone())
525 },
526 #[cfg(feature = "dtype-struct")]
527 F::RowDecode(fs, variants) => {
528 map_as_slice!(misc::row_decode, fs.clone(), variants.clone())
529 },
530 }
531}
532
533pub trait GroupsUdf: Send + Sync + 'static {
534 fn evaluate_on_groups<'a>(
535 &self,
536 inputs: &[Arc<dyn PhysicalExpr>],
537 df: &DataFrame,
538 groups: &'a GroupPositions,
539 state: &ExecutionState,
540 ) -> PolarsResult<AggregationContext<'a>>;
541}
542
543pub fn function_expr_to_groups_udf(func: &IRFunctionExpr) -> Option<SpecialEq<Arc<dyn GroupsUdf>>> {
544 macro_rules! wrap_groups {
545 ($f:expr$(, ($arg:expr, $n:ident:$ty:ty))*) => {{
546 struct Wrap($($ty),*);
547 impl GroupsUdf for Wrap {
548 fn evaluate_on_groups<'a>(
549 &self,
550 inputs: &[Arc<dyn PhysicalExpr>],
551 df: &DataFrame,
552 groups: &'a GroupPositions,
553 state: &ExecutionState,
554 ) -> PolarsResult<AggregationContext<'a>> {
555 let Wrap($($n),*) = self;
556 $f(inputs, df, groups, state$(, *$n)*)
557 }
558 }
559
560 SpecialEq::new(Arc::new(Wrap($($arg),*)) as Arc<dyn GroupsUdf>)
561 }};
562 }
563 use IRFunctionExpr as F;
564 Some(match func {
565 F::NullCount => wrap_groups!(groups_dispatch::null_count),
566 F::Reverse => wrap_groups!(groups_dispatch::reverse),
567 F::Boolean(IRBooleanFunction::Any { ignore_nulls }) => {
568 let ignore_nulls = *ignore_nulls;
569 wrap_groups!(groups_dispatch::any, (ignore_nulls, v: bool))
570 },
571 F::Boolean(IRBooleanFunction::All { ignore_nulls }) => {
572 let ignore_nulls = *ignore_nulls;
573 wrap_groups!(groups_dispatch::all, (ignore_nulls, v: bool))
574 },
575 #[cfg(feature = "bitwise")]
576 F::Bitwise(f) => {
577 use polars_plan::plans::IRBitwiseFunction as B;
578 match f {
579 B::And => wrap_groups!(groups_dispatch::bitwise_and),
580 B::Or => wrap_groups!(groups_dispatch::bitwise_or),
581 B::Xor => wrap_groups!(groups_dispatch::bitwise_xor),
582 _ => return None,
583 }
584 },
585 F::DropNans => wrap_groups!(groups_dispatch::drop_nans),
586 F::DropNulls => wrap_groups!(groups_dispatch::drop_nulls),
587
588 #[cfg(feature = "moment")]
589 F::Skew(bias) => wrap_groups!(groups_dispatch::skew, (*bias, v: bool)),
590 #[cfg(feature = "moment")]
591 F::Kurtosis(fisher, bias) => {
592 wrap_groups!(groups_dispatch::kurtosis, (*fisher, v1: bool), (*bias, v2: bool))
593 },
594
595 F::Unique(stable) => wrap_groups!(groups_dispatch::unique, (*stable, v: bool)),
596 F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Forward(limit)) => {
597 wrap_groups!(groups_dispatch::forward_fill_null, (*limit, v: Option<IdxSize>))
598 },
599 F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Backward(limit)) => {
600 wrap_groups!(groups_dispatch::backward_fill_null, (*limit, v: Option<IdxSize>))
601 },
602
603 _ => return None,
604 })
605}