use nodedb_types::BoundingBox;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardSpatialExtent {
pub shard_id: u16,
pub extent: Option<BoundingBox>,
pub entry_count: usize,
}
pub struct SpatialRoutingTable {
extents: HashMap<String, HashMap<u16, ShardSpatialExtent>>,
}
impl SpatialRoutingTable {
pub fn new() -> Self {
Self {
extents: HashMap::new(),
}
}
pub fn update_extent(
&mut self,
collection: &str,
shard_id: u16,
extent: Option<BoundingBox>,
entry_count: usize,
) {
let entry = ShardSpatialExtent {
shard_id,
extent,
entry_count,
};
self.extents
.entry(collection.to_string())
.or_default()
.insert(shard_id, entry);
}
pub fn route_query(
&self,
collection: &str,
query_bbox: &BoundingBox,
all_shard_ids: &[u16],
) -> Vec<u16> {
let Some(shard_extents) = self.extents.get(collection) else {
return all_shard_ids.to_vec();
};
let mut target_shards = Vec::new();
for &shard_id in all_shard_ids {
match shard_extents.get(&shard_id) {
Some(ext) => {
match &ext.extent {
Some(bbox) if !bbox.intersects(query_bbox) => {
}
_ => target_shards.push(shard_id),
}
}
None => {
target_shards.push(shard_id);
}
}
}
target_shards
}
pub fn total_entries(&self, collection: &str) -> usize {
self.extents
.get(collection)
.map(|shards| shards.values().map(|e| e.entry_count).sum())
.unwrap_or(0)
}
}
impl Default for SpatialRoutingTable {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn route_filters_non_overlapping() {
let mut table = SpatialRoutingTable::new();
table.update_extent(
"buildings",
0,
Some(BoundingBox::new(-74.1, 40.6, -73.9, 40.8)),
1000,
);
table.update_extent(
"buildings",
1,
Some(BoundingBox::new(-0.2, 51.4, 0.1, 51.6)),
500,
);
table.update_extent(
"buildings",
2,
Some(BoundingBox::new(139.6, 35.6, 139.8, 35.8)),
300,
);
let query = BoundingBox::new(-74.05, 40.7, -73.95, 40.8);
let targets = table.route_query("buildings", &query, &[0, 1, 2]);
assert_eq!(targets, vec![0]);
}
#[test]
fn route_includes_unknown_shards() {
let mut table = SpatialRoutingTable::new();
table.update_extent(
"buildings",
0,
Some(BoundingBox::new(0.0, 0.0, 10.0, 10.0)),
100,
);
let query = BoundingBox::new(5.0, 5.0, 15.0, 15.0);
let targets = table.route_query("buildings", &query, &[0, 1]);
assert_eq!(targets.len(), 2);
}
#[test]
fn route_no_extent_data_fans_out_all() {
let table = SpatialRoutingTable::new();
let targets =
table.route_query("unknown", &BoundingBox::new(0.0, 0.0, 1.0, 1.0), &[0, 1, 2]);
assert_eq!(targets, vec![0, 1, 2]);
}
#[test]
fn total_entries() {
let mut table = SpatialRoutingTable::new();
table.update_extent("col", 0, Some(BoundingBox::new(0.0, 0.0, 1.0, 1.0)), 100);
table.update_extent("col", 1, Some(BoundingBox::new(2.0, 2.0, 3.0, 3.0)), 200);
assert_eq!(table.total_entries("col"), 300);
}
}