use std::sync::Arc;
use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::config::ConfigOptions;
use datafusion::error::Result as DFResult;
use datafusion::physical_optimizer::PhysicalOptimizerRule;
#[allow(deprecated)]
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::{
ExecutionPlan, ExecutionPlanProperties,
aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy},
coalesce_partitions::CoalescePartitionsExec,
projection::ProjectionExec,
repartition::RepartitionExec,
union::UnionExec,
};
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datafusion_physical_expr::expressions::{Column, Literal};
use lance_index::scalar::expression::ScalarIndexExpr;
use log::warn;
use roaring::RoaringBitmap;
use super::count_from_mask::CountFromMaskExec;
use super::filtered_read::{FilteredReadExec, FilteredReadOptions};
use super::scalar_index::ScalarIndexExec;
#[derive(Debug)]
pub struct CountPushdown;
impl PhysicalOptimizerRule for CountPushdown {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> DFResult<Arc<dyn ExecutionPlan>> {
Ok(plan
.transform_down(|plan| {
let Some(agg) = plan.as_any().downcast_ref::<AggregateExec>() else {
return Ok(Transformed::no(plan));
};
if let Some(rewritten) = try_rewrite(agg)? {
return Ok(Transformed::yes(rewritten));
}
Ok(Transformed::no(plan))
})?
.data)
}
fn name(&self) -> &str {
"count_pushdown"
}
fn schema_check(&self) -> bool {
true
}
}
fn try_rewrite(agg: &AggregateExec) -> DFResult<Option<Arc<dyn ExecutionPlan>>> {
let mode = match agg.mode() {
AggregateMode::Single => AggregateMode::Single,
AggregateMode::Partial => AggregateMode::Partial,
_ => return Ok(None),
};
if !agg.group_expr().is_empty() {
return Ok(None);
}
if agg.aggr_expr().is_empty() {
return Ok(None);
}
for (af, filter) in agg.aggr_expr().iter().zip(agg.filter_expr().iter()) {
if !is_count_star(af) {
return Ok(None);
}
if filter.is_some() {
return Ok(None);
}
}
let Some(filtered_read) = strip_row_preserving_wrappers(agg.children()[0]) else {
return Ok(None);
};
if filtered_read.dataset().manifest().uses_stable_row_ids() {
warn!(
"count_pushdown: skipped because the dataset uses stable row ids; \
the count will be computed via a full scan. Reconciling the two id spaces \
would let this query be answered from index metadata."
);
return Ok(None);
}
let options = filtered_read.options();
if options.refine_filter.is_some() {
return Ok(None);
}
if options.full_filter.is_some() && filtered_read.index_input().is_none() {
return Ok(None);
}
if options.scan_range_before_filter.is_some() || options.scan_range_after_filter.is_some() {
return Ok(None);
}
if options.with_deleted_rows {
warn!(
"count_pushdown: skipped because the FilteredReadExec was built \
with with_deleted_rows; the count will be computed via a full \
scan."
);
return Ok(None);
}
if options.fragments.is_some() {
warn!(
"count_pushdown: skipped because the FilteredReadExec was scoped \
to an explicit fragment subset; the count will be computed via a \
full scan. Intersecting that subset into the coverage logic would \
let this query be answered from index metadata."
);
return Ok(None);
}
let dataset = filtered_read.dataset().clone();
let dataset_fragments: RoaringBitmap =
dataset.fragments().iter().map(|f| f.id as u32).collect();
let prefilter_input = filtered_read.index_input().cloned();
let index_coverage = match &prefilter_input {
None => None,
Some(input) => {
let scalar_exec = input
.as_any()
.downcast_ref::<ScalarIndexExec>()
.ok_or_else(|| {
datafusion::error::DataFusionError::Internal(
"count_pushdown: FilteredReadExec.index_input is not a ScalarIndexExec"
.to_string(),
)
})?;
if scalar_exec.expr().needs_recheck() {
return Ok(None);
}
let Some(coverage) = collect_coverage(scalar_exec.expr()) else {
return Ok(None);
};
Some(coverage)
}
};
let aggr_exprs: Vec<Arc<AggregateFunctionExpr>> = agg.aggr_expr().to_vec();
let (partial_stream, partial_state_schema): (Arc<dyn ExecutionPlan>, _) = match index_coverage {
None => {
let exec = CountFromMaskExec::try_new_restricted(
dataset,
aggr_exprs.clone(),
prefilter_input,
None,
)?;
let schema = exec.schema();
(Arc::new(exec), schema)
}
Some(coverage) if (&dataset_fragments - &coverage).is_empty() => {
let exec = CountFromMaskExec::try_new_restricted(
dataset,
aggr_exprs.clone(),
prefilter_input,
None,
)?;
let schema = exec.schema();
(Arc::new(exec), schema)
}
Some(coverage) => {
let uncovered = &dataset_fragments - &coverage;
let pushdown_exec = CountFromMaskExec::try_new_restricted(
dataset,
aggr_exprs.clone(),
prefilter_input,
Some(&dataset_fragments & &coverage),
)?;
let partial_state_schema = pushdown_exec.schema();
let pushdown_branch: Arc<dyn ExecutionPlan> = Arc::new(pushdown_exec);
let scan_branch =
build_scan_branch(filtered_read, options, &uncovered, aggr_exprs.clone())?;
let union: Arc<dyn ExecutionPlan> =
UnionExec::try_new(vec![pushdown_branch, scan_branch])?;
(union, partial_state_schema)
}
};
match mode {
AggregateMode::Partial => {
Ok(Some(partial_stream))
}
AggregateMode::Single => {
let final_input: Arc<dyn ExecutionPlan> =
if partial_stream.output_partitioning().partition_count() > 1 {
Arc::new(CoalescePartitionsExec::new(partial_stream))
} else {
partial_stream
};
let filters: Vec<Option<Arc<dyn datafusion::physical_expr::PhysicalExpr>>> =
(0..aggr_exprs.len()).map(|_| None).collect();
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
aggr_exprs,
filters,
final_input,
partial_state_schema,
)?;
Ok(Some(Arc::new(final_agg)))
}
_ => unreachable!("mode was checked at the top of try_rewrite"),
}
}
fn build_scan_branch(
filtered_read: &FilteredReadExec,
options: &FilteredReadOptions,
uncovered: &RoaringBitmap,
aggr_exprs: Vec<Arc<AggregateFunctionExpr>>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
let dataset = filtered_read.dataset().clone();
let uncovered_fragments: Vec<_> = dataset
.manifest()
.fragments
.iter()
.filter(|f| uncovered.contains(f.id as u32))
.cloned()
.collect();
let mut scan_options = options.clone();
scan_options.fragments = Some(Arc::new(uncovered_fragments));
let scan = FilteredReadExec::try_new(dataset, scan_options, None)?;
let scan: Arc<dyn ExecutionPlan> = Arc::new(scan);
let scan_schema = scan.schema();
let filters: Vec<Option<Arc<dyn datafusion::physical_expr::PhysicalExpr>>> =
(0..aggr_exprs.len()).map(|_| None).collect();
let partial = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
aggr_exprs,
filters,
scan,
scan_schema,
)?;
Ok(Arc::new(partial))
}
fn strip_row_preserving_wrappers(plan: &Arc<dyn ExecutionPlan>) -> Option<&FilteredReadExec> {
let mut current: &dyn ExecutionPlan = plan.as_ref();
loop {
if let Some(filtered_read) = current.as_any().downcast_ref::<FilteredReadExec>() {
return Some(filtered_read);
}
let next: &Arc<dyn ExecutionPlan> =
if let Some(inner) = current.as_any().downcast_ref::<RepartitionExec>() {
inner.input()
} else if let Some(inner) = {
#[allow(deprecated)]
current.as_any().downcast_ref::<CoalesceBatchesExec>()
} {
inner.input()
} else if let Some(inner) = current.as_any().downcast_ref::<CoalescePartitionsExec>() {
inner.input()
} else if let Some(proj) = current.as_any().downcast_ref::<ProjectionExec>() {
let input_schema = proj.input().schema();
let identity = proj.expr().iter().all(|projection_expr| {
projection_expr
.expr
.as_any()
.downcast_ref::<Column>()
.is_some_and(|c| c.name() == input_schema.field(c.index()).name())
});
if !identity {
return None;
}
proj.input()
} else {
return None;
};
current = next.as_ref();
}
}
fn collect_coverage(expr: &ScalarIndexExpr) -> Option<RoaringBitmap> {
match expr {
ScalarIndexExpr::Not(inner) => collect_coverage(inner),
ScalarIndexExpr::And(lhs, rhs) | ScalarIndexExpr::Or(lhs, rhs) => {
let l = collect_coverage(lhs)?;
let r = collect_coverage(rhs)?;
Some(l & r)
}
ScalarIndexExpr::Query(search) => search.fragment_bitmap.clone(),
}
}
fn is_count_star(af: &Arc<AggregateFunctionExpr>) -> bool {
if af.fun().name() != "count" {
return false;
}
if af.is_distinct() {
return false;
}
let args = af.expressions();
if args.len() != 1 {
return false;
}
let Some(lit) = args[0].as_any().downcast_ref::<Literal>() else {
return false;
};
!lit.value().is_null()
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::datatypes::{Int64Type, UInt64Type};
use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion::physical_plan::{ExecutionPlan, displayable};
use futures::TryStreamExt;
use lance_core::utils::tempfile::TempStrDir;
use lance_datagen::gen_batch;
use lance_index::IndexType;
use lance_index::scalar::{BuiltinIndexType, ScalarIndexParams};
use super::*;
use crate::Dataset;
use crate::dataset::scanner::AggregateExpr;
use crate::index::DatasetIndexExt;
use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount};
struct Fixture {
dataset: Arc<Dataset>,
_tmp: TempStrDir,
}
async fn make_fixture() -> Fixture {
let tmp = TempStrDir::default();
let mut dataset = gen_batch()
.col("ordered", lance_datagen::array::step::<UInt64Type>())
.into_dataset(
tmp.as_str(),
FragmentCount::from(4),
FragmentRowCount::from(10),
)
.await
.unwrap();
dataset
.create_index(
&["ordered"],
IndexType::BTree,
None,
&ScalarIndexParams::default(),
true,
)
.await
.unwrap();
Fixture {
dataset: Arc::new(dataset),
_tmp: tmp,
}
}
fn plan_contains_pushdown(plan: &Arc<dyn ExecutionPlan>) -> bool {
let mut found = false;
plan.apply(|node| {
if node.as_any().is::<CountFromMaskExec>() {
found = true;
Ok(TreeNodeRecursion::Stop)
} else {
Ok(TreeNodeRecursion::Continue)
}
})
.unwrap();
found
}
fn plan_contains_union(plan: &Arc<dyn ExecutionPlan>) -> bool {
let mut found = false;
plan.apply(|node| {
if node.as_any().is::<UnionExec>() {
found = true;
Ok(TreeNodeRecursion::Stop)
} else {
Ok(TreeNodeRecursion::Continue)
}
})
.unwrap();
found
}
async fn run_count(
scanner: &mut crate::dataset::scanner::Scanner,
) -> (Arc<dyn ExecutionPlan>, i64) {
scanner
.aggregate(AggregateExpr::builder().count_star().build())
.unwrap();
let plan = scanner.create_plan().await.unwrap();
let stream = datafusion::physical_plan::execute_stream(
plan.clone(),
Arc::new(datafusion::execution::TaskContext::default()),
)
.unwrap();
let batches: Vec<_> = stream.try_collect().await.unwrap();
assert_eq!(
batches.len(),
1,
"count plan emitted {} batches",
batches.len()
);
let count = batches[0]
.column(0)
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<Int64Type>>()
.expect("count column should be Int64")
.value(0);
(plan, count)
}
#[tokio::test]
async fn rule_fires_on_unfiltered_count_star() {
let fixture = make_fixture().await;
let mut scanner = fixture.dataset.scan();
let (plan, count) = run_count(&mut scanner).await;
assert_eq!(count, 40);
assert!(
plan_contains_pushdown(&plan),
"expected CountFromMaskExec in plan: {}",
displayable(plan.as_ref()).indent(true)
);
assert!(
!plan_contains_union(&plan),
"no union expected for unfiltered count, got: {}",
displayable(plan.as_ref()).indent(true)
);
}
#[tokio::test]
async fn rule_fires_when_filter_fully_indexed() {
let fixture = make_fixture().await;
let mut scanner = fixture.dataset.scan();
scanner.filter("ordered < 25").unwrap();
let (plan, count) = run_count(&mut scanner).await;
assert_eq!(count, 25);
assert!(
plan_contains_pushdown(&plan),
"expected CountFromMaskExec in plan: {}",
displayable(plan.as_ref()).indent(true)
);
assert!(
!plan_contains_union(&plan),
"no union expected when index covers every fragment, got: {}",
displayable(plan.as_ref()).indent(true)
);
}
#[tokio::test]
async fn rule_emits_split_plan_for_partial_index_coverage() {
use crate::dataset::WriteParams;
let tmp = TempStrDir::default();
let mut dataset = gen_batch()
.col("ordered", lance_datagen::array::step::<UInt64Type>())
.into_dataset(
tmp.as_str(),
FragmentCount::from(4),
FragmentRowCount::from(10),
)
.await
.unwrap();
dataset
.create_index(
&["ordered"],
IndexType::BTree,
None,
&ScalarIndexParams::default(),
true,
)
.await
.unwrap();
let extra = gen_batch()
.col("ordered", lance_datagen::array::step::<UInt64Type>())
.into_reader_rows(
lance_datagen::RowCount::from(10),
lance_datagen::BatchCount::from(1),
);
let dataset = Dataset::write(
extra,
tmp.as_str(),
Some(WriteParams {
mode: crate::dataset::WriteMode::Append,
max_rows_per_file: 10,
..Default::default()
}),
)
.await
.unwrap();
let dataset = Arc::new(dataset);
let mut scanner = dataset.scan();
scanner.filter("ordered < 100").unwrap();
let (plan, count) = run_count(&mut scanner).await;
assert_eq!(count, 50);
assert!(
plan_contains_pushdown(&plan),
"expected pushdown branch in split plan: {}",
displayable(plan.as_ref()).indent(true)
);
assert!(
plan_contains_union(&plan),
"expected UnionExec for partial-coverage split, got: {}",
displayable(plan.as_ref()).indent(true)
);
}
#[tokio::test]
async fn rule_skips_with_stable_row_ids() {
use crate::dataset::WriteParams;
let tmp = TempStrDir::default();
let mut dataset = gen_batch()
.col("ordered", lance_datagen::array::step::<UInt64Type>())
.into_dataset_with_params(
tmp.as_str(),
FragmentCount::from(2),
FragmentRowCount::from(10),
Some(WriteParams {
max_rows_per_file: 10,
enable_stable_row_ids: true,
..Default::default()
}),
)
.await
.unwrap();
dataset.delete("ordered = 0").await.unwrap();
let dataset = Arc::new(dataset);
let mut scanner = dataset.scan();
let (plan, count) = run_count(&mut scanner).await;
assert_eq!(count, 19);
assert!(
!plan_contains_pushdown(&plan),
"rule must not fire under stable row IDs, got plan: {}",
displayable(plan.as_ref()).indent(true)
);
}
#[tokio::test]
async fn rule_skips_when_filter_needs_refine() {
let tmp = TempStrDir::default();
let mut dataset = gen_batch()
.col("ordered", lance_datagen::array::step::<UInt64Type>())
.col("unindexed", lance_datagen::array::step::<UInt64Type>())
.into_dataset(
tmp.as_str(),
FragmentCount::from(4),
FragmentRowCount::from(10),
)
.await
.unwrap();
dataset
.create_index(
&["ordered"],
IndexType::BTree,
None,
&ScalarIndexParams::default(),
true,
)
.await
.unwrap();
let dataset = Arc::new(dataset);
let mut scanner = dataset.scan();
scanner.filter("unindexed > 5").unwrap();
let (plan, count) = run_count(&mut scanner).await;
assert_eq!(count, 34);
assert!(
!plan_contains_pushdown(&plan),
"rule should not fire with non-indexed filter, got plan: {}",
displayable(plan.as_ref()).indent(true)
);
}
#[tokio::test]
async fn rule_skips_when_index_is_inexact() {
let tmp = TempStrDir::default();
let mut dataset = gen_batch()
.col("ordered", lance_datagen::array::step::<UInt64Type>())
.into_dataset(
tmp.as_str(),
FragmentCount::from(4),
FragmentRowCount::from(10),
)
.await
.unwrap();
dataset
.create_index(
&["ordered"],
IndexType::ZoneMap,
None,
&ScalarIndexParams::for_builtin(BuiltinIndexType::ZoneMap),
true,
)
.await
.unwrap();
let dataset = Arc::new(dataset);
let mut scanner = dataset.scan();
scanner.filter("ordered < 25").unwrap();
let (plan, count) = run_count(&mut scanner).await;
assert_eq!(count, 25);
assert!(
!plan_contains_pushdown(&plan),
"rule must not fire when the index produces inexact (needs_recheck) results, \
got plan: {}",
displayable(plan.as_ref()).indent(true)
);
}
#[tokio::test]
async fn rule_skips_count_with_group_by() {
let fixture = make_fixture().await;
let mut scanner = fixture.dataset.scan();
scanner
.aggregate(
AggregateExpr::builder()
.group_by("ordered")
.count_star()
.build(),
)
.unwrap();
let plan = scanner.create_plan().await.unwrap();
assert!(
!plan_contains_pushdown(&plan),
"rule should not fire for GROUP BY: {}",
displayable(plan.as_ref()).indent(true)
);
}
}