use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use crate::backends::SpatialBackend;
use crate::types::{BBox, CoordType, EntryId, Point};
pub(crate) struct BackendBox<T, C, const D: usize>
where
C: CoordType,
{
pub inner: RwLock<Box<dyn SpatialBackend<T, C, D>>>,
}
impl<T, C, const D: usize> BackendBox<T, C, D>
where
C: CoordType,
T: Send + Sync + 'static,
{
fn new(backend: Box<dyn SpatialBackend<T, C, D>>) -> *mut Self {
Box::into_raw(Box::new(Self {
inner: RwLock::new(backend),
}))
}
pub(crate) fn read(&self) -> RwLockReadGuard<'_, Box<dyn SpatialBackend<T, C, D>>> {
self.inner.read().expect("backend RwLock poisoned")
}
pub(crate) fn write(&self) -> RwLockWriteGuard<'_, Box<dyn SpatialBackend<T, C, D>>> {
self.inner.write().expect("backend RwLock poisoned")
}
}
pub struct IndexRouter<T, C, const D: usize>
where
C: CoordType,
T: Send + Sync + 'static,
{
active: AtomicPtr<BackendBox<T, C, D>>,
shadow: AtomicPtr<BackendBox<T, C, D>>,
migrating: AtomicBool,
}
impl<T, C, const D: usize> IndexRouter<T, C, D>
where
C: CoordType,
T: Clone + Send + Sync + 'static,
{
pub fn new(backend: Box<dyn SpatialBackend<T, C, D>>) -> Self {
Self {
active: AtomicPtr::new(BackendBox::new(backend)),
shadow: AtomicPtr::new(std::ptr::null_mut()),
migrating: AtomicBool::new(false),
}
}
pub fn insert(&self, point: Point<C, D>, payload: T) -> EntryId {
let active = unsafe { &*self.active.load(Ordering::SeqCst) };
let id = active.write().insert(point, payload.clone());
if self.migrating.load(Ordering::SeqCst) {
let shadow_ptr = self.shadow.load(Ordering::SeqCst);
if !shadow_ptr.is_null() {
let shadow = unsafe { &*shadow_ptr };
shadow.write().insert(point, payload);
}
}
id
}
pub fn remove(&self, id: EntryId) -> Option<T> {
let active = unsafe { &*self.active.load(Ordering::SeqCst) };
let result = active.write().remove(id);
if self.migrating.load(Ordering::SeqCst) {
let shadow_ptr = self.shadow.load(Ordering::SeqCst);
if !shadow_ptr.is_null() {
let shadow = unsafe { &*shadow_ptr };
shadow.write().remove(id);
}
}
result
}
pub fn range_query(&self, bbox: &BBox<C, D>) -> Vec<(EntryId, T)>
where
T: Clone,
{
let active = unsafe { &*self.active.load(Ordering::SeqCst) };
active
.read()
.range_query(bbox)
.into_iter()
.map(|(id, t)| (id, t.clone()))
.collect()
}
pub fn len(&self) -> usize {
let active = unsafe { &*self.active.load(Ordering::SeqCst) };
active.read().len()
}
pub fn is_empty(&self) -> bool {
let active = unsafe { &*self.active.load(Ordering::SeqCst) };
active.read().is_empty()
}
pub fn is_migrating(&self) -> bool {
self.migrating.load(Ordering::SeqCst)
}
pub(crate) fn active_ptr(&self) -> *mut BackendBox<T, C, D> {
self.active.load(Ordering::SeqCst)
}
pub fn begin_migration(&self, shadow: Box<dyn SpatialBackend<T, C, D>>) -> bool {
if self
.migrating
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
return false;
}
let ptr = BackendBox::new(shadow);
self.shadow.store(ptr, Ordering::SeqCst);
true
}
pub fn commit_migration(&self) {
assert!(
self.migrating.load(Ordering::SeqCst),
"commit_migration called without an active migration"
);
let shadow_ptr = self.shadow.swap(std::ptr::null_mut(), Ordering::SeqCst);
assert!(
!shadow_ptr.is_null(),
"shadow pointer is null during commit"
);
let old_ptr = self.active.swap(shadow_ptr, Ordering::SeqCst);
self.migrating.store(false, Ordering::SeqCst);
if !old_ptr.is_null() {
drop(unsafe { Box::from_raw(old_ptr) });
}
}
}
impl<T, C, const D: usize> Drop for IndexRouter<T, C, D>
where
C: CoordType,
T: Send + Sync + 'static,
{
fn drop(&mut self) {
let active_ptr = *self.active.get_mut();
if !active_ptr.is_null() {
drop(unsafe { Box::from_raw(active_ptr) });
}
let shadow_ptr = *self.shadow.get_mut();
if !shadow_ptr.is_null() {
drop(unsafe { Box::from_raw(shadow_ptr) });
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backends::KDTree;
use crate::types::Point;
fn make_router() -> IndexRouter<usize, f64, 2> {
IndexRouter::new(Box::new(KDTree::<usize, f64, 2>::new()))
}
#[test]
fn insert_and_len() {
let router = make_router();
router.insert(Point::new([1.0, 2.0]), 42);
router.insert(Point::new([3.0, 4.0]), 99);
assert_eq!(router.len(), 2);
}
#[test]
fn remove_returns_payload() {
let router = make_router();
let id = router.insert(Point::new([1.0, 2.0]), 7usize);
let payload = router.remove(id);
assert_eq!(payload, Some(7));
assert_eq!(router.len(), 0);
}
#[test]
fn range_query_returns_inserted_points() {
let router = make_router();
router.insert(Point::new([1.0, 1.0]), 1usize);
router.insert(Point::new([5.0, 5.0]), 2usize);
let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([3.0, 3.0]));
let results = router.range_query(&bbox);
assert_eq!(results.len(), 1);
assert_eq!(results[0].1, 1);
}
#[test]
fn migration_swap_preserves_entries() {
let router = make_router();
for i in 0..10usize {
router.insert(Point::new([i as f64, i as f64]), i);
}
assert_eq!(router.len(), 10);
let new_backend = Box::new(KDTree::<usize, f64, 2>::new());
assert!(router.begin_migration(new_backend));
assert!(router.is_migrating());
for i in 10..15usize {
router.insert(Point::new([i as f64, i as f64]), i);
}
router.commit_migration();
assert!(!router.is_migrating());
assert_eq!(router.len(), 5);
}
#[test]
fn begin_migration_rejects_concurrent() {
let router = make_router();
let b1 = Box::new(KDTree::<usize, f64, 2>::new());
let b2 = Box::new(KDTree::<usize, f64, 2>::new());
assert!(router.begin_migration(b1));
assert!(!router.begin_migration(b2));
router.commit_migration();
}
#[test]
fn concurrent_inserts_and_queries() {
use std::sync::Arc;
use std::thread;
let router = Arc::new(make_router());
let mut handles = Vec::new();
for t in 0..4u64 {
let r = Arc::clone(&router);
handles.push(thread::spawn(move || {
for i in 0..100usize {
r.insert(Point::new([t as f64 * 100.0 + i as f64, 0.0]), i);
}
}));
}
for _ in 0..2u64 {
let r = Arc::clone(&router);
handles.push(thread::spawn(move || {
let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([1000.0, 1000.0]));
for _ in 0..50 {
let _ = r.range_query(&bbox);
}
}));
}
for h in handles {
h.join().expect("thread panicked");
}
assert_eq!(router.len(), 400);
}
}
#[cfg(all(loom, test))]
mod loom_tests {
use loom::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
use loom::sync::Arc;
use loom::thread;
struct LoomRouter {
active: AtomicPtr<LoomBackendBox>,
migrating: AtomicBool,
}
struct LoomBackendBox {
data: std::sync::Mutex<Vec<([f64; 2], usize)>>,
}
impl LoomBackendBox {
fn new() -> *mut Self {
Box::into_raw(Box::new(Self {
data: std::sync::Mutex::new(Vec::new()),
}))
}
fn insert(&self, x: f64, y: f64, val: usize) {
self.data.lock().unwrap().push(([x, y], val));
}
fn query(&self) -> usize {
self.data.lock().unwrap().len()
}
}
impl LoomRouter {
fn new() -> Self {
Self {
active: AtomicPtr::new(LoomBackendBox::new()),
migrating: AtomicBool::new(false),
}
}
fn insert(&self, x: f64, y: f64, val: usize) {
let ptr = self.active.load(Ordering::SeqCst);
unsafe { &*ptr }.insert(x, y, val);
}
fn query(&self) -> usize {
let ptr = self.active.load(Ordering::SeqCst);
unsafe { &*ptr }.query()
}
}
impl Drop for LoomRouter {
fn drop(&mut self) {
let ptr = self.active.load(Ordering::Relaxed);
if !ptr.is_null() {
drop(unsafe { Box::from_raw(ptr) });
}
}
}
#[test]
fn prop_concurrent_read_write_safety() {
loom::model(|| {
let router = Arc::new(LoomRouter::new());
let r1 = Arc::clone(&router);
let writer = thread::spawn(move || {
r1.insert(1.0, 1.0, 1);
r1.insert(2.0, 2.0, 2);
});
let r2 = Arc::clone(&router);
let reader = thread::spawn(move || {
let _ = r2.query();
});
writer.join().unwrap();
reader.join().unwrap();
});
}
}