1mod aggregation;
2mod alias;
3mod apply;
4mod binary;
5mod cast;
6mod column;
7mod count;
8mod filter;
9mod gather;
10mod group_iter;
11mod literal;
12#[cfg(feature = "dynamic_group_by")]
13mod rolling;
14mod slice;
15mod sort;
16mod sortby;
17mod ternary;
18mod window;
19
20use std::borrow::Cow;
21use std::fmt::{Display, Formatter};
22
23pub(crate) use aggregation::*;
24pub(crate) use alias::*;
25pub(crate) use apply::*;
26use arrow::array::ArrayRef;
27use arrow::legacy::utils::CustomIterTools;
28pub(crate) use binary::*;
29pub(crate) use cast::*;
30pub(crate) use column::*;
31pub(crate) use count::*;
32pub(crate) use filter::*;
33pub(crate) use gather::*;
34pub(crate) use literal::*;
35use polars_core::prelude::*;
36use polars_io::predicates::PhysicalIoExpr;
37use polars_plan::prelude::*;
38#[cfg(feature = "dynamic_group_by")]
39pub(crate) use rolling::RollingExpr;
40pub(crate) use slice::*;
41pub(crate) use sort::*;
42pub(crate) use sortby::*;
43pub(crate) use ternary::*;
44pub use window::window_function_format_order_by;
45pub(crate) use window::*;
46
47use crate::state::ExecutionState;
48
49#[derive(Clone, Debug)]
50pub enum AggState {
51 AggregatedList(Column),
54 AggregatedScalar(Column),
58 NotAggregated(Column),
60 Literal(Column),
61}
62
63impl AggState {
64 fn try_map<F>(&self, func: F) -> PolarsResult<Self>
65 where
66 F: FnOnce(&Column) -> PolarsResult<Column>,
67 {
68 Ok(match self {
69 AggState::AggregatedList(c) => AggState::AggregatedList(func(c)?),
70 AggState::AggregatedScalar(c) => AggState::AggregatedScalar(func(c)?),
71 AggState::Literal(c) => AggState::Literal(func(c)?),
72 AggState::NotAggregated(c) => AggState::NotAggregated(func(c)?),
73 })
74 }
75}
76
77#[cfg_attr(debug_assertions, derive(Debug))]
79#[derive(PartialEq, Clone, Copy)]
80pub(crate) enum UpdateGroups {
81 No,
83 WithGroupsLen,
86 WithSeriesLen,
90}
91
92#[cfg_attr(debug_assertions, derive(Debug))]
93pub struct AggregationContext<'a> {
94 state: AggState,
98 groups: Cow<'a, GroupPositions>,
100 sorted: bool,
104 update_groups: UpdateGroups,
108 original_len: bool,
111}
112
113impl<'a> AggregationContext<'a> {
114 pub(crate) fn dtype(&self) -> DataType {
115 match &self.state {
116 AggState::Literal(s) => s.dtype().clone(),
117 AggState::AggregatedList(s) => s.list().unwrap().inner_dtype().clone(),
118 AggState::AggregatedScalar(s) => s.dtype().clone(),
119 AggState::NotAggregated(s) => s.dtype().clone(),
120 }
121 }
122 pub(crate) fn groups(&mut self) -> &Cow<'a, GroupPositions> {
123 match self.update_groups {
124 UpdateGroups::No => {},
125 UpdateGroups::WithGroupsLen => {
126 let mut offset = 0 as IdxSize;
131
132 match self.groups.as_ref().as_ref() {
133 GroupsType::Idx(groups) => {
134 let groups = groups
135 .iter()
136 .map(|g| {
137 let len = g.1.len() as IdxSize;
138 let new_offset = offset + len;
139 let out = [offset, len];
140 offset = new_offset;
141 out
142 })
143 .collect();
144 self.groups = Cow::Owned(
145 GroupsType::Slice {
146 groups,
147 rolling: false,
148 }
149 .into_sliceable(),
150 )
151 },
152 GroupsType::Slice { .. } => {},
154 }
155 self.update_groups = UpdateGroups::No;
156 },
157 UpdateGroups::WithSeriesLen => {
158 let s = self.get_values().clone();
159 self.det_groups_from_list(s.as_materialized_series());
160 },
161 }
162 &self.groups
163 }
164
165 pub(crate) fn get_values(&self) -> &Column {
166 match &self.state {
167 AggState::NotAggregated(s)
168 | AggState::AggregatedScalar(s)
169 | AggState::AggregatedList(s) => s,
170 AggState::Literal(s) => s,
171 }
172 }
173
174 pub fn agg_state(&self) -> &AggState {
175 &self.state
176 }
177
178 pub(crate) fn is_not_aggregated(&self) -> bool {
179 matches!(
180 &self.state,
181 AggState::NotAggregated(_) | AggState::Literal(_)
182 )
183 }
184
185 pub(crate) fn is_aggregated(&self) -> bool {
186 !self.is_not_aggregated()
187 }
188
189 pub(crate) fn is_literal(&self) -> bool {
190 matches!(self.state, AggState::Literal(_))
191 }
192
193 fn new(
197 column: Column,
198 groups: Cow<'a, GroupPositions>,
199 aggregated: bool,
200 ) -> AggregationContext<'a> {
201 let series = match (aggregated, column.dtype()) {
202 (true, &DataType::List(_)) => {
203 assert_eq!(column.len(), groups.len());
204 AggState::AggregatedList(column)
205 },
206 (true, _) => {
207 assert_eq!(column.len(), groups.len());
208 AggState::AggregatedScalar(column)
209 },
210 _ => AggState::NotAggregated(column),
211 };
212
213 Self {
214 state: series,
215 groups,
216 sorted: false,
217 update_groups: UpdateGroups::No,
218 original_len: true,
219 }
220 }
221
222 fn with_agg_state(&mut self, agg_state: AggState) {
223 self.state = agg_state;
224 }
225
226 fn from_agg_state(
227 agg_state: AggState,
228 groups: Cow<'a, GroupPositions>,
229 ) -> AggregationContext<'a> {
230 Self {
231 state: agg_state,
232 groups,
233 sorted: false,
234 update_groups: UpdateGroups::No,
235 original_len: true,
236 }
237 }
238
239 fn from_literal(lit: Column, groups: Cow<'a, GroupPositions>) -> AggregationContext<'a> {
240 Self {
241 state: AggState::Literal(lit),
242 groups,
243 sorted: false,
244 update_groups: UpdateGroups::No,
245 original_len: true,
246 }
247 }
248
249 pub(crate) fn set_original_len(&mut self, original_len: bool) -> &mut Self {
250 self.original_len = original_len;
251 self
252 }
253
254 pub(crate) fn with_update_groups(&mut self, update: UpdateGroups) -> &mut Self {
255 self.update_groups = update;
256 self
257 }
258
259 pub(crate) fn det_groups_from_list(&mut self, s: &Series) {
260 let mut offset = 0 as IdxSize;
261 let list = s
262 .list()
263 .expect("impl error, should be a list at this point");
264
265 match list.chunks().len() {
266 1 => {
267 let arr = list.downcast_iter().next().unwrap();
268 let offsets = arr.offsets().as_slice();
269
270 let mut previous = 0i64;
271 let groups = offsets[1..]
272 .iter()
273 .map(|&o| {
274 let len = (o - previous) as IdxSize;
275 let new_offset = offset + len + (len == 0) as IdxSize;
278
279 previous = o;
280 let out = [offset, len];
281 offset = new_offset;
282 out
283 })
284 .collect_trusted();
285 self.groups = Cow::Owned(
286 GroupsType::Slice {
287 groups,
288 rolling: false,
289 }
290 .into_sliceable(),
291 );
292 },
293 _ => {
294 let groups = {
295 self.get_values()
296 .list()
297 .expect("impl error, should be a list at this point")
298 .amortized_iter()
299 .map(|s| {
300 if let Some(s) = s {
301 let len = s.as_ref().len() as IdxSize;
302 let new_offset = offset + len;
303 let out = [offset, len];
304 offset = new_offset;
305 out
306 } else {
307 [offset, 0]
308 }
309 })
310 .collect_trusted()
311 };
312 self.groups = Cow::Owned(
313 GroupsType::Slice {
314 groups,
315 rolling: false,
316 }
317 .into_sliceable(),
318 );
319 },
320 }
321 self.update_groups = UpdateGroups::No;
322 }
323
324 pub(crate) fn with_values(
328 &mut self,
329 column: Column,
330 aggregated: bool,
331 expr: Option<&Expr>,
332 ) -> PolarsResult<&mut Self> {
333 self.with_values_and_args(column, aggregated, expr, false)
334 }
335
336 pub(crate) fn with_values_and_args(
337 &mut self,
338 column: Column,
339 aggregated: bool,
340 expr: Option<&Expr>,
341 mapped: bool,
344 ) -> PolarsResult<&mut Self> {
345 self.state = match (aggregated, column.dtype()) {
346 (true, &DataType::List(_)) => {
347 if column.len() != self.groups.len() {
348 let fmt_expr = if let Some(e) = expr {
349 format!("'{e:?}' ")
350 } else {
351 String::new()
352 };
353 polars_bail!(
354 ComputeError:
355 "aggregation expression '{}' produced a different number of elements: {} \
356 than the number of groups: {} (this is likely invalid)",
357 fmt_expr, column.len(), self.groups.len(),
358 );
359 }
360 AggState::AggregatedList(column)
361 },
362 (true, _) => AggState::AggregatedScalar(column),
363 _ => {
364 match self.state {
365 AggState::AggregatedScalar(_) => AggState::AggregatedScalar(column),
368 AggState::Literal(_) if column.len() == 1 && mapped => {
370 AggState::Literal(column)
371 },
372 _ => AggState::NotAggregated(column.into_column()),
373 }
374 },
375 };
376 Ok(self)
377 }
378
379 pub(crate) fn with_literal(&mut self, column: Column) -> &mut Self {
380 self.state = AggState::Literal(column);
381 self
382 }
383
384 pub(crate) fn with_groups(&mut self, groups: GroupPositions) -> &mut Self {
386 if let AggState::AggregatedList(_) = self.agg_state() {
387 self.with_values(self.flat_naive().into_owned(), false, None)
389 .unwrap();
390 }
391 self.groups = Cow::Owned(groups);
392 self.update_groups = UpdateGroups::No;
394 self
395 }
396
397 pub fn aggregated(&mut self) -> Column {
399 match self.state.clone() {
402 AggState::NotAggregated(s) => {
403 self.groups();
408 #[cfg(debug_assertions)]
409 {
410 if self.groups.len() > s.len() {
411 polars_warn!("groups may be out of bounds; more groups than elements in a series is only possible in dynamic group_by")
412 }
413 }
414
415 let out = unsafe { s.agg_list(&self.groups) };
418 self.state = AggState::AggregatedList(out.clone());
419
420 self.sorted = true;
421 self.update_groups = UpdateGroups::WithGroupsLen;
422 out
423 },
424 AggState::AggregatedList(s) | AggState::AggregatedScalar(s) => s.into_column(),
425 AggState::Literal(s) => {
426 self.groups();
427 let rows = self.groups.len();
428 let s = s.new_from_index(0, rows);
429 let out = s
430 .reshape_list(&[
431 ReshapeDimension::new_dimension(rows as u64),
432 ReshapeDimension::Infer,
433 ])
434 .unwrap();
435 self.state = AggState::AggregatedList(out.clone());
436 out.into_column()
437 },
438 }
439 }
440
441 pub fn finalize(&mut self) -> Column {
443 match &self.state {
446 AggState::Literal(c) => {
447 let c = c.clone();
448 self.groups();
449 let rows = self.groups.len();
450 c.new_from_index(0, rows)
451 },
452 _ => self.aggregated(),
453 }
454 }
455
456 fn arity_should_explode(&self) -> bool {
459 use AggState::*;
460 match self.agg_state() {
461 Literal(s) => s.len() == 1,
462 AggregatedScalar(_) => true,
463 _ => false,
464 }
465 }
466
467 pub fn get_final_aggregation(mut self) -> (Column, Cow<'a, GroupPositions>) {
468 let _ = self.groups();
469 let groups = self.groups;
470 match self.state {
471 AggState::NotAggregated(c) => (c, groups),
472 AggState::AggregatedScalar(c) => (c, groups),
473 AggState::Literal(c) => (c, groups),
474 AggState::AggregatedList(c) => {
475 let flattened = c.explode().unwrap();
476 let groups = groups.into_owned();
477 let groups = groups.unroll();
502 (flattened, Cow::Owned(groups))
503 },
504 }
505 }
506
507 pub(crate) fn flat_naive(&self) -> Cow<'_, Column> {
512 match &self.state {
513 AggState::NotAggregated(c) => Cow::Borrowed(c),
514 AggState::AggregatedList(c) => {
515 #[cfg(debug_assertions)]
516 {
517 if let GroupsType::Slice { rolling: true, .. } = self.groups.as_ref().as_ref() {
520 panic!("implementation error, polars should not hit this branch for overlapping groups")
521 }
522 }
523
524 Cow::Owned(c.explode().unwrap())
525 },
526 AggState::AggregatedScalar(c) => Cow::Borrowed(c),
527 AggState::Literal(c) => Cow::Borrowed(c),
528 }
529 }
530
531 pub(crate) fn take(&mut self) -> Column {
533 let c = match &mut self.state {
534 AggState::NotAggregated(c)
535 | AggState::AggregatedScalar(c)
536 | AggState::AggregatedList(c) => c,
537 AggState::Literal(c) => c,
538 };
539 std::mem::take(c)
540 }
541}
542
543pub trait PhysicalExpr: Send + Sync {
546 fn as_expression(&self) -> Option<&Expr> {
547 None
548 }
549
550 fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult<Column>;
552
553 fn evaluate_inline(&self) -> Option<Column> {
559 self.evaluate_inline_impl(4)
560 }
561
562 fn evaluate_inline_impl(&self, _depth_limit: u8) -> Option<Column> {
564 None
565 }
566
567 #[allow(clippy::ptr_arg)]
590 fn evaluate_on_groups<'a>(
591 &self,
592 df: &DataFrame,
593 groups: &'a GroupPositions,
594 state: &ExecutionState,
595 ) -> PolarsResult<AggregationContext<'a>>;
596
597 fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field>;
599
600 fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
602 None
603 }
604
605 fn collect_live_columns(&self, lv: &mut PlIndexSet<PlSmallStr>);
608
609 fn as_stats_evaluator(&self) -> Option<&dyn polars_io::predicates::StatsEvaluator> {
613 None
614 }
615 fn is_literal(&self) -> bool {
616 false
617 }
618 fn is_scalar(&self) -> bool;
619}
620
621impl Display for &dyn PhysicalExpr {
622 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
623 match self.as_expression() {
624 None => Ok(()),
625 Some(e) => write!(f, "{e:?}"),
626 }
627 }
628}
629
630pub struct PhysicalIoHelper {
634 pub expr: Arc<dyn PhysicalExpr>,
635 pub has_window_function: bool,
636}
637
638impl PhysicalIoExpr for PhysicalIoHelper {
639 fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
640 let mut state: ExecutionState = Default::default();
641 if self.has_window_function {
642 state.insert_has_window_function_flag();
643 }
644 self.expr
645 .evaluate(df, &state)
646 .map(|c| c.take_materialized_series())
647 }
648
649 fn collect_live_columns(&self, live_columns: &mut PlIndexSet<PlSmallStr>) {
650 self.expr.collect_live_columns(live_columns);
651 }
652
653 #[cfg(feature = "parquet")]
654 fn as_stats_evaluator(&self) -> Option<&dyn polars_io::predicates::StatsEvaluator> {
655 self.expr.as_stats_evaluator()
656 }
657}
658
659pub fn phys_expr_to_io_expr(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalIoExpr> {
660 let has_window_function = if let Some(expr) = expr.as_expression() {
661 expr.into_iter()
662 .any(|expr| matches!(expr, Expr::Window { .. }))
663 } else {
664 false
665 };
666 Arc::new(PhysicalIoHelper {
667 expr,
668 has_window_function,
669 }) as Arc<dyn PhysicalIoExpr>
670}
671
672pub trait PartitionedAggregation: Send + Sync + PhysicalExpr {
673 #[allow(clippy::ptr_arg)]
681 fn evaluate_partitioned(
682 &self,
683 df: &DataFrame,
684 groups: &GroupPositions,
685 state: &ExecutionState,
686 ) -> PolarsResult<Column>;
687
688 #[allow(clippy::ptr_arg)]
690 fn finalize(
691 &self,
692 partitioned: Column,
693 groups: &GroupPositions,
694 state: &ExecutionState,
695 ) -> PolarsResult<Column>;
696}