use std::sync::Arc;
use std::{any::Any, collections::HashMap};
pub mod builder;
pub mod ivf;
pub mod pq;
pub mod utils;
#[cfg(test)]
mod fixture_test;
use self::{ivf::*, pq::PQIndex};
use arrow_schema::DataType;
use builder::IvfIndexBuilder;
use lance_core::utils::tempfile::TempStdDir;
use lance_file::previous::reader::FileReader as PreviousFileReader;
use lance_index::frag_reuse::FragReuseIndex;
use lance_index::metrics::NoOpMetricsCollector;
use lance_index::optimize::OptimizeOptions;
use lance_index::progress::{IndexBuildProgress, noop_progress};
use lance_index::vector::bq::builder::RabitQuantizer;
use lance_index::vector::bq::{RQBuildParams, RQRotationType};
use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer};
use lance_index::vector::hnsw::HNSW;
use lance_index::vector::ivf::builder::recommended_num_partitions;
use lance_index::vector::ivf::storage::IvfModel;
use object_store::path::Path;
use lance_arrow::FixedSizeListArrayExt;
use lance_index::vector::pq::ProductQuantizer;
use lance_index::vector::quantizer::QuantizationType;
use lance_index::vector::v3::shuffler::{Shuffler, create_ivf_shuffler};
use lance_index::vector::v3::subindex::SubIndexType;
use lance_index::vector::{
VectorIndex,
hnsw::{
builder::HnswBuildParams,
index::{HNSWIndex, HNSWIndexOptions},
},
ivf::IvfBuildParams,
pq::PQBuildParams,
sq::{ScalarQuantizer, builder::SQBuildParams},
};
use lance_index::{
DatasetIndexExt, INDEX_AUXILIARY_FILE_NAME, INDEX_METADATA_SCHEMA_KEY, IndexType,
};
use lance_io::traits::Reader;
use lance_linalg::distance::*;
use lance_table::format::{IndexMetadata, list_index_files_with_sizes};
use serde::Serialize;
use tracing::instrument;
use utils::get_vector_type;
use uuid::Uuid;
use super::{DatasetIndexInternalExt, IndexParams, pb, vector_index_details};
use crate::dataset::index::dataset_format_version;
use crate::dataset::transaction::{Operation, Transaction};
use crate::{Error, Result, dataset::Dataset, index::pb::vector_index_stage::Stage};
pub const LANCE_VECTOR_INDEX: &str = "__lance_vector_index";
#[derive(Debug, Clone)]
pub enum StageParams {
Ivf(IvfBuildParams),
Hnsw(HnswBuildParams),
PQ(PQBuildParams),
SQ(SQBuildParams),
RQ(RQBuildParams),
}
#[derive(Debug, Clone, Serialize)]
pub enum IndexFileVersion {
Legacy,
V3,
}
impl IndexFileVersion {
pub fn try_from(version: &str) -> Result<Self> {
match version.to_lowercase().as_str() {
"legacy" => Ok(Self::Legacy),
"v3" => Ok(Self::V3),
_ => Err(Error::index(format!(
"Invalid index file version: {}",
version
))),
}
}
}
#[derive(Debug, Clone)]
pub struct VectorIndexParams {
pub stages: Vec<StageParams>,
pub metric_type: MetricType,
pub version: IndexFileVersion,
pub skip_transpose: bool,
}
impl VectorIndexParams {
pub fn version(&mut self, version: IndexFileVersion) -> &mut Self {
self.version = version;
self
}
pub fn skip_transpose(&mut self, skip_transpose: bool) -> &mut Self {
self.skip_transpose = skip_transpose;
self
}
pub fn ivf_flat(num_partitions: usize, metric_type: MetricType) -> Self {
let ivf_params = IvfBuildParams::new(num_partitions);
let stages = vec![StageParams::Ivf(ivf_params)];
Self {
stages,
metric_type,
version: IndexFileVersion::V3,
skip_transpose: false,
}
}
pub fn with_ivf_flat_params(metric_type: MetricType, ivf: IvfBuildParams) -> Self {
let stages = vec![StageParams::Ivf(ivf)];
Self {
stages,
metric_type,
version: IndexFileVersion::V3,
skip_transpose: false,
}
}
pub fn ivf_pq(
num_partitions: usize,
num_bits: u8,
num_sub_vectors: usize,
metric_type: MetricType,
max_iterations: usize,
) -> Self {
let mut stages: Vec<StageParams> = vec![];
stages.push(StageParams::Ivf(IvfBuildParams::new(num_partitions)));
let pq_params = PQBuildParams {
num_bits: num_bits as usize,
num_sub_vectors,
max_iters: max_iterations,
..Default::default()
};
stages.push(StageParams::PQ(pq_params));
Self {
stages,
metric_type,
version: IndexFileVersion::V3,
skip_transpose: false,
}
}
pub fn ivf_rq(num_partitions: usize, num_bits: u8, distance_type: DistanceType) -> Self {
Self::ivf_rq_with_rotation(
num_partitions,
num_bits,
distance_type,
RQRotationType::default(),
)
}
pub fn ivf_rq_with_rotation(
num_partitions: usize,
num_bits: u8,
distance_type: DistanceType,
rotation_type: RQRotationType,
) -> Self {
let ivf = IvfBuildParams::new(num_partitions);
let rq = RQBuildParams::with_rotation_type(num_bits, rotation_type);
let stages = vec![StageParams::Ivf(ivf), StageParams::RQ(rq)];
Self {
stages,
metric_type: distance_type,
version: IndexFileVersion::V3,
skip_transpose: false,
}
}
pub fn with_ivf_pq_params(
metric_type: MetricType,
ivf: IvfBuildParams,
pq: PQBuildParams,
) -> Self {
let stages = vec![StageParams::Ivf(ivf), StageParams::PQ(pq)];
Self {
stages,
metric_type,
version: IndexFileVersion::V3,
skip_transpose: false,
}
}
pub fn with_ivf_sq_params(
metric_type: MetricType,
ivf: IvfBuildParams,
sq: SQBuildParams,
) -> Self {
let stages = vec![StageParams::Ivf(ivf), StageParams::SQ(sq)];
Self {
stages,
metric_type,
version: IndexFileVersion::V3,
skip_transpose: false,
}
}
pub fn with_ivf_rq_params(
metric_type: MetricType,
ivf: IvfBuildParams,
rq: RQBuildParams,
) -> Self {
let stages = vec![StageParams::Ivf(ivf), StageParams::RQ(rq)];
Self {
stages,
metric_type,
version: IndexFileVersion::V3,
skip_transpose: false,
}
}
pub fn ivf_hnsw(
distance_type: DistanceType,
ivf: IvfBuildParams,
hnsw: HnswBuildParams,
) -> Self {
let stages = vec![StageParams::Ivf(ivf), StageParams::Hnsw(hnsw)];
Self {
stages,
metric_type: distance_type,
version: IndexFileVersion::V3,
skip_transpose: false,
}
}
pub fn with_ivf_hnsw_pq_params(
metric_type: MetricType,
ivf: IvfBuildParams,
hnsw: HnswBuildParams,
pq: PQBuildParams,
) -> Self {
let stages = vec![
StageParams::Ivf(ivf),
StageParams::Hnsw(hnsw),
StageParams::PQ(pq),
];
Self {
stages,
metric_type,
version: IndexFileVersion::V3,
skip_transpose: false,
}
}
pub fn with_ivf_hnsw_sq_params(
metric_type: MetricType,
ivf: IvfBuildParams,
hnsw: HnswBuildParams,
sq: SQBuildParams,
) -> Self {
let stages = vec![
StageParams::Ivf(ivf),
StageParams::Hnsw(hnsw),
StageParams::SQ(sq),
];
Self {
stages,
metric_type,
version: IndexFileVersion::V3,
skip_transpose: false,
}
}
pub fn index_type(&self) -> IndexType {
let len = self.stages.len();
match (len, self.stages.get(1), self.stages.last()) {
(0, _, _) => IndexType::Vector,
(1, _, Some(StageParams::Ivf(_))) => IndexType::IvfFlat,
(2, _, Some(StageParams::PQ(_))) => IndexType::IvfPq,
(2, _, Some(StageParams::SQ(_))) => IndexType::IvfSq,
(2, _, Some(StageParams::RQ(_))) => IndexType::IvfRq,
(2, _, Some(StageParams::Hnsw(_))) => IndexType::IvfHnswFlat,
(3, Some(StageParams::Hnsw(_)), Some(StageParams::PQ(_))) => IndexType::IvfHnswPq,
(3, Some(StageParams::Hnsw(_)), Some(StageParams::SQ(_))) => IndexType::IvfHnswSq,
_ => IndexType::Vector,
}
}
}
impl IndexParams for VectorIndexParams {
fn as_any(&self) -> &dyn Any {
self
}
fn index_name(&self) -> &str {
LANCE_VECTOR_INDEX
}
}
async fn prepare_vector_segment_build(
dataset: &Dataset,
column: &str,
params: &VectorIndexParams,
progress: Arc<dyn IndexBuildProgress>,
mode: &str,
require_precomputed_ivf: bool,
) -> Result<(DataType, IndexType, IvfBuildParams, Box<dyn Shuffler>)> {
let stages = ¶ms.stages;
if stages.is_empty() {
return Err(Error::index(format!("{mode}: must have at least 1 stage")));
}
let StageParams::Ivf(ivf_params0) = &stages[0] else {
return Err(Error::index(format!(
"{mode}: invalid stages: {:?}",
stages
)));
};
if require_precomputed_ivf && ivf_params0.centroids.is_none() {
return Err(Error::index(format!(
"{mode}: missing precomputed IVF centroids; please provide \
IvfBuildParams.centroids for distributed segment build"
)));
}
let (vector_type, element_type) = get_vector_type(dataset.schema(), column)?;
if let DataType::List(_) = vector_type
&& params.metric_type != DistanceType::Cosine
{
return Err(Error::index(format!(
"{mode}: multivector type supports only cosine distance"
)));
}
let num_rows = dataset.count_rows(None).await?;
let index_type = params.index_type();
let num_partitions = ivf_params0.num_partitions.unwrap_or_else(|| {
recommended_num_partitions(
num_rows,
ivf_params0
.target_partition_size
.unwrap_or(index_type.target_partition_size()),
)
});
let mut ivf_params = ivf_params0.clone();
ivf_params.num_partitions = Some(num_partitions);
let format_version = dataset_format_version(dataset);
let temp_dir = TempStdDir::default();
let temp_dir_path = Path::from_filesystem_path(&temp_dir)?;
let shuffler = create_ivf_shuffler(
temp_dir_path,
num_partitions,
format_version,
Some(progress),
);
Ok((element_type, index_type, ivf_params, shuffler))
}
#[allow(clippy::too_many_arguments)]
#[instrument(level = "debug", skip(dataset))]
pub(crate) async fn build_distributed_vector_index(
dataset: &Dataset,
column: &str,
_name: &str,
uuid: &str,
params: &VectorIndexParams,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
fragment_ids: &[u32],
progress: Arc<dyn IndexBuildProgress>,
) -> Result<Uuid> {
let (element_type, index_type, ivf_params, shuffler) = prepare_vector_segment_build(
dataset,
column,
params,
progress.clone(),
"Build Distributed Vector Index",
true,
)
.await?;
let stages = ¶ms.stages;
let ivf_centroids = ivf_params
.centroids
.as_ref()
.expect("precomputed IVF centroids required for distributed indexing; checked above")
.as_ref()
.clone();
let filtered_dataset = dataset.clone();
let segment_uuid = Uuid::parse_str(uuid)
.map_err(|err| Error::invalid_input(format!("Invalid index UUID '{uuid}': {err}")))?;
let index_dir = dataset.indices_dir().child(segment_uuid.to_string());
let fragment_filter = fragment_ids.to_vec();
let make_ivf_model = || IvfModel::new(ivf_centroids.clone(), None);
let make_global_pq = |pq_params: &PQBuildParams| -> Result<ProductQuantizer> {
if pq_params.codebook.is_none() {
return Err(Error::index(
"Build Distributed Vector Index: missing precomputed PQ codebook; \
please provide PQBuildParams.codebook for distributed indexing"
.to_string(),
));
}
let dim = crate::index::vector::utils::get_vector_dim(filtered_dataset.schema(), column)?;
let metric_type = params.metric_type;
let pre_codebook = pq_params
.codebook
.clone()
.expect("checked above that PQ codebook is present");
let codebook_fsl =
arrow_array::FixedSizeListArray::try_new_from_values(pre_codebook, dim as i32)?;
Ok(ProductQuantizer::new(
pq_params.num_sub_vectors,
pq_params.num_bits as u32,
dim,
codebook_fsl,
if metric_type == MetricType::Cosine {
MetricType::L2
} else {
metric_type
},
))
};
match index_type {
IndexType::IvfFlat => match element_type {
DataType::Float16 | DataType::Float32 | DataType::Float64 => {
let ivf_model = make_ivf_model();
IvfIndexBuilder::<FlatIndex, FlatQuantizer>::new(
filtered_dataset,
column.to_owned(),
index_dir.clone(),
params.metric_type,
shuffler,
Some(ivf_params),
Some(()),
(),
frag_reuse_index,
)?
.with_ivf(ivf_model)
.with_fragment_filter(fragment_filter)
.with_progress(progress.clone())
.build()
.await?;
}
DataType::UInt8 => {
let ivf_model = make_ivf_model();
IvfIndexBuilder::<FlatIndex, FlatBinQuantizer>::new(
filtered_dataset,
column.to_owned(),
index_dir.clone(),
params.metric_type,
shuffler,
Some(ivf_params),
Some(()),
(),
frag_reuse_index,
)?
.with_ivf(ivf_model)
.with_fragment_filter(fragment_filter)
.with_progress(progress.clone())
.build()
.await?;
}
_ => {
return Err(Error::index(format!(
"Build Distributed Vector Index: invalid data type: {:?}",
element_type
)));
}
},
IndexType::IvfPq => {
let len = stages.len();
let StageParams::PQ(pq_params) = &stages[len - 1] else {
return Err(Error::index(format!(
"Build Distributed Vector Index: invalid stages: {:?}",
stages
)));
};
match params.version {
IndexFileVersion::Legacy => {
return Err(Error::index(
"Distributed indexing does not support legacy IVF_PQ format".to_string(),
));
}
IndexFileVersion::V3 => {
let ivf_model = make_ivf_model();
let global_pq = make_global_pq(pq_params)?;
IvfIndexBuilder::<FlatIndex, ProductQuantizer>::new(
filtered_dataset,
column.to_owned(),
index_dir.clone(),
params.metric_type,
shuffler,
Some(ivf_params),
Some(pq_params.clone()),
(),
frag_reuse_index,
)?
.with_ivf(ivf_model)
.with_quantizer(global_pq)
.with_transpose(false)
.with_fragment_filter(fragment_filter)
.with_progress(progress.clone())
.build()
.await?;
}
}
}
IndexType::IvfSq => {
let StageParams::SQ(sq_params) = &stages[1] else {
return Err(Error::index(format!(
"Build Distributed Vector Index: invalid stages: {:?}",
stages
)));
};
IvfIndexBuilder::<FlatIndex, ScalarQuantizer>::new(
filtered_dataset,
column.to_owned(),
index_dir.clone(),
params.metric_type,
shuffler,
Some(ivf_params),
Some(sq_params.clone()),
(),
frag_reuse_index,
)?
.with_fragment_filter(fragment_filter)
.with_progress(progress.clone())
.build()
.await?;
}
IndexType::IvfHnswFlat => {
let StageParams::Hnsw(hnsw_params) = &stages[1] else {
return Err(Error::index(format!(
"Build Distributed Vector Index: invalid stages: {:?}",
stages
)));
};
IvfIndexBuilder::<HNSW, FlatQuantizer>::new(
filtered_dataset,
column.to_owned(),
index_dir.clone(),
params.metric_type,
shuffler,
Some(ivf_params),
Some(()),
hnsw_params.clone(),
frag_reuse_index,
)?
.with_fragment_filter(fragment_filter)
.with_progress(progress.clone())
.build()
.await?;
}
IndexType::IvfHnswPq => {
let StageParams::Hnsw(hnsw_params) = &stages[1] else {
return Err(Error::index(format!(
"Build Distributed Vector Index: invalid stages: {:?}",
stages
)));
};
let StageParams::PQ(pq_params) = &stages[2] else {
return Err(Error::index(format!(
"Build Distributed Vector Index: invalid stages: {:?}",
stages
)));
};
let ivf_model = make_ivf_model();
let global_pq = make_global_pq(pq_params)?;
IvfIndexBuilder::<HNSW, ProductQuantizer>::new(
filtered_dataset,
column.to_owned(),
index_dir.clone(),
params.metric_type,
shuffler,
Some(ivf_params),
Some(pq_params.clone()),
hnsw_params.clone(),
frag_reuse_index,
)?
.with_ivf(ivf_model)
.with_quantizer(global_pq)
.with_transpose(false)
.with_fragment_filter(fragment_filter)
.with_progress(progress.clone())
.build()
.await?;
}
IndexType::IvfHnswSq => {
let StageParams::Hnsw(hnsw_params) = &stages[1] else {
return Err(Error::index(format!(
"Build Distributed Vector Index: invalid stages: {:?}",
stages
)));
};
let StageParams::SQ(sq_params) = &stages[2] else {
return Err(Error::index(format!(
"Build Distributed Vector Index: invalid stages: {:?}",
stages
)));
};
IvfIndexBuilder::<HNSW, ScalarQuantizer>::new(
filtered_dataset,
column.to_owned(),
index_dir.clone(),
params.metric_type,
shuffler,
Some(ivf_params),
Some(sq_params.clone()),
hnsw_params.clone(),
frag_reuse_index,
)?
.with_fragment_filter(fragment_filter)
.with_progress(progress.clone())
.build()
.await?;
}
IndexType::IvfRq => {
return Err(Error::index(format!(
"Build Distributed Vector Index: invalid index type: {:?} \
is not supported in distributed mode; skipping this shard",
index_type
)));
}
_ => {
return Err(Error::index(format!(
"Build Distributed Vector Index: invalid index type: {:?}",
index_type
)));
}
};
Ok(segment_uuid)
}
#[instrument(level = "debug", skip(dataset))]
pub(crate) async fn build_vector_index(
dataset: &Dataset,
column: &str,
name: &str,
uuid: &str,
params: &VectorIndexParams,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
progress: Arc<dyn IndexBuildProgress>,
) -> Result<()> {
let (element_type, index_type, ivf_params, shuffler) = prepare_vector_segment_build(
dataset,
column,
params,
progress.clone(),
"Build Vector Index",
false,
)
.await?;
let stages = ¶ms.stages;
match index_type {
IndexType::IvfFlat => match element_type {
DataType::Float16 | DataType::Float32 | DataType::Float64 => {
IvfIndexBuilder::<FlatIndex, FlatQuantizer>::new(
dataset.clone(),
column.to_owned(),
dataset.indices_dir().child(uuid),
params.metric_type,
shuffler,
Some(ivf_params),
Some(()),
(),
frag_reuse_index,
)?
.with_progress(progress.clone())
.build()
.await?;
}
DataType::UInt8 => {
IvfIndexBuilder::<FlatIndex, FlatBinQuantizer>::new(
dataset.clone(),
column.to_owned(),
dataset.indices_dir().child(uuid),
params.metric_type,
shuffler,
Some(ivf_params),
Some(()),
(),
frag_reuse_index,
)?
.with_progress(progress.clone())
.build()
.await?;
}
_ => {
return Err(Error::index(format!(
"Build Vector Index: invalid data type: {:?}",
element_type
)));
}
},
IndexType::IvfPq => {
let len = stages.len();
let StageParams::PQ(pq_params) = &stages[len - 1] else {
return Err(Error::index(format!(
"Build Vector Index: invalid stages: {:?}",
stages
)));
};
match params.version {
IndexFileVersion::Legacy => {
build_ivf_pq_index(
dataset,
column,
name,
uuid,
params.metric_type,
&ivf_params,
pq_params,
progress.clone(),
)
.await?;
}
IndexFileVersion::V3 => {
let mut builder = IvfIndexBuilder::<FlatIndex, ProductQuantizer>::new(
dataset.clone(),
column.to_owned(),
dataset.indices_dir().child(uuid),
params.metric_type,
shuffler,
Some(ivf_params),
Some(pq_params.clone()),
(),
frag_reuse_index,
)?;
builder
.with_transpose(!params.skip_transpose)
.with_progress(progress.clone())
.build()
.await?;
}
}
}
IndexType::IvfSq => {
let StageParams::SQ(sq_params) = &stages[1] else {
return Err(Error::index(format!(
"Build Vector Index: invalid stages: {:?}",
stages
)));
};
IvfIndexBuilder::<FlatIndex, ScalarQuantizer>::new(
dataset.clone(),
column.to_owned(),
dataset.indices_dir().child(uuid),
params.metric_type,
shuffler,
Some(ivf_params),
Some(sq_params.clone()),
(),
frag_reuse_index,
)?
.with_progress(progress.clone())
.build()
.await?;
}
IndexType::IvfRq => {
let StageParams::RQ(rq_params) = &stages[1] else {
return Err(Error::index(format!(
"Build Vector Index: invalid stages: {:?}",
stages
)));
};
let mut builder = IvfIndexBuilder::<FlatIndex, RabitQuantizer>::new(
dataset.clone(),
column.to_owned(),
dataset.indices_dir().child(uuid),
params.metric_type,
shuffler,
Some(ivf_params),
Some(rq_params.clone()),
(),
frag_reuse_index,
)?;
builder
.with_transpose(!params.skip_transpose)
.with_progress(progress.clone())
.build()
.await?;
}
IndexType::IvfHnswFlat => {
let StageParams::Hnsw(hnsw_params) = &stages[1] else {
return Err(Error::index(format!(
"Build Vector Index: invalid stages: {:?}",
stages
)));
};
IvfIndexBuilder::<HNSW, FlatQuantizer>::new(
dataset.clone(),
column.to_owned(),
dataset.indices_dir().child(uuid),
params.metric_type,
shuffler,
Some(ivf_params),
Some(()),
hnsw_params.clone(),
frag_reuse_index,
)?
.with_progress(progress.clone())
.build()
.await?;
}
IndexType::IvfHnswPq => {
let StageParams::Hnsw(hnsw_params) = &stages[1] else {
return Err(Error::index(format!(
"Build Vector Index: invalid stages: {:?}",
stages
)));
};
let StageParams::PQ(pq_params) = &stages[2] else {
return Err(Error::index(format!(
"Build Vector Index: invalid stages: {:?}",
stages
)));
};
IvfIndexBuilder::<HNSW, ProductQuantizer>::new(
dataset.clone(),
column.to_owned(),
dataset.indices_dir().child(uuid),
params.metric_type,
shuffler,
Some(ivf_params),
Some(pq_params.clone()),
hnsw_params.clone(),
frag_reuse_index,
)?
.with_progress(progress.clone())
.build()
.await?;
}
IndexType::IvfHnswSq => {
let StageParams::Hnsw(hnsw_params) = &stages[1] else {
return Err(Error::index(format!(
"Build Vector Index: invalid stages: {:?}",
stages
)));
};
let StageParams::SQ(sq_params) = &stages[2] else {
return Err(Error::index(format!(
"Build Vector Index: invalid stages: {:?}",
stages
)));
};
IvfIndexBuilder::<HNSW, ScalarQuantizer>::new(
dataset.clone(),
column.to_owned(),
dataset.indices_dir().child(uuid),
params.metric_type,
shuffler,
Some(ivf_params),
Some(sq_params.clone()),
hnsw_params.clone(),
frag_reuse_index,
)?
.with_progress(progress.clone())
.build()
.await?;
}
_ => {
return Err(Error::index(format!(
"Build Vector Index: invalid index type: {:?}",
index_type
)));
}
};
Ok(())
}
#[instrument(level = "debug", skip(dataset, existing_index, frag_reuse_index))]
pub(crate) async fn build_vector_index_incremental(
dataset: &Dataset,
column: &str,
uuid: &str,
params: &VectorIndexParams,
existing_index: Arc<dyn VectorIndex>,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
progress: Arc<dyn IndexBuildProgress>,
) -> Result<()> {
let stages = ¶ms.stages;
if stages.is_empty() {
return Err(Error::index(
"Build Vector Index: must have at least 1 stage".to_string(),
));
};
let StageParams::Ivf(ivf_params) = &stages[0] else {
return Err(Error::index(format!(
"Build Vector Index: invalid stages: {:?}",
stages
)));
};
let (vector_type, element_type) = get_vector_type(dataset.schema(), column)?;
if let DataType::List(_) = vector_type
&& params.metric_type != DistanceType::Cosine
{
return Err(Error::index(
"Build Vector Index: multivector type supports only cosine distance".to_string(),
));
}
let ivf_model = existing_index.ivf_model().clone();
let quantizer = existing_index.quantizer();
let expected_partitions = ivf_params
.num_partitions
.unwrap_or(ivf_model.num_partitions());
if ivf_model.num_partitions() != expected_partitions {
return Err(Error::index(format!(
"Number of partitions mismatch: existing index has {} partitions, but params specify {}",
ivf_model.num_partitions(),
expected_partitions
)));
}
let format_version = dataset_format_version(dataset);
let temp_dir = TempStdDir::default();
let temp_dir_path = Path::from_filesystem_path(&temp_dir)?;
let shuffler = create_ivf_shuffler(
temp_dir_path,
ivf_model.num_partitions(),
format_version,
Some(progress.clone()),
);
let index_dir = dataset.indices_dir().child(uuid);
let (sub_index_type, quantization_type) = existing_index.sub_index_type();
match (sub_index_type, quantization_type) {
(SubIndexType::Flat, QuantizationType::Flat) => match element_type {
DataType::Float16 | DataType::Float32 | DataType::Float64 => {
IvfIndexBuilder::<FlatIndex, FlatQuantizer>::new_incremental(
dataset.clone(),
column.to_owned(),
index_dir,
params.metric_type,
shuffler,
(),
frag_reuse_index,
OptimizeOptions::append(),
)?
.with_ivf(ivf_model)
.with_quantizer(quantizer.try_into()?)
.with_progress(progress.clone())
.build()
.await?;
}
DataType::UInt8 => {
IvfIndexBuilder::<FlatIndex, FlatBinQuantizer>::new_incremental(
dataset.clone(),
column.to_owned(),
index_dir,
params.metric_type,
shuffler,
(),
frag_reuse_index,
OptimizeOptions::append(),
)?
.with_ivf(ivf_model)
.with_quantizer(quantizer.try_into()?)
.with_progress(progress.clone())
.build()
.await?;
}
_ => {
return Err(Error::index(format!(
"Build Vector Index: invalid data type: {:?}",
element_type
)));
}
},
(SubIndexType::Flat, QuantizationType::Product) => {
let mut builder = IvfIndexBuilder::<FlatIndex, ProductQuantizer>::new_incremental(
dataset.clone(),
column.to_owned(),
index_dir,
params.metric_type,
shuffler,
(),
frag_reuse_index,
OptimizeOptions::append(),
)?;
builder
.with_ivf(ivf_model)
.with_quantizer(quantizer.try_into()?)
.with_transpose(!params.skip_transpose)
.with_progress(progress.clone())
.build()
.await?;
}
(SubIndexType::Flat, QuantizationType::Scalar) => {
IvfIndexBuilder::<FlatIndex, ScalarQuantizer>::new_incremental(
dataset.clone(),
column.to_owned(),
index_dir,
params.metric_type,
shuffler,
(),
frag_reuse_index,
OptimizeOptions::append(),
)?
.with_ivf(ivf_model)
.with_quantizer(quantizer.try_into()?)
.with_progress(progress.clone())
.build()
.await?;
}
(SubIndexType::Flat, QuantizationType::Rabit) => {
let mut builder = IvfIndexBuilder::<FlatIndex, RabitQuantizer>::new_incremental(
dataset.clone(),
column.to_owned(),
index_dir,
params.metric_type,
shuffler,
(),
frag_reuse_index,
OptimizeOptions::append(),
)?;
builder
.with_ivf(ivf_model)
.with_quantizer(quantizer.try_into()?)
.with_transpose(!params.skip_transpose)
.with_progress(progress.clone())
.build()
.await?;
}
(SubIndexType::Hnsw, quantization_type) => {
let StageParams::Hnsw(hnsw_params) = &stages[1] else {
return Err(Error::index(format!(
"Build Vector Index: HNSW index missing HNSW params in stages: {:?}",
stages
)));
};
match quantization_type {
QuantizationType::Flat => {
IvfIndexBuilder::<HNSW, FlatQuantizer>::new_incremental(
dataset.clone(),
column.to_owned(),
index_dir,
params.metric_type,
shuffler,
hnsw_params.clone(),
frag_reuse_index,
OptimizeOptions::append(),
)?
.with_ivf(ivf_model)
.with_quantizer(quantizer.try_into()?)
.with_progress(progress.clone())
.build()
.await?;
}
QuantizationType::Product => {
IvfIndexBuilder::<HNSW, ProductQuantizer>::new_incremental(
dataset.clone(),
column.to_owned(),
index_dir,
params.metric_type,
shuffler,
hnsw_params.clone(),
frag_reuse_index,
OptimizeOptions::append(),
)?
.with_ivf(ivf_model)
.with_quantizer(quantizer.try_into()?)
.with_progress(progress.clone())
.build()
.await?;
}
QuantizationType::Scalar => {
IvfIndexBuilder::<HNSW, ScalarQuantizer>::new_incremental(
dataset.clone(),
column.to_owned(),
index_dir,
params.metric_type,
shuffler,
hnsw_params.clone(),
frag_reuse_index,
OptimizeOptions::append(),
)?
.with_ivf(ivf_model)
.with_quantizer(quantizer.try_into()?)
.with_progress(progress.clone())
.build()
.await?;
}
QuantizationType::Rabit => {
return Err(Error::index(
"Rabit quantization is not supported for HNSW index".to_string(),
));
}
}
}
}
Ok(())
}
#[instrument(level = "debug", skip_all)]
pub(crate) async fn build_empty_vector_index(
_dataset: &Dataset,
column: &str,
name: &str,
_uuid: &str,
_params: &VectorIndexParams,
) -> Result<()> {
Err(Error::not_supported_source(
format!(
"Creating empty vector indices with train=False is not yet implemented. \
Index '{}' for column '{}' cannot be created without training.",
name, column
)
.into(),
))
}
#[instrument(level = "debug", skip_all, fields(old_uuid = old_uuid.to_string(), new_uuid = new_uuid.to_string(), num_rows = mapping.len()))]
pub(crate) async fn remap_vector_index(
dataset: Arc<Dataset>,
column: &str,
old_uuid: &Uuid,
new_uuid: &Uuid,
old_metadata: &IndexMetadata,
mapping: &HashMap<u64, Option<u64>>,
) -> Result<()> {
let old_index = dataset
.open_vector_index(column, &old_uuid.to_string(), &NoOpMetricsCollector)
.await?;
if let Some(ivf_index) = old_index.as_any().downcast_ref::<IVFIndex>() {
remap_index_file(
dataset.as_ref(),
&old_uuid.to_string(),
&new_uuid.to_string(),
old_metadata.dataset_version,
ivf_index,
mapping,
old_metadata.name.clone(),
column.to_string(),
vec![],
)
.await?;
} else {
remap_index_file_v3(
dataset.as_ref(),
&new_uuid.to_string(),
old_index,
mapping,
column.to_string(),
)
.await?;
}
Ok(())
}
#[instrument(level = "debug", skip(dataset, vec_idx, reader))]
pub(crate) async fn open_vector_index(
dataset: Arc<Dataset>,
uuid: &str,
vec_idx: &lance_index::pb::VectorIndex,
reader: Arc<dyn Reader>,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
) -> Result<Arc<dyn VectorIndex>> {
let metric_type = pb::VectorMetricType::try_from(vec_idx.metric_type)?.into();
let mut last_stage: Option<Arc<dyn VectorIndex>> = None;
let frag_reuse_uuid = dataset.frag_reuse_index_uuid().await;
for stg in vec_idx.stages.iter().rev() {
match stg.stage.as_ref() {
#[allow(unused_variables)]
Some(Stage::Transform(tf)) => {
if last_stage.is_none() {
return Err(Error::index(format!(
"Invalid vector index stages: {:?}",
vec_idx.stages
)));
}
}
Some(Stage::Ivf(ivf_pb)) => {
if last_stage.is_none() {
return Err(Error::index(format!(
"Invalid vector index stages: {:?}",
vec_idx.stages
)));
}
let ivf = IvfModel::try_from(ivf_pb.to_owned())?;
last_stage = Some(Arc::new(IVFIndex::try_new(
uuid,
ivf,
reader.clone(),
last_stage.unwrap(),
metric_type,
dataset
.index_cache
.for_index(uuid, frag_reuse_uuid.as_ref()),
)?));
}
Some(Stage::Pq(pq_proto)) => {
if last_stage.is_some() {
return Err(Error::index(format!(
"Invalid vector index stages: {:?}",
vec_idx.stages
)));
};
let pq = ProductQuantizer::from_proto(pq_proto, metric_type)?;
last_stage = Some(Arc::new(PQIndex::new(
pq,
metric_type,
frag_reuse_index.clone(),
)));
}
Some(Stage::Diskann(_)) => {
return Err(Error::index(
"DiskANN support is removed from Lance.".to_string(),
));
}
_ => {}
}
}
if last_stage.is_none() {
return Err(Error::index(format!(
"Invalid index stages: {:?}",
vec_idx.stages
)));
}
let idx = last_stage.unwrap();
Ok(idx)
}
#[instrument(level = "debug", skip(dataset, reader))]
pub(crate) async fn open_vector_index_v2(
dataset: Arc<Dataset>,
column: &str,
uuid: &str,
reader: PreviousFileReader,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
) -> Result<Arc<dyn VectorIndex>> {
let index_metadata = reader
.schema()
.metadata
.get(INDEX_METADATA_SCHEMA_KEY)
.ok_or(Error::index("Index Metadata not found".to_owned()))?;
let index_metadata: lance_index::IndexMetadata = serde_json::from_str(index_metadata)?;
let distance_type = DistanceType::try_from(index_metadata.distance_type.as_str())?;
let frag_reuse_uuid = dataset.frag_reuse_index_uuid().await;
let index_meta = dataset
.load_index(uuid)
.await?
.ok_or_else(|| Error::index(format!("Index with id {} does not exist", uuid)))?;
let index_dir = dataset.indice_files_dir(&index_meta)?;
let index: Arc<dyn VectorIndex> = match index_metadata.index_type.as_str() {
"IVF_HNSW_PQ" => {
let aux_path = index_dir.child(uuid).child(INDEX_AUXILIARY_FILE_NAME);
let aux_reader = dataset.object_store().open(&aux_path).await?;
let ivf_data = IvfModel::load(&reader).await?;
let options = HNSWIndexOptions { use_residual: true };
let hnsw = HNSWIndex::<ProductQuantizer>::try_new(
reader.object_reader.clone(),
aux_reader.into(),
options,
)
.await?;
let pb_ivf = pb::Ivf::try_from(&ivf_data)?;
let ivf = IvfModel::try_from(pb_ivf)?;
Arc::new(IVFIndex::try_new(
uuid,
ivf,
reader.object_reader.clone(),
Arc::new(hnsw),
distance_type,
dataset
.index_cache
.for_index(uuid, frag_reuse_uuid.as_ref()),
)?)
}
"IVF_HNSW_SQ" => {
let aux_path = index_dir.child(uuid).child(INDEX_AUXILIARY_FILE_NAME);
let aux_reader = dataset.object_store().open(&aux_path).await?;
let ivf_data = IvfModel::load(&reader).await?;
let options = HNSWIndexOptions {
use_residual: false,
};
let hnsw = HNSWIndex::<ScalarQuantizer>::try_new(
reader.object_reader.clone(),
aux_reader.into(),
options,
)
.await?;
let pb_ivf = pb::Ivf::try_from(&ivf_data)?;
let ivf = IvfModel::try_from(pb_ivf)?;
Arc::new(IVFIndex::try_new(
uuid,
ivf,
reader.object_reader.clone(),
Arc::new(hnsw),
distance_type,
dataset
.index_cache
.for_index(uuid, frag_reuse_uuid.as_ref()),
)?)
}
index_type => {
if let Some(ext) = dataset
.session
.index_extensions
.get(&(IndexType::Vector, index_type.to_string()))
{
ext.clone()
.to_vector()
.ok_or(Error::internal(
"unable to cast index extension to vector".to_string(),
))?
.load_index(dataset.clone(), column, uuid, reader)
.await?
} else {
return Err(Error::index(format!(
"Unsupported index type: {}",
index_metadata.index_type
)));
}
}
};
Ok(index)
}
pub async fn initialize_vector_index(
target_dataset: &mut Dataset,
source_dataset: &Dataset,
source_index: &IndexMetadata,
field_names: &[&str],
) -> Result<()> {
if field_names.is_empty() || field_names.len() > 1 {
return Err(Error::index(format!(
"Unsupported fields for vector index: {:?}",
field_names
)));
}
let column_name = field_names[0];
let source_vector_index = source_dataset
.open_vector_index(
column_name,
&source_index.uuid.to_string(),
&NoOpMetricsCollector,
)
.await?;
let metric_type = source_vector_index.metric_type();
let ivf_model = source_vector_index.ivf_model();
let quantizer = source_vector_index.quantizer();
let (sub_index_type, quantization_type) = source_vector_index.sub_index_type();
let ivf_params = derive_ivf_params(ivf_model);
let params = match (sub_index_type, quantization_type) {
(SubIndexType::Flat, QuantizationType::Flat) => {
VectorIndexParams::with_ivf_flat_params(metric_type, ivf_params)
}
(SubIndexType::Flat, QuantizationType::Product) => {
let pq_quantizer: ProductQuantizer = quantizer.try_into()?;
let pq_params = derive_pq_params(&pq_quantizer);
VectorIndexParams::with_ivf_pq_params(metric_type, ivf_params, pq_params)
}
(SubIndexType::Flat, QuantizationType::Scalar) => {
let sq_quantizer: ScalarQuantizer = quantizer.try_into()?;
let sq_params = derive_sq_params(&sq_quantizer);
VectorIndexParams::with_ivf_sq_params(metric_type, ivf_params, sq_params)
}
(SubIndexType::Flat, QuantizationType::Rabit) => {
let rabit_quantizer: RabitQuantizer = quantizer.try_into()?;
let rabit_params = derive_rabit_params(&rabit_quantizer);
VectorIndexParams::with_ivf_rq_params(metric_type, ivf_params, rabit_params)
}
(SubIndexType::Hnsw, quantization_type) => {
let hnsw_params = derive_hnsw_params(source_vector_index.as_ref());
match quantization_type {
QuantizationType::Flat => {
VectorIndexParams::ivf_hnsw(metric_type, ivf_params, hnsw_params)
}
QuantizationType::Product => {
let pq_quantizer: ProductQuantizer = quantizer.try_into()?;
let pq_params = derive_pq_params(&pq_quantizer);
VectorIndexParams::with_ivf_hnsw_pq_params(
metric_type,
ivf_params,
hnsw_params,
pq_params,
)
}
QuantizationType::Scalar => {
let sq_quantizer: ScalarQuantizer = quantizer.try_into()?;
let sq_params = derive_sq_params(&sq_quantizer);
VectorIndexParams::with_ivf_hnsw_sq_params(
metric_type,
ivf_params,
hnsw_params,
sq_params,
)
}
QuantizationType::Rabit => {
return Err(Error::index(
"Rabit quantization is not supported for HNSW index".to_string(),
));
}
}
}
};
let new_uuid = Uuid::new_v4();
let frag_reuse_index = target_dataset
.open_frag_reuse_index(&NoOpMetricsCollector)
.await?;
build_vector_index_incremental(
target_dataset,
column_name,
&new_uuid.to_string(),
¶ms,
source_vector_index,
frag_reuse_index,
noop_progress(),
)
.await?;
let index_dir = target_dataset.indices_dir().child(new_uuid.to_string());
let files = list_index_files_with_sizes(&target_dataset.object_store, &index_dir).await?;
let field = target_dataset.schema().field(column_name).ok_or_else(|| {
Error::index(format!(
"Column '{}' not found in target dataset",
column_name
))
})?;
let fragment_bitmap = Some(target_dataset.fragment_bitmap.as_ref().clone());
let new_idx = IndexMetadata {
uuid: new_uuid,
name: source_index.name.clone(),
fields: vec![field.id],
dataset_version: target_dataset.manifest.version,
fragment_bitmap,
index_details: Some(Arc::new(vector_index_details())),
index_version: source_index.index_version,
created_at: Some(chrono::Utc::now()),
base_id: None,
files: Some(files),
};
let transaction = Transaction::new(
target_dataset.manifest.version,
Operation::CreateIndex {
new_indices: vec![new_idx],
removed_indices: vec![],
},
None,
);
target_dataset
.apply_commit(transaction, &Default::default(), &Default::default())
.await?;
Ok(())
}
fn derive_ivf_params(ivf_model: &IvfModel) -> IvfBuildParams {
IvfBuildParams {
num_partitions: Some(ivf_model.num_partitions()),
target_partition_size: None,
max_iters: 50, centroids: ivf_model.centroids.clone().map(Arc::new),
#[allow(deprecated)]
retrain: false, sample_rate: 256, precomputed_partitions_file: None,
precomputed_shuffle_buffers: None,
shuffle_partition_batches: 1024 * 10, shuffle_partition_concurrency: 2, storage_options: None,
}
}
fn derive_pq_params(pq_quantizer: &ProductQuantizer) -> PQBuildParams {
PQBuildParams {
num_sub_vectors: pq_quantizer.num_sub_vectors,
num_bits: pq_quantizer.num_bits as usize,
max_iters: 50, kmeans_redos: 1, codebook: Some(Arc::new(pq_quantizer.codebook.clone())),
sample_rate: 256, }
}
fn derive_sq_params(sq_quantizer: &ScalarQuantizer) -> SQBuildParams {
SQBuildParams {
num_bits: sq_quantizer.num_bits(),
sample_rate: 256, }
}
fn derive_rabit_params(rabit_quantizer: &RabitQuantizer) -> RQBuildParams {
RQBuildParams {
num_bits: rabit_quantizer.num_bits(),
rotation_type: rabit_quantizer.rotation_type(),
}
}
fn derive_hnsw_params(source_index: &dyn VectorIndex) -> HnswBuildParams {
let default_params = HnswBuildParams {
max_level: 4,
m: 20,
ef_construction: 100,
prefetch_distance: None,
};
let Ok(stats) = source_index.statistics() else {
return default_params;
};
let Some(sub_index) = stats.get("sub_index") else {
return default_params;
};
if let Some(params) = sub_index.get("params") {
let max_level = params
.get("max_level")
.and_then(|v| v.as_u64())
.map(|v| v as u16)
.unwrap_or(4);
let m = params
.get("m")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
.unwrap_or(20);
let ef_construction = params
.get("ef_construction")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
.unwrap_or(100);
return HnswBuildParams {
max_level,
m,
ef_construction,
prefetch_distance: None,
};
}
default_params
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::Dataset;
use arrow_array::Array;
use arrow_array::RecordBatch;
use arrow_array::types::{Float32Type, Int32Type};
use arrow_schema::{DataType as ArrowDataType, Field, Schema as ArrowSchema};
use lance_core::utils::tempfile::TempStrDir;
use lance_datagen::{BatchCount, RowCount, array};
use lance_file::writer::FileWriterOptions;
use lance_index::DatasetIndexExt;
use lance_index::metrics::NoOpMetricsCollector;
use lance_linalg::distance::MetricType;
#[tokio::test]
async fn test_initialize_vector_index_ivf_pq() {
let test_dir = TempStrDir::default();
let source_uri = format!("{}/source", test_dir.as_str());
let target_uri = format!("{}/target", test_dir.as_str());
let source_reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(32.into()))
.into_reader_rows(RowCount::from(300), BatchCount::from(1));
let mut source_dataset = Dataset::write(source_reader, &source_uri, None)
.await
.unwrap();
let params = VectorIndexParams::ivf_pq(10, 8, 16, MetricType::L2, 50);
source_dataset
.create_index(
&["vector"],
IndexType::Vector,
Some("vector_ivf_pq".to_string()),
¶ms,
false,
)
.await
.unwrap();
let source_dataset = Dataset::open(&source_uri).await.unwrap();
let source_indices = source_dataset.load_indices().await.unwrap();
let source_index = source_indices
.iter()
.find(|idx| idx.name == "vector_ivf_pq")
.unwrap();
let target_reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(32.into()))
.into_reader_rows(RowCount::from(300), BatchCount::from(1));
let mut target_dataset = Dataset::write(target_reader, &target_uri, None)
.await
.unwrap();
initialize_vector_index(
&mut target_dataset,
&source_dataset,
source_index,
&["vector"],
)
.await
.unwrap();
let target_indices = target_dataset.load_indices().await.unwrap();
assert_eq!(target_indices.len(), 1, "Target should have 1 index");
assert_eq!(
target_indices[0].name, "vector_ivf_pq",
"Index name should match"
);
assert_eq!(
target_indices[0].fields,
vec![1],
"Index should be on field 1 (vector)"
);
let target_vector_index = target_dataset
.open_vector_index(
"vector",
&target_indices[0].uuid.to_string(),
&NoOpMetricsCollector,
)
.await
.unwrap();
let stats = target_vector_index.statistics().unwrap();
assert_eq!(
stats.get("index_type").and_then(|v| v.as_str()),
Some("IVF_PQ"),
"Index type should be IVF_PQ"
);
assert_eq!(
stats.get("metric_type").and_then(|v| v.as_str()),
Some("l2"),
"Metric type should be L2"
);
assert_eq!(
stats.get("num_partitions").and_then(|v| v.as_u64()),
Some(10),
"Should have 10 partitions"
);
let source_vector_index = source_dataset
.open_vector_index(
"vector",
&source_index.uuid.to_string(),
&NoOpMetricsCollector,
)
.await
.unwrap();
let source_ivf_model = source_vector_index.ivf_model();
let target_ivf_model = target_vector_index.ivf_model();
assert_eq!(
source_ivf_model.num_partitions(),
target_ivf_model.num_partitions(),
"Source and target should have same number of partitions"
);
if let (Some(source_centroids), Some(target_centroids)) =
(&source_ivf_model.centroids, &target_ivf_model.centroids)
{
assert_eq!(
source_centroids.len(),
target_centroids.len(),
"Centroids arrays should have same length"
);
for i in 0..source_centroids.len() {
let source_centroid = source_centroids.value(i);
let target_centroid = target_centroids.value(i);
let source_data = source_centroid
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<arrow_array::types::Float32Type>>()
.expect("Centroid should be Float32Array");
let target_data = target_centroid
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<arrow_array::types::Float32Type>>()
.expect("Centroid should be Float32Array");
assert_eq!(
source_data.values(),
target_data.values(),
"Centroid {} values should be identical between source and target",
i
);
}
} else {
panic!("Both source and target should have centroids");
}
let source_ivf_params = derive_ivf_params(source_ivf_model);
let target_ivf_params = derive_ivf_params(target_ivf_model);
assert_eq!(
source_ivf_params.num_partitions, target_ivf_params.num_partitions,
"IVF num_partitions should match"
);
assert_eq!(
target_ivf_params.num_partitions,
Some(10),
"Should have 10 partitions as configured"
);
let source_quantizer = source_vector_index.quantizer();
let target_quantizer = target_vector_index.quantizer();
let source_pq: ProductQuantizer = source_quantizer.try_into().unwrap();
let target_pq: ProductQuantizer = target_quantizer.try_into().unwrap();
let source_pq_params = derive_pq_params(&source_pq);
let target_pq_params = derive_pq_params(&target_pq);
assert_eq!(
source_pq_params.num_sub_vectors, target_pq_params.num_sub_vectors,
"PQ num_sub_vectors should match"
);
assert_eq!(
source_pq_params.num_bits, target_pq_params.num_bits,
"PQ num_bits should match"
);
assert_eq!(
target_pq_params.num_sub_vectors, 16,
"PQ should have 16 sub vectors"
);
assert_eq!(target_pq_params.num_bits, 8, "PQ should use 8 bits");
let query_vector = lance_datagen::gen_batch()
.anon_col(array::rand_vec::<Float32Type>(32.into()))
.into_batch_rows(RowCount::from(1))
.unwrap()
.column(0)
.clone();
let query_vector = query_vector
.as_any()
.downcast_ref::<arrow_array::FixedSizeListArray>()
.unwrap();
let results = target_dataset
.scan()
.nearest("vector", &query_vector.value(0), 10)
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), 10, "Should return 10 nearest neighbors");
}
#[tokio::test]
async fn test_initialize_vector_index_ivf_flat() {
let test_dir = TempStrDir::default();
let source_uri = format!("{}/source", test_dir.as_str());
let target_uri = format!("{}/target", test_dir.as_str());
let source_reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(64.into()))
.into_reader_rows(RowCount::from(300), BatchCount::from(1));
let mut source_dataset = Dataset::write(source_reader, &source_uri, None)
.await
.unwrap();
let params = VectorIndexParams::ivf_flat(8, MetricType::Cosine);
source_dataset
.create_index(
&["vector"],
IndexType::Vector,
Some("vector_ivf_flat".to_string()),
¶ms,
false,
)
.await
.unwrap();
let source_dataset = Dataset::open(&source_uri).await.unwrap();
let source_indices = source_dataset.load_indices().await.unwrap();
let source_index = source_indices
.iter()
.find(|idx| idx.name == "vector_ivf_flat")
.unwrap();
let target_reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(64.into()))
.into_reader_rows(RowCount::from(300), BatchCount::from(1));
let mut target_dataset = Dataset::write(target_reader, &target_uri, None)
.await
.unwrap();
initialize_vector_index(
&mut target_dataset,
&source_dataset,
source_index,
&["vector"],
)
.await
.unwrap();
let target_indices = target_dataset.load_indices().await.unwrap();
assert_eq!(target_indices.len(), 1, "Target should have 1 index");
assert_eq!(
target_indices[0].name, "vector_ivf_flat",
"Index name should match"
);
assert_eq!(
target_indices[0].fields,
vec![1],
"Index should be on field 1 (vector)"
);
let target_vector_index = target_dataset
.open_vector_index(
"vector",
&target_indices[0].uuid.to_string(),
&NoOpMetricsCollector,
)
.await
.unwrap();
let stats = target_vector_index.statistics().unwrap();
assert_eq!(
stats.get("index_type").and_then(|v| v.as_str()),
Some("IVF_FLAT"),
"Index type should be IVF_FLAT"
);
let metric = stats
.get("metric_type")
.and_then(|v| v.as_str())
.unwrap_or("");
assert!(
metric == "cosine" || metric == "Cosine",
"Metric type should be Cosine, got: {}",
metric
);
assert_eq!(
stats.get("num_partitions").and_then(|v| v.as_u64()),
Some(8),
"Should have 8 partitions"
);
let source_vector_index = source_dataset
.open_vector_index(
"vector",
&source_index.uuid.to_string(),
&NoOpMetricsCollector,
)
.await
.unwrap();
let source_ivf_model = source_vector_index.ivf_model();
let target_ivf_model = target_vector_index.ivf_model();
assert_eq!(
source_ivf_model.num_partitions(),
target_ivf_model.num_partitions(),
"Source and target should have same number of partitions"
);
if let (Some(source_centroids), Some(target_centroids)) =
(&source_ivf_model.centroids, &target_ivf_model.centroids)
{
assert_eq!(
source_centroids.len(),
target_centroids.len(),
"Centroids arrays should have same length"
);
for i in 0..source_centroids.len() {
let source_centroid = source_centroids.value(i);
let target_centroid = target_centroids.value(i);
let source_data = source_centroid
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<arrow_array::types::Float32Type>>()
.expect("Centroid should be Float32Array");
let target_data = target_centroid
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<arrow_array::types::Float32Type>>()
.expect("Centroid should be Float32Array");
assert_eq!(
source_data.values(),
target_data.values(),
"Centroid {} values should be identical between source and target",
i
);
}
} else {
panic!("Both source and target should have centroids");
}
let source_ivf_params = derive_ivf_params(source_ivf_model);
let target_ivf_params = derive_ivf_params(target_ivf_model);
assert_eq!(
source_ivf_params.num_partitions, target_ivf_params.num_partitions,
"IVF num_partitions should match"
);
assert_eq!(
target_ivf_params.num_partitions,
Some(8),
"Should have 8 partitions as configured"
);
let query_vector = lance_datagen::gen_batch()
.anon_col(array::rand_vec::<Float32Type>(64.into()))
.into_batch_rows(RowCount::from(1))
.unwrap()
.column(0)
.clone();
let query_vector = query_vector
.as_any()
.downcast_ref::<arrow_array::FixedSizeListArray>()
.unwrap();
let results = target_dataset
.scan()
.nearest("vector", &query_vector.value(0), 5)
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), 5, "Should return 5 nearest neighbors");
}
#[tokio::test]
async fn test_build_distributed_invalid_fragment_ids() {
let test_dir = TempStrDir::default();
let uri = format!("{}/ds", test_dir.as_str());
let reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(32.into()))
.into_reader_rows(RowCount::from(128), BatchCount::from(1));
let dataset = Dataset::write(reader, &uri, None).await.unwrap();
let fragments = dataset.fragments();
assert!(
!fragments.is_empty(),
"Dataset should have at least one fragment"
);
let max_id = fragments.iter().map(|f| f.id as u32).max().unwrap();
let invalid_id = max_id + 1000;
let uuid = Uuid::new_v4().to_string();
let mut ivf_params = IvfBuildParams {
num_partitions: Some(4),
..Default::default()
};
let dim = utils::get_vector_dim(dataset.schema(), "vector").unwrap();
let ivf_model = build_ivf_model(
&dataset,
"vector",
dim,
MetricType::L2,
&ivf_params,
noop_progress(),
)
.await
.unwrap();
ivf_params.centroids = ivf_model.centroids.clone().map(Arc::new);
let params = VectorIndexParams::with_ivf_flat_params(MetricType::L2, ivf_params);
let result = build_distributed_vector_index(
&dataset,
"vector",
"vector_ivf_flat_dist",
&uuid,
¶ms,
None,
&[invalid_id],
noop_progress(),
)
.await;
assert!(
result.is_ok(),
"Expected Ok for invalid fragment ids, got {:?}",
result
);
}
#[tokio::test]
async fn test_build_distributed_empty_fragment_ids() {
let test_dir = TempStrDir::default();
let uri = format!("{}/ds", test_dir.as_str());
let reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(32.into()))
.into_reader_rows(RowCount::from(128), BatchCount::from(1));
let dataset = Dataset::write(reader, &uri, None).await.unwrap();
let uuid = Uuid::new_v4().to_string();
let mut ivf_params = IvfBuildParams {
num_partitions: Some(4),
..Default::default()
};
let dim = utils::get_vector_dim(dataset.schema(), "vector").unwrap();
let ivf_model = build_ivf_model(
&dataset,
"vector",
dim,
MetricType::L2,
&ivf_params,
noop_progress(),
)
.await
.unwrap();
ivf_params.centroids = ivf_model.centroids.clone().map(Arc::new);
let params = VectorIndexParams::with_ivf_flat_params(MetricType::L2, ivf_params);
let result = build_distributed_vector_index(
&dataset,
"vector",
"vector_ivf_flat_dist",
&uuid,
¶ms,
None,
&[],
noop_progress(),
)
.await;
assert!(
result.is_ok(),
"Expected Ok for empty fragment ids, got {:?}",
result
);
}
#[tokio::test]
async fn test_train_ivf_progress_is_emitted_before_completion() {
use std::sync::atomic::{AtomicBool, Ordering};
#[derive(Debug)]
struct RecordingProgress {
train_ivf_complete: AtomicBool,
saw_train_ivf_progress_before_complete: AtomicBool,
saw_train_ivf_progress_after_complete: AtomicBool,
}
#[async_trait::async_trait]
impl IndexBuildProgress for RecordingProgress {
async fn stage_start(&self, _: &str, _: Option<u64>, _: &str) -> Result<()> {
Ok(())
}
async fn stage_progress(&self, stage: &str, _: u64) -> Result<()> {
if stage == "train_ivf" {
if self.train_ivf_complete.load(Ordering::Relaxed) {
self.saw_train_ivf_progress_after_complete
.store(true, Ordering::Relaxed);
} else {
self.saw_train_ivf_progress_before_complete
.store(true, Ordering::Relaxed);
}
}
Ok(())
}
async fn stage_complete(&self, stage: &str) -> Result<()> {
if stage == "train_ivf" {
self.train_ivf_complete.store(true, Ordering::Relaxed);
}
Ok(())
}
}
let test_dir = TempStrDir::default();
let uri = format!("{}/ds", test_dir.as_str());
let reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(32.into()))
.into_reader_rows(RowCount::from(128), BatchCount::from(1));
let dataset = Dataset::write(reader, &uri, None).await.unwrap();
let params = VectorIndexParams::ivf_flat(4, MetricType::L2);
let uuid = Uuid::new_v4().to_string();
let progress = Arc::new(RecordingProgress {
train_ivf_complete: AtomicBool::new(false),
saw_train_ivf_progress_before_complete: AtomicBool::new(false),
saw_train_ivf_progress_after_complete: AtomicBool::new(false),
});
build_vector_index(
&dataset,
"vector",
"vector_ivf_flat_progress",
&uuid,
¶ms,
None,
progress.clone(),
)
.await
.unwrap();
assert!(
progress
.saw_train_ivf_progress_before_complete
.load(Ordering::Relaxed),
"expected at least one train_ivf progress event before completion"
);
assert!(
!progress
.saw_train_ivf_progress_after_complete
.load(Ordering::Relaxed),
"found train_ivf progress after completion"
);
}
#[tokio::test]
async fn test_build_distributed_training_metadata_missing() {
let test_dir = TempStrDir::default();
let uri = format!("{}/ds", test_dir.as_str());
let reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(32.into()))
.into_reader_rows(RowCount::from(128), BatchCount::from(1));
let dataset = Dataset::write(reader, &uri, None).await.unwrap();
let params = VectorIndexParams::ivf_flat(4, MetricType::L2);
let uuid = Uuid::new_v4().to_string();
let out_base = dataset.indices_dir().child(&*uuid);
let training_path = out_base.child("global_training.idx");
let writer = dataset.object_store().create(&training_path).await.unwrap();
let arrow_schema = ArrowSchema::new(vec![Field::new("dummy", ArrowDataType::Int32, true)]);
let mut v2w = lance_file::writer::FileWriter::try_new(
writer,
lance_core::datatypes::Schema::try_from(&arrow_schema).unwrap(),
FileWriterOptions::default(),
)
.unwrap();
let empty_batch = RecordBatch::new_empty(Arc::new(arrow_schema));
v2w.write_batch(&empty_batch).await.unwrap();
v2w.finish().await.unwrap();
let fragments = dataset.fragments();
assert!(
!fragments.is_empty(),
"Dataset should have at least one fragment"
);
let valid_id = fragments[0].id as u32;
let result = build_distributed_vector_index(
&dataset,
"vector",
"vector_ivf_flat_dist",
&uuid,
¶ms,
None,
&[valid_id],
noop_progress(),
)
.await;
match result {
Err(Error::Index { message, .. }) => {
assert!(
message.contains("missing precomputed IVF centroids"),
"Unexpected error message: {}",
message
);
}
Ok(_) => panic!("Expected Error::Index when IVF training metadata is missing, got Ok"),
Err(e) => panic!(
"Expected Error::Index when IVF training metadata is missing, got {:?}",
e
),
}
}
#[tokio::test]
async fn test_initialize_vector_index_empty_dataset() {
let test_dir = TempStrDir::default();
let source_uri = format!("{}/source", test_dir.as_str());
let target_uri = format!("{}/target", test_dir.as_str());
let source_reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(32.into()))
.into_reader_rows(RowCount::from(300), BatchCount::from(1));
let mut source_dataset = Dataset::write(source_reader, &source_uri, None)
.await
.unwrap();
let params = VectorIndexParams::ivf_pq(10, 8, 16, MetricType::L2, 50);
source_dataset
.create_index(
&["vector"],
IndexType::Vector,
Some("vector_ivf_pq".to_string()),
¶ms,
false,
)
.await
.unwrap();
let source_dataset = Dataset::open(&source_uri).await.unwrap();
let source_indices = source_dataset.load_indices().await.unwrap();
let source_index = source_indices
.iter()
.find(|idx| idx.name == "vector_ivf_pq")
.unwrap();
let empty_reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(32.into()))
.into_reader_rows(RowCount::from(0), BatchCount::from(1)); let mut target_dataset = Dataset::write(empty_reader, &target_uri, None)
.await
.unwrap();
initialize_vector_index(
&mut target_dataset,
&source_dataset,
source_index,
&["vector"],
)
.await
.unwrap();
let target_indices = target_dataset.load_indices().await.unwrap();
assert_eq!(target_indices.len(), 1, "Empty target should have 1 index");
assert_eq!(
target_indices[0].name, "vector_ivf_pq",
"Index name should match"
);
let source_vector_index = source_dataset
.open_vector_index(
"vector",
&source_index.uuid.to_string(),
&NoOpMetricsCollector,
)
.await
.unwrap();
let target_vector_index = target_dataset
.open_vector_index(
"vector",
&target_indices[0].uuid.to_string(),
&NoOpMetricsCollector,
)
.await
.unwrap();
let source_ivf_model = source_vector_index.ivf_model();
let target_ivf_model = target_vector_index.ivf_model();
assert_eq!(
source_ivf_model.num_partitions(),
target_ivf_model.num_partitions(),
"Empty dataset should still have same number of partitions as source"
);
if let (Some(source_centroids), Some(target_centroids)) =
(&source_ivf_model.centroids, &target_ivf_model.centroids)
{
assert_eq!(
source_centroids.len(),
target_centroids.len(),
"Centroids arrays should have same length even for empty dataset"
);
for i in 0..source_centroids.len() {
let source_centroid = source_centroids.value(i);
let target_centroid = target_centroids.value(i);
let source_data = source_centroid
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<arrow_array::types::Float32Type>>()
.expect("Centroid should be Float32Array");
let target_data = target_centroid
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<arrow_array::types::Float32Type>>()
.expect("Centroid should be Float32Array");
assert_eq!(
source_data.values(),
target_data.values(),
"Empty dataset should have identical centroids from source"
);
}
} else {
panic!("Both source and empty target should have centroids");
}
let new_data_reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(32.into()))
.into_reader_rows(RowCount::from(100), BatchCount::from(1));
target_dataset.append(new_data_reader, None).await.unwrap();
use lance_index::optimize::OptimizeOptions;
target_dataset
.optimize_indices(&OptimizeOptions::merge(10))
.await
.unwrap();
let target_dataset = Dataset::open(&target_uri).await.unwrap();
let index_stats = target_dataset
.index_statistics("vector_ivf_pq")
.await
.unwrap();
let stats_json: serde_json::Value = serde_json::from_str(&index_stats).unwrap();
assert_eq!(
stats_json["num_indices"], 1,
"Should have only 1 merged index after optimize with high num_indices_to_merge"
);
assert_eq!(
stats_json["num_indexed_fragments"], 1,
"Should have indexed the appended fragment (empty dataset has no fragments)"
);
assert_eq!(
stats_json["num_unindexed_fragments"], 0,
"All fragments should be indexed after optimization"
);
let query_vector = lance_datagen::gen_batch()
.anon_col(array::rand_vec::<Float32Type>(32.into()))
.into_batch_rows(RowCount::from(1))
.unwrap()
.column(0)
.clone();
let query_vector = query_vector
.as_any()
.downcast_ref::<arrow_array::FixedSizeListArray>()
.unwrap();
let results = target_dataset
.scan()
.nearest("vector", &query_vector.value(0), 5)
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(
results.num_rows(),
5,
"Should return 5 nearest neighbors after optimizing index"
);
let target_indices = target_dataset.load_indices().await.unwrap();
let target_vector_index = target_dataset
.open_vector_index(
"vector",
&target_indices[0].uuid.to_string(),
&NoOpMetricsCollector,
)
.await
.unwrap();
let target_ivf_model = target_vector_index.ivf_model();
if let (Some(source_centroids), Some(target_centroids)) =
(&source_ivf_model.centroids, &target_ivf_model.centroids)
{
for i in 0..source_centroids.len() {
let source_centroid = source_centroids.value(i);
let target_centroid = target_centroids.value(i);
let source_data = source_centroid
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<arrow_array::types::Float32Type>>()
.expect("Centroid should be Float32Array");
let target_data = target_centroid
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<arrow_array::types::Float32Type>>()
.expect("Centroid should be Float32Array");
assert_eq!(
source_data.values(),
target_data.values(),
"Centroids should remain identical after optimize_indices"
);
}
}
}
#[tokio::test]
async fn test_initialize_vector_index_ivf_sq() {
let test_dir = TempStrDir::default();
let source_uri = format!("{}/source", test_dir.as_str());
let target_uri = format!("{}/target", test_dir.as_str());
let source_reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(32.into()))
.into_reader_rows(RowCount::from(400), BatchCount::from(1));
let mut source_dataset = Dataset::write(source_reader, &source_uri, None)
.await
.unwrap();
use lance_index::vector::ivf::IvfBuildParams;
use lance_index::vector::sq::builder::SQBuildParams;
let ivf_params = IvfBuildParams::new(6);
let sq_params = SQBuildParams::default();
let params = VectorIndexParams::with_ivf_sq_params(MetricType::Dot, ivf_params, sq_params);
source_dataset
.create_index(
&["vector"],
IndexType::Vector,
Some("vector_ivf_sq".to_string()),
¶ms,
false,
)
.await
.unwrap();
let source_dataset = Dataset::open(&source_uri).await.unwrap();
let source_indices = source_dataset.load_indices().await.unwrap();
let source_index = source_indices
.iter()
.find(|idx| idx.name == "vector_ivf_sq")
.unwrap();
let target_reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(32.into()))
.into_reader_rows(RowCount::from(400), BatchCount::from(1));
let mut target_dataset = Dataset::write(target_reader, &target_uri, None)
.await
.unwrap();
initialize_vector_index(
&mut target_dataset,
&source_dataset,
source_index,
&["vector"],
)
.await
.unwrap();
let target_indices = target_dataset.load_indices().await.unwrap();
assert_eq!(target_indices.len(), 1, "Target should have 1 index");
assert_eq!(
target_indices[0].name, "vector_ivf_sq",
"Index name should match"
);
assert_eq!(
target_indices[0].fields,
vec![1],
"Index should be on field 1 (vector)"
);
let target_vector_index = target_dataset
.open_vector_index(
"vector",
&target_indices[0].uuid.to_string(),
&NoOpMetricsCollector,
)
.await
.unwrap();
let stats = target_vector_index.statistics().unwrap();
assert_eq!(
stats.get("index_type").and_then(|v| v.as_str()),
Some("IVF_SQ"),
"Index type should be IVF_SQ"
);
let metric = stats
.get("metric_type")
.and_then(|v| v.as_str())
.unwrap_or("");
assert!(
metric == "dot" || metric == "Dot",
"Metric type should be Dot, got: {}",
metric
);
assert_eq!(
stats.get("num_partitions").and_then(|v| v.as_u64()),
Some(6),
"Should have 6 partitions"
);
let source_vector_index = source_dataset
.open_vector_index(
"vector",
&source_index.uuid.to_string(),
&NoOpMetricsCollector,
)
.await
.unwrap();
let source_ivf_model = source_vector_index.ivf_model();
let target_ivf_model = target_vector_index.ivf_model();
assert_eq!(
source_ivf_model.num_partitions(),
target_ivf_model.num_partitions(),
"Source and target should have same number of partitions"
);
if let (Some(source_centroids), Some(target_centroids)) =
(&source_ivf_model.centroids, &target_ivf_model.centroids)
{
assert_eq!(
source_centroids.len(),
target_centroids.len(),
"Centroids arrays should have same length"
);
for i in 0..source_centroids.len() {
let source_centroid = source_centroids.value(i);
let target_centroid = target_centroids.value(i);
let source_data = source_centroid
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<arrow_array::types::Float32Type>>()
.expect("Centroid should be Float32Array");
let target_data = target_centroid
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<arrow_array::types::Float32Type>>()
.expect("Centroid should be Float32Array");
assert_eq!(
source_data.values(),
target_data.values(),
"Centroid {} values should be identical between source and target",
i
);
}
} else {
panic!("Both source and target should have centroids");
}
let source_ivf_params = derive_ivf_params(source_ivf_model);
let target_ivf_params = derive_ivf_params(target_ivf_model);
assert_eq!(
source_ivf_params.num_partitions, target_ivf_params.num_partitions,
"IVF num_partitions should match"
);
assert_eq!(
target_ivf_params.num_partitions,
Some(6),
"Should have 6 partitions as configured"
);
let source_quantizer = source_vector_index.quantizer();
let target_quantizer = target_vector_index.quantizer();
let source_sq: ScalarQuantizer = source_quantizer.try_into().unwrap();
let target_sq: ScalarQuantizer = target_quantizer.try_into().unwrap();
let source_sq_params = derive_sq_params(&source_sq);
let target_sq_params = derive_sq_params(&target_sq);
assert_eq!(
source_sq_params.num_bits, target_sq_params.num_bits,
"SQ num_bits should match"
);
let query_vector = lance_datagen::gen_batch()
.anon_col(array::rand_vec::<Float32Type>(32.into()))
.into_batch_rows(RowCount::from(1))
.unwrap()
.column(0)
.clone();
let query_vector = query_vector
.as_any()
.downcast_ref::<arrow_array::FixedSizeListArray>()
.unwrap();
let results = target_dataset
.scan()
.nearest("vector", &query_vector.value(0), 15)
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), 15, "Should return 15 nearest neighbors");
}
#[tokio::test]
async fn test_initialize_vector_index_ivf_hnsw_pq() {
let test_dir = TempStrDir::default();
let source_uri = format!("{}/source", test_dir.as_str());
let target_uri = format!("{}/target", test_dir.as_str());
let source_reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(32.into()))
.into_reader_rows(RowCount::from(400), BatchCount::from(1));
let mut source_dataset = Dataset::write(source_reader, &source_uri, None)
.await
.unwrap();
let ivf_params = IvfBuildParams {
num_partitions: Some(8),
..Default::default()
};
let hnsw_params = HnswBuildParams {
max_level: 6,
m: 24,
ef_construction: 120,
prefetch_distance: None,
};
let pq_params = PQBuildParams {
num_sub_vectors: 8,
num_bits: 8,
..Default::default()
};
let params = VectorIndexParams::with_ivf_hnsw_pq_params(
MetricType::L2,
ivf_params,
hnsw_params,
pq_params,
);
source_dataset
.create_index(
&["vector"],
IndexType::Vector,
Some("vector_ivf_hnsw_pq".to_string()),
¶ms,
false,
)
.await
.unwrap();
let source_dataset = Dataset::open(&source_uri).await.unwrap();
let source_indices = source_dataset.load_indices().await.unwrap();
let source_index = source_indices
.iter()
.find(|idx| idx.name == "vector_ivf_hnsw_pq")
.unwrap();
let target_reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(32.into()))
.into_reader_rows(RowCount::from(100), BatchCount::from(1));
let mut target_dataset = Dataset::write(target_reader, &target_uri, None)
.await
.unwrap();
initialize_vector_index(
&mut target_dataset,
&source_dataset,
source_index,
&["vector"],
)
.await
.unwrap();
let target_indices = target_dataset.load_indices().await.unwrap();
assert_eq!(target_indices.len(), 1, "Target should have 1 index");
assert_eq!(
target_indices[0].name, "vector_ivf_hnsw_pq",
"Index name should match"
);
let target_vector_index = target_dataset
.open_vector_index(
"vector",
&target_indices[0].uuid.to_string(),
&NoOpMetricsCollector,
)
.await
.unwrap();
let stats = target_vector_index.statistics().unwrap();
assert_eq!(
stats.get("index_type").and_then(|v| v.as_str()),
Some("IVF_HNSW_PQ"),
"Index type should be IVF_HNSW_PQ"
);
assert_eq!(
stats.get("metric_type").and_then(|v| v.as_str()),
Some("l2"),
"Metric type should be L2"
);
assert_eq!(
stats.get("num_partitions").and_then(|v| v.as_u64()),
Some(8),
"Should have 8 partitions"
);
let source_vector_index = source_dataset
.open_vector_index(
"vector",
&source_index.uuid.to_string(),
&NoOpMetricsCollector,
)
.await
.unwrap();
let source_ivf_model = source_vector_index.ivf_model();
let target_ivf_model = target_vector_index.ivf_model();
assert_eq!(
source_ivf_model.num_partitions(),
target_ivf_model.num_partitions(),
"Source and target should have same number of partitions"
);
if let (Some(source_centroids), Some(target_centroids)) =
(&source_ivf_model.centroids, &target_ivf_model.centroids)
{
assert_eq!(
source_centroids.len(),
target_centroids.len(),
"Centroids arrays should have same length"
);
let source_centroid = source_centroids.value(0);
let target_centroid = target_centroids.value(0);
let source_data = source_centroid
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<arrow_array::types::Float32Type>>()
.expect("Centroid should be Float32Array");
let target_data = target_centroid
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<arrow_array::types::Float32Type>>()
.expect("Centroid should be Float32Array");
assert_eq!(
source_data.values(),
target_data.values(),
"Centroid values should be identical between source and target"
);
} else {
panic!("Both source and target should have centroids");
}
let sub_index = stats
.get("sub_index")
.and_then(|v| v.as_object())
.expect("IVF_HNSW_PQ index should have sub_index");
assert_eq!(
sub_index.get("nbits").and_then(|v| v.as_u64()),
Some(8),
"PQ should use 8 bits"
);
assert_eq!(
sub_index.get("num_sub_vectors").and_then(|v| v.as_u64()),
Some(8),
"PQ should have 8 sub vectors"
);
let source_ivf_params = derive_ivf_params(source_ivf_model);
let target_ivf_params = derive_ivf_params(target_ivf_model);
assert_eq!(
source_ivf_params.num_partitions, target_ivf_params.num_partitions,
"IVF num_partitions should match"
);
assert_eq!(
target_ivf_params.num_partitions,
Some(8),
"Should have 8 partitions as configured"
);
let source_quantizer = source_vector_index.quantizer();
let target_quantizer = target_vector_index.quantizer();
let source_pq: ProductQuantizer = source_quantizer.try_into().unwrap();
let target_pq: ProductQuantizer = target_quantizer.try_into().unwrap();
let source_pq_params = derive_pq_params(&source_pq);
let target_pq_params = derive_pq_params(&target_pq);
assert_eq!(
source_pq_params.num_sub_vectors, target_pq_params.num_sub_vectors,
"PQ num_sub_vectors should match"
);
assert_eq!(
source_pq_params.num_bits, target_pq_params.num_bits,
"PQ num_bits should match"
);
assert_eq!(
target_pq_params.num_sub_vectors, 8,
"PQ should have 8 sub vectors"
);
assert_eq!(target_pq_params.num_bits, 8, "PQ should use 8 bits");
let derived_hnsw_params = derive_hnsw_params(target_vector_index.as_ref());
assert_eq!(
derived_hnsw_params.max_level, 6,
"HNSW max_level should be extracted as 6 from source index"
);
assert_eq!(
derived_hnsw_params.m, 24,
"HNSW m should be extracted as 24 from source index"
);
assert_eq!(
derived_hnsw_params.ef_construction, 120,
"HNSW ef_construction should be extracted as 120 from source index"
);
let query_vector = lance_datagen::gen_batch()
.anon_col(array::rand_vec::<Float32Type>(32.into()))
.into_batch_rows(RowCount::from(1))
.unwrap()
.column(0)
.clone();
let query_vector = query_vector
.as_any()
.downcast_ref::<arrow_array::FixedSizeListArray>()
.unwrap();
let results = target_dataset
.scan()
.nearest("vector", &query_vector.value(0), 5)
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), 5, "Should return 5 nearest neighbors");
}
#[tokio::test]
async fn test_initialize_vector_index_ivf_hnsw_sq() {
let test_dir = TempStrDir::default();
let source_uri = format!("{}/source", test_dir.as_str());
let target_uri = format!("{}/target", test_dir.as_str());
let source_reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(32.into()))
.into_reader_rows(RowCount::from(300), BatchCount::from(1));
let mut source_dataset = Dataset::write(source_reader, &source_uri, None)
.await
.unwrap();
let ivf_params = IvfBuildParams {
num_partitions: Some(6),
..Default::default()
};
let hnsw_params = HnswBuildParams {
max_level: 5,
m: 16,
ef_construction: 80,
prefetch_distance: None,
};
let sq_params = SQBuildParams {
num_bits: 8,
..Default::default()
};
let params = VectorIndexParams::with_ivf_hnsw_sq_params(
MetricType::Cosine,
ivf_params,
hnsw_params,
sq_params,
);
source_dataset
.create_index(
&["vector"],
IndexType::Vector,
Some("vector_ivf_hnsw_sq".to_string()),
¶ms,
false,
)
.await
.unwrap();
let source_dataset = Dataset::open(&source_uri).await.unwrap();
let source_indices = source_dataset.load_indices().await.unwrap();
let source_index = source_indices
.iter()
.find(|idx| idx.name == "vector_ivf_hnsw_sq")
.unwrap();
let target_reader = lance_datagen::gen_batch()
.col("id", array::step::<Int32Type>())
.col("vector", array::rand_vec::<Float32Type>(32.into()))
.into_reader_rows(RowCount::from(100), BatchCount::from(1));
let mut target_dataset = Dataset::write(target_reader, &target_uri, None)
.await
.unwrap();
initialize_vector_index(
&mut target_dataset,
&source_dataset,
source_index,
&["vector"],
)
.await
.unwrap();
let target_indices = target_dataset.load_indices().await.unwrap();
assert_eq!(target_indices.len(), 1, "Target should have 1 index");
assert_eq!(
target_indices[0].name, "vector_ivf_hnsw_sq",
"Index name should match"
);
let target_vector_index = target_dataset
.open_vector_index(
"vector",
&target_indices[0].uuid.to_string(),
&NoOpMetricsCollector,
)
.await
.unwrap();
let stats = target_vector_index.statistics().unwrap();
assert_eq!(
stats.get("index_type").and_then(|v| v.as_str()),
Some("IVF_HNSW_SQ"),
"Index type should be IVF_HNSW_SQ"
);
assert_eq!(
stats.get("metric_type").and_then(|v| v.as_str()),
Some("cosine"),
"Metric type should be cosine"
);
assert_eq!(
stats.get("num_partitions").and_then(|v| v.as_u64()),
Some(6),
"Should have 6 partitions"
);
let source_vector_index = source_dataset
.open_vector_index(
"vector",
&source_index.uuid.to_string(),
&NoOpMetricsCollector,
)
.await
.unwrap();
let source_ivf_model = source_vector_index.ivf_model();
let target_ivf_model = target_vector_index.ivf_model();
assert_eq!(
source_ivf_model.num_partitions(),
target_ivf_model.num_partitions(),
"Source and target should have same number of partitions"
);
let sub_index = stats
.get("sub_index")
.and_then(|v| v.as_object())
.expect("IVF_HNSW_SQ index should have sub_index");
assert_eq!(
sub_index.get("num_bits").and_then(|v| v.as_u64()),
Some(8),
"SQ should use 8 bits"
);
if let (Some(source_centroids), Some(target_centroids)) =
(&source_ivf_model.centroids, &target_ivf_model.centroids)
{
assert_eq!(
source_centroids.len(),
target_centroids.len(),
"Centroids arrays should have same length"
);
for i in 0..source_centroids.len() {
let source_centroid = source_centroids.value(i);
let target_centroid = target_centroids.value(i);
let source_data = source_centroid
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<arrow_array::types::Float32Type>>()
.expect("Centroid should be Float32Array");
let target_data = target_centroid
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<arrow_array::types::Float32Type>>()
.expect("Centroid should be Float32Array");
assert_eq!(
source_data.values(),
target_data.values(),
"Centroid {} values should be identical between source and target",
i
);
}
} else {
panic!("Both source and target should have centroids");
}
let source_ivf_params = derive_ivf_params(source_ivf_model);
let target_ivf_params = derive_ivf_params(target_ivf_model);
assert_eq!(
source_ivf_params.num_partitions, target_ivf_params.num_partitions,
"IVF num_partitions should match"
);
assert_eq!(
target_ivf_params.num_partitions,
Some(6),
"Should have 6 partitions as configured"
);
let source_quantizer = source_vector_index.quantizer();
let target_quantizer = target_vector_index.quantizer();
let source_sq: ScalarQuantizer = source_quantizer.try_into().unwrap();
let target_sq: ScalarQuantizer = target_quantizer.try_into().unwrap();
let source_sq_params = derive_sq_params(&source_sq);
let target_sq_params = derive_sq_params(&target_sq);
assert_eq!(
source_sq_params.num_bits, target_sq_params.num_bits,
"SQ num_bits should match"
);
assert_eq!(target_sq_params.num_bits, 8, "SQ should use 8 bits");
let derived_hnsw_params = derive_hnsw_params(target_vector_index.as_ref());
assert_eq!(
derived_hnsw_params.max_level, 5,
"HNSW max_level should be extracted as 5 from source index"
);
assert_eq!(
derived_hnsw_params.m, 16,
"HNSW m should be extracted as 16 from source index"
);
assert_eq!(
derived_hnsw_params.ef_construction, 80,
"HNSW ef_construction should be extracted as 80 from source index"
);
let query_vector = lance_datagen::gen_batch()
.anon_col(array::rand_vec::<Float32Type>(32.into()))
.into_batch_rows(RowCount::from(1))
.unwrap()
.column(0)
.clone();
let query_vector = query_vector
.as_any()
.downcast_ref::<arrow_array::FixedSizeListArray>()
.unwrap();
let results = target_dataset
.scan()
.nearest("vector", &query_vector.value(0), 5)
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), 5, "Should return 5 nearest neighbors");
}
}