Skip to main content

reddb_server/storage/engine/
hnsw.rs

1//! HNSW (Hierarchical Navigable Small World) Index
2//!
3//! A from-scratch implementation of the HNSW algorithm for approximate
4//! nearest neighbor search. No external dependencies.
5//!
6//! # Algorithm Overview
7//!
8//! HNSW builds a multi-layer graph where:
9//! - Layer 0 contains all nodes
10//! - Higher layers contain progressively fewer nodes
11//! - Each layer is a navigable small world graph
12//!
13//! Search starts from the top layer and greedily descends,
14//! using each layer to quickly approach the target region.
15//!
16//! # References
17//!
18//! - Original paper: "Efficient and robust approximate nearest neighbor search
19//!   using Hierarchical Navigable Small World graphs" (Malkov & Yashunin, 2018)
20
21use std::cmp::Reverse;
22use std::collections::{BinaryHeap, HashMap, HashSet};
23use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
24
25use super::distance::{
26    cmp_f32, distance_simd, DistanceMetric, DistanceResult, ReverseDistanceResult,
27};
28
29/// Node identifier in the HNSW graph
30pub type NodeId = u64;
31
32/// HNSW index configuration parameters
33#[derive(Debug, Clone)]
34pub struct HnswConfig {
35    /// Maximum number of connections per node (except layer 0)
36    pub m: usize,
37    /// Maximum connections at layer 0 (typically 2*M)
38    pub m_max0: usize,
39    /// Size of dynamic candidate list during construction
40    pub ef_construction: usize,
41    /// Size of dynamic candidate list during search (can be adjusted)
42    pub ef_search: usize,
43    /// Normalization factor for layer assignment (1/ln(M))
44    pub ml: f64,
45    /// Distance metric to use
46    pub metric: DistanceMetric,
47}
48
49impl Default for HnswConfig {
50    fn default() -> Self {
51        let m = 16;
52        Self {
53            m,
54            m_max0: m * 2,
55            ef_construction: 100,
56            ef_search: 50,
57            ml: 1.0 / (m as f64).ln(),
58            metric: DistanceMetric::L2,
59        }
60    }
61}
62
63impl HnswConfig {
64    /// Create a new configuration with custom M value
65    pub fn with_m(m: usize) -> Self {
66        Self {
67            m,
68            m_max0: m * 2,
69            ml: 1.0 / (m as f64).ln(),
70            ..Default::default()
71        }
72    }
73
74    /// Set the distance metric
75    pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
76        self.metric = metric;
77        self
78    }
79
80    /// Set ef_construction
81    pub fn with_ef_construction(mut self, ef: usize) -> Self {
82        self.ef_construction = ef;
83        self
84    }
85
86    /// Set ef_search
87    pub fn with_ef_search(mut self, ef: usize) -> Self {
88        self.ef_search = ef;
89        self
90    }
91}
92
93/// A node in the HNSW graph
94#[derive(Debug, Clone)]
95struct HnswNode {
96    /// The node's ID
97    id: NodeId,
98    /// The node's vector data
99    vector: Vec<f32>,
100    /// Maximum layer this node appears in
101    max_layer: usize,
102    /// Connections at each layer (layer index -> neighbor IDs)
103    connections: Vec<Vec<NodeId>>,
104}
105
106impl HnswNode {
107    fn new(id: NodeId, vector: Vec<f32>, max_layer: usize) -> Self {
108        let mut connections = Vec::with_capacity(max_layer + 1);
109        for _ in 0..=max_layer {
110            connections.push(Vec::new());
111        }
112        Self {
113            id,
114            vector,
115            max_layer,
116            connections,
117        }
118    }
119}
120
121/// HNSW Index for approximate nearest neighbor search
122pub struct HnswIndex {
123    /// Configuration parameters
124    config: HnswConfig,
125    /// All nodes in the index
126    nodes: HashMap<NodeId, HnswNode>,
127    /// Entry point (node with highest layer)
128    entry_point: Option<NodeId>,
129    /// Maximum layer in the graph
130    max_layer: usize,
131    /// Vector dimension
132    dimension: usize,
133    /// Next available node ID
134    next_id: AtomicU64,
135    /// Simple RNG state for layer assignment
136    rng_state: u64,
137}
138
139impl HnswIndex {
140    /// Create a new empty HNSW index
141    pub fn new(dimension: usize, config: HnswConfig) -> Self {
142        Self {
143            config,
144            nodes: HashMap::new(),
145            entry_point: None,
146            max_layer: 0,
147            dimension,
148            next_id: AtomicU64::new(0),
149            rng_state: 0x853c49e6748fea9b, // Random seed
150        }
151    }
152
153    /// Create with default configuration
154    pub fn with_dimension(dimension: usize) -> Self {
155        Self::new(dimension, HnswConfig::default())
156    }
157
158    /// Get the number of vectors in the index
159    pub fn len(&self) -> usize {
160        self.nodes.len()
161    }
162
163    /// Check if the index is empty
164    pub fn is_empty(&self) -> bool {
165        self.nodes.is_empty()
166    }
167
168    /// Get the vector for a node ID
169    pub fn get_vector(&self, id: NodeId) -> Option<&[f32]> {
170        self.nodes.get(&id).map(|n| n.vector.as_slice())
171    }
172
173    /// Insert a vector and return its assigned ID
174    pub fn insert(&mut self, vector: Vec<f32>) -> NodeId {
175        let id = self.next_id.fetch_add(1, AtomicOrdering::SeqCst);
176        self.insert_with_id(id, vector);
177        id
178    }
179
180    /// Insert a vector with a specific ID
181    pub fn insert_with_id(&mut self, id: NodeId, vector: Vec<f32>) {
182        assert_eq!(
183            vector.len(),
184            self.dimension,
185            "Vector dimension mismatch: expected {}, got {}",
186            self.dimension,
187            vector.len()
188        );
189
190        // Assign random layer using exponential distribution
191        let node_layer = self.random_layer();
192
193        // Create the node
194        let node = HnswNode::new(id, vector, node_layer);
195
196        if self.entry_point.is_none() {
197            // First node - just add it
198            self.nodes.insert(id, node);
199            self.entry_point = Some(id);
200            self.max_layer = node_layer;
201            return;
202        }
203
204        let entry_point = self.entry_point.unwrap();
205        let vector = self.nodes.get(&id).map(|n| n.vector.clone());
206
207        // We need to insert the node first so we can access its vector
208        // But we need to find neighbors first... let's clone the vector
209        let vector = node.vector.clone();
210        self.nodes.insert(id, node);
211
212        // Find entry point for search
213        let mut current = entry_point;
214
215        // Traverse from top layer down to node_layer + 1
216        // This finds the best entry point for the insertion layer
217        for layer in (node_layer + 1..=self.max_layer).rev() {
218            current = self.search_layer_single(&vector, current, layer);
219        }
220
221        // For layers from node_layer down to 0, find and connect to neighbors
222        for layer in (0..=node_layer.min(self.max_layer)).rev() {
223            // Find ef_construction nearest neighbors at this layer
224            let neighbors = self.search_layer(&vector, current, self.config.ef_construction, layer);
225
226            // Select M best neighbors
227            let m = if layer == 0 {
228                self.config.m_max0
229            } else {
230                self.config.m
231            };
232            let selected: Vec<NodeId> = neighbors.into_iter().take(m).map(|r| r.id).collect();
233
234            // Connect node to selected neighbors
235            if let Some(node) = self.nodes.get_mut(&id) {
236                node.connections[layer] = selected.clone();
237            }
238
239            // Add bidirectional connections
240            for &neighbor_id in &selected {
241                self.add_connection(neighbor_id, id, layer);
242            }
243
244            // Update current for next layer
245            if let Some(&first) = selected.first() {
246                current = first;
247            }
248        }
249
250        // Update entry point if new node has higher layer
251        if node_layer > self.max_layer {
252            self.entry_point = Some(id);
253            self.max_layer = node_layer;
254        }
255    }
256
257    /// Search for k nearest neighbors
258    pub fn search(&self, query: &[f32], k: usize) -> Vec<DistanceResult> {
259        self.search_with_ef(query, k, self.config.ef_search)
260    }
261
262    /// Search with custom ef parameter
263    pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<DistanceResult> {
264        if self.entry_point.is_none() {
265            return Vec::new();
266        }
267
268        let entry_point = self.entry_point.unwrap();
269        let mut current = entry_point;
270
271        // Traverse from top layer down to layer 1
272        for layer in (1..=self.max_layer).rev() {
273            current = self.search_layer_single(query, current, layer);
274        }
275
276        // Search layer 0 with ef candidates
277        let candidates = self.search_layer(query, current, ef.max(k), 0);
278
279        // Return top k
280        candidates.into_iter().take(k).collect()
281    }
282
283    /// Search with a filter (bitset of allowed IDs)
284    pub fn search_filtered(
285        &self,
286        query: &[f32],
287        k: usize,
288        filter: &HashSet<NodeId>,
289    ) -> Vec<DistanceResult> {
290        self.search_filtered_with_ef(query, k, filter, self.config.ef_search)
291    }
292
293    /// Search with filter and custom ef
294    pub fn search_filtered_with_ef(
295        &self,
296        query: &[f32],
297        k: usize,
298        filter: &HashSet<NodeId>,
299        ef: usize,
300    ) -> Vec<DistanceResult> {
301        if self.entry_point.is_none() || filter.is_empty() {
302            return Vec::new();
303        }
304
305        let entry_point = self.entry_point.unwrap();
306        let mut current = entry_point;
307
308        // Traverse from top layer down to layer 1
309        for layer in (1..=self.max_layer).rev() {
310            current = self.search_layer_single(query, current, layer);
311        }
312
313        // Search layer 0 with filter
314        // We need to search with higher ef to account for filtered results
315        let expanded_ef = (ef * 2).max(k * 4);
316        let candidates = self.search_layer(query, current, expanded_ef, 0);
317
318        // Filter and return top k
319        candidates
320            .into_iter()
321            .filter(|r| filter.contains(&r.id))
322            .take(k)
323            .collect()
324    }
325
326    // =========================================================================
327    // Private methods
328    // =========================================================================
329
330    /// Generate a random layer using exponential distribution
331    fn random_layer(&mut self) -> usize {
332        // Simple xorshift64 PRNG
333        self.rng_state ^= self.rng_state << 13;
334        self.rng_state ^= self.rng_state >> 7;
335        self.rng_state ^= self.rng_state << 17;
336
337        // Convert to uniform [0, 1)
338        let uniform = (self.rng_state as f64) / (u64::MAX as f64);
339
340        // Exponential distribution: -ln(uniform) * ml
341
342        (-uniform.ln() * self.config.ml).floor() as usize
343    }
344
345    /// Search a single layer for the closest node (greedy)
346    fn search_layer_single(&self, query: &[f32], entry: NodeId, layer: usize) -> NodeId {
347        let mut current = entry;
348        let mut current_dist = self.compute_distance(query, current);
349
350        loop {
351            let mut changed = false;
352
353            if let Some(node) = self.nodes.get(&current) {
354                if layer < node.connections.len() {
355                    for &neighbor in &node.connections[layer] {
356                        let dist = self.compute_distance(query, neighbor);
357                        if dist < current_dist {
358                            current_dist = dist;
359                            current = neighbor;
360                            changed = true;
361                        }
362                    }
363                }
364            }
365
366            if !changed {
367                break;
368            }
369        }
370
371        current
372    }
373
374    /// Search a layer for ef nearest neighbors
375    fn search_layer(
376        &self,
377        query: &[f32],
378        entry: NodeId,
379        ef: usize,
380        layer: usize,
381    ) -> Vec<DistanceResult> {
382        let entry_dist = self.compute_distance(query, entry);
383
384        // Candidates: min-heap of nodes to explore (closest first)
385        let mut candidates: BinaryHeap<Reverse<DistanceResult>> = BinaryHeap::new();
386        candidates.push(Reverse(DistanceResult::new(entry, entry_dist)));
387
388        // Results: max-heap of found neighbors (furthest first for pruning)
389        let mut results: BinaryHeap<ReverseDistanceResult> = BinaryHeap::new();
390        results.push(ReverseDistanceResult(DistanceResult::new(
391            entry, entry_dist,
392        )));
393
394        // Visited set
395        let mut visited: HashSet<NodeId> = HashSet::new();
396        visited.insert(entry);
397
398        while let Some(Reverse(current)) = candidates.pop() {
399            // Get furthest result distance for pruning
400            let furthest_dist = results.peek().map(|r| r.0.distance).unwrap_or(f32::MAX);
401
402            // If current is further than our furthest result, we're done
403            if current.distance > furthest_dist {
404                break;
405            }
406
407            // Explore neighbors
408            if let Some(node) = self.nodes.get(&current.id) {
409                if layer < node.connections.len() {
410                    for &neighbor_id in &node.connections[layer] {
411                        if visited.contains(&neighbor_id) {
412                            continue;
413                        }
414                        visited.insert(neighbor_id);
415
416                        let dist = self.compute_distance(query, neighbor_id);
417                        let furthest_dist =
418                            results.peek().map(|r| r.0.distance).unwrap_or(f32::MAX);
419
420                        // If this neighbor is closer than our furthest result, or we need more results
421                        if dist < furthest_dist || results.len() < ef {
422                            candidates.push(Reverse(DistanceResult::new(neighbor_id, dist)));
423                            results.push(ReverseDistanceResult(DistanceResult::new(
424                                neighbor_id,
425                                dist,
426                            )));
427
428                            // Prune results to ef size
429                            while results.len() > ef {
430                                results.pop();
431                            }
432                        }
433                    }
434                }
435            }
436        }
437
438        // Convert results to sorted vector (closest first)
439        let mut result_vec: Vec<DistanceResult> = results.into_iter().map(|r| r.0).collect();
440        result_vec.sort_by(|a, b| cmp_f32(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
441        result_vec
442    }
443
444    /// Add a bidirectional connection, pruning if necessary
445    fn add_connection(&mut self, from: NodeId, to: NodeId, layer: usize) {
446        let max_connections = if layer == 0 {
447            self.config.m_max0
448        } else {
449            self.config.m
450        };
451
452        if let Some(node) = self.nodes.get_mut(&from) {
453            // Ensure we have enough layers
454            while node.connections.len() <= layer {
455                node.connections.push(Vec::new());
456            }
457
458            // Add connection if not already present
459            if !node.connections[layer].contains(&to) {
460                node.connections[layer].push(to);
461
462                // Prune if too many connections
463                if node.connections[layer].len() > max_connections {
464                    self.prune_connections(from, layer, max_connections);
465                }
466            }
467        }
468    }
469
470    /// Prune connections to max_connections using simple heuristic
471    fn prune_connections(&mut self, node_id: NodeId, layer: usize, max_connections: usize) {
472        let node_vector = self.nodes.get(&node_id).map(|n| n.vector.clone());
473        let node_vector = match node_vector {
474            Some(v) => v,
475            None => return,
476        };
477
478        // Get all neighbors with distances
479        let neighbors: Vec<NodeId> = self
480            .nodes
481            .get(&node_id)
482            .map(|n| n.connections[layer].clone())
483            .unwrap_or_default();
484
485        let mut scored: Vec<DistanceResult> = neighbors
486            .iter()
487            .map(|&neighbor_id| {
488                let dist = self.compute_distance(&node_vector, neighbor_id);
489                DistanceResult::new(neighbor_id, dist)
490            })
491            .collect();
492
493        // Sort by distance (closest first)
494        scored.sort_by(|a, b| cmp_f32(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
495
496        // Keep only max_connections closest
497        let kept: Vec<NodeId> = scored
498            .into_iter()
499            .take(max_connections)
500            .map(|r| r.id)
501            .collect();
502
503        if let Some(node) = self.nodes.get_mut(&node_id) {
504            node.connections[layer] = kept;
505        }
506    }
507
508    /// Compute distance between query and a node (SIMD-accelerated)
509    fn compute_distance(&self, query: &[f32], node_id: NodeId) -> f32 {
510        match self.nodes.get(&node_id) {
511            Some(node) => distance_simd(query, &node.vector, self.config.metric),
512            None => f32::MAX,
513        }
514    }
515
516    // ========================================================================
517    // Batch Operations
518    // ========================================================================
519
520    /// Insert multiple vectors at once
521    ///
522    /// More efficient than individual inserts for large batches
523    pub fn insert_batch(&mut self, vectors: Vec<Vec<f32>>) -> Vec<NodeId> {
524        vectors.into_iter().map(|v| self.insert(v)).collect()
525    }
526
527    /// Insert multiple vectors with specific IDs
528    pub fn insert_batch_with_ids(&mut self, items: Vec<(NodeId, Vec<f32>)>) {
529        for (id, vector) in items {
530            self.insert_with_id(id, vector);
531        }
532    }
533
534    // ========================================================================
535    // Delete Operations
536    // ========================================================================
537
538    /// Remove a node from the index
539    ///
540    /// Note: This is a "soft" delete - the node is marked as deleted and
541    /// excluded from search results, but connections are not fully repaired.
542    /// For better performance after many deletes, rebuild the index.
543    pub fn delete(&mut self, id: NodeId) -> bool {
544        if self.nodes.remove(&id).is_none() {
545            return false;
546        }
547
548        // Update entry point if we deleted it
549        if self.entry_point == Some(id) {
550            self.entry_point = self.nodes.keys().next().copied();
551
552            // Update max_layer based on new entry point
553            if let Some(ep) = self.entry_point {
554                self.max_layer = self.nodes.get(&ep).map(|n| n.max_layer).unwrap_or(0);
555            } else {
556                self.max_layer = 0;
557            }
558        }
559
560        // Remove references to this node from other nodes' connections
561        for node in self.nodes.values_mut() {
562            for layer_connections in node.connections.iter_mut() {
563                layer_connections.retain(|&neighbor| neighbor != id);
564            }
565        }
566
567        true
568    }
569
570    /// Check if a node exists
571    pub fn contains(&self, id: NodeId) -> bool {
572        self.nodes.contains_key(&id)
573    }
574
575    // ========================================================================
576    // Adaptive Search
577    // ========================================================================
578
579    /// Search with adaptive ef based on index size
580    ///
581    /// Automatically adjusts ef_search based on the number of results needed
582    /// and the size of the index for better accuracy/speed tradeoff.
583    pub fn search_adaptive(&self, query: &[f32], k: usize) -> Vec<DistanceResult> {
584        // Adaptive ef: at least k, scales with log of index size
585        let n = self.nodes.len();
586        let adaptive_ef = if n < 100 {
587            k.max(10)
588        } else if n < 10000 {
589            k.max(50)
590        } else if n < 100000 {
591            k.max(100)
592        } else {
593            k.max(200)
594        };
595
596        self.search_with_ef(query, k, adaptive_ef)
597    }
598
599    // ========================================================================
600    // Index Statistics
601    // ========================================================================
602
603    /// Get statistics about the index
604    pub fn stats(&self) -> HnswStats {
605        let mut layer_counts = vec![0usize; self.max_layer + 1];
606        let mut total_connections = 0usize;
607        let mut max_connections = 0usize;
608        let mut min_connections = usize::MAX;
609
610        for node in self.nodes.values() {
611            for (layer, layer_count) in layer_counts.iter_mut().enumerate().take(node.max_layer + 1)
612            {
613                *layer_count += 1;
614                let conns = node.connections.get(layer).map(|c| c.len()).unwrap_or(0);
615                total_connections += conns;
616                max_connections = max_connections.max(conns);
617                if conns > 0 {
618                    min_connections = min_connections.min(conns);
619                }
620            }
621        }
622
623        if self.nodes.is_empty() {
624            min_connections = 0;
625        }
626
627        HnswStats {
628            node_count: self.nodes.len(),
629            dimension: self.dimension,
630            max_layer: self.max_layer,
631            layer_counts,
632            total_connections,
633            avg_connections: if self.nodes.is_empty() {
634                0.0
635            } else {
636                total_connections as f64 / self.nodes.len() as f64
637            },
638            max_connections,
639            min_connections,
640            entry_point: self.entry_point,
641        }
642    }
643
644    // ========================================================================
645    // Persistence
646    // ========================================================================
647
648    /// Serialize the index to bytes for storage.
649    ///
650    /// The `HNSW` payload byte layout is owned by `reddb-file` (ADR 0046); this
651    /// only projects the engine state into [`reddb_file::HnswIndexLayout`] and
652    /// maps the distance metric to its on-disk discriminant.
653    pub fn to_bytes(&self) -> Vec<u8> {
654        let metric = match self.config.metric {
655            DistanceMetric::L2 => 0,
656            DistanceMetric::Cosine => 1,
657            DistanceMetric::InnerProduct => 2,
658        };
659        let nodes = self
660            .nodes
661            .values()
662            .map(|node| reddb_file::HnswNodeLayout {
663                id: node.id,
664                max_layer: node.max_layer,
665                vector: node.vector.clone(),
666                connections: node.connections.clone(),
667            })
668            .collect();
669        let layout = reddb_file::HnswIndexLayout {
670            dimension: self.dimension,
671            m: self.config.m,
672            m_max0: self.config.m_max0,
673            ef_construction: self.config.ef_construction,
674            ef_search: self.config.ef_search,
675            ml: self.config.ml,
676            metric,
677            max_layer: self.max_layer,
678            entry_point: self.entry_point,
679            nodes,
680        };
681        reddb_file::encode_hnsw_index(&layout)
682    }
683
684    /// Deserialize index from bytes via the `reddb-file` codec.
685    pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
686        let layout = reddb_file::decode_hnsw_index(bytes).map_err(|e| e.to_string())?;
687
688        let metric = match layout.metric {
689            0 => DistanceMetric::L2,
690            1 => DistanceMetric::Cosine,
691            2 => DistanceMetric::InnerProduct,
692            _ => return Err("Invalid distance metric".to_string()),
693        };
694
695        let config = HnswConfig {
696            m: layout.m,
697            m_max0: layout.m_max0,
698            ef_construction: layout.ef_construction,
699            ef_search: layout.ef_search,
700            ml: layout.ml,
701            metric,
702        };
703
704        let mut nodes = HashMap::new();
705        let mut max_id = 0u64;
706        for node in layout.nodes {
707            max_id = max_id.max(node.id);
708            nodes.insert(
709                node.id,
710                HnswNode {
711                    id: node.id,
712                    max_layer: node.max_layer,
713                    vector: node.vector,
714                    connections: node.connections,
715                },
716            );
717        }
718
719        Ok(Self {
720            config,
721            nodes,
722            entry_point: layout.entry_point,
723            max_layer: layout.max_layer,
724            dimension: layout.dimension,
725            next_id: AtomicU64::new(max_id + 1),
726            rng_state: 12345, // Reset RNG
727        })
728    }
729}
730
731/// Statistics about an HNSW index
732#[derive(Debug, Clone)]
733pub struct HnswStats {
734    /// Number of vectors in the index
735    pub node_count: usize,
736    /// Vector dimension
737    pub dimension: usize,
738    /// Maximum layer in the graph
739    pub max_layer: usize,
740    /// Number of nodes per layer
741    pub layer_counts: Vec<usize>,
742    /// Total number of connections
743    pub total_connections: usize,
744    /// Average connections per node
745    pub avg_connections: f64,
746    /// Maximum connections on any node
747    pub max_connections: usize,
748    /// Minimum connections on any node
749    pub min_connections: usize,
750    /// Entry point node ID
751    pub entry_point: Option<NodeId>,
752}
753
754/// Bitset for efficient filtering
755#[derive(Debug, Clone)]
756pub struct Bitset {
757    bits: Vec<u64>,
758    len: usize,
759}
760
761impl Bitset {
762    /// Create a new bitset with capacity for n elements
763    pub fn with_capacity(n: usize) -> Self {
764        let num_words = n.div_ceil(64);
765        Self {
766            bits: vec![0; num_words],
767            len: n,
768        }
769    }
770
771    /// Create a bitset with all bits set
772    pub fn all(n: usize) -> Self {
773        let num_words = n.div_ceil(64);
774        let mut bits = vec![u64::MAX; num_words];
775
776        // Clear excess bits in last word
777        if !n.is_multiple_of(64) {
778            let last_idx = num_words - 1;
779            let valid_bits = n % 64;
780            bits[last_idx] = (1u64 << valid_bits) - 1;
781        }
782
783        Self { bits, len: n }
784    }
785
786    /// Set a bit
787    pub fn set(&mut self, idx: usize) {
788        if idx < self.len {
789            let word = idx / 64;
790            let bit = idx % 64;
791            self.bits[word] |= 1u64 << bit;
792        }
793    }
794
795    /// Clear a bit
796    pub fn clear(&mut self, idx: usize) {
797        if idx < self.len {
798            let word = idx / 64;
799            let bit = idx % 64;
800            self.bits[word] &= !(1u64 << bit);
801        }
802    }
803
804    /// Check if a bit is set
805    pub fn is_set(&self, idx: usize) -> bool {
806        if idx >= self.len {
807            return false;
808        }
809        let word = idx / 64;
810        let bit = idx % 64;
811        (self.bits[word] & (1u64 << bit)) != 0
812    }
813
814    /// Convert to HashSet for use with HNSW filter
815    pub fn to_hashset(&self) -> HashSet<NodeId> {
816        let mut set = HashSet::new();
817        for i in 0..self.len {
818            if self.is_set(i) {
819                set.insert(i as NodeId);
820            }
821        }
822        set
823    }
824}
825
826#[cfg(test)]
827mod tests {
828    use super::*;
829
830    fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
831        let mut state = seed;
832        (0..dim)
833            .map(|_| {
834                state ^= state << 13;
835                state ^= state >> 7;
836                state ^= state << 17;
837                (state as f32) / (u64::MAX as f32)
838            })
839            .collect()
840    }
841
842    #[test]
843    fn test_empty_index() {
844        let index = HnswIndex::with_dimension(128);
845        assert!(index.is_empty());
846        assert_eq!(index.len(), 0);
847
848        let results = index.search(&vec![0.0; 128], 10);
849        assert!(results.is_empty());
850    }
851
852    #[test]
853    fn test_single_insert() {
854        let mut index = HnswIndex::with_dimension(3);
855        let id = index.insert(vec![1.0, 2.0, 3.0]);
856
857        assert_eq!(index.len(), 1);
858        assert!(!index.is_empty());
859        assert!(index.get_vector(id).is_some());
860    }
861
862    #[test]
863    fn test_exact_match() {
864        let mut index = HnswIndex::with_dimension(3);
865        index.insert(vec![1.0, 0.0, 0.0]);
866        index.insert(vec![0.0, 1.0, 0.0]);
867        index.insert(vec![0.0, 0.0, 1.0]);
868
869        // Search for exact match
870        let results = index.search(&[1.0, 0.0, 0.0], 1);
871        assert_eq!(results.len(), 1);
872        assert_eq!(results[0].distance, 0.0);
873    }
874
875    #[test]
876    fn test_nearest_neighbor() {
877        let mut index = HnswIndex::with_dimension(2);
878        index.insert_with_id(0, vec![0.0, 0.0]);
879        index.insert_with_id(1, vec![1.0, 0.0]);
880        index.insert_with_id(2, vec![2.0, 0.0]);
881        index.insert_with_id(3, vec![3.0, 0.0]);
882
883        // Search for something close to (0.9, 0.0)
884        let results = index.search(&[0.9, 0.0], 1);
885        assert_eq!(results.len(), 1);
886        assert_eq!(results[0].id, 1); // Should find (1.0, 0.0)
887    }
888
889    #[test]
890    fn test_k_nearest() {
891        let mut index = HnswIndex::with_dimension(2);
892        for i in 0..10 {
893            index.insert_with_id(i, vec![i as f32, 0.0]);
894        }
895
896        // Search for 3 nearest to (4.5, 0.0)
897        let results = index.search(&[4.5, 0.0], 3);
898        assert_eq!(results.len(), 3);
899
900        // Should find 4, 5, and either 3 or 6
901        let ids: HashSet<_> = results.iter().map(|r| r.id).collect();
902        assert!(ids.contains(&4));
903        assert!(ids.contains(&5));
904    }
905
906    #[test]
907    fn test_filtered_search() {
908        let mut index = HnswIndex::with_dimension(2);
909        for i in 0..10 {
910            index.insert_with_id(i, vec![i as f32, 0.0]);
911        }
912
913        // Only allow even IDs
914        let filter: HashSet<NodeId> = [0, 2, 4, 6, 8].iter().copied().collect();
915
916        // Search for nearest to (5.0, 0.0) with filter
917        let results = index.search_filtered(&[5.0, 0.0], 2, &filter);
918
919        // Should find 4 and 6 (closest even numbers to 5)
920        assert_eq!(results.len(), 2);
921        let ids: HashSet<_> = results.iter().map(|r| r.id).collect();
922        assert!(ids.contains(&4));
923        assert!(ids.contains(&6));
924    }
925
926    #[test]
927    fn test_cosine_distance() {
928        let config = HnswConfig::default().with_metric(DistanceMetric::Cosine);
929        let mut index = HnswIndex::new(3, config);
930
931        // Insert normalized vectors
932        index.insert_with_id(0, vec![1.0, 0.0, 0.0]);
933        index.insert_with_id(1, vec![0.0, 1.0, 0.0]);
934        index.insert_with_id(2, vec![0.707, 0.707, 0.0]); // 45 degrees
935
936        // Search for something at 45 degrees
937        let results = index.search(&[0.707, 0.707, 0.0], 1);
938        assert_eq!(results[0].id, 2);
939    }
940
941    #[test]
942    fn test_many_vectors() {
943        let dim = 64;
944        let n: usize = 1000;
945
946        let mut index = HnswIndex::with_dimension(dim);
947
948        // Insert many random vectors
949        for i in 0..n {
950            let vector = random_vector(dim, i as u64);
951            index.insert_with_id(i as u64, vector);
952        }
953
954        assert_eq!(index.len(), n);
955
956        // Search should return k results
957        let query = random_vector(dim, 12345);
958        let results = index.search(&query, 10);
959        assert_eq!(results.len(), 10);
960
961        // Results should be sorted by distance
962        for i in 1..results.len() {
963            assert!(results[i].distance >= results[i - 1].distance);
964        }
965    }
966
967    #[test]
968    fn test_bitset() {
969        let mut bs = Bitset::with_capacity(100);
970
971        bs.set(0);
972        bs.set(50);
973        bs.set(99);
974
975        assert!(bs.is_set(0));
976        assert!(bs.is_set(50));
977        assert!(bs.is_set(99));
978        assert!(!bs.is_set(1));
979        assert!(!bs.is_set(64));
980
981        bs.clear(50);
982        assert!(!bs.is_set(50));
983    }
984
985    #[test]
986    fn test_bitset_all() {
987        let bs = Bitset::all(100);
988
989        for i in 0..100 {
990            assert!(bs.is_set(i));
991        }
992        assert!(!bs.is_set(100)); // Out of bounds
993    }
994}