use std::collections::HashMap;
use std::mem::size_of;
use std::sync::Arc;
use iqdb_index::{Index, IndexCore, IndexStats};
use iqdb_types::{DistanceMetric, Hit, IqdbError, Metadata, Result, SearchParams, VectorId};
use crate::config::HnswConfig;
use crate::graph::NodeIdx;
use crate::rng::SplitMix64;
use crate::{insert as insert_algo, search as search_algo};
#[derive(Debug)]
pub struct HnswIndex {
pub(crate) dim: usize,
pub(crate) metric: DistanceMetric,
pub(crate) cfg: HnswConfig,
pub(crate) m_l_inv: f64,
pub(crate) vectors: Vec<Arc<[f32]>>,
pub(crate) ids: Vec<VectorId>,
pub(crate) metadata: Vec<Option<Metadata>>,
pub(crate) seqs: Vec<u64>,
pub(crate) tombstoned: Vec<bool>,
pub(crate) node_layer: Vec<u8>,
pub(crate) layers: Vec<Vec<Vec<NodeIdx>>>,
pub(crate) id_to_node: HashMap<VectorId, NodeIdx>,
pub(crate) entry: Option<NodeIdx>,
pub(crate) top_layer: u8,
pub(crate) next_seq: u64,
pub(crate) rng: SplitMix64,
pub(crate) live_count: usize,
}
impl HnswIndex {
pub fn new_unconfigured(dim: usize, metric: DistanceMetric, cfg: HnswConfig) -> Result<Self> {
if dim == 0 {
return Err(IqdbError::InvalidConfig {
reason: "HnswIndex dim must be greater than zero",
});
}
if cfg.m == 0 {
return Err(IqdbError::InvalidConfig {
reason: "HnswConfig.m must be greater than zero",
});
}
if cfg.ef_construction < cfg.m {
return Err(IqdbError::InvalidConfig {
reason: "HnswConfig.ef_construction must be >= m",
});
}
if cfg.ef_search == 0 {
return Err(IqdbError::InvalidConfig {
reason: "HnswConfig.ef_search must be greater than zero",
});
}
if cfg.filter_widen == 0 {
return Err(IqdbError::InvalidConfig {
reason: "HnswConfig.filter_widen must be greater than zero",
});
}
let m_l_inv = 1.0_f64 / (cfg.m as f64).ln();
Ok(Self {
dim,
metric,
cfg,
m_l_inv,
vectors: Vec::new(),
ids: Vec::new(),
metadata: Vec::new(),
seqs: Vec::new(),
tombstoned: Vec::new(),
node_layer: Vec::new(),
layers: Vec::new(),
id_to_node: HashMap::new(),
entry: None,
top_layer: 0,
next_seq: 0,
rng: SplitMix64::new(cfg.seed),
live_count: 0,
})
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use]
pub fn metric(&self) -> DistanceMetric {
self.metric
}
#[must_use]
pub fn len(&self) -> usize {
self.live_count
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.live_count == 0
}
#[must_use]
pub fn config(&self) -> HnswConfig {
self.cfg
}
#[must_use]
pub fn node_layer_histogram(&self) -> Vec<usize> {
if self.node_layer.is_empty() {
return Vec::new();
}
let max_layer = self.node_layer.iter().copied().max().unwrap_or(0);
let mut out = vec![0_usize; (max_layer as usize) + 1];
for &l in &self.node_layer {
out[l as usize] = out[l as usize].saturating_add(1);
}
out
}
#[doc(hidden)]
pub fn search_with_ef(
&self,
query: &[f32],
params: &SearchParams,
ef: usize,
) -> Result<Vec<Hit>> {
search_algo::search_with_ef(self, query, params, ef)
}
pub(crate) fn check_dim(&self, vector_len: usize) -> Result<()> {
if vector_len != self.dim {
return Err(IqdbError::DimensionMismatch {
expected: self.dim,
found: vector_len,
});
}
Ok(())
}
fn approximate_memory_bytes(&self) -> usize {
let arc_header_bytes = 2 * size_of::<usize>();
let vectors_bytes = self
.vectors
.iter()
.map(|arc| arc.len() * size_of::<f32>() + arc_header_bytes)
.sum::<usize>()
+ self.vectors.capacity() * size_of::<Arc<[f32]>>();
let ids_bytes = self.ids.capacity() * size_of::<VectorId>();
let metadata_bytes = self.metadata.capacity() * size_of::<Option<Metadata>>();
let seqs_bytes = self.seqs.capacity() * size_of::<u64>();
let tombstoned_bytes = self.tombstoned.capacity() * size_of::<bool>();
let node_layer_bytes = self.node_layer.capacity() * size_of::<u8>();
let layers_bytes: usize = self
.layers
.iter()
.map(|per_node| {
size_of::<Vec<Vec<NodeIdx>>>() * per_node.capacity()
+ per_node
.iter()
.map(|adj| adj.capacity() * size_of::<NodeIdx>())
.sum::<usize>()
})
.sum();
let id_to_node_bytes =
self.id_to_node.capacity() * (size_of::<VectorId>() + size_of::<NodeIdx>());
vectors_bytes
+ ids_bytes
+ metadata_bytes
+ seqs_bytes
+ tombstoned_bytes
+ node_layer_bytes
+ layers_bytes
+ id_to_node_bytes
}
}
impl IndexCore for HnswIndex {
fn insert(
&mut self,
id: VectorId,
vector: Arc<[f32]>,
metadata: Option<Metadata>,
) -> Result<()> {
insert_algo::insert_node(self, id, vector, metadata)
}
fn delete(&mut self, id: &VectorId) -> Result<()> {
let node = self.id_to_node.remove(id).ok_or(IqdbError::NotFound)?;
self.tombstoned[node as usize] = true;
self.live_count = self.live_count.saturating_sub(1);
Ok(())
}
fn search(&self, query: &[f32], params: &SearchParams) -> Result<Vec<Hit>> {
search_algo::search(self, query, params)
}
fn len(&self) -> usize {
HnswIndex::len(self)
}
fn is_empty(&self) -> bool {
HnswIndex::is_empty(self)
}
fn dim(&self) -> usize {
HnswIndex::dim(self)
}
fn metric(&self) -> DistanceMetric {
HnswIndex::metric(self)
}
fn flush(&mut self) -> Result<()> {
Ok(())
}
fn stats(&self) -> IndexStats {
IndexStats {
n_vectors: self.live_count,
memory_bytes: self.approximate_memory_bytes(),
disk_bytes: None,
index_type: "hnsw",
extra: None,
}
}
}
impl Index for HnswIndex {
type Config = HnswConfig;
fn new(dim: usize, metric: DistanceMetric, config: Self::Config) -> Result<Self> {
Self::new_unconfigured(dim, metric, config)
}
}