use crate::distance::FloatOrd;
use crate::RetrieveError;
use qntz::rabitq::{QuantizedVector, RaBitQConfig, RaBitQQuantizer};
#[derive(Clone, Debug)]
pub struct IVFRaBitQParams {
pub num_clusters: usize,
pub nprobe: usize,
pub total_bits: usize,
pub seed: u64,
}
impl Default for IVFRaBitQParams {
fn default() -> Self {
Self {
num_clusters: 256,
nprobe: 10,
total_bits: 4,
seed: 42,
}
}
}
#[derive(Debug)]
struct Cluster {
vector_indices: Vec<u32>,
quantized: Vec<QuantizedVector>,
}
pub struct IVFRaBitQIndex {
dimension: usize,
params: IVFRaBitQParams,
built: bool,
compacted: bool,
vectors: Vec<f32>,
num_vectors: usize,
doc_ids: Vec<u32>,
clusters: Vec<Cluster>,
centroids: Vec<f32>,
quantizer: RaBitQQuantizer,
#[cfg(feature = "hnsw")]
coarse_quantizer: Option<crate::hnsw::HNSWIndex>,
}
impl IVFRaBitQIndex {
pub fn new(dimension: usize, params: IVFRaBitQParams) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::InvalidParameter(
"dimension must be > 0".into(),
));
}
let config = RaBitQConfig {
total_bits: params.total_bits,
t_const: None,
};
let quantizer = RaBitQQuantizer::with_config(dimension, params.seed, config)
.map_err(|e| RetrieveError::InvalidParameter(format!("RaBitQ config: {e}")))?;
Ok(Self {
dimension,
params,
built: false,
compacted: false,
vectors: Vec::new(),
num_vectors: 0,
doc_ids: Vec::new(),
clusters: Vec::new(),
centroids: Vec::new(),
quantizer,
#[cfg(feature = "hnsw")]
coarse_quantizer: None,
})
}
pub fn set_nprobe(&mut self, nprobe: usize) {
self.params.nprobe = nprobe;
}
pub fn add(&mut self, doc_id: u32, vector: Vec<f32>) -> Result<(), RetrieveError> {
self.add_slice(doc_id, &vector)
}
pub fn add_slice(&mut self, doc_id: u32, vector: &[f32]) -> Result<(), RetrieveError> {
if self.built {
return Err(RetrieveError::InvalidParameter(
"cannot add vectors after index is built".into(),
));
}
if vector.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: vector.len(),
doc_dim: self.dimension,
});
}
let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
self.vectors.extend(vector.iter().map(|x| x / norm));
} else {
self.vectors.extend_from_slice(vector);
}
self.doc_ids.push(doc_id);
self.num_vectors += 1;
Ok(())
}
pub fn add_batch(&mut self, doc_ids: &[u32], vectors: &[f32]) -> Result<(), RetrieveError> {
if vectors.len() != doc_ids.len() * self.dimension {
return Err(RetrieveError::InvalidParameter(format!(
"expected {} floats for {} vectors of dim {}, got {}",
doc_ids.len() * self.dimension,
doc_ids.len(),
self.dimension,
vectors.len()
)));
}
for (i, &doc_id) in doc_ids.iter().enumerate() {
let start = i * self.dimension;
let end = start + self.dimension;
self.add_slice(doc_id, &vectors[start..end])?;
}
Ok(())
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.built {
return Ok(());
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let num_clusters = self.params.num_clusters.min(self.num_vectors);
let mut kmeans = crate::partitioning::kmeans::KMeans::new(self.dimension, num_clusters)?;
kmeans.fit(&self.vectors, self.num_vectors)?;
self.centroids = kmeans
.centroids()
.iter()
.flat_map(|c: &Vec<f32>| c.iter().copied())
.collect();
#[cfg(feature = "hnsw")]
{
let nc = self.centroids.len() / self.dimension;
let mut hnsw = crate::hnsw::HNSWIndex::builder(self.dimension)
.m(16)
.ef_construction(200)
.auto_normalize(true)
.build()?;
for i in 0..nc {
let centroid = self.get_centroid(i);
hnsw.add_slice(i as u32, centroid)?;
}
hnsw.build()?;
self.coarse_quantizer = Some(hnsw);
}
let assignments = kmeans.assign_clusters(&self.vectors, self.num_vectors);
let mut cluster_indices: Vec<Vec<u32>> = vec![Vec::new(); num_clusters];
for (vector_idx, &cluster_idx) in assignments.iter().enumerate() {
cluster_indices[cluster_idx].push(vector_idx as u32);
}
self.clusters = Vec::with_capacity(num_clusters);
for (cluster_idx, indices) in cluster_indices.into_iter().enumerate() {
let centroid = self.get_centroid(cluster_idx).to_vec();
let mut quantized = Vec::with_capacity(indices.len());
for &vector_idx in &indices {
let vec = self.get_vector(vector_idx as usize);
let qv = self
.quantizer
.quantize_with_centroid(vec, ¢roid)
.map_err(|e| {
RetrieveError::InvalidParameter(format!("RaBitQ quantize: {e}"))
})?;
quantized.push(qv);
}
self.clusters.push(Cluster {
vector_indices: indices,
quantized,
});
}
self.built = true;
Ok(())
}
pub fn compact(&mut self) {
assert!(self.built, "compact() called before build()");
self.vectors = Vec::new();
self.compacted = true;
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
self.search_with_ef(query, k, self.params.nprobe)
}
pub fn search_with_ef(
&self,
query: &[f32],
k: usize,
nprobe: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let query_norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
let query_normalized: Vec<f32> = if query_norm > 1e-10 {
query.iter().map(|x| x / query_norm).collect()
} else {
query.to_vec()
};
let query = query_normalized.as_slice();
let cluster_distances = self.find_nearest_centroids(query, nprobe);
let rerank_size = (k * 10).max(k * nprobe).max(64);
let rotated_query = self
.quantizer
.rotate_query(query)
.map_err(|e| RetrieveError::InvalidParameter(format!("rotate query: {e}")))?;
let mut heap: std::collections::BinaryHeap<(FloatOrd, u32)> =
std::collections::BinaryHeap::with_capacity(rerank_size + 1);
for (cluster_idx, _centroid_dist) in &cluster_distances {
let cluster = &self.clusters[*cluster_idx];
if cluster.vector_indices.is_empty() {
continue;
}
for (i, qv) in cluster.quantized.iter().enumerate() {
let dist = RaBitQQuantizer::approximate_l2_sqr_prerotated(&rotated_query, qv);
let vec_idx = cluster.vector_indices[i];
if heap.len() < rerank_size {
heap.push((FloatOrd(dist), vec_idx));
} else if let Some(&(FloatOrd(worst), _)) = heap.peek() {
if dist < worst {
heap.pop();
heap.push((FloatOrd(dist), vec_idx));
}
}
}
}
let mut results: Vec<(u32, f32)> = if self.compacted {
heap.into_iter()
.map(|(FloatOrd(dist), vec_idx)| (self.doc_ids[vec_idx as usize], dist))
.collect()
} else {
heap.into_iter()
.map(|(_, vec_idx)| {
let vec = self.get_vector(vec_idx as usize);
let dist = crate::distance::cosine_distance_normalized(query, vec);
(self.doc_ids[vec_idx as usize], dist)
})
.collect()
};
results.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
results.truncate(k);
Ok(results)
}
pub fn len(&self) -> usize {
self.num_vectors
}
pub fn is_empty(&self) -> bool {
self.num_vectors == 0
}
pub fn memory_usage(&self) -> crate::memory::MemoryReport {
let vectors_bytes = self.vectors.len() * std::mem::size_of::<f32>();
let quantized_bytes: usize = self
.clusters
.iter()
.flat_map(|c| &c.quantized)
.map(|qv| {
qv.binary_codes.len()
+ qv.extended_codes.len()
+ qv.codes.len() * std::mem::size_of::<u16>()
})
.sum();
let metadata_bytes = self.doc_ids.len() * std::mem::size_of::<u32>()
+ self.centroids.len() * std::mem::size_of::<f32>();
crate::memory::MemoryReport {
vectors_bytes,
graph_bytes: 0,
quantized_bytes,
metadata_bytes,
}
}
#[inline]
fn get_vector(&self, idx: usize) -> &[f32] {
let start = idx * self.dimension;
&self.vectors[start..start + self.dimension]
}
#[inline]
fn get_centroid(&self, idx: usize) -> &[f32] {
let start = idx * self.dimension;
&self.centroids[start..start + self.dimension]
}
fn find_nearest_centroids(&self, query: &[f32], nprobe: usize) -> Vec<(usize, f32)> {
#[cfg(feature = "hnsw")]
if let Some(ref hnsw) = self.coarse_quantizer {
let ef = nprobe * 2;
if let Ok(results) = hnsw.search(query, nprobe, ef.max(nprobe)) {
return results
.into_iter()
.map(|(id, d)| (id as usize, d))
.collect();
}
}
let num_centroids = self.centroids.len() / self.dimension;
let mut dists: Vec<(usize, f32)> = (0..num_centroids)
.map(|idx| {
let c = self.get_centroid(idx);
(idx, crate::distance::cosine_distance_normalized(query, c))
})
.collect();
let nprobe = nprobe.min(dists.len());
if nprobe < dists.len() {
dists.select_nth_unstable_by(nprobe, |a, b| a.1.total_cmp(&b.1));
dists.truncate(nprobe);
}
dists.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
dists
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn make_vectors(n: usize, dim: usize, seed: u64) -> Vec<f32> {
let mut rng = seed;
(0..n * dim)
.map(|_| {
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
((rng >> 33) as f32 / (1u64 << 31) as f32) - 1.0
})
.collect()
}
#[test]
fn build_and_search_basic() {
let dim = 32;
let n = 200;
let data = make_vectors(n, dim, 42);
let doc_ids: Vec<u32> = (0..n as u32).collect();
let params = IVFRaBitQParams {
num_clusters: 8,
nprobe: 4,
total_bits: 4,
seed: 42,
};
let mut index = IVFRaBitQIndex::new(dim, params).unwrap();
index.add_batch(&doc_ids, &data).unwrap();
index.build().unwrap();
let query = &data[0..dim]; let results = index.search(query, 5).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 5);
assert!(
results.iter().any(|(id, _)| *id == 0),
"expected doc_id 0 in results: {:?}",
results
);
}
#[test]
fn empty_index_returns_error() {
let params = IVFRaBitQParams::default();
let mut index = IVFRaBitQIndex::new(32, params).unwrap();
assert!(index.build().is_err());
}
#[test]
fn dimension_mismatch_rejected() {
let params = IVFRaBitQParams::default();
let mut index = IVFRaBitQIndex::new(32, params).unwrap();
assert!(index.add(0, vec![1.0; 64]).is_err());
}
#[test]
fn binary_quantization_works() {
let dim = 64;
let n = 100;
let data = make_vectors(n, dim, 99);
let doc_ids: Vec<u32> = (0..n as u32).collect();
let params = IVFRaBitQParams {
num_clusters: 4,
nprobe: 4,
total_bits: 1, seed: 42,
};
let mut index = IVFRaBitQIndex::new(dim, params).unwrap();
index.add_batch(&doc_ids, &data).unwrap();
index.build().unwrap();
let results = index.search(&data[0..dim], 3).unwrap();
assert!(!results.is_empty());
}
#[test]
fn compact_search_works() {
let dim = 32;
let n = 200;
let data = make_vectors(n, dim, 42);
let doc_ids: Vec<u32> = (0..n as u32).collect();
let params = IVFRaBitQParams {
num_clusters: 8,
nprobe: 4,
total_bits: 4,
seed: 42,
};
let mut index = IVFRaBitQIndex::new(dim, params).unwrap();
index.add_batch(&doc_ids, &data).unwrap();
index.build().unwrap();
index.compact();
let query = &data[0..dim];
let results = index.search(query, 5).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 5);
for &(id, dist) in &results {
assert!((id as usize) < n, "doc_id {id} out of range");
assert!(dist >= 0.0, "negative distance {dist}");
}
}
#[test]
fn self_search_recall() {
let dim = 32;
let n = 100;
let data = make_vectors(n, dim, 7);
let doc_ids: Vec<u32> = (0..n as u32).collect();
let params = IVFRaBitQParams {
num_clusters: 4,
nprobe: 4, total_bits: 4,
seed: 42,
};
let mut index = IVFRaBitQIndex::new(dim, params).unwrap();
index.add_batch(&doc_ids, &data).unwrap();
index.build().unwrap();
let mut hits = 0;
for i in 0..n {
let query = &data[i * dim..(i + 1) * dim];
let results = index.search(query, 1).unwrap();
if results.first().map(|(id, _)| *id) == Some(i as u32) {
hits += 1;
}
}
let recall = hits as f64 / n as f64;
assert!(
recall > 0.7,
"self-search recall too low: {recall:.2} ({hits}/{n})"
);
}
}