nodedb_cluster/distributed_vector/
coordinator.rs1use super::merge::{ShardSearchResult, VectorHit, VectorMerger};
7use crate::wire::{VShardEnvelope, VShardMessageType};
8
9#[derive(Debug, Clone, zerompk::ToMessagePack, zerompk::FromMessagePack)]
11pub struct VectorScatterPayload {
12 pub collection: String,
13 pub query_vector: Vec<f32>,
14 pub top_k: u32,
15 pub ef_search: u32,
16 pub has_filter: bool,
17}
18
19pub struct VectorScatterGather {
21 pub source_node: u64,
23 pub shard_ids: Vec<u16>,
25 merger: VectorMerger,
27}
28
29impl VectorScatterGather {
30 pub fn new(source_node: u64, shard_ids: Vec<u16>) -> Self {
31 let count = shard_ids.len();
32 Self {
33 source_node,
34 shard_ids,
35 merger: VectorMerger::new(count),
36 }
37 }
38
39 pub fn build_scatter_envelopes(
44 &self,
45 collection: &str,
46 query_vector: &[f32],
47 top_k: usize,
48 ef_search: usize,
49 filter_bitmap: Option<&[u8]>,
50 ) -> Vec<(u16, VShardEnvelope)> {
51 let msg = VectorScatterPayload {
52 collection: collection.to_string(),
53 query_vector: query_vector.to_vec(),
54 top_k: top_k as u32,
55 ef_search: ef_search as u32,
56 has_filter: filter_bitmap.is_some(),
57 };
58 let mut payload_bytes =
59 zerompk::to_msgpack_vec(&msg).expect("VectorScatterPayload is always serializable");
60
61 if let Some(bitmap) = filter_bitmap {
63 payload_bytes.extend_from_slice(&(bitmap.len() as u32).to_le_bytes());
64 payload_bytes.extend_from_slice(bitmap);
65 }
66
67 self.shard_ids
68 .iter()
69 .map(|&shard_id| {
70 let env = VShardEnvelope::new(
71 VShardMessageType::VectorScatterRequest,
72 self.source_node,
73 0, shard_id,
75 payload_bytes.clone(),
76 );
77 (shard_id, env)
78 })
79 .collect()
80 }
81
82 pub fn record_response(&mut self, result: &ShardSearchResult) {
84 self.merger.add_shard_result(result);
85 }
86
87 pub fn all_responded(&self) -> bool {
89 self.merger.all_responded()
90 }
91
92 pub fn merge_top_k(&mut self, top_k: usize) -> Vec<VectorHit> {
94 self.merger.top_k(top_k)
95 }
96
97 pub fn response_count(&self) -> usize {
99 self.merger.response_count()
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn scatter_envelopes_built() {
109 let coord = VectorScatterGather::new(1, vec![0, 1, 2]);
110 let query = vec![0.1f32, 0.2, 0.3];
111 let envs = coord.build_scatter_envelopes("embeddings", &query, 10, 100, None);
112 assert_eq!(envs.len(), 3);
113 for (shard_id, env) in &envs {
114 assert_eq!(env.msg_type, VShardMessageType::VectorScatterRequest);
115 assert_eq!(env.vshard_id, *shard_id);
116 assert!(!env.payload.is_empty());
117 }
118 }
119
120 #[test]
121 fn scatter_with_filter() {
122 let coord = VectorScatterGather::new(1, vec![0, 1]);
123 let query = vec![1.0f32; 32];
124 let filter = vec![0xFF_u8; 128];
125 let envs = coord.build_scatter_envelopes("col", &query, 5, 50, Some(&filter));
126 assert_eq!(envs.len(), 2);
127 let no_filter = coord.build_scatter_envelopes("col", &query, 5, 50, None);
129 assert!(envs[0].1.payload.len() > no_filter[0].1.payload.len());
130 }
131
132 #[test]
133 fn collect_and_merge() {
134 let mut coord = VectorScatterGather::new(1, vec![0, 1]);
135 assert!(!coord.all_responded());
136
137 coord.record_response(&ShardSearchResult {
138 shard_id: 0,
139 hits: vec![
140 VectorHit {
141 vector_id: 1,
142 distance: 0.1,
143 shard_id: 0,
144 doc_id: None,
145 },
146 VectorHit {
147 vector_id: 2,
148 distance: 0.5,
149 shard_id: 0,
150 doc_id: None,
151 },
152 ],
153 success: true,
154 error: None,
155 });
156 assert!(!coord.all_responded());
157
158 coord.record_response(&ShardSearchResult {
159 shard_id: 1,
160 hits: vec![
161 VectorHit {
162 vector_id: 10,
163 distance: 0.05,
164 shard_id: 1,
165 doc_id: None,
166 },
167 VectorHit {
168 vector_id: 11,
169 distance: 0.3,
170 shard_id: 1,
171 doc_id: None,
172 },
173 ],
174 success: true,
175 error: None,
176 });
177 assert!(coord.all_responded());
178
179 let top2 = coord.merge_top_k(2);
180 assert_eq!(top2.len(), 2);
181 assert_eq!(top2[0].vector_id, 10); assert_eq!(top2[1].vector_id, 1); }
184}