use super::native_inner::NativeHnswInner;
use super::params::{HnswParams, SearchQuality};
use super::sharded_mappings::ShardedMappings;
use super::sharded_vectors::ShardedVectors;
use super::upsert::{self, UpsertResult};
use crate::distance::DistanceMetric;
use crate::index::VectorIndex;
use crate::scored_result::ScoredResult;
use parking_lot::RwLock;
use std::path::Path;
pub struct NativeHnswIndex {
dimension: usize,
metric: DistanceMetric,
inner: RwLock<NativeHnswInner>,
pub(crate) mappings: ShardedMappings,
vectors: ShardedVectors,
enable_vector_storage: bool,
#[allow(dead_code)] params: HnswParams,
}
impl NativeHnswIndex {
pub fn new(dimension: usize, metric: DistanceMetric) -> crate::error::Result<Self> {
Self::with_params(dimension, metric, HnswParams::auto(dimension))
}
pub fn with_params(
dimension: usize,
metric: DistanceMetric,
params: HnswParams,
) -> crate::error::Result<Self> {
let inner = NativeHnswInner::new(
metric,
params.max_connections,
params.max_elements,
params.ef_construction,
dimension,
)?;
Ok(Self {
dimension,
metric,
inner: RwLock::new(inner),
mappings: ShardedMappings::new(),
vectors: ShardedVectors::new(dimension),
enable_vector_storage: true,
params,
})
}
pub fn new_turbo(dimension: usize, metric: DistanceMetric) -> crate::error::Result<Self> {
Self::with_params(dimension, metric, HnswParams::turbo())
}
pub fn new_fast_insert(dimension: usize, metric: DistanceMetric) -> crate::error::Result<Self> {
let mut index = Self::new(dimension, metric)?;
index.enable_vector_storage = false;
Ok(index)
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> std::io::Result<()> {
use super::persistence::{self, HnswMappingsData, HnswMeta};
let path = path.as_ref();
std::fs::create_dir_all(path)?;
let inner = self.inner.read();
inner.file_dump(path, "native_hnsw")?;
let (id_to_idx, idx_to_id, next_idx) = self.mappings.as_parts();
persistence::save_mappings(
path,
&HnswMappingsData {
id_to_idx,
idx_to_id,
next_idx,
},
)?;
persistence::save_or_cleanup_vectors(path, self.enable_vector_storage, &self.vectors)?;
persistence::save_meta(
path,
&HnswMeta {
dimension: self.dimension,
metric: self.metric,
enable_vector_storage: self.enable_vector_storage,
},
)?;
Ok(())
}
pub fn load<P: AsRef<Path>>(
path: P,
_dimension: usize,
_metric: DistanceMetric,
) -> std::io::Result<Self> {
use super::persistence;
let path = path.as_ref();
let meta = persistence::load_meta(path)?;
let inner = NativeHnswInner::file_load(path, "native_hnsw", meta.metric, meta.dimension)?;
let mappings_data = persistence::load_mappings(path)?;
let mappings = ShardedMappings::from_parts(
mappings_data.id_to_idx,
mappings_data.idx_to_id,
mappings_data.next_idx,
);
let (vectors, enable_vector_storage) = persistence::load_vectors_or_disable(path, &meta)?;
Ok(Self {
dimension: meta.dimension,
metric: meta.metric,
inner: RwLock::new(inner),
mappings,
vectors,
enable_vector_storage,
params: HnswParams::auto(meta.dimension),
})
}
#[inline]
#[must_use]
pub fn dimension(&self) -> usize {
self.dimension
}
#[inline]
#[must_use]
pub fn metric(&self) -> DistanceMetric {
self.metric
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.mappings.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.mappings.is_empty()
}
#[inline]
#[must_use]
pub fn has_vector_storage(&self) -> bool {
self.enable_vector_storage
}
#[must_use]
pub fn search(&self, query: &[f32], k: usize) -> Vec<ScoredResult> {
self.search_with_quality(query, k, SearchQuality::Balanced)
}
#[must_use]
pub fn search_with_quality(
&self,
query: &[f32],
k: usize,
quality: SearchQuality,
) -> Vec<ScoredResult> {
let ef_search = quality.ef_search(k);
let inner = self.inner.read();
let neighbors = inner.search(query, k, ef_search);
neighbors
.into_iter()
.filter_map(|n| {
self.mappings.get_id(n.d_id).map(|id| {
let score = inner.transform_score(n.distance);
ScoredResult::new(id, score)
})
})
.collect()
}
#[must_use]
fn upsert_mapping(&self, id: u64) -> UpsertResult {
upsert::upsert_mapping(
&self.mappings,
&self.vectors,
self.enable_vector_storage,
id,
)
}
fn rollback_upsert(&self, id: u64, result: &UpsertResult) {
upsert::rollback_upsert(&self.mappings, id, result);
}
pub fn insert(&self, id: u64, vector: &[f32]) -> crate::error::Result<()> {
if vector.len() != self.dimension {
return Err(crate::error::Error::DimensionMismatch {
expected: self.dimension,
actual: vector.len(),
});
}
let result = self.upsert_mapping(id);
if let Err(e) = self.inner.read().insert((vector, result.idx)) {
self.rollback_upsert(id, &result);
return Err(e);
}
if self.enable_vector_storage {
self.vectors.insert(result.idx, vector);
}
Ok(())
}
pub fn insert_batch(&self, items: &[(u64, Vec<f32>)]) -> crate::error::Result<()> {
for (_id, vec) in items {
if vec.len() != self.dimension {
return Err(crate::error::Error::DimensionMismatch {
expected: self.dimension,
actual: vec.len(),
});
}
}
let ids: Vec<u64> = items.iter().map(|(id, _)| *id).collect();
let upsert_results = upsert::upsert_mapping_batch(
&self.mappings,
&self.vectors,
self.enable_vector_storage,
&ids,
);
let mut data: Vec<(&[f32], usize)> = Vec::with_capacity(items.len());
let mut rollback_info: Vec<(u64, UpsertResult)> = Vec::with_capacity(items.len());
for ((id, vec), result) in items.iter().zip(upsert_results) {
data.push((vec.as_slice(), result.idx));
rollback_info.push((*id, result));
}
let assigned_ids = match self.inner.read().parallel_insert(&data) {
Ok(ids) => ids,
Err(e) => {
for (id, result) in rollback_info.iter().rev() {
self.rollback_upsert(*id, result);
}
return Err(e);
}
};
let storage_ids = self.reconcile_batch_mappings(&rollback_info, &assigned_ids);
if self.enable_vector_storage {
for (vec_slice, idx) in data.iter().map(|(v, _)| *v).zip(storage_ids) {
self.vectors.insert(idx, vec_slice);
}
}
Ok(())
}
fn reconcile_batch_mappings(
&self,
rollback_info: &[(u64, UpsertResult)],
assigned_ids: &[usize],
) -> Vec<usize> {
let mut storage_ids = Vec::with_capacity(assigned_ids.len());
for (assigned_id, (ext_id, result)) in assigned_ids.iter().zip(rollback_info) {
if *assigned_id == result.idx {
storage_ids.push(result.idx);
} else {
self.mappings.remove_reverse(result.idx);
self.mappings.restore(*ext_id, *assigned_id);
storage_ids.push(*assigned_id);
}
}
storage_ids
}
pub fn remove(&self, id: u64) -> bool {
if let Some(old_idx) = self.mappings.remove(id) {
if self.enable_vector_storage {
self.vectors.remove(old_idx);
}
true
} else {
false
}
}
#[allow(clippy::unused_self)]
pub fn set_searching_mode(&self) {}
#[allow(clippy::needless_pass_by_value)]
pub fn insert_batch_parallel<I>(&self, items: I) -> usize
where
I: IntoIterator<Item = (u64, Vec<f32>)>,
{
let items: Vec<_> = items.into_iter().collect();
let count = items.len();
if let Err(e) = self.insert_batch(items.as_slice()) {
tracing::error!("insert_batch_parallel failed: {e}");
return 0;
}
count
}
#[must_use]
pub fn search_batch_parallel(
&self,
queries: &[&[f32]],
k: usize,
quality: SearchQuality,
) -> Vec<Vec<ScoredResult>> {
use rayon::prelude::*;
queries
.par_iter()
.map(|q| self.search_with_quality(q, k, quality))
.collect()
}
#[must_use]
pub fn brute_force_search_parallel(&self, query: &[f32], k: usize) -> Vec<ScoredResult> {
use rayon::prelude::*;
let vectors_snapshot = self.vectors.collect_for_parallel();
if vectors_snapshot.is_empty() {
return Vec::new();
}
let inner = self.inner.read();
let mut results: Vec<ScoredResult> = vectors_snapshot
.par_iter()
.filter_map(|(idx, vec)| {
let id = self.mappings.get_id(*idx)?;
let distance = inner.compute_distance(query, vec);
Some(ScoredResult::new(id, distance))
})
.collect();
self.metric.sort_scored_results(&mut results);
results.truncate(k);
results
}
}
impl VectorIndex for NativeHnswIndex {
fn insert(&self, id: u64, vector: &[f32]) {
if let Err(e) = NativeHnswIndex::insert(self, id, vector) {
tracing::error!("NativeHnswIndex::insert failed for id={id}: {e}");
}
}
fn remove(&self, id: u64) -> bool {
NativeHnswIndex::remove(self, id)
}
fn search(&self, query: &[f32], k: usize) -> Vec<ScoredResult> {
NativeHnswIndex::search(self, query, k)
}
fn len(&self) -> usize {
NativeHnswIndex::len(self)
}
fn dimension(&self) -> usize {
self.dimension
}
fn metric(&self) -> DistanceMetric {
self.metric
}
}