1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use super::*;
use crate::internals::priority_queue::PriorityQueue;
use hashbrown::HashSet;

pub trait AnnoyIndexSearchApi {
    fn get_item_vector(&self, item_index: i64) -> Vec<f32>;
    fn get_nearest(
        &self,
        query_vector: &[f32],
        n_results: usize,
        search_k: i32,
        should_include_distance: bool,
    ) -> AnnoyIndexSearchResult;
    fn get_nearest_to_item(
        &self,
        item_index: i64,
        n_results: usize,
        search_k: i32,
        should_include_distance: bool,
    ) -> AnnoyIndexSearchResult;
}

impl AnnoyIndexSearchApi for AnnoyIndex {
    fn get_item_vector(&self, item_index: i64) -> Vec<f32> {
        let node_offset = item_index as usize * self.node_size;
        let slice = self.get_node_slice_with_offset(node_offset);
        slice.iter().map(|&a| a).collect()
    }

    fn get_nearest(
        &self,
        query_vector: &[f32],
        n_results: usize,
        search_k: i32,
        should_include_distance: bool,
    ) -> AnnoyIndexSearchResult {
        let result_capacity = n_results.min(self.size).max(1);
        let search_k_fixed = if search_k > 0 {
            search_k as usize
        } else {
            result_capacity * self.roots.len()
        };

        let mut pq = PriorityQueue::with_capacity(result_capacity, false);
        for i in 0..self.roots.len() {
            let id = self.roots[i];
            pq.push(id as i32, f32::MAX);
        }

        let mut nearest_neighbors = HashSet::new();
        while pq.len() > 0 && nearest_neighbors.len() < search_k_fixed {
            if let Some((top_node_id_i32, top_node_margin)) = pq.pop() {
                let top_node_id = top_node_id_i32 as usize;
                let top_node = self.get_node_from_id(top_node_id);
                let top_node_header = top_node.header;
                let top_node_offset = top_node.offset;
                let n_descendants = top_node_header.get_n_descendant();
                if n_descendants == 1 && top_node_id < self.size {
                    nearest_neighbors.insert(top_node_id_i32);
                } else if n_descendants <= self.max_descendants {
                    let children_id_slice =
                        self.get_descendant_id_slice(top_node_offset, n_descendants as usize);
                    for &child_id in children_id_slice {
                        nearest_neighbors.insert(child_id);
                    }
                } else {
                    let v = self.get_node_slice_with_offset(top_node_offset);
                    let margin = self.get_margin(v, query_vector, top_node_offset);
                    let children_id = top_node_header.get_children_id_slice();
                    // NOTE: Hamming has different logic to calculate margin
                    pq.push(children_id[1], top_node_margin.min(margin));
                    pq.push(children_id[0], top_node_margin.min(-margin));
                }
            }
        }

        let mut sorted_nns = PriorityQueue::with_capacity(nearest_neighbors.len(), true);
        for nn_id in nearest_neighbors {
            let node = self.get_node_from_id(nn_id as usize);
            let n_descendants = node.header.get_n_descendant();
            if n_descendants != 1 {
                continue;
            }

            let s = self.get_node_slice_with_offset(nn_id as usize * self.node_size);
            sorted_nns.push(nn_id, self.get_distance_no_norm(s, query_vector));
        }

        let final_result_capcity = n_results.min(sorted_nns.len());
        let mut id_list = Vec::with_capacity(final_result_capcity);
        let mut distance_list = Vec::with_capacity(if should_include_distance {
            final_result_capcity
        } else {
            0
        });
        for _i in 0..final_result_capcity {
            let nn = &sorted_nns.pop().unwrap();
            id_list.push(nn.0 as u64);
            if should_include_distance {
                distance_list.push(self.normalized_distance(nn.1));
            }
        }
        return AnnoyIndexSearchResult {
            count: final_result_capcity,
            is_distance_included: should_include_distance,
            id_list: id_list,
            distance_list: distance_list,
        };
    }

    fn get_nearest_to_item(
        &self,
        item_index: i64,
        n_results: usize,
        search_k: i32,
        should_include_distance: bool,
    ) -> AnnoyIndexSearchResult {
        let item_vector = self.get_item_vector(item_index);
        self.get_nearest(
            item_vector.as_slice(),
            n_results,
            search_k,
            should_include_distance,
        )
    }
}