omendb_core/vector/hnsw/
merge.rs

1// HNSW Graph Merging - IGTM Algorithm
2//
3// Implements Iterative Greedy Tree Merging from:
4// - Elasticsearch Labs: "Speeding up merging of HNSW graphs" (2024-2025)
5// - arXiv:2505.16064 (May 2025) - MERGE-HNSW algorithms
6//
7// Expected speedup: 1.28-1.72x for batch inserts (validated in Lucene 10.2)
8
9use super::error::{HNSWError, Result};
10use super::index::HNSWIndex;
11use std::collections::{HashMap, HashSet};
12use std::time::{Duration, Instant};
13use tracing::{debug, info, instrument, warn};
14
15/// Configuration for graph merging
16#[derive(Clone, Debug)]
17pub struct MergeConfig {
18    /// Minimum number of neighbors each vertex must have in the join set
19    /// Default: 2 (from IGTM paper)
20    pub min_coverage: usize,
21
22    /// ef parameter for fast search during merge (lower than `ef_construction` for speed)
23    /// Default: `ef_construction` / 2
24    pub fast_ef: Option<usize>,
25
26    /// Whether to use parallel join set computation
27    pub parallel_join_set: bool,
28}
29
30impl Default for MergeConfig {
31    fn default() -> Self {
32        Self {
33            min_coverage: 2,
34            fast_ef: None,
35            parallel_join_set: true,
36        }
37    }
38}
39
40/// Statistics from a merge operation
41#[derive(Clone, Debug)]
42pub struct MergeStats {
43    /// Total vectors merged from small graph
44    pub vectors_merged: usize,
45
46    /// Size of the join set (strategic vertices inserted first)
47    pub join_set_size: usize,
48
49    /// Time spent computing join set
50    pub join_set_duration: Duration,
51
52    /// Time spent inserting join set
53    pub join_set_insert_duration: Duration,
54
55    /// Time spent inserting remaining vectors
56    pub remaining_insert_duration: Duration,
57
58    /// Total merge duration
59    pub total_duration: Duration,
60
61    /// Vectors inserted using fast path (entry points from join set)
62    pub fast_path_inserts: usize,
63
64    /// Vectors inserted using fallback (standard insert)
65    pub fallback_inserts: usize,
66}
67
68impl MergeStats {
69    /// Calculate speedup vs naive approach (estimated)
70    #[must_use]
71    pub fn estimated_speedup(&self) -> f64 {
72        // Naive: all vectors go through full search
73        // IGTM: join_set gets full search, remaining get fast search
74        // Fast search is ~5x faster per vector
75
76        // Assuming fast path is 5x faster
77        // Speedup = 1 / (join_set_ratio + remaining_ratio * 0.2)
78        let join_set_ratio = self.join_set_size as f64 / self.vectors_merged.max(1) as f64;
79        let remaining_ratio = 1.0 - join_set_ratio;
80
81        1.0 / (join_set_ratio + remaining_ratio * 0.2)
82    }
83}
84
85/// HNSW Graph Merger using IGTM algorithm
86///
87/// Merges a small graph into a large graph using strategic vertex selection.
88/// Expected 1.3-1.7x speedup over naive insertion.
89pub struct GraphMerger {
90    config: MergeConfig,
91}
92
93impl GraphMerger {
94    /// Create a new graph merger with default configuration
95    #[must_use]
96    pub fn new() -> Self {
97        Self {
98            config: MergeConfig::default(),
99        }
100    }
101
102    /// Create a new graph merger with custom configuration
103    #[must_use]
104    pub fn with_config(config: MergeConfig) -> Self {
105        Self { config }
106    }
107
108    /// Merge a small graph into a large graph using IGTM algorithm
109    ///
110    /// # Algorithm
111    /// 1. Compute join set: Find minimal vertex subset that covers all vertices
112    ///    (every vertex has ≥`min_coverage` neighbors in the join set)
113    /// 2. Insert join set into large graph using standard insertion
114    /// 3. For remaining vertices, use join set neighbors as entry points for fast insertion
115    ///
116    /// # Arguments
117    /// * `large` - Target graph (will be modified)
118    /// * `small` - Source graph (vectors will be moved)
119    ///
120    /// # Returns
121    /// Merge statistics including timing breakdown
122    #[instrument(skip(self, large, small), fields(large_size = large.len(), small_size = small.len()))]
123    pub fn merge_graphs(&self, large: &mut HNSWIndex, small: &HNSWIndex) -> Result<MergeStats> {
124        let total_start = Instant::now();
125        let small_size = small.len();
126
127        if small_size == 0 {
128            return Ok(MergeStats {
129                vectors_merged: 0,
130                join_set_size: 0,
131                join_set_duration: Duration::ZERO,
132                join_set_insert_duration: Duration::ZERO,
133                remaining_insert_duration: Duration::ZERO,
134                total_duration: total_start.elapsed(),
135                fast_path_inserts: 0,
136                fallback_inserts: 0,
137            });
138        }
139
140        info!(
141            large_size = large.len(),
142            small_size = small_size,
143            "Starting IGTM graph merge"
144        );
145
146        // Phase 1: Compute join set
147        let join_set_start = Instant::now();
148        let join_set = self.compute_join_set(small);
149        let join_set_duration = join_set_start.elapsed();
150
151        debug!(
152            join_set_size = join_set.len(),
153            coverage_target = self.config.min_coverage,
154            duration_ms = join_set_duration.as_millis(),
155            "Join set computed"
156        );
157
158        // Phase 2: Insert join set vectors and build ID mapping
159        let join_insert_start = Instant::now();
160        let mut small_to_large: HashMap<u32, u32> = HashMap::with_capacity(join_set.len());
161
162        for &small_id in &join_set {
163            let vector = small
164                .get_vector(small_id)
165                .ok_or(HNSWError::VectorNotFound(small_id))?;
166            let large_id = large.insert(vector)?;
167            small_to_large.insert(small_id, large_id);
168        }
169        let join_set_insert_duration = join_insert_start.elapsed();
170
171        debug!(
172            inserted = join_set.len(),
173            duration_ms = join_set_insert_duration.as_millis(),
174            "Join set inserted"
175        );
176
177        // Phase 3: Insert remaining vectors using fast path
178        let remaining_start = Instant::now();
179        let mut fast_path_inserts = 0;
180        let mut fallback_inserts = 0;
181
182        let fast_ef = self
183            .config
184            .fast_ef
185            .unwrap_or(large.params().ef_construction / 2);
186
187        for node_id in 0..small.len() as u32 {
188            if join_set.contains(&node_id) {
189                continue;
190            }
191
192            let vector = small
193                .get_vector(node_id)
194                .ok_or(HNSWError::VectorNotFound(node_id))?;
195
196            // Find neighbors of this node that are in the join set
197            // Map small graph IDs to large graph IDs
198            let small_neighbors = small.get_neighbors_level0(node_id);
199            let entry_points: Vec<u32> = small_neighbors
200                .iter()
201                .filter_map(|&small_neighbor_id| small_to_large.get(&small_neighbor_id).copied())
202                .collect();
203
204            if entry_points.is_empty() {
205                // Fallback: no join set neighbors, use standard insert
206                large.insert(vector)?;
207                fallback_inserts += 1;
208            } else {
209                // Fast path: use mapped join set neighbors as entry points in large graph
210                large.insert_with_hints(vector, &entry_points, fast_ef)?;
211                fast_path_inserts += 1;
212            }
213        }
214        let remaining_insert_duration = remaining_start.elapsed();
215
216        let total_duration = total_start.elapsed();
217
218        let stats = MergeStats {
219            vectors_merged: small_size,
220            join_set_size: join_set.len(),
221            join_set_duration,
222            join_set_insert_duration,
223            remaining_insert_duration,
224            total_duration,
225            fast_path_inserts,
226            fallback_inserts,
227        };
228
229        info!(
230            vectors_merged = stats.vectors_merged,
231            join_set_size = stats.join_set_size,
232            fast_path_ratio = format!(
233                "{:.1}%",
234                (stats.fast_path_inserts as f64 / stats.vectors_merged.max(1) as f64) * 100.0
235            ),
236            total_ms = stats.total_duration.as_millis(),
237            estimated_speedup = format!("{:.2}x", stats.estimated_speedup()),
238            "IGTM merge complete"
239        );
240
241        Ok(stats)
242    }
243
244    /// Compute join set using greedy covering algorithm
245    ///
246    /// Finds minimal subset J such that every vertex v has ≥`min_coverage` neighbors in J.
247    /// Uses greedy selection: pick vertex maximizing coverage gain at each step.
248    fn compute_join_set(&self, graph: &HNSWIndex) -> HashSet<u32> {
249        let mut join_set = HashSet::new();
250        let mut coverage: HashMap<u32, usize> = HashMap::new();
251
252        let num_vectors = graph.len();
253        if num_vectors == 0 {
254            return join_set;
255        }
256
257        // Greedy selection until all vertices are covered
258        while !self.is_fully_covered(&coverage, graph) {
259            // Find vertex with maximum gain
260            let best = (0..num_vectors as u32)
261                .filter(|id| !join_set.contains(id))
262                .max_by_key(|&id| {
263                    self.calculate_gain(id, &join_set, &coverage, graph)
264                        .unwrap_or(0)
265                });
266
267            if let Some(best_id) = best {
268                join_set.insert(best_id);
269
270                // Update coverage: all neighbors of best_id gain a neighbor in J
271                let neighbors = graph.get_neighbors_level0(best_id);
272                for &neighbor in &neighbors {
273                    *coverage.entry(neighbor).or_insert(0) += 1;
274                }
275
276                // Also update coverage for best_id itself (it's now covered)
277                *coverage.entry(best_id).or_insert(0) += self.config.min_coverage;
278            } else {
279                // No more vertices to add, but not fully covered
280                // This can happen with disconnected components
281                warn!("Join set computation terminated early - graph may have disconnected components");
282                break;
283            }
284        }
285
286        join_set
287    }
288
289    /// Calculate gain for adding vertex to join set
290    ///
291    /// Gain = number of vertices that would increase their coverage
292    #[allow(clippy::unnecessary_wraps)]
293    fn calculate_gain(
294        &self,
295        vertex_id: u32,
296        join_set: &HashSet<u32>,
297        coverage: &HashMap<u32, usize>,
298        graph: &HNSWIndex,
299    ) -> Result<usize> {
300        // Skip if already in join set
301        if join_set.contains(&vertex_id) {
302            return Ok(0);
303        }
304
305        let neighbors = graph.get_neighbors_level0(vertex_id);
306        let mut gain = 0;
307
308        // Gain for self (if not yet covered)
309        let self_coverage = coverage.get(&vertex_id).copied().unwrap_or(0);
310        if self_coverage < self.config.min_coverage {
311            gain += 1;
312        }
313
314        // Gain for each neighbor that would benefit
315        for &neighbor in &neighbors {
316            let neighbor_coverage = coverage.get(&neighbor).copied().unwrap_or(0);
317            if neighbor_coverage < self.config.min_coverage {
318                gain += 1;
319            }
320        }
321
322        Ok(gain)
323    }
324
325    /// Check if all vertices have sufficient coverage
326    fn is_fully_covered(&self, coverage: &HashMap<u32, usize>, graph: &HNSWIndex) -> bool {
327        for node_id in 0..graph.len() as u32 {
328            let c = coverage.get(&node_id).copied().unwrap_or(0);
329            if c < self.config.min_coverage {
330                return false;
331            }
332        }
333        true
334    }
335}
336
337impl Default for GraphMerger {
338    fn default() -> Self {
339        Self::new()
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use crate::vector::hnsw::{DistanceFunction, HNSWParams};
347
348    fn create_test_index(num_vectors: usize, dim: usize) -> HNSWIndex {
349        let params = HNSWParams {
350            m: 16,
351            ef_construction: 100,
352            ..Default::default()
353        };
354        let mut index = HNSWIndex::new(dim, params, DistanceFunction::L2, false).unwrap();
355
356        for i in 0..num_vectors {
357            let vector: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 / 100.0).collect();
358            index.insert(&vector).unwrap();
359        }
360
361        index
362    }
363
364    #[test]
365    fn test_merge_empty_small_graph() {
366        let mut large = create_test_index(100, 8);
367        let small = HNSWIndex::new(8, HNSWParams::default(), DistanceFunction::L2, false).unwrap();
368
369        let merger = GraphMerger::new();
370        let stats = merger.merge_graphs(&mut large, &small).unwrap();
371
372        assert_eq!(stats.vectors_merged, 0);
373        assert_eq!(stats.join_set_size, 0);
374        assert_eq!(large.len(), 100);
375    }
376
377    #[test]
378    fn test_merge_small_graphs() {
379        let mut large = create_test_index(100, 8);
380        let small = create_test_index(50, 8);
381
382        let initial_size = large.len();
383        let merger = GraphMerger::new();
384        let stats = merger.merge_graphs(&mut large, &small).unwrap();
385
386        assert_eq!(stats.vectors_merged, 50);
387        assert_eq!(large.len(), initial_size + 50);
388        assert!(stats.join_set_size > 0);
389        assert!(stats.join_set_size <= 50);
390    }
391
392    #[test]
393    fn test_join_set_coverage() {
394        let small = create_test_index(100, 8);
395        let merger = GraphMerger::new();
396
397        let join_set = merger.compute_join_set(&small);
398
399        // Join set should be non-empty
400        assert!(!join_set.is_empty());
401
402        // Join set should be smaller than total (typically 10-30%)
403        assert!(join_set.len() < small.len());
404
405        // All vertices should have sufficient coverage
406        let mut coverage: HashMap<u32, usize> = HashMap::new();
407        for &j_id in &join_set {
408            let neighbors = small.get_neighbors_level0(j_id);
409            for &n in &neighbors {
410                *coverage.entry(n).or_insert(0) += 1;
411            }
412            *coverage.entry(j_id).or_insert(0) += merger.config.min_coverage;
413        }
414
415        for node_id in 0..small.len() as u32 {
416            let c = coverage.get(&node_id).copied().unwrap_or(0);
417            assert!(
418                c >= merger.config.min_coverage,
419                "Node {} has insufficient coverage: {} < {}",
420                node_id,
421                c,
422                merger.config.min_coverage
423            );
424        }
425    }
426
427    #[test]
428    fn test_merge_preserves_searchability() {
429        let mut large = create_test_index(100, 8);
430        let small = create_test_index(50, 8);
431
432        // Remember a vector from small graph
433        let test_vector = small.get_vector(25).unwrap().to_vec();
434
435        let merger = GraphMerger::new();
436        merger.merge_graphs(&mut large, &small).unwrap();
437
438        // Should be able to find similar vectors after merge
439        let results = large.search(&test_vector, 5, 50).unwrap();
440        assert!(!results.is_empty());
441
442        // At least one result should be close
443        assert!(results[0].distance < 1.0);
444    }
445}