use crate::coordinates::Distance;
#[ derive( Debug, Clone, Copy, PartialEq ) ]
pub struct SpatialBounds
{
pub left: i32,
pub top: i32,
pub right: i32,
pub bottom: i32,
}
impl SpatialBounds
{
pub fn new(left: i32, top: i32, right: i32, bottom: i32) -> Self {
Self { left, top, right, bottom }
}
pub fn from_center_size(center_x: i32, center_y: i32, width: i32, height: i32) -> Self {
let half_width = width / 2;
let half_height = height / 2;
Self {
left: center_x - half_width,
top: center_y - half_height,
right: center_x + half_width,
bottom: center_y + half_height,
}
}
pub fn width(&self) -> i32 {
self.right - self.left
}
pub fn height(&self) -> i32 {
self.bottom - self.top
}
pub fn area(&self) -> i32 {
self.width() * self.height()
}
pub fn contains_point(&self, x: i32, y: i32) -> bool {
x >= self.left && x <= self.right && y >= self.top && y <= self.bottom
}
pub fn intersects(&self, other: &SpatialBounds) -> bool {
!(self.right < other.left ||
self.left > other.right ||
self.bottom < other.top ||
self.top > other.bottom)
}
pub fn contains(&self, other: &SpatialBounds) -> bool {
self.left <= other.left &&
self.right >= other.right &&
self.top <= other.top &&
self.bottom >= other.bottom
}
pub fn center(&self) -> (i32, i32) {
((self.left + self.right) / 2, (self.top + self.bottom) / 2)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct SpatialEntity<C> {
pub id: u32,
pub position: C,
pub radius: i32,
}
impl<C> SpatialEntity<C> {
pub fn new(id: u32, position: C, radius: i32) -> Self {
Self { id, position, radius }
}
pub fn bounds(&self) -> SpatialBounds
where
C: SpatialCoordinate,
{
let (x, y) = self.position.to_spatial_coords();
SpatialBounds::from_center_size(x, y, self.radius * 2, self.radius * 2)
}
pub fn intersects_bounds(&self, bounds: &SpatialBounds) -> bool
where
C: SpatialCoordinate,
{
self.bounds().intersects(bounds)
}
pub fn intersects_entity(&self, other: &SpatialEntity<C>) -> bool
where
C: Distance,
{
let distance = self.position.distance(&other.position);
distance <= (self.radius + other.radius) as u32
}
}
pub trait SpatialCoordinate {
fn to_spatial_coords(&self) -> (i32, i32);
fn from_spatial_coords(x: i32, y: i32) -> Self;
}
#[derive(Debug)]
enum QuadtreeNode<C> {
Leaf {
entities: Vec<SpatialEntity<C>>,
},
Internal {
northeast: Box<QuadtreeNode<C>>,
northwest: Box<QuadtreeNode<C>>,
southeast: Box<QuadtreeNode<C>>,
southwest: Box<QuadtreeNode<C>>,
},
}
impl<C> QuadtreeNode<C> {
fn new_leaf() -> Self {
QuadtreeNode::Leaf {
entities: Vec::new(),
}
}
}
#[derive(Debug)]
pub struct Quadtree<C> {
root: QuadtreeNode<C>,
bounds: SpatialBounds,
max_entities: usize,
max_depth: usize,
}
impl<C> Quadtree<C>
where
C: SpatialCoordinate + Clone,
{
pub fn new(bounds: SpatialBounds, max_entities: usize) -> Self {
Self {
root: QuadtreeNode::new_leaf(),
bounds,
max_entities,
max_depth: 0,
}
}
pub fn insert(&mut self, entity: SpatialEntity<C>) {
let bounds = self.bounds;
let max_entities = self.max_entities;
Self::insert_recursive_static(&mut self.root, entity, &bounds, 0, max_entities, &mut self.max_depth);
}
pub fn remove(&mut self, entity_id: u32) -> Vec<SpatialEntity<C>> {
let mut removed = Vec::new();
Self::remove_recursive_static(&mut self.root, entity_id, &mut removed);
removed
}
pub fn query_region(&self, query_bounds: &SpatialBounds) -> Vec<SpatialEntity<C>> {
let mut results = Vec::new();
self.query_recursive(&self.root, query_bounds, &self.bounds, &mut results);
results
}
pub fn query_circle(&self, center_x: i32, center_y: i32, radius: i32) -> Vec<SpatialEntity<C>>
where
C: Distance,
{
let query_bounds = SpatialBounds::from_center_size(center_x, center_y, radius * 2, radius * 2);
let candidates = self.query_region(&query_bounds);
let center_coord = C::from_spatial_coords(center_x, center_y);
candidates.into_iter()
.filter(|entity| {
let distance = entity.position.distance(¢er_coord);
distance <= (radius as u32)
})
.collect()
}
pub fn all_entities(&self) -> Vec<SpatialEntity<C>> {
let mut entities = Vec::new();
self.collect_all_entities(&self.root, &mut entities);
entities
}
pub fn clear(&mut self) {
self.root = QuadtreeNode::new_leaf();
self.max_depth = 0;
}
pub fn stats(&self) -> QuadtreeStats {
let mut stats = QuadtreeStats::default();
self.calculate_stats(&self.root, 0, &mut stats);
stats
}
fn insert_recursive_static(
node: &mut QuadtreeNode<C>,
entity: SpatialEntity<C>,
bounds: &SpatialBounds,
depth: usize,
max_entities: usize,
current_max_depth: &mut usize,
) {
*current_max_depth = (*current_max_depth).max(depth);
match node {
QuadtreeNode::Leaf { entities } => {
entities.push(entity);
if entities.len() > max_entities && depth < 16 { Self::subdivide_node_static(node, bounds, depth, max_entities, current_max_depth);
}
}
QuadtreeNode::Internal { northeast, northwest, southeast, southwest } => {
let (center_x, center_y) = bounds.center();
let (entity_x, entity_y) = entity.position.to_spatial_coords();
let in_north = entity_y <= center_y;
let in_east = entity_x >= center_x;
match (in_north, in_east) {
(true, true) => {
Self::insert_recursive_static(northeast, entity,
&SpatialBounds::new(center_x, bounds.top, bounds.right, center_y),
depth + 1, max_entities, current_max_depth);
}
(true, false) => {
Self::insert_recursive_static(northwest, entity,
&SpatialBounds::new(bounds.left, bounds.top, center_x, center_y),
depth + 1, max_entities, current_max_depth);
}
(false, true) => {
Self::insert_recursive_static(southeast, entity,
&SpatialBounds::new(center_x, center_y, bounds.right, bounds.bottom),
depth + 1, max_entities, current_max_depth);
}
(false, false) => {
Self::insert_recursive_static(southwest, entity,
&SpatialBounds::new(bounds.left, center_y, center_x, bounds.bottom),
depth + 1, max_entities, current_max_depth);
}
};
}
}
}
fn subdivide_node_static(
node: &mut QuadtreeNode<C>,
bounds: &SpatialBounds,
depth: usize,
max_entities: usize,
current_max_depth: &mut usize,
) {
if let QuadtreeNode::Leaf { entities } = node {
let entities_to_redistribute = std::mem::take(entities);
*node = QuadtreeNode::Internal {
northeast: Box::new(QuadtreeNode::new_leaf()),
northwest: Box::new(QuadtreeNode::new_leaf()),
southeast: Box::new(QuadtreeNode::new_leaf()),
southwest: Box::new(QuadtreeNode::new_leaf()),
};
for entity in entities_to_redistribute {
Self::insert_recursive_static(node, entity, bounds, depth, max_entities, current_max_depth);
}
}
}
fn remove_recursive_static(
node: &mut QuadtreeNode<C>,
entity_id: u32,
removed: &mut Vec<SpatialEntity<C>>
) {
match node {
QuadtreeNode::Leaf { entities } => {
let _original_len = entities.len();
entities.retain(|e| {
if e.id == entity_id {
removed.push(e.clone());
false
} else {
true
}
});
}
QuadtreeNode::Internal { northeast, northwest, southeast, southwest } => {
Self::remove_recursive_static(northeast, entity_id, removed);
Self::remove_recursive_static(northwest, entity_id, removed);
Self::remove_recursive_static(southeast, entity_id, removed);
Self::remove_recursive_static(southwest, entity_id, removed);
}
}
}
fn query_recursive(
&self,
node: &QuadtreeNode<C>,
query_bounds: &SpatialBounds,
node_bounds: &SpatialBounds,
results: &mut Vec<SpatialEntity<C>>
) {
if !query_bounds.intersects(node_bounds) {
return;
}
match node {
QuadtreeNode::Leaf { entities } => {
for entity in entities {
if entity.intersects_bounds(query_bounds) {
results.push(entity.clone());
}
}
}
QuadtreeNode::Internal { northeast, northwest, southeast, southwest } => {
let (center_x, center_y) = node_bounds.center();
self.query_recursive(
northeast, query_bounds,
&SpatialBounds::new(center_x, node_bounds.top, node_bounds.right, center_y),
results
);
self.query_recursive(
northwest, query_bounds,
&SpatialBounds::new(node_bounds.left, node_bounds.top, center_x, center_y),
results
);
self.query_recursive(
southeast, query_bounds,
&SpatialBounds::new(center_x, center_y, node_bounds.right, node_bounds.bottom),
results
);
self.query_recursive(
southwest, query_bounds,
&SpatialBounds::new(node_bounds.left, center_y, center_x, node_bounds.bottom),
results
);
}
}
}
fn collect_all_entities(&self, node: &QuadtreeNode<C>, entities: &mut Vec<SpatialEntity<C>>) {
match node {
QuadtreeNode::Leaf { entities: node_entities } => {
entities.extend_from_slice(node_entities);
}
QuadtreeNode::Internal { northeast, northwest, southeast, southwest } => {
self.collect_all_entities(northeast, entities);
self.collect_all_entities(northwest, entities);
self.collect_all_entities(southeast, entities);
self.collect_all_entities(southwest, entities);
}
}
}
fn calculate_stats(&self, node: &QuadtreeNode<C>, depth: usize, stats: &mut QuadtreeStats) {
stats.total_nodes += 1;
stats.max_depth = stats.max_depth.max(depth);
match node {
QuadtreeNode::Leaf { entities } => {
stats.leaf_nodes += 1;
stats.total_entities += entities.len();
stats.max_entities_per_node = stats.max_entities_per_node.max(entities.len());
if entities.is_empty() {
stats.empty_nodes += 1;
}
}
QuadtreeNode::Internal { northeast, northwest, southeast, southwest } => {
stats.internal_nodes += 1;
self.calculate_stats(northeast, depth + 1, stats);
self.calculate_stats(northwest, depth + 1, stats);
self.calculate_stats(southeast, depth + 1, stats);
self.calculate_stats(southwest, depth + 1, stats);
}
}
}
}
#[derive(Debug, Default, Clone)]
pub struct QuadtreeStats {
pub total_nodes: usize,
pub leaf_nodes: usize,
pub internal_nodes: usize,
pub empty_nodes: usize,
pub max_depth: usize,
pub total_entities: usize,
pub max_entities_per_node: usize,
}
impl QuadtreeStats {
pub fn average_entities_per_leaf(&self) -> f32 {
if self.leaf_nodes > 0 {
self.total_entities as f32 / self.leaf_nodes as f32
} else {
0.0
}
}
pub fn fill_ratio(&self) -> f32 {
if self.total_nodes > 0 {
(self.total_nodes - self.empty_nodes) as f32 / self.total_nodes as f32
} else {
0.0
}
}
}
impl SpatialCoordinate for (i32, i32) {
fn to_spatial_coords(&self) -> (i32, i32) {
(self.0, self.1)
}
fn from_spatial_coords(x: i32, y: i32) -> Self {
(x, y)
}
}
impl<T> SpatialCoordinate for crate::coordinates::square::Coordinate<T> {
fn to_spatial_coords(&self) -> (i32, i32) {
(self.x, self.y)
}
fn from_spatial_coords(x: i32, y: i32) -> Self {
Self::new(x, y)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::coordinates::square::{Coordinate as SquareCoord, FourConnected};
#[test]
fn test_spatial_bounds_creation() {
let bounds = SpatialBounds::new(0, 0, 100, 100);
assert_eq!(bounds.width(), 100);
assert_eq!(bounds.height(), 100);
assert_eq!(bounds.area(), 10000);
assert_eq!(bounds.center(), (50, 50));
}
#[test]
fn test_spatial_bounds_contains() {
let bounds = SpatialBounds::new(10, 10, 50, 50);
assert!(bounds.contains_point(25, 25));
assert!(!bounds.contains_point(5, 5));
assert!(!bounds.contains_point(60, 60));
}
#[test]
fn test_spatial_bounds_intersects() {
let bounds1 = SpatialBounds::new(0, 0, 50, 50);
let bounds2 = SpatialBounds::new(25, 25, 75, 75);
let bounds3 = SpatialBounds::new(100, 100, 150, 150);
assert!(bounds1.intersects(&bounds2));
assert!(!bounds1.intersects(&bounds3));
}
#[test]
fn test_spatial_entity_creation() {
let pos = SquareCoord::<FourConnected>::new(10, 20);
let entity = SpatialEntity::new(1, pos, 5);
assert_eq!(entity.id, 1);
assert_eq!(entity.radius, 5);
let bounds = entity.bounds();
assert_eq!(bounds.center(), (10, 20));
}
#[test]
fn test_quadtree_basic_operations() {
let bounds = SpatialBounds::new(0, 0, 100, 100);
let mut quadtree = Quadtree::new(bounds, 4);
let entity1 = SpatialEntity::new(1, SquareCoord::<FourConnected>::new(25, 25), 1);
let entity2 = SpatialEntity::new(2, SquareCoord::<FourConnected>::new(75, 75), 1);
quadtree.insert(entity1);
quadtree.insert(entity2);
let all_entities = quadtree.all_entities();
assert_eq!(all_entities.len(), 2);
let query_bounds = SpatialBounds::new(0, 0, 50, 50);
let region_entities = quadtree.query_region(&query_bounds);
assert_eq!(region_entities.len(), 1);
assert_eq!(region_entities[0].id, 1);
}
#[test]
fn test_quadtree_subdivision() {
let bounds = SpatialBounds::new(0, 0, 100, 100);
let mut quadtree = Quadtree::new(bounds, 2);
for i in 0..10 {
let entity = SpatialEntity::new(i, SquareCoord::<FourConnected>::new((i * 10) as i32, (i * 10) as i32), 1);
quadtree.insert(entity);
}
let stats = quadtree.stats();
assert!(stats.max_depth > 0); assert_eq!(stats.total_entities, 10);
}
#[test]
fn test_quadtree_circular_query() {
let bounds = SpatialBounds::new(0, 0, 100, 100);
let mut quadtree = Quadtree::new(bounds, 10);
quadtree.insert(SpatialEntity::new(1, SquareCoord::<FourConnected>::new(50, 50), 1)); quadtree.insert(SpatialEntity::new(2, SquareCoord::<FourConnected>::new(52, 50), 1)); quadtree.insert(SpatialEntity::new(3, SquareCoord::<FourConnected>::new(80, 80), 1));
let nearby = quadtree.query_circle(50, 50, 5);
assert_eq!(nearby.len(), 2); }
#[test]
fn test_quadtree_remove() {
let bounds = SpatialBounds::new(0, 0, 100, 100);
let mut quadtree = Quadtree::new(bounds, 10);
let entity1 = SpatialEntity::new(1, SquareCoord::<FourConnected>::new(25, 25), 1);
let entity2 = SpatialEntity::new(2, SquareCoord::<FourConnected>::new(75, 75), 1);
quadtree.insert(entity1);
quadtree.insert(entity2);
let removed = quadtree.remove(1);
assert_eq!(removed.len(), 1);
assert_eq!(removed[0].id, 1);
let remaining = quadtree.all_entities();
assert_eq!(remaining.len(), 1);
assert_eq!(remaining[0].id, 2);
}
#[test]
fn test_quadtree_stats() {
let bounds = SpatialBounds::new(0, 0, 100, 100);
let mut quadtree = Quadtree::new(bounds, 5);
for i in 0..20 {
let entity = SpatialEntity::new(i, SquareCoord::<FourConnected>::new((i * 5) as i32, (i * 5) as i32), 1);
quadtree.insert(entity);
}
let stats = quadtree.stats();
assert_eq!(stats.total_entities, 20);
assert!(stats.average_entities_per_leaf() > 0.0);
assert!(stats.fill_ratio() > 0.0);
}
}