use crate::error::{SparseError, SparseResult};
#[derive(Debug, Clone)]
pub struct BoundingBox {
pub min: Vec<f64>,
pub max: Vec<f64>,
}
impl BoundingBox {
pub fn new(min: Vec<f64>, max: Vec<f64>) -> SparseResult<Self> {
if min.len() != max.len() {
return Err(SparseError::ValueError(format!(
"BoundingBox: min length {} != max length {}",
min.len(),
max.len()
)));
}
for (i, (&lo, &hi)) in min.iter().zip(max.iter()).enumerate() {
if lo > hi {
return Err(SparseError::ValueError(format!(
"BoundingBox: min[{}]={} > max[{}]={}",
i, lo, i, hi
)));
}
}
Ok(Self { min, max })
}
pub fn dim(&self) -> usize {
self.min.len()
}
pub fn diameter(&self) -> f64 {
self.min
.iter()
.zip(self.max.iter())
.map(|(&lo, &hi)| {
let d = hi - lo;
d * d
})
.sum::<f64>()
.sqrt()
}
pub fn distance_to(&self, other: &BoundingBox) -> f64 {
self.min
.iter()
.zip(self.max.iter())
.zip(other.min.iter().zip(other.max.iter()))
.map(|((&lo_s, &hi_s), (&lo_o, &hi_o))| {
let gap = f64::max(0.0, f64::max(lo_s - hi_o, lo_o - hi_s));
gap * gap
})
.sum::<f64>()
.sqrt()
}
pub fn union(&self, other: &BoundingBox) -> SparseResult<BoundingBox> {
if self.dim() != other.dim() {
return Err(SparseError::ValueError(format!(
"BoundingBox::union: dimension mismatch {} vs {}",
self.dim(),
other.dim()
)));
}
let min = self
.min
.iter()
.zip(other.min.iter())
.map(|(&a, &b)| f64::min(a, b))
.collect();
let max = self
.max
.iter()
.zip(other.max.iter())
.map(|(&a, &b)| f64::max(a, b))
.collect();
BoundingBox::new(min, max)
}
pub fn center(&self) -> Vec<f64> {
self.min
.iter()
.zip(self.max.iter())
.map(|(&lo, &hi)| 0.5 * (lo + hi))
.collect()
}
pub fn widest_dim(&self) -> usize {
self.min
.iter()
.zip(self.max.iter())
.enumerate()
.max_by(|(_, (lo1, hi1)), (_, (lo2, hi2))| {
let d1 = *hi1 - *lo1;
let d2 = *hi2 - *lo2;
d1.partial_cmp(&d2).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0)
}
}
#[derive(Debug, Clone)]
pub struct ClusterNode {
pub indices: Vec<usize>,
pub bbox: BoundingBox,
pub left: Option<usize>,
pub right: Option<usize>,
pub depth: usize,
}
impl ClusterNode {
pub fn is_leaf(&self) -> bool {
self.left.is_none() && self.right.is_none()
}
}
#[derive(Debug, Clone)]
pub struct ClusterTree {
pub nodes: Vec<ClusterNode>,
pub dim: usize,
pub leaf_size: usize,
}
impl ClusterTree {
pub fn root_idx(&self) -> usize {
0
}
pub fn root(&self) -> &ClusterNode {
&self.nodes[0]
}
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
pub fn leaf_indices(&self) -> Vec<usize> {
self.nodes
.iter()
.enumerate()
.filter(|(_, n)| n.is_leaf())
.map(|(i, _)| i)
.collect()
}
pub fn depth(&self) -> usize {
self.nodes.iter().map(|n| n.depth).max().unwrap_or(0)
}
}
pub fn build_cluster_tree(
coords: &[f64],
dim: usize,
leaf_size: usize,
) -> SparseResult<ClusterTree> {
if dim == 0 {
return Err(SparseError::ValueError(
"build_cluster_tree: dim must be >= 1".to_string(),
));
}
if leaf_size == 0 {
return Err(SparseError::ValueError(
"build_cluster_tree: leaf_size must be >= 1".to_string(),
));
}
if coords.len() % dim != 0 {
return Err(SparseError::ValueError(format!(
"build_cluster_tree: coords length {} is not divisible by dim {}",
coords.len(),
dim
)));
}
let n = coords.len() / dim;
if n == 0 {
return Err(SparseError::ValueError(
"build_cluster_tree: empty coordinate set".to_string(),
));
}
let points: Vec<Vec<f64>> = (0..n)
.map(|i| coords[i * dim..(i + 1) * dim].to_vec())
.collect();
let all_indices: Vec<usize> = (0..n).collect();
let root_bbox = compute_bbox(&points, &all_indices)?;
let mut tree = ClusterTree {
nodes: Vec::new(),
dim,
leaf_size,
};
build_recursive(&mut tree, &points, all_indices, root_bbox, 0)?;
Ok(tree)
}
fn build_recursive(
tree: &mut ClusterTree,
points: &[Vec<f64>],
indices: Vec<usize>,
bbox: BoundingBox,
depth: usize,
) -> SparseResult<usize> {
let node_idx = tree.nodes.len();
tree.nodes.push(ClusterNode {
indices: indices.clone(),
bbox: bbox.clone(),
left: None,
right: None,
depth,
});
if indices.len() <= tree.leaf_size {
return Ok(node_idx);
}
let split_dim = bbox.widest_dim();
let (left_indices, right_indices) = split_indices(points, &indices, split_dim);
if left_indices.is_empty() || right_indices.is_empty() {
return Ok(node_idx);
}
let left_bbox = compute_bbox(points, &left_indices)?;
let right_bbox = compute_bbox(points, &right_indices)?;
let left_child = build_recursive(tree, points, left_indices, left_bbox, depth + 1)?;
let right_child = build_recursive(tree, points, right_indices, right_bbox, depth + 1)?;
tree.nodes[node_idx].left = Some(left_child);
tree.nodes[node_idx].right = Some(right_child);
Ok(node_idx)
}
fn compute_bbox(points: &[Vec<f64>], indices: &[usize]) -> SparseResult<BoundingBox> {
if indices.is_empty() {
return Err(SparseError::ValueError(
"compute_bbox: empty index set".to_string(),
));
}
let dim = points[indices[0]].len();
let mut min_coords = vec![f64::INFINITY; dim];
let mut max_coords = vec![f64::NEG_INFINITY; dim];
for &idx in indices {
for (d, &coord) in points[idx].iter().enumerate() {
if coord < min_coords[d] {
min_coords[d] = coord;
}
if coord > max_coords[d] {
max_coords[d] = coord;
}
}
}
BoundingBox::new(min_coords, max_coords)
}
fn split_indices(
points: &[Vec<f64>],
indices: &[usize],
split_dim: usize,
) -> (Vec<usize>, Vec<usize>) {
let mut coords: Vec<(f64, usize)> = indices
.iter()
.map(|&i| (points[i][split_dim], i))
.collect();
coords.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let mid = coords.len() / 2;
let left: Vec<usize> = coords[..mid].iter().map(|(_, i)| *i).collect();
let right: Vec<usize> = coords[mid..].iter().map(|(_, i)| *i).collect();
(left, right)
}
pub fn admissibility_check(tau: &BoundingBox, sigma: &BoundingBox, eta: f64) -> bool {
let diam_tau = tau.diameter();
let diam_sigma = sigma.diameter();
let dist = tau.distance_to(sigma);
if dist <= 0.0 {
return false;
}
let min_diam = f64::min(diam_tau, diam_sigma);
min_diam <= eta * dist
}
#[cfg(test)]
mod tests {
use super::*;
fn make_1d_tree(n: usize, leaf_size: usize) -> ClusterTree {
let coords: Vec<f64> = (0..n).map(|i| i as f64).collect();
build_cluster_tree(&coords, 1, leaf_size).expect("build failed")
}
#[test]
fn test_bounding_box_diameter() {
let bb = BoundingBox::new(vec![0.0, 0.0], vec![3.0, 4.0]).expect("ok");
let d = bb.diameter();
assert!((d - 5.0).abs() < 1e-12, "diameter = {}", d);
}
#[test]
fn test_bounding_box_distance() {
let a = BoundingBox::new(vec![0.0], vec![1.0]).expect("ok");
let b = BoundingBox::new(vec![3.0], vec![4.0]).expect("ok");
let d = a.distance_to(&b);
assert!((d - 2.0).abs() < 1e-12, "distance = {}", d);
}
#[test]
fn test_bounding_box_overlapping_distance() {
let a = BoundingBox::new(vec![0.0], vec![2.0]).expect("ok");
let b = BoundingBox::new(vec![1.0], vec![3.0]).expect("ok");
let d = a.distance_to(&b);
assert_eq!(d, 0.0, "overlapping boxes should have distance 0");
}
#[test]
fn test_build_cluster_tree_leaf() {
let tree = make_1d_tree(2, 2);
assert_eq!(tree.num_nodes(), 1);
assert!(tree.root().is_leaf());
}
#[test]
fn test_build_cluster_tree_split() {
let tree = make_1d_tree(4, 2);
assert_eq!(tree.num_nodes(), 3, "nodes={}", tree.num_nodes());
assert!(!tree.root().is_leaf());
}
#[test]
fn test_build_cluster_tree_large() {
let tree = make_1d_tree(16, 2);
let leaves = tree.leaf_indices();
let mut all: Vec<usize> = leaves
.iter()
.flat_map(|&li| tree.nodes[li].indices.iter().copied())
.collect();
all.sort_unstable();
let expected: Vec<usize> = (0..16).collect();
assert_eq!(all, expected, "indices not covered: {:?}", all);
}
#[test]
fn test_admissibility_check() {
let tau = BoundingBox::new(vec![0.0], vec![1.0]).expect("ok");
let sigma = BoundingBox::new(vec![5.0], vec![6.0]).expect("ok");
assert!(admissibility_check(&tau, &sigma, 1.0));
}
#[test]
fn test_admissibility_check_not_admissible() {
let tau = BoundingBox::new(vec![0.0], vec![1.0]).expect("ok");
let sigma = BoundingBox::new(vec![1.5], vec![2.5]).expect("ok");
assert!(!admissibility_check(&tau, &sigma, 0.5));
}
#[test]
fn test_admissibility_overlapping() {
let tau = BoundingBox::new(vec![0.0], vec![2.0]).expect("ok");
let sigma = BoundingBox::new(vec![1.0], vec![3.0]).expect("ok");
assert!(!admissibility_check(&tau, &sigma, 10.0));
}
#[test]
fn test_2d_cluster_tree() {
let mut coords = Vec::new();
for i in 0..4 {
for j in 0..4 {
coords.push(i as f64);
coords.push(j as f64);
}
}
let tree = build_cluster_tree(&coords, 2, 2).expect("build failed");
let leaves = tree.leaf_indices();
let mut all: Vec<usize> = leaves
.iter()
.flat_map(|&li| tree.nodes[li].indices.iter().copied())
.collect();
all.sort_unstable();
let expected: Vec<usize> = (0..16).collect();
assert_eq!(all, expected);
}
#[test]
fn test_error_dim_zero() {
let coords = vec![1.0, 2.0];
let result = build_cluster_tree(&coords, 0, 1);
assert!(result.is_err());
}
#[test]
fn test_error_leaf_size_zero() {
let coords = vec![1.0, 2.0];
let result = build_cluster_tree(&coords, 1, 0);
assert!(result.is_err());
}
#[test]
fn test_error_coords_not_divisible() {
let coords = vec![1.0, 2.0, 3.0];
let result = build_cluster_tree(&coords, 2, 1);
assert!(result.is_err());
}
}