1use std::sync::Arc;
2
3use polars_compute::rolling::QuantileMethod;
4use polars_core::error::PolarsResult;
5use polars_core::frame::DataFrame;
6use polars_core::prelude::{Column, GroupPositions};
7use polars_plan::dsl::{ColumnsUdf, SpecialEq};
8use polars_plan::plans::{IRBooleanFunction, IRFunctionExpr, IRPowFunction};
9use polars_utils::IdxSize;
10
11use crate::prelude::{AggregationContext, PhysicalExpr};
12use crate::state::ExecutionState;
13
14#[macro_export]
15macro_rules! wrap {
16 ($e:expr) => {
17 SpecialEq::new(Arc::new($e))
18 };
19
20 ($e:expr, $($args:expr),*) => {{
21 let f = move |s: &mut [::polars_core::prelude::Column]| {
22 $e(s, $($args),*)
23 };
24
25 SpecialEq::new(Arc::new(f))
26 }};
27}
28
29#[macro_export]
33macro_rules! map_as_slice {
34 ($func:path) => {{
35 let f = move |s: &mut [::polars_core::prelude::Column]| {
36 $func(s)
37 };
38
39 SpecialEq::new(Arc::new(f))
40 }};
41
42 ($func:path, $($args:expr),*) => {{
43 let f = move |s: &mut [::polars_core::prelude::Column]| {
44 $func(s, $($args),*)
45 };
46
47 SpecialEq::new(Arc::new(f))
48 }};
49}
50
51#[macro_export]
54macro_rules! map_owned {
55 ($func:path) => {{
56 let f = move |c: &mut [::polars_core::prelude::Column]| {
57 let c = std::mem::take(&mut c[0]);
58 $func(c)
59 };
60
61 SpecialEq::new(Arc::new(f))
62 }};
63
64 ($func:path, $($args:expr),*) => {{
65 let f = move |c: &mut [::polars_core::prelude::Column]| {
66 let c = std::mem::take(&mut c[0]);
67 $func(c, $($args),*)
68 };
69
70 SpecialEq::new(Arc::new(f))
71 }};
72}
73
74#[macro_export]
76macro_rules! map {
77 ($func:path) => {{
78 let f = move |c: &mut [::polars_core::prelude::Column]| {
79 let c = &c[0];
80 $func(c)
81 };
82
83 SpecialEq::new(Arc::new(f))
84 }};
85
86 ($func:path, $($args:expr),*) => {{
87 let f = move |c: &mut [::polars_core::prelude::Column]| {
88 let c = &c[0];
89 $func(c, $($args),*)
90 };
91
92 SpecialEq::new(Arc::new(f))
93 }};
94}
95
96#[cfg(feature = "dtype-array")]
97mod array;
98mod binary;
99#[cfg(feature = "bitwise")]
100mod bitwise;
101mod boolean;
102#[cfg(feature = "business")]
103mod business;
104#[cfg(feature = "dtype-categorical")]
105mod cat;
106#[cfg(feature = "cum_agg")]
107mod cum;
108#[cfg(feature = "temporal")]
109mod datetime;
110#[cfg(feature = "dtype-extension")]
111mod extension;
112mod groups_dispatch;
113mod horizontal;
114mod list;
115mod misc;
116mod pow;
117#[cfg(feature = "random")]
118mod random;
119#[cfg(feature = "range")]
120mod range;
121#[cfg(feature = "rolling_window")]
122mod rolling;
123#[cfg(feature = "rolling_window_by")]
124mod rolling_by;
125#[cfg(feature = "round_series")]
126mod round;
127mod shift_and_fill;
128#[cfg(feature = "strings")]
129mod strings;
130#[cfg(feature = "dtype-struct")]
131pub(crate) mod struct_;
132#[cfg(feature = "temporal")]
133mod temporal;
134#[cfg(feature = "trigonometry")]
135mod trigonometry;
136
137pub use groups_dispatch::drop_items;
138
139pub fn function_expr_to_udf(func: IRFunctionExpr) -> SpecialEq<Arc<dyn ColumnsUdf>> {
140 use IRFunctionExpr as F;
141 match func {
142 #[cfg(feature = "dtype-array")]
144 F::ArrayExpr(func) => array::function_expr_to_udf(func),
145 F::BinaryExpr(func) => binary::function_expr_to_udf(func),
146 #[cfg(feature = "dtype-categorical")]
147 F::Categorical(func) => cat::function_expr_to_udf(func),
148 #[cfg(feature = "dtype-extension")]
149 F::Extension(func) => extension::function_expr_to_udf(func),
150 F::ListExpr(func) => list::function_expr_to_udf(func),
151 #[cfg(feature = "strings")]
152 F::StringExpr(func) => strings::function_expr_to_udf(func),
153 #[cfg(feature = "dtype-struct")]
154 F::StructExpr(func) => struct_::function_expr_to_udf(func),
155 #[cfg(feature = "temporal")]
156 F::TemporalExpr(func) => temporal::temporal_func_to_udf(func),
157 #[cfg(feature = "bitwise")]
158 F::Bitwise(func) => bitwise::function_expr_to_udf(func),
159
160 F::Boolean(func) => boolean::function_expr_to_udf(func),
162 #[cfg(feature = "business")]
163 F::Business(func) => business::function_expr_to_udf(func),
164 #[cfg(feature = "abs")]
165 F::Abs => map!(misc::abs),
166 F::Negate => map!(misc::negate),
167 F::NullCount => {
168 let f = |s: &mut [Column]| {
169 let s = &s[0];
170 Ok(Column::new(s.name().clone(), [s.null_count() as IdxSize]))
171 };
172 wrap!(f)
173 },
174 F::Pow(func) => match func {
175 IRPowFunction::Generic => wrap!(pow::pow),
176 IRPowFunction::Sqrt => map!(pow::sqrt),
177 IRPowFunction::Cbrt => map!(pow::cbrt),
178 },
179 #[cfg(feature = "row_hash")]
180 F::Hash(k0, k1, k2, k3) => {
181 map!(misc::row_hash, k0, k1, k2, k3)
182 },
183 #[cfg(feature = "arg_where")]
184 F::ArgWhere => {
185 wrap!(misc::arg_where)
186 },
187 #[cfg(feature = "index_of")]
188 F::IndexOf => {
189 map_as_slice!(misc::index_of)
190 },
191 #[cfg(feature = "search_sorted")]
192 F::SearchSorted { side, descending } => {
193 map_as_slice!(misc::search_sorted_impl, side, descending)
194 },
195 #[cfg(feature = "range")]
196 F::Range(func) => range::function_expr_to_udf(func),
197
198 #[cfg(feature = "trigonometry")]
199 F::Trigonometry(trig_function) => {
200 map!(trigonometry::apply_trigonometric_function, trig_function)
201 },
202 #[cfg(feature = "trigonometry")]
203 F::Atan2 => {
204 wrap!(trigonometry::apply_arctan2)
205 },
206
207 #[cfg(feature = "sign")]
208 F::Sign => {
209 map!(misc::sign)
210 },
211 F::FillNull => {
212 map_as_slice!(misc::fill_null)
213 },
214 #[cfg(feature = "rolling_window")]
215 F::RollingExpr { function, options } => {
216 use IRRollingFunction::*;
217 use polars_plan::plans::IRRollingFunction;
218 match function {
219 Min => map!(rolling::rolling_min, options.clone()),
220 Max => map!(rolling::rolling_max, options.clone()),
221 Mean => map!(rolling::rolling_mean, options.clone()),
222 Sum => map!(rolling::rolling_sum, options.clone()),
223 Quantile => map!(rolling::rolling_quantile, options.clone()),
224 Var => map!(rolling::rolling_var, options.clone()),
225 Std => map!(rolling::rolling_std, options.clone()),
226 Rank => map!(rolling::rolling_rank, options.clone()),
227 #[cfg(feature = "moment")]
228 Skew => map!(rolling::rolling_skew, options.clone()),
229 #[cfg(feature = "moment")]
230 Kurtosis => map!(rolling::rolling_kurtosis, options.clone()),
231 #[cfg(feature = "cov")]
232 CorrCov {
233 corr_cov_options,
234 is_corr,
235 } => {
236 map_as_slice!(
237 rolling::rolling_corr_cov,
238 options.clone(),
239 corr_cov_options,
240 is_corr
241 )
242 },
243 Map(f) => {
244 map!(rolling::rolling_map, options.clone(), f.clone())
245 },
246 }
247 },
248 #[cfg(feature = "rolling_window_by")]
249 F::RollingExprBy {
250 function_by,
251 options,
252 } => {
253 use IRRollingFunctionBy::*;
254 use polars_plan::plans::IRRollingFunctionBy;
255 match function_by {
256 MinBy => map_as_slice!(rolling_by::rolling_min_by, options.clone()),
257 MaxBy => map_as_slice!(rolling_by::rolling_max_by, options.clone()),
258 MeanBy => map_as_slice!(rolling_by::rolling_mean_by, options.clone()),
259 SumBy => map_as_slice!(rolling_by::rolling_sum_by, options.clone()),
260 QuantileBy => {
261 map_as_slice!(rolling_by::rolling_quantile_by, options.clone())
262 },
263 VarBy => map_as_slice!(rolling_by::rolling_var_by, options.clone()),
264 StdBy => map_as_slice!(rolling_by::rolling_std_by, options.clone()),
265 RankBy => map_as_slice!(rolling_by::rolling_rank_by, options.clone()),
266 }
267 },
268 #[cfg(feature = "hist")]
269 F::Hist {
270 bin_count,
271 include_category,
272 include_breakpoint,
273 } => {
274 map_as_slice!(misc::hist, bin_count, include_category, include_breakpoint)
275 },
276 F::Rechunk => map!(misc::rechunk),
277 F::ShiftAndFill => {
278 map_as_slice!(shift_and_fill::shift_and_fill)
279 },
280 F::DropNans => map_owned!(misc::drop_nans),
281 F::DropNulls => map!(misc::drop_nulls),
282 #[cfg(feature = "round_series")]
283 F::Clip { has_min, has_max } => {
284 map_as_slice!(misc::clip, has_min, has_max)
285 },
286 F::Quantile { method } => map_as_slice!(misc::quantile, method),
287 #[cfg(feature = "mode")]
288 F::Mode { maintain_order } => map!(misc::mode, maintain_order),
289 #[cfg(feature = "moment")]
290 F::Skew(bias) => map!(misc::skew, bias),
291 #[cfg(feature = "moment")]
292 F::Kurtosis(fisher, bias) => map!(misc::kurtosis, fisher, bias),
293 F::ArgUnique => map!(misc::arg_unique),
294 F::ArgMin => map!(misc::arg_min),
295 F::ArgMax => map!(misc::arg_max),
296 F::ArgSort {
297 descending,
298 nulls_last,
299 } => map!(misc::arg_sort, descending, nulls_last),
300 F::MinBy => map_as_slice!(misc::min_by),
301 F::MaxBy => map_as_slice!(misc::max_by),
302 F::Product => map!(misc::product),
303 F::Repeat => map_as_slice!(misc::repeat),
304 #[cfg(feature = "rank")]
305 F::Rank { options, seed } => map!(misc::rank, options, seed),
306 #[cfg(feature = "dtype-struct")]
307 F::AsStruct => {
308 map_as_slice!(misc::as_struct)
309 },
310 #[cfg(feature = "top_k")]
311 F::TopK { descending } => {
312 map_as_slice!(polars_ops::prelude::top_k, descending)
313 },
314 #[cfg(feature = "top_k")]
315 F::TopKBy { descending } => {
316 map_as_slice!(polars_ops::prelude::top_k_by, descending.clone())
317 },
318 F::Shift => map_as_slice!(shift_and_fill::shift),
319 #[cfg(feature = "cum_agg")]
320 F::CumCount { reverse } => map!(cum::cum_count, reverse),
321 #[cfg(feature = "cum_agg")]
322 F::CumSum { reverse } => map!(cum::cum_sum, reverse),
323 #[cfg(feature = "cum_agg")]
324 F::CumProd { reverse } => map!(cum::cum_prod, reverse),
325 #[cfg(feature = "cum_agg")]
326 F::CumMin { reverse } => map!(cum::cum_min, reverse),
327 #[cfg(feature = "cum_agg")]
328 F::CumMax { reverse } => map!(cum::cum_max, reverse),
329 #[cfg(feature = "dtype-struct")]
330 F::ValueCounts {
331 sort,
332 parallel,
333 name,
334 normalize,
335 } => map!(misc::value_counts, sort, parallel, name.clone(), normalize),
336 #[cfg(feature = "unique_counts")]
337 F::UniqueCounts => map!(misc::unique_counts),
338 F::Reverse => map!(misc::reverse),
339 #[cfg(feature = "approx_unique")]
340 F::ApproxNUnique => map!(misc::approx_n_unique),
341 F::Coalesce => map_as_slice!(misc::coalesce),
342 #[cfg(feature = "diff")]
343 F::Diff(null_behavior) => map_as_slice!(misc::diff, null_behavior),
344 #[cfg(feature = "pct_change")]
345 F::PctChange => map_as_slice!(misc::pct_change),
346 #[cfg(feature = "interpolate")]
347 F::Interpolate(method) => {
348 map!(misc::interpolate, method)
349 },
350 #[cfg(feature = "interpolate_by")]
351 F::InterpolateBy => {
352 map_as_slice!(misc::interpolate_by)
353 },
354 #[cfg(feature = "log")]
355 F::Entropy { base, normalize } => map!(misc::entropy, base, normalize),
356 #[cfg(feature = "log")]
357 F::Log => map_as_slice!(misc::log),
358 #[cfg(feature = "log")]
359 F::Log1p => map!(misc::log1p),
360 #[cfg(feature = "log")]
361 F::Exp => map!(misc::exp),
362 F::Unique(stable) => map!(misc::unique, stable),
363 #[cfg(feature = "round_series")]
364 F::Round { decimals, mode } => map!(round::round, decimals, mode),
365 #[cfg(feature = "round_series")]
366 F::RoundSF { digits } => map!(round::round_sig_figs, digits),
367 #[cfg(feature = "round_series")]
368 F::Truncate { decimals } => map!(round::truncate, decimals),
369 #[cfg(feature = "round_series")]
370 F::Floor => map!(round::floor),
371 #[cfg(feature = "round_series")]
372 F::Ceil => map!(round::ceil),
373 #[cfg(feature = "fused")]
374 F::Fused(op) => map_as_slice!(misc::fused, op),
375 F::ConcatExpr { rechunk } => map_as_slice!(misc::concat_expr, rechunk),
376 #[cfg(feature = "cov")]
377 F::Correlation { method } => map_as_slice!(misc::corr, method),
378 #[cfg(feature = "peaks")]
379 F::PeakMin => map!(misc::peak_min),
380 #[cfg(feature = "peaks")]
381 F::PeakMax => map!(misc::peak_max),
382 #[cfg(feature = "repeat_by")]
383 F::RepeatBy => map_as_slice!(misc::repeat_by),
384 #[cfg(feature = "dtype-array")]
385 F::Reshape(dims) => map!(misc::reshape, &dims),
386 #[cfg(feature = "cutqcut")]
387 F::Cut {
388 breaks,
389 labels,
390 left_closed,
391 include_breaks,
392 } => map!(
393 misc::cut,
394 breaks.clone(),
395 labels.clone(),
396 left_closed,
397 include_breaks
398 ),
399 #[cfg(feature = "cutqcut")]
400 F::QCut {
401 probs,
402 labels,
403 left_closed,
404 allow_duplicates,
405 include_breaks,
406 } => map!(
407 misc::qcut,
408 probs.clone(),
409 labels.clone(),
410 left_closed,
411 allow_duplicates,
412 include_breaks
413 ),
414 #[cfg(feature = "rle")]
415 F::RLE => map!(polars_ops::series::rle),
416 #[cfg(feature = "rle")]
417 F::RLEID => map!(polars_ops::series::rle_id),
418 F::ToPhysical => map!(misc::to_physical),
419 #[cfg(feature = "random")]
420 F::Random { method, seed } => {
421 use IRRandomMethod::*;
422 use polars_plan::plans::IRRandomMethod;
423 match method {
424 Shuffle => map!(random::shuffle, seed),
425 Sample {
426 is_fraction,
427 with_replacement,
428 shuffle,
429 } => {
430 if is_fraction {
431 map_as_slice!(random::sample_frac, with_replacement, shuffle, seed)
432 } else {
433 map_as_slice!(random::sample_n, with_replacement, shuffle, seed)
434 }
435 },
436 }
437 },
438 F::SetSortedFlag(sortedness) => map!(misc::set_sorted_flag, sortedness),
439 #[cfg(feature = "ffi_plugin")]
440 F::FfiPlugin {
441 flags: _,
442 lib,
443 symbol,
444 kwargs,
445 } => unsafe {
446 map_as_slice!(
447 polars_plan::plans::plugin::call_plugin,
448 lib.as_ref(),
449 symbol.as_ref(),
450 kwargs.as_ref()
451 )
452 },
453
454 F::FoldHorizontal {
455 callback,
456 returns_scalar,
457 return_dtype,
458 } => map_as_slice!(
459 horizontal::fold,
460 &callback,
461 returns_scalar,
462 return_dtype.as_ref()
463 ),
464 F::ReduceHorizontal {
465 callback,
466 returns_scalar,
467 return_dtype,
468 } => map_as_slice!(
469 horizontal::reduce,
470 &callback,
471 returns_scalar,
472 return_dtype.as_ref()
473 ),
474 #[cfg(feature = "dtype-struct")]
475 F::CumReduceHorizontal {
476 callback,
477 returns_scalar,
478 return_dtype,
479 } => map_as_slice!(
480 horizontal::cum_reduce,
481 &callback,
482 returns_scalar,
483 return_dtype.as_ref()
484 ),
485 #[cfg(feature = "dtype-struct")]
486 F::CumFoldHorizontal {
487 callback,
488 returns_scalar,
489 return_dtype,
490 include_init,
491 } => map_as_slice!(
492 horizontal::cum_fold,
493 &callback,
494 returns_scalar,
495 return_dtype.as_ref(),
496 include_init
497 ),
498
499 F::MaxHorizontal => wrap!(misc::max_horizontal),
500 F::MinHorizontal => wrap!(misc::min_horizontal),
501 F::SumHorizontal { ignore_nulls } => wrap!(misc::sum_horizontal, ignore_nulls),
502 F::MeanHorizontal { ignore_nulls } => wrap!(misc::mean_horizontal, ignore_nulls),
503 #[cfg(feature = "ewma")]
504 F::EwmMean { options } => map!(misc::ewm_mean, options),
505 #[cfg(feature = "ewma_by")]
506 F::EwmMeanBy { half_life } => map_as_slice!(misc::ewm_mean_by, half_life),
507 #[cfg(feature = "ewma")]
508 F::EwmStd { options } => map!(misc::ewm_std, options),
509 #[cfg(feature = "ewma")]
510 F::EwmVar { options } => map!(misc::ewm_var, options),
511 #[cfg(feature = "replace")]
512 F::Replace => {
513 map_as_slice!(misc::replace)
514 },
515 #[cfg(feature = "replace")]
516 F::ReplaceStrict { return_dtype } => {
517 map_as_slice!(misc::replace_strict, return_dtype.clone())
518 },
519
520 F::FillNullWithStrategy(strategy) => map!(misc::fill_null_with_strategy, strategy),
521 F::GatherEvery { n, offset } => map!(misc::gather_every, n, offset),
522 #[cfg(feature = "reinterpret")]
523 F::Reinterpret(dtype) => map!(misc::reinterpret, &dtype),
524 F::ExtendConstant => map_as_slice!(misc::extend_constant),
525
526 F::RowEncode(dts, variants) => {
527 map_as_slice!(misc::row_encode, dts.clone(), variants.clone())
528 },
529 #[cfg(feature = "dtype-struct")]
530 F::RowDecode(fs, variants) => {
531 map_as_slice!(misc::row_decode, fs.clone(), variants.clone())
532 },
533 F::DynamicPred { pred } => {
534 map_as_slice!(misc::dynamic_pred, &pred)
535 },
536 }
537}
538
539pub trait GroupsUdf: Send + Sync + 'static {
540 fn evaluate_on_groups<'a>(
541 &self,
542 inputs: &[Arc<dyn PhysicalExpr>],
543 df: &DataFrame,
544 groups: &'a GroupPositions,
545 state: &ExecutionState,
546 ) -> PolarsResult<AggregationContext<'a>>;
547}
548
549pub fn function_expr_to_groups_udf(func: &IRFunctionExpr) -> Option<SpecialEq<Arc<dyn GroupsUdf>>> {
550 macro_rules! wrap_groups {
551 ($f:expr$(, ($arg:expr, $n:ident:$ty:ty))*) => {{
552 struct Wrap($($ty),*);
553 impl GroupsUdf for Wrap {
554 fn evaluate_on_groups<'a>(
555 &self,
556 inputs: &[Arc<dyn PhysicalExpr>],
557 df: &DataFrame,
558 groups: &'a GroupPositions,
559 state: &ExecutionState,
560 ) -> PolarsResult<AggregationContext<'a>> {
561 let Wrap($($n),*) = self;
562 $f(inputs, df, groups, state$(, *$n)*)
563 }
564 }
565
566 SpecialEq::new(Arc::new(Wrap($($arg),*)) as Arc<dyn GroupsUdf>)
567 }};
568 }
569 use IRFunctionExpr as F;
570 Some(match func {
571 F::NullCount => wrap_groups!(groups_dispatch::null_count),
572 F::Reverse => wrap_groups!(groups_dispatch::reverse),
573 F::Boolean(IRBooleanFunction::HasNulls) => wrap_groups!(groups_dispatch::has_nulls),
574 F::Boolean(IRBooleanFunction::Any { ignore_nulls }) => {
575 let ignore_nulls = *ignore_nulls;
576 wrap_groups!(groups_dispatch::any, (ignore_nulls, v: bool))
577 },
578 F::Boolean(IRBooleanFunction::All { ignore_nulls }) => {
579 let ignore_nulls = *ignore_nulls;
580 wrap_groups!(groups_dispatch::all, (ignore_nulls, v: bool))
581 },
582 F::Boolean(IRBooleanFunction::IsEmpty { ignore_nulls }) => {
583 let ignore_nulls = *ignore_nulls;
584 wrap_groups!(groups_dispatch::is_empty, (ignore_nulls, v: bool))
585 },
586 #[cfg(feature = "bitwise")]
587 F::Bitwise(f) => {
588 use polars_plan::plans::IRBitwiseFunction as B;
589 match f {
590 B::And => wrap_groups!(groups_dispatch::bitwise_and),
591 B::Or => wrap_groups!(groups_dispatch::bitwise_or),
592 B::Xor => wrap_groups!(groups_dispatch::bitwise_xor),
593 _ => return None,
594 }
595 },
596 F::DropNans => wrap_groups!(groups_dispatch::drop_nans),
597 F::DropNulls => wrap_groups!(groups_dispatch::drop_nulls),
598
599 F::Quantile { method } => {
600 wrap_groups!(groups_dispatch::quantile, (*method, v: QuantileMethod))
601 },
602 #[cfg(feature = "moment")]
603 F::Skew(bias) => wrap_groups!(groups_dispatch::skew, (*bias, v: bool)),
604 #[cfg(feature = "moment")]
605 F::Kurtosis(fisher, bias) => {
606 wrap_groups!(groups_dispatch::kurtosis, (*fisher, v1: bool), (*bias, v2: bool))
607 },
608
609 F::Unique(stable) => wrap_groups!(groups_dispatch::unique, (*stable, v: bool)),
610 F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Forward(limit)) => {
611 wrap_groups!(groups_dispatch::forward_fill_null, (*limit, v: Option<IdxSize>))
612 },
613 F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Backward(limit)) => {
614 wrap_groups!(groups_dispatch::backward_fill_null, (*limit, v: Option<IdxSize>))
615 },
616
617 _ => return None,
618 })
619}