omendb_core/vector/hnsw/index/mod.rs
1// HNSW Index - Main implementation
2//
3// Architecture:
4// - Flattened index (contiguous nodes, u32 node IDs)
5// - Separate neighbor storage (fetch only when needed)
6// - Cache-optimized layout (64-byte aligned hot data)
7//
8// Module structure:
9// - mod.rs: Core struct, constructors, getters, distance methods
10// - insert.rs: Insert operations (single, batch, graph construction)
11// - search.rs: Search operations (k-NN, filtered, layer-level)
12// - persistence.rs: Save/load to disk
13// - stats.rs: Statistics, memory usage, cache optimization
14
15mod delete;
16mod insert;
17mod persistence;
18mod search;
19mod stats;
20
21#[cfg(test)]
22mod tests;
23
24use super::error::{HNSWError, Result};
25use super::graph_storage::GraphStorage;
26use super::storage::VectorStorage;
27use super::types::{Distance, DistanceFunction, HNSWNode, HNSWParams};
28use crate::compression::RaBitQParams;
29use serde::{Deserialize, Serialize};
30
31/// Index statistics for monitoring and debugging
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct IndexStats {
34 /// Total number of vectors in index
35 pub num_vectors: usize,
36
37 /// Vector dimensionality
38 pub dimensions: usize,
39
40 /// Entry point node ID
41 pub entry_point: Option<u32>,
42
43 /// Maximum level in the graph
44 pub max_level: u8,
45
46 /// Level distribution (count of nodes at each level as their TOP level)
47 pub level_distribution: Vec<usize>,
48
49 /// Average neighbors per node (level 0)
50 pub avg_neighbors_l0: f32,
51
52 /// Max neighbors per node (level 0)
53 pub max_neighbors_l0: usize,
54
55 /// Memory usage in bytes
56 pub memory_bytes: usize,
57
58 /// HNSW parameters
59 pub params: HNSWParams,
60
61 /// Distance function
62 pub distance_function: DistanceFunction,
63
64 /// Whether quantization is enabled
65 pub quantization_enabled: bool,
66}
67
68/// HNSW Index
69///
70/// Hierarchical graph index for approximate nearest neighbor search.
71/// Optimized for cache locality and memory efficiency.
72///
73/// **Note**: Not Clone due to `GraphStorage` containing non-cloneable backends.
74/// Use persistence APIs (save/load) instead of cloning.
75#[derive(Debug, Serialize, Deserialize)]
76pub struct HNSWIndex {
77 /// Node metadata (cache-line aligned)
78 pub(super) nodes: Vec<HNSWNode>,
79
80 /// Graph storage (mode-dependent: in-memory or hybrid disk+cache)
81 pub(super) neighbors: GraphStorage,
82
83 /// Vector storage (full precision or quantized)
84 pub(super) vectors: VectorStorage,
85
86 /// Entry point (top-level node)
87 pub(super) entry_point: Option<u32>,
88
89 /// Construction parameters
90 pub(super) params: HNSWParams,
91
92 /// Distance function
93 pub(super) distance_fn: DistanceFunction,
94
95 /// Random number generator seed state
96 pub(super) rng_state: u64,
97}
98
99impl HNSWIndex {
100 // =========================================================================
101 // Constructors
102 // =========================================================================
103
104 /// Build an HNSWIndex with pre-created vector storage
105 fn build(vectors: VectorStorage, params: HNSWParams, distance_fn: DistanceFunction) -> Self {
106 let neighbors = GraphStorage::new(params.max_level as usize);
107 Self {
108 nodes: Vec::new(),
109 neighbors,
110 vectors,
111 entry_point: None,
112 rng_state: params.seed,
113 params,
114 distance_fn,
115 }
116 }
117
118 /// Validate params and check that distance function is L2 (required for quantized modes)
119 fn validate_l2_required(
120 params: &HNSWParams,
121 distance_fn: DistanceFunction,
122 mode_name: &str,
123 ) -> Result<()> {
124 params.validate().map_err(HNSWError::InvalidParams)?;
125 if !matches!(distance_fn, DistanceFunction::L2) {
126 return Err(HNSWError::InvalidParams(format!(
127 "{mode_name} only supports L2 distance function"
128 )));
129 }
130 Ok(())
131 }
132
133 /// Create a new empty HNSW index
134 ///
135 /// # Arguments
136 /// * `dimensions` - Vector dimensionality
137 /// * `params` - HNSW construction parameters
138 /// * `distance_fn` - Distance function (L2, Cosine, Dot)
139 /// * `use_quantization` - Whether to use binary quantization
140 pub fn new(
141 dimensions: usize,
142 params: HNSWParams,
143 distance_fn: DistanceFunction,
144 use_quantization: bool,
145 ) -> Result<Self> {
146 params.validate().map_err(HNSWError::InvalidParams)?;
147
148 let vectors = if use_quantization {
149 VectorStorage::new_binary_quantized(dimensions, true)
150 } else {
151 VectorStorage::new_full_precision(dimensions)
152 };
153
154 Ok(Self::build(vectors, params, distance_fn))
155 }
156
157 /// Create a new HNSW index with `RaBitQ` asymmetric search (CLOUD MOAT)
158 ///
159 /// This enables 2-3x faster search by using asymmetric distance computation:
160 /// - Query vector stays full precision
161 /// - Candidate vectors use `RaBitQ` quantization (8x smaller)
162 /// - Final reranking uses full precision for accuracy
163 ///
164 /// # Arguments
165 /// * `dimensions` - Vector dimensionality
166 /// * `params` - HNSW construction parameters
167 /// * `distance_fn` - Distance function (only L2 supported for asymmetric)
168 /// * `rabitq_params` - `RaBitQ` quantization parameters (typically 4-bit)
169 ///
170 /// # Performance
171 /// - Search: 2-3x faster than full precision
172 /// - Memory: 8x smaller quantized storage (+ original for reranking)
173 /// - Recall: 98%+ with reranking
174 ///
175 /// # Example
176 /// ```ignore
177 /// let params = HNSWParams::default();
178 /// let rabitq = RaBitQParams::bits4(); // 4-bit, 8x compression
179 /// let index = HNSWIndex::new_with_asymmetric(128, params, DistanceFunction::L2, rabitq)?;
180 /// ```
181 pub fn new_with_asymmetric(
182 dimensions: usize,
183 params: HNSWParams,
184 distance_fn: DistanceFunction,
185 rabitq_params: RaBitQParams,
186 ) -> Result<Self> {
187 Self::validate_l2_required(¶ms, distance_fn, "RaBitQ asymmetric search")?;
188 let vectors = VectorStorage::new_rabitq_quantized(dimensions, rabitq_params);
189 Ok(Self::build(vectors, params, distance_fn))
190 }
191
192 /// Create new HNSW index with SQ8 (Scalar Quantization)
193 ///
194 /// SQ8 compresses f32 → u8 (4x smaller) and uses direct SIMD operations
195 /// for ~2x faster search than full precision.
196 ///
197 /// # Arguments
198 /// * `dimensions` - Vector dimensionality
199 /// * `params` - HNSW parameters (m, `ef_construction`, `ef_search`)
200 /// * `distance_fn` - Distance function (only L2 supported for SQ8)
201 ///
202 /// # Example
203 /// ```ignore
204 /// let params = HNSWParams::default();
205 /// let index = HNSWIndex::new_with_sq8(768, params, DistanceFunction::L2)?;
206 /// ```
207 pub fn new_with_sq8(
208 dimensions: usize,
209 params: HNSWParams,
210 distance_fn: DistanceFunction,
211 ) -> Result<Self> {
212 Self::validate_l2_required(¶ms, distance_fn, "SQ8 quantization")?;
213 let vectors = VectorStorage::new_sq8_quantized(dimensions);
214 Ok(Self::build(vectors, params, distance_fn))
215 }
216
217 /// Create new HNSW index with binary (1-bit) quantization
218 ///
219 /// Uses SIMD-optimized Hamming distance for fast search.
220 ///
221 /// # Performance
222 /// - Search: 2-4x faster than SQ8 (SIMD Hamming is extremely fast)
223 /// - Memory: 32x smaller quantized storage (+ original for reranking)
224 /// - Recall: ~85% raw, ~95-98% with reranking
225 ///
226 /// # Example
227 /// ```ignore
228 /// let params = HNSWParams::default();
229 /// let index = HNSWIndex::new_with_binary(768, params, DistanceFunction::L2)?;
230 /// ```
231 pub fn new_with_binary(
232 dimensions: usize,
233 params: HNSWParams,
234 distance_fn: DistanceFunction,
235 ) -> Result<Self> {
236 Self::validate_l2_required(¶ms, distance_fn, "Binary quantization")?;
237 let vectors = VectorStorage::new_binary_quantized(dimensions, true);
238 Ok(Self::build(vectors, params, distance_fn))
239 }
240
241 // =========================================================================
242 // Getters
243 // =========================================================================
244
245 /// Check if this index uses asymmetric search (`RaBitQ` or `SQ8`)
246 #[must_use]
247 pub fn is_asymmetric(&self) -> bool {
248 self.vectors.is_asymmetric()
249 }
250
251 /// Check if this index uses SQ8 quantization
252 #[must_use]
253 pub fn is_sq8(&self) -> bool {
254 self.vectors.is_sq8()
255 }
256
257 /// Train the quantizer from sample vectors
258 pub fn train_quantizer(&mut self, sample_vectors: &[Vec<f32>]) -> Result<()> {
259 self.vectors
260 .train_quantization(sample_vectors)
261 .map_err(HNSWError::InvalidParams)
262 }
263
264 /// Get number of vectors in index
265 #[must_use]
266 pub fn len(&self) -> usize {
267 self.nodes.len()
268 }
269
270 /// Check if index is empty
271 #[must_use]
272 pub fn is_empty(&self) -> bool {
273 self.nodes.is_empty()
274 }
275
276 /// Get dimensions
277 #[must_use]
278 pub fn dimensions(&self) -> usize {
279 self.vectors.dimensions()
280 }
281
282 /// Get a vector by ID (full precision)
283 ///
284 /// Returns None if the ID is invalid or out of bounds.
285 #[must_use]
286 pub fn get_vector(&self, id: u32) -> Option<&[f32]> {
287 self.vectors.get(id)
288 }
289
290 /// Get entry point
291 #[must_use]
292 pub fn entry_point(&self) -> Option<u32> {
293 self.entry_point
294 }
295
296 /// Get node level
297 #[must_use]
298 pub fn node_level(&self, node_id: u32) -> Option<u8> {
299 self.nodes.get(node_id as usize).map(|n| n.level)
300 }
301
302 /// Get neighbor count for a node at a level
303 #[must_use]
304 pub fn neighbor_count(&self, node_id: u32, level: u8) -> usize {
305 self.neighbors.get_neighbors(node_id, level).len()
306 }
307
308 /// Get HNSW parameters
309 #[must_use]
310 pub fn params(&self) -> &HNSWParams {
311 &self.params
312 }
313
314 /// Get neighbors at level 0 for a node
315 ///
316 /// Level 0 has the most connections (M*2) and is used for graph merging.
317 #[must_use]
318 pub fn get_neighbors_level0(&self, node_id: u32) -> Vec<u32> {
319 self.neighbors.get_neighbors(node_id, 0)
320 }
321
322 // =========================================================================
323 // Internal helpers
324 // =========================================================================
325
326 /// Assign random level to new node
327 ///
328 /// Uses exponential decay: P(level = l) = (1/M)^l
329 /// This ensures most nodes are at level 0, fewer at higher levels.
330 pub(super) fn random_level(&mut self) -> u8 {
331 // Simple LCG for deterministic random numbers
332 self.rng_state = self
333 .rng_state
334 .wrapping_mul(6_364_136_223_846_793_005)
335 .wrapping_add(1);
336 let rand_val = (self.rng_state >> 32) as f32 / u32::MAX as f32;
337
338 // Exponential distribution: -ln(uniform) / ln(M)
339 let level = (-rand_val.ln() * self.params.ml) as u8;
340 level.min(self.params.max_level - 1)
341 }
342
343 // =========================================================================
344 // Distance functions
345 // =========================================================================
346
347 /// Distance between nodes for ordering comparisons
348 ///
349 /// Uses dequantized vectors if storage is quantized (SQ8).
350 #[inline]
351 pub(super) fn distance_between_cmp(&self, id_a: u32, id_b: u32) -> Result<f32> {
352 // Try asymmetric distance first (for SQ8/RaBitQ - use id_b as quantized candidate)
353 if let Some(vec_a) = self.vectors.get_dequantized(id_a) {
354 if let Some(dist) = self.vectors.distance_asymmetric_l2(&vec_a, id_b) {
355 return Ok(dist);
356 }
357 }
358 // Fallback to full precision
359 let vec_a = self
360 .vectors
361 .get(id_a)
362 .ok_or(HNSWError::VectorNotFound(id_a))?;
363 let vec_b = self
364 .vectors
365 .get(id_b)
366 .ok_or(HNSWError::VectorNotFound(id_b))?;
367 Ok(self.distance_fn.distance_for_comparison(vec_a, vec_b))
368 }
369
370 /// Distance from query to node for ordering comparisons
371 ///
372 /// Tries asymmetric distance first (for SQ8/RaBitQ), falls back to full precision.
373 #[inline(always)]
374 pub(super) fn distance_cmp(&self, query: &[f32], id: u32) -> Result<f32> {
375 // Try asymmetric distance first (for SQ8/RaBitQ storage)
376 if let Some(dist) = self.vectors.distance_asymmetric_l2(query, id) {
377 return Ok(dist);
378 }
379 // Fallback to full precision
380 let vec = self.vectors.get(id).ok_or(HNSWError::VectorNotFound(id))?;
381 Ok(self.distance_fn.distance_for_comparison(query, vec))
382 }
383
384 /// Monomorphized distance computation (static dispatch, no match)
385 ///
386 /// Critical for x86/ARM servers where branch misprediction hurts performance.
387 /// The Distance trait enables compile-time specialization.
388 #[inline(always)]
389 pub(super) fn distance_cmp_mono<D: Distance>(&self, query: &[f32], id: u32) -> Result<f32> {
390 // Try asymmetric distance first (for SQ8/RaBitQ storage)
391 if let Some(dist) = self.vectors.distance_asymmetric_l2(query, id) {
392 return Ok(dist);
393 }
394 // Fallback to full precision with static dispatch
395 let vec = self.vectors.get(id).ok_or(HNSWError::VectorNotFound(id))?;
396 Ok(D::distance(query, vec))
397 }
398
399 /// Distance from query to node using full precision (f32-to-f32)
400 ///
401 /// Used during graph construction where quantization noise hurts graph quality.
402 /// For RaBitQ, uses stored originals. For SQ8, dequantizes.
403 #[inline]
404 pub(super) fn distance_cmp_full_precision(&self, query: &[f32], id: u32) -> Result<f32> {
405 // Always use dequantized/original vectors for full precision comparison
406 let vec = self
407 .vectors
408 .get_dequantized(id)
409 .ok_or(HNSWError::VectorNotFound(id))?;
410 Ok(self.distance_fn.distance_for_comparison(query, &vec))
411 }
412
413 /// Actual distance (with sqrt for L2)
414 #[inline]
415 pub(super) fn distance_exact(&self, query: &[f32], id: u32) -> Result<f32> {
416 // For SQ8/RaBitQ: use asymmetric distance (returns squared L2)
417 // For Binary: skip asymmetric (hamming is not L2), use original vectors
418 if !self.vectors.is_binary_quantized() {
419 if let Some(dist) = self.vectors.distance_asymmetric_l2(query, id) {
420 return Ok(dist.sqrt());
421 }
422 }
423 let vec = self.vectors.get(id).ok_or(HNSWError::VectorNotFound(id))?;
424 Ok(self.distance_fn.distance(query, vec))
425 }
426
427 /// L2 distance using decomposition: ||a-b||² = ||a||² + ||b||² - 2⟨a,b⟩
428 ///
429 /// ~7% faster than direct L2 by pre-computing vector norms during insert.
430 /// Query norm is computed once per search and passed in.
431 ///
432 /// Returns None if decomposition is not available (non-FullPrecision storage).
433 #[inline(always)]
434 pub(super) fn distance_l2_decomposed(
435 &self,
436 query: &[f32],
437 query_norm: f32,
438 id: u32,
439 ) -> Option<f32> {
440 self.vectors.distance_l2_decomposed(query, query_norm, id)
441 }
442
443 /// Check if L2 decomposition optimization is available
444 ///
445 /// Returns true if storage supports L2 decomposition AND distance function is L2.
446 #[inline]
447 pub(super) fn supports_l2_decomposition(&self) -> bool {
448 matches!(self.distance_fn, DistanceFunction::L2) && self.vectors.supports_l2_decomposition()
449 }
450}