use std::collections::HashMap;
use std::sync::Arc;
use arrow_schema::DataType;
use async_trait::async_trait;
use lance_core::{Error, Result};
use lance_index::mem_wal::{MEM_WAL_INDEX_NAME, MemWalIndexDetails, ShardingField, ShardingSpec};
use lance_index::vector::hnsw::builder::HnswBuildParams;
use lance_io::object_store::ObjectStore;
use uuid::Uuid;
use crate::Dataset;
use crate::dataset::CommitBuilder;
use crate::dataset::transaction::{Operation, Transaction};
use crate::index::DatasetIndexExt;
use crate::index::DatasetIndexInternalExt;
use crate::index::mem_wal::{load_mem_wal_index_details, new_mem_wal_index_meta};
use super::ShardWriterConfig;
use super::scanner::flushed_cache::open_flushed_dataset;
use super::scanner::{DatasetCache, ShardSnapshot};
use super::write::MemIndexConfig;
use super::write::ShardWriter;
const SHARDING_SPEC_ID: u32 = 1;
const SHARDING_FIELD_ID: &str = "bucket";
const SHARDING_RESULT_TYPE: &str = "int32";
const BUCKET_TRANSFORM: &str = "bucket";
const UNSHARDED_TRANSFORM: &str = "unsharded";
const IDENTITY_TRANSFORM: &str = "identity";
const NUM_BUCKETS_PARAM: &str = "num_buckets";
const MAX_NUM_BUCKETS: u32 = 1024;
#[derive(Debug)]
enum Sharding {
Manual,
Unsharded,
Bucket { column: String, num_buckets: u32 },
Identity { column: String },
}
#[must_use = "InitializeMemWalBuilder does nothing unless `.execute()` is awaited"]
pub struct InitializeMemWalBuilder<'a> {
dataset: &'a mut Dataset,
sharding: Sharding,
maintained_indexes: Vec<String>,
writer_config_defaults: HashMap<String, String>,
}
impl<'a> InitializeMemWalBuilder<'a> {
fn new(dataset: &'a mut Dataset) -> Self {
Self {
dataset,
sharding: Sharding::Manual,
maintained_indexes: Vec::new(),
writer_config_defaults: HashMap::new(),
}
}
pub fn unsharded(mut self) -> Self {
self.sharding = Sharding::Unsharded;
self
}
pub fn bucket_sharding(mut self, column: impl Into<String>, num_buckets: u32) -> Self {
self.sharding = Sharding::Bucket {
column: column.into(),
num_buckets,
};
self
}
pub fn identity_sharding(mut self, column: impl Into<String>) -> Self {
self.sharding = Sharding::Identity {
column: column.into(),
};
self
}
pub fn maintained_indexes<I, S>(mut self, indexes: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.maintained_indexes = indexes.into_iter().map(Into::into).collect();
self
}
pub fn writer_config_defaults(mut self, config: ShardWriterConfig) -> Self {
self.writer_config_defaults
.extend(writer_config_to_defaults(&config));
self
}
pub fn add_writer_config_default(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Self {
self.writer_config_defaults.insert(key.into(), value.into());
self
}
pub async fn execute(self) -> Result<()> {
let Self {
dataset,
sharding,
maintained_indexes,
writer_config_defaults,
} = self;
let (sharding_specs, num_shards) = resolve_sharding(dataset, sharding)?;
let indices = dataset.load_indices().await?;
for index_name in &maintained_indexes {
if !indices.iter().any(|idx| &idx.name == index_name) {
return Err(Error::invalid_input(format!(
"Index '{}' not found on dataset. maintained_indexes must reference existing indexes.",
index_name
)));
}
}
if indices.iter().any(|idx| idx.name == MEM_WAL_INDEX_NAME) {
return Err(Error::invalid_input(
"MemWAL is already initialized on this dataset.",
));
}
let details = MemWalIndexDetails {
num_shards,
sharding_specs,
maintained_indexes,
writer_config_defaults,
..Default::default()
};
let index_meta = new_mem_wal_index_meta(dataset.manifest.version, details)?;
let transaction = Transaction::new(
dataset.manifest.version,
Operation::CreateIndex {
new_indices: vec![index_meta],
removed_indices: vec![],
},
None,
);
let new_dataset = CommitBuilder::new(Arc::new(dataset.clone()))
.execute(transaction)
.await?;
*dataset = new_dataset;
Ok(())
}
}
fn resolve_sharding(dataset: &Dataset, sharding: Sharding) -> Result<(Vec<ShardingSpec>, u32)> {
match sharding {
Sharding::Manual => Ok((Vec::new(), 0)),
Sharding::Unsharded => Ok((vec![unsharded_sharding_spec()], 1)),
Sharding::Bucket {
column,
num_buckets,
} => Ok((
vec![bucket_sharding_spec(dataset, &column, num_buckets)?],
num_buckets,
)),
Sharding::Identity { column } => Ok((vec![identity_sharding_spec(dataset, &column)?], 0)),
}
}
fn unsharded_sharding_spec() -> ShardingSpec {
ShardingSpec {
spec_id: SHARDING_SPEC_ID,
fields: vec![ShardingField {
field_id: SHARDING_FIELD_ID.to_string(),
source_ids: Vec::new(),
transform: Some(UNSHARDED_TRANSFORM.to_string()),
expression: None,
result_type: SHARDING_RESULT_TYPE.to_string(),
parameters: HashMap::new(),
}],
}
}
fn bucket_sharding_spec(dataset: &Dataset, column: &str, num_buckets: u32) -> Result<ShardingSpec> {
if num_buckets == 0 || num_buckets > MAX_NUM_BUCKETS {
return Err(Error::invalid_input(format!(
"bucket_sharding: num_buckets must be in [1, {}], got {}",
MAX_NUM_BUCKETS, num_buckets
)));
}
let source_field = dataset.schema().field(column).ok_or_else(|| {
Error::invalid_input(format!(
"bucket_sharding: column '{}' not found on the dataset",
column
))
})?;
let data_type = source_field.data_type();
if !is_bucket_sharding_supported_type(&data_type) {
return Err(Error::invalid_input(format!(
"bucket_sharding: column '{}' has type {:?}, which cannot be used as a shard key",
column, data_type
)));
}
Ok(ShardingSpec {
spec_id: SHARDING_SPEC_ID,
fields: vec![ShardingField {
field_id: SHARDING_FIELD_ID.to_string(),
source_ids: vec![source_field.id],
transform: Some(BUCKET_TRANSFORM.to_string()),
expression: None,
result_type: SHARDING_RESULT_TYPE.to_string(),
parameters: HashMap::from([(NUM_BUCKETS_PARAM.to_string(), num_buckets.to_string())]),
}],
})
}
fn identity_sharding_spec(dataset: &Dataset, column: &str) -> Result<ShardingSpec> {
let field = dataset.schema().field(column).ok_or_else(|| {
Error::invalid_input(format!(
"identity_sharding: column '{}' not found on the dataset",
column
))
})?;
let data_type = field.data_type();
let result_type = scalar_result_type(&data_type).ok_or_else(|| {
Error::invalid_input(format!(
"identity_sharding: column '{}' has type {:?}, which cannot be used as a shard key",
column, data_type
))
})?;
Ok(ShardingSpec {
spec_id: SHARDING_SPEC_ID,
fields: vec![ShardingField {
field_id: SHARDING_FIELD_ID.to_string(),
source_ids: vec![field.id],
transform: Some(IDENTITY_TRANSFORM.to_string()),
expression: None,
result_type: result_type.to_string(),
parameters: HashMap::new(),
}],
})
}
fn is_bucket_sharding_supported_type(data_type: &DataType) -> bool {
matches!(
data_type,
DataType::Boolean
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64
| DataType::Date32
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Utf8
| DataType::LargeUtf8
)
}
fn scalar_result_type(data_type: &DataType) -> Option<&'static str> {
Some(match data_type {
DataType::Int8 => "int8",
DataType::Int16 => "int16",
DataType::Int32 => "int32",
DataType::Int64 => "int64",
DataType::UInt8 => "uint8",
DataType::UInt16 => "uint16",
DataType::UInt32 => "uint32",
DataType::UInt64 => "uint64",
DataType::Utf8 => "utf8",
DataType::LargeUtf8 => "large_utf8",
DataType::Boolean => "boolean",
_ => return None,
})
}
fn writer_config_to_defaults(config: &ShardWriterConfig) -> HashMap<String, String> {
let mut defaults = HashMap::from([
(
"durable_write".to_string(),
config.durable_write.to_string(),
),
(
"sync_indexed_write".to_string(),
config.sync_indexed_write.to_string(),
),
(
"max_wal_buffer_size".to_string(),
config.max_wal_buffer_size.to_string(),
),
(
"max_memtable_size".to_string(),
config.max_memtable_size.to_string(),
),
(
"max_memtable_rows".to_string(),
config.max_memtable_rows.to_string(),
),
(
"max_memtable_batches".to_string(),
config.max_memtable_batches.to_string(),
),
(
"manifest_scan_batch_size".to_string(),
config.manifest_scan_batch_size.to_string(),
),
(
"max_unflushed_memtable_bytes".to_string(),
config.max_unflushed_memtable_bytes.to_string(),
),
(
"backpressure_log_interval_ms".to_string(),
config.backpressure_log_interval.as_millis().to_string(),
),
(
"async_index_buffer_rows".to_string(),
config.async_index_buffer_rows.to_string(),
),
(
"async_index_interval_ms".to_string(),
config.async_index_interval.as_millis().to_string(),
),
(
"enable_memtable".to_string(),
config.enable_memtable.to_string(),
),
]);
if let Some(interval) = config.max_wal_flush_interval {
defaults.insert(
"max_wal_flush_interval_ms".to_string(),
interval.as_millis().to_string(),
);
}
if let Some(interval) = config.stats_log_interval {
defaults.insert(
"stats_log_interval_ms".to_string(),
interval.as_millis().to_string(),
);
}
for (index_name, params) in &config.hnsw_params {
defaults.insert(format!("hnsw.{index_name}.num_edges"), params.m.to_string());
defaults.insert(
format!("hnsw.{index_name}.ef_construction"),
params.ef_construction.to_string(),
);
defaults.insert(
format!("hnsw.{index_name}.max_level"),
params.max_level.to_string(),
);
}
defaults
}
#[async_trait]
pub trait DatasetMemWalExt {
fn initialize_mem_wal(&mut self) -> InitializeMemWalBuilder<'_>;
async fn mem_wal_index_details(&self) -> Result<Option<MemWalIndexDetails>> {
Ok(None)
}
async fn list_mem_wal_latest_shard_ids(&self) -> Result<Vec<Uuid>> {
Ok(Vec::new())
}
async fn prewarm_mem_wal(
&self,
_snapshots: &[ShardSnapshot],
_cache: Option<&Arc<dyn DatasetCache>>,
) -> Result<()> {
Ok(())
}
async fn mem_wal_writer(
&self,
shard_id: Uuid,
config: ShardWriterConfig,
) -> Result<ShardWriter>;
}
async fn prewarm_all_indexes(dataset: &Dataset) -> Result<()> {
let indices = dataset.load_indices().await?;
let mut seen = std::collections::HashSet::new();
for index in indices.iter() {
if seen.insert(index.name.as_str()) {
dataset.prewarm_index(&index.name).await?;
}
}
Ok(())
}
#[async_trait]
impl DatasetMemWalExt for Dataset {
fn initialize_mem_wal(&mut self) -> InitializeMemWalBuilder<'_> {
InitializeMemWalBuilder::new(self)
}
async fn mem_wal_index_details(&self) -> Result<Option<MemWalIndexDetails>> {
let Some(index_meta) = self.load_index_by_name(MEM_WAL_INDEX_NAME).await? else {
return Ok(None);
};
load_mem_wal_index_details(index_meta).map(Some)
}
async fn list_mem_wal_latest_shard_ids(&self) -> Result<Vec<Uuid>> {
let prefix = super::util::mem_wal_path(&self.branch_location().path);
let object_store = self.object_store(None).await?;
let list_result = object_store
.inner
.list_with_delimiter(Some(&prefix))
.await
.map_err(|e| {
Error::io(format!(
"failed to list MemWAL shard directories at {}: {}",
prefix, e
))
})?;
let mut ids = Vec::new();
for shard_prefix in list_result.common_prefixes {
if let Some(name) = shard_prefix.filename()
&& let Ok(shard_id) = Uuid::parse_str(name)
{
ids.push(shard_id);
}
}
ids.sort();
Ok(ids)
}
async fn prewarm_mem_wal(
&self,
snapshots: &[ShardSnapshot],
cache: Option<&Arc<dyn DatasetCache>>,
) -> Result<()> {
let session = self.session();
let base_path = self.uri().trim_end_matches('/').to_string();
let opens = snapshots
.iter()
.flat_map(|snapshot| {
let shard_id = snapshot.shard_id;
let base_path = &base_path;
let session = &session;
snapshot.flushed_generations.iter().map(move |flushed| {
let path = format!("{}/_mem_wal/{}/{}", base_path, shard_id, flushed.path);
async move {
let dataset =
open_flushed_dataset(&path, Some(session), cache, None).await?;
prewarm_all_indexes(&dataset).await
}
})
})
.collect::<Vec<_>>();
futures::future::try_join_all(opens).await?;
Ok(())
}
async fn mem_wal_writer(
&self,
shard_id: Uuid,
mut config: ShardWriterConfig,
) -> Result<ShardWriter> {
use lance_index::metrics::NoOpMetricsCollector;
let mem_wal_index = self
.open_mem_wal_index(&NoOpMetricsCollector)
.await?
.ok_or_else(|| {
Error::invalid_input(
"MemWAL is not initialized on this dataset. Call initialize_mem_wal() first.",
)
})?;
let maintained_indexes = &mem_wal_index.details.maintained_indexes;
let mut index_configs = Vec::new();
for index_name in maintained_indexes {
let index_meta = self
.load_indices_by_name(index_name)
.await?
.into_iter()
.next()
.ok_or_else(|| {
Error::invalid_input(format!(
"Index '{}' from maintained_indexes not found on dataset",
index_name
))
})?;
let type_url = index_meta
.index_details
.as_ref()
.map(|d| d.type_url.as_str())
.unwrap_or("");
let index_type = MemIndexConfig::detect_index_type(type_url)?;
match index_type {
"btree" => {
index_configs.push(MemIndexConfig::btree_from_metadata(
&index_meta,
self.schema(),
)?);
}
"fts" => {
index_configs.push(MemIndexConfig::fts_from_metadata(
&index_meta,
self.schema(),
)?);
}
"vector" => {
let hnsw_params = config.hnsw_params.get(index_name).cloned();
let vector_config =
load_vector_index_config(self, index_name, &index_meta, hnsw_params)
.await?;
index_configs.push(vector_config);
}
_ => {
return Err(Error::invalid_input(format!(
"Unknown index type: {}",
index_type
)));
}
};
}
config.shard_id = shard_id;
let base_uri = self.uri();
let (store, base_path) = ObjectStore::from_uri(base_uri).await?;
ShardWriter::open(
store,
base_path,
base_uri,
config,
Arc::new(self.schema().into()),
index_configs,
)
.await
}
}
async fn load_vector_index_config(
dataset: &Dataset,
index_name: &str,
index_meta: &lance_table::format::IndexMetadata,
hnsw_params: Option<HnswBuildParams>,
) -> Result<MemIndexConfig> {
use lance_index::metrics::NoOpMetricsCollector;
let field_id = index_meta.fields.first().ok_or_else(|| {
Error::invalid_input(format!("Vector index '{}' has no fields", index_name))
})?;
let field = dataset.schema().field_by_id(*field_id).ok_or_else(|| {
Error::invalid_input(format!("Field not found for vector index '{}'", index_name))
})?;
let column = field.name.clone();
let distance_type = dataset
.open_vector_index(&column, &index_meta.uuid, &NoOpMetricsCollector)
.await
.map_err(|e| {
Error::invalid_input(format!(
"Failed to open base vector index '{}' to inherit distance type: {}",
index_name, e
))
})?
.metric_type();
Ok(match hnsw_params {
Some(params) => MemIndexConfig::hnsw_with_params(
index_name.to_string(),
*field_id,
column,
distance_type,
params,
),
None => MemIndexConfig::hnsw(index_name.to_string(), *field_id, column, distance_type),
})
}
#[cfg(test)]
mod tests {
use super::super::scanner::FlushedMemTableCache;
use super::*;
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use lance_index::IndexType;
use lance_index::scalar::ScalarIndexParams;
use crate::dataset::WriteParams;
fn id_v_schema() -> Arc<ArrowSchema> {
Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("v", DataType::Int32, true),
]))
}
fn id_v_batch(schema: &Arc<ArrowSchema>, ids: &[i32]) -> RecordBatch {
let vs: Vec<i32> = ids.iter().map(|i| i * 10).collect();
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(ids.to_vec())),
Arc::new(Int32Array::from(vs)),
],
)
.unwrap()
}
#[tokio::test]
async fn test_prewarm_mem_wal_opens_and_warms_indexes() {
let tmp = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", tmp.path().to_str().unwrap());
let schema = id_v_schema();
let reader = RecordBatchIterator::new([Ok(id_v_batch(&schema, &[-1]))], schema.clone());
let base = Dataset::write(reader, &base_uri, Some(WriteParams::default()))
.await
.unwrap();
let shard_id = Uuid::new_v4();
let folder = "deadbeef_gen_1";
let gen_uri = format!("{}/_mem_wal/{}/{}", base_uri, shard_id, folder);
let reader =
RecordBatchIterator::new([Ok(id_v_batch(&schema, &[1, 2, 3]))], schema.clone());
let mut gen_ds = Dataset::write(reader, &gen_uri, Some(WriteParams::default()))
.await
.unwrap();
gen_ds
.create_index(
&["id"],
IndexType::BTree,
Some("id_idx".to_string()),
&ScalarIndexParams::default(),
true,
)
.await
.unwrap();
let snapshot = ShardSnapshot::new(shard_id)
.with_current_generation(2)
.with_flushed_generation(1, folder.to_string());
let cache: Arc<dyn DatasetCache> = Arc::new(FlushedMemTableCache::new(4));
base.prewarm_mem_wal(std::slice::from_ref(&snapshot), Some(&cache))
.await
.expect("prewarm must open the generation and warm its index");
let warmed = cache
.get_or_open(&gen_uri, Some(base.session()))
.await
.unwrap();
assert_eq!(warmed.load_indices().await.unwrap().len(), 1);
}
#[tokio::test]
async fn test_prewarm_mem_wal_empty_is_noop() {
let tmp = tempfile::tempdir().unwrap();
let base_uri = format!("{}/base", tmp.path().to_str().unwrap());
let schema = id_v_schema();
let reader = RecordBatchIterator::new([Ok(id_v_batch(&schema, &[-1]))], schema.clone());
let base = Dataset::write(reader, &base_uri, Some(WriteParams::default()))
.await
.unwrap();
base.prewarm_mem_wal(&[], None).await.unwrap();
let empty = ShardSnapshot::new(Uuid::new_v4()).with_current_generation(1);
base.prewarm_mem_wal(std::slice::from_ref(&empty), None)
.await
.unwrap();
}
}