1use super::error::{HNSWError, Result};
10use super::index::HNSWIndex;
11use std::collections::{HashMap, HashSet};
12use std::time::{Duration, Instant};
13use tracing::{debug, info, instrument, warn};
14
15#[derive(Clone, Debug)]
17pub struct MergeConfig {
18 pub min_coverage: usize,
21
22 pub fast_ef: Option<usize>,
25
26 pub parallel_join_set: bool,
28}
29
30impl Default for MergeConfig {
31 fn default() -> Self {
32 Self {
33 min_coverage: 2,
34 fast_ef: None,
35 parallel_join_set: true,
36 }
37 }
38}
39
40#[derive(Clone, Debug)]
42pub struct MergeStats {
43 pub vectors_merged: usize,
45
46 pub join_set_size: usize,
48
49 pub join_set_duration: Duration,
51
52 pub join_set_insert_duration: Duration,
54
55 pub remaining_insert_duration: Duration,
57
58 pub total_duration: Duration,
60
61 pub fast_path_inserts: usize,
63
64 pub fallback_inserts: usize,
66}
67
68impl MergeStats {
69 #[must_use]
71 pub fn estimated_speedup(&self) -> f64 {
72 let join_set_ratio = self.join_set_size as f64 / self.vectors_merged.max(1) as f64;
79 let remaining_ratio = 1.0 - join_set_ratio;
80
81 1.0 / (join_set_ratio + remaining_ratio * 0.2)
82 }
83}
84
85pub struct GraphMerger {
90 config: MergeConfig,
91}
92
93impl GraphMerger {
94 #[must_use]
96 pub fn new() -> Self {
97 Self {
98 config: MergeConfig::default(),
99 }
100 }
101
102 #[must_use]
104 pub fn with_config(config: MergeConfig) -> Self {
105 Self { config }
106 }
107
108 #[instrument(skip(self, large, small), fields(large_size = large.len(), small_size = small.len()))]
123 pub fn merge_graphs(&self, large: &mut HNSWIndex, small: &HNSWIndex) -> Result<MergeStats> {
124 let total_start = Instant::now();
125 let small_size = small.len();
126
127 if small_size == 0 {
128 return Ok(MergeStats {
129 vectors_merged: 0,
130 join_set_size: 0,
131 join_set_duration: Duration::ZERO,
132 join_set_insert_duration: Duration::ZERO,
133 remaining_insert_duration: Duration::ZERO,
134 total_duration: total_start.elapsed(),
135 fast_path_inserts: 0,
136 fallback_inserts: 0,
137 });
138 }
139
140 info!(
141 large_size = large.len(),
142 small_size = small_size,
143 "Starting IGTM graph merge"
144 );
145
146 let join_set_start = Instant::now();
148 let join_set = self.compute_join_set(small);
149 let join_set_duration = join_set_start.elapsed();
150
151 debug!(
152 join_set_size = join_set.len(),
153 coverage_target = self.config.min_coverage,
154 duration_ms = join_set_duration.as_millis(),
155 "Join set computed"
156 );
157
158 let join_insert_start = Instant::now();
160 let mut small_to_large: HashMap<u32, u32> = HashMap::with_capacity(join_set.len());
161
162 for &small_id in &join_set {
163 let vector = small
164 .get_vector(small_id)
165 .ok_or(HNSWError::VectorNotFound(small_id))?;
166 let large_id = large.insert(vector)?;
167 small_to_large.insert(small_id, large_id);
168 }
169 let join_set_insert_duration = join_insert_start.elapsed();
170
171 debug!(
172 inserted = join_set.len(),
173 duration_ms = join_set_insert_duration.as_millis(),
174 "Join set inserted"
175 );
176
177 let remaining_start = Instant::now();
179 let mut fast_path_inserts = 0;
180 let mut fallback_inserts = 0;
181
182 let fast_ef = self
183 .config
184 .fast_ef
185 .unwrap_or(large.params().ef_construction / 2);
186
187 for node_id in 0..small.len() as u32 {
188 if join_set.contains(&node_id) {
189 continue;
190 }
191
192 let vector = small
193 .get_vector(node_id)
194 .ok_or(HNSWError::VectorNotFound(node_id))?;
195
196 let small_neighbors = small.get_neighbors_level0(node_id);
199 let entry_points: Vec<u32> = small_neighbors
200 .iter()
201 .filter_map(|&small_neighbor_id| small_to_large.get(&small_neighbor_id).copied())
202 .collect();
203
204 if entry_points.is_empty() {
205 large.insert(vector)?;
207 fallback_inserts += 1;
208 } else {
209 large.insert_with_hints(vector, &entry_points, fast_ef)?;
211 fast_path_inserts += 1;
212 }
213 }
214 let remaining_insert_duration = remaining_start.elapsed();
215
216 let total_duration = total_start.elapsed();
217
218 let stats = MergeStats {
219 vectors_merged: small_size,
220 join_set_size: join_set.len(),
221 join_set_duration,
222 join_set_insert_duration,
223 remaining_insert_duration,
224 total_duration,
225 fast_path_inserts,
226 fallback_inserts,
227 };
228
229 info!(
230 vectors_merged = stats.vectors_merged,
231 join_set_size = stats.join_set_size,
232 fast_path_ratio = format!(
233 "{:.1}%",
234 (stats.fast_path_inserts as f64 / stats.vectors_merged.max(1) as f64) * 100.0
235 ),
236 total_ms = stats.total_duration.as_millis(),
237 estimated_speedup = format!("{:.2}x", stats.estimated_speedup()),
238 "IGTM merge complete"
239 );
240
241 Ok(stats)
242 }
243
244 fn compute_join_set(&self, graph: &HNSWIndex) -> HashSet<u32> {
249 let mut join_set = HashSet::new();
250 let mut coverage: HashMap<u32, usize> = HashMap::new();
251
252 let num_vectors = graph.len();
253 if num_vectors == 0 {
254 return join_set;
255 }
256
257 while !self.is_fully_covered(&coverage, graph) {
259 let best = (0..num_vectors as u32)
261 .filter(|id| !join_set.contains(id))
262 .max_by_key(|&id| {
263 self.calculate_gain(id, &join_set, &coverage, graph)
264 .unwrap_or(0)
265 });
266
267 if let Some(best_id) = best {
268 join_set.insert(best_id);
269
270 let neighbors = graph.get_neighbors_level0(best_id);
272 for &neighbor in &neighbors {
273 *coverage.entry(neighbor).or_insert(0) += 1;
274 }
275
276 *coverage.entry(best_id).or_insert(0) += self.config.min_coverage;
278 } else {
279 warn!("Join set computation terminated early - graph may have disconnected components");
282 break;
283 }
284 }
285
286 join_set
287 }
288
289 #[allow(clippy::unnecessary_wraps)]
293 fn calculate_gain(
294 &self,
295 vertex_id: u32,
296 join_set: &HashSet<u32>,
297 coverage: &HashMap<u32, usize>,
298 graph: &HNSWIndex,
299 ) -> Result<usize> {
300 if join_set.contains(&vertex_id) {
302 return Ok(0);
303 }
304
305 let neighbors = graph.get_neighbors_level0(vertex_id);
306 let mut gain = 0;
307
308 let self_coverage = coverage.get(&vertex_id).copied().unwrap_or(0);
310 if self_coverage < self.config.min_coverage {
311 gain += 1;
312 }
313
314 for &neighbor in &neighbors {
316 let neighbor_coverage = coverage.get(&neighbor).copied().unwrap_or(0);
317 if neighbor_coverage < self.config.min_coverage {
318 gain += 1;
319 }
320 }
321
322 Ok(gain)
323 }
324
325 fn is_fully_covered(&self, coverage: &HashMap<u32, usize>, graph: &HNSWIndex) -> bool {
327 for node_id in 0..graph.len() as u32 {
328 let c = coverage.get(&node_id).copied().unwrap_or(0);
329 if c < self.config.min_coverage {
330 return false;
331 }
332 }
333 true
334 }
335}
336
337impl Default for GraphMerger {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use crate::vector::hnsw::{DistanceFunction, HNSWParams};
347
348 fn create_test_index(num_vectors: usize, dim: usize) -> HNSWIndex {
349 let params = HNSWParams {
350 m: 16,
351 ef_construction: 100,
352 ..Default::default()
353 };
354 let mut index = HNSWIndex::new(dim, params, DistanceFunction::L2, false).unwrap();
355
356 for i in 0..num_vectors {
357 let vector: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 / 100.0).collect();
358 index.insert(&vector).unwrap();
359 }
360
361 index
362 }
363
364 #[test]
365 fn test_merge_empty_small_graph() {
366 let mut large = create_test_index(100, 8);
367 let small = HNSWIndex::new(8, HNSWParams::default(), DistanceFunction::L2, false).unwrap();
368
369 let merger = GraphMerger::new();
370 let stats = merger.merge_graphs(&mut large, &small).unwrap();
371
372 assert_eq!(stats.vectors_merged, 0);
373 assert_eq!(stats.join_set_size, 0);
374 assert_eq!(large.len(), 100);
375 }
376
377 #[test]
378 fn test_merge_small_graphs() {
379 let mut large = create_test_index(100, 8);
380 let small = create_test_index(50, 8);
381
382 let initial_size = large.len();
383 let merger = GraphMerger::new();
384 let stats = merger.merge_graphs(&mut large, &small).unwrap();
385
386 assert_eq!(stats.vectors_merged, 50);
387 assert_eq!(large.len(), initial_size + 50);
388 assert!(stats.join_set_size > 0);
389 assert!(stats.join_set_size <= 50);
390 }
391
392 #[test]
393 fn test_join_set_coverage() {
394 let small = create_test_index(100, 8);
395 let merger = GraphMerger::new();
396
397 let join_set = merger.compute_join_set(&small);
398
399 assert!(!join_set.is_empty());
401
402 assert!(join_set.len() < small.len());
404
405 let mut coverage: HashMap<u32, usize> = HashMap::new();
407 for &j_id in &join_set {
408 let neighbors = small.get_neighbors_level0(j_id);
409 for &n in &neighbors {
410 *coverage.entry(n).or_insert(0) += 1;
411 }
412 *coverage.entry(j_id).or_insert(0) += merger.config.min_coverage;
413 }
414
415 for node_id in 0..small.len() as u32 {
416 let c = coverage.get(&node_id).copied().unwrap_or(0);
417 assert!(
418 c >= merger.config.min_coverage,
419 "Node {} has insufficient coverage: {} < {}",
420 node_id,
421 c,
422 merger.config.min_coverage
423 );
424 }
425 }
426
427 #[test]
428 fn test_merge_preserves_searchability() {
429 let mut large = create_test_index(100, 8);
430 let small = create_test_index(50, 8);
431
432 let test_vector = small.get_vector(25).unwrap().to_vec();
434
435 let merger = GraphMerger::new();
436 merger.merge_graphs(&mut large, &small).unwrap();
437
438 let results = large.search(&test_vector, 5, 50).unwrap();
440 assert!(!results.is_empty());
441
442 assert!(results[0].distance < 1.0);
444 }
445}