1mod aggregation;
2mod alias;
3mod apply;
4mod binary;
5mod cast;
6mod column;
7mod count;
8mod element;
9mod eval;
10#[cfg(feature = "dtype-struct")]
11mod field;
12mod filter;
13mod gather;
14mod group_iter;
15mod literal;
16#[cfg(feature = "dynamic_group_by")]
17mod rolling;
18mod slice;
19mod sort;
20mod sortby;
21#[cfg(feature = "dtype-struct")]
22mod structeval;
23mod ternary;
24mod window;
25
26use std::borrow::Cow;
27use std::fmt::{Display, Formatter};
28
29pub(crate) use aggregation::*;
30pub(crate) use alias::*;
31pub(crate) use apply::*;
32use arrow::array::ArrayRef;
33use arrow::bitmap::MutableBitmap;
34use arrow::legacy::utils::CustomIterTools;
35pub(crate) use binary::*;
36pub(crate) use cast::*;
37pub(crate) use column::*;
38pub(crate) use count::*;
39pub(crate) use element::*;
40pub(crate) use eval::*;
41#[cfg(feature = "dtype-struct")]
42pub(crate) use field::*;
43pub(crate) use filter::*;
44pub(crate) use gather::*;
45pub(crate) use literal::*;
46use polars_core::prelude::*;
47use polars_io::predicates::PhysicalIoExpr;
48use polars_plan::prelude::*;
49#[cfg(feature = "dynamic_group_by")]
50pub(crate) use rolling::RollingExpr;
51pub(crate) use slice::*;
52pub(crate) use sort::*;
53pub(crate) use sortby::*;
54#[cfg(feature = "dtype-struct")]
55pub(crate) use structeval::*;
56pub(crate) use ternary::*;
57pub use window::window_function_format_order_by;
58pub(crate) use window::*;
59
60use crate::state::ExecutionState;
61
62#[derive(Clone, Debug)]
63pub enum AggState {
64 AggregatedList(Column),
67 AggregatedScalar(Column),
71 NotAggregated(Column),
73 LiteralScalar(Column),
75}
76
77impl AggState {
78 fn try_map<F>(&self, func: F) -> PolarsResult<Self>
79 where
80 F: FnOnce(&Column) -> PolarsResult<Column>,
81 {
82 Ok(match self {
83 AggState::AggregatedList(c) => AggState::AggregatedList(func(c)?),
84 AggState::AggregatedScalar(c) => AggState::AggregatedScalar(func(c)?),
85 AggState::LiteralScalar(c) => AggState::LiteralScalar(func(c)?),
86 AggState::NotAggregated(c) => AggState::NotAggregated(func(c)?),
87 })
88 }
89
90 fn is_scalar(&self) -> bool {
91 matches!(self, Self::AggregatedScalar(_))
92 }
93
94 pub fn name(&self) -> &PlSmallStr {
95 match self {
96 AggState::AggregatedList(s)
97 | AggState::NotAggregated(s)
98 | AggState::LiteralScalar(s)
99 | AggState::AggregatedScalar(s) => s.name(),
100 }
101 }
102
103 pub fn flat_dtype(&self) -> &DataType {
104 match self {
105 AggState::AggregatedList(s) => s.dtype().inner_dtype().unwrap(),
106 AggState::NotAggregated(s)
107 | AggState::LiteralScalar(s)
108 | AggState::AggregatedScalar(s) => s.dtype(),
109 }
110 }
111}
112
113#[derive(Debug, PartialEq, Clone, Copy)]
115pub(crate) enum UpdateGroups {
116 No,
118 WithGroupsLen,
121 WithSeriesLen,
125}
126
127#[cfg_attr(debug_assertions, derive(Debug))]
128pub struct AggregationContext<'a> {
129 pub(crate) state: AggState,
136 pub(crate) groups: Cow<'a, GroupPositions>,
138 pub(crate) update_groups: UpdateGroups,
142 pub(crate) original_len: bool,
145}
146
147impl<'a> AggregationContext<'a> {
148 pub(crate) fn groups(&mut self) -> &Cow<'a, GroupPositions> {
149 match self.update_groups {
150 UpdateGroups::No => {},
151 UpdateGroups::WithGroupsLen => {
152 let mut offset = 0 as IdxSize;
157
158 match self.groups.as_ref().as_ref() {
159 GroupsType::Idx(groups) => {
160 let groups = groups
161 .iter()
162 .map(|g| {
163 let len = g.1.len() as IdxSize;
164 let new_offset = offset + len;
165 let out = [offset, len];
166 offset = new_offset;
167 out
168 })
169 .collect();
170 self.groups =
171 Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable())
172 },
173 GroupsType::Slice { groups, .. } => {
177 let groups = groups
179 .iter()
180 .map(|g| {
181 let len = g[1];
182 let new = [offset, g[1]];
183 offset += len;
184 new
185 })
186 .collect();
187 self.groups =
188 Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable())
189 },
190 }
191 self.update_groups = UpdateGroups::No;
192 },
193 UpdateGroups::WithSeriesLen => {
194 let s = self.get_values().clone();
195 self.det_groups_from_list(s.as_materialized_series());
196 },
197 }
198 &self.groups
199 }
200
201 pub(crate) fn get_values(&self) -> &Column {
202 match &self.state {
203 AggState::NotAggregated(s)
204 | AggState::AggregatedScalar(s)
205 | AggState::AggregatedList(s) => s,
206 AggState::LiteralScalar(s) => s,
207 }
208 }
209
210 pub fn agg_state(&self) -> &AggState {
211 &self.state
212 }
213
214 pub(crate) fn is_not_aggregated(&self) -> bool {
215 matches!(
216 &self.state,
217 AggState::NotAggregated(_) | AggState::LiteralScalar(_)
218 )
219 }
220
221 pub(crate) fn is_aggregated(&self) -> bool {
222 !self.is_not_aggregated()
223 }
224
225 pub(crate) fn is_literal(&self) -> bool {
226 matches!(self.state, AggState::LiteralScalar(_))
227 }
228
229 fn new(
233 column: Column,
234 groups: Cow<'a, GroupPositions>,
235 aggregated: bool,
236 ) -> AggregationContext<'a> {
237 let series = if aggregated {
238 assert_eq!(column.len(), groups.len());
239 AggState::AggregatedScalar(column)
240 } else {
241 AggState::NotAggregated(column)
242 };
243
244 Self {
245 state: series,
246 groups,
247 update_groups: UpdateGroups::No,
248 original_len: true,
249 }
250 }
251
252 fn with_agg_state(&mut self, agg_state: AggState) {
253 self.state = agg_state;
254 }
255
256 fn from_agg_state(
257 agg_state: AggState,
258 groups: Cow<'a, GroupPositions>,
259 ) -> AggregationContext<'a> {
260 Self {
261 state: agg_state,
262 groups,
263 update_groups: UpdateGroups::No,
264 original_len: true,
265 }
266 }
267
268 pub(crate) fn set_original_len(&mut self, original_len: bool) -> &mut Self {
269 self.original_len = original_len;
270 self
271 }
272
273 pub(crate) fn with_update_groups(&mut self, update: UpdateGroups) -> &mut Self {
274 self.update_groups = update;
275 self
276 }
277
278 fn det_groups_from_list(&mut self, s: &Series) {
279 let mut offset = 0 as IdxSize;
280 let list = s
281 .list()
282 .expect("impl error, should be a list at this point");
283
284 match list.chunks().len() {
285 1 => {
286 let arr = list.downcast_iter().next().unwrap();
287 let offsets = arr.offsets().as_slice();
288
289 let mut previous = 0i64;
290 let groups = offsets[1..]
291 .iter()
292 .map(|&o| {
293 let len = (o - previous) as IdxSize;
294 let new_offset = offset + len;
295
296 previous = o;
297 let out = [offset, len];
298 offset = new_offset;
299 out
300 })
301 .collect_trusted();
302 self.groups =
303 Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable());
304 },
305 _ => {
306 let groups = {
307 self.get_values()
308 .list()
309 .expect("impl error, should be a list at this point")
310 .amortized_iter()
311 .map(|s| {
312 if let Some(s) = s {
313 let len = s.as_ref().len() as IdxSize;
314 let new_offset = offset + len;
315 let out = [offset, len];
316 offset = new_offset;
317 out
318 } else {
319 [offset, 0]
320 }
321 })
322 .collect_trusted()
323 };
324 self.groups =
325 Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable());
326 },
327 }
328 self.update_groups = UpdateGroups::No;
329 }
330
331 pub(crate) fn with_values(
335 &mut self,
336 column: Column,
337 aggregated: bool,
338 expr: Option<&Expr>,
339 ) -> PolarsResult<&mut Self> {
340 self.with_values_and_args(
341 column,
342 aggregated,
343 expr,
344 false,
345 self.agg_state().is_scalar(),
346 )
347 }
348
349 pub(crate) fn with_values_and_args(
350 &mut self,
351 column: Column,
352 aggregated: bool,
353 expr: Option<&Expr>,
354 preserve_literal: bool,
357 returns_scalar: bool,
358 ) -> PolarsResult<&mut Self> {
359 self.state = match (aggregated, column.dtype()) {
360 (true, &DataType::List(_)) if !returns_scalar => {
361 if column.len() != self.groups.len() {
362 let fmt_expr = if let Some(e) = expr {
363 format!("'{e:?}' ")
364 } else {
365 String::new()
366 };
367 polars_bail!(
368 ComputeError:
369 "aggregation expression '{}' produced a different number of elements: {} \
370 than the number of groups: {} (this is likely invalid)",
371 fmt_expr, column.len(), self.groups.len(),
372 );
373 }
374 AggState::AggregatedList(column)
375 },
376 (true, _) => AggState::AggregatedScalar(column),
377 _ => {
378 match self.state {
379 AggState::AggregatedScalar(_) => AggState::AggregatedScalar(column),
382 AggState::LiteralScalar(_) if column.len() == 1 && preserve_literal => {
384 AggState::LiteralScalar(column)
385 },
386 _ => AggState::NotAggregated(column.into_column()),
387 }
388 },
389 };
390 Ok(self)
391 }
392
393 pub(crate) fn with_literal(&mut self, column: Column) -> &mut Self {
394 self.state = AggState::LiteralScalar(column);
395 self
396 }
397
398 pub(crate) fn with_groups(&mut self, groups: GroupPositions) -> &mut Self {
400 if let AggState::AggregatedList(_) = self.agg_state() {
401 self.with_values(self.flat_naive().into_owned(), false, None)
403 .unwrap();
404 }
405 self.groups = Cow::Owned(groups);
406 self.update_groups = UpdateGroups::No;
408 self
409 }
410
411 pub fn normalize_values(&mut self) {
413 self.set_original_len(false);
414 self.groups();
415 let values = self.flat_naive();
416 let values = unsafe { values.agg_list(&self.groups) };
417 self.state = AggState::AggregatedList(values);
418 self.with_update_groups(UpdateGroups::WithGroupsLen);
419 }
420
421 pub fn aggregated_as_list<'b>(&'b mut self) -> Cow<'b, ListChunked> {
423 self.aggregated();
424 let out = self.get_values();
425 match self.agg_state() {
426 AggState::AggregatedScalar(_) => Cow::Owned(out.as_list()),
427 _ => Cow::Borrowed(out.list().unwrap()),
428 }
429 }
430
431 pub fn aggregated(&mut self) -> Column {
433 match self.state.clone() {
436 AggState::NotAggregated(s) => {
437 self.groups();
442 #[cfg(debug_assertions)]
443 {
444 if self.groups.len() > s.len() {
445 polars_warn!(
446 "groups may be out of bounds; more groups than elements in a series is only possible in dynamic group_by"
447 )
448 }
449 }
450
451 let out = unsafe { s.agg_list(&self.groups) };
454 self.state = AggState::AggregatedList(out.clone());
455
456 self.update_groups = UpdateGroups::WithGroupsLen;
457 out
458 },
459 AggState::AggregatedList(s) | AggState::AggregatedScalar(s) => s.into_column(),
460 AggState::LiteralScalar(s) => {
461 let rows = self.groups.len();
462 let s = s.implode().unwrap();
463 let s = s.new_from_index(0, rows);
464 let s = s.into_column();
465 self.state = AggState::AggregatedList(s.clone());
466 self.with_update_groups(UpdateGroups::WithSeriesLen);
467 s.clone()
468 },
469 }
470 }
471
472 pub fn finalize(&mut self) -> Column {
474 match &self.state {
477 AggState::LiteralScalar(c) => {
478 let c = c.clone();
479 self.groups();
480 let rows = self.groups.len();
481 c.new_from_index(0, rows)
482 },
483 _ => self.aggregated(),
484 }
485 }
486
487 fn arity_should_explode(&self) -> bool {
490 use AggState::*;
491 match self.agg_state() {
492 LiteralScalar(s) => s.len() == 1,
493 AggregatedScalar(_) => true,
494 _ => false,
495 }
496 }
497
498 pub fn get_final_aggregation(mut self) -> (Column, Cow<'a, GroupPositions>) {
499 let _ = self.groups();
500 let groups = self.groups;
501 match self.state {
502 AggState::NotAggregated(c) => (c, groups),
503 AggState::AggregatedScalar(c) => (c, groups),
504 AggState::LiteralScalar(c) => (c, groups),
505 AggState::AggregatedList(c) => {
506 let flattened = c
507 .explode(ExplodeOptions {
508 empty_as_null: false,
509 keep_nulls: true,
510 })
511 .unwrap();
512 let groups = groups.into_owned();
513 let groups = groups.unroll();
538 (flattened, Cow::Owned(groups))
539 },
540 }
541 }
542
543 pub(crate) fn flat_naive(&self) -> Cow<'_, Column> {
548 match &self.state {
549 AggState::NotAggregated(c) => Cow::Borrowed(c),
550 AggState::AggregatedList(c) => {
551 if cfg!(debug_assertions) {
552 if self.groups.is_overlapping() {
555 polars_warn!(
556 "performance - an aggregated list with overlapping groups may consume excessive memory"
557 )
558 }
559 }
560
561 Cow::Owned(
563 c.explode(ExplodeOptions {
564 empty_as_null: false,
565 keep_nulls: true,
566 })
567 .unwrap(),
568 )
569 },
570 AggState::AggregatedScalar(c) => Cow::Borrowed(c),
571 AggState::LiteralScalar(c) => Cow::Borrowed(c),
572 }
573 }
574
575 fn flat_naive_length(&self) -> usize {
576 match &self.state {
577 AggState::NotAggregated(c) => c.len(),
578 AggState::AggregatedList(c) => c.list().unwrap().inner_length(),
579 AggState::AggregatedScalar(c) => c.len(),
580 AggState::LiteralScalar(_) => 1,
581 }
582 }
583
584 pub(crate) fn take(&mut self) -> Column {
586 let c = match &mut self.state {
587 AggState::NotAggregated(c)
588 | AggState::AggregatedScalar(c)
589 | AggState::AggregatedList(c) => c,
590 AggState::LiteralScalar(c) => c,
591 };
592 std::mem::take(c)
593 }
594
595 fn groups_cover_all_values(&mut self) -> bool {
597 if matches!(
598 self.state,
599 AggState::LiteralScalar(_) | AggState::AggregatedScalar(_)
600 ) {
601 return true;
602 }
603
604 let num_values = self.flat_naive_length();
605 match self.groups().as_ref().as_ref() {
606 GroupsType::Idx(groups) => {
607 let mut seen = MutableBitmap::from_len_zeroed(num_values);
608 for (_, g) in groups {
609 for i in g.iter() {
610 unsafe { seen.set_unchecked(*i as usize, true) };
611 }
612 }
613 seen.unset_bits() == 0
614 },
615 GroupsType::Slice {
616 groups,
617 overlapping: true,
618 monotonic: _,
619 } => {
620 let mut offset = 0;
622 let mut covers_all = true;
623 for [start, length] in groups {
624 covers_all &= *start <= offset;
625 offset = start + length;
626 }
627 covers_all && offset == num_values as IdxSize
628 },
629
630 GroupsType::Slice {
632 groups,
633 overlapping: false,
634 monotonic: _,
635 } => groups.iter().map(|[_, l]| *l as usize).sum::<usize>() == num_values,
636 }
637 }
638
639 fn set_groups_for_undefined_agg_states(&mut self) {
642 match &self.state {
643 AggState::AggregatedList(_) | AggState::NotAggregated(_) => {},
644 AggState::AggregatedScalar(c) => {
645 assert_eq!(self.update_groups, UpdateGroups::No);
646 self.groups = Cow::Owned({
647 let groups = (0..c.len() as IdxSize).map(|i| [i, 1]).collect();
648 GroupsType::new_slice(groups, false, true).into_sliceable()
649 });
650 },
651 AggState::LiteralScalar(c) => {
652 assert_eq!(c.len(), 1);
653 assert_eq!(self.update_groups, UpdateGroups::No);
654 self.groups = Cow::Owned({
655 let groups = vec![[0, 1]; self.groups.len()];
656 GroupsType::new_slice(groups, true, true).into_sliceable()
657 });
658 },
659 }
660 }
661
662 pub fn into_static(&self) -> AggregationContext<'static> {
663 let groups: GroupPositions = GroupPositions::to_owned(&self.groups);
664 let groups: Cow<'static, GroupPositions> = Cow::Owned(groups);
665 AggregationContext {
666 state: self.state.clone(),
667 groups,
668 update_groups: self.update_groups,
669 original_len: self.original_len,
670 }
671 }
672}
673
674pub trait PhysicalExpr: Send + Sync {
677 fn as_expression(&self) -> Option<&Expr> {
678 None
679 }
680
681 fn as_column(&self) -> Option<PlSmallStr> {
682 None
683 }
684
685 fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult<Column>;
687
688 #[allow(clippy::ptr_arg)]
711 fn evaluate_on_groups<'a>(
712 &self,
713 df: &DataFrame,
714 groups: &'a GroupPositions,
715 state: &ExecutionState,
716 ) -> PolarsResult<AggregationContext<'a>>;
717
718 fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field>;
720
721 fn is_literal(&self) -> bool {
722 false
723 }
724 fn is_scalar(&self) -> bool;
725}
726
727impl Display for &dyn PhysicalExpr {
728 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
729 match self.as_expression() {
730 None => Ok(()),
731 Some(e) => write!(f, "{e:?}"),
732 }
733 }
734}
735
736pub struct PhysicalIoHelper {
740 pub expr: Arc<dyn PhysicalExpr>,
741 pub has_window_function: bool,
742}
743
744impl PhysicalIoExpr for PhysicalIoHelper {
745 fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
746 let mut state: ExecutionState = Default::default();
747 if self.has_window_function {
748 state.insert_has_window_function_flag();
749 }
750 self.expr.evaluate(df, &state).map(|c| {
751 debug_assert_eq!(c.dtype(), &DataType::Boolean);
753 (if c.len() == 1 && df.height() != 1 {
754 c.new_from_index(0, df.height())
756 } else {
757 c
758 })
759 .take_materialized_series()
760 })
761 }
762}
763
764pub fn phys_expr_to_io_expr(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalIoExpr> {
765 let has_window_function = if let Some(expr) = expr.as_expression() {
766 expr.into_iter().any(|expr| {
767 #[cfg(feature = "dynamic_group_by")]
768 if matches!(expr, Expr::Rolling { .. }) {
769 return true;
770 }
771
772 matches!(expr, Expr::Over { .. })
773 })
774 } else {
775 false
776 };
777 Arc::new(PhysicalIoHelper {
778 expr,
779 has_window_function,
780 }) as Arc<dyn PhysicalIoExpr>
781}