1use super::error::{HNSWError, Result};
10use super::index::HNSWIndex;
11use std::collections::{HashMap, HashSet};
12use std::time::{Duration, Instant};
13use tracing::{debug, info, instrument, warn};
14
15#[derive(Clone, Debug)]
17pub struct MergeConfig {
18 pub min_coverage: usize,
21
22 pub fast_ef: Option<usize>,
25
26 pub parallel_join_set: bool,
28}
29
30impl Default for MergeConfig {
31 fn default() -> Self {
32 Self {
33 min_coverage: 2,
34 fast_ef: None,
35 parallel_join_set: true,
36 }
37 }
38}
39
40#[derive(Clone, Debug)]
42pub struct MergeStats {
43 pub vectors_merged: usize,
45
46 pub join_set_size: usize,
48
49 pub join_set_duration: Duration,
51
52 pub join_set_insert_duration: Duration,
54
55 pub remaining_insert_duration: Duration,
57
58 pub total_duration: Duration,
60
61 pub fast_path_inserts: usize,
63
64 pub fallback_inserts: usize,
66}
67
68impl MergeStats {
69 #[must_use]
71 pub fn estimated_speedup(&self) -> f64 {
72 let join_set_ratio = self.join_set_size as f64 / self.vectors_merged.max(1) as f64;
79 let remaining_ratio = 1.0 - join_set_ratio;
80
81 1.0 / (join_set_ratio + remaining_ratio * 0.2)
82 }
83}
84
85pub struct GraphMerger {
90 config: MergeConfig,
91}
92
93impl GraphMerger {
94 #[must_use]
96 pub fn new() -> Self {
97 Self {
98 config: MergeConfig::default(),
99 }
100 }
101
102 #[must_use]
104 pub fn with_config(config: MergeConfig) -> Self {
105 Self { config }
106 }
107
108 #[instrument(skip(self, large, small), fields(large_size = large.len(), small_size = small.len()))]
123 pub fn merge_graphs(&self, large: &mut HNSWIndex, small: &HNSWIndex) -> Result<MergeStats> {
124 let total_start = Instant::now();
125 let small_size = small.len();
126
127 if small_size == 0 {
128 return Ok(MergeStats {
129 vectors_merged: 0,
130 join_set_size: 0,
131 join_set_duration: Duration::ZERO,
132 join_set_insert_duration: Duration::ZERO,
133 remaining_insert_duration: Duration::ZERO,
134 total_duration: total_start.elapsed(),
135 fast_path_inserts: 0,
136 fallback_inserts: 0,
137 });
138 }
139
140 info!(
141 large_size = large.len(),
142 small_size = small_size,
143 "Starting IGTM graph merge"
144 );
145
146 let join_set_start = Instant::now();
148 let join_set = self.compute_join_set(small);
149 let join_set_duration = join_set_start.elapsed();
150
151 debug!(
152 join_set_size = join_set.len(),
153 coverage_target = self.config.min_coverage,
154 duration_ms = join_set_duration.as_millis(),
155 "Join set computed"
156 );
157
158 let join_insert_start = Instant::now();
160 for &node_id in &join_set {
161 let vector = small
162 .get_vector(node_id)
163 .ok_or(HNSWError::VectorNotFound(node_id))?;
164 large.insert(vector)?;
165 }
166 let join_set_insert_duration = join_insert_start.elapsed();
167
168 debug!(
169 inserted = join_set.len(),
170 duration_ms = join_set_insert_duration.as_millis(),
171 "Join set inserted"
172 );
173
174 let remaining_start = Instant::now();
176 let mut fast_path_inserts = 0;
177 let mut fallback_inserts = 0;
178
179 let fast_ef = self
180 .config
181 .fast_ef
182 .unwrap_or(large.params().ef_construction / 2);
183
184 for node_id in 0..small.len() as u32 {
185 if join_set.contains(&node_id) {
186 continue;
187 }
188
189 let vector = small
190 .get_vector(node_id)
191 .ok_or(HNSWError::VectorNotFound(node_id))?;
192
193 let small_neighbors = small.get_neighbors_level0(node_id);
195 let entry_points: Vec<u32> = small_neighbors
196 .iter()
197 .filter(|&&n| join_set.contains(&n))
198 .copied()
199 .collect();
200
201 if entry_points.is_empty() {
202 large.insert(vector)?;
204 fallback_inserts += 1;
205 } else {
206 large.insert_with_hints(vector, &entry_points, fast_ef)?;
209 fast_path_inserts += 1;
210 }
211 }
212 let remaining_insert_duration = remaining_start.elapsed();
213
214 let total_duration = total_start.elapsed();
215
216 let stats = MergeStats {
217 vectors_merged: small_size,
218 join_set_size: join_set.len(),
219 join_set_duration,
220 join_set_insert_duration,
221 remaining_insert_duration,
222 total_duration,
223 fast_path_inserts,
224 fallback_inserts,
225 };
226
227 info!(
228 vectors_merged = stats.vectors_merged,
229 join_set_size = stats.join_set_size,
230 fast_path_ratio = format!(
231 "{:.1}%",
232 (stats.fast_path_inserts as f64 / stats.vectors_merged.max(1) as f64) * 100.0
233 ),
234 total_ms = stats.total_duration.as_millis(),
235 estimated_speedup = format!("{:.2}x", stats.estimated_speedup()),
236 "IGTM merge complete"
237 );
238
239 Ok(stats)
240 }
241
242 fn compute_join_set(&self, graph: &HNSWIndex) -> HashSet<u32> {
247 let mut join_set = HashSet::new();
248 let mut coverage: HashMap<u32, usize> = HashMap::new();
249
250 let num_vectors = graph.len();
251 if num_vectors == 0 {
252 return join_set;
253 }
254
255 while !self.is_fully_covered(&coverage, graph) {
257 let best = (0..num_vectors as u32)
259 .filter(|id| !join_set.contains(id))
260 .max_by_key(|&id| {
261 self.calculate_gain(id, &join_set, &coverage, graph)
262 .unwrap_or(0)
263 });
264
265 if let Some(best_id) = best {
266 join_set.insert(best_id);
267
268 let neighbors = graph.get_neighbors_level0(best_id);
270 for &neighbor in &neighbors {
271 *coverage.entry(neighbor).or_insert(0) += 1;
272 }
273
274 *coverage.entry(best_id).or_insert(0) += self.config.min_coverage;
276 } else {
277 warn!("Join set computation terminated early - graph may have disconnected components");
280 break;
281 }
282 }
283
284 join_set
285 }
286
287 #[allow(clippy::unnecessary_wraps)]
291 fn calculate_gain(
292 &self,
293 vertex_id: u32,
294 join_set: &HashSet<u32>,
295 coverage: &HashMap<u32, usize>,
296 graph: &HNSWIndex,
297 ) -> Result<usize> {
298 if join_set.contains(&vertex_id) {
300 return Ok(0);
301 }
302
303 let neighbors = graph.get_neighbors_level0(vertex_id);
304 let mut gain = 0;
305
306 let self_coverage = coverage.get(&vertex_id).copied().unwrap_or(0);
308 if self_coverage < self.config.min_coverage {
309 gain += 1;
310 }
311
312 for &neighbor in &neighbors {
314 let neighbor_coverage = coverage.get(&neighbor).copied().unwrap_or(0);
315 if neighbor_coverage < self.config.min_coverage {
316 gain += 1;
317 }
318 }
319
320 Ok(gain)
321 }
322
323 fn is_fully_covered(&self, coverage: &HashMap<u32, usize>, graph: &HNSWIndex) -> bool {
325 for node_id in 0..graph.len() as u32 {
326 let c = coverage.get(&node_id).copied().unwrap_or(0);
327 if c < self.config.min_coverage {
328 return false;
329 }
330 }
331 true
332 }
333}
334
335impl Default for GraphMerger {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use crate::vector::hnsw::{DistanceFunction, HNSWParams};
345
346 fn create_test_index(num_vectors: usize, dim: usize) -> HNSWIndex {
347 let params = HNSWParams {
348 m: 16,
349 ef_construction: 100,
350 ..Default::default()
351 };
352 let mut index = HNSWIndex::new(dim, params, DistanceFunction::L2, false).unwrap();
353
354 for i in 0..num_vectors {
355 let vector: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 / 100.0).collect();
356 index.insert(&vector).unwrap();
357 }
358
359 index
360 }
361
362 #[test]
363 fn test_merge_empty_small_graph() {
364 let mut large = create_test_index(100, 8);
365 let small = HNSWIndex::new(8, HNSWParams::default(), DistanceFunction::L2, false).unwrap();
366
367 let merger = GraphMerger::new();
368 let stats = merger.merge_graphs(&mut large, &small).unwrap();
369
370 assert_eq!(stats.vectors_merged, 0);
371 assert_eq!(stats.join_set_size, 0);
372 assert_eq!(large.len(), 100);
373 }
374
375 #[test]
376 fn test_merge_small_graphs() {
377 let mut large = create_test_index(100, 8);
378 let small = create_test_index(50, 8);
379
380 let initial_size = large.len();
381 let merger = GraphMerger::new();
382 let stats = merger.merge_graphs(&mut large, &small).unwrap();
383
384 assert_eq!(stats.vectors_merged, 50);
385 assert_eq!(large.len(), initial_size + 50);
386 assert!(stats.join_set_size > 0);
387 assert!(stats.join_set_size <= 50);
388 }
389
390 #[test]
391 fn test_join_set_coverage() {
392 let small = create_test_index(100, 8);
393 let merger = GraphMerger::new();
394
395 let join_set = merger.compute_join_set(&small);
396
397 assert!(!join_set.is_empty());
399
400 assert!(join_set.len() < small.len());
402
403 let mut coverage: HashMap<u32, usize> = HashMap::new();
405 for &j_id in &join_set {
406 let neighbors = small.get_neighbors_level0(j_id);
407 for &n in &neighbors {
408 *coverage.entry(n).or_insert(0) += 1;
409 }
410 *coverage.entry(j_id).or_insert(0) += merger.config.min_coverage;
411 }
412
413 for node_id in 0..small.len() as u32 {
414 let c = coverage.get(&node_id).copied().unwrap_or(0);
415 assert!(
416 c >= merger.config.min_coverage,
417 "Node {} has insufficient coverage: {} < {}",
418 node_id,
419 c,
420 merger.config.min_coverage
421 );
422 }
423 }
424
425 #[test]
426 fn test_merge_preserves_searchability() {
427 let mut large = create_test_index(100, 8);
428 let small = create_test_index(50, 8);
429
430 let test_vector = small.get_vector(25).unwrap().to_vec();
432
433 let merger = GraphMerger::new();
434 merger.merge_graphs(&mut large, &small).unwrap();
435
436 let results = large.search(&test_vector, 5, 50).unwrap();
438 assert!(!results.is_empty());
439
440 assert!(results[0].distance < 1.0);
442 }
443}