Skip to main content

nodedb_cluster/distributed_spatial/
merge.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! Cross-shard spatial result merging.
4//!
5//! Unlike vector search (which needs distance re-ranking), spatial predicates
6//! are boolean — a document either matches or it doesn't. The merge is a
7//! simple concatenation of shard results with deduplication.
8
9use serde::{Deserialize, Serialize};
10
11/// A single spatial match from a shard.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct SpatialHit {
14    /// Document ID.
15    pub doc_id: String,
16    /// Which shard produced this hit.
17    pub shard_id: u32,
18    /// Distance to query geometry in meters (for ST_DWithin ordering).
19    /// 0.0 for non-distance predicates (ST_Contains, ST_Intersects).
20    pub distance_meters: f64,
21}
22
23/// Results from a single shard's local spatial query.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ShardSpatialResult {
26    pub shard_id: u32,
27    pub hits: Vec<SpatialHit>,
28    pub success: bool,
29    pub error: Option<String>,
30}
31
32/// Merges spatial results from multiple shards.
33///
34/// For boolean predicates (ST_Contains, ST_Intersects): simple concatenation.
35/// For distance predicates (ST_DWithin): merge and sort by distance, take limit.
36pub struct SpatialResultMerger {
37    all_hits: Vec<SpatialHit>,
38    responded: usize,
39    expected: usize,
40}
41
42impl SpatialResultMerger {
43    pub fn new(expected_shards: usize) -> Self {
44        Self {
45            all_hits: Vec::new(),
46            responded: 0,
47            expected: expected_shards,
48        }
49    }
50
51    /// Add a shard's results.
52    pub fn add_shard_result(&mut self, result: &ShardSpatialResult) {
53        if result.success {
54            self.all_hits.extend_from_slice(&result.hits);
55        }
56        self.responded += 1;
57    }
58
59    /// Whether all expected shards have responded.
60    pub fn all_responded(&self) -> bool {
61        self.responded >= self.expected
62    }
63
64    /// Merge all results: deduplicate by doc_id, optionally sort by distance.
65    ///
66    /// For ST_DWithin, results are sorted by distance (nearest first).
67    /// For boolean predicates, order is arbitrary. Truncates to `limit`.
68    pub fn merge(&mut self, limit: usize, sort_by_distance: bool) -> Vec<SpatialHit> {
69        // Deduplicate by doc_id (a document can only appear on one shard,
70        // but defensive in case of ghost stubs or migration overlap).
71        let mut seen = std::collections::HashSet::new();
72        self.all_hits.retain(|h| seen.insert(h.doc_id.clone()));
73
74        if sort_by_distance {
75            self.all_hits.sort_by(|a, b| {
76                a.distance_meters
77                    .partial_cmp(&b.distance_meters)
78                    .unwrap_or(std::cmp::Ordering::Equal)
79            });
80        }
81
82        self.all_hits.truncate(limit);
83        self.all_hits.clone()
84    }
85
86    /// Total hits collected (before merge).
87    pub fn total_hits(&self) -> usize {
88        self.all_hits.len()
89    }
90
91    pub fn response_count(&self) -> usize {
92        self.responded
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn merge_two_shards_boolean() {
102        let mut merger = SpatialResultMerger::new(2);
103        merger.add_shard_result(&ShardSpatialResult {
104            shard_id: 0,
105            hits: vec![
106                SpatialHit {
107                    doc_id: "a".into(),
108                    shard_id: 0,
109                    distance_meters: 0.0,
110                },
111                SpatialHit {
112                    doc_id: "b".into(),
113                    shard_id: 0,
114                    distance_meters: 0.0,
115                },
116            ],
117            success: true,
118            error: None,
119        });
120        merger.add_shard_result(&ShardSpatialResult {
121            shard_id: 1,
122            hits: vec![SpatialHit {
123                doc_id: "c".into(),
124                shard_id: 1,
125                distance_meters: 0.0,
126            }],
127            success: true,
128            error: None,
129        });
130
131        assert!(merger.all_responded());
132        let results = merger.merge(10, false);
133        assert_eq!(results.len(), 3);
134    }
135
136    #[test]
137    fn merge_with_distance_sort() {
138        let mut merger = SpatialResultMerger::new(2);
139        merger.add_shard_result(&ShardSpatialResult {
140            shard_id: 0,
141            hits: vec![SpatialHit {
142                doc_id: "far".into(),
143                shard_id: 0,
144                distance_meters: 500.0,
145            }],
146            success: true,
147            error: None,
148        });
149        merger.add_shard_result(&ShardSpatialResult {
150            shard_id: 1,
151            hits: vec![SpatialHit {
152                doc_id: "near".into(),
153                shard_id: 1,
154                distance_meters: 100.0,
155            }],
156            success: true,
157            error: None,
158        });
159
160        let results = merger.merge(10, true);
161        assert_eq!(results[0].doc_id, "near");
162        assert_eq!(results[1].doc_id, "far");
163    }
164
165    #[test]
166    fn merge_with_failed_shard() {
167        let mut merger = SpatialResultMerger::new(2);
168        merger.add_shard_result(&ShardSpatialResult {
169            shard_id: 0,
170            hits: vec![SpatialHit {
171                doc_id: "a".into(),
172                shard_id: 0,
173                distance_meters: 0.0,
174            }],
175            success: true,
176            error: None,
177        });
178        merger.add_shard_result(&ShardSpatialResult {
179            shard_id: 1,
180            hits: vec![],
181            success: false,
182            error: Some("timeout".into()),
183        });
184
185        let results = merger.merge(10, false);
186        assert_eq!(results.len(), 1);
187    }
188
189    #[test]
190    fn merge_respects_limit() {
191        let mut merger = SpatialResultMerger::new(1);
192        merger.add_shard_result(&ShardSpatialResult {
193            shard_id: 0,
194            hits: (0..100)
195                .map(|i| SpatialHit {
196                    doc_id: format!("d{i}"),
197                    shard_id: 0,
198                    distance_meters: i as f64,
199                })
200                .collect(),
201            success: true,
202            error: None,
203        });
204        let results = merger.merge(5, true);
205        assert_eq!(results.len(), 5);
206        assert_eq!(results[0].distance_meters, 0.0);
207    }
208}