use super::native_inner::NativeHnswInner;
use super::params::{HnswParams, SearchQuality};
use super::sharded_mappings::ShardedMappings;
use super::sharded_vectors::ShardedVectors;
use super::upsert::{self, UpsertResult};
use crate::distance::DistanceMetric;
use crate::index::VectorIndex;
use crate::scored_result::ScoredResult;
use crate::validation::validate_dimension_match;
use parking_lot::RwLock;
pub struct NativeHnswIndex {
pub(crate) dimension: usize,
pub(crate) metric: DistanceMetric,
pub(crate) inner: RwLock<NativeHnswInner>,
pub(crate) mappings: ShardedMappings,
pub(crate) vectors: ShardedVectors,
pub(crate) enable_vector_storage: bool,
#[allow(dead_code)] pub(crate) params: HnswParams,
}
impl NativeHnswIndex {
pub fn new(dimension: usize, metric: DistanceMetric) -> crate::error::Result<Self> {
Self::with_params(dimension, metric, HnswParams::auto(dimension))
}
pub fn with_params(
dimension: usize,
metric: DistanceMetric,
params: HnswParams,
) -> crate::error::Result<Self> {
let inner = NativeHnswInner::new_with_options(
metric,
params.max_connections,
params.max_elements,
params.ef_construction,
dimension,
params.storage_mode,
params.alpha,
)?;
Ok(Self {
dimension,
metric,
inner: RwLock::new(inner),
mappings: ShardedMappings::new(),
vectors: ShardedVectors::new(dimension),
enable_vector_storage: true,
params,
})
}
pub fn new_turbo(dimension: usize, metric: DistanceMetric) -> crate::error::Result<Self> {
Self::with_params(dimension, metric, HnswParams::turbo())
}
pub fn new_fast_insert(dimension: usize, metric: DistanceMetric) -> crate::error::Result<Self> {
let mut index = Self::new(dimension, metric)?;
index.enable_vector_storage = false;
Ok(index)
}
#[inline]
#[must_use]
pub fn dimension(&self) -> usize {
self.dimension
}
#[inline]
#[must_use]
pub fn metric(&self) -> DistanceMetric {
self.metric
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.mappings.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.mappings.is_empty()
}
#[inline]
#[must_use]
pub fn has_vector_storage(&self) -> bool {
self.enable_vector_storage
}
#[must_use]
pub fn search(&self, query: &[f32], k: usize) -> Vec<ScoredResult> {
self.search_with_quality(query, k, SearchQuality::Balanced)
}
#[must_use]
pub fn search_with_quality(
&self,
query: &[f32],
k: usize,
quality: SearchQuality,
) -> Vec<ScoredResult> {
let ef_search = quality.ef_search_for_scale(k, self.len());
let inner = self.inner.read();
let neighbors = inner.search_auto(query, k, ef_search);
neighbors
.into_iter()
.filter_map(|(node_id, raw_dist)| {
self.mappings.get_id(node_id).map(|id| {
let score = inner.transform_score(raw_dist);
ScoredResult::new(id, score)
})
})
.collect()
}
#[must_use]
fn upsert_mapping(&self, id: u64) -> UpsertResult {
upsert::upsert_mapping(
&self.mappings,
&self.vectors,
self.enable_vector_storage,
id,
)
}
fn rollback_upsert(&self, id: u64, result: &UpsertResult) {
upsert::rollback_upsert(&self.mappings, id, result);
}
pub fn insert(&self, id: u64, vector: &[f32]) -> crate::error::Result<()> {
validate_dimension_match(self.dimension, vector.len())?;
let result = self.upsert_mapping(id);
if let Err(e) = self.inner.read().insert((vector, result.idx)) {
self.rollback_upsert(id, &result);
return Err(e);
}
if self.enable_vector_storage {
self.vectors.insert(result.idx, vector);
}
Ok(())
}
pub fn insert_batch(&self, items: &[(u64, Vec<f32>)]) -> crate::error::Result<()> {
let upsert_results = upsert::validate_and_register_batch(
&self.mappings,
&self.vectors,
self.enable_vector_storage,
self.dimension,
items,
)?;
let mut data: Vec<(&[f32], usize)> = Vec::with_capacity(items.len());
let mut rollback_info: Vec<(u64, UpsertResult)> = Vec::with_capacity(items.len());
for ((id, vec), result) in items.iter().zip(upsert_results) {
data.push((vec.as_slice(), result.idx));
rollback_info.push((*id, result));
}
let assigned_ids = match self.inner.read().parallel_insert(&data) {
Ok(ids) => ids,
Err(e) => {
upsert::rollback_batch(&self.mappings, &rollback_info);
return Err(e);
}
};
let storage_ids =
upsert::reconcile_batch_mappings(&self.mappings, &rollback_info, &assigned_ids);
if self.enable_vector_storage {
for (vec_slice, idx) in data.iter().map(|(v, _)| *v).zip(storage_ids) {
self.vectors.insert(idx, vec_slice);
}
}
Ok(())
}
pub fn remove(&self, id: u64) -> bool {
upsert::soft_delete(
&self.mappings,
&self.vectors,
self.enable_vector_storage,
id,
)
}
#[allow(clippy::unused_self)]
pub fn set_searching_mode(&self) {}
#[allow(clippy::needless_pass_by_value)]
pub fn insert_batch_parallel<I>(&self, items: I) -> usize
where
I: IntoIterator<Item = (u64, Vec<f32>)>,
{
let items: Vec<_> = items.into_iter().collect();
let count = items.len();
if let Err(e) = self.insert_batch(items.as_slice()) {
tracing::error!("insert_batch_parallel failed: {e}");
return 0;
}
count
}
#[must_use]
pub fn search_batch_parallel(
&self,
queries: &[&[f32]],
k: usize,
quality: SearchQuality,
) -> Vec<Vec<ScoredResult>> {
use rayon::prelude::*;
queries
.par_iter()
.map(|q| self.search_with_quality(q, k, quality))
.collect()
}
#[must_use]
pub fn brute_force_search_parallel(&self, query: &[f32], k: usize) -> Vec<ScoredResult> {
use rayon::prelude::*;
let vectors_snapshot = self.vectors.collect_for_parallel();
if vectors_snapshot.is_empty() {
return Vec::new();
}
let inner = self.inner.read();
let mut results: Vec<ScoredResult> = vectors_snapshot
.par_iter()
.filter_map(|(idx, vec)| {
let id = self.mappings.get_id(*idx)?;
let raw_distance = inner.compute_distance(query, vec);
let score = inner.transform_score(raw_distance);
Some(ScoredResult::new(id, score))
})
.collect();
self.metric.sort_scored_results(&mut results);
results.truncate(k);
results
}
}
impl VectorIndex for NativeHnswIndex {
fn insert(&self, id: u64, vector: &[f32]) {
if let Err(e) = NativeHnswIndex::insert(self, id, vector) {
tracing::error!("NativeHnswIndex::insert failed for id={id}: {e}");
}
}
fn remove(&self, id: u64) -> bool {
NativeHnswIndex::remove(self, id)
}
fn search(&self, query: &[f32], k: usize) -> Vec<ScoredResult> {
NativeHnswIndex::search(self, query, k)
}
fn len(&self) -> usize {
NativeHnswIndex::len(self)
}
fn dimension(&self) -> usize {
self.dimension
}
fn metric(&self) -> DistanceMetric {
self.metric
}
}