use super::distance::DistanceEngine;
use super::graph::{NativeHnsw, NO_ENTRY_POINT};
use super::layer::NodeId;
use crate::quantization::{RaBitQIndex, RaBitQVectorStore};
use parking_lot::{Mutex, RwLock};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct RaBitQPrecisionConfig {
pub oversampling_ratio: usize,
pub min_index_size: usize,
}
impl Default for RaBitQPrecisionConfig {
fn default() -> Self {
Self {
oversampling_ratio: 6,
min_index_size: 5000,
}
}
}
pub struct RaBitQPrecisionHnsw<D: DistanceEngine> {
pub(in crate::index::hnsw) inner: NativeHnsw<D>,
rabitq_index: RwLock<Option<Arc<RaBitQIndex>>>,
rabitq_store: RwLock<Option<RaBitQVectorStore>>,
dimension: usize,
training_sample_size: usize,
training_buffer: Mutex<Vec<Vec<f32>>>,
}
impl<D: DistanceEngine> RaBitQPrecisionHnsw<D> {
pub fn new(
distance: D,
dimension: usize,
max_connections: usize,
ef_construction: usize,
max_elements: usize,
) -> crate::error::Result<Self> {
Ok(Self {
inner: NativeHnsw::new_with_dimension(
distance,
max_connections,
ef_construction,
max_elements,
dimension,
)?,
rabitq_index: RwLock::new(None),
rabitq_store: RwLock::new(None),
dimension,
training_sample_size: 1000.min(max_elements),
training_buffer: Mutex::new(Vec::with_capacity(1000)),
})
}
#[must_use]
pub fn from_inner(inner: NativeHnsw<D>, _distance: D, dimension: usize) -> Self {
Self {
inner,
rabitq_index: RwLock::new(None),
rabitq_store: RwLock::new(None),
dimension,
training_sample_size: 1000,
training_buffer: Mutex::new(Vec::with_capacity(1000)),
}
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn is_quantizer_trained(&self) -> bool {
self.rabitq_index.read().is_some()
}
pub fn insert(&self, vector: &[f32]) -> crate::error::Result<NodeId> {
debug_assert_eq!(vector.len(), self.dimension);
let index_guard = self.rabitq_index.read();
if let Some(rabitq) = index_guard.as_ref().map(Arc::clone) {
drop(index_guard);
let encoded = rabitq.encode(vector)?;
if let Some(store) = self.rabitq_store.write().as_mut() {
store.push(&encoded.bits, encoded.correction);
}
} else {
drop(index_guard);
self.insert_training_phase(vector)?;
}
self.inner.insert(vector)
}
fn insert_training_phase(&self, vector: &[f32]) -> crate::error::Result<()> {
let mut buffer = self.training_buffer.lock();
buffer.push(vector.to_vec());
if buffer.len() >= self.training_sample_size {
drop(buffer);
self.train_rabitq()?;
}
Ok(())
}
#[must_use]
pub fn search(&self, query: &[f32], k: usize, ef_search: usize) -> Vec<(NodeId, f32)> {
if self.rabitq_index.read().is_none() {
return self.search_and_transform(query, k, ef_search);
}
self.search_rabitq_precision(query, k, ef_search)
}
fn search_and_transform(
&self,
query: &[f32],
k: usize,
ef_search: usize,
) -> Vec<(NodeId, f32)> {
self.inner
.search(query, k, ef_search)
.into_iter()
.map(|(id, raw)| (id, self.inner.transform_score(raw)))
.collect()
}
pub fn force_train_quantizer(&self) -> crate::error::Result<()> {
if self.rabitq_index.read().is_none() && !self.training_buffer.lock().is_empty() {
self.train_rabitq()?;
}
Ok(())
}
}
impl<D: DistanceEngine> RaBitQPrecisionHnsw<D> {
#[cfg(feature = "persistence")]
fn train_rabitq(&self) -> crate::error::Result<()> {
let mut index_guard = self.rabitq_index.write();
if index_guard.is_some() {
return Ok(());
}
let training_data = {
let mut buffer = self.training_buffer.lock();
if buffer.is_empty() {
return Ok(());
}
let data = std::mem::take(&mut *buffer);
buffer.shrink_to_fit();
data
};
let rabitq = Arc::new(RaBitQIndex::train(&training_data, 42)?);
let mut store = RaBitQVectorStore::new(self.dimension, self.inner.len() + 1000);
for vec in &training_data {
let encoded = rabitq.encode(vec)?;
store.push(&encoded.bits, encoded.correction);
}
*self.rabitq_store.write() = Some(store);
*index_guard = Some(rabitq);
Ok(())
}
#[cfg(not(feature = "persistence"))]
fn train_rabitq(&self) -> crate::error::Result<()> {
Ok(())
}
fn search_rabitq_precision(
&self,
query: &[f32],
k: usize,
ef_search: usize,
) -> Vec<(NodeId, f32)> {
let index_guard = self.rabitq_index.read();
let Some(rabitq) = index_guard.as_ref() else {
return self.search_and_transform(query, k, ef_search);
};
let rabitq = Arc::clone(rabitq);
drop(index_guard);
let store_guard = self.rabitq_store.read();
let Some(store) = store_guard.as_ref() else {
return self.search_and_transform(query, k, ef_search);
};
let Some(prepared) = rabitq.prepare_query(query) else {
return self.search_and_transform(query, k, ef_search);
};
let config = RaBitQPrecisionConfig::default();
let candidates_k = k * config.oversampling_ratio;
let coarse = self.search_layer_rabitq(&prepared, candidates_k, ef_search, &rabitq, store);
if coarse.is_empty() {
return Vec::new();
}
let candidate_ids: Vec<NodeId> = coarse.into_iter().map(|(id, _)| id).collect();
self.rerank_with_exact_f32(query, &candidate_ids, k)
}
fn rerank_with_exact_f32(
&self,
query: &[f32],
candidate_ids: &[NodeId],
k: usize,
) -> Vec<(NodeId, f32)> {
let vectors_guard = self.inner.vectors.read();
let mut reranked: Vec<(NodeId, f32)> = if let Some(vectors) = vectors_guard.as_ref() {
candidate_ids
.iter()
.filter_map(|&node_id| {
let vec = vectors.get(node_id)?;
let raw_dist = self.inner.compute_distance(query, vec);
let final_dist = self.inner.transform_score(raw_dist);
Some((node_id, final_dist))
})
.collect()
} else {
Vec::new()
};
reranked.sort_by(|a, b| a.1.total_cmp(&b.1));
reranked.truncate(k);
reranked
}
}