use std::collections::BinaryHeap;
use serde::{Deserialize, Serialize};
use crate::bbox::{BoundingBoxN, SpatialEntryN};
use crate::iter::SpatialIterN;
use crate::node::{str_build_nodes, InsertResult, NodeN};
use crate::{SpatialConfig, SpatialError, SplitStrategy};
pub struct SpatialIndexN<const D: usize, T> {
pub(crate) root: NodeN<D, T>,
len: usize,
config: SpatialConfig,
}
impl<const D: usize, T: Clone> Clone for SpatialIndexN<D, T> {
fn clone(&self) -> Self {
Self {
root: self.root.clone(),
len: self.len,
config: self.config,
}
}
}
impl<const D: usize, T> Default for SpatialIndexN<D, T> {
fn default() -> Self {
Self::new()
}
}
impl<const D: usize, T> SpatialIndexN<D, T> {
#[must_use]
pub const fn new() -> Self {
Self {
root: NodeN::Leaf {
entries: Vec::new(),
},
len: 0,
config: SpatialConfig::DEFAULT,
}
}
#[must_use]
pub const fn with_config(config: SpatialConfig) -> Self {
Self {
root: NodeN::Leaf {
entries: Vec::new(),
},
len: 0,
config,
}
}
#[must_use]
pub const fn config(&self) -> SpatialConfig {
self.config
}
pub fn insert(&mut self, entry: SpatialEntryN<D, T>) {
self.len += 1;
match self.root.insert_rstar(entry, self.config, true) {
InsertResult::Ok => {},
InsertResult::Split(sibling_bounds, sibling) => {
self.promote_root(sibling_bounds, sibling);
},
InsertResult::Reinsert(entries) => {
for e in entries {
self.insert_internal(e);
}
},
}
}
pub fn remove<F>(&mut self, region: BoundingBoxN<D>, pred: F) -> Result<(), SpatialError>
where
F: Fn(&SpatialEntryN<D, T>) -> bool,
{
let (found, orphans) = self.root.remove(region, &pred, self.config);
if !found {
return Err(SpatialError::NotFound);
}
self.len -= 1;
self.shrink_root();
for entry in orphans {
self.insert_internal(entry);
}
Ok(())
}
fn promote_root(&mut self, sibling_bounds: BoundingBoxN<D>, sibling: NodeN<D, T>) {
let old_root = std::mem::replace(
&mut self.root,
NodeN::Leaf {
entries: Vec::new(),
},
);
let old_bounds = old_root
.bounds()
.unwrap_or_else(BoundingBoxN::from_raw_zero);
self.root = NodeN::Internal {
children: vec![(old_bounds, old_root), (sibling_bounds, sibling)],
};
}
fn insert_internal(&mut self, entry: SpatialEntryN<D, T>) {
match self.root.insert_rstar(entry, self.config, false) {
InsertResult::Ok => {},
InsertResult::Split(sb, sn) => self.promote_root(sb, sn),
InsertResult::Reinsert(_) => {
unreachable!("reinsert with allow_reinsert=false");
},
}
}
fn shrink_root(&mut self) {
loop {
match &mut self.root {
NodeN::Internal { children } if children.is_empty() => {
self.root = NodeN::Leaf {
entries: Vec::new(),
};
},
NodeN::Internal { children } if children.len() == 1 => {
let (_, child) = children.pop().expect("single child");
self.root = child;
},
_ => break,
}
}
}
#[must_use]
pub fn query_region(&self, region: BoundingBoxN<D>) -> Vec<&SpatialEntryN<D, T>> {
let mut results = Vec::new();
self.root.query_region(region, &mut results);
results
}
#[must_use]
pub fn query_nearest_nd(&self, point: [f32; D], k: usize) -> Vec<&SpatialEntryN<D, T>> {
if k == 0 {
return Vec::new();
}
let mut heap = BinaryHeap::new();
self.root.query_nearest_heap(&point, &mut heap, k);
let mut results: Vec<_> = heap.into_iter().map(|c| c.entry).collect();
results.sort_by(|a, b| {
let da = a.bounds.min_dist_sq_nd(&point);
let db = b.bounds.min_dist_sq_nd(&point);
da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
});
results
}
#[must_use]
pub fn query_nearest_by_centroid_nd(
&self,
point: [f32; D],
k: usize,
) -> Vec<&SpatialEntryN<D, T>> {
if k == 0 {
return Vec::new();
}
let mut heap = BinaryHeap::new();
self.root
.query_nearest_by_centroid_heap(&point, &mut heap, k);
let mut results: Vec<_> = heap.into_iter().map(|c| c.entry).collect();
results.sort_by(|a, b| {
let da = a.bounds.center_dist_sq_nd(&point);
let db = b.bounds.center_dist_sq_nd(&point);
da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
});
results
}
#[must_use]
pub fn query_within_radius_nd(&self, point: [f32; D], r: f32) -> Vec<&SpatialEntryN<D, T>> {
if r < 0.0 {
return Vec::new();
}
let r_sq = r * r;
let mut results = Vec::new();
self.root.query_within_radius(&point, r_sq, &mut results);
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
results.into_iter().map(|(entry, _)| entry).collect()
}
#[must_use]
pub fn query_within_radius_with_distances_nd(
&self,
point: [f32; D],
r: f32,
) -> Vec<(&SpatialEntryN<D, T>, f32)> {
if r < 0.0 {
return Vec::new();
}
let r_sq = r * r;
let mut results = Vec::new();
self.root.query_within_radius(&point, r_sq, &mut results);
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
results
.into_iter()
.map(|(entry, dist_sq)| (entry, dist_sq.sqrt()))
.collect()
}
#[must_use]
pub const fn len(&self) -> usize {
self.len
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.len == 0
}
pub fn clear(&mut self) {
self.root = NodeN::Leaf {
entries: Vec::new(),
};
self.len = 0;
}
#[must_use]
pub fn bulk_load(entries: Vec<SpatialEntryN<D, T>>) -> Self {
Self::bulk_load_with_config(entries, SpatialConfig::DEFAULT)
}
#[must_use]
pub fn bulk_load_with_config(entries: Vec<SpatialEntryN<D, T>>, config: SpatialConfig) -> Self {
if entries.is_empty() {
return Self::with_config(config);
}
let len = entries.len();
let root = str_build_nodes(entries, config);
Self { root, len, config }
}
#[must_use]
pub fn iter(&self) -> SpatialIterN<'_, D, T> {
let mut entries = Vec::new();
self.root.collect_all(&mut entries);
SpatialIterN { entries, pos: 0 }
}
}
impl<T> SpatialIndexN<2, T> {
#[must_use]
pub fn query_nearest(&self, x: f32, y: f32, k: usize) -> Vec<&SpatialEntryN<2, T>> {
self.query_nearest_nd([x, y], k)
}
#[must_use]
pub fn query_nearest_by_centroid(&self, x: f32, y: f32, k: usize) -> Vec<&SpatialEntryN<2, T>> {
self.query_nearest_by_centroid_nd([x, y], k)
}
#[must_use]
pub fn query_within_radius(&self, x: f32, y: f32, r: f32) -> Vec<&SpatialEntryN<2, T>> {
self.query_within_radius_nd([x, y], r)
}
#[must_use]
pub fn query_within_radius_with_distances(
&self,
x: f32,
y: f32,
r: f32,
) -> Vec<(&SpatialEntryN<2, T>, f32)> {
self.query_within_radius_with_distances_nd([x, y], r)
}
}
impl<T> SpatialIndexN<3, T> {
#[must_use]
pub fn query_nearest(&self, x: f32, y: f32, z: f32, k: usize) -> Vec<&SpatialEntryN<3, T>> {
self.query_nearest_nd([x, y, z], k)
}
#[must_use]
#[allow(clippy::similar_names)]
pub fn query_nearest_by_centroid(
&self,
x: f32,
y: f32,
z: f32,
k: usize,
) -> Vec<&SpatialEntryN<3, T>> {
self.query_nearest_by_centroid_nd([x, y, z], k)
}
#[must_use]
pub fn query_within_radius(&self, x: f32, y: f32, z: f32, r: f32) -> Vec<&SpatialEntryN<3, T>> {
self.query_within_radius_nd([x, y, z], r)
}
#[must_use]
pub fn query_within_radius_with_distances(
&self,
x: f32,
y: f32,
z: f32,
r: f32,
) -> Vec<(&SpatialEntryN<3, T>, f32)> {
self.query_within_radius_with_distances_nd([x, y, z], r)
}
}
impl<'a, const D: usize, T> IntoIterator for &'a SpatialIndexN<D, T> {
type Item = &'a SpatialEntryN<D, T>;
type IntoIter = SpatialIterN<'a, D, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[derive(Serialize)]
struct SpatialIndexDtoRefN<'a, const D: usize, T> {
version: u8,
max_entries: usize,
split_strategy: u8,
entries: Vec<&'a SpatialEntryN<D, T>>,
}
#[derive(Deserialize)]
struct SpatialIndexDtoN<const D: usize, T> {
version: u8,
max_entries: usize,
split_strategy: u8,
entries: Vec<SpatialEntryN<D, T>>,
}
impl<const D: usize, T: Serialize> Serialize for SpatialIndexN<D, T> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut entries = Vec::with_capacity(self.len);
self.root.collect_all(&mut entries);
let dto = SpatialIndexDtoRefN {
version: 3,
max_entries: self.config.max_entries(),
split_strategy: self.config.split_strategy() as u8,
entries,
};
dto.serialize(serializer)
}
}
impl<'de, const D: usize, T: Deserialize<'de>> Deserialize<'de> for SpatialIndexN<D, T> {
fn deserialize<De: serde::Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
let dto = SpatialIndexDtoN::<D, T>::deserialize(deserializer)?;
if dto.version != 3 {
return Err(serde::de::Error::custom("unsupported SpatialIndex version"));
}
let strategy = match dto.split_strategy {
0 => SplitStrategy::Linear,
1 => SplitStrategy::RStar,
_ => return Err(serde::de::Error::custom("unsupported split strategy value")),
};
let config = SpatialConfig::with_strategy(dto.max_entries, strategy)
.map_err(serde::de::Error::custom)?;
Ok(Self::bulk_load_with_config(dto.entries, config))
}
}