annoy_rs/types/serving/
mod.rs

1use super::*;
2use crate::internals::priority_queue::PriorityQueue;
3// use std::collections::HashSet;
4// Benchmark does not show visible perf difference, use hashbrown which is cryptographically secure instead
5// use ahash::AHashSet as HashSet;
6use hashbrown::HashSet;
7
8pub trait AnnoyIndexSearchApi {
9    fn get_item_vector(&self, item_index: u64) -> Vec<f32>;
10    fn get_nearest(
11        &self,
12        query_vector: &[f32],
13        n_results: usize,
14        search_k: i32,
15        should_include_distance: bool,
16    ) -> AnnoyIndexSearchResult;
17    fn get_nearest_to_item(
18        &self,
19        item_index: u64,
20        n_results: usize,
21        search_k: i32,
22        should_include_distance: bool,
23    ) -> AnnoyIndexSearchResult;
24}
25
26impl AnnoyIndexSearchApi for AnnoyIndex {
27    fn get_item_vector(&self, item_index: u64) -> Vec<f32> {
28        let node_offset = item_index as usize * self.node_size;
29        let slice = self.get_node_slice_with_offset(node_offset);
30        slice.iter().map(|&a| a).collect()
31    }
32
33    fn get_nearest(
34        &self,
35        query_vector: &[f32],
36        n_results: usize,
37        search_k: i32,
38        should_include_distance: bool,
39    ) -> AnnoyIndexSearchResult {
40        let result_capacity = n_results.min(self.size).max(1);
41        let search_k_fixed = if search_k > 0 {
42            search_k as usize
43        } else {
44            result_capacity * self.roots.len()
45        };
46
47        let mut pq = PriorityQueue::with_capacity(result_capacity, false);
48        for i in 0..self.roots.len() {
49            let id = self.roots[i];
50            pq.push(id as i32, f32::MAX);
51        }
52
53        let mut nearest_neighbors = Vec::with_capacity(search_k_fixed);
54        // let mut nearest_neighbors = HashSet::with_capacity(search_k_fixed);
55        while pq.len() > 0 && nearest_neighbors.len() < search_k_fixed {
56            if let Some((top_node_id_i32, top_node_margin)) = pq.pop() {
57                let top_node_id = top_node_id_i32 as usize;
58                let top_node = self.get_node_from_id(top_node_id);
59                let top_node_header = top_node.header;
60                let top_node_offset = top_node.offset;
61                let n_descendants = top_node_header.get_n_descendant();
62                if n_descendants == 1 && top_node_id < self.size {
63                    nearest_neighbors.push(top_node_id_i32);
64                    // nearest_neighbors.insert(top_node_id_i32);
65                } else if n_descendants <= self.max_descendants {
66                    let children_id_slice =
67                        self.get_descendant_id_slice(top_node_offset, n_descendants as usize);
68                    for &child_id in children_id_slice {
69                        nearest_neighbors.push(child_id);
70                        // nearest_neighbors.insert(child_id);
71                    }
72                } else {
73                    let v = self.get_node_slice_with_offset(top_node_offset);
74                    let margin = self.get_margin(v, query_vector, top_node_offset);
75                    let children_id = top_node_header.get_children_id_slice();
76                    // NOTE: Hamming has different logic to calculate margin
77                    pq.push(children_id[1], top_node_margin.min(margin));
78                    pq.push(children_id[0], top_node_margin.min(-margin));
79                }
80            }
81        }
82        // let mut nearest_neighbors: Vec<i32> = nearest_neighbors.into_iter().collect();
83        nearest_neighbors.sort();
84        let mut sorted_nns = PriorityQueue::with_capacity(nearest_neighbors.len(), true);
85        let mut nn_id_last = -1;
86        for nn_id in nearest_neighbors {
87            if nn_id == nn_id_last {
88                continue;
89            }
90            nn_id_last = nn_id;
91            let node = self.get_node_from_id(nn_id as usize);
92            let n_descendants = node.header.get_n_descendant();
93            if n_descendants != 1 {
94                continue;
95            }
96
97            let s = self.get_node_slice_with_offset(nn_id as usize * self.node_size);
98            sorted_nns.push(nn_id, self.get_distance_no_norm(s, query_vector));
99        }
100
101        let final_result_capcity = n_results.min(sorted_nns.len());
102        let mut id_list = Vec::with_capacity(final_result_capcity);
103        let mut distance_list = Vec::with_capacity(if should_include_distance {
104            final_result_capcity
105        } else {
106            0
107        });
108        for _i in 0..final_result_capcity {
109            let nn = &sorted_nns.pop().unwrap();
110            id_list.push(nn.0 as u64);
111            if should_include_distance {
112                distance_list.push(self.normalized_distance(nn.1));
113            }
114        }
115        return AnnoyIndexSearchResult {
116            count: final_result_capcity,
117            is_distance_included: should_include_distance,
118            id_list: id_list,
119            distance_list: distance_list,
120        };
121    }
122
123    fn get_nearest_to_item(
124        &self,
125        item_index: u64,
126        n_results: usize,
127        search_k: i32,
128        should_include_distance: bool,
129    ) -> AnnoyIndexSearchResult {
130        let item_vector = self.get_item_vector(item_index);
131        self.get_nearest(
132            item_vector.as_slice(),
133            n_results,
134            search_k,
135            should_include_distance,
136        )
137    }
138}