use super::distance::DistanceEngine;
use super::graph::NativeHnsw;
use super::layer::NodeId;
use super::quantization::{QuantizedVectorStore, ScalarQuantizer};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct DualPrecisionConfig {
pub oversampling_ratio: usize,
pub use_int8_traversal: bool,
pub min_index_size: usize,
pub debug_timings: bool,
}
impl Default for DualPrecisionConfig {
fn default() -> Self {
Self {
oversampling_ratio: 4,
use_int8_traversal: true,
min_index_size: 10_000,
debug_timings: false,
}
}
}
pub struct DualPrecisionHnsw<D: DistanceEngine> {
pub(in crate::index::hnsw) inner: NativeHnsw<D>,
quantizer: Option<Arc<ScalarQuantizer>>,
quantized_store: Option<QuantizedVectorStore>,
dimension: usize,
training_sample_size: usize,
training_buffer: Vec<Vec<f32>>,
}
impl<D: DistanceEngine> DualPrecisionHnsw<D> {
pub fn new(
distance: D,
dimension: usize,
max_connections: usize,
ef_construction: usize,
max_elements: usize,
) -> crate::error::Result<Self> {
Ok(Self {
inner: NativeHnsw::new_with_dimension(
distance,
max_connections,
ef_construction,
max_elements,
dimension,
)?,
quantizer: None,
quantized_store: None,
dimension,
training_sample_size: 1000.min(max_elements),
training_buffer: Vec::with_capacity(1000),
})
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn is_quantizer_trained(&self) -> bool {
self.quantizer.is_some()
}
pub fn insert(&mut self, vector: &[f32]) -> crate::error::Result<NodeId> {
debug_assert_eq!(vector.len(), self.dimension);
if let Some(ref mut store) = self.quantized_store {
store.push(vector);
} else {
self.training_buffer.push(vector.to_vec());
if self.training_buffer.len() >= self.training_sample_size {
self.train_quantizer();
}
}
self.inner.insert(vector)
}
fn train_quantizer(&mut self) {
if self.training_buffer.is_empty() {
return;
}
let refs: Vec<&[f32]> = self.training_buffer.iter().map(Vec::as_slice).collect();
let quantizer = Arc::new(
ScalarQuantizer::train(&refs).expect("invariant: training_buffer is non-empty"),
);
let mut store = QuantizedVectorStore::new(Arc::clone(&quantizer), self.inner.len() + 1000);
for vec in &self.training_buffer {
store.push(vec);
}
self.quantizer = Some(quantizer);
self.quantized_store = Some(store);
self.training_buffer.clear();
self.training_buffer.shrink_to_fit();
}
pub fn force_train_quantizer(&mut self) {
if self.quantizer.is_none() && !self.training_buffer.is_empty() {
self.train_quantizer();
}
}
#[must_use]
pub fn search(&self, query: &[f32], k: usize, ef_search: usize) -> Vec<(NodeId, f32)> {
if self.quantizer.is_none() {
return self.search_and_transform(query, k, ef_search);
}
self.search_dual_precision(query, k, ef_search)
}
fn search_and_transform(
&self,
query: &[f32],
k: usize,
ef_search: usize,
) -> Vec<(NodeId, f32)> {
self.inner
.search(query, k, ef_search)
.into_iter()
.map(|(id, raw)| (id, self.inner.transform_score(raw)))
.collect()
}
fn search_dual_precision(
&self,
query: &[f32],
k: usize,
ef_search: usize,
) -> Vec<(NodeId, f32)> {
let rerank_k = (ef_search * 2).max(k * 4);
let candidates = self.inner.search(query, rerank_k, ef_search);
if candidates.is_empty() {
return candidates;
}
let candidate_ids: Vec<NodeId> = candidates.iter().map(|&(id, _)| id).collect();
self.rerank_with_exact_f32(query, &candidate_ids, k)
}
pub(super) fn rerank_with_exact_f32(
&self,
query: &[f32],
candidate_ids: &[NodeId],
k: usize,
) -> Vec<(NodeId, f32)> {
let vectors_guard = self.inner.vectors.read();
let mut reranked: Vec<(NodeId, f32)> = if let Some(vectors) = vectors_guard.as_ref() {
candidate_ids
.iter()
.filter_map(|&node_id| {
let vec = vectors.get(node_id)?;
let raw_dist = self.inner.compute_distance(query, vec);
let final_dist = self.inner.transform_score(raw_dist);
Some((node_id, final_dist))
})
.collect()
} else {
Vec::new()
};
reranked.sort_by(|a, b| a.1.total_cmp(&b.1));
reranked.truncate(k);
reranked
}
#[must_use]
pub fn quantizer(&self) -> Option<&Arc<ScalarQuantizer>> {
self.quantizer.as_ref()
}
#[must_use]
pub fn search_with_config(
&self,
query: &[f32],
k: usize,
ef_search: usize,
config: &DualPrecisionConfig,
) -> Vec<(NodeId, f32)> {
if self.quantizer.is_none() || !config.use_int8_traversal {
return self.search_and_transform(query, k, ef_search);
}
if self.inner.len() < config.min_index_size {
return self.search_and_transform(query, k, ef_search);
}
self.search_int8_traversal(query, k, ef_search, config)
}
fn search_int8_traversal(
&self,
query: &[f32],
k: usize,
ef_search: usize,
config: &DualPrecisionConfig,
) -> Vec<(NodeId, f32)> {
let (Some(quantizer), Some(store)) =
(self.quantizer.as_ref(), self.quantized_store.as_ref())
else {
debug_assert!(
false,
"Invariant violated: int8 traversal requires trained quantizer and store"
);
return self.inner.search(query, k, ef_search);
};
let query_quantized = quantizer.quantize(query);
let candidates_k = k * config.oversampling_ratio;
let coarse_candidates =
self.search_layer_int8(&query_quantized.data, candidates_k, ef_search, store);
if coarse_candidates.is_empty() {
return Vec::new();
}
let candidate_ids: Vec<NodeId> = coarse_candidates.into_iter().map(|(id, _)| id).collect();
self.rerank_with_exact_f32(query, &candidate_ids, k)
}
}