Skip to main content

oxirs_vec/
index_merger.rs

1//! ANN vector index merging — combines multiple flat indices into one (v1.1.0 round 14).
2//!
3//! Provides utilities to:
4//! - Build flat vector indices from individual entries
5//! - Merge multiple flat indices with last-write-wins deduplication
6//! - Filter entries during merge
7//! - Split large indices into even partitions
8//! - Collect merge statistics
9
10use std::collections::HashMap;
11
12/// A single vector entry stored in a flat index.
13#[derive(Debug, Clone, PartialEq)]
14pub struct VectorEntry {
15    /// Unique identifier for this vector.
16    pub id: u64,
17    /// The raw vector data.
18    pub vector: Vec<f32>,
19    /// Arbitrary string metadata attached to the entry.
20    pub metadata: HashMap<String, String>,
21}
22
23impl VectorEntry {
24    /// Create a new entry with the given `id` and `vector`.
25    pub fn new(id: u64, vector: Vec<f32>) -> Self {
26        Self {
27            id,
28            vector,
29            metadata: HashMap::new(),
30        }
31    }
32
33    /// Create a new entry with `id`, `vector`, and metadata.
34    pub fn with_metadata(id: u64, vector: Vec<f32>, metadata: HashMap<String, String>) -> Self {
35        Self {
36            id,
37            vector,
38            metadata,
39        }
40    }
41}
42
43/// A flat in-memory ANN index holding a collection of [`VectorEntry`] values.
44///
45/// All entries in the index must have the same dimensionality.
46#[derive(Debug, Clone, PartialEq)]
47pub struct FlatIndex {
48    /// All stored entries.
49    pub entries: Vec<VectorEntry>,
50    /// Dimensionality of the vectors stored in this index.
51    pub dims: usize,
52}
53
54impl FlatIndex {
55    /// Create an empty flat index for vectors of the given dimensionality.
56    pub fn new(dims: usize) -> Self {
57        Self {
58            entries: Vec::new(),
59            dims,
60        }
61    }
62
63    /// Insert an entry.  Returns an error if the entry's vector length does
64    /// not match the index dimensionality.
65    pub fn insert(&mut self, entry: VectorEntry) -> Result<(), MergeError> {
66        if entry.vector.len() != self.dims {
67            return Err(MergeError::DimensionMismatch {
68                expected: self.dims,
69                got: entry.vector.len(),
70            });
71        }
72        self.entries.push(entry);
73        Ok(())
74    }
75
76    /// Number of entries in the index.
77    pub fn len(&self) -> usize {
78        self.entries.len()
79    }
80
81    /// Returns `true` if the index contains no entries.
82    pub fn is_empty(&self) -> bool {
83        self.entries.is_empty()
84    }
85}
86
87/// Statistics collected during a merge operation.
88#[derive(Debug, Clone, PartialEq)]
89pub struct MergeStats {
90    /// Number of input indices that were merged.
91    pub input_count: usize,
92    /// Total number of entries across all input indices before deduplication.
93    pub total_before: usize,
94    /// Number of entries removed by deduplication.
95    pub deduplicated: usize,
96    /// Number of entries in the merged output index.
97    pub total_after: usize,
98}
99
100/// Errors that can occur during index merge / split operations.
101#[derive(Debug, Clone, PartialEq)]
102pub enum MergeError {
103    /// Two indices have different dimensionalities.
104    DimensionMismatch { expected: usize, got: usize },
105    /// No input indices were provided.
106    EmptyInput,
107    /// The requested number of parts is invalid (0).
108    InvalidParts,
109}
110
111impl std::fmt::Display for MergeError {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        match self {
114            MergeError::DimensionMismatch { expected, got } => {
115                write!(f, "Dimension mismatch: expected {expected}, got {got}")
116            }
117            MergeError::EmptyInput => write!(f, "No input indices provided"),
118            MergeError::InvalidParts => {
119                write!(f, "Number of parts must be greater than zero")
120            }
121        }
122    }
123}
124
125impl std::error::Error for MergeError {}
126
127/// Combines multiple [`FlatIndex`] instances into a single merged index.
128///
129/// # Deduplication
130/// When two entries share the same `id`, the **last one wins** (insertion
131/// order across indices, then within each index).
132#[derive(Debug, Default)]
133pub struct IndexMerger {
134    indices: Vec<FlatIndex>,
135}
136
137impl IndexMerger {
138    /// Create a new, empty merger.
139    pub fn new() -> Self {
140        Self {
141            indices: Vec::new(),
142        }
143    }
144
145    /// Add an index to the merge set.
146    pub fn add_index(&mut self, idx: FlatIndex) {
147        self.indices.push(idx);
148    }
149
150    /// Merge all added indices into a single [`FlatIndex`].
151    ///
152    /// Deduplication is performed on `id`: if multiple entries share the same
153    /// ID, the one that appears **latest** (last index, last position within
154    /// that index) wins.
155    ///
156    /// Returns [`MergeError::EmptyInput`] if no indices have been added.
157    pub fn merge(&mut self) -> Result<FlatIndex, MergeError> {
158        if self.indices.is_empty() {
159            return Err(MergeError::EmptyInput);
160        }
161
162        let dims = self.indices[0].dims;
163
164        // Validate that all indices share the same dimensionality
165        for idx in &self.indices {
166            if idx.dims != dims {
167                return Err(MergeError::DimensionMismatch {
168                    expected: dims,
169                    got: idx.dims,
170                });
171            }
172        }
173
174        // Last-write-wins deduplication using an ordered map
175        // (we use a Vec to preserve insertion order for iteration, and a
176        // HashMap for O(1) lookup / update)
177        let mut order: Vec<u64> = Vec::new();
178        let mut map: HashMap<u64, VectorEntry> = HashMap::new();
179
180        for idx in &self.indices {
181            for entry in &idx.entries {
182                if !map.contains_key(&entry.id) {
183                    order.push(entry.id);
184                }
185                map.insert(entry.id, entry.clone());
186            }
187        }
188
189        let mut out = FlatIndex::new(dims);
190        for id in &order {
191            if let Some(entry) = map.remove(id) {
192                out.entries.push(entry);
193            }
194        }
195
196        Ok(out)
197    }
198
199    /// Merge all indices, retaining only entries for which `filter` returns
200    /// `true`.  Deduplication happens **before** filtering.
201    pub fn merge_with_filter<F>(&mut self, filter: F) -> Result<FlatIndex, MergeError>
202    where
203        F: Fn(&VectorEntry) -> bool,
204    {
205        let merged = self.merge()?;
206        let dims = merged.dims;
207        let mut out = FlatIndex::new(dims);
208        for entry in merged.entries {
209            if filter(&entry) {
210                out.entries.push(entry);
211            }
212        }
213        Ok(out)
214    }
215
216    /// Merge all indices and return both the merged index and statistics.
217    pub fn merge_with_stats(&mut self) -> Result<(FlatIndex, MergeStats), MergeError> {
218        if self.indices.is_empty() {
219            return Err(MergeError::EmptyInput);
220        }
221
222        let input_count = self.indices.len();
223        let total_before: usize = self.indices.iter().map(|i| i.len()).sum();
224
225        let merged = self.merge()?;
226        let total_after = merged.len();
227        let deduplicated = total_before.saturating_sub(total_after);
228
229        let stats = MergeStats {
230            input_count,
231            total_before,
232            deduplicated,
233            total_after,
234        };
235        Ok((merged, stats))
236    }
237
238    /// Split a [`FlatIndex`] into `parts` evenly-sized sub-indices.
239    ///
240    /// If the number of entries does not divide evenly, the first
241    /// `entries.len() % parts` partitions will each receive one extra entry.
242    ///
243    /// Returns [`MergeError::InvalidParts`] if `parts == 0`.
244    pub fn split(idx: &FlatIndex, parts: usize) -> Vec<FlatIndex> {
245        if parts == 0 {
246            return vec![];
247        }
248        if idx.is_empty() {
249            return (0..parts).map(|_| FlatIndex::new(idx.dims)).collect();
250        }
251
252        let n = idx.entries.len();
253        let base = n / parts;
254        let remainder = n % parts;
255
256        let mut result = Vec::with_capacity(parts);
257        let mut offset = 0usize;
258
259        for i in 0..parts {
260            let chunk_size = base + if i < remainder { 1 } else { 0 };
261            let mut sub = FlatIndex::new(idx.dims);
262            sub.entries
263                .extend_from_slice(&idx.entries[offset..offset + chunk_size]);
264            offset += chunk_size;
265            result.push(sub);
266        }
267        result
268    }
269}
270
271// ---------------------------------------------------------------------------
272// Tests
273// ---------------------------------------------------------------------------
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    fn make_entry(id: u64, dims: usize, val: f32) -> VectorEntry {
280        VectorEntry::new(id, vec![val; dims])
281    }
282
283    fn make_index(dims: usize, ids: &[(u64, f32)]) -> FlatIndex {
284        let mut idx = FlatIndex::new(dims);
285        for (id, val) in ids {
286            idx.insert(make_entry(*id, dims, *val)).expect("insert ok");
287        }
288        idx
289    }
290
291    // -- FlatIndex -----------------------------------------------------------
292
293    #[test]
294    fn test_flat_index_new_is_empty() {
295        let idx = FlatIndex::new(4);
296        assert!(idx.is_empty());
297        assert_eq!(idx.len(), 0);
298        assert_eq!(idx.dims, 4);
299    }
300
301    #[test]
302    fn test_flat_index_insert_valid() {
303        let mut idx = FlatIndex::new(3);
304        let entry = make_entry(1, 3, 0.5);
305        assert!(idx.insert(entry).is_ok());
306        assert_eq!(idx.len(), 1);
307    }
308
309    #[test]
310    fn test_flat_index_insert_dimension_mismatch() {
311        let mut idx = FlatIndex::new(3);
312        let entry = make_entry(1, 4, 0.5);
313        assert_eq!(
314            idx.insert(entry),
315            Err(MergeError::DimensionMismatch {
316                expected: 3,
317                got: 4
318            })
319        );
320    }
321
322    #[test]
323    fn test_flat_index_is_not_empty_after_insert() {
324        let mut idx = FlatIndex::new(2);
325        idx.insert(make_entry(1, 2, 1.0)).expect("ok");
326        assert!(!idx.is_empty());
327    }
328
329    // -- IndexMerger::merge -------------------------------------------------
330
331    #[test]
332    fn test_merge_empty_returns_error() {
333        let mut merger = IndexMerger::new();
334        assert_eq!(merger.merge(), Err(MergeError::EmptyInput));
335    }
336
337    #[test]
338    fn test_merge_single_index() {
339        let idx = make_index(2, &[(1, 1.0), (2, 2.0)]);
340        let mut merger = IndexMerger::new();
341        merger.add_index(idx);
342        let out = merger.merge().expect("merge ok");
343        assert_eq!(out.len(), 2);
344    }
345
346    #[test]
347    fn test_merge_two_disjoint_indices() {
348        let a = make_index(2, &[(1, 1.0), (2, 2.0)]);
349        let b = make_index(2, &[(3, 3.0), (4, 4.0)]);
350        let mut merger = IndexMerger::new();
351        merger.add_index(a);
352        merger.add_index(b);
353        let out = merger.merge().expect("merge ok");
354        assert_eq!(out.len(), 4);
355    }
356
357    #[test]
358    fn test_merge_deduplication_last_write_wins() {
359        // Both indices contain ID 1; the one in `b` should survive.
360        let a = make_index(2, &[(1, 1.0)]);
361        let b = make_index(2, &[(1, 9.9)]);
362        let mut merger = IndexMerger::new();
363        merger.add_index(a);
364        merger.add_index(b);
365        let out = merger.merge().expect("merge ok");
366        assert_eq!(out.len(), 1);
367        assert!((out.entries[0].vector[0] - 9.9).abs() < 1e-6);
368    }
369
370    #[test]
371    fn test_merge_deduplication_count() {
372        let a = make_index(2, &[(1, 1.0), (2, 2.0)]);
373        let b = make_index(2, &[(2, 2.5), (3, 3.0)]);
374        let mut merger = IndexMerger::new();
375        merger.add_index(a);
376        merger.add_index(b);
377        let out = merger.merge().expect("merge ok");
378        // IDs: 1, 2 (from b), 3
379        assert_eq!(out.len(), 3);
380    }
381
382    #[test]
383    fn test_merge_dimension_mismatch_error() {
384        let a = make_index(2, &[(1, 1.0)]);
385        let b = make_index(3, &[(2, 2.0)]);
386        let mut merger = IndexMerger::new();
387        merger.add_index(a);
388        merger.add_index(b);
389        assert!(merger.merge().is_err());
390    }
391
392    #[test]
393    fn test_merge_preserves_metadata() {
394        let mut meta = HashMap::new();
395        meta.insert("key".to_string(), "val".to_string());
396        let entry = VectorEntry::with_metadata(42, vec![1.0, 2.0], meta.clone());
397        let mut idx = FlatIndex::new(2);
398        idx.insert(entry).expect("ok");
399        let mut merger = IndexMerger::new();
400        merger.add_index(idx);
401        let out = merger.merge().expect("ok");
402        assert_eq!(out.entries[0].metadata.get("key"), Some(&"val".to_string()));
403    }
404
405    // -- merge_with_filter ---------------------------------------------------
406
407    #[test]
408    fn test_merge_with_filter_keeps_matching() {
409        let idx = make_index(2, &[(1, 1.0), (2, 2.0), (3, 3.0)]);
410        let mut merger = IndexMerger::new();
411        merger.add_index(idx);
412        let out = merger.merge_with_filter(|e| e.id % 2 == 1).expect("ok");
413        assert_eq!(out.len(), 2);
414        assert!(out.entries.iter().all(|e| e.id % 2 == 1));
415    }
416
417    #[test]
418    fn test_merge_with_filter_all_excluded() {
419        let idx = make_index(2, &[(1, 1.0), (2, 2.0)]);
420        let mut merger = IndexMerger::new();
421        merger.add_index(idx);
422        let out = merger.merge_with_filter(|_| false).expect("ok");
423        assert!(out.is_empty());
424    }
425
426    #[test]
427    fn test_merge_with_filter_all_included() {
428        let idx = make_index(2, &[(1, 1.0), (2, 2.0)]);
429        let mut merger = IndexMerger::new();
430        merger.add_index(idx);
431        let out = merger.merge_with_filter(|_| true).expect("ok");
432        assert_eq!(out.len(), 2);
433    }
434
435    #[test]
436    fn test_merge_with_filter_empty_input() {
437        let mut merger = IndexMerger::new();
438        assert_eq!(
439            merger.merge_with_filter(|_| true),
440            Err(MergeError::EmptyInput)
441        );
442    }
443
444    // -- merge_with_stats ----------------------------------------------------
445
446    #[test]
447    fn test_merge_stats_no_dedup() {
448        let a = make_index(2, &[(1, 1.0), (2, 2.0)]);
449        let b = make_index(2, &[(3, 3.0)]);
450        let mut merger = IndexMerger::new();
451        merger.add_index(a);
452        merger.add_index(b);
453        let (out, stats) = merger.merge_with_stats().expect("ok");
454        assert_eq!(stats.input_count, 2);
455        assert_eq!(stats.total_before, 3);
456        assert_eq!(stats.deduplicated, 0);
457        assert_eq!(stats.total_after, 3);
458        assert_eq!(out.len(), 3);
459    }
460
461    #[test]
462    fn test_merge_stats_with_dedup() {
463        let a = make_index(2, &[(1, 1.0), (2, 2.0)]);
464        let b = make_index(2, &[(2, 9.0), (3, 3.0)]);
465        let mut merger = IndexMerger::new();
466        merger.add_index(a);
467        merger.add_index(b);
468        let (_out, stats) = merger.merge_with_stats().expect("ok");
469        assert_eq!(stats.total_before, 4);
470        assert_eq!(stats.deduplicated, 1);
471        assert_eq!(stats.total_after, 3);
472    }
473
474    #[test]
475    fn test_merge_stats_empty_input() {
476        let mut merger = IndexMerger::new();
477        assert_eq!(merger.merge_with_stats(), Err(MergeError::EmptyInput));
478    }
479
480    // -- split ---------------------------------------------------------------
481
482    #[test]
483    fn test_split_even() {
484        let idx = make_index(2, &[(1, 1.0), (2, 2.0), (3, 3.0), (4, 4.0)]);
485        let parts = IndexMerger::split(&idx, 2);
486        assert_eq!(parts.len(), 2);
487        assert_eq!(parts[0].len(), 2);
488        assert_eq!(parts[1].len(), 2);
489    }
490
491    #[test]
492    fn test_split_uneven() {
493        let idx = make_index(2, &[(1, 1.0), (2, 2.0), (3, 3.0)]);
494        let parts = IndexMerger::split(&idx, 2);
495        assert_eq!(parts.len(), 2);
496        // 3 entries / 2 parts → first gets 2, second gets 1
497        assert_eq!(parts[0].len(), 2);
498        assert_eq!(parts[1].len(), 1);
499    }
500
501    #[test]
502    fn test_split_into_one() {
503        let idx = make_index(2, &[(1, 1.0), (2, 2.0)]);
504        let parts = IndexMerger::split(&idx, 1);
505        assert_eq!(parts.len(), 1);
506        assert_eq!(parts[0].len(), 2);
507    }
508
509    #[test]
510    fn test_split_zero_parts() {
511        let idx = make_index(2, &[(1, 1.0)]);
512        let parts = IndexMerger::split(&idx, 0);
513        assert!(parts.is_empty());
514    }
515
516    #[test]
517    fn test_split_empty_index() {
518        let idx = FlatIndex::new(3);
519        let parts = IndexMerger::split(&idx, 3);
520        assert_eq!(parts.len(), 3);
521        assert!(parts.iter().all(|p| p.is_empty()));
522    }
523
524    #[test]
525    fn test_split_more_parts_than_entries() {
526        let idx = make_index(2, &[(1, 1.0), (2, 2.0)]);
527        let parts = IndexMerger::split(&idx, 5);
528        assert_eq!(parts.len(), 5);
529        let total: usize = parts.iter().map(|p| p.len()).sum();
530        assert_eq!(total, 2);
531    }
532
533    #[test]
534    fn test_split_preserves_dims() {
535        let idx = make_index(7, &[(1, 1.0), (2, 2.0), (3, 3.0)]);
536        let parts = IndexMerger::split(&idx, 2);
537        for p in &parts {
538            assert_eq!(p.dims, 7);
539        }
540    }
541
542    #[test]
543    fn test_split_total_count_preserved() {
544        let ids: Vec<(u64, f32)> = (1u64..=10).map(|i| (i, i as f32)).collect();
545        let idx = make_index(4, &ids);
546        let parts = IndexMerger::split(&idx, 3);
547        let total: usize = parts.iter().map(|p| p.len()).sum();
548        assert_eq!(total, 10);
549    }
550
551    // -- Error display -------------------------------------------------------
552
553    #[test]
554    fn test_error_display_empty_input() {
555        let e = MergeError::EmptyInput;
556        assert!(e.to_string().contains("No input"));
557    }
558
559    #[test]
560    fn test_error_display_dimension_mismatch() {
561        let e = MergeError::DimensionMismatch {
562            expected: 4,
563            got: 3,
564        };
565        let s = e.to_string();
566        assert!(s.contains("4"));
567        assert!(s.contains("3"));
568    }
569
570    #[test]
571    fn test_error_display_invalid_parts() {
572        let e = MergeError::InvalidParts;
573        assert!(e.to_string().contains("zero"));
574    }
575
576    #[test]
577    fn test_error_is_std_error() {
578        let e: Box<dyn std::error::Error> = Box::new(MergeError::EmptyInput);
579        assert!(e.to_string().contains("No input"));
580    }
581
582    // -- VectorEntry ---------------------------------------------------------
583
584    #[test]
585    fn test_vector_entry_new() {
586        let e = VectorEntry::new(7, vec![1.0, 2.0, 3.0]);
587        assert_eq!(e.id, 7);
588        assert_eq!(e.vector.len(), 3);
589        assert!(e.metadata.is_empty());
590    }
591
592    #[test]
593    fn test_vector_entry_with_metadata() {
594        let mut meta = HashMap::new();
595        meta.insert("source".into(), "test".into());
596        let e = VectorEntry::with_metadata(1, vec![0.0], meta);
597        assert_eq!(e.metadata.get("source"), Some(&"test".to_string()));
598    }
599
600    #[test]
601    fn test_index_merger_default() {
602        let _m: IndexMerger = IndexMerger::default();
603    }
604
605    #[test]
606    fn test_merge_three_indices() {
607        let a = make_index(2, &[(1, 1.0)]);
608        let b = make_index(2, &[(2, 2.0)]);
609        let c = make_index(2, &[(3, 3.0)]);
610        let mut merger = IndexMerger::new();
611        merger.add_index(a);
612        merger.add_index(b);
613        merger.add_index(c);
614        let out = merger.merge().expect("ok");
615        assert_eq!(out.len(), 3);
616    }
617
618    #[test]
619    fn test_merge_large_index() {
620        let pairs: Vec<(u64, f32)> = (1u64..=100).map(|i| (i, i as f32)).collect();
621        let idx = make_index(4, &pairs);
622        let mut merger = IndexMerger::new();
623        merger.add_index(idx);
624        let out = merger.merge().expect("ok");
625        assert_eq!(out.len(), 100);
626    }
627
628    #[test]
629    fn test_split_four_parts() {
630        let pairs: Vec<(u64, f32)> = (1u64..=8).map(|i| (i, i as f32)).collect();
631        let idx = make_index(2, &pairs);
632        let parts = IndexMerger::split(&idx, 4);
633        assert_eq!(parts.len(), 4);
634        assert!(parts.iter().all(|p| p.len() == 2));
635    }
636
637    #[test]
638    fn test_merge_filter_by_vector_value() {
639        let pairs: Vec<(u64, f32)> = (1u64..=10).map(|i| (i, i as f32)).collect();
640        let idx = make_index(2, &pairs);
641        let mut merger = IndexMerger::new();
642        merger.add_index(idx);
643        // Only keep entries where first vector element >= 5.0
644        let out = merger
645            .merge_with_filter(|e| e.vector[0] >= 5.0)
646            .expect("ok");
647        assert_eq!(out.len(), 6); // 5.0, 6.0, 7.0, 8.0, 9.0, 10.0
648    }
649
650    #[test]
651    fn test_flat_index_dims_preserved_through_merge() {
652        let idx = make_index(128, &[(1, 0.1), (2, 0.2)]);
653        let mut merger = IndexMerger::new();
654        merger.add_index(idx);
655        let out = merger.merge().expect("ok");
656        assert_eq!(out.dims, 128);
657    }
658
659    #[test]
660    fn test_stats_input_count_three() {
661        let mut merger = IndexMerger::new();
662        merger.add_index(make_index(2, &[(1, 1.0)]));
663        merger.add_index(make_index(2, &[(2, 2.0)]));
664        merger.add_index(make_index(2, &[(3, 3.0)]));
665        let (_, stats) = merger.merge_with_stats().expect("ok");
666        assert_eq!(stats.input_count, 3);
667    }
668
669    #[test]
670    fn test_split_single_entry_many_parts() {
671        let idx = make_index(2, &[(42, 1.0)]);
672        let parts = IndexMerger::split(&idx, 4);
673        let total: usize = parts.iter().map(|p| p.len()).sum();
674        assert_eq!(total, 1);
675        assert_eq!(parts.len(), 4);
676    }
677}