nodedb_cluster/distributed_vector/
coordinator.rs1use super::merge::{ShardSearchResult, VectorHit, VectorMerger};
9use crate::wire::{VShardEnvelope, VShardMessageType};
10
11#[derive(Debug, Clone, zerompk::ToMessagePack, zerompk::FromMessagePack)]
13pub struct VectorScatterPayload {
14 pub collection: String,
15 pub query_vector: Vec<f32>,
16 pub top_k: u32,
17 pub ef_search: u32,
18 pub has_filter: bool,
19}
20
21pub struct VectorScatterGather {
23 pub source_node: u64,
25 pub shard_ids: Vec<u32>,
27 merger: VectorMerger,
29}
30
31impl VectorScatterGather {
32 pub fn new(source_node: u64, shard_ids: Vec<u32>) -> Self {
33 let count = shard_ids.len();
34 Self {
35 source_node,
36 shard_ids,
37 merger: VectorMerger::new(count),
38 }
39 }
40
41 pub fn build_scatter_envelopes(
46 &self,
47 collection: &str,
48 query_vector: &[f32],
49 top_k: usize,
50 ef_search: usize,
51 filter_bitmap: Option<&[u8]>,
52 ) -> Vec<(u32, VShardEnvelope)> {
53 let msg = VectorScatterPayload {
54 collection: collection.to_string(),
55 query_vector: query_vector.to_vec(),
56 top_k: top_k as u32,
57 ef_search: ef_search as u32,
58 has_filter: filter_bitmap.is_some(),
59 };
60 let mut payload_bytes =
61 zerompk::to_msgpack_vec(&msg).expect("VectorScatterPayload is always serializable");
62
63 if let Some(bitmap) = filter_bitmap {
65 payload_bytes.extend_from_slice(&(bitmap.len() as u32).to_le_bytes());
66 payload_bytes.extend_from_slice(bitmap);
67 }
68
69 self.shard_ids
70 .iter()
71 .map(|&shard_id| {
72 let env = VShardEnvelope::new(
73 VShardMessageType::VectorScatterRequest,
74 self.source_node,
75 0, shard_id,
77 payload_bytes.clone(),
78 );
79 (shard_id, env)
80 })
81 .collect()
82 }
83
84 pub fn record_response(&mut self, result: &ShardSearchResult) {
86 self.merger.add_shard_result(result);
87 }
88
89 pub fn all_responded(&self) -> bool {
91 self.merger.all_responded()
92 }
93
94 pub fn merge_top_k(&mut self, top_k: usize) -> Vec<VectorHit> {
96 self.merger.top_k(top_k)
97 }
98
99 pub fn response_count(&self) -> usize {
101 self.merger.response_count()
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn scatter_envelopes_built() {
111 let coord = VectorScatterGather::new(1, vec![0, 1, 2]);
112 let query = vec![0.1f32, 0.2, 0.3];
113 let envs = coord.build_scatter_envelopes("embeddings", &query, 10, 100, None);
114 assert_eq!(envs.len(), 3);
115 for (shard_id, env) in &envs {
116 assert_eq!(env.msg_type, VShardMessageType::VectorScatterRequest);
117 assert_eq!(env.vshard_id, *shard_id);
118 assert!(!env.payload.is_empty());
119 }
120 }
121
122 #[test]
123 fn scatter_with_filter() {
124 let coord = VectorScatterGather::new(1, vec![0, 1]);
125 let query = vec![1.0f32; 32];
126 let filter = vec![0xFF_u8; 128];
127 let envs = coord.build_scatter_envelopes("col", &query, 5, 50, Some(&filter));
128 assert_eq!(envs.len(), 2);
129 let no_filter = coord.build_scatter_envelopes("col", &query, 5, 50, None);
131 assert!(envs[0].1.payload.len() > no_filter[0].1.payload.len());
132 }
133
134 #[test]
135 fn collect_and_merge() {
136 let mut coord = VectorScatterGather::new(1, vec![0, 1]);
137 assert!(!coord.all_responded());
138
139 coord.record_response(&ShardSearchResult {
140 shard_id: 0,
141 hits: vec![
142 VectorHit {
143 vector_id: 1,
144 distance: 0.1,
145 shard_id: 0,
146 doc_id: None,
147 },
148 VectorHit {
149 vector_id: 2,
150 distance: 0.5,
151 shard_id: 0,
152 doc_id: None,
153 },
154 ],
155 success: true,
156 error: None,
157 });
158 assert!(!coord.all_responded());
159
160 coord.record_response(&ShardSearchResult {
161 shard_id: 1,
162 hits: vec![
163 VectorHit {
164 vector_id: 10,
165 distance: 0.05,
166 shard_id: 1,
167 doc_id: None,
168 },
169 VectorHit {
170 vector_id: 11,
171 distance: 0.3,
172 shard_id: 1,
173 doc_id: None,
174 },
175 ],
176 success: true,
177 error: None,
178 });
179 assert!(coord.all_responded());
180
181 let top2 = coord.merge_top_k(2);
182 assert_eq!(top2.len(), 2);
183 assert_eq!(top2[0].vector_id, 10); assert_eq!(top2[1].vector_id, 1); }
186}