use std::collections::HashMap;
use std::sync::{Arc, LazyLock};
use datafusion::common::ScalarValue;
use lance_core::{Error, Result};
use lance_index::metrics::NoOpMetricsCollector;
use lance_index::registry::IndexPluginRegistry;
use lance_index::scalar::btree::BTreeIndex;
use lance_index::scalar::lance_format::LanceIndexStore;
use lance_index::scalar::{
IndexStore as ScalarIndexStore, SargableQuery, ScalarIndex, SearchResult,
};
use uuid::Uuid;
use super::data_source::{FreshTierWatermark, LsmDataSource, LsmGeneration};
use super::flushed_cache::{DatasetCache, open_flushed_dataset};
use crate::dataset::mem_wal::index::encode_pk_tuple;
use crate::dataset::mem_wal::util::PK_INDEX_DIR;
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
use crate::session::Session;
static PK_BTREE_REGISTRY: LazyLock<Arc<IndexPluginRegistry>> =
LazyLock::new(IndexPluginRegistry::with_default_plugins);
#[derive(Clone, Debug)]
pub enum GenMembership {
InMemory {
index_store: Arc<IndexStore>,
max_visible_row: Option<u64>,
},
OnDisk(Arc<dyn ScalarIndex>),
}
impl GenMembership {
pub async fn contains(&self, key: &ScalarValue) -> Result<bool> {
match self {
Self::InMemory {
index_store,
max_visible_row,
} => Ok(max_visible_row.is_some_and(|max| index_store.pk_contains_key(key, max))),
Self::OnDisk(index) => {
let result = index
.search(&SargableQuery::Equals(key.clone()), &NoOpMetricsCollector)
.await
.map_err(|e| Error::io(e.to_string()))?;
Ok(!search_is_empty(&result))
}
}
}
pub async fn contains_keys(&self, keys: &[ScalarValue]) -> Result<Vec<bool>> {
match self {
Self::InMemory {
index_store,
max_visible_row,
} => Ok(keys
.iter()
.map(|key| max_visible_row.is_some_and(|max| index_store.pk_contains_key(key, max)))
.collect()),
Self::OnDisk(index) => {
let btree = index.as_any().downcast_ref::<BTreeIndex>().ok_or_else(|| {
Error::io("flushed PK dedup index is not a BTree".to_string())
})?;
btree
.contains_keys(keys, &NoOpMetricsCollector)
.await
.map_err(|e| Error::io(e.to_string()))
}
}
}
fn is_empty(&self) -> bool {
match self {
Self::InMemory {
index_store,
max_visible_row,
} => max_visible_row.is_none() || index_store.pk_is_empty(),
Self::OnDisk(_) => false,
}
}
}
fn search_is_empty(result: &SearchResult) -> bool {
match result {
SearchResult::Exact(set) | SearchResult::AtMost(set) | SearchResult::AtLeast(set) => {
set.is_empty()
}
}
}
pub fn on_disk_pk_key(values: &[ScalarValue]) -> Result<ScalarValue> {
match values {
[single] => Ok(single.clone()),
_ => Ok(ScalarValue::Binary(Some(encode_pk_tuple(values)?))),
}
}
pub type SourceBlockLists = HashMap<(Option<Uuid>, LsmGeneration), Vec<GenMembership>>;
type ShardGenSets = HashMap<Uuid, Vec<(LsmGeneration, GenMembership)>>;
pub async fn compute_source_block_lists(
sources: &[LsmDataSource],
session: Option<&Arc<Session>>,
flushed_cache: Option<&Arc<dyn DatasetCache>>,
) -> Result<SourceBlockLists> {
let mut by_shard: ShardGenSets = HashMap::new();
let mut has_base = false;
let mut flushed_loads = Vec::new();
for source in sources {
match source {
LsmDataSource::BaseTable { .. } => has_base = true,
LsmDataSource::ActiveMemTable {
batch_store,
index_store,
shard_id,
generation,
..
} => {
let membership = in_memory_membership(batch_store, index_store);
by_shard
.entry(*shard_id)
.or_default()
.push((*generation, membership));
}
LsmDataSource::FlushedMemTable {
path,
shard_id,
generation,
..
} => flushed_loads.push(async move {
let index = open_pk_index(path, session, flushed_cache).await?;
Ok::<_, Error>((*shard_id, *generation, GenMembership::OnDisk(index)))
}),
}
}
for (shard_id, generation, membership) in futures::future::try_join_all(flushed_loads).await? {
by_shard
.entry(shard_id)
.or_default()
.push((generation, membership));
}
let mut blocked: SourceBlockLists = HashMap::new();
let mut base_blocked: Vec<GenMembership> = Vec::new();
for (shard, mut gens) in by_shard {
gens.sort_by_key(|(generation, _)| std::cmp::Reverse(*generation));
let mut newer: Vec<GenMembership> = Vec::new();
for (generation, membership) in gens {
if !newer.is_empty() {
blocked.insert((Some(shard), generation), newer.clone());
}
if !membership.is_empty() {
base_blocked.push(membership.clone());
newer.push(membership);
}
}
}
if has_base && !base_blocked.is_empty() {
blocked.insert((None, LsmGeneration::BASE_TABLE), base_blocked);
}
Ok(blocked)
}
pub async fn fresh_tier_block_list(
sources: &[LsmDataSource],
session: Option<&Arc<Session>>,
flushed_cache: Option<&Arc<dyn DatasetCache>>,
watermarks: Option<&HashMap<Uuid, FreshTierWatermark>>,
) -> Result<Vec<GenMembership>> {
let mut slots: Vec<Option<GenMembership>> = Vec::with_capacity(sources.len());
let mut flushed_loads = Vec::new();
for source in sources {
match source {
LsmDataSource::BaseTable { .. } => slots.push(None),
LsmDataSource::ActiveMemTable {
batch_store,
index_store,
shard_id,
generation,
..
} => {
let membership = match watermarks.and_then(|m| m.get(shard_id)) {
None => Some(in_memory_membership(batch_store, index_store)),
Some(watermark) => {
let g = generation.as_u64();
if g > watermark.active_generation {
None
} else if g == watermark.active_generation {
Some(bounded_in_memory_membership(
batch_store,
index_store,
watermark.active_batch_count,
))
} else {
Some(in_memory_membership(batch_store, index_store))
}
}
};
slots.push(membership);
}
LsmDataSource::FlushedMemTable {
path,
shard_id,
generation,
..
} => {
let flushed_after_snapshot = watermarks
.and_then(|m| m.get(shard_id))
.is_some_and(|watermark| generation.as_u64() >= watermark.active_generation);
if flushed_after_snapshot {
slots.push(None);
} else {
let slot = slots.len();
slots.push(None);
flushed_loads.push(async move {
let index = open_pk_index(path, session, flushed_cache).await?;
Ok::<_, Error>((slot, GenMembership::OnDisk(index)))
});
}
}
}
}
for (slot, membership) in futures::future::try_join_all(flushed_loads).await? {
slots[slot] = Some(membership);
}
Ok(slots
.into_iter()
.flatten()
.filter(|membership| !membership.is_empty())
.collect())
}
fn in_memory_membership(
batch_store: &Arc<BatchStore>,
index_store: &Arc<IndexStore>,
) -> GenMembership {
let max_visible_row = batch_store.max_visible_row(index_store.max_visible_batch_position());
GenMembership::InMemory {
index_store: index_store.clone(),
max_visible_row,
}
}
fn bounded_in_memory_membership(
batch_store: &Arc<BatchStore>,
index_store: &Arc<IndexStore>,
batch_count: u64,
) -> GenMembership {
let max_visible_row = batch_count
.checked_sub(1)
.and_then(|last_batch| batch_store.max_visible_row(last_batch as usize));
GenMembership::InMemory {
index_store: index_store.clone(),
max_visible_row,
}
}
fn path_cache_uuid(path: &str) -> Uuid {
use std::hash::{Hash, Hasher};
let mut lo = std::collections::hash_map::DefaultHasher::new();
path.hash(&mut lo);
let mut hi = std::collections::hash_map::DefaultHasher::new();
"lance/flushed-pk-index".hash(&mut hi);
path.hash(&mut hi);
Uuid::from_u128(((hi.finish() as u128) << 64) | lo.finish() as u128)
}
async fn open_pk_index(
path: &str,
session: Option<&Arc<Session>>,
flushed_cache: Option<&Arc<dyn DatasetCache>>,
) -> Result<Arc<dyn ScalarIndex>> {
let dataset = open_flushed_dataset(path, session, flushed_cache, None).await?;
let index_cache = dataset.index_cache.for_index(&path_cache_uuid(path), None);
let index_dir = dataset.base.clone().join(PK_INDEX_DIR);
let store: Arc<dyn ScalarIndexStore> = Arc::new(LanceIndexStore::new(
dataset.object_store.clone(),
index_dir,
Arc::new(index_cache.clone()),
));
let plugin = PK_BTREE_REGISTRY.get_plugin_by_name("BTree")?;
if let Some(index) = plugin
.get_from_cache(store.clone(), None, &index_cache)
.await?
{
return Ok(index);
}
let details = prost_types::Any::from_msg(&lance_index::pbold::BTreeIndexDetails::default())
.map_err(|e| Error::io(e.to_string()))?;
let index = plugin
.load_index(store, &details, None, &index_cache)
.await?;
plugin.put_in_cache(&index_cache, index.clone()).await?;
Ok(index)
}
pub async fn write_pk_sidecar(
uri: &str,
batches: &[arrow_array::RecordBatch],
pk_columns: &[&str],
) -> Result<()> {
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use lance_core::cache::LanceCache;
use lance_index::scalar::btree::train_btree_index;
use lance_io::object_store::ObjectStore;
use crate::dataset::mem_wal::util::pk_index_path;
let pk: Vec<(String, i32)> = pk_columns
.iter()
.enumerate()
.map(|(i, c)| (c.to_string(), i as i32))
.collect();
let mut index = IndexStore::new();
index.enable_pk_index(&pk);
let mut offset = 0u64;
for batch in batches {
index.insert(batch, offset)?;
offset += batch.num_rows() as u64;
}
let training = index.pk_training_batches(8192)?;
if training.is_empty() {
return Ok(());
}
let schema = training[0].schema();
let (object_store, base_path) = ObjectStore::from_uri(uri).await?;
let store = LanceIndexStore::new(
object_store,
pk_index_path(&base_path),
Arc::new(LanceCache::no_cache()),
);
let stream = Box::pin(RecordBatchStreamAdapter::new(
schema,
futures::stream::iter(training.into_iter().map(Ok)),
));
train_btree_index(stream, &store, 8192, None, None).await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::mem_wal::scanner::data_source::{LsmDataSource, LsmGeneration};
use crate::dataset::mem_wal::write::IndexStore;
use arrow_array::{Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
use uuid::Uuid;
fn id_batch(ids: &[i32]) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids.to_vec()))]).unwrap()
}
fn active_source(shard: Uuid, generation: u64, ids: &[i32]) -> LsmDataSource {
let store = BatchStore::with_capacity(16);
let mut index = IndexStore::new();
index.enable_pk_index(&[("id".to_string(), 0)]);
for &id in ids {
let b = id_batch(&[id]);
let (bp, off, _) = store.append(b.clone()).unwrap();
index.insert_with_batch_position(&b, off, Some(bp)).unwrap();
}
LsmDataSource::ActiveMemTable {
batch_store: Arc::new(store),
index_store: Arc::new(index),
schema: id_batch(&[1]).schema(),
shard_id: shard,
generation: LsmGeneration::memtable(generation),
}
}
async fn blocks(memberships: &[GenMembership], id: i32) -> bool {
let key = on_disk_pk_key(&[ScalarValue::Int32(Some(id))]).unwrap();
for m in memberships {
if m.contains(&key).await.unwrap() {
return true;
}
}
false
}
#[test]
fn on_disk_key_is_typed_for_single_and_binary_for_composite() {
let single = [ScalarValue::Int32(Some(7))];
assert_eq!(
on_disk_pk_key(&single).unwrap(),
ScalarValue::Int32(Some(7))
);
let composite = [ScalarValue::Int32(Some(1)), ScalarValue::from("a")];
assert!(matches!(
on_disk_pk_key(&composite).unwrap(),
ScalarValue::Binary(Some(_))
));
}
#[tokio::test]
async fn fresh_tier_block_list_one_membership_per_in_memory_gen() {
let shard = Uuid::new_v4();
let sources = vec![
active_source(shard, 2, &[1, 2]),
active_source(shard, 1, &[3]),
];
let memberships = fresh_tier_block_list(&sources, None, None, None)
.await
.unwrap();
assert_eq!(memberships.len(), 2);
for id in [1, 2, 3] {
assert!(blocks(&memberships, id).await);
}
assert!(!blocks(&memberships, 4).await);
}
#[tokio::test]
async fn block_lists_suppress_stale_across_in_memory_gens() {
let shard = Uuid::new_v4();
let sources = vec![
active_source(shard, 1, &[1]),
active_source(shard, 2, &[1, 2]),
];
let blocked = Box::pin(compute_source_block_lists(&sources, None, None))
.await
.unwrap();
let g1 = LsmGeneration::memtable(1);
let g2 = LsmGeneration::memtable(2);
assert!(blocks(&blocked[&(Some(shard), g1)], 1).await);
assert!(!blocked.contains_key(&(Some(shard), g2)));
}
#[tokio::test]
async fn block_lists_suppress_stale_base_row() {
use crate::dataset::{Dataset, WriteParams};
use arrow_array::RecordBatchIterator;
let base_batch = id_batch(&[1, 3]);
let schema = base_batch.schema();
let tmp = tempfile::tempdir().unwrap();
let uri = format!("{}/base", tmp.path().to_str().unwrap());
let reader = RecordBatchIterator::new(vec![Ok(base_batch)], schema.clone());
let base = Arc::new(
Dataset::write(reader, &uri, Some(WriteParams::default()))
.await
.unwrap(),
);
let sources = vec![
LsmDataSource::BaseTable { dataset: base },
active_source(Uuid::new_v4(), 1, &[1, 2]),
];
let blocked = Box::pin(compute_source_block_lists(&sources, None, None))
.await
.unwrap();
let base_blocked = blocked
.get(&(None, LsmGeneration::BASE_TABLE))
.expect("base has a blocked set");
assert!(blocks(base_blocked, 1).await);
assert!(!blocks(base_blocked, 3).await);
}
#[tokio::test]
async fn block_lists_are_keyed_per_shard() {
let a = Uuid::new_v4();
let b = Uuid::new_v4();
let sources = vec![
active_source(a, 1, &[1]),
active_source(a, 2, &[1]),
active_source(b, 1, &[2]),
active_source(b, 2, &[2]),
];
let blocked = Box::pin(compute_source_block_lists(&sources, None, None))
.await
.unwrap();
let g1 = LsmGeneration::memtable(1);
let g2 = LsmGeneration::memtable(2);
assert!(blocks(&blocked[&(Some(a), g1)], 1).await);
assert!(!blocks(&blocked[&(Some(a), g1)], 2).await);
assert!(blocks(&blocked[&(Some(b), g1)], 2).await);
assert!(!blocks(&blocked[&(Some(b), g1)], 1).await);
assert!(!blocked.contains_key(&(Some(a), g2)));
assert!(!blocked.contains_key(&(Some(b), g2)));
}
#[tokio::test]
async fn index_membership_is_snapshot_bounded() {
let shard = Uuid::new_v4();
let schema = id_batch(&[1]).schema();
let g1 = active_source(shard, 1, &[1]);
let g2_store = BatchStore::with_capacity(8);
let mut g2_index = IndexStore::new();
g2_index.enable_pk_index(&[("id".to_string(), 0)]);
let b0 = id_batch(&[99]);
let (bp0, off0, _) = g2_store.append(b0.clone()).unwrap();
g2_index
.insert_with_batch_position(&b0, off0, Some(bp0)) .unwrap();
let b1 = id_batch(&[1]);
let (_, off1, _) = g2_store.append(b1.clone()).unwrap();
g2_index
.insert_with_batch_position(&b1, off1, None) .unwrap();
let g2 = LsmDataSource::ActiveMemTable {
batch_store: Arc::new(g2_store),
index_store: Arc::new(g2_index),
schema,
shard_id: shard,
generation: LsmGeneration::memtable(2),
};
let blocked = Box::pin(compute_source_block_lists(&[g1, g2], None, None))
.await
.unwrap();
let g1_block = &blocked[&(Some(shard), LsmGeneration::memtable(1))];
assert!(blocks(g1_block, 99).await);
assert!(
!blocks(g1_block, 1).await,
"a not-yet-visible newer write must not shadow an older visible copy"
);
}
#[tokio::test]
async fn fresh_tier_watermark_bounds_active_memtable_by_batch_count() {
use crate::dataset::mem_wal::scanner::data_source::FreshTierWatermark;
use std::collections::HashMap;
let shard = Uuid::new_v4();
let sources = vec![active_source(shard, 1, &[1, 2, 3])];
let watermarks: HashMap<Uuid, FreshTierWatermark> = [(
shard,
FreshTierWatermark {
active_generation: 1,
active_batch_count: 2,
},
)]
.into_iter()
.collect();
let sets = fresh_tier_block_list(&sources, None, None, Some(&watermarks))
.await
.unwrap();
assert!(blocks(&sets, 1).await);
assert!(blocks(&sets, 2).await);
assert!(!blocks(&sets, 3).await);
let sets = fresh_tier_block_list(&sources, None, None, None)
.await
.unwrap();
for id in [1, 2, 3] {
assert!(blocks(&sets, id).await);
}
}
#[tokio::test]
async fn fresh_tier_watermark_excludes_newer_gen_includes_lower_gen() {
use crate::dataset::mem_wal::scanner::data_source::FreshTierWatermark;
use std::collections::HashMap;
let shard = Uuid::new_v4();
let sources = vec![
active_source(shard, 3, &[100]),
active_source(shard, 2, &[20, 21]),
active_source(shard, 1, &[1, 2]),
];
let watermarks: HashMap<Uuid, FreshTierWatermark> = [(
shard,
FreshTierWatermark {
active_generation: 2,
active_batch_count: 1,
},
)]
.into_iter()
.collect();
let sets = fresh_tier_block_list(&sources, None, None, Some(&watermarks))
.await
.unwrap();
assert!(blocks(&sets, 1).await); assert!(blocks(&sets, 2).await); assert!(blocks(&sets, 20).await); assert!(!blocks(&sets, 21).await); assert!(!blocks(&sets, 100).await); }
#[tokio::test]
async fn fresh_tier_watermark_excludes_flushed_at_or_above_active() {
use crate::dataset::mem_wal::scanner::data_source::FreshTierWatermark;
use crate::dataset::{Dataset, WriteParams};
use arrow_array::RecordBatchIterator;
use std::collections::HashMap;
let flushed_batch = id_batch(&[5]);
let schema = flushed_batch.schema();
let tmp = tempfile::tempdir().unwrap();
let path = format!("{}/gen2", tmp.path().to_str().unwrap());
let reader = RecordBatchIterator::new(vec![Ok(flushed_batch.clone())], schema.clone());
Dataset::write(reader, &path, Some(WriteParams::default()))
.await
.unwrap();
write_pk_sidecar(&path, &[flushed_batch], &["id"])
.await
.unwrap();
let shard = Uuid::new_v4();
let sources = vec![LsmDataSource::FlushedMemTable {
path,
shard_id: shard,
generation: LsmGeneration::memtable(2),
}];
let at: HashMap<Uuid, FreshTierWatermark> = [(
shard,
FreshTierWatermark {
active_generation: 2,
active_batch_count: u64::MAX,
},
)]
.into_iter()
.collect();
let sets = fresh_tier_block_list(&sources, None, None, Some(&at))
.await
.unwrap();
assert!(!blocks(&sets, 5).await);
let above: HashMap<Uuid, FreshTierWatermark> = [(
shard,
FreshTierWatermark {
active_generation: 3,
active_batch_count: u64::MAX,
},
)]
.into_iter()
.collect();
let sets = fresh_tier_block_list(&sources, None, None, Some(&above))
.await
.unwrap();
assert!(blocks(&sets, 5).await);
}
}