datafusion_python/expr/
aggregate.rs1use datafusion::common::DataFusionError;
19use datafusion::logical_expr::expr::{AggregateFunction, AggregateFunctionParams, Alias};
20use datafusion::logical_expr::logical_plan::Aggregate;
21use datafusion::logical_expr::Expr;
22use pyo3::{prelude::*, IntoPyObjectExt};
23use std::fmt::{self, Display, Formatter};
24
25use super::logical_node::LogicalNode;
26use crate::common::df_schema::PyDFSchema;
27use crate::errors::py_type_err;
28use crate::expr::PyExpr;
29use crate::sql::logical::PyLogicalPlan;
30
31#[pyclass(name = "Aggregate", module = "datafusion.expr", subclass)]
32#[derive(Clone)]
33pub struct PyAggregate {
34 aggregate: Aggregate,
35}
36
37impl From<Aggregate> for PyAggregate {
38 fn from(aggregate: Aggregate) -> PyAggregate {
39 PyAggregate { aggregate }
40 }
41}
42
43impl TryFrom<PyAggregate> for Aggregate {
44 type Error = DataFusionError;
45
46 fn try_from(agg: PyAggregate) -> Result<Self, Self::Error> {
47 Ok(agg.aggregate)
48 }
49}
50
51impl Display for PyAggregate {
52 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
53 write!(
54 f,
55 "Aggregate
56 \nGroupBy(s): {:?}
57 \nAggregates(s): {:?}
58 \nInput: {:?}
59 \nProjected Schema: {:?}",
60 &self.aggregate.group_expr,
61 &self.aggregate.aggr_expr,
62 self.aggregate.input,
63 self.aggregate.schema
64 )
65 }
66}
67
68#[pymethods]
69impl PyAggregate {
70 fn group_by_exprs(&self) -> PyResult<Vec<PyExpr>> {
72 Ok(self
73 .aggregate
74 .group_expr
75 .iter()
76 .map(|e| PyExpr::from(e.clone()))
77 .collect())
78 }
79
80 fn aggregate_exprs(&self) -> PyResult<Vec<PyExpr>> {
82 Ok(self
83 .aggregate
84 .aggr_expr
85 .iter()
86 .map(|e| PyExpr::from(e.clone()))
87 .collect())
88 }
89
90 pub fn agg_expressions(&self) -> PyResult<Vec<PyExpr>> {
92 Ok(self
93 .aggregate
94 .aggr_expr
95 .iter()
96 .map(|e| PyExpr::from(e.clone()))
97 .collect())
98 }
99
100 pub fn agg_func_name(&self, expr: PyExpr) -> PyResult<String> {
101 Self::_agg_func_name(&expr.expr)
102 }
103
104 pub fn aggregation_arguments(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
105 self._aggregation_arguments(&expr.expr)
106 }
107
108 fn input(&self) -> PyResult<Vec<PyLogicalPlan>> {
110 Ok(Self::inputs(self))
111 }
112
113 fn schema(&self) -> PyDFSchema {
115 (*self.aggregate.schema).clone().into()
116 }
117
118 fn __repr__(&self) -> PyResult<String> {
119 Ok(format!("Aggregate({})", self))
120 }
121}
122
123impl PyAggregate {
124 #[allow(clippy::only_used_in_recursion)]
125 fn _aggregation_arguments(&self, expr: &Expr) -> PyResult<Vec<PyExpr>> {
126 match expr {
127 Expr::Alias(Alias { expr, .. }) => self._aggregation_arguments(expr.as_ref()),
129 Expr::AggregateFunction(AggregateFunction {
130 func: _,
131 params: AggregateFunctionParams { args, .. },
132 ..
133 }) => Ok(args.iter().map(|e| PyExpr::from(e.clone())).collect()),
134 _ => Err(py_type_err(
135 "Encountered a non Aggregate type in aggregation_arguments",
136 )),
137 }
138 }
139
140 fn _agg_func_name(expr: &Expr) -> PyResult<String> {
141 match expr {
142 Expr::Alias(Alias { expr, .. }) => Self::_agg_func_name(expr.as_ref()),
143 Expr::AggregateFunction(AggregateFunction { func, .. }) => Ok(func.name().to_owned()),
144 _ => Err(py_type_err(
145 "Encountered a non Aggregate type in agg_func_name",
146 )),
147 }
148 }
149}
150
151impl LogicalNode for PyAggregate {
152 fn inputs(&self) -> Vec<PyLogicalPlan> {
153 vec![PyLogicalPlan::from((*self.aggregate.input).clone())]
154 }
155
156 fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
157 self.clone().into_bound_py_any(py)
158 }
159}