Skip to main content

nodedb_cluster/distributed_vector/
coordinator.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! Vector scatter-gather coordinator for cross-shard k-NN search.
4//!
5//! Same pattern as graph BSP and timeseries scatter-gather:
6//! coordinator → VShardEnvelope per shard → collect responses → merge.
7
8use super::merge::{ShardSearchResult, VectorHit, VectorMerger};
9use crate::wire::{VShardEnvelope, VShardMessageType};
10
11/// Wire message for vector scatter request payload (zerompk).
12#[derive(Debug, Clone, zerompk::ToMessagePack, zerompk::FromMessagePack)]
13pub struct VectorScatterPayload {
14    pub collection: String,
15    pub query_vector: Vec<f32>,
16    pub top_k: u32,
17    pub ef_search: u32,
18    pub has_filter: bool,
19}
20
21/// Scatter-gather coordinator for distributed k-NN vector search.
22pub struct VectorScatterGather {
23    /// Source node ID (this coordinator's node).
24    pub source_node: u64,
25    /// Target shard IDs to fan out to.
26    pub shard_ids: Vec<u32>,
27    /// Merger collecting shard responses.
28    merger: VectorMerger,
29}
30
31impl VectorScatterGather {
32    pub fn new(source_node: u64, shard_ids: Vec<u32>) -> Self {
33        let count = shard_ids.len();
34        Self {
35            source_node,
36            shard_ids,
37            merger: VectorMerger::new(count),
38        }
39    }
40
41    /// Build scatter envelopes for a k-NN search query.
42    ///
43    /// Returns one `VShardEnvelope` per shard. Each contains the query
44    /// vector + parameters as JSON payload.
45    pub fn build_scatter_envelopes(
46        &self,
47        collection: &str,
48        query_vector: &[f32],
49        top_k: usize,
50        ef_search: usize,
51        filter_bitmap: Option<&[u8]>,
52    ) -> Vec<(u32, VShardEnvelope)> {
53        let msg = VectorScatterPayload {
54            collection: collection.to_string(),
55            query_vector: query_vector.to_vec(),
56            top_k: top_k as u32,
57            ef_search: ef_search as u32,
58            has_filter: filter_bitmap.is_some(),
59        };
60        let mut payload_bytes =
61            zerompk::to_msgpack_vec(&msg).expect("VectorScatterPayload is always serializable");
62
63        // Append filter bitmap as raw bytes after JSON (length-prefixed).
64        if let Some(bitmap) = filter_bitmap {
65            payload_bytes.extend_from_slice(&(bitmap.len() as u32).to_le_bytes());
66            payload_bytes.extend_from_slice(bitmap);
67        }
68
69        self.shard_ids
70            .iter()
71            .map(|&shard_id| {
72                let env = VShardEnvelope::new(
73                    VShardMessageType::VectorScatterRequest,
74                    self.source_node,
75                    0, // target_node resolved by routing table
76                    shard_id,
77                    payload_bytes.clone(),
78                );
79                (shard_id, env)
80            })
81            .collect()
82    }
83
84    /// Record a shard's response.
85    pub fn record_response(&mut self, result: &ShardSearchResult) {
86        self.merger.add_shard_result(result);
87    }
88
89    /// Whether all shards have responded.
90    pub fn all_responded(&self) -> bool {
91        self.merger.all_responded()
92    }
93
94    /// Merge all shard results and return the global top-K.
95    pub fn merge_top_k(&mut self, top_k: usize) -> Vec<VectorHit> {
96        self.merger.top_k(top_k)
97    }
98
99    /// Number of shards that have responded.
100    pub fn response_count(&self) -> usize {
101        self.merger.response_count()
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn scatter_envelopes_built() {
111        let coord = VectorScatterGather::new(1, vec![0, 1, 2]);
112        let query = vec![0.1f32, 0.2, 0.3];
113        let envs = coord.build_scatter_envelopes("embeddings", &query, 10, 100, None);
114        assert_eq!(envs.len(), 3);
115        for (shard_id, env) in &envs {
116            assert_eq!(env.msg_type, VShardMessageType::VectorScatterRequest);
117            assert_eq!(env.vshard_id, *shard_id);
118            assert!(!env.payload.is_empty());
119        }
120    }
121
122    #[test]
123    fn scatter_with_filter() {
124        let coord = VectorScatterGather::new(1, vec![0, 1]);
125        let query = vec![1.0f32; 32];
126        let filter = vec![0xFF_u8; 128];
127        let envs = coord.build_scatter_envelopes("col", &query, 5, 50, Some(&filter));
128        assert_eq!(envs.len(), 2);
129        // Payload should be larger than without filter.
130        let no_filter = coord.build_scatter_envelopes("col", &query, 5, 50, None);
131        assert!(envs[0].1.payload.len() > no_filter[0].1.payload.len());
132    }
133
134    #[test]
135    fn collect_and_merge() {
136        let mut coord = VectorScatterGather::new(1, vec![0, 1]);
137        assert!(!coord.all_responded());
138
139        coord.record_response(&ShardSearchResult {
140            shard_id: 0,
141            hits: vec![
142                VectorHit {
143                    vector_id: 1,
144                    distance: 0.1,
145                    shard_id: 0,
146                    doc_id: None,
147                },
148                VectorHit {
149                    vector_id: 2,
150                    distance: 0.5,
151                    shard_id: 0,
152                    doc_id: None,
153                },
154            ],
155            success: true,
156            error: None,
157        });
158        assert!(!coord.all_responded());
159
160        coord.record_response(&ShardSearchResult {
161            shard_id: 1,
162            hits: vec![
163                VectorHit {
164                    vector_id: 10,
165                    distance: 0.05,
166                    shard_id: 1,
167                    doc_id: None,
168                },
169                VectorHit {
170                    vector_id: 11,
171                    distance: 0.3,
172                    shard_id: 1,
173                    doc_id: None,
174                },
175            ],
176            success: true,
177            error: None,
178        });
179        assert!(coord.all_responded());
180
181        let top2 = coord.merge_top_k(2);
182        assert_eq!(top2.len(), 2);
183        assert_eq!(top2[0].vector_id, 10); // distance 0.05
184        assert_eq!(top2[1].vector_id, 1); // distance 0.1
185    }
186}