use std::{cmp::Ordering, collections::HashSet, sync::Arc};
use arrow_array::ArrayRef;
use arrow_schema::DataType;
use datafusion::{
catalog::TableProvider,
common::{ScalarValue, tree_node::Transformed},
datasource::{DefaultTableSource, provider_as_source},
error::{DataFusionError, Result as DfResult},
functions::core::expr_fn::{coalesce, greatest, least},
functions_aggregate::expr_fn::{count, max, min, sum},
logical_expr::{Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, TableScan, lit},
optimizer::{OptimizerConfig, OptimizerRule, optimizer::ApplyOrder},
prelude::{cast, col},
};
use uuid::Uuid;
use crate::supertable::{
manifest::{SuperfileEntry, add_sum_arrays},
options::{DECIMAL128_PRECISION, DECIMAL128_SCALE},
query::provider::SupertableProvider,
};
#[derive(Debug, Default)]
pub(crate) struct CoveredAggregateRewrite;
impl OptimizerRule for CoveredAggregateRewrite {
fn name(&self) -> &str {
"covered_aggregate_rewrite"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}
fn supports_rewrite(&self) -> bool {
true
}
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> DfResult<Transformed<LogicalPlan>> {
match try_rewrite(&plan)? {
Some(rewritten) => Ok(Transformed::yes(rewritten)),
None => Ok(Transformed::no(plan)),
}
}
}
enum AggKind {
CountStar,
Sum(String),
Min(String),
Max(String),
Avg(String),
}
#[derive(Clone)]
struct Bound {
value: ScalarValue,
inclusive: bool,
}
struct RangeFilter {
column: String,
lo: Option<Bound>,
hi: Option<Bound>,
}
enum Class {
Disjoint,
Covered,
Boundary,
}
fn try_rewrite(plan: &LogicalPlan) -> DfResult<Option<LogicalPlan>> {
let LogicalPlan::Aggregate(agg) = plan else {
return Ok(None);
};
if !agg.group_expr.is_empty() {
return Ok(None);
}
let Some(kinds) = parse_aggregates(&agg.aggr_expr) else {
return Ok(None);
};
let (predicate, scan) = match peel_input(agg.input.as_ref()) {
Some(found) => found,
None => return Ok(None),
};
let Some(provider) = provider_of(scan) else {
return Ok(None);
};
if provider.is_segment_restricted() {
return Ok(None);
}
let Some(range) = extract_range(&predicate) else {
return Ok(None);
};
let manifest = provider.manifest();
if !manifest.is_in_process_only() {
return Ok(None);
}
let id_column = manifest.options.id_column.as_str();
let mut covered: Vec<&Arc<SuperfileEntry>> = Vec::new();
let mut boundary: HashSet<Uuid> = HashSet::new();
for entry in &manifest.superfiles {
let class = classify(entry, id_column, &range);
match class {
Class::Disjoint => {}
Class::Covered => {
if provider.entry_is_clean(entry) && has_required_stats(entry, &kinds) {
covered.push(entry);
} else {
boundary.insert(entry.superfile_id);
}
}
Class::Boundary => {
boundary.insert(entry.superfile_id);
}
}
}
if covered.is_empty() {
return Ok(None);
}
let mut partials: Vec<Partial> = Vec::with_capacity(kinds.len());
for kind in &kinds {
match accumulate_partial(kind, &covered) {
Some(partial) => partials.push(partial),
None => return Ok(None),
}
}
let restricted = Arc::new(provider.restricted_to(boundary));
let mut builder = LogicalPlanBuilder::scan(
scan.table_name.clone(),
provider_as_source(restricted as Arc<dyn TableProvider>),
None,
)?
.filter(predicate.clone())?;
let mut partial_exprs: Vec<Expr> = Vec::new();
for (i, kind) in kinds.iter().enumerate() {
match kind {
AggKind::CountStar => {
partial_exprs.push(count(lit(1i64)).alias(format!("__resid_{i}_cnt")));
}
AggKind::Sum(column) => {
partial_exprs.push(sum(col(column)).alias(format!("__resid_{i}_sum")));
}
AggKind::Min(column) => {
partial_exprs.push(min(col(column)).alias(format!("__resid_{i}_min")));
}
AggKind::Max(column) => {
partial_exprs.push(max(col(column)).alias(format!("__resid_{i}_max")));
}
AggKind::Avg(column) => {
partial_exprs.push(sum(col(column)).alias(format!("__resid_{i}_sum")));
partial_exprs.push(count(col(column)).alias(format!("__resid_{i}_cnt")));
}
}
}
builder = builder.aggregate(Vec::<Expr>::new(), partial_exprs)?;
let out_names: Vec<String> = agg
.schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect();
let mut final_exprs: Vec<Expr> = Vec::with_capacity(kinds.len());
for (i, (kind, partial)) in kinds.iter().zip(&partials).enumerate() {
let name = &out_names[i];
let expr = match (kind, partial) {
(AggKind::CountStar, Partial::Count(n)) => {
(col(format!("__resid_{i}_cnt")) + lit(*n)).alias(name)
}
(AggKind::Sum(_), Partial::Sum(value)) => {
let zero = typed_zero(value)?;
(coalesce(vec![col(format!("__resid_{i}_sum")), lit(zero)]) + lit(value.clone()))
.alias(name)
}
(AggKind::Min(_), Partial::Bound(value)) => {
least(vec![col(format!("__resid_{i}_min")), lit(value.clone())]).alias(name)
}
(AggKind::Max(_), Partial::Bound(value)) => {
greatest(vec![col(format!("__resid_{i}_max")), lit(value.clone())]).alias(name)
}
(AggKind::Avg(_), Partial::Avg { sum, count }) => {
let zero = typed_zero(sum)?;
let total_sum =
coalesce(vec![col(format!("__resid_{i}_sum")), lit(zero)]) + lit(sum.clone());
let total_cnt = col(format!("__resid_{i}_cnt")) + lit(*count);
(cast(total_sum, DataType::Float64) / cast(total_cnt, DataType::Float64))
.alias(name)
}
_ => return Ok(None),
};
final_exprs.push(expr);
}
let rewritten = builder.project(final_exprs)?.build()?;
Ok(Some(rewritten))
}
enum Partial {
Count(i64),
Sum(ScalarValue),
Bound(ScalarValue),
Avg { sum: ScalarValue, count: i64 },
}
fn accumulate_partial(kind: &AggKind, covered: &[&Arc<SuperfileEntry>]) -> Option<Partial> {
match kind {
AggKind::CountStar => {
let mut total: i64 = 0;
for entry in covered {
total = total.checked_add(i64::try_from(entry.n_docs).ok()?)?;
}
Some(Partial::Count(total))
}
AggKind::Sum(col) => Some(Partial::Sum(fold_sums(covered, col)?)),
AggKind::Min(col) => {
let (min, _) = fold_bounds(covered, col)?;
Some(Partial::Bound(min))
}
AggKind::Max(col) => {
let (_, max) = fold_bounds(covered, col)?;
Some(Partial::Bound(max))
}
AggKind::Avg(col) => {
let sum = fold_sums(covered, col)?;
let mut count: i64 = 0;
for entry in covered {
let nulls = entry.scalar_stats.get(col)?.null_count?;
let non_null = entry.n_docs.checked_sub(nulls)?;
count = count.checked_add(i64::try_from(non_null).ok()?)?;
}
if count == 0 {
return None;
}
Some(Partial::Avg { sum, count })
}
}
}
fn fold_sums(covered: &[&Arc<SuperfileEntry>], col: &str) -> Option<ScalarValue> {
let mut acc: Option<ArrayRef> = None;
for entry in covered {
let part = entry.scalar_stats.get(col)?.sum.as_ref()?;
acc = Some(match acc {
None => Arc::clone(part),
Some(total) => add_sum_arrays(&total, part)?,
});
}
ScalarValue::try_from_array(&acc?, 0).ok()
}
fn fold_bounds(covered: &[&Arc<SuperfileEntry>], col: &str) -> Option<(ScalarValue, ScalarValue)> {
let mut acc: Option<(ScalarValue, ScalarValue)> = None;
for entry in covered {
let agg = entry.scalar_stats.get(col)?;
let min = ScalarValue::try_from_array(&agg.min, 0).ok()?;
let max = ScalarValue::try_from_array(&agg.max, 0).ok()?;
if min.is_null() || max.is_null() {
return None;
}
acc = match acc {
None => Some((min, max)),
Some((cur_min, cur_max)) => {
let new_min = if min.partial_cmp(&cur_min)? == Ordering::Less {
min
} else {
cur_min
};
let new_max = if max.partial_cmp(&cur_max)? == Ordering::Greater {
max
} else {
cur_max
};
Some((new_min, new_max))
}
};
}
acc
}
fn typed_zero(value: &ScalarValue) -> DfResult<ScalarValue> {
match value {
ScalarValue::Int64(_) => Ok(ScalarValue::Int64(Some(0))),
ScalarValue::UInt64(_) => Ok(ScalarValue::UInt64(Some(0))),
ScalarValue::Float64(_) => Ok(ScalarValue::Float64(Some(0.0))),
other => Err(DataFusionError::Internal(format!(
"covered_agg: unexpected sum type {other:?}"
))),
}
}
fn parse_aggregates(exprs: &[Expr]) -> Option<Vec<AggKind>> {
let mut kinds = Vec::with_capacity(exprs.len());
for expr in exprs {
let inner = match expr {
Expr::Alias(alias) => alias.expr.as_ref(),
other => other,
};
let Expr::AggregateFunction(agg) = inner else {
return None;
};
let params = &agg.params;
if params.distinct || params.filter.is_some() || !params.order_by.is_empty() {
return None;
}
let kind = match (agg.func.name(), params.args.as_slice()) {
("count", []) => AggKind::CountStar,
("count", [Expr::Literal(value, _)]) if !value.is_null() => AggKind::CountStar,
("sum", [Expr::Column(c)]) => AggKind::Sum(c.name.clone()),
("min", [Expr::Column(c)]) => AggKind::Min(c.name.clone()),
("max", [Expr::Column(c)]) => AggKind::Max(c.name.clone()),
("avg", [arg]) => AggKind::Avg(avg_column(arg)?),
_ => return None,
};
kinds.push(kind);
}
Some(kinds)
}
fn avg_column(arg: &Expr) -> Option<String> {
match arg {
Expr::Column(c) => Some(c.name.clone()),
Expr::Cast(cast) => match cast.expr.as_ref() {
Expr::Column(c) => Some(c.name.clone()),
_ => None,
},
_ => None,
}
}
fn peel_input(input: &LogicalPlan) -> Option<(Expr, &TableScan)> {
let mut node = input;
if let LogicalPlan::Projection(projection) = node {
if !projection.expr.iter().all(|e| matches!(e, Expr::Column(_))) {
return None;
}
node = projection.input.as_ref();
}
let LogicalPlan::Filter(Filter {
predicate, input, ..
}) = node
else {
return None;
};
let LogicalPlan::TableScan(scan) = input.as_ref() else {
return None;
};
Some((predicate.clone(), scan))
}
fn provider_of(scan: &TableScan) -> Option<&SupertableProvider> {
scan.source
.as_any()
.downcast_ref::<DefaultTableSource>()?
.table_provider
.as_any()
.downcast_ref::<SupertableProvider>()
}
fn extract_range(predicate: &Expr) -> Option<RangeFilter> {
let mut leaves: Vec<(String, Operator, ScalarValue)> = Vec::new();
collect_range_leaves(predicate, &mut leaves)?;
let mut range: Option<RangeFilter> = None;
for (column, op, value) in leaves {
let entry = range.get_or_insert_with(|| RangeFilter {
column: column.clone(),
lo: None,
hi: None,
});
if entry.column != column {
return None;
}
let (slot, inclusive) = match op {
Operator::Gt => (&mut entry.lo, false),
Operator::GtEq => (&mut entry.lo, true),
Operator::Lt => (&mut entry.hi, false),
Operator::LtEq => (&mut entry.hi, true),
Operator::Eq => {
if entry.lo.is_some() || entry.hi.is_some() {
return None;
}
entry.lo = Some(Bound {
value: value.clone(),
inclusive: true,
});
entry.hi = Some(Bound {
value,
inclusive: true,
});
continue;
}
_ => return None,
};
if slot.is_some() {
return None;
}
*slot = Some(Bound { value, inclusive });
}
let range = range?;
(range.lo.is_some() || range.hi.is_some()).then_some(range)
}
fn collect_range_leaves(expr: &Expr, out: &mut Vec<(String, Operator, ScalarValue)>) -> Option<()> {
match expr {
Expr::BinaryExpr(binary) if binary.op == Operator::And => {
collect_range_leaves(&binary.left, out)?;
collect_range_leaves(&binary.right, out)
}
Expr::BinaryExpr(binary) => {
let (column, op, value) = match (binary.left.as_ref(), binary.right.as_ref()) {
(Expr::Column(c), Expr::Literal(v, _)) => (c.name.clone(), binary.op, v.clone()),
(Expr::Literal(v, _), Expr::Column(c)) => {
(c.name.clone(), binary.op.swap()?, v.clone())
}
_ => return None,
};
if value.is_null() {
return None;
}
out.push((column, op, value));
Some(())
}
Expr::Between(between) if !between.negated => {
let Expr::Column(c) = between.expr.as_ref() else {
return None;
};
let (Expr::Literal(lo, _), Expr::Literal(hi, _)) =
(between.low.as_ref(), between.high.as_ref())
else {
return None;
};
if lo.is_null() || hi.is_null() {
return None;
}
out.push((c.name.clone(), Operator::GtEq, lo.clone()));
out.push((c.name.clone(), Operator::LtEq, hi.clone()));
Some(())
}
_ => None,
}
}
fn classify(entry: &SuperfileEntry, id_column: &str, range: &RangeFilter) -> Class {
let bounds = if range.column == id_column {
Some((
ScalarValue::Decimal128(Some(entry.id_min), DECIMAL128_PRECISION, DECIMAL128_SCALE),
ScalarValue::Decimal128(Some(entry.id_max), DECIMAL128_PRECISION, DECIMAL128_SCALE),
))
} else {
entry.scalar_stats.get(&range.column).and_then(|agg| {
let mn = ScalarValue::try_from_array(&agg.min, 0).ok()?;
let mx = ScalarValue::try_from_array(&agg.max, 0).ok()?;
(!mn.is_null() && !mx.is_null()).then_some((mn, mx))
})
};
let Some((seg_min, seg_max)) = bounds else {
return Class::Boundary;
};
if let Some(lo) = &range.lo {
let cmp = seg_max.partial_cmp(&lo.value);
match (cmp, lo.inclusive) {
(Some(Ordering::Less), _) => return Class::Disjoint,
(Some(Ordering::Equal), false) => return Class::Disjoint,
(None, _) => return Class::Boundary,
_ => {}
}
}
if let Some(hi) = &range.hi {
let cmp = seg_min.partial_cmp(&hi.value);
match (cmp, hi.inclusive) {
(Some(Ordering::Greater), _) => return Class::Disjoint,
(Some(Ordering::Equal), false) => return Class::Disjoint,
(None, _) => return Class::Boundary,
_ => {}
}
}
let lo_ok = match &range.lo {
None => true,
Some(lo) => match (seg_min.partial_cmp(&lo.value), lo.inclusive) {
(Some(Ordering::Greater), _) => true,
(Some(Ordering::Equal), true) => true,
(None, _) => return Class::Boundary,
_ => false,
},
};
let hi_ok = match &range.hi {
None => true,
Some(hi) => match (seg_max.partial_cmp(&hi.value), hi.inclusive) {
(Some(Ordering::Less), _) => true,
(Some(Ordering::Equal), true) => true,
(None, _) => return Class::Boundary,
_ => false,
},
};
if lo_ok && hi_ok {
Class::Covered
} else {
Class::Boundary
}
}
fn has_required_stats(entry: &SuperfileEntry, kinds: &[AggKind]) -> bool {
kinds.iter().all(|kind| match kind {
AggKind::CountStar => true,
AggKind::Sum(col) => entry.scalar_stats.get(col).is_some_and(|a| a.sum.is_some()),
AggKind::Min(col) | AggKind::Max(col) => entry.scalar_stats.contains_key(col),
AggKind::Avg(col) => entry
.scalar_stats
.get(col)
.is_some_and(|a| a.sum.is_some() && a.null_count.is_some()),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[allow(deprecated)] fn rule_opts_into_rewrite_and_reports_name() {
let rule = CoveredAggregateRewrite;
assert!(rule.supports_rewrite(), "rule must use the rewrite path");
assert_eq!(rule.name(), "covered_aggregate_rewrite");
}
}