use crate::vector::product_quantize::{PQCode, PQConfig, PQDistanceCache, ProductQuantizer};
use crate::{Document, RagError, Result, SearchResult};
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashSet};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PQHNSWConfig {
pub m: usize,
pub m0: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub ml: f32,
pub use_distance_cache: bool,
pub store_original: bool,
pub rerank_candidates: usize,
}
impl Default for PQHNSWConfig {
fn default() -> Self {
let m = 16;
Self {
m,
m0: m * 2,
ef_construction: 200,
ef_search: 50,
ml: 1.0 / (m as f32).ln(),
use_distance_cache: true,
store_original: false,
rerank_candidates: 0,
}
}
}
impl PQHNSWConfig {
pub fn with_reranking(mut self, candidates: usize) -> Self {
self.store_original = true;
self.rerank_candidates = candidates;
self
}
pub fn with_ef_search(mut self, ef: usize) -> Self {
self.ef_search = ef;
self
}
}
#[derive(Debug, Clone)]
struct PQNode {
id: String,
content: String,
code: PQCode,
original: Option<Vec<f32>>,
metadata: Option<serde_json::Value>,
connections: Vec<HashSet<usize>>,
}
pub struct PQHNSWIndex {
config: PQHNSWConfig,
pq: ProductQuantizer,
distance_cache: Option<PQDistanceCache>,
nodes: Vec<PQNode>,
entry_point: Option<usize>,
max_layer: usize,
}
impl PQHNSWIndex {
pub fn train(
pq_config: PQConfig,
training_data: &[Vec<f32>],
config: PQHNSWConfig,
) -> Result<Self> {
let pq = ProductQuantizer::train(training_data, pq_config)?;
let distance_cache = if config.use_distance_cache {
Some(PQDistanceCache::build(&pq))
} else {
None
};
Ok(Self {
config,
pq,
distance_cache,
nodes: Vec::new(),
entry_point: None,
max_layer: 0,
})
}
pub fn from_quantizer(pq: ProductQuantizer, config: PQHNSWConfig) -> Self {
let distance_cache = if config.use_distance_cache {
Some(PQDistanceCache::build(&pq))
} else {
None
};
Self {
config,
pq,
distance_cache,
nodes: Vec::new(),
entry_point: None,
max_layer: 0,
}
}
pub fn add(&mut self, document: Document) -> Result<()> {
let dim = self.pq.config().dim;
if document.embedding.len() != dim {
return Err(RagError::DimensionMismatch {
expected: dim,
actual: document.embedding.len(),
});
}
if document.embedding.iter().any(|v| !v.is_finite()) {
return Err(RagError::InvalidInput(
"embedding contains non-finite values (NaN or Inf)".to_string(),
));
}
let node_id = self.nodes.len();
let node_level = self.random_level();
let mut connections = Vec::with_capacity(node_level + 1);
for _ in 0..=node_level {
connections.push(HashSet::new());
}
let code = self.pq.encode(&document.embedding);
let original = if self.config.store_original {
Some(document.embedding)
} else {
None
};
let node = PQNode {
id: document.id,
content: document.content,
code,
original,
metadata: document.metadata,
connections,
};
self.nodes.push(node);
if self.entry_point.is_none() {
self.entry_point = Some(node_id);
self.max_layer = node_level;
return Ok(());
}
self.insert_node(node_id, node_level);
if node_level > self.max_layer {
self.max_layer = node_level;
self.entry_point = Some(node_id);
}
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
let dim = self.pq.config().dim;
if query.len() != dim {
return Err(RagError::DimensionMismatch {
expected: dim,
actual: query.len(),
});
}
if self.nodes.is_empty() {
return Ok(Vec::new());
}
let distance_table = self.pq.compute_distance_table(query);
let entry_point = self.entry_point.unwrap();
let mut current_nearest = vec![entry_point];
for layer in (1..=self.max_layer).rev() {
current_nearest = self.search_layer_adc(&distance_table, ¤t_nearest, 1, layer);
}
let ef = self.config.ef_search.max(k);
let candidates = if self.config.rerank_candidates > 0 {
self.config.ef_search.max(self.config.rerank_candidates)
} else {
ef
};
current_nearest = self.search_layer_adc(&distance_table, ¤t_nearest, candidates, 0);
let results = if self.config.store_original && self.config.rerank_candidates > 0 {
self.rerank_with_original(query, ¤t_nearest, k)
} else {
self.to_search_results_adc(query, ¤t_nearest, &distance_table, k)
};
Ok(results)
}
pub fn search_symmetric(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
let dim = self.pq.config().dim;
if query.len() != dim {
return Err(RagError::DimensionMismatch {
expected: dim,
actual: query.len(),
});
}
if self.nodes.is_empty() {
return Ok(Vec::new());
}
let query_code = self.pq.encode(query);
let entry_point = self.entry_point.unwrap();
let mut current_nearest = vec![entry_point];
for layer in (1..=self.max_layer).rev() {
current_nearest = self.search_layer_symmetric(&query_code, ¤t_nearest, 1, layer);
}
let ef = self.config.ef_search.max(k);
current_nearest = self.search_layer_symmetric(&query_code, ¤t_nearest, ef, 0);
let results = self.to_search_results_symmetric(&query_code, ¤t_nearest, k);
Ok(results)
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn clear(&mut self) {
self.nodes.clear();
self.entry_point = None;
self.max_layer = 0;
}
pub fn embedding_dim(&self) -> usize {
self.pq.config().dim
}
pub fn quantizer(&self) -> &ProductQuantizer {
&self.pq
}
pub fn memory_usage(&self) -> usize {
if self.nodes.is_empty() {
return 0;
}
let code_size = self.pq.config().num_subvectors;
let original_size = if self.config.store_original {
self.pq.config().dim * 4
} else {
0
};
let overhead_per_node = 100;
let codebook_size = self.pq.config().num_subvectors
* self.pq.config().num_centroids()
* self.pq.config().subvector_dim()
* 4;
let cache_size = if self.distance_cache.is_some() {
self.pq.config().num_subvectors
* self.pq.config().num_centroids()
* self.pq.config().num_centroids()
* 4
} else {
0
};
self.nodes.len() * (code_size + original_size + overhead_per_node)
+ codebook_size
+ cache_size
}
pub fn compression_ratio(&self) -> f32 {
if self.nodes.is_empty() {
return 0.0;
}
let full_size = self.nodes.len() * self.pq.config().dim * 4;
let actual_size = self.memory_usage();
full_size as f32 / actual_size as f32
}
fn random_level(&self) -> usize {
let mut rng = rand::thread_rng();
let uniform: f32 = rng.gen::<f32>().max(f32::EPSILON);
(-uniform.ln() * self.config.ml).floor() as usize
}
fn insert_node(&mut self, node_id: usize, node_level: usize) {
let entry_point = self.entry_point.unwrap();
let mut current_nearest = vec![entry_point];
let node_code = self.nodes[node_id].code.clone();
for layer in (node_level + 1..=self.max_layer).rev() {
current_nearest = self.search_layer_symmetric(&node_code, ¤t_nearest, 1, layer);
}
for layer in (0..=node_level).rev() {
current_nearest = self.search_layer_symmetric(
&node_code,
¤t_nearest,
self.config.ef_construction,
layer,
);
let m = if layer == 0 {
self.config.m0
} else {
self.config.m
};
let neighbors = self.select_neighbors(¤t_nearest, &node_code, m);
for &neighbor_id in &neighbors {
self.nodes[node_id].connections[layer].insert(neighbor_id);
if layer < self.nodes[neighbor_id].connections.len() {
self.nodes[neighbor_id].connections[layer].insert(node_id);
let neighbor_m = if layer == 0 {
self.config.m0
} else {
self.config.m
};
if self.nodes[neighbor_id].connections[layer].len() > neighbor_m {
let neighbor_code = self.nodes[neighbor_id].code.clone();
let neighbor_connections: Vec<usize> = self.nodes[neighbor_id].connections
[layer]
.iter()
.copied()
.collect();
let pruned = self.select_neighbors(
&neighbor_connections,
&neighbor_code,
neighbor_m,
);
self.nodes[neighbor_id].connections[layer] = pruned.into_iter().collect();
}
}
}
}
}
fn search_layer_adc(
&self,
distance_table: &[Vec<f32>],
entry_points: &[usize],
ef: usize,
layer: usize,
) -> Vec<usize> {
let mut visited = HashSet::new();
let mut candidates = BinaryHeap::new();
let mut best = BinaryHeap::new();
for &ep in entry_points {
let dist = self
.pq
.distance_with_table(distance_table, &self.nodes[ep].code);
candidates.push(Reverse((OrderedFloat(dist), ep)));
best.push((OrderedFloat(dist), ep));
visited.insert(ep);
}
while let Some(Reverse((current_dist, current_id))) = candidates.pop() {
if best.len() >= ef {
if let Some(&(furthest_dist, _)) = best.peek() {
if current_dist > furthest_dist {
break;
}
}
}
if layer < self.nodes[current_id].connections.len() {
for &neighbor_id in &self.nodes[current_id].connections[layer] {
if !visited.contains(&neighbor_id) {
visited.insert(neighbor_id);
let dist = self
.pq
.distance_with_table(distance_table, &self.nodes[neighbor_id].code);
let dist_ord = OrderedFloat(dist);
if best.len() < ef {
candidates.push(Reverse((dist_ord, neighbor_id)));
best.push((dist_ord, neighbor_id));
} else if let Some(&(furthest_dist, _)) = best.peek() {
if dist_ord < furthest_dist {
candidates.push(Reverse((dist_ord, neighbor_id)));
best.push((dist_ord, neighbor_id));
if best.len() > ef {
best.pop();
}
}
}
}
}
}
}
let mut results: Vec<(f32, usize)> = best
.into_iter()
.map(|(OrderedFloat(dist), id)| (dist, id))
.collect();
results.sort_by(|a, b| a.0.total_cmp(&b.0));
results.into_iter().map(|(_, id)| id).collect()
}
fn search_layer_symmetric(
&self,
query_code: &PQCode,
entry_points: &[usize],
ef: usize,
layer: usize,
) -> Vec<usize> {
let mut visited = HashSet::new();
let mut candidates = BinaryHeap::new();
let mut best = BinaryHeap::new();
for &ep in entry_points {
let dist = self.distance_symmetric(query_code, &self.nodes[ep].code);
candidates.push(Reverse((OrderedFloat(dist), ep)));
best.push((OrderedFloat(dist), ep));
visited.insert(ep);
}
while let Some(Reverse((current_dist, current_id))) = candidates.pop() {
if best.len() >= ef {
if let Some(&(furthest_dist, _)) = best.peek() {
if current_dist > furthest_dist {
break;
}
}
}
if layer < self.nodes[current_id].connections.len() {
for &neighbor_id in &self.nodes[current_id].connections[layer] {
if !visited.contains(&neighbor_id) {
visited.insert(neighbor_id);
let dist =
self.distance_symmetric(query_code, &self.nodes[neighbor_id].code);
let dist_ord = OrderedFloat(dist);
if best.len() < ef {
candidates.push(Reverse((dist_ord, neighbor_id)));
best.push((dist_ord, neighbor_id));
} else if let Some(&(furthest_dist, _)) = best.peek() {
if dist_ord < furthest_dist {
candidates.push(Reverse((dist_ord, neighbor_id)));
best.push((dist_ord, neighbor_id));
if best.len() > ef {
best.pop();
}
}
}
}
}
}
}
let mut results: Vec<(f32, usize)> = best
.into_iter()
.map(|(OrderedFloat(dist), id)| (dist, id))
.collect();
results.sort_by(|a, b| a.0.total_cmp(&b.0));
results.into_iter().map(|(_, id)| id).collect()
}
fn select_neighbors(&self, candidates: &[usize], query_code: &PQCode, m: usize) -> Vec<usize> {
let mut scored: Vec<(f32, usize)> = candidates
.iter()
.map(|&id| {
let dist = self.distance_symmetric(query_code, &self.nodes[id].code);
(dist, id)
})
.collect();
scored.sort_by(|a, b| a.0.total_cmp(&b.0));
scored.truncate(m);
scored.into_iter().map(|(_, id)| id).collect()
}
fn distance_symmetric(&self, a: &PQCode, b: &PQCode) -> f32 {
if let Some(ref cache) = self.distance_cache {
cache.distance(a, b)
} else {
self.pq.symmetric_distance(a, b)
}
}
fn to_search_results_adc(
&self,
_query: &[f32],
node_ids: &[usize],
distance_table: &[Vec<f32>],
k: usize,
) -> Vec<SearchResult> {
let mut results: Vec<SearchResult> = node_ids
.iter()
.map(|&id| {
let node = &self.nodes[id];
let dist = self.pq.distance_with_table(distance_table, &node.code);
let score = 1.0 / (1.0 + dist);
SearchResult {
id: node.id.clone(),
content: node.content.clone(),
score,
metadata: node.metadata.clone(),
}
})
.collect();
results.sort_by(|a, b| b.score.total_cmp(&a.score));
results.truncate(k);
results
}
fn to_search_results_symmetric(
&self,
query_code: &PQCode,
node_ids: &[usize],
k: usize,
) -> Vec<SearchResult> {
let mut results: Vec<SearchResult> = node_ids
.iter()
.map(|&id| {
let node = &self.nodes[id];
let dist = self.distance_symmetric(query_code, &node.code);
let score = 1.0 / (1.0 + dist);
SearchResult {
id: node.id.clone(),
content: node.content.clone(),
score,
metadata: node.metadata.clone(),
}
})
.collect();
results.sort_by(|a, b| b.score.total_cmp(&a.score));
results.truncate(k);
results
}
fn rerank_with_original(
&self,
query: &[f32],
node_ids: &[usize],
k: usize,
) -> Vec<SearchResult> {
let mut results: Vec<SearchResult> = node_ids
.iter()
.filter_map(|&id| {
let node = &self.nodes[id];
node.original.as_ref().map(|original| {
let score = crate::vector::cosine_similarity(query, original).unwrap_or(0.0);
SearchResult {
id: node.id.clone(),
content: node.content.clone(),
score,
metadata: node.metadata.clone(),
}
})
})
.collect();
results.sort_by(|a, b| b.score.total_cmp(&a.score));
results.truncate(k);
results
}
}
impl crate::index::VectorIndex for PQHNSWIndex {
fn add(&mut self, document: Document) -> Result<()> {
self.add(document)
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
self.search(query, k)
}
fn len(&self) -> usize {
self.len()
}
fn clear(&mut self) {
self.clear()
}
fn embedding_dim(&self) -> usize {
self.embedding_dim()
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
struct OrderedFloat(f32);
impl Eq for OrderedFloat {}
impl PartialOrd for OrderedFloat {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrderedFloat {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.total_cmp(&other.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn generate_random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
(0..n)
.map(|_| {
(0..dim)
.map(|_| rand::Rng::gen_range(&mut rng, -1.0..1.0))
.collect()
})
.collect()
}
fn create_test_document(id: &str, embedding: Vec<f32>) -> Document {
Document {
id: id.to_string(),
content: format!("Content for {}", id),
embedding,
metadata: None,
}
}
#[test]
fn test_pq_hnsw_config() {
let config = PQHNSWConfig::default();
assert_eq!(config.m, 16);
assert_eq!(config.ef_construction, 200);
assert!(!config.store_original);
}
#[test]
fn test_pq_hnsw_basic() {
let dim = 64;
let pq_config = PQConfig::new(dim, 8, 8)
.with_seed(42)
.with_kmeans_iterations(10);
let training_data = generate_random_vectors(500, dim, 42);
let mut index =
PQHNSWIndex::train(pq_config, &training_data, PQHNSWConfig::default()).unwrap();
assert!(index.is_empty());
let doc = create_test_document("doc1", generate_random_vectors(1, dim, 100)[0].clone());
index.add(doc).unwrap();
assert_eq!(index.len(), 1);
}
#[test]
fn test_pq_hnsw_search() {
let dim = 64;
let pq_config = PQConfig::new(dim, 8, 8)
.with_seed(42)
.with_kmeans_iterations(10);
let training_data = generate_random_vectors(500, dim, 42);
let mut index =
PQHNSWIndex::train(pq_config, &training_data, PQHNSWConfig::default()).unwrap();
for i in 0..100 {
let doc = create_test_document(
&format!("doc{}", i),
generate_random_vectors(1, dim, i)[0].clone(),
);
index.add(doc).unwrap();
}
assert_eq!(index.len(), 100);
let query = generate_random_vectors(1, dim, 999)[0].clone();
let results = index.search(&query, 10).unwrap();
assert_eq!(results.len(), 10);
for i in 0..results.len() - 1 {
assert!(results[i].score >= results[i + 1].score);
}
}
#[test]
fn test_pq_hnsw_compression() {
let dim = 384;
let pq_config = PQConfig::new(dim, 8, 8)
.with_seed(42)
.with_kmeans_iterations(10);
let training_data = generate_random_vectors(500, dim, 42);
let mut index =
PQHNSWIndex::train(pq_config, &training_data, PQHNSWConfig::default()).unwrap();
for i in 0..5000 {
let doc = create_test_document(
&format!("doc{}", i),
generate_random_vectors(1, dim, i)[0].clone(),
);
index.add(doc).unwrap();
}
let memory = index.memory_usage();
let full_size = 5000 * dim * 4;
println!("PQ HNSW memory: {} bytes", memory);
println!("Full f32 would be: {} bytes", full_size);
println!("Compression ratio: {:.1}x", index.compression_ratio());
assert!(
memory < full_size / 2,
"Memory should be < 50% of full size, got {}%",
(memory as f64 / full_size as f64 * 100.0) as u32
);
}
#[test]
fn test_pq_hnsw_with_reranking() {
let dim = 64;
let pq_config = PQConfig::new(dim, 8, 8)
.with_seed(42)
.with_kmeans_iterations(10);
let training_data = generate_random_vectors(500, dim, 42);
let config = PQHNSWConfig::default().with_reranking(50);
let mut index = PQHNSWIndex::train(pq_config, &training_data, config).unwrap();
for i in 0..100 {
let doc = create_test_document(
&format!("doc{}", i),
generate_random_vectors(1, dim, i)[0].clone(),
);
index.add(doc).unwrap();
}
let query = generate_random_vectors(1, dim, 999)[0].clone();
let results = index.search(&query, 10).unwrap();
assert_eq!(results.len(), 10);
for result in &results {
assert!(result.score >= -1.0 && result.score <= 1.0);
}
}
#[test]
fn test_pq_hnsw_dimension_mismatch() {
let dim = 64;
let pq_config = PQConfig::new(dim, 8, 8)
.with_seed(42)
.with_kmeans_iterations(10);
let training_data = generate_random_vectors(100, dim, 42);
let mut index =
PQHNSWIndex::train(pq_config, &training_data, PQHNSWConfig::default()).unwrap();
let doc = create_test_document("doc1", vec![0.5; 32]);
assert!(index.add(doc).is_err());
}
#[test]
fn test_pq_hnsw_symmetric_search() {
let dim = 64;
let pq_config = PQConfig::new(dim, 8, 8)
.with_seed(42)
.with_kmeans_iterations(10);
let training_data = generate_random_vectors(500, dim, 42);
let mut index =
PQHNSWIndex::train(pq_config, &training_data, PQHNSWConfig::default()).unwrap();
for i in 0..100 {
let doc = create_test_document(
&format!("doc{}", i),
generate_random_vectors(1, dim, i)[0].clone(),
);
index.add(doc).unwrap();
}
let query = generate_random_vectors(1, dim, 999)[0].clone();
let symmetric_results = index.search_symmetric(&query, 10).unwrap();
assert_eq!(symmetric_results.len(), 10);
let adc_results = index.search(&query, 10).unwrap();
assert_eq!(adc_results.len(), 10);
}
#[test]
fn test_pq_hnsw_add_nan_embedding_rejected() {
let dim = 64;
let pq_config = PQConfig::new(dim, 8, 8)
.with_seed(42)
.with_kmeans_iterations(10);
let training_data = generate_random_vectors(100, dim, 42);
let mut index =
PQHNSWIndex::train(pq_config, &training_data, PQHNSWConfig::default()).unwrap();
let mut embedding = vec![0.5; dim];
embedding[10] = f32::NAN;
let doc = create_test_document("nan_doc", embedding);
assert!(index.add(doc).is_err());
assert_eq!(index.len(), 0);
}
#[test]
fn test_pq_search_with_nan_query_does_not_panic() {
let dim = 64;
let pq_config = PQConfig::new(dim, 8, 8)
.with_seed(7)
.with_kmeans_iterations(5);
let training_data = generate_random_vectors(256, dim, 11);
let mut index =
PQHNSWIndex::train(pq_config, &training_data, PQHNSWConfig::default()).unwrap();
index
.add(create_test_document(
"doc1",
generate_random_vectors(1, dim, 99)[0].clone(),
))
.unwrap();
let query = vec![f32::NAN; dim];
let outcome = std::panic::catch_unwind(|| index.search(&query, 1));
assert!(
outcome.is_ok(),
"PQ search panicked when query contains NaN"
);
}
}