use super::{QuantizationParams, QuantizedVector, QuantizedDistance, AsymmetricCosine};
use hnsw_rs::prelude::{Distance, Hnsw, Neighbour};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub struct QuantizedHNSW {
hnsw: Hnsw<u8, QuantizedDistanceWrapper>,
quant_params: QuantizationParams,
dimension: usize,
id_map: HashMap<usize, String>,
reverse_map: HashMap<String, usize>,
next_id: usize,
count: usize,
}
#[derive(Clone, Debug)]
pub struct QuantizedDistanceWrapper {
pub params: QuantizationParams,
pub dimension: usize,
pub metric: DistanceMetric,
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub enum DistanceMetric {
Cosine,
L2,
Dot,
}
impl Default for DistanceMetric {
fn default() -> Self {
DistanceMetric::Cosine
}
}
impl QuantizedDistanceWrapper {
pub fn new(params: QuantizationParams, dimension: usize) -> Self {
Self {
params,
dimension,
metric: DistanceMetric::Cosine,
}
}
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self
}
#[inline]
fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
quantized
.iter()
.map(|&v| self.params.dequantize(v))
.collect()
}
fn asymmetric_distance(&self, query: &[f32], stored: &[u8]) -> f32 {
assert_eq!(query.len(), self.dimension);
assert_eq!(stored.len(), self.dimension);
let stored_qv = QuantizedVector::new(
stored.to_vec(),
self.params,
self.dimension,
);
match self.metric {
DistanceMetric::Cosine => {
AsymmetricCosine::asymmetric_distance(query, &stored_qv)
}
DistanceMetric::L2 => {
super::AsymmetricL2::asymmetric_distance(query, &stored_qv)
}
DistanceMetric::Dot => {
super::AsymmetricDot::asymmetric_distance(query, &stored_qv)
}
}
}
}
impl Default for QuantizedDistanceWrapper {
fn default() -> Self {
Self {
params: QuantizationParams::default(),
dimension: 768,
metric: DistanceMetric::Cosine,
}
}
}
impl Distance<u8> for QuantizedDistanceWrapper {
fn eval(&self, va: &[u8], vb: &[u8]) -> f32 {
let dequantized_a: Vec<f32> = va.iter().map(|&v| self.params.dequantize(v)).collect();
let dequantized_b: Vec<f32> = vb.iter().map(|&v| self.params.dequantize(v)).collect();
match self.metric {
DistanceMetric::Cosine => cosine_distance(&dequantized_a, &dequantized_b),
DistanceMetric::L2 => l2_distance(&dequantized_a, &dequantized_b),
DistanceMetric::Dot => dot_distance(&dequantized_a, &dequantized_b),
}
}
}
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let norm_a = norm_a.sqrt();
let norm_b = norm_b.sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 1.0;
}
let similarity = dot / (norm_a * norm_b);
(1.0 - similarity).max(0.0)
}
fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
let mut sum = 0.0f32;
for i in 0..a.len() {
let diff = a[i] - b[i];
sum += diff * diff;
}
sum.sqrt()
}
fn dot_distance(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
}
(1.0 - dot).max(0.0)
}
impl QuantizedHNSW {
pub fn new(
dimension: usize,
quant_params: QuantizationParams,
max_elements: usize,
) -> Self {
let distance = QuantizedDistanceWrapper::new(quant_params, dimension);
let m = 16;
let max_layer = 16;
let ef_construction = 200;
let hnsw = Hnsw::new(m, max_elements, max_layer, ef_construction, distance);
Self {
hnsw,
quant_params,
dimension,
id_map: HashMap::new(),
reverse_map: HashMap::new(),
next_id: 0,
count: 0,
}
}
pub fn insert(&mut self, node_id: String, vector: &[f32]) -> Result<(), QuantizedHnswError> {
if vector.len() != self.dimension {
return Err(QuantizedHnswError::DimensionMismatch {
expected: self.dimension,
got: vector.len(),
});
}
if self.reverse_map.contains_key(&node_id) {
return Err(QuantizedHnswError::NodeExists(node_id));
}
let quantized: Vec<u8> = vector
.iter()
.map(|&v| self.quant_params.quantize(v))
.collect();
let internal_id = self.next_id;
self.next_id += 1;
self.hnsw.insert((&quantized, internal_id));
self.id_map.insert(internal_id, node_id.clone());
self.reverse_map.insert(node_id, internal_id);
self.count += 1;
Ok(())
}
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(String, f32)> {
if query.len() != self.dimension {
return Vec::new();
}
if self.count == 0 {
return Vec::new();
}
let query_quantized: Vec<u8> = query
.iter()
.map(|&v| self.quant_params.quantize(v))
.collect();
let ef_search = 50.max(top_k);
let results = self.hnsw.search(&query_quantized, top_k, ef_search);
results
.into_iter()
.filter_map(|neighbour| {
self.id_map
.get(&neighbour.d_id)
.map(|node_id| (node_id.clone(), 1.0 - neighbour.distance))
})
.collect()
}
pub fn search_asymmetric_brute_force(
&self,
query: &[f32],
top_k: usize,
) -> Vec<(String, f32)> {
if query.len() != self.dimension || self.count == 0 {
return Vec::new();
}
self.search(query, top_k)
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn quantization_params(&self) -> &QuantizationParams {
&self.quant_params
}
pub fn remove(&mut self, node_id: &str) -> bool {
if let Some(internal_id) = self.reverse_map.remove(node_id) {
self.id_map.remove(&internal_id);
self.count -= 1;
true
} else {
false
}
}
pub fn estimated_memory_bytes(&self) -> usize {
let vector_data = self.count * self.dimension;
let overhead = self.id_map.len() * (std::mem::size_of::<usize>() + std::mem::size_of::<String>());
vector_data + overhead
}
}
#[derive(Debug, thiserror::Error)]
pub enum QuantizedHnswError {
#[error("Dimension mismatch: expected {expected}, got {got}")]
DimensionMismatch {
expected: usize,
got: usize,
},
#[error("Node {0} already exists")]
NodeExists(String),
#[error("Node {0} not found")]
NodeNotFound(String),
#[error("Quantization error: {0}")]
QuantizationError(String),
}
pub struct HybridQuantizedIndex {
hnsw: QuantizedHNSW,
original_vectors: HashMap<String, Vec<f32>>,
use_reranking: bool,
}
impl HybridQuantizedIndex {
pub fn new(
dimension: usize,
quant_params: QuantizationParams,
max_elements: usize,
) -> Self {
Self {
hnsw: QuantizedHNSW::new(dimension, quant_params, max_elements),
original_vectors: HashMap::new(),
use_reranking: false,
}
}
pub fn enable_reranking(&mut self) {
self.use_reranking = true;
}
pub fn disable_reranking(&mut self) {
self.use_reranking = false;
}
pub fn insert(
&mut self,
node_id: String,
vector: &[f32],
store_original: bool,
) -> Result<(), QuantizedHnswError> {
if store_original {
self.original_vectors.insert(node_id.clone(), vector.to_vec());
}
self.hnsw.insert(node_id, vector)
}
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(String, f32)> {
if !self.use_reranking || self.original_vectors.is_empty() {
return self.hnsw.search(query, top_k);
}
let candidates = self.hnsw.search(query, top_k * 2);
let mut reranked: Vec<(String, f32)> = candidates
.into_iter()
.filter_map(|(node_id, _)| {
self.original_vectors.get(&node_id).map(|stored| {
let stored_qv = QuantizedVector::from_f32(stored, *self.hnsw.quantization_params());
let distance = AsymmetricCosine::asymmetric_distance(query, &stored_qv);
(node_id, 1.0 - distance) })
})
.collect();
reranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
reranked.truncate(top_k);
reranked
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_vector(dimension: usize, value: f32) -> Vec<f32> {
vec![value; dimension]
}
#[test]
fn test_quantized_hnsw_creation() {
let params = QuantizationParams::from_min_max(0.0, 1.0);
let index = QuantizedHNSW::new(128, params, 1000);
assert_eq!(index.dimension(), 128);
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_quantized_hnsw_insert() {
let params = QuantizationParams::from_min_max(0.0, 1.0);
let mut index = QuantizedHNSW::new(128, params, 1000);
let vector = create_test_vector(128, 0.5);
index.insert("test".to_string(), &vector).unwrap();
assert_eq!(index.len(), 1);
assert!(!index.is_empty());
}
#[test]
fn test_quantized_hnsw_dimension_mismatch() {
let params = QuantizationParams::from_min_max(0.0, 1.0);
let mut index = QuantizedHNSW::new(128, params, 1000);
let vector = create_test_vector(64, 0.5); let result = index.insert("test".to_string(), &vector);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), QuantizedHnswError::DimensionMismatch { .. }));
}
#[test]
fn test_quantized_hnsw_duplicate_insert() {
let params = QuantizationParams::from_min_max(0.0, 1.0);
let mut index = QuantizedHNSW::new(128, params, 1000);
let vector = create_test_vector(128, 0.5);
index.insert("test".to_string(), &vector).unwrap();
let result = index.insert("test".to_string(), &vector);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), QuantizedHnswError::NodeExists(_)));
}
#[test]
fn test_quantized_hnsw_search() {
let params = QuantizationParams::from_min_max(0.0, 1.0);
let mut index = QuantizedHNSW::new(64, params, 1000);
for i in 0..10 {
let vector = create_test_vector(64, i as f32 / 10.0);
index.insert(format!("node_{}", i), &vector).unwrap();
}
let query = create_test_vector(64, 0.5);
let results = index.search(&query, 5);
assert!(!results.is_empty());
assert!(results.len() <= 5);
}
#[test]
fn test_quantized_hnsw_empty_search() {
let params = QuantizationParams::from_min_max(0.0, 1.0);
let index = QuantizedHNSW::new(64, params, 1000);
let query = create_test_vector(64, 0.5);
let results = index.search(&query, 10);
assert!(results.is_empty());
}
#[test]
fn test_quantized_hnsw_remove() {
let params = QuantizationParams::from_min_max(0.0, 1.0);
let mut index = QuantizedHNSW::new(64, params, 1000);
let vector = create_test_vector(64, 0.5);
index.insert("test".to_string(), &vector).unwrap();
assert_eq!(index.len(), 1);
assert!(index.remove("test"));
assert_eq!(index.len(), 0);
assert!(!index.remove("nonexistent"));
}
#[test]
fn test_distance_wrapper_eval() {
let params = QuantizationParams::from_min_max(0.0, 1.0);
let wrapper = QuantizedDistanceWrapper::new(params, 3);
let a = vec![0.0f32, 0.5, 1.0];
let b = vec![1.0f32, 0.5, 0.0];
let a_q: Vec<u8> = a.iter().map(|&v| params.quantize(v)).collect();
let b_q: Vec<u8> = b.iter().map(|&v| params.quantize(v)).collect();
let distance = wrapper.eval(&a_q, &b_q);
assert!(distance >= 0.0);
}
#[test]
fn test_memory_efficiency() {
let params = QuantizationParams::from_min_max(0.0, 1.0);
let mut index = QuantizedHNSW::new(768, params, 1000);
for i in 0..100 {
let vector: Vec<f32> = (0..768).map(|j| ((i * 768 + j) % 100) as f32 / 100.0).collect();
index.insert(format!("node_{}", i), &vector).unwrap();
}
let memory = index.estimated_memory_bytes();
assert!(memory < 100 * 768 * 2); }
}