use std::sync::Arc;
use arrow::datatypes::*;
use arrow_array::{
ArrayRef, BinaryArray, BinaryViewArray, Float32Array, Float64Array, Int32Array,
LargeBinaryArray, LargeStringArray, RecordBatch, StringArray, StringViewArray,
};
use arrow_schema::DataType;
use lance::Dataset;
use lance::dataset::WriteParams;
use lance::dataset::optimize::{CompactionOptions, compact_files};
use lance_datagen::{ArrayGeneratorExt, RowCount, array, gen_batch};
use lance_index::{DatasetIndexExt, IndexType};
use super::{test_filter, test_scan, test_take};
use crate::utils::DatasetTestCases;
#[tokio::test]
async fn test_query_bool() {
let batch = gen_batch()
.col("id", array::step::<Int32Type>())
.col(
"value",
array::cycle_bool(vec![true, false]).with_random_nulls(0.1),
)
.into_batch_rows(RowCount::from(60))
.unwrap();
DatasetTestCases::from_data(batch)
.with_index_types(
"value",
[None],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value").await;
test_filter(&original, &ds, "NOT value").await;
})
.await
}
#[tokio::test]
#[rstest::rstest]
#[case::int8(DataType::Int8)]
#[case::int16(DataType::Int16)]
#[case::int32(DataType::Int32)]
#[case::int64(DataType::Int64)]
#[case::uint8(DataType::UInt8)]
#[case::uint16(DataType::UInt16)]
#[case::uint32(DataType::UInt32)]
#[case::uint64(DataType::UInt64)]
async fn test_query_integer(#[case] data_type: DataType) {
let batch = gen_batch()
.col("id", array::step::<Int32Type>())
.col("value", array::rand_type(&data_type).with_random_nulls(0.1))
.into_batch_rows(RowCount::from(60))
.unwrap();
DatasetTestCases::from_data(batch)
.with_index_types(
"value",
[
None,
Some(IndexType::Bitmap),
Some(IndexType::BTree),
Some(IndexType::BloomFilter),
Some(IndexType::ZoneMap),
],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value > 20").await;
test_filter(&original, &ds, "NOT (value > 20)").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
test_filter(&original, &ds, "(value != 0) OR (value < 20)").await;
test_filter(&original, &ds, "NOT ((value != 0) OR (value < 20))").await;
test_filter(
&original,
&ds,
"(value != 5) OR ((value != 52) OR (value IS NULL))",
)
.await;
test_filter(
&original,
&ds,
"NOT ((value != 5) OR ((value != 52) OR (value IS NULL)))",
)
.await;
})
.await
}
#[tokio::test]
async fn test_btree_nullable_or_with_absent_value() {
let value_array: Int32Array = (0..60)
.map(|i| if i % 3 == 0 { None } else { Some(100 + i) })
.collect();
let id_array = Int32Array::from((0..60).collect::<Vec<i32>>());
let batch = RecordBatch::try_from_iter(vec![
("id", Arc::new(id_array) as ArrayRef),
("value", Arc::new(value_array) as ArrayRef),
])
.unwrap();
DatasetTestCases::from_data(batch)
.with_index_types("value", [Some(IndexType::BTree)])
.run(|ds: Dataset, original: RecordBatch| async move {
test_filter(&original, &ds, "(value != 0) OR (value < 5)").await;
test_filter(&original, &ds, "NOT ((value != 0) OR (value < 5))").await;
test_filter(&original, &ds, "value != 0").await;
test_filter(&original, &ds, "NOT (value = 0)").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
})
.await;
}
#[tokio::test]
#[rstest::rstest]
#[case::float32(DataType::Float32)]
#[case::float64(DataType::Float64)]
async fn test_query_float(#[case] data_type: DataType) {
let batch = gen_batch()
.col("id", array::step::<Int32Type>())
.col("value", array::rand_type(&data_type).with_random_nulls(0.1))
.into_batch_rows(RowCount::from(60))
.unwrap();
DatasetTestCases::from_data(batch)
.with_index_types(
"value",
[
None,
Some(IndexType::BTree),
Some(IndexType::Bitmap),
Some(IndexType::BloomFilter),
Some(IndexType::ZoneMap),
],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value > 0.5").await;
test_filter(&original, &ds, "NOT (value > 0.5)").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
test_filter(&original, &ds, "isnan(value)").await;
test_filter(&original, &ds, "not isnan(value)").await;
})
.await
}
#[tokio::test]
#[rstest::rstest]
#[case::float32(DataType::Float32)]
#[case::float64(DataType::Float64)]
async fn test_query_float_special_values(#[case] data_type: DataType) {
let value_array: Arc<dyn arrow_array::Array> = match data_type {
DataType::Float32 => Arc::new(Float32Array::from(vec![
Some(0.0_f32),
Some(-0.0_f32),
Some(f32::INFINITY),
Some(f32::NEG_INFINITY),
Some(f32::NAN),
Some(1.0_f32),
Some(-1.0_f32),
Some(f32::MIN),
Some(f32::MAX),
None,
])),
DataType::Float64 => Arc::new(Float64Array::from(vec![
Some(0.0_f64),
Some(-0.0_f64),
Some(f64::INFINITY),
Some(f64::NEG_INFINITY),
Some(f64::NAN),
Some(1.0_f64),
Some(-1.0_f64),
Some(f64::MIN),
Some(f64::MAX),
None,
])),
_ => unreachable!(),
};
let id_array = Arc::new(Int32Array::from((0..10).collect::<Vec<i32>>()));
let batch =
RecordBatch::try_from_iter(vec![("id", id_array as ArrayRef), ("value", value_array)])
.unwrap();
DatasetTestCases::from_data(batch)
.with_index_types(
"value",
[
None,
Some(IndexType::BTree),
Some(IndexType::Bitmap),
Some(IndexType::BloomFilter),
Some(IndexType::ZoneMap),
],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value > 0.0").await;
test_filter(&original, &ds, "value < 0.0").await;
test_filter(&original, &ds, "value = 0.0").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
test_filter(&original, &ds, "isnan(value)").await;
test_filter(&original, &ds, "not isnan(value)").await;
})
.await
}
#[tokio::test]
#[rstest::rstest]
#[case::date32(DataType::Date32)]
#[case::date64(DataType::Date64)]
async fn test_query_date(#[case] data_type: DataType) {
let batch = gen_batch()
.col("id", array::step::<Int32Type>())
.col("value", array::rand_type(&data_type).with_random_nulls(0.1))
.into_batch_rows(RowCount::from(60))
.unwrap();
DatasetTestCases::from_data(batch)
.with_index_types(
"value",
[
None,
Some(IndexType::Bitmap),
Some(IndexType::BTree),
Some(IndexType::BloomFilter),
Some(IndexType::ZoneMap),
],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value < current_date()").await;
test_filter(&original, &ds, "value > DATE '2024-01-01'").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
})
.await
}
#[tokio::test]
#[rstest::rstest]
#[case::timestamp_second(DataType::Timestamp(TimeUnit::Second, None))]
#[case::timestamp_millisecond(DataType::Timestamp(TimeUnit::Millisecond, None))]
#[case::timestamp_microsecond(DataType::Timestamp(TimeUnit::Microsecond, None))]
#[case::timestamp_nanosecond(DataType::Timestamp(TimeUnit::Nanosecond, None))]
async fn test_query_timestamp(#[case] data_type: DataType) {
let batch = gen_batch()
.col("id", array::step::<Int32Type>())
.col("value", array::rand_type(&data_type).with_random_nulls(0.1))
.into_batch_rows(RowCount::from(60))
.unwrap();
DatasetTestCases::from_data(batch)
.with_index_types(
"value",
[
None,
Some(IndexType::BTree),
Some(IndexType::Bitmap),
Some(IndexType::BloomFilter),
Some(IndexType::ZoneMap),
],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value < current_timestamp()").await;
test_filter(&original, &ds, "value > TIMESTAMP '2024-01-01 00:00:00'").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
})
.await
}
#[tokio::test]
#[rstest::rstest]
#[case::utf8(DataType::Utf8)]
#[case::large_utf8(DataType::LargeUtf8)]
async fn test_query_string(#[case] data_type: DataType) {
let string_values = vec![
Some("hello"),
Some("world"),
Some(""),
Some("test"),
Some("data"),
Some(""),
None,
Some("apple"),
Some("zebra"),
Some(""),
];
let value_array: ArrayRef = match data_type {
DataType::Utf8 => Arc::new(StringArray::from(string_values.clone())),
DataType::LargeUtf8 => Arc::new(LargeStringArray::from(string_values.clone())),
DataType::Utf8View => Arc::new(StringViewArray::from(string_values.clone())),
_ => unreachable!(),
};
let id_array = Arc::new(Int32Array::from((0..10).collect::<Vec<i32>>()));
let batch =
RecordBatch::try_from_iter(vec![("id", id_array as ArrayRef), ("value", value_array)])
.unwrap();
DatasetTestCases::from_data(batch)
.with_index_types(
"value",
[
None,
Some(IndexType::Bitmap),
Some(IndexType::BTree),
Some(IndexType::BloomFilter),
Some(IndexType::ZoneMap),
],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value = 'hello'").await;
test_filter(&original, &ds, "value != 'hello'").await;
test_filter(&original, &ds, "value = ''").await;
test_filter(&original, &ds, "value > 'hello'").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
})
.await
}
#[tokio::test]
#[rstest::rstest]
#[case::binary(DataType::Binary)]
#[case::large_binary(DataType::LargeBinary)]
async fn test_query_binary(#[case] data_type: DataType) {
let binary_values = vec![
Some(b"hello".as_slice()),
Some(b"world".as_slice()),
Some(b"".as_slice()),
Some(b"test".as_slice()),
Some(b"data".as_slice()),
Some(b"".as_slice()),
None,
Some(b"apple".as_slice()),
Some(b"zebra".as_slice()),
Some(b"".as_slice()),
];
let value_array: ArrayRef = match data_type {
DataType::Binary => Arc::new(BinaryArray::from(binary_values.clone())),
DataType::LargeBinary => Arc::new(LargeBinaryArray::from(binary_values.clone())),
DataType::BinaryView => Arc::new(BinaryViewArray::from(binary_values.clone())),
_ => unreachable!(),
};
let id_array = Arc::new(Int32Array::from((0..10).collect::<Vec<i32>>()));
let batch =
RecordBatch::try_from_iter(vec![("id", id_array as ArrayRef), ("value", value_array)])
.unwrap();
DatasetTestCases::from_data(batch)
.with_index_types(
"value",
[
None,
Some(IndexType::Bitmap),
Some(IndexType::BTree),
Some(IndexType::BloomFilter),
Some(IndexType::ZoneMap),
],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value = X'68656C6C6F'").await; test_filter(&original, &ds, "value != X'68656C6C6F'").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
})
.await
}
#[tokio::test]
#[rstest::rstest]
#[case::decimal128(DataType::Decimal128(38, 10))]
#[case::decimal256(DataType::Decimal256(76, 20))]
async fn test_query_decimal(#[case] data_type: DataType) {
let batch = gen_batch()
.col("id", array::step::<Int32Type>())
.col("value", array::rand_type(&data_type).with_random_nulls(0.1))
.into_batch_rows(RowCount::from(60))
.unwrap();
DatasetTestCases::from_data(batch)
.with_index_types(
"value",
[None, Some(IndexType::Bitmap), Some(IndexType::BTree)],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value > 0").await;
test_filter(&original, &ds, "value < 0").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
})
.await
}
#[tokio::test]
async fn test_filtered_scan_after_compact_with_srid() {
use arrow::record_batch::RecordBatchIterator;
let batch = RecordBatch::try_from_iter(vec![(
"int_col",
Arc::new(Int32Array::from_iter_values(0..100)) as ArrayRef,
)])
.unwrap();
let schema = batch.schema();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let write_params = WriteParams {
enable_stable_row_ids: true,
max_rows_per_file: 50,
..Default::default()
};
let mut ds = Dataset::write(reader, "memory://compact_srid_test", Some(write_params))
.await
.unwrap();
assert_eq!(ds.get_fragments().len(), 2);
assert_eq!(ds.count_rows(None).await.unwrap(), 100);
ds.delete("int_col >= 60 AND int_col < 70").await.unwrap();
assert_eq!(ds.count_rows(None).await.unwrap(), 90);
compact_files(&mut ds, CompactionOptions::default(), None)
.await
.unwrap();
ds.create_index(
&["int_col"],
IndexType::BTree,
None,
&lance_index::scalar::ScalarIndexParams::default(),
true,
)
.await
.unwrap();
let results = ds
.scan()
.filter("int_col < 200")
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(
results.num_rows(),
90,
"Expected 90 rows (100 written - 10 deleted) but got {}",
results.num_rows()
);
}