Skip to main content

nodedb_cluster/distributed_spatial/
shard_routing.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! Spatial-aware shard routing.
4//!
5//! Maintains per-shard bounding box metadata — the spatial extent of all
6//! geometries on each shard. On spatial query, only fan out to shards
7//! whose bounding box overlaps the query region.
8//!
9//! Updated at flush time: each shard reports its spatial extent to the
10//! routing table. Bounding box union is cheap (min of mins, max of maxes).
11
12use nodedb_types::BoundingBox;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16/// Per-shard spatial extent for a collection.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ShardSpatialExtent {
19    pub shard_id: u32,
20    /// Bounding box covering all geometries in this shard for this collection.
21    /// None if the shard has no spatial data for this collection.
22    pub extent: Option<BoundingBox>,
23    /// Number of spatial entries on this shard.
24    pub entry_count: usize,
25}
26
27/// Per-collection spatial routing metadata.
28///
29/// Tracks which shards have spatial data and their extent. Used by the
30/// coordinator to avoid fanning out to shards that can't possibly match.
31pub struct SpatialRoutingTable {
32    /// collection_name → shard_id → extent.
33    extents: HashMap<String, HashMap<u32, ShardSpatialExtent>>,
34}
35
36impl SpatialRoutingTable {
37    pub fn new() -> Self {
38        Self {
39            extents: HashMap::new(),
40        }
41    }
42
43    /// Update the spatial extent for a shard.
44    pub fn update_extent(
45        &mut self,
46        collection: &str,
47        shard_id: u32,
48        extent: Option<BoundingBox>,
49        entry_count: usize,
50    ) {
51        let entry = ShardSpatialExtent {
52            shard_id,
53            extent,
54            entry_count,
55        };
56        self.extents
57            .entry(collection.to_string())
58            .or_default()
59            .insert(shard_id, entry);
60    }
61
62    /// Find which shards might contain results for a spatial query.
63    ///
64    /// Returns shard IDs whose extent overlaps the query bounding box.
65    /// If no extent data is available for a shard, it is conservatively
66    /// included (we don't know what's there, so we must check).
67    pub fn route_query(
68        &self,
69        collection: &str,
70        query_bbox: &BoundingBox,
71        all_shard_ids: &[u32],
72    ) -> Vec<u32> {
73        let Some(shard_extents) = self.extents.get(collection) else {
74            // No extent data at all — fan out to all shards (conservative).
75            return all_shard_ids.to_vec();
76        };
77
78        let mut target_shards = Vec::new();
79        for &shard_id in all_shard_ids {
80            match shard_extents.get(&shard_id) {
81                Some(ext) => {
82                    match &ext.extent {
83                        Some(bbox) if !bbox.intersects(query_bbox) => {
84                            // Skip — shard extent doesn't overlap query.
85                        }
86                        _ => target_shards.push(shard_id),
87                    }
88                }
89                None => {
90                    // No extent data for this shard — include conservatively.
91                    target_shards.push(shard_id);
92                }
93            }
94        }
95        target_shards
96    }
97
98    /// Get the total number of spatial entries across all shards for a collection.
99    pub fn total_entries(&self, collection: &str) -> usize {
100        self.extents
101            .get(collection)
102            .map(|shards| shards.values().map(|e| e.entry_count).sum())
103            .unwrap_or(0)
104    }
105}
106
107impl Default for SpatialRoutingTable {
108    fn default() -> Self {
109        Self::new()
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn route_filters_non_overlapping() {
119        let mut table = SpatialRoutingTable::new();
120        // Shard 0: buildings in NYC area.
121        table.update_extent(
122            "buildings",
123            0,
124            Some(BoundingBox::new(-74.1, 40.6, -73.9, 40.8)),
125            1000,
126        );
127        // Shard 1: buildings in London area.
128        table.update_extent(
129            "buildings",
130            1,
131            Some(BoundingBox::new(-0.2, 51.4, 0.1, 51.6)),
132            500,
133        );
134        // Shard 2: buildings in Tokyo.
135        table.update_extent(
136            "buildings",
137            2,
138            Some(BoundingBox::new(139.6, 35.6, 139.8, 35.8)),
139            300,
140        );
141
142        // Query near NYC — should only route to shard 0.
143        let query = BoundingBox::new(-74.05, 40.7, -73.95, 40.8);
144        let targets = table.route_query("buildings", &query, &[0, 1, 2]);
145        assert_eq!(targets, vec![0]);
146    }
147
148    #[test]
149    fn route_includes_unknown_shards() {
150        let mut table = SpatialRoutingTable::new();
151        table.update_extent(
152            "buildings",
153            0,
154            Some(BoundingBox::new(0.0, 0.0, 10.0, 10.0)),
155            100,
156        );
157        // Shard 1 has no extent data.
158
159        let query = BoundingBox::new(5.0, 5.0, 15.0, 15.0);
160        let targets = table.route_query("buildings", &query, &[0, 1]);
161        // Both shards included: 0 overlaps, 1 is unknown.
162        assert_eq!(targets.len(), 2);
163    }
164
165    #[test]
166    fn route_no_extent_data_fans_out_all() {
167        let table = SpatialRoutingTable::new();
168        let targets =
169            table.route_query("unknown", &BoundingBox::new(0.0, 0.0, 1.0, 1.0), &[0, 1, 2]);
170        assert_eq!(targets, vec![0, 1, 2]);
171    }
172
173    #[test]
174    fn total_entries() {
175        let mut table = SpatialRoutingTable::new();
176        table.update_extent("col", 0, Some(BoundingBox::new(0.0, 0.0, 1.0, 1.0)), 100);
177        table.update_extent("col", 1, Some(BoundingBox::new(2.0, 2.0, 3.0, 3.0)), 200);
178        assert_eq!(table.total_entries("col"), 300);
179    }
180}