use std::sync::Arc;
use arrow_array::cast::AsArray;
use arrow_array::types::{Float64Type, Int64Type};
use arrow_array::{
FixedSizeListArray, Float32Array, Int64Array, RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use datafusion_substrait::substrait::proto::{
AggregateFunction, AggregateRel, Expression, FunctionArgument, Plan, PlanRel, Rel, RelRoot,
SortField, Version,
aggregate_function::AggregationInvocation,
aggregate_rel::{Grouping, Measure},
expression::{
FieldReference, ReferenceSegment, RexType,
field_reference::{ReferenceType, RootReference, RootType},
reference_segment::{self, StructField},
},
extensions::{
SimpleExtensionDeclaration, SimpleExtensionUri,
simple_extension_declaration::{ExtensionFunction, MappingType},
},
function_argument::ArgType,
rel::RelType,
sort_field::SortKind,
};
use futures::TryStreamExt;
use lance_datafusion::exec::{LanceExecutionOptions, execute_plan};
use lance_datagen::{array, gen_batch};
use lance_table::format::Fragment;
use prost::Message;
use tempfile::tempdir;
use crate::Dataset;
use crate::dataset::scanner::AggregateExpr;
use crate::index::vector::VectorIndexParams;
use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount, assert_plan_node_equals};
use lance_arrow::FixedSizeListArrayExt;
use lance_index::scalar::FullTextSearchQuery;
use lance_index::scalar::inverted::InvertedIndexParams;
use lance_index::{DatasetIndexExt, IndexType};
use lance_linalg::distance::MetricType;
fn field_ref(field_index: i32) -> Expression {
Expression {
rex_type: Some(RexType::Selection(Box::new(FieldReference {
reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
reference_type: Some(reference_segment::ReferenceType::StructField(Box::new(
StructField {
field: field_index,
child: None,
},
))),
})),
root_type: Some(RootType::RootReference(RootReference {})),
}))),
}
}
fn create_aggregate_rel(
measures: Vec<Measure>,
grouping_expressions: Vec<Expression>,
groupings: Vec<Grouping>,
extensions: Vec<SimpleExtensionDeclaration>,
output_names: Vec<String>,
) -> Vec<u8> {
let aggregate_rel = AggregateRel {
common: None,
input: None, groupings,
measures,
grouping_expressions,
advanced_extension: None,
};
let rel = Rel {
rel_type: Some(RelType::Aggregate(Box::new(aggregate_rel))),
};
let plan = Plan {
version: Some(Version {
major_number: 0,
minor_number: 63,
patch_number: 0,
git_hash: String::new(),
producer: "lance-test".to_string(),
}),
#[allow(deprecated)]
extension_uris: vec![
SimpleExtensionUri {
extension_uri_anchor: 1,
uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml".to_string(),
},
SimpleExtensionUri {
extension_uri_anchor: 2,
uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml".to_string(),
},
],
extensions,
relations: vec![PlanRel {
rel_type: Some(datafusion_substrait::substrait::proto::plan_rel::RelType::Root(
RelRoot {
input: Some(rel),
names: output_names,
},
)),
}],
advanced_extensions: None,
expected_type_urls: vec![],
extension_urns: vec![],
parameter_bindings: vec![],
type_aliases: vec![],
};
plan.encode_to_vec()
}
fn agg_extension(anchor: u32, name: &str) -> SimpleExtensionDeclaration {
SimpleExtensionDeclaration {
mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction {
#[allow(deprecated)]
extension_uri_reference: 1,
extension_urn_reference: 0,
function_anchor: anchor,
name: name.to_string(),
})),
}
}
fn count_star_measure(function_ref: u32) -> Measure {
Measure {
measure: Some(AggregateFunction {
function_reference: function_ref,
arguments: vec![], options: vec![],
output_type: None,
phase: 0,
sorts: vec![],
invocation: AggregationInvocation::All as i32,
#[allow(deprecated)]
args: vec![],
}),
filter: None,
}
}
fn simple_agg_measure(function_ref: u32, column_index: i32) -> Measure {
Measure {
measure: Some(AggregateFunction {
function_reference: function_ref,
arguments: vec![FunctionArgument {
arg_type: Some(ArgType::Value(field_ref(column_index))),
}],
options: vec![],
output_type: None,
phase: 0,
sorts: vec![],
invocation: AggregationInvocation::All as i32,
#[allow(deprecated)]
args: vec![],
}),
filter: None,
}
}
fn ordered_agg_measure(
function_ref: u32,
column_index: i32,
sort_column_index: i32,
ascending: bool,
) -> Measure {
use datafusion_substrait::substrait::proto::sort_field::SortDirection;
let sort_direction = if ascending {
SortDirection::AscNullsLast
} else {
SortDirection::DescNullsLast
};
Measure {
measure: Some(AggregateFunction {
function_reference: function_ref,
arguments: vec![FunctionArgument {
arg_type: Some(ArgType::Value(field_ref(column_index))),
}],
options: vec![],
output_type: None,
phase: 0,
sorts: vec![SortField {
expr: Some(field_ref(sort_column_index)),
sort_kind: Some(SortKind::Direction(sort_direction as i32)),
}],
invocation: AggregationInvocation::All as i32,
#[allow(deprecated)]
args: vec![],
}),
filter: None,
}
}
async fn execute_aggregate(
dataset: &Dataset,
aggregate_bytes: &[u8],
) -> crate::Result<Vec<RecordBatch>> {
let mut scanner = dataset.scan();
scanner.aggregate(AggregateExpr::substrait(aggregate_bytes))?;
let plan = scanner.create_plan().await?;
let stream = execute_plan(plan, LanceExecutionOptions::default())?;
stream.try_collect().await.map_err(|e| e.into())
}
async fn execute_aggregate_on_fragments(
dataset: &Dataset,
aggregate_bytes: &[u8],
fragments: Vec<Fragment>,
) -> crate::Result<Vec<RecordBatch>> {
let mut scanner = dataset.scan();
scanner.with_fragments(fragments);
scanner.aggregate(AggregateExpr::substrait(aggregate_bytes))?;
let plan = scanner.create_plan().await?;
let stream = execute_plan(plan, LanceExecutionOptions::default())?;
stream.try_collect().await.map_err(|e| e.into())
}
async fn create_numeric_dataset(uri: &str, num_fragments: u32, rows_per_fragment: u32) -> Dataset {
gen_batch()
.col("x", array::step::<Int64Type>())
.col("y", array::step_custom::<Int64Type>(0, 2))
.col("category", array::cycle::<Int64Type>(vec![1, 2, 3]))
.into_dataset(
uri,
FragmentCount::from(num_fragments),
FragmentRowCount::from(rows_per_fragment),
)
.await
.unwrap()
}
#[tokio::test]
async fn test_count_star_single_fragment() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 1, 100).await;
let agg_bytes = create_aggregate_rel(
vec![count_star_measure(1)],
vec![],
vec![],
vec![agg_extension(1, "count")],
vec![],
);
let mut scanner = ds.scan();
scanner
.aggregate(AggregateExpr::substrait(agg_bytes.clone()))
.unwrap();
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"AggregateExec: mode=Single, gby=[], aggr=[count(...)]
LanceRead: uri=..., projection=[], num_fragments=1, range_before=None, range_after=None, row_id=false, row_addr=true, full_filter=--, refine_filter=--",
)
.await
.unwrap();
let results = execute_aggregate(&ds, &agg_bytes).await.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 1);
assert_eq!(batch.column(0).as_primitive::<Int64Type>().value(0), 100);
}
#[tokio::test]
async fn test_count_star_multiple_fragments() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 5, 100).await;
let agg_bytes = create_aggregate_rel(
vec![count_star_measure(1)],
vec![],
vec![],
vec![agg_extension(1, "count")],
vec![],
);
let results = execute_aggregate(&ds, &agg_bytes).await.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 1);
assert_eq!(batch.column(0).as_primitive::<Int64Type>().value(0), 500);
}
#[tokio::test]
async fn test_count_star_subset_fragments() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 5, 100).await;
let all_fragments = ds.get_fragments();
let subset: Vec<Fragment> = all_fragments
.into_iter()
.take(2)
.map(|f| f.metadata)
.collect();
let agg_bytes = create_aggregate_rel(
vec![count_star_measure(1)],
vec![],
vec![],
vec![agg_extension(1, "count")],
vec![],
);
let results = execute_aggregate_on_fragments(&ds, &agg_bytes, subset)
.await
.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 1);
assert_eq!(batch.column(0).as_primitive::<Int64Type>().value(0), 200);
}
#[tokio::test]
async fn test_sum_single_fragment() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 1, 100).await;
let agg_bytes = create_aggregate_rel(
vec![simple_agg_measure(1, 0)], vec![],
vec![],
vec![agg_extension(1, "sum")],
vec![],
);
let mut scanner = ds.scan();
scanner
.aggregate(AggregateExpr::substrait(agg_bytes.clone()))
.unwrap();
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"AggregateExec: mode=Single, gby=[], aggr=[sum(...)]
LanceRead: uri=..., projection=[x], num_fragments=1, range_before=None, range_after=None, row_id=false, row_addr=false, full_filter=--, refine_filter=--",
)
.await
.unwrap();
let results = execute_aggregate(&ds, &agg_bytes).await.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 1);
assert_eq!(batch.column(0).as_primitive::<Int64Type>().value(0), 4950);
}
#[tokio::test]
async fn test_sum_multiple_fragments() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 4, 25).await;
let agg_bytes = create_aggregate_rel(
vec![simple_agg_measure(1, 0)],
vec![],
vec![],
vec![agg_extension(1, "sum")],
vec![],
);
let results = execute_aggregate(&ds, &agg_bytes).await.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.column(0).as_primitive::<Int64Type>().value(0), 4950);
}
#[tokio::test]
async fn test_min_max() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 4, 25).await;
let agg_bytes = create_aggregate_rel(
vec![
simple_agg_measure(1, 0), simple_agg_measure(2, 0), ],
vec![],
vec![],
vec![agg_extension(1, "min"), agg_extension(2, "max")],
vec![],
);
let results = execute_aggregate(&ds, &agg_bytes).await.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 1);
assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.column(0).as_primitive::<Int64Type>().value(0), 0);
assert_eq!(batch.column(1).as_primitive::<Int64Type>().value(0), 99);
}
#[tokio::test]
async fn test_avg() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 4, 25).await;
let agg_bytes = create_aggregate_rel(
vec![simple_agg_measure(1, 0)],
vec![],
vec![],
vec![agg_extension(1, "avg")],
vec![],
);
let results = execute_aggregate(&ds, &agg_bytes).await.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
let avg = batch.column(0).as_primitive::<Float64Type>().value(0);
assert!((avg - 49.5).abs() < 0.001);
}
#[tokio::test]
async fn test_multiple_aggregates() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 4, 25).await;
let agg_bytes = create_aggregate_rel(
vec![
count_star_measure(1),
simple_agg_measure(2, 0), simple_agg_measure(3, 0), simple_agg_measure(4, 0), simple_agg_measure(5, 0), ],
vec![],
vec![],
vec![
agg_extension(1, "count"),
agg_extension(2, "sum"),
agg_extension(3, "min"),
agg_extension(4, "max"),
agg_extension(5, "avg"),
],
vec![],
);
let results = execute_aggregate(&ds, &agg_bytes).await.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 1);
assert_eq!(batch.num_columns(), 5);
assert_eq!(batch.column(0).as_primitive::<Int64Type>().value(0), 100); assert_eq!(batch.column(1).as_primitive::<Int64Type>().value(0), 4950); assert_eq!(batch.column(2).as_primitive::<Int64Type>().value(0), 0); assert_eq!(batch.column(3).as_primitive::<Int64Type>().value(0), 99); let avg = batch.column(4).as_primitive::<Float64Type>().value(0);
assert!((avg - 49.5).abs() < 0.001); }
#[tokio::test]
async fn test_group_by_with_count() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 4, 30).await;
let agg_bytes = create_aggregate_rel(
vec![count_star_measure(1)],
vec![field_ref(2)], vec![Grouping {
#[allow(deprecated)]
grouping_expressions: vec![],
expression_references: vec![0], }],
vec![agg_extension(1, "count")],
vec![],
);
let mut scanner = ds.scan();
scanner
.aggregate(AggregateExpr::substrait(agg_bytes.clone()))
.unwrap();
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"AggregateExec: mode=Single, gby=[category@0 as category], aggr=[count(...)]
LanceRead: uri=..., projection=[category], num_fragments=4, range_before=None, range_after=None, row_id=false, row_addr=false, full_filter=--, refine_filter=--",
)
.await
.unwrap();
let results = execute_aggregate(&ds, &agg_bytes).await.unwrap();
assert!(!results.is_empty());
let batch = arrow::compute::concat_batches(&results[0].schema(), &results).unwrap();
assert_eq!(batch.num_rows(), 3);
let counts: Vec<i64> = batch
.column(1) .as_primitive::<Int64Type>()
.values()
.to_vec();
for count in counts {
assert_eq!(count, 40);
}
}
#[tokio::test]
async fn test_group_by_with_sum() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 1, 9).await;
let agg_bytes = create_aggregate_rel(
vec![simple_agg_measure(1, 0)], vec![field_ref(2)], vec![Grouping {
#[allow(deprecated)]
grouping_expressions: vec![],
expression_references: vec![0],
}],
vec![agg_extension(1, "sum")],
vec![],
);
let results = execute_aggregate(&ds, &agg_bytes).await.unwrap();
assert!(!results.is_empty());
let batch = arrow::compute::concat_batches(&results[0].schema(), &results).unwrap();
assert_eq!(batch.num_rows(), 3);
let categories: Vec<i64> = batch
.column(0) .as_primitive::<Int64Type>()
.values()
.to_vec();
let sums: Vec<i64> = batch
.column(1) .as_primitive::<Int64Type>()
.values()
.to_vec();
let mut results_map = std::collections::HashMap::new();
for (cat, sum) in categories.iter().zip(sums.iter()) {
results_map.insert(*cat, *sum);
}
assert_eq!(results_map.get(&1), Some(&9));
assert_eq!(results_map.get(&2), Some(&12));
assert_eq!(results_map.get(&3), Some(&15));
}
#[tokio::test]
async fn test_aggregate_specific_fragments() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 10, 10).await;
let all_fragments = ds.get_fragments();
let subset: Vec<Fragment> = all_fragments
.into_iter()
.enumerate()
.filter(|(i, _)| *i == 3 || *i == 5 || *i == 7)
.map(|(_, f)| f.metadata)
.collect();
let agg_bytes = create_aggregate_rel(
vec![count_star_measure(1)],
vec![],
vec![],
vec![agg_extension(1, "count")],
vec![],
);
let results = execute_aggregate_on_fragments(&ds, &agg_bytes, subset)
.await
.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.column(0).as_primitive::<Int64Type>().value(0), 30);
}
#[tokio::test]
async fn test_sum_specific_fragments() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 4, 10).await;
let all_fragments = ds.get_fragments();
let subset: Vec<Fragment> = all_fragments
.into_iter()
.enumerate()
.filter(|(i, _)| *i == 1 || *i == 2)
.map(|(_, f)| f.metadata)
.collect();
let agg_bytes = create_aggregate_rel(
vec![simple_agg_measure(1, 0)], vec![],
vec![],
vec![agg_extension(1, "sum")],
vec![],
);
let results = execute_aggregate_on_fragments(&ds, &agg_bytes, subset)
.await
.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.column(0).as_primitive::<Int64Type>().value(0), 390);
}
#[tokio::test]
async fn test_aggregate_with_filter() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 1, 100).await;
let mut scanner = ds.scan();
scanner.filter("x >= 50").unwrap();
let agg_bytes = create_aggregate_rel(
vec![
count_star_measure(1),
simple_agg_measure(2, 0), simple_agg_measure(3, 0), simple_agg_measure(4, 0), ],
vec![],
vec![],
vec![
agg_extension(1, "count"),
agg_extension(2, "sum"),
agg_extension(3, "min"),
agg_extension(4, "max"),
],
vec![],
);
scanner
.aggregate(AggregateExpr::substrait(agg_bytes))
.unwrap();
let plan = scanner.create_plan().await.unwrap();
let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap();
let results: Vec<RecordBatch> = stream.try_collect().await.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 1);
assert_eq!(batch.column(0).as_primitive::<Int64Type>().value(0), 50); assert_eq!(batch.column(1).as_primitive::<Int64Type>().value(0), 3725); assert_eq!(batch.column(2).as_primitive::<Int64Type>().value(0), 50); assert_eq!(batch.column(3).as_primitive::<Int64Type>().value(0), 99); }
#[tokio::test]
async fn test_aggregate_empty_result() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 1, 100).await;
let mut scanner = ds.scan();
scanner.project::<&str>(&[]).unwrap();
scanner.with_row_id();
scanner.filter("x > 1000").unwrap();
let agg_bytes = create_aggregate_rel(
vec![count_star_measure(1)],
vec![],
vec![],
vec![agg_extension(1, "count")],
vec![],
);
scanner
.aggregate(AggregateExpr::substrait(agg_bytes))
.unwrap();
let plan = scanner.create_plan().await.unwrap();
let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap();
let results: Vec<RecordBatch> = stream.try_collect().await.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 1);
assert_eq!(batch.column(0).as_primitive::<Int64Type>().value(0), 0);
}
#[tokio::test]
async fn test_aggregate_single_row() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"x",
DataType::Int64,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(arrow_array::Int64Array::from(vec![42]))],
)
.unwrap();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let ds = Dataset::write(reader, uri, None).await.unwrap();
let agg_bytes = create_aggregate_rel(
vec![
count_star_measure(1),
simple_agg_measure(2, 0), simple_agg_measure(3, 0), simple_agg_measure(4, 0), ],
vec![],
vec![],
vec![
agg_extension(1, "count"),
agg_extension(2, "sum"),
agg_extension(3, "min"),
agg_extension(4, "max"),
],
vec![],
);
let results = execute_aggregate(&ds, &agg_bytes).await.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.column(0).as_primitive::<Int64Type>().value(0), 1); assert_eq!(batch.column(1).as_primitive::<Int64Type>().value(0), 42); assert_eq!(batch.column(2).as_primitive::<Int64Type>().value(0), 42); assert_eq!(batch.column(3).as_primitive::<Int64Type>().value(0), 42); }
#[tokio::test]
async fn test_aggregate_with_aliases() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 1, 100).await;
let agg_bytes = create_aggregate_rel(
vec![
count_star_measure(1),
simple_agg_measure(2, 0),
simple_agg_measure(3, 0),
],
vec![],
vec![],
vec![
agg_extension(1, "count"),
agg_extension(2, "sum"),
agg_extension(3, "min"),
],
vec![
"total_count".to_string(),
"sum_of_x".to_string(),
"min_x".to_string(),
],
);
let results = execute_aggregate(&ds, &agg_bytes).await.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
let schema = batch.schema();
assert_eq!(schema.fields().len(), 3);
assert_eq!(schema.field(0).name(), "total_count");
assert_eq!(schema.field(1).name(), "sum_of_x");
assert_eq!(schema.field(2).name(), "min_x");
assert_eq!(batch.column(0).as_primitive::<Int64Type>().value(0), 100);
assert_eq!(batch.column(1).as_primitive::<Int64Type>().value(0), 4950);
assert_eq!(batch.column(2).as_primitive::<Int64Type>().value(0), 0);
}
#[tokio::test]
async fn test_group_by_with_aliases() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 1, 9).await;
let agg_bytes = create_aggregate_rel(
vec![simple_agg_measure(1, 0)],
vec![field_ref(2)],
vec![Grouping {
#[allow(deprecated)]
grouping_expressions: vec![],
expression_references: vec![0],
}],
vec![agg_extension(1, "sum")],
vec!["group_key".to_string(), "total_sum".to_string()],
);
let results = execute_aggregate(&ds, &agg_bytes).await.unwrap();
assert!(!results.is_empty());
let batch = arrow::compute::concat_batches(&results[0].schema(), &results).unwrap();
let schema = batch.schema();
assert_eq!(schema.fields().len(), 2);
assert_eq!(schema.field(0).name(), "group_key");
assert_eq!(schema.field(1).name(), "total_sum");
}
#[tokio::test]
async fn test_first_value_with_order_by() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 1, 9).await;
let agg_bytes = create_aggregate_rel(
vec![ordered_agg_measure(1, 0, 0, true)], vec![field_ref(2)], vec![Grouping {
#[allow(deprecated)]
grouping_expressions: vec![],
expression_references: vec![0],
}],
vec![agg_extension(1, "first_value")],
vec![],
);
let results = execute_aggregate(&ds, &agg_bytes).await.unwrap();
assert!(!results.is_empty());
let batch = arrow::compute::concat_batches(&results[0].schema(), &results).unwrap();
assert_eq!(batch.num_rows(), 3);
let categories: Vec<i64> = batch
.column(0)
.as_primitive::<Int64Type>()
.values()
.to_vec();
let first_values: Vec<i64> = batch
.column(1)
.as_primitive::<Int64Type>()
.values()
.to_vec();
let mut results_map = std::collections::HashMap::new();
for (cat, val) in categories.iter().zip(first_values.iter()) {
results_map.insert(*cat, *val);
}
assert_eq!(results_map.get(&1), Some(&0));
assert_eq!(results_map.get(&2), Some(&1));
assert_eq!(results_map.get(&3), Some(&2));
}
#[tokio::test]
async fn test_first_value_with_order_by_desc() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let ds = create_numeric_dataset(uri, 1, 9).await;
let agg_bytes = create_aggregate_rel(
vec![ordered_agg_measure(1, 0, 0, false)], vec![field_ref(2)], vec![Grouping {
#[allow(deprecated)]
grouping_expressions: vec![],
expression_references: vec![0],
}],
vec![agg_extension(1, "first_value")],
vec![],
);
let results = execute_aggregate(&ds, &agg_bytes).await.unwrap();
assert!(!results.is_empty());
let batch = arrow::compute::concat_batches(&results[0].schema(), &results).unwrap();
assert_eq!(batch.num_rows(), 3);
let categories: Vec<i64> = batch
.column(0)
.as_primitive::<Int64Type>()
.values()
.to_vec();
let first_values: Vec<i64> = batch
.column(1)
.as_primitive::<Int64Type>()
.values()
.to_vec();
let mut results_map = std::collections::HashMap::new();
for (cat, val) in categories.iter().zip(first_values.iter()) {
results_map.insert(*cat, *val);
}
assert_eq!(results_map.get(&1), Some(&6));
assert_eq!(results_map.get(&2), Some(&7));
assert_eq!(results_map.get(&3), Some(&8));
}
async fn create_vector_text_dataset(uri: &str, num_rows: i64) -> Dataset {
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new(
"vec",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
true,
),
Field::new("text", DataType::Utf8, false),
Field::new("category", DataType::Utf8, false),
]));
let ids: Vec<i64> = (0..num_rows).collect();
let vectors: Vec<f32> = (0..num_rows).flat_map(|i| vec![i as f32; 4]).collect();
let texts: Vec<String> = (0..num_rows).map(|i| format!("document {}", i)).collect();
let categories: Vec<String> = (0..num_rows)
.map(|i| match i % 3 {
0 => "category_a".to_string(),
1 => "category_b".to_string(),
_ => "category_c".to_string(),
})
.collect();
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int64Array::from(ids)),
Arc::new(
FixedSizeListArray::try_new_from_values(Float32Array::from(vectors), 4).unwrap(),
),
Arc::new(StringArray::from(texts)),
Arc::new(StringArray::from(categories)),
],
)
.unwrap();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
Dataset::write(reader, uri, None).await.unwrap()
}
#[tokio::test]
async fn test_vector_search_with_aggregate() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let mut dataset = create_vector_text_dataset(uri, 100).await;
let params = VectorIndexParams::ivf_flat(2, MetricType::L2);
dataset
.create_index(&["vec"], IndexType::Vector, None, ¶ms, true)
.await
.unwrap();
let query_vector = Float32Array::from(vec![50.0f32, 50.0, 50.0, 50.0]);
let agg_bytes = create_aggregate_rel(
vec![count_star_measure(1)],
vec![field_ref(3)], vec![Grouping {
#[allow(deprecated)]
grouping_expressions: vec![],
expression_references: vec![0],
}],
vec![agg_extension(1, "count")],
vec!["category".to_string(), "count".to_string()],
);
let mut scanner = dataset.scan();
scanner
.nearest("vec", &query_vector, 30)
.unwrap()
.project(&["id", "category"])
.unwrap()
.aggregate(AggregateExpr::substrait(agg_bytes))
.unwrap();
let results = scanner.try_into_batch().await.unwrap();
assert!(
results.num_rows() >= 1 && results.num_rows() <= 3,
"Expected 1-3 rows but got {}",
results.num_rows()
);
let counts: Vec<i64> = results
.column(1)
.as_primitive::<Int64Type>()
.values()
.to_vec();
let total: i64 = counts.iter().sum();
assert_eq!(total, 30);
}
#[tokio::test]
async fn test_fts_with_aggregate() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let mut dataset = create_vector_text_dataset(uri, 100).await;
dataset
.create_index(
&["text"],
IndexType::Inverted,
None,
&InvertedIndexParams::default(),
true,
)
.await
.unwrap();
let agg_bytes = create_aggregate_rel(
vec![count_star_measure(1)],
vec![field_ref(3)], vec![Grouping {
#[allow(deprecated)]
grouping_expressions: vec![],
expression_references: vec![0],
}],
vec![agg_extension(1, "count")],
vec!["category".to_string(), "count".to_string()],
);
let mut scanner = dataset.scan();
scanner
.full_text_search(FullTextSearchQuery::new("document".to_string()))
.unwrap()
.project(&["id", "category"])
.unwrap()
.aggregate(AggregateExpr::substrait(agg_bytes))
.unwrap();
let results = scanner.try_into_batch().await.unwrap();
assert_eq!(
results.num_rows(),
3,
"Expected 3 rows but got {}",
results.num_rows()
);
let counts: Vec<i64> = results
.column(1)
.as_primitive::<Int64Type>()
.values()
.to_vec();
let total: i64 = counts.iter().sum();
assert_eq!(total, 100);
for count in &counts {
assert!(*count >= 33 && *count <= 34);
}
}
#[tokio::test]
async fn test_vector_search_with_sum_aggregate() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let mut dataset = create_vector_text_dataset(uri, 100).await;
let params = VectorIndexParams::ivf_flat(2, MetricType::L2);
dataset
.create_index(&["vec"], IndexType::Vector, None, ¶ms, true)
.await
.unwrap();
let query_vector = Float32Array::from(vec![50.0f32, 50.0, 50.0, 50.0]);
let agg_bytes = create_aggregate_rel(
vec![simple_agg_measure(1, 0)], vec![field_ref(3)], vec![Grouping {
#[allow(deprecated)]
grouping_expressions: vec![],
expression_references: vec![0],
}],
vec![agg_extension(1, "sum")],
vec!["category".to_string(), "sum_id".to_string()],
);
let mut scanner = dataset.scan();
scanner
.nearest("vec", &query_vector, 10)
.unwrap()
.project(&["id", "category"])
.unwrap()
.aggregate(AggregateExpr::substrait(agg_bytes))
.unwrap();
let results = scanner.try_into_batch().await.unwrap();
assert!(
results.num_rows() >= 1 && results.num_rows() <= 3,
"Expected 1-3 rows but got {}",
results.num_rows()
);
assert_eq!(results.num_columns(), 2);
}
#[tokio::test]
async fn test_scanner_count_rows() {
let ds = create_numeric_dataset("memory://test_count_rows", 2, 50).await;
let mut scanner = ds.scan();
scanner
.aggregate(AggregateExpr::builder().count_star().build())
.unwrap();
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan.clone(),
"AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))]
LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, row_id=false, row_addr=true, full_filter=--, refine_filter=--",
)
.await
.unwrap();
let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
assert_eq!(batches.len(), 1);
assert_eq!(
batches[0].column(0).as_primitive::<Int64Type>().value(0),
100 );
}
#[tokio::test]
async fn test_scanner_count_rows_with_filter() {
let ds = create_numeric_dataset("memory://test_count_rows_filter", 1, 100).await;
let mut scanner = ds.scan();
scanner.filter("x >= 50").unwrap();
scanner
.aggregate(AggregateExpr::builder().count_star().build())
.unwrap();
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan.clone(),
"AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))]
LanceRead: uri=..., projection=[x], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=x >= Int64(50), refine_filter=x >= Int64(50)",
)
.await
.unwrap();
let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
assert_eq!(batches.len(), 1);
assert_eq!(
batches[0].column(0).as_primitive::<Int64Type>().value(0),
50
);
}
#[tokio::test]
async fn test_scanner_count_rows_empty_result() {
let ds = create_numeric_dataset("memory://test_count_rows_empty", 1, 100).await;
let mut scanner = ds.scan();
scanner.filter("x > 1000").unwrap(); let count = scanner.count_rows().await.unwrap();
assert_eq!(count, 0);
}
#[tokio::test]
async fn test_scanner_count_rows_with_vector_search() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let mut dataset = create_vector_text_dataset(uri, 100).await;
let params = VectorIndexParams::ivf_flat(2, MetricType::L2);
dataset
.create_index(&["vec"], IndexType::Vector, None, ¶ms, true)
.await
.unwrap();
let query_vector = Float32Array::from(vec![50.0f32, 50.0, 50.0, 50.0]);
let mut scanner = dataset.scan();
scanner.nearest("vec", &query_vector, 30).unwrap();
scanner
.aggregate(AggregateExpr::builder().count_star().build())
.unwrap();
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan.clone(),
"AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))]
SortExec: TopK(fetch=30), ...
ANNSubIndex: ...
ANNIvfPartition: ...deltas=1",
)
.await
.unwrap();
let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
assert_eq!(batches.len(), 1);
assert_eq!(
batches[0].column(0).as_primitive::<Int64Type>().value(0),
30 );
}
#[tokio::test]
async fn test_scanner_count_rows_with_fts() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let mut dataset = create_vector_text_dataset(uri, 100).await;
dataset
.create_index(
&["text"],
IndexType::Inverted,
None,
&InvertedIndexParams::default(),
true,
)
.await
.unwrap();
let mut scanner = dataset.scan();
scanner
.full_text_search(FullTextSearchQuery::new("document".to_string()))
.unwrap();
scanner
.aggregate(AggregateExpr::builder().count_star().build())
.unwrap();
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan.clone(),
"AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))]
MatchQuery: column=text, query=document",
)
.await
.unwrap();
let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
assert_eq!(batches.len(), 1);
assert_eq!(
batches[0].column(0).as_primitive::<Int64Type>().value(0),
100
);
}
#[tokio::test]
async fn test_scanner_count_rows_with_vector_search_and_filter() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let mut dataset = create_vector_text_dataset(uri, 100).await;
let params = VectorIndexParams::ivf_flat(2, MetricType::L2);
dataset
.create_index(&["vec"], IndexType::Vector, None, ¶ms, true)
.await
.unwrap();
let query_vector = Float32Array::from(vec![50.0f32, 50.0, 50.0, 50.0]);
let mut scanner = dataset.scan();
scanner
.nearest("vec", &query_vector, 50)
.unwrap()
.filter("category = 'category_a'")
.unwrap();
let count = scanner.count_rows().await.unwrap();
assert!(count > 0 && count <= 50);
}