nodedb_cluster/distributed_vector/
merge.rs1use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct VectorHit {
17 pub vector_id: u32,
19 pub distance: f32,
21 pub shard_id: u32,
23 pub doc_id: Option<String>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ShardSearchResult {
30 pub shard_id: u32,
31 pub hits: Vec<VectorHit>,
32 pub success: bool,
33 pub error: Option<String>,
34}
35
36pub struct VectorMerger {
43 all_hits: Vec<VectorHit>,
45 responded: usize,
47 expected: usize,
49}
50
51impl VectorMerger {
52 pub fn new(expected_shards: usize) -> Self {
53 Self {
54 all_hits: Vec::new(),
55 responded: 0,
56 expected: expected_shards,
57 }
58 }
59
60 pub fn add_shard_result(&mut self, result: &ShardSearchResult) {
62 if result.success {
63 self.all_hits.extend_from_slice(&result.hits);
64 }
65 self.responded += 1;
66 }
67
68 pub fn all_responded(&self) -> bool {
70 self.responded >= self.expected
71 }
72
73 pub fn top_k(&mut self, top_k: usize) -> Vec<VectorHit> {
78 self.all_hits.sort_by(|a, b| {
79 a.distance
80 .partial_cmp(&b.distance)
81 .unwrap_or(std::cmp::Ordering::Equal)
82 .then(a.shard_id.cmp(&b.shard_id))
83 });
84 self.all_hits.truncate(top_k);
85 self.all_hits.clone()
86 }
87
88 pub fn total_hits(&self) -> usize {
90 self.all_hits.len()
91 }
92
93 pub fn response_count(&self) -> usize {
95 self.responded
96 }
97}
98
99pub fn over_fetch_factor(has_filter: bool, estimated_pass_rate: f64) -> usize {
107 if !has_filter {
108 return 1;
109 }
110 if estimated_pass_rate > 0.5 { 2 } else { 3 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn merge_two_shards() {
119 let mut merger = VectorMerger::new(2);
120
121 merger.add_shard_result(&ShardSearchResult {
122 shard_id: 0,
123 hits: vec![
124 VectorHit {
125 vector_id: 1,
126 distance: 0.1,
127 shard_id: 0,
128 doc_id: None,
129 },
130 VectorHit {
131 vector_id: 2,
132 distance: 0.3,
133 shard_id: 0,
134 doc_id: None,
135 },
136 VectorHit {
137 vector_id: 3,
138 distance: 0.5,
139 shard_id: 0,
140 doc_id: None,
141 },
142 ],
143 success: true,
144 error: None,
145 });
146
147 merger.add_shard_result(&ShardSearchResult {
148 shard_id: 1,
149 hits: vec![
150 VectorHit {
151 vector_id: 10,
152 distance: 0.05,
153 shard_id: 1,
154 doc_id: None,
155 },
156 VectorHit {
157 vector_id: 11,
158 distance: 0.2,
159 shard_id: 1,
160 doc_id: None,
161 },
162 VectorHit {
163 vector_id: 12,
164 distance: 0.4,
165 shard_id: 1,
166 doc_id: None,
167 },
168 ],
169 success: true,
170 error: None,
171 });
172
173 assert!(merger.all_responded());
174 let top3 = merger.top_k(3);
175 assert_eq!(top3.len(), 3);
176 assert_eq!(top3[0].vector_id, 10);
178 assert_eq!(top3[0].shard_id, 1);
179 assert!((top3[0].distance - 0.05).abs() < f32::EPSILON);
180 assert_eq!(top3[1].vector_id, 1);
182 assert_eq!(top3[2].vector_id, 11);
184 }
185
186 #[test]
187 fn merge_with_failed_shard() {
188 let mut merger = VectorMerger::new(2);
189
190 merger.add_shard_result(&ShardSearchResult {
191 shard_id: 0,
192 hits: vec![VectorHit {
193 vector_id: 1,
194 distance: 0.1,
195 shard_id: 0,
196 doc_id: None,
197 }],
198 success: true,
199 error: None,
200 });
201
202 merger.add_shard_result(&ShardSearchResult {
204 shard_id: 1,
205 hits: vec![],
206 success: false,
207 error: Some("timeout".into()),
208 });
209
210 assert!(merger.all_responded());
211 let top1 = merger.top_k(1);
212 assert_eq!(top1.len(), 1);
213 assert_eq!(top1[0].vector_id, 1);
214 }
215
216 #[test]
217 fn top_k_truncation() {
218 let mut merger = VectorMerger::new(1);
219 merger.add_shard_result(&ShardSearchResult {
220 shard_id: 0,
221 hits: (0..100)
222 .map(|i| VectorHit {
223 vector_id: i,
224 distance: i as f32 * 0.01,
225 shard_id: 0,
226 doc_id: None,
227 })
228 .collect(),
229 success: true,
230 error: None,
231 });
232
233 let top5 = merger.top_k(5);
234 assert_eq!(top5.len(), 5);
235 for (i, hit) in top5.iter().enumerate() {
237 assert_eq!(hit.vector_id, i as u32);
238 }
239 }
240
241 #[test]
242 fn over_fetch_no_filter() {
243 assert_eq!(over_fetch_factor(false, 1.0), 1);
244 }
245
246 #[test]
247 fn over_fetch_light_filter() {
248 assert_eq!(over_fetch_factor(true, 0.7), 2);
249 }
250
251 #[test]
252 fn over_fetch_heavy_filter() {
253 assert_eq!(over_fetch_factor(true, 0.2), 3);
254 }
255
256 #[test]
257 fn deterministic_tie_breaking() {
258 let mut merger = VectorMerger::new(2);
259 merger.add_shard_result(&ShardSearchResult {
260 shard_id: 0,
261 hits: vec![VectorHit {
262 vector_id: 1,
263 distance: 0.5,
264 shard_id: 0,
265 doc_id: None,
266 }],
267 success: true,
268 error: None,
269 });
270 merger.add_shard_result(&ShardSearchResult {
271 shard_id: 1,
272 hits: vec![VectorHit {
273 vector_id: 2,
274 distance: 0.5,
275 shard_id: 1,
276 doc_id: None,
277 }],
278 success: true,
279 error: None,
280 });
281
282 let top2 = merger.top_k(2);
283 assert_eq!(top2[0].shard_id, 0);
285 assert_eq!(top2[1].shard_id, 1);
286 }
287}