nodedb_cluster/distributed_spatial/
coordinator.rs1use super::merge::{ShardSpatialResult, SpatialResultMerger};
7use crate::wire::{VShardEnvelope, VShardMessageType};
8
9#[derive(Debug, Clone, zerompk::ToMessagePack, zerompk::FromMessagePack)]
11pub struct SpatialScatterPayload {
12 pub collection: String,
13 pub field: String,
14 pub predicate: String,
15 pub query_geometry: Vec<u8>,
17 pub distance_meters: f64,
18 pub limit: u32,
19}
20
21pub struct SpatialScatterGather {
23 pub source_node: u64,
24 pub shard_ids: Vec<u16>,
25 merger: SpatialResultMerger,
26}
27
28impl SpatialScatterGather {
29 pub fn new(source_node: u64, shard_ids: Vec<u16>) -> Self {
30 let count = shard_ids.len();
31 Self {
32 source_node,
33 shard_ids,
34 merger: SpatialResultMerger::new(count),
35 }
36 }
37
38 pub fn build_scatter_envelopes(
43 &self,
44 collection: &str,
45 field: &str,
46 predicate: &str,
47 query_geometry_json: &[u8],
48 distance_meters: f64,
49 limit: usize,
50 ) -> Vec<(u16, VShardEnvelope)> {
51 let msg = SpatialScatterPayload {
52 collection: collection.to_string(),
53 field: field.to_string(),
54 predicate: predicate.to_string(),
55 query_geometry: query_geometry_json.to_vec(),
56 distance_meters,
57 limit: limit as u32,
58 };
59 let payload_bytes =
60 zerompk::to_msgpack_vec(&msg).expect("SpatialScatterPayload is always serializable");
61
62 self.shard_ids
63 .iter()
64 .map(|&shard_id| {
65 let env = VShardEnvelope::new(
66 VShardMessageType::SpatialScatterRequest,
67 self.source_node,
68 0, shard_id,
70 payload_bytes.clone(),
71 );
72 (shard_id, env)
73 })
74 .collect()
75 }
76
77 pub fn record_response(&mut self, result: &ShardSpatialResult) {
79 self.merger.add_shard_result(result);
80 }
81
82 pub fn all_responded(&self) -> bool {
84 self.merger.all_responded()
85 }
86
87 pub fn merge_results(
89 &mut self,
90 limit: usize,
91 sort_by_distance: bool,
92 ) -> Vec<super::merge::SpatialHit> {
93 self.merger.merge(limit, sort_by_distance)
94 }
95
96 pub fn response_count(&self) -> usize {
97 self.merger.response_count()
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::super::merge::{ShardSpatialResult, SpatialHit};
104 use super::*;
105
106 #[test]
107 fn scatter_envelopes_built() {
108 let coord = SpatialScatterGather::new(1, vec![0, 1, 2]);
109 let query =
110 serde_json::to_vec(&serde_json::json!({"type": "Point", "coordinates": [0.0, 0.0]}))
111 .unwrap();
112 let envs =
113 coord.build_scatter_envelopes("buildings", "geom", "st_dwithin", &query, 1000.0, 100);
114 assert_eq!(envs.len(), 3);
115 for (shard_id, env) in &envs {
116 assert_eq!(env.msg_type, VShardMessageType::SpatialScatterRequest);
117 assert_eq!(env.vshard_id, *shard_id);
118 }
119 }
120
121 #[test]
122 fn collect_and_merge() {
123 let mut coord = SpatialScatterGather::new(1, vec![0, 1]);
124 coord.record_response(&ShardSpatialResult {
125 shard_id: 0,
126 hits: vec![SpatialHit {
127 doc_id: "a".into(),
128 shard_id: 0,
129 distance_meters: 200.0,
130 }],
131 success: true,
132 error: None,
133 });
134 coord.record_response(&ShardSpatialResult {
135 shard_id: 1,
136 hits: vec![SpatialHit {
137 doc_id: "b".into(),
138 shard_id: 1,
139 distance_meters: 50.0,
140 }],
141 success: true,
142 error: None,
143 });
144 assert!(coord.all_responded());
145
146 let results = coord.merge_results(10, true);
147 assert_eq!(results.len(), 2);
148 assert_eq!(results[0].doc_id, "b"); }
150}