nodedb_cluster/distributed_spatial/
merge.rs1use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct SpatialHit {
12 pub doc_id: String,
14 pub shard_id: u16,
16 pub distance_meters: f64,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ShardSpatialResult {
24 pub shard_id: u16,
25 pub hits: Vec<SpatialHit>,
26 pub success: bool,
27 pub error: Option<String>,
28}
29
30pub struct SpatialResultMerger {
35 all_hits: Vec<SpatialHit>,
36 responded: usize,
37 expected: usize,
38}
39
40impl SpatialResultMerger {
41 pub fn new(expected_shards: usize) -> Self {
42 Self {
43 all_hits: Vec::new(),
44 responded: 0,
45 expected: expected_shards,
46 }
47 }
48
49 pub fn add_shard_result(&mut self, result: &ShardSpatialResult) {
51 if result.success {
52 self.all_hits.extend_from_slice(&result.hits);
53 }
54 self.responded += 1;
55 }
56
57 pub fn all_responded(&self) -> bool {
59 self.responded >= self.expected
60 }
61
62 pub fn merge(&mut self, limit: usize, sort_by_distance: bool) -> Vec<SpatialHit> {
67 let mut seen = std::collections::HashSet::new();
70 self.all_hits.retain(|h| seen.insert(h.doc_id.clone()));
71
72 if sort_by_distance {
73 self.all_hits.sort_by(|a, b| {
74 a.distance_meters
75 .partial_cmp(&b.distance_meters)
76 .unwrap_or(std::cmp::Ordering::Equal)
77 });
78 }
79
80 self.all_hits.truncate(limit);
81 self.all_hits.clone()
82 }
83
84 pub fn total_hits(&self) -> usize {
86 self.all_hits.len()
87 }
88
89 pub fn response_count(&self) -> usize {
90 self.responded
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[test]
99 fn merge_two_shards_boolean() {
100 let mut merger = SpatialResultMerger::new(2);
101 merger.add_shard_result(&ShardSpatialResult {
102 shard_id: 0,
103 hits: vec![
104 SpatialHit {
105 doc_id: "a".into(),
106 shard_id: 0,
107 distance_meters: 0.0,
108 },
109 SpatialHit {
110 doc_id: "b".into(),
111 shard_id: 0,
112 distance_meters: 0.0,
113 },
114 ],
115 success: true,
116 error: None,
117 });
118 merger.add_shard_result(&ShardSpatialResult {
119 shard_id: 1,
120 hits: vec![SpatialHit {
121 doc_id: "c".into(),
122 shard_id: 1,
123 distance_meters: 0.0,
124 }],
125 success: true,
126 error: None,
127 });
128
129 assert!(merger.all_responded());
130 let results = merger.merge(10, false);
131 assert_eq!(results.len(), 3);
132 }
133
134 #[test]
135 fn merge_with_distance_sort() {
136 let mut merger = SpatialResultMerger::new(2);
137 merger.add_shard_result(&ShardSpatialResult {
138 shard_id: 0,
139 hits: vec![SpatialHit {
140 doc_id: "far".into(),
141 shard_id: 0,
142 distance_meters: 500.0,
143 }],
144 success: true,
145 error: None,
146 });
147 merger.add_shard_result(&ShardSpatialResult {
148 shard_id: 1,
149 hits: vec![SpatialHit {
150 doc_id: "near".into(),
151 shard_id: 1,
152 distance_meters: 100.0,
153 }],
154 success: true,
155 error: None,
156 });
157
158 let results = merger.merge(10, true);
159 assert_eq!(results[0].doc_id, "near");
160 assert_eq!(results[1].doc_id, "far");
161 }
162
163 #[test]
164 fn merge_with_failed_shard() {
165 let mut merger = SpatialResultMerger::new(2);
166 merger.add_shard_result(&ShardSpatialResult {
167 shard_id: 0,
168 hits: vec![SpatialHit {
169 doc_id: "a".into(),
170 shard_id: 0,
171 distance_meters: 0.0,
172 }],
173 success: true,
174 error: None,
175 });
176 merger.add_shard_result(&ShardSpatialResult {
177 shard_id: 1,
178 hits: vec![],
179 success: false,
180 error: Some("timeout".into()),
181 });
182
183 let results = merger.merge(10, false);
184 assert_eq!(results.len(), 1);
185 }
186
187 #[test]
188 fn merge_respects_limit() {
189 let mut merger = SpatialResultMerger::new(1);
190 merger.add_shard_result(&ShardSpatialResult {
191 shard_id: 0,
192 hits: (0..100)
193 .map(|i| SpatialHit {
194 doc_id: format!("d{i}"),
195 shard_id: 0,
196 distance_meters: i as f64,
197 })
198 .collect(),
199 success: true,
200 error: None,
201 });
202 let results = merger.merge(5, true);
203 assert_eq!(results.len(), 5);
204 assert_eq!(results[0].distance_meters, 0.0);
205 }
206}