use crate::error::{SpatialError, SpatialResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use super::node::{Entry, Node, RTree, Rectangle};
#[derive(Debug, Clone)]
pub struct BoundingBox {
pub min: Vec<f64>,
pub max: Vec<f64>,
}
impl BoundingBox {
pub fn new(min: Vec<f64>, max: Vec<f64>) -> SpatialResult<Self> {
if min.len() != max.len() {
return Err(SpatialError::DimensionError(format!(
"BoundingBox min ({}) and max ({}) must have the same length",
min.len(),
max.len()
)));
}
for i in 0..min.len() {
if min[i] > max[i] {
return Err(SpatialError::ValueError(format!(
"BoundingBox min[{i}]={} > max[{i}]={}",
min[i], max[i]
)));
}
}
Ok(Self { min, max })
}
pub fn from_point(point: &[f64]) -> Self {
Self {
min: point.to_vec(),
max: point.to_vec(),
}
}
pub fn ndim(&self) -> usize {
self.min.len()
}
pub fn volume(&self) -> f64 {
self.min
.iter()
.zip(self.max.iter())
.map(|(lo, hi)| hi - lo)
.product()
}
pub fn half_perimeter(&self) -> f64 {
self.min
.iter()
.zip(self.max.iter())
.map(|(lo, hi)| hi - lo)
.sum()
}
pub fn center(&self) -> Vec<f64> {
self.min
.iter()
.zip(self.max.iter())
.map(|(lo, hi)| (lo + hi) * 0.5)
.collect()
}
pub fn contains_point(&self, point: &[f64]) -> SpatialResult<bool> {
if point.len() != self.ndim() {
return Err(SpatialError::DimensionError(format!(
"Point dimension {} != BoundingBox dimension {}",
point.len(),
self.ndim()
)));
}
Ok(self
.min
.iter()
.zip(self.max.iter())
.zip(point.iter())
.all(|((lo, hi), p)| *p >= *lo && *p <= *hi))
}
pub fn overlaps(&self, other: &BoundingBox) -> SpatialResult<bool> {
if self.ndim() != other.ndim() {
return Err(SpatialError::DimensionError(format!(
"BoundingBox dimensions {} and {} differ",
self.ndim(),
other.ndim()
)));
}
Ok(self
.min
.iter()
.zip(self.max.iter())
.zip(other.min.iter())
.zip(other.max.iter())
.all(|(((lo, hi), olo), ohi)| lo <= ohi && hi >= olo))
}
pub fn overlap_volume(&self, other: &BoundingBox) -> SpatialResult<f64> {
if self.ndim() != other.ndim() {
return Err(SpatialError::DimensionError(format!(
"BoundingBox dimensions {} and {} differ",
self.ndim(),
other.ndim()
)));
}
let mut vol = 1.0_f64;
for i in 0..self.ndim() {
let lo = self.min[i].max(other.min[i]);
let hi = self.max[i].min(other.max[i]);
if hi < lo {
return Ok(0.0);
}
vol *= hi - lo;
}
Ok(vol)
}
pub fn union(&self, other: &BoundingBox) -> SpatialResult<BoundingBox> {
if self.ndim() != other.ndim() {
return Err(SpatialError::DimensionError(format!(
"BoundingBox dimensions {} and {} differ",
self.ndim(),
other.ndim()
)));
}
let min: Vec<f64> = self
.min
.iter()
.zip(other.min.iter())
.map(|(a, b)| a.min(*b))
.collect();
let max: Vec<f64> = self
.max
.iter()
.zip(other.max.iter())
.map(|(a, b)| a.max(*b))
.collect();
Ok(BoundingBox { min, max })
}
pub fn enlargement_needed(&self, other: &BoundingBox) -> SpatialResult<f64> {
let enlarged = self.union(other)?;
Ok(enlarged.volume() - self.volume())
}
pub fn min_distance_to_point(&self, point: &[f64]) -> SpatialResult<f64> {
if point.len() != self.ndim() {
return Err(SpatialError::DimensionError(format!(
"Point dimension {} != BoundingBox dimension {}",
point.len(),
self.ndim()
)));
}
let sq: f64 = self
.min
.iter()
.zip(self.max.iter())
.zip(point.iter())
.map(|((lo, hi), p)| {
if *p < *lo {
(p - lo).powi(2)
} else if *p > *hi {
(p - hi).powi(2)
} else {
0.0
}
})
.sum();
Ok(sq.sqrt())
}
pub fn from_rectangle(r: &Rectangle) -> Self {
Self {
min: r.min.to_vec(),
max: r.max.to_vec(),
}
}
pub fn to_rectangle(&self) -> SpatialResult<Rectangle> {
let min = Array1::from_vec(self.min.clone());
let max = Array1::from_vec(self.max.clone());
Rectangle::new(min, max)
}
}
#[derive(Debug, Clone)]
pub struct RTreeNode<T: Clone> {
pub bbox: BoundingBox,
pub kind: RTreeNodeKind<T>,
}
#[derive(Debug, Clone)]
pub enum RTreeNodeKind<T: Clone> {
Internal(Vec<RTreeNode<T>>),
Leaf(Vec<(usize, T)>),
}
pub struct RTreeIterator<'a, T: Clone> {
query_min: Array1<f64>,
query_max: Array1<f64>,
stack: Vec<&'a Node<T>>,
buffer: Vec<(usize, T)>,
}
impl<'a, T: Clone> RTreeIterator<'a, T> {
pub(crate) fn new(
root: &'a Node<T>,
query_min: Array1<f64>,
query_max: Array1<f64>,
) -> Self {
Self {
query_min,
query_max,
stack: vec![root],
buffer: Vec::new(),
}
}
fn advance(&mut self) {
while self.buffer.is_empty() {
let node = match self.stack.pop() {
Some(n) => n,
None => return,
};
let query_rect =
match Rectangle::new(self.query_min.clone(), self.query_max.clone()) {
Ok(r) => r,
Err(_) => return,
};
for entry in &node.entries {
let intersects = entry.mbr().intersects(&query_rect).unwrap_or(false);
if !intersects {
continue;
}
match entry {
Entry::Leaf { index, data, .. } => {
self.buffer.push((*index, data.clone()));
}
Entry::NonLeaf { child, .. } => {
self.stack.push(child);
}
}
}
}
}
}
impl<'a, T: Clone> Iterator for RTreeIterator<'a, T> {
type Item = (usize, T);
fn next(&mut self) -> Option<Self::Item> {
if self.buffer.is_empty() {
self.advance();
}
if self.buffer.is_empty() {
None
} else {
Some(self.buffer.remove(0))
}
}
}
impl<T: Clone> RTree<T> {
pub fn iter_range(
&self,
min: Array1<f64>,
max: Array1<f64>,
) -> SpatialResult<RTreeIterator<'_, T>> {
if min.len() != self.ndim() || max.len() != self.ndim() {
return Err(SpatialError::DimensionError(format!(
"Range dimensions ({}, {}) do not match tree dimension {}",
min.len(),
max.len(),
self.ndim()
)));
}
Ok(RTreeIterator::new(&self.root, min, max))
}
pub fn root_bbox(&self) -> Option<BoundingBox> {
self.root
.mbr()
.ok()
.flatten()
.as_ref()
.map(BoundingBox::from_rectangle)
}
}
pub struct STRPacking;
impl STRPacking {
pub fn build_2d<T: Clone>(
points: &ArrayView2<f64>,
data: Vec<T>,
node_capacity: usize,
) -> SpatialResult<RTree<T>> {
if points.ncols() != 2 {
return Err(SpatialError::DimensionError(format!(
"STRPacking::build_2d requires 2-column input, got {}",
points.ncols()
)));
}
if points.nrows() != data.len() {
return Err(SpatialError::ValueError(format!(
"points has {} rows but data has {} elements",
points.nrows(),
data.len()
)));
}
if node_capacity < 2 {
return Err(SpatialError::ValueError(
"node_capacity must be at least 2".to_string(),
));
}
let n = points.nrows();
if n == 0 {
return RTree::new(2, 1, node_capacity);
}
let mut items: Vec<(f64, f64, usize, T)> = points
.rows()
.into_iter()
.enumerate()
.zip(data.into_iter())
.map(|((i, row), d)| (row[0], row[1], i, d))
.collect();
items.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let capacity = node_capacity;
let num_leaves = (n + capacity - 1) / capacity;
let slice_count = (num_leaves as f64).sqrt().ceil() as usize;
let slice_size = slice_count * capacity;
for chunk in items.chunks_mut(slice_size) {
chunk.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
}
let min_cap = (capacity / 2).max(1);
let mut tree = RTree::new(2, min_cap, capacity)?;
for (x, y, orig_idx, d) in items {
let pt = Array1::from_vec(vec![x, y]);
let _ = orig_idx; tree.insert(pt, d)?;
}
Ok(tree)
}
pub fn build_nd<T: Clone>(
points: &ArrayView2<f64>,
data: Vec<T>,
node_capacity: usize,
) -> SpatialResult<RTree<T>> {
let ndim = points.ncols();
if ndim == 0 {
return Err(SpatialError::DimensionError(
"Points must have at least 1 dimension".to_string(),
));
}
if points.nrows() != data.len() {
return Err(SpatialError::ValueError(format!(
"points has {} rows but data has {} elements",
points.nrows(),
data.len()
)));
}
if node_capacity < 2 {
return Err(SpatialError::ValueError(
"node_capacity must be at least 2".to_string(),
));
}
let n = points.nrows();
if n == 0 {
let min_cap = (node_capacity / 2).max(1);
return RTree::new(ndim, min_cap, node_capacity);
}
let mut items: Vec<(Vec<f64>, T)> = points
.rows()
.into_iter()
.zip(data.into_iter())
.map(|(row, d)| (row.to_vec(), d))
.collect();
Self::sort_by_dimension(&mut items, 0, ndim);
let min_cap = (node_capacity / 2).max(1);
let mut tree = RTree::new(ndim, min_cap, node_capacity)?;
for (coords, d) in items {
tree.insert(Array1::from_vec(coords), d)?;
}
Ok(tree)
}
fn sort_by_dimension<T: Clone>(
items: &mut [(Vec<f64>, T)],
dim: usize,
ndim: usize,
) {
if items.len() <= 1 {
return;
}
let d = dim % ndim;
items.sort_by(|a, b| {
a.0[d]
.partial_cmp(&b.0[d])
.unwrap_or(std::cmp::Ordering::Equal)
});
if dim + 1 < ndim {
let mid = items.len() / 2;
let (left, right) = items.split_at_mut(mid);
Self::sort_by_dimension(left, dim + 1, ndim);
Self::sort_by_dimension(right, dim + 1, ndim);
}
}
pub fn compute_bbox(points: &ArrayView2<f64>) -> SpatialResult<BoundingBox> {
let ndim = points.ncols();
if ndim == 0 {
return Err(SpatialError::DimensionError(
"Points must have at least 1 dimension".to_string(),
));
}
if points.nrows() == 0 {
return Err(SpatialError::ValueError(
"Cannot compute bounding box of empty point set".to_string(),
));
}
let mut min = vec![f64::INFINITY; ndim];
let mut max = vec![f64::NEG_INFINITY; ndim];
for row in points.rows() {
for (i, &v) in row.iter().enumerate() {
if v < min[i] {
min[i] = v;
}
if v > max[i] {
max[i] = v;
}
}
}
BoundingBox::new(min, max)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_bounding_box_basic() {
let bb = BoundingBox::new(vec![0.0, 0.0], vec![2.0, 3.0]).unwrap();
assert!((bb.volume() - 6.0).abs() < 1e-12);
assert!((bb.half_perimeter() - 5.0).abs() < 1e-12);
}
#[test]
fn test_bounding_box_overlap() {
let a = BoundingBox::new(vec![0.0, 0.0], vec![2.0, 2.0]).unwrap();
let b = BoundingBox::new(vec![1.0, 1.0], vec![3.0, 3.0]).unwrap();
let c = BoundingBox::new(vec![5.0, 5.0], vec![6.0, 6.0]).unwrap();
assert!(a.overlaps(&b).unwrap());
assert!(!a.overlaps(&c).unwrap());
}
#[test]
fn test_bounding_box_overlap_volume() {
let a = BoundingBox::new(vec![0.0, 0.0], vec![2.0, 2.0]).unwrap();
let b = BoundingBox::new(vec![1.0, 1.0], vec![3.0, 3.0]).unwrap();
let ov = a.overlap_volume(&b).unwrap();
assert!((ov - 1.0).abs() < 1e-12);
}
#[test]
fn test_bounding_box_enlargement() {
let a = BoundingBox::new(vec![0.0, 0.0], vec![1.0, 1.0]).unwrap();
let b = BoundingBox::new(vec![0.0, 0.0], vec![2.0, 2.0]).unwrap();
let enl = a.enlargement_needed(&b).unwrap();
assert!((enl - 3.0).abs() < 1e-12);
}
#[test]
fn test_str_packing_2d() {
let points = array![
[0.0_f64, 0.0],
[1.0, 1.0],
[2.0, 0.5],
[3.0, 2.0],
[0.5, 2.5],
[1.5, 0.5],
];
let data: Vec<usize> = (0..6).collect();
let tree = STRPacking::build_2d(&points.view(), data, 3).unwrap();
assert_eq!(tree.size(), 6);
}
#[test]
fn test_str_packing_nd() {
let points = array![
[0.0_f64, 0.0, 0.0],
[1.0, 1.0, 1.0],
[2.0, 0.5, 1.5],
[3.0, 2.0, 0.5],
];
let data: Vec<usize> = (0..4).collect();
let tree = STRPacking::build_nd(&points.view(), data, 2).unwrap();
assert_eq!(tree.size(), 4);
}
#[test]
fn test_rtree_iter_range() {
let mut tree = RTree::new(2, 2, 4).unwrap();
use scirs2_core::ndarray::Array1;
let pts: Vec<[f64; 2]> = vec![
[0.0, 0.0],
[1.0, 1.0],
[2.0, 2.0],
[5.0, 5.0],
];
for (i, p) in pts.iter().enumerate() {
tree.insert(Array1::from_vec(p.to_vec()), i).unwrap();
}
let min = Array1::from_vec(vec![-0.5, -0.5]);
let max = Array1::from_vec(vec![2.5, 2.5]);
let results: Vec<_> = tree.iter_range(min, max).unwrap().collect();
assert_eq!(results.len(), 3);
}
#[test]
fn test_str_compute_bbox() {
let pts = array![[1.0_f64, 2.0], [3.0, 0.0], [2.0, 5.0]];
let bb = STRPacking::compute_bbox(&pts.view()).unwrap();
assert!((bb.min[0] - 1.0).abs() < 1e-12);
assert!((bb.max[0] - 3.0).abs() < 1e-12);
assert!((bb.min[1] - 0.0).abs() < 1e-12);
assert!((bb.max[1] - 5.0).abs() < 1e-12);
}
}