use std::fmt;
use serde::Serialize;
use super::*;
use crate::v1::binder::{BoundAggCall, BoundExpr, ExprVisitor};
use crate::v1::optimizer::logical_plan_rewriter::ExprRewriter;
#[derive(Debug, Clone, Serialize)]
pub struct LogicalAggregate {
agg_calls: Vec<BoundAggCall>,
group_keys: Vec<BoundExpr>,
child: PlanRef,
}
impl LogicalAggregate {
pub fn new(agg_calls: Vec<BoundAggCall>, group_keys: Vec<BoundExpr>, child: PlanRef) -> Self {
LogicalAggregate {
agg_calls,
group_keys,
child,
}
}
pub fn agg_calls(&self) -> &[BoundAggCall] {
self.agg_calls.as_ref()
}
pub fn group_keys(&self) -> &[BoundExpr] {
self.group_keys.as_ref()
}
pub fn clone_with_rewrite_expr(
&self,
new_child: PlanRef,
rewriter: &impl ExprRewriter,
) -> Self {
let mut new_agg_calls = self.agg_calls().to_vec();
let mut new_keys = self.group_keys().to_vec();
for agg in &mut new_agg_calls {
for arg in &mut agg.args {
rewriter.rewrite_expr(arg);
}
}
for keys in &mut new_keys {
rewriter.rewrite_expr(keys);
}
LogicalAggregate::new(new_agg_calls, new_keys, new_child)
}
}
impl PlanTreeNodeUnary for LogicalAggregate {
fn child(&self) -> PlanRef {
self.child.clone()
}
#[must_use]
fn clone_with_child(&self, child: PlanRef) -> Self {
Self::new(self.agg_calls().to_vec(), self.group_keys().to_vec(), child)
}
}
impl_plan_tree_node_for_unary!(LogicalAggregate);
impl PlanNode for LogicalAggregate {
fn schema(&self) -> Vec<ColumnDesc> {
let child_schema = self.child.schema();
self.group_keys
.iter()
.map(|expr| ColumnDesc::new(expr.return_type(), expr.format_name(&child_schema), false))
.chain(self.agg_calls.iter().map(|agg_call| {
agg_call
.return_type
.clone()
.to_column(format!("{}", agg_call.kind))
}))
.collect()
}
fn estimated_cardinality(&self) -> usize {
self.child().estimated_cardinality()
}
fn prune_col(&self, required_cols: BitSet) -> PlanRef {
let group_keys_len = self.group_keys.len();
let mut visitor =
CollectRequiredCols(BitSet::with_capacity(group_keys_len + self.agg_calls.len()));
let mut new_agg_calls: Vec<_> = required_cols
.iter()
.filter(|&index| index >= group_keys_len)
.map(|index| {
let call = &self.agg_calls[index - group_keys_len];
call.args.iter().for_each(|expr| {
visitor.visit_expr(expr);
});
self.agg_calls[index - group_keys_len].clone()
})
.collect();
self.group_keys
.iter()
.for_each(|group| visitor.visit_expr(group));
let input_cols = visitor.0;
let mapper = Mapper::new_with_bitset(&input_cols);
for call in &mut new_agg_calls {
call.args.iter_mut().for_each(|expr| {
mapper.rewrite_expr(expr);
})
}
let mut group_keys = self.group_keys.clone();
group_keys
.iter_mut()
.for_each(|expr| mapper.rewrite_expr(expr));
let new_agg = LogicalAggregate::new(
new_agg_calls.clone(),
group_keys,
self.child.prune_col(input_cols),
);
let bitset = BitSet::from_iter(0..group_keys_len);
if bitset.is_subset(&required_cols) {
new_agg.into_plan_ref()
} else {
let mut new_projection: Vec<BoundExpr> = required_cols
.iter()
.filter(|&i| i < group_keys_len)
.map(|index| {
BoundExpr::InputRef(BoundInputRef {
index,
return_type: self.group_keys[index].return_type(),
})
})
.collect();
for (index, item) in new_agg_calls.iter().enumerate() {
new_projection.push(BoundExpr::InputRef(BoundInputRef {
index: group_keys_len + index,
return_type: item.return_type.clone(),
}))
}
LogicalProjection::new(new_projection, new_agg.into_plan_ref()).into_plan_ref()
}
}
}
impl fmt::Display for LogicalAggregate {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "LogicalAggregate: {} agg calls", self.agg_calls.len(),)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::DataTypeKind;
use crate::v1::binder::AggKind;
#[test]
fn test_aggregate_out_names() {
let plan = LogicalAggregate::new(
vec![
BoundAggCall {
kind: AggKind::Sum,
args: vec![],
return_type: DataTypeKind::Float64.not_null(),
},
BoundAggCall {
kind: AggKind::Avg,
args: vec![],
return_type: DataTypeKind::Float64.not_null(),
},
BoundAggCall {
kind: AggKind::Count,
args: vec![],
return_type: DataTypeKind::Float64.not_null(),
},
BoundAggCall {
kind: AggKind::RowCount,
args: vec![],
return_type: DataTypeKind::Float64.not_null(),
},
],
vec![],
Arc::new(Dummy::new(Vec::new())),
);
let column_names = plan.out_names();
assert_eq!(column_names[0], "sum");
assert_eq!(column_names[1], "avg");
assert_eq!(column_names[2], "count");
assert_eq!(column_names[3], "count");
}
#[test]
fn test_prune_aggregate() {
let ty = DataTypeKind::Int32.not_null();
let col_descs = vec![
ty.clone().to_column("v1".into()),
ty.clone().to_column("v2".into()),
ty.clone().to_column("v3".into()),
];
let table_scan = LogicalTableScan::new(
crate::catalog::TableRefId {
database_id: 0,
schema_id: 0,
table_id: 0,
},
vec![1, 2, 3],
col_descs,
false,
false,
None,
);
let input_refs = vec![
BoundExpr::InputRef(BoundInputRef {
index: 0,
return_type: ty.clone(),
}),
BoundExpr::InputRef(BoundInputRef {
index: 1,
return_type: ty.clone(),
}),
BoundExpr::InputRef(BoundInputRef {
index: 2,
return_type: ty,
}),
];
let aggregate = LogicalAggregate::new(
vec![
BoundAggCall {
kind: AggKind::Sum,
args: vec![input_refs[0].clone()],
return_type: DataTypeKind::Int32.not_null(),
},
BoundAggCall {
kind: AggKind::Avg,
args: vec![input_refs[1].clone()],
return_type: DataTypeKind::Int32.not_null(),
},
],
vec![input_refs[2].clone()],
Arc::new(table_scan),
);
let mut required_cols = BitSet::new();
required_cols.insert(2);
let plan = aggregate.prune_col(required_cols);
let plan = plan.as_logical_projection().unwrap();
assert_eq!(plan.project_expressions(), vec![input_refs[1].clone()]);
let plan = plan.child();
let plan = plan.as_logical_aggregate().unwrap();
assert_eq!(
plan.agg_calls(),
vec![BoundAggCall {
kind: AggKind::Avg,
args: vec![input_refs[0].clone()],
return_type: DataTypeKind::Int32.not_null(),
}]
);
}
}