1use ahash::{AHashMap, AHashSet};
19use ordered_float::OrderedFloat;
20use rand::Rng;
21use std::cmp::Reverse;
22use std::collections::BinaryHeap;
23
24use crate::error::{Error, Result};
25use super::traits::{DistanceType, IndexConfig, SearchResult, VectorIndex};
26
27#[derive(Debug, Clone)]
29pub struct HNSWConfig {
30 pub base: IndexConfig,
32 pub m: usize,
34 pub m_max0: usize,
36 pub ef_construction: usize,
38 pub ef_search: usize,
40 pub ml: f64,
42}
43
44impl HNSWConfig {
45 #[must_use]
47 pub fn new(dimension: usize) -> Self {
48 let m = 16;
49 Self {
50 base: IndexConfig::new(dimension),
51 m,
52 m_max0: 2 * m,
53 ef_construction: 200,
54 ef_search: 128,
55 ml: 1.0 / (m as f64).ln(),
56 }
57 }
58
59 #[must_use]
61 pub fn with_m(mut self, m: usize) -> Self {
62 self.m = m;
63 self.m_max0 = 2 * m;
64 self.ml = 1.0 / (m as f64).ln();
65 self
66 }
67
68 #[must_use]
70 pub const fn with_ef_construction(mut self, ef: usize) -> Self {
71 self.ef_construction = ef;
72 self
73 }
74
75 #[must_use]
77 pub const fn with_ef_search(mut self, ef: usize) -> Self {
78 self.ef_search = ef;
79 self
80 }
81
82 #[must_use]
84 pub fn with_distance(mut self, distance_type: DistanceType) -> Self {
85 self.base.distance_type = distance_type;
86 self
87 }
88}
89
90#[derive(Debug, Clone)]
92struct HNSWNode {
93 id: String,
95 vector: Vec<f32>,
97 #[allow(dead_code)]
99 level: usize,
100 neighbors: Vec<AHashSet<usize>>,
102}
103
104#[derive(Debug)]
108pub struct HNSWIndex {
109 config: HNSWConfig,
111 nodes: Vec<HNSWNode>,
113 id_to_idx: AHashMap<String, usize>,
115 entry_point: Option<usize>,
117 max_level: usize,
119 rng: parking_lot::Mutex<rand::rngs::SmallRng>,
121}
122
123impl HNSWIndex {
124 #[must_use]
126 pub fn new(config: HNSWConfig) -> Self {
127 use rand::SeedableRng;
128 Self {
129 config,
130 nodes: Vec::new(),
131 id_to_idx: AHashMap::new(),
132 entry_point: None,
133 max_level: 0,
134 rng: parking_lot::Mutex::new(rand::rngs::SmallRng::from_entropy()),
135 }
136 }
137
138 fn random_level(&self) -> usize {
140 let mut rng = self.rng.lock();
141 let mut level = 0;
142 while rng.gen::<f64>() < self.config.ml && level < 16 {
143 level += 1;
144 }
145 level
146 }
147
148 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
150 match self.config.base.distance_type {
151 DistanceType::L2 => {
152 a.iter()
153 .zip(b.iter())
154 .map(|(x, y)| (x - y).powi(2))
155 .sum::<f32>()
156 .sqrt()
157 }
158 DistanceType::InnerProduct => {
159 -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
161 }
162 DistanceType::Cosine => {
163 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
164 let norm_a: f32 = a.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
165 let norm_b: f32 = b.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
166 if norm_a == 0.0 || norm_b == 0.0 {
167 1.0
168 } else {
169 1.0 - (dot / (norm_a * norm_b))
170 }
171 }
172 }
173 }
174
175 fn search_layer(
177 &self,
178 query: &[f32],
179 entry_points: Vec<usize>,
180 ef: usize,
181 level: usize,
182 ) -> Vec<(f32, usize)> {
183 let mut visited: AHashSet<usize> = entry_points.iter().copied().collect();
184
185 let mut candidates: BinaryHeap<Reverse<(OrderedFloat<f32>, usize)>> = BinaryHeap::new();
187
188 let mut results: BinaryHeap<(OrderedFloat<f32>, usize)> = BinaryHeap::new();
190
191 for &ep in &entry_points {
193 let dist = self.distance(query, &self.nodes[ep].vector);
194 candidates.push(Reverse((OrderedFloat(dist), ep)));
195 results.push((OrderedFloat(dist), ep));
196 }
197
198 while let Some(Reverse((OrderedFloat(c_dist), c_idx))) = candidates.pop() {
199 let f_dist = results.peek().map(|(d, _)| d.0).unwrap_or(f32::INFINITY);
201
202 if c_dist > f_dist && results.len() >= ef {
203 break;
204 }
205
206 if level < self.nodes[c_idx].neighbors.len() {
208 for &neighbor_idx in &self.nodes[c_idx].neighbors[level] {
209 if visited.insert(neighbor_idx) {
210 let dist = self.distance(query, &self.nodes[neighbor_idx].vector);
211 let f_dist = results.peek().map(|(d, _)| d.0).unwrap_or(f32::INFINITY);
212
213 if dist < f_dist || results.len() < ef {
214 candidates.push(Reverse((OrderedFloat(dist), neighbor_idx)));
215 results.push((OrderedFloat(dist), neighbor_idx));
216
217 if results.len() > ef {
218 results.pop();
219 }
220 }
221 }
222 }
223 }
224 }
225
226 let mut result_vec: Vec<_> = results.into_iter().map(|(d, idx)| (d.0, idx)).collect();
228 result_vec.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
229 result_vec
230 }
231
232 fn select_neighbors(&self, candidates: &[(f32, usize)], m: usize) -> Vec<usize> {
234 candidates.iter().take(m).map(|(_, idx)| *idx).collect()
235 }
236
237 fn get_max_connections(&self, level: usize) -> usize {
239 if level == 0 {
240 self.config.m_max0
241 } else {
242 self.config.m
243 }
244 }
245}
246
247impl VectorIndex for HNSWIndex {
248 fn add(&mut self, id: String, vector: &[f32]) -> Result<()> {
249 if vector.len() != self.config.base.dimension {
250 return Err(Error::InvalidQuery {
251 reason: format!(
252 "Dimension mismatch: expected {}, got {}",
253 self.config.base.dimension,
254 vector.len()
255 ),
256 });
257 }
258
259 if self.id_to_idx.contains_key(&id) {
260 return Err(Error::DuplicateRecord { record_id: id });
261 }
262
263 let level = self.random_level();
264 let new_idx = self.nodes.len();
265
266 let mut node = HNSWNode {
268 id: id.clone(),
269 vector: vector.to_vec(),
270 level,
271 neighbors: vec![AHashSet::new(); level + 1],
272 };
273
274 if self.entry_point.is_none() {
276 self.nodes.push(node);
277 self.id_to_idx.insert(id, new_idx);
278 self.entry_point = Some(new_idx);
279 self.max_level = level;
280 return Ok(());
281 }
282
283 let entry_point = self.entry_point.unwrap();
284 let mut curr_ep = vec![entry_point];
285
286 for lc in (level + 1..=self.max_level).rev() {
288 let nearest = self.search_layer(vector, curr_ep.clone(), 1, lc);
289 if !nearest.is_empty() {
290 curr_ep = vec![nearest[0].1];
291 }
292 }
293
294 for lc in (0..=level.min(self.max_level)).rev() {
296 let candidates = self.search_layer(
297 vector,
298 curr_ep.clone(),
299 self.config.ef_construction,
300 lc,
301 );
302
303 let m = self.get_max_connections(lc);
304 let neighbors = self.select_neighbors(&candidates, m);
305
306 node.neighbors[lc] = neighbors.iter().copied().collect();
308
309 for &neighbor_idx in &neighbors {
310 if lc < self.nodes[neighbor_idx].neighbors.len() {
311 self.nodes[neighbor_idx].neighbors[lc].insert(new_idx);
312
313 if self.nodes[neighbor_idx].neighbors[lc].len() > m {
315 let neighbor_vec = &self.nodes[neighbor_idx].vector;
316 let new_node_vec = vector;
318 let mut scored: Vec<_> = self.nodes[neighbor_idx].neighbors[lc]
319 .iter()
320 .map(|&idx| {
321 let dist = if idx == new_idx {
322 self.distance(neighbor_vec, new_node_vec)
323 } else {
324 self.distance(neighbor_vec, &self.nodes[idx].vector)
325 };
326 (dist, idx)
327 })
328 .collect();
329 scored.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
330 self.nodes[neighbor_idx].neighbors[lc] =
331 scored.into_iter().take(m).map(|(_, idx)| idx).collect();
332 }
333 }
334 }
335
336 if !candidates.is_empty() {
337 curr_ep = vec![candidates[0].1];
338 }
339 }
340
341 self.nodes.push(node);
342 self.id_to_idx.insert(id, new_idx);
343
344 if level > self.max_level {
346 self.entry_point = Some(new_idx);
347 self.max_level = level;
348 }
349
350 Ok(())
351 }
352
353 fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
354 if query.len() != self.config.base.dimension {
355 return Err(Error::InvalidQuery {
356 reason: format!(
357 "Query dimension mismatch: expected {}, got {}",
358 self.config.base.dimension,
359 query.len()
360 ),
361 });
362 }
363
364 if self.nodes.is_empty() {
365 return Ok(vec![]);
366 }
367
368 let entry_point = self.entry_point.unwrap();
369 let mut curr_ep = vec![entry_point];
370
371 for lc in (1..=self.max_level).rev() {
373 let nearest = self.search_layer(query, curr_ep.clone(), 1, lc);
374 if !nearest.is_empty() {
375 curr_ep = vec![nearest[0].1];
376 }
377 }
378
379 let results = self.search_layer(query, curr_ep, self.config.ef_search, 0);
381
382 let k = k.min(results.len());
384 Ok(results
385 .into_iter()
386 .take(k)
387 .map(|(dist, idx)| {
388 let actual_dist = match self.config.base.distance_type {
389 DistanceType::InnerProduct => -dist,
390 DistanceType::Cosine => 1.0 - dist,
391 DistanceType::L2 => dist,
392 };
393 SearchResult::new(
394 self.nodes[idx].id.clone(),
395 actual_dist,
396 self.config.base.distance_type,
397 )
398 })
399 .collect())
400 }
401
402 fn remove(&mut self, id: &str) -> Result<bool> {
403 if let Some(&idx) = self.id_to_idx.get(id) {
406 for node in &mut self.nodes {
408 for neighbors in &mut node.neighbors {
409 neighbors.remove(&idx);
410 }
411 }
412 self.id_to_idx.remove(id);
413 self.nodes[idx].id = String::new();
415 self.nodes[idx].vector.clear();
416 Ok(true)
417 } else {
418 Ok(false)
419 }
420 }
421
422 fn contains(&self, id: &str) -> bool {
423 self.id_to_idx.contains_key(id)
424 }
425
426 fn len(&self) -> usize {
427 self.id_to_idx.len()
428 }
429
430 fn dimension(&self) -> usize {
431 self.config.base.dimension
432 }
433
434 fn distance_type(&self) -> DistanceType {
435 self.config.base.distance_type
436 }
437
438 fn clear(&mut self) {
439 self.nodes.clear();
440 self.id_to_idx.clear();
441 self.entry_point = None;
442 self.max_level = 0;
443 }
444
445 fn memory_usage(&self) -> usize {
446 let node_size = self.config.base.dimension * 4 + 64; let neighbor_size = self.config.m * 8 * 2; self.nodes.len() * (node_size + neighbor_size)
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455
456 fn create_test_index() -> HNSWIndex {
457 let config = HNSWConfig::new(4)
458 .with_m(4)
459 .with_ef_construction(16)
460 .with_ef_search(16);
461 let mut index = HNSWIndex::new(config);
462
463 index.add("a".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
464 index.add("b".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
465 index.add("c".to_string(), &[0.0, 0.0, 1.0, 0.0]).unwrap();
466 index.add("d".to_string(), &[0.5, 0.5, 0.0, 0.0]).unwrap();
467 index.add("e".to_string(), &[0.9, 0.1, 0.0, 0.0]).unwrap();
468
469 index
470 }
471
472 #[test]
473 fn test_add_and_search() {
474 let index = create_test_index();
475
476 let results = index.search(&[1.0, 0.0, 0.0, 0.0], 3).unwrap();
477
478 assert!(!results.is_empty());
479 assert!(results[0].id == "a" || results[0].id == "e");
481 }
482
483 #[test]
484 fn test_recall() {
485 let config = HNSWConfig::new(8).with_m(8).with_ef_search(32);
486 let mut index = HNSWIndex::new(config);
487
488 for i in 0..100 {
490 let vec: Vec<f32> = (0..8).map(|j| ((i * j) % 100) as f32 / 100.0).collect();
491 index.add(format!("v{}", i), &vec).unwrap();
492 }
493
494 let results = index.search(&[0.5; 8], 10).unwrap();
496 assert_eq!(results.len(), 10);
497 }
498
499 #[test]
500 fn test_duplicate_id() {
501 let mut index = create_test_index();
502
503 let result = index.add("a".to_string(), &[0.0, 0.0, 0.0, 1.0]);
504 assert!(result.is_err());
505 }
506}