#![allow(clippy::similar_names)]
#![allow(unused_variables)]
use crate::{Result, Error};
use super::{Vector, DistanceMetric};
use hnsw_rs::prelude::*;
use parking_lot::RwLock;
use std::sync::Arc;
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswConfig {
pub max_connections: usize,
pub ef_construction: usize,
pub dimension: usize,
pub distance_metric: DistanceMetric,
pub ef_search_base: usize,
pub dynamic_ef_search: bool,
pub ef_search_min: usize,
pub ef_search_max: usize,
}
impl Default for HnswConfig {
fn default() -> Self {
Self {
max_connections: 16,
ef_construction: 200,
dimension: 1536, distance_metric: DistanceMetric::L2,
ef_search_base: 200,
dynamic_ef_search: true,
ef_search_min: 50,
ef_search_max: 500,
}
}
}
pub struct HnswIndex {
index: Arc<RwLock<Hnsw<'static, f32, DistL2>>>,
config: HnswConfig,
id_mapping: Arc<RwLock<Vec<u64>>>,
reverse_mapping: Arc<RwLock<std::collections::HashMap<u64, usize>>>,
}
impl HnswIndex {
pub fn new(config: HnswConfig) -> Result<Self> {
let max_nb_connection = config.max_connections;
let ef_construction = config.ef_construction;
let index = Hnsw::<f32, DistL2>::new(
max_nb_connection,
config.dimension,
ef_construction,
100, DistL2,
);
Ok(Self {
index: Arc::new(RwLock::new(index)),
config,
id_mapping: Arc::new(RwLock::new(Vec::new())),
reverse_mapping: Arc::new(RwLock::new(std::collections::HashMap::new())),
})
}
pub fn insert(&self, row_id: u64, vector: &Vector) -> Result<()> {
if vector.len() != self.config.dimension {
return Err(Error::query_execution(format!(
"Vector dimension mismatch: expected {}, got {}",
self.config.dimension,
vector.len()
)));
}
let mut id_mapping = self.id_mapping.write();
let mut reverse_mapping = self.reverse_mapping.write();
let hnsw_id = id_mapping.len();
id_mapping.push(row_id);
reverse_mapping.insert(row_id, hnsw_id);
let index = self.index.write();
let data_id = DataId::from(hnsw_id);
index.insert((vector.as_slice(), data_id));
Ok(())
}
pub fn search(&self, query: &Vector, k: usize) -> Result<Vec<(u64, f32)>> {
if query.len() != self.config.dimension {
return Err(Error::query_execution(format!(
"Query vector dimension mismatch: expected {}, got {}",
self.config.dimension,
query.len()
)));
}
let index = self.index.read();
let id_mapping = self.id_mapping.read();
let ef_search = self.calculate_ef_search(k, id_mapping.len());
let results = index.search(query.as_slice(), k, ef_search);
let mapped_results: Vec<(u64, f32)> = results
.into_iter()
.filter_map(|neighbor| {
let hnsw_id = neighbor.d_id as usize;
id_mapping.get(hnsw_id).map(|&row_id| {
(row_id, neighbor.distance)
})
})
.collect();
Ok(mapped_results)
}
fn calculate_ef_search(&self, k: usize, index_size: usize) -> usize {
if !self.config.dynamic_ef_search {
return self.config.ef_search_base;
}
let k_based = k * 2;
let size_factor = if index_size > 1000 {
(index_size as f64).log2() / 10.0 } else {
1.0
};
let adjusted = ((self.config.ef_search_base as f64 * size_factor) as usize).max(k_based);
adjusted.clamp(self.config.ef_search_min, self.config.ef_search_max)
}
pub fn search_with_ef(&self, query: &Vector, k: usize, ef_search: usize) -> Result<Vec<(u64, f32)>> {
if query.len() != self.config.dimension {
return Err(Error::query_execution(format!(
"Query vector dimension mismatch: expected {}, got {}",
self.config.dimension,
query.len()
)));
}
let index = self.index.read();
let id_mapping = self.id_mapping.read();
let ef_search = ef_search.clamp(k, self.config.ef_search_max);
let results = index.search(query.as_slice(), k, ef_search);
let mapped_results: Vec<(u64, f32)> = results
.into_iter()
.filter_map(|neighbor| {
let hnsw_id = neighbor.d_id as usize;
id_mapping.get(hnsw_id).map(|&row_id| {
(row_id, neighbor.distance)
})
})
.collect();
Ok(mapped_results)
}
pub fn delete(&self, row_id: u64) -> Result<()> {
let mut reverse_mapping = self.reverse_mapping.write();
if let Some(&hnsw_id) = reverse_mapping.get(&row_id) {
reverse_mapping.remove(&row_id);
Ok(())
} else {
Err(Error::query_execution(format!(
"Vector with row_id {} not found in index",
row_id
)))
}
}
pub fn len(&self) -> usize {
self.id_mapping.read().len()
}
pub fn is_empty(&self) -> bool {
self.id_mapping.read().is_empty()
}
pub fn dimension(&self) -> usize {
self.config.dimension
}
}
pub enum MultiMetricHnswIndex {
L2(HnswIndex),
Cosine(CosineHnswIndex),
InnerProduct(InnerProductHnswIndex),
}
impl MultiMetricHnswIndex {
pub fn new(config: HnswConfig) -> Result<Self> {
match config.distance_metric {
DistanceMetric::L2 => Ok(Self::L2(HnswIndex::new(config)?)),
DistanceMetric::Cosine => Ok(Self::Cosine(CosineHnswIndex::new(config)?)),
DistanceMetric::InnerProduct => Ok(Self::InnerProduct(InnerProductHnswIndex::new(config)?)),
}
}
pub fn insert(&self, row_id: u64, vector: &Vector) -> Result<()> {
match self {
Self::L2(index) => index.insert(row_id, vector),
Self::Cosine(index) => index.insert(row_id, vector),
Self::InnerProduct(index) => index.insert(row_id, vector),
}
}
pub fn search(&self, query: &Vector, k: usize) -> Result<Vec<(u64, f32)>> {
match self {
Self::L2(index) => index.search(query, k),
Self::Cosine(index) => index.search(query, k),
Self::InnerProduct(index) => index.search(query, k),
}
}
pub fn delete(&self, row_id: u64) -> Result<()> {
match self {
Self::L2(index) => index.delete(row_id),
Self::Cosine(index) => index.delete(row_id),
Self::InnerProduct(index) => index.delete(row_id),
}
}
pub fn dimension(&self) -> usize {
match self {
Self::L2(index) => index.dimension(),
Self::Cosine(index) => index.dimension(),
Self::InnerProduct(index) => index.dimension(),
}
}
pub fn len(&self) -> usize {
match self {
Self::L2(index) => index.len(),
Self::Cosine(index) => index.len(),
Self::InnerProduct(index) => index.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub struct CosineHnswIndex {
index: Arc<RwLock<Hnsw<'static, f32, DistCosine>>>,
config: HnswConfig,
id_mapping: Arc<RwLock<Vec<u64>>>,
reverse_mapping: Arc<RwLock<std::collections::HashMap<u64, usize>>>,
}
impl CosineHnswIndex {
pub fn new(config: HnswConfig) -> Result<Self> {
let index = Hnsw::<f32, DistCosine>::new(
config.max_connections,
config.dimension,
config.ef_construction,
100,
DistCosine,
);
Ok(Self {
index: Arc::new(RwLock::new(index)),
config,
id_mapping: Arc::new(RwLock::new(Vec::new())),
reverse_mapping: Arc::new(RwLock::new(std::collections::HashMap::new())),
})
}
pub fn insert(&self, row_id: u64, vector: &Vector) -> Result<()> {
if vector.len() != self.config.dimension {
return Err(Error::query_execution(format!(
"Vector dimension mismatch: expected {}, got {}",
self.config.dimension,
vector.len()
)));
}
let mut id_mapping = self.id_mapping.write();
let mut reverse_mapping = self.reverse_mapping.write();
let hnsw_id = id_mapping.len();
id_mapping.push(row_id);
reverse_mapping.insert(row_id, hnsw_id);
let index = self.index.write();
let data_id = DataId::from(hnsw_id);
index.insert((vector.as_slice(), data_id));
Ok(())
}
pub fn search(&self, query: &Vector, k: usize) -> Result<Vec<(u64, f32)>> {
if query.len() != self.config.dimension {
return Err(Error::query_execution(format!(
"Query vector dimension mismatch: expected {}, got {}",
self.config.dimension,
query.len()
)));
}
let index = self.index.read();
let id_mapping = self.id_mapping.read();
let results = index.search(query.as_slice(), k, 200);
let mapped_results: Vec<(u64, f32)> = results
.into_iter()
.filter_map(|neighbor| {
let hnsw_id = neighbor.d_id as usize;
id_mapping.get(hnsw_id).map(|&row_id| {
(row_id, neighbor.distance)
})
})
.collect();
Ok(mapped_results)
}
pub fn delete(&self, row_id: u64) -> Result<()> {
let mut reverse_mapping = self.reverse_mapping.write();
if reverse_mapping.remove(&row_id).is_some() {
Ok(())
} else {
Err(Error::query_execution(format!(
"Vector with row_id {} not found in index",
row_id
)))
}
}
pub fn dimension(&self) -> usize {
self.config.dimension
}
pub fn len(&self) -> usize {
self.id_mapping.read().len()
}
pub fn is_empty(&self) -> bool {
self.id_mapping.read().is_empty()
}
}
pub struct InnerProductHnswIndex {
index: Arc<RwLock<Hnsw<'static, f32, DistDot>>>,
config: HnswConfig,
id_mapping: Arc<RwLock<Vec<u64>>>,
reverse_mapping: Arc<RwLock<std::collections::HashMap<u64, usize>>>,
}
impl InnerProductHnswIndex {
pub fn new(config: HnswConfig) -> Result<Self> {
let index = Hnsw::<f32, DistDot>::new(
config.max_connections,
config.dimension,
config.ef_construction,
100,
DistDot,
);
Ok(Self {
index: Arc::new(RwLock::new(index)),
config,
id_mapping: Arc::new(RwLock::new(Vec::new())),
reverse_mapping: Arc::new(RwLock::new(std::collections::HashMap::new())),
})
}
pub fn insert(&self, row_id: u64, vector: &Vector) -> Result<()> {
if vector.len() != self.config.dimension {
return Err(Error::query_execution(format!(
"Vector dimension mismatch: expected {}, got {}",
self.config.dimension,
vector.len()
)));
}
let mut id_mapping = self.id_mapping.write();
let mut reverse_mapping = self.reverse_mapping.write();
let hnsw_id = id_mapping.len();
id_mapping.push(row_id);
reverse_mapping.insert(row_id, hnsw_id);
let index = self.index.write();
let data_id = DataId::from(hnsw_id);
index.insert((vector.as_slice(), data_id));
Ok(())
}
pub fn search(&self, query: &Vector, k: usize) -> Result<Vec<(u64, f32)>> {
if query.len() != self.config.dimension {
return Err(Error::query_execution(format!(
"Query vector dimension mismatch: expected {}, got {}",
self.config.dimension,
query.len()
)));
}
let index = self.index.read();
let id_mapping = self.id_mapping.read();
let results = index.search(query.as_slice(), k, 200);
let mapped_results: Vec<(u64, f32)> = results
.into_iter()
.filter_map(|neighbor| {
let hnsw_id = neighbor.d_id as usize;
id_mapping.get(hnsw_id).map(|&row_id| {
(row_id, neighbor.distance)
})
})
.collect();
Ok(mapped_results)
}
pub fn delete(&self, row_id: u64) -> Result<()> {
let mut reverse_mapping = self.reverse_mapping.write();
if reverse_mapping.remove(&row_id).is_some() {
Ok(())
} else {
Err(Error::query_execution(format!(
"Vector with row_id {} not found in index",
row_id
)))
}
}
pub fn dimension(&self) -> usize {
self.config.dimension
}
pub fn len(&self) -> usize {
self.id_mapping.read().len()
}
pub fn is_empty(&self) -> bool {
self.id_mapping.read().is_empty()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_hnsw_basic() {
let config = HnswConfig {
dimension: 3,
max_connections: 16,
ef_construction: 200,
distance_metric: DistanceMetric::L2,
ef_search_base: 200,
dynamic_ef_search: true,
ef_search_min: 50,
ef_search_max: 500,
};
let index = HnswIndex::new(config).unwrap();
index.insert(1, &vec![1.0, 0.0, 0.0]).unwrap();
index.insert(2, &vec![0.0, 1.0, 0.0]).unwrap();
index.insert(3, &vec![0.0, 0.0, 1.0]).unwrap();
let query = vec![1.0, 0.1, 0.0];
let results = index.search(&query, 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 1); }
#[test]
fn test_dimension_validation() {
let config = HnswConfig {
dimension: 3,
..Default::default()
};
let index = HnswIndex::new(config).unwrap();
let result = index.insert(1, &vec![1.0, 0.0]);
assert!(result.is_err());
}
#[test]
fn test_multi_metric_index() {
let config = HnswConfig {
dimension: 2,
distance_metric: DistanceMetric::Cosine,
..Default::default()
};
let index = MultiMetricHnswIndex::new(config).unwrap();
index.insert(1, &vec![1.0, 0.0]).unwrap();
index.insert(2, &vec![0.0, 1.0]).unwrap();
let results = index.search(&vec![0.7, 0.7], 1).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn test_vector_count_tracking() {
let test_configs = vec![
(DistanceMetric::L2, "L2"),
(DistanceMetric::Cosine, "Cosine"),
(DistanceMetric::InnerProduct, "InnerProduct"),
];
for (metric, name) in test_configs {
let config = HnswConfig {
dimension: 3,
distance_metric: metric,
..Default::default()
};
let index = MultiMetricHnswIndex::new(config).unwrap();
assert_eq!(index.len(), 0, "{} index should start empty", name);
assert!(index.is_empty(), "{} index should be empty", name);
index.insert(1, &vec![1.0, 0.0, 0.0]).unwrap();
assert_eq!(index.len(), 1, "{} index should have 1 vector", name);
assert!(!index.is_empty(), "{} index should not be empty", name);
index.insert(2, &vec![0.0, 1.0, 0.0]).unwrap();
assert_eq!(index.len(), 2, "{} index should have 2 vectors", name);
index.insert(3, &vec![0.0, 0.0, 1.0]).unwrap();
assert_eq!(index.len(), 3, "{} index should have 3 vectors", name);
index.delete(2).unwrap();
assert_eq!(index.len(), 3, "{} index length should remain 3 (tombstone)", name);
}
}
#[test]
fn test_index_len_methods() {
let config = HnswConfig {
dimension: 2,
max_connections: 16,
ef_construction: 200,
distance_metric: DistanceMetric::L2,
ef_search_base: 200,
dynamic_ef_search: true,
ef_search_min: 50,
ef_search_max: 500,
};
let index = HnswIndex::new(config).unwrap();
assert_eq!(index.len(), 0);
assert!(index.is_empty());
index.insert(1, &vec![1.0, 0.0]).unwrap();
assert_eq!(index.len(), 1);
assert!(!index.is_empty());
index.insert(2, &vec![0.0, 1.0]).unwrap();
assert_eq!(index.len(), 2);
}
}