Skip to main content

nodedb_cluster/distributed_vector/
merge.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! Cross-shard k-NN merge for distributed vector search.
4//!
5//! Each shard runs local HNSW search and returns its top-K hits with
6//! distances. The coordinator merges results from all shards by distance
7//! re-ranking: sort all hits globally, take the top-K.
8//!
9//! This is the standard scatter-gather k-NN merge used by Milvus, Qdrant,
10//! Weaviate, and every distributed vector database.
11
12use serde::{Deserialize, Serialize};
13
14/// A single vector search hit from a shard.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct VectorHit {
17    /// Internal vector ID within the shard.
18    pub vector_id: u32,
19    /// Distance to the query vector (lower = closer for L2/cosine).
20    pub distance: f32,
21    /// Which shard produced this hit.
22    pub shard_id: u32,
23    /// Optional document ID associated with this vector.
24    pub doc_id: Option<String>,
25}
26
27/// Results from a single shard's local k-NN search.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ShardSearchResult {
30    pub shard_id: u32,
31    pub hits: Vec<VectorHit>,
32    pub success: bool,
33    pub error: Option<String>,
34}
35
36/// Merges k-NN results from multiple shards.
37///
38/// Standard algorithm: collect all shard-local top-K results, sort globally
39/// by distance, return the top-K. Over-fetch factor is handled per-shard
40/// (each shard returns `top_k * over_fetch_factor` to account for filter
41/// reduction).
42pub struct VectorMerger {
43    /// All hits collected from all shards, unsorted.
44    all_hits: Vec<VectorHit>,
45    /// Number of shards that have responded.
46    responded: usize,
47    /// Total number of shards expected.
48    expected: usize,
49}
50
51impl VectorMerger {
52    pub fn new(expected_shards: usize) -> Self {
53        Self {
54            all_hits: Vec::new(),
55            responded: 0,
56            expected: expected_shards,
57        }
58    }
59
60    /// Add a shard's search results.
61    pub fn add_shard_result(&mut self, result: &ShardSearchResult) {
62        if result.success {
63            self.all_hits.extend_from_slice(&result.hits);
64        }
65        self.responded += 1;
66    }
67
68    /// Whether all expected shards have responded.
69    pub fn all_responded(&self) -> bool {
70        self.responded >= self.expected
71    }
72
73    /// Merge all shard results and return the global top-K.
74    ///
75    /// Sorts by distance ascending (nearest first) and truncates to `top_k`.
76    /// Ties are broken by shard_id for deterministic ordering.
77    pub fn top_k(&mut self, top_k: usize) -> Vec<VectorHit> {
78        self.all_hits.sort_by(|a, b| {
79            a.distance
80                .partial_cmp(&b.distance)
81                .unwrap_or(std::cmp::Ordering::Equal)
82                .then(a.shard_id.cmp(&b.shard_id))
83        });
84        self.all_hits.truncate(top_k);
85        self.all_hits.clone()
86    }
87
88    /// Number of total hits collected (before merge).
89    pub fn total_hits(&self) -> usize {
90        self.all_hits.len()
91    }
92
93    /// Number of shards that responded.
94    pub fn response_count(&self) -> usize {
95        self.responded
96    }
97}
98
99/// Determine the per-shard over-fetch factor.
100///
101/// When metadata filters are active, some shard-local results will be
102/// filtered out during post-filter. Over-fetching compensates:
103/// - No filter: 1x (exact top-K per shard)
104/// - Light filter (>50% pass rate): 2x
105/// - Heavy filter (<50% pass rate): 3x
106pub fn over_fetch_factor(has_filter: bool, estimated_pass_rate: f64) -> usize {
107    if !has_filter {
108        return 1;
109    }
110    if estimated_pass_rate > 0.5 { 2 } else { 3 }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn merge_two_shards() {
119        let mut merger = VectorMerger::new(2);
120
121        merger.add_shard_result(&ShardSearchResult {
122            shard_id: 0,
123            hits: vec![
124                VectorHit {
125                    vector_id: 1,
126                    distance: 0.1,
127                    shard_id: 0,
128                    doc_id: None,
129                },
130                VectorHit {
131                    vector_id: 2,
132                    distance: 0.3,
133                    shard_id: 0,
134                    doc_id: None,
135                },
136                VectorHit {
137                    vector_id: 3,
138                    distance: 0.5,
139                    shard_id: 0,
140                    doc_id: None,
141                },
142            ],
143            success: true,
144            error: None,
145        });
146
147        merger.add_shard_result(&ShardSearchResult {
148            shard_id: 1,
149            hits: vec![
150                VectorHit {
151                    vector_id: 10,
152                    distance: 0.05,
153                    shard_id: 1,
154                    doc_id: None,
155                },
156                VectorHit {
157                    vector_id: 11,
158                    distance: 0.2,
159                    shard_id: 1,
160                    doc_id: None,
161                },
162                VectorHit {
163                    vector_id: 12,
164                    distance: 0.4,
165                    shard_id: 1,
166                    doc_id: None,
167                },
168            ],
169            success: true,
170            error: None,
171        });
172
173        assert!(merger.all_responded());
174        let top3 = merger.top_k(3);
175        assert_eq!(top3.len(), 3);
176        // Nearest hit should be shard 1's vector_id 10 (distance 0.05).
177        assert_eq!(top3[0].vector_id, 10);
178        assert_eq!(top3[0].shard_id, 1);
179        assert!((top3[0].distance - 0.05).abs() < f32::EPSILON);
180        // Second should be shard 0's vector_id 1 (distance 0.1).
181        assert_eq!(top3[1].vector_id, 1);
182        // Third should be shard 1's vector_id 11 (distance 0.2).
183        assert_eq!(top3[2].vector_id, 11);
184    }
185
186    #[test]
187    fn merge_with_failed_shard() {
188        let mut merger = VectorMerger::new(2);
189
190        merger.add_shard_result(&ShardSearchResult {
191            shard_id: 0,
192            hits: vec![VectorHit {
193                vector_id: 1,
194                distance: 0.1,
195                shard_id: 0,
196                doc_id: None,
197            }],
198            success: true,
199            error: None,
200        });
201
202        // Shard 1 failed — no hits.
203        merger.add_shard_result(&ShardSearchResult {
204            shard_id: 1,
205            hits: vec![],
206            success: false,
207            error: Some("timeout".into()),
208        });
209
210        assert!(merger.all_responded());
211        let top1 = merger.top_k(1);
212        assert_eq!(top1.len(), 1);
213        assert_eq!(top1[0].vector_id, 1);
214    }
215
216    #[test]
217    fn top_k_truncation() {
218        let mut merger = VectorMerger::new(1);
219        merger.add_shard_result(&ShardSearchResult {
220            shard_id: 0,
221            hits: (0..100)
222                .map(|i| VectorHit {
223                    vector_id: i,
224                    distance: i as f32 * 0.01,
225                    shard_id: 0,
226                    doc_id: None,
227                })
228                .collect(),
229            success: true,
230            error: None,
231        });
232
233        let top5 = merger.top_k(5);
234        assert_eq!(top5.len(), 5);
235        // Nearest 5 by distance.
236        for (i, hit) in top5.iter().enumerate() {
237            assert_eq!(hit.vector_id, i as u32);
238        }
239    }
240
241    #[test]
242    fn over_fetch_no_filter() {
243        assert_eq!(over_fetch_factor(false, 1.0), 1);
244    }
245
246    #[test]
247    fn over_fetch_light_filter() {
248        assert_eq!(over_fetch_factor(true, 0.7), 2);
249    }
250
251    #[test]
252    fn over_fetch_heavy_filter() {
253        assert_eq!(over_fetch_factor(true, 0.2), 3);
254    }
255
256    #[test]
257    fn deterministic_tie_breaking() {
258        let mut merger = VectorMerger::new(2);
259        merger.add_shard_result(&ShardSearchResult {
260            shard_id: 0,
261            hits: vec![VectorHit {
262                vector_id: 1,
263                distance: 0.5,
264                shard_id: 0,
265                doc_id: None,
266            }],
267            success: true,
268            error: None,
269        });
270        merger.add_shard_result(&ShardSearchResult {
271            shard_id: 1,
272            hits: vec![VectorHit {
273                vector_id: 2,
274                distance: 0.5,
275                shard_id: 1,
276                doc_id: None,
277            }],
278            success: true,
279            error: None,
280        });
281
282        let top2 = merger.top_k(2);
283        // Same distance — broken by shard_id (0 < 1).
284        assert_eq!(top2[0].shard_id, 0);
285        assert_eq!(top2[1].shard_id, 1);
286    }
287}