datafusion_python/expr/
aggregate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion::common::DataFusionError;
19use datafusion::logical_expr::expr::{AggregateFunction, AggregateFunctionParams, Alias};
20use datafusion::logical_expr::logical_plan::Aggregate;
21use datafusion::logical_expr::Expr;
22use pyo3::{prelude::*, IntoPyObjectExt};
23use std::fmt::{self, Display, Formatter};
24
25use super::logical_node::LogicalNode;
26use crate::common::df_schema::PyDFSchema;
27use crate::errors::py_type_err;
28use crate::expr::PyExpr;
29use crate::sql::logical::PyLogicalPlan;
30
31#[pyclass(name = "Aggregate", module = "datafusion.expr", subclass)]
32#[derive(Clone)]
33pub struct PyAggregate {
34    aggregate: Aggregate,
35}
36
37impl From<Aggregate> for PyAggregate {
38    fn from(aggregate: Aggregate) -> PyAggregate {
39        PyAggregate { aggregate }
40    }
41}
42
43impl TryFrom<PyAggregate> for Aggregate {
44    type Error = DataFusionError;
45
46    fn try_from(agg: PyAggregate) -> Result<Self, Self::Error> {
47        Ok(agg.aggregate)
48    }
49}
50
51impl Display for PyAggregate {
52    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
53        write!(
54            f,
55            "Aggregate
56            \nGroupBy(s): {:?}
57            \nAggregates(s): {:?}
58            \nInput: {:?}
59            \nProjected Schema: {:?}",
60            &self.aggregate.group_expr,
61            &self.aggregate.aggr_expr,
62            self.aggregate.input,
63            self.aggregate.schema
64        )
65    }
66}
67
68#[pymethods]
69impl PyAggregate {
70    /// Retrieves the grouping expressions for this `Aggregate`
71    fn group_by_exprs(&self) -> PyResult<Vec<PyExpr>> {
72        Ok(self
73            .aggregate
74            .group_expr
75            .iter()
76            .map(|e| PyExpr::from(e.clone()))
77            .collect())
78    }
79
80    /// Retrieves the aggregate expressions for this `Aggregate`
81    fn aggregate_exprs(&self) -> PyResult<Vec<PyExpr>> {
82        Ok(self
83            .aggregate
84            .aggr_expr
85            .iter()
86            .map(|e| PyExpr::from(e.clone()))
87            .collect())
88    }
89
90    /// Returns the inner Aggregate Expr(s)
91    pub fn agg_expressions(&self) -> PyResult<Vec<PyExpr>> {
92        Ok(self
93            .aggregate
94            .aggr_expr
95            .iter()
96            .map(|e| PyExpr::from(e.clone()))
97            .collect())
98    }
99
100    pub fn agg_func_name(&self, expr: PyExpr) -> PyResult<String> {
101        Self::_agg_func_name(&expr.expr)
102    }
103
104    pub fn aggregation_arguments(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
105        self._aggregation_arguments(&expr.expr)
106    }
107
108    // Retrieves the input `LogicalPlan` to this `Aggregate` node
109    fn input(&self) -> PyResult<Vec<PyLogicalPlan>> {
110        Ok(Self::inputs(self))
111    }
112
113    // Resulting Schema for this `Aggregate` node instance
114    fn schema(&self) -> PyDFSchema {
115        (*self.aggregate.schema).clone().into()
116    }
117
118    fn __repr__(&self) -> PyResult<String> {
119        Ok(format!("Aggregate({})", self))
120    }
121}
122
123impl PyAggregate {
124    #[allow(clippy::only_used_in_recursion)]
125    fn _aggregation_arguments(&self, expr: &Expr) -> PyResult<Vec<PyExpr>> {
126        match expr {
127            // TODO: This Alias logic seems to be returning some strange results that we should investigate
128            Expr::Alias(Alias { expr, .. }) => self._aggregation_arguments(expr.as_ref()),
129            Expr::AggregateFunction(AggregateFunction {
130                func: _,
131                params: AggregateFunctionParams { args, .. },
132                ..
133            }) => Ok(args.iter().map(|e| PyExpr::from(e.clone())).collect()),
134            _ => Err(py_type_err(
135                "Encountered a non Aggregate type in aggregation_arguments",
136            )),
137        }
138    }
139
140    fn _agg_func_name(expr: &Expr) -> PyResult<String> {
141        match expr {
142            Expr::Alias(Alias { expr, .. }) => Self::_agg_func_name(expr.as_ref()),
143            Expr::AggregateFunction(AggregateFunction { func, .. }) => Ok(func.name().to_owned()),
144            _ => Err(py_type_err(
145                "Encountered a non Aggregate type in agg_func_name",
146            )),
147        }
148    }
149}
150
151impl LogicalNode for PyAggregate {
152    fn inputs(&self) -> Vec<PyLogicalPlan> {
153        vec![PyLogicalPlan::from((*self.aggregate.input).clone())]
154    }
155
156    fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
157        self.clone().into_bound_py_any(py)
158    }
159}