1mod aggregation;
2mod alias;
3mod apply;
4mod binary;
5mod cast;
6mod column;
7mod count;
8mod eval;
9mod filter;
10mod gather;
11mod group_iter;
12mod literal;
13#[cfg(feature = "dynamic_group_by")]
14mod rolling;
15mod slice;
16mod sort;
17mod sortby;
18mod ternary;
19mod window;
20
21use std::borrow::Cow;
22use std::fmt::{Display, Formatter};
23
24pub(crate) use aggregation::*;
25pub(crate) use alias::*;
26pub(crate) use apply::*;
27use arrow::array::ArrayRef;
28use arrow::legacy::utils::CustomIterTools;
29pub(crate) use binary::*;
30pub(crate) use cast::*;
31pub(crate) use column::*;
32pub(crate) use count::*;
33pub(crate) use eval::*;
34pub(crate) use filter::*;
35pub(crate) use gather::*;
36pub(crate) use literal::*;
37use polars_core::prelude::*;
38use polars_io::predicates::PhysicalIoExpr;
39use polars_plan::prelude::*;
40#[cfg(feature = "dynamic_group_by")]
41pub(crate) use rolling::RollingExpr;
42pub(crate) use slice::*;
43pub(crate) use sort::*;
44pub(crate) use sortby::*;
45pub(crate) use ternary::*;
46pub use window::window_function_format_order_by;
47pub(crate) use window::*;
48
49use crate::state::ExecutionState;
50
51#[derive(Clone, Debug)]
52pub enum AggState {
53 AggregatedList(Column),
56 AggregatedScalar(Column),
60 NotAggregated(Column),
62 Literal(Column),
63}
64
65impl AggState {
66 fn try_map<F>(&self, func: F) -> PolarsResult<Self>
67 where
68 F: FnOnce(&Column) -> PolarsResult<Column>,
69 {
70 Ok(match self {
71 AggState::AggregatedList(c) => AggState::AggregatedList(func(c)?),
72 AggState::AggregatedScalar(c) => AggState::AggregatedScalar(func(c)?),
73 AggState::Literal(c) => AggState::Literal(func(c)?),
74 AggState::NotAggregated(c) => AggState::NotAggregated(func(c)?),
75 })
76 }
77
78 fn is_scalar(&self) -> bool {
79 matches!(self, Self::AggregatedScalar(_))
80 }
81}
82
83#[cfg_attr(debug_assertions, derive(Debug))]
85#[derive(PartialEq, Clone, Copy)]
86pub(crate) enum UpdateGroups {
87 No,
89 WithGroupsLen,
92 WithSeriesLen,
96}
97
98#[cfg_attr(debug_assertions, derive(Debug))]
99pub struct AggregationContext<'a> {
100 state: AggState,
104 groups: Cow<'a, GroupPositions>,
106 sorted: bool,
110 update_groups: UpdateGroups,
114 original_len: bool,
117}
118
119impl<'a> AggregationContext<'a> {
120 pub(crate) fn dtype(&self) -> DataType {
121 match &self.state {
122 AggState::Literal(s) => s.dtype().clone(),
123 AggState::AggregatedList(s) => s.list().unwrap().inner_dtype().clone(),
124 AggState::AggregatedScalar(s) => s.dtype().clone(),
125 AggState::NotAggregated(s) => s.dtype().clone(),
126 }
127 }
128 pub(crate) fn groups(&mut self) -> &Cow<'a, GroupPositions> {
129 match self.update_groups {
130 UpdateGroups::No => {},
131 UpdateGroups::WithGroupsLen => {
132 let mut offset = 0 as IdxSize;
137
138 match self.groups.as_ref().as_ref() {
139 GroupsType::Idx(groups) => {
140 let groups = groups
141 .iter()
142 .map(|g| {
143 let len = g.1.len() as IdxSize;
144 let new_offset = offset + len;
145 let out = [offset, len];
146 offset = new_offset;
147 out
148 })
149 .collect();
150 self.groups = Cow::Owned(
151 GroupsType::Slice {
152 groups,
153 rolling: false,
154 }
155 .into_sliceable(),
156 )
157 },
158 GroupsType::Slice { .. } => {},
160 }
161 self.update_groups = UpdateGroups::No;
162 },
163 UpdateGroups::WithSeriesLen => {
164 let s = self.get_values().clone();
165 self.det_groups_from_list(s.as_materialized_series());
166 },
167 }
168 &self.groups
169 }
170
171 pub(crate) fn get_values(&self) -> &Column {
172 match &self.state {
173 AggState::NotAggregated(s)
174 | AggState::AggregatedScalar(s)
175 | AggState::AggregatedList(s) => s,
176 AggState::Literal(s) => s,
177 }
178 }
179
180 pub fn agg_state(&self) -> &AggState {
181 &self.state
182 }
183
184 pub(crate) fn is_not_aggregated(&self) -> bool {
185 matches!(
186 &self.state,
187 AggState::NotAggregated(_) | AggState::Literal(_)
188 )
189 }
190
191 pub(crate) fn is_aggregated(&self) -> bool {
192 !self.is_not_aggregated()
193 }
194
195 pub(crate) fn is_literal(&self) -> bool {
196 matches!(self.state, AggState::Literal(_))
197 }
198
199 fn new(
203 column: Column,
204 groups: Cow<'a, GroupPositions>,
205 aggregated: bool,
206 ) -> AggregationContext<'a> {
207 let series = match (aggregated, column.dtype()) {
208 (true, &DataType::List(_)) => {
209 assert_eq!(column.len(), groups.len());
210 AggState::AggregatedList(column)
211 },
212 (true, _) => {
213 assert_eq!(column.len(), groups.len());
214 AggState::AggregatedScalar(column)
215 },
216 _ => AggState::NotAggregated(column),
217 };
218
219 Self {
220 state: series,
221 groups,
222 sorted: false,
223 update_groups: UpdateGroups::No,
224 original_len: true,
225 }
226 }
227
228 fn with_agg_state(&mut self, agg_state: AggState) {
229 self.state = agg_state;
230 }
231
232 fn from_agg_state(
233 agg_state: AggState,
234 groups: Cow<'a, GroupPositions>,
235 ) -> AggregationContext<'a> {
236 Self {
237 state: agg_state,
238 groups,
239 sorted: false,
240 update_groups: UpdateGroups::No,
241 original_len: true,
242 }
243 }
244
245 fn from_literal(lit: Column, groups: Cow<'a, GroupPositions>) -> AggregationContext<'a> {
246 Self {
247 state: AggState::Literal(lit),
248 groups,
249 sorted: false,
250 update_groups: UpdateGroups::No,
251 original_len: true,
252 }
253 }
254
255 pub(crate) fn set_original_len(&mut self, original_len: bool) -> &mut Self {
256 self.original_len = original_len;
257 self
258 }
259
260 pub(crate) fn with_update_groups(&mut self, update: UpdateGroups) -> &mut Self {
261 self.update_groups = update;
262 self
263 }
264
265 fn det_groups_from_list(&mut self, s: &Series) {
266 let mut offset = 0 as IdxSize;
267 let list = s
268 .list()
269 .expect("impl error, should be a list at this point");
270
271 match list.chunks().len() {
272 1 => {
273 let arr = list.downcast_iter().next().unwrap();
274 let offsets = arr.offsets().as_slice();
275
276 let mut previous = 0i64;
277 let groups = offsets[1..]
278 .iter()
279 .map(|&o| {
280 let len = (o - previous) as IdxSize;
281 let new_offset = offset + len + (len == 0) as IdxSize;
284
285 previous = o;
286 let out = [offset, len];
287 offset = new_offset;
288 out
289 })
290 .collect_trusted();
291 self.groups = Cow::Owned(
292 GroupsType::Slice {
293 groups,
294 rolling: false,
295 }
296 .into_sliceable(),
297 );
298 },
299 _ => {
300 let groups = {
301 self.get_values()
302 .list()
303 .expect("impl error, should be a list at this point")
304 .amortized_iter()
305 .map(|s| {
306 if let Some(s) = s {
307 let len = s.as_ref().len() as IdxSize;
308 let new_offset = offset + len;
309 let out = [offset, len];
310 offset = new_offset;
311 out
312 } else {
313 [offset, 0]
314 }
315 })
316 .collect_trusted()
317 };
318 self.groups = Cow::Owned(
319 GroupsType::Slice {
320 groups,
321 rolling: false,
322 }
323 .into_sliceable(),
324 );
325 },
326 }
327 self.update_groups = UpdateGroups::No;
328 }
329
330 pub(crate) fn with_values(
334 &mut self,
335 column: Column,
336 aggregated: bool,
337 expr: Option<&Expr>,
338 ) -> PolarsResult<&mut Self> {
339 self.with_values_and_args(
340 column,
341 aggregated,
342 expr,
343 false,
344 self.agg_state().is_scalar(),
345 )
346 }
347
348 pub(crate) fn with_values_and_args(
349 &mut self,
350 column: Column,
351 aggregated: bool,
352 expr: Option<&Expr>,
353 mapped: bool,
356 returns_scalar: bool,
357 ) -> PolarsResult<&mut Self> {
358 self.state = match (aggregated, column.dtype()) {
359 (true, &DataType::List(_)) if !returns_scalar => {
360 if column.len() != self.groups.len() {
361 let fmt_expr = if let Some(e) = expr {
362 format!("'{e:?}' ")
363 } else {
364 String::new()
365 };
366 polars_bail!(
367 ComputeError:
368 "aggregation expression '{}' produced a different number of elements: {} \
369 than the number of groups: {} (this is likely invalid)",
370 fmt_expr, column.len(), self.groups.len(),
371 );
372 }
373 AggState::AggregatedList(column)
374 },
375 (true, _) => AggState::AggregatedScalar(column),
376 _ => {
377 match self.state {
378 AggState::AggregatedScalar(_) => AggState::AggregatedScalar(column),
381 AggState::Literal(_) if column.len() == 1 && mapped => {
383 AggState::Literal(column)
384 },
385 _ => AggState::NotAggregated(column.into_column()),
386 }
387 },
388 };
389 Ok(self)
390 }
391
392 pub(crate) fn with_literal(&mut self, column: Column) -> &mut Self {
393 self.state = AggState::Literal(column);
394 self
395 }
396
397 pub(crate) fn with_groups(&mut self, groups: GroupPositions) -> &mut Self {
399 if let AggState::AggregatedList(_) = self.agg_state() {
400 self.with_values(self.flat_naive().into_owned(), false, None)
402 .unwrap();
403 }
404 self.groups = Cow::Owned(groups);
405 self.update_groups = UpdateGroups::No;
407 self
408 }
409
410 pub(crate) fn _implode_no_agg(&mut self) {
411 match self.state.clone() {
412 AggState::NotAggregated(_) => {
413 let _ = self.aggregated();
414 let AggState::AggregatedList(s) = self.state.clone() else {
415 unreachable!()
416 };
417 self.state = AggState::AggregatedScalar(s);
418 },
419 AggState::AggregatedList(s) => {
420 self.state = AggState::AggregatedScalar(s);
421 },
422 _ => unreachable!("should only be called in non-agg/list-agg state by aggregation.rs"),
423 }
424 }
425
426 pub fn aggregated(&mut self) -> Column {
428 match self.state.clone() {
431 AggState::NotAggregated(s) => {
432 self.groups();
437 #[cfg(debug_assertions)]
438 {
439 if self.groups.len() > s.len() {
440 polars_warn!(
441 "groups may be out of bounds; more groups than elements in a series is only possible in dynamic group_by"
442 )
443 }
444 }
445
446 let out = unsafe { s.agg_list(&self.groups) };
449 self.state = AggState::AggregatedList(out.clone());
450
451 self.sorted = true;
452 self.update_groups = UpdateGroups::WithGroupsLen;
453 out
454 },
455 AggState::AggregatedList(s) | AggState::AggregatedScalar(s) => s.into_column(),
456 AggState::Literal(s) => {
457 self.groups();
458 let rows = self.groups.len();
459 let s = s.new_from_index(0, rows);
460 let out = s
461 .reshape_list(&[
462 ReshapeDimension::new_dimension(rows as u64),
463 ReshapeDimension::Infer,
464 ])
465 .unwrap();
466 self.state = AggState::AggregatedList(out.clone());
467 out.into_column()
468 },
469 }
470 }
471
472 pub fn finalize(&mut self) -> Column {
474 match &self.state {
477 AggState::Literal(c) => {
478 let c = c.clone();
479 self.groups();
480 let rows = self.groups.len();
481 c.new_from_index(0, rows)
482 },
483 _ => self.aggregated(),
484 }
485 }
486
487 fn arity_should_explode(&self) -> bool {
490 use AggState::*;
491 match self.agg_state() {
492 Literal(s) => s.len() == 1,
493 AggregatedScalar(_) => true,
494 _ => false,
495 }
496 }
497
498 pub fn get_final_aggregation(mut self) -> (Column, Cow<'a, GroupPositions>) {
499 let _ = self.groups();
500 let groups = self.groups;
501 match self.state {
502 AggState::NotAggregated(c) => (c, groups),
503 AggState::AggregatedScalar(c) => (c, groups),
504 AggState::Literal(c) => (c, groups),
505 AggState::AggregatedList(c) => {
506 let flattened = c.explode(false).unwrap();
507 let groups = groups.into_owned();
508 let groups = groups.unroll();
533 (flattened, Cow::Owned(groups))
534 },
535 }
536 }
537
538 pub(crate) fn flat_naive(&self) -> Cow<'_, Column> {
543 match &self.state {
544 AggState::NotAggregated(c) => Cow::Borrowed(c),
545 AggState::AggregatedList(c) => {
546 #[cfg(debug_assertions)]
547 {
548 if let GroupsType::Slice { rolling: true, .. } = self.groups.as_ref().as_ref() {
551 panic!(
552 "implementation error, polars should not hit this branch for overlapping groups"
553 )
554 }
555 }
556
557 Cow::Owned(c.explode(false).unwrap())
558 },
559 AggState::AggregatedScalar(c) => Cow::Borrowed(c),
560 AggState::Literal(c) => Cow::Borrowed(c),
561 }
562 }
563
564 pub(crate) fn take(&mut self) -> Column {
566 let c = match &mut self.state {
567 AggState::NotAggregated(c)
568 | AggState::AggregatedScalar(c)
569 | AggState::AggregatedList(c) => c,
570 AggState::Literal(c) => c,
571 };
572 std::mem::take(c)
573 }
574}
575
576pub trait PhysicalExpr: Send + Sync {
579 fn as_expression(&self) -> Option<&Expr> {
580 None
581 }
582
583 fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult<Column>;
585
586 fn evaluate_inline(&self) -> Option<Column> {
592 self.evaluate_inline_impl(4)
593 }
594
595 fn evaluate_inline_impl(&self, _depth_limit: u8) -> Option<Column> {
597 None
598 }
599
600 #[allow(clippy::ptr_arg)]
623 fn evaluate_on_groups<'a>(
624 &self,
625 df: &DataFrame,
626 groups: &'a GroupPositions,
627 state: &ExecutionState,
628 ) -> PolarsResult<AggregationContext<'a>>;
629
630 fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field>;
632
633 fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
635 None
636 }
637
638 fn is_literal(&self) -> bool {
639 false
640 }
641 fn is_scalar(&self) -> bool;
642}
643
644impl Display for &dyn PhysicalExpr {
645 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
646 match self.as_expression() {
647 None => Ok(()),
648 Some(e) => write!(f, "{e:?}"),
649 }
650 }
651}
652
653pub struct PhysicalIoHelper {
657 pub expr: Arc<dyn PhysicalExpr>,
658 pub has_window_function: bool,
659}
660
661impl PhysicalIoExpr for PhysicalIoHelper {
662 fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
663 let mut state: ExecutionState = Default::default();
664 if self.has_window_function {
665 state.insert_has_window_function_flag();
666 }
667 self.expr
668 .evaluate(df, &state)
669 .map(|c| c.take_materialized_series())
670 }
671}
672
673pub fn phys_expr_to_io_expr(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalIoExpr> {
674 let has_window_function = if let Some(expr) = expr.as_expression() {
675 expr.into_iter()
676 .any(|expr| matches!(expr, Expr::Window { .. }))
677 } else {
678 false
679 };
680 Arc::new(PhysicalIoHelper {
681 expr,
682 has_window_function,
683 }) as Arc<dyn PhysicalIoExpr>
684}
685
686pub trait PartitionedAggregation: Send + Sync + PhysicalExpr {
687 #[allow(clippy::ptr_arg)]
695 fn evaluate_partitioned(
696 &self,
697 df: &DataFrame,
698 groups: &GroupPositions,
699 state: &ExecutionState,
700 ) -> PolarsResult<Column>;
701
702 #[allow(clippy::ptr_arg)]
704 fn finalize(
705 &self,
706 partitioned: Column,
707 groups: &GroupPositions,
708 state: &ExecutionState,
709 ) -> PolarsResult<Column>;
710}