Skip to main content

nodedb_cluster/distributed_vector/
coordinator.rs

1//! Vector scatter-gather coordinator for cross-shard k-NN search.
2//!
3//! Same pattern as graph BSP and timeseries scatter-gather:
4//! coordinator → VShardEnvelope per shard → collect responses → merge.
5
6use super::merge::{ShardSearchResult, VectorHit, VectorMerger};
7use crate::wire::{VShardEnvelope, VShardMessageType};
8
9/// Wire message for vector scatter request payload (zerompk).
10#[derive(Debug, Clone, zerompk::ToMessagePack, zerompk::FromMessagePack)]
11pub struct VectorScatterPayload {
12    pub collection: String,
13    pub query_vector: Vec<f32>,
14    pub top_k: u32,
15    pub ef_search: u32,
16    pub has_filter: bool,
17}
18
19/// Scatter-gather coordinator for distributed k-NN vector search.
20pub struct VectorScatterGather {
21    /// Source node ID (this coordinator's node).
22    pub source_node: u64,
23    /// Target shard IDs to fan out to.
24    pub shard_ids: Vec<u16>,
25    /// Merger collecting shard responses.
26    merger: VectorMerger,
27}
28
29impl VectorScatterGather {
30    pub fn new(source_node: u64, shard_ids: Vec<u16>) -> Self {
31        let count = shard_ids.len();
32        Self {
33            source_node,
34            shard_ids,
35            merger: VectorMerger::new(count),
36        }
37    }
38
39    /// Build scatter envelopes for a k-NN search query.
40    ///
41    /// Returns one `VShardEnvelope` per shard. Each contains the query
42    /// vector + parameters as JSON payload.
43    pub fn build_scatter_envelopes(
44        &self,
45        collection: &str,
46        query_vector: &[f32],
47        top_k: usize,
48        ef_search: usize,
49        filter_bitmap: Option<&[u8]>,
50    ) -> Vec<(u16, VShardEnvelope)> {
51        let msg = VectorScatterPayload {
52            collection: collection.to_string(),
53            query_vector: query_vector.to_vec(),
54            top_k: top_k as u32,
55            ef_search: ef_search as u32,
56            has_filter: filter_bitmap.is_some(),
57        };
58        let mut payload_bytes =
59            zerompk::to_msgpack_vec(&msg).expect("VectorScatterPayload is always serializable");
60
61        // Append filter bitmap as raw bytes after JSON (length-prefixed).
62        if let Some(bitmap) = filter_bitmap {
63            payload_bytes.extend_from_slice(&(bitmap.len() as u32).to_le_bytes());
64            payload_bytes.extend_from_slice(bitmap);
65        }
66
67        self.shard_ids
68            .iter()
69            .map(|&shard_id| {
70                let env = VShardEnvelope::new(
71                    VShardMessageType::VectorScatterRequest,
72                    self.source_node,
73                    0, // target_node resolved by routing table
74                    shard_id,
75                    payload_bytes.clone(),
76                );
77                (shard_id, env)
78            })
79            .collect()
80    }
81
82    /// Record a shard's response.
83    pub fn record_response(&mut self, result: &ShardSearchResult) {
84        self.merger.add_shard_result(result);
85    }
86
87    /// Whether all shards have responded.
88    pub fn all_responded(&self) -> bool {
89        self.merger.all_responded()
90    }
91
92    /// Merge all shard results and return the global top-K.
93    pub fn merge_top_k(&mut self, top_k: usize) -> Vec<VectorHit> {
94        self.merger.top_k(top_k)
95    }
96
97    /// Number of shards that have responded.
98    pub fn response_count(&self) -> usize {
99        self.merger.response_count()
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn scatter_envelopes_built() {
109        let coord = VectorScatterGather::new(1, vec![0, 1, 2]);
110        let query = vec![0.1f32, 0.2, 0.3];
111        let envs = coord.build_scatter_envelopes("embeddings", &query, 10, 100, None);
112        assert_eq!(envs.len(), 3);
113        for (shard_id, env) in &envs {
114            assert_eq!(env.msg_type, VShardMessageType::VectorScatterRequest);
115            assert_eq!(env.vshard_id, *shard_id);
116            assert!(!env.payload.is_empty());
117        }
118    }
119
120    #[test]
121    fn scatter_with_filter() {
122        let coord = VectorScatterGather::new(1, vec![0, 1]);
123        let query = vec![1.0f32; 32];
124        let filter = vec![0xFF_u8; 128];
125        let envs = coord.build_scatter_envelopes("col", &query, 5, 50, Some(&filter));
126        assert_eq!(envs.len(), 2);
127        // Payload should be larger than without filter.
128        let no_filter = coord.build_scatter_envelopes("col", &query, 5, 50, None);
129        assert!(envs[0].1.payload.len() > no_filter[0].1.payload.len());
130    }
131
132    #[test]
133    fn collect_and_merge() {
134        let mut coord = VectorScatterGather::new(1, vec![0, 1]);
135        assert!(!coord.all_responded());
136
137        coord.record_response(&ShardSearchResult {
138            shard_id: 0,
139            hits: vec![
140                VectorHit {
141                    vector_id: 1,
142                    distance: 0.1,
143                    shard_id: 0,
144                    doc_id: None,
145                },
146                VectorHit {
147                    vector_id: 2,
148                    distance: 0.5,
149                    shard_id: 0,
150                    doc_id: None,
151                },
152            ],
153            success: true,
154            error: None,
155        });
156        assert!(!coord.all_responded());
157
158        coord.record_response(&ShardSearchResult {
159            shard_id: 1,
160            hits: vec![
161                VectorHit {
162                    vector_id: 10,
163                    distance: 0.05,
164                    shard_id: 1,
165                    doc_id: None,
166                },
167                VectorHit {
168                    vector_id: 11,
169                    distance: 0.3,
170                    shard_id: 1,
171                    doc_id: None,
172                },
173            ],
174            success: true,
175            error: None,
176        });
177        assert!(coord.all_responded());
178
179        let top2 = coord.merge_top_k(2);
180        assert_eq!(top2.len(), 2);
181        assert_eq!(top2[0].vector_id, 10); // distance 0.05
182        assert_eq!(top2[1].vector_id, 1); // distance 0.1
183    }
184}