use crate::quantizer::{IdentityQuantizer, Quantizer, QueryEvaluator};
use crate::topk_selectors::topk_heap::TopkHeap;
use crate::topk_selectors::OnlineTopKSelector;
use crate::{hnsw_utils::*, DistanceType};
use crate::{Dataset, Float, GrowableDataset};
use crate::{DotProduct, EuclideanDistance};
use config_hnsw::ConfigHnsw;
use hnsw_builder::HnswBuilder;
use level::Level;
use serde::{Deserialize, Serialize};
use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashSet};
use std::marker::PhantomData;
#[derive(Serialize, Deserialize)]
pub struct GraphIndex<D, Q>
where
D: Dataset<Q>,
Q: Quantizer<DatasetType = D>,
{
levels: Box<[Level]>,
dataset: D,
num_neighbors_per_vec: usize,
id_permutation: Box<[usize]>,
entry_vec: usize,
_phantom: PhantomData<Q>,
}
impl<D, Q> GraphIndex<D, Q>
where
D: Dataset<Q> + GrowableDataset<Q>,
Q: Quantizer<DatasetType = D>,
{
pub fn from_dataset<'a, SD, IQ>(
source_dataset: &'a SD,
config: &ConfigHnsw,
quantizer: Q,
) -> Self
where
SD: Dataset<IQ> + Sync,
IQ: IdentityQuantizer<DatasetType = SD, T: Float> + Sync + 'a,
<IQ as Quantizer>::Evaluator<'a>:
QueryEvaluator<'a, QueryType = <SD as Dataset<IQ>>::DataType<'a>>,
D: GrowableDataset<Q, InputDataType<'a> = <SD as Dataset<IQ>>::DataType<'a>>,
<Q as Quantizer>::InputItem: 'a,
{
let mut hnsw_builder = HnswBuilder::new(config.get_num_neighbors_per_vec(), source_dataset);
let (levels, id_permutation, entry_vector) = hnsw_builder.compute_graph(config);
let mut encoded_dataset = D::new(quantizer, source_dataset.dim());
for id in 0..source_dataset.len() {
let vec = source_dataset.get(id_permutation[id]);
encoded_dataset.push(&vec);
}
GraphIndex::new(
levels,
encoded_dataset,
config.get_num_neighbors_per_vec(),
id_permutation,
entry_vector,
)
}
fn new(
levels: Vec<Level>,
dataset: D,
num_neighbors_per_vec: usize,
id_permutation: Vec<usize>,
entry_vec: usize,
) -> Self {
Self {
levels: levels.into_boxed_slice(),
dataset,
num_neighbors_per_vec,
_phantom: PhantomData,
id_permutation: id_permutation.into_boxed_slice(),
entry_vec,
}
}
}
impl<D, Q> GraphIndex<D, Q>
where
D: Dataset<Q> + Sync,
Q: Quantizer<InputItem: Float, DatasetType = D> + Sync,
{
pub fn dim(&self) -> usize {
self.dataset.dim()
}
pub fn search<'a, QD, QQ>(
&self,
query: QD::DataType<'a>,
k: usize,
config: &ConfigHnsw,
) -> Vec<(f32, usize)>
where
QD: Dataset<QQ> + Sync,
QQ: Quantizer<DatasetType = QD> + Sync,
<Q as Quantizer>::Evaluator<'a>:
QueryEvaluator<'a, QueryType = <QD as Dataset<QQ>>::DataType<'a>>,
<Q as Quantizer>::InputItem: EuclideanDistance<<Q as Quantizer>::InputItem>
+ DotProduct<<Q as Quantizer>::InputItem>,
<Q as Quantizer>::InputItem: 'a,
{
let query_topk = self.find_k_nearest_neighbors(query, k, config);
let mut topk: Vec<(f32, usize)> = query_topk
.iter()
.map(|x| (x.0, self.id_permutation[x.1]))
.collect();
if self.dataset.quantizer().distance() == DistanceType::DotProduct {
topk.iter_mut().for_each(|(dis, _)| *dis = -(*dis));
}
topk
}
pub fn find_k_nearest_neighbors<'a>(
&self,
query_vec: <Q::Evaluator<'a> as QueryEvaluator<'a>>::QueryType,
k: usize,
config: &ConfigHnsw,
) -> Vec<(f32, usize)>
where
<Q as Quantizer>::InputItem: EuclideanDistance<<Q as Quantizer>::InputItem>
+ DotProduct<<Q as Quantizer>::InputItem>,
{
let mut topk_heap = TopkHeap::new(k);
let query_evaluator = self.dataset.query_evaluator(query_vec);
let mut nearest_vec = self.entry_vec;
let mut dis_nearest_vec = query_evaluator.compute_distance(&self.dataset, nearest_vec);
for level in self.levels.iter().skip(1).rev() {
level.greedy_update_nearest(
&self.dataset,
&query_evaluator,
&mut nearest_vec,
&mut dis_nearest_vec,
);
}
let ef = std::cmp::max(config.get_ef_search(), k);
let mut top_candidates = self.search_from_candidates_unbounded(
Node(dis_nearest_vec, nearest_vec),
&query_evaluator,
ef,
&self.levels[0],
);
while top_candidates.len() > k {
top_candidates.pop();
}
while let Some(node) = top_candidates.pop() {
topk_heap.push_with_id(node.distance(), node.id_vec());
}
topk_heap.topk()
}
fn search_from_candidates_unbounded<'a, E>(
&self,
starting_node: Node,
query_evaluator: &E,
ef: usize,
level: &Level,
) -> BinaryHeap<Node>
where
E: QueryEvaluator<'a, Q = Q>, Q: Quantizer<DatasetType = D>, <Q as Quantizer>::InputItem: EuclideanDistance<<Q as Quantizer>::InputItem>
+ DotProduct<<Q as Quantizer>::InputItem>,
{
let mut top_candidates: BinaryHeap<Node> = BinaryHeap::new();
let mut candidates: BinaryHeap<Reverse<Node>> = BinaryHeap::new();
let mut visited_table = HashSet::with_capacity(ef * 32);
top_candidates.push(starting_node);
candidates.push(Reverse(starting_node));
visited_table.insert(starting_node.id_vec());
while let Some(Reverse(node)) = candidates.peek() {
let id_candidate = node.id_vec();
let distance_candidate = node.distance();
if distance_candidate > top_candidates.peek().unwrap().distance() {
break;
}
candidates.pop();
let neighbors = level.get_neighbors_from_id(id_candidate);
self.process_neighbors(
neighbors,
&mut visited_table,
query_evaluator,
|dis_neigh, neighbor| {
add_neighbor_to_heaps(
&mut candidates,
&mut top_candidates,
Node(dis_neigh, neighbor),
ef,
);
},
)
}
top_candidates
}
fn process_neighbors<'a, E, F>(
&self,
neighbors: &[usize],
visited_table: &mut HashSet<usize>,
query_evaluator: &E,
mut add_distances_fn: F,
) where
E: QueryEvaluator<'a, Q = Q>,
F: FnMut(f32, usize),
{
let mut counter = 0;
let mut ids: Vec<usize> = vec![0; 4];
for &neighbor in neighbors.iter() {
let visited = visited_table.contains(&neighbor);
visited_table.insert(neighbor);
ids[counter] = neighbor;
if !visited {
counter += 1;
}
if counter == 4 {
let distances =
query_evaluator.compute_four_distances(&self.dataset, ids.iter().copied());
for (dis_neigh, &neighbor) in distances.zip(ids.iter()) {
add_distances_fn(dis_neigh, neighbor);
}
counter = 0;
}
}
for neighbor in ids.iter().take(counter) {
let distance_neighbor: f32 = query_evaluator.compute_distance(&self.dataset, *neighbor);
add_distances_fn(distance_neighbor, *neighbor);
}
}
pub fn print_space_usage_byte(&self) -> usize {
println!("Space Usage:");
let forward: usize = self.dataset.get_space_usage_bytes();
println!("\tForward Index: {:} Bytes", forward);
let levels: usize = self
.levels
.iter()
.map(|level| level.get_space_usage_bytes())
.sum();
let permutation: usize = self.id_permutation.len() * std::mem::size_of::<usize>();
let additional: usize = 2 * std::mem::size_of::<usize>();
println!(
"\tLinks structure: {:} Bytes",
levels + permutation + additional
);
println!(
"\tTotal: {:} Bytes",
forward + permutation + additional + levels
);
forward + permutation + additional + levels
}
}