use std::cmp::{max, Ordering, Reverse};
use std::collections::BinaryHeap;
use std::collections::HashSet;
#[cfg(feature = "indicatif")]
use std::sync::atomic::{self, AtomicUsize};
#[cfg(feature = "indicatif")]
use indicatif::ProgressBar;
use ordered_float::OrderedFloat;
use parking_lot::{Mutex, RwLock};
use rand::rngs::SmallRng;
use rand::SeedableRng;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
mod types;
pub use types::PointId;
use types::{Candidate, Layer, LayerId, UpperNode, Visited, ZeroNode, INVALID};
pub struct Builder {
ef_search: Option<usize>,
ef_construction: Option<usize>,
heuristic: Option<Heuristic>,
ml: Option<f32>,
seed: Option<u64>,
#[cfg(feature = "indicatif")]
progress: Option<ProgressBar>,
}
impl Builder {
pub fn ef_construction(mut self, ef_construction: usize) -> Self {
self.ef_construction = Some(ef_construction);
self
}
pub fn ef_search(mut self, ef: usize) -> Self {
self.ef_search = Some(ef);
self
}
pub fn select_heuristic(mut self, params: Option<Heuristic>) -> Self {
self.heuristic = params;
self
}
pub fn ml(mut self, ml: f32) -> Self {
self.ml = Some(ml);
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
#[cfg(feature = "indicatif")]
pub fn progress(mut self, bar: ProgressBar) -> Self {
self.progress = Some(bar);
self
}
pub fn build<P: Point>(self, points: &[P]) -> (Hnsw<P>, Vec<PointId>) {
Hnsw::new(points, self)
}
}
impl Default for Builder {
fn default() -> Self {
Self {
ef_search: None,
ef_construction: None,
heuristic: Some(Heuristic::default()),
ml: None,
seed: None,
#[cfg(feature = "indicatif")]
progress: None,
}
}
}
#[derive(Copy, Clone, Debug)]
pub struct Heuristic {
pub extend_candidates: bool,
pub keep_pruned: bool,
}
impl Default for Heuristic {
fn default() -> Self {
Heuristic {
extend_candidates: false,
keep_pruned: true,
}
}
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Hnsw<P> {
ef_search: usize,
points: Vec<P>,
zero: Vec<ZeroNode>,
layers: Vec<Vec<UpperNode>>,
}
impl<P> Hnsw<P>
where
P: Point,
{
pub fn builder() -> Builder {
Builder::default()
}
fn new(points: &[P], builder: Builder) -> (Self, Vec<PointId>) {
let ef_search = builder.ef_search.unwrap_or(100);
let ef_construction = builder.ef_construction.unwrap_or(100);
let ml = builder.ml.unwrap_or_else(|| (M as f32).ln());
let heuristic = builder.heuristic;
let mut rng = match builder.seed {
Some(seed) => SmallRng::seed_from_u64(seed),
None => SmallRng::from_entropy(),
};
#[cfg(feature = "indicatif")]
let progress = builder.progress;
#[cfg(feature = "indicatif")]
if let Some(bar) = &progress {
bar.set_draw_delta(1_000);
bar.set_length(points.len() as u64);
bar.set_message("Build index (preparation)");
}
if points.is_empty() {
return (
Self {
ef_search,
zero: Vec::new(),
points: Vec::new(),
layers: Vec::new(),
},
Vec::new(),
);
}
assert!(points.len() < u32::MAX as usize);
let mut nodes = (0..points.len())
.map(|i| (LayerId::random(ml, &mut rng), i))
.collect::<Vec<_>>();
nodes.sort_unstable_by_key(|&n| Reverse(n));
let (mut num_layers, mut prev) = (1, nodes[0].0);
for (layer, _) in nodes.iter() {
if *layer != prev {
num_layers += 1;
prev = *layer;
}
}
let mut cur_layer = LayerId(num_layers - 1);
let mut prev_layer = nodes[0].0;
let mut new_points = Vec::with_capacity(points.len());
let mut new_nodes = Vec::with_capacity(points.len());
let mut out = vec![INVALID; points.len()];
for (i, &(layer, idx)) in nodes.iter().enumerate() {
if prev_layer != layer {
cur_layer = LayerId(cur_layer.0 - 1);
prev_layer = layer;
}
let pid = PointId(i as u32);
new_points.push(points[idx].clone());
new_nodes.push((cur_layer, pid));
out[idx] = pid;
}
let (points, nodes) = (new_points, new_nodes);
debug_assert_eq!(nodes.last().unwrap().0, LayerId(0));
debug_assert_eq!(nodes.first().unwrap().0, LayerId(num_layers - 1));
let top = match nodes.first() {
Some((top, _)) => *top,
None => LayerId(0),
};
let mut sizes = vec![0; top.0 + 1];
for (layer, _) in nodes.iter().copied() {
sizes[layer.0] += 1;
}
let mut start = 0;
let mut ranges = Vec::with_capacity(top.0);
for (i, size) in sizes.into_iter().enumerate().rev() {
ranges.push((LayerId(i), max(start, 1)..start + size));
start += size;
}
let mut layers = vec![vec![]; top.0];
let zero = points
.iter()
.map(|_| RwLock::new(ZeroNode::default()))
.collect::<Vec<_>>();
let pool = SearchPool::new(points.len());
#[cfg(feature = "indicatif")]
let done = AtomicUsize::new(0);
for (layer, range) in ranges {
let num = if layer.0 > 0 { M } else { M * 2 };
#[cfg(feature = "indicatif")]
if let Some(bar) = &progress {
bar.set_message(&format!("Building index (layer {})", layer.0));
}
nodes[range].into_par_iter().for_each(|(_, pid)| {
let (mut search, mut insertion) = pool.pop();
let point = &points.as_slice()[*pid];
search.reset();
search.push(PointId(0), point, &points);
for cur in top.descend() {
search.ef = if cur <= layer { ef_construction } else { 1 };
match cur > layer {
true => {
search.search(point, layers[cur.0 - 1].as_slice(), &points, num);
search.cull();
}
false => {
search.search(point, zero.as_slice(), &points, num);
break;
}
}
}
insertion.ef = ef_construction;
insert(
*pid,
&mut insertion,
&mut search,
&zero,
&points,
&heuristic,
);
#[cfg(feature = "indicatif")]
if let Some(bar) = &progress {
let value = done.fetch_add(1, atomic::Ordering::Relaxed);
if value % 1000 == 0 {
bar.set_position(value as u64);
}
}
pool.push((search, insertion));
});
if layer.0 > 0 {
let mut upper = Vec::with_capacity(zero.len());
upper.extend(zero.iter().map(|zero| UpperNode::from_zero(&zero.read())));
layers[layer.0 - 1] = upper;
}
}
#[cfg(feature = "indicatif")]
if let Some(bar) = progress {
bar.finish();
}
(
Self {
ef_search,
zero: zero.into_iter().map(|node| node.into_inner()).collect(),
points,
layers,
},
out,
)
}
pub fn search(&self, point: &P, out: &mut [PointId], search: &mut Search) -> usize {
if self.points.is_empty() {
return 0;
}
search.visited.reserve_capacity(self.points.len());
search.reset();
search.push(PointId(0), point, &self.points);
for cur in LayerId(self.layers.len()).descend() {
let (ef, num) = match cur.is_zero() {
true => (self.ef_search, M * 2),
false => (1, M),
};
search.ef = ef;
match cur.0 {
0 => search.search(point, self.zero.as_slice(), &self.points, num),
l => search.search(point, self.layers[l - 1].as_slice(), &self.points, num),
}
if !cur.is_zero() {
search.cull();
}
}
let nearest = search.select_simple(out.len());
for (i, candidate) in nearest.iter().enumerate() {
out[i] = candidate.pid;
}
nearest.len()
}
pub fn iter(&self) -> impl Iterator<Item = (PointId, &P)> {
self.points
.iter()
.enumerate()
.map(|(i, p)| (PointId(i as u32), p))
}
}
fn insert<P: Point>(
new: PointId,
insertion: &mut Search,
search: &mut Search,
layer: &[RwLock<ZeroNode>],
points: &[P],
heuristic: &Option<Heuristic>,
) {
let mut node = layer[new].write();
let found = match heuristic {
None => search.select_simple(M * 2),
Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic),
};
debug_assert_eq!(
found.len(),
found.iter().map(|c| c.pid).collect::<HashSet<_>>().len()
);
for (i, candidate) in found.iter().enumerate() {
let &Candidate { distance, pid } = candidate;
if let Some(heuristic) = heuristic {
let found = insertion.add_neighbor_heuristic(
new,
layer.nearest_iter(pid),
layer,
&points[pid],
points,
*heuristic,
);
layer[pid]
.write()
.rewrite(found.iter().map(|candidate| candidate.pid));
node.set(i, pid);
} else {
let old = &points[pid];
let idx = layer[pid]
.read()
.binary_search_by(|third| {
let third = match third {
pid if pid.is_valid() => *pid,
_ => return Ordering::Greater,
};
distance.cmp(&old.distance(&points[third]).into())
})
.unwrap_or_else(|e| e);
layer[pid].write().insert(idx, new);
node.set(i, pid);
}
}
}
struct SearchPool {
pool: Mutex<Vec<(Search, Search)>>,
len: usize,
}
impl SearchPool {
fn new(len: usize) -> Self {
Self {
pool: Mutex::new(Vec::new()),
len,
}
}
fn pop(&self) -> (Search, Search) {
match self.pool.lock().pop() {
Some(res) => res,
None => (Search::new(self.len), Search::new(self.len)),
}
}
fn push(&self, item: (Search, Search)) {
self.pool.lock().push(item);
}
}
pub struct Search {
visited: Visited,
candidates: BinaryHeap<Reverse<Candidate>>,
nearest: Vec<Candidate>,
working: Vec<Candidate>,
discarded: Vec<Candidate>,
ef: usize,
}
impl Search {
fn new(capacity: usize) -> Self {
Self {
visited: Visited::with_capacity(capacity),
..Default::default()
}
}
fn search<L: Layer, P: Point>(&mut self, point: &P, layer: L, points: &[P], links: usize) {
while let Some(Reverse(candidate)) = self.candidates.pop() {
if let Some(furthest) = self.nearest.last() {
if candidate.distance > furthest.distance {
break;
}
}
for pid in layer.nearest_iter(candidate.pid).take(links) {
self.push(pid, point, points);
}
self.nearest.truncate(self.ef);
}
}
fn add_neighbor_heuristic<L: Layer, P: Point>(
&mut self,
new: PointId,
current: impl Iterator<Item = PointId>,
layer: L,
point: &P,
points: &[P],
params: Heuristic,
) -> &[Candidate] {
self.reset();
self.push(new, point, points);
for pid in current {
self.push(pid, point, points);
}
self.select_heuristic(point, layer, points, params)
}
fn select_heuristic<L: Layer, P: Point>(
&mut self,
point: &P,
layer: L,
points: &[P],
params: Heuristic,
) -> &[Candidate] {
self.working.clear();
for &candidate in &self.nearest {
self.working.push(candidate);
if params.extend_candidates {
for hop in layer.nearest_iter(candidate.pid) {
if !self.visited.insert(hop) {
continue;
}
let other = &points[hop];
let distance = OrderedFloat::from(point.distance(other));
let new = Candidate { distance, pid: hop };
self.working.push(new);
}
}
}
if params.extend_candidates {
self.working.sort_unstable();
}
self.nearest.clear();
self.discarded.clear();
for candidate in self.working.drain(..) {
if self.nearest.len() >= M * 2 {
break;
}
let candidate_point = &points[candidate.pid];
let nearest = !self.nearest.iter().any(|result| {
let distance = OrderedFloat::from(candidate_point.distance(&points[result.pid]));
distance < candidate.distance
});
match nearest {
true => self.nearest.push(candidate),
false => self.discarded.push(candidate),
}
}
if params.keep_pruned {
for candidate in self.discarded.drain(..) {
if self.nearest.len() >= M * 2 {
break;
}
self.nearest.push(candidate);
}
}
&self.nearest
}
fn push<P: Point>(&mut self, pid: PointId, point: &P, points: &[P]) {
if !self.visited.insert(pid) {
return;
}
let other = &points[pid];
let distance = OrderedFloat::from(point.distance(other));
let new = Candidate { distance, pid };
let idx = match self.nearest.binary_search(&new) {
Err(idx) if idx < self.ef => idx,
Err(_) => return,
Ok(_) => unreachable!(),
};
self.nearest.insert(idx, new);
self.candidates.push(Reverse(new));
}
fn cull(&mut self) {
self.candidates.clear();
for &candidate in self.nearest.iter() {
self.candidates.push(Reverse(candidate));
}
self.visited.clear();
self.visited.extend(self.nearest.iter().map(|c| c.pid));
}
fn reset(&mut self) {
let Search {
visited,
candidates,
nearest,
working,
discarded,
ef: _,
} = self;
visited.clear();
candidates.clear();
nearest.clear();
working.clear();
discarded.clear();
}
fn select_simple(&mut self, num: usize) -> &[Candidate] {
self.nearest.truncate(num);
&self.nearest
}
}
impl Default for Search {
fn default() -> Self {
Self {
visited: Visited::with_capacity(0),
candidates: BinaryHeap::new(),
nearest: Vec::new(),
working: Vec::new(),
discarded: Vec::new(),
ef: 1,
}
}
}
pub trait Point: Clone + Sync {
fn distance(&self, other: &Self) -> f32;
}
const M: usize = 32;