Skip to main content

reddb_server/storage/engine/
hnsw.rs

1//! HNSW (Hierarchical Navigable Small World) Index
2//!
3//! A from-scratch implementation of the HNSW algorithm for approximate
4//! nearest neighbor search. No external dependencies.
5//!
6//! # Algorithm Overview
7//!
8//! HNSW builds a multi-layer graph where:
9//! - Layer 0 contains all nodes
10//! - Higher layers contain progressively fewer nodes
11//! - Each layer is a navigable small world graph
12//!
13//! Search starts from the top layer and greedily descends,
14//! using each layer to quickly approach the target region.
15//!
16//! # References
17//!
18//! - Original paper: "Efficient and robust approximate nearest neighbor search
19//!   using Hierarchical Navigable Small World graphs" (Malkov & Yashunin, 2018)
20
21use std::cmp::Reverse;
22use std::collections::{BinaryHeap, HashMap, HashSet};
23use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
24
25use super::distance::{
26    cmp_f32, distance_simd, DistanceMetric, DistanceResult, ReverseDistanceResult,
27};
28
29/// Node identifier in the HNSW graph
30pub type NodeId = u64;
31
32/// HNSW index configuration parameters
33#[derive(Debug, Clone)]
34pub struct HnswConfig {
35    /// Maximum number of connections per node (except layer 0)
36    pub m: usize,
37    /// Maximum connections at layer 0 (typically 2*M)
38    pub m_max0: usize,
39    /// Size of dynamic candidate list during construction
40    pub ef_construction: usize,
41    /// Size of dynamic candidate list during search (can be adjusted)
42    pub ef_search: usize,
43    /// Normalization factor for layer assignment (1/ln(M))
44    pub ml: f64,
45    /// Distance metric to use
46    pub metric: DistanceMetric,
47}
48
49impl Default for HnswConfig {
50    fn default() -> Self {
51        let m = 16;
52        Self {
53            m,
54            m_max0: m * 2,
55            ef_construction: 100,
56            ef_search: 50,
57            ml: 1.0 / (m as f64).ln(),
58            metric: DistanceMetric::L2,
59        }
60    }
61}
62
63impl HnswConfig {
64    /// Create a new configuration with custom M value
65    pub fn with_m(m: usize) -> Self {
66        Self {
67            m,
68            m_max0: m * 2,
69            ml: 1.0 / (m as f64).ln(),
70            ..Default::default()
71        }
72    }
73
74    /// Set the distance metric
75    pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
76        self.metric = metric;
77        self
78    }
79
80    /// Set ef_construction
81    pub fn with_ef_construction(mut self, ef: usize) -> Self {
82        self.ef_construction = ef;
83        self
84    }
85
86    /// Set ef_search
87    pub fn with_ef_search(mut self, ef: usize) -> Self {
88        self.ef_search = ef;
89        self
90    }
91}
92
93/// A node in the HNSW graph
94#[derive(Debug, Clone)]
95struct HnswNode {
96    /// The node's ID
97    id: NodeId,
98    /// The node's vector data
99    vector: Vec<f32>,
100    /// Maximum layer this node appears in
101    max_layer: usize,
102    /// Connections at each layer (layer index -> neighbor IDs)
103    connections: Vec<Vec<NodeId>>,
104}
105
106impl HnswNode {
107    fn new(id: NodeId, vector: Vec<f32>, max_layer: usize) -> Self {
108        let mut connections = Vec::with_capacity(max_layer + 1);
109        for _ in 0..=max_layer {
110            connections.push(Vec::new());
111        }
112        Self {
113            id,
114            vector,
115            max_layer,
116            connections,
117        }
118    }
119}
120
121/// HNSW Index for approximate nearest neighbor search
122pub struct HnswIndex {
123    /// Configuration parameters
124    config: HnswConfig,
125    /// All nodes in the index
126    nodes: HashMap<NodeId, HnswNode>,
127    /// Entry point (node with highest layer)
128    entry_point: Option<NodeId>,
129    /// Maximum layer in the graph
130    max_layer: usize,
131    /// Vector dimension
132    dimension: usize,
133    /// Next available node ID
134    next_id: AtomicU64,
135    /// Simple RNG state for layer assignment
136    rng_state: u64,
137}
138
139impl HnswIndex {
140    /// Create a new empty HNSW index
141    pub fn new(dimension: usize, config: HnswConfig) -> Self {
142        Self {
143            config,
144            nodes: HashMap::new(),
145            entry_point: None,
146            max_layer: 0,
147            dimension,
148            next_id: AtomicU64::new(0),
149            rng_state: 0x853c49e6748fea9b, // Random seed
150        }
151    }
152
153    /// Create with default configuration
154    pub fn with_dimension(dimension: usize) -> Self {
155        Self::new(dimension, HnswConfig::default())
156    }
157
158    /// Get the number of vectors in the index
159    pub fn len(&self) -> usize {
160        self.nodes.len()
161    }
162
163    /// Check if the index is empty
164    pub fn is_empty(&self) -> bool {
165        self.nodes.is_empty()
166    }
167
168    /// Get the vector for a node ID
169    pub fn get_vector(&self, id: NodeId) -> Option<&[f32]> {
170        self.nodes.get(&id).map(|n| n.vector.as_slice())
171    }
172
173    /// Insert a vector and return its assigned ID
174    pub fn insert(&mut self, vector: Vec<f32>) -> NodeId {
175        let id = self.next_id.fetch_add(1, AtomicOrdering::SeqCst);
176        self.insert_with_id(id, vector);
177        id
178    }
179
180    /// Insert a vector with a specific ID
181    pub fn insert_with_id(&mut self, id: NodeId, vector: Vec<f32>) {
182        assert_eq!(
183            vector.len(),
184            self.dimension,
185            "Vector dimension mismatch: expected {}, got {}",
186            self.dimension,
187            vector.len()
188        );
189
190        // Assign random layer using exponential distribution
191        let node_layer = self.random_layer();
192
193        // Create the node
194        let node = HnswNode::new(id, vector, node_layer);
195
196        if self.entry_point.is_none() {
197            // First node - just add it
198            self.nodes.insert(id, node);
199            self.entry_point = Some(id);
200            self.max_layer = node_layer;
201            return;
202        }
203
204        let entry_point = self.entry_point.unwrap();
205        let vector = self.nodes.get(&id).map(|n| n.vector.clone());
206
207        // We need to insert the node first so we can access its vector
208        // But we need to find neighbors first... let's clone the vector
209        let vector = node.vector.clone();
210        self.nodes.insert(id, node);
211
212        // Find entry point for search
213        let mut current = entry_point;
214
215        // Traverse from top layer down to node_layer + 1
216        // This finds the best entry point for the insertion layer
217        for layer in (node_layer + 1..=self.max_layer).rev() {
218            current = self.search_layer_single(&vector, current, layer);
219        }
220
221        // For layers from node_layer down to 0, find and connect to neighbors
222        for layer in (0..=node_layer.min(self.max_layer)).rev() {
223            // Find ef_construction nearest neighbors at this layer
224            let neighbors = self.search_layer(&vector, current, self.config.ef_construction, layer);
225
226            // Select M best neighbors
227            let m = if layer == 0 {
228                self.config.m_max0
229            } else {
230                self.config.m
231            };
232            let selected: Vec<NodeId> = neighbors.into_iter().take(m).map(|r| r.id).collect();
233
234            // Connect node to selected neighbors
235            if let Some(node) = self.nodes.get_mut(&id) {
236                node.connections[layer] = selected.clone();
237            }
238
239            // Add bidirectional connections
240            for &neighbor_id in &selected {
241                self.add_connection(neighbor_id, id, layer);
242            }
243
244            // Update current for next layer
245            if let Some(&first) = selected.first() {
246                current = first;
247            }
248        }
249
250        // Update entry point if new node has higher layer
251        if node_layer > self.max_layer {
252            self.entry_point = Some(id);
253            self.max_layer = node_layer;
254        }
255    }
256
257    /// Search for k nearest neighbors
258    pub fn search(&self, query: &[f32], k: usize) -> Vec<DistanceResult> {
259        self.search_with_ef(query, k, self.config.ef_search)
260    }
261
262    /// Search with custom ef parameter
263    pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<DistanceResult> {
264        if self.entry_point.is_none() {
265            return Vec::new();
266        }
267
268        let entry_point = self.entry_point.unwrap();
269        let mut current = entry_point;
270
271        // Traverse from top layer down to layer 1
272        for layer in (1..=self.max_layer).rev() {
273            current = self.search_layer_single(query, current, layer);
274        }
275
276        // Search layer 0 with ef candidates
277        let candidates = self.search_layer(query, current, ef.max(k), 0);
278
279        // Return top k
280        candidates.into_iter().take(k).collect()
281    }
282
283    /// Search with a filter (bitset of allowed IDs)
284    pub fn search_filtered(
285        &self,
286        query: &[f32],
287        k: usize,
288        filter: &HashSet<NodeId>,
289    ) -> Vec<DistanceResult> {
290        self.search_filtered_with_ef(query, k, filter, self.config.ef_search)
291    }
292
293    /// Search with filter and custom ef
294    pub fn search_filtered_with_ef(
295        &self,
296        query: &[f32],
297        k: usize,
298        filter: &HashSet<NodeId>,
299        ef: usize,
300    ) -> Vec<DistanceResult> {
301        if self.entry_point.is_none() || filter.is_empty() {
302            return Vec::new();
303        }
304
305        let entry_point = self.entry_point.unwrap();
306        let mut current = entry_point;
307
308        // Traverse from top layer down to layer 1
309        for layer in (1..=self.max_layer).rev() {
310            current = self.search_layer_single(query, current, layer);
311        }
312
313        // Search layer 0 with filter
314        // We need to search with higher ef to account for filtered results
315        let expanded_ef = (ef * 2).max(k * 4);
316        let candidates = self.search_layer(query, current, expanded_ef, 0);
317
318        // Filter and return top k
319        candidates
320            .into_iter()
321            .filter(|r| filter.contains(&r.id))
322            .take(k)
323            .collect()
324    }
325
326    // =========================================================================
327    // Private methods
328    // =========================================================================
329
330    /// Generate a random layer using exponential distribution
331    fn random_layer(&mut self) -> usize {
332        // Simple xorshift64 PRNG
333        self.rng_state ^= self.rng_state << 13;
334        self.rng_state ^= self.rng_state >> 7;
335        self.rng_state ^= self.rng_state << 17;
336
337        // Convert to uniform [0, 1)
338        let uniform = (self.rng_state as f64) / (u64::MAX as f64);
339
340        // Exponential distribution: -ln(uniform) * ml
341
342        (-uniform.ln() * self.config.ml).floor() as usize
343    }
344
345    /// Search a single layer for the closest node (greedy)
346    fn search_layer_single(&self, query: &[f32], entry: NodeId, layer: usize) -> NodeId {
347        let mut current = entry;
348        let mut current_dist = self.compute_distance(query, current);
349
350        loop {
351            let mut changed = false;
352
353            if let Some(node) = self.nodes.get(&current) {
354                if layer < node.connections.len() {
355                    for &neighbor in &node.connections[layer] {
356                        let dist = self.compute_distance(query, neighbor);
357                        if dist < current_dist {
358                            current_dist = dist;
359                            current = neighbor;
360                            changed = true;
361                        }
362                    }
363                }
364            }
365
366            if !changed {
367                break;
368            }
369        }
370
371        current
372    }
373
374    /// Search a layer for ef nearest neighbors
375    fn search_layer(
376        &self,
377        query: &[f32],
378        entry: NodeId,
379        ef: usize,
380        layer: usize,
381    ) -> Vec<DistanceResult> {
382        let entry_dist = self.compute_distance(query, entry);
383
384        // Candidates: min-heap of nodes to explore (closest first)
385        let mut candidates: BinaryHeap<Reverse<DistanceResult>> = BinaryHeap::new();
386        candidates.push(Reverse(DistanceResult::new(entry, entry_dist)));
387
388        // Results: max-heap of found neighbors (furthest first for pruning)
389        let mut results: BinaryHeap<ReverseDistanceResult> = BinaryHeap::new();
390        results.push(ReverseDistanceResult(DistanceResult::new(
391            entry, entry_dist,
392        )));
393
394        // Visited set
395        let mut visited: HashSet<NodeId> = HashSet::new();
396        visited.insert(entry);
397
398        while let Some(Reverse(current)) = candidates.pop() {
399            // Get furthest result distance for pruning
400            let furthest_dist = results.peek().map(|r| r.0.distance).unwrap_or(f32::MAX);
401
402            // If current is further than our furthest result, we're done
403            if current.distance > furthest_dist {
404                break;
405            }
406
407            // Explore neighbors
408            if let Some(node) = self.nodes.get(&current.id) {
409                if layer < node.connections.len() {
410                    for &neighbor_id in &node.connections[layer] {
411                        if visited.contains(&neighbor_id) {
412                            continue;
413                        }
414                        visited.insert(neighbor_id);
415
416                        let dist = self.compute_distance(query, neighbor_id);
417                        let furthest_dist =
418                            results.peek().map(|r| r.0.distance).unwrap_or(f32::MAX);
419
420                        // If this neighbor is closer than our furthest result, or we need more results
421                        if dist < furthest_dist || results.len() < ef {
422                            candidates.push(Reverse(DistanceResult::new(neighbor_id, dist)));
423                            results.push(ReverseDistanceResult(DistanceResult::new(
424                                neighbor_id,
425                                dist,
426                            )));
427
428                            // Prune results to ef size
429                            while results.len() > ef {
430                                results.pop();
431                            }
432                        }
433                    }
434                }
435            }
436        }
437
438        // Convert results to sorted vector (closest first)
439        let mut result_vec: Vec<DistanceResult> = results.into_iter().map(|r| r.0).collect();
440        result_vec.sort_by(|a, b| cmp_f32(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
441        result_vec
442    }
443
444    /// Add a bidirectional connection, pruning if necessary
445    fn add_connection(&mut self, from: NodeId, to: NodeId, layer: usize) {
446        let max_connections = if layer == 0 {
447            self.config.m_max0
448        } else {
449            self.config.m
450        };
451
452        if let Some(node) = self.nodes.get_mut(&from) {
453            // Ensure we have enough layers
454            while node.connections.len() <= layer {
455                node.connections.push(Vec::new());
456            }
457
458            // Add connection if not already present
459            if !node.connections[layer].contains(&to) {
460                node.connections[layer].push(to);
461
462                // Prune if too many connections
463                if node.connections[layer].len() > max_connections {
464                    self.prune_connections(from, layer, max_connections);
465                }
466            }
467        }
468    }
469
470    /// Prune connections to max_connections using simple heuristic
471    fn prune_connections(&mut self, node_id: NodeId, layer: usize, max_connections: usize) {
472        let node_vector = self.nodes.get(&node_id).map(|n| n.vector.clone());
473        let node_vector = match node_vector {
474            Some(v) => v,
475            None => return,
476        };
477
478        // Get all neighbors with distances
479        let neighbors: Vec<NodeId> = self
480            .nodes
481            .get(&node_id)
482            .map(|n| n.connections[layer].clone())
483            .unwrap_or_default();
484
485        let mut scored: Vec<DistanceResult> = neighbors
486            .iter()
487            .map(|&neighbor_id| {
488                let dist = self.compute_distance(&node_vector, neighbor_id);
489                DistanceResult::new(neighbor_id, dist)
490            })
491            .collect();
492
493        // Sort by distance (closest first)
494        scored.sort_by(|a, b| cmp_f32(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
495
496        // Keep only max_connections closest
497        let kept: Vec<NodeId> = scored
498            .into_iter()
499            .take(max_connections)
500            .map(|r| r.id)
501            .collect();
502
503        if let Some(node) = self.nodes.get_mut(&node_id) {
504            node.connections[layer] = kept;
505        }
506    }
507
508    /// Compute distance between query and a node (SIMD-accelerated)
509    fn compute_distance(&self, query: &[f32], node_id: NodeId) -> f32 {
510        match self.nodes.get(&node_id) {
511            Some(node) => distance_simd(query, &node.vector, self.config.metric),
512            None => f32::MAX,
513        }
514    }
515
516    // ========================================================================
517    // Batch Operations
518    // ========================================================================
519
520    /// Insert multiple vectors at once
521    ///
522    /// More efficient than individual inserts for large batches
523    pub fn insert_batch(&mut self, vectors: Vec<Vec<f32>>) -> Vec<NodeId> {
524        vectors.into_iter().map(|v| self.insert(v)).collect()
525    }
526
527    /// Insert multiple vectors with specific IDs
528    pub fn insert_batch_with_ids(&mut self, items: Vec<(NodeId, Vec<f32>)>) {
529        for (id, vector) in items {
530            self.insert_with_id(id, vector);
531        }
532    }
533
534    // ========================================================================
535    // Delete Operations
536    // ========================================================================
537
538    /// Remove a node from the index
539    ///
540    /// Note: This is a "soft" delete - the node is marked as deleted and
541    /// excluded from search results, but connections are not fully repaired.
542    /// For better performance after many deletes, rebuild the index.
543    pub fn delete(&mut self, id: NodeId) -> bool {
544        if self.nodes.remove(&id).is_none() {
545            return false;
546        }
547
548        // Update entry point if we deleted it
549        if self.entry_point == Some(id) {
550            self.entry_point = self.nodes.keys().next().copied();
551
552            // Update max_layer based on new entry point
553            if let Some(ep) = self.entry_point {
554                self.max_layer = self.nodes.get(&ep).map(|n| n.max_layer).unwrap_or(0);
555            } else {
556                self.max_layer = 0;
557            }
558        }
559
560        // Remove references to this node from other nodes' connections
561        for node in self.nodes.values_mut() {
562            for layer_connections in node.connections.iter_mut() {
563                layer_connections.retain(|&neighbor| neighbor != id);
564            }
565        }
566
567        true
568    }
569
570    /// Check if a node exists
571    pub fn contains(&self, id: NodeId) -> bool {
572        self.nodes.contains_key(&id)
573    }
574
575    // ========================================================================
576    // Adaptive Search
577    // ========================================================================
578
579    /// Search with adaptive ef based on index size
580    ///
581    /// Automatically adjusts ef_search based on the number of results needed
582    /// and the size of the index for better accuracy/speed tradeoff.
583    pub fn search_adaptive(&self, query: &[f32], k: usize) -> Vec<DistanceResult> {
584        // Adaptive ef: at least k, scales with log of index size
585        let n = self.nodes.len();
586        let adaptive_ef = if n < 100 {
587            k.max(10)
588        } else if n < 10000 {
589            k.max(50)
590        } else if n < 100000 {
591            k.max(100)
592        } else {
593            k.max(200)
594        };
595
596        self.search_with_ef(query, k, adaptive_ef)
597    }
598
599    // ========================================================================
600    // Index Statistics
601    // ========================================================================
602
603    /// Get statistics about the index
604    pub fn stats(&self) -> HnswStats {
605        let mut layer_counts = vec![0usize; self.max_layer + 1];
606        let mut total_connections = 0usize;
607        let mut max_connections = 0usize;
608        let mut min_connections = usize::MAX;
609
610        for node in self.nodes.values() {
611            for (layer, layer_count) in layer_counts.iter_mut().enumerate().take(node.max_layer + 1)
612            {
613                *layer_count += 1;
614                let conns = node.connections.get(layer).map(|c| c.len()).unwrap_or(0);
615                total_connections += conns;
616                max_connections = max_connections.max(conns);
617                if conns > 0 {
618                    min_connections = min_connections.min(conns);
619                }
620            }
621        }
622
623        if self.nodes.is_empty() {
624            min_connections = 0;
625        }
626
627        HnswStats {
628            node_count: self.nodes.len(),
629            dimension: self.dimension,
630            max_layer: self.max_layer,
631            layer_counts,
632            total_connections,
633            avg_connections: if self.nodes.is_empty() {
634                0.0
635            } else {
636                total_connections as f64 / self.nodes.len() as f64
637            },
638            max_connections,
639            min_connections,
640            entry_point: self.entry_point,
641        }
642    }
643
644    // ========================================================================
645    // Persistence
646    // ========================================================================
647
648    /// Serialize the index to bytes for storage
649    pub fn to_bytes(&self) -> Vec<u8> {
650        let mut bytes = Vec::new();
651
652        // Magic number and version
653        bytes.extend_from_slice(b"HNSW");
654        bytes.extend_from_slice(&1u32.to_le_bytes()); // version
655
656        // Config
657        bytes.extend_from_slice(&(self.dimension as u32).to_le_bytes());
658        bytes.extend_from_slice(&(self.config.m as u32).to_le_bytes());
659        bytes.extend_from_slice(&(self.config.m_max0 as u32).to_le_bytes());
660        bytes.extend_from_slice(&(self.config.ef_construction as u32).to_le_bytes());
661        bytes.extend_from_slice(&(self.config.ef_search as u32).to_le_bytes());
662        bytes.extend_from_slice(&self.config.ml.to_le_bytes());
663        bytes.push(match self.config.metric {
664            DistanceMetric::L2 => 0,
665            DistanceMetric::Cosine => 1,
666            DistanceMetric::InnerProduct => 2,
667        });
668
669        // Index state
670        bytes.extend_from_slice(&(self.max_layer as u32).to_le_bytes());
671        bytes.extend_from_slice(&self.entry_point.unwrap_or(u64::MAX).to_le_bytes());
672
673        // Node count
674        bytes.extend_from_slice(&(self.nodes.len() as u64).to_le_bytes());
675
676        // Nodes
677        for (&id, node) in &self.nodes {
678            bytes.extend_from_slice(&id.to_le_bytes());
679            bytes.extend_from_slice(&(node.max_layer as u32).to_le_bytes());
680
681            // Vector
682            for &val in &node.vector {
683                bytes.extend_from_slice(&val.to_le_bytes());
684            }
685
686            // Connections per layer
687            for layer in 0..=node.max_layer {
688                let conns = &node.connections[layer];
689                bytes.extend_from_slice(&(conns.len() as u32).to_le_bytes());
690                for &conn in conns {
691                    bytes.extend_from_slice(&conn.to_le_bytes());
692                }
693            }
694        }
695
696        bytes
697    }
698
699    /// Deserialize index from bytes
700    pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
701        if bytes.len() < 8 {
702            return Err("Data too short".to_string());
703        }
704
705        // Check magic number
706        if &bytes[0..4] != b"HNSW" {
707            return Err("Invalid magic number".to_string());
708        }
709
710        let version = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
711        if version != 1 {
712            return Err(format!("Unsupported version: {}", version));
713        }
714
715        let mut pos = 8;
716
717        // Config
718        let dimension = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
719        pos += 4;
720        let m = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
721        pos += 4;
722        let m_max0 = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
723        pos += 4;
724        let ef_construction = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
725        pos += 4;
726        let ef_search = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
727        pos += 4;
728        let ml = f64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap());
729        pos += 8;
730        let metric = match bytes[pos] {
731            0 => DistanceMetric::L2,
732            1 => DistanceMetric::Cosine,
733            2 => DistanceMetric::InnerProduct,
734            _ => return Err("Invalid distance metric".to_string()),
735        };
736        pos += 1;
737
738        let config = HnswConfig {
739            m,
740            m_max0,
741            ef_construction,
742            ef_search,
743            ml,
744            metric,
745        };
746
747        // Index state
748        let max_layer = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
749        pos += 4;
750        let ep_value = u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap());
751        pos += 8;
752        let entry_point = if ep_value == u64::MAX {
753            None
754        } else {
755            Some(ep_value)
756        };
757
758        // Node count
759        let node_count = u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap()) as usize;
760        pos += 8;
761
762        let mut nodes = HashMap::new();
763        let mut max_id = 0u64;
764
765        for _ in 0..node_count {
766            let id = u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap());
767            pos += 8;
768            max_id = max_id.max(id);
769
770            let level = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
771            pos += 4;
772
773            let mut vector = Vec::with_capacity(dimension);
774            for _ in 0..dimension {
775                vector.push(f32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()));
776                pos += 4;
777            }
778
779            let mut connections = vec![Vec::new(); level + 1];
780            for conn_list in connections.iter_mut().take(level + 1) {
781                let conn_count =
782                    u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()) as usize;
783                pos += 4;
784
785                for _ in 0..conn_count {
786                    let conn = u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap());
787                    pos += 8;
788                    conn_list.push(conn);
789                }
790            }
791
792            nodes.insert(
793                id,
794                HnswNode {
795                    id,
796                    max_layer: level,
797                    vector,
798                    connections,
799                },
800            );
801        }
802
803        Ok(Self {
804            config,
805            nodes,
806            entry_point,
807            max_layer,
808            dimension,
809            next_id: AtomicU64::new(max_id + 1),
810            rng_state: 12345, // Reset RNG
811        })
812    }
813}
814
815/// Statistics about an HNSW index
816#[derive(Debug, Clone)]
817pub struct HnswStats {
818    /// Number of vectors in the index
819    pub node_count: usize,
820    /// Vector dimension
821    pub dimension: usize,
822    /// Maximum layer in the graph
823    pub max_layer: usize,
824    /// Number of nodes per layer
825    pub layer_counts: Vec<usize>,
826    /// Total number of connections
827    pub total_connections: usize,
828    /// Average connections per node
829    pub avg_connections: f64,
830    /// Maximum connections on any node
831    pub max_connections: usize,
832    /// Minimum connections on any node
833    pub min_connections: usize,
834    /// Entry point node ID
835    pub entry_point: Option<NodeId>,
836}
837
838/// Bitset for efficient filtering
839#[derive(Debug, Clone)]
840pub struct Bitset {
841    bits: Vec<u64>,
842    len: usize,
843}
844
845impl Bitset {
846    /// Create a new bitset with capacity for n elements
847    pub fn with_capacity(n: usize) -> Self {
848        let num_words = n.div_ceil(64);
849        Self {
850            bits: vec![0; num_words],
851            len: n,
852        }
853    }
854
855    /// Create a bitset with all bits set
856    pub fn all(n: usize) -> Self {
857        let num_words = n.div_ceil(64);
858        let mut bits = vec![u64::MAX; num_words];
859
860        // Clear excess bits in last word
861        if !n.is_multiple_of(64) {
862            let last_idx = num_words - 1;
863            let valid_bits = n % 64;
864            bits[last_idx] = (1u64 << valid_bits) - 1;
865        }
866
867        Self { bits, len: n }
868    }
869
870    /// Set a bit
871    pub fn set(&mut self, idx: usize) {
872        if idx < self.len {
873            let word = idx / 64;
874            let bit = idx % 64;
875            self.bits[word] |= 1u64 << bit;
876        }
877    }
878
879    /// Clear a bit
880    pub fn clear(&mut self, idx: usize) {
881        if idx < self.len {
882            let word = idx / 64;
883            let bit = idx % 64;
884            self.bits[word] &= !(1u64 << bit);
885        }
886    }
887
888    /// Check if a bit is set
889    pub fn is_set(&self, idx: usize) -> bool {
890        if idx >= self.len {
891            return false;
892        }
893        let word = idx / 64;
894        let bit = idx % 64;
895        (self.bits[word] & (1u64 << bit)) != 0
896    }
897
898    /// Convert to HashSet for use with HNSW filter
899    pub fn to_hashset(&self) -> HashSet<NodeId> {
900        let mut set = HashSet::new();
901        for i in 0..self.len {
902            if self.is_set(i) {
903                set.insert(i as NodeId);
904            }
905        }
906        set
907    }
908}
909
910#[cfg(test)]
911mod tests {
912    use super::*;
913
914    fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
915        let mut state = seed;
916        (0..dim)
917            .map(|_| {
918                state ^= state << 13;
919                state ^= state >> 7;
920                state ^= state << 17;
921                (state as f32) / (u64::MAX as f32)
922            })
923            .collect()
924    }
925
926    #[test]
927    fn test_empty_index() {
928        let index = HnswIndex::with_dimension(128);
929        assert!(index.is_empty());
930        assert_eq!(index.len(), 0);
931
932        let results = index.search(&vec![0.0; 128], 10);
933        assert!(results.is_empty());
934    }
935
936    #[test]
937    fn test_single_insert() {
938        let mut index = HnswIndex::with_dimension(3);
939        let id = index.insert(vec![1.0, 2.0, 3.0]);
940
941        assert_eq!(index.len(), 1);
942        assert!(!index.is_empty());
943        assert!(index.get_vector(id).is_some());
944    }
945
946    #[test]
947    fn test_exact_match() {
948        let mut index = HnswIndex::with_dimension(3);
949        index.insert(vec![1.0, 0.0, 0.0]);
950        index.insert(vec![0.0, 1.0, 0.0]);
951        index.insert(vec![0.0, 0.0, 1.0]);
952
953        // Search for exact match
954        let results = index.search(&[1.0, 0.0, 0.0], 1);
955        assert_eq!(results.len(), 1);
956        assert_eq!(results[0].distance, 0.0);
957    }
958
959    #[test]
960    fn test_nearest_neighbor() {
961        let mut index = HnswIndex::with_dimension(2);
962        index.insert_with_id(0, vec![0.0, 0.0]);
963        index.insert_with_id(1, vec![1.0, 0.0]);
964        index.insert_with_id(2, vec![2.0, 0.0]);
965        index.insert_with_id(3, vec![3.0, 0.0]);
966
967        // Search for something close to (0.9, 0.0)
968        let results = index.search(&[0.9, 0.0], 1);
969        assert_eq!(results.len(), 1);
970        assert_eq!(results[0].id, 1); // Should find (1.0, 0.0)
971    }
972
973    #[test]
974    fn test_k_nearest() {
975        let mut index = HnswIndex::with_dimension(2);
976        for i in 0..10 {
977            index.insert_with_id(i, vec![i as f32, 0.0]);
978        }
979
980        // Search for 3 nearest to (4.5, 0.0)
981        let results = index.search(&[4.5, 0.0], 3);
982        assert_eq!(results.len(), 3);
983
984        // Should find 4, 5, and either 3 or 6
985        let ids: HashSet<_> = results.iter().map(|r| r.id).collect();
986        assert!(ids.contains(&4));
987        assert!(ids.contains(&5));
988    }
989
990    #[test]
991    fn test_filtered_search() {
992        let mut index = HnswIndex::with_dimension(2);
993        for i in 0..10 {
994            index.insert_with_id(i, vec![i as f32, 0.0]);
995        }
996
997        // Only allow even IDs
998        let filter: HashSet<NodeId> = [0, 2, 4, 6, 8].iter().copied().collect();
999
1000        // Search for nearest to (5.0, 0.0) with filter
1001        let results = index.search_filtered(&[5.0, 0.0], 2, &filter);
1002
1003        // Should find 4 and 6 (closest even numbers to 5)
1004        assert_eq!(results.len(), 2);
1005        let ids: HashSet<_> = results.iter().map(|r| r.id).collect();
1006        assert!(ids.contains(&4));
1007        assert!(ids.contains(&6));
1008    }
1009
1010    #[test]
1011    fn test_cosine_distance() {
1012        let config = HnswConfig::default().with_metric(DistanceMetric::Cosine);
1013        let mut index = HnswIndex::new(3, config);
1014
1015        // Insert normalized vectors
1016        index.insert_with_id(0, vec![1.0, 0.0, 0.0]);
1017        index.insert_with_id(1, vec![0.0, 1.0, 0.0]);
1018        index.insert_with_id(2, vec![0.707, 0.707, 0.0]); // 45 degrees
1019
1020        // Search for something at 45 degrees
1021        let results = index.search(&[0.707, 0.707, 0.0], 1);
1022        assert_eq!(results[0].id, 2);
1023    }
1024
1025    #[test]
1026    fn test_many_vectors() {
1027        let dim = 64;
1028        let n: usize = 1000;
1029
1030        let mut index = HnswIndex::with_dimension(dim);
1031
1032        // Insert many random vectors
1033        for i in 0..n {
1034            let vector = random_vector(dim, i as u64);
1035            index.insert_with_id(i as u64, vector);
1036        }
1037
1038        assert_eq!(index.len(), n);
1039
1040        // Search should return k results
1041        let query = random_vector(dim, 12345);
1042        let results = index.search(&query, 10);
1043        assert_eq!(results.len(), 10);
1044
1045        // Results should be sorted by distance
1046        for i in 1..results.len() {
1047            assert!(results[i].distance >= results[i - 1].distance);
1048        }
1049    }
1050
1051    #[test]
1052    fn test_bitset() {
1053        let mut bs = Bitset::with_capacity(100);
1054
1055        bs.set(0);
1056        bs.set(50);
1057        bs.set(99);
1058
1059        assert!(bs.is_set(0));
1060        assert!(bs.is_set(50));
1061        assert!(bs.is_set(99));
1062        assert!(!bs.is_set(1));
1063        assert!(!bs.is_set(64));
1064
1065        bs.clear(50);
1066        assert!(!bs.is_set(50));
1067    }
1068
1069    #[test]
1070    fn test_bitset_all() {
1071        let bs = Bitset::all(100);
1072
1073        for i in 0..100 {
1074            assert!(bs.is_set(i));
1075        }
1076        assert!(!bs.is_set(100)); // Out of bounds
1077    }
1078}