1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
use std::collections::{BinaryHeap, HashMap, HashSet};
use super::node::Node;
use crate::{distance_metric::DistanceMetric, primitives::vector::Vector};
use anyhow::{bail, Result};
use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};
/// World is the main struct that represents the full HNSW graph world
#[derive(Clone, Serialize, Deserialize)]
pub struct World {
/// nodes is a list of all the nodes in the world by id
nodes: HashMap<u32, Node>,
/// level_entrypoints is a list of the ids of the entrypoint nodes for each level
/// index = level, value = id of the entrypoint node
level_entrypoints: Vec<u32>,
/// m is the maximum number of connections for a node
m: usize,
/// ef_construction is the maximum number of connections to explore for a node during construction
ef_construction: usize,
/// ef_search is the maximum number of connections to explore for a node during search
ef_search: usize,
// max_level is the maximum level of the HNSW graph
max_level: usize,
/// distance_metric is the distance metric used to calculate distances between vectors
distance_metric: DistanceMetric,
}
impl World {
pub fn new(
m: usize,
ef_construction: usize,
ef_search: usize,
distance_metric: DistanceMetric,
) -> Result<Self> {
// ef_construction must be >= M
if ef_construction < m {
bail!("ef_construction must be >= M");
}
Ok(Self {
nodes: HashMap::new(),
level_entrypoints: vec![],
m,
ef_construction,
ef_search,
max_level: 1,
distance_metric,
})
}
/// new_from_dump creates a new world from a serialized dump
pub fn new_from_dump(data: &[u8]) -> Result<Self> {
bincode::deserialize(data)
.map_err(|e| anyhow::anyhow!("Failed to deserialize world: {}", e))
}
/// pick_node_level picks the level at which a new node should be inserted based on the probabalistic insertion strategy.
pub(crate) fn pick_node_level(&self) -> usize {
let p = 1.0 / (self.m as f32);
let mut level = 0;
while fastrand::f32() < (1.0 - p) && level < self.max_level {
level += 1;
}
level
}
/// get_entrypoint_node gets the entrypoint node for the HNSW graph.
fn get_entrypoint_node(&self) -> Node {
self.get_entrypoint_node_per_level(self.max_level)
}
/// get_entrypoint_node_per_level gets the entrypoint node for a given level
fn get_entrypoint_node_per_level(&self, level: usize) -> Node {
if self.level_entrypoints.is_empty() {
return self.nodes.values().next().unwrap().clone();
}
let id = self.level_entrypoints[level];
self.nodes.get(&id).unwrap().clone()
}
/// greedy_search performs a greedy search for the k nearest neighbours to the query vector
fn greedy_search(&self, query: &Vector, entry_node: &Node, level: usize) -> Vec<u32> {
let mut visited = HashSet::new();
let mut candidates: BinaryHeap<(OrderedFloat<f32>, u32)> = BinaryHeap::new();
let mut best_candidates: BinaryHeap<(OrderedFloat<f32>, u32)> = BinaryHeap::new();
// get the distance between the new node and the entry node
let distance = entry_node.distance(query, &self.distance_metric);
// add the entry node to the candidates
// we're using negatives here because BinaryHeap is a max heap by default and we want min heap behaviour to find the nearest neighbours, not the furthest
candidates.push((-OrderedFloat(distance), entry_node.id()));
// add to the visited set
visited.insert(entry_node.id());
best_candidates.push((OrderedFloat(distance), entry_node.id()));
// let's go through the graph to find the best candidates
while !candidates.is_empty() {
let (current_dist, current_id) = candidates.pop().unwrap();
// if the current distance is greater than the best candidate, skip it because it's a bad candidate
if !best_candidates.is_empty() && -current_dist > best_candidates.peek().unwrap().0 {
continue;
}
// at this level, we need to check the neighbours
for neighbour_id in self.nodes.get(¤t_id).unwrap().connections(level) {
// if we've already visited this node, skip it
if visited.contains(&neighbour_id) {
continue;
}
// visit the node
visited.insert(neighbour_id);
// get the distance between the new node and the neighbour
let distance = self
.nodes
.get(&neighbour_id)
.unwrap()
.distance(query, &self.distance_metric);
// if this candidate is better than the best candidate
let ef_size = if level == 0 {
self.ef_search
} else {
self.ef_construction
};
if best_candidates.len() < ef_size
|| OrderedFloat(distance) < best_candidates.peek().unwrap().0
{
// The new candidate is strictly better (smaller distance) than the worst one we have so far.
candidates.push((OrderedFloat(distance), neighbour_id));
best_candidates.push((OrderedFloat(distance), neighbour_id));
// Enforce ef_construction by popping the largest distance from best_candidates if needed (max heap, so the root node is the furthest and therefore the worst)
if best_candidates.len() > ef_size {
best_candidates.pop();
}
}
}
}
// return our best candidates
best_candidates.into_iter().map(|(_, id)| id).collect()
}
/// insert_node inserts a new node into the world.
/// id must be fully unique in the World
// 1. pick the level at which to insert the node
// 2. find the M nearest neighbors for the node at the chosen level
// 3. connect the new node to the neighbors and on all lower levels
// 4. recursively connect the new node to the neighbors' neighbors
// 5. if the new node has no connections, add it to the graph at level 0
pub fn insert_vector(&mut self, id: u32, vector: Vector) -> Result<()> {
// If this is the first node, initialize it as the entrypoint for all levels
if self.nodes.is_empty() {
let initial_level = self.pick_node_level();
let node = Node::new(id, vector, initial_level);
self.nodes.insert(node.id(), node.clone());
self.level_entrypoints = vec![id; initial_level + 1];
self.max_level = initial_level;
return Ok(());
}
// ensure the id is completely unique
if self.nodes.contains_key(&id) {
bail!("Node id must be unique");
}
let new_max_level = calculate_max_level(self.nodes.len() + 1, self.m);
if new_max_level > self.max_level {
self.max_level = new_max_level;
self.level_entrypoints.resize(new_max_level + 1, id);
}
let level = self.pick_node_level();
let mut node = Node::new(id, vector, level);
// add the new node to the world
self.nodes.insert(node.id(), node.clone());
// Start from the top-level entry point
let mut current_node_id = self.level_entrypoints[self.max_level];
for lvl in (level + 1..=self.max_level).rev() {
let current_node = self.nodes.get(¤t_node_id).unwrap();
let candidates = self.greedy_search(node.value(), current_node, lvl);
// Pick the closest candidate as the new entry point for the next level down
if let Some(closest_id) = candidates.iter().min_by(|&id_a, &id_b| {
let dist_a = self
.nodes
.get(id_a)
.unwrap()
.distance(&node.value(), &self.distance_metric);
let dist_b = self
.nodes
.get(id_b)
.unwrap()
.distance(&node.value(), &self.distance_metric);
dist_a.partial_cmp(&dist_b).unwrap()
}) {
current_node_id = *closest_id;
}
}
// Now we are at the correct insertion level (node_level), perform a local search here
let insertion_node = self.nodes.get(¤t_node_id).unwrap();
let nearest_neighbours = self.greedy_search(node.value(), insertion_node, level);
// Connect new node with found neighbors and prune
for &nbr_id in &nearest_neighbours {
// Connect `node` and neighbor
node.connect(self.nodes.get_mut(&nbr_id).unwrap(), level);
}
// prune the node if it has more than M connections
if node.connections(level).len() > self.m {
self.prune_node_connections(node.id(), level);
}
println!("Inserted node {} at level {}", id, level);
println!("My neighbours are {:?}", node.connections(level));
Ok(())
}
/// prune_node_connections prunes the connections of a node at a given level by getting rid of the furthest connections
fn prune_node_connections(&mut self, node_id: u32, level: usize) {
let mut distances: Vec<(u32, f32)> = {
let node = self.nodes.get(&node_id).unwrap();
node.connections(level)
.iter()
.map(|&neighbour_id| {
let distance = node.distance(
self.nodes.get(&neighbour_id).unwrap().value(),
&self.distance_metric,
);
(neighbour_id, distance)
})
.collect()
};
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let connections_to_remove = distances
.iter()
.skip(self.m)
.map(|(id, _)| *id)
.collect::<Vec<_>>();
// Remove connections from the current node
let node = self.nodes.get_mut(&node_id).unwrap();
node.remove_connections(&connections_to_remove, level);
// Remove the corresponding back-connections from neighbor nodes
for &neighbor_id in &connections_to_remove {
if let Some(neighbor) = self.nodes.get_mut(&neighbor_id) {
neighbor.remove_connections(&vec![node_id], level);
}
}
}
/// search gets the k nearest neighbours to the query vector using beam search
pub fn search(&self, query: &Vector, k: usize, beam_width: usize) -> Result<Vec<u32>> {
if k > self.ef_search {
bail!(
"k is greater than the maximum number of connections to explore for a node during search"
);
}
let candidates = self.beam_search(query, beam_width);
let mut results: Vec<(u32, f32)> = candidates
.into_iter()
.map(|id| {
let node = self.nodes.get(&id).unwrap();
let distance = node.distance(query, &self.distance_metric);
(id, distance)
})
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
// Return top k results
Ok(results.into_iter().take(k).map(|(id, _)| id).collect())
}
/// beam_search performs a beam search for the k nearest neighbours to the query vector
fn beam_search(&self, query: &Vector, beam_width: usize) -> Vec<u32> {
let mut candidates: BinaryHeap<(OrderedFloat<f32>, u32)> = BinaryHeap::new();
let entrypoint_node = self.get_entrypoint_node();
let initial_distance = entrypoint_node.distance(query, &self.distance_metric);
candidates.push((OrderedFloat(initial_distance), entrypoint_node.id()));
let mut visited = HashSet::new();
let mut final_candidates = Vec::new();
for level in (0..=self.max_level).rev() {
let mut next_candidates = Vec::new();
// Store current candidates before processing
let current_candidates: Vec<_> = candidates.drain().collect();
for (dist, candidate_id) in current_candidates {
if visited.contains(&candidate_id) {
continue;
}
visited.insert(candidate_id);
let candidate = self.nodes.get(&candidate_id).unwrap();
let local_best = self.greedy_search(&query, candidate, level);
// Add the current candidate to final candidates
final_candidates.push((dist, candidate_id));
for &id in &local_best {
if !visited.contains(&id) {
let dist = self
.nodes
.get(&id)
.unwrap()
.distance(query, &self.distance_metric);
next_candidates.push((OrderedFloat(dist), id));
}
}
}
// Combine current and next candidates
candidates = next_candidates.into_iter().collect();
}
// Return the best candidates we've found
final_candidates.sort_by_key(|(dist, _)| *dist);
final_candidates
.into_iter()
.take(beam_width)
.map(|(_, id)| id)
.collect()
}
/// dump serializes the world to binary data so the user can save it for later use without abstraction
pub fn dump(&self) -> Result<Vec<u8>> {
bincode::serialize(&self).map_err(|e| anyhow::anyhow!("Failed to serialize world: {}", e))
}
}
/// calculate_max_level calculates the maximum level of the HNSW graph based on the number of nodes and the maximum number of connections per node
fn calculate_max_level(n: usize, m: usize) -> usize {
// p = 1/m
// max_level ≈ log(n)/log(m)
(((n as f64).ln() / (m as f64).ln()).ceil() as usize).max(1)
}
#[cfg(test)]
mod tests {
use crate::distance_metric::CosineDistance;
use super::*;
#[test]
fn test_world_insert_and_search() -> Result<()> {
let mut world = World::new(5, 10, 10, DistanceMetric::Cosine(CosineDistance))?;
let test_vectors = vec![
(1, Vector::new_f32(&[1.0, 0.0, 0.0])),
(2, Vector::new_f32(&[0.0, 1.0, 0.0])),
(3, Vector::new_f32(&[0.0, 0.0, 1.0])),
(4, Vector::new_f32(&[0.7, 0.7, 0.0])),
];
for (id, vector) in test_vectors {
world.insert_vector(id, vector)?;
}
let query = Vector::new_f32(&[0.8, 0.8, 0.0]);
let results = world.search(&query, 2, 5)?;
assert!(results.len() >= 1, "Should find at least 1 result");
Ok(())
}
#[test]
fn test_world_dump_and_load() -> Result<()> {
// make world, dump it, hash it, load it, hash it, assert equal
let world = World::new(5, 10, 10, DistanceMetric::Cosine(CosineDistance))?;
let dump = world.dump()?;
let original_hash = blake3::hash(&dump);
let loaded_world = World::new_from_dump(&dump)?;
let loaded_hash = blake3::hash(&loaded_world.dump()?);
assert_eq!(original_hash, loaded_hash);
Ok(())
}
}