use hnsw_rs::prelude::Hnsw;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use thiserror::Error;
use super::distance::{
clear_adc_query_context, set_adc_query_context, AdcDistanceMetric, Int8AdcDistance,
};
use super::quantization::Quantize;
use super::vector::Int8QuantizedVector;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Int8HnswParams {
pub m: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub max_elements: usize,
pub max_layer: usize,
pub metric: AdcDistanceMetric,
}
impl Default for Int8HnswParams {
fn default() -> Self {
Self {
m: 16,
ef_construction: 200,
ef_search: 50,
max_elements: 100_000,
max_layer: 16,
metric: AdcDistanceMetric::Cosine,
}
}
}
impl Int8HnswParams {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_m(mut self, m: usize) -> Self {
self.m = m;
self
}
#[must_use]
pub fn with_ef_construction(mut self, ef: usize) -> Self {
self.ef_construction = ef;
self
}
#[must_use]
pub fn with_ef_search(mut self, ef: usize) -> Self {
self.ef_search = ef;
self
}
#[must_use]
pub fn with_max_elements(mut self, max: usize) -> Self {
self.max_elements = max;
self
}
#[must_use]
pub fn with_max_layer(mut self, max: usize) -> Self {
self.max_layer = max;
self
}
#[must_use]
pub fn with_metric(mut self, metric: AdcDistanceMetric) -> Self {
self.metric = metric;
self
}
pub fn validate(&self) -> Result<(), Int8HnswError> {
if self.m == 0 {
return Err(Int8HnswError::InvalidParameter("m must be > 0".to_string()));
}
if self.ef_construction < self.m {
return Err(Int8HnswError::InvalidParameter(
"ef_construction must be >= m".to_string(),
));
}
if self.ef_search == 0 {
return Err(Int8HnswError::InvalidParameter(
"ef_search must be > 0".to_string(),
));
}
if self.max_elements == 0 {
return Err(Int8HnswError::InvalidParameter(
"max_elements must be > 0".to_string(),
));
}
if self.max_layer == 0 {
return Err(Int8HnswError::InvalidParameter(
"max_layer must be > 0".to_string(),
));
}
Ok(())
}
}
pub struct Int8HnswIndex {
hnsw: Hnsw<'static, Int8QuantizedVector, Int8AdcDistance>,
id_map: HashMap<usize, String>,
reverse_map: HashMap<String, usize>,
deleted: HashSet<usize>,
next_id: usize,
dimension: usize,
params: Int8HnswParams,
count: usize,
max_elements: usize,
}
impl Int8HnswIndex {
pub fn new(dimension: usize) -> Self {
Self::with_params(dimension, Int8HnswParams::default())
}
pub fn with_params(dimension: usize, params: Int8HnswParams) -> Self {
params.validate().unwrap_or_else(|e| {
tracing::warn!("Invalid INT8 HNSW params, using defaults: {:?}", e);
});
let distance = Int8AdcDistance::new(params.metric);
let max_elements = params.max_elements;
let hnsw = Hnsw::new(
params.m,
params.max_elements,
params.max_layer,
params.ef_construction,
distance,
);
Self {
hnsw,
id_map: HashMap::new(),
reverse_map: HashMap::new(),
deleted: HashSet::new(),
next_id: 0,
dimension,
params,
count: 0,
max_elements,
}
}
pub fn insert(&mut self, node_id: String, embedding: Vec<f32>) -> Result<(), Int8HnswError> {
if embedding.len() != self.dimension {
return Err(Int8HnswError::DimensionMismatch {
expected: self.dimension,
got: embedding.len(),
});
}
if self.reverse_map.contains_key(&node_id) {
return Err(Int8HnswError::NodeExists(node_id));
}
let quantized: Int8QuantizedVector = embedding.quantize();
let internal_id = self.next_id;
self.next_id += 1;
self.hnsw.insert((&vec![quantized], internal_id));
self.id_map.insert(internal_id, node_id.clone());
self.reverse_map.insert(node_id, internal_id);
self.count += 1;
Ok(())
}
pub fn insert_batch(&mut self, vectors: impl IntoIterator<Item = (String, Vec<f32>)>) -> usize {
let mut inserted = 0;
for (node_id, embedding) in vectors {
if self.insert(node_id, embedding).is_ok() {
inserted += 1;
}
}
inserted
}
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(String, f32)> {
if query.len() != self.dimension {
return Vec::new();
}
if self.count == 0 {
return Vec::new();
}
set_adc_query_context(query, self.params.metric);
let dummy_query = Int8QuantizedVector::new(
vec![0i8; self.dimension],
super::vector::Int8QuantizedVectorMetadata::default(),
self.dimension,
);
let ef_search = self.params.ef_search.max(top_k);
let results = self
.hnsw
.search(std::slice::from_ref(&dummy_query), top_k, ef_search);
clear_adc_query_context();
let mut output = Vec::new();
for neighbour in results.into_iter() {
let internal_id = neighbour.d_id;
let dist = neighbour.distance;
if self.deleted.contains(&internal_id) {
continue;
}
if let Some(node_id) = self.id_map.get(&internal_id) {
let similarity = match self.params.metric {
AdcDistanceMetric::Cosine | AdcDistanceMetric::Dot => (1.0 - dist).max(0.0),
AdcDistanceMetric::L2Squared => {
let max_dist = 4.0f32; (1.0 - dist / max_dist).max(0.0)
}
};
output.push((node_id.clone(), similarity));
}
}
output.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
output
}
#[must_use]
pub fn len(&self) -> usize {
self.count
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.count == 0
}
#[must_use]
pub fn dimension(&self) -> usize {
self.dimension
}
#[must_use]
pub fn params(&self) -> &Int8HnswParams {
&self.params
}
pub fn remove(&mut self, node_id: &str) -> bool {
if let Some(internal_id) = self.reverse_map.remove(node_id) {
self.id_map.remove(&internal_id);
self.deleted.insert(internal_id);
self.count -= 1;
true
} else {
false
}
}
pub fn clear(&mut self) {
let distance = Int8AdcDistance::new(self.params.metric);
self.hnsw = Hnsw::new(
self.params.m,
self.max_elements,
self.params.max_layer,
self.params.ef_construction,
distance,
);
self.id_map.clear();
self.reverse_map.clear();
self.deleted.clear();
self.next_id = 0;
self.count = 0;
}
#[must_use]
pub fn estimated_memory_bytes(&self) -> usize {
let vector_data = self.count * self.dimension + self.count * 32;
let edge_data = self.count * self.params.m * std::mem::size_of::<usize>();
let map_overhead = self.id_map.len()
* (std::mem::size_of::<usize>() + std::mem::size_of::<String>())
+ self.reverse_map.len()
* (std::mem::size_of::<String>() + std::mem::size_of::<usize>());
vector_data + edge_data + map_overhead
}
#[must_use]
pub fn memory_reduction_ratio(&self) -> f32 {
let quantized_memory = self.estimated_memory_bytes() as f32;
let f32_memory = (self.count * self.dimension * 4) as f32;
if f32_memory == 0.0 {
return 0.0;
}
1.0 - (quantized_memory / f32_memory)
}
}
impl Default for Int8HnswIndex {
fn default() -> Self {
Self::new(768)
}
}
#[derive(Debug, Error)]
pub enum Int8HnswError {
#[error("Dimension mismatch: expected {expected}, got {got}")]
DimensionMismatch {
expected: usize,
got: usize,
},
#[error("Node {0} already exists")]
NodeExists(String),
#[error("Node {0} not found")]
NodeNotFound(String),
#[error("Invalid parameter: {0}")]
InvalidParameter(String),
#[error("Insertion failed: {0}")]
InsertionFailed(String),
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_vector(dimension: usize, value: f32) -> Vec<f32> {
vec![value; dimension]
}
#[test]
fn test_int8_hnsw_creation() {
let index = Int8HnswIndex::new(768);
assert_eq!(index.dimension(), 768);
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_int8_hnsw_with_params() {
let params = Int8HnswParams::new()
.with_m(32)
.with_ef_construction(400)
.with_ef_search(100);
let index = Int8HnswIndex::with_params(768, params);
assert_eq!(index.params().m, 32);
assert_eq!(index.params().ef_construction, 400);
assert_eq!(index.params().ef_search, 100);
}
#[test]
fn test_int8_hnsw_insert() {
let mut index = Int8HnswIndex::new(64);
let vector = create_test_vector(64, 0.5);
let result = index.insert("test".to_string(), vector);
assert!(result.is_ok());
assert_eq!(index.len(), 1);
assert!(!index.is_empty());
}
#[test]
fn test_int8_hnsw_dimension_mismatch() {
let mut index = Int8HnswIndex::new(64);
let vector = create_test_vector(32, 0.5);
let result = index.insert("test".to_string(), vector);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
Int8HnswError::DimensionMismatch { .. }
));
}
#[test]
fn test_int8_hnsw_duplicate_insert() {
let mut index = Int8HnswIndex::new(64);
let vector = create_test_vector(64, 0.5);
index.insert("test".to_string(), vector.clone()).unwrap();
let result = index.insert("test".to_string(), vector);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), Int8HnswError::NodeExists(_)));
}
#[test]
fn test_int8_hnsw_search() {
let mut index = Int8HnswIndex::new(64);
for i in 0..10 {
let vector = create_test_vector(64, i as f32 / 10.0);
index.insert(format!("node_{}", i), vector).unwrap();
}
let query = create_test_vector(64, 0.5);
let results = index.search(&query, 5);
assert!(!results.is_empty());
assert!(results.len() <= 5);
for i in 1..results.len() {
assert!(results[i - 1].1 >= results[i].1);
}
}
#[test]
fn test_int8_hnsw_empty_search() {
let index = Int8HnswIndex::new(64);
let query = create_test_vector(64, 0.5);
let results = index.search(&query, 10);
assert!(results.is_empty());
}
#[test]
fn test_int8_hnsw_wrong_dimension_search() {
let mut index = Int8HnswIndex::new(64);
index
.insert("test".to_string(), create_test_vector(64, 0.5))
.unwrap();
let query = create_test_vector(32, 0.5); let results = index.search(&query, 10);
assert!(results.is_empty());
}
#[test]
fn test_int8_hnsw_batch_insert() {
let mut index = Int8HnswIndex::new(64);
let vectors: Vec<(String, Vec<f32>)> = (0..10)
.map(|i| {
(
format!("node_{}", i),
create_test_vector(64, i as f32 / 10.0),
)
})
.collect();
let inserted = index.insert_batch(vectors);
assert_eq!(inserted, 10);
assert_eq!(index.len(), 10);
}
#[test]
fn test_int8_hnsw_remove() {
let mut index = Int8HnswIndex::new(64);
index
.insert("test".to_string(), create_test_vector(64, 0.5))
.unwrap();
assert_eq!(index.len(), 1);
assert!(index.remove("test"));
assert_eq!(index.len(), 0);
assert!(!index.remove("nonexistent"));
}
#[test]
fn test_int8_hnsw_clear() {
let mut index = Int8HnswIndex::new(64);
for i in 0..5 {
index
.insert(format!("node_{}", i), create_test_vector(64, 0.5))
.unwrap();
}
assert_eq!(index.len(), 5);
index.clear();
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_int8_hnsw_params_validation() {
let valid_params = Int8HnswParams::default();
assert!(valid_params.validate().is_ok());
let invalid_m = Int8HnswParams {
m: 0,
..Default::default()
};
assert!(invalid_m.validate().is_err());
let invalid_ef = Int8HnswParams {
ef_construction: 5,
m: 10,
..Default::default()
};
assert!(invalid_ef.validate().is_err());
}
#[test]
fn test_memory_efficiency() {
let mut index = Int8HnswIndex::new(768);
for i in 0..100 {
let vector: Vec<f32> = (0..768)
.map(|j| ((i * 768 + j) % 100) as f32 / 100.0)
.collect();
index.insert(format!("node_{}", i), vector).unwrap();
}
let quantized_memory = index.estimated_memory_bytes();
let reduction_ratio = index.memory_reduction_ratio();
assert!(
reduction_ratio > 0.5,
"Memory reduction too low: {}",
reduction_ratio
);
let f32_memory = 100 * 768 * 4;
assert!(quantized_memory < f32_memory);
}
#[test]
fn test_int8_hnsw_different_metrics() {
for metric in [
AdcDistanceMetric::Cosine,
AdcDistanceMetric::L2Squared,
AdcDistanceMetric::Dot,
] {
let params = Int8HnswParams::new().with_metric(metric);
let mut index = Int8HnswIndex::with_params(64, params);
for i in 0..5 {
let vector = create_test_vector(64, i as f32 / 5.0);
index.insert(format!("node_{}", i), vector).unwrap();
}
let query = create_test_vector(64, 0.5);
let results = index.search(&query, 3);
assert!(!results.is_empty(), "Metric {:?} failed", metric);
}
}
#[test]
fn test_int8_hnsw_search_consistency() {
let mut index = Int8HnswIndex::new(64);
for i in 0..5 {
let mut vector = vec![0.0; 64];
vector[i] = 1.0;
index.insert(format!("node_{}", i), vector).unwrap();
}
let mut query = vec![0.0f32; 64];
query[0] = 1.0;
let results1 = index.search(&query, 3);
let results2 = index.search(&query, 3);
let results3 = index.search(&query, 3);
assert_eq!(results1.len(), results2.len());
assert_eq!(results2.len(), results3.len());
for i in 0..results1.len() {
assert_eq!(results1[i].0, results2[i].0);
assert_eq!(results2[i].0, results3[i].0);
}
}
}