use super::VectorAccessor;
use super::quantization::{BinaryQuantizer, ProductQuantizer, QuantizationType, ScalarQuantizer};
use super::{HnswConfig, HnswIndex, compute_distance};
use grafeo_common::types::NodeId;
use ordered_float::OrderedFloat;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
pub struct QuantizedHnswIndex {
hnsw: HnswIndex,
vectors: RwLock<HashMap<NodeId, Arc<[f32]>>>,
quantization_type: QuantizationType,
scalar_quantizer: RwLock<Option<ScalarQuantizer>>,
product_quantizer: RwLock<Option<ProductQuantizer>>,
scalar_vectors: RwLock<HashMap<NodeId, Vec<u8>>>,
binary_vectors: RwLock<HashMap<NodeId, Vec<u64>>>,
product_codes: RwLock<HashMap<NodeId, Vec<u8>>>,
rescore: bool,
rescore_factor: usize,
training_threshold: usize,
training_samples: RwLock<Vec<Arc<[f32]>>>,
quantizer_trained: RwLock<bool>,
}
impl QuantizedHnswIndex {
#[must_use]
pub fn new(config: HnswConfig, quantization: QuantizationType) -> Self {
Self {
hnsw: HnswIndex::new(config),
vectors: RwLock::new(HashMap::new()),
quantization_type: quantization,
scalar_quantizer: RwLock::new(None),
product_quantizer: RwLock::new(None),
scalar_vectors: RwLock::new(HashMap::new()),
binary_vectors: RwLock::new(HashMap::new()),
product_codes: RwLock::new(HashMap::new()),
rescore: true,
rescore_factor: 2,
training_threshold: 1000,
training_samples: RwLock::new(Vec::new()),
quantizer_trained: RwLock::new(false),
}
}
#[must_use]
pub fn with_seed(config: HnswConfig, quantization: QuantizationType, seed: u64) -> Self {
Self {
hnsw: HnswIndex::with_seed(config, seed),
vectors: RwLock::new(HashMap::new()),
quantization_type: quantization,
scalar_quantizer: RwLock::new(None),
product_quantizer: RwLock::new(None),
scalar_vectors: RwLock::new(HashMap::new()),
binary_vectors: RwLock::new(HashMap::new()),
product_codes: RwLock::new(HashMap::new()),
rescore: true,
rescore_factor: 2,
training_threshold: 1000,
training_samples: RwLock::new(Vec::new()),
quantizer_trained: RwLock::new(false),
}
}
#[must_use]
pub fn without_rescore(mut self) -> Self {
self.rescore = false;
self
}
#[must_use]
pub fn with_rescore_factor(mut self, factor: usize) -> Self {
self.rescore_factor = factor.max(1);
self
}
#[must_use]
pub fn with_training_threshold(mut self, threshold: usize) -> Self {
self.training_threshold = threshold.max(10);
self
}
#[must_use]
pub fn quantization_type(&self) -> QuantizationType {
self.quantization_type
}
#[must_use]
pub fn config(&self) -> &HnswConfig {
self.hnsw.config()
}
#[must_use]
pub fn len(&self) -> usize {
self.hnsw.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.hnsw.is_empty()
}
#[must_use]
pub fn memory_usage(&self) -> usize {
let base = self.hnsw.len() * self.config().dimensions * 4; let quantized = match self.quantization_type {
QuantizationType::None => 0,
QuantizationType::Scalar => self.hnsw.len() * self.config().dimensions, QuantizationType::Binary => {
self.hnsw.len() * BinaryQuantizer::bytes_needed(self.config().dimensions)
}
QuantizationType::Product { num_subvectors } => self.hnsw.len() * num_subvectors, };
base + quantized
}
#[must_use]
pub fn theoretical_compression_ratio(&self) -> f32 {
self.quantization_type
.compression_ratio(self.config().dimensions) as f32
}
#[must_use]
pub fn memory_ratio(&self) -> f32 {
let full_size = self.hnsw.len() * self.config().dimensions * 4;
if full_size == 0 {
return 1.0;
}
self.memory_usage() as f32 / full_size as f32
}
fn accessor(&self) -> impl VectorAccessor + '_ {
let vectors = self.vectors.read();
let snapshot: HashMap<NodeId, Arc<[f32]>> =
vectors.iter().map(|(&id, v)| (id, Arc::clone(v))).collect();
move |id: NodeId| -> Option<Arc<[f32]>> { snapshot.get(&id).cloned() }
}
pub fn insert(&self, id: NodeId, vector: &[f32]) {
let arc: Arc<[f32]> = vector.into();
self.vectors.write().insert(id, arc);
let accessor = self.accessor();
self.hnsw.insert(id, vector, &accessor);
match self.quantization_type {
QuantizationType::None => {}
QuantizationType::Scalar => self.insert_scalar_quantized(id, vector),
QuantizationType::Binary => self.insert_binary_quantized(id, vector),
QuantizationType::Product { num_subvectors } => {
self.insert_product_quantized(id, vector, num_subvectors);
}
}
}
fn insert_scalar_quantized(&self, id: NodeId, vector: &[f32]) {
let trained = *self.quantizer_trained.read();
if trained {
if let Some(ref quantizer) = *self.scalar_quantizer.read() {
let quantized = quantizer.quantize(vector);
self.scalar_vectors.write().insert(id, quantized);
}
} else {
let vector_arc: Arc<[f32]> = vector.into();
let mut samples = self.training_samples.write();
samples.push(vector_arc);
if samples.len() >= self.training_threshold {
let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_ref()).collect();
let quantizer = ScalarQuantizer::train(&refs);
let mut scalar_vecs = self.scalar_vectors.write();
let vectors = self.vectors.read();
for (&old_id, old_vec) in vectors.iter() {
scalar_vecs.insert(old_id, quantizer.quantize(old_vec));
}
*self.scalar_quantizer.write() = Some(quantizer);
*self.quantizer_trained.write() = true;
samples.clear();
}
}
}
fn insert_binary_quantized(&self, id: NodeId, vector: &[f32]) {
let bits = BinaryQuantizer::quantize(vector);
self.binary_vectors.write().insert(id, bits);
}
fn insert_product_quantized(&self, id: NodeId, vector: &[f32], num_subvectors: usize) {
let trained = *self.quantizer_trained.read();
if trained {
if let Some(ref quantizer) = *self.product_quantizer.read() {
let codes = quantizer.quantize(vector);
self.product_codes.write().insert(id, codes);
}
} else {
let vector_arc: Arc<[f32]> = vector.into();
let mut samples = self.training_samples.write();
samples.push(vector_arc);
if samples.len() >= self.training_threshold {
let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_ref()).collect();
let quantizer = ProductQuantizer::train(&refs, num_subvectors, 256, 10);
let mut codes = self.product_codes.write();
let vectors = self.vectors.read();
for (&old_id, old_vec) in vectors.iter() {
codes.insert(old_id, quantizer.quantize(old_vec));
}
*self.product_quantizer.write() = Some(quantizer);
*self.quantizer_trained.write() = true;
samples.clear();
}
}
}
#[must_use]
pub fn search(&self, query: &[f32], k: usize) -> Vec<(NodeId, f32)> {
self.search_with_ef(query, k, self.config().ef)
}
#[must_use]
pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<(NodeId, f32)> {
let accessor = self.accessor();
match self.quantization_type {
QuantizationType::None => {
self.hnsw.search_with_ef(query, k, ef, &accessor)
}
QuantizationType::Scalar => self.search_scalar_quantized(query, k, ef, &accessor),
QuantizationType::Binary => self.search_binary_quantized(query, k, ef, &accessor),
QuantizationType::Product { .. } => {
self.search_product_quantized(query, k, ef, &accessor)
}
}
}
fn search_scalar_quantized(
&self,
query: &[f32],
k: usize,
ef: usize,
accessor: &impl VectorAccessor,
) -> Vec<(NodeId, f32)> {
let trained = *self.quantizer_trained.read();
if !trained {
return self.hnsw.search_with_ef(query, k, ef, accessor);
}
let num_candidates = if self.rescore {
k.saturating_mul(self.rescore_factor)
} else {
k
};
let candidates = self
.hnsw
.search_with_ef(query, num_candidates, ef, accessor);
if !self.rescore {
return candidates.into_iter().take(k).collect();
}
self.rescore_candidates(query, candidates, k)
}
fn search_binary_quantized(
&self,
query: &[f32],
k: usize,
ef: usize,
accessor: &impl VectorAccessor,
) -> Vec<(NodeId, f32)> {
let binary_vecs = self.binary_vectors.read();
if binary_vecs.is_empty() {
return self.hnsw.search_with_ef(query, k, ef, accessor);
}
let query_bits = BinaryQuantizer::quantize(query);
let dims = self.config().dimensions;
let num_candidates = if self.rescore {
k.saturating_mul(self.rescore_factor).saturating_mul(2) } else {
k
};
let hnsw_candidates = self
.hnsw
.search_with_ef(query, num_candidates, ef, accessor);
let mut scored: Vec<(NodeId, f32)> = hnsw_candidates
.iter()
.filter_map(|(id, _)| {
binary_vecs.get(id).map(|bits| {
let approx_dist =
BinaryQuantizer::approximate_euclidean(&query_bits, bits, dims);
(*id, approx_dist)
})
})
.collect();
scored.sort_by_key(|(_, d)| OrderedFloat(*d));
scored.truncate(num_candidates);
if !self.rescore {
return scored.into_iter().take(k).collect();
}
self.rescore_candidates(query, scored, k)
}
fn search_product_quantized(
&self,
query: &[f32],
k: usize,
ef: usize,
accessor: &impl VectorAccessor,
) -> Vec<(NodeId, f32)> {
let trained = *self.quantizer_trained.read();
if !trained {
return self.hnsw.search_with_ef(query, k, ef, accessor);
}
let num_candidates = if self.rescore {
k.saturating_mul(self.rescore_factor)
} else {
k
};
let candidates = self
.hnsw
.search_with_ef(query, num_candidates, ef, accessor);
if !self.rescore {
return candidates.into_iter().take(k).collect();
}
let pq_guard = self.product_quantizer.read();
let codes_guard = self.product_codes.read();
if let Some(ref pq) = *pq_guard {
let table = pq.build_distance_table(query);
let mut scored: Vec<(NodeId, f32)> = candidates
.into_iter()
.filter_map(|(id, _)| {
codes_guard.get(&id).map(|codes| {
let dist = pq.distance_with_table(&table, codes);
(id, dist.sqrt()) })
})
.collect();
scored.sort_by_key(|(_, d)| OrderedFloat(*d));
scored.truncate(k);
if self.rescore {
return self.rescore_candidates(query, scored, k);
}
scored
} else {
candidates.into_iter().take(k).collect()
}
}
fn rescore_candidates(
&self,
query: &[f32],
candidates: Vec<(NodeId, f32)>,
k: usize,
) -> Vec<(NodeId, f32)> {
let metric = self.config().metric;
let vectors = self.vectors.read();
let mut rescored: Vec<(NodeId, f32)> = candidates
.into_iter()
.filter_map(|(id, _approx_dist)| {
vectors.get(&id).map(|vec| {
let exact_dist = compute_distance(query, vec, metric);
(id, exact_dist)
})
})
.collect();
rescored.sort_by_key(|(_, d)| OrderedFloat(*d));
rescored.truncate(k);
rescored
}
#[must_use]
pub fn get(&self, id: NodeId) -> Option<Arc<[f32]>> {
self.vectors.read().get(&id).cloned()
}
#[must_use]
pub fn contains(&self, id: NodeId) -> bool {
self.hnsw.contains(id)
}
pub fn remove(&self, id: NodeId) -> bool {
self.vectors.write().remove(&id);
match self.quantization_type {
QuantizationType::None => {}
QuantizationType::Scalar => {
self.scalar_vectors.write().remove(&id);
}
QuantizationType::Binary => {
self.binary_vectors.write().remove(&id);
}
QuantizationType::Product { .. } => {
self.product_codes.write().remove(&id);
}
}
self.hnsw.remove(id)
}
pub fn batch_insert<'a, I>(&self, vectors: I)
where
I: IntoIterator<Item = (NodeId, &'a [f32])>,
{
for (id, vec) in vectors {
self.insert(id, vec);
}
}
#[must_use]
pub fn batch_search(&self, queries: &[Vec<f32>], k: usize) -> Vec<Vec<(NodeId, f32)>> {
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
queries
.par_iter()
.map(|query| self.search(query, k))
.collect()
}
#[cfg(not(feature = "parallel"))]
{
queries.iter().map(|query| self.search(query, k)).collect()
}
}
#[must_use]
pub fn search_with_filter(
&self,
query: &[f32],
k: usize,
allowlist: &std::collections::HashSet<NodeId>,
) -> Vec<(NodeId, f32)> {
let results = self.search(query, k.max(allowlist.len()));
results
.into_iter()
.filter(|(id, _)| allowlist.contains(id))
.take(k)
.collect()
}
#[must_use]
pub fn search_with_ef_and_filter(
&self,
query: &[f32],
k: usize,
ef: usize,
allowlist: &std::collections::HashSet<NodeId>,
) -> Vec<(NodeId, f32)> {
let results = self.search_with_ef(query, k.max(allowlist.len()), ef);
results
.into_iter()
.filter(|(id, _)| allowlist.contains(id))
.take(k)
.collect()
}
#[must_use]
pub fn batch_search_with_ef(
&self,
queries: &[Vec<f32>],
k: usize,
ef: usize,
) -> Vec<Vec<(NodeId, f32)>> {
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
queries
.par_iter()
.map(|query| self.search_with_ef(query, k, ef))
.collect()
}
#[cfg(not(feature = "parallel"))]
{
queries
.iter()
.map(|query| self.search_with_ef(query, k, ef))
.collect()
}
}
#[must_use]
pub fn batch_search_with_filter(
&self,
queries: &[Vec<f32>],
k: usize,
allowlist: &std::collections::HashSet<NodeId>,
) -> Vec<Vec<(NodeId, f32)>> {
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
queries
.par_iter()
.map(|query| self.search_with_filter(query, k, allowlist))
.collect()
}
#[cfg(not(feature = "parallel"))]
{
queries
.iter()
.map(|query| self.search_with_filter(query, k, allowlist))
.collect()
}
}
#[must_use]
pub fn batch_search_with_ef_and_filter(
&self,
queries: &[Vec<f32>],
k: usize,
ef: usize,
allowlist: &std::collections::HashSet<NodeId>,
) -> Vec<Vec<(NodeId, f32)>> {
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
queries
.par_iter()
.map(|query| self.search_with_ef_and_filter(query, k, ef, allowlist))
.collect()
}
#[cfg(not(feature = "parallel"))]
{
queries
.iter()
.map(|query| self.search_with_ef_and_filter(query, k, ef, allowlist))
.collect()
}
}
#[must_use]
pub fn snapshot_topology(&self) -> (Option<NodeId>, usize, Vec<(NodeId, Vec<Vec<NodeId>>)>) {
self.hnsw.snapshot_topology()
}
pub fn restore_topology(
&self,
entry_point: Option<NodeId>,
max_level: usize,
node_data: Vec<(NodeId, Vec<Vec<NodeId>>)>,
) {
self.hnsw
.restore_topology(entry_point, max_level, node_data);
}
#[must_use]
pub fn heap_memory_bytes(&self) -> usize {
self.hnsw.heap_memory_bytes() + self.memory_usage()
}
}
impl std::fmt::Debug for QuantizedHnswIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("QuantizedHnswIndex")
.field("len", &self.len())
.field("quantization", &self.quantization_type)
.field("rescore", &self.rescore)
.field("rescore_factor", &self.rescore_factor)
.field(
"theoretical_compression",
&self.theoretical_compression_ratio(),
)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::vector::DistanceMetric;
fn create_test_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
(0..n)
.map(|i| {
(0..dim)
.map(|j| ((i * dim + j) as f32) / (n * dim) as f32)
.collect()
})
.collect()
}
#[test]
fn test_quantized_hnsw_no_quantization() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = QuantizedHnswIndex::new(config, QuantizationType::None);
let vectors = create_test_vectors(50, 4);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec);
}
assert_eq!(index.len(), 50);
let results = index.search(&vectors[25], 5);
assert_eq!(results.len(), 5);
assert_eq!(results[0].0, NodeId::new(26));
}
#[test]
fn test_quantized_hnsw_scalar_quantization() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = QuantizedHnswIndex::with_seed(config, QuantizationType::Scalar, 42)
.with_training_threshold(10);
let vectors = create_test_vectors(50, 4);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec);
}
assert_eq!(index.len(), 50);
assert_eq!(index.theoretical_compression_ratio(), 4.0);
assert!(index.memory_ratio() > 1.0);
let results = index.search(&vectors[25], 5);
assert_eq!(results.len(), 5);
assert_eq!(results[0].0, NodeId::new(26));
}
#[test]
fn test_quantized_hnsw_binary_quantization() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = QuantizedHnswIndex::with_seed(config, QuantizationType::Binary, 42);
let vectors = create_test_vectors(50, 4);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec);
}
assert_eq!(index.len(), 50);
let results = index.search(&vectors[25], 5);
assert_eq!(results.len(), 5);
assert_eq!(results[0].0, NodeId::new(26));
}
#[test]
fn test_quantized_hnsw_without_rescore() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = QuantizedHnswIndex::with_seed(config, QuantizationType::Scalar, 42)
.with_training_threshold(10)
.without_rescore();
let vectors = create_test_vectors(50, 4);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec);
}
let results = index.search(&vectors[25], 5);
assert_eq!(results.len(), 5);
}
#[test]
fn test_quantized_hnsw_remove() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = QuantizedHnswIndex::new(config, QuantizationType::Binary);
index.insert(NodeId::new(1), &[0.1, 0.2, 0.3, 0.4]);
index.insert(NodeId::new(2), &[0.5, 0.6, 0.7, 0.8]);
assert_eq!(index.len(), 2);
assert!(index.remove(NodeId::new(1)));
assert_eq!(index.len(), 1);
assert!(!index.contains(NodeId::new(1)));
assert!(index.contains(NodeId::new(2)));
}
#[test]
fn test_quantized_hnsw_batch_operations() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = QuantizedHnswIndex::with_seed(config, QuantizationType::Scalar, 42)
.with_training_threshold(10);
let vectors = create_test_vectors(50, 4);
let pairs: Vec<_> = vectors
.iter()
.enumerate()
.map(|(i, v)| (NodeId::new(i as u64 + 1), v.as_slice()))
.collect();
index.batch_insert(pairs);
assert_eq!(index.len(), 50);
let queries = vec![vectors[10].clone(), vectors[30].clone()];
let results = index.batch_search(&queries, 3);
assert_eq!(results.len(), 2);
assert_eq!(results[0].len(), 3);
assert_eq!(results[1].len(), 3);
}
#[test]
fn test_quantization_type_enum() {
let dims = 384;
assert_eq!(QuantizationType::None.compression_ratio(dims), 1);
assert_eq!(QuantizationType::Scalar.compression_ratio(dims), 4);
assert_eq!(QuantizationType::Binary.compression_ratio(dims), 32);
assert_eq!(
QuantizationType::Product { num_subvectors: 8 }.compression_ratio(dims),
192
);
}
#[test]
fn test_quantized_hnsw_memory_usage() {
let config = HnswConfig::new(384, DistanceMetric::Cosine);
let index = QuantizedHnswIndex::new(config, QuantizationType::Scalar);
assert_eq!(index.memory_usage(), 0);
index.insert(NodeId::new(1), &vec![0.1f32; 384]);
assert!(index.memory_usage() > 0);
}
#[test]
fn test_quantized_hnsw_product_quantization() {
let config = HnswConfig::new(32, DistanceMetric::Euclidean);
let index = QuantizedHnswIndex::with_seed(
config,
QuantizationType::Product { num_subvectors: 8 },
42,
)
.with_training_threshold(20);
let vectors = create_test_vectors(50, 32);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec);
}
assert_eq!(index.len(), 50);
assert_eq!(index.theoretical_compression_ratio(), 16.0);
let results = index.search(&vectors[25], 5);
assert_eq!(results.len(), 5);
assert_eq!(results[0].0, NodeId::new(26));
}
#[test]
fn test_quantized_hnsw_product_before_training() {
let config = HnswConfig::new(16, DistanceMetric::Euclidean);
let index = QuantizedHnswIndex::with_seed(
config,
QuantizationType::Product { num_subvectors: 4 },
42,
)
.with_training_threshold(100);
let vectors = create_test_vectors(10, 16);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec);
}
let results = index.search(&vectors[5], 3);
assert_eq!(results.len(), 3);
assert_eq!(results[0].0, NodeId::new(6));
}
#[test]
fn test_search_with_allowlist_filter() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = QuantizedHnswIndex::new(config, QuantizationType::Scalar);
let vectors = create_test_vectors(50, 4);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec);
}
let allowlist: std::collections::HashSet<NodeId> =
[1, 10, 25, 40].iter().map(|&i| NodeId::new(i)).collect();
let results = index.search_with_filter(&vectors[24], 3, &allowlist);
assert!(!results.is_empty());
assert!(results.len() <= 3);
for (id, _) in &results {
assert!(allowlist.contains(id), "result {id:?} not in allowlist");
}
}
#[test]
fn test_search_with_ef_and_allowlist_filter() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = QuantizedHnswIndex::new(config, QuantizationType::Binary);
let vectors = create_test_vectors(50, 4);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec);
}
let allowlist: std::collections::HashSet<NodeId> =
[5, 15, 30].iter().map(|&i| NodeId::new(i)).collect();
let results = index.search_with_ef_and_filter(&vectors[14], 2, 50, &allowlist);
assert!(!results.is_empty());
assert!(results.len() <= 2);
for (id, _) in &results {
assert!(allowlist.contains(id));
}
}
#[test]
fn test_batch_search_with_ef() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = QuantizedHnswIndex::new(config, QuantizationType::Scalar);
let vectors = create_test_vectors(50, 4);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec);
}
let queries = vec![vectors[10].clone(), vectors[30].clone()];
let results = index.batch_search_with_ef(&queries, 3, 50);
assert_eq!(results.len(), 2);
assert_eq!(results[0].len(), 3);
assert_eq!(results[1].len(), 3);
}
#[test]
fn test_batch_search_with_filter() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = QuantizedHnswIndex::new(config, QuantizationType::Binary);
let vectors = create_test_vectors(50, 4);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec);
}
let allowlist: std::collections::HashSet<NodeId> = (1..=25).map(NodeId::new).collect();
let queries = vec![vectors[5].clone(), vectors[20].clone()];
let results = index.batch_search_with_filter(&queries, 3, &allowlist);
assert_eq!(results.len(), 2);
for batch in &results {
for (id, _) in batch {
assert!(allowlist.contains(id));
}
}
}
#[test]
fn test_batch_search_with_ef_and_filter() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = QuantizedHnswIndex::new(config, QuantizationType::Scalar);
let vectors = create_test_vectors(50, 4);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec);
}
let allowlist: std::collections::HashSet<NodeId> = (10..=40).map(NodeId::new).collect();
let queries = vec![vectors[15].clone()];
let results = index.batch_search_with_ef_and_filter(&queries, 5, 100, &allowlist);
assert_eq!(results.len(), 1);
for (id, _) in &results[0] {
assert!(allowlist.contains(id));
}
}
#[test]
fn test_snapshot_and_restore_topology() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = QuantizedHnswIndex::new(config.clone(), QuantizationType::Scalar);
let vectors = create_test_vectors(20, 4);
for (i, vec) in vectors.iter().enumerate() {
index.insert(NodeId::new(i as u64 + 1), vec);
}
let before = index.search(&vectors[10], 5);
let (entry, max_level, nodes) = index.snapshot_topology();
let index2 = QuantizedHnswIndex::new(config, QuantizationType::Scalar);
for (i, vec) in vectors.iter().enumerate() {
index2.insert(NodeId::new(i as u64 + 1), vec);
}
index2.restore_topology(entry, max_level, nodes);
let after = index2.search(&vectors[10], 5);
assert_eq!(before.len(), after.len());
}
#[test]
fn test_heap_memory_bytes() {
let config = HnswConfig::new(4, DistanceMetric::Euclidean);
let index = QuantizedHnswIndex::new(config, QuantizationType::Scalar);
let empty_mem = index.heap_memory_bytes();
index.insert(NodeId::new(1), &[0.1, 0.2, 0.3, 0.4]);
assert!(
index.heap_memory_bytes() > empty_mem,
"memory should grow after insert"
);
}
}