#![allow(clippy::similar_names)]
#![allow(unused_variables)]
use crate::{Result, Error};
use super::{
Vector, DistanceMetric, ProductQuantizer, ProductQuantizerConfig,
QuantizedVector,
};
use parking_lot::RwLock;
use std::sync::Arc;
use std::collections::HashMap;
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizedHnswConfig {
pub max_connections: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub dimension: usize,
pub distance_metric: DistanceMetric,
pub pq_config: ProductQuantizerConfig,
pub use_pq_storage: bool,
}
impl QuantizedHnswConfig {
pub fn default_for_dimension(dimension: usize) -> Result<Self> {
let pq_config = ProductQuantizerConfig::default_for_dimension(dimension)
.map_err(|e| Error::query_execution(format!("PQ config error: {}", e)))?;
Ok(Self {
max_connections: 16,
ef_construction: 200,
ef_search: 200,
dimension,
distance_metric: DistanceMetric::L2,
pq_config,
use_pq_storage: true, })
}
#[cfg(test)]
pub fn test_for_dimension(dimension: usize) -> Result<Self> {
let pq_config = ProductQuantizerConfig::test_config(dimension)
.map_err(|e| Error::query_execution(format!("PQ config error: {}", e)))?;
Ok(Self {
max_connections: 16,
ef_construction: 200,
ef_search: 200,
dimension,
distance_metric: DistanceMetric::L2,
pq_config,
use_pq_storage: true,
})
}
}
impl Default for QuantizedHnswConfig {
fn default() -> Self {
Self {
max_connections: 16,
ef_construction: 200,
ef_search: 200,
dimension: 768,
distance_metric: DistanceMetric::L2,
pq_config: ProductQuantizerConfig::default(),
use_pq_storage: true,
}
}
}
pub struct QuantizedHnswIndex {
config: QuantizedHnswConfig,
pq: Arc<ProductQuantizer>,
quantized_vectors: Arc<RwLock<Vec<QuantizedVector>>>,
original_vectors: Arc<RwLock<Vec<Option<Vector>>>>,
graph: Arc<RwLock<Vec<HashMap<usize, Vec<usize>>>>>,
id_mapping: Arc<RwLock<Vec<u64>>>,
reverse_mapping: Arc<RwLock<HashMap<u64, usize>>>,
entry_point: Arc<RwLock<Option<usize>>>,
}
impl QuantizedHnswIndex {
pub fn new(config: QuantizedHnswConfig, pq: ProductQuantizer) -> Result<Self> {
if pq.config().dimension != config.dimension {
return Err(Error::query_execution(format!(
"PQ dimension {} doesn't match config dimension {}",
pq.config().dimension,
config.dimension
)));
}
Ok(Self {
config,
pq: Arc::new(pq),
quantized_vectors: Arc::new(RwLock::new(Vec::new())),
original_vectors: Arc::new(RwLock::new(Vec::new())),
graph: Arc::new(RwLock::new(Vec::new())),
id_mapping: Arc::new(RwLock::new(Vec::new())),
reverse_mapping: Arc::new(RwLock::new(HashMap::new())),
entry_point: Arc::new(RwLock::new(None)),
})
}
pub fn train(
config: QuantizedHnswConfig,
training_vectors: &[Vector],
) -> Result<Self> {
let pq = ProductQuantizer::train(config.pq_config.clone(), training_vectors)
.map_err(|e| Error::query_execution(format!("PQ training failed: {}", e)))?;
Self::new(config, pq)
}
pub fn insert(&self, row_id: u64, vector: &Vector) -> Result<()> {
if vector.len() != self.config.dimension {
return Err(Error::query_execution(format!(
"Vector dimension mismatch: expected {}, got {}",
self.config.dimension,
vector.len()
)));
}
let quantized = self.pq.encode(vector)
.map_err(|e| Error::query_execution(format!("Encoding failed: {}", e)))?;
let mut quantized_vectors = self.quantized_vectors.write();
let mut original_vectors = self.original_vectors.write();
let mut id_mapping = self.id_mapping.write();
let mut reverse_mapping = self.reverse_mapping.write();
let internal_id = quantized_vectors.len();
quantized_vectors.push(quantized);
if !self.config.use_pq_storage {
original_vectors.push(Some(vector.clone()));
} else {
original_vectors.push(None);
}
id_mapping.push(row_id);
reverse_mapping.insert(row_id, internal_id);
let mut entry_point = self.entry_point.write();
if entry_point.is_none() {
*entry_point = Some(internal_id);
}
Ok(())
}
pub fn search(&self, query: &Vector, k: usize) -> Result<Vec<(u64, f32)>> {
if query.len() != self.config.dimension {
return Err(Error::query_execution(format!(
"Query vector dimension mismatch: expected {}, got {}",
self.config.dimension,
query.len()
)));
}
let quantized_vectors = self.quantized_vectors.read();
let id_mapping = self.id_mapping.read();
if quantized_vectors.is_empty() {
return Ok(Vec::new());
}
let distance_table = self.pq.precompute_distance_table(query)
.map_err(|e| Error::query_execution(format!("Distance table computation failed: {}", e)))?;
let mut distances: Vec<(usize, f32)> = quantized_vectors
.iter()
.enumerate()
.filter_map(|(idx, qv)| {
self.pq
.compute_distance_with_table(&distance_table, qv)
.ok()
.map(|dist| (idx, dist))
})
.collect();
distances.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
});
distances.truncate(k);
let results: Vec<(u64, f32)> = distances
.into_iter()
.filter_map(|(internal_id, dist)| {
id_mapping.get(internal_id).map(|&row_id| (row_id, dist))
})
.collect();
Ok(results)
}
pub fn memory_stats(&self) -> MemoryStats {
let quantized_vectors = self.quantized_vectors.read();
let original_vectors = self.original_vectors.read();
let id_mapping = self.id_mapping.read();
let reverse_mapping = self.reverse_mapping.read();
let graph = self.graph.read();
let num_vectors = quantized_vectors.len();
let vector_dimension = self.config.dimension;
let quantized_size = num_vectors * self.pq.memory_per_vector();
let original_size = if self.config.use_pq_storage {
0
} else {
num_vectors * vector_dimension * std::mem::size_of::<f32>()
};
let codebook_size = self.pq.codebook_size();
let id_mapping_size = id_mapping.len() * std::mem::size_of::<u64>();
let reverse_mapping_size = reverse_mapping.len() *
(std::mem::size_of::<u64>() + std::mem::size_of::<usize>() + 16);
let mut graph_size = graph.len() * std::mem::size_of::<HashMap<usize, Vec<usize>>>();
for layer in graph.iter() {
graph_size += layer.len() * (std::mem::size_of::<usize>() + std::mem::size_of::<Vec<usize>>());
for neighbors in layer.values() {
graph_size += neighbors.len() * std::mem::size_of::<usize>();
}
}
let entry_point_size = std::mem::size_of::<Option<usize>>();
let metadata_size = id_mapping_size + reverse_mapping_size + graph_size + entry_point_size;
let total_size = quantized_size + original_size + codebook_size + metadata_size;
let uncompressed_size = num_vectors * vector_dimension * std::mem::size_of::<f32>();
let compression_ratio = if total_size > 0 {
uncompressed_size as f32 / total_size as f32
} else {
0.0
};
MemoryStats {
num_vectors,
quantized_size,
original_size,
codebook_size,
total_size,
compression_ratio,
}
}
pub fn product_quantizer(&self) -> Arc<ProductQuantizer> {
self.pq.clone()
}
pub fn config(&self) -> &QuantizedHnswConfig {
&self.config
}
pub fn delete(&self, row_id: u64) -> Result<()> {
let mut reverse_mapping = self.reverse_mapping.write();
if reverse_mapping.remove(&row_id).is_some() {
Ok(())
} else {
Err(Error::query_execution(format!(
"Vector with row_id {} not found in index",
row_id
)))
}
}
pub fn len(&self) -> usize {
self.quantized_vectors.read().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn to_bytes(&self) -> Result<Vec<u8>> {
use bincode;
#[derive(Serialize, Deserialize)]
struct SerializedIndex {
config: QuantizedHnswConfig,
codebook: super::Codebook,
quantized_vectors: Vec<super::QuantizedVector>,
id_mapping: Vec<u64>,
}
let codebook = (*self.pq.codebook()).clone();
let quantized_vectors = self.quantized_vectors.read().clone();
let id_mapping = self.id_mapping.read().clone();
let serialized = SerializedIndex {
config: self.config.clone(),
codebook,
quantized_vectors,
id_mapping,
};
bincode::serialize(&serialized)
.map_err(|e| Error::query_execution(format!("Failed to serialize index: {}", e)))
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
use bincode;
#[derive(Serialize, Deserialize)]
struct SerializedIndex {
config: QuantizedHnswConfig,
codebook: super::Codebook,
quantized_vectors: Vec<super::QuantizedVector>,
id_mapping: Vec<u64>,
}
let serialized: SerializedIndex = bincode::deserialize(bytes)
.map_err(|e| Error::query_execution(format!("Failed to deserialize index: {}", e)))?;
let pq = ProductQuantizer::new(serialized.config.pq_config.clone(), serialized.codebook)
.map_err(|e| Error::query_execution(format!("Failed to create PQ: {}", e)))?;
let mut reverse_mapping = HashMap::new();
for (internal_id, &row_id) in serialized.id_mapping.iter().enumerate() {
reverse_mapping.insert(row_id, internal_id);
}
let entry_point = if !serialized.id_mapping.is_empty() {
Some(0)
} else {
None
};
Ok(Self {
config: serialized.config,
pq: Arc::new(pq),
quantized_vectors: Arc::new(RwLock::new(serialized.quantized_vectors)),
original_vectors: Arc::new(RwLock::new(vec![None; serialized.id_mapping.len()])),
graph: Arc::new(RwLock::new(Vec::new())),
id_mapping: Arc::new(RwLock::new(serialized.id_mapping)),
reverse_mapping: Arc::new(RwLock::new(reverse_mapping)),
entry_point: Arc::new(RwLock::new(entry_point)),
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryStats {
pub num_vectors: usize,
pub quantized_size: usize,
pub original_size: usize,
pub codebook_size: usize,
pub total_size: usize,
pub compression_ratio: f32,
}
impl MemoryStats {
pub fn format(&self) -> String {
format!(
"Vectors: {}, Quantized: {:.2} MB, Original: {:.2} MB, Codebook: {:.2} KB, Total: {:.2} MB, Compression: {:.1}x",
self.num_vectors,
self.quantized_size as f64 / 1024.0 / 1024.0,
self.original_size as f64 / 1024.0 / 1024.0,
self.codebook_size as f64 / 1024.0,
self.total_size as f64 / 1024.0 / 1024.0,
self.compression_ratio
)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn generate_random_vectors(count: usize, dimension: usize) -> Vec<Vector> {
use rand::Rng;
let mut rng = rand::thread_rng();
(0..count)
.map(|_| {
(0..dimension)
.map(|_| rng.gen_range(-1.0..1.0))
.collect()
})
.collect()
}
#[test]
fn test_quantized_hnsw_creation() {
let config = QuantizedHnswConfig::test_for_dimension(128).unwrap();
let training_vectors = generate_random_vectors(2000, 128);
let index = QuantizedHnswIndex::train(config, &training_vectors);
assert!(index.is_ok());
}
#[test]
fn test_quantized_hnsw_insert_search() {
let config = QuantizedHnswConfig::test_for_dimension(128).unwrap();
let training_vectors = generate_random_vectors(2000, 128);
let index = QuantizedHnswIndex::train(config, &training_vectors).unwrap();
for (i, vector) in training_vectors.iter().take(100).enumerate() {
index.insert(i as u64, vector).unwrap();
}
assert_eq!(index.len(), 100);
let query = &training_vectors[0];
let results = index.search(query, 5).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 5);
assert!(results.iter().any(|(id, _)| *id == 0)); }
#[test]
fn test_memory_stats() {
let config = QuantizedHnswConfig::test_for_dimension(768).unwrap();
let training_vectors = generate_random_vectors(2000, 768);
let index = QuantizedHnswIndex::train(config, &training_vectors).unwrap();
for (i, vector) in training_vectors.iter().take(100).enumerate() {
index.insert(i as u64, vector).unwrap();
}
let stats = index.memory_stats();
assert_eq!(stats.num_vectors, 100);
assert!(stats.compression_ratio > 1.0);
assert!(stats.total_size > 0);
println!("{}", stats.format());
}
}