use crate::NodeAddress;
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::f32;
use super::*;
use super::query_items::{QueryAddress, QuerySingleton};
#[derive(Debug)]
pub struct KnnQueryHeap {
child_heap: BinaryHeap<QueryAddress>,
singleton_heap: BinaryHeap<QueryAddress>,
known_indexes: HashSet<usize>,
est_min_dist: HashMap<NodeAddress, f32>,
dist_heap: BinaryHeap<QuerySingleton>,
k: usize,
scale_base: f32,
}
impl RoutingQueryHeap for KnnQueryHeap {
fn push_nodes(
&mut self,
indexes: &[NodeAddress],
dists: &[f32],
parent_address: Option<NodeAddress>,
) {
let mut max_dist = self.max_dist();
let mut parent_est_dist_update = 0.0;
for ((si, pi), d) in indexes.iter().zip(dists) {
let emd = (d - self.scale_base.powi(*si)).max(0.0);
parent_est_dist_update = emd.max(parent_est_dist_update);
if emd < max_dist {
self.child_heap.push(QueryAddress {
address: (*si, *pi),
dist_to_center: *d,
min_dist: emd,
});
}
if !self.known_indexes.contains(pi) {
self.known_indexes.insert(*pi);
match self.dist_heap.peek() {
Some(my_dist) => {
if !(my_dist.dist < *d && self.dist_heap.len() >= self.k) {
self.dist_heap.push(QuerySingleton::new(*pi, *d));
}
}
None => self.dist_heap.push(QuerySingleton::new(*pi, *d)),
};
}
while self.dist_heap.len() > self.k {
self.dist_heap.pop();
max_dist = self.max_dist();
}
}
if let Some(a) = parent_address {
self.increase_estimated_distance(a, parent_est_dist_update);
}
}
}
impl SingletonQueryHeap for KnnQueryHeap {
fn push_outliers(&mut self, indexes: &[usize], dists: &[f32]) {
for (i, d) in indexes.iter().zip(dists) {
if !self.known_indexes.contains(i) {
self.known_indexes.insert(*i);
match self.dist_heap.peek() {
Some(my_dist) => {
if !(my_dist.dist < *d && self.dist_heap.len() >= self.k) {
self.dist_heap.push(QuerySingleton::new(*i, *d));
}
}
None => self.dist_heap.push(QuerySingleton::new(*i, *d)),
};
while self.dist_heap.len() > self.k {
self.dist_heap.pop();
}
}
}
}
}
impl KnnQueryHeap {
pub fn new(k: usize, scale_base: f32) -> KnnQueryHeap {
KnnQueryHeap {
child_heap: BinaryHeap::new(),
singleton_heap: BinaryHeap::new(),
est_min_dist: HashMap::new(),
dist_heap: BinaryHeap::new(),
known_indexes: HashSet::new(),
k,
scale_base,
}
}
pub fn closest_unvisited_child_covering_address(&mut self) -> Option<(f32, NodeAddress)> {
while let Some(mut node_to_visit) = self.child_heap.pop() {
if let Some(min_dist_update) = self.est_min_dist.remove(&node_to_visit.address) {
if min_dist_update > node_to_visit.min_dist {
node_to_visit.min_dist = min_dist_update;
self.child_heap.push(node_to_visit);
} else {
self.singleton_heap.push(node_to_visit);
return Some((node_to_visit.dist_to_center, node_to_visit.address));
}
} else {
self.singleton_heap.push(node_to_visit);
return Some((node_to_visit.dist_to_center, node_to_visit.address));
}
}
None
}
pub fn closest_unvisited_singleton_covering_address(&mut self) -> Option<(f32, NodeAddress)> {
while let Some(mut node_to_visit) = self.singleton_heap.pop() {
if let Some(min_dist_update) = self.est_min_dist.remove(&node_to_visit.address) {
if min_dist_update > node_to_visit.min_dist {
node_to_visit.min_dist = min_dist_update;
self.singleton_heap.push(node_to_visit);
} else {
return Some((node_to_visit.dist_to_center, node_to_visit.address));
}
} else {
return Some((node_to_visit.dist_to_center, node_to_visit.address));
}
}
None
}
pub fn len(&self) -> usize {
self.dist_heap.len()
}
pub fn is_empty(&self) -> bool {
self.dist_heap.is_empty()
}
pub fn node_len(&self) -> usize {
self.child_heap.len() + self.singleton_heap.len()
}
pub fn max_dist(&self) -> f32 {
if self.len() < self.k {
std::f32::MAX
} else {
self.dist_heap.peek().map(|x| x.dist).unwrap_or(f32::MAX)
}
}
pub fn unpack(mut self) -> Vec<(f32, usize)> {
let mut result = Vec::with_capacity(self.k);
while let Some(el) = self.dist_heap.pop() {
result.push((el.dist, el.index));
}
result.iter().rev().cloned().collect()
}
pub fn increase_estimated_distance(&mut self, address: NodeAddress, new_estimate: f32) {
let d = self.est_min_dist.entry(address).or_insert(0.0);
if *d < new_estimate {
*d = new_estimate;
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
#[test]
fn unpacking_has_correct_order() {
let mut heap = KnnQueryHeap::new(4, 2.0);
heap.push_outliers(&[2, 4, 6, 8], &[0.2, 0.4, 0.6, 0.8]);
heap.push_nodes(
&[(0, 1), (0, 3), (1, 5), (1, 7)],
&[0.1, 0.3, 0.5, 0.7],
None,
);
let unpack = heap.unpack();
for i in 1..5 {
assert!(unpack[i - 1].1 == i);
}
}
pub fn clone_unvisited_nodes(heap: &KnnQueryHeap) -> Vec<(f32, NodeAddress)> {
let mut all_nodes: Vec<QueryAddress> = heap.child_heap.iter().cloned().collect();
all_nodes.extend(heap.singleton_heap.iter().cloned());
all_nodes.sort();
all_nodes.iter().map(|n| (n.min_dist, n.address)).collect()
}
}