use std::collections::HashMap;
use std::sync::Arc;
use arrow_array::{Array, BinaryArray, Int64Array, RecordBatch};
use arrow_schema::{Schema, SchemaRef};
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
execution_plan::{Boundedness, EmissionType},
metrics::{ExecutionPlanMetricsSet, MetricsSet},
};
use datafusion_physical_expr::EquivalenceProperties;
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use futures::{StreamExt, TryStreamExt};
use lance_core::{Error, Result};
use lance_select::{RowAddrMask, RowAddrSelection, RowAddrTreeMap};
use lance_table::format::Fragment;
use roaring::RoaringBitmap;
use tracing::instrument;
use super::utils::InstrumentedRecordBatchStreamAdapter;
use crate::Dataset;
use crate::index::prefilter::DatasetPreFilter;
#[derive(Debug)]
pub struct CountFromMaskExec {
dataset: Arc<Dataset>,
aggregate_funcs: Vec<Arc<AggregateFunctionExpr>>,
prefilter_input: Option<Arc<dyn ExecutionPlan>>,
restrict_to_fragments: Option<RoaringBitmap>,
schema: SchemaRef,
properties: Arc<PlanProperties>,
metrics: ExecutionPlanMetricsSet,
}
impl DisplayAs for CountFromMaskExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "CountFromMask")
}
DisplayFormatType::TreeRender => write!(f, "CountFromMask"),
}
}
}
impl CountFromMaskExec {
pub fn try_new(
dataset: Arc<Dataset>,
aggregate_funcs: Vec<Arc<AggregateFunctionExpr>>,
prefilter_input: Option<Arc<dyn ExecutionPlan>>,
) -> Result<Self> {
Self::try_new_restricted(dataset, aggregate_funcs, prefilter_input, None)
}
pub fn try_new_restricted(
dataset: Arc<Dataset>,
aggregate_funcs: Vec<Arc<AggregateFunctionExpr>>,
prefilter_input: Option<Arc<dyn ExecutionPlan>>,
restrict_to_fragments: Option<RoaringBitmap>,
) -> Result<Self> {
if aggregate_funcs.is_empty() {
return Err(Error::invalid_input(
"CountFromMaskExec requires at least one aggregate".to_string(),
));
}
let state_fields = aggregate_funcs
.iter()
.map(|agg| agg.state_fields())
.collect::<datafusion::error::Result<Vec<_>>>()
.map_err(|e| Error::invalid_input(e.to_string()))?
.into_iter()
.flatten()
.collect::<Vec<_>>();
let state_fields_owned: Vec<arrow_schema::Field> =
state_fields.iter().map(|f| f.as_ref().clone()).collect();
let schema: SchemaRef = Arc::new(Schema::new(state_fields_owned));
let properties = Arc::new(PlanProperties::new(
EquivalenceProperties::new(schema.clone()),
Partitioning::RoundRobinBatch(1),
EmissionType::Incremental,
Boundedness::Bounded,
));
Ok(Self {
dataset,
aggregate_funcs,
prefilter_input,
restrict_to_fragments,
schema,
properties,
metrics: ExecutionPlanMetricsSet::new(),
})
}
async fn load_prefilter(
prefilter_input: Arc<dyn ExecutionPlan>,
context: Arc<datafusion::execution::context::TaskContext>,
) -> Result<RowAddrMask> {
let mut stream = prefilter_input.execute(0, context).map_err(Error::from)?;
let batch = stream
.try_next()
.await
.map_err(Error::from)?
.ok_or_else(|| {
Error::internal(
"CountFromMaskExec: prefilter input produced no batches".to_string(),
)
})?;
while stream.try_next().await.map_err(Error::from)?.is_some() {}
let result_col = batch
.column(0)
.as_any()
.downcast_ref::<BinaryArray>()
.ok_or_else(|| {
Error::internal(format!(
"CountFromMaskExec: prefilter result column has type {:?}, expected Binary",
batch.column(0).data_type()
))
})?;
RowAddrMask::from_arrow(result_col)
}
fn combine_masks(
fragments_allow: RowAddrTreeMap,
prefilter: Option<RowAddrMask>,
deletion_mask: Option<Arc<RowAddrMask>>,
) -> RowAddrMask {
let base = RowAddrMask::AllowList(fragments_allow);
let after_prefilter = match prefilter {
None => base,
Some(prefilter) => base & prefilter,
};
match deletion_mask {
None => after_prefilter,
Some(deletion_mask) => after_prefilter & (*deletion_mask).clone(),
}
}
fn count_from_mask(mask: &RowAddrMask, dataset: &Dataset) -> Result<i64> {
let allow = mask.allow_list().ok_or_else(|| {
Error::internal("CountFromMaskExec: combined mask is not an AllowList".to_string())
})?;
let frag_map: HashMap<u32, &Fragment> = dataset
.fragments()
.iter()
.map(|f| (f.id as u32, f))
.collect();
let mut count = 0i64;
for (frag_id, sel) in allow.iter() {
match sel {
RowAddrSelection::Full => {
let frag = frag_map.get(frag_id).ok_or_else(|| {
Error::internal(format!(
"CountFromMaskExec: fragment {} not found in manifest",
frag_id
))
})?;
let n = frag.physical_rows.ok_or_else(|| {
Error::internal(format!(
"CountFromMaskExec: physical_rows missing for fragment {}",
frag_id
))
})?;
count += n as i64;
}
RowAddrSelection::Partial(bitmap) => {
count += bitmap.len() as i64;
}
}
}
Ok(count)
}
#[instrument(name = "count_from_mask", skip_all, level = "debug")]
async fn do_execute(
dataset: Arc<Dataset>,
aggregate_funcs_len: usize,
prefilter_input: Option<Arc<dyn ExecutionPlan>>,
restrict_to_fragments: Option<RoaringBitmap>,
context: Arc<datafusion::execution::context::TaskContext>,
schema: SchemaRef,
) -> Result<RecordBatch> {
let prefilter = match prefilter_input {
None => None,
Some(input) => Some(Self::load_prefilter(input, context.clone()).await?),
};
let dataset_fragments: RoaringBitmap =
dataset.fragments().iter().map(|f| f.id as u32).collect();
let fragments_covered = match restrict_to_fragments {
Some(restrict) => dataset_fragments & restrict,
None => dataset_fragments,
};
let frag_map: HashMap<u32, &Fragment> = dataset
.fragments()
.iter()
.map(|f| (f.id as u32, f))
.collect();
let mut fragments_allow = RowAddrTreeMap::new();
for frag_id in fragments_covered.iter() {
let frag = frag_map.get(&frag_id).ok_or_else(|| {
Error::internal(format!(
"CountFromMaskExec: fragment {} not in manifest",
frag_id
))
})?;
let physical = frag.physical_rows.ok_or_else(|| {
Error::internal(format!(
"CountFromMaskExec: physical_rows missing for fragment {}",
frag_id
))
})?;
let mut bitmap = RoaringBitmap::new();
bitmap.insert_range(0u32..(physical as u32));
fragments_allow.insert_bitmap(frag_id, bitmap);
}
let deletion_mask =
match DatasetPreFilter::create_deletion_mask(dataset.clone(), fragments_covered) {
Some(fut) => Some(fut.await?),
None => None,
};
let combined = Self::combine_masks(fragments_allow, prefilter, deletion_mask);
let count = Self::count_from_mask(&combined, dataset.as_ref())?;
let arrays: Vec<Arc<dyn Array>> = (0..aggregate_funcs_len)
.map(|_| Arc::new(Int64Array::from(vec![count])) as Arc<dyn Array>)
.collect();
Ok(RecordBatch::try_new(schema, arrays)?)
}
}
impl ExecutionPlan for CountFromMaskExec {
fn name(&self) -> &str {
"CountFromMaskExec"
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
match &self.prefilter_input {
Some(input) => vec![input],
None => vec![],
}
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
let prefilter_input = match children.len() {
0 => None,
1 => Some(children.into_iter().next().unwrap()),
n => {
return Err(datafusion::error::DataFusionError::Internal(format!(
"CountFromMaskExec accepts 0 or 1 children, got {}",
n
)));
}
};
Ok(Arc::new(Self {
dataset: self.dataset.clone(),
aggregate_funcs: self.aggregate_funcs.clone(),
prefilter_input,
restrict_to_fragments: self.restrict_to_fragments.clone(),
schema: self.schema.clone(),
properties: self.properties.clone(),
metrics: self.metrics.clone(),
}))
}
fn execute(
&self,
partition: usize,
context: Arc<datafusion::execution::context::TaskContext>,
) -> datafusion::error::Result<datafusion::physical_plan::SendableRecordBatchStream> {
let schema = self.schema.clone();
let batch_fut = Self::do_execute(
self.dataset.clone(),
self.aggregate_funcs.len(),
self.prefilter_input.clone(),
self.restrict_to_fragments.clone(),
context,
schema.clone(),
);
let stream = futures::stream::iter(vec![batch_fut])
.then(|fut| async move { fut.await.map_err(|err| err.into()) })
.boxed();
Ok(Box::pin(InstrumentedRecordBatchStreamAdapter::new(
schema,
stream,
partition,
&self.metrics,
)))
}
fn partition_statistics(
&self,
_partition: Option<usize>,
) -> datafusion::error::Result<datafusion::physical_plan::Statistics> {
Ok(datafusion::physical_plan::Statistics {
num_rows: datafusion::common::stats::Precision::Exact(1),
..datafusion::physical_plan::Statistics::new_unknown(&self.schema)
})
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn supports_limit_pushdown(&self) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use std::{ops::Bound, sync::Arc};
use arrow::datatypes::{Int64Type, UInt64Type};
use datafusion::common::DFSchema;
use datafusion::execution::TaskContext;
use datafusion::functions_aggregate;
use datafusion::logical_expr::lit;
use datafusion::physical_expr::execution_props::ExecutionProps;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
use datafusion::scalar::ScalarValue;
use futures::TryStreamExt;
use lance_core::utils::tempfile::TempStrDir;
use lance_datagen::gen_batch;
use lance_index::IndexType;
use lance_index::scalar::{
SargableQuery, ScalarIndexParams,
expression::{ScalarIndexExpr, ScalarIndexSearch},
};
use lance_select::result::IndexExprResultWireFormat;
use lance_select::{RowAddrMask, RowAddrTreeMap};
use super::*;
use crate::Dataset;
use crate::index::DatasetIndexExt;
use crate::io::exec::scalar_index::ScalarIndexExec;
use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount};
fn count_star_expr(input_schema: &SchemaRef) -> Arc<AggregateFunctionExpr> {
let expr = functions_aggregate::count::count(lit(1));
let df_schema = DFSchema::try_from(input_schema.as_ref().clone()).unwrap();
let (agg_expr, _filter, _order_by) = create_aggregate_expr_and_maybe_filter(
&expr,
&df_schema,
input_schema.as_ref(),
&ExecutionProps::default(),
)
.unwrap();
agg_expr
}
struct Fixture {
dataset: Arc<Dataset>,
_tmp: TempStrDir,
}
async fn make_fixture() -> Fixture {
let tmp = TempStrDir::default();
let mut dataset = gen_batch()
.col("ordered", lance_datagen::array::step::<UInt64Type>())
.into_dataset(
tmp.as_str(),
FragmentCount::from(4),
FragmentRowCount::from(10),
)
.await
.unwrap();
dataset
.create_index(
&["ordered"],
IndexType::BTree,
None,
&ScalarIndexParams::default(),
true,
)
.await
.unwrap();
Fixture {
dataset: Arc::new(dataset),
_tmp: tmp,
}
}
fn input_schema() -> SchemaRef {
Arc::new(Schema::new(vec![arrow_schema::Field::new(
"ordered",
arrow_schema::DataType::UInt64,
false,
)]))
}
async fn run(plan: CountFromMaskExec) -> i64 {
let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].num_rows(), 1);
batches[0]
.column(0)
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<Int64Type>>()
.expect("count partial state should be Int64")
.value(0)
}
#[tokio::test]
async fn try_new_rejects_empty_aggregate_funcs() {
let fixture = make_fixture().await;
let err = CountFromMaskExec::try_new(fixture.dataset, vec![], None).unwrap_err();
assert!(err.to_string().contains("at least one aggregate"), "{err}");
}
#[tokio::test]
async fn count_from_mask_mixes_full_and_partial() {
let fixture = make_fixture().await;
let mut tm = RowAddrTreeMap::new();
tm.insert_fragment(0);
let row_addr_for = |frag_id: u32, offset: u32| ((frag_id as u64) << 32) | offset as u64;
tm.insert(row_addr_for(1, 0));
tm.insert(row_addr_for(1, 1));
tm.insert(row_addr_for(1, 2));
let mask = RowAddrMask::AllowList(tm);
let count = CountFromMaskExec::count_from_mask(&mask, fixture.dataset.as_ref()).unwrap();
assert_eq!(count, 10 + 3);
}
#[tokio::test]
async fn execute_count_no_prefilter() {
let fixture = make_fixture().await;
let schema = input_schema();
let plan = CountFromMaskExec::try_new(
fixture.dataset.clone(),
vec![count_star_expr(&schema)],
None,
)
.unwrap();
let count = run(plan).await;
assert_eq!(count, 40); }
#[tokio::test]
async fn execute_count_with_allow_list_prefilter() {
let fixture = make_fixture().await;
let schema = input_schema();
let prefilter_expr = ScalarIndexExpr::Query(ScalarIndexSearch {
column: "ordered".to_string(),
index_name: "ordered_idx".to_string(),
index_type: "BTree".to_string(),
query: Arc::new(SargableQuery::Range(
Bound::Unbounded,
Bound::Excluded(ScalarValue::UInt64(Some(25))),
)),
needs_recheck: false,
fragment_bitmap: None,
});
let prefilter: Arc<dyn ExecutionPlan> = Arc::new(ScalarIndexExec::new(
fixture.dataset.clone(),
prefilter_expr,
IndexExprResultWireFormat::default(),
));
let plan = CountFromMaskExec::try_new(
fixture.dataset.clone(),
vec![count_star_expr(&schema)],
Some(prefilter),
)
.unwrap();
let count = run(plan).await;
assert_eq!(count, 25);
}
#[tokio::test]
async fn execute_count_with_block_list_prefilter() {
let fixture = make_fixture().await;
let schema = input_schema();
let prefilter_expr =
ScalarIndexExpr::Not(Box::new(ScalarIndexExpr::Query(ScalarIndexSearch {
column: "ordered".to_string(),
index_name: "ordered_idx".to_string(),
index_type: "BTree".to_string(),
query: Arc::new(SargableQuery::Range(
Bound::Unbounded,
Bound::Excluded(ScalarValue::UInt64(Some(25))),
)),
needs_recheck: false,
fragment_bitmap: None,
})));
let prefilter: Arc<dyn ExecutionPlan> = Arc::new(ScalarIndexExec::new(
fixture.dataset.clone(),
prefilter_expr,
IndexExprResultWireFormat::default(),
));
let plan = CountFromMaskExec::try_new(
fixture.dataset.clone(),
vec![count_star_expr(&schema)],
Some(prefilter),
)
.unwrap();
let count = run(plan).await;
assert_eq!(count, 15);
}
#[tokio::test]
async fn execute_count_respects_deletions() {
let fixture = make_fixture().await;
let mut dataset = (*fixture.dataset).clone();
dataset.delete("ordered < 10").await.unwrap();
let dataset = Arc::new(dataset);
let schema = input_schema();
let plan =
CountFromMaskExec::try_new(dataset, vec![count_star_expr(&schema)], None).unwrap();
let count = run(plan).await;
assert_eq!(count, 30);
}
}