use std::collections::HashSet;
use std::future;
use std::sync::Arc;
use std::{collections::HashMap, pin::Pin};
use arrow::array::{AsArray as _, PrimitiveBuilder, UInt32Builder, UInt64Builder};
use arrow::compute::sort_to_indices;
use arrow::datatypes::{self};
use arrow::datatypes::{Float16Type, Float64Type, UInt8Type, UInt64Type};
use arrow_array::types::Float32Type;
use arrow_array::{
Array, ArrayRef, ArrowPrimitiveType, BooleanArray, FixedSizeListArray, PrimitiveArray,
RecordBatch, UInt32Array, UInt64Array,
};
use arrow_schema::{DataType, Field, Fields};
use futures::{FutureExt, stream};
use futures::{
Stream,
prelude::stream::{StreamExt, TryStreamExt},
};
use itertools::Itertools;
use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt};
use lance_core::ROW_ID;
use lance_core::datatypes::Schema;
use lance_core::utils::tempfile::TempStdDir;
use lance_core::utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu};
use lance_core::{Error, ROW_ID_FIELD, Result};
use lance_encoding::version::LanceFileVersion;
use lance_file::writer::{FileWriter, FileWriterOptions};
use lance_index::frag_reuse::FragReuseIndex;
use lance_index::metrics::NoOpMetricsCollector;
use lance_index::optimize::OptimizeOptions;
use lance_index::progress::{IndexBuildProgress, NoopIndexBuildProgress};
use lance_index::vector::bq::storage::{RABIT_CODE_COLUMN, pack_codes, unpack_codes};
use lance_index::vector::kmeans::KMeansParams;
use lance_index::vector::pq::storage::transpose;
use lance_index::vector::quantizer::{
QuantizationMetadata, QuantizationType, QuantizerBuildParams,
};
use lance_index::vector::quantizer::{QuantizerMetadata, QuantizerStorage};
use lance_index::vector::shared::{SupportedIvfIndexType, write_unified_ivf_and_index_metadata};
use lance_index::vector::storage::STORAGE_METADATA_KEY;
use lance_index::vector::transform::Flatten;
use lance_index::vector::v3::shuffler::{EmptyReader, IvfShufflerReader};
use lance_index::vector::v3::subindex::SubIndexType;
use lance_index::vector::{LOSS_METADATA_KEY, PART_ID_COLUMN, PQ_CODE_COLUMN, VectorIndex};
use lance_index::vector::{PART_ID_FIELD, ivf::storage::IvfModel};
use lance_index::{
INDEX_AUXILIARY_FILE_NAME, INDEX_FILE_NAME, pb,
vector::{
DISTANCE_TYPE_KEY,
ivf::{IvfBuildParams, storage::IVF_METADATA_KEY},
quantizer::Quantization,
storage::{StorageBuilder, VectorStore},
transform::Transformer,
v3::{
shuffler::{ShuffleReader, Shuffler},
subindex::IvfSubIndex,
},
},
};
use lance_index::{
INDEX_METADATA_SCHEMA_KEY, IndexMetadata, IndexType, MAX_PARTITION_SIZE_FACTOR,
MIN_PARTITION_SIZE_PERCENT,
};
use lance_io::local::to_local_path;
use lance_io::stream::RecordBatchStream;
use lance_io::{object_store::ObjectStore, stream::RecordBatchStreamAdapter};
use lance_linalg::distance::{DistanceType, Dot, L2, Normalize};
use lance_linalg::kernels::normalize_fsl;
use log::info;
use object_store::path::Path;
use prost::Message;
use tracing::{Level, instrument, span};
use crate::Dataset;
use crate::dataset::ProjectionRequest;
use crate::dataset::index::dataset_format_version;
use crate::index::vector::ivf::v2::PartitionEntry;
use crate::index::vector::utils::{infer_vector_dim, infer_vector_element_type};
use super::v2::IVFIndex;
use super::{
ivf::load_precomputed_partitions_if_available,
utils::{self, get_vector_type},
};
fn stable_sort_batch_by_row_id(batch: &RecordBatch) -> Result<RecordBatch> {
if let Some(row_id_col) = batch.column_by_name(ROW_ID) {
let row_ids = row_id_col.as_primitive::<UInt64Type>();
if row_ids.len() > 1 {
let mut order: Vec<usize> = (0..row_ids.len()).collect();
order.sort_by(|&i, &j| row_ids.value(i).cmp(&row_ids.value(j)));
let indices = UInt32Array::from_iter_values(order.into_iter().map(|i| i as u32));
return Ok(batch.take(&indices)?);
}
}
Ok(batch.clone())
}
const REASSIGN_RANGE: usize = 64;
pub struct IvfIndexBuilder<S: IvfSubIndex, Q: Quantization> {
store: ObjectStore,
column: String,
index_dir: Path,
distance_type: DistanceType,
dataset: Option<Dataset>,
shuffler: Option<Arc<dyn Shuffler>>,
ivf_params: Option<IvfBuildParams>,
quantizer_params: Option<Q::BuildParams>,
sub_index_params: Option<S::BuildParams>,
_temp_dir: TempStdDir, temp_dir: Path,
ivf: Option<IvfModel>,
quantizer: Option<Q>,
shuffle_reader: Option<Arc<dyn ShuffleReader>>,
existing_indices: Vec<Arc<dyn VectorIndex>>,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
fragment_filter: Option<Vec<u32>>,
optimize_options: Option<OptimizeOptions>,
merged_num: usize,
transpose_codes: bool,
format_version: LanceFileVersion,
progress: Arc<dyn IndexBuildProgress>,
}
type BuildStream<S, Q> =
Pin<Box<dyn Stream<Item = Result<Option<(<Q as Quantization>::Storage, S, f64)>>> + Send>>;
impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q> {
#[allow(clippy::too_many_arguments)]
pub fn new(
dataset: Dataset,
column: String,
index_dir: Path,
distance_type: DistanceType,
shuffler: Box<dyn Shuffler>,
ivf_params: Option<IvfBuildParams>,
quantizer_params: Option<Q::BuildParams>,
sub_index_params: S::BuildParams,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
) -> Result<Self> {
let temp_dir = TempStdDir::default();
let temp_dir_path = Path::from_filesystem_path(&temp_dir)?;
let format_version = dataset_format_version(&dataset);
Ok(Self {
store: dataset.object_store().clone(),
column,
index_dir,
distance_type,
dataset: Some(dataset),
shuffler: Some(shuffler.into()),
ivf_params,
quantizer_params,
sub_index_params: Some(sub_index_params),
_temp_dir: temp_dir,
temp_dir: temp_dir_path,
ivf: None,
quantizer: None,
shuffle_reader: None,
existing_indices: Vec::new(),
frag_reuse_index,
fragment_filter: None,
optimize_options: None,
merged_num: 0,
transpose_codes: true,
format_version,
progress: Arc::new(NoopIndexBuildProgress),
})
}
#[allow(clippy::too_many_arguments)]
pub fn new_incremental(
dataset: Dataset,
column: String,
index_dir: Path,
distance_type: DistanceType,
shuffler: Box<dyn Shuffler>,
sub_index_params: S::BuildParams,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
optimize_options: OptimizeOptions,
) -> Result<Self> {
let mut builder = Self::new(
dataset,
column,
index_dir,
distance_type,
shuffler,
None,
None,
sub_index_params,
frag_reuse_index,
)?;
builder.optimize_options = Some(optimize_options);
Ok(builder)
}
pub fn new_remapper(
dataset: Dataset,
column: String,
index_dir: Path,
index: Arc<dyn VectorIndex>,
) -> Result<Self> {
let ivf_index = index
.as_any()
.downcast_ref::<IVFIndex<S, Q>>()
.ok_or(Error::invalid_input("existing index is not IVF index"))?;
let temp_dir = TempStdDir::default();
let temp_dir_path = Path::from_filesystem_path(&temp_dir)?;
let format_version = dataset_format_version(&dataset);
Ok(Self {
store: dataset.object_store().clone(),
column,
index_dir,
distance_type: ivf_index.metric_type(),
dataset: Some(dataset),
shuffler: None,
ivf_params: None,
quantizer_params: None,
sub_index_params: None,
_temp_dir: temp_dir,
temp_dir: temp_dir_path,
ivf: Some(ivf_index.ivf_model().clone()),
quantizer: Some(ivf_index.quantizer().try_into()?),
shuffle_reader: None,
existing_indices: vec![index],
frag_reuse_index: None,
fragment_filter: None,
optimize_options: None,
merged_num: 0,
transpose_codes: true,
format_version,
progress: Arc::new(NoopIndexBuildProgress),
})
}
pub async fn build(&mut self) -> Result<usize> {
let progress = self.progress.clone();
let max_iters = self.ivf_params.as_ref().map(|p| p.max_iters as u64);
progress
.stage_start("train_ivf", max_iters, "iterations")
.await?;
self.with_ivf(self.load_or_build_ivf().await?);
progress.stage_complete("train_ivf").await?;
progress.stage_start("train_quantizer", None, "").await?;
self.with_quantizer(self.load_or_build_quantizer().await?);
progress.stage_complete("train_quantizer").await?;
if self.shuffle_reader.is_none() {
let num_rows = self.num_rows_to_shuffle().await?;
progress.stage_start("shuffle", num_rows, "rows").await?;
self.shuffle_dataset().await?;
progress.stage_complete("shuffle").await?;
}
let num_partitions = self.ivf.as_ref().map(|ivf| ivf.num_partitions() as u64);
progress
.stage_start("merge_partitions", num_partitions, "partitions")
.await?;
let build_idx_stream = self.build_partitions().boxed().await?;
self.merge_partitions(build_idx_stream).await?;
progress.stage_complete("merge_partitions").await?;
Ok(self.merged_num)
}
pub async fn remap(&mut self, mapping: &HashMap<u64, Option<u64>>) -> Result<()> {
if self.existing_indices.is_empty() {
return Err(Error::invalid_input(
"No existing indices available for remapping",
));
}
let Some(ivf) = self.ivf.as_ref() else {
return Err(Error::invalid_input("IVF model not set before remapping"));
};
log::info!("remap {} partitions", ivf.num_partitions());
let existing_index = self.existing_indices[0].clone();
let mapping = Arc::new(mapping.clone());
let build_iter =
(0..ivf.num_partitions()).map(move |part_id| {
let existing_index = existing_index.clone();
let mapping = mapping.clone();
async move {
let ivf_index = existing_index
.as_any()
.downcast_ref::<IVFIndex<S, Q>>()
.ok_or(Error::invalid_input("existing index is not IVF index"))?;
let part = ivf_index
.load_partition(part_id, false, &NoOpMetricsCollector)
.await?;
let part = part.as_any().downcast_ref::<PartitionEntry<S, Q>>().ok_or(
Error::internal("failed to downcast partition entry".to_string()),
)?;
let storage = part.storage.remap(&mapping)?;
let index = part.index.remap(&mapping, &storage)?;
Result::Ok(Some((storage, index, 0.0)))
}
});
self.merge_partitions(
stream::iter(build_iter)
.buffered(get_num_compute_intensive_cpus())
.boxed(),
)
.await?;
Ok(())
}
pub fn with_ivf(&mut self, ivf: IvfModel) -> &mut Self {
self.ivf = Some(ivf);
self
}
pub fn with_quantizer(&mut self, quantizer: Q) -> &mut Self {
self.quantizer = Some(quantizer);
self
}
pub fn with_existing_indices(&mut self, indices: Vec<Arc<dyn VectorIndex>>) -> &mut Self {
self.existing_indices = indices;
self
}
pub fn with_fragment_filter(&mut self, fragment_ids: Vec<u32>) -> &mut Self {
self.fragment_filter = Some(fragment_ids);
self
}
pub fn with_transpose(&mut self, transpose: bool) -> &mut Self {
self.transpose_codes = transpose;
self
}
pub fn with_progress(&mut self, progress: Arc<dyn IndexBuildProgress>) -> &mut Self {
self.progress = progress;
self
}
#[instrument(name = "load_or_build_ivf", level = "debug", skip_all)]
async fn load_or_build_ivf(&self) -> Result<IvfModel> {
match &self.ivf {
Some(ivf) => Ok(ivf.clone()),
None => {
let Some(dataset) = self.dataset.as_ref() else {
return Err(Error::invalid_input(
"dataset not set before loading or building IVF",
));
};
let dim = utils::get_vector_dim(dataset.schema(), &self.column)?;
let ivf_params = self
.ivf_params
.as_ref()
.ok_or(Error::invalid_input("IVF build params not set"))?;
super::build_ivf_model(
dataset,
&self.column,
dim,
self.distance_type,
ivf_params,
self.progress.clone(),
)
.await
}
}
}
#[instrument(name = "load_or_build_quantizer", level = "debug", skip_all)]
async fn load_or_build_quantizer(&self) -> Result<Q> {
if self.quantizer.is_some() {
return Ok(self.quantizer.clone().unwrap());
}
let Some(dataset) = self.dataset.as_ref() else {
return Err(Error::invalid_input(
"dataset not set before loading or building quantizer",
));
};
let sample_size_hint = match &self.quantizer_params {
Some(params) => params.sample_size(),
None => 256 * 256, };
let start = std::time::Instant::now();
info!(
"loading training data for quantizer. sample size: {}",
sample_size_hint
);
let training_data =
utils::maybe_sample_training_data(dataset, &self.column, sample_size_hint).await?;
info!(
"Finished loading training data in {:02} seconds",
start.elapsed().as_secs_f32()
);
let training_data = if self.distance_type == DistanceType::Cosine {
lance_linalg::kernels::normalize_fsl_owned(training_data)?
} else {
training_data
};
let training_data = utils::filter_finite_training_data(training_data)?;
let training_data = match (self.ivf.as_ref(), Q::use_residual(self.distance_type)) {
(Some(ivf), true) => {
let ivf_transformer = lance_index::vector::ivf::new_ivf_transformer(
ivf.centroids.clone().unwrap(),
DistanceType::L2,
vec![],
);
span!(Level::INFO, "compute residual for PQ training")
.in_scope(|| ivf_transformer.compute_residual(&training_data))?
}
_ => training_data,
};
info!("Start to train quantizer");
let start = std::time::Instant::now();
let quantizer = match &self.quantizer {
Some(q) => q.clone(),
None => {
let quantizer_params = self
.quantizer_params
.as_ref()
.ok_or(Error::invalid_input("quantizer build params not set"))?;
Q::build(&training_data, DistanceType::L2, quantizer_params)?
}
};
info!(
"Trained quantizer in {:02} seconds",
start.elapsed().as_secs_f32()
);
Ok(quantizer)
}
fn rename_row_id(
stream: impl RecordBatchStream + Unpin + 'static,
row_id_idx: usize,
) -> impl RecordBatchStream + Unpin + 'static {
let new_schema = Arc::new(arrow_schema::Schema::new(
stream
.schema()
.fields
.iter()
.enumerate()
.map(|(field_idx, field)| {
if field_idx == row_id_idx {
arrow_schema::Field::new(
ROW_ID,
field.data_type().clone(),
field.is_nullable(),
)
} else {
field.as_ref().clone()
}
})
.collect::<Fields>(),
));
RecordBatchStreamAdapter::new(
new_schema.clone(),
stream.map_ok(move |batch| {
RecordBatch::try_new(new_schema.clone(), batch.columns().to_vec()).unwrap()
}),
)
}
async fn num_rows_to_shuffle(&self) -> Result<Option<u64>> {
let Some(dataset) = self.dataset.as_ref() else {
return Ok(None);
};
match &self.fragment_filter {
Some(fragment_ids) => {
let fragments: Vec<_> = dataset
.get_fragments()
.into_iter()
.filter(|f| fragment_ids.contains(&(f.id() as u32)))
.collect();
let counts = futures::stream::iter(fragments)
.map(|f| async move { f.count_rows(None).await })
.buffer_unordered(16) .try_collect::<Vec<_>>()
.await?;
Ok(Some(counts.iter().sum::<usize>() as u64))
}
None => Ok(Some(dataset.count_rows(None).await? as u64)),
}
}
async fn shuffle_dataset(&mut self) -> Result<()> {
let Some(dataset) = self.dataset.as_ref() else {
return Err(Error::invalid_input("dataset not set before shuffling"));
};
let stream = match self
.ivf_params
.as_ref()
.and_then(|p| p.precomputed_shuffle_buffers.as_ref())
{
Some((uri, _)) => {
let uri = to_local_path(uri);
let uri = uri.trim_end_matches("data");
log::info!("shuffle with precomputed shuffle buffers from {}", uri);
let ds = Dataset::open(uri).await?;
ds.scan().try_into_stream().await?
}
_ => {
log::info!("shuffle column {} over dataset", self.column);
let mut builder = dataset.scan();
builder
.batch_readahead(get_num_compute_intensive_cpus())
.project(&[self.column.as_str()])?
.with_row_id();
if let Some(fragment_ids) = &self.fragment_filter {
log::info!(
"applying fragment filter for distributed indexing: {:?}",
fragment_ids
);
let all_fragments = dataset.fragments();
let filtered_fragments: Vec<_> = all_fragments
.iter()
.filter(|fragment| fragment_ids.contains(&(fragment.id as u32)))
.cloned()
.collect();
builder.with_fragments(filtered_fragments);
}
let (vector_type, _) = get_vector_type(dataset.schema(), &self.column)?;
let is_multivector = matches!(vector_type, datatypes::DataType::List(_));
if is_multivector {
builder.batch_size(64);
}
builder.try_into_stream().await?
}
};
if let Some((row_id_idx, _)) = stream.schema().column_with_name("row_id") {
self.shuffle_data(Some(Self::rename_row_id(stream, row_id_idx)))
.await?;
} else {
self.shuffle_data(Some(stream)).await?;
}
Ok(())
}
pub async fn shuffle_data(
&mut self,
data: Option<impl RecordBatchStream + Unpin + 'static>,
) -> Result<&mut Self> {
let Some(ivf) = self.ivf.as_ref() else {
return Err(Error::invalid_input("IVF not set before shuffle data"));
};
let Some(data) = data else {
self.shuffle_reader = Some(Arc::new(EmptyReader));
return Ok(self);
};
let Some(quantizer) = self.quantizer.clone() else {
return Err(Error::invalid_input(
"quantizer not set before shuffle data",
));
};
let Some(shuffler) = self.shuffler.as_ref() else {
return Err(Error::invalid_input("shuffler not set before shuffle data"));
};
let code_column = quantizer.column();
let transformer = Arc::new(
lance_index::vector::ivf::new_ivf_transformer_with_quantizer(
ivf.centroids.clone().unwrap(),
self.distance_type,
&self.column,
quantizer.into(),
None,
)?,
);
let precomputed_partitions = if let Some(params) = self.ivf_params.as_ref() {
load_precomputed_partitions_if_available(params)
.await?
.unwrap_or_default()
} else {
HashMap::new()
};
let partition_map = Arc::new(precomputed_partitions);
let mut transformed_stream = Box::pin(
data.map(move |batch| {
let partition_map = partition_map.clone();
let ivf_transformer = transformer.clone();
tokio::spawn(async move {
let mut batch = batch?;
if !partition_map.is_empty() {
let row_ids = &batch[ROW_ID];
let part_ids = UInt32Array::from_iter(
row_ids
.as_primitive::<UInt64Type>()
.values()
.iter()
.map(|row_id| partition_map.get(row_id).copied()),
);
let part_ids = UInt32Array::from(part_ids);
batch = batch
.try_with_column(PART_ID_FIELD.clone(), Arc::new(part_ids.clone()))
.expect("failed to add part id column");
if part_ids.null_count() > 0 {
log::info!(
"Filter out rows without valid partition IDs: null_count={}",
part_ids.null_count()
);
let indices = UInt32Array::from_iter(
part_ids
.iter()
.enumerate()
.filter_map(|(idx, v)| v.map(|_| idx as u32)),
);
assert_eq!(indices.len(), batch.num_rows() - part_ids.null_count());
batch = batch.take(&indices)?;
}
}
match batch.schema().column_with_name(code_column) {
Some(_) => {
Ok(batch)
}
None => ivf_transformer.transform(&batch),
}
})
})
.buffered(get_num_compute_intensive_cpus())
.map(|x| x.unwrap())
.peekable(),
);
let batch = transformed_stream.as_mut().peek_mut().await;
let schema = match batch {
Some(Ok(b)) => b.schema(),
Some(Err(e)) => return Err(std::mem::replace(e, Error::Stop)),
None => {
log::info!("no data to shuffle");
self.shuffle_reader = Some(Arc::new(IvfShufflerReader::new(
Arc::new(self.store.clone()),
self.temp_dir.clone(),
vec![0; ivf.num_partitions()],
0.0,
)));
return Ok(self);
}
};
self.shuffle_reader = Some(
shuffler
.shuffle(Box::new(RecordBatchStreamAdapter::new(
schema,
transformed_stream,
)))
.await?
.into(),
);
Ok(self)
}
#[instrument(name = "build_partitions", level = "debug", skip_all)]
async fn build_partitions(&mut self) -> Result<BuildStream<S, Q>> {
let Some(ivf) = self.ivf.as_ref() else {
return Err(Error::invalid_input(
"IVF not set before building partitions",
));
};
let Some(quantizer) = self.quantizer.clone() else {
return Err(Error::invalid_input(
"quantizer not set before building partition",
));
};
let Some(sub_index_params) = self.sub_index_params.clone() else {
return Err(Error::invalid_input(
"sub index params not set before building partition",
));
};
let Some(reader) = self.shuffle_reader.as_ref() else {
return Err(Error::invalid_input(
"shuffle reader not set before building partitions",
));
};
let reader = reader.clone();
let num_indices_to_merge = self
.optimize_options
.as_ref()
.and_then(|opt| opt.num_indices_to_merge);
let no_partition_adjustment = || {
let is_retrain = self
.optimize_options
.as_ref()
.map(|opt| opt.retrain)
.unwrap_or(false);
let num_to_merge = match is_retrain {
true => self.existing_indices.len(), false => num_indices_to_merge.unwrap_or(0),
};
let indices_to_merge = self.existing_indices
[self.existing_indices.len().saturating_sub(num_to_merge)..]
.to_vec();
(
vec![None; ivf.num_partitions()],
Arc::new(indices_to_merge),
None,
)
};
let (assign_batches, merge_indices, partition_adjustment) = if num_indices_to_merge
.is_some()
|| self.optimize_options.is_none()
{
no_partition_adjustment()
} else {
match Self::check_partition_adjustment(ivf, reader.as_ref(), &self.existing_indices)? {
Some(partition_adjustment) => match partition_adjustment {
PartitionAdjustment::Split(partition) => {
log::info!(
"split partition {}, will merge all {} delta indices",
partition,
self.existing_indices.len()
);
let split_results = self.split_partition(partition, ivf).await?;
let Some(ivf) = self.ivf.as_mut() else {
return Err(Error::invalid_input(
"IVF not set before building partitions",
));
};
ivf.centroids = Some(split_results.new_centroids);
(
split_results.assign_batches,
Arc::new(self.existing_indices.clone()),
Some(partition_adjustment),
)
}
PartitionAdjustment::Join(partition) => {
log::info!("join partition {}", partition);
let results = self.join_partition(partition, ivf).await?;
let Some(ivf) = self.ivf.as_mut() else {
return Err(Error::invalid_input(
"IVF model not set before joining partition",
));
};
ivf.centroids = Some(results.new_centroids);
(
results.assign_batches,
Arc::new(self.existing_indices.clone()),
Some(partition_adjustment),
)
}
},
None => no_partition_adjustment(),
}
};
self.merged_num = merge_indices.len();
log::info!(
"merge {}/{} delta indices",
self.merged_num,
self.existing_indices.len()
);
let distance_type = self.distance_type;
let column = self.column.clone();
let frag_reuse_index = self.frag_reuse_index.clone();
let build_iter =
assign_batches
.into_iter()
.enumerate()
.map(move |(partition, assign_batch)| {
let reader = reader.clone();
let indices = merge_indices.clone();
let distance_type = distance_type;
let quantizer = quantizer.clone();
let sub_index_params = sub_index_params.clone();
let column = column.clone();
let frag_reuse_index = frag_reuse_index.clone();
let skip_existing_batches =
partition_adjustment == Some(PartitionAdjustment::Split(partition));
let partition = match partition_adjustment {
Some(PartitionAdjustment::Join(joined_partition))
if partition >= joined_partition =>
{
partition + 1
}
_ => partition,
};
async move {
let (mut batches, loss) = if skip_existing_batches {
(Vec::new(), 0.0)
} else {
Self::take_partition_batches(
partition,
indices.as_ref(),
Some(reader.as_ref()),
)
.await?
};
spawn_cpu(move || {
if let Some((assign_batch, deleted_row_ids)) = assign_batch {
if !deleted_row_ids.is_empty() {
let deleted_row_ids = HashSet::<u64>::from_iter(
deleted_row_ids.values().iter().copied(),
);
for batch in batches.iter_mut() {
let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>();
let mask =
BooleanArray::from_iter(row_ids.iter().map(|row_id| {
row_id.map(|row_id| {
!deleted_row_ids.contains(&row_id)
})
}));
*batch = arrow::compute::filter_record_batch(batch, &mask)?;
}
}
if assign_batch.num_rows() > 0 {
let assign_batch = assign_batch.drop_column(PART_ID_COLUMN)?;
batches.push(assign_batch);
}
}
let num_rows = batches.iter().map(|b| b.num_rows()).sum::<usize>();
if num_rows == 0 {
return Ok(None);
}
let (storage, sub_index) = Self::build_index(
distance_type,
quantizer,
sub_index_params,
batches,
column,
frag_reuse_index,
)?;
Ok(Some((storage, sub_index, loss)))
})
.await
}
});
Ok(stream::iter(build_iter)
.buffered(get_num_compute_intensive_cpus())
.boxed())
}
#[instrument(name = "build_index", level = "debug", skip_all)]
#[allow(clippy::too_many_arguments)]
fn build_index(
distance_type: DistanceType,
quantizer: Q,
sub_index_params: S::BuildParams,
batches: Vec<RecordBatch>,
column: String,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
) -> Result<(Q::Storage, S)> {
let storage = StorageBuilder::new(column, distance_type, quantizer, frag_reuse_index)?
.build(batches)?;
let sub_index = S::index_vectors(&storage, sub_index_params)?;
Ok((storage, sub_index))
}
#[instrument(name = "take_partition_batches", level = "debug", skip_all)]
async fn take_partition_batches(
part_id: usize,
existing_indices: &[Arc<dyn VectorIndex>],
reader: Option<&dyn ShuffleReader>,
) -> Result<(Vec<RecordBatch>, f64)> {
let mut batches = Vec::new();
for existing_index in existing_indices.iter() {
let existing_index = existing_index
.as_any()
.downcast_ref::<IVFIndex<S, Q>>()
.ok_or(Error::invalid_input("existing index is not IVF index"))?;
if part_id >= existing_index.ivf_model().num_partitions() {
continue;
}
let part_storage = existing_index.load_partition_storage(part_id).await?;
let mut part_batches = part_storage.to_batches()?.collect::<Vec<_>>();
match Q::quantization_type() {
QuantizationType::Product => {
for batch in part_batches.iter_mut() {
if batch.num_rows() == 0 {
continue;
}
let codes = batch[PQ_CODE_COLUMN]
.as_fixed_size_list()
.values()
.as_primitive::<datatypes::UInt8Type>();
let codes_num_bytes = codes.len() / batch.num_rows();
let original_codes = transpose(codes, codes_num_bytes, batch.num_rows());
let original_codes = FixedSizeListArray::try_new_from_values(
original_codes,
codes_num_bytes as i32,
)?;
*batch = batch
.replace_column_by_name(PQ_CODE_COLUMN, Arc::new(original_codes))?
.drop_column(PART_ID_COLUMN)?;
}
}
QuantizationType::Rabit => {
for batch in part_batches.iter_mut() {
if batch.num_rows() == 0 {
continue;
}
let codes = batch[RABIT_CODE_COLUMN].as_fixed_size_list();
let original_codes = unpack_codes(codes);
*batch = batch
.replace_column_by_name(RABIT_CODE_COLUMN, Arc::new(original_codes))?
.drop_column(PART_ID_COLUMN)?;
}
}
_ => {}
}
for batch in part_batches.iter_mut() {
if batch.num_rows() == 0 {
continue;
}
*batch = stable_sort_batch_by_row_id(batch)?;
}
batches.extend(part_batches);
}
let mut loss = 0.0;
if let Some(reader) = reader
&& reader.partition_size(part_id)? > 0
{
let mut partition_data =
reader
.read_partition(part_id)
.await?
.ok_or(Error::invalid_input(format!(
"partition {} is empty",
part_id
)))?;
while let Some(batch) = partition_data.try_next().await? {
loss += batch
.metadata()
.get(LOSS_METADATA_KEY)
.map(|s| s.parse::<f64>().unwrap_or(0.0))
.unwrap_or(0.0);
let batch = batch.drop_column(PART_ID_COLUMN)?;
let batch = stable_sort_batch_by_row_id(&batch)?;
batches.push(batch);
}
}
Ok((batches, loss))
}
#[instrument(name = "merge_partitions", level = "debug", skip_all)]
async fn merge_partitions(&mut self, mut build_stream: BuildStream<S, Q>) -> Result<()> {
let Some(ivf) = self.ivf.as_ref() else {
return Err(Error::invalid_input("IVF not set before merge partitions"));
};
let Some(quantizer) = self.quantizer.clone() else {
return Err(Error::invalid_input(
"quantizer not set before merge partitions",
));
};
let is_pq = Q::quantization_type() == QuantizationType::Product;
let is_rq = Q::quantization_type() == QuantizationType::Rabit;
let storage_path = self.index_dir.child(INDEX_AUXILIARY_FILE_NAME);
let index_path = self.index_dir.child(INDEX_FILE_NAME);
let mut fields = vec![ROW_ID_FIELD.clone(), quantizer.field()];
fields.extend(quantizer.extra_fields());
let storage_schema: Schema = (&arrow_schema::Schema::new(fields)).try_into()?;
let writer_options = FileWriterOptions {
format_version: Some(self.format_version),
..Default::default()
};
let mut storage_writer = FileWriter::try_new(
self.store.create(&storage_path).await?,
storage_schema.clone(),
writer_options.clone(),
)?;
let mut index_writer = FileWriter::try_new(
self.store.create(&index_path).await?,
S::schema().as_ref().try_into()?,
writer_options,
)?;
let mut storage_ivf = IvfModel::empty();
let mut index_ivf = IvfModel::new(ivf.centroids.clone().unwrap(), ivf.loss);
let mut partition_index_metadata = Vec::with_capacity(ivf.num_partitions());
let mut part_id = 0;
let mut total_loss = 0.0;
let progress = self.progress.clone();
log::info!("merging {} partitions", ivf.num_partitions());
while let Some(part) = build_stream.try_next().await? {
part_id += 1;
progress.stage_progress("merge_partitions", part_id).await?;
let Some((storage, index, loss)) = part else {
log::warn!("partition {} is empty, skipping", part_id);
storage_ivf.add_partition(0);
index_ivf.add_partition(0);
partition_index_metadata.push(String::new());
continue;
};
total_loss += loss;
if storage.len() == 0 {
storage_ivf.add_partition(0);
} else {
let batches = storage.to_batches()?.collect::<Vec<_>>();
let mut batch =
arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?;
if is_pq && batch.column_by_name(PQ_CODE_COLUMN).is_some() {
let codes_fsl = batch
.column_by_name(PQ_CODE_COLUMN)
.unwrap()
.as_fixed_size_list();
let num_rows = batch.num_rows();
let bytes_per_code = codes_fsl.value_length() as usize;
let codes = codes_fsl.values().as_primitive::<datatypes::UInt8Type>();
let original_codes = transpose(codes, bytes_per_code, num_rows);
let original_fsl = Arc::new(FixedSizeListArray::try_new_from_values(
original_codes,
bytes_per_code as i32,
)?);
batch = batch.replace_column_by_name(PQ_CODE_COLUMN, original_fsl)?;
}
if is_rq && batch.column_by_name(RABIT_CODE_COLUMN).is_some() {
let codes_fsl = batch
.column_by_name(RABIT_CODE_COLUMN)
.unwrap()
.as_fixed_size_list();
let unpacked = Arc::new(unpack_codes(codes_fsl));
batch = batch.replace_column_by_name(RABIT_CODE_COLUMN, unpacked)?;
}
batch = stable_sort_batch_by_row_id(&batch)?;
if is_pq && self.transpose_codes && batch.column_by_name(PQ_CODE_COLUMN).is_some() {
let codes_fsl = batch
.column_by_name(PQ_CODE_COLUMN)
.unwrap()
.as_fixed_size_list();
let num_rows = batch.num_rows();
let bytes_per_code = codes_fsl.value_length() as usize;
let codes = codes_fsl.values().as_primitive::<datatypes::UInt8Type>();
let transposed_codes = transpose(codes, num_rows, bytes_per_code);
let transposed_fsl = Arc::new(FixedSizeListArray::try_new_from_values(
transposed_codes,
bytes_per_code as i32,
)?);
batch = batch.replace_column_by_name(PQ_CODE_COLUMN, transposed_fsl)?;
}
if is_rq
&& self.transpose_codes
&& batch.column_by_name(RABIT_CODE_COLUMN).is_some()
{
let codes_fsl = batch
.column_by_name(RABIT_CODE_COLUMN)
.unwrap()
.as_fixed_size_list();
let packed = Arc::new(pack_codes(codes_fsl));
batch = batch.replace_column_by_name(RABIT_CODE_COLUMN, packed)?;
}
storage_writer.write_batch(&batch).await?;
storage_ivf.add_partition(batch.num_rows() as u32);
}
let index_batch = index.to_batch()?;
if index_batch.num_rows() == 0 {
index_ivf.add_partition(0);
partition_index_metadata.push(String::new());
} else {
index_writer.write_batch(&index_batch).await?;
index_ivf.add_partition(index_batch.num_rows() as u32);
partition_index_metadata.push(
index_batch
.schema()
.metadata
.get(S::metadata_key())
.cloned()
.unwrap_or_default(),
);
}
}
match self.shuffle_reader.as_ref() {
Some(reader) => {
if let Some(loss) = reader.total_loss() {
total_loss += loss;
}
index_ivf.loss = Some(total_loss);
}
None => {
}
}
let storage_ivf_pb = pb::Ivf::try_from(&storage_ivf)?;
storage_writer.add_schema_metadata(DISTANCE_TYPE_KEY, self.distance_type.to_string());
let ivf_buffer_pos = storage_writer
.add_global_buffer(storage_ivf_pb.encode_to_vec().into())
.await?;
storage_writer.add_schema_metadata(IVF_METADATA_KEY, ivf_buffer_pos.to_string());
let quant_type = Q::quantization_type();
let transposed = match quant_type {
QuantizationType::Product | QuantizationType::Rabit => self.transpose_codes,
_ => false,
};
let mut metadata = quantizer.metadata(Some(QuantizationMetadata {
codebook_position: Some(0),
codebook: None,
transposed,
}));
if let Some(extra_metadata) = metadata.extra_metadata()? {
let idx = storage_writer.add_global_buffer(extra_metadata).await?;
metadata.set_buffer_index(idx);
}
let metadata = serde_json::to_string(&metadata)?;
let storage_partition_metadata = vec![metadata];
storage_writer.add_schema_metadata(
STORAGE_METADATA_KEY,
serde_json::to_string(&storage_partition_metadata)?,
);
let index_type_str = index_type_string(S::name().try_into()?, Q::quantization_type());
if let Some(idx_type) = SupportedIvfIndexType::from_index_type_str(&index_type_str) {
write_unified_ivf_and_index_metadata(
&mut index_writer,
&index_ivf,
self.distance_type,
idx_type,
)
.await?;
} else {
let index_ivf_pb = pb::Ivf::try_from(&index_ivf)?;
let index_metadata = IndexMetadata {
index_type: index_type_str,
distance_type: self.distance_type.to_string(),
};
index_writer.add_schema_metadata(
INDEX_METADATA_SCHEMA_KEY,
serde_json::to_string(&index_metadata)?,
);
let ivf_buffer_pos = index_writer
.add_global_buffer(index_ivf_pb.encode_to_vec().into())
.await?;
index_writer.add_schema_metadata(IVF_METADATA_KEY, ivf_buffer_pos.to_string());
}
index_writer.add_schema_metadata(
S::metadata_key(),
serde_json::to_string(&partition_index_metadata)?,
);
storage_writer.finish().await?;
index_writer.finish().await?;
log::info!("merging {} partitions done", ivf.num_partitions());
Ok(())
}
async fn take_vectors(
dataset: &Dataset,
column: &str,
store: &ObjectStore,
row_ids: &[u64],
) -> Result<Vec<RecordBatch>> {
let projection = Arc::new(dataset.schema().project(&[column])?);
let mut batches = Vec::new();
let row_ids = dataset.filter_deleted_ids(row_ids).await?;
for chunk in row_ids.chunks(store.block_size()) {
let batch = dataset
.take_rows(chunk, ProjectionRequest::Schema(projection.clone()))
.await?;
if batch.num_rows() != chunk.len() {
return Err(Error::invalid_input(format!(
"batch.num_rows() != chunk.len() ({} != {})",
batch.num_rows(),
chunk.len()
)));
}
let batch = batch.try_with_column(
ROW_ID_FIELD.clone(),
Arc::new(UInt64Array::from(chunk.to_vec())),
)?;
batches.push(batch);
}
Ok(batches)
}
async fn load_partition_raw_vectors(
&self,
part_idx: usize,
) -> Result<Option<(UInt64Array, FixedSizeListArray)>> {
let Some(dataset) = self.dataset.as_ref() else {
return Err(Error::invalid_input(
"dataset not set before split partition",
));
};
let mut row_ids = self.partition_row_ids(part_idx).await?;
if !row_ids.is_sorted() {
row_ids.sort();
}
row_ids.dedup();
let batches = Self::take_vectors(dataset, &self.column, &self.store, &row_ids).await?;
if batches.is_empty() {
return Ok(None);
}
let batch = arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?;
let batch = Flatten::new(&self.column).transform(&batch)?;
let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>().clone();
let vectors = batch
.column_by_qualified_name(&self.column)
.ok_or(Error::invalid_input(format!(
"vector column {} not found in batch {}",
self.column,
batch.schema()
)))?
.as_fixed_size_list()
.clone();
Ok(Some((row_ids, vectors)))
}
fn check_partition_adjustment(
ivf: &IvfModel,
reader: &dyn ShuffleReader,
existing_indices: &[Arc<dyn VectorIndex>],
) -> Result<Option<PartitionAdjustment>> {
let index_type = IndexType::try_from(
index_type_string(S::name().try_into()?, Q::quantization_type()).as_str(),
)?;
let mut split_partition = None;
let mut join_partition = None;
let mut max_partition_size = 0;
let mut min_partition_size = usize::MAX;
for partition in 0..ivf.num_partitions() {
let mut num_rows = reader.partition_size(partition)?;
for index in existing_indices.iter() {
num_rows += index.partition_size(partition);
}
if num_rows > max_partition_size
&& num_rows > MAX_PARTITION_SIZE_FACTOR * index_type.target_partition_size()
{
max_partition_size = num_rows;
split_partition = Some(partition);
}
if ivf.num_partitions() > 1
&& num_rows < min_partition_size
&& num_rows < MIN_PARTITION_SIZE_PERCENT * index_type.target_partition_size() / 100
{
min_partition_size = num_rows;
join_partition = Some(partition);
}
}
if let Some(partition) = split_partition {
Ok(Some(PartitionAdjustment::Split(partition)))
} else if let Some(partition) = join_partition {
Ok(Some(PartitionAdjustment::Join(partition)))
} else {
Ok(None)
}
}
async fn split_partition(&self, part_idx: usize, ivf: &IvfModel) -> Result<AssignResult> {
let Some((row_ids, vectors)) = self.load_partition_raw_vectors(part_idx).await? else {
return Ok(AssignResult {
assign_batches: vec![None; ivf.num_partitions()],
new_centroids: ivf.centroids_array().unwrap().clone(),
});
};
let element_type = infer_vector_element_type(vectors.data_type())?;
match element_type {
DataType::Float16 => {
self.split_partition_impl::<Float16Type>(part_idx, ivf, &row_ids, &vectors)
.await
}
DataType::Float32 => {
self.split_partition_impl::<Float32Type>(part_idx, ivf, &row_ids, &vectors)
.await
}
DataType::Float64 => {
self.split_partition_impl::<Float64Type>(part_idx, ivf, &row_ids, &vectors)
.await
}
DataType::UInt8 => {
self.split_partition_impl::<UInt8Type>(part_idx, ivf, &row_ids, &vectors)
.await
}
dt => Err(Error::invalid_input(format!(
"vectors must be float16, float32, float64 or uint8, but got {:?}",
dt
))),
}
}
async fn split_partition_impl<T: ArrowPrimitiveType>(
&self,
part_idx: usize,
ivf: &IvfModel,
row_ids: &UInt64Array,
vectors: &FixedSizeListArray,
) -> Result<AssignResult>
where
T::Native: Dot + L2 + Normalize,
PrimitiveArray<T>: From<Vec<T::Native>>,
{
let centroids = ivf.centroids_array().unwrap();
let mut new_centroids: Vec<ArrayRef> = Vec::with_capacity(ivf.num_partitions() + 1);
new_centroids.extend(centroids.iter().map(|vec| vec.unwrap()));
let dimension = infer_vector_dim(vectors.data_type())?;
let (normalized_dist_type, normalized_vectors) = match self.distance_type {
DistanceType::Cosine => {
let vectors = normalize_fsl(vectors)?;
(DistanceType::L2, vectors)
}
_ => (self.distance_type, vectors.clone()),
};
let params = KMeansParams::new(None, 50, 1, normalized_dist_type);
let kmeans = lance_index::vector::kmeans::train_kmeans::<T>(
normalized_vectors.values().as_primitive::<T>(),
params,
dimension,
2,
256,
)?;
let c0 = ivf
.centroid(part_idx)
.ok_or(Error::invalid_input("original centroid not found"))?;
let c1 = kmeans.centroids.slice(0, dimension);
let c2 = kmeans.centroids.slice(dimension, dimension);
new_centroids[part_idx] = c1.clone();
new_centroids.push(c2.clone());
let centroid1_part_idx = part_idx;
let centroid2_part_idx = new_centroids.len() - 1;
let new_centroids = new_centroids
.iter()
.map(|vec| vec.as_ref())
.collect::<Vec<_>>();
let new_centroids = arrow::compute::concat(&new_centroids)?;
let (reassign_part_ids, reassign_part_centroids) =
self.select_reassign_candidates(ivf, part_idx, &c0)?;
let d0 = self.distance_type.arrow_batch_func()(&c0, vectors)?;
let d1 = self.distance_type.arrow_batch_func()(&c1, vectors)?;
let d2 = self.distance_type.arrow_batch_func()(&c2, vectors)?;
let d0 = d0.values();
let d1 = d1.values();
let d2 = d2.values();
let mut assign_ops = vec![Vec::new(); ivf.num_partitions() + 1];
self.assign_vectors::<T>(
part_idx,
centroid1_part_idx,
centroid2_part_idx,
row_ids,
vectors,
d0,
d1,
d2,
&reassign_part_ids,
&reassign_part_centroids,
true,
&mut assign_ops,
)?;
let reassign_targets = reassign_part_ids
.values()
.iter()
.copied()
.enumerate()
.collect::<Vec<_>>();
if !reassign_targets.is_empty() {
let builder = self;
let distance_type = self.distance_type;
let reassign_part_ids_clone = reassign_part_ids.clone();
let reassign_part_centroids_clone = reassign_part_centroids.clone();
stream::iter(
reassign_targets
.into_iter()
.map(move |(candidate_idx, part_id)| {
let builder = builder;
let reassign_part_ids = reassign_part_ids_clone.clone();
let reassign_part_centroids = reassign_part_centroids_clone.clone();
let centroid1 = c1.clone();
let centroid2 = c2.clone();
async move {
let part_idx = part_id as usize;
let Some((row_ids, vectors)) =
builder.load_partition_raw_vectors(part_idx).await?
else {
return Ok::<Vec<(usize, AssignOp)>, Error>(Vec::new());
};
let ops = spawn_cpu(move || {
Self::compute_reassign_assign_ops::<T>(
distance_type,
part_idx,
candidate_idx,
centroid1_part_idx,
centroid2_part_idx,
&row_ids,
&vectors,
centroid1,
centroid2,
&reassign_part_ids,
&reassign_part_centroids,
)
})
.await?;
Ok(ops)
}
}),
)
.buffered(get_num_compute_intensive_cpus())
.try_for_each(|ops| {
for (target_idx, op) in ops {
assign_ops[target_idx].push(op);
}
future::ready(Ok(()))
})
.await?;
}
let new_centroids =
FixedSizeListArray::try_new_from_values(new_centroids, dimension as i32)?;
let assign_batches = self.build_assign_batch::<T>(&new_centroids, &assign_ops)?;
Ok(AssignResult {
assign_batches,
new_centroids,
})
}
async fn join_partition(&self, part_idx: usize, ivf: &IvfModel) -> Result<AssignResult> {
let centroids = ivf.centroids_array().unwrap();
let mut new_centroids: Vec<ArrayRef> = Vec::with_capacity(ivf.num_partitions() - 1);
new_centroids.extend(centroids.iter().enumerate().filter_map(|(i, vec)| {
if i == part_idx {
None
} else {
Some(vec.unwrap())
}
}));
let new_centroids = new_centroids
.iter()
.map(|vec| vec.as_ref())
.collect::<Vec<_>>();
let new_centroids = arrow::compute::concat(&new_centroids)?;
let new_centroids =
FixedSizeListArray::try_new_from_values(new_centroids, centroids.value_length())?;
let Some((row_ids, vectors)) = self.load_partition_raw_vectors(part_idx).await? else {
return Ok(AssignResult {
assign_batches: vec![None; ivf.num_partitions() - 1],
new_centroids,
});
};
match vectors.value_type() {
DataType::Float16 => {
self.join_partition_impl::<Float16Type>(
part_idx,
ivf,
&row_ids,
&vectors,
new_centroids,
)
.await
}
DataType::Float32 => {
self.join_partition_impl::<Float32Type>(
part_idx,
ivf,
&row_ids,
&vectors,
new_centroids,
)
.await
}
DataType::Float64 => {
self.join_partition_impl::<Float64Type>(
part_idx,
ivf,
&row_ids,
&vectors,
new_centroids,
)
.await
}
DataType::UInt8 => {
self.join_partition_impl::<UInt8Type>(
part_idx,
ivf,
&row_ids,
&vectors,
new_centroids,
)
.await
}
dt => Err(Error::invalid_input(format!(
"vectors must be float16, float32, float64 or uint8, but got {:?}",
dt
))),
}
}
async fn join_partition_impl<T: ArrowPrimitiveType>(
&self,
part_idx: usize,
ivf: &IvfModel,
row_ids: &UInt64Array,
vectors: &FixedSizeListArray,
new_centroids: FixedSizeListArray,
) -> Result<AssignResult>
where
T::Native: Dot + L2 + Normalize,
PrimitiveArray<T>: From<Vec<T::Native>>,
{
assert_eq!(row_ids.len(), vectors.len());
let c0 = ivf
.centroid(part_idx)
.ok_or(Error::invalid_input("original centroid not found"))?;
let (reassign_part_ids, reassign_part_centroids) =
self.select_reassign_candidates(ivf, part_idx, &c0)?;
let new_part_id = |idx: usize| -> usize {
if idx < part_idx {
idx
} else {
idx - 1
}
};
let mut assign_ops = vec![Vec::new(); ivf.num_partitions() - 1];
for (i, &row_id) in row_ids.values().iter().enumerate() {
let ReassignPartition::ReassignCandidate(idx) = self.reassign_vectors(
vectors.value(i).as_primitive::<T>(),
None,
&reassign_part_ids,
&reassign_part_centroids,
)?
else {
log::warn!("this is a bug, the vector is not reassigned");
continue;
};
assign_ops[new_part_id(idx as usize)].push(AssignOp::Add((row_id, vectors.value(i))));
}
let assign_batches = self.build_assign_batch::<T>(&new_centroids, &assign_ops)?;
Ok(AssignResult {
assign_batches,
new_centroids,
})
}
fn build_assign_batch<T: ArrowPrimitiveType>(
&self,
centroids: &FixedSizeListArray,
assign_ops: &[Vec<AssignOp>],
) -> Result<Vec<Option<(RecordBatch, UInt64Array)>>> {
let Some(dataset) = self.dataset.as_ref() else {
return Err(Error::invalid_input(
"dataset not set before building assign batch",
));
};
let Some(quantizer) = self.quantizer.clone() else {
return Err(Error::invalid_input(
"quantizer not set before building assign batch",
));
};
let Some(vector_field) =
dataset
.schema()
.field(&self.column)
.map(|f| match f.data_type() {
DataType::List(inner) | DataType::LargeList(inner) => {
Field::new(self.column.as_str(), inner.data_type().clone(), true)
}
_ => f.into(),
})
else {
return Err(Error::invalid_input(
"vector field not found in dataset schema",
));
};
let transformer = Arc::new(
lance_index::vector::ivf::new_ivf_transformer_with_quantizer(
centroids.clone(),
self.distance_type,
vector_field.name().as_str(),
quantizer.into(),
None,
)?,
);
let num_rows = assign_ops
.iter()
.map(|ops| {
ops.iter()
.map(|op| match op {
AssignOp::Add(_) => 1,
AssignOp::Remove(_) => 0,
})
.sum::<usize>()
})
.sum::<usize>();
let mut row_ids_builder = UInt64Builder::with_capacity(num_rows);
let mut vector_builder =
PrimitiveBuilder::<T>::with_capacity(num_rows * centroids.value_length() as usize);
let mut part_ids_builder = UInt32Builder::with_capacity(num_rows);
let mut deleted_row_ids = UInt64Builder::with_capacity(num_rows);
let mut ops_count = Vec::with_capacity(assign_ops.len());
for (part_idx, ops) in assign_ops.iter().enumerate() {
let mut add_count = 0;
let mut remove_count = 0;
for op in ops {
match op {
AssignOp::Add((row_id, vector)) => {
row_ids_builder.append_value(*row_id);
vector_builder.append_array(vector.as_primitive::<T>());
part_ids_builder.append_value(part_idx as u32);
add_count += 1;
}
AssignOp::Remove(row_id) => {
deleted_row_ids.append_value(*row_id);
remove_count += 1;
}
}
}
ops_count.push((add_count, remove_count));
}
let row_ids = row_ids_builder.finish();
let vector = FixedSizeListArray::try_new_from_values(
vector_builder.finish(),
centroids.value_length(),
)?;
let part_ids = part_ids_builder.finish();
let deleted_row_ids = deleted_row_ids.finish();
let schema = arrow_schema::Schema::new(vec![
ROW_ID_FIELD.clone(),
vector_field,
PART_ID_FIELD.clone(),
]);
let batch = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(row_ids), Arc::new(vector), Arc::new(part_ids)],
)?;
let batch = transformer.transform(&batch)?;
let mut results = Vec::with_capacity(assign_ops.len());
let mut add_offset = 0;
let mut remove_offset = 0;
for (add_count, remove_count) in ops_count.into_iter() {
if add_count == 0 && remove_count == 0 {
results.push(None);
continue;
}
let batch = batch.slice(add_offset, add_count);
let deleted_row_ids = deleted_row_ids.slice(remove_offset, remove_count);
results.push(Some((batch, deleted_row_ids)));
add_offset += add_count;
remove_offset += remove_count;
}
Ok(results)
}
async fn partition_row_ids(&self, part_idx: usize) -> Result<Vec<u64>> {
let mut row_ids = Vec::new();
for index in self.existing_indices.iter() {
if part_idx >= index.ivf_model().num_partitions() {
log::warn!(
"partition index is {} but the number of partitions is {}, skip loading it",
part_idx,
index.ivf_model().num_partitions()
);
continue;
}
let mut reader = index
.partition_reader(part_idx, false, &NoOpMetricsCollector)
.await?;
while let Some(batch) = reader.try_next().await? {
row_ids.extend(batch[ROW_ID].as_primitive::<UInt64Type>().values());
}
}
if let Some(reader) = self.shuffle_reader.as_ref() {
if let Some(mut reader) = reader.read_partition(part_idx).await? {
while let Some(batch) = reader.try_next().await? {
row_ids.extend(batch[ROW_ID].as_primitive::<UInt64Type>().values());
}
}
}
Ok(row_ids)
}
fn select_reassign_candidates(
&self,
ivf: &IvfModel,
part_idx: usize,
c0: &ArrayRef,
) -> Result<(UInt32Array, FixedSizeListArray)> {
select_reassign_candidates_impl(self.distance_type, ivf, part_idx, c0)
}
#[allow(clippy::too_many_arguments)]
fn assign_vectors<T: ArrowPrimitiveType>(
&self,
part_idx: usize,
centroid1_part_idx: usize,
centroid2_part_idx: usize,
row_ids: &UInt64Array,
vectors: &FixedSizeListArray,
d0: &[f32],
d1: &[f32],
d2: &[f32],
reassign_part_ids: &UInt32Array,
reassign_part_centroids: &FixedSizeListArray,
deleted_original_partition: bool,
assign_ops: &mut [Vec<AssignOp>],
) -> Result<()> {
Self::assign_vectors_impl::<T, _>(
self.distance_type,
part_idx,
centroid1_part_idx,
centroid2_part_idx,
row_ids,
vectors,
d0,
d1,
d2,
reassign_part_ids,
reassign_part_centroids,
deleted_original_partition,
|idx, op| assign_ops[idx].push(op),
)
}
#[allow(clippy::too_many_arguments)]
fn assign_vectors_impl<T: ArrowPrimitiveType, F: FnMut(usize, AssignOp)>(
distance_type: DistanceType,
part_idx: usize,
centroid1_part_idx: usize,
centroid2_part_idx: usize,
row_ids: &UInt64Array,
vectors: &FixedSizeListArray,
d0: &[f32],
d1: &[f32],
d2: &[f32],
reassign_part_ids: &UInt32Array,
reassign_part_centroids: &FixedSizeListArray,
deleted_original_partition: bool,
mut sink: F,
) -> Result<()> {
for (i, &row_id) in row_ids.values().iter().enumerate() {
if d0[i] <= d1[i] && d0[i] <= d2[i] {
if !deleted_original_partition {
continue;
}
match Self::reassign_vectors_impl(
distance_type,
vectors.value(i).as_primitive::<T>(),
Some((d1[i], d2[i])),
reassign_part_ids,
reassign_part_centroids,
)? {
ReassignPartition::NewCentroid1 => {
sink(
centroid1_part_idx,
AssignOp::Add((row_id, vectors.value(i))),
);
}
ReassignPartition::NewCentroid2 => {
sink(
centroid2_part_idx,
AssignOp::Add((row_id, vectors.value(i))),
);
}
ReassignPartition::ReassignCandidate(idx) => {
sink(idx as usize, AssignOp::Add((row_id, vectors.value(i))));
}
}
} else {
if !deleted_original_partition {
sink(part_idx, AssignOp::Remove(row_id));
}
if d1[i] <= d2[i] {
sink(
centroid1_part_idx,
AssignOp::Add((row_id, vectors.value(i))),
);
} else {
sink(
centroid2_part_idx,
AssignOp::Add((row_id, vectors.value(i))),
);
}
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn compute_reassign_assign_ops<T: ArrowPrimitiveType>(
distance_type: DistanceType,
part_idx: usize,
candidate_idx: usize,
centroid1_part_idx: usize,
centroid2_part_idx: usize,
row_ids: &UInt64Array,
vectors: &FixedSizeListArray,
centroid1: ArrayRef,
centroid2: ArrayRef,
reassign_part_ids: &UInt32Array,
reassign_part_centroids: &FixedSizeListArray,
) -> Result<Vec<(usize, AssignOp)>>
where
T::Native: Dot + L2 + Normalize,
PrimitiveArray<T>: From<Vec<T::Native>>,
{
let d0 = distance_type.arrow_batch_func()(
reassign_part_centroids.value(candidate_idx).as_ref(),
vectors,
)?;
let d1 = distance_type.arrow_batch_func()(centroid1.as_ref(), vectors)?;
let d2 = distance_type.arrow_batch_func()(centroid2.as_ref(), vectors)?;
let d0 = d0.values();
let d1 = d1.values();
let d2 = d2.values();
let mut ops = Vec::new();
Self::assign_vectors_impl::<T, _>(
distance_type,
part_idx,
centroid1_part_idx,
centroid2_part_idx,
row_ids,
vectors,
d0,
d1,
d2,
reassign_part_ids,
reassign_part_centroids,
false,
|idx, op| ops.push((idx, op)),
)?;
Ok(ops)
}
fn reassign_vectors<T: ArrowPrimitiveType>(
&self,
vector: &PrimitiveArray<T>,
split_centroids_dists: Option<(f32, f32)>,
reassign_candidate_ids: &UInt32Array,
reassign_candidate_centroids: &FixedSizeListArray,
) -> Result<ReassignPartition> {
Self::reassign_vectors_impl(
self.distance_type,
vector,
split_centroids_dists,
reassign_candidate_ids,
reassign_candidate_centroids,
)
}
fn reassign_vectors_impl<T: ArrowPrimitiveType>(
distance_type: DistanceType,
vector: &PrimitiveArray<T>,
split_centroids_dists: Option<(f32, f32)>,
reassign_candidate_ids: &UInt32Array,
reassign_candidate_centroids: &FixedSizeListArray,
) -> Result<ReassignPartition> {
let dists = distance_type.arrow_batch_func()(vector, reassign_candidate_centroids)?;
let min_dist_idx = dists.values().iter().position_min_by(|a, b| a.total_cmp(b));
let min_dist = min_dist_idx
.map(|idx| dists.value(idx))
.unwrap_or(f32::INFINITY);
match split_centroids_dists {
Some((d1, d2)) => {
if min_dist <= d1 && min_dist <= d2 {
Ok(ReassignPartition::ReassignCandidate(
reassign_candidate_ids.value(min_dist_idx.unwrap()),
))
} else if d1 <= d2 {
Ok(ReassignPartition::NewCentroid1)
} else {
Ok(ReassignPartition::NewCentroid2)
}
}
None => Ok(ReassignPartition::ReassignCandidate(
reassign_candidate_ids.value(min_dist_idx.unwrap()),
)),
}
}
}
fn select_reassign_candidates_impl(
distance_type: DistanceType,
ivf: &IvfModel,
part_idx: usize,
c0: &ArrayRef,
) -> Result<(UInt32Array, FixedSizeListArray)> {
let reassign_range = std::cmp::min(REASSIGN_RANGE + 1, ivf.num_partitions());
let centroids = ivf.centroids_array().unwrap();
let centroid_dists = distance_type.arrow_batch_func()(&c0, centroids)?;
let reassign_range_candidates =
sort_to_indices(centroid_dists.as_ref(), None, Some(reassign_range))?;
let selection_len = reassign_range.saturating_sub(1);
let filtered_ids = reassign_range_candidates
.values()
.iter()
.copied()
.filter(|&idx| idx as usize != part_idx)
.take(selection_len)
.collect::<Vec<_>>();
let reassign_candidate_ids = UInt32Array::from(filtered_ids);
let reassign_candidate_centroids =
arrow::compute::take(centroids, &reassign_candidate_ids, None)?;
Ok((
reassign_candidate_ids,
reassign_candidate_centroids.as_fixed_size_list().clone(),
))
}
struct AssignResult {
assign_batches: Vec<Option<(RecordBatch, UInt64Array)>>,
new_centroids: FixedSizeListArray,
}
#[derive(Debug, Clone)]
enum AssignOp {
Add((u64, ArrayRef)),
Remove(u64),
}
#[derive(Debug, Copy, Clone)]
enum ReassignPartition {
NewCentroid1,
NewCentroid2,
ReassignCandidate(u32),
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum PartitionAdjustment {
Split(usize),
Join(usize),
}
pub(crate) fn index_type_string(sub_index: SubIndexType, quantizer: QuantizationType) -> String {
match (sub_index, quantizer) {
(SubIndexType::Flat, quantization_type) => format!("IVF_{}", quantization_type),
(sub_index_type, quantization_type) => {
if sub_index_type.to_string() == quantization_type.to_string() {
format!("IVF_{}", sub_index_type)
} else {
format!("IVF_{}_{}", sub_index_type, quantization_type)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::Float32Array;
use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer};
#[test]
fn select_reassign_candidates_skips_deleted_partition() {
let dim = 4;
let centroid_values = Float32Array::from(vec![0.0_f32; dim * 2]);
let centroids =
FixedSizeListArray::try_new_from_values(centroid_values, dim as i32).unwrap();
let mut ivf = IvfModel::new(centroids, None);
ivf.lengths = vec![10, 20];
ivf.offsets = vec![0, 10];
let c0 = ivf.centroid(1).unwrap();
let (reassign_ids, reassign_centroids) =
select_reassign_candidates_impl(DistanceType::L2, &ivf, 1, &c0).unwrap();
assert_eq!(reassign_ids.len(), 1);
assert_eq!(reassign_ids.value(0), 0);
assert_eq!(reassign_centroids.len(), 1);
let expected_centroid = ivf.centroid(0).unwrap();
assert_eq!(
reassign_centroids
.value(0)
.as_primitive::<Float32Type>()
.values(),
expected_centroid.as_primitive::<Float32Type>().values()
);
}
#[test]
fn compute_reassign_assign_ops_moves_vectors_to_new_centroids() {
let row_ids = UInt64Array::from(vec![1_u64, 2_u64]);
let vectors = FixedSizeListArray::try_new_from_values(
Float32Array::from(vec![0.0_f32, 0.0, 10.0, 10.0]),
2,
)
.unwrap();
let reassign_part_ids = UInt32Array::from(vec![0_u32]);
let reassign_part_centroids =
FixedSizeListArray::try_new_from_values(Float32Array::from(vec![9.0_f32, 9.0]), 2)
.unwrap();
let centroid1: ArrayRef = Arc::new(Float32Array::from(vec![0.0_f32, 0.0]));
let centroid2: ArrayRef = Arc::new(Float32Array::from(vec![20.0_f32, 20.0]));
let ops = IvfIndexBuilder::<FlatIndex, FlatQuantizer>::compute_reassign_assign_ops::<
Float32Type,
>(
DistanceType::L2,
0,
0,
1,
2,
&row_ids,
&vectors,
centroid1,
centroid2,
&reassign_part_ids,
&reassign_part_centroids,
)
.unwrap();
assert_eq!(ops.len(), 2);
assert!(matches!(ops[0], (0, AssignOp::Remove(1))));
match &ops[1] {
(1, AssignOp::Add((row_id, vector))) => {
assert_eq!(*row_id, 1);
assert_eq!(
vector.as_primitive::<Float32Type>().values(),
&[0.0_f32, 0.0]
);
}
other => panic!("unexpected op: {:?}", other),
}
}
}