libroxanne_search/
lib.rs

1use ahash::HashSet;
2use ahash::HashSetExt;
3use dashmap::DashSet;
4use ndarray::Array1;
5use ndarray::ArrayView1;
6use ndarray_linalg::Scalar;
7use ordered_float::OrderedFloat;
8use parking_lot::Mutex;
9use rayon::iter::IntoParallelIterator;
10use rayon::iter::IntoParallelRefIterator;
11use rayon::iter::ParallelIterator;
12use serde::Deserialize;
13use serde::Serialize;
14use std::collections::BinaryHeap;
15use std::collections::VecDeque;
16use std::sync::atomic::AtomicUsize;
17use std::sync::atomic::Ordering;
18use strum_macros::Display;
19use strum_macros::EnumString;
20
21pub type Id = usize;
22pub type Metric<T> = fn(&ArrayView1<T>, &ArrayView1<T>) -> f64;
23
24#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
25pub struct PointDist {
26  pub id: Id,
27  pub dist: f64,
28}
29
30// A metric implementation of the Euclidean distance.
31pub fn metric_euclidean<T: Scalar>(a: &ArrayView1<T>, b: &ArrayView1<T>) -> f64 {
32  let diff = a - b;
33  let squared_diff = &diff * &diff;
34  let sum_squared_diff = squared_diff.sum();
35  sum_squared_diff.to_f64().unwrap().sqrt()
36}
37
38// A metric implementation of the cosine distance (NOT similarity).
39pub fn metric_cosine<T: Scalar>(a: &ArrayView1<T>, b: &ArrayView1<T>) -> f64 {
40  let dot_product = a.dot(b).to_f64().unwrap();
41
42  let a_norm = a.dot(a).to_f64().unwrap();
43  let b_norm = b.dot(b).to_f64().unwrap();
44
45  let denominator = (a_norm * b_norm).sqrt();
46
47  if denominator == 0.0 {
48    1.0
49  } else {
50    1.0 - dot_product / denominator
51  }
52}
53
54#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Display, EnumString, Serialize, Deserialize)]
55pub enum StdMetric {
56  L2,
57  Cosine,
58}
59
60impl StdMetric {
61  pub fn get_fn<T: Scalar>(self) -> fn(&ArrayView1<T>, &ArrayView1<T>) -> f64 {
62    match self {
63      StdMetric::L2 => metric_euclidean::<T>,
64      StdMetric::Cosine => metric_cosine::<T>,
65    }
66  }
67}
68
69#[derive(Clone, Debug, Default, Serialize, Deserialize)]
70pub struct SearchIterationMetrics {
71  pub visited: usize,
72  pub unvisited_dist_sum: f64,   // The sum of all unvisited nodes.
73  pub unvisited_dist_mins: f64,  // The min. dist. of all unvisited nodes.
74  pub dropped_candidates: usize, // Neighbors of an expanded node that were already visited or in the unvisited queue.
75  pub new_candidates: usize,
76  pub dropped_visited: usize,
77  pub dropped_unvisited: usize,
78  pub ground_truth_found: usize, // Cumulative.
79}
80
81#[derive(Clone, Debug, Default, Serialize, Deserialize)]
82pub struct SearchMetrics {
83  pub iterations: Vec<SearchIterationMetrics>,
84}
85
86pub trait GreedySearchable<T: Scalar + Send + Sync>: Send + Sync {
87  fn get_point(&self, id: Id) -> Array1<T>;
88  fn get_out_neighbors(&self, id: Id) -> Vec<Id>;
89}
90
91// Optimised custom function for k=1 and search_list_cap=1.
92pub fn greedy_search_fast1<T: Scalar + Send + Sync>(
93  graph: &impl GreedySearchable<T>,
94  query: &ArrayView1<T>,
95  metric: Metric<T>,
96  start: Id,
97  filter: impl (Fn(Id) -> bool) + Send + Sync,
98) -> Option<PointDist> {
99  struct State {
100    cur: PointDist,
101    optima: Option<PointDist>,
102  }
103  // We traverse filtered nodes but don't allow them as the answer. This is better than ignoring them while traversing, as that breaks the graph (poor/no navigability). However, it's not as exhaustive as a typical search as we do not backtrack if we have found a local optima but it's filtered, as that may cause excessive backtracking and a regression to a full graph scan (e.g. consider a scenario where the only node left is one that is actually the furtherest from the query). For our usages, it's fine to be "approximate" and even if not the most accurate node is returned (or None is returned), as overall it still works out.
104  // TODO Study impact of performance and accuracy compared to full exhaustive search with backtracking while filtered.
105  let state = Mutex::new(State {
106    cur: PointDist {
107      id: start,
108      dist: metric(&graph.get_point(start).view(), query),
109    },
110    // The optima cannot default to `start` as it may be filtered.
111    optima: None,
112  });
113  let seen = DashSet::<Id>::new();
114  seen.insert(start);
115  loop {
116    let cur = state.lock().cur.id;
117    graph
118      .get_out_neighbors(cur)
119      .into_par_iter()
120      .filter(|n| !seen.contains(n))
121      .map(|n| PointDist {
122        id: n,
123        dist: metric(&graph.get_point(cur).view(), &graph.get_point(n).view()),
124      })
125      .for_each(|n| {
126        // If this node is a neighbor of a future expanded node, we don't need to compare the distance to this node, as if it's not the shortest now, it won't be then either.
127        seen.insert(n.id);
128        let mut s = state.lock();
129        if filter(n.id) && !s.optima.is_some_and(|o| n.dist >= o.dist) {
130          s.optima = Some(n);
131        }
132        if n.dist < s.cur.dist {
133          s.cur = n;
134        }
135      });
136    if state.lock().cur.id == cur {
137      // No change, reached local optima.
138      break;
139    }
140  }
141  state.into_inner().optima
142}
143
144// DiskANN paper, Algorithm 1: GreedySearch.
145// Returns a pair: (closest points, visited node IDs).
146// Filtered nodes will be visited and expanded but not considered for the final set of neighbors.
147pub fn greedy_search<T: Scalar + Send + Sync>(
148  graph: &impl GreedySearchable<T>,
149  query: &ArrayView1<T>,
150  k: usize,
151  search_list_cap: usize,
152  beam_width: usize,
153  metric: Metric<T>,
154  start: Id,
155  filter: impl Fn(PointDist) -> bool,
156  mut out_visited: Option<&mut HashSet<Id>>,
157  mut out_metrics: Option<&mut SearchMetrics>,
158  ground_truth: Option<&HashSet<Id>>,
159) -> Vec<PointDist> {
160  assert!(
161    search_list_cap >= k,
162    "search list capacity must be greater than or equal to k"
163  );
164
165  // It's too inefficient to calculate L\V repeatedly.
166  // Since we need both L (return value) and L\V (each iteration), we split L into V and ¬V.
167  // For simplicity, we'll just allow both to reach `k` size, and do a final merge at the end. This doubles the memory requirements, but in reality `k` is often small enough that it's not a problem.
168  // We also need `all_visited` as `l_visited` truncates to `k`, but we also want all visited points in the end.
169  // L = l_visited + l_unvisited
170  // `l_unvisited_set` is for members currently in l_unvisited, to avoid pushing duplicates. (NOTE: This is *not* the same as `all_visited`.) We don't need one for `l_visited` because we only push popped elements from `l_unvisited`, which we guarantee are unique (as previously mentioned).
171  let mut l_unvisited_set = HashSet::new();
172  let mut l_unvisited = VecDeque::<PointDist>::new(); // L \ V
173  let mut l_visited = VecDeque::<PointDist>::new(); // V
174  let mut all_visited = HashSet::new();
175  l_unvisited.push_back(PointDist {
176    id: start,
177    dist: metric(&graph.get_point(start).view(), query),
178  });
179  let ground_truth_found = AtomicUsize::new(0);
180  while !l_unvisited.is_empty() {
181    let new_visited = (0..beam_width)
182      .filter_map(|_| l_unvisited.pop_front())
183      .collect::<Vec<_>>();
184    let neighbors = DashSet::new();
185    new_visited.par_iter().for_each(|p_star| {
186      for j in graph.get_out_neighbors(p_star.id) {
187        neighbors.insert(j);
188      }
189    });
190    // Move to visited section.
191    all_visited.extend(new_visited.iter().map(|e| e.id));
192    l_visited.extend(new_visited.iter().filter(|e| filter(**e)));
193    l_visited
194      .make_contiguous()
195      .sort_unstable_by_key(|s| OrderedFloat(s.dist));
196
197    let new_unvisited = neighbors
198      .par_iter()
199      .filter_map(|neighbor| {
200        let neighbor = *neighbor;
201        // We separate L out into V and not V, so we must manually ensure the property that l_visited and l_unvisited are disjoint.
202        if all_visited.contains(&neighbor) {
203          return None;
204        };
205        if l_unvisited_set.contains(&neighbor) {
206          return None;
207        };
208        if ground_truth.is_some_and(|gt| gt.contains(&neighbor)) {
209          ground_truth_found.fetch_add(1, Ordering::Relaxed);
210        }
211        Some(PointDist {
212          id: neighbor,
213          dist: metric(&graph.get_point(neighbor).view(), query),
214        })
215      })
216      .collect::<VecDeque<_>>();
217    l_unvisited_set.extend(new_unvisited.iter().map(|e| e.id));
218    l_unvisited.extend(&new_unvisited);
219    l_unvisited
220      .make_contiguous()
221      .sort_unstable_by_key(|s| OrderedFloat(s.dist));
222
223    let mut dropped_unvisited = 0;
224    let mut dropped_visited = 0;
225    while l_unvisited.len() + l_visited.len() > search_list_cap {
226      let (Some(u), Some(v)) = (l_unvisited.back(), l_visited.back()) else {
227        break;
228      };
229      if u.dist >= v.dist {
230        l_unvisited.pop_back();
231        dropped_unvisited += 1;
232      } else {
233        l_visited.pop_back();
234        dropped_visited += 1;
235      }
236    }
237
238    if let Some(m) = &mut out_metrics {
239      m.iterations.push(SearchIterationMetrics {
240        dropped_candidates: neighbors.len() - new_unvisited.len(),
241        dropped_unvisited,
242        dropped_visited,
243        ground_truth_found: ground_truth_found.load(Ordering::Relaxed),
244        new_candidates: new_unvisited.len(),
245        unvisited_dist_mins: l_unvisited.front().map(|n| n.dist).unwrap_or_default(),
246        unvisited_dist_sum: l_unvisited.iter().map(|n| n.dist).sum(),
247        visited: all_visited.len(),
248      });
249    }
250  }
251
252  // Find the k closest points from both l_visited + l_unvisited (= L).
253  let mut closest = Vec::new();
254  while closest.len() < k {
255    match (l_visited.pop_front(), l_unvisited.pop_front()) {
256      (None, None) => break,
257      (Some(v), None) | (None, Some(v)) => closest.push(v),
258      (Some(a), Some(b)) => {
259        if a.dist < b.dist {
260          closest.push(a);
261          closest.push(b);
262        } else {
263          closest.push(b);
264          closest.push(a);
265        };
266      }
267    };
268  }
269  // We may have exceeded k due to pushing both a and b in the last match arm.
270  closest.truncate(k);
271  if let Some(out) = &mut out_visited {
272    out.extend(all_visited);
273  };
274  closest
275}
276
277pub fn find_shortest_spanning_tree<T: Scalar + Send + Sync>(
278  graph: &impl GreedySearchable<T>,
279  metric: Metric<T>,
280  start: Id,
281) -> Vec<(Id, Id)> {
282  let mut visited = HashSet::<Id>::new();
283  let mut path = Vec::<(Id, Id)>::new();
284  // Why use Dijkstra instead of simply calculating min(dist_of_in_neighbors_edges) for each node? After all, we traverse and expand every node in both, but the latter can be a bit more parallel. The reason is that the latter will create lots of symmetric edges, because if A is closest to B then so is B to A, but this is bad for forming a dependency path to iterate through the graph as we'll frequently stop due to A already being visited once we get to B. Basically, we use Dijkstra because it has a `visited` set that forces a non-cyclic continuous path.
285  let mut pq = BinaryHeap::<(OrderedFloat<f64>, Id, Id)>::new();
286  pq.push((OrderedFloat(0.0), start, start));
287  while let Some((_, from, to)) = pq.pop() {
288    if !visited.insert(to) {
289      // We've already visited (there was a shorter path to this node).
290      continue;
291    };
292
293    path.push((from, to));
294
295    // Move on to neighbors of `to` in the base shard.
296    let new = graph
297      .get_out_neighbors(to)
298      .into_par_iter()
299      .filter_map(|neighbor| {
300        if visited.contains(&neighbor) {
301          return None;
302        }
303        let dist = metric(
304          &graph.get_point(to).view(),
305          &graph.get_point(neighbor).view(),
306        );
307        // Use negative dist as BinaryHeap is a max-heap.
308        Some((OrderedFloat(-dist), to, neighbor))
309      })
310      .collect::<Vec<_>>();
311    pq.extend(new);
312  }
313  path
314}