nodedb_cluster/distributed_spatial/
merge.rs1use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct SpatialHit {
14 pub doc_id: String,
16 pub shard_id: u32,
18 pub distance_meters: f64,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ShardSpatialResult {
26 pub shard_id: u32,
27 pub hits: Vec<SpatialHit>,
28 pub success: bool,
29 pub error: Option<String>,
30}
31
32pub struct SpatialResultMerger {
37 all_hits: Vec<SpatialHit>,
38 responded: usize,
39 expected: usize,
40}
41
42impl SpatialResultMerger {
43 pub fn new(expected_shards: usize) -> Self {
44 Self {
45 all_hits: Vec::new(),
46 responded: 0,
47 expected: expected_shards,
48 }
49 }
50
51 pub fn add_shard_result(&mut self, result: &ShardSpatialResult) {
53 if result.success {
54 self.all_hits.extend_from_slice(&result.hits);
55 }
56 self.responded += 1;
57 }
58
59 pub fn all_responded(&self) -> bool {
61 self.responded >= self.expected
62 }
63
64 pub fn merge(&mut self, limit: usize, sort_by_distance: bool) -> Vec<SpatialHit> {
69 let mut seen = std::collections::HashSet::new();
72 self.all_hits.retain(|h| seen.insert(h.doc_id.clone()));
73
74 if sort_by_distance {
75 self.all_hits.sort_by(|a, b| {
76 a.distance_meters
77 .partial_cmp(&b.distance_meters)
78 .unwrap_or(std::cmp::Ordering::Equal)
79 });
80 }
81
82 self.all_hits.truncate(limit);
83 self.all_hits.clone()
84 }
85
86 pub fn total_hits(&self) -> usize {
88 self.all_hits.len()
89 }
90
91 pub fn response_count(&self) -> usize {
92 self.responded
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn merge_two_shards_boolean() {
102 let mut merger = SpatialResultMerger::new(2);
103 merger.add_shard_result(&ShardSpatialResult {
104 shard_id: 0,
105 hits: vec![
106 SpatialHit {
107 doc_id: "a".into(),
108 shard_id: 0,
109 distance_meters: 0.0,
110 },
111 SpatialHit {
112 doc_id: "b".into(),
113 shard_id: 0,
114 distance_meters: 0.0,
115 },
116 ],
117 success: true,
118 error: None,
119 });
120 merger.add_shard_result(&ShardSpatialResult {
121 shard_id: 1,
122 hits: vec![SpatialHit {
123 doc_id: "c".into(),
124 shard_id: 1,
125 distance_meters: 0.0,
126 }],
127 success: true,
128 error: None,
129 });
130
131 assert!(merger.all_responded());
132 let results = merger.merge(10, false);
133 assert_eq!(results.len(), 3);
134 }
135
136 #[test]
137 fn merge_with_distance_sort() {
138 let mut merger = SpatialResultMerger::new(2);
139 merger.add_shard_result(&ShardSpatialResult {
140 shard_id: 0,
141 hits: vec![SpatialHit {
142 doc_id: "far".into(),
143 shard_id: 0,
144 distance_meters: 500.0,
145 }],
146 success: true,
147 error: None,
148 });
149 merger.add_shard_result(&ShardSpatialResult {
150 shard_id: 1,
151 hits: vec![SpatialHit {
152 doc_id: "near".into(),
153 shard_id: 1,
154 distance_meters: 100.0,
155 }],
156 success: true,
157 error: None,
158 });
159
160 let results = merger.merge(10, true);
161 assert_eq!(results[0].doc_id, "near");
162 assert_eq!(results[1].doc_id, "far");
163 }
164
165 #[test]
166 fn merge_with_failed_shard() {
167 let mut merger = SpatialResultMerger::new(2);
168 merger.add_shard_result(&ShardSpatialResult {
169 shard_id: 0,
170 hits: vec![SpatialHit {
171 doc_id: "a".into(),
172 shard_id: 0,
173 distance_meters: 0.0,
174 }],
175 success: true,
176 error: None,
177 });
178 merger.add_shard_result(&ShardSpatialResult {
179 shard_id: 1,
180 hits: vec![],
181 success: false,
182 error: Some("timeout".into()),
183 });
184
185 let results = merger.merge(10, false);
186 assert_eq!(results.len(), 1);
187 }
188
189 #[test]
190 fn merge_respects_limit() {
191 let mut merger = SpatialResultMerger::new(1);
192 merger.add_shard_result(&ShardSpatialResult {
193 shard_id: 0,
194 hits: (0..100)
195 .map(|i| SpatialHit {
196 doc_id: format!("d{i}"),
197 shard_id: 0,
198 distance_meters: i as f64,
199 })
200 .collect(),
201 success: true,
202 error: None,
203 });
204 let results = merger.merge(5, true);
205 assert_eq!(results.len(), 5);
206 assert_eq!(results[0].distance_meters, 0.0);
207 }
208}