omendb_core/vector/hnsw/
merge.rs

1// HNSW Graph Merging - IGTM Algorithm
2//
3// Implements Iterative Greedy Tree Merging from:
4// - Elasticsearch Labs: "Speeding up merging of HNSW graphs" (2024-2025)
5// - arXiv:2505.16064 (May 2025) - MERGE-HNSW algorithms
6//
7// Expected speedup: 1.28-1.72x for batch inserts (validated in Lucene 10.2)
8
9use super::error::{HNSWError, Result};
10use super::index::HNSWIndex;
11use std::collections::{HashMap, HashSet};
12use std::time::{Duration, Instant};
13use tracing::{debug, info, instrument, warn};
14
15/// Configuration for graph merging
16#[derive(Clone, Debug)]
17pub struct MergeConfig {
18    /// Minimum number of neighbors each vertex must have in the join set
19    /// Default: 2 (from IGTM paper)
20    pub min_coverage: usize,
21
22    /// ef parameter for fast search during merge (lower than `ef_construction` for speed)
23    /// Default: `ef_construction` / 2
24    pub fast_ef: Option<usize>,
25
26    /// Whether to use parallel join set computation
27    pub parallel_join_set: bool,
28}
29
30impl Default for MergeConfig {
31    fn default() -> Self {
32        Self {
33            min_coverage: 2,
34            fast_ef: None,
35            parallel_join_set: true,
36        }
37    }
38}
39
40/// Statistics from a merge operation
41#[derive(Clone, Debug)]
42pub struct MergeStats {
43    /// Total vectors merged from small graph
44    pub vectors_merged: usize,
45
46    /// Size of the join set (strategic vertices inserted first)
47    pub join_set_size: usize,
48
49    /// Time spent computing join set
50    pub join_set_duration: Duration,
51
52    /// Time spent inserting join set
53    pub join_set_insert_duration: Duration,
54
55    /// Time spent inserting remaining vectors
56    pub remaining_insert_duration: Duration,
57
58    /// Total merge duration
59    pub total_duration: Duration,
60
61    /// Vectors inserted using fast path (entry points from join set)
62    pub fast_path_inserts: usize,
63
64    /// Vectors inserted using fallback (standard insert)
65    pub fallback_inserts: usize,
66}
67
68impl MergeStats {
69    /// Calculate speedup vs naive approach (estimated)
70    #[must_use]
71    pub fn estimated_speedup(&self) -> f64 {
72        // Naive: all vectors go through full search
73        // IGTM: join_set gets full search, remaining get fast search
74        // Fast search is ~5x faster per vector
75
76        // Assuming fast path is 5x faster
77        // Speedup = 1 / (join_set_ratio + remaining_ratio * 0.2)
78        let join_set_ratio = self.join_set_size as f64 / self.vectors_merged.max(1) as f64;
79        let remaining_ratio = 1.0 - join_set_ratio;
80
81        1.0 / (join_set_ratio + remaining_ratio * 0.2)
82    }
83}
84
85/// HNSW Graph Merger using IGTM algorithm
86///
87/// Merges a small graph into a large graph using strategic vertex selection.
88/// Expected 1.3-1.7x speedup over naive insertion.
89pub struct GraphMerger {
90    config: MergeConfig,
91}
92
93impl GraphMerger {
94    /// Create a new graph merger with default configuration
95    #[must_use]
96    pub fn new() -> Self {
97        Self {
98            config: MergeConfig::default(),
99        }
100    }
101
102    /// Create a new graph merger with custom configuration
103    #[must_use]
104    pub fn with_config(config: MergeConfig) -> Self {
105        Self { config }
106    }
107
108    /// Merge a small graph into a large graph using IGTM algorithm
109    ///
110    /// # Algorithm
111    /// 1. Compute join set: Find minimal vertex subset that covers all vertices
112    ///    (every vertex has ≥`min_coverage` neighbors in the join set)
113    /// 2. Insert join set into large graph using standard insertion
114    /// 3. For remaining vertices, use join set neighbors as entry points for fast insertion
115    ///
116    /// # Arguments
117    /// * `large` - Target graph (will be modified)
118    /// * `small` - Source graph (vectors will be moved)
119    ///
120    /// # Returns
121    /// Merge statistics including timing breakdown
122    #[instrument(skip(self, large, small), fields(large_size = large.len(), small_size = small.len()))]
123    pub fn merge_graphs(&self, large: &mut HNSWIndex, small: &HNSWIndex) -> Result<MergeStats> {
124        let total_start = Instant::now();
125        let small_size = small.len();
126
127        if small_size == 0 {
128            return Ok(MergeStats {
129                vectors_merged: 0,
130                join_set_size: 0,
131                join_set_duration: Duration::ZERO,
132                join_set_insert_duration: Duration::ZERO,
133                remaining_insert_duration: Duration::ZERO,
134                total_duration: total_start.elapsed(),
135                fast_path_inserts: 0,
136                fallback_inserts: 0,
137            });
138        }
139
140        info!(
141            large_size = large.len(),
142            small_size = small_size,
143            "Starting IGTM graph merge"
144        );
145
146        // Phase 1: Compute join set
147        let join_set_start = Instant::now();
148        let join_set = self.compute_join_set(small);
149        let join_set_duration = join_set_start.elapsed();
150
151        debug!(
152            join_set_size = join_set.len(),
153            coverage_target = self.config.min_coverage,
154            duration_ms = join_set_duration.as_millis(),
155            "Join set computed"
156        );
157
158        // Phase 2: Insert join set vectors
159        let join_insert_start = Instant::now();
160        for &node_id in &join_set {
161            let vector = small
162                .get_vector(node_id)
163                .ok_or(HNSWError::VectorNotFound(node_id))?;
164            large.insert(vector)?;
165        }
166        let join_set_insert_duration = join_insert_start.elapsed();
167
168        debug!(
169            inserted = join_set.len(),
170            duration_ms = join_set_insert_duration.as_millis(),
171            "Join set inserted"
172        );
173
174        // Phase 3: Insert remaining vectors using fast path
175        let remaining_start = Instant::now();
176        let mut fast_path_inserts = 0;
177        let mut fallback_inserts = 0;
178
179        let fast_ef = self
180            .config
181            .fast_ef
182            .unwrap_or(large.params().ef_construction / 2);
183
184        for node_id in 0..small.len() as u32 {
185            if join_set.contains(&node_id) {
186                continue;
187            }
188
189            let vector = small
190                .get_vector(node_id)
191                .ok_or(HNSWError::VectorNotFound(node_id))?;
192
193            // Find neighbors of this node that are in the join set
194            let small_neighbors = small.get_neighbors_level0(node_id);
195            let entry_points: Vec<u32> = small_neighbors
196                .iter()
197                .filter(|&&n| join_set.contains(&n))
198                .copied()
199                .collect();
200
201            if entry_points.is_empty() {
202                // Fallback: no join set neighbors, use standard insert
203                large.insert(vector)?;
204                fallback_inserts += 1;
205            } else {
206                // Fast path: use join set neighbors as entry points
207                // These vectors were already inserted, so we can find them in large graph
208                large.insert_with_hints(vector, &entry_points, fast_ef)?;
209                fast_path_inserts += 1;
210            }
211        }
212        let remaining_insert_duration = remaining_start.elapsed();
213
214        let total_duration = total_start.elapsed();
215
216        let stats = MergeStats {
217            vectors_merged: small_size,
218            join_set_size: join_set.len(),
219            join_set_duration,
220            join_set_insert_duration,
221            remaining_insert_duration,
222            total_duration,
223            fast_path_inserts,
224            fallback_inserts,
225        };
226
227        info!(
228            vectors_merged = stats.vectors_merged,
229            join_set_size = stats.join_set_size,
230            fast_path_ratio = format!(
231                "{:.1}%",
232                (stats.fast_path_inserts as f64 / stats.vectors_merged.max(1) as f64) * 100.0
233            ),
234            total_ms = stats.total_duration.as_millis(),
235            estimated_speedup = format!("{:.2}x", stats.estimated_speedup()),
236            "IGTM merge complete"
237        );
238
239        Ok(stats)
240    }
241
242    /// Compute join set using greedy covering algorithm
243    ///
244    /// Finds minimal subset J such that every vertex v has ≥`min_coverage` neighbors in J.
245    /// Uses greedy selection: pick vertex maximizing coverage gain at each step.
246    fn compute_join_set(&self, graph: &HNSWIndex) -> HashSet<u32> {
247        let mut join_set = HashSet::new();
248        let mut coverage: HashMap<u32, usize> = HashMap::new();
249
250        let num_vectors = graph.len();
251        if num_vectors == 0 {
252            return join_set;
253        }
254
255        // Greedy selection until all vertices are covered
256        while !self.is_fully_covered(&coverage, graph) {
257            // Find vertex with maximum gain
258            let best = (0..num_vectors as u32)
259                .filter(|id| !join_set.contains(id))
260                .max_by_key(|&id| {
261                    self.calculate_gain(id, &join_set, &coverage, graph)
262                        .unwrap_or(0)
263                });
264
265            if let Some(best_id) = best {
266                join_set.insert(best_id);
267
268                // Update coverage: all neighbors of best_id gain a neighbor in J
269                let neighbors = graph.get_neighbors_level0(best_id);
270                for &neighbor in &neighbors {
271                    *coverage.entry(neighbor).or_insert(0) += 1;
272                }
273
274                // Also update coverage for best_id itself (it's now covered)
275                *coverage.entry(best_id).or_insert(0) += self.config.min_coverage;
276            } else {
277                // No more vertices to add, but not fully covered
278                // This can happen with disconnected components
279                warn!("Join set computation terminated early - graph may have disconnected components");
280                break;
281            }
282        }
283
284        join_set
285    }
286
287    /// Calculate gain for adding vertex to join set
288    ///
289    /// Gain = number of vertices that would increase their coverage
290    #[allow(clippy::unnecessary_wraps)]
291    fn calculate_gain(
292        &self,
293        vertex_id: u32,
294        join_set: &HashSet<u32>,
295        coverage: &HashMap<u32, usize>,
296        graph: &HNSWIndex,
297    ) -> Result<usize> {
298        // Skip if already in join set
299        if join_set.contains(&vertex_id) {
300            return Ok(0);
301        }
302
303        let neighbors = graph.get_neighbors_level0(vertex_id);
304        let mut gain = 0;
305
306        // Gain for self (if not yet covered)
307        let self_coverage = coverage.get(&vertex_id).copied().unwrap_or(0);
308        if self_coverage < self.config.min_coverage {
309            gain += 1;
310        }
311
312        // Gain for each neighbor that would benefit
313        for &neighbor in &neighbors {
314            let neighbor_coverage = coverage.get(&neighbor).copied().unwrap_or(0);
315            if neighbor_coverage < self.config.min_coverage {
316                gain += 1;
317            }
318        }
319
320        Ok(gain)
321    }
322
323    /// Check if all vertices have sufficient coverage
324    fn is_fully_covered(&self, coverage: &HashMap<u32, usize>, graph: &HNSWIndex) -> bool {
325        for node_id in 0..graph.len() as u32 {
326            let c = coverage.get(&node_id).copied().unwrap_or(0);
327            if c < self.config.min_coverage {
328                return false;
329            }
330        }
331        true
332    }
333}
334
335impl Default for GraphMerger {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use crate::vector::hnsw::{DistanceFunction, HNSWParams};
345
346    fn create_test_index(num_vectors: usize, dim: usize) -> HNSWIndex {
347        let params = HNSWParams {
348            m: 16,
349            ef_construction: 100,
350            ..Default::default()
351        };
352        let mut index = HNSWIndex::new(dim, params, DistanceFunction::L2, false).unwrap();
353
354        for i in 0..num_vectors {
355            let vector: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 / 100.0).collect();
356            index.insert(&vector).unwrap();
357        }
358
359        index
360    }
361
362    #[test]
363    fn test_merge_empty_small_graph() {
364        let mut large = create_test_index(100, 8);
365        let small = HNSWIndex::new(8, HNSWParams::default(), DistanceFunction::L2, false).unwrap();
366
367        let merger = GraphMerger::new();
368        let stats = merger.merge_graphs(&mut large, &small).unwrap();
369
370        assert_eq!(stats.vectors_merged, 0);
371        assert_eq!(stats.join_set_size, 0);
372        assert_eq!(large.len(), 100);
373    }
374
375    #[test]
376    fn test_merge_small_graphs() {
377        let mut large = create_test_index(100, 8);
378        let small = create_test_index(50, 8);
379
380        let initial_size = large.len();
381        let merger = GraphMerger::new();
382        let stats = merger.merge_graphs(&mut large, &small).unwrap();
383
384        assert_eq!(stats.vectors_merged, 50);
385        assert_eq!(large.len(), initial_size + 50);
386        assert!(stats.join_set_size > 0);
387        assert!(stats.join_set_size <= 50);
388    }
389
390    #[test]
391    fn test_join_set_coverage() {
392        let small = create_test_index(100, 8);
393        let merger = GraphMerger::new();
394
395        let join_set = merger.compute_join_set(&small);
396
397        // Join set should be non-empty
398        assert!(!join_set.is_empty());
399
400        // Join set should be smaller than total (typically 10-30%)
401        assert!(join_set.len() < small.len());
402
403        // All vertices should have sufficient coverage
404        let mut coverage: HashMap<u32, usize> = HashMap::new();
405        for &j_id in &join_set {
406            let neighbors = small.get_neighbors_level0(j_id);
407            for &n in &neighbors {
408                *coverage.entry(n).or_insert(0) += 1;
409            }
410            *coverage.entry(j_id).or_insert(0) += merger.config.min_coverage;
411        }
412
413        for node_id in 0..small.len() as u32 {
414            let c = coverage.get(&node_id).copied().unwrap_or(0);
415            assert!(
416                c >= merger.config.min_coverage,
417                "Node {} has insufficient coverage: {} < {}",
418                node_id,
419                c,
420                merger.config.min_coverage
421            );
422        }
423    }
424
425    #[test]
426    fn test_merge_preserves_searchability() {
427        let mut large = create_test_index(100, 8);
428        let small = create_test_index(50, 8);
429
430        // Remember a vector from small graph
431        let test_vector = small.get_vector(25).unwrap().to_vec();
432
433        let merger = GraphMerger::new();
434        merger.merge_graphs(&mut large, &small).unwrap();
435
436        // Should be able to find similar vectors after merge
437        let results = large.search(&test_vector, 5, 50).unwrap();
438        assert!(!results.is_empty());
439
440        // At least one result should be close
441        assert!(results[0].distance < 1.0);
442    }
443}