use super::distance::DistanceEngine;
use super::graph::NativeHnsw;
use super::layer::NodeId;
use super::quantization::{QuantizedVectorStore, ScalarQuantizer};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct DualPrecisionConfig {
pub oversampling_ratio: usize,
pub use_int8_traversal: bool,
pub min_index_size: usize,
pub debug_timings: bool,
}
impl Default for DualPrecisionConfig {
fn default() -> Self {
Self {
oversampling_ratio: 4,
use_int8_traversal: true,
min_index_size: 10_000,
debug_timings: false,
}
}
}
pub struct DualPrecisionHnsw<D: DistanceEngine> {
inner: NativeHnsw<D>,
quantizer: Option<Arc<ScalarQuantizer>>,
quantized_store: Option<QuantizedVectorStore>,
dimension: usize,
training_sample_size: usize,
training_buffer: Vec<Vec<f32>>,
}
impl<D: DistanceEngine> DualPrecisionHnsw<D> {
pub fn new(
distance: D,
dimension: usize,
max_connections: usize,
ef_construction: usize,
max_elements: usize,
) -> crate::error::Result<Self> {
Ok(Self {
inner: NativeHnsw::new_with_dimension(
distance,
max_connections,
ef_construction,
max_elements,
dimension,
)?,
quantizer: None,
quantized_store: None,
dimension,
training_sample_size: 1000.min(max_elements),
training_buffer: Vec::with_capacity(1000),
})
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn is_quantizer_trained(&self) -> bool {
self.quantizer.is_some()
}
pub fn insert(&mut self, vector: &[f32]) -> crate::error::Result<NodeId> {
debug_assert_eq!(vector.len(), self.dimension);
if let Some(ref mut store) = self.quantized_store {
store.push(vector);
} else {
self.training_buffer.push(vector.to_vec());
if self.training_buffer.len() >= self.training_sample_size {
self.train_quantizer();
}
}
self.inner.insert(vector)
}
fn train_quantizer(&mut self) {
if self.training_buffer.is_empty() {
return;
}
let refs: Vec<&[f32]> = self.training_buffer.iter().map(Vec::as_slice).collect();
let quantizer = Arc::new(ScalarQuantizer::train(&refs));
let mut store = QuantizedVectorStore::new(Arc::clone(&quantizer), self.inner.len() + 1000);
for vec in &self.training_buffer {
store.push(vec);
}
self.quantizer = Some(quantizer);
self.quantized_store = Some(store);
self.training_buffer.clear();
self.training_buffer.shrink_to_fit();
}
pub fn force_train_quantizer(&mut self) {
if self.quantizer.is_none() && !self.training_buffer.is_empty() {
self.train_quantizer();
}
}
#[must_use]
pub fn search(&self, query: &[f32], k: usize, ef_search: usize) -> Vec<(NodeId, f32)> {
if self.quantizer.is_none() {
return self.inner.search(query, k, ef_search);
}
self.search_dual_precision(query, k, ef_search)
}
fn search_dual_precision(
&self,
query: &[f32],
k: usize,
ef_search: usize,
) -> Vec<(NodeId, f32)> {
let rerank_k = (ef_search * 2).max(k * 4);
let candidates = self.inner.search(query, rerank_k, ef_search);
if candidates.is_empty() {
return candidates;
}
let candidate_ids: Vec<NodeId> = candidates.iter().map(|&(id, _)| id).collect();
self.rerank_with_exact_f32(query, &candidate_ids, k)
}
fn rerank_with_exact_f32(
&self,
query: &[f32],
candidate_ids: &[NodeId],
k: usize,
) -> Vec<(NodeId, f32)> {
let vectors_guard = self.inner.vectors.read();
let mut reranked: Vec<(NodeId, f32)> = if let Some(vectors) = vectors_guard.as_ref() {
candidate_ids
.iter()
.filter_map(|&node_id| {
let vec = vectors.get(node_id)?;
let exact_dist = self.inner.compute_distance(query, vec);
Some((node_id, exact_dist))
})
.collect()
} else {
Vec::new()
};
reranked.sort_by(|a, b| a.1.total_cmp(&b.1));
reranked.truncate(k);
reranked
}
#[must_use]
pub fn quantizer(&self) -> Option<&Arc<ScalarQuantizer>> {
self.quantizer.as_ref()
}
#[must_use]
pub fn search_with_config(
&self,
query: &[f32],
k: usize,
ef_search: usize,
config: &DualPrecisionConfig,
) -> Vec<(NodeId, f32)> {
if self.quantizer.is_none() || !config.use_int8_traversal {
return self.inner.search(query, k, ef_search);
}
if self.inner.len() < config.min_index_size {
return self.inner.search(query, k, ef_search);
}
self.search_int8_traversal(query, k, ef_search, config)
}
fn search_int8_traversal(
&self,
query: &[f32],
k: usize,
ef_search: usize,
config: &DualPrecisionConfig,
) -> Vec<(NodeId, f32)> {
let (Some(quantizer), Some(store)) =
(self.quantizer.as_ref(), self.quantized_store.as_ref())
else {
debug_assert!(
false,
"Invariant violated: int8 traversal requires trained quantizer and store"
);
return self.inner.search(query, k, ef_search);
};
let query_quantized = quantizer.quantize(query);
let candidates_k = k * config.oversampling_ratio;
let coarse_candidates =
self.search_layer_int8(&query_quantized.data, candidates_k, ef_search, store);
if coarse_candidates.is_empty() {
return Vec::new();
}
let candidate_ids: Vec<NodeId> = coarse_candidates.into_iter().map(|(id, _)| id).collect();
self.rerank_with_exact_f32(query, &candidate_ids, k)
}
fn search_layer_int8(
&self,
query_int8: &[u8],
k: usize,
ef_search: usize,
store: &QuantizedVectorStore,
) -> Vec<(NodeId, u32)> {
use rustc_hash::FxHashSet;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
let entry_point = *self.inner.entry_point.read();
let Some(ep) = entry_point else {
return Vec::new();
};
let max_layer = self
.inner
.max_layer
.load(std::sync::atomic::Ordering::Relaxed);
let mut current_ep = ep;
for layer_idx in (1..=max_layer).rev() {
current_ep = self.greedy_search_int8(query_int8, current_ep, layer_idx, store);
}
let ef = ef_search.max(k);
let mut visited: FxHashSet<NodeId> = FxHashSet::default();
let mut candidates: BinaryHeap<Reverse<(u32, NodeId)>> = BinaryHeap::new();
let mut results: BinaryHeap<(u32, NodeId)> = BinaryHeap::new();
Self::init_search_from_ep(
store,
query_int8,
current_ep,
&mut visited,
&mut candidates,
&mut results,
);
while let Some(Reverse((c_dist, c_node))) = candidates.pop() {
if c_dist > results.peek().map_or(u32::MAX, |r| r.0) && results.len() >= ef {
break;
}
let layers = self.inner.layers.read();
let _ = layers[0].with_neighbors(c_node, |neighbors| {
Self::process_int8_neighbors(
store,
query_int8,
neighbors,
ef,
&mut visited,
&mut candidates,
&mut results,
);
});
}
let mut result_vec: Vec<(NodeId, u32)> = results.into_iter().map(|(d, n)| (n, d)).collect();
result_vec.sort_by_key(|&(_, d)| d);
result_vec.truncate(k);
result_vec
}
fn init_search_from_ep(
store: &QuantizedVectorStore,
query_int8: &[u8],
ep: NodeId,
visited: &mut rustc_hash::FxHashSet<NodeId>,
candidates: &mut std::collections::BinaryHeap<std::cmp::Reverse<(u32, NodeId)>>,
results: &mut std::collections::BinaryHeap<(u32, NodeId)>,
) {
if let Some(ep_slice) = store.get_slice(ep) {
let dist = store
.quantizer()
.distance_l2_quantized_slice(query_int8, ep_slice);
candidates.push(std::cmp::Reverse((dist, ep)));
results.push((dist, ep));
visited.insert(ep);
}
}
fn process_int8_neighbors(
store: &QuantizedVectorStore,
query_int8: &[u8],
neighbors: &[NodeId],
ef: usize,
visited: &mut rustc_hash::FxHashSet<NodeId>,
candidates: &mut std::collections::BinaryHeap<std::cmp::Reverse<(u32, NodeId)>>,
results: &mut std::collections::BinaryHeap<(u32, NodeId)>,
) {
let quantizer = store.quantizer();
for &neighbor in neighbors {
if !visited.insert(neighbor) {
continue;
}
let Some(neighbor_slice) = store.get_slice(neighbor) else {
continue;
};
let dist = quantizer.distance_l2_quantized_slice(query_int8, neighbor_slice);
let furthest = results.peek().map_or(u32::MAX, |r| r.0);
if dist < furthest || results.len() < ef {
candidates.push(std::cmp::Reverse((dist, neighbor)));
results.push((dist, neighbor));
if results.len() > ef {
results.pop();
}
}
}
}
fn greedy_search_int8(
&self,
query_int8: &[u8],
entry: NodeId,
layer: usize,
store: &QuantizedVectorStore,
) -> NodeId {
let quantizer = store.quantizer();
let mut current = entry;
let mut current_dist = store.get_slice(entry).map_or(u32::MAX, |s| {
quantizer.distance_l2_quantized_slice(query_int8, s)
});
loop {
let mut improved = false;
let layers = self.inner.layers.read();
let _ = layers[layer].with_neighbors(current, |neighbors| {
for &neighbor in neighbors {
if let Some(neighbor_slice) = store.get_slice(neighbor) {
let dist =
quantizer.distance_l2_quantized_slice(query_int8, neighbor_slice);
if dist < current_dist {
current = neighbor;
current_dist = dist;
improved = true;
}
}
}
});
if !improved {
break;
}
}
current
}
}