#[cfg(feature = "ggez")]
use ggez::{
graphics::{self, DrawMode},
Context,
};
use failure::{Error, Fail};
use snowflake::ProcessUniqueId as Uid;
use std::collections::{HashMap, HashSet};
use crate::rect::*;
use crate::Point2;
#[derive(Clone, Debug, PartialEq)]
pub struct QTreeNode {
pub boundary: Rect,
objects: HashMap<Uid, Rect>,
children: Option<Box<[Self; 4]>>,
pub capacity: usize,
}
#[derive(Clone, Debug, Fail)]
pub enum QTreeError {
#[fail(display = "The supplied rectangle doesn't fit the boundary")]
RectDoesNotFit,
}
impl QTreeNode {
pub fn new(boundary: Rect, capacity: usize) -> Self {
Self {
boundary,
objects: HashMap::new(),
children: None,
capacity,
}
}
fn subdiv(&mut self) {
if self.children.is_some() {
return;
}
let b = &self.boundary;
let ne = b.corner(NE).unwrap();
let nw = b.corner(NW).unwrap();
let sw = b.corner(SW).unwrap();
let se = b.corner(SE).unwrap();
let rect_ne = Rect {
center: Point2::new((b.center.x + ne.x) / 2.0, (b.center.y + ne.y) / 2.0),
w_half: b.w_half / 2.0,
h_half: b.h_half / 2.0,
};
let rect_nw = Rect {
center: Point2::new((b.center.x + nw.x) / 2.0, (b.center.y + nw.y) / 2.0),
w_half: b.w_half / 2.0,
h_half: b.h_half / 2.0,
};
let rect_sw = Rect {
center: Point2::new((b.center.x + sw.x) / 2.0, (b.center.y + sw.y) / 2.0),
w_half: b.w_half / 2.0,
h_half: b.h_half / 2.0,
};
let rect_se = Rect {
center: Point2::new((b.center.x + se.x) / 2.0, (b.center.y + se.y) / 2.0),
w_half: b.w_half / 2.0,
h_half: b.h_half / 2.0,
};
self.children = Some(Box::new([
QTreeNode::new(rect_ne, self.capacity),
QTreeNode::new(rect_nw, self.capacity),
QTreeNode::new(rect_sw, self.capacity),
QTreeNode::new(rect_se, self.capacity),
]))
}
pub fn insert(&mut self, rect: &Rect, id: Uid) -> Result<(), Error> {
if !self.boundary.contains_rect(&rect) {
return Err(QTreeError::RectDoesNotFit.into());
}
if self.objects.len() < self.capacity {
self.objects.insert(id, rect.clone());
return Ok(());
}
if self.children.is_none() {
self.subdiv();
}
for child in self.children.as_mut().unwrap().iter_mut() {
match child.insert(rect, id) {
Ok(()) => return Ok(()),
Err(e) => match e.downcast::<QTreeError>() {
Ok(QTreeError::RectDoesNotFit) => {}
Err(e) => return Err(e),
},
}
}
self.objects.insert(id, rect.clone());
Ok(())
}
pub fn query_point<'a>(&'a self, point: &Point2, mut limit: Option<usize>) -> HashSet<Uid> {
let mut ret = HashSet::new();
if !self.boundary.contains_point(point) {
return ret;
}
for (id, obj) in &self.objects {
if obj.contains_point(point) {
ret.insert(*id);
if let Some(limit) = limit.as_mut() {
*limit -= 1;
if *limit == 0 {
break;
}
}
}
}
if let Some(children) = self.children.as_ref() {
for child in children.iter() {
ret = ret
.union(&child.query_point(point, limit))
.cloned()
.collect();
}
}
ret
}
}
#[cfg(feature = "ggez")]
impl QTreeNode {
pub fn draw_regions(&self, ctx: &mut Context, mode: DrawMode) -> Result<(), Error> {
graphics::rectangle(ctx, mode, self.boundary.to_ggez())?;
if let Some(children) = self.children.as_ref() {
for chld in children.iter() {
chld.draw_regions(ctx, mode)?;
}
}
Ok(())
}
pub fn draw_objects(&self, ctx: &mut Context, mode: DrawMode) -> Result<(), Error> {
for (_id, obj) in &self.objects {
graphics::rectangle(ctx, mode, obj.to_ggez())?;
}
if let Some(children) = self.children.as_ref() {
for chld in children.iter() {
chld.draw_objects(ctx, mode)?;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn subdiv_produces_children() {
let rect = Rect {
center: Point2::new(rand::random(), rand::random()),
w_half: rand::random(),
h_half: rand::random(),
};
let ne = rect.corner(NE).unwrap();
let nw = rect.corner(NW).unwrap();
let sw = rect.corner(SW).unwrap();
let se = rect.corner(SE).unwrap();
let expected_rects = vec![
Rect {
center: Point2::new((rect.center.x + ne.x) / 2.0, (rect.center.y + ne.y) / 2.0),
w_half: rect.w_half / 2.0,
h_half: rect.h_half / 2.0,
},
Rect {
center: Point2::new((rect.center.x + nw.x) / 2.0, (rect.center.y + nw.y) / 2.0),
w_half: rect.w_half / 2.0,
h_half: rect.h_half / 2.0,
},
Rect {
center: Point2::new((rect.center.x + sw.x) / 2.0, (rect.center.y + sw.y) / 2.0),
w_half: rect.w_half / 2.0,
h_half: rect.h_half / 2.0,
},
Rect {
center: Point2::new((rect.center.x + se.x) / 2.0, (rect.center.y + se.y) / 2.0),
w_half: rect.w_half / 2.0,
h_half: rect.h_half / 2.0,
},
];
let mut qt = QTreeNode::new(rect.clone(), 4);
dbg!(qt.clone());
qt.subdiv();
assert_ne!(qt.children, None);
let children = qt.children.unwrap();
let found_rects: Vec<_> = children.iter().cloned().map(|node| node.boundary).collect();
assert_eq!(found_rects, expected_rects);
}
#[test]
fn insert_capacity_works() {
let boundary = Rect::new(0.0, 0.0, 200.0, 200.0);
let capacity = 4;
let mut qt = QTreeNode::new(boundary.clone(), capacity);
let mut item = Rect::new(50.0, 50.0, 50.0, 50.0);
for _i in 0..capacity + 1 {
let id = Uid::new();
qt.insert(&item, id).unwrap();
assert_eq!(qt.objects[&id], item);
item.center.x += 5.0;
}
let fitting_item = Rect::new(10.0, 10.0, 10.0, 10.0);
let id = Uid::new();
qt.insert(&fitting_item, id).unwrap();
assert!(qt.children.is_some());
let children = qt.children.as_ref().unwrap();
dbg!(children);
assert_eq!(children[NW].objects[&id], fitting_item);
}
#[test]
fn insert_discards_not_fitting() {
let boundary = Rect::new(10.0, 10.0, 10.0, 10.0);
let item = Rect::new(0.0, 0.0, 20.0, 20.0);
let mut qt = QTreeNode::new(boundary, 4);
assert!(qt.insert(&item, Uid::new()).is_err());
}
#[test]
fn query_point_finds_all_rects() {
let boundary = Rect::new(0.0, 0.0, 10.0, 10.0);
let capacity = 4;
let mut qt = QTreeNode::new(boundary.clone(), capacity);
for _i in 0..capacity + 1 {
qt.insert(&boundary, Uid::new()).unwrap();
}
let found_rects = qt.query_point(&Point2::new(5.0, 5.0), None);
assert_eq!(found_rects.len(), capacity + 1);
}
}