use crate::Aabb;
pub struct Bvh {
root: Option<Node>,
}
enum Node {
Leaf {
item_index: usize,
},
Branch {
bounds: Aabb,
left: Box<Self>,
right: Box<Self>,
},
}
impl Bvh {
pub fn build(bounds: &[Aabb]) -> Self {
if bounds.is_empty() {
return Self { root: None };
}
let mut indices: Vec<usize> = (0..bounds.len()).collect();
Self {
root: Some(build_node(bounds, &mut indices)),
}
}
pub fn overlapping(&self, query: &Aabb) -> Vec<usize> {
let mut results = Vec::new();
if let Some(root) = &self.root {
collect_overlapping(root, query, &mut results);
}
results
}
pub fn ray_overlapping(
&self,
origin: [f32; 3],
inverse_direction: [f32; 3],
max_distance: f32,
) -> Vec<usize> {
let mut results = Vec::new();
if let Some(root) = &self.root {
collect_ray_overlapping(root, origin, inverse_direction, max_distance, &mut results);
}
results
}
}
fn build_node(bounds: &[Aabb], indices: &mut [usize]) -> Node {
if let [single] = indices {
return Node::Leaf {
item_index: *single,
};
}
let combined = indices.iter().fold(Aabb::EMPTY, |accumulated, &index| {
accumulated.merged(&bounds[index])
});
let split_axis = largest_axis(&combined);
indices.sort_unstable_by(|&a, &b| {
let center_a = f32::midpoint(bounds[a].min[split_axis], bounds[a].max[split_axis]);
let center_b = f32::midpoint(bounds[b].min[split_axis], bounds[b].max[split_axis]);
center_a.total_cmp(¢er_b)
});
let middle = indices.len() / 2;
let (left_indices, right_indices) = indices.split_at_mut(middle);
Node::Branch {
bounds: combined,
left: Box::new(build_node(bounds, left_indices)),
right: Box::new(build_node(bounds, right_indices)),
}
}
fn largest_axis(bounds: &Aabb) -> usize {
let extents = [
bounds.max[0] - bounds.min[0],
bounds.max[1] - bounds.min[1],
bounds.max[2] - bounds.min[2],
];
if extents[0] >= extents[1] && extents[0] >= extents[2] {
0
} else if extents[1] >= extents[2] {
1
} else {
2
}
}
fn collect_overlapping(node: &Node, query: &Aabb, results: &mut Vec<usize>) {
match node {
Node::Leaf { item_index } => results.push(*item_index),
Node::Branch {
bounds,
left,
right,
} => {
if bounds.overlaps(query) {
collect_overlapping(left, query, results);
collect_overlapping(right, query, results);
}
}
}
}
fn collect_ray_overlapping(
node: &Node,
origin: [f32; 3],
inverse_direction: [f32; 3],
max_distance: f32,
results: &mut Vec<usize>,
) {
match node {
Node::Leaf { item_index } => results.push(*item_index),
Node::Branch {
bounds,
left,
right,
} => {
if ray_hits(bounds, origin, inverse_direction, max_distance) {
collect_ray_overlapping(left, origin, inverse_direction, max_distance, results);
collect_ray_overlapping(right, origin, inverse_direction, max_distance, results);
}
}
}
}
fn ray_hits(
bounds: &Aabb,
origin: [f32; 3],
inverse_direction: [f32; 3],
max_distance: f32,
) -> bool {
let mut entry = 0.0_f32;
let mut exit = max_distance;
for axis in 0..3 {
let near = (bounds.min[axis] - origin[axis]) * inverse_direction[axis];
let far = (bounds.max[axis] - origin[axis]) * inverse_direction[axis];
entry = entry.max(near.min(far));
exit = exit.min(near.max(far));
}
entry <= exit
}