1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
//! Shared graph construction utilities used by multiple ANN index modules.
//!
//! - [`build_knn_graph_nndescent`]: NN-descent (Dong et al., 2011) kNN graph construction.
//! - [`ensure_connectivity`]: O(n * k) graph connectivity repair via union-find.
//!
//! Both functions are conditionally used depending on which algorithm features
//! are enabled. Allow dead_code at the module level to avoid per-feature-combo lint noise.
#![allow(dead_code)]
use smallvec::SmallVec;
/// Build a kNN graph using NN-descent (Dong et al., 2011).
///
/// Returns neighbor lists as `Vec<SmallVec<[u32; 16]>>` where each entry contains
/// the `k` approximate nearest neighbor IDs for that node, sorted by distance (closest first).
///
/// # Arguments
/// - `n`: number of nodes
/// - `k`: target neighbor count per node
/// - `dist_fn`: `dist_fn(i, j)` returns the distance between nodes `i` and `j`
#[allow(clippy::needless_range_loop)]
pub fn build_knn_graph_nndescent<F>(n: usize, k: usize, dist_fn: F) -> Vec<SmallVec<[u32; 16]>>
where
F: Fn(usize, usize) -> f32,
{
let k = k.min(n.saturating_sub(1));
if k == 0 {
return vec![SmallVec::new(); n];
}
let mut nn: Vec<Vec<(f32, u32)>> = vec![Vec::with_capacity(k + 1); n];
// LCG RNG for deterministic random initialization.
let mut rng: u64 = 0xdeadbeef_cafebabe;
let lcg_next = |state: &mut u64| -> usize {
*state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(*state >> 33) as usize
};
// Initialize each node with k random distinct neighbors.
for i in 0..n {
let mut added = 0usize;
let mut attempts = 0usize;
while added < k && attempts < n * 4 {
attempts += 1;
let j = lcg_next(&mut rng) % n;
if j == i {
continue;
}
if nn[i].iter().any(|&(_, id)| id == j as u32) {
continue;
}
let d = dist_fn(i, j);
nn[i].push((d, j as u32));
added += 1;
}
nn[i].sort_unstable_by(|a, b| a.0.total_cmp(&b.0));
}
fn try_insert(list: &mut Vec<(f32, u32)>, k: usize, dist: f32, id: u32) -> bool {
if list.len() >= k && dist >= list[list.len() - 1].0 {
return false;
}
if list.iter().any(|&(_, nid)| nid == id) {
return false;
}
let pos = list.partition_point(|&(d, _)| d <= dist);
list.insert(pos, (dist, id));
if list.len() > k {
list.pop();
}
true
}
let max_iters = 10usize;
let early_stop_threshold = (0.001 * (n * k) as f64) as usize;
for _iter in 0..max_iters {
let mut updates = 0usize;
for u in 0..n {
let neighbors_u: Vec<(f32, u32)> = nn[u].clone();
let len = neighbors_u.len();
for a in 0..len {
let (_, v1) = neighbors_u[a];
for b in (a + 1)..len {
let (_, v2) = neighbors_u[b];
if v1 == v2 {
continue;
}
let d12 = dist_fn(v1 as usize, v2 as usize);
if try_insert(&mut nn[v1 as usize], k, d12, v2) {
updates += 1;
}
if try_insert(&mut nn[v2 as usize], k, d12, v1) {
updates += 1;
}
}
}
}
if updates <= early_stop_threshold {
break;
}
}
nn.into_iter()
.map(|list| list.into_iter().map(|(_, id)| id).collect())
.collect()
}
/// Ensure the graph rooted at `entry_point` is fully connected.
///
/// Uses union-find to identify connected components, then bridges each isolated
/// component to the main component via beam search on the existing graph.
/// Complexity: O(n * k) where k is average neighbor degree, vs O(n^2) for
/// the brute-force approach.
///
/// # Arguments
/// - `neighbors`: mutable neighbor adjacency lists
/// - `entry_point`: the root/medoid node index
/// - `dist_fn`: `dist_fn(i, j)` returns the distance between nodes `i` and `j`
pub fn ensure_connectivity<F>(neighbors: &mut [SmallVec<[u32; 16]>], entry_point: u32, dist_fn: F)
where
F: Fn(usize, usize) -> f32,
{
let n = neighbors.len();
if n <= 1 {
return;
}
// Union-Find with path compression and union by rank.
let mut parent: Vec<u32> = (0..n as u32).collect();
let mut rank: Vec<u8> = vec![0; n];
fn find(parent: &mut [u32], x: u32) -> u32 {
let mut r = x;
while parent[r as usize] != r {
parent[r as usize] = parent[parent[r as usize] as usize]; // path halving
r = parent[r as usize];
}
r
}
fn union(parent: &mut [u32], rank: &mut [u8], a: u32, b: u32) {
let ra = find(parent, a);
let rb = find(parent, b);
if ra == rb {
return;
}
if rank[ra as usize] < rank[rb as usize] {
parent[ra as usize] = rb;
} else if rank[ra as usize] > rank[rb as usize] {
parent[rb as usize] = ra;
} else {
parent[rb as usize] = ra;
rank[ra as usize] += 1;
}
}
// Build union-find from existing edges.
for (i, nbrs) in neighbors.iter().enumerate() {
for &nb in nbrs {
union(&mut parent, &mut rank, i as u32, nb);
}
}
let entry_root = find(&mut parent, entry_point);
// Collect one representative per non-entry component.
// For each isolated node, find its closest reachable neighbor via a local
// neighborhood scan (check neighbors-of-neighbors) instead of O(n) brute force.
for i in 0..n {
if find(&mut parent, i as u32) == entry_root {
continue;
}
// Strategy: scan all neighbors of our own neighbors to find one in the
// entry component. If that fails, fall back to scanning the entry point's
// neighborhood. This is O(k^2) per isolated node instead of O(n).
let mut best_id = entry_point;
let mut best_dist = dist_fn(i, entry_point as usize);
// Check neighbors-of-neighbors for a bridge.
let my_neighbors: SmallVec<[u32; 16]> = neighbors[i].clone();
for &nb in &my_neighbors {
let nb_neighbors: SmallVec<[u32; 16]> = neighbors[nb as usize].clone();
for &nb2 in &nb_neighbors {
if find(&mut parent, nb2) == entry_root {
let d = dist_fn(i, nb2 as usize);
if d < best_dist {
best_dist = d;
best_id = nb2;
}
}
}
}
// Scan entry point's neighbors (guaranteed reachable).
let entry_neighbors: SmallVec<[u32; 16]> = neighbors[entry_point as usize].clone();
for &nb in &entry_neighbors {
let d = dist_fn(i, nb as usize);
if d < best_dist {
best_dist = d;
best_id = nb;
}
}
// Two-hop: scan entry neighbors' neighbors for better bridge candidates.
// Costs O(k^2) per isolated node but covers a much larger neighborhood.
for &enb in &entry_neighbors {
let enb_neighbors: SmallVec<[u32; 16]> = neighbors[enb as usize].clone();
for &enb2 in &enb_neighbors {
if find(&mut parent, enb2) == entry_root {
let d = dist_fn(i, enb2 as usize);
if d < best_dist {
best_dist = d;
best_id = enb2;
}
}
}
}
// Bridge the components.
neighbors[i].push(best_id);
neighbors[best_id as usize].push(i as u32);
union(&mut parent, &mut rank, i as u32, best_id);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nndescent_basic() {
// 2D points in a line: 0, 1, 2, 3, 4
let points: Vec<[f32; 2]> = (0..5).map(|i| [i as f32, 0.0]).collect();
let dist_fn = |i: usize, j: usize| {
let dx = points[i][0] - points[j][0];
let dy = points[i][1] - points[j][1];
(dx * dx + dy * dy).sqrt()
};
let neighbors = build_knn_graph_nndescent(5, 2, dist_fn);
assert_eq!(neighbors.len(), 5);
// Each node should have 2 neighbors
for nb in &neighbors {
assert_eq!(nb.len(), 2);
}
// Node 0's nearest should be node 1
assert!(neighbors[0].contains(&1));
// Node 4's nearest should be node 3
assert!(neighbors[4].contains(&3));
}
#[test]
fn test_ensure_connectivity_already_connected() {
let mut neighbors: Vec<SmallVec<[u32; 16]>> = vec![
SmallVec::from_slice(&[1]),
SmallVec::from_slice(&[0, 2]),
SmallVec::from_slice(&[1]),
];
let dist_fn = |i: usize, j: usize| (i as f32 - j as f32).abs();
ensure_connectivity(&mut neighbors, 0, dist_fn);
// Should be unchanged (already connected)
assert_eq!(neighbors[0].len(), 1);
}
#[test]
fn test_ensure_connectivity_disconnected() {
// Two components: {0, 1} and {2, 3}
let mut neighbors: Vec<SmallVec<[u32; 16]>> = vec![
SmallVec::from_slice(&[1]),
SmallVec::from_slice(&[0]),
SmallVec::from_slice(&[3]),
SmallVec::from_slice(&[2]),
];
let dist_fn = |i: usize, j: usize| (i as f32 - j as f32).abs();
ensure_connectivity(&mut neighbors, 0, dist_fn);
// Verify all nodes reachable from entry_point 0
let mut visited = [false; 4];
let mut stack = vec![0usize];
visited[0] = true;
while let Some(node) = stack.pop() {
for &nb in &neighbors[node] {
let nb = nb as usize;
if !visited[nb] {
visited[nb] = true;
stack.push(nb);
}
}
}
assert!(visited.iter().all(|&v| v), "All nodes should be reachable");
}
}