use std::sync::Arc;
use bytes::Bytes;
use lance_core::cache::LanceCache;
use lance_core::{Error, Result};
use lance_index::IndexType;
use lance_index::mem_wal::{FlushedGeneration, RegionManifest};
use lance_index::scalar::{IndexStore, ScalarIndexParams};
use lance_io::object_store::ObjectStore;
use lance_table::format::IndexMetadata;
use log::info;
use object_store::path::Path;
use uuid::Uuid;
use super::super::index::MemIndexConfig;
use super::super::memtable::MemTable;
use crate::Dataset;
use crate::dataset::mem_wal::manifest::RegionManifestStore;
use crate::dataset::mem_wal::util::{flushed_memtable_path, generate_random_hash};
#[derive(Debug, Clone)]
pub struct FlushResult {
pub generation: FlushedGeneration,
pub rows_flushed: usize,
pub covered_wal_entry_position: u64,
}
pub struct MemTableFlusher {
object_store: Arc<ObjectStore>,
base_path: Path,
base_uri: String,
region_id: Uuid,
manifest_store: Arc<RegionManifestStore>,
}
impl MemTableFlusher {
pub fn new(
object_store: Arc<ObjectStore>,
base_path: Path,
base_uri: impl Into<String>,
region_id: Uuid,
manifest_store: Arc<RegionManifestStore>,
) -> Self {
Self {
object_store,
base_path,
base_uri: base_uri.into(),
region_id,
manifest_store,
}
}
fn path_to_uri(&self, path: &Path) -> String {
let path_str = path.as_ref();
let base_str = self.base_path.as_ref();
let relative = if let Some(stripped) = path_str.strip_prefix(base_str) {
stripped.trim_start_matches('/')
} else {
path_str
};
let base = self.base_uri.trim_end_matches('/');
if relative.is_empty() {
base.to_string()
} else {
format!("{}/{}", base, relative)
}
}
pub async fn flush(&self, memtable: &MemTable, epoch: u64) -> Result<FlushResult> {
self.manifest_store.check_fenced(epoch).await?;
if memtable.row_count() == 0 {
return Err(Error::invalid_input("Cannot flush empty MemTable"));
}
if !memtable.all_flushed_to_wal() {
return Err(Error::invalid_input(
"MemTable has unflushed fragments - WAL flush required first",
));
}
let random_hash = generate_random_hash();
let generation = memtable.generation();
let gen_folder_name = format!("{}_gen_{}", random_hash, generation);
let gen_path =
flushed_memtable_path(&self.base_path, &self.region_id, &random_hash, generation);
info!(
"Flushing MemTable generation {} to {} ({} rows, {} batches)",
generation,
gen_path,
memtable.row_count(),
memtable.batch_count()
);
let rows_flushed = self.write_data_file(&gen_path, memtable).await?;
let bloom_path = gen_path.child("bloom_filter.bin");
self.write_bloom_filter(&bloom_path, memtable.bloom_filter())
.await?;
let last_wal_entry_position = memtable.last_flushed_wal_entry_position();
let new_manifest = self
.update_manifest(epoch, generation, &gen_folder_name, last_wal_entry_position)
.await?;
info!(
"Flushed generation {} for region {} (manifest version {})",
generation, self.region_id, new_manifest.version
);
Ok(FlushResult {
generation: FlushedGeneration {
generation,
path: gen_folder_name,
},
rows_flushed,
covered_wal_entry_position: last_wal_entry_position,
})
}
async fn write_data_file(&self, path: &Path, memtable: &MemTable) -> Result<usize> {
use arrow_array::RecordBatchIterator;
use crate::dataset::WriteParams;
if memtable.row_count() == 0 {
return Ok(0);
}
let (batches, total_rows) = memtable.scan_batches_reversed().await?;
if batches.is_empty() {
return Ok(0);
}
let uri = self.path_to_uri(path);
let reader =
RecordBatchIterator::new(batches.into_iter().map(Ok), memtable.schema().clone());
let write_params = WriteParams {
max_rows_per_file: usize::MAX,
..Default::default()
};
Dataset::write(reader, &uri, Some(write_params)).await?;
Ok(total_rows)
}
async fn write_bloom_filter(
&self,
path: &Path,
bloom: &lance_index::scalar::bloomfilter::sbbf::Sbbf,
) -> Result<()> {
let data = bloom.to_bytes();
self.object_store
.inner
.put(path, Bytes::from(data).into())
.await
.map_err(|e| Error::io(format!("Failed to write bloom filter: {}", e)))?;
Ok(())
}
pub async fn flush_with_indexes(
&self,
memtable: &MemTable,
epoch: u64,
index_configs: &[MemIndexConfig],
) -> Result<FlushResult> {
self.manifest_store.check_fenced(epoch).await?;
if memtable.row_count() == 0 {
return Err(Error::invalid_input("Cannot flush empty MemTable"));
}
if !memtable.all_flushed_to_wal() {
return Err(Error::invalid_input(
"MemTable has unflushed fragments - WAL flush required first",
));
}
let random_hash = generate_random_hash();
let generation = memtable.generation();
let gen_folder_name = format!("{}_gen_{}", random_hash, generation);
let gen_path =
flushed_memtable_path(&self.base_path, &self.region_id, &random_hash, generation);
info!(
"Flushing MemTable generation {} with indexes to {} ({} rows, {} batches)",
generation,
gen_path,
memtable.row_count(),
memtable.batch_count()
);
let total_rows = self.write_data_file(&gen_path, memtable).await?;
let created_indexes = self
.create_indexes(&gen_path, index_configs, memtable.indexes(), total_rows)
.await?;
if !created_indexes.is_empty() {
info!(
"Created {} BTree indexes on flushed generation {}",
created_indexes.len(),
generation
);
}
if let Some(registry) = memtable.indexes() {
let uri = self.path_to_uri(&gen_path);
let mut dataset = Dataset::open(&uri).await?;
for config in index_configs {
if let MemIndexConfig::IvfPq(ivf_pq_config) = config
&& let Some(mem_index) = registry.get_ivf_pq(&ivf_pq_config.name)
{
let mut index_meta = self
.create_ivf_pq_index(&gen_path, ivf_pq_config, mem_index, total_rows)
.await?;
let schema = dataset.schema();
let field_idx = schema
.field(&ivf_pq_config.column)
.map(|f| f.id)
.unwrap_or(0);
index_meta.fields = vec![field_idx];
index_meta.dataset_version = dataset.version().version;
let fragment_ids: roaring::RoaringBitmap =
dataset.fragment_bitmap.as_ref().clone();
index_meta.fragment_bitmap = Some(fragment_ids);
use crate::dataset::transaction::{Operation, Transaction};
let transaction = Transaction::new(
index_meta.dataset_version,
Operation::CreateIndex {
new_indices: vec![index_meta],
removed_indices: vec![],
},
None,
);
dataset
.apply_commit(transaction, &Default::default(), &Default::default())
.await?;
info!(
"Created IVF-PQ index '{}' on flushed generation {}",
ivf_pq_config.name, generation
);
}
}
self.create_fts_indexes(&gen_path, index_configs, memtable.indexes(), total_rows)
.await?;
}
let bloom_path = gen_path.child("bloom_filter.bin");
self.write_bloom_filter(&bloom_path, memtable.bloom_filter())
.await?;
let last_wal_entry_position = memtable.last_flushed_wal_entry_position();
let new_manifest = self
.update_manifest(epoch, generation, &gen_folder_name, last_wal_entry_position)
.await?;
info!(
"Flushed generation {} for region {} (manifest version {})",
generation, self.region_id, new_manifest.version
);
Ok(FlushResult {
generation: FlushedGeneration {
generation,
path: gen_folder_name,
},
rows_flushed: memtable.row_count(),
covered_wal_entry_position: last_wal_entry_position,
})
}
async fn create_indexes(
&self,
gen_path: &Path,
index_configs: &[MemIndexConfig],
mem_indexes: Option<&super::super::index::IndexStore>,
total_rows: usize,
) -> Result<Vec<IndexMetadata>> {
use arrow_array::RecordBatchIterator;
use crate::index::CreateIndexBuilder;
let uri = self.path_to_uri(gen_path);
let btree_configs: Vec<_> = index_configs
.iter()
.filter_map(|c| match c {
MemIndexConfig::BTree(cfg) => Some(cfg),
MemIndexConfig::IvfPq(_) => None,
MemIndexConfig::Fts(_) => None,
})
.collect();
if btree_configs.is_empty() {
return Ok(vec![]);
}
let mut dataset = Dataset::open(&uri).await?;
let mut created_indexes = Vec::new();
for btree_cfg in btree_configs {
let params = ScalarIndexParams::default();
let mut builder = CreateIndexBuilder::new(
&mut dataset,
&[btree_cfg.column.as_str()],
IndexType::BTree,
¶ms,
)
.name(btree_cfg.name.clone());
if let Some(registry) = mem_indexes
&& let Some(btree_index) = registry.get_btree(&btree_cfg.name)
{
let training_batches =
btree_index.to_training_batches_reversed(8192, total_rows)?;
if !training_batches.is_empty() {
let schema = training_batches[0].schema();
let reader =
RecordBatchIterator::new(training_batches.into_iter().map(Ok), schema);
builder = builder.preprocessed_data(Box::new(reader));
}
}
let index_meta = builder.execute_uncommitted().await?;
created_indexes.push(index_meta.clone());
use crate::dataset::transaction::{Operation, Transaction};
let transaction = Transaction::new(
index_meta.dataset_version,
Operation::CreateIndex {
new_indices: vec![index_meta],
removed_indices: vec![],
},
None,
);
dataset
.apply_commit(transaction, &Default::default(), &Default::default())
.await?;
}
Ok(created_indexes)
}
async fn create_fts_indexes(
&self,
gen_path: &Path,
index_configs: &[MemIndexConfig],
mem_indexes: Option<&super::super::index::IndexStore>,
total_rows: usize,
) -> Result<()> {
use lance_index::pbold;
use lance_index::scalar::inverted::current_fts_format_version;
use lance_index::scalar::lance_format::LanceIndexStore;
let fts_configs: Vec<_> = index_configs
.iter()
.filter_map(|c| match c {
MemIndexConfig::Fts(cfg) => Some(cfg),
_ => None,
})
.collect();
if fts_configs.is_empty() {
return Ok(());
}
let Some(registry) = mem_indexes else {
return Ok(());
};
let uri = self.path_to_uri(gen_path);
let mut dataset = Dataset::open(&uri).await?;
for fts_cfg in fts_configs {
let Some(fts_index) = registry.get_fts(&fts_cfg.name) else {
continue;
};
if fts_index.is_empty() {
continue;
}
let partition_id = uuid::Uuid::new_v4().as_u64_pair().0;
let mut inner_builder =
fts_index.to_index_builder_reversed(partition_id, total_rows)?;
let index_uuid = uuid::Uuid::new_v4();
let index_dir = gen_path.child("_indices").child(index_uuid.to_string());
let index_store = LanceIndexStore::new(
self.object_store.clone(),
index_dir.clone(),
Arc::new(LanceCache::no_cache()),
);
inner_builder.write(&index_store).await?;
self.write_fts_metadata(&index_store, partition_id, fts_cfg)
.await?;
let details = pbold::InvertedIndexDetails::try_from(&fts_cfg.params)?;
let index_details = prost_types::Any::from_msg(&details)
.map_err(|e| Error::io(format!("Failed to serialize index details: {}", e)))?;
let schema = dataset.schema();
let field_idx = schema.field(&fts_cfg.column).map(|f| f.id).unwrap_or(0);
let fragment_ids: roaring::RoaringBitmap = dataset.fragment_bitmap.as_ref().clone();
let index_meta = IndexMetadata {
uuid: index_uuid,
name: fts_cfg.name.clone(),
fields: vec![field_idx],
dataset_version: dataset.version().version,
fragment_bitmap: Some(fragment_ids),
index_details: Some(Arc::new(index_details)),
index_version: current_fts_format_version().index_version() as i32,
created_at: None,
base_id: None,
files: None,
};
use crate::dataset::transaction::{Operation, Transaction};
let transaction = Transaction::new(
index_meta.dataset_version,
Operation::CreateIndex {
new_indices: vec![index_meta],
removed_indices: vec![],
},
None,
);
dataset
.apply_commit(transaction, &Default::default(), &Default::default())
.await?;
info!(
"Created FTS index '{}' on column '{}' (direct flush)",
fts_cfg.name, fts_cfg.column
);
}
Ok(())
}
async fn write_fts_metadata(
&self,
index_store: &lance_index::scalar::lance_format::LanceIndexStore,
partition_id: u64,
config: &super::super::index::FtsIndexConfig,
) -> Result<()> {
use arrow_array::{RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
use lance_index::scalar::inverted::TokenSetFormat;
let params_json = serde_json::to_string(&config.params)?;
let partitions_json = serde_json::to_string(&[partition_id])?;
let token_set_format = TokenSetFormat::default().to_string();
let schema = Arc::new(
Schema::new(vec![Field::new("_placeholder", DataType::Utf8, true)]).with_metadata(
[
("params".to_string(), params_json),
("partitions".to_string(), partitions_json),
("token_set_format".to_string(), token_set_format),
]
.into(),
),
);
let placeholder_array = Arc::new(StringArray::from(vec![None::<&str>]));
let batch = RecordBatch::try_new(schema.clone(), vec![placeholder_array])?;
let mut writer = index_store.new_index_file("metadata.lance", schema).await?;
writer.write_record_batch(batch).await?;
writer.finish().await?;
Ok(())
}
async fn create_ivf_pq_index(
&self,
gen_path: &Path,
config: &super::super::index::IvfPqIndexConfig,
mem_index: &super::super::index::IvfPqMemIndex,
total_rows: usize,
) -> Result<IndexMetadata> {
use arrow_schema::{Field, Schema as ArrowSchema};
use lance_core::ROW_ID;
use lance_file::writer::FileWriter;
use lance_index::pb;
use lance_index::vector::flat::index::FlatIndex;
use lance_index::vector::ivf::storage::IVF_METADATA_KEY;
use lance_index::vector::quantizer::{
Quantization, QuantizationMetadata, QuantizerMetadata,
};
use lance_index::vector::storage::STORAGE_METADATA_KEY;
use lance_index::vector::v3::subindex::IvfSubIndex;
use lance_index::vector::{DISTANCE_TYPE_KEY, PQ_CODE_COLUMN};
use lance_index::{
INDEX_AUXILIARY_FILE_NAME, INDEX_FILE_NAME, INDEX_METADATA_SCHEMA_KEY,
IndexMetadata as IndexMetaSchema,
};
use prost::Message;
use std::sync::Arc;
let index_uuid = uuid::Uuid::new_v4();
let index_dir = gen_path.child("_indices").child(index_uuid.to_string());
let partition_batches = mem_index.to_partition_batches_reversed(total_rows)?;
let ivf_model = mem_index.ivf_model();
let pq = mem_index.pq();
let distance_type = mem_index.distance_type();
let pq_code_len = pq.num_sub_vectors * pq.num_bits as usize / 8;
let storage_schema: ArrowSchema = ArrowSchema::new(vec![
Field::new(ROW_ID, arrow_schema::DataType::UInt64, false),
Field::new(
PQ_CODE_COLUMN,
arrow_schema::DataType::FixedSizeList(
Arc::new(Field::new("item", arrow_schema::DataType::UInt8, false)),
pq_code_len as i32,
),
false,
),
]);
let index_schema: ArrowSchema = FlatIndex::schema().as_ref().clone();
let storage_path = index_dir.child(INDEX_AUXILIARY_FILE_NAME);
let index_path = index_dir.child(INDEX_FILE_NAME);
let mut storage_writer = FileWriter::try_new(
self.object_store.create(&storage_path).await?,
(&storage_schema).try_into()?,
Default::default(),
)?;
let mut index_writer = FileWriter::try_new(
self.object_store.create(&index_path).await?,
(&index_schema).try_into()?,
Default::default(),
)?;
let mut storage_ivf = lance_index::vector::ivf::storage::IvfModel::empty();
let centroids = ivf_model
.centroids
.clone()
.ok_or_else(|| Error::io("IVF model has no centroids"))?;
let mut index_ivf = lance_index::vector::ivf::storage::IvfModel::new(centroids, None);
let mut partition_index_metadata = Vec::with_capacity(ivf_model.num_partitions());
let partition_map: std::collections::HashMap<usize, _> =
partition_batches.into_iter().collect();
for part_id in 0..ivf_model.num_partitions() {
if let Some(batch) = partition_map.get(&part_id) {
let transposed_batch = transpose_pq_batch(batch, pq_code_len)?;
storage_writer.write_batch(&transposed_batch).await?;
storage_ivf.add_partition(transposed_batch.num_rows() as u32);
index_ivf.add_partition(0);
partition_index_metadata.push(String::new());
} else {
storage_ivf.add_partition(0);
index_ivf.add_partition(0);
partition_index_metadata.push(String::new());
}
}
let storage_ivf_pb = pb::Ivf::try_from(&storage_ivf)?;
storage_writer.add_schema_metadata(DISTANCE_TYPE_KEY, distance_type.to_string());
let ivf_buffer_pos = storage_writer
.add_global_buffer(storage_ivf_pb.encode_to_vec().into())
.await?;
storage_writer.add_schema_metadata(IVF_METADATA_KEY, ivf_buffer_pos.to_string());
let pq_metadata = pq.metadata(Some(QuantizationMetadata {
codebook_position: Some(0),
codebook: None,
transposed: true,
}));
if let Some(extra_metadata) = pq_metadata.extra_metadata()? {
let idx = storage_writer.add_global_buffer(extra_metadata).await?;
let mut pq_meta = pq_metadata;
pq_meta.set_buffer_index(idx);
let storage_partition_metadata = vec![serde_json::to_string(&pq_meta)?];
storage_writer.add_schema_metadata(
STORAGE_METADATA_KEY,
serde_json::to_string(&storage_partition_metadata)?,
);
}
let index_ivf_pb = pb::Ivf::try_from(&index_ivf)?;
let index_metadata = IndexMetaSchema {
index_type: "IVF_PQ".to_string(),
distance_type: distance_type.to_string(),
};
index_writer.add_schema_metadata(
INDEX_METADATA_SCHEMA_KEY,
serde_json::to_string(&index_metadata)?,
);
let ivf_buffer_pos = index_writer
.add_global_buffer(index_ivf_pb.encode_to_vec().into())
.await?;
index_writer.add_schema_metadata(IVF_METADATA_KEY, ivf_buffer_pos.to_string());
index_writer.add_schema_metadata(
FlatIndex::metadata_key(),
serde_json::to_string(&partition_index_metadata)?,
);
storage_writer.finish().await?;
index_writer.finish().await?;
let index_details = Some(std::sync::Arc::new(prost_types::Any {
type_url: "type.googleapis.com/lance.index.VectorIndexDetails".to_string(),
value: vec![],
}));
let index_meta = IndexMetadata {
uuid: index_uuid,
name: config.name.clone(),
fields: vec![0], dataset_version: 0,
fragment_bitmap: None,
index_details,
base_id: None,
created_at: Some(chrono::Utc::now()),
index_version: 1,
files: None,
};
Ok(index_meta)
}
async fn update_manifest(
&self,
epoch: u64,
generation: u64,
gen_path: &str,
covered_wal_entry_position: u64,
) -> Result<RegionManifest> {
let gen_path = gen_path.to_string();
self.manifest_store
.commit_update(epoch, |current| {
let mut flushed_generations = current.flushed_generations.clone();
flushed_generations.push(FlushedGeneration {
generation,
path: gen_path.clone(),
});
RegionManifest {
version: current.version + 1,
replay_after_wal_entry_position: covered_wal_entry_position,
wal_entry_position_last_seen: current
.wal_entry_position_last_seen
.max(covered_wal_entry_position),
current_generation: generation + 1,
flushed_generations,
..current.clone()
}
})
.await
}
}
fn transpose_pq_batch(
batch: &arrow_array::RecordBatch,
pq_code_len: usize,
) -> Result<arrow_array::RecordBatch> {
use arrow_array::FixedSizeListArray;
use arrow_array::cast::AsArray;
use arrow_schema::Field;
use lance_core::ROW_ID;
use lance_index::vector::PQ_CODE_COLUMN;
use lance_index::vector::pq::storage::transpose;
use std::sync::Arc;
let row_ids = batch
.column_by_name(ROW_ID)
.ok_or_else(|| Error::io("Missing _rowid column in partition batch"))?;
let pq_codes = batch
.column_by_name(PQ_CODE_COLUMN)
.ok_or_else(|| Error::io("Missing __pq_code column in partition batch"))?;
let pq_codes_fsl = pq_codes.as_fixed_size_list();
let codes_flat = pq_codes_fsl
.values()
.as_primitive::<arrow_array::types::UInt8Type>();
let transposed = transpose(codes_flat, pq_code_len, batch.num_rows());
let inner_field = Arc::new(Field::new("item", arrow_schema::DataType::UInt8, false));
let transposed_fsl = Arc::new(
FixedSizeListArray::try_new(inner_field, pq_code_len as i32, Arc::new(transposed), None)
.map_err(|e| Error::io(format!("Failed to create transposed PQ array: {}", e)))?,
);
arrow_array::RecordBatch::try_new(batch.schema(), vec![row_ids.clone(), transposed_fsl])
.map_err(|e| Error::io(format!("Failed to create transposed batch: {}", e)))
}
pub struct TriggerMemTableFlush {
pub memtable: Arc<MemTable>,
pub done: Option<tokio::sync::oneshot::Sender<Result<FlushResult>>>,
}
impl std::fmt::Debug for TriggerMemTableFlush {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TriggerMemTableFlush")
.field("memtable_gen", &self.memtable.generation())
.field("memtable_rows", &self.memtable.row_count())
.field("has_done", &self.done.is_some())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int32Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use std::sync::Arc;
use tempfile::TempDir;
async fn create_local_store() -> (Arc<ObjectStore>, Path, String, TempDir) {
let temp_dir = tempfile::tempdir().unwrap();
let uri = format!("file://{}", temp_dir.path().display());
let (store, path) = ObjectStore::from_uri(&uri).await.unwrap();
(store, path, uri, temp_dir)
}
fn create_test_schema() -> Arc<ArrowSchema> {
Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]))
}
fn create_test_batch(schema: &ArrowSchema, num_rows: usize) -> RecordBatch {
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from_iter_values(0..num_rows as i32)),
Arc::new(StringArray::from_iter_values(
(0..num_rows).map(|i| format!("name_{}", i)),
)),
],
)
.unwrap()
}
#[tokio::test]
async fn test_flusher_requires_wal_flush() {
let (store, base_path, base_uri, _temp_dir) = create_local_store().await;
let region_id = Uuid::new_v4();
let manifest_store = Arc::new(RegionManifestStore::new(
store.clone(),
&base_path,
region_id,
2,
));
let (epoch, _manifest) = manifest_store.claim_epoch(0).await.unwrap();
let schema = create_test_schema();
let mut memtable = MemTable::new(schema.clone(), 1, vec![]).unwrap();
memtable
.insert(create_test_batch(&schema, 10))
.await
.unwrap();
assert!(!memtable.all_flushed_to_wal());
let flusher = MemTableFlusher::new(store, base_path, base_uri, region_id, manifest_store);
let result = flusher.flush(&memtable, epoch).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("unflushed fragments")
);
}
#[tokio::test]
async fn test_flusher_empty_memtable() {
let (store, base_path, base_uri, _temp_dir) = create_local_store().await;
let region_id = Uuid::new_v4();
let manifest_store = Arc::new(RegionManifestStore::new(
store.clone(),
&base_path,
region_id,
2,
));
let (epoch, _manifest) = manifest_store.claim_epoch(0).await.unwrap();
let schema = create_test_schema();
let memtable = MemTable::new(schema, 1, vec![]).unwrap();
let flusher = MemTableFlusher::new(store, base_path, base_uri, region_id, manifest_store);
let result = flusher.flush(&memtable, epoch).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("empty MemTable"));
}
#[tokio::test]
async fn test_flusher_success() {
let (store, base_path, base_uri, _temp_dir) = create_local_store().await;
let region_id = Uuid::new_v4();
let manifest_store = Arc::new(RegionManifestStore::new(
store.clone(),
&base_path,
region_id,
2,
));
let (epoch, _manifest) = manifest_store.claim_epoch(0).await.unwrap();
let schema = create_test_schema();
let mut memtable = MemTable::new(schema.clone(), 1, vec![]).unwrap();
let frag_id = memtable
.insert(create_test_batch(&schema, 10))
.await
.unwrap();
memtable.mark_wal_flushed(&[frag_id], 1, &[0]);
assert!(memtable.all_flushed_to_wal());
let flusher = MemTableFlusher::new(
store.clone(),
base_path,
base_uri,
region_id,
manifest_store.clone(),
);
let result = flusher.flush(&memtable, epoch).await.unwrap();
assert_eq!(result.generation.generation, 1);
assert_eq!(result.rows_flushed, 10);
assert_eq!(result.covered_wal_entry_position, 1);
let updated_manifest = manifest_store.read_latest().await.unwrap().unwrap();
assert_eq!(updated_manifest.version, 2);
assert_eq!(updated_manifest.replay_after_wal_entry_position, 1);
assert_eq!(updated_manifest.current_generation, 2);
assert_eq!(updated_manifest.flushed_generations.len(), 1);
}
#[tokio::test]
async fn test_flusher_with_btree_index() {
use super::super::super::index::{BTreeIndexConfig, IndexStore};
use lance_index::DatasetIndexExt;
let (store, base_path, base_uri, _temp_dir) = create_local_store().await;
let region_id = Uuid::new_v4();
let manifest_store = Arc::new(RegionManifestStore::new(
store.clone(),
&base_path,
region_id,
2,
));
let (epoch, _manifest) = manifest_store.claim_epoch(0).await.unwrap();
let index_configs = vec![MemIndexConfig::BTree(BTreeIndexConfig {
name: "id_btree".to_string(),
field_id: 0,
column: "id".to_string(),
})];
let schema = create_test_schema();
let mut memtable = MemTable::new(schema.clone(), 1, vec![]).unwrap();
let registry = IndexStore::from_configs(&index_configs, 100_000, 8).unwrap();
memtable.set_indexes(registry);
let frag_id = memtable
.insert(create_test_batch(&schema, 10))
.await
.unwrap();
memtable.mark_wal_flushed(&[frag_id], 1, &[0]);
let flusher = MemTableFlusher::new(
store.clone(),
base_path.clone(),
base_uri.clone(),
region_id,
manifest_store.clone(),
);
let result = flusher
.flush_with_indexes(&memtable, epoch, &index_configs)
.await
.unwrap();
assert_eq!(result.generation.generation, 1);
assert_eq!(result.rows_flushed, 10);
let gen_uri = format!(
"{}/_mem_wal/{}/{}",
base_uri, region_id, result.generation.path
);
let dataset = Dataset::open(&gen_uri).await.unwrap();
let indices = dataset.load_indices().await.unwrap();
assert_eq!(indices.len(), 1);
assert_eq!(indices[0].name, "id_btree");
let batch = dataset
.scan()
.filter("id = 5")
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(batch.num_rows(), 1);
let id_col = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<arrow_array::Int32Array>()
.unwrap();
assert_eq!(id_col.value(0), 5);
let mut scan = dataset.scan();
scan.filter("id = 5").unwrap();
scan.prefilter(true);
let plan = scan.create_plan().await.unwrap();
crate::utils::test::assert_plan_node_equals(
plan,
"LanceRead: ...full_filter=id = Int32(5)...
ScalarIndexQuery: query=[id = 5]@id_btree",
)
.await
.unwrap();
}
#[tokio::test]
async fn test_flusher_with_ivf_pq_index() {
use super::super::super::index::{IndexStore, IvfPqIndexConfig};
use arrow_array::{FixedSizeListArray, Float32Array};
use lance_arrow::FixedSizeListArrayExt;
use lance_index::DatasetIndexExt;
use lance_index::vector::ivf::storage::IvfModel;
use lance_index::vector::kmeans::{KMeansParams, train_kmeans};
use lance_index::vector::pq::PQBuildParams;
use lance_linalg::distance::DistanceType;
let (store, base_path, base_uri, _temp_dir) = create_local_store().await;
let region_id = Uuid::new_v4();
let manifest_store = Arc::new(RegionManifestStore::new(
store.clone(),
&base_path,
region_id,
2,
));
let (epoch, _manifest) = manifest_store.claim_epoch(0).await.unwrap();
let vector_dim = 8;
let num_vectors = 300;
let num_partitions = 4;
let num_sub_vectors = 2;
let vector_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new(
"vector",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, false)),
vector_dim as i32,
),
false,
),
]));
let vectors: Vec<f32> = (0..num_vectors * vector_dim)
.map(|i| ((i as f32 * 0.1).sin() + (i as f32 * 0.05).cos()) * 0.5)
.collect();
let vectors_array = Float32Array::from(vectors);
let kmeans_params = KMeansParams::new(None, 10, 1, DistanceType::L2);
let kmeans = train_kmeans::<arrow_array::types::Float32Type>(
&vectors_array,
kmeans_params,
vector_dim,
num_partitions,
num_vectors, )
.unwrap();
let centroids_flat = kmeans
.centroids
.as_any()
.downcast_ref::<Float32Array>()
.expect("Centroids should be Float32Array")
.clone();
let centroids_fsl =
FixedSizeListArray::try_new_from_values(centroids_flat, vector_dim as i32).unwrap();
let ivf_model = IvfModel::new(centroids_fsl, None);
let vectors_fsl =
FixedSizeListArray::try_new_from_values(vectors_array.clone(), vector_dim as i32)
.unwrap();
let pq_params = PQBuildParams::new(num_sub_vectors, 8);
let pq = pq_params.build(&vectors_fsl, DistanceType::L2).unwrap();
let index_configs = vec![MemIndexConfig::IvfPq(Box::new(IvfPqIndexConfig {
name: "vector_ivf_pq".to_string(),
field_id: 1,
column: "vector".to_string(),
ivf_model: ivf_model.clone(),
pq: pq.clone(),
distance_type: DistanceType::L2,
}))];
let mut memtable = MemTable::new(vector_schema.clone(), 1, vec![]).unwrap();
let mut registry = IndexStore::from_configs(&index_configs, 100_000, 8).unwrap();
registry.add_ivf_pq(
"vector_ivf_pq".to_string(),
1, "vector".to_string(),
ivf_model,
pq,
DistanceType::L2,
);
memtable.set_indexes(registry);
let ids = Int32Array::from_iter_values(0..num_vectors as i32);
let inner_field = Arc::new(Field::new("item", DataType::Float32, false));
let vectors_fsl_data = FixedSizeListArray::try_new(
inner_field,
vector_dim as i32,
Arc::new(vectors_array),
None,
)
.unwrap();
let batch = RecordBatch::try_new(
vector_schema.clone(),
vec![Arc::new(ids), Arc::new(vectors_fsl_data)],
)
.unwrap();
let frag_id = memtable.insert(batch).await.unwrap();
memtable.mark_wal_flushed(&[frag_id], 1, &[0]);
let flusher = MemTableFlusher::new(
store.clone(),
base_path.clone(),
base_uri.clone(),
region_id,
manifest_store.clone(),
);
let result = flusher
.flush_with_indexes(&memtable, epoch, &index_configs)
.await
.unwrap();
assert_eq!(result.generation.generation, 1);
assert_eq!(result.rows_flushed, num_vectors);
let gen_uri = format!(
"{}/_mem_wal/{}/{}",
base_uri, region_id, result.generation.path
);
let dataset = Dataset::open(&gen_uri).await.unwrap();
let indices = dataset.load_indices().await.unwrap();
assert_eq!(indices.len(), 1);
assert_eq!(indices[0].name, "vector_ivf_pq");
let query_vector: Vec<f32> = (0..vector_dim)
.map(|i| ((i as f32 * 0.1).sin() + (i as f32 * 0.05).cos()) * 0.5)
.collect();
let query_array = Float32Array::from(query_vector);
let batch = dataset
.scan()
.nearest("vector", &query_array, 10)
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(batch.num_rows(), 10);
let distance_col = batch
.column_by_name("_distance")
.unwrap()
.as_any()
.downcast_ref::<Float32Array>()
.unwrap();
assert!(
distance_col.value(0) >= 0.0,
"First distance should be non-negative"
);
for i in 1..10 {
assert!(
distance_col.value(i - 1) <= distance_col.value(i),
"Distances should be sorted: {} > {}",
distance_col.value(i - 1),
distance_col.value(i)
);
}
let id_col = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
for i in 0..10 {
let id = id_col.value(i);
assert!(
id >= 0 && id < num_vectors as i32,
"ID {} should be in range [0, {})",
id,
num_vectors
);
}
let mut scan = dataset.scan();
scan.nearest("vector", &query_array, 10).unwrap();
let plan = scan.create_plan().await.unwrap();
crate::utils::test::assert_plan_node_equals(
plan,
"ProjectionExec: expr=[id@2 as id, vector@3 as vector, _distance@0 as _distance]
Take: ...
CoalesceBatchesExec: ...
SortExec: TopK...
ANNSubIndex: name=vector_ivf_pq, k=10, deltas=1, metric=L2
ANNIvfPartition: ...deltas=1",
)
.await
.unwrap();
}
#[tokio::test]
async fn test_flusher_with_fts_index() {
use super::super::super::index::{FtsIndexConfig, IndexStore};
use arrow_array::StringArray;
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use lance_index::DatasetIndexExt;
use std::sync::Arc;
let (store, base_path, base_uri, _temp_dir) = create_local_store().await;
let region_id = Uuid::new_v4();
let manifest_store = Arc::new(RegionManifestStore::new(
store.clone(),
&base_path,
region_id,
2,
));
let (epoch, _manifest) = manifest_store.claim_epoch(0).await.unwrap();
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("text", DataType::Utf8, true),
]));
let index_configs = vec![MemIndexConfig::Fts(FtsIndexConfig::new(
"text_fts".to_string(),
1,
"text".to_string(),
))];
let mut memtable = MemTable::new(schema.clone(), 1, vec![]).unwrap();
let registry = IndexStore::from_configs(&index_configs, 100_000, 8).unwrap();
memtable.set_indexes(registry);
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec![
"hello world",
"quick brown fox",
"lazy dog jumps",
])),
],
)
.unwrap();
let frag_id = memtable.insert(batch).await.unwrap();
memtable.mark_wal_flushed(&[frag_id], 1, &[0]);
let flusher = MemTableFlusher::new(
store.clone(),
base_path.clone(),
base_uri.clone(),
region_id,
manifest_store.clone(),
);
let result = flusher
.flush_with_indexes(&memtable, epoch, &index_configs)
.await
.unwrap();
assert_eq!(result.generation.generation, 1);
assert_eq!(result.rows_flushed, 3);
let gen_uri = format!(
"{}/_mem_wal/{}/{}",
base_uri, region_id, result.generation.path
);
let dataset = Dataset::open(&gen_uri).await.unwrap();
let indices = dataset.load_indices().await.unwrap();
assert_eq!(indices.len(), 1);
assert_eq!(indices[0].name, "text_fts");
use lance_index::scalar::FullTextSearchQuery;
let batch = dataset
.scan()
.full_text_search(FullTextSearchQuery::new("hello".to_owned()))
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(batch.num_rows(), 1);
let id_col = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<arrow_array::Int32Array>()
.unwrap();
assert_eq!(
id_col.value(0),
1,
"Should find document with 'hello world'"
);
let batch = dataset
.scan()
.full_text_search(FullTextSearchQuery::new("fox".to_owned()))
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(batch.num_rows(), 1);
let id_col = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<arrow_array::Int32Array>()
.unwrap();
assert_eq!(
id_col.value(0),
2,
"Should find document with 'quick brown fox'"
);
let mut scan = dataset.scan();
scan.full_text_search(FullTextSearchQuery::new("hello".to_owned()))
.unwrap();
let plan = scan.create_plan().await.unwrap();
crate::utils::test::assert_plan_node_equals(
plan,
"ProjectionExec: expr=[id@2 as id, text@3 as text, _score@1 as _score]
Take: ...
CoalesceBatchesExec: ...
MatchQuery: column=text, query=hello",
)
.await
.unwrap();
}
}