use crate::error::{FFTError, FFTResult};
#[derive(Debug, Clone)]
pub struct TreeNode {
pub center: [f64; 2],
pub half_width: f64,
pub depth: usize,
pub point_indices: Vec<usize>,
pub children: Option<Box<[TreeNode; 4]>>,
pub is_leaf: bool,
pub node_id: usize,
}
impl TreeNode {
pub fn new(center: [f64; 2], half_width: f64, depth: usize) -> Self {
TreeNode {
center,
half_width,
depth,
point_indices: Vec::new(),
children: None,
is_leaf: true,
node_id: 0,
}
}
pub fn contains(&self, p: [f64; 2]) -> bool {
let hw = self.half_width;
p[0] >= self.center[0] - hw
&& p[0] <= self.center[0] + hw
&& p[1] >= self.center[1] - hw
&& p[1] <= self.center[1] + hw
}
pub fn child_centers(&self) -> [[f64; 2]; 4] {
let qw = self.half_width * 0.5;
let cx = self.center[0];
let cy = self.center[1];
[
[cx - qw, cy - qw], [cx + qw, cy - qw], [cx - qw, cy + qw], [cx + qw, cy + qw], ]
}
pub fn child_index_for(&self, p: [f64; 2]) -> Option<usize> {
if !self.contains(p) {
return None;
}
let east = p[0] > self.center[0];
let north = p[1] > self.center[1];
let idx = match (east, north) {
(false, false) => 0, (true, false) => 1, (false, true) => 2, (true, true) => 3, };
Some(idx)
}
pub fn is_adjacent(&self, other: &TreeNode) -> bool {
let tol = 1e-10;
let hw_sum = self.half_width + other.half_width + tol;
let dx = (self.center[0] - other.center[0]).abs();
let dy = (self.center[1] - other.center[1]).abs();
dx < hw_sum && dy < hw_sum
}
pub fn is_well_separated(&self, other: &TreeNode, mac: f64) -> bool {
let dx = self.center[0] - other.center[0];
let dy = self.center[1] - other.center[1];
let r = (dx * dx + dy * dy).sqrt();
let max_hw = self.half_width.max(other.half_width);
r > max_hw / mac
}
pub fn bbox(&self) -> [f64; 4] {
[
self.center[0] - self.half_width,
self.center[1] - self.half_width,
self.center[0] + self.half_width,
self.center[1] + self.half_width,
]
}
}
#[derive(Debug, Clone)]
pub struct QuadTree {
pub root: TreeNode,
pub max_depth: usize,
pub max_points_per_leaf: usize,
pub node_count: usize,
}
impl QuadTree {
pub fn build(
points: &[[f64; 2]],
max_depth: usize,
max_per_leaf: usize,
) -> FFTResult<Self> {
if points.is_empty() {
return Err(FFTError::ValueError("QuadTree: empty point set".into()));
}
let mut min_x = points[0][0];
let mut max_x = points[0][0];
let mut min_y = points[0][1];
let mut max_y = points[0][1];
for p in points.iter() {
if p[0] < min_x { min_x = p[0]; }
if p[0] > max_x { max_x = p[0]; }
if p[1] < min_y { min_y = p[1]; }
if p[1] > max_y { max_y = p[1]; }
}
let margin = 1e-6;
let center_x = (min_x + max_x) * 0.5;
let center_y = (min_y + max_y) * 0.5;
let half_width = ((max_x - min_x).max(max_y - min_y)) * 0.5 + margin;
let root = TreeNode::new([center_x, center_y], half_width, 0);
let mut tree = QuadTree {
root,
max_depth,
max_points_per_leaf: max_per_leaf,
node_count: 1,
};
let all_indices: Vec<usize> = (0..points.len()).collect();
insert_recursive(
&mut tree.root,
points,
&all_indices,
0,
max_depth,
max_per_leaf,
&mut tree.node_count,
);
Ok(tree)
}
pub fn leaves(&self) -> Vec<&TreeNode> {
let mut result = Vec::new();
collect_leaves(&self.root, &mut result);
result
}
pub fn level_nodes(&self, level: usize) -> Vec<&TreeNode> {
let mut result = Vec::new();
collect_level(&self.root, level, &mut result);
result
}
pub fn all_nodes_bfs(&self) -> Vec<&TreeNode> {
let mut result = Vec::new();
let mut queue = std::collections::VecDeque::new();
queue.push_back(&self.root);
while let Some(node) = queue.pop_front() {
result.push(node);
if let Some(children) = &node.children {
for child in children.iter() {
queue.push_back(child);
}
}
}
result
}
pub fn actual_depth(&self) -> usize {
max_depth_recursive(&self.root)
}
}
fn insert_recursive(
node: &mut TreeNode,
points: &[[f64; 2]],
indices: &[usize],
depth: usize,
max_depth: usize,
max_per_leaf: usize,
node_count: &mut usize,
) {
node.point_indices.extend_from_slice(indices);
if depth >= max_depth || indices.len() <= max_per_leaf {
node.is_leaf = true;
return;
}
node.is_leaf = false;
let child_centers = node.child_centers();
let child_hw = node.half_width * 0.5;
let child_depth = depth + 1;
let mut children = Box::new([
TreeNode::new(child_centers[0], child_hw, child_depth),
TreeNode::new(child_centers[1], child_hw, child_depth),
TreeNode::new(child_centers[2], child_hw, child_depth),
TreeNode::new(child_centers[3], child_hw, child_depth),
]);
for i in 0..4usize {
*node_count += 1;
children[i].node_id = *node_count;
}
let mut child_indices: [Vec<usize>; 4] = [
Vec::new(),
Vec::new(),
Vec::new(),
Vec::new(),
];
for &idx in indices.iter() {
let p = points[idx];
let east = p[0] > node.center[0];
let north = p[1] > node.center[1];
let q = match (east, north) {
(false, false) => 0,
(true, false) => 1,
(false, true) => 2,
(true, true) => 3,
};
child_indices[q].push(idx);
}
for i in 0..4usize {
if !child_indices[i].is_empty() {
insert_recursive(
&mut children[i],
points,
&child_indices[i],
child_depth,
max_depth,
max_per_leaf,
node_count,
);
} else {
children[i].is_leaf = true;
}
}
node.children = Some(children);
}
fn collect_leaves<'a>(node: &'a TreeNode, result: &mut Vec<&'a TreeNode>) {
if node.is_leaf {
result.push(node);
} else if let Some(children) = &node.children {
for child in children.iter() {
collect_leaves(child, result);
}
}
}
fn collect_level<'a>(node: &'a TreeNode, level: usize, result: &mut Vec<&'a TreeNode>) {
if node.depth == level {
result.push(node);
return;
}
if let Some(children) = &node.children {
for child in children.iter() {
collect_level(child, level, result);
}
}
}
fn max_depth_recursive(node: &TreeNode) -> usize {
if node.is_leaf {
return node.depth;
}
if let Some(children) = &node.children {
children
.iter()
.map(max_depth_recursive)
.max()
.unwrap_or(node.depth)
} else {
node.depth
}
}
#[derive(Debug, Clone)]
pub struct OctTree {
pub center: [f64; 3],
pub half_width: f64,
pub max_depth: usize,
}
impl OctTree {
pub fn new(center: [f64; 3], half_width: f64, max_depth: usize) -> Self {
OctTree { center, half_width, max_depth }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tree_node_contains() {
let node = TreeNode::new([0.0, 0.0], 1.0, 0);
assert!(node.contains([0.0, 0.0]));
assert!(node.contains([0.9, 0.9]));
assert!(!node.contains([1.1, 0.0]));
}
#[test]
fn test_child_centers_quadrants() {
let node = TreeNode::new([0.0, 0.0], 1.0, 0);
let centers = node.child_centers();
assert!((centers[0][0] + 0.5).abs() < 1e-12);
assert!((centers[0][1] + 0.5).abs() < 1e-12);
assert!((centers[3][0] - 0.5).abs() < 1e-12);
assert!((centers[3][1] - 0.5).abs() < 1e-12);
}
#[test]
fn test_quad_tree_build() {
let points: Vec<[f64; 2]> = (0..100)
.map(|i| {
let t = i as f64 / 100.0;
[t.cos(), t.sin()]
})
.collect();
let tree = QuadTree::build(&points, 5, 4).expect("build failed");
let leaves = tree.leaves();
assert!(!leaves.is_empty());
let mut seen = vec![false; 100];
for leaf in &leaves {
for &idx in &leaf.point_indices {
assert!(!seen[idx], "duplicate index {idx}");
seen[idx] = true;
}
}
assert!(seen.iter().all(|&s| s));
}
#[test]
fn test_adjacency() {
let n1 = TreeNode::new([0.0, 0.0], 0.5, 0);
let n2 = TreeNode::new([1.0, 0.0], 0.5, 0);
let n3 = TreeNode::new([2.0, 0.0], 0.5, 0);
assert!(n1.is_adjacent(&n2));
assert!(!n1.is_adjacent(&n3));
}
#[test]
fn test_well_separated() {
let n1 = TreeNode::new([0.0, 0.0], 0.5, 0);
let n2 = TreeNode::new([10.0, 0.0], 0.5, 0);
assert!(n1.is_well_separated(&n2, 0.5));
let n3 = TreeNode::new([1.0, 0.0], 0.5, 0);
assert!(!n1.is_well_separated(&n3, 0.5));
}
}