1use ahash::HashSet;
2use ahash::HashSetExt;
3use dashmap::DashSet;
4use ndarray::Array1;
5use ndarray::ArrayView1;
6use ndarray_linalg::Scalar;
7use ordered_float::OrderedFloat;
8use parking_lot::Mutex;
9use rayon::iter::IntoParallelIterator;
10use rayon::iter::IntoParallelRefIterator;
11use rayon::iter::ParallelIterator;
12use serde::Deserialize;
13use serde::Serialize;
14use std::collections::BinaryHeap;
15use std::collections::VecDeque;
16use std::sync::atomic::AtomicUsize;
17use std::sync::atomic::Ordering;
18use strum_macros::Display;
19use strum_macros::EnumString;
20
21pub type Id = usize;
22pub type Metric<T> = fn(&ArrayView1<T>, &ArrayView1<T>) -> f64;
23
24#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
25pub struct PointDist {
26 pub id: Id,
27 pub dist: f64,
28}
29
30pub fn metric_euclidean<T: Scalar>(a: &ArrayView1<T>, b: &ArrayView1<T>) -> f64 {
32 let diff = a - b;
33 let squared_diff = &diff * &diff;
34 let sum_squared_diff = squared_diff.sum();
35 sum_squared_diff.to_f64().unwrap().sqrt()
36}
37
38pub fn metric_cosine<T: Scalar>(a: &ArrayView1<T>, b: &ArrayView1<T>) -> f64 {
40 let dot_product = a.dot(b).to_f64().unwrap();
41
42 let a_norm = a.dot(a).to_f64().unwrap();
43 let b_norm = b.dot(b).to_f64().unwrap();
44
45 let denominator = (a_norm * b_norm).sqrt();
46
47 if denominator == 0.0 {
48 1.0
49 } else {
50 1.0 - dot_product / denominator
51 }
52}
53
54#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Display, EnumString, Serialize, Deserialize)]
55pub enum StdMetric {
56 L2,
57 Cosine,
58}
59
60impl StdMetric {
61 pub fn get_fn<T: Scalar>(self) -> fn(&ArrayView1<T>, &ArrayView1<T>) -> f64 {
62 match self {
63 StdMetric::L2 => metric_euclidean::<T>,
64 StdMetric::Cosine => metric_cosine::<T>,
65 }
66 }
67}
68
69#[derive(Clone, Debug, Default, Serialize, Deserialize)]
70pub struct SearchIterationMetrics {
71 pub visited: usize,
72 pub unvisited_dist_sum: f64, pub unvisited_dist_mins: f64, pub dropped_candidates: usize, pub new_candidates: usize,
76 pub dropped_visited: usize,
77 pub dropped_unvisited: usize,
78 pub ground_truth_found: usize, }
80
81#[derive(Clone, Debug, Default, Serialize, Deserialize)]
82pub struct SearchMetrics {
83 pub iterations: Vec<SearchIterationMetrics>,
84}
85
86pub trait GreedySearchable<T: Scalar + Send + Sync>: Send + Sync {
87 fn get_point(&self, id: Id) -> Array1<T>;
88 fn get_out_neighbors(&self, id: Id) -> Vec<Id>;
89}
90
91pub fn greedy_search_fast1<T: Scalar + Send + Sync>(
93 graph: &impl GreedySearchable<T>,
94 query: &ArrayView1<T>,
95 metric: Metric<T>,
96 start: Id,
97 filter: impl (Fn(Id) -> bool) + Send + Sync,
98) -> Option<PointDist> {
99 struct State {
100 cur: PointDist,
101 optima: Option<PointDist>,
102 }
103 let state = Mutex::new(State {
106 cur: PointDist {
107 id: start,
108 dist: metric(&graph.get_point(start).view(), query),
109 },
110 optima: None,
112 });
113 let seen = DashSet::<Id>::new();
114 seen.insert(start);
115 loop {
116 let cur = state.lock().cur.id;
117 graph
118 .get_out_neighbors(cur)
119 .into_par_iter()
120 .filter(|n| !seen.contains(n))
121 .map(|n| PointDist {
122 id: n,
123 dist: metric(&graph.get_point(cur).view(), &graph.get_point(n).view()),
124 })
125 .for_each(|n| {
126 seen.insert(n.id);
128 let mut s = state.lock();
129 if filter(n.id) && !s.optima.is_some_and(|o| n.dist >= o.dist) {
130 s.optima = Some(n);
131 }
132 if n.dist < s.cur.dist {
133 s.cur = n;
134 }
135 });
136 if state.lock().cur.id == cur {
137 break;
139 }
140 }
141 state.into_inner().optima
142}
143
144pub fn greedy_search<T: Scalar + Send + Sync>(
148 graph: &impl GreedySearchable<T>,
149 query: &ArrayView1<T>,
150 k: usize,
151 search_list_cap: usize,
152 beam_width: usize,
153 metric: Metric<T>,
154 start: Id,
155 filter: impl Fn(PointDist) -> bool,
156 mut out_visited: Option<&mut HashSet<Id>>,
157 mut out_metrics: Option<&mut SearchMetrics>,
158 ground_truth: Option<&HashSet<Id>>,
159) -> Vec<PointDist> {
160 assert!(
161 search_list_cap >= k,
162 "search list capacity must be greater than or equal to k"
163 );
164
165 let mut l_unvisited_set = HashSet::new();
172 let mut l_unvisited = VecDeque::<PointDist>::new(); let mut l_visited = VecDeque::<PointDist>::new(); let mut all_visited = HashSet::new();
175 l_unvisited.push_back(PointDist {
176 id: start,
177 dist: metric(&graph.get_point(start).view(), query),
178 });
179 let ground_truth_found = AtomicUsize::new(0);
180 while !l_unvisited.is_empty() {
181 let new_visited = (0..beam_width)
182 .filter_map(|_| l_unvisited.pop_front())
183 .collect::<Vec<_>>();
184 let neighbors = DashSet::new();
185 new_visited.par_iter().for_each(|p_star| {
186 for j in graph.get_out_neighbors(p_star.id) {
187 neighbors.insert(j);
188 }
189 });
190 all_visited.extend(new_visited.iter().map(|e| e.id));
192 l_visited.extend(new_visited.iter().filter(|e| filter(**e)));
193 l_visited
194 .make_contiguous()
195 .sort_unstable_by_key(|s| OrderedFloat(s.dist));
196
197 let new_unvisited = neighbors
198 .par_iter()
199 .filter_map(|neighbor| {
200 let neighbor = *neighbor;
201 if all_visited.contains(&neighbor) {
203 return None;
204 };
205 if l_unvisited_set.contains(&neighbor) {
206 return None;
207 };
208 if ground_truth.is_some_and(|gt| gt.contains(&neighbor)) {
209 ground_truth_found.fetch_add(1, Ordering::Relaxed);
210 }
211 Some(PointDist {
212 id: neighbor,
213 dist: metric(&graph.get_point(neighbor).view(), query),
214 })
215 })
216 .collect::<VecDeque<_>>();
217 l_unvisited_set.extend(new_unvisited.iter().map(|e| e.id));
218 l_unvisited.extend(&new_unvisited);
219 l_unvisited
220 .make_contiguous()
221 .sort_unstable_by_key(|s| OrderedFloat(s.dist));
222
223 let mut dropped_unvisited = 0;
224 let mut dropped_visited = 0;
225 while l_unvisited.len() + l_visited.len() > search_list_cap {
226 let (Some(u), Some(v)) = (l_unvisited.back(), l_visited.back()) else {
227 break;
228 };
229 if u.dist >= v.dist {
230 l_unvisited.pop_back();
231 dropped_unvisited += 1;
232 } else {
233 l_visited.pop_back();
234 dropped_visited += 1;
235 }
236 }
237
238 if let Some(m) = &mut out_metrics {
239 m.iterations.push(SearchIterationMetrics {
240 dropped_candidates: neighbors.len() - new_unvisited.len(),
241 dropped_unvisited,
242 dropped_visited,
243 ground_truth_found: ground_truth_found.load(Ordering::Relaxed),
244 new_candidates: new_unvisited.len(),
245 unvisited_dist_mins: l_unvisited.front().map(|n| n.dist).unwrap_or_default(),
246 unvisited_dist_sum: l_unvisited.iter().map(|n| n.dist).sum(),
247 visited: all_visited.len(),
248 });
249 }
250 }
251
252 let mut closest = Vec::new();
254 while closest.len() < k {
255 match (l_visited.pop_front(), l_unvisited.pop_front()) {
256 (None, None) => break,
257 (Some(v), None) | (None, Some(v)) => closest.push(v),
258 (Some(a), Some(b)) => {
259 if a.dist < b.dist {
260 closest.push(a);
261 closest.push(b);
262 } else {
263 closest.push(b);
264 closest.push(a);
265 };
266 }
267 };
268 }
269 closest.truncate(k);
271 if let Some(out) = &mut out_visited {
272 out.extend(all_visited);
273 };
274 closest
275}
276
277pub fn find_shortest_spanning_tree<T: Scalar + Send + Sync>(
278 graph: &impl GreedySearchable<T>,
279 metric: Metric<T>,
280 start: Id,
281) -> Vec<(Id, Id)> {
282 let mut visited = HashSet::<Id>::new();
283 let mut path = Vec::<(Id, Id)>::new();
284 let mut pq = BinaryHeap::<(OrderedFloat<f64>, Id, Id)>::new();
286 pq.push((OrderedFloat(0.0), start, start));
287 while let Some((_, from, to)) = pq.pop() {
288 if !visited.insert(to) {
289 continue;
291 };
292
293 path.push((from, to));
294
295 let new = graph
297 .get_out_neighbors(to)
298 .into_par_iter()
299 .filter_map(|neighbor| {
300 if visited.contains(&neighbor) {
301 return None;
302 }
303 let dist = metric(
304 &graph.get_point(to).view(),
305 &graph.get_point(neighbor).view(),
306 );
307 Some((OrderedFloat(-dist), to, neighbor))
309 })
310 .collect::<Vec<_>>();
311 pq.extend(new);
312 }
313 path
314}