use std::collections::HashMap;
use parking_lot::RwLock;
use rstar::{primitives::GeomWithData, RTree, AABB};
use super::entity::EntityId;
#[derive(Debug, Clone, PartialEq)]
pub enum SpatialIndexError {
MissingIndex { collection: String, column: String },
}
impl std::fmt::Display for SpatialIndexError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingIndex { collection, column } => {
write!(
f,
"spatial index for column '{column}' was not found in collection '{collection}'"
)
}
}
}
}
impl std::error::Error for SpatialIndexError {}
type SpatialEntry = GeomWithData<[f64; 2], EntityId>;
fn make_entry(lat: f64, lon: f64, entity_id: EntityId) -> SpatialEntry {
GeomWithData::new([lon, lat], entity_id)
}
pub use crate::geo::haversine_km;
fn km_to_approx_degrees(km: f64) -> f64 {
km / 111.32
}
#[derive(Debug, Clone)]
pub struct SpatialSearchResult {
pub entity_id: EntityId,
pub distance_km: f64,
}
pub struct SpatialIndex {
tree: RTree<SpatialEntry>,
points: HashMap<EntityId, (f64, f64)>,
pub column: String,
}
impl SpatialIndex {
pub fn new(column: impl Into<String>) -> Self {
Self {
tree: RTree::new(),
points: HashMap::new(),
column: column.into(),
}
}
pub fn bulk_load(column: impl Into<String>, data: Vec<(EntityId, f64, f64)>) -> Self {
let mut points = HashMap::with_capacity(data.len());
let entries: Vec<SpatialEntry> = data
.into_iter()
.map(|(id, lat, lon)| {
points.insert(id, (lat, lon));
make_entry(lat, lon, id)
})
.collect();
Self {
tree: RTree::bulk_load(entries),
points,
column: column.into(),
}
}
pub fn insert(&mut self, entity_id: EntityId, lat: f64, lon: f64) {
if let Some((old_lat, old_lon)) = self.points.remove(&entity_id) {
self.tree.remove(&make_entry(old_lat, old_lon, entity_id));
}
self.tree.insert(make_entry(lat, lon, entity_id));
self.points.insert(entity_id, (lat, lon));
}
pub fn remove(&mut self, entity_id: EntityId) -> bool {
if let Some((lat, lon)) = self.points.remove(&entity_id) {
self.tree.remove(&make_entry(lat, lon, entity_id));
true
} else {
false
}
}
pub fn search_radius(
&self,
center_lat: f64,
center_lon: f64,
radius_km: f64,
limit: usize,
) -> Vec<SpatialSearchResult> {
let deg = km_to_approx_degrees(radius_km) * 1.2; let aabb = AABB::from_corners(
[center_lon - deg, center_lat - deg],
[center_lon + deg, center_lat + deg],
);
let mut results: Vec<SpatialSearchResult> = self
.tree
.locate_in_envelope(&aabb)
.filter_map(|entry| {
let [lon, lat] = *entry.geom();
let dist = haversine_km(center_lat, center_lon, lat, lon);
if dist <= radius_km {
Some(SpatialSearchResult {
entity_id: entry.data,
distance_km: dist,
})
} else {
None
}
})
.collect();
results.sort_by(|a, b| {
a.distance_km
.partial_cmp(&b.distance_km)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
results
}
pub fn search_bbox(
&self,
min_lat: f64,
min_lon: f64,
max_lat: f64,
max_lon: f64,
limit: usize,
) -> Vec<SpatialSearchResult> {
let aabb = AABB::from_corners([min_lon, min_lat], [max_lon, max_lat]);
self.tree
.locate_in_envelope(&aabb)
.take(limit)
.map(|entry| SpatialSearchResult {
entity_id: entry.data,
distance_km: 0.0, })
.collect()
}
pub fn search_nearest(&self, lat: f64, lon: f64, k: usize) -> Vec<SpatialSearchResult> {
self.tree
.nearest_neighbor_iter(&[lon, lat])
.take(k)
.map(|entry| {
let [elon, elat] = *entry.geom();
SpatialSearchResult {
entity_id: entry.data,
distance_km: haversine_km(lat, lon, elat, elon),
}
})
.collect()
}
pub fn len(&self) -> usize {
self.points.len()
}
pub fn is_empty(&self) -> bool {
self.points.is_empty()
}
pub fn memory_bytes(&self) -> usize {
std::mem::size_of::<Self>()
+ self.points.len() * 32 + self.tree.size() * std::mem::size_of::<SpatialEntry>()
}
}
pub struct SpatialIndexManager {
indices: RwLock<HashMap<(String, String), SpatialIndex>>,
}
impl SpatialIndexManager {
pub fn new() -> Self {
Self {
indices: RwLock::new(HashMap::new()),
}
}
pub fn create_index(&self, collection: &str, column: &str) {
let mut indices = self.indices.write();
let key = (collection.to_string(), column.to_string());
indices
.entry(key)
.or_insert_with(|| SpatialIndex::new(column));
}
pub fn drop_index(&self, collection: &str, column: &str) -> bool {
let mut indices = self.indices.write();
indices
.remove(&(collection.to_string(), column.to_string()))
.is_some()
}
pub fn insert(
&self,
collection: &str,
column: &str,
entity_id: EntityId,
lat: f64,
lon: f64,
) -> Result<(), SpatialIndexError> {
let mut indices = self.indices.write();
if let Some(index) = indices.get_mut(&(collection.to_string(), column.to_string())) {
index.insert(entity_id, lat, lon);
Ok(())
} else {
Err(SpatialIndexError::MissingIndex {
collection: collection.to_string(),
column: column.to_string(),
})
}
}
pub fn remove(
&self,
collection: &str,
column: &str,
entity_id: EntityId,
) -> Result<bool, SpatialIndexError> {
let mut indices = self.indices.write();
if let Some(index) = indices.get_mut(&(collection.to_string(), column.to_string())) {
Ok(index.remove(entity_id))
} else {
Err(SpatialIndexError::MissingIndex {
collection: collection.to_string(),
column: column.to_string(),
})
}
}
pub fn search_radius(
&self,
collection: &str,
column: &str,
center_lat: f64,
center_lon: f64,
radius_km: f64,
limit: usize,
) -> Result<Vec<SpatialSearchResult>, SpatialIndexError> {
let indices = self.indices.read();
if let Some(idx) = indices.get(&(collection.to_string(), column.to_string())) {
Ok(idx.search_radius(center_lat, center_lon, radius_km, limit))
} else {
Err(SpatialIndexError::MissingIndex {
collection: collection.to_string(),
column: column.to_string(),
})
}
}
pub fn search_bbox(
&self,
collection: &str,
column: &str,
min_lat: f64,
min_lon: f64,
max_lat: f64,
max_lon: f64,
limit: usize,
) -> Result<Vec<SpatialSearchResult>, SpatialIndexError> {
let indices = self.indices.read();
if let Some(idx) = indices.get(&(collection.to_string(), column.to_string())) {
Ok(idx.search_bbox(min_lat, min_lon, max_lat, max_lon, limit))
} else {
Err(SpatialIndexError::MissingIndex {
collection: collection.to_string(),
column: column.to_string(),
})
}
}
pub fn search_nearest(
&self,
collection: &str,
column: &str,
lat: f64,
lon: f64,
k: usize,
) -> Result<Vec<SpatialSearchResult>, SpatialIndexError> {
let indices = self.indices.read();
if let Some(idx) = indices.get(&(collection.to_string(), column.to_string())) {
Ok(idx.search_nearest(lat, lon, k))
} else {
Err(SpatialIndexError::MissingIndex {
collection: collection.to_string(),
column: column.to_string(),
})
}
}
pub fn index_stats(
&self,
collection: &str,
column: &str,
) -> Result<SpatialIndexStats, SpatialIndexError> {
let indices = self.indices.read();
if let Some(idx) = indices.get(&(collection.to_string(), column.to_string())) {
Ok(SpatialIndexStats {
column: column.to_string(),
collection: collection.to_string(),
point_count: idx.len(),
memory_bytes: idx.memory_bytes(),
})
} else {
Err(SpatialIndexError::MissingIndex {
collection: collection.to_string(),
column: column.to_string(),
})
}
}
}
impl Default for SpatialIndexManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SpatialIndexStats {
pub column: String,
pub collection: String,
pub point_count: usize,
pub memory_bytes: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_haversine() {
let d = haversine_km(48.8566, 2.3522, 51.5074, -0.1278);
assert!((d - 344.0).abs() < 5.0, "Paris-London: {d} km");
}
#[test]
fn test_spatial_insert_and_radius() {
let mut idx = SpatialIndex::new("location");
idx.insert(EntityId::new(1), 48.8566, 2.3522);
idx.insert(EntityId::new(2), 51.5074, -0.1278);
idx.insert(EntityId::new(3), 52.5200, 13.4050);
idx.insert(EntityId::new(4), 35.6762, 139.6503);
let results = idx.search_radius(48.8566, 2.3522, 500.0, 10);
let ids: Vec<u64> = results.iter().map(|r| r.entity_id.raw()).collect();
assert!(ids.contains(&1), "Should find Paris");
assert!(ids.contains(&2), "Should find London");
assert!(!ids.contains(&4), "Should NOT find Tokyo");
}
#[test]
fn test_spatial_bbox() {
let mut idx = SpatialIndex::new("location");
idx.insert(EntityId::new(1), 48.8566, 2.3522); idx.insert(EntityId::new(2), 51.5074, -0.1278); idx.insert(EntityId::new(3), 35.6762, 139.6503);
let results = idx.search_bbox(40.0, -10.0, 55.0, 20.0, 10);
let ids: Vec<u64> = results.iter().map(|r| r.entity_id.raw()).collect();
assert!(ids.contains(&1)); assert!(ids.contains(&2)); assert!(!ids.contains(&3)); }
#[test]
fn test_spatial_nearest() {
let mut idx = SpatialIndex::new("location");
idx.insert(EntityId::new(1), 48.8566, 2.3522); idx.insert(EntityId::new(2), 51.5074, -0.1278); idx.insert(EntityId::new(3), 52.5200, 13.4050);
let results = idx.search_nearest(50.8503, 4.3517, 2);
assert_eq!(results.len(), 2);
assert!(results[0].distance_km < results[1].distance_km);
}
#[test]
fn test_spatial_remove() {
let mut idx = SpatialIndex::new("location");
idx.insert(EntityId::new(1), 48.8566, 2.3522);
idx.insert(EntityId::new(2), 51.5074, -0.1278);
assert_eq!(idx.len(), 2);
idx.remove(EntityId::new(1));
assert_eq!(idx.len(), 1);
let results = idx.search_nearest(48.8566, 2.3522, 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].entity_id, EntityId::new(2));
}
#[test]
fn test_spatial_bulk_load() {
let data = vec![
(EntityId::new(1), 48.8566, 2.3522),
(EntityId::new(2), 51.5074, -0.1278),
(EntityId::new(3), 52.5200, 13.4050),
];
let idx = SpatialIndex::bulk_load("location", data);
assert_eq!(idx.len(), 3);
}
#[test]
fn test_spatial_manager() {
let mgr = SpatialIndexManager::new();
mgr.create_index("sites", "location");
mgr.insert("sites", "location", EntityId::new(1), 48.8566, 2.3522)
.expect("spatial insert should succeed");
mgr.insert("sites", "location", EntityId::new(2), 51.5074, -0.1278)
.expect("spatial insert should succeed");
let results = mgr
.search_radius("sites", "location", 48.8566, 2.3522, 500.0, 10)
.unwrap();
assert!(!results.is_empty());
let stats = mgr.index_stats("sites", "location").unwrap();
assert_eq!(stats.point_count, 2);
}
#[test]
fn test_spatial_manager_recovers_from_poisoned_lock() {
let mgr = SpatialIndexManager::new();
mgr.create_index("sites", "location");
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _guard = mgr.indices.write();
panic!("poison spatial index manager");
}));
mgr.insert("sites", "location", EntityId::new(1), 48.8566, 2.3522)
.expect("spatial insert should recover after poison");
let results = mgr
.search_nearest("sites", "location", 48.8566, 2.3522, 1)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].entity_id, EntityId::new(1));
}
#[test]
fn test_spatial_manager_lookup_missing_index_returns_error() {
let mgr = SpatialIndexManager::new();
let err = mgr
.search_nearest("sites", "location", 48.8566, 2.3522, 1)
.expect_err("spatial lookup should fail when the index does not exist");
assert_eq!(
err,
SpatialIndexError::MissingIndex {
collection: "sites".to_string(),
column: "location".to_string(),
}
);
}
}