use std::{
any::Any,
cmp::Ordering,
collections::{BTreeMap, BinaryHeap, HashMap, HashSet},
fmt::{Debug, Display},
ops::Bound,
sync::Arc,
};
use super::{
AnyQuery, BuiltinIndexType, IndexReader, IndexStore, IndexWriter, MetricsCollector,
OldIndexDataFilter, SargableQuery, ScalarIndex, ScalarIndexParams, SearchResult,
compute_next_prefix,
};
use crate::{Index, IndexType};
use crate::{
frag_reuse::FragReuseIndex,
scalar::{
CreatedIndex, UpdateCriteria,
expression::{SargableQueryParser, ScalarQueryParser},
registry::{ScalarIndexPlugin, TrainingOrdering, TrainingRequest, VALUE_COLUMN_NAME},
},
};
use crate::{metrics::NoOpMetricsCollector, scalar::registry::TrainingCriteria};
use crate::{pbold, scalar::btree::flat::FlatIndex};
use arrow_arith::numeric::add;
use arrow_array::{Array, RecordBatch, UInt32Array, new_empty_array};
use arrow_schema::{DataType, Field, Schema, SortOptions};
use async_trait::async_trait;
use datafusion::physical_plan::{
ExecutionPlan, SendableRecordBatchStream,
sorts::sort_preserving_merge::SortPreservingMergeExec, stream::RecordBatchStreamAdapter,
union::UnionExec,
};
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_physical_expr::{PhysicalSortExpr, expressions::Column};
use deepsize::DeepSizeOf;
use futures::{
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
future::BoxFuture,
stream::{self},
};
use lance_core::{
Error, ROW_ID, Result,
cache::{CacheKey, LanceCache, WeakLanceCache},
error::LanceOptionExt,
utils::{
mask::NullableRowAddrSet,
tokio::get_num_compute_intensive_cpus,
tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS},
},
};
use lance_datafusion::{
chunker::chunk_concat_stream,
exec::{LanceExecutionOptions, OneShotExec, execute_plan},
};
use lance_io::object_store::ObjectStore;
use log::{debug, warn};
use object_store::path::Path;
use rangemap::RangeInclusiveMap;
use roaring::RoaringBitmap;
use serde::{Deserialize, Serialize, Serializer};
use tracing::{info, instrument};
mod flat;
const BTREE_LOOKUP_NAME: &str = "page_lookup.lance";
const BTREE_PAGES_NAME: &str = "page_data.lance";
pub const DEFAULT_BTREE_BATCH_SIZE: u64 = 4096;
const BATCH_SIZE_META_KEY: &str = "batch_size";
const DEFAULT_RANGE_PARTITIONED: bool = false;
const RANGE_PARTITIONED_META_KEY: &str = "range_partitioned";
const PAGE_NUM_PER_RANGE_PARTITION_META_KEY: &str = "page_num_per_range_partition";
const BTREE_INDEX_VERSION: u32 = 0;
pub(crate) const BTREE_VALUES_COLUMN: &str = "values";
pub(crate) const BTREE_IDS_COLUMN: &str = "ids";
#[derive(Clone, Debug)]
pub struct OrderableScalarValue(pub ScalarValue);
impl DeepSizeOf for OrderableScalarValue {
fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
self.0.size() - std::mem::size_of::<ScalarValue>()
}
}
impl Display for OrderableScalarValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)
}
}
impl PartialEq for OrderableScalarValue {
fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0)
}
}
impl Eq for OrderableScalarValue {}
impl PartialOrd for OrderableScalarValue {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrderableScalarValue {
fn cmp(&self, other: &Self) -> Ordering {
use ScalarValue::*;
match (&self.0, &other.0) {
(Decimal32(v1, p1, s1), Decimal32(v2, p2, s2)) => {
if p1.eq(p2) && s1.eq(s2) {
v1.cmp(v2)
} else {
panic!("Attempt to compare decimals with unequal precision / scale")
}
}
(Decimal32(v1, _, _), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Decimal32(_, _, _), _) => panic!("Attempt to compare decimal with non-decimal"),
(Decimal64(v1, p1, s1), Decimal64(v2, p2, s2)) => {
if p1.eq(p2) && s1.eq(s2) {
v1.cmp(v2)
} else {
panic!("Attempt to compare decimals with unequal precision / scale")
}
}
(Decimal64(v1, _, _), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Decimal64(_, _, _), _) => panic!("Attempt to compare decimal with non-decimal"),
(Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => {
if p1.eq(p2) && s1.eq(s2) {
v1.cmp(v2)
} else {
panic!("Attempt to compare decimals with unequal precision / scale")
}
}
(Decimal128(v1, _, _), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Decimal128(_, _, _), _) => panic!("Attempt to compare decimal with non-decimal"),
(Decimal256(v1, p1, s1), Decimal256(v2, p2, s2)) => {
if p1.eq(p2) && s1.eq(s2) {
v1.cmp(v2)
} else {
panic!("Attempt to compare decimals with unequal precision / scale")
}
}
(Decimal256(v1, _, _), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Decimal256(_, _, _), _) => panic!("Attempt to compare decimal with non-decimal"),
(Boolean(v1), Boolean(v2)) => v1.cmp(v2),
(Boolean(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Boolean(_), _) => panic!("Attempt to compare boolean with non-boolean"),
(Float32(v1), Float32(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => f1.total_cmp(f2),
(None, Some(_)) => Ordering::Less,
(Some(_), None) => Ordering::Greater,
(None, None) => Ordering::Equal,
},
(Float32(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Float32(_), _) => panic!("Attempt to compare f32 with non-f32"),
(Float64(v1), Float64(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => f1.total_cmp(f2),
(None, Some(_)) => Ordering::Less,
(Some(_), None) => Ordering::Greater,
(None, None) => Ordering::Equal,
},
(Float64(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Float64(_), _) => panic!("Attempt to compare f64 with non-f64"),
(Float16(v1), Float16(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => f1.total_cmp(f2),
(None, Some(_)) => Ordering::Less,
(Some(_), None) => Ordering::Greater,
(None, None) => Ordering::Equal,
},
(Float16(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Float16(_), _) => panic!("Attempt to compare f16 with non-f16"),
(Int8(v1), Int8(v2)) => v1.cmp(v2),
(Int8(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Int8(_), _) => panic!("Attempt to compare Int8 with non-Int8"),
(Int16(v1), Int16(v2)) => v1.cmp(v2),
(Int16(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Int16(_), _) => panic!("Attempt to compare Int16 with non-Int16"),
(Int32(v1), Int32(v2)) => v1.cmp(v2),
(Int32(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Int32(_), _) => panic!("Attempt to compare Int32 with non-Int32"),
(Int64(v1), Int64(v2)) => v1.cmp(v2),
(Int64(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Int64(_), _) => panic!("Attempt to compare Int64 with non-Int64"),
(UInt8(v1), UInt8(v2)) => v1.cmp(v2),
(UInt8(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(UInt8(_), _) => panic!("Attempt to compare UInt8 with non-UInt8"),
(UInt16(v1), UInt16(v2)) => v1.cmp(v2),
(UInt16(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(UInt16(_), _) => panic!("Attempt to compare UInt16 with non-UInt16"),
(UInt32(v1), UInt32(v2)) => v1.cmp(v2),
(UInt32(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(UInt32(_), _) => panic!("Attempt to compare UInt32 with non-UInt32"),
(UInt64(v1), UInt64(v2)) => v1.cmp(v2),
(UInt64(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(UInt64(_), _) => panic!("Attempt to compare UInt64 with non-UInt64"),
(Utf8(v1) | Utf8View(v1) | LargeUtf8(v1), Utf8(v2) | Utf8View(v2) | LargeUtf8(v2)) => {
v1.cmp(v2)
}
(Utf8(v1) | Utf8View(v1) | LargeUtf8(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Utf8(_) | Utf8View(_) | LargeUtf8(_), _) => {
panic!("Attempt to compare Utf8 with non-Utf8")
}
(
Binary(v1) | LargeBinary(v1) | BinaryView(v1),
Binary(v2) | LargeBinary(v2) | BinaryView(v2),
) => v1.cmp(v2),
(Binary(v1) | LargeBinary(v1) | BinaryView(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Binary(_) | LargeBinary(_) | BinaryView(_), _) => {
panic!("Attempt to compare Binary with non-Binary")
}
(FixedSizeBinary(_, v1), FixedSizeBinary(_, v2)) => v1.cmp(v2),
(FixedSizeBinary(_, v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(FixedSizeBinary(_, _), _) => {
panic!("Attempt to compare FixedSizeBinary with non-FixedSizeBinary")
}
(FixedSizeList(left), FixedSizeList(right)) => {
if left.eq(right) {
todo!()
} else {
panic!(
"Attempt to compare fixed size list elements with different widths/fields"
)
}
}
(FixedSizeList(left), Null) => {
if left.is_null(0) {
Ordering::Equal
} else {
Ordering::Greater
}
}
(FixedSizeList(_), _) => {
panic!("Attempt to compare FixedSizeList with non-FixedSizeList")
}
(List(_), List(_)) => todo!(),
(List(left), Null) => {
if left.is_null(0) {
Ordering::Equal
} else {
Ordering::Greater
}
}
(List(_), _) => {
panic!("Attempt to compare List with non-List")
}
(LargeList(_), _) => todo!(),
(Map(_), Map(_)) => todo!(),
(Map(left), Null) => {
if left.is_null(0) {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Map(_), _) => {
panic!("Attempt to compare Map with non-Map")
}
(Date32(v1), Date32(v2)) => v1.cmp(v2),
(Date32(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Date32(_), _) => panic!("Attempt to compare Date32 with non-Date32"),
(Date64(v1), Date64(v2)) => v1.cmp(v2),
(Date64(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Date64(_), _) => panic!("Attempt to compare Date64 with non-Date64"),
(Time32Second(v1), Time32Second(v2)) => v1.cmp(v2),
(Time32Second(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Time32Second(_), _) => panic!("Attempt to compare Time32Second with non-Time32Second"),
(Time32Millisecond(v1), Time32Millisecond(v2)) => v1.cmp(v2),
(Time32Millisecond(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Time32Millisecond(_), _) => {
panic!("Attempt to compare Time32Millisecond with non-Time32Millisecond")
}
(Time64Microsecond(v1), Time64Microsecond(v2)) => v1.cmp(v2),
(Time64Microsecond(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Time64Microsecond(_), _) => {
panic!("Attempt to compare Time64Microsecond with non-Time64Microsecond")
}
(Time64Nanosecond(v1), Time64Nanosecond(v2)) => v1.cmp(v2),
(Time64Nanosecond(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Time64Nanosecond(_), _) => {
panic!("Attempt to compare Time64Nanosecond with non-Time64Nanosecond")
}
(TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.cmp(v2),
(TimestampSecond(v1, _), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(TimestampSecond(_, _), _) => {
panic!("Attempt to compare TimestampSecond with non-TimestampSecond")
}
(TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => v1.cmp(v2),
(TimestampMillisecond(v1, _), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(TimestampMillisecond(_, _), _) => {
panic!("Attempt to compare TimestampMillisecond with non-TimestampMillisecond")
}
(TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => v1.cmp(v2),
(TimestampMicrosecond(v1, _), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(TimestampMicrosecond(_, _), _) => {
panic!("Attempt to compare TimestampMicrosecond with non-TimestampMicrosecond")
}
(TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => v1.cmp(v2),
(TimestampNanosecond(v1, _), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(TimestampNanosecond(_, _), _) => {
panic!("Attempt to compare TimestampNanosecond with non-TimestampNanosecond")
}
(IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.cmp(v2),
(IntervalYearMonth(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(IntervalYearMonth(_), _) => {
panic!("Attempt to compare IntervalYearMonth with non-IntervalYearMonth")
}
(IntervalDayTime(v1), IntervalDayTime(v2)) => v1.cmp(v2),
(IntervalDayTime(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(IntervalDayTime(_), _) => {
panic!("Attempt to compare IntervalDayTime with non-IntervalDayTime")
}
(IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.cmp(v2),
(IntervalMonthDayNano(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(IntervalMonthDayNano(_), _) => {
panic!("Attempt to compare IntervalMonthDayNano with non-IntervalMonthDayNano")
}
(DurationSecond(v1), DurationSecond(v2)) => v1.cmp(v2),
(DurationSecond(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(DurationSecond(_), _) => {
panic!("Attempt to compare DurationSecond with non-DurationSecond")
}
(DurationMillisecond(v1), DurationMillisecond(v2)) => v1.cmp(v2),
(DurationMillisecond(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(DurationMillisecond(_), _) => {
panic!("Attempt to compare DurationMillisecond with non-DurationMillisecond")
}
(DurationMicrosecond(v1), DurationMicrosecond(v2)) => v1.cmp(v2),
(DurationMicrosecond(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(DurationMicrosecond(_), _) => {
panic!("Attempt to compare DurationMicrosecond with non-DurationMicrosecond")
}
(DurationNanosecond(v1), DurationNanosecond(v2)) => v1.cmp(v2),
(DurationNanosecond(v1), Null) => {
if v1.is_none() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(DurationNanosecond(_), _) => {
panic!("Attempt to compare DurationNanosecond with non-DurationNanosecond")
}
(Struct(_arr), Struct(_arr2)) => todo!(),
(Struct(arr), Null) => {
if arr.is_empty() {
Ordering::Equal
} else {
Ordering::Greater
}
}
(Struct(_arr), _) => panic!("Attempt to compare Struct with non-Struct"),
(Dictionary(_k1, _v1), Dictionary(_k2, _v2)) => todo!(),
(Dictionary(_, v1), Null) => Self(*v1.clone()).cmp(&Self(ScalarValue::Null)),
(Dictionary(_, _), _) => panic!("Attempt to compare Dictionary with non-Dictionary"),
(Union(_, _, _), _) => todo!("Support for union scalars"),
(Null, Null) => Ordering::Equal,
(Null, _) => todo!(),
}
}
}
#[derive(Debug, DeepSizeOf, PartialEq, Eq)]
struct PageRecord {
max: OrderableScalarValue,
page_number: u32,
}
trait BTreeMapExt<K, V> {
fn largest_node_less(&self, key: &K) -> Option<(&K, &V)>;
}
impl<K: Ord, V> BTreeMapExt<K, V> for BTreeMap<K, V> {
fn largest_node_less(&self, key: &K) -> Option<(&K, &V)> {
self.range((Bound::Unbounded, Bound::Excluded(key)))
.next_back()
}
}
#[derive(Debug, DeepSizeOf, PartialEq, Eq)]
pub struct BTreeLookup {
tree: BTreeMap<OrderableScalarValue, Vec<PageRecord>>,
null_pages: Vec<u32>,
all_null_pages: Vec<u32>,
}
impl BTreeLookup {
fn empty() -> Self {
Self {
tree: BTreeMap::new(),
null_pages: Vec::new(),
all_null_pages: Vec::new(),
}
}
}
#[derive(Debug, Copy, Clone)]
enum Matches {
Some(u32),
All(u32),
}
impl Matches {
fn page_id(&self) -> u32 {
match self {
Self::Some(page_id) => *page_id,
Self::All(page_id) => *page_id,
}
}
}
impl BTreeLookup {
fn new(
tree: BTreeMap<OrderableScalarValue, Vec<PageRecord>>,
null_pages: Vec<u32>,
all_null_pages: Vec<u32>,
) -> Self {
Self {
tree,
null_pages,
all_null_pages,
}
}
fn pages_eq(&self, query: &OrderableScalarValue) -> Vec<Matches> {
if query.0.is_null() {
self.pages_null()
} else {
self.pages_between((Bound::Included(query), Bound::Excluded(query)))
}
}
fn pages_in(&self, values: impl IntoIterator<Item = OrderableScalarValue>) -> Vec<Matches> {
let page_lists = values
.into_iter()
.map(|val| {
self.pages_eq(&val)
.into_iter()
.map(|matches| matches.page_id())
})
.collect::<Vec<_>>();
let total_size = page_lists.iter().map(|set| set.len()).sum();
let mut heap = BinaryHeap::with_capacity(total_size);
for page_list in page_lists {
heap.extend(page_list);
}
let mut all_pages = heap.into_sorted_vec();
all_pages.dedup();
all_pages.into_iter().map(Matches::Some).collect()
}
fn pages_between(
&self,
range: (Bound<&OrderableScalarValue>, Bound<&OrderableScalarValue>),
) -> Vec<Matches> {
let lower_bound = match range.0 {
Bound::Unbounded => Bound::Unbounded,
Bound::Included(lower) => self
.tree
.largest_node_less(lower)
.map(|val| Bound::Included(val.0))
.unwrap_or(Bound::Unbounded),
Bound::Excluded(lower) => self
.tree
.largest_node_less(lower)
.map(|val| Bound::Included(val.0))
.unwrap_or(Bound::Unbounded),
};
let upper_bound = match range.1 {
Bound::Unbounded => Bound::Unbounded,
Bound::Included(upper) => Bound::Included(upper),
Bound::Excluded(upper) => Bound::Included(upper),
};
match (lower_bound, upper_bound) {
(Bound::Excluded(lower), Bound::Excluded(upper))
| (Bound::Excluded(lower), Bound::Included(upper))
| (Bound::Included(lower), Bound::Excluded(upper)) => {
if lower >= upper {
return vec![];
}
}
(Bound::Included(lower), Bound::Included(upper)) => {
if lower > upper {
return vec![];
}
}
_ => {}
}
let mut matches = Vec::new();
for (min, page_records) in self.tree.range((lower_bound, upper_bound)) {
for page_record in page_records {
match lower_bound {
Bound::Unbounded => {}
Bound::Included(lower) => {
if page_record.max.cmp(lower) == Ordering::Less {
continue;
}
}
Bound::Excluded(lower) => {
if page_record.max.cmp(lower) != Ordering::Greater {
continue;
}
}
}
if min.0.is_null() || page_record.max.0.is_null() {
matches.push(Matches::Some(page_record.page_number));
continue;
}
match range.0 {
Bound::Excluded(lower) => {
if min.cmp(lower) != Ordering::Greater {
matches.push(Matches::Some(page_record.page_number));
continue;
}
}
Bound::Included(lower) => {
if min.cmp(lower) == Ordering::Less {
matches.push(Matches::Some(page_record.page_number));
continue;
}
}
Bound::Unbounded => {}
}
match range.1 {
Bound::Excluded(upper) => {
if page_record.max.cmp(upper) != Ordering::Less {
matches.push(Matches::Some(page_record.page_number));
continue;
}
}
Bound::Included(upper) => {
if page_record.max.cmp(upper) == Ordering::Greater {
matches.push(Matches::Some(page_record.page_number));
continue;
}
}
Bound::Unbounded => {}
}
matches.push(Matches::All(page_record.page_number));
}
}
matches
}
fn pages_null(&self) -> Vec<Matches> {
self.null_pages
.iter()
.map(|page_id| Matches::Some(*page_id))
.chain(self.all_null_pages.iter().copied().map(Matches::All))
.collect()
}
}
#[derive(Clone)]
struct LazyIndexReader {
index_reader: Arc<tokio::sync::Mutex<Option<Arc<dyn IndexReader>>>>,
store: Arc<dyn IndexStore>,
ranges_to_files: Option<Arc<RangeInclusiveMap<u32, (String, u32)>>>,
}
impl LazyIndexReader {
fn new(
store: Arc<dyn IndexStore>,
ranges_to_files: Option<Arc<RangeInclusiveMap<u32, (String, u32)>>>,
) -> Self {
Self {
index_reader: Arc::new(tokio::sync::Mutex::new(None)),
store,
ranges_to_files,
}
}
async fn get(&self) -> Result<Arc<dyn IndexReader>> {
let mut reader = self.index_reader.lock().await;
if reader.is_none() {
let index_reader = if let Some(ranges_to_files) = &self.ranges_to_files {
Arc::new(LazyRangedIndexReader::new(
self.store.clone(),
ranges_to_files.clone(),
))
} else {
self.store.open_index_file(BTREE_PAGES_NAME).await?
};
*reader = Some(index_reader);
}
Ok(reader.as_ref().unwrap().clone())
}
}
struct LazyRangedIndexReader {
#[allow(clippy::type_complexity)]
readers:
Arc<tokio::sync::Mutex<HashMap<String, Arc<tokio::sync::OnceCell<Arc<dyn IndexReader>>>>>>,
store: Arc<dyn IndexStore>,
ranges_to_files: Arc<RangeInclusiveMap<u32, (String, u32)>>,
}
impl LazyRangedIndexReader {
fn new(
store: Arc<dyn IndexStore>,
ranges_to_files: Arc<RangeInclusiveMap<u32, (String, u32)>>,
) -> Self {
Self {
readers: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
store,
ranges_to_files,
}
}
async fn get_reader(&self, file_name: &str) -> Result<Arc<dyn IndexReader>> {
let reader_cell = {
let mut guard = self.readers.lock().await;
guard
.entry(file_name.to_string())
.or_insert_with(|| Arc::new(tokio::sync::OnceCell::new()))
.clone()
};
let reader = reader_cell
.get_or_try_init(|| async { self.store.open_index_file(file_name).await })
.await?;
Ok(reader.clone())
}
async fn get_reader_and_local_page_idx(
&self,
page_idx: u32,
) -> Result<(Arc<dyn IndexReader>, u32)> {
let (page_file_name, offset) = self.ranges_to_files.get(&page_idx).ok_or_else(|| {
Error::internal(format!(
"Unexpected page index, index {} is out of range.",
page_idx
))
})?;
let reader = self.get_reader(page_file_name).await?;
Ok((reader.clone(), page_idx - *offset))
}
}
#[async_trait]
impl IndexReader for LazyRangedIndexReader {
async fn read_record_batch(&self, n: u64, batch_size: u64) -> Result<RecordBatch> {
let (reader, local_page_idx) = self.get_reader_and_local_page_idx(n as u32).await?;
reader
.read_record_batch(local_page_idx as u64, batch_size)
.await
}
async fn read_range(
&self,
_range: std::ops::Range<usize>,
_projection: Option<&[&str]>,
) -> Result<RecordBatch> {
unimplemented!("Read range is not implemented for lazy page file reader.");
}
async fn num_batches(&self, batch_size: u64) -> u32 {
let mut total_batches = 0;
for (_, (file_name, _)) in self.ranges_to_files.iter() {
let reader = self
.get_reader(file_name)
.await
.unwrap_or_else(|_| panic!("Cannot open page file {}.", file_name));
total_batches += reader.as_ref().num_batches(batch_size).await;
}
total_batches
}
fn num_rows(&self) -> usize {
unimplemented!("only async functions are available for lazy page index reader.");
}
fn schema(&self) -> &lance_core::datatypes::Schema {
unimplemented!("only async functions are available for lazy page index reader.");
}
}
#[derive(Debug, Clone, DeepSizeOf)]
pub struct CachedScalarIndex(Arc<dyn ScalarIndex>);
impl CachedScalarIndex {
pub fn new(index: Arc<dyn ScalarIndex>) -> Self {
Self(index)
}
pub fn into_inner(self) -> Arc<dyn ScalarIndex> {
self.0
}
}
#[derive(Debug, Clone)]
pub struct BTreePageKey {
pub page_number: u32,
}
impl CacheKey for BTreePageKey {
type ValueType = FlatIndex;
fn key(&self) -> std::borrow::Cow<'_, str> {
format!("page-{}", self.page_number).into()
}
}
#[derive(Clone, Debug)]
pub struct BTreeIndex {
page_lookup: Arc<BTreeLookup>,
index_cache: WeakLanceCache,
store: Arc<dyn IndexStore>,
data_type: DataType,
batch_size: u64,
ranges_to_files: Option<Arc<RangeInclusiveMap<u32, (String, u32)>>>,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
}
impl DeepSizeOf for BTreeIndex {
fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
self.page_lookup.deep_size_of_children(context) + self.store.deep_size_of_children(context)
}
}
impl BTreeIndex {
#[allow(clippy::too_many_arguments)]
fn new(
page_lookup: Arc<BTreeLookup>,
store: Arc<dyn IndexStore>,
data_type: DataType,
index_cache: WeakLanceCache,
batch_size: u64,
ranges_to_files: Option<Arc<RangeInclusiveMap<u32, (String, u32)>>>,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
) -> Self {
Self {
page_lookup,
store,
data_type,
index_cache,
batch_size,
ranges_to_files,
frag_reuse_index,
}
}
async fn lookup_page(
&self,
page_number: u32,
index_reader: LazyIndexReader,
metrics: &dyn MetricsCollector,
) -> Result<Arc<FlatIndex>> {
self.index_cache
.get_or_insert_with_key(BTreePageKey { page_number }, move || async move {
self.read_page(page_number, index_reader, metrics).await
})
.await
}
#[instrument(level = "debug", skip_all)]
async fn read_page(
&self,
page_number: u32,
index_reader: LazyIndexReader,
metrics: &dyn MetricsCollector,
) -> Result<FlatIndex> {
metrics.record_part_load();
info!(target: TRACE_IO_EVENTS, r#type=IO_TYPE_LOAD_SCALAR_PART, index_type="btree", part_id=page_number);
let index_reader = index_reader.get().await?;
let mut serialized_page = index_reader
.read_record_batch(page_number as u64, self.batch_size)
.await?;
if let Some(frag_reuse_index_ref) = self.frag_reuse_index.as_ref() {
serialized_page =
frag_reuse_index_ref.remap_row_ids_record_batch(serialized_page, 1)?;
}
FlatIndex::try_new(serialized_page)
}
async fn search_page(
&self,
query: &SargableQuery,
matches: Matches,
index_reader: LazyIndexReader,
metrics: &dyn MetricsCollector,
) -> Result<NullableRowAddrSet> {
let subindex = self
.lookup_page(matches.page_id(), index_reader, metrics)
.await?;
match matches {
Matches::Some(_) => {
subindex.search(query, metrics)
}
Matches::All(_) => Ok(match query {
SargableQuery::IsNull() => subindex.all_ignore_nulls(),
_ => subindex.all(),
}),
}
}
#[instrument(level = "debug", skip_all)]
fn try_from_serialized(
data: RecordBatch,
store: Arc<dyn IndexStore>,
index_cache: &LanceCache,
batch_size: u64,
ranges_to_files: Option<Arc<RangeInclusiveMap<u32, (String, u32)>>>,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
) -> Result<Self> {
let mut map = BTreeMap::<OrderableScalarValue, Vec<PageRecord>>::new();
let mut null_pages = Vec::<u32>::new();
let mut all_null_pages = Vec::<u32>::new();
if data.num_rows() == 0 {
let data_type = data.column(0).data_type().clone();
let page_lookup = Arc::new(BTreeLookup::empty());
return Ok(Self::new(
page_lookup,
store,
data_type,
WeakLanceCache::from(index_cache),
batch_size,
ranges_to_files,
frag_reuse_index,
));
}
let mins = data.column(0);
let maxs = data.column(1);
let null_counts = data
.column(2)
.as_any()
.downcast_ref::<UInt32Array>()
.unwrap();
let page_numbers = data
.column(3)
.as_any()
.downcast_ref::<UInt32Array>()
.unwrap();
for idx in 0..data.num_rows() {
let min = OrderableScalarValue(ScalarValue::try_from_array(&mins, idx)?);
let max = OrderableScalarValue(ScalarValue::try_from_array(&maxs, idx)?);
let null_count = null_counts.values()[idx];
let page_number = page_numbers.values()[idx];
if max.0.is_null() {
all_null_pages.push(page_number);
continue;
} else {
map.entry(min)
.or_default()
.push(PageRecord { max, page_number });
}
if null_count > 0 {
null_pages.push(page_number);
}
}
let last_max = ScalarValue::try_from_array(&maxs, data.num_rows() - 1)?;
map.entry(OrderableScalarValue(last_max)).or_default();
let data_type = mins.data_type();
let page_lookup = Arc::new(BTreeLookup::new(map, null_pages, all_null_pages));
Ok(Self::new(
page_lookup,
store,
data_type.clone(),
WeakLanceCache::from(index_cache),
batch_size,
ranges_to_files,
frag_reuse_index,
))
}
async fn load(
store: Arc<dyn IndexStore>,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
index_cache: &LanceCache,
) -> Result<Arc<Self>> {
let page_lookup_file = store.open_index_file(BTREE_LOOKUP_NAME).await?;
let num_rows_in_lookup = page_lookup_file.num_rows();
let serialized_lookup = page_lookup_file
.read_range(0..num_rows_in_lookup, None)
.await?;
let file_schema = page_lookup_file.schema();
let batch_size = file_schema
.metadata
.get(BATCH_SIZE_META_KEY)
.map(|bs| bs.parse().unwrap_or(DEFAULT_BTREE_BATCH_SIZE))
.unwrap_or(DEFAULT_BTREE_BATCH_SIZE);
let range_partitioned = file_schema
.metadata
.get(RANGE_PARTITIONED_META_KEY)
.map(|bs| bs.parse().unwrap_or(DEFAULT_RANGE_PARTITIONED))
.unwrap_or(DEFAULT_RANGE_PARTITIONED);
let ranges_to_files = if range_partitioned {
let part_sizes_str = file_schema
.metadata
.get(PAGE_NUM_PER_RANGE_PARTITION_META_KEY)
.expect("Range-partitioned Btree lookup file must have page-number-per-range-file metadata!");
let part_sizes_vec: Vec<(u64, u32)> = serde_json::from_str(part_sizes_str)?;
let mut offset: u32 = 0;
let range_map = part_sizes_vec
.into_iter()
.map(|(id, size)| {
let range = offset..=(offset + size - 1);
let file_with_size = (part_page_data_file_path(id), offset);
offset += size;
(range, file_with_size)
})
.collect();
Some(Arc::new(range_map))
} else {
None
};
Ok(Arc::new(Self::try_from_serialized(
serialized_lookup,
store,
index_cache,
batch_size,
ranges_to_files,
frag_reuse_index,
)?))
}
fn train_schema(&self) -> Schema {
let value_field = Field::new(VALUE_COLUMN_NAME, self.data_type.clone(), true);
let row_id_field = Field::new(ROW_ID, DataType::UInt64, false);
Schema::new(vec![value_field, row_id_field])
}
async fn into_data_stream(self) -> Result<SendableRecordBatchStream> {
let lazy_reader = LazyIndexReader::new(self.store.clone(), self.ranges_to_files.clone());
let reader = lazy_reader.get().await?;
let new_schema = Arc::new(self.train_schema());
let new_schema_clone = new_schema.clone();
let reader_stream = IndexReaderStream::new(reader, self.batch_size).await;
let batches = reader_stream
.map(|fut| fut.map_err(DataFusionError::from))
.buffered(self.store.io_parallelism())
.map_ok(move |batch| {
RecordBatch::try_new(
new_schema.clone(),
vec![batch.column(0).clone(), batch.column(1).clone()],
)
.unwrap()
})
.boxed();
Ok(Box::pin(RecordBatchStreamAdapter::new(
new_schema_clone,
batches,
)))
}
async fn combine_old_new(
self,
new_data: SendableRecordBatchStream,
chunk_size: u64,
old_data_filter: Option<OldIndexDataFilter>,
) -> Result<SendableRecordBatchStream> {
let value_column_index = new_data.schema().index_of(VALUE_COLUMN_NAME)?;
let new_input = Arc::new(OneShotExec::new(new_data));
let old_stream = self.into_data_stream().await?;
let old_stream = match old_data_filter {
Some(filter) => filter_row_ids(old_stream, filter),
None => old_stream,
};
let old_input = Arc::new(OneShotExec::new(old_stream));
debug_assert_eq!(
old_input.schema().flattened_fields().len(),
new_input.schema().flattened_fields().len()
);
let sort_expr = PhysicalSortExpr {
expr: Arc::new(Column::new(VALUE_COLUMN_NAME, value_column_index)),
options: SortOptions {
descending: false,
nulls_first: true,
},
};
let all_data = UnionExec::try_new(vec![old_input, new_input])?;
let ordered = Arc::new(SortPreservingMergeExec::new([sort_expr].into(), all_data));
let unchunked = execute_plan(
ordered,
LanceExecutionOptions {
use_spilling: true,
..Default::default()
},
)?;
Ok(chunk_concat_stream(unchunked, chunk_size as usize))
}
}
fn filter_row_ids(
stream: SendableRecordBatchStream,
old_data_filter: OldIndexDataFilter,
) -> SendableRecordBatchStream {
let schema = stream.schema();
let filtered = stream.map(move |batch_result| {
let batch = batch_result?;
let row_ids = batch[ROW_ID]
.as_any()
.downcast_ref::<arrow_array::UInt64Array>()
.ok_or_else(|| Error::internal("expected UInt64Array for row_id column"))?;
let mask = old_data_filter.filter_row_ids(row_ids);
Ok(arrow_select::filter::filter_record_batch(&batch, &mask)?)
});
Box::pin(RecordBatchStreamAdapter::new(schema, filtered))
}
fn wrap_bound(bound: &Bound<ScalarValue>) -> Bound<OrderableScalarValue> {
match bound {
Bound::Unbounded => Bound::Unbounded,
Bound::Included(val) => Bound::Included(OrderableScalarValue(val.clone())),
Bound::Excluded(val) => Bound::Excluded(OrderableScalarValue(val.clone())),
}
}
fn serialize_with_display<T: Display, S: Serializer>(
value: &Option<T>,
serializer: S,
) -> std::result::Result<S::Ok, S::Error> {
if let Some(value) = value {
serializer.collect_str(value)
} else {
serializer.collect_str("N/A")
}
}
#[derive(Serialize)]
struct BTreeStatistics {
#[serde(serialize_with = "serialize_with_display")]
min: Option<OrderableScalarValue>,
#[serde(serialize_with = "serialize_with_display")]
max: Option<OrderableScalarValue>,
num_pages: u32,
}
#[async_trait]
impl Index for BTreeIndex {
fn as_any(&self) -> &dyn Any {
self
}
fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
self
}
fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn crate::vector::VectorIndex>> {
Err(Error::not_supported_source(
"BTreeIndex is not vector index".into(),
))
}
async fn prewarm(&self) -> Result<()> {
let index_reader = LazyIndexReader::new(self.store.clone(), self.ranges_to_files.clone());
let reader = index_reader.get().await?;
let num_pages = reader.num_batches(self.batch_size).await;
let mut pages = stream::iter(0..num_pages)
.map(|page_idx| {
let index_reader = index_reader.clone();
async move {
let page = self
.read_page(page_idx, index_reader, &NoOpMetricsCollector)
.await?;
Result::Ok((page_idx, page))
}
})
.buffer_unordered(get_num_compute_intensive_cpus());
while let Some((page_idx, page)) = pages.try_next().await? {
let inserted = self
.index_cache
.insert_with_key(
&BTreePageKey {
page_number: page_idx,
},
Arc::new(page),
)
.await;
if !inserted {
return Err(Error::internal(
"Failed to prewarm index: cache is no longer available".to_string(),
));
}
}
Ok(())
}
fn index_type(&self) -> IndexType {
IndexType::BTree
}
fn statistics(&self) -> Result<serde_json::Value> {
let min = self
.page_lookup
.tree
.first_key_value()
.map(|(k, _)| k.clone());
let max = self
.page_lookup
.tree
.last_key_value()
.map(|(k, _)| k.clone());
serde_json::to_value(&BTreeStatistics {
num_pages: self.page_lookup.tree.len() as u32,
min,
max,
})
.map_err(|err| err.into())
}
async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
let mut frag_ids = RoaringBitmap::default();
let lazy_reader = LazyIndexReader::new(self.store.clone(), self.ranges_to_files.clone());
let sub_index_reader = lazy_reader.get().await?;
let mut reader_stream = IndexReaderStream::new(sub_index_reader, self.batch_size)
.await
.buffered(self.store.io_parallelism());
while let Some(serialized) = reader_stream.try_next().await? {
let page = FlatIndex::try_new(serialized)?;
frag_ids |= page.calculate_included_frags()?;
}
Ok(frag_ids)
}
}
#[async_trait]
impl ScalarIndex for BTreeIndex {
async fn search(
&self,
query: &dyn AnyQuery,
metrics: &dyn MetricsCollector,
) -> Result<SearchResult> {
let query = query.as_any().downcast_ref::<SargableQuery>().unwrap();
let mut pages = match query {
SargableQuery::Equals(val) => self
.page_lookup
.pages_eq(&OrderableScalarValue(val.clone())),
SargableQuery::Range(start, end) => self
.page_lookup
.pages_between((wrap_bound(start).as_ref(), wrap_bound(end).as_ref())),
SargableQuery::IsIn(values) => self
.page_lookup
.pages_in(values.iter().map(|val| OrderableScalarValue(val.clone()))),
SargableQuery::FullTextSearch(_) => {
return Err(Error::invalid_input(
"full text search is not supported for BTree index, build a inverted index for it",
));
}
SargableQuery::IsNull() => self.page_lookup.pages_null(),
SargableQuery::LikePrefix(prefix) => {
match prefix {
ScalarValue::Utf8(Some(s)) => {
let start = Bound::Included(OrderableScalarValue(prefix.clone()));
let end = match compute_next_prefix(s) {
Some(next) => {
Bound::Excluded(OrderableScalarValue(ScalarValue::Utf8(Some(next))))
}
None => Bound::Unbounded,
};
self.page_lookup
.pages_between((start.as_ref(), end.as_ref()))
}
ScalarValue::LargeUtf8(Some(s)) => {
let start = Bound::Included(OrderableScalarValue(prefix.clone()));
let end = match compute_next_prefix(s) {
Some(next) => Bound::Excluded(OrderableScalarValue(
ScalarValue::LargeUtf8(Some(next)),
)),
None => Bound::Unbounded,
};
self.page_lookup
.pages_between((start.as_ref(), end.as_ref()))
}
_ => {
self.page_lookup
.pages_between((Bound::Unbounded, Bound::Unbounded))
}
}
}
};
if !matches!(query, SargableQuery::IsNull()) {
let existing: HashSet<u32> = pages.iter().map(|m| m.page_id()).collect();
for &page_id in self
.page_lookup
.null_pages
.iter()
.chain(self.page_lookup.all_null_pages.iter())
{
if !existing.contains(&page_id) {
pages.push(Matches::Some(page_id));
}
}
}
let lazy_index_reader =
LazyIndexReader::new(self.store.clone(), self.ranges_to_files.clone());
let page_tasks = pages
.into_iter()
.map(|page_index| {
self.search_page(query, page_index, lazy_index_reader.clone(), metrics)
.boxed()
})
.collect::<Vec<_>>();
debug!("Searching {} btree pages", page_tasks.len());
let results: Vec<NullableRowAddrSet> = stream::iter(page_tasks)
.buffered(get_num_compute_intensive_cpus())
.try_collect()
.await?;
let selection = NullableRowAddrSet::union_all(&results);
Ok(SearchResult::Exact(selection))
}
fn can_remap(&self) -> bool {
true
}
async fn remap(
&self,
mapping: &HashMap<u64, Option<u64>>,
dest_store: &dyn IndexStore,
) -> Result<CreatedIndex> {
let part_page_files: Vec<(Option<u32>, &str)> =
if let Some(ranges_to_files) = &self.ranges_to_files {
ranges_to_files
.iter()
.enumerate()
.map(|(part_id, (_, (path, _)))| (Some(part_id as u32), path.as_str()))
.collect()
} else {
vec![(None, BTREE_PAGES_NAME)]
};
let mapping = Arc::new(mapping.clone());
let train_schema = Arc::new(self.train_schema());
for (part_id, page_file) in part_page_files {
let sub_index_reader = self.store.open_index_file(page_file).await?;
let mapping = mapping.clone();
let train_schema_clone = train_schema.clone();
let train_schema = train_schema.clone();
let remapped_stream = IndexReaderStream::new(sub_index_reader, self.batch_size)
.await
.buffered(self.store.io_parallelism())
.map_err(DataFusionError::from)
.and_then(move |batch| {
let remapped =
FlatIndex::remap_batch(batch, &mapping).map_err(DataFusionError::from);
let with_train_schema = remapped.and_then(|batch| {
RecordBatch::try_new(train_schema.clone(), batch.columns().to_vec())
.map_err(DataFusionError::from)
});
std::future::ready(with_train_schema)
});
let remapped_stream = Box::pin(RecordBatchStreamAdapter::new(
train_schema_clone,
remapped_stream,
));
train_btree_index(remapped_stream, dest_store, self.batch_size, None, part_id).await?;
}
if let Some(ranges_to_files) = &self.ranges_to_files {
let num_parts = ranges_to_files.len();
let page_files = (0..num_parts)
.map(|part_id| part_page_data_file_path((part_id as u64) << 32))
.collect::<Vec<_>>();
let lookup_files = (0..num_parts)
.map(|part_id| part_lookup_file_path((part_id as u64) << 32))
.collect::<Vec<_>>();
merge_metadata_files(dest_store, &page_files, &lookup_files, None).await?;
}
Ok(CreatedIndex {
index_details: prost_types::Any::from_msg(&pbold::BTreeIndexDetails::default())
.unwrap(),
index_version: BTREE_INDEX_VERSION,
files: Some(dest_store.list_files_with_sizes().await?),
})
}
async fn update(
&self,
new_data: SendableRecordBatchStream,
dest_store: &dyn IndexStore,
old_data_filter: Option<OldIndexDataFilter>,
) -> Result<CreatedIndex> {
let merged_data_source = self
.clone()
.combine_old_new(new_data, self.batch_size, old_data_filter)
.await?;
train_btree_index(merged_data_source, dest_store, self.batch_size, None, None).await?;
Ok(CreatedIndex {
index_details: prost_types::Any::from_msg(&pbold::BTreeIndexDetails::default())
.unwrap(),
index_version: BTREE_INDEX_VERSION,
files: Some(dest_store.list_files_with_sizes().await?),
})
}
fn update_criteria(&self) -> UpdateCriteria {
UpdateCriteria::only_new_data(TrainingCriteria::new(TrainingOrdering::Values).with_row_id())
}
fn derive_index_params(&self) -> Result<ScalarIndexParams> {
let params = serde_json::to_value(BTreeParameters {
zone_size: Some(self.batch_size),
range_id: None,
})?;
Ok(ScalarIndexParams::for_builtin(BuiltinIndexType::BTree).with_params(¶ms))
}
}
struct BatchStats {
min: ScalarValue,
max: ScalarValue,
null_count: u32,
}
fn analyze_batch(batch: &RecordBatch) -> Result<BatchStats> {
let values = batch.column_by_name(VALUE_COLUMN_NAME).expect_ok()?;
if values.is_empty() {
return Err(Error::internal(
"received an empty batch in btree training".to_string(),
));
}
let min = ScalarValue::try_from_array(&values, 0)
.map_err(|e| Error::internal(format!("failed to get min value from batch: {}", e)))?;
let max = ScalarValue::try_from_array(&values, values.len() - 1)
.map_err(|e| Error::internal(format!("failed to get max value from batch: {}", e)))?;
Ok(BatchStats {
min,
max,
null_count: values.null_count() as u32,
})
}
#[async_trait]
pub trait BTreeSubIndex: Debug + Send + Sync + DeepSizeOf {
async fn train(&self, batch: RecordBatch) -> Result<RecordBatch>;
async fn load_subindex(&self, serialized: RecordBatch) -> Result<Arc<dyn ScalarIndex>>;
async fn retrieve_data(&self, serialized: RecordBatch) -> Result<RecordBatch>;
fn schema(&self) -> &Arc<Schema>;
async fn remap_subindex(
&self,
serialized: RecordBatch,
mapping: &HashMap<u64, Option<u64>>,
) -> Result<RecordBatch>;
}
struct EncodedBatch {
stats: BatchStats,
page_number: u32,
}
async fn train_btree_page(
batch: RecordBatch,
batch_idx: u32,
writer: &mut dyn IndexWriter,
schema: Arc<Schema>,
) -> Result<EncodedBatch> {
let stats = analyze_batch(&batch)?;
let trained = RecordBatch::try_new(
schema.clone(),
vec![
batch.column_by_name(VALUE_COLUMN_NAME).expect_ok()?.clone(),
batch.column_by_name(ROW_ID).expect_ok()?.clone(),
],
)?;
writer.write_record_batch(trained).await?;
Ok(EncodedBatch {
stats,
page_number: batch_idx,
})
}
fn btree_stats_as_batch(stats: Vec<EncodedBatch>, value_type: &DataType) -> Result<RecordBatch> {
let mins = if stats.is_empty() {
new_empty_array(value_type)
} else {
ScalarValue::iter_to_array(stats.iter().map(|stat| stat.stats.min.clone()))?
};
let maxs = if stats.is_empty() {
new_empty_array(value_type)
} else {
ScalarValue::iter_to_array(stats.iter().map(|stat| stat.stats.max.clone()))?
};
let null_counts = UInt32Array::from_iter_values(stats.iter().map(|stat| stat.stats.null_count));
let page_numbers = UInt32Array::from_iter_values(stats.iter().map(|stat| stat.page_number));
let schema = Arc::new(Schema::new(vec![
Field::new("min", mins.data_type().clone(), true),
Field::new("max", maxs.data_type().clone(), true),
Field::new("null_count", null_counts.data_type().clone(), false),
Field::new("page_idx", page_numbers.data_type().clone(), false),
]));
let columns = vec![
mins,
maxs,
Arc::new(null_counts) as Arc<dyn Array>,
Arc::new(page_numbers) as Arc<dyn Array>,
];
Ok(RecordBatch::try_new(schema, columns)?)
}
pub async fn train_btree_index(
batches_source: SendableRecordBatchStream,
index_store: &dyn IndexStore,
batch_size: u64,
fragment_ids: Option<Vec<u32>>,
range_id: Option<u32>,
) -> Result<()> {
let partition_id = fragment_ids
.as_ref()
.and_then(|frag_ids| frag_ids.first())
.map(|&first_frag_id| (first_frag_id as u64) << 32)
.or_else(|| range_id.map(|id| (id as u64) << 32));
let flat_schema = Arc::new(Schema::new(vec![
Field::new(
BTREE_VALUES_COLUMN,
batches_source.schema().field(0).data_type().clone(),
true,
),
Field::new(BTREE_IDS_COLUMN, DataType::UInt64, false),
]));
let mut sub_index_file = match partition_id {
None => {
index_store
.new_index_file(BTREE_PAGES_NAME, flat_schema.clone())
.await?
}
Some(partition_id) => {
index_store
.new_index_file(
part_page_data_file_path(partition_id).as_str(),
flat_schema.clone(),
)
.await?
}
};
let mut encoded_batches = Vec::new();
let mut batch_idx = 0;
let value_type = batches_source
.schema()
.field_with_name(VALUE_COLUMN_NAME)?
.data_type()
.clone();
let mut batches_source = chunk_concat_stream(batches_source, batch_size as usize);
while let Some(batch) = batches_source.try_next().await? {
encoded_batches.push(
train_btree_page(
batch,
batch_idx,
sub_index_file.as_mut(),
flat_schema.clone(),
)
.await?,
);
batch_idx += 1;
}
sub_index_file.finish().await?;
let record_batch = btree_stats_as_batch(encoded_batches, &value_type)?;
let mut file_schema = record_batch.schema().as_ref().clone();
file_schema
.metadata
.insert(BATCH_SIZE_META_KEY.to_string(), batch_size.to_string());
file_schema.metadata.insert(
RANGE_PARTITIONED_META_KEY.to_string(),
range_id.is_some().to_string(),
);
let mut btree_index_file = match partition_id {
None => {
index_store
.new_index_file(BTREE_LOOKUP_NAME, Arc::new(file_schema))
.await?
}
Some(partition_id) => {
index_store
.new_index_file(
part_lookup_file_path(partition_id).as_str(),
Arc::new(file_schema),
)
.await?
}
};
btree_index_file.write_record_batch(record_batch).await?;
btree_index_file.finish().await?;
Ok(())
}
pub async fn merge_index_files(
object_store: &ObjectStore,
index_dir: &Path,
store: Arc<dyn IndexStore>,
batch_readhead: Option<usize>,
) -> Result<()> {
let (part_page_files, part_lookup_files) =
list_page_lookup_files(object_store, index_dir).await?;
merge_metadata_files(
store.as_ref(),
&part_page_files,
&part_lookup_files,
batch_readhead,
)
.await
}
async fn list_page_lookup_files(
object_store: &ObjectStore,
index_dir: &Path,
) -> Result<(Vec<String>, Vec<String>)> {
let mut part_page_files = Vec::new();
let mut part_lookup_files = Vec::new();
let mut list_stream = object_store.list(Some(index_dir.clone()));
while let Some(item) = list_stream.next().await {
match item {
Ok(meta) => {
let file_name = meta.location.filename().unwrap_or_default();
if file_name.starts_with("part_") && file_name.ends_with("_page_data.lance") {
part_page_files.push(file_name.to_string());
}
if file_name.starts_with("part_") && file_name.ends_with("_page_lookup.lance") {
part_lookup_files.push(file_name.to_string());
}
}
Err(_) => continue,
}
}
if part_page_files.is_empty() || part_lookup_files.is_empty() {
return Err(Error::internal(format!(
"No partition metadata files found in index directory: {} (page_files: {}, lookup_files: {})",
index_dir,
part_page_files.len(),
part_lookup_files.len()
)));
}
Ok((part_page_files, part_lookup_files))
}
async fn merge_metadata_files(
store: &dyn IndexStore,
part_page_files: &[String],
part_lookup_files: &[String],
batch_readhead: Option<usize>,
) -> Result<()> {
if part_lookup_files.is_empty() || part_page_files.is_empty() {
return Err(Error::internal(
"No partition files provided for merging".to_string(),
));
}
if part_lookup_files.len() != part_page_files.len() {
return Err(Error::internal(format!(
"Number of partition lookup files ({}) does not match number of partition page files ({})",
part_lookup_files.len(),
part_page_files.len()
)));
}
let mut page_files_map = HashMap::new();
for page_file in part_page_files {
let partition_id = extract_partition_id(page_file)?;
page_files_map.insert(partition_id, page_file);
}
for lookup_file in part_lookup_files {
let partition_id = extract_partition_id(lookup_file)?;
if !page_files_map.contains_key(&partition_id) {
return Err(Error::internal(format!(
"No corresponding page file found for lookup file: {} (partition_id: {})",
lookup_file, partition_id
)));
}
}
let first_lookup_reader = store.open_index_file(&part_lookup_files[0]).await?;
let batch_size = first_lookup_reader
.schema()
.metadata
.get(BATCH_SIZE_META_KEY)
.map(|bs| bs.parse().unwrap_or(DEFAULT_BTREE_BATCH_SIZE))
.unwrap_or(DEFAULT_BTREE_BATCH_SIZE);
let range_partitioned = first_lookup_reader
.schema()
.metadata
.get(RANGE_PARTITIONED_META_KEY)
.map(|bs| bs.parse().unwrap_or(DEFAULT_RANGE_PARTITIONED))
.unwrap_or(DEFAULT_RANGE_PARTITIONED);
let value_type = first_lookup_reader
.schema()
.fields
.first()
.unwrap()
.data_type();
let mut metadata = HashMap::new();
metadata.insert(BATCH_SIZE_META_KEY.to_string(), batch_size.to_string());
let lookup_schema = Arc::new(Schema::new(vec![
Field::new("min", value_type.clone(), true),
Field::new("max", value_type.clone(), true),
Field::new("null_count", DataType::UInt32, false),
Field::new("page_idx", DataType::UInt32, false),
]));
if range_partitioned {
merge_range_partitioned_lookups(
store,
part_lookup_files,
lookup_schema,
metadata,
batch_size,
batch_readhead,
)
.await
} else {
merge_pages_and_lookups(
store,
part_page_files,
part_lookup_files,
&page_files_map,
lookup_schema,
metadata,
batch_size,
batch_readhead,
)
.await
}
}
async fn merge_range_partitioned_lookups(
store: &dyn IndexStore,
part_lookup_files: &[String],
lookup_schema: Arc<Schema>,
mut metadata: HashMap<String, String>,
batch_size: u64,
batch_readhead: Option<usize>,
) -> Result<()> {
let sorted_part_lookup_files = sort_files_by_partition_id(part_lookup_files)?;
let mut lookup_file = store
.new_index_file(BTREE_LOOKUP_NAME, lookup_schema)
.await?;
let mut pages_per_file: Vec<(u64, u32)> = Vec::with_capacity(sorted_part_lookup_files.len());
let mut num_pages_written = 0u32;
for (part_id, part_lookup_file) in sorted_part_lookup_files {
let lookup_reader = store.open_index_file(&part_lookup_file).await?;
let reader_stream = IndexReaderStream::new(lookup_reader.clone(), batch_size).await;
let mut stream = reader_stream.buffered(batch_readhead.unwrap_or(1)).boxed();
while let Some(batch) = stream.next().await {
let original_batch = batch?;
let modified_batch = add_offset_to_page_idx(&original_batch, num_pages_written)?;
lookup_file.write_record_batch(modified_batch).await?;
}
pages_per_file.push((part_id, lookup_reader.num_rows() as u32));
num_pages_written += lookup_reader.num_rows() as u32;
}
metadata.insert(RANGE_PARTITIONED_META_KEY.to_string(), "true".to_string());
metadata.insert(
PAGE_NUM_PER_RANGE_PARTITION_META_KEY.to_string(),
serde_json::to_string(&pages_per_file)?,
);
lookup_file.finish_with_metadata(metadata).await?;
cleanup_partition_files(store, part_lookup_files, &[]).await;
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn merge_pages_and_lookups(
store: &dyn IndexStore,
part_page_files: &[String],
part_lookup_files: &[String],
page_files_map: &HashMap<u64, &String>,
lookup_schema: Arc<Schema>,
metadata: HashMap<String, String>,
batch_size: u64,
batch_readhead: Option<usize>,
) -> Result<()> {
let partition_id = extract_partition_id(part_lookup_files[0].as_str())?;
let page_file = page_files_map.get(&partition_id).unwrap();
let page_reader = store.open_index_file(page_file).await?;
let page_schema = page_reader.schema().clone();
let arrow_schema = Arc::new(Schema::from(&page_schema));
let mut page_file = store
.new_index_file(BTREE_PAGES_NAME, arrow_schema.clone())
.await?;
let lookup_entries = merge_pages(
part_lookup_files,
page_files_map,
store,
batch_size,
&mut page_file,
arrow_schema.clone(),
batch_readhead,
)
.await?;
page_file.finish().await?;
let lookup_batch = RecordBatch::try_new(
lookup_schema.clone(),
vec![
ScalarValue::iter_to_array(lookup_entries.iter().map(|(min, _, _, _)| min.clone()))?,
ScalarValue::iter_to_array(lookup_entries.iter().map(|(_, max, _, _)| max.clone()))?,
Arc::new(UInt32Array::from_iter_values(
lookup_entries
.iter()
.map(|(_, _, null_count, _)| *null_count),
)),
Arc::new(UInt32Array::from_iter_values(
lookup_entries.iter().map(|(_, _, _, page_idx)| *page_idx),
)),
],
)?;
let mut lookup_file = store
.new_index_file(BTREE_LOOKUP_NAME, lookup_schema)
.await?;
lookup_file.write_record_batch(lookup_batch).await?;
lookup_file.finish_with_metadata(metadata).await?;
cleanup_partition_files(store, part_lookup_files, part_page_files).await;
Ok(())
}
fn add_offset_to_page_idx(batch: &RecordBatch, offset: u32) -> Result<RecordBatch> {
let (page_idx_pos, _) = batch.schema().column_with_name("page_idx").ok_or_else(|| {
Error::internal("Column 'page_idx' not found in RecordBatch schema".to_string())
})?;
let page_idx_array = batch
.column(page_idx_pos)
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or_else(|| {
Error::internal("Failed to downcast 'page_idx' column to UInt32Array".to_string())
})?;
let offset_array = UInt32Array::from(vec![offset; page_idx_array.len()]);
let new_page_idx_array_ref = add(page_idx_array, &offset_array)?;
let mut new_columns = batch.columns().to_vec();
new_columns[page_idx_pos] = new_page_idx_array_ref;
let new_batch = RecordBatch::try_new(batch.schema(), new_columns)?;
Ok(new_batch)
}
async fn merge_pages(
part_lookup_files: &[String],
page_files_map: &HashMap<u64, &String>,
store: &dyn IndexStore,
batch_size: u64,
page_file: &mut Box<dyn IndexWriter>,
arrow_schema: Arc<Schema>,
batch_readhead: Option<usize>,
) -> Result<Vec<(ScalarValue, ScalarValue, u32, u32)>> {
let mut lookup_entries = Vec::new();
let mut page_idx = 0u32;
debug!(
"Starting SortPreservingMerge with {} partitions",
part_lookup_files.len()
);
let value_field = arrow_schema.field(0).clone().with_name(VALUE_COLUMN_NAME);
let row_id_field = arrow_schema.field(1).clone().with_name(ROW_ID);
let stream_schema = Arc::new(Schema::new(vec![value_field, row_id_field]));
let mut inputs: Vec<Arc<dyn ExecutionPlan>> = Vec::new();
for lookup_file in part_lookup_files {
let partition_id = extract_partition_id(lookup_file)?;
let page_file_name = (*page_files_map.get(&partition_id).ok_or_else(|| {
Error::internal(format!(
"Page file not found for partition ID: {}",
partition_id
))
})?)
.clone();
let reader = store.open_index_file(&page_file_name).await?;
let reader_stream = IndexReaderStream::new(reader, batch_size).await;
let stream = reader_stream
.map(|fut| fut.map_err(DataFusionError::from))
.buffered(batch_readhead.unwrap_or(1))
.boxed();
let sendable_stream =
Box::pin(RecordBatchStreamAdapter::new(stream_schema.clone(), stream));
inputs.push(Arc::new(OneShotExec::new(sendable_stream)));
}
let union_inputs = UnionExec::try_new(inputs)?;
let value_column_index = stream_schema.index_of(VALUE_COLUMN_NAME)?;
let sort_expr = PhysicalSortExpr {
expr: Arc::new(Column::new(VALUE_COLUMN_NAME, value_column_index)),
options: SortOptions {
descending: false,
nulls_first: true,
},
};
let merge_exec = Arc::new(SortPreservingMergeExec::new(
[sort_expr].into(),
union_inputs,
));
let unchunked = execute_plan(
merge_exec,
LanceExecutionOptions {
use_spilling: false,
..Default::default()
},
)?;
let mut chunked_stream = chunk_concat_stream(unchunked, batch_size as usize);
while let Some(batch) = chunked_stream.try_next().await? {
let writer_batch = RecordBatch::try_new(
arrow_schema.clone(),
vec![batch.column(0).clone(), batch.column(1).clone()],
)?;
page_file.write_record_batch(writer_batch).await?;
let min_val = ScalarValue::try_from_array(batch.column(0), 0)?;
let max_val = ScalarValue::try_from_array(batch.column(0), batch.num_rows() - 1)?;
let null_count = batch.column(0).null_count() as u32;
lookup_entries.push((min_val, max_val, null_count, page_idx));
page_idx += 1;
}
Ok(lookup_entries)
}
fn sort_files_by_partition_id(part_files: &[String]) -> Result<Vec<(u64, String)>> {
let mut files_with_ids: Vec<(u64, &String)> = part_files
.iter()
.map(|file| extract_partition_id(file).map(|id| (id, file)))
.collect::<Result<Vec<_>>>()?;
files_with_ids.sort_unstable_by_key(|k| k.0);
let sorted_files = files_with_ids
.into_iter()
.map(|(id, file)| (id, file.clone()))
.collect();
Ok(sorted_files)
}
fn extract_partition_id(filename: &str) -> Result<u64> {
if !filename.starts_with("part_") {
return Err(Error::internal(format!(
"Invalid partition file name format: {}",
filename
)));
}
let parts: Vec<&str> = filename.split('_').collect();
if parts.len() < 3 {
return Err(Error::internal(format!(
"Invalid partition file name format: {}",
filename
)));
}
parts[1].parse::<u64>().map_err(|_| {
Error::internal(format!(
"Failed to parse partition ID from filename: {}",
filename
))
})
}
async fn cleanup_partition_files(
store: &dyn IndexStore,
part_lookup_files: &[String],
part_page_files: &[String],
) {
for file_name in part_lookup_files {
cleanup_single_file(
store,
file_name,
"part_",
"_page_lookup.lance",
"partition lookup",
)
.await;
}
for file_name in part_page_files {
cleanup_single_file(
store,
file_name,
"part_",
"_page_data.lance",
"partition page",
)
.await;
}
}
async fn cleanup_single_file(
store: &dyn IndexStore,
file_name: &str,
expected_prefix: &str,
expected_suffix: &str,
file_type: &str,
) {
if file_name.starts_with(expected_prefix) && file_name.ends_with(expected_suffix) {
match store.delete_index_file(file_name).await {
Ok(()) => {
debug!("Successfully deleted {} file: {}", file_type, file_name);
}
Err(e) => {
warn!(
"Failed to delete {} file '{}': {}. \
This does not affect the merge operation, but may leave \
partition files that should be cleaned up manually.",
file_type, file_name, e
);
}
}
} else {
warn!(
"Skipping deletion of file '{}' as it does not match the expected \
{} file pattern ({}*{})",
file_name, file_type, expected_prefix, expected_suffix
);
}
}
pub(crate) fn part_page_data_file_path(partition_id: u64) -> String {
format!("part_{}_{}", partition_id, BTREE_PAGES_NAME)
}
pub(crate) fn part_lookup_file_path(partition_id: u64) -> String {
format!("part_{}_{}", partition_id, BTREE_LOOKUP_NAME)
}
struct IndexReaderStream {
reader: Arc<dyn IndexReader>,
batch_size: u64,
num_batches: u32,
batch_idx: u32,
}
impl IndexReaderStream {
async fn new(reader: Arc<dyn IndexReader>, batch_size: u64) -> Self {
let num_batches = reader.num_batches(batch_size).await;
Self {
reader,
batch_size,
num_batches,
batch_idx: 0,
}
}
}
impl Stream for IndexReaderStream {
type Item = BoxFuture<'static, Result<RecordBatch>>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.batch_idx >= this.num_batches {
return std::task::Poll::Ready(None);
}
let batch_num = this.batch_idx;
this.batch_idx += 1;
let reader_copy = this.reader.clone();
let batch_size = this.batch_size;
let read_task = async move {
reader_copy
.read_record_batch(batch_num as u64, batch_size)
.await
}
.boxed();
std::task::Poll::Ready(Some(read_task))
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct BTreeParameters {
pub zone_size: Option<u64>,
pub range_id: Option<u32>,
}
struct BTreeTrainingRequest {
parameters: BTreeParameters,
criteria: TrainingCriteria,
}
impl BTreeTrainingRequest {
pub fn new(parameters: BTreeParameters) -> Self {
Self {
parameters,
criteria: TrainingCriteria::new(TrainingOrdering::Values).with_row_id(),
}
}
}
impl TrainingRequest for BTreeTrainingRequest {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn criteria(&self) -> &TrainingCriteria {
&self.criteria
}
}
#[derive(Debug, Default)]
pub struct BTreeIndexPlugin;
#[async_trait]
impl ScalarIndexPlugin for BTreeIndexPlugin {
fn name(&self) -> &str {
"BTree"
}
fn new_training_request(
&self,
params: &str,
field: &Field,
) -> Result<Box<dyn TrainingRequest>> {
if field.data_type().is_nested() {
return Err(Error::invalid_input_source(
"A btree index can only be created on a non-nested field.".into(),
));
}
let params = serde_json::from_str::<BTreeParameters>(params)?;
Ok(Box::new(BTreeTrainingRequest::new(params)))
}
fn provides_exact_answer(&self) -> bool {
true
}
fn version(&self) -> u32 {
BTREE_INDEX_VERSION
}
fn new_query_parser(
&self,
index_name: String,
_index_details: &prost_types::Any,
) -> Option<Box<dyn ScalarQueryParser>> {
Some(Box::new(SargableQueryParser::new(index_name, false)))
}
async fn train_index(
&self,
data: SendableRecordBatchStream,
index_store: &dyn IndexStore,
request: Box<dyn TrainingRequest>,
fragment_ids: Option<Vec<u32>>,
_progress: Arc<dyn crate::progress::IndexBuildProgress>,
) -> Result<CreatedIndex> {
let request = request
.as_any()
.downcast_ref::<BTreeTrainingRequest>()
.unwrap();
train_btree_index(
data,
index_store,
request
.parameters
.zone_size
.unwrap_or(DEFAULT_BTREE_BATCH_SIZE),
fragment_ids,
request.parameters.range_id,
)
.await?;
Ok(CreatedIndex {
index_details: prost_types::Any::from_msg(&pbold::BTreeIndexDetails::default())
.unwrap(),
index_version: BTREE_INDEX_VERSION,
files: Some(index_store.list_files_with_sizes().await?),
})
}
async fn load_index(
&self,
index_store: Arc<dyn IndexStore>,
_index_details: &prost_types::Any,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
cache: &LanceCache,
) -> Result<Arc<dyn ScalarIndex>> {
Ok(BTreeIndex::load(index_store, frag_reuse_index, cache).await? as Arc<dyn ScalarIndex>)
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::Ordering;
use std::{collections::HashMap, sync::Arc};
use arrow::datatypes::{Float32Type, Float64Type, Int32Type, UInt64Type};
use arrow_array::{FixedSizeListArray, record_batch};
use datafusion::{
execution::{SendableRecordBatchStream, TaskContext},
physical_plan::{ExecutionPlan, sorts::sort::SortExec, stream::RecordBatchStreamAdapter},
};
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_physical_expr::{PhysicalSortExpr, expressions::col};
use deepsize::DeepSizeOf;
use futures::TryStreamExt;
use futures::stream;
use lance_core::utils::mask::RowSetOps;
use lance_core::utils::tempfile::TempObjDir;
use lance_core::{cache::LanceCache, utils::mask::RowAddrTreeMap};
use lance_datafusion::{chunker::break_stream, datagen::DatafusionDatagenExt};
use lance_datagen::{ArrayGeneratorExt, BatchCount, RowCount, array, gen_batch};
use lance_io::object_store::ObjectStore;
use object_store::path::Path;
use crate::metrics::LocalMetricsCollector;
use crate::{
metrics::NoOpMetricsCollector,
scalar::{
IndexStore, OldIndexDataFilter, SargableQuery, ScalarIndex, SearchResult,
btree::{BTREE_PAGES_NAME, BTreeIndex},
lance_format::LanceIndexStore,
},
};
use super::{
DEFAULT_BTREE_BATCH_SIZE, OrderableScalarValue, part_lookup_file_path,
part_page_data_file_path, train_btree_index,
};
#[test]
fn test_scalar_value_size() {
let size_of_i32 = OrderableScalarValue(ScalarValue::Int32(Some(0))).deep_size_of();
let size_of_many_i32 = OrderableScalarValue(ScalarValue::FixedSizeList(Arc::new(
FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
vec![Some(vec![Some(0); 128])],
128,
),
)))
.deep_size_of();
assert!(size_of_i32 > 4);
assert!(size_of_many_i32 > 128 * 4);
}
#[tokio::test]
async fn test_null_ids() {
let tmpdir = TempObjDir::default();
let test_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let stream = gen_batch()
.col(
"value",
array::rand::<Float32Type>().with_nulls(&[true, false, false, false, false]),
)
.col("_rowid", array::step::<UInt64Type>())
.into_df_stream(RowCount::from(5000), BatchCount::from(10));
train_btree_index(stream, test_store.as_ref(), 5000, None, None)
.await
.unwrap();
let index = BTreeIndex::load(test_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
assert_eq!(index.page_lookup.null_pages.len(), 10);
let remap_dir = TempObjDir::default();
let remap_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
remap_dir.clone(),
Arc::new(LanceCache::no_cache()),
));
index
.remap(&HashMap::default(), remap_store.as_ref())
.await
.unwrap();
let remap_index = BTreeIndex::load(remap_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
assert_eq!(remap_index.page_lookup, index.page_lookup);
let original_pages = test_store.open_index_file(BTREE_PAGES_NAME).await.unwrap();
let remapped_pages = remap_store.open_index_file(BTREE_PAGES_NAME).await.unwrap();
assert_eq!(original_pages.num_rows(), remapped_pages.num_rows());
let original_data = original_pages
.read_record_batch(0, original_pages.num_rows() as u64)
.await
.unwrap();
let remapped_data = remapped_pages
.read_record_batch(0, remapped_pages.num_rows() as u64)
.await
.unwrap();
assert_eq!(original_data, remapped_data);
}
#[tokio::test]
async fn test_nan_ordering() {
let tmpdir = TempObjDir::default();
let test_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let values = vec![
0.0,
1.0,
2.0,
3.0,
f64::NAN,
f64::NEG_INFINITY,
f64::INFINITY,
];
let data = gen_batch()
.col("value", array::cycle::<Float64Type>(values.clone()))
.col("_rowid", array::step::<UInt64Type>())
.into_df_exec(RowCount::from(10), BatchCount::from(100));
let schema = data.schema();
let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap());
let plan = Arc::new(SortExec::new([sort_expr].into(), data));
let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap();
let stream = break_stream(stream, 64);
let stream = stream.map_err(DataFusionError::from);
let stream =
Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream;
train_btree_index(stream, test_store.as_ref(), 64, None, None)
.await
.unwrap();
let index = BTreeIndex::load(test_store, None, &LanceCache::no_cache())
.await
.unwrap();
for (idx, value) in values.into_iter().enumerate() {
let query = SargableQuery::Equals(ScalarValue::Float64(Some(value)));
let result = index.search(&query, &NoOpMetricsCollector).await.unwrap();
assert_eq!(
result,
SearchResult::exact(RowAddrTreeMap::from_iter(((idx as u64)..1000).step_by(7)))
);
}
}
#[tokio::test]
async fn test_page_cache() {
let tmpdir = TempObjDir::default();
let test_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let data = gen_batch()
.col("value", array::step::<Float32Type>())
.col("_rowid", array::step::<UInt64Type>())
.into_df_exec(RowCount::from(1000), BatchCount::from(10));
let schema = data.schema();
let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap());
let plan = Arc::new(SortExec::new([sort_expr].into(), data));
let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap();
let stream = break_stream(stream, 64);
let stream = stream.map_err(DataFusionError::from);
let stream =
Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream;
train_btree_index(stream, test_store.as_ref(), 64, None, None)
.await
.unwrap();
let cache = Arc::new(LanceCache::with_capacity(100 * 1024 * 1024));
let index = BTreeIndex::load(test_store, None, cache.as_ref())
.await
.unwrap();
let query = SargableQuery::Equals(ScalarValue::Float32(Some(0.0)));
let metrics = LocalMetricsCollector::default();
let query1 = index.search(&query, &metrics);
let query2 = index.search(&query, &metrics);
tokio::join!(query1, query2).0.unwrap();
assert_eq!(metrics.parts_loaded.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_like_prefix_search() {
use arrow::datatypes::DataType;
use arrow_array::StringArray;
let tmpdir = TempObjDir::default();
let test_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let values = vec![
"apple",
"app",
"application",
"banana",
"band",
"test_ns$table1",
"test_ns$table2",
"test_ns2$table1",
"test",
"testing",
];
let row_ids: Vec<u64> = (0..values.len() as u64).collect();
let schema = Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new("value", DataType::Utf8, false),
arrow::datatypes::Field::new("_rowid", DataType::UInt64, false),
]));
let batch = arrow::record_batch::RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(values.clone())),
Arc::new(arrow_array::UInt64Array::from(row_ids)),
],
)
.unwrap();
let stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new(
schema,
stream::once(async { Ok(batch) }),
));
train_btree_index(stream, test_store.as_ref(), 100, None, None)
.await
.unwrap();
let index = BTreeIndex::load(test_store, None, &LanceCache::no_cache())
.await
.unwrap();
let query = SargableQuery::LikePrefix(ScalarValue::Utf8(Some("app".to_string())));
let result = index.search(&query, &NoOpMetricsCollector).await.unwrap();
match &result {
SearchResult::Exact(row_ids) => {
let ids: Vec<u64> = row_ids
.true_rows()
.row_addrs()
.unwrap()
.map(u64::from)
.collect();
assert!(ids.contains(&0), "Should contain row 0 (apple)");
assert!(ids.contains(&1), "Should contain row 1 (app)");
assert!(ids.contains(&2), "Should contain row 2 (application)");
assert!(!ids.contains(&3), "Should not contain row 3 (banana)");
}
_ => panic!("Expected Exact result"),
}
let query = SargableQuery::LikePrefix(ScalarValue::Utf8(Some("test_ns$".to_string())));
let result = index.search(&query, &NoOpMetricsCollector).await.unwrap();
match &result {
SearchResult::Exact(row_ids) => {
let ids: Vec<u64> = row_ids
.true_rows()
.row_addrs()
.unwrap()
.map(u64::from)
.collect();
assert!(ids.contains(&5), "Should contain row 5 (test_ns$table1)");
assert!(ids.contains(&6), "Should contain row 6 (test_ns$table2)");
assert!(
!ids.contains(&7),
"Should not contain row 7 (test_ns2$table1)"
);
}
_ => panic!("Expected Exact result"),
}
let query = SargableQuery::LikePrefix(ScalarValue::Utf8(Some("test".to_string())));
let result = index.search(&query, &NoOpMetricsCollector).await.unwrap();
match &result {
SearchResult::Exact(row_ids) => {
let ids: Vec<u64> = row_ids
.true_rows()
.row_addrs()
.unwrap()
.map(u64::from)
.collect();
assert!(
ids.contains(&5),
"Should contain row 5 (test_ns$table1): {:?}",
ids
);
assert!(
ids.contains(&6),
"Should contain row 6 (test_ns$table2): {:?}",
ids
);
assert!(
ids.contains(&7),
"Should contain row 7 (test_ns2$table1): {:?}",
ids
);
assert!(ids.contains(&8), "Should contain row 8 (test): {:?}", ids);
assert!(
ids.contains(&9),
"Should contain row 9 (testing): {:?}",
ids
);
}
_ => panic!("Expected Exact result"),
}
}
#[tokio::test]
async fn test_like_prefix_search_large_utf8() {
use arrow::datatypes::DataType;
use arrow_array::LargeStringArray;
let tmpdir = TempObjDir::default();
let test_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let values = vec!["apple", "app", "application", "banana"];
let row_ids: Vec<u64> = (0..values.len() as u64).collect();
let schema = Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new("value", DataType::LargeUtf8, false),
arrow::datatypes::Field::new("_rowid", DataType::UInt64, false),
]));
let batch = arrow::record_batch::RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(LargeStringArray::from(values)),
Arc::new(arrow_array::UInt64Array::from(row_ids)),
],
)
.unwrap();
let stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new(
schema,
stream::once(async { Ok(batch) }),
));
train_btree_index(stream, test_store.as_ref(), 100, None, None)
.await
.unwrap();
let index = BTreeIndex::load(test_store, None, &LanceCache::no_cache())
.await
.unwrap();
let query = SargableQuery::LikePrefix(ScalarValue::LargeUtf8(Some("app".to_string())));
let result = index.search(&query, &NoOpMetricsCollector).await.unwrap();
match &result {
SearchResult::Exact(row_ids) => {
let ids: Vec<u64> = row_ids
.true_rows()
.row_addrs()
.unwrap()
.map(u64::from)
.collect();
assert!(ids.contains(&0), "Should contain row 0 (apple)");
assert!(ids.contains(&1), "Should contain row 1 (app)");
assert!(ids.contains(&2), "Should contain row 2 (application)");
assert!(!ids.contains(&3), "Should not contain row 3 (banana)");
}
_ => panic!("Expected Exact result"),
}
}
#[tokio::test]
async fn test_fragment_btree_index_consistency() {
let full_tmpdir = TempObjDir::default();
let full_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
full_tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let fragment_tmpdir = TempObjDir::default();
let fragment_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
fragment_tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let total_count = 2 * DEFAULT_BTREE_BATCH_SIZE;
let full_data_gen = gen_batch()
.col("value", array::step::<Int32Type>())
.col("_rowid", array::step::<UInt64Type>())
.into_df_stream(RowCount::from(total_count / 2), BatchCount::from(2));
let full_data_source = Box::pin(RecordBatchStreamAdapter::new(
full_data_gen.schema(),
full_data_gen,
));
train_btree_index(
full_data_source,
full_store.as_ref(),
DEFAULT_BTREE_BATCH_SIZE,
None,
None,
)
.await
.unwrap();
let half_count = DEFAULT_BTREE_BATCH_SIZE;
let fragment1_gen = gen_batch()
.col("value", array::step::<Int32Type>())
.col("_rowid", array::step::<UInt64Type>())
.into_df_stream(RowCount::from(half_count), BatchCount::from(1));
let fragment1_data_source = Box::pin(RecordBatchStreamAdapter::new(
fragment1_gen.schema(),
fragment1_gen,
));
train_btree_index(
fragment1_data_source,
fragment_store.as_ref(),
DEFAULT_BTREE_BATCH_SIZE,
Some(vec![1]), None,
)
.await
.unwrap();
let start_val = DEFAULT_BTREE_BATCH_SIZE as i32;
let end_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32;
let values_second_half: Vec<i32> = (start_val..end_val).collect();
let row_ids_second_half: Vec<u64> = (start_val as u64..end_val as u64).collect();
let fragment2_gen = gen_batch()
.col("value", array::cycle::<Int32Type>(values_second_half))
.col("_rowid", array::cycle::<UInt64Type>(row_ids_second_half))
.into_df_stream(RowCount::from(half_count), BatchCount::from(1));
let fragment2_data_source = Box::pin(RecordBatchStreamAdapter::new(
fragment2_gen.schema(),
fragment2_gen,
));
train_btree_index(
fragment2_data_source,
fragment_store.as_ref(),
DEFAULT_BTREE_BATCH_SIZE,
Some(vec![2]), None,
)
.await
.unwrap();
let part_page_files = vec![
part_page_data_file_path(1 << 32),
part_page_data_file_path(2 << 32),
];
let part_lookup_files = vec![
part_lookup_file_path(1 << 32),
part_lookup_file_path(2 << 32),
];
super::merge_metadata_files(
fragment_store.as_ref(),
&part_page_files,
&part_lookup_files,
Option::from(1usize),
)
.await
.unwrap();
let full_index = BTreeIndex::load(full_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
let merged_index = BTreeIndex::load(fragment_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
let query_0 = SargableQuery::Equals(ScalarValue::Int32(Some(0)));
let full_result_0 = full_index
.search(&query_0, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_0 = merged_index
.search(&query_0, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(full_result_0, merged_result_0, "Query for value 0 failed");
let mid_first_batch = (DEFAULT_BTREE_BATCH_SIZE / 2) as i32;
let query_mid_first = SargableQuery::Equals(ScalarValue::Int32(Some(mid_first_batch)));
let full_result_mid_first = full_index
.search(&query_mid_first, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_mid_first = merged_index
.search(&query_mid_first, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_mid_first, merged_result_mid_first,
"Query for value {} failed",
mid_first_batch
);
let first_second_batch = DEFAULT_BTREE_BATCH_SIZE as i32;
let query_first_second =
SargableQuery::Equals(ScalarValue::Int32(Some(first_second_batch)));
let full_result_first_second = full_index
.search(&query_first_second, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_first_second = merged_index
.search(&query_first_second, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_first_second, merged_result_first_second,
"Query for value {} failed",
first_second_batch
);
let mid_second_batch = (DEFAULT_BTREE_BATCH_SIZE + DEFAULT_BTREE_BATCH_SIZE / 2) as i32;
let query_mid_second = SargableQuery::Equals(ScalarValue::Int32(Some(mid_second_batch)));
let full_result_mid_second = full_index
.search(&query_mid_second, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_mid_second = merged_index
.search(&query_mid_second, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_mid_second, merged_result_mid_second,
"Query for value {} failed",
mid_second_batch
);
}
#[tokio::test]
async fn test_fragment_btree_index_boundary_queries() {
let full_tmpdir = TempObjDir::default();
let full_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
full_tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let fragment_tmpdir = TempObjDir::default();
let fragment_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
fragment_tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let total_count = 3 * DEFAULT_BTREE_BATCH_SIZE;
let full_data_gen = gen_batch()
.col("value", array::step::<Int32Type>())
.col("_rowid", array::step::<UInt64Type>())
.into_df_stream(RowCount::from(total_count / 3), BatchCount::from(3));
let full_data_source = Box::pin(RecordBatchStreamAdapter::new(
full_data_gen.schema(),
full_data_gen,
));
train_btree_index(
full_data_source,
full_store.as_ref(),
DEFAULT_BTREE_BATCH_SIZE,
None,
None,
)
.await
.unwrap();
let fragment_size = DEFAULT_BTREE_BATCH_SIZE;
let fragment1_gen = gen_batch()
.col("value", array::step::<Int32Type>())
.col("_rowid", array::step::<UInt64Type>())
.into_df_stream(RowCount::from(fragment_size), BatchCount::from(1));
let fragment1_data_source = Box::pin(RecordBatchStreamAdapter::new(
fragment1_gen.schema(),
fragment1_gen,
));
train_btree_index(
fragment1_data_source,
fragment_store.as_ref(),
DEFAULT_BTREE_BATCH_SIZE,
Some(vec![1]),
None,
)
.await
.unwrap();
let start_val2 = DEFAULT_BTREE_BATCH_SIZE as i32;
let end_val2 = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32;
let values_fragment2: Vec<i32> = (start_val2..end_val2).collect();
let row_ids_fragment2: Vec<u64> = (start_val2 as u64..end_val2 as u64).collect();
let fragment2_gen = gen_batch()
.col("value", array::cycle::<Int32Type>(values_fragment2))
.col("_rowid", array::cycle::<UInt64Type>(row_ids_fragment2))
.into_df_stream(RowCount::from(fragment_size), BatchCount::from(1));
let fragment2_data_source = Box::pin(RecordBatchStreamAdapter::new(
fragment2_gen.schema(),
fragment2_gen,
));
train_btree_index(
fragment2_data_source,
fragment_store.as_ref(),
DEFAULT_BTREE_BATCH_SIZE,
Some(vec![2]),
None,
)
.await
.unwrap();
let start_val3 = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32;
let end_val3 = (3 * DEFAULT_BTREE_BATCH_SIZE) as i32;
let values_fragment3: Vec<i32> = (start_val3..end_val3).collect();
let row_ids_fragment3: Vec<u64> = (start_val3 as u64..end_val3 as u64).collect();
let fragment3_gen = gen_batch()
.col("value", array::cycle::<Int32Type>(values_fragment3))
.col("_rowid", array::cycle::<UInt64Type>(row_ids_fragment3))
.into_df_stream(RowCount::from(fragment_size), BatchCount::from(1));
let fragment3_data_source = Box::pin(RecordBatchStreamAdapter::new(
fragment3_gen.schema(),
fragment3_gen,
));
train_btree_index(
fragment3_data_source,
fragment_store.as_ref(),
DEFAULT_BTREE_BATCH_SIZE,
Some(vec![3]),
None,
)
.await
.unwrap();
let part_page_files = vec![
part_page_data_file_path(1 << 32),
part_page_data_file_path(2 << 32),
part_page_data_file_path(3 << 32),
];
let part_lookup_files = vec![
part_lookup_file_path(1 << 32),
part_lookup_file_path(2 << 32),
part_lookup_file_path(3 << 32),
];
super::merge_metadata_files(
fragment_store.as_ref(),
&part_page_files,
&part_lookup_files,
Option::from(1usize),
)
.await
.unwrap();
let full_index = BTreeIndex::load(full_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
let merged_index = BTreeIndex::load(fragment_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
let query_min = SargableQuery::Equals(ScalarValue::Int32(Some(0)));
let full_result_min = full_index
.search(&query_min, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_min = merged_index
.search(&query_min, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_min, merged_result_min,
"Query for minimum value 0 failed"
);
let max_val = (3 * DEFAULT_BTREE_BATCH_SIZE - 1) as i32;
let query_max = SargableQuery::Equals(ScalarValue::Int32(Some(max_val)));
let full_result_max = full_index
.search(&query_max, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_max = merged_index
.search(&query_max, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_max, merged_result_max,
"Query for maximum value {} failed",
max_val
);
let fragment1_last = (DEFAULT_BTREE_BATCH_SIZE - 1) as i32;
let query_frag1_last = SargableQuery::Equals(ScalarValue::Int32(Some(fragment1_last)));
let full_result_frag1_last = full_index
.search(&query_frag1_last, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_frag1_last = merged_index
.search(&query_frag1_last, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_frag1_last, merged_result_frag1_last,
"Query for fragment 1 last value {} failed",
fragment1_last
);
let fragment2_first = DEFAULT_BTREE_BATCH_SIZE as i32;
let query_frag2_first = SargableQuery::Equals(ScalarValue::Int32(Some(fragment2_first)));
let full_result_frag2_first = full_index
.search(&query_frag2_first, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_frag2_first = merged_index
.search(&query_frag2_first, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_frag2_first, merged_result_frag2_first,
"Query for fragment 2 first value {} failed",
fragment2_first
);
let fragment2_last = (2 * DEFAULT_BTREE_BATCH_SIZE - 1) as i32;
let query_frag2_last = SargableQuery::Equals(ScalarValue::Int32(Some(fragment2_last)));
let full_result_frag2_last = full_index
.search(&query_frag2_last, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_frag2_last = merged_index
.search(&query_frag2_last, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_frag2_last, merged_result_frag2_last,
"Query for fragment 2 last value {} failed",
fragment2_last
);
let fragment3_first = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32;
let query_frag3_first = SargableQuery::Equals(ScalarValue::Int32(Some(fragment3_first)));
let full_result_frag3_first = full_index
.search(&query_frag3_first, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_frag3_first = merged_index
.search(&query_frag3_first, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_frag3_first, merged_result_frag3_first,
"Query for fragment 3 first value {} failed",
fragment3_first
);
let query_below_min = SargableQuery::Equals(ScalarValue::Int32(Some(-1)));
let full_result_below = full_index
.search(&query_below_min, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_below = merged_index
.search(&query_below_min, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_below, merged_result_below,
"Query for value below minimum (-1) failed"
);
let query_above_max = SargableQuery::Equals(ScalarValue::Int32(Some(max_val + 1)));
let full_result_above = full_index
.search(&query_above_max, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_above = merged_index
.search(&query_above_max, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_above,
merged_result_above,
"Query for value above maximum ({}) failed",
max_val + 1
);
let range_start = (DEFAULT_BTREE_BATCH_SIZE - 100) as i32;
let range_end = (DEFAULT_BTREE_BATCH_SIZE + 100) as i32;
let query_cross_frag = SargableQuery::Range(
std::collections::Bound::Included(ScalarValue::Int32(Some(range_start))),
std::collections::Bound::Excluded(ScalarValue::Int32(Some(range_end))),
);
let full_result_cross = full_index
.search(&query_cross_frag, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_cross = merged_index
.search(&query_cross_frag, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_cross, merged_result_cross,
"Cross-fragment range query [{}, {}] failed",
range_start, range_end
);
let single_frag_start = 100i32;
let single_frag_end = 200i32;
let query_single_frag = SargableQuery::Range(
std::collections::Bound::Included(ScalarValue::Int32(Some(single_frag_start))),
std::collections::Bound::Excluded(ScalarValue::Int32(Some(single_frag_end))),
);
let full_result_single = full_index
.search(&query_single_frag, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_single = merged_index
.search(&query_single_frag, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_single, merged_result_single,
"Single fragment range query [{}, {}] failed",
single_frag_start, single_frag_end
);
let large_range_start = 100i32;
let large_range_end = (3 * DEFAULT_BTREE_BATCH_SIZE - 100) as i32;
let query_large_range = SargableQuery::Range(
std::collections::Bound::Included(ScalarValue::Int32(Some(large_range_start))),
std::collections::Bound::Excluded(ScalarValue::Int32(Some(large_range_end))),
);
let full_result_large = full_index
.search(&query_large_range, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_large = merged_index
.search(&query_large_range, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_large, merged_result_large,
"Large range query [{}, {}] failed",
large_range_start, large_range_end
);
let lt_val = (DEFAULT_BTREE_BATCH_SIZE / 2) as i32;
let query_lt = SargableQuery::Range(
std::collections::Bound::Included(ScalarValue::Int32(Some(0))),
std::collections::Bound::Excluded(ScalarValue::Int32(Some(lt_val))),
);
let full_result_lt = full_index
.search(&query_lt, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_lt = merged_index
.search(&query_lt, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_lt, merged_result_lt,
"Less than query (<{}) failed",
lt_val
);
let gt_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32;
let max_range_val = (3 * DEFAULT_BTREE_BATCH_SIZE) as i32;
let query_gt = SargableQuery::Range(
std::collections::Bound::Excluded(ScalarValue::Int32(Some(gt_val))),
std::collections::Bound::Excluded(ScalarValue::Int32(Some(max_range_val))),
);
let full_result_gt = full_index
.search(&query_gt, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_gt = merged_index
.search(&query_gt, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_gt, merged_result_gt,
"Greater than query (>{}) failed",
gt_val
);
let lte_val = (DEFAULT_BTREE_BATCH_SIZE - 1) as i32;
let query_lte = SargableQuery::Range(
std::collections::Bound::Included(ScalarValue::Int32(Some(0))),
std::collections::Bound::Included(ScalarValue::Int32(Some(lte_val))),
);
let full_result_lte = full_index
.search(&query_lte, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_lte = merged_index
.search(&query_lte, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_lte, merged_result_lte,
"Less than or equal query (<={}) failed",
lte_val
);
let gte_val = (2 * DEFAULT_BTREE_BATCH_SIZE) as i32;
let query_gte = SargableQuery::Range(
std::collections::Bound::Included(ScalarValue::Int32(Some(gte_val))),
std::collections::Bound::Excluded(ScalarValue::Int32(Some(max_range_val))),
);
let full_result_gte = full_index
.search(&query_gte, &NoOpMetricsCollector)
.await
.unwrap();
let merged_result_gte = merged_index
.search(&query_gte, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_gte, merged_result_gte,
"Greater than or equal query (>={}) failed",
gte_val
);
}
#[test]
fn test_extract_partition_id() {
assert_eq!(
super::extract_partition_id("part_123_page_data.lance").unwrap(),
123
);
assert_eq!(
super::extract_partition_id("part_456_page_lookup.lance").unwrap(),
456
);
assert_eq!(
super::extract_partition_id("part_4294967296_page_data.lance").unwrap(),
4294967296
);
assert!(super::extract_partition_id("invalid_filename.lance").is_err());
assert!(super::extract_partition_id("part_abc_page_data.lance").is_err());
assert!(super::extract_partition_id("part_123").is_err());
assert!(super::extract_partition_id("part_").is_err());
}
#[tokio::test]
async fn test_cleanup_partition_files() {
let tmpdir = TempObjDir::default();
let test_store: Arc<dyn crate::scalar::IndexStore> = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let lookup_files = vec![
"part_123_page_lookup.lance".to_string(),
"invalid_lookup_file.lance".to_string(),
"part_456_page_lookup.lance".to_string(),
];
let page_files = vec![
"part_123_page_data.lance".to_string(),
"invalid_page_file.lance".to_string(),
"part_456_page_data.lance".to_string(),
];
super::cleanup_partition_files(test_store.as_ref(), &lookup_files, &page_files).await;
}
#[tokio::test]
async fn test_btree_null_handling_in_queries() {
let store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::memory()),
Path::default(),
Arc::new(LanceCache::no_cache()),
));
let batch = record_batch!(
("value", Int32, [None, Some(0), Some(5)]),
("_rowid", UInt64, [0, 1, 2])
)
.unwrap();
let stream = stream::once(futures::future::ok(batch.clone()));
let stream = Box::pin(RecordBatchStreamAdapter::new(batch.schema(), stream));
super::train_btree_index(stream, store.as_ref(), 256, None, None)
.await
.unwrap();
let cache = LanceCache::with_capacity(1024 * 1024);
let index = super::BTreeIndex::load(store.clone(), None, &cache)
.await
.unwrap();
let query = SargableQuery::Equals(ScalarValue::Int32(Some(5)));
let result = index.search(&query, &NoOpMetricsCollector).await.unwrap();
match result {
SearchResult::Exact(row_ids) => {
let actual_rows: Vec<u64> = row_ids
.true_rows()
.row_addrs()
.unwrap()
.map(u64::from)
.collect();
assert_eq!(actual_rows, vec![2], "Should find row 2 where value == 5");
let null_row_ids = row_ids.null_rows();
assert!(!null_row_ids.is_empty(), "null_row_ids should be non-empty");
let null_rows: Vec<u64> =
null_row_ids.row_addrs().unwrap().map(u64::from).collect();
assert_eq!(null_rows, vec![0], "Should report row 0 as null");
}
_ => panic!("Expected Exact search result"),
}
let query = SargableQuery::Range(
std::ops::Bound::Included(ScalarValue::Int32(Some(0))),
std::ops::Bound::Included(ScalarValue::Int32(Some(3))),
);
let result = index.search(&query, &NoOpMetricsCollector).await.unwrap();
match result {
SearchResult::Exact(row_ids) => {
let actual_rows: Vec<u64> = row_ids
.true_rows()
.row_addrs()
.unwrap()
.map(u64::from)
.collect();
assert_eq!(actual_rows, vec![1], "Should find row 1 where value == 0");
let null_row_ids = row_ids.null_rows();
assert!(!null_row_ids.is_empty(), "null_row_ids should be non-empty");
let null_rows: Vec<u64> =
null_row_ids.row_addrs().unwrap().map(u64::from).collect();
assert_eq!(null_rows, vec![0], "Should report row 0 as null");
}
_ => panic!("Expected Exact search result"),
}
let query = SargableQuery::IsIn(vec![
ScalarValue::Int32(Some(0)),
ScalarValue::Int32(Some(5)),
]);
let result = index.search(&query, &NoOpMetricsCollector).await.unwrap();
match result {
SearchResult::Exact(row_ids) => {
let mut actual_rows: Vec<u64> = row_ids
.true_rows()
.row_addrs()
.unwrap()
.map(u64::from)
.collect();
actual_rows.sort();
assert_eq!(
actual_rows,
vec![1, 2],
"Should find rows 1 and 2 where value in [0, 5]"
);
let null_row_ids = row_ids.null_rows();
assert!(!null_row_ids.is_empty(), "null_row_ids should be non-empty");
let null_rows: Vec<u64> =
null_row_ids.row_addrs().unwrap().map(u64::from).collect();
assert_eq!(null_rows, vec![0], "Should report row 0 as null");
}
_ => panic!("Expected Exact search result"),
}
}
#[tokio::test]
async fn test_range_btree_index_consistency() {
let full_tmpdir = TempObjDir::default();
let full_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
full_tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let range_tmpdir = TempObjDir::default();
let range_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
range_tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let total_count = 4 * DEFAULT_BTREE_BATCH_SIZE;
let full_data_gen = gen_batch()
.col("value", array::step::<Int32Type>())
.col("_rowid", array::step::<UInt64Type>())
.into_df_stream(RowCount::from(total_count / 4), BatchCount::from(4));
let full_data_source = Box::pin(RecordBatchStreamAdapter::new(
full_data_gen.schema(),
full_data_gen,
));
train_btree_index(
full_data_source,
full_store.as_ref(),
DEFAULT_BTREE_BATCH_SIZE,
None,
None,
)
.await
.unwrap();
let range1_gen = gen_batch()
.col("value", array::step::<Int32Type>())
.col("_rowid", array::step::<UInt64Type>())
.into_df_stream(
RowCount::from(DEFAULT_BTREE_BATCH_SIZE / 2),
BatchCount::from(5),
);
let range1_data_source = Box::pin(RecordBatchStreamAdapter::new(
range1_gen.schema(),
range1_gen,
));
train_btree_index(
range1_data_source,
range_store.as_ref(),
DEFAULT_BTREE_BATCH_SIZE,
None,
Option::from(0u32),
)
.await
.unwrap();
let start_val = (DEFAULT_BTREE_BATCH_SIZE * 2 + DEFAULT_BTREE_BATCH_SIZE / 2) as i32;
let end_val = (4 * DEFAULT_BTREE_BATCH_SIZE) as i32;
let values_second_half: Vec<i32> = (start_val..end_val).collect();
let row_ids_second_half: Vec<u64> = (start_val as u64..end_val as u64).collect();
let range2_gen = gen_batch()
.col("value", array::cycle::<Int32Type>(values_second_half))
.col("_rowid", array::cycle::<UInt64Type>(row_ids_second_half))
.into_df_stream(
RowCount::from(DEFAULT_BTREE_BATCH_SIZE / 2),
BatchCount::from(3),
);
let range2_data_source = Box::pin(RecordBatchStreamAdapter::new(
range2_gen.schema(),
range2_gen,
));
train_btree_index(
range2_data_source,
range_store.as_ref(),
DEFAULT_BTREE_BATCH_SIZE,
None,
Option::from(1u32),
)
.await
.unwrap();
let part_page_files = vec![
part_page_data_file_path(0 << 32),
part_page_data_file_path(1 << 32),
];
let part_lookup_files = vec![
part_lookup_file_path(0 << 32),
part_lookup_file_path(1 << 32),
];
super::merge_metadata_files(
range_store.as_ref(),
&part_page_files,
&part_lookup_files,
Option::from(1usize),
)
.await
.unwrap();
let full_index = BTreeIndex::load(full_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
let ranged_index = BTreeIndex::load(range_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
let query_0 = SargableQuery::Equals(ScalarValue::Int32(Some(0)));
let full_result_0 = full_index
.search(&query_0, &NoOpMetricsCollector)
.await
.unwrap();
let ranged_result_0 = ranged_index
.search(&query_0, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(full_result_0, ranged_result_0, "Query for value 0 failed");
let mid_first_batch = (DEFAULT_BTREE_BATCH_SIZE / 2) as i32;
let query_mid_first = SargableQuery::Equals(ScalarValue::Int32(Some(mid_first_batch)));
let full_result_mid_first = full_index
.search(&query_mid_first, &NoOpMetricsCollector)
.await
.unwrap();
let ranged_result_mid_first = ranged_index
.search(&query_mid_first, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_mid_first, ranged_result_mid_first,
"Query for value {} failed",
mid_first_batch
);
let mid_last_batch = (DEFAULT_BTREE_BATCH_SIZE * 3 + (DEFAULT_BTREE_BATCH_SIZE / 2)) as i32;
let query_mid_last = SargableQuery::Equals(ScalarValue::Int32(Some(mid_last_batch)));
let full_result_mid_last = full_index
.search(&query_mid_last, &NoOpMetricsCollector)
.await
.unwrap();
let ranged_result_mid_last = ranged_index
.search(&query_mid_last, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_mid_last, ranged_result_mid_last,
"Query for value {} failed",
mid_last_batch
);
let max_val = (4 * DEFAULT_BTREE_BATCH_SIZE - 1) as i32;
let query_max = SargableQuery::Equals(ScalarValue::Int32(Some(max_val)));
let full_result_max = full_index
.search(&query_max, &NoOpMetricsCollector)
.await
.unwrap();
let ranged_result_max = ranged_index
.search(&query_max, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_max, ranged_result_max,
"Query for maximum value {} failed",
max_val
);
let second_first_val = (DEFAULT_BTREE_BATCH_SIZE * 2 + DEFAULT_BTREE_BATCH_SIZE / 2) as i32;
let query_second_first = SargableQuery::Equals(ScalarValue::Int32(Some(second_first_val)));
let full_result_second_first = full_index
.search(&query_second_first, &NoOpMetricsCollector)
.await
.unwrap();
let ranged_result_second_first = ranged_index
.search(&query_second_first, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_second_first, ranged_result_second_first,
"Query for first value of the second page file {} failed",
second_first_val
);
let query_below_min = SargableQuery::Equals(ScalarValue::Int32(Some(-1)));
let full_result_below = full_index
.search(&query_below_min, &NoOpMetricsCollector)
.await
.unwrap();
let ranged_result_below = ranged_index
.search(&query_below_min, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_below, ranged_result_below,
"Query for value below minimum (-1) failed"
);
let query_above_max = SargableQuery::Equals(ScalarValue::Int32(Some(max_val + 1)));
let full_result_above = full_index
.search(&query_above_max, &NoOpMetricsCollector)
.await
.unwrap();
let ranged_result_above = ranged_index
.search(&query_above_max, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_above,
ranged_result_above,
"Query for value above maximum ({}) failed",
max_val + 1
);
let range_start =
(DEFAULT_BTREE_BATCH_SIZE * 2 + DEFAULT_BTREE_BATCH_SIZE / 2 - 100) as i32;
let range_end = range_start + 200;
let query_cross_range = SargableQuery::Range(
std::collections::Bound::Included(ScalarValue::Int32(Some(range_start))),
std::collections::Bound::Excluded(ScalarValue::Int32(Some(range_end))),
);
let full_result_cross = full_index
.search(&query_cross_range, &NoOpMetricsCollector)
.await
.unwrap();
let ranged_result_cross = ranged_index
.search(&query_cross_range, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_cross, ranged_result_cross,
"Cross-range range query [{}, {}] failed",
range_start, range_end
);
let single_range_start = (DEFAULT_BTREE_BATCH_SIZE * 4 - 300) as i32;
let single_range_end = single_range_start + 200;
let query_single_range = SargableQuery::Range(
std::collections::Bound::Included(ScalarValue::Int32(Some(single_range_start))),
std::collections::Bound::Excluded(ScalarValue::Int32(Some(single_range_end))),
);
let full_result_single = full_index
.search(&query_single_range, &NoOpMetricsCollector)
.await
.unwrap();
let ranged_result_single = ranged_index
.search(&query_single_range, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_single, ranged_result_single,
"Single range query [{}, {}] failed",
single_range_start, single_range_end
);
let large_range_start = 100_i32;
let large_range_end = (DEFAULT_BTREE_BATCH_SIZE * 4 - 100) as i32;
let query_large_range = SargableQuery::Range(
std::collections::Bound::Included(ScalarValue::Int32(Some(large_range_start))),
std::collections::Bound::Excluded(ScalarValue::Int32(Some(large_range_end))),
);
let full_result_single = full_index
.search(&query_large_range, &NoOpMetricsCollector)
.await
.unwrap();
let ranged_result_single = ranged_index
.search(&query_large_range, &NoOpMetricsCollector)
.await
.unwrap();
assert_eq!(
full_result_single, ranged_result_single,
"Single fragment range query [{}, {}] failed",
large_range_start, large_range_end
);
let remap_dir = TempObjDir::default();
let remap_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
remap_dir.clone(),
Arc::new(LanceCache::no_cache()),
));
ranged_index
.remap(&HashMap::default(), remap_store.as_ref())
.await
.unwrap();
let remap_index = BTreeIndex::load(remap_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
assert_eq!(remap_index.page_lookup, ranged_index.page_lookup);
let ranged_pages = range_store
.open_index_file(part_page_data_file_path(1 << 32).as_str())
.await
.unwrap();
let remapped_pages = remap_store
.open_index_file(part_page_data_file_path(1 << 32).as_str())
.await
.unwrap();
assert_eq!(ranged_pages.num_rows(), remapped_pages.num_rows());
let original_data = ranged_pages
.read_record_batch(0, ranged_pages.num_rows() as u64)
.await
.unwrap();
let remapped_data = remapped_pages
.read_record_batch(0, remapped_pages.num_rows() as u64)
.await
.unwrap();
assert_eq!(original_data, remapped_data);
}
#[tokio::test]
async fn test_update_ranged_index() {
let old_tmpdir = TempObjDir::default();
let old_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
old_tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let new_tmpdir = TempObjDir::default();
let new_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
new_tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let range1_gen = gen_batch()
.col("value", array::step::<Int32Type>())
.col("_rowid", array::step::<UInt64Type>())
.into_df_stream(
RowCount::from(DEFAULT_BTREE_BATCH_SIZE / 2),
BatchCount::from(5),
);
let range1_data_source = Box::pin(RecordBatchStreamAdapter::new(
range1_gen.schema(),
range1_gen,
));
train_btree_index(
range1_data_source,
old_store.as_ref(),
DEFAULT_BTREE_BATCH_SIZE,
None,
Option::from(1u32),
)
.await
.unwrap();
let start_val = (DEFAULT_BTREE_BATCH_SIZE * 2 + DEFAULT_BTREE_BATCH_SIZE / 2) as i32;
let end_val = (4 * DEFAULT_BTREE_BATCH_SIZE) as i32;
let values_second_half: Vec<i32> = (start_val..end_val).collect();
let row_ids_second_half: Vec<u64> = (start_val as u64..end_val as u64).collect();
let range2_gen = gen_batch()
.col("value", array::cycle::<Int32Type>(values_second_half))
.col("_rowid", array::cycle::<UInt64Type>(row_ids_second_half))
.into_df_stream(
RowCount::from(DEFAULT_BTREE_BATCH_SIZE / 2),
BatchCount::from(3),
);
let range2_data_source = Box::pin(RecordBatchStreamAdapter::new(
range2_gen.schema(),
range2_gen,
));
train_btree_index(
range2_data_source,
old_store.as_ref(),
DEFAULT_BTREE_BATCH_SIZE,
None,
Option::from(2u32),
)
.await
.unwrap();
let part_page_files = vec![
part_page_data_file_path(1 << 32),
part_page_data_file_path(2 << 32),
];
let part_lookup_files = vec![
part_lookup_file_path(1 << 32),
part_lookup_file_path(2 << 32),
];
super::merge_metadata_files(
old_store.as_ref(),
&part_page_files,
&part_lookup_files,
Option::from(1usize),
)
.await
.unwrap();
let start_val = (DEFAULT_BTREE_BATCH_SIZE * 2) as i32;
let end_val = (DEFAULT_BTREE_BATCH_SIZE * 3) as i32;
let row_id_delta = (DEFAULT_BTREE_BATCH_SIZE * 3) as i32;
let values: Vec<i32> = (start_val..end_val).collect();
let row_ids: Vec<u64> =
((start_val + row_id_delta) as u64..(end_val + row_id_delta) as u64).collect();
let update_data = gen_batch()
.col("value", array::cycle::<Int32Type>(values))
.col("_rowid", array::cycle::<UInt64Type>(row_ids))
.into_df_stream(
RowCount::from(DEFAULT_BTREE_BATCH_SIZE / 2),
BatchCount::from(2),
);
let update_data_source = Box::pin(RecordBatchStreamAdapter::new(
update_data.schema(),
update_data,
));
let ranged_index = BTreeIndex::load(old_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
ranged_index
.update(update_data_source, new_store.as_ref(), None)
.await
.expect("Error in updating ranged index");
let updated_index = BTreeIndex::load(new_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
assert!(
updated_index.ranges_to_files.is_none(),
"Updated ranged-btree-index should fall back to non-ranged"
);
let updated_value = (DEFAULT_BTREE_BATCH_SIZE * 2 + (DEFAULT_BTREE_BATCH_SIZE / 2)) as i32;
let updated_query = SargableQuery::Equals(ScalarValue::Int32(Some(updated_value)));
let query_result = updated_index
.search(&updated_query, &NoOpMetricsCollector)
.await
.unwrap();
match query_result {
SearchResult::Exact(row_id_map) => {
assert!(
row_id_map.selected(updated_value as u64),
"Updated index should contain original rowids."
);
assert!(
row_id_map.selected((updated_value + row_id_delta) as u64),
"Updated index should contain new rowids"
);
}
_ => {
panic!("Btree search result should always be Exact.");
}
}
}
#[tokio::test]
async fn test_update_with_exact_row_id_filter() {
let old_tmpdir = TempObjDir::default();
let old_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
old_tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let new_tmpdir = TempObjDir::default();
let new_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
new_tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let old_data = gen_batch()
.col("value", array::step::<Int32Type>())
.col("_rowid", array::step::<UInt64Type>())
.into_df_stream(RowCount::from(512), BatchCount::from(2));
let old_data_source = Box::pin(RecordBatchStreamAdapter::new(old_data.schema(), old_data));
train_btree_index(
old_data_source,
old_store.as_ref(),
DEFAULT_BTREE_BATCH_SIZE,
None,
None,
)
.await
.unwrap();
let index = BTreeIndex::load(old_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
let new_data = gen_batch()
.col("value", array::step_custom::<Int32Type>(2000, 1))
.col("_rowid", array::step_custom::<UInt64Type>(2000, 1))
.into_df_stream(RowCount::from(100), BatchCount::from(1));
let new_data_source = Box::pin(RecordBatchStreamAdapter::new(new_data.schema(), new_data));
let mut retained_old_rows = RowAddrTreeMap::new();
retained_old_rows.insert_range(0..64);
retained_old_rows.insert_range(300..364);
index
.update(
new_data_source,
new_store.as_ref(),
Some(OldIndexDataFilter::RowIds(retained_old_rows)),
)
.await
.unwrap();
let updated_index = BTreeIndex::load(new_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
let present = |value: i32| {
let updated_index = updated_index.clone();
async move {
let query = SargableQuery::Equals(ScalarValue::Int32(Some(value)));
match updated_index
.search(&query, &NoOpMetricsCollector)
.await
.unwrap()
{
SearchResult::Exact(row_id_map) => row_id_map.selected(value as u64),
_ => unreachable!("Btree search result should always be Exact"),
}
}
};
assert!(present(12).await);
assert!(present(320).await);
assert!(!present(120).await);
assert!(!present(420).await);
assert!(present(2005).await);
}
#[tokio::test]
async fn test_btree_remap_big_deletions() {
let tmpdir = TempObjDir::default();
let test_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let batch_size = 4096;
let total_rows = 15000;
let stream = gen_batch()
.col("value", array::step::<Int32Type>())
.col("_rowid", array::step::<UInt64Type>())
.into_df_stream(RowCount::from(total_rows), BatchCount::from(1));
train_btree_index(stream, test_store.as_ref(), batch_size, None, None)
.await
.unwrap();
let index = BTreeIndex::load(test_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
let mut mapping: HashMap<u64, Option<u64>> = HashMap::new();
for old_id in 1001..10000 {
mapping.insert(old_id, None);
}
let mut new_id_counter = 100_000;
for old_id in (0..1000).chain(10000..15000) {
let new_id = new_id_counter;
new_id_counter += 1;
mapping.insert(old_id, Some(new_id));
}
let remap_dir = TempObjDir::default();
let remap_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
remap_dir.clone(),
Arc::new(LanceCache::no_cache()),
));
index.remap(&mapping, remap_store.as_ref()).await.unwrap();
let remapped_index = BTreeIndex::load(remap_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
let should_exist = vec![0, 500, 1000, 10000, 13000, 14000, 14999];
for value in should_exist {
let query = SargableQuery::Equals(ScalarValue::Int32(Some(value)));
let result = remapped_index
.search(&query, &NoOpMetricsCollector)
.await
.unwrap();
match result {
SearchResult::Exact(row_id_map) => {
assert!(
!row_id_map.is_empty(),
"Value {} should exist in remapped index but was not found",
value
);
}
_ => {
panic!("Btree search result should always be Exact.");
}
}
}
let should_not_exist = vec![1001, 5000, 8000, 9999];
for value in should_not_exist {
let query = SargableQuery::Equals(ScalarValue::Int32(Some(value)));
let result = remapped_index
.search(&query, &NoOpMetricsCollector)
.await
.unwrap();
match result {
SearchResult::Exact(row_id_map) => {
assert!(
row_id_map.is_empty(),
"Value {} should NOT exist in remapped index but was found",
value
);
}
_ => {
panic!("Btree search result should always be Exact.");
}
}
}
}
#[tokio::test]
async fn test_search_tracks_nulls_for_absent_value() {
use arrow_array::{Int32Array, UInt64Array};
let tmpdir = TempObjDir::default();
let test_store = Arc::new(LanceIndexStore::new(
Arc::new(ObjectStore::local()),
tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));
let num_rows = 5000u64;
let values: Int32Array = (0..num_rows)
.map(|i| {
if i % 5 != 0 {
None } else {
Some(100 + i as i32) }
})
.collect();
let row_ids = UInt64Array::from_iter_values(0..num_rows);
let data = arrow_array::RecordBatch::try_from_iter(vec![
("value", Arc::new(values) as arrow_array::ArrayRef),
("_rowid", Arc::new(row_ids) as arrow_array::ArrayRef),
])
.unwrap();
let schema = data.schema();
let stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new(
schema,
stream::iter(vec![Ok(data)]),
));
train_btree_index(stream, test_store.as_ref(), num_rows, None, None)
.await
.unwrap();
let index = BTreeIndex::load(test_store.clone(), None, &LanceCache::no_cache())
.await
.unwrap();
assert!(
!index.page_lookup.all_null_pages.is_empty(),
"Test setup requires all-null pages; got null_pages={}, all_null_pages={}",
index.page_lookup.null_pages.len(),
index.page_lookup.all_null_pages.len(),
);
let metrics = NoOpMetricsCollector;
let result = index
.search(
&SargableQuery::Equals(ScalarValue::Int32(Some(0))),
&metrics,
)
.await
.unwrap();
match result {
SearchResult::Exact(set) => {
assert!(set.true_rows().is_empty(), "No rows should match Equals(0)");
assert!(
!set.null_rows().is_empty(),
"Null rows must be tracked even when no pages match the value"
);
}
_ => panic!("BTree search should return Exact"),
}
let result = index
.search(
&SargableQuery::Range(
std::ops::Bound::Unbounded,
std::ops::Bound::Excluded(ScalarValue::Int32(Some(50))),
),
&metrics,
)
.await
.unwrap();
match result {
SearchResult::Exact(set) => {
assert!(set.true_rows().is_empty(), "No rows should be < 50");
assert!(
!set.null_rows().is_empty(),
"Null rows must be tracked for range queries too"
);
}
_ => panic!("BTree search should return Exact"),
}
}
}