use std::sync::Arc;
use std::time::Duration;
#[cfg(not(target_arch = "wasm32"))]
use std::time::Instant;
#[cfg(target_arch = "wasm32")]
struct Instant;
#[cfg(target_arch = "wasm32")]
impl Instant {
#[inline]
fn now() -> Self {
Instant
}
#[inline]
fn elapsed(&self) -> Duration {
Duration::ZERO
}
}
use crate::backends::{GridIndex, KDTree, Quadtree, RTree, SpatialBackend};
use crate::bloom::BloomCache;
use crate::profiler::{Observation, Profiler};
use crate::router::IndexRouter;
use crate::stats::StatsCollector;
use crate::types::{
BBox, BackendKind, BonsaiError, CoordType, DataShape, EntryId, Point, QueryMix, Stats,
};
#[derive(Debug, Clone)]
pub struct BonsaiConfig {
pub initial_backend: BackendKind,
pub migration_threshold: f64,
pub hysteresis_window: usize,
pub reservoir_size: usize,
pub bloom_memory_bytes: usize,
pub max_migration_latency: Duration,
}
impl Default for BonsaiConfig {
fn default() -> Self {
Self {
initial_backend: BackendKind::KDTree,
migration_threshold: 0.77,
hysteresis_window: 1000,
reservoir_size: 4096,
bloom_memory_bytes: 65_536,
max_migration_latency: Duration::from_micros(50),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct BonsaiBuilder<T, C = f64, const D: usize = 2>
where
C: CoordType,
T: Clone + Send + Sync + 'static,
{
config: BonsaiConfig,
_phantom: std::marker::PhantomData<(T, C)>,
}
impl<T, C, const D: usize> BonsaiBuilder<T, C, D>
where
C: CoordType,
T: Clone + Send + Sync + 'static,
{
pub fn initial_backend(mut self, backend: BackendKind) -> Self {
self.config.initial_backend = backend;
self
}
pub fn migration_threshold(mut self, threshold: f64) -> Self {
self.config.migration_threshold = threshold;
self
}
pub fn hysteresis_window(mut self, window: usize) -> Self {
self.config.hysteresis_window = window;
self
}
pub fn reservoir_size(mut self, size: usize) -> Self {
self.config.reservoir_size = size;
self
}
pub fn bloom_memory_bytes(mut self, bytes: usize) -> Self {
self.config.bloom_memory_bytes = bytes;
self
}
pub fn build(self) -> BonsaiIndex<T, C, D> {
BonsaiIndex::from_config(self.config)
}
}
pub struct BonsaiIndex<T, C = f64, const D: usize = 2>
where
C: CoordType,
T: Clone + Send + Sync + 'static,
{
pub(crate) router: Arc<IndexRouter<T, C, D>>,
profiler: Profiler<C, D>,
stats: Arc<StatsCollector>,
bloom: BloomCache<D>,
pub(crate) config: BonsaiConfig,
pub(crate) migration_count: u64,
pub(crate) frozen: bool,
pub(crate) point_count: usize,
}
impl<T, C, const D: usize> std::fmt::Debug for BonsaiIndex<T, C, D>
where
C: CoordType,
T: Clone + Send + Sync + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BonsaiIndex")
.field("point_count", &self.point_count)
.field("migration_count", &self.migration_count)
.field("frozen", &self.frozen)
.field("config", &self.config)
.finish_non_exhaustive()
}
}
impl<T, C, const D: usize> BonsaiIndex<T, C, D>
where
C: CoordType,
T: Clone + Send + Sync + 'static,
{
pub fn builder() -> BonsaiBuilder<T, C, D> {
BonsaiBuilder {
config: BonsaiConfig::default(),
_phantom: std::marker::PhantomData,
}
}
pub fn from_config(config: BonsaiConfig) -> Self {
let backend: Box<dyn SpatialBackend<T, C, D>> = Box::new(KDTree::<T, C, D>::new());
let router = Arc::new(IndexRouter::new(backend));
let profiler = Profiler::new(config.reservoir_size);
let stats = Arc::new(StatsCollector::new());
let bloom = BloomCache::new(config.bloom_memory_bytes, 7);
Self {
router,
profiler,
stats,
bloom,
config,
migration_count: 0,
frozen: false,
point_count: 0,
}
}
pub fn insert(&mut self, point: Point<C, D>, payload: T) -> EntryId {
let id = self.router.insert(point, payload);
self.stats.record_insert();
self.profiler.observe(Observation::Insert(point));
let point_bbox = point_to_f64_bbox(point);
self.bloom.insert(&point_bbox);
self.point_count += 1;
id
}
pub fn remove(&mut self, id: EntryId) -> Option<T> {
let result = self.router.remove(id);
if result.is_some() {
self.point_count = self.point_count.saturating_sub(1);
}
result
}
pub fn range_query(&mut self, bbox: &BBox<C, D>) -> Vec<(EntryId, T)> {
let t0 = Instant::now();
let results = self.router.range_query(bbox);
self.stats.record_query(t0.elapsed().as_nanos() as u64);
results
}
pub fn knn_query(&self, point: &Point<C, D>, k: usize) -> Vec<(f64, EntryId, T)> {
let t0 = Instant::now();
let results = {
let active = unsafe { &*self.router.active_ptr() };
active
.read()
.knn_query(point, k)
.into_iter()
.map(|(d, id, t)| (d, id, t.clone()))
.collect()
};
self.stats.record_query(t0.elapsed().as_nanos() as u64);
results
}
pub fn nearest(&self, point: &Point<C, D>) -> Option<(f64, EntryId, T)> {
let mut results = self.knn_query(point, 1);
results.pop()
}
pub fn contains(&self, point: &Point<C, D>, bbox: &BBox<C, D>) -> bool {
bbox.contains_point(point)
}
pub fn spatial_join(&self, other: &BonsaiIndex<T, C, D>) -> Vec<(EntryId, EntryId)> {
let active_self = unsafe { &*self.router.active_ptr() };
let active_other = unsafe { &*other.router.active_ptr() };
let guard_self = active_self.read();
let guard_other = active_other.read();
guard_self.spatial_join(guard_other.as_ref())
}
pub fn stats(&self) -> Stats<D> {
let backend = {
let active = unsafe { &*self.router.active_ptr() };
active.read().kind()
};
let data_shape = self
.profiler
.data_shape()
.cloned()
.unwrap_or_else(default_data_shape::<D>);
Stats {
backend,
point_count: self.point_count,
migrations: self.migration_count,
last_migration_at: None,
query_count: self.stats.query_count(),
data_shape,
migrating: self.router.is_migrating(),
dimensions: D,
}
}
pub fn force_backend(&mut self, _backend: BackendKind) -> Result<(), BonsaiError> {
if self.router.is_migrating() {
return Err(BonsaiError::MigrationInProgress);
}
Ok(())
}
pub fn clear(&mut self) -> Result<(), BonsaiError> {
if self.router.is_migrating() {
return Err(BonsaiError::MigrationInProgress);
}
let kind = {
let active = unsafe { &*self.router.active_ptr() };
active.read().kind()
};
let fresh_backend: Box<dyn SpatialBackend<T, C, D>> = match kind {
BackendKind::KDTree => Box::new(KDTree::new()),
BackendKind::RTree => Box::new(RTree::new()),
BackendKind::Quadtree => Box::new(Quadtree::new()),
BackendKind::Grid => Box::new(GridIndex::<T, C, D>::default()),
};
self.router = Arc::new(IndexRouter::new(fresh_backend));
self.bloom = BloomCache::new(self.config.bloom_memory_bytes, 7);
self.profiler = Profiler::new(self.config.reservoir_size);
self.stats = Arc::new(StatsCollector::new());
self.point_count = 0;
Ok(())
}
pub fn freeze(&mut self) {
self.frozen = true;
}
pub fn unfreeze(&mut self) {
self.frozen = false;
}
pub fn is_frozen(&self) -> bool {
self.frozen
}
pub fn len(&self) -> usize {
self.point_count
}
pub fn is_empty(&self) -> bool {
self.point_count == 0
}
}
fn point_to_f64_bbox<C: CoordType, const D: usize>(point: Point<C, D>) -> BBox<f64, D> {
let coords: [f64; D] = std::array::from_fn(|d| point.coords()[d].into());
let p = Point::new(coords);
BBox::new(p, p)
}
fn default_data_shape<const D: usize>() -> DataShape<D> {
DataShape {
point_count: 0,
bbox: BBox::new(Point::new([0.0; D]), Point::new([1.0; D])),
skewness: [0.0; D],
clustering_coef: 1.0,
overlap_ratio: 0.0,
effective_dim: D as f64,
query_mix: QueryMix::default(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Point;
fn make_index() -> BonsaiIndex<&'static str> {
BonsaiIndex::builder().build()
}
#[test]
fn insert_and_len() {
let mut idx = make_index();
assert_eq!(idx.len(), 0);
idx.insert(Point::new([1.0, 2.0]), "a");
idx.insert(Point::new([3.0, 4.0]), "b");
assert_eq!(idx.len(), 2);
}
#[test]
fn remove_returns_payload() {
let mut idx: BonsaiIndex<u32> = BonsaiIndex::builder().build();
let id = idx.insert(Point::new([0.0, 0.0]), 99u32);
assert_eq!(idx.remove(id), Some(99));
assert_eq!(idx.len(), 0);
}
#[test]
fn range_query_basic() {
let mut idx = make_index();
idx.insert(Point::new([1.0, 1.0]), "inside");
idx.insert(Point::new([9.0, 9.0]), "outside");
let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([5.0, 5.0]));
let results = idx.range_query(&bbox);
assert_eq!(results.len(), 1);
assert_eq!(results[0].1, "inside");
}
#[test]
fn range_query_empty_region() {
let mut idx = make_index();
idx.insert(Point::new([1.0, 1.0]), "a");
let far_bbox = BBox::new(Point::new([900.0, 900.0]), Point::new([1000.0, 1000.0]));
let results = idx.range_query(&far_bbox);
assert_eq!(results.len(), 0);
}
#[test]
fn knn_query_returns_nearest() {
let mut idx = make_index();
idx.insert(Point::new([0.0, 0.0]), "origin");
idx.insert(Point::new([3.0, 4.0]), "far");
let results = idx.knn_query(&Point::new([0.0, 0.0]), 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].2, "origin");
assert!((results[0].0).abs() < 1e-9);
}
#[test]
fn nearest_returns_closest() {
let mut idx = make_index();
idx.insert(Point::new([1.0, 0.0]), "near");
idx.insert(Point::new([10.0, 0.0]), "far");
let (dist, _, payload) = idx.nearest(&Point::new([0.0, 0.0])).unwrap();
assert_eq!(payload, "near");
assert!((dist - 1.0).abs() < 1e-9);
}
#[test]
fn nearest_empty_returns_none() {
let idx = make_index();
assert!(idx.nearest(&Point::new([0.0, 0.0])).is_none());
}
#[test]
fn contains_geometric_check() {
let idx = make_index();
let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([1.0, 1.0]));
assert!(idx.contains(&Point::new([0.5, 0.5]), &bbox));
assert!(!idx.contains(&Point::new([2.0, 0.5]), &bbox));
}
#[test]
fn stats_reflects_inserts() {
let mut idx = make_index();
idx.insert(Point::new([1.0, 1.0]), "a");
idx.insert(Point::new([2.0, 2.0]), "b");
let s = idx.stats();
assert_eq!(s.point_count, 2);
assert_eq!(s.dimensions, 2);
assert!(!s.migrating);
}
#[test]
fn freeze_unfreeze() {
let mut idx = make_index();
assert!(!idx.is_frozen());
idx.freeze();
assert!(idx.is_frozen());
idx.unfreeze();
assert!(!idx.is_frozen());
}
#[test]
fn force_backend_ok_when_not_migrating() {
let mut idx = make_index();
assert!(idx.force_backend(BackendKind::RTree).is_ok());
}
#[test]
fn builder_pattern() {
let idx: BonsaiIndex<i32> = BonsaiIndex::builder()
.initial_backend(BackendKind::RTree)
.reservoir_size(512)
.bloom_memory_bytes(8192)
.build();
assert_eq!(idx.len(), 0);
}
#[test]
fn spatial_join_same_points() {
let mut a: BonsaiIndex<()> = BonsaiIndex::builder().build();
let mut b: BonsaiIndex<()> = BonsaiIndex::builder().build();
a.insert(Point::new([1.0, 1.0]), ());
b.insert(Point::new([1.0, 1.0]), ());
let pairs = a.spatial_join(&b);
assert_eq!(pairs.len(), 1);
}
#[test]
fn clear_on_empty_index_succeeds() {
let mut idx = make_index();
assert!(idx.clear().is_ok());
assert_eq!(idx.len(), 0);
}
#[test]
fn clear_resets_len() {
let mut idx = make_index();
for i in 0..5 {
idx.insert(Point::new([i as f64, i as f64]), "x");
}
assert_eq!(idx.len(), 5);
idx.clear().unwrap();
assert_eq!(idx.len(), 0);
}
#[test]
fn clear_preserves_frozen() {
let mut idx = make_index();
idx.freeze();
idx.clear().unwrap();
assert!(idx.is_frozen());
}
#[test]
fn clear_preserves_migration_count() {
let mut idx = make_index();
let migrations_before = idx.stats().migrations;
idx.insert(Point::new([1.0, 1.0]), "a");
idx.clear().unwrap();
assert_eq!(idx.stats().migrations, migrations_before);
}
#[test]
fn clear_returns_err_when_migrating() {
let mut idx: BonsaiIndex<&str> = BonsaiIndex::builder().build();
use crate::backends::KDTree as KD;
idx.router
.begin_migration(Box::new(KD::<&str, f64, 2>::new()));
let result = idx.clear();
assert!(matches!(result, Err(BonsaiError::MigrationInProgress)));
idx.router.commit_migration();
}
#[test]
fn clear_empty_range_query_returns_empty() {
let mut idx = make_index();
idx.insert(Point::new([1.0, 1.0]), "a");
idx.clear().unwrap();
let full = BBox::new(Point::new([-1e9, -1e9]), Point::new([1e9, 1e9]));
assert!(idx.range_query(&full).is_empty());
}
#[test]
fn clear_then_insert_range_query() {
let mut idx = make_index();
idx.insert(Point::new([50.0, 50.0]), "old");
idx.clear().unwrap();
idx.insert(Point::new([1.0, 1.0]), "new");
let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([5.0, 5.0]));
let results = idx.range_query(&bbox);
assert_eq!(results.len(), 1);
assert_eq!(results[0].1, "new");
}
#[test]
fn clear_then_insert_knn_query() {
let mut idx = make_index();
idx.insert(Point::new([50.0, 50.0]), "old");
idx.clear().unwrap();
idx.insert(Point::new([1.0, 1.0]), "new");
let results = idx.knn_query(&Point::new([0.0, 0.0]), 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].2, "new");
}
mod prop_tests {
use super::*;
use proptest::prelude::*;
fn point_strategy() -> impl Strategy<Value = Point<f64, 2>> {
(-1000.0_f64..1000.0_f64, -1000.0_f64..1000.0_f64).prop_map(|(x, y)| Point::new([x, y]))
}
fn bbox_strategy() -> impl Strategy<Value = BBox<f64, 2>> {
(
-1000.0_f64..1000.0_f64,
-1000.0_f64..1000.0_f64,
0.001_f64..500.0_f64,
0.001_f64..500.0_f64,
)
.prop_map(|(x, y, w, h)| BBox::new(Point::new([x, y]), Point::new([x + w, y + h])))
}
proptest! {
#![proptest_config(proptest::test_runner::Config {
cases: 100,
..Default::default()
})]
#[test]
fn prop_clear_preserves_backend_kind(
points in prop::collection::vec(point_strategy(), 0..30),
backend_idx in 0usize..4,
) {
let kinds = [
BackendKind::KDTree,
BackendKind::RTree,
BackendKind::Quadtree,
BackendKind::Grid,
];
let kind = kinds[backend_idx % 4];
let mut idx: BonsaiIndex<u32> = BonsaiIndex::builder()
.initial_backend(kind)
.build();
idx.force_backend(kind).unwrap();
for (i, p) in points.iter().enumerate() {
idx.insert(*p, i as u32);
}
let kind_before = idx.stats().backend;
idx.clear().unwrap();
prop_assert_eq!(
idx.stats().backend,
kind_before,
"backend kind must be preserved after clear"
);
}
}
proptest! {
#![proptest_config(proptest::test_runner::Config {
cases: 100,
..Default::default()
})]
#[test]
fn prop_clear_len_invariant(
pre_inserts in prop::collection::vec(point_strategy(), 0..30),
post_inserts in prop::collection::vec(point_strategy(), 0..30),
) {
let mut idx: BonsaiIndex<u32> = BonsaiIndex::builder().build();
for (i, p) in pre_inserts.iter().enumerate() {
idx.insert(*p, i as u32);
}
idx.clear().unwrap();
for (i, p) in post_inserts.iter().enumerate() {
idx.insert(*p, i as u32);
}
prop_assert_eq!(
idx.len(),
post_inserts.len(),
"len() must equal the number of inserts after clear"
);
}
}
proptest! {
#![proptest_config(proptest::test_runner::Config {
cases: 100,
..Default::default()
})]
#[test]
fn prop_clear_empty_queries(
points in prop::collection::vec(point_strategy(), 1..30),
query_bbox in bbox_strategy(),
knn_point in point_strategy(),
k in 1usize..10,
) {
let mut idx: BonsaiIndex<u32> = BonsaiIndex::builder().build();
for (i, p) in points.iter().enumerate() {
idx.insert(*p, i as u32);
}
idx.clear().unwrap();
prop_assert!(
idx.range_query(&query_bbox).is_empty(),
"range_query must return empty after clear"
);
prop_assert!(
idx.knn_query(&knn_point, k).is_empty(),
"knn_query must return empty after clear"
);
}
}
proptest! {
#![proptest_config(proptest::test_runner::Config {
cases: 100,
..Default::default()
})]
#[test]
fn prop_clear_round_trip_equivalence(
pre_inserts in prop::collection::vec(point_strategy(), 0..20),
entries in prop::collection::vec(
(point_strategy(), 0u32..10_000u32),
1..20,
),
) {
let full_bbox = BBox::new(
Point::new([-1001.0, -1001.0]),
Point::new([1001.0, 1001.0]),
);
let mut fresh: BonsaiIndex<u32> = BonsaiIndex::builder().build();
for (p, v) in &entries {
fresh.insert(*p, *v);
}
let mut fresh_results: Vec<u32> = fresh
.range_query(&full_bbox)
.into_iter()
.map(|(_, v)| v)
.collect();
fresh_results.sort_unstable();
let mut cleared: BonsaiIndex<u32> = BonsaiIndex::builder().build();
for (i, p) in pre_inserts.iter().enumerate() {
cleared.insert(*p, i as u32 + 100_000);
}
cleared.clear().unwrap();
for (p, v) in &entries {
cleared.insert(*p, *v);
}
let mut cleared_results: Vec<u32> = cleared
.range_query(&full_bbox)
.into_iter()
.map(|(_, v)| v)
.collect();
cleared_results.sort_unstable();
prop_assert_eq!(
fresh_results,
cleared_results,
"clear+insert must produce the same range_query results as fresh+insert"
);
}
}
}
}