use super::PolyRef;
#[derive(Debug, Clone, Copy)]
pub struct Aabb {
pub min: [f32; 3],
pub max: [f32; 3],
}
impl Aabb {
pub fn new(min: [f32; 3], max: [f32; 3]) -> Self {
Self { min, max }
}
pub fn empty() -> Self {
Self {
min: [f32::MAX; 3],
max: [f32::MIN; 3],
}
}
pub fn is_valid(&self) -> bool {
self.min[0] <= self.max[0] && self.min[1] <= self.max[1] && self.min[2] <= self.max[2]
}
pub fn expand(&mut self, other: &Aabb) {
self.min[0] = self.min[0].min(other.min[0]);
self.min[1] = self.min[1].min(other.min[1]);
self.min[2] = self.min[2].min(other.min[2]);
self.max[0] = self.max[0].max(other.max[0]);
self.max[1] = self.max[1].max(other.max[1]);
self.max[2] = self.max[2].max(other.max[2]);
}
pub fn expand_point(&mut self, point: &[f32; 3]) {
self.min[0] = self.min[0].min(point[0]);
self.min[1] = self.min[1].min(point[1]);
self.min[2] = self.min[2].min(point[2]);
self.max[0] = self.max[0].max(point[0]);
self.max[1] = self.max[1].max(point[1]);
self.max[2] = self.max[2].max(point[2]);
}
pub fn overlaps(&self, other: &Aabb) -> bool {
self.min[0] <= other.max[0]
&& self.max[0] >= other.min[0]
&& self.min[1] <= other.max[1]
&& self.max[1] >= other.min[1]
&& self.min[2] <= other.max[2]
&& self.max[2] >= other.min[2]
}
pub fn center(&self) -> [f32; 3] {
[
(self.min[0] + self.max[0]) * 0.5,
(self.min[1] + self.max[1]) * 0.5,
(self.min[2] + self.max[2]) * 0.5,
]
}
pub fn surface_area(&self) -> f32 {
let dx = self.max[0] - self.min[0];
let dy = self.max[1] - self.min[1];
let dz = self.max[2] - self.min[2];
2.0 * (dx * dy + dy * dz + dz * dx)
}
}
#[derive(Debug, Clone)]
pub struct BVHItem {
pub poly_ref: PolyRef,
pub bounds: Aabb,
}
#[derive(Debug, Clone)]
pub enum BVHNode {
Leaf { bounds: Aabb, items: Vec<BVHItem> },
Internal {
bounds: Aabb,
left: Box<BVHNode>,
right: Box<BVHNode>,
},
}
impl BVHNode {
pub fn bounds(&self) -> &Aabb {
match self {
BVHNode::Leaf { bounds, .. } => bounds,
BVHNode::Internal { bounds, .. } => bounds,
}
}
pub(crate) fn query(&self, query_bounds: &Aabb, results: &mut Vec<PolyRef>) {
if !self.bounds().overlaps(query_bounds) {
return;
}
match self {
BVHNode::Leaf { items, .. } => {
for item in items {
if item.bounds.overlaps(query_bounds) {
results.push(item.poly_ref);
}
}
}
BVHNode::Internal { left, right, .. } => {
left.query(query_bounds, results);
right.query(query_bounds, results);
}
}
}
}
#[derive(Debug)]
pub struct BVHTree {
root: Option<BVHNode>,
max_leaf_size: usize,
}
impl Default for BVHTree {
fn default() -> Self {
Self::new()
}
}
impl BVHTree {
pub fn new() -> Self {
Self {
root: None,
max_leaf_size: 4,
}
}
pub fn build(&mut self, items: Vec<BVHItem>) {
if items.is_empty() {
self.root = None;
return;
}
self.root = Some(self.build_node(items));
}
fn build_node(&self, mut items: Vec<BVHItem>) -> BVHNode {
let mut bounds = Aabb::empty();
for item in &items {
bounds.expand(&item.bounds);
}
if items.len() <= self.max_leaf_size {
return BVHNode::Leaf { bounds, items };
}
let center = bounds.center();
let mut best_axis = 0;
let mut best_split = 0.0;
let mut best_cost = f32::MAX;
for (axis, ¢er_val) in center.iter().enumerate() {
let mut splits = vec![center_val];
for item in &items {
splits.push(item.bounds.center()[axis]);
}
for &split in &splits {
let (left_items, right_items) = partition_items(&items, axis, split);
if left_items.is_empty() || right_items.is_empty() {
continue;
}
let cost = self.calculate_sah_cost(&left_items, &right_items, &bounds);
if cost < best_cost {
best_cost = cost;
best_axis = axis;
best_split = split;
}
}
}
let (left_items, right_items) = partition_items(&items, best_axis, best_split);
let (left_items, right_items) = if left_items.is_empty() || right_items.is_empty() {
items.sort_by(|a, b| {
a.bounds.center()[best_axis]
.partial_cmp(&b.bounds.center()[best_axis])
.unwrap_or(std::cmp::Ordering::Equal)
});
let split_index = items.len() / 2;
let right = items.split_off(split_index);
(items, right)
} else {
(left_items, right_items)
};
let left = Box::new(self.build_node(left_items));
let right = Box::new(self.build_node(right_items));
BVHNode::Internal {
bounds,
left,
right,
}
}
fn calculate_sah_cost(
&self,
left_items: &[BVHItem],
right_items: &[BVHItem],
parent_bounds: &Aabb,
) -> f32 {
let mut left_bounds = Aabb::empty();
for item in left_items {
left_bounds.expand(&item.bounds);
}
let mut right_bounds = Aabb::empty();
for item in right_items {
right_bounds.expand(&item.bounds);
}
let parent_area = parent_bounds.surface_area();
let left_area = left_bounds.surface_area();
let right_area = right_bounds.surface_area();
let traversal_cost = 1.0;
let intersection_cost = 1.0;
traversal_cost
+ intersection_cost
* ((left_items.len() as f32 * left_area / parent_area)
+ (right_items.len() as f32 * right_area / parent_area))
}
pub fn query(&self, query_bounds: &Aabb) -> Vec<PolyRef> {
let mut results = Vec::new();
if let Some(root) = &self.root {
root.query(query_bounds, &mut results);
}
results
}
#[allow(dead_code)]
pub fn clear(&mut self) {
self.root = None;
}
}
fn partition_items(items: &[BVHItem], axis: usize, split: f32) -> (Vec<BVHItem>, Vec<BVHItem>) {
let mut left = Vec::new();
let mut right = Vec::new();
for item in items {
if item.bounds.center()[axis] < split {
left.push(item.clone());
} else {
right.push(item.clone());
}
}
(left, right)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aabb_overlap() {
let aabb1 = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
let aabb2 = Aabb::new([0.5, 0.5, 0.5], [1.5, 1.5, 1.5]);
let aabb3 = Aabb::new([2.0, 2.0, 2.0], [3.0, 3.0, 3.0]);
assert!(aabb1.overlaps(&aabb2));
assert!(aabb2.overlaps(&aabb1));
assert!(!aabb1.overlaps(&aabb3));
assert!(!aabb3.overlaps(&aabb1));
}
#[test]
fn test_bvh_query() {
let mut tree = BVHTree::new();
let items = vec![
BVHItem {
poly_ref: PolyRef::new(1),
bounds: Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
},
BVHItem {
poly_ref: PolyRef::new(2),
bounds: Aabb::new([1.0, 0.0, 0.0], [2.0, 1.0, 1.0]),
},
BVHItem {
poly_ref: PolyRef::new(3),
bounds: Aabb::new([0.0, 1.0, 0.0], [1.0, 2.0, 1.0]),
},
BVHItem {
poly_ref: PolyRef::new(4),
bounds: Aabb::new([5.0, 5.0, 5.0], [6.0, 6.0, 6.0]),
},
];
tree.build(items);
let query_bounds = Aabb::new([0.5, 0.5, 0.0], [1.5, 1.5, 1.0]);
let results = tree.query(&query_bounds);
assert_eq!(results.len(), 3);
assert!(results.contains(&PolyRef::new(1)));
assert!(results.contains(&PolyRef::new(2)));
assert!(results.contains(&PolyRef::new(3)));
assert!(!results.contains(&PolyRef::new(4)));
}
}