omendb_core/vector/hnsw/index/
mod.rs

1// HNSW Index - Main implementation
2//
3// Architecture:
4// - Flattened index (contiguous nodes, u32 node IDs)
5// - Separate neighbor storage (fetch only when needed)
6// - Cache-optimized layout (64-byte aligned hot data)
7//
8// Module structure:
9// - mod.rs: Core struct, constructors, getters, distance methods
10// - insert.rs: Insert operations (single, batch, graph construction)
11// - search.rs: Search operations (k-NN, filtered, layer-level)
12// - persistence.rs: Save/load to disk
13// - stats.rs: Statistics, memory usage, cache optimization
14
15mod delete;
16mod insert;
17mod persistence;
18mod search;
19mod stats;
20
21#[cfg(test)]
22mod tests;
23
24use super::error::{HNSWError, Result};
25use super::graph_storage::GraphStorage;
26use super::storage::VectorStorage;
27use super::types::{Distance, DistanceFunction, HNSWNode, HNSWParams};
28use crate::compression::RaBitQParams;
29use serde::{Deserialize, Serialize};
30
31/// Index statistics for monitoring and debugging
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct IndexStats {
34    /// Total number of vectors in index
35    pub num_vectors: usize,
36
37    /// Vector dimensionality
38    pub dimensions: usize,
39
40    /// Entry point node ID
41    pub entry_point: Option<u32>,
42
43    /// Maximum level in the graph
44    pub max_level: u8,
45
46    /// Level distribution (count of nodes at each level as their TOP level)
47    pub level_distribution: Vec<usize>,
48
49    /// Average neighbors per node (level 0)
50    pub avg_neighbors_l0: f32,
51
52    /// Max neighbors per node (level 0)
53    pub max_neighbors_l0: usize,
54
55    /// Memory usage in bytes
56    pub memory_bytes: usize,
57
58    /// HNSW parameters
59    pub params: HNSWParams,
60
61    /// Distance function
62    pub distance_function: DistanceFunction,
63
64    /// Whether quantization is enabled
65    pub quantization_enabled: bool,
66}
67
68/// HNSW Index
69///
70/// Hierarchical graph index for approximate nearest neighbor search.
71/// Optimized for cache locality and memory efficiency.
72///
73/// **Note**: Not Clone due to `GraphStorage` containing non-cloneable backends.
74/// Use persistence APIs (save/load) instead of cloning.
75#[derive(Debug, Serialize, Deserialize)]
76pub struct HNSWIndex {
77    /// Node metadata (cache-line aligned)
78    pub(super) nodes: Vec<HNSWNode>,
79
80    /// Graph storage (mode-dependent: in-memory or hybrid disk+cache)
81    pub(super) neighbors: GraphStorage,
82
83    /// Vector storage (full precision or quantized)
84    pub(super) vectors: VectorStorage,
85
86    /// Entry point (top-level node)
87    pub(super) entry_point: Option<u32>,
88
89    /// Construction parameters
90    pub(super) params: HNSWParams,
91
92    /// Distance function
93    pub(super) distance_fn: DistanceFunction,
94
95    /// Random number generator seed state
96    pub(super) rng_state: u64,
97}
98
99impl HNSWIndex {
100    // =========================================================================
101    // Constructors
102    // =========================================================================
103
104    /// Build an HNSWIndex with pre-created vector storage
105    fn build(vectors: VectorStorage, params: HNSWParams, distance_fn: DistanceFunction) -> Self {
106        let neighbors = GraphStorage::new(params.max_level as usize);
107        Self {
108            nodes: Vec::new(),
109            neighbors,
110            vectors,
111            entry_point: None,
112            rng_state: params.seed,
113            params,
114            distance_fn,
115        }
116    }
117
118    /// Validate params and check that distance function is L2 (required for quantized modes)
119    fn validate_l2_required(
120        params: &HNSWParams,
121        distance_fn: DistanceFunction,
122        mode_name: &str,
123    ) -> Result<()> {
124        params.validate().map_err(HNSWError::InvalidParams)?;
125        if !matches!(distance_fn, DistanceFunction::L2) {
126            return Err(HNSWError::InvalidParams(format!(
127                "{mode_name} only supports L2 distance function"
128            )));
129        }
130        Ok(())
131    }
132
133    /// Create a new empty HNSW index
134    ///
135    /// # Arguments
136    /// * `dimensions` - Vector dimensionality
137    /// * `params` - HNSW construction parameters
138    /// * `distance_fn` - Distance function (L2, Cosine, Dot)
139    /// * `use_quantization` - Whether to use binary quantization
140    pub fn new(
141        dimensions: usize,
142        params: HNSWParams,
143        distance_fn: DistanceFunction,
144        use_quantization: bool,
145    ) -> Result<Self> {
146        params.validate().map_err(HNSWError::InvalidParams)?;
147
148        let vectors = if use_quantization {
149            VectorStorage::new_binary_quantized(dimensions, true)
150        } else {
151            VectorStorage::new_full_precision(dimensions)
152        };
153
154        Ok(Self::build(vectors, params, distance_fn))
155    }
156
157    /// Create a new HNSW index with `RaBitQ` asymmetric search (CLOUD MOAT)
158    ///
159    /// This enables 2-3x faster search by using asymmetric distance computation:
160    /// - Query vector stays full precision
161    /// - Candidate vectors use `RaBitQ` quantization (8x smaller)
162    /// - Final reranking uses full precision for accuracy
163    ///
164    /// # Arguments
165    /// * `dimensions` - Vector dimensionality
166    /// * `params` - HNSW construction parameters
167    /// * `distance_fn` - Distance function (only L2 supported for asymmetric)
168    /// * `rabitq_params` - `RaBitQ` quantization parameters (typically 4-bit)
169    ///
170    /// # Performance
171    /// - Search: 2-3x faster than full precision
172    /// - Memory: 8x smaller quantized storage (+ original for reranking)
173    /// - Recall: 98%+ with reranking
174    ///
175    /// # Example
176    /// ```ignore
177    /// let params = HNSWParams::default();
178    /// let rabitq = RaBitQParams::bits4(); // 4-bit, 8x compression
179    /// let index = HNSWIndex::new_with_asymmetric(128, params, DistanceFunction::L2, rabitq)?;
180    /// ```
181    pub fn new_with_asymmetric(
182        dimensions: usize,
183        params: HNSWParams,
184        distance_fn: DistanceFunction,
185        rabitq_params: RaBitQParams,
186    ) -> Result<Self> {
187        Self::validate_l2_required(&params, distance_fn, "RaBitQ asymmetric search")?;
188        let vectors = VectorStorage::new_rabitq_quantized(dimensions, rabitq_params);
189        Ok(Self::build(vectors, params, distance_fn))
190    }
191
192    /// Create new HNSW index with SQ8 (Scalar Quantization)
193    ///
194    /// SQ8 compresses f32 → u8 (4x smaller) and uses direct SIMD operations
195    /// for ~2x faster search than full precision.
196    ///
197    /// # Arguments
198    /// * `dimensions` - Vector dimensionality
199    /// * `params` - HNSW parameters (m, `ef_construction`, `ef_search`)
200    /// * `distance_fn` - Distance function (only L2 supported for SQ8)
201    ///
202    /// # Example
203    /// ```ignore
204    /// let params = HNSWParams::default();
205    /// let index = HNSWIndex::new_with_sq8(768, params, DistanceFunction::L2)?;
206    /// ```
207    pub fn new_with_sq8(
208        dimensions: usize,
209        params: HNSWParams,
210        distance_fn: DistanceFunction,
211    ) -> Result<Self> {
212        Self::validate_l2_required(&params, distance_fn, "SQ8 quantization")?;
213        let vectors = VectorStorage::new_sq8_quantized(dimensions);
214        Ok(Self::build(vectors, params, distance_fn))
215    }
216
217    /// Create new HNSW index with binary (1-bit) quantization
218    ///
219    /// Uses SIMD-optimized Hamming distance for fast search.
220    ///
221    /// # Performance
222    /// - Search: 2-4x faster than SQ8 (SIMD Hamming is extremely fast)
223    /// - Memory: 32x smaller quantized storage (+ original for reranking)
224    /// - Recall: ~85% raw, ~95-98% with reranking
225    ///
226    /// # Example
227    /// ```ignore
228    /// let params = HNSWParams::default();
229    /// let index = HNSWIndex::new_with_binary(768, params, DistanceFunction::L2)?;
230    /// ```
231    pub fn new_with_binary(
232        dimensions: usize,
233        params: HNSWParams,
234        distance_fn: DistanceFunction,
235    ) -> Result<Self> {
236        Self::validate_l2_required(&params, distance_fn, "Binary quantization")?;
237        let vectors = VectorStorage::new_binary_quantized(dimensions, true);
238        Ok(Self::build(vectors, params, distance_fn))
239    }
240
241    // =========================================================================
242    // Getters
243    // =========================================================================
244
245    /// Check if this index uses asymmetric search (`RaBitQ` or `SQ8`)
246    #[must_use]
247    pub fn is_asymmetric(&self) -> bool {
248        self.vectors.is_asymmetric()
249    }
250
251    /// Check if this index uses SQ8 quantization
252    #[must_use]
253    pub fn is_sq8(&self) -> bool {
254        self.vectors.is_sq8()
255    }
256
257    /// Train the quantizer from sample vectors
258    pub fn train_quantizer(&mut self, sample_vectors: &[Vec<f32>]) -> Result<()> {
259        self.vectors
260            .train_quantization(sample_vectors)
261            .map_err(HNSWError::InvalidParams)
262    }
263
264    /// Get number of vectors in index
265    #[must_use]
266    pub fn len(&self) -> usize {
267        self.nodes.len()
268    }
269
270    /// Check if index is empty
271    #[must_use]
272    pub fn is_empty(&self) -> bool {
273        self.nodes.is_empty()
274    }
275
276    /// Get dimensions
277    #[must_use]
278    pub fn dimensions(&self) -> usize {
279        self.vectors.dimensions()
280    }
281
282    /// Get a vector by ID (full precision)
283    ///
284    /// Returns None if the ID is invalid or out of bounds.
285    #[must_use]
286    pub fn get_vector(&self, id: u32) -> Option<&[f32]> {
287        self.vectors.get(id)
288    }
289
290    /// Get entry point
291    #[must_use]
292    pub fn entry_point(&self) -> Option<u32> {
293        self.entry_point
294    }
295
296    /// Get node level
297    #[must_use]
298    pub fn node_level(&self, node_id: u32) -> Option<u8> {
299        self.nodes.get(node_id as usize).map(|n| n.level)
300    }
301
302    /// Get neighbor count for a node at a level
303    #[must_use]
304    pub fn neighbor_count(&self, node_id: u32, level: u8) -> usize {
305        self.neighbors.get_neighbors(node_id, level).len()
306    }
307
308    /// Get HNSW parameters
309    #[must_use]
310    pub fn params(&self) -> &HNSWParams {
311        &self.params
312    }
313
314    /// Get neighbors at level 0 for a node
315    ///
316    /// Level 0 has the most connections (M*2) and is used for graph merging.
317    #[must_use]
318    pub fn get_neighbors_level0(&self, node_id: u32) -> Vec<u32> {
319        self.neighbors.get_neighbors(node_id, 0)
320    }
321
322    // =========================================================================
323    // Internal helpers
324    // =========================================================================
325
326    /// Assign random level to new node
327    ///
328    /// Uses exponential decay: P(level = l) = (1/M)^l
329    /// This ensures most nodes are at level 0, fewer at higher levels.
330    pub(super) fn random_level(&mut self) -> u8 {
331        // Simple LCG for deterministic random numbers
332        self.rng_state = self
333            .rng_state
334            .wrapping_mul(6_364_136_223_846_793_005)
335            .wrapping_add(1);
336        let rand_val = (self.rng_state >> 32) as f32 / u32::MAX as f32;
337
338        // Exponential distribution: -ln(uniform) / ln(M)
339        let level = (-rand_val.ln() * self.params.ml) as u8;
340        level.min(self.params.max_level - 1)
341    }
342
343    // =========================================================================
344    // Distance functions
345    // =========================================================================
346
347    /// Distance between nodes for ordering comparisons
348    ///
349    /// Uses dequantized vectors if storage is quantized (SQ8).
350    #[inline]
351    pub(super) fn distance_between_cmp(&self, id_a: u32, id_b: u32) -> Result<f32> {
352        // Try asymmetric distance first (for SQ8/RaBitQ - use id_b as quantized candidate)
353        if let Some(vec_a) = self.vectors.get_dequantized(id_a) {
354            if let Some(dist) = self.vectors.distance_asymmetric_l2(&vec_a, id_b) {
355                return Ok(dist);
356            }
357        }
358        // Fallback to full precision
359        let vec_a = self
360            .vectors
361            .get(id_a)
362            .ok_or(HNSWError::VectorNotFound(id_a))?;
363        let vec_b = self
364            .vectors
365            .get(id_b)
366            .ok_or(HNSWError::VectorNotFound(id_b))?;
367        Ok(self.distance_fn.distance_for_comparison(vec_a, vec_b))
368    }
369
370    /// Distance from query to node for ordering comparisons
371    ///
372    /// Tries asymmetric distance first (for SQ8/RaBitQ), falls back to full precision.
373    #[inline(always)]
374    pub(super) fn distance_cmp(&self, query: &[f32], id: u32) -> Result<f32> {
375        // Try asymmetric distance first (for SQ8/RaBitQ storage)
376        if let Some(dist) = self.vectors.distance_asymmetric_l2(query, id) {
377            return Ok(dist);
378        }
379        // Fallback to full precision
380        let vec = self.vectors.get(id).ok_or(HNSWError::VectorNotFound(id))?;
381        Ok(self.distance_fn.distance_for_comparison(query, vec))
382    }
383
384    /// Monomorphized distance computation (static dispatch, no match)
385    ///
386    /// Critical for x86/ARM servers where branch misprediction hurts performance.
387    /// The Distance trait enables compile-time specialization.
388    #[inline(always)]
389    pub(super) fn distance_cmp_mono<D: Distance>(&self, query: &[f32], id: u32) -> Result<f32> {
390        // Try asymmetric distance first (for SQ8/RaBitQ storage)
391        if let Some(dist) = self.vectors.distance_asymmetric_l2(query, id) {
392            return Ok(dist);
393        }
394        // Fallback to full precision with static dispatch
395        let vec = self.vectors.get(id).ok_or(HNSWError::VectorNotFound(id))?;
396        Ok(D::distance(query, vec))
397    }
398
399    /// Distance from query to node using full precision (f32-to-f32)
400    ///
401    /// Used during graph construction where quantization noise hurts graph quality.
402    /// For RaBitQ, uses stored originals. For SQ8, dequantizes.
403    #[inline]
404    pub(super) fn distance_cmp_full_precision(&self, query: &[f32], id: u32) -> Result<f32> {
405        // Always use dequantized/original vectors for full precision comparison
406        let vec = self
407            .vectors
408            .get_dequantized(id)
409            .ok_or(HNSWError::VectorNotFound(id))?;
410        Ok(self.distance_fn.distance_for_comparison(query, &vec))
411    }
412
413    /// Actual distance (with sqrt for L2)
414    #[inline]
415    pub(super) fn distance_exact(&self, query: &[f32], id: u32) -> Result<f32> {
416        // For SQ8/RaBitQ: use asymmetric distance (returns squared L2)
417        // For Binary: skip asymmetric (hamming is not L2), use original vectors
418        if !self.vectors.is_binary_quantized() {
419            if let Some(dist) = self.vectors.distance_asymmetric_l2(query, id) {
420                return Ok(dist.sqrt());
421            }
422        }
423        let vec = self.vectors.get(id).ok_or(HNSWError::VectorNotFound(id))?;
424        Ok(self.distance_fn.distance(query, vec))
425    }
426
427    /// L2 distance using decomposition: ||a-b||² = ||a||² + ||b||² - 2⟨a,b⟩
428    ///
429    /// ~7% faster than direct L2 by pre-computing vector norms during insert.
430    /// Query norm is computed once per search and passed in.
431    ///
432    /// Returns None if decomposition is not available (non-FullPrecision storage).
433    #[inline(always)]
434    pub(super) fn distance_l2_decomposed(
435        &self,
436        query: &[f32],
437        query_norm: f32,
438        id: u32,
439    ) -> Option<f32> {
440        self.vectors.distance_l2_decomposed(query, query_norm, id)
441    }
442
443    /// Check if L2 decomposition optimization is available
444    ///
445    /// Returns true if storage supports L2 decomposition AND distance function is L2.
446    #[inline]
447    pub(super) fn supports_l2_decomposition(&self) -> bool {
448        matches!(self.distance_fn, DistanceFunction::L2) && self.vectors.supports_l2_decomposition()
449    }
450}