omendb_core/vector/hnsw/index/
mod.rs

1// HNSW Index - Main implementation
2//
3// Architecture:
4// - Flattened index (contiguous nodes, u32 node IDs)
5// - Separate neighbor storage (fetch only when needed)
6// - Cache-optimized layout (64-byte aligned hot data)
7//
8// Module structure:
9// - mod.rs: Core struct, constructors, getters, distance methods
10// - insert.rs: Insert operations (single, batch, graph construction)
11// - search.rs: Search operations (k-NN, filtered, layer-level)
12// - persistence.rs: Save/load to disk
13// - stats.rs: Statistics, memory usage, cache optimization
14
15mod delete;
16mod insert;
17mod persistence;
18mod search;
19mod stats;
20
21#[cfg(test)]
22mod tests;
23
24use super::error::{HNSWError, Result};
25use super::graph_storage::GraphStorage;
26use super::storage::VectorStorage;
27use super::types::{Distance, DistanceFunction, HNSWNode, HNSWParams};
28use crate::compression::RaBitQParams;
29use serde::{Deserialize, Serialize};
30
31/// Index statistics for monitoring and debugging
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct IndexStats {
34    /// Total number of vectors in index
35    pub num_vectors: usize,
36
37    /// Vector dimensionality
38    pub dimensions: usize,
39
40    /// Entry point node ID
41    pub entry_point: Option<u32>,
42
43    /// Maximum level in the graph
44    pub max_level: u8,
45
46    /// Level distribution (count of nodes at each level as their TOP level)
47    pub level_distribution: Vec<usize>,
48
49    /// Average neighbors per node (level 0)
50    pub avg_neighbors_l0: f32,
51
52    /// Max neighbors per node (level 0)
53    pub max_neighbors_l0: usize,
54
55    /// Memory usage in bytes
56    pub memory_bytes: usize,
57
58    /// HNSW parameters
59    pub params: HNSWParams,
60
61    /// Distance function
62    pub distance_function: DistanceFunction,
63
64    /// Whether quantization is enabled
65    pub quantization_enabled: bool,
66}
67
68/// HNSW Index
69///
70/// Hierarchical graph index for approximate nearest neighbor search.
71/// Optimized for cache locality and memory efficiency.
72///
73/// **Note**: Not Clone due to `GraphStorage` containing non-cloneable backends.
74/// Use persistence APIs (save/load) instead of cloning.
75#[derive(Debug, Serialize, Deserialize)]
76pub struct HNSWIndex {
77    /// Node metadata (cache-line aligned)
78    pub(super) nodes: Vec<HNSWNode>,
79
80    /// Graph storage (mode-dependent: in-memory or hybrid disk+cache)
81    pub(super) neighbors: GraphStorage,
82
83    /// Vector storage (full precision or quantized)
84    pub(super) vectors: VectorStorage,
85
86    /// Entry point (top-level node)
87    pub(super) entry_point: Option<u32>,
88
89    /// Construction parameters
90    pub(super) params: HNSWParams,
91
92    /// Distance function
93    pub(super) distance_fn: DistanceFunction,
94
95    /// Random number generator seed state
96    pub(super) rng_state: u64,
97}
98
99impl HNSWIndex {
100    // =========================================================================
101    // Constructors
102    // =========================================================================
103
104    /// Create a new empty HNSW index
105    ///
106    /// # Arguments
107    /// * `dimensions` - Vector dimensionality
108    /// * `params` - HNSW construction parameters
109    /// * `distance_fn` - Distance function (L2, Cosine, Dot)
110    /// * `use_quantization` - Whether to use binary quantization
111    pub fn new(
112        dimensions: usize,
113        params: HNSWParams,
114        distance_fn: DistanceFunction,
115        use_quantization: bool,
116    ) -> Result<Self> {
117        params.validate().map_err(HNSWError::InvalidParams)?;
118
119        let vectors = if use_quantization {
120            VectorStorage::new_binary_quantized(dimensions, true)
121        } else {
122            VectorStorage::new_full_precision(dimensions)
123        };
124
125        let neighbors = GraphStorage::new(params.max_level as usize);
126
127        Ok(Self {
128            nodes: Vec::new(),
129            neighbors,
130            vectors,
131            entry_point: None,
132            params,
133            distance_fn,
134            rng_state: params.seed,
135        })
136    }
137
138    /// Create a new HNSW index with `RaBitQ` asymmetric search (CLOUD MOAT)
139    ///
140    /// This enables 2-3x faster search by using asymmetric distance computation:
141    /// - Query vector stays full precision
142    /// - Candidate vectors use `RaBitQ` quantization (8x smaller)
143    /// - Final reranking uses full precision for accuracy
144    ///
145    /// # Arguments
146    /// * `dimensions` - Vector dimensionality
147    /// * `params` - HNSW construction parameters
148    /// * `distance_fn` - Distance function (only L2 supported for asymmetric)
149    /// * `rabitq_params` - `RaBitQ` quantization parameters (typically 4-bit)
150    ///
151    /// # Performance
152    /// - Search: 2-3x faster than full precision
153    /// - Memory: 8x smaller quantized storage (+ original for reranking)
154    /// - Recall: 98%+ with reranking
155    ///
156    /// # Example
157    /// ```ignore
158    /// let params = HNSWParams::default();
159    /// let rabitq = RaBitQParams::bits4(); // 4-bit, 8x compression
160    /// let index = HNSWIndex::new_with_asymmetric(128, params, DistanceFunction::L2, rabitq)?;
161    /// ```
162    pub fn new_with_asymmetric(
163        dimensions: usize,
164        params: HNSWParams,
165        distance_fn: DistanceFunction,
166        rabitq_params: RaBitQParams,
167    ) -> Result<Self> {
168        params.validate().map_err(HNSWError::InvalidParams)?;
169
170        // RaBitQ asymmetric search only supports L2 distance
171        if !matches!(distance_fn, DistanceFunction::L2) {
172            return Err(HNSWError::InvalidParams(
173                "Asymmetric search only supports L2 distance function".to_string(),
174            ));
175        }
176
177        let vectors = VectorStorage::new_rabitq_quantized(dimensions, rabitq_params);
178        let neighbors = GraphStorage::new(params.max_level as usize);
179
180        Ok(Self {
181            nodes: Vec::new(),
182            neighbors,
183            vectors,
184            entry_point: None,
185            params,
186            distance_fn,
187            rng_state: params.seed,
188        })
189    }
190
191    /// Create new HNSW index with SQ8 (Scalar Quantization)
192    ///
193    /// SQ8 compresses f32 → u8 (4x smaller) and uses direct SIMD operations
194    /// for ~2x faster search than full precision.
195    ///
196    /// # Arguments
197    /// * `dimensions` - Vector dimensionality
198    /// * `params` - HNSW parameters (m, `ef_construction`, `ef_search`)
199    /// * `distance_fn` - Distance function (only L2 supported for SQ8)
200    ///
201    /// # Example
202    /// ```ignore
203    /// let params = HNSWParams::default();
204    /// let index = HNSWIndex::new_with_sq8(768, params, DistanceFunction::L2)?;
205    /// ```
206    pub fn new_with_sq8(
207        dimensions: usize,
208        params: HNSWParams,
209        distance_fn: DistanceFunction,
210    ) -> Result<Self> {
211        params.validate().map_err(HNSWError::InvalidParams)?;
212
213        // SQ8 asymmetric search only supports L2 distance
214        if !matches!(distance_fn, DistanceFunction::L2) {
215            return Err(HNSWError::InvalidParams(
216                "SQ8 asymmetric search only supports L2 distance function".to_string(),
217            ));
218        }
219
220        let vectors = VectorStorage::new_sq8_quantized(dimensions);
221        let neighbors = GraphStorage::new(params.max_level as usize);
222
223        Ok(Self {
224            nodes: Vec::new(),
225            neighbors,
226            vectors,
227            entry_point: None,
228            params,
229            distance_fn,
230            rng_state: params.seed,
231        })
232    }
233
234    /// Create new HNSW index with binary (1-bit) quantization
235    ///
236    /// Uses SIMD-optimized Hamming distance for fast search.
237    ///
238    /// # Performance
239    /// - Search: 2-4x faster than SQ8 (SIMD Hamming is extremely fast)
240    /// - Memory: 32x smaller quantized storage (+ original for reranking)
241    /// - Recall: ~85% raw, ~95-98% with reranking
242    ///
243    /// # Example
244    /// ```ignore
245    /// let params = HNSWParams::default();
246    /// let index = HNSWIndex::new_with_binary(768, params, DistanceFunction::L2)?;
247    /// ```
248    pub fn new_with_binary(
249        dimensions: usize,
250        params: HNSWParams,
251        distance_fn: DistanceFunction,
252    ) -> Result<Self> {
253        params.validate().map_err(HNSWError::InvalidParams)?;
254
255        // Binary quantization only supports L2 distance
256        if !matches!(distance_fn, DistanceFunction::L2) {
257            return Err(HNSWError::InvalidParams(
258                "Binary quantization only supports L2 distance function".to_string(),
259            ));
260        }
261
262        let vectors = VectorStorage::new_binary_quantized(dimensions, true);
263        let neighbors = GraphStorage::new(params.max_level as usize);
264
265        Ok(Self {
266            nodes: Vec::new(),
267            neighbors,
268            vectors,
269            entry_point: None,
270            params,
271            distance_fn,
272            rng_state: params.seed,
273        })
274    }
275
276    // =========================================================================
277    // Getters
278    // =========================================================================
279
280    /// Check if this index uses asymmetric search (`RaBitQ` or `SQ8`)
281    #[must_use]
282    pub fn is_asymmetric(&self) -> bool {
283        self.vectors.is_asymmetric()
284    }
285
286    /// Check if this index uses SQ8 quantization
287    #[must_use]
288    pub fn is_sq8(&self) -> bool {
289        self.vectors.is_sq8()
290    }
291
292    /// Train the quantizer from sample vectors
293    pub fn train_quantizer(&mut self, sample_vectors: &[Vec<f32>]) -> Result<()> {
294        self.vectors
295            .train_quantization(sample_vectors)
296            .map_err(HNSWError::InvalidParams)
297    }
298
299    /// Get number of vectors in index
300    #[must_use]
301    pub fn len(&self) -> usize {
302        self.nodes.len()
303    }
304
305    /// Check if index is empty
306    #[must_use]
307    pub fn is_empty(&self) -> bool {
308        self.nodes.is_empty()
309    }
310
311    /// Get dimensions
312    #[must_use]
313    pub fn dimensions(&self) -> usize {
314        self.vectors.dimensions()
315    }
316
317    /// Get a vector by ID (full precision)
318    ///
319    /// Returns None if the ID is invalid or out of bounds.
320    #[must_use]
321    pub fn get_vector(&self, id: u32) -> Option<&[f32]> {
322        self.vectors.get(id)
323    }
324
325    /// Get entry point
326    #[must_use]
327    pub fn entry_point(&self) -> Option<u32> {
328        self.entry_point
329    }
330
331    /// Get node level
332    #[must_use]
333    pub fn node_level(&self, node_id: u32) -> Option<u8> {
334        self.nodes.get(node_id as usize).map(|n| n.level)
335    }
336
337    /// Get neighbor count for a node at a level
338    #[must_use]
339    pub fn neighbor_count(&self, node_id: u32, level: u8) -> usize {
340        self.neighbors.get_neighbors(node_id, level).len()
341    }
342
343    /// Get HNSW parameters
344    #[must_use]
345    pub fn params(&self) -> &HNSWParams {
346        &self.params
347    }
348
349    /// Get neighbors at level 0 for a node
350    ///
351    /// Level 0 has the most connections (M*2) and is used for graph merging.
352    #[must_use]
353    pub fn get_neighbors_level0(&self, node_id: u32) -> Vec<u32> {
354        self.neighbors.get_neighbors(node_id, 0)
355    }
356
357    // =========================================================================
358    // Internal helpers
359    // =========================================================================
360
361    /// Assign random level to new node
362    ///
363    /// Uses exponential decay: P(level = l) = (1/M)^l
364    /// This ensures most nodes are at level 0, fewer at higher levels.
365    pub(super) fn random_level(&mut self) -> u8 {
366        // Simple LCG for deterministic random numbers
367        self.rng_state = self
368            .rng_state
369            .wrapping_mul(6_364_136_223_846_793_005)
370            .wrapping_add(1);
371        let rand_val = (self.rng_state >> 32) as f32 / u32::MAX as f32;
372
373        // Exponential distribution: -ln(uniform) / ln(M)
374        let level = (-rand_val.ln() * self.params.ml) as u8;
375        level.min(self.params.max_level - 1)
376    }
377
378    // =========================================================================
379    // Distance functions
380    // =========================================================================
381
382    /// Distance between nodes for ordering comparisons
383    ///
384    /// Uses dequantized vectors if storage is quantized (SQ8).
385    #[inline]
386    pub(super) fn distance_between_cmp(&self, id_a: u32, id_b: u32) -> Result<f32> {
387        // Try asymmetric distance first (for SQ8/RaBitQ - use id_b as quantized candidate)
388        if let Some(vec_a) = self.vectors.get_dequantized(id_a) {
389            if let Some(dist) = self.vectors.distance_asymmetric_l2(&vec_a, id_b) {
390                return Ok(dist);
391            }
392        }
393        // Fallback to full precision
394        let vec_a = self
395            .vectors
396            .get(id_a)
397            .ok_or(HNSWError::VectorNotFound(id_a))?;
398        let vec_b = self
399            .vectors
400            .get(id_b)
401            .ok_or(HNSWError::VectorNotFound(id_b))?;
402        Ok(self.distance_fn.distance_for_comparison(vec_a, vec_b))
403    }
404
405    /// Distance from query to node for ordering comparisons
406    ///
407    /// Tries asymmetric distance first (for SQ8/RaBitQ), falls back to full precision.
408    #[inline(always)]
409    pub(super) fn distance_cmp(&self, query: &[f32], id: u32) -> Result<f32> {
410        // Try asymmetric distance first (for SQ8/RaBitQ storage)
411        if let Some(dist) = self.vectors.distance_asymmetric_l2(query, id) {
412            return Ok(dist);
413        }
414        // Fallback to full precision
415        let vec = self.vectors.get(id).ok_or(HNSWError::VectorNotFound(id))?;
416        Ok(self.distance_fn.distance_for_comparison(query, vec))
417    }
418
419    /// Monomorphized distance computation (static dispatch, no match)
420    ///
421    /// Critical for x86/ARM servers where branch misprediction hurts performance.
422    /// The Distance trait enables compile-time specialization.
423    #[inline(always)]
424    pub(super) fn distance_cmp_mono<D: Distance>(&self, query: &[f32], id: u32) -> Result<f32> {
425        // Try asymmetric distance first (for SQ8/RaBitQ storage)
426        if let Some(dist) = self.vectors.distance_asymmetric_l2(query, id) {
427            return Ok(dist);
428        }
429        // Fallback to full precision with static dispatch
430        let vec = self.vectors.get(id).ok_or(HNSWError::VectorNotFound(id))?;
431        Ok(D::distance(query, vec))
432    }
433
434    /// Distance from query to node using full precision (f32-to-f32)
435    ///
436    /// Used during graph construction where quantization noise hurts graph quality.
437    /// For RaBitQ, uses stored originals. For SQ8, dequantizes.
438    #[inline]
439    pub(super) fn distance_cmp_full_precision(&self, query: &[f32], id: u32) -> Result<f32> {
440        // Always use dequantized/original vectors for full precision comparison
441        let vec = self
442            .vectors
443            .get_dequantized(id)
444            .ok_or(HNSWError::VectorNotFound(id))?;
445        Ok(self.distance_fn.distance_for_comparison(query, &vec))
446    }
447
448    /// Actual distance (with sqrt for L2)
449    #[inline]
450    pub(super) fn distance_exact(&self, query: &[f32], id: u32) -> Result<f32> {
451        // Try asymmetric distance first (for SQ8/RaBitQ storage)
452        if let Some(dist) = self.vectors.distance_asymmetric_l2(query, id) {
453            return Ok(dist.sqrt());
454        }
455        let vec = self.vectors.get(id).ok_or(HNSWError::VectorNotFound(id))?;
456        Ok(self.distance_fn.distance(query, vec))
457    }
458
459    /// Asymmetric distance for `RaBitQ` search (CLOUD MOAT - HOT PATH)
460    ///
461    /// Query stays full precision, candidate uses quantized representation.
462    /// Falls back to regular `distance_cmp` if not using asymmetric storage.
463    #[inline]
464    pub(super) fn distance_asymmetric(&self, query: &[f32], id: u32) -> Result<f32> {
465        // Try asymmetric distance first (for RaBitQ storage)
466        if let Some(dist) = self.vectors.distance_asymmetric_l2(query, id) {
467            return Ok(dist);
468        }
469
470        // Fallback to regular distance for non-RaBitQ storage
471        self.distance_cmp(query, id)
472    }
473
474    /// L2 distance using decomposition: ||a-b||² = ||a||² + ||b||² - 2⟨a,b⟩
475    ///
476    /// ~7% faster than direct L2 by pre-computing vector norms during insert.
477    /// Query norm is computed once per search and passed in.
478    ///
479    /// Returns None if decomposition is not available (non-FullPrecision storage).
480    #[inline(always)]
481    pub(super) fn distance_l2_decomposed(
482        &self,
483        query: &[f32],
484        query_norm: f32,
485        id: u32,
486    ) -> Option<f32> {
487        self.vectors.distance_l2_decomposed(query, query_norm, id)
488    }
489
490    /// Check if L2 decomposition optimization is available
491    ///
492    /// Returns true if storage supports L2 decomposition AND distance function is L2.
493    #[inline]
494    pub(super) fn supports_l2_decomposition(&self) -> bool {
495        matches!(self.distance_fn, DistanceFunction::L2) && self.vectors.supports_l2_decomposition()
496    }
497}