nodedb_cluster/distributed_vector/
merge.rs1use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct VectorHit {
15 pub vector_id: u32,
17 pub distance: f32,
19 pub shard_id: u16,
21 pub doc_id: Option<String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ShardSearchResult {
28 pub shard_id: u16,
29 pub hits: Vec<VectorHit>,
30 pub success: bool,
31 pub error: Option<String>,
32}
33
34pub struct VectorMerger {
41 all_hits: Vec<VectorHit>,
43 responded: usize,
45 expected: usize,
47}
48
49impl VectorMerger {
50 pub fn new(expected_shards: usize) -> Self {
51 Self {
52 all_hits: Vec::new(),
53 responded: 0,
54 expected: expected_shards,
55 }
56 }
57
58 pub fn add_shard_result(&mut self, result: &ShardSearchResult) {
60 if result.success {
61 self.all_hits.extend_from_slice(&result.hits);
62 }
63 self.responded += 1;
64 }
65
66 pub fn all_responded(&self) -> bool {
68 self.responded >= self.expected
69 }
70
71 pub fn top_k(&mut self, top_k: usize) -> Vec<VectorHit> {
76 self.all_hits.sort_by(|a, b| {
77 a.distance
78 .partial_cmp(&b.distance)
79 .unwrap_or(std::cmp::Ordering::Equal)
80 .then(a.shard_id.cmp(&b.shard_id))
81 });
82 self.all_hits.truncate(top_k);
83 self.all_hits.clone()
84 }
85
86 pub fn total_hits(&self) -> usize {
88 self.all_hits.len()
89 }
90
91 pub fn response_count(&self) -> usize {
93 self.responded
94 }
95}
96
97pub fn over_fetch_factor(has_filter: bool, estimated_pass_rate: f64) -> usize {
105 if !has_filter {
106 return 1;
107 }
108 if estimated_pass_rate > 0.5 { 2 } else { 3 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 #[test]
116 fn merge_two_shards() {
117 let mut merger = VectorMerger::new(2);
118
119 merger.add_shard_result(&ShardSearchResult {
120 shard_id: 0,
121 hits: vec![
122 VectorHit {
123 vector_id: 1,
124 distance: 0.1,
125 shard_id: 0,
126 doc_id: None,
127 },
128 VectorHit {
129 vector_id: 2,
130 distance: 0.3,
131 shard_id: 0,
132 doc_id: None,
133 },
134 VectorHit {
135 vector_id: 3,
136 distance: 0.5,
137 shard_id: 0,
138 doc_id: None,
139 },
140 ],
141 success: true,
142 error: None,
143 });
144
145 merger.add_shard_result(&ShardSearchResult {
146 shard_id: 1,
147 hits: vec![
148 VectorHit {
149 vector_id: 10,
150 distance: 0.05,
151 shard_id: 1,
152 doc_id: None,
153 },
154 VectorHit {
155 vector_id: 11,
156 distance: 0.2,
157 shard_id: 1,
158 doc_id: None,
159 },
160 VectorHit {
161 vector_id: 12,
162 distance: 0.4,
163 shard_id: 1,
164 doc_id: None,
165 },
166 ],
167 success: true,
168 error: None,
169 });
170
171 assert!(merger.all_responded());
172 let top3 = merger.top_k(3);
173 assert_eq!(top3.len(), 3);
174 assert_eq!(top3[0].vector_id, 10);
176 assert_eq!(top3[0].shard_id, 1);
177 assert!((top3[0].distance - 0.05).abs() < f32::EPSILON);
178 assert_eq!(top3[1].vector_id, 1);
180 assert_eq!(top3[2].vector_id, 11);
182 }
183
184 #[test]
185 fn merge_with_failed_shard() {
186 let mut merger = VectorMerger::new(2);
187
188 merger.add_shard_result(&ShardSearchResult {
189 shard_id: 0,
190 hits: vec![VectorHit {
191 vector_id: 1,
192 distance: 0.1,
193 shard_id: 0,
194 doc_id: None,
195 }],
196 success: true,
197 error: None,
198 });
199
200 merger.add_shard_result(&ShardSearchResult {
202 shard_id: 1,
203 hits: vec![],
204 success: false,
205 error: Some("timeout".into()),
206 });
207
208 assert!(merger.all_responded());
209 let top1 = merger.top_k(1);
210 assert_eq!(top1.len(), 1);
211 assert_eq!(top1[0].vector_id, 1);
212 }
213
214 #[test]
215 fn top_k_truncation() {
216 let mut merger = VectorMerger::new(1);
217 merger.add_shard_result(&ShardSearchResult {
218 shard_id: 0,
219 hits: (0..100)
220 .map(|i| VectorHit {
221 vector_id: i,
222 distance: i as f32 * 0.01,
223 shard_id: 0,
224 doc_id: None,
225 })
226 .collect(),
227 success: true,
228 error: None,
229 });
230
231 let top5 = merger.top_k(5);
232 assert_eq!(top5.len(), 5);
233 for (i, hit) in top5.iter().enumerate() {
235 assert_eq!(hit.vector_id, i as u32);
236 }
237 }
238
239 #[test]
240 fn over_fetch_no_filter() {
241 assert_eq!(over_fetch_factor(false, 1.0), 1);
242 }
243
244 #[test]
245 fn over_fetch_light_filter() {
246 assert_eq!(over_fetch_factor(true, 0.7), 2);
247 }
248
249 #[test]
250 fn over_fetch_heavy_filter() {
251 assert_eq!(over_fetch_factor(true, 0.2), 3);
252 }
253
254 #[test]
255 fn deterministic_tie_breaking() {
256 let mut merger = VectorMerger::new(2);
257 merger.add_shard_result(&ShardSearchResult {
258 shard_id: 0,
259 hits: vec![VectorHit {
260 vector_id: 1,
261 distance: 0.5,
262 shard_id: 0,
263 doc_id: None,
264 }],
265 success: true,
266 error: None,
267 });
268 merger.add_shard_result(&ShardSearchResult {
269 shard_id: 1,
270 hits: vec![VectorHit {
271 vector_id: 2,
272 distance: 0.5,
273 shard_id: 1,
274 doc_id: None,
275 }],
276 success: true,
277 error: None,
278 });
279
280 let top2 = merger.top_k(2);
281 assert_eq!(top2[0].shard_id, 0);
283 assert_eq!(top2[1].shard_id, 1);
284 }
285}