use alloc::collections::BinaryHeap;
use alloc::string::String;
use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use core::cmp::Ordering as CmpOrdering;
use crate::TableDefinition;
use crate::error::StorageError;
use crate::observer::DatabaseObserver;
#[cfg(feature = "metrics")]
use crate::observer::DbMetrics;
use crate::storage_traits::{ReadTable, StorageRead, StorageWrite, WriteTable};
use crate::vector_ops::{DistanceMetric, Neighbor, l2_normalize};
use super::adc::IntAdcTable;
use super::cluster_blob::{ClusterBlobRef, merge_into_blob, remove_from_blob};
use super::config::{
FORMAT_V0_LEGACY, IndexConfig, IvfPqIndexDefinition, STATE_TRAINED, SearchParams,
};
use super::kmeans;
use super::metadata::{MetadataMap, passes_filter};
use super::pq::{self, Codebooks};
use super::types::{decode_index_config, encode_index_config};
type OwnedBlobEntry = (u64, Vec<u8>, Option<Vec<u8>>);
const MAX_RERANK_CANDIDATES: usize = 1_000_000;
#[derive(Debug, Clone)]
pub enum TrainProgress {
CollectingVectors { count: usize },
TrainingCentroids {
num_clusters: usize,
num_vectors: usize,
},
ComputingResiduals { num_vectors: usize },
TrainingCodebooks { num_subvectors: usize },
Persisting,
Done,
}
fn meta_name(name: &str) -> String {
alloc::format!("__ivfpq:{name}:meta")
}
fn centroids_name(name: &str) -> String {
alloc::format!("__ivfpq:{name}:centroids")
}
fn codebooks_name(name: &str) -> String {
alloc::format!("__ivfpq:{name}:codebooks")
}
fn clusters_name(name: &str) -> String {
alloc::format!("__ivfpq:{name}:clusters")
}
fn vectors_name(name: &str) -> String {
alloc::format!("__ivfpq:{name}:vectors")
}
fn assignments_name(name: &str) -> String {
alloc::format!("__ivfpq:{name}:assignments")
}
fn vector_meta_name(name: &str) -> String {
alloc::format!("__ivfpq:{name}:vector_meta")
}
fn validate_config(config: &IndexConfig) -> crate::Result<()> {
if config.num_subvectors == 0 {
return Err(StorageError::invalid_index_config(
"IVF-PQ: num_subvectors must be > 0",
));
}
if config.dim == 0 {
return Err(StorageError::invalid_index_config(
"IVF-PQ: dim must be > 0",
));
}
if config.dim as usize % config.num_subvectors as usize != 0 {
return Err(StorageError::invalid_index_config(alloc::format!(
"IVF-PQ: dim ({}) must be divisible by num_subvectors ({})",
config.dim,
config.num_subvectors,
)));
}
if config.num_clusters == 0 {
return Err(StorageError::invalid_index_config(
"IVF-PQ: num_clusters must be > 0",
));
}
Ok(())
}
pub struct IvfPqIndex<'txn, T: StorageWrite> {
txn: &'txn T,
pub(crate) config: IndexConfig,
name: String,
requested_num_clusters: u32,
centroids: Option<Vec<f32>>,
codebooks: Option<Codebooks>,
config_dirty: bool,
observer: Arc<dyn DatabaseObserver>,
#[cfg(feature = "metrics")]
db_metrics: Arc<DbMetrics>,
}
impl<'txn, T: StorageWrite> IvfPqIndex<'txn, T> {
pub(crate) fn open(
txn: &'txn T,
definition: &IvfPqIndexDefinition,
observer: Arc<dyn DatabaseObserver>,
#[cfg(feature = "metrics")] db_metrics: Arc<DbMetrics>,
) -> crate::Result<Self> {
let name = String::from(definition.name());
let mn = meta_name(&name);
let meta_def = TableDefinition::<&str, &[u8]>::new_internal(&mn);
let mut meta_table = txn.open_storage_table(meta_def)?;
let existing = meta_table.st_get(&"config")?;
let config = if let Some(guard) = existing {
decode_index_config(guard.value())
} else {
let config = definition.to_config();
validate_config(&config)?;
let bytes = encode_index_config(&config);
drop(existing); meta_table.st_insert(&"config", &bytes.as_slice())?;
config
};
if config.format_version == FORMAT_V0_LEGACY && config.state != 0 {
return Err(StorageError::format_error(alloc::format!(
"IVF-PQ '{name}': legacy format v0 -- re-train required for blob format v1",
)));
}
{
let cn = centroids_name(&name);
let _ = txn.open_storage_table(TableDefinition::<u32, &[u8]>::new_internal(&cn))?;
let cb = codebooks_name(&name);
let _ = txn.open_storage_table(TableDefinition::<u32, &[u8]>::new_internal(&cb))?;
let cl = clusters_name(&name);
let _ = txn.open_storage_table(TableDefinition::<u32, &[u8]>::new_internal(&cl))?;
let vn = vectors_name(&name);
let _ = txn.open_storage_table(TableDefinition::<u64, &[u8]>::new_internal(&vn))?;
let an = assignments_name(&name);
let _ = txn.open_storage_table(TableDefinition::<u64, u32>::new_internal(&an))?;
}
let requested_num_clusters = definition.num_clusters();
Ok(Self {
txn,
config,
name,
requested_num_clusters,
centroids: None,
codebooks: None,
config_dirty: false,
observer,
#[cfg(feature = "metrics")]
db_metrics,
})
}
pub fn config(&self) -> &IndexConfig {
&self.config
}
pub fn flush(&mut self) -> crate::Result<()> {
if self.config_dirty {
self.persist_config_inner()?;
self.config_dirty = false;
}
Ok(())
}
pub fn train<I>(&mut self, training_vectors: I, max_iter: usize) -> crate::Result<()>
where
I: Iterator<Item = (u64, Vec<f32>)>,
{
validate_config(&self.config)?;
let dim = self.config.dim as usize;
let num_clusters = self.requested_num_clusters as usize;
let num_subvectors = self.config.num_subvectors as usize;
let mut flat: Vec<f32> = Vec::new();
for (_id, mut vec) in training_vectors {
if vec.len() != dim {
return Err(StorageError::dimension_mismatch(&self.name, dim, vec.len()));
}
if self.config.metric == DistanceMetric::Cosine {
l2_normalize(&mut vec);
}
flat.extend_from_slice(&vec);
}
let n = flat.len() / dim;
if n == 0 {
return Err(StorageError::invalid_index_config(alloc::format!(
"IVF-PQ '{}': no training vectors provided",
self.name,
)));
}
let centroid_data = kmeans::kmeans(&flat, dim, num_clusters, max_iter, self.config.metric);
let actual_k = centroid_data.len() / dim;
let old_k = self.config.num_clusters as usize;
#[allow(clippy::cast_possible_truncation)]
{
self.config.num_clusters = actual_k as u32;
}
let mut residuals = Vec::with_capacity(flat.len());
for i in 0..n {
let vec_slice = &flat[i * dim..(i + 1) * dim];
let (cid, _) = kmeans::assign_nearest(
vec_slice,
¢roid_data,
dim,
actual_k,
self.config.metric,
);
let c_offset = cid as usize * dim;
for d in 0..dim {
residuals.push(vec_slice[d] - centroid_data[c_offset + d]);
}
}
let codebooks_trained = pq::train_codebooks(
&residuals,
dim,
num_subvectors,
max_iter,
self.config.metric,
)?;
self.clear_stale_training_data(old_k, actual_k)?;
{
let tn = centroids_name(&self.name);
let def = TableDefinition::<u32, &[u8]>::new_internal(&tn);
let mut table = self.txn.open_storage_table(def)?;
for c in 0..actual_k {
let bytes = f32_slice_to_le_bytes(¢roid_data[c * dim..(c + 1) * dim]);
#[allow(clippy::cast_possible_truncation)]
table.st_insert(&(c as u32), &bytes.as_slice())?;
}
}
{
let tn = codebooks_name(&self.name);
let def = TableDefinition::<u32, &[u8]>::new_internal(&tn);
let mut table = self.txn.open_storage_table(def)?;
for m in 0..num_subvectors {
let bytes = codebooks_trained.serialize_codebook(m);
#[allow(clippy::cast_possible_truncation)]
table.st_insert(&(m as u32), &bytes.as_slice())?;
}
}
self.config.state = STATE_TRAINED;
self.config.num_vectors = 0;
self.persist_config_inner()?;
self.config_dirty = false;
self.centroids = Some(centroid_data);
self.codebooks = Some(codebooks_trained);
Ok(())
}
pub fn train_with_progress<I, F>(
&mut self,
training_vectors: I,
max_iter: usize,
progress: F,
) -> crate::Result<()>
where
I: Iterator<Item = (u64, Vec<f32>)>,
F: Fn(&TrainProgress),
{
validate_config(&self.config)?;
let dim = self.config.dim as usize;
let num_clusters = self.requested_num_clusters as usize;
let num_subvectors = self.config.num_subvectors as usize;
let mut flat: Vec<f32> = Vec::new();
let mut count = 0usize;
for (_id, mut vec) in training_vectors {
if vec.len() != dim {
return Err(StorageError::dimension_mismatch(&self.name, dim, vec.len()));
}
if self.config.metric == DistanceMetric::Cosine {
l2_normalize(&mut vec);
}
flat.extend_from_slice(&vec);
count += 1;
}
let p = TrainProgress::CollectingVectors { count };
progress(&p);
self.observer.on_train_progress(&self.name, &p);
let n = flat.len() / dim;
if n == 0 {
return Err(StorageError::invalid_index_config(alloc::format!(
"IVF-PQ '{}': no training vectors provided",
self.name,
)));
}
let p = TrainProgress::TrainingCentroids {
num_clusters,
num_vectors: n,
};
progress(&p);
self.observer.on_train_progress(&self.name, &p);
let centroid_data = kmeans::kmeans(&flat, dim, num_clusters, max_iter, self.config.metric);
let actual_k = centroid_data.len() / dim;
let old_k = self.config.num_clusters as usize;
#[allow(clippy::cast_possible_truncation)]
{
self.config.num_clusters = actual_k as u32;
}
let p = TrainProgress::ComputingResiduals { num_vectors: n };
progress(&p);
self.observer.on_train_progress(&self.name, &p);
let mut residuals = Vec::with_capacity(flat.len());
for i in 0..n {
let vec_slice = &flat[i * dim..(i + 1) * dim];
let (cid, _) = kmeans::assign_nearest(
vec_slice,
¢roid_data,
dim,
actual_k,
self.config.metric,
);
let c_offset = cid as usize * dim;
for d in 0..dim {
residuals.push(vec_slice[d] - centroid_data[c_offset + d]);
}
}
let p = TrainProgress::TrainingCodebooks { num_subvectors };
progress(&p);
self.observer.on_train_progress(&self.name, &p);
let codebooks_trained = pq::train_codebooks(
&residuals,
dim,
num_subvectors,
max_iter,
self.config.metric,
)?;
self.clear_stale_training_data(old_k, actual_k)?;
let p = TrainProgress::Persisting;
progress(&p);
self.observer.on_train_progress(&self.name, &p);
{
let tn = centroids_name(&self.name);
let def = TableDefinition::<u32, &[u8]>::new_internal(&tn);
let mut table = self.txn.open_storage_table(def)?;
for c in 0..actual_k {
let bytes = f32_slice_to_le_bytes(¢roid_data[c * dim..(c + 1) * dim]);
#[allow(clippy::cast_possible_truncation)]
table.st_insert(&(c as u32), &bytes.as_slice())?;
}
}
{
let tn = codebooks_name(&self.name);
let def = TableDefinition::<u32, &[u8]>::new_internal(&tn);
let mut table = self.txn.open_storage_table(def)?;
for m in 0..num_subvectors {
let bytes = codebooks_trained.serialize_codebook(m);
#[allow(clippy::cast_possible_truncation)]
table.st_insert(&(m as u32), &bytes.as_slice())?;
}
}
self.config.state = STATE_TRAINED;
self.config.num_vectors = 0;
self.persist_config_inner()?;
self.config_dirty = false;
self.centroids = Some(centroid_data);
self.codebooks = Some(codebooks_trained);
let p = TrainProgress::Done;
progress(&p);
self.observer.on_train_progress(&self.name, &p);
Ok(())
}
#[allow(clippy::cast_possible_truncation)]
pub fn insert(&mut self, vector_id: u64, vector: &[f32]) -> crate::Result<()> {
self.ensure_trained()?;
let dim = self.config.dim as usize;
if vector.len() != dim {
return Err(StorageError::dimension_mismatch(
&self.name,
dim,
vector.len(),
));
}
Self::validate_finite(vector, &self.name)?;
let vec_owned;
let vec_ref = if self.config.metric == DistanceMetric::Cosine {
vec_owned = crate::vector_ops::l2_normalized(vector);
&vec_owned
} else {
vector
};
let centroids = self.load_centroids()?;
let (cluster_id, _) = kmeans::assign_nearest(
vec_ref,
¢roids,
dim,
self.config.num_clusters as usize,
self.config.metric,
);
let c_offset = cluster_id as usize * dim;
let residual: Vec<f32> = vec_ref
.iter()
.enumerate()
.map(|(d, &v)| v - centroids[c_offset + d])
.collect();
let codebooks = self.load_codebooks()?;
let pq_codes = codebooks.encode(&residual);
let old_cluster = {
let tn = assignments_name(&self.name);
let def = TableDefinition::<u64, u32>::new_internal(&tn);
let table = self.txn.open_storage_table(def)?;
table.st_get(&vector_id)?.map(|g| g.value())
};
let pq_len = self.config.num_subvectors as u16;
if let Some(old_cid) = old_cluster
&& old_cid != cluster_id
{
self.remove_from_cluster_blob(old_cid, vector_id, pq_len)?;
}
{
let tn = clusters_name(&self.name);
let def = TableDefinition::<u32, &[u8]>::new_internal(&tn);
let mut table = self.txn.open_storage_table(def)?;
let existing_blob = table.st_get(&cluster_id)?;
let existing_ref = match existing_blob {
Some(ref guard) => Some(ClusterBlobRef::new(guard.value(), pq_len, dim)?),
None => None,
};
let mut new_entries: Vec<OwnedBlobEntry> = vec![(vector_id, pq_codes, None)];
let merged = merge_into_blob(existing_ref.as_ref(), &mut new_entries, pq_len);
drop(existing_blob);
table.st_insert(&cluster_id, &merged.as_slice())?;
}
if self.config.store_raw_vectors {
let raw_bytes = f32_slice_to_le_bytes(vec_ref);
let vn = vectors_name(&self.name);
let vdef = TableDefinition::<u64, &[u8]>::new_internal(&vn);
let mut vt = self.txn.open_storage_table(vdef)?;
vt.st_insert(&vector_id, &raw_bytes.as_slice())?;
}
{
let tn = assignments_name(&self.name);
let def = TableDefinition::<u64, u32>::new_internal(&tn);
let mut table = self.txn.open_storage_table(def)?;
table.st_insert(&vector_id, &cluster_id)?;
}
if old_cluster.is_none() {
self.config.num_vectors = self.config.num_vectors.saturating_add(1);
self.config_dirty = true;
}
Ok(())
}
#[allow(clippy::cast_possible_truncation)]
pub fn insert_batch<I>(&mut self, vectors: I) -> crate::Result<u64>
where
I: Iterator<Item = (u64, Vec<f32>)>,
{
self.ensure_trained()?;
let dim = self.config.dim as usize;
let centroids = self.load_centroids()?;
let num_clusters = self.config.num_clusters as usize;
let metric = self.config.metric;
let store_raw = self.config.store_raw_vectors;
let pq_len = self.config.num_subvectors as u16;
let codebooks = self.load_codebooks()?;
let mut grouped: Vec<Vec<OwnedBlobEntry>> = Vec::new();
grouped.resize_with(num_clusters, Vec::new);
let mut raw_vectors: Vec<(u64, Vec<u8>)> = Vec::new();
let an = assignments_name(&self.name);
let ad = TableDefinition::<u64, u32>::new_internal(&an);
let mut at = self.txn.open_storage_table(ad)?;
let mut old_assignments: Vec<(u64, u32)> = Vec::new();
let mut new_count = 0u64;
for (vector_id, mut vec) in vectors {
if vec.len() != dim {
return Err(StorageError::dimension_mismatch(&self.name, dim, vec.len()));
}
Self::validate_finite(&vec, &self.name)?;
if metric == DistanceMetric::Cosine {
l2_normalize(&mut vec);
}
let (cluster_id, _) =
kmeans::assign_nearest(&vec, ¢roids, dim, num_clusters, metric);
let c_offset = cluster_id as usize * dim;
let residual: Vec<f32> = vec
.iter()
.enumerate()
.map(|(d, &v)| v - centroids[c_offset + d])
.collect();
let pq_codes = codebooks.encode(&residual);
if store_raw {
raw_vectors.push((vector_id, f32_slice_to_le_bytes(&vec)));
}
let old_cluster = at.st_get(&vector_id)?.map(|g| g.value());
if let Some(old_cid) = old_cluster {
if old_cid != cluster_id {
old_assignments.push((vector_id, old_cid));
}
} else {
new_count += 1;
}
at.st_insert(&vector_id, &cluster_id)?;
grouped[cluster_id as usize].push((vector_id, pq_codes, None));
}
drop(at);
if !old_assignments.is_empty() {
for &(vid, old_cid) in &old_assignments {
self.remove_from_cluster_blob(old_cid, vid, pq_len)?;
}
}
{
let tn = clusters_name(&self.name);
let def = TableDefinition::<u32, &[u8]>::new_internal(&tn);
let mut table = self.txn.open_storage_table(def)?;
for (cid, mut entries) in grouped.into_iter().enumerate() {
if entries.is_empty() {
continue;
}
let cid_u32 = cid as u32;
let existing_blob = table.st_get(&cid_u32)?;
let existing_ref = match existing_blob {
Some(ref guard) => Some(ClusterBlobRef::new(guard.value(), pq_len, dim)?),
None => None,
};
let merged = merge_into_blob(existing_ref.as_ref(), &mut entries, pq_len);
drop(existing_blob);
table.st_insert(&cid_u32, &merged.as_slice())?;
}
}
if !raw_vectors.is_empty() {
let vn = vectors_name(&self.name);
let vdef = TableDefinition::<u64, &[u8]>::new_internal(&vn);
let mut vt = self.txn.open_storage_table(vdef)?;
for (vid, raw) in &raw_vectors {
vt.st_insert(vid, &raw.as_slice())?;
}
}
if new_count > 0 {
self.config.num_vectors = self.config.num_vectors.saturating_add(new_count);
self.config_dirty = true;
}
Ok(new_count)
}
#[allow(clippy::cast_possible_truncation)]
pub fn remove(&mut self, vector_id: u64) -> crate::Result<bool> {
let cluster_id = {
let tn = assignments_name(&self.name);
let def = TableDefinition::<u64, u32>::new_internal(&tn);
let mut table = self.txn.open_storage_table(def)?;
match table.st_remove(&vector_id)? {
Some(guard) => guard.value(),
None => return Ok(false),
}
};
let pq_len = self.config.num_subvectors as u16;
self.remove_from_cluster_blob(cluster_id, vector_id, pq_len)?;
if self.config.store_raw_vectors {
let vn = vectors_name(&self.name);
let vdef = TableDefinition::<u64, &[u8]>::new_internal(&vn);
let mut vt = self.txn.open_storage_table(vdef)?;
vt.st_remove(&vector_id)?;
}
self.config.num_vectors = self.config.num_vectors.saturating_sub(1);
self.config_dirty = true;
{
let mn = vector_meta_name(&self.name);
let mdef = TableDefinition::<u64, &[u8]>::new_internal(&mn);
let mut mt = self.txn.open_storage_table(mdef)?;
mt.st_remove(&vector_id)?;
}
Ok(true)
}
pub fn insert_metadata(&mut self, vector_id: u64, metadata: &MetadataMap) -> crate::Result<()> {
let encoded = metadata.encode();
let mn = vector_meta_name(&self.name);
let mdef = TableDefinition::<u64, &[u8]>::new_internal(&mn);
let mut mt = self.txn.open_storage_table(mdef)?;
mt.st_insert(&vector_id, &encoded.as_slice())?;
Ok(())
}
pub fn remove_metadata(&mut self, vector_id: u64) -> crate::Result<()> {
let mn = vector_meta_name(&self.name);
let mdef = TableDefinition::<u64, &[u8]>::new_internal(&mn);
let mut mt = self.txn.open_storage_table(mdef)?;
mt.st_remove(&vector_id)?;
Ok(())
}
#[allow(clippy::cast_possible_truncation)]
pub fn search(
&mut self,
query: &[f32],
params: &SearchParams,
) -> crate::Result<Vec<Neighbor<u64>>> {
self.ensure_trained()?;
self.flush()?;
let dim = self.config.dim as usize;
if query.len() != dim {
return Err(StorageError::dimension_mismatch(
&self.name,
dim,
query.len(),
));
}
let centroids = self.load_centroids()?;
let codebooks = self.load_codebooks()?;
let query_owned;
let q = if self.config.metric == DistanceMetric::Cosine {
if crate::vector_ops::l2_norm(query) == 0.0 {
return Ok(Vec::new());
}
query_owned = crate::vector_ops::l2_normalized(query);
&query_owned
} else {
query
};
let nprobe = (params.nprobe).max(1).min(self.config.num_clusters) as usize;
let probes = kmeans::nearest_clusters(
q,
¢roids,
dim,
self.config.num_clusters as usize,
nprobe,
self.config.metric,
params.diversity,
);
let cap = if params.rerank && self.config.store_raw_vectors {
params.candidates.max(params.k).min(MAX_RERANK_CANDIDATES)
} else {
params.k
};
let mut heap = CandidateHeap::new(cap);
let pq_len = self.config.num_subvectors as u16;
let metric = self.config.metric;
let want_rerank = params.rerank && self.config.store_raw_vectors;
{
let tn = clusters_name(&self.name);
let def = TableDefinition::<u32, &[u8]>::new_internal(&tn);
let table = self.txn.open_storage_table(def)?;
let meta_table = if params.filter.is_some() {
let mn = vector_meta_name(&self.name);
let mdef = TableDefinition::<u64, &[u8]>::new_internal(&mn);
Some(self.txn.open_storage_table(mdef)?)
} else {
None
};
let mut query_residual = vec![0.0f32; dim];
for &(cid, _) in &probes {
let c_offset = cid as usize * dim;
for d in 0..dim {
query_residual[d] = q[d] - centroids[c_offset + d];
}
let Some(blob_data) = table.st_get(&cid)? else {
continue;
};
let blob = ClusterBlobRef::new(blob_data.value(), pq_len, dim)?;
let adc = IntAdcTable::build(&query_residual, &codebooks, metric);
let pq_block = blob.pq_codes_block();
let m = pq_len as usize;
for i in 0..blob.count() {
let codes = &pq_block[i as usize * m..(i as usize + 1) * m];
let dist = adc.to_f32(adc.approximate_distance(codes));
let vid = blob.vector_id(i);
if let Some(ref filter) = params.filter
&& let Some(ref mt) = meta_table
&& let Some(guard) = mt.st_get(&vid)?
&& !passes_filter(guard.value(), filter)
{
continue;
}
heap.push(vid, dist);
}
}
}
#[cfg(feature = "metrics")]
self.db_metrics
.vector_searches
.fetch_add(1, portable_atomic::Ordering::Relaxed);
if want_rerank {
let sorted = heap.into_sorted();
rerank_from_vectors_table_write(self.txn, q, &sorted, &self.name, dim, metric, params.k)
} else {
Ok(heap.into_sorted().into_iter().take(params.k).collect())
}
}
fn validate_finite(vector: &[f32], name: &str) -> crate::Result<()> {
for (i, &v) in vector.iter().enumerate() {
if !v.is_finite() {
return Err(StorageError::invalid_index_config(alloc::format!(
"IVF-PQ '{name}': vector contains non-finite value ({v}) at index {i}",
)));
}
}
Ok(())
}
fn clear_stale_training_data(&self, old_k: usize, new_k: usize) -> crate::Result<()> {
if old_k > new_k {
let tn = centroids_name(&self.name);
let def = TableDefinition::<u32, &[u8]>::new_internal(&tn);
let mut table = self.txn.open_storage_table(def)?;
for c in new_k..old_k {
#[allow(clippy::cast_possible_truncation)]
table.st_remove(&(c as u32))?;
}
}
{
let tn = clusters_name(&self.name);
let def = TableDefinition::<u32, &[u8]>::new_internal(&tn);
let mut table = self.txn.open_storage_table(def)?;
table.st_drain_all()?;
}
{
let vn = vectors_name(&self.name);
let vdef = TableDefinition::<u64, &[u8]>::new_internal(&vn);
let mut vt = self.txn.open_storage_table(vdef)?;
vt.st_drain_all()?;
}
{
let tn = assignments_name(&self.name);
let def = TableDefinition::<u64, u32>::new_internal(&tn);
let mut table = self.txn.open_storage_table(def)?;
table.st_drain_all()?;
}
Ok(())
}
fn ensure_trained(&self) -> crate::Result<()> {
if self.config.state != STATE_TRAINED {
return Err(StorageError::index_not_trained(&self.name));
}
Ok(())
}
fn persist_config_inner(&self) -> crate::Result<()> {
let tn = meta_name(&self.name);
let def = TableDefinition::<&str, &[u8]>::new_internal(&tn);
let mut table = self.txn.open_storage_table(def)?;
let bytes = encode_index_config(&self.config);
table.st_insert(&"config", &bytes.as_slice())?;
Ok(())
}
fn load_centroids(&mut self) -> crate::Result<Vec<f32>> {
if let Some(ref c) = self.centroids {
return Ok(c.clone());
}
let data = self.read_centroids()?;
self.centroids = Some(data.clone());
Ok(data)
}
fn read_centroids(&self) -> crate::Result<Vec<f32>> {
let dim = self.config.dim as usize;
let k = self.config.num_clusters as usize;
let tn = centroids_name(&self.name);
let def = TableDefinition::<u32, &[u8]>::new_internal(&tn);
let table = self.txn.open_storage_table(def)?;
let mut flat = Vec::with_capacity(k * dim);
for c in 0..k {
#[allow(clippy::cast_possible_truncation)]
let guard = table.st_get(&(c as u32))?.ok_or_else(|| {
StorageError::Corrupted(alloc::format!(
"IVF-PQ '{}': missing centroid {c}",
self.name,
))
})?;
let raw = guard.value();
let expected_bytes = dim * 4;
if raw.len() != expected_bytes {
return Err(StorageError::Corrupted(alloc::format!(
"IVF-PQ '{}': centroid {c} byte length {} != expected {expected_bytes} (dim={dim})",
self.name,
raw.len(),
)));
}
for chunk in raw.chunks_exact(4) {
if let Ok(bytes) = chunk.try_into() {
flat.push(f32::from_le_bytes(bytes));
}
}
}
Ok(flat)
}
fn load_codebooks(&mut self) -> crate::Result<Codebooks> {
if let Some(ref cb) = self.codebooks {
return Ok(cb.clone());
}
let cb = self.read_codebooks()?;
self.codebooks = Some(cb.clone());
Ok(cb)
}
fn read_codebooks(&self) -> crate::Result<Codebooks> {
let m = self.config.num_subvectors as usize;
let sd = self.config.sub_dim();
let tn = codebooks_name(&self.name);
let def = TableDefinition::<u32, &[u8]>::new_internal(&tn);
let table = self.txn.open_storage_table(def)?;
let mut data = Vec::with_capacity(m * 256 * sd);
for i in 0..m {
#[allow(clippy::cast_possible_truncation)]
let guard = table.st_get(&(i as u32))?.ok_or_else(|| {
StorageError::Corrupted(alloc::format!(
"IVF-PQ '{}': missing codebook {i}",
self.name,
))
})?;
data.extend_from_slice(&Codebooks::deserialize_codebook(guard.value(), sd));
}
Ok(Codebooks {
data,
num_subvectors: m,
sub_dim: sd,
})
}
#[allow(clippy::cast_possible_truncation)]
fn remove_from_cluster_blob(
&self,
cluster_id: u32,
vector_id: u64,
pq_len: u16,
) -> crate::Result<()> {
let dim = self.config.dim as usize;
let tn = clusters_name(&self.name);
let def = TableDefinition::<u32, &[u8]>::new_internal(&tn);
let mut table = self.txn.open_storage_table(def)?;
let new_blob_or_empty = {
let existing = table.st_get(&cluster_id)?;
if let Some(guard) = existing {
let blob = ClusterBlobRef::new(guard.value(), pq_len, dim)?;
Some(remove_from_blob(&blob, vector_id, pq_len))
} else {
None
}
};
if let Some(result) = new_blob_or_empty {
match result {
Some(new_blob) => {
table.st_insert(&cluster_id, &new_blob.as_slice())?;
}
None => {
table.st_remove(&cluster_id)?;
}
}
}
Ok(())
}
}
impl<T: StorageWrite> Drop for IvfPqIndex<'_, T> {
fn drop(&mut self) {
if self.config_dirty {
let _ = self.persist_config_inner();
}
}
}
macro_rules! impl_rerank_from_vectors {
($fn_name:ident, $trait_bound:path) => {
#[allow(clippy::too_many_arguments)]
fn $fn_name<S: $trait_bound>(
txn: &S,
query: &[f32],
candidates: &[Neighbor<u64>],
index_name: &str,
dim: usize,
metric: DistanceMetric,
k: usize,
) -> crate::Result<Vec<Neighbor<u64>>> {
let vn = vectors_name(index_name);
let vdef = TableDefinition::<u64, &[u8]>::new_internal(&vn);
let vt = txn.open_storage_table(vdef)?;
let mut sorted_cands: Vec<&Neighbor<u64>> = candidates.iter().collect();
sorted_cands.sort_unstable_by_key(|c| c.key);
let expected_bytes = dim * 4;
let mut raw_buf = vec![0.0f32; dim];
let mut results: Vec<Neighbor<u64>> = Vec::with_capacity(sorted_cands.len());
for cand in &sorted_cands {
if let Some(guard) = vt.st_get(&cand.key)? {
let raw = guard.value();
if raw.len() == expected_bytes {
bytes_to_f32_buf(raw, &mut raw_buf);
results.push(Neighbor {
key: cand.key,
distance: metric.compute(query, &raw_buf),
});
}
}
}
results.sort_unstable_by(|a, b| a.distance.total_cmp(&b.distance));
results.truncate(k);
Ok(results)
}
};
}
impl_rerank_from_vectors!(rerank_from_vectors_table, StorageRead);
impl_rerank_from_vectors!(rerank_from_vectors_table_write, StorageWrite);
pub struct ReadOnlyIvfPqIndex {
config: IndexConfig,
name: String,
centroids: Vec<f32>,
codebooks: Codebooks,
#[cfg(feature = "metrics")]
db_metrics: Arc<DbMetrics>,
}
impl ReadOnlyIvfPqIndex {
pub(crate) fn open<R: StorageRead>(
txn: &R,
definition: &IvfPqIndexDefinition,
#[cfg(feature = "metrics")] db_metrics: Arc<DbMetrics>,
) -> crate::Result<Self> {
let name = String::from(definition.name());
let mn = meta_name(&name);
let md = TableDefinition::<&str, &[u8]>::new_internal(&mn);
let mt = txn.open_storage_table(md)?;
let config = match mt.st_get(&"config")? {
Some(guard) => decode_index_config(guard.value()),
None => {
return Err(StorageError::index_not_trained(alloc::format!(
"{name} (missing config)",
)));
}
};
let dim = config.dim as usize;
let num_clusters = config.num_clusters as usize;
let centroids = {
let tn = centroids_name(&name);
let def = TableDefinition::<u32, &[u8]>::new_internal(&tn);
let table = txn.open_storage_table(def)?;
let mut flat = Vec::with_capacity(num_clusters * dim);
for c in 0..num_clusters {
#[allow(clippy::cast_possible_truncation)]
let guard = table.st_get(&(c as u32))?.ok_or_else(|| {
StorageError::Corrupted(
alloc::format!("IVF-PQ '{name}': missing centroid {c}",),
)
})?;
for chunk in guard.value().chunks_exact(4) {
if let Ok(bytes) = chunk.try_into() {
flat.push(f32::from_le_bytes(bytes));
}
}
}
flat
};
let codebooks = {
let num_subvectors = config.num_subvectors as usize;
let sub_dim = config.sub_dim();
let tn = codebooks_name(&name);
let def = TableDefinition::<u32, &[u8]>::new_internal(&tn);
let table = txn.open_storage_table(def)?;
let mut data = Vec::with_capacity(num_subvectors * 256 * sub_dim);
for m in 0..num_subvectors {
#[allow(clippy::cast_possible_truncation)]
let guard = table.st_get(&(m as u32))?.ok_or_else(|| {
StorageError::Corrupted(
alloc::format!("IVF-PQ '{name}': missing codebook {m}",),
)
})?;
data.extend_from_slice(&Codebooks::deserialize_codebook(guard.value(), sub_dim));
}
Codebooks {
data,
num_subvectors,
sub_dim,
}
};
Ok(Self {
config,
name,
centroids,
codebooks,
#[cfg(feature = "metrics")]
db_metrics,
})
}
pub fn config(&self) -> &IndexConfig {
&self.config
}
#[allow(clippy::cast_possible_truncation)]
pub fn search<R: StorageRead>(
&self,
txn: &R,
query: &[f32],
params: &SearchParams,
) -> crate::Result<Vec<Neighbor<u64>>> {
if self.config.state != STATE_TRAINED {
return Err(StorageError::index_not_trained(&self.name));
}
let dim = self.config.dim as usize;
if query.len() != dim {
return Err(StorageError::dimension_mismatch(
&self.name,
dim,
query.len(),
));
}
let query_owned;
let q = if self.config.metric == DistanceMetric::Cosine {
if crate::vector_ops::l2_norm(query) == 0.0 {
return Ok(Vec::new());
}
query_owned = crate::vector_ops::l2_normalized(query);
&query_owned
} else {
query
};
let nprobe = (params.nprobe).max(1).min(self.config.num_clusters) as usize;
let probes = kmeans::nearest_clusters(
q,
&self.centroids,
dim,
self.config.num_clusters as usize,
nprobe,
self.config.metric,
params.diversity,
);
let cap = if params.rerank && self.config.store_raw_vectors {
params.candidates.max(params.k).min(MAX_RERANK_CANDIDATES)
} else {
params.k
};
let mut heap = CandidateHeap::new(cap);
let pq_len = self.config.num_subvectors as u16;
let metric = self.config.metric;
let want_rerank = params.rerank && self.config.store_raw_vectors;
{
let tn = clusters_name(&self.name);
let def = TableDefinition::<u32, &[u8]>::new_internal(&tn);
let table = txn.open_storage_table(def)?;
let meta_table = if params.filter.is_some() {
let mn = vector_meta_name(&self.name);
let mdef = TableDefinition::<u64, &[u8]>::new_internal(&mn);
Some(txn.open_storage_table(mdef)?)
} else {
None
};
let mut query_residual = vec![0.0f32; dim];
for &(cid, _) in &probes {
let c_offset = cid as usize * dim;
for d in 0..dim {
query_residual[d] = q[d] - self.centroids[c_offset + d];
}
let Some(blob_data) = table.st_get(&cid)? else {
continue;
};
let blob = ClusterBlobRef::new(blob_data.value(), pq_len, dim)?;
let adc = IntAdcTable::build(&query_residual, &self.codebooks, metric);
let pq_block = blob.pq_codes_block();
let m = pq_len as usize;
for i in 0..blob.count() {
let codes = &pq_block[i as usize * m..(i as usize + 1) * m];
let dist = adc.to_f32(adc.approximate_distance(codes));
let vid = blob.vector_id(i);
if let Some(ref filter) = params.filter
&& let Some(ref mt) = meta_table
&& let Some(guard) = mt.st_get(&vid)?
&& !passes_filter(guard.value(), filter)
{
continue;
}
heap.push(vid, dist);
}
}
}
#[cfg(feature = "metrics")]
self.db_metrics
.vector_searches
.fetch_add(1, portable_atomic::Ordering::Relaxed);
if want_rerank {
let sorted = heap.into_sorted();
rerank_from_vectors_table(txn, q, &sorted, &self.name, dim, metric, params.k)
} else {
Ok(heap.into_sorted().into_iter().take(params.k).collect())
}
}
}
struct CandidateHeap {
capacity: usize,
heap: BinaryHeap<CandidateEntry>,
}
#[derive(PartialEq)]
struct CandidateEntry {
vector_id: u64,
distance: f32,
}
impl Eq for CandidateEntry {}
impl PartialOrd for CandidateEntry {
fn partial_cmp(&self, other: &Self) -> Option<CmpOrdering> {
Some(self.cmp(other))
}
}
impl Ord for CandidateEntry {
fn cmp(&self, other: &Self) -> CmpOrdering {
self.distance.total_cmp(&other.distance)
}
}
impl CandidateHeap {
fn new(capacity: usize) -> Self {
Self {
capacity,
heap: BinaryHeap::with_capacity(capacity + 1),
}
}
fn push(&mut self, vector_id: u64, distance: f32) {
if self.heap.len() < self.capacity {
self.heap.push(CandidateEntry {
vector_id,
distance,
});
} else if let Some(worst) = self.heap.peek()
&& distance < worst.distance
{
self.heap.pop();
self.heap.push(CandidateEntry {
vector_id,
distance,
});
}
}
fn into_sorted(self) -> Vec<Neighbor<u64>> {
let mut items: Vec<Neighbor<u64>> = self
.heap
.into_iter()
.map(|e| Neighbor {
key: e.vector_id,
distance: e.distance,
})
.collect();
items.sort_unstable_by(|a, b| a.distance.total_cmp(&b.distance));
items
}
}
#[inline]
fn f32_slice_to_le_bytes(floats: &[f32]) -> Vec<u8> {
let byte_len = floats.len() * 4;
let mut out = vec![0u8; byte_len];
#[cfg(target_endian = "little")]
{
unsafe {
core::ptr::copy_nonoverlapping(
floats.as_ptr().cast::<u8>(),
out.as_mut_ptr(),
byte_len,
);
}
}
#[cfg(not(target_endian = "little"))]
{
for (i, &f) in floats.iter().enumerate() {
let b = f.to_le_bytes();
out[i * 4..i * 4 + 4].copy_from_slice(&b);
}
}
out
}
#[inline]
fn bytes_to_f32_buf(bytes: &[u8], buf: &mut [f32]) {
for (i, chunk) in bytes.chunks_exact(4).enumerate() {
buf[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
}