Skip to main content

nodedb_vector/delta/
merge.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Query-time merge of main HNSW and delta result sets.
4//!
5//! Filters tombstones from both lists, tags each result with its origin,
6//! deduplicates by id (delta wins), and returns the top-k sorted ascending
7//! by distance.
8
9use std::collections::{HashMap, HashSet};
10
11/// A single result in the merged candidate set.
12#[derive(Debug, Clone, PartialEq)]
13pub struct MergedResult {
14    /// Vector id.
15    pub id: u32,
16    /// Distance from the query vector.
17    pub distance: f32,
18    /// `true` if this result came from the delta index.
19    pub from_delta: bool,
20}
21
22/// Merge `main` and `delta` result lists into a single sorted top-k list.
23///
24/// Rules:
25/// - Any id present in `tombstones` is excluded.
26/// - If the same id appears in both lists, the delta entry wins (its distance
27///   and `from_delta = true` are used).
28/// - Output is sorted ascending by distance and truncated to `k`.
29pub fn merge_results(
30    main: Vec<(u32, f32)>,
31    delta: Vec<(u32, f32)>,
32    tombstones: &HashSet<u32>,
33    k: usize,
34) -> Vec<MergedResult> {
35    if k == 0 {
36        return Vec::new();
37    }
38
39    // Collect delta entries first (they win on collision).
40    let mut by_id: HashMap<u32, MergedResult> = HashMap::new();
41
42    for (id, dist) in delta {
43        if tombstones.contains(&id) {
44            continue;
45        }
46        by_id.insert(
47            id,
48            MergedResult {
49                id,
50                distance: dist,
51                from_delta: true,
52            },
53        );
54    }
55
56    // Main results only inserted when id not already present (delta wins).
57    for (id, dist) in main {
58        if tombstones.contains(&id) {
59            continue;
60        }
61        by_id.entry(id).or_insert(MergedResult {
62            id,
63            distance: dist,
64            from_delta: false,
65        });
66    }
67
68    let mut merged: Vec<MergedResult> = by_id.into_values().collect();
69
70    // Partial sort to get top-k cheaply.
71    if k < merged.len() {
72        merged.select_nth_unstable_by(k, |a, b| {
73            a.distance
74                .partial_cmp(&b.distance)
75                .unwrap_or(std::cmp::Ordering::Equal)
76        });
77        merged.truncate(k);
78    }
79
80    merged.sort_unstable_by(|a, b| {
81        a.distance
82            .partial_cmp(&b.distance)
83            .unwrap_or(std::cmp::Ordering::Equal)
84    });
85
86    merged
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn delta_beats_main_wins_collision() {
95        let main = vec![(1, 1.0), (2, 2.0)];
96        let delta = vec![(3, 0.5), (1, 0.9)]; // id=1 in both; delta wins
97        let tombstones = HashSet::new();
98        let results = merge_results(main, delta, &tombstones, 3);
99
100        // Top-3: 3@0.5, 1@0.9, 2@2.0
101        assert_eq!(results.len(), 3);
102        assert_eq!(results[0].id, 3);
103        assert!(results[0].from_delta);
104        assert_eq!(results[1].id, 1);
105        assert!(results[1].from_delta);
106        assert!((results[1].distance - 0.9).abs() < 1e-6);
107        assert_eq!(results[2].id, 2);
108        assert!(!results[2].from_delta);
109    }
110
111    #[test]
112    fn top_two_delta_first() {
113        // Spec example: main {1:1.0, 2:2.0}, delta {3:0.5} → top-2 = [3,1]
114        let main = vec![(1, 1.0f32), (2, 2.0f32)];
115        let delta = vec![(3, 0.5f32)];
116        let tombstones = HashSet::new();
117        let results = merge_results(main, delta, &tombstones, 2);
118        assert_eq!(results.len(), 2);
119        assert_eq!(results[0].id, 3);
120        assert_eq!(results[1].id, 1);
121    }
122
123    #[test]
124    fn tombstone_excludes_from_both() {
125        let main = vec![(1, 1.0f32), (2, 2.0f32)];
126        let delta = vec![(3, 0.5f32)];
127        let mut tombstones = HashSet::new();
128        tombstones.insert(2u32);
129        let results = merge_results(main, delta, &tombstones, 10);
130        assert!(results.iter().all(|r| r.id != 2));
131        assert_eq!(results.len(), 2);
132    }
133
134    #[test]
135    fn empty_inputs_returns_empty() {
136        let results = merge_results(vec![], vec![], &HashSet::new(), 10);
137        assert!(results.is_empty());
138    }
139
140    #[test]
141    fn k_zero_returns_empty() {
142        let main = vec![(1, 0.5f32)];
143        let results = merge_results(main, vec![], &HashSet::new(), 0);
144        assert!(results.is_empty());
145    }
146}