use crate::*;
use rand_core::{RngCore, SeedableRng};
use rand_pcg::Pcg64;
use rustc_hash::FxHasher;
use std::collections::HashSet;
#[derive(Clone)]
pub struct HNSW<
T,
M: ArrayLength<u32> = typenum::U12,
M0: ArrayLength<u32> = typenum::U24,
R = Pcg64,
> {
zero: Vec<ZeroNode<M0>>,
features: Vec<T>,
layers: Vec<Vec<Node<M>>>,
prng: R,
params: Params,
}
#[derive(Clone)]
struct ZeroNode<N: ArrayLength<u32>> {
neighbors: GenericArray<u32, N>,
}
impl<N: ArrayLength<u32>> ZeroNode<N> {
fn neighbors<'a>(&'a self) -> impl Iterator<Item = u32> + 'a {
self.neighbors.iter().cloned().take_while(|&n| n != !0)
}
}
#[derive(Clone, Debug)]
struct Node<N: ArrayLength<u32>> {
zero_node: u32,
next_node: u32,
neighbors: GenericArray<u32, N>,
}
impl<N: ArrayLength<u32>> Node<N> {
fn neighbors<'a>(&'a self) -> impl Iterator<Item = u32> + 'a {
self.neighbors.iter().cloned().take_while(|&n| n != !0)
}
}
#[derive(Clone, Debug, Default)]
pub struct Searcher {
candidates: Candidates,
nearest: FixedCandidates,
seen: HashSet<u32, std::hash::BuildHasherDefault<FxHasher>>,
}
impl Searcher {
pub fn new() -> Self {
Default::default()
}
fn clear(&mut self) {
self.candidates.clear();
self.nearest.clear();
self.seen.clear();
}
}
impl<T, M: ArrayLength<u32>, M0: ArrayLength<u32>, R> HNSW<T, M, M0, R>
where
R: RngCore + SeedableRng,
{
pub fn new() -> Self {
Self::default()
}
pub fn new_params(params: Params) -> Self {
Self {
params,
..Default::default()
}
}
}
impl<T, M: ArrayLength<u32>, M0: ArrayLength<u32>, R> HNSW<T, M, M0, R>
where
R: RngCore,
T: Distance,
{
pub fn new_prng(prng: R) -> Self {
Self {
zero: vec![],
features: vec![],
layers: vec![],
prng,
params: Default::default(),
}
}
pub fn new_params_and_prng(params: Params, prng: R) -> Self {
Self {
zero: vec![],
features: vec![],
layers: vec![],
prng,
params,
}
}
pub fn insert(&mut self, q: T, searcher: &mut Searcher) -> u32 {
let level = self.random_level();
if self.is_empty() {
self.zero.push(ZeroNode {
neighbors: std::iter::repeat(!0).collect(),
});
self.features.push(q);
while self.layers.len() < level {
let node = Node {
zero_node: 0,
next_node: 0,
neighbors: std::iter::repeat(!0).collect(),
};
self.layers.push(vec![node]);
}
return 0;
}
self.initialize_searcher(
&q,
searcher,
if level >= self.layers.len() {
self.params.ef_construction
} else {
1
},
);
for ix in (level..self.layers.len()).rev() {
self.search_layer(&q, searcher, &self.layers[ix]);
self.lower_search(
&self.layers[ix],
searcher,
if ix == level {
self.params.ef_construction
} else {
1
},
);
}
for ix in (0..std::cmp::min(level, self.layers.len())).rev() {
self.search_layer(&q, searcher, &self.layers[ix]);
self.create_node(&q, &searcher.nearest, ix + 1);
self.lower_search(&self.layers[ix], searcher, self.params.ef_construction);
}
self.search_zero_layer(&q, searcher);
self.create_node(&q, &searcher.nearest, 0);
self.features.push(q);
let zero_node = (self.zero.len() - 1) as u32;
while self.layers.len() < level {
let node = Node {
zero_node,
next_node: self
.layers
.last()
.map(|l| (l.len() - 1) as u32)
.unwrap_or(zero_node),
neighbors: std::iter::repeat(!0).collect(),
};
self.layers.push(vec![node]);
}
zero_node
}
pub fn nearest<'a>(
&self,
q: &T,
ef: usize,
searcher: &mut Searcher,
dest: &'a mut [u32],
) -> &'a mut [u32] {
if self.features.is_empty() {
return &mut [];
}
self.initialize_searcher(q, searcher, if self.layers.is_empty() { ef } else { 1 });
for (ix, layer) in self.layers.iter().enumerate().rev() {
self.search_layer(q, searcher, layer);
self.lower_search(layer, searcher, if ix == 0 { ef } else { 1 });
}
self.search_zero_layer(q, searcher);
searcher.nearest.fill_slice(dest)
}
pub fn feature(&self, item: u32) -> &T {
&self.features[item as usize]
}
pub fn len(&self) -> usize {
self.zero.len()
}
pub fn is_empty(&self) -> bool {
self.zero.is_empty()
}
fn search_layer(&self, q: &T, searcher: &mut Searcher, layer: &[Node<M>]) {
while let Some((_, node)) = searcher.candidates.pop() {
for neighbor in layer[node as usize].neighbors() {
let neighbor_node = &layer[neighbor as usize];
if searcher.seen.insert(neighbor_node.zero_node) {
let distance = T::distance(q, &self.features[neighbor_node.zero_node as usize]);
if searcher.nearest.push(distance, neighbor) {
searcher.candidates.push(distance, neighbor);
}
}
}
}
}
fn search_zero_layer(&self, q: &T, searcher: &mut Searcher) {
while let Some((_, node)) = searcher.candidates.pop() {
for neighbor in self.zero[node as usize].neighbors() {
if searcher.seen.insert(neighbor) {
let distance = T::distance(q, &self.features[neighbor as usize]);
if searcher.nearest.push(distance, neighbor) {
searcher.candidates.push(distance, neighbor);
}
}
}
}
}
fn lower_search(&self, layer: &[Node<M>], searcher: &mut Searcher, m: usize) {
searcher.candidates.clear();
let (distance, node) = searcher.nearest.pop().unwrap();
searcher.nearest.clear();
searcher.nearest.set_cap(m);
let new_node = layer[node as usize].next_node;
searcher.nearest.push(distance, new_node);
searcher.candidates.push(distance, new_node);
}
fn initialize_searcher(&self, q: &T, searcher: &mut Searcher, cap: usize) {
searcher.clear();
searcher.nearest.set_cap(cap);
let entry_distance = T::distance(q, self.entry_feature());
searcher.candidates.push(entry_distance, 0);
searcher.nearest.push(entry_distance, 0);
searcher.seen.insert(
self.layers
.last()
.map(|layer| layer[0].zero_node)
.unwrap_or(0),
);
}
fn entry_feature(&self) -> &T {
if let Some(last_layer) = self.layers.last() {
&self.features[last_layer[0].zero_node as usize]
} else {
&self.features[0]
}
}
fn random_level(&mut self) -> usize {
use rand_distr::{Distribution, Standard};
let uniform: f32 = Standard.sample(&mut self.prng);
(-uniform.ln() * (M::to_usize() as f32).ln().recip()) as usize
}
fn create_node(&mut self, q: &T, nearest: &FixedCandidates, layer: usize) {
if layer == 0 {
let new_index = self.zero.len();
let mut neighbors: GenericArray<u32, M0> = std::iter::repeat(!0).collect();
nearest.fill_slice(&mut neighbors);
let node = ZeroNode { neighbors };
for neighbor in node.neighbors() {
self.add_neighbor(q, new_index as u32, neighbor, layer);
}
self.zero.push(node);
} else {
let new_index = self.layers[layer - 1].len();
let mut neighbors: GenericArray<u32, M> = std::iter::repeat(!0).collect();
nearest.fill_slice(&mut neighbors);
let node = Node {
zero_node: self.zero.len() as u32,
next_node: if layer == 1 {
self.zero.len()
} else {
self.layers[layer - 2].len()
} as u32,
neighbors,
};
for neighbor in node.neighbors() {
self.add_neighbor(q, new_index as u32, neighbor, layer);
}
self.layers[layer - 1].push(node);
}
}
fn add_neighbor(&mut self, q: &T, node_ix: u32, target_ix: u32, layer: usize) {
let (target_feature, target_neighbors) = if layer == 0 {
(
&self.features[target_ix as usize],
&self.zero[target_ix as usize].neighbors[..],
)
} else {
let target = &self.layers[layer - 1][target_ix as usize];
(
&self.features[target.zero_node as usize],
&target.neighbors[..],
)
};
let (worst_ix, worst_distance) = target_neighbors
.iter()
.enumerate()
.map(|(ix, &n)| {
let distance = if n == !0 {
std::u32::MAX
} else {
T::distance(
target_feature,
&self.features[if layer == 0 {
n as usize
} else {
self.layers[layer - 1][n as usize].zero_node as usize
}],
)
};
(ix, distance)
})
.min_by_key(|&(_, distance)| !distance)
.unwrap();
if T::distance(q, target_feature) < worst_distance {
if layer == 0 {
self.zero[target_ix as usize].neighbors[worst_ix] = node_ix;
} else {
self.layers[layer - 1][target_ix as usize].neighbors[worst_ix] = node_ix;
}
}
}
}
impl<T, M: ArrayLength<u32>, M0: ArrayLength<u32>, R> Default for HNSW<T, M, M0, R>
where
R: SeedableRng,
{
fn default() -> Self {
Self {
zero: vec![],
features: vec![],
layers: vec![],
prng: R::from_seed(R::Seed::default()),
params: Params::new(),
}
}
}