use std::collections::HashMap;
use std::sync::Arc;
use arrow_schema::DataType;
use async_trait::async_trait;
use lance_core::{Error, Result};
use lance_index::mem_wal::{MEM_WAL_INDEX_NAME, MemWalIndexDetails, ShardingField, ShardingSpec};
use lance_io::object_store::ObjectStore;
use uuid::Uuid;
use crate::Dataset;
use crate::dataset::CommitBuilder;
use crate::dataset::transaction::{Operation, Transaction};
use crate::index::DatasetIndexExt;
use crate::index::DatasetIndexInternalExt;
use crate::index::mem_wal::{load_mem_wal_index_details, new_mem_wal_index_meta};
use super::ShardWriterConfig;
use super::write::MemIndexConfig;
use super::write::ShardWriter;
const SHARDING_SPEC_ID: u32 = 1;
const SHARDING_FIELD_ID: &str = "bucket";
const SHARDING_RESULT_TYPE: &str = "int32";
const BUCKET_TRANSFORM: &str = "bucket";
const UNSHARDED_TRANSFORM: &str = "unsharded";
const IDENTITY_TRANSFORM: &str = "identity";
const NUM_BUCKETS_PARAM: &str = "num_buckets";
const MAX_NUM_BUCKETS: u32 = 1024;
#[derive(Debug)]
enum Sharding {
Manual,
Unsharded,
Bucket { column: String, num_buckets: u32 },
Identity { column: String },
}
#[must_use = "InitializeMemWalBuilder does nothing unless `.execute()` is awaited"]
pub struct InitializeMemWalBuilder<'a> {
dataset: &'a mut Dataset,
sharding: Sharding,
maintained_indexes: Vec<String>,
writer_config_defaults: HashMap<String, String>,
}
impl<'a> InitializeMemWalBuilder<'a> {
fn new(dataset: &'a mut Dataset) -> Self {
Self {
dataset,
sharding: Sharding::Manual,
maintained_indexes: Vec::new(),
writer_config_defaults: HashMap::new(),
}
}
pub fn unsharded(mut self) -> Self {
self.sharding = Sharding::Unsharded;
self
}
pub fn bucket_sharding(mut self, column: impl Into<String>, num_buckets: u32) -> Self {
self.sharding = Sharding::Bucket {
column: column.into(),
num_buckets,
};
self
}
pub fn identity_sharding(mut self, column: impl Into<String>) -> Self {
self.sharding = Sharding::Identity {
column: column.into(),
};
self
}
pub fn maintained_indexes<I, S>(mut self, indexes: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.maintained_indexes = indexes.into_iter().map(Into::into).collect();
self
}
pub fn writer_config_defaults(mut self, config: ShardWriterConfig) -> Self {
self.writer_config_defaults
.extend(writer_config_to_defaults(&config));
self
}
pub fn add_writer_config_default(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Self {
self.writer_config_defaults.insert(key.into(), value.into());
self
}
pub async fn execute(self) -> Result<()> {
let Self {
dataset,
sharding,
maintained_indexes,
writer_config_defaults,
} = self;
let (sharding_specs, num_shards) = resolve_sharding(dataset, sharding)?;
let indices = dataset.load_indices().await?;
for index_name in &maintained_indexes {
if !indices.iter().any(|idx| &idx.name == index_name) {
return Err(Error::invalid_input(format!(
"Index '{}' not found on dataset. maintained_indexes must reference existing indexes.",
index_name
)));
}
}
if indices.iter().any(|idx| idx.name == MEM_WAL_INDEX_NAME) {
return Err(Error::invalid_input(
"MemWAL is already initialized on this dataset.",
));
}
let details = MemWalIndexDetails {
num_shards,
sharding_specs,
maintained_indexes,
writer_config_defaults,
..Default::default()
};
let index_meta = new_mem_wal_index_meta(dataset.manifest.version, details)?;
let transaction = Transaction::new(
dataset.manifest.version,
Operation::CreateIndex {
new_indices: vec![index_meta],
removed_indices: vec![],
},
None,
);
let new_dataset = CommitBuilder::new(Arc::new(dataset.clone()))
.execute(transaction)
.await?;
*dataset = new_dataset;
Ok(())
}
}
fn resolve_sharding(dataset: &Dataset, sharding: Sharding) -> Result<(Vec<ShardingSpec>, u32)> {
match sharding {
Sharding::Manual => Ok((Vec::new(), 0)),
Sharding::Unsharded => Ok((vec![unsharded_sharding_spec()], 1)),
Sharding::Bucket {
column,
num_buckets,
} => Ok((
vec![bucket_sharding_spec(dataset, &column, num_buckets)?],
num_buckets,
)),
Sharding::Identity { column } => Ok((vec![identity_sharding_spec(dataset, &column)?], 0)),
}
}
fn unsharded_sharding_spec() -> ShardingSpec {
ShardingSpec {
spec_id: SHARDING_SPEC_ID,
fields: vec![ShardingField {
field_id: SHARDING_FIELD_ID.to_string(),
source_ids: Vec::new(),
transform: Some(UNSHARDED_TRANSFORM.to_string()),
expression: None,
result_type: SHARDING_RESULT_TYPE.to_string(),
parameters: HashMap::new(),
}],
}
}
fn bucket_sharding_spec(dataset: &Dataset, column: &str, num_buckets: u32) -> Result<ShardingSpec> {
if num_buckets == 0 || num_buckets > MAX_NUM_BUCKETS {
return Err(Error::invalid_input(format!(
"bucket_sharding: num_buckets must be in [1, {}], got {}",
MAX_NUM_BUCKETS, num_buckets
)));
}
let source_field = dataset.schema().field(column).ok_or_else(|| {
Error::invalid_input(format!(
"bucket_sharding: column '{}' not found on the dataset",
column
))
})?;
let data_type = source_field.data_type();
if !is_bucket_sharding_supported_type(&data_type) {
return Err(Error::invalid_input(format!(
"bucket_sharding: column '{}' has type {:?}, which cannot be used as a shard key",
column, data_type
)));
}
Ok(ShardingSpec {
spec_id: SHARDING_SPEC_ID,
fields: vec![ShardingField {
field_id: SHARDING_FIELD_ID.to_string(),
source_ids: vec![source_field.id],
transform: Some(BUCKET_TRANSFORM.to_string()),
expression: None,
result_type: SHARDING_RESULT_TYPE.to_string(),
parameters: HashMap::from([(NUM_BUCKETS_PARAM.to_string(), num_buckets.to_string())]),
}],
})
}
fn identity_sharding_spec(dataset: &Dataset, column: &str) -> Result<ShardingSpec> {
let field = dataset.schema().field(column).ok_or_else(|| {
Error::invalid_input(format!(
"identity_sharding: column '{}' not found on the dataset",
column
))
})?;
let data_type = field.data_type();
let result_type = scalar_result_type(&data_type).ok_or_else(|| {
Error::invalid_input(format!(
"identity_sharding: column '{}' has type {:?}, which cannot be used as a shard key",
column, data_type
))
})?;
Ok(ShardingSpec {
spec_id: SHARDING_SPEC_ID,
fields: vec![ShardingField {
field_id: SHARDING_FIELD_ID.to_string(),
source_ids: vec![field.id],
transform: Some(IDENTITY_TRANSFORM.to_string()),
expression: None,
result_type: result_type.to_string(),
parameters: HashMap::new(),
}],
})
}
fn is_bucket_sharding_supported_type(data_type: &DataType) -> bool {
matches!(
data_type,
DataType::Boolean
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64
| DataType::Date32
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Utf8
| DataType::LargeUtf8
)
}
fn scalar_result_type(data_type: &DataType) -> Option<&'static str> {
Some(match data_type {
DataType::Int8 => "int8",
DataType::Int16 => "int16",
DataType::Int32 => "int32",
DataType::Int64 => "int64",
DataType::UInt8 => "uint8",
DataType::UInt16 => "uint16",
DataType::UInt32 => "uint32",
DataType::UInt64 => "uint64",
DataType::Utf8 => "utf8",
DataType::LargeUtf8 => "large_utf8",
DataType::Boolean => "boolean",
_ => return None,
})
}
fn writer_config_to_defaults(config: &ShardWriterConfig) -> HashMap<String, String> {
let mut defaults = HashMap::from([
(
"durable_write".to_string(),
config.durable_write.to_string(),
),
(
"sync_indexed_write".to_string(),
config.sync_indexed_write.to_string(),
),
(
"max_wal_buffer_size".to_string(),
config.max_wal_buffer_size.to_string(),
),
(
"max_memtable_size".to_string(),
config.max_memtable_size.to_string(),
),
(
"max_memtable_rows".to_string(),
config.max_memtable_rows.to_string(),
),
(
"max_memtable_batches".to_string(),
config.max_memtable_batches.to_string(),
),
(
"manifest_scan_batch_size".to_string(),
config.manifest_scan_batch_size.to_string(),
),
(
"max_unflushed_memtable_bytes".to_string(),
config.max_unflushed_memtable_bytes.to_string(),
),
(
"backpressure_log_interval_ms".to_string(),
config.backpressure_log_interval.as_millis().to_string(),
),
(
"async_index_buffer_rows".to_string(),
config.async_index_buffer_rows.to_string(),
),
(
"async_index_interval_ms".to_string(),
config.async_index_interval.as_millis().to_string(),
),
(
"enable_memtable".to_string(),
config.enable_memtable.to_string(),
),
]);
if let Some(interval) = config.max_wal_flush_interval {
defaults.insert(
"max_wal_flush_interval_ms".to_string(),
interval.as_millis().to_string(),
);
}
if let Some(interval) = config.stats_log_interval {
defaults.insert(
"stats_log_interval_ms".to_string(),
interval.as_millis().to_string(),
);
}
defaults
}
#[async_trait]
pub trait DatasetMemWalExt {
fn initialize_mem_wal(&mut self) -> InitializeMemWalBuilder<'_>;
async fn mem_wal_index_details(&self) -> Result<Option<MemWalIndexDetails>> {
Ok(None)
}
async fn list_mem_wal_latest_shard_ids(&self) -> Result<Vec<Uuid>> {
Ok(Vec::new())
}
async fn mem_wal_writer(
&self,
shard_id: Uuid,
config: ShardWriterConfig,
) -> Result<ShardWriter>;
}
#[async_trait]
impl DatasetMemWalExt for Dataset {
fn initialize_mem_wal(&mut self) -> InitializeMemWalBuilder<'_> {
InitializeMemWalBuilder::new(self)
}
async fn mem_wal_index_details(&self) -> Result<Option<MemWalIndexDetails>> {
let Some(index_meta) = self.load_index_by_name(MEM_WAL_INDEX_NAME).await? else {
return Ok(None);
};
load_mem_wal_index_details(index_meta).map(Some)
}
async fn list_mem_wal_latest_shard_ids(&self) -> Result<Vec<Uuid>> {
let prefix = super::util::mem_wal_path(&self.branch_location().path);
let object_store = self.object_store(None).await?;
let list_result = object_store
.inner
.list_with_delimiter(Some(&prefix))
.await
.map_err(|e| {
Error::io(format!(
"failed to list MemWAL shard directories at {}: {}",
prefix, e
))
})?;
let mut ids = Vec::new();
for shard_prefix in list_result.common_prefixes {
if let Some(name) = shard_prefix.filename()
&& let Ok(shard_id) = Uuid::parse_str(name)
{
ids.push(shard_id);
}
}
ids.sort();
Ok(ids)
}
async fn mem_wal_writer(
&self,
shard_id: Uuid,
mut config: ShardWriterConfig,
) -> Result<ShardWriter> {
use lance_index::metrics::NoOpMetricsCollector;
let mem_wal_index = self
.open_mem_wal_index(&NoOpMetricsCollector)
.await?
.ok_or_else(|| {
Error::invalid_input(
"MemWAL is not initialized on this dataset. Call initialize_mem_wal() first.",
)
})?;
let maintained_indexes = &mem_wal_index.details.maintained_indexes;
let mut index_configs = Vec::new();
for index_name in maintained_indexes {
let index_meta = self.load_index_by_name(index_name).await?.ok_or_else(|| {
Error::invalid_input(format!(
"Index '{}' from maintained_indexes not found on dataset",
index_name
))
})?;
let type_url = index_meta
.index_details
.as_ref()
.map(|d| d.type_url.as_str())
.unwrap_or("");
let index_type = MemIndexConfig::detect_index_type(type_url)?;
match index_type {
"btree" => {
index_configs.push(MemIndexConfig::btree_from_metadata(
&index_meta,
self.schema(),
)?);
}
"fts" => {
index_configs.push(MemIndexConfig::fts_from_metadata(
&index_meta,
self.schema(),
)?);
}
"vector" => {
let vector_config =
load_vector_index_config(self, index_name, &index_meta).await?;
index_configs.push(vector_config);
}
_ => {
return Err(Error::invalid_input(format!(
"Unknown index type: {}",
index_type
)));
}
};
}
config.shard_id = shard_id;
let base_uri = self.uri();
let (store, base_path) = ObjectStore::from_uri(base_uri).await?;
ShardWriter::open(
store,
base_path,
base_uri,
config,
Arc::new(self.schema().into()),
index_configs,
)
.await
}
}
async fn load_vector_index_config(
dataset: &Dataset,
index_name: &str,
index_meta: &lance_table::format::IndexMetadata,
) -> Result<MemIndexConfig> {
use lance_index::metrics::NoOpMetricsCollector;
let field_id = index_meta.fields.first().ok_or_else(|| {
Error::invalid_input(format!("Vector index '{}' has no fields", index_name))
})?;
let field = dataset.schema().field_by_id(*field_id).ok_or_else(|| {
Error::invalid_input(format!("Field not found for vector index '{}'", index_name))
})?;
let column = field.name.clone();
let distance_type = dataset
.open_vector_index(&column, &index_meta.uuid.to_string(), &NoOpMetricsCollector)
.await
.map_err(|e| {
Error::invalid_input(format!(
"Failed to open base vector index '{}' to inherit distance type: {}",
index_name, e
))
})?
.metric_type();
Ok(MemIndexConfig::hnsw(
index_name.to_string(),
*field_id,
column,
distance_type,
))
}