Skip to main content

nodedb_cluster/distributed_vector/
merge.rs

1//! Cross-shard k-NN merge for distributed vector search.
2//!
3//! Each shard runs local HNSW search and returns its top-K hits with
4//! distances. The coordinator merges results from all shards by distance
5//! re-ranking: sort all hits globally, take the top-K.
6//!
7//! This is the standard scatter-gather k-NN merge used by Milvus, Qdrant,
8//! Weaviate, and every distributed vector database.
9
10use serde::{Deserialize, Serialize};
11
12/// A single vector search hit from a shard.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct VectorHit {
15    /// Internal vector ID within the shard.
16    pub vector_id: u32,
17    /// Distance to the query vector (lower = closer for L2/cosine).
18    pub distance: f32,
19    /// Which shard produced this hit.
20    pub shard_id: u16,
21    /// Optional document ID associated with this vector.
22    pub doc_id: Option<String>,
23}
24
25/// Results from a single shard's local k-NN search.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ShardSearchResult {
28    pub shard_id: u16,
29    pub hits: Vec<VectorHit>,
30    pub success: bool,
31    pub error: Option<String>,
32}
33
34/// Merges k-NN results from multiple shards.
35///
36/// Standard algorithm: collect all shard-local top-K results, sort globally
37/// by distance, return the top-K. Over-fetch factor is handled per-shard
38/// (each shard returns `top_k * over_fetch_factor` to account for filter
39/// reduction).
40pub struct VectorMerger {
41    /// All hits collected from all shards, unsorted.
42    all_hits: Vec<VectorHit>,
43    /// Number of shards that have responded.
44    responded: usize,
45    /// Total number of shards expected.
46    expected: usize,
47}
48
49impl VectorMerger {
50    pub fn new(expected_shards: usize) -> Self {
51        Self {
52            all_hits: Vec::new(),
53            responded: 0,
54            expected: expected_shards,
55        }
56    }
57
58    /// Add a shard's search results.
59    pub fn add_shard_result(&mut self, result: &ShardSearchResult) {
60        if result.success {
61            self.all_hits.extend_from_slice(&result.hits);
62        }
63        self.responded += 1;
64    }
65
66    /// Whether all expected shards have responded.
67    pub fn all_responded(&self) -> bool {
68        self.responded >= self.expected
69    }
70
71    /// Merge all shard results and return the global top-K.
72    ///
73    /// Sorts by distance ascending (nearest first) and truncates to `top_k`.
74    /// Ties are broken by shard_id for deterministic ordering.
75    pub fn top_k(&mut self, top_k: usize) -> Vec<VectorHit> {
76        self.all_hits.sort_by(|a, b| {
77            a.distance
78                .partial_cmp(&b.distance)
79                .unwrap_or(std::cmp::Ordering::Equal)
80                .then(a.shard_id.cmp(&b.shard_id))
81        });
82        self.all_hits.truncate(top_k);
83        self.all_hits.clone()
84    }
85
86    /// Number of total hits collected (before merge).
87    pub fn total_hits(&self) -> usize {
88        self.all_hits.len()
89    }
90
91    /// Number of shards that responded.
92    pub fn response_count(&self) -> usize {
93        self.responded
94    }
95}
96
97/// Determine the per-shard over-fetch factor.
98///
99/// When metadata filters are active, some shard-local results will be
100/// filtered out during post-filter. Over-fetching compensates:
101/// - No filter: 1x (exact top-K per shard)
102/// - Light filter (>50% pass rate): 2x
103/// - Heavy filter (<50% pass rate): 3x
104pub fn over_fetch_factor(has_filter: bool, estimated_pass_rate: f64) -> usize {
105    if !has_filter {
106        return 1;
107    }
108    if estimated_pass_rate > 0.5 { 2 } else { 3 }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn merge_two_shards() {
117        let mut merger = VectorMerger::new(2);
118
119        merger.add_shard_result(&ShardSearchResult {
120            shard_id: 0,
121            hits: vec![
122                VectorHit {
123                    vector_id: 1,
124                    distance: 0.1,
125                    shard_id: 0,
126                    doc_id: None,
127                },
128                VectorHit {
129                    vector_id: 2,
130                    distance: 0.3,
131                    shard_id: 0,
132                    doc_id: None,
133                },
134                VectorHit {
135                    vector_id: 3,
136                    distance: 0.5,
137                    shard_id: 0,
138                    doc_id: None,
139                },
140            ],
141            success: true,
142            error: None,
143        });
144
145        merger.add_shard_result(&ShardSearchResult {
146            shard_id: 1,
147            hits: vec![
148                VectorHit {
149                    vector_id: 10,
150                    distance: 0.05,
151                    shard_id: 1,
152                    doc_id: None,
153                },
154                VectorHit {
155                    vector_id: 11,
156                    distance: 0.2,
157                    shard_id: 1,
158                    doc_id: None,
159                },
160                VectorHit {
161                    vector_id: 12,
162                    distance: 0.4,
163                    shard_id: 1,
164                    doc_id: None,
165                },
166            ],
167            success: true,
168            error: None,
169        });
170
171        assert!(merger.all_responded());
172        let top3 = merger.top_k(3);
173        assert_eq!(top3.len(), 3);
174        // Nearest hit should be shard 1's vector_id 10 (distance 0.05).
175        assert_eq!(top3[0].vector_id, 10);
176        assert_eq!(top3[0].shard_id, 1);
177        assert!((top3[0].distance - 0.05).abs() < f32::EPSILON);
178        // Second should be shard 0's vector_id 1 (distance 0.1).
179        assert_eq!(top3[1].vector_id, 1);
180        // Third should be shard 1's vector_id 11 (distance 0.2).
181        assert_eq!(top3[2].vector_id, 11);
182    }
183
184    #[test]
185    fn merge_with_failed_shard() {
186        let mut merger = VectorMerger::new(2);
187
188        merger.add_shard_result(&ShardSearchResult {
189            shard_id: 0,
190            hits: vec![VectorHit {
191                vector_id: 1,
192                distance: 0.1,
193                shard_id: 0,
194                doc_id: None,
195            }],
196            success: true,
197            error: None,
198        });
199
200        // Shard 1 failed — no hits.
201        merger.add_shard_result(&ShardSearchResult {
202            shard_id: 1,
203            hits: vec![],
204            success: false,
205            error: Some("timeout".into()),
206        });
207
208        assert!(merger.all_responded());
209        let top1 = merger.top_k(1);
210        assert_eq!(top1.len(), 1);
211        assert_eq!(top1[0].vector_id, 1);
212    }
213
214    #[test]
215    fn top_k_truncation() {
216        let mut merger = VectorMerger::new(1);
217        merger.add_shard_result(&ShardSearchResult {
218            shard_id: 0,
219            hits: (0..100)
220                .map(|i| VectorHit {
221                    vector_id: i,
222                    distance: i as f32 * 0.01,
223                    shard_id: 0,
224                    doc_id: None,
225                })
226                .collect(),
227            success: true,
228            error: None,
229        });
230
231        let top5 = merger.top_k(5);
232        assert_eq!(top5.len(), 5);
233        // Nearest 5 by distance.
234        for (i, hit) in top5.iter().enumerate() {
235            assert_eq!(hit.vector_id, i as u32);
236        }
237    }
238
239    #[test]
240    fn over_fetch_no_filter() {
241        assert_eq!(over_fetch_factor(false, 1.0), 1);
242    }
243
244    #[test]
245    fn over_fetch_light_filter() {
246        assert_eq!(over_fetch_factor(true, 0.7), 2);
247    }
248
249    #[test]
250    fn over_fetch_heavy_filter() {
251        assert_eq!(over_fetch_factor(true, 0.2), 3);
252    }
253
254    #[test]
255    fn deterministic_tie_breaking() {
256        let mut merger = VectorMerger::new(2);
257        merger.add_shard_result(&ShardSearchResult {
258            shard_id: 0,
259            hits: vec![VectorHit {
260                vector_id: 1,
261                distance: 0.5,
262                shard_id: 0,
263                doc_id: None,
264            }],
265            success: true,
266            error: None,
267        });
268        merger.add_shard_result(&ShardSearchResult {
269            shard_id: 1,
270            hits: vec![VectorHit {
271                vector_id: 2,
272                distance: 0.5,
273                shard_id: 1,
274                doc_id: None,
275            }],
276            success: true,
277            error: None,
278        });
279
280        let top2 = merger.top_k(2);
281        // Same distance — broken by shard_id (0 < 1).
282        assert_eq!(top2[0].shard_id, 0);
283        assert_eq!(top2[1].shard_id, 1);
284    }
285}