1mod aggregation;
2mod alias;
3mod apply;
4mod binary;
5mod cast;
6mod column;
7mod count;
8mod eval;
9mod filter;
10mod gather;
11mod group_iter;
12mod literal;
13#[cfg(feature = "dynamic_group_by")]
14mod rolling;
15mod slice;
16mod sort;
17mod sortby;
18mod ternary;
19mod window;
20
21use std::borrow::Cow;
22use std::fmt::{Display, Formatter};
23
24pub(crate) use aggregation::*;
25pub(crate) use alias::*;
26pub(crate) use apply::*;
27use arrow::array::ArrayRef;
28use arrow::bitmap::MutableBitmap;
29use arrow::legacy::utils::CustomIterTools;
30pub(crate) use binary::*;
31pub(crate) use cast::*;
32pub(crate) use column::*;
33pub(crate) use count::*;
34pub(crate) use eval::*;
35pub(crate) use filter::*;
36pub(crate) use gather::*;
37pub(crate) use literal::*;
38use polars_core::prelude::*;
39use polars_io::predicates::PhysicalIoExpr;
40use polars_plan::prelude::*;
41#[cfg(feature = "dynamic_group_by")]
42pub(crate) use rolling::RollingExpr;
43pub(crate) use slice::*;
44pub(crate) use sort::*;
45pub(crate) use sortby::*;
46pub(crate) use ternary::*;
47pub use window::window_function_format_order_by;
48pub(crate) use window::*;
49
50use crate::state::ExecutionState;
51
52#[derive(Clone, Debug)]
53pub enum AggState {
54 AggregatedList(Column),
57 AggregatedScalar(Column),
61 NotAggregated(Column),
63 LiteralScalar(Column),
65}
66
67impl AggState {
68 fn try_map<F>(&self, func: F) -> PolarsResult<Self>
69 where
70 F: FnOnce(&Column) -> PolarsResult<Column>,
71 {
72 Ok(match self {
73 AggState::AggregatedList(c) => AggState::AggregatedList(func(c)?),
74 AggState::AggregatedScalar(c) => AggState::AggregatedScalar(func(c)?),
75 AggState::LiteralScalar(c) => AggState::LiteralScalar(func(c)?),
76 AggState::NotAggregated(c) => AggState::NotAggregated(func(c)?),
77 })
78 }
79
80 fn is_scalar(&self) -> bool {
81 matches!(self, Self::AggregatedScalar(_))
82 }
83
84 pub fn name(&self) -> &PlSmallStr {
85 match self {
86 AggState::AggregatedList(s)
87 | AggState::NotAggregated(s)
88 | AggState::LiteralScalar(s)
89 | AggState::AggregatedScalar(s) => s.name(),
90 }
91 }
92
93 pub fn flat_dtype(&self) -> &DataType {
94 match self {
95 AggState::AggregatedList(s) => s.dtype().inner_dtype().unwrap(),
96 AggState::NotAggregated(s)
97 | AggState::LiteralScalar(s)
98 | AggState::AggregatedScalar(s) => s.dtype(),
99 }
100 }
101}
102
103#[derive(Debug, PartialEq, Clone, Copy)]
105pub(crate) enum UpdateGroups {
106 No,
108 WithGroupsLen,
111 WithSeriesLen,
115}
116
117#[cfg_attr(debug_assertions, derive(Debug))]
118pub struct AggregationContext<'a> {
119 pub(crate) state: AggState,
126 pub(crate) groups: Cow<'a, GroupPositions>,
128 pub(crate) update_groups: UpdateGroups,
132 pub(crate) original_len: bool,
135}
136
137impl<'a> AggregationContext<'a> {
138 pub(crate) fn groups(&mut self) -> &Cow<'a, GroupPositions> {
139 match self.update_groups {
140 UpdateGroups::No => {},
141 UpdateGroups::WithGroupsLen => {
142 let mut offset = 0 as IdxSize;
147
148 match self.groups.as_ref().as_ref() {
149 GroupsType::Idx(groups) => {
150 let groups = groups
151 .iter()
152 .map(|g| {
153 let len = g.1.len() as IdxSize;
154 let new_offset = offset + len;
155 let out = [offset, len];
156 offset = new_offset;
157 out
158 })
159 .collect();
160 self.groups = Cow::Owned(
161 GroupsType::Slice {
162 groups,
163 overlapping: false,
164 }
165 .into_sliceable(),
166 )
167 },
168 GroupsType::Slice { groups, .. } => {
172 let groups = groups
174 .iter()
175 .map(|g| {
176 let len = g[1];
177 let new = [offset, g[1]];
178 offset += len;
179 new
180 })
181 .collect();
182 self.groups = Cow::Owned(
183 GroupsType::Slice {
184 groups,
185 overlapping: false,
186 }
187 .into_sliceable(),
188 )
189 },
190 }
191 self.update_groups = UpdateGroups::No;
192 },
193 UpdateGroups::WithSeriesLen => {
194 let s = self.get_values().clone();
195 self.det_groups_from_list(s.as_materialized_series());
196 },
197 }
198 &self.groups
199 }
200
201 pub(crate) fn get_values(&self) -> &Column {
202 match &self.state {
203 AggState::NotAggregated(s)
204 | AggState::AggregatedScalar(s)
205 | AggState::AggregatedList(s) => s,
206 AggState::LiteralScalar(s) => s,
207 }
208 }
209
210 pub fn agg_state(&self) -> &AggState {
211 &self.state
212 }
213
214 pub(crate) fn is_not_aggregated(&self) -> bool {
215 matches!(
216 &self.state,
217 AggState::NotAggregated(_) | AggState::LiteralScalar(_)
218 )
219 }
220
221 pub(crate) fn is_aggregated(&self) -> bool {
222 !self.is_not_aggregated()
223 }
224
225 pub(crate) fn is_literal(&self) -> bool {
226 matches!(self.state, AggState::LiteralScalar(_))
227 }
228
229 fn new(
233 column: Column,
234 groups: Cow<'a, GroupPositions>,
235 aggregated: bool,
236 ) -> AggregationContext<'a> {
237 let series = if aggregated {
238 assert_eq!(column.len(), groups.len());
239 AggState::AggregatedScalar(column)
240 } else {
241 AggState::NotAggregated(column)
242 };
243
244 Self {
245 state: series,
246 groups,
247 update_groups: UpdateGroups::No,
248 original_len: true,
249 }
250 }
251
252 fn with_agg_state(&mut self, agg_state: AggState) {
253 self.state = agg_state;
254 }
255
256 fn from_agg_state(
257 agg_state: AggState,
258 groups: Cow<'a, GroupPositions>,
259 ) -> AggregationContext<'a> {
260 Self {
261 state: agg_state,
262 groups,
263 update_groups: UpdateGroups::No,
264 original_len: true,
265 }
266 }
267
268 pub(crate) fn set_original_len(&mut self, original_len: bool) -> &mut Self {
269 self.original_len = original_len;
270 self
271 }
272
273 pub(crate) fn with_update_groups(&mut self, update: UpdateGroups) -> &mut Self {
274 self.update_groups = update;
275 self
276 }
277
278 fn det_groups_from_list(&mut self, s: &Series) {
279 let mut offset = 0 as IdxSize;
280 let list = s
281 .list()
282 .expect("impl error, should be a list at this point");
283
284 match list.chunks().len() {
285 1 => {
286 let arr = list.downcast_iter().next().unwrap();
287 let offsets = arr.offsets().as_slice();
288
289 let mut previous = 0i64;
290 let groups = offsets[1..]
291 .iter()
292 .map(|&o| {
293 let len = (o - previous) as IdxSize;
294 let new_offset = offset + len;
295
296 previous = o;
297 let out = [offset, len];
298 offset = new_offset;
299 out
300 })
301 .collect_trusted();
302 self.groups = Cow::Owned(
303 GroupsType::Slice {
304 groups,
305 overlapping: false,
306 }
307 .into_sliceable(),
308 );
309 },
310 _ => {
311 let groups = {
312 self.get_values()
313 .list()
314 .expect("impl error, should be a list at this point")
315 .amortized_iter()
316 .map(|s| {
317 if let Some(s) = s {
318 let len = s.as_ref().len() as IdxSize;
319 let new_offset = offset + len;
320 let out = [offset, len];
321 offset = new_offset;
322 out
323 } else {
324 [offset, 0]
325 }
326 })
327 .collect_trusted()
328 };
329 self.groups = Cow::Owned(
330 GroupsType::Slice {
331 groups,
332 overlapping: false,
333 }
334 .into_sliceable(),
335 );
336 },
337 }
338 self.update_groups = UpdateGroups::No;
339 }
340
341 pub(crate) fn with_values(
345 &mut self,
346 column: Column,
347 aggregated: bool,
348 expr: Option<&Expr>,
349 ) -> PolarsResult<&mut Self> {
350 self.with_values_and_args(
351 column,
352 aggregated,
353 expr,
354 false,
355 self.agg_state().is_scalar(),
356 )
357 }
358
359 pub(crate) fn with_values_and_args(
360 &mut self,
361 column: Column,
362 aggregated: bool,
363 expr: Option<&Expr>,
364 preserve_literal: bool,
367 returns_scalar: bool,
368 ) -> PolarsResult<&mut Self> {
369 self.state = match (aggregated, column.dtype()) {
370 (true, &DataType::List(_)) if !returns_scalar => {
371 if column.len() != self.groups.len() {
372 let fmt_expr = if let Some(e) = expr {
373 format!("'{e:?}' ")
374 } else {
375 String::new()
376 };
377 polars_bail!(
378 ComputeError:
379 "aggregation expression '{}' produced a different number of elements: {} \
380 than the number of groups: {} (this is likely invalid)",
381 fmt_expr, column.len(), self.groups.len(),
382 );
383 }
384 AggState::AggregatedList(column)
385 },
386 (true, _) => AggState::AggregatedScalar(column),
387 _ => {
388 match self.state {
389 AggState::AggregatedScalar(_) => AggState::AggregatedScalar(column),
392 AggState::LiteralScalar(_) if column.len() == 1 && preserve_literal => {
394 AggState::LiteralScalar(column)
395 },
396 _ => AggState::NotAggregated(column.into_column()),
397 }
398 },
399 };
400 Ok(self)
401 }
402
403 pub(crate) fn with_literal(&mut self, column: Column) -> &mut Self {
404 self.state = AggState::LiteralScalar(column);
405 self
406 }
407
408 pub(crate) fn with_groups(&mut self, groups: GroupPositions) -> &mut Self {
410 if let AggState::AggregatedList(_) = self.agg_state() {
411 self.with_values(self.flat_naive().into_owned(), false, None)
413 .unwrap();
414 }
415 self.groups = Cow::Owned(groups);
416 self.update_groups = UpdateGroups::No;
418 self
419 }
420
421 pub fn normalize_values(&mut self) {
423 self.set_original_len(false);
424 self.groups();
425 let values = self.flat_naive();
426 let values = unsafe { values.agg_list(&self.groups) };
427 self.state = AggState::AggregatedList(values);
428 self.with_update_groups(UpdateGroups::WithGroupsLen);
429 }
430
431 pub fn aggregated_as_list<'b>(&'b mut self) -> Cow<'b, ListChunked> {
433 self.aggregated();
434 let out = self.get_values();
435 match self.agg_state() {
436 AggState::AggregatedScalar(_) => Cow::Owned(out.as_list()),
437 _ => Cow::Borrowed(out.list().unwrap()),
438 }
439 }
440
441 pub fn aggregated(&mut self) -> Column {
443 match self.state.clone() {
446 AggState::NotAggregated(s) => {
447 self.groups();
452 #[cfg(debug_assertions)]
453 {
454 if self.groups.len() > s.len() {
455 polars_warn!(
456 "groups may be out of bounds; more groups than elements in a series is only possible in dynamic group_by"
457 )
458 }
459 }
460
461 let out = unsafe { s.agg_list(&self.groups) };
464 self.state = AggState::AggregatedList(out.clone());
465
466 self.update_groups = UpdateGroups::WithGroupsLen;
467 out
468 },
469 AggState::AggregatedList(s) | AggState::AggregatedScalar(s) => s.into_column(),
470 AggState::LiteralScalar(s) => {
471 let rows = self.groups.len();
472 let s = s.implode().unwrap();
473 let s = s.new_from_index(0, rows);
474 let s = s.into_column();
475 self.state = AggState::AggregatedList(s.clone());
476 self.with_update_groups(UpdateGroups::WithSeriesLen);
477 s.clone()
478 },
479 }
480 }
481
482 pub fn finalize(&mut self) -> Column {
484 match &self.state {
487 AggState::LiteralScalar(c) => {
488 let c = c.clone();
489 self.groups();
490 let rows = self.groups.len();
491 c.new_from_index(0, rows)
492 },
493 _ => self.aggregated(),
494 }
495 }
496
497 fn arity_should_explode(&self) -> bool {
500 use AggState::*;
501 match self.agg_state() {
502 LiteralScalar(s) => s.len() == 1,
503 AggregatedScalar(_) => true,
504 _ => false,
505 }
506 }
507
508 pub fn get_final_aggregation(mut self) -> (Column, Cow<'a, GroupPositions>) {
509 let _ = self.groups();
510 let groups = self.groups;
511 match self.state {
512 AggState::NotAggregated(c) => (c, groups),
513 AggState::AggregatedScalar(c) => (c, groups),
514 AggState::LiteralScalar(c) => (c, groups),
515 AggState::AggregatedList(c) => {
516 let flattened = c.explode(true).unwrap();
517 let groups = groups.into_owned();
518 let groups = groups.unroll();
543 (flattened, Cow::Owned(groups))
544 },
545 }
546 }
547
548 pub(crate) fn flat_naive(&self) -> Cow<'_, Column> {
553 match &self.state {
554 AggState::NotAggregated(c) => Cow::Borrowed(c),
555 AggState::AggregatedList(c) => {
556 if cfg!(debug_assertions) {
557 if let GroupsType::Slice {
560 overlapping: true, ..
561 } = self.groups.as_ref().as_ref()
562 {
563 polars_warn!(
564 "performance - an aggregated list with overlapping groups may consume excessive memory"
565 )
566 }
567 }
568
569 Cow::Owned(c.explode(true).unwrap())
571 },
572 AggState::AggregatedScalar(c) => Cow::Borrowed(c),
573 AggState::LiteralScalar(c) => Cow::Borrowed(c),
574 }
575 }
576
577 fn flat_naive_length(&self) -> usize {
578 match &self.state {
579 AggState::NotAggregated(c) => c.len(),
580 AggState::AggregatedList(c) => c.list().unwrap().inner_length(),
581 AggState::AggregatedScalar(c) => c.len(),
582 AggState::LiteralScalar(_) => 1,
583 }
584 }
585
586 pub(crate) fn take(&mut self) -> Column {
588 let c = match &mut self.state {
589 AggState::NotAggregated(c)
590 | AggState::AggregatedScalar(c)
591 | AggState::AggregatedList(c) => c,
592 AggState::LiteralScalar(c) => c,
593 };
594 std::mem::take(c)
595 }
596
597 fn groups_cover_all_values(&mut self) -> bool {
599 if matches!(
600 self.state,
601 AggState::LiteralScalar(_) | AggState::AggregatedScalar(_)
602 ) {
603 return true;
604 }
605
606 let num_values = self.flat_naive_length();
607 match self.groups().as_ref().as_ref() {
608 GroupsType::Idx(groups) => {
609 let mut seen = MutableBitmap::from_len_zeroed(num_values);
610 for (_, g) in groups {
611 for i in g.iter() {
612 unsafe { seen.set_unchecked(*i as usize, true) };
613 }
614 }
615 seen.unset_bits() == 0
616 },
617 GroupsType::Slice {
618 groups,
619 overlapping: true,
620 } => {
621 let mut offset = 0;
623 let mut covers_all = true;
624 for [start, length] in groups {
625 covers_all &= *start <= offset;
626 offset = start + length;
627 }
628 covers_all && offset == num_values as IdxSize
629 },
630
631 GroupsType::Slice {
633 groups,
634 overlapping: false,
635 } => groups.iter().map(|[_, l]| *l as usize).sum::<usize>() == num_values,
636 }
637 }
638
639 fn set_groups_for_undefined_agg_states(&mut self) {
642 match &self.state {
643 AggState::AggregatedList(_) | AggState::NotAggregated(_) => {},
644 AggState::AggregatedScalar(c) => {
645 assert_eq!(self.update_groups, UpdateGroups::No);
646 self.groups = Cow::Owned(
647 GroupsType::Slice {
648 groups: (0..c.len() as IdxSize).map(|i| [i, 1]).collect(),
649 overlapping: false,
650 }
651 .into_sliceable(),
652 );
653 },
654 AggState::LiteralScalar(c) => {
655 assert_eq!(c.len(), 1);
656 assert_eq!(self.update_groups, UpdateGroups::No);
657 self.groups = Cow::Owned(
658 GroupsType::Slice {
659 groups: vec![[0, 1]; self.groups.len()],
660 overlapping: true,
661 }
662 .into_sliceable(),
663 );
664 },
665 }
666 }
667
668 pub fn into_static(&self) -> AggregationContext<'static> {
669 let groups: GroupPositions = GroupPositions::to_owned(&self.groups);
670 let groups: Cow<'static, GroupPositions> = Cow::Owned(groups);
671 AggregationContext {
672 state: self.state.clone(),
673 groups,
674 update_groups: self.update_groups,
675 original_len: self.original_len,
676 }
677 }
678}
679
680pub trait PhysicalExpr: Send + Sync {
683 fn as_expression(&self) -> Option<&Expr> {
684 None
685 }
686
687 fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult<Column>;
689
690 #[allow(clippy::ptr_arg)]
713 fn evaluate_on_groups<'a>(
714 &self,
715 df: &DataFrame,
716 groups: &'a GroupPositions,
717 state: &ExecutionState,
718 ) -> PolarsResult<AggregationContext<'a>>;
719
720 fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field>;
722
723 fn is_literal(&self) -> bool {
724 false
725 }
726 fn is_scalar(&self) -> bool;
727}
728
729impl Display for &dyn PhysicalExpr {
730 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
731 match self.as_expression() {
732 None => Ok(()),
733 Some(e) => write!(f, "{e:?}"),
734 }
735 }
736}
737
738pub struct PhysicalIoHelper {
742 pub expr: Arc<dyn PhysicalExpr>,
743 pub has_window_function: bool,
744}
745
746impl PhysicalIoExpr for PhysicalIoHelper {
747 fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
748 let mut state: ExecutionState = Default::default();
749 if self.has_window_function {
750 state.insert_has_window_function_flag();
751 }
752 self.expr.evaluate(df, &state).map(|c| {
753 debug_assert_eq!(c.dtype(), &DataType::Boolean);
755 (if c.len() == 1 && df.height() != 1 {
756 c.new_from_index(0, df.height())
758 } else {
759 c
760 })
761 .take_materialized_series()
762 })
763 }
764}
765
766pub fn phys_expr_to_io_expr(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalIoExpr> {
767 let has_window_function = if let Some(expr) = expr.as_expression() {
768 expr.into_iter()
769 .any(|expr| matches!(expr, Expr::Window { .. }))
770 } else {
771 false
772 };
773 Arc::new(PhysicalIoHelper {
774 expr,
775 has_window_function,
776 }) as Arc<dyn PhysicalIoExpr>
777}