#![allow(dead_code)]
#![allow(clippy::cast_precision_loss)]
use super::native::rabitq_precision::RaBitQPrecisionHnsw;
use super::native::{CachedSimdDistance, NativeHnsw, NativeNeighbour};
use crate::distance::DistanceMetric;
use std::path::Path;
#[allow(clippy::large_enum_variant)]
enum HnswBackend {
Standard(NativeHnsw<CachedSimdDistance>),
RaBitQ(Box<RaBitQPrecisionHnsw<CachedSimdDistance>>),
}
pub struct NativeHnswInner {
backend: HnswBackend,
metric: DistanceMetric,
}
impl NativeHnswInner {
pub fn new(
metric: DistanceMetric,
max_connections: usize,
max_elements: usize,
ef_construction: usize,
dimension: usize,
) -> crate::error::Result<Self> {
Self::new_with_storage_mode(
metric,
max_connections,
max_elements,
ef_construction,
dimension,
crate::StorageMode::Full,
)
}
pub fn new_with_storage_mode(
metric: DistanceMetric,
max_connections: usize,
max_elements: usize,
ef_construction: usize,
dimension: usize,
storage_mode: crate::StorageMode,
) -> crate::error::Result<Self> {
let backend = if matches!(storage_mode, crate::StorageMode::RaBitQ) {
let distance = CachedSimdDistance::new(metric, dimension);
let rabitq = RaBitQPrecisionHnsw::new(
distance,
dimension,
max_connections,
ef_construction,
max_elements,
)?;
HnswBackend::RaBitQ(Box::new(rabitq))
} else {
let distance = CachedSimdDistance::new(metric, dimension);
let inner = if dimension > 0 {
NativeHnsw::new_with_dimension(
distance,
max_connections,
ef_construction,
max_elements,
dimension,
)?
} else {
NativeHnsw::new(distance, max_connections, ef_construction, max_elements)
};
HnswBackend::Standard(inner)
};
Ok(Self { backend, metric })
}
#[must_use]
pub fn storage_mode(&self) -> crate::StorageMode {
match &self.backend {
HnswBackend::Standard(_) => crate::StorageMode::Full,
HnswBackend::RaBitQ(_) => crate::StorageMode::RaBitQ,
}
}
}
impl NativeHnswInner {
#[inline]
#[must_use]
pub fn search(&self, query: &[f32], k: usize, ef_search: usize) -> Vec<(usize, f32)> {
match &self.backend {
HnswBackend::Standard(hnsw) => hnsw.search(query, k, ef_search),
HnswBackend::RaBitQ(rabitq) => rabitq.search(query, k, ef_search),
}
}
#[inline]
#[must_use]
pub fn search_neighbours(
&self,
query: &[f32],
k: usize,
ef_search: usize,
) -> Vec<NativeNeighbour> {
match &self.backend {
HnswBackend::Standard(hnsw) => hnsw.search_neighbours(query, k, ef_search),
HnswBackend::RaBitQ(rabitq) => rabitq
.search(query, k, ef_search)
.into_iter()
.map(|(id, dist)| NativeNeighbour {
d_id: id,
distance: dist,
})
.collect(),
}
}
}
impl NativeHnswInner {
pub fn insert(&self, data: (&[f32], usize)) -> crate::error::Result<usize> {
let (vector, expected_idx) = data;
let assigned_id = match &self.backend {
HnswBackend::Standard(hnsw) => hnsw.insert(vector)?,
HnswBackend::RaBitQ(rabitq) => rabitq.insert(vector)?,
};
if assigned_id != expected_idx {
tracing::warn!(
"NativeHnsw node_id mismatch: expected {expected_idx}, got {assigned_id} \
— mapping may be desynchronised under concurrent inserts"
);
}
Ok(assigned_id)
}
pub fn parallel_insert(&self, data: &[(&[f32], usize)]) -> crate::error::Result<Vec<usize>> {
match &self.backend {
HnswBackend::Standard(hnsw) => hnsw.parallel_insert(data),
HnswBackend::RaBitQ(_) => {
let mut ids = Vec::with_capacity(data.len());
for &(vector, expected_idx) in data {
ids.push(self.insert((vector, expected_idx))?);
}
Ok(ids)
}
}
}
pub fn set_searching_mode(&mut self, mode: bool) {
match &mut self.backend {
HnswBackend::Standard(hnsw) => hnsw.set_searching_mode(mode),
HnswBackend::RaBitQ(rabitq) => rabitq.inner.set_searching_mode(mode),
}
}
}
impl NativeHnswInner {
pub fn file_dump(&self, path: &Path, basename: &str) -> std::io::Result<()> {
match &self.backend {
HnswBackend::Standard(hnsw) => hnsw.file_dump(path, basename),
HnswBackend::RaBitQ(rabitq) => rabitq.inner.file_dump(path, basename),
}
}
pub fn file_load(
path: &Path,
basename: &str,
metric: DistanceMetric,
dimension: usize,
) -> std::io::Result<Self> {
let distance = CachedSimdDistance::new(metric, dimension);
let inner = NativeHnsw::file_load(path, basename, distance)?;
Ok(Self {
backend: HnswBackend::Standard(inner),
metric,
})
}
pub fn file_load_with_storage_mode(
path: &Path,
basename: &str,
metric: DistanceMetric,
dimension: usize,
storage_mode: crate::StorageMode,
) -> std::io::Result<Self> {
let distance = CachedSimdDistance::new(metric, dimension);
let inner = NativeHnsw::file_load(path, basename, distance)?;
let backend = if matches!(storage_mode, crate::StorageMode::RaBitQ) {
let distance = CachedSimdDistance::new(metric, dimension);
let rabitq = RaBitQPrecisionHnsw::from_inner(inner, distance, dimension);
HnswBackend::RaBitQ(Box::new(rabitq))
} else {
HnswBackend::Standard(inner)
};
Ok(Self { backend, metric })
}
}
impl NativeHnswInner {
#[inline]
#[must_use]
pub fn transform_score(&self, raw_distance: f32) -> f32 {
match &self.backend {
HnswBackend::Standard(hnsw) => hnsw.transform_score(raw_distance),
HnswBackend::RaBitQ(_) => raw_distance,
}
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
match &self.backend {
HnswBackend::Standard(hnsw) => hnsw.len(),
HnswBackend::RaBitQ(rabitq) => rabitq.len(),
}
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
match &self.backend {
HnswBackend::Standard(hnsw) => hnsw.is_empty(),
HnswBackend::RaBitQ(rabitq) => rabitq.is_empty(),
}
}
#[inline]
#[must_use]
pub fn metric(&self) -> DistanceMetric {
self.metric
}
#[inline]
#[must_use]
pub fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
match &self.backend {
HnswBackend::Standard(hnsw) => hnsw.compute_distance(a, b),
HnswBackend::RaBitQ(rabitq) => rabitq.inner.compute_distance(a, b),
}
}
#[inline]
pub fn with_contiguous_vectors<R: Default>(
&self,
f: impl FnOnce(&crate::perf_optimizations::ContiguousVectors) -> R,
) -> R {
match &self.backend {
HnswBackend::Standard(hnsw) => hnsw.with_vectors_read(f),
HnswBackend::RaBitQ(rabitq) => rabitq.inner.with_vectors_read(f),
}
}
}
unsafe impl Send for NativeHnswInner {}
unsafe impl Sync for NativeHnswInner {}