use crate::*;
use alloc::{vec, vec::Vec};
use hashbrown::HashSet;
use rand_core::{RngCore, SeedableRng};
use rand_pcg::Pcg64;
use rustc_hash::FxHasher;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use space::{CandidatesVec, MetricPoint, Neighbor};
#[derive(Clone)]
#[cfg_attr(
feature = "serde1",
derive(Serialize, Deserialize),
serde(bound(
serialize = "T: Serialize, R: Serialize",
deserialize = "T: Deserialize<'de>, R: Deserialize<'de>"
))
)]
pub struct HNSW<
T,
M: ArrayLength<u32> = typenum::U12,
M0: ArrayLength<u32> = typenum::U24,
R = Pcg64,
> {
zero: Vec<ZeroNode<M0>>,
features: Vec<T>,
layers: Vec<Vec<Node<M>>>,
prng: R,
params: Params,
}
#[derive(Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize), serde(bound = ""))]
struct ZeroNode<N: ArrayLength<u32>> {
neighbors: GenericArray<u32, N>,
}
impl<N: ArrayLength<u32>> ZeroNode<N> {
fn neighbors<'a>(&'a self) -> impl Iterator<Item = u32> + 'a {
self.neighbors.iter().cloned().take_while(|&n| n != !0)
}
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize), serde(bound = ""))]
struct Node<N: ArrayLength<u32>> {
zero_node: u32,
next_node: u32,
neighbors: GenericArray<u32, N>,
}
impl<N: ArrayLength<u32>> Node<N> {
fn neighbors<'a>(&'a self) -> impl Iterator<Item = u32> + 'a {
self.neighbors.iter().cloned().take_while(|&n| n != !0)
}
}
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Searcher {
candidates: Vec<Neighbor>,
nearest: CandidatesVec,
seen: HashSet<u32, core::hash::BuildHasherDefault<FxHasher>>,
}
impl Searcher {
pub fn new() -> Self {
Default::default()
}
fn clear(&mut self) {
self.candidates.clear();
self.nearest.clear();
self.seen.clear();
}
}
impl<T, M: ArrayLength<u32>, M0: ArrayLength<u32>, R> HNSW<T, M, M0, R>
where
R: RngCore + SeedableRng,
{
pub fn new() -> Self {
Self::default()
}
pub fn new_params(params: Params) -> Self {
Self {
params,
..Default::default()
}
}
}
impl<T, M: ArrayLength<u32>, M0: ArrayLength<u32>, R> HNSW<T, M, M0, R>
where
R: RngCore,
T: MetricPoint,
{
pub fn new_prng(prng: R) -> Self {
Self {
zero: vec![],
features: vec![],
layers: vec![],
prng,
params: Default::default(),
}
}
pub fn new_params_and_prng(params: Params, prng: R) -> Self {
Self {
zero: vec![],
features: vec![],
layers: vec![],
prng,
params,
}
}
pub fn insert(&mut self, q: T, searcher: &mut Searcher) -> u32 {
let level = self.random_level();
if self.is_empty() {
self.zero.push(ZeroNode {
neighbors: core::iter::repeat(!0).collect(),
});
self.features.push(q);
while self.layers.len() < level {
let node = Node {
zero_node: 0,
next_node: 0,
neighbors: core::iter::repeat(!0).collect(),
};
self.layers.push(vec![node]);
}
return 0;
}
self.initialize_searcher(
&q,
searcher,
if level >= self.layers.len() {
self.params.ef_construction
} else {
1
},
);
for ix in (level..self.layers.len()).rev() {
self.search_single_layer(&q, searcher, &self.layers[ix]);
self.lower_search(
&self.layers[ix],
searcher,
if ix == level {
self.params.ef_construction
} else {
1
},
);
}
for ix in (0..core::cmp::min(level, self.layers.len())).rev() {
self.search_single_layer(&q, searcher, &self.layers[ix]);
self.create_node(&q, &searcher.nearest, ix + 1);
self.lower_search(&self.layers[ix], searcher, self.params.ef_construction);
}
self.search_zero_layer(&q, searcher);
self.create_node(&q, &searcher.nearest, 0);
self.features.push(q);
let zero_node = (self.zero.len() - 1) as u32;
while self.layers.len() < level {
let node = Node {
zero_node,
next_node: self
.layers
.last()
.map(|l| (l.len() - 1) as u32)
.unwrap_or(zero_node),
neighbors: core::iter::repeat(!0).collect(),
};
self.layers.push(vec![node]);
}
zero_node
}
pub fn nearest<'a>(
&self,
q: &T,
ef: usize,
searcher: &mut Searcher,
dest: &'a mut [Neighbor],
) -> &'a mut [Neighbor] {
self.search_layer(q, ef, 0, searcher, dest)
}
pub fn feature(&self, item: u32) -> &T {
&self.features[item as usize]
}
pub fn layer_feature(&self, level: usize, item: u32) -> &T {
&self.features[self.layer_item_id(level, item) as usize]
}
pub fn layer_item_id(&self, level: usize, item: u32) -> u32 {
if level == 0 {
item
} else {
self.layers[level][item as usize].zero_node
}
}
pub fn layers(&self) -> usize {
self.layers.len() + 1
}
pub fn len(&self) -> usize {
self.zero.len()
}
pub fn layer_len(&self, level: usize) -> usize {
if level == 0 {
self.features.len()
} else if level < self.layers() {
self.layers[level - 1].len()
} else {
0
}
}
pub fn is_empty(&self) -> bool {
self.zero.is_empty()
}
pub fn layer_is_empty(&self, level: usize) -> bool {
self.layer_len(level) == 0
}
pub fn search_layer<'a>(
&self,
q: &T,
ef: usize,
level: usize,
searcher: &mut Searcher,
dest: &'a mut [Neighbor],
) -> &'a mut [Neighbor] {
if self.features.is_empty() || level >= self.layers() {
return &mut [];
}
self.initialize_searcher(q, searcher, if self.layers.is_empty() { ef } else { 1 });
for (ix, layer) in self.layers.iter().enumerate().rev() {
self.search_single_layer(q, searcher, layer);
if ix + 1 == level {
return searcher.nearest.fill_slice(dest);
}
self.lower_search(layer, searcher, if ix == 0 { ef } else { 1 });
}
self.search_zero_layer(q, searcher);
searcher.nearest.fill_slice(dest)
}
fn search_single_layer(&self, q: &T, searcher: &mut Searcher, layer: &[Node<M>]) {
while let Some(Neighbor { index, .. }) = searcher.candidates.pop() {
for neighbor in layer[index as usize].neighbors() {
let neighbor_node = &layer[neighbor as usize];
if searcher.seen.insert(neighbor_node.zero_node) {
let distance = T::distance(q, &self.features[neighbor_node.zero_node as usize]);
let candidate = Neighbor {
index: neighbor as usize,
distance,
};
if searcher.nearest.push(candidate) {
searcher.candidates.push(candidate);
}
}
}
}
}
fn search_zero_layer(&self, q: &T, searcher: &mut Searcher) {
while let Some(Neighbor { index, .. }) = searcher.candidates.pop() {
for neighbor in self.zero[index as usize].neighbors() {
if searcher.seen.insert(neighbor) {
let distance = T::distance(q, &self.features[neighbor as usize]);
let candidate = Neighbor {
index: neighbor as usize,
distance,
};
if searcher.nearest.push(candidate) {
searcher.candidates.push(candidate);
}
}
}
}
}
fn lower_search(&self, layer: &[Node<M>], searcher: &mut Searcher, m: usize) {
searcher.candidates.clear();
let Neighbor { index, distance } = searcher.nearest.best().unwrap();
searcher.nearest.clear();
searcher.nearest.set_cap(m);
let new_index = layer[index].next_node as usize;
let candidate = Neighbor {
index: new_index,
distance,
};
searcher.nearest.push(candidate);
searcher.candidates.push(candidate);
}
fn initialize_searcher(&self, q: &T, searcher: &mut Searcher, cap: usize) {
searcher.clear();
searcher.nearest.set_cap(cap);
let entry_distance = T::distance(q, self.entry_feature());
let candidate = Neighbor {
index: 0,
distance: entry_distance,
};
searcher.candidates.push(candidate);
searcher.nearest.push(candidate);
searcher.seen.insert(
self.layers
.last()
.map(|layer| layer[0].zero_node)
.unwrap_or(0),
);
}
fn entry_feature(&self) -> &T {
if let Some(last_layer) = self.layers.last() {
&self.features[last_layer[0].zero_node as usize]
} else {
&self.features[0]
}
}
fn random_level(&mut self) -> usize {
let uniform: f64 = self.prng.next_u32() as f64 / core::u32::MAX as f64;
(libm::log(-uniform) * libm::log(M::to_usize() as f64).recip()) as usize
}
fn create_node(&mut self, q: &T, nearest: &CandidatesVec, layer: usize) {
if layer == 0 {
let new_index = self.zero.len();
let mut neighbors: GenericArray<u32, M0> = core::iter::repeat(!0).collect();
for (d, s) in neighbors.iter_mut().zip(nearest.iter()) {
*d = s.index as u32;
}
let node = ZeroNode { neighbors };
for neighbor in node.neighbors() {
self.add_neighbor(q, new_index as u32, neighbor, layer);
}
self.zero.push(node);
} else {
let new_index = self.layers[layer - 1].len();
let mut neighbors: GenericArray<u32, M> = core::iter::repeat(!0).collect();
for (d, s) in neighbors.iter_mut().zip(nearest.iter()) {
*d = s.index as u32;
}
let node = Node {
zero_node: self.zero.len() as u32,
next_node: if layer == 1 {
self.zero.len()
} else {
self.layers[layer - 2].len()
} as u32,
neighbors,
};
for neighbor in node.neighbors() {
self.add_neighbor(q, new_index as u32, neighbor, layer);
}
self.layers[layer - 1].push(node);
}
}
fn add_neighbor(&mut self, q: &T, node_ix: u32, target_ix: u32, layer: usize) {
let (target_feature, target_neighbors) = if layer == 0 {
(
&self.features[target_ix as usize],
&self.zero[target_ix as usize].neighbors[..],
)
} else {
let target = &self.layers[layer - 1][target_ix as usize];
(
&self.features[target.zero_node as usize],
&target.neighbors[..],
)
};
let (worst_ix, worst_distance) = target_neighbors
.iter()
.enumerate()
.map(|(ix, &n)| {
let distance = if n == !0 {
core::u32::MAX
} else {
T::distance(
target_feature,
&self.features[if layer == 0 {
n as usize
} else {
self.layers[layer - 1][n as usize].zero_node as usize
}],
)
};
(ix, distance)
})
.min_by_key(|&(_, distance)| !distance)
.unwrap();
if T::distance(q, target_feature) < worst_distance {
if layer == 0 {
self.zero[target_ix as usize].neighbors[worst_ix] = node_ix;
} else {
self.layers[layer - 1][target_ix as usize].neighbors[worst_ix] = node_ix;
}
}
}
}
impl<T, M: ArrayLength<u32>, M0: ArrayLength<u32>, R> Default for HNSW<T, M, M0, R>
where
R: SeedableRng,
{
fn default() -> Self {
Self {
zero: vec![],
features: vec![],
layers: vec![],
prng: R::from_seed(R::Seed::default()),
params: Params::new(),
}
}
}