use crate::ndarray_backend::NdarrayBox;
use crate::utils::BOUNDARY_CONTAINMENT_THRESHOLD;
use crate::{Box, BoxError};
use ndarray::Array1;
pub fn vector_to_box_distance(point: &Array1<f32>, box_: &NdarrayBox) -> Result<f32, BoxError> {
if point.len() != box_.dim() {
return Err(BoxError::DimensionMismatch {
expected: box_.dim(),
actual: point.len(),
});
}
let mut dist_sq = 0.0;
for i in 0..box_.dim() {
let point_val = point[i];
let min_val = box_.min()[i];
let max_val = box_.max()[i];
if point_val < min_val {
let gap = min_val - point_val;
dist_sq += gap * gap;
} else if point_val > max_val {
let gap = point_val - max_val;
dist_sq += gap * gap;
}
}
Ok(dist_sq.sqrt())
}
pub fn boundary_distance(outer: &NdarrayBox, inner: &NdarrayBox) -> Result<Option<f32>, BoxError> {
let containment = outer.containment_prob(inner)?;
if containment < BOUNDARY_CONTAINMENT_THRESHOLD {
return Ok(None);
}
let mut min_gap = f32::INFINITY;
for i in 0..outer.dim() {
let gap_min = inner.min()[i] - outer.min()[i];
let gap_max = outer.max()[i] - inner.max()[i];
let gap = gap_min.min(gap_max);
min_gap = min_gap.min(gap);
}
if min_gap == f32::INFINITY {
return Ok(Some(0.0));
}
Ok(Some(min_gap))
}
pub fn query2box_distance(
query_box: &NdarrayBox,
entity_point: &Array1<f32>,
alpha: f32,
) -> Result<f32, BoxError> {
let center: Vec<f32> = query_box
.min()
.iter()
.zip(query_box.max().iter())
.map(|(lo, hi)| (lo + hi) * 0.5)
.collect();
let offset: Vec<f32> = query_box
.min()
.iter()
.zip(query_box.max().iter())
.map(|(lo, hi)| (hi - lo) * 0.5)
.collect();
crate::distance::query2box_distance(¢er, &offset, entity_point.as_slice().unwrap(), alpha)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn test_vector_to_box_distance_inside() {
let box_ = NdarrayBox::new(array![0.0, 0.0], array![1.0, 1.0], 1.0).unwrap();
let point = array![0.5, 0.5];
let dist = vector_to_box_distance(&point, &box_).unwrap();
assert_eq!(dist, 0.0);
}
#[test]
fn test_vector_to_box_distance_outside() {
let box_ = NdarrayBox::new(array![0.0, 0.0], array![1.0, 1.0], 1.0).unwrap();
let point = array![2.0, 2.0];
let dist = vector_to_box_distance(&point, &box_).unwrap();
assert!((dist - 2.0_f32.sqrt()).abs() < 1e-5);
}
#[test]
fn test_vector_to_box_distance_partial() {
let box_ = NdarrayBox::new(array![0.0, 0.0], array![1.0, 1.0], 1.0).unwrap();
let point = array![0.5, 2.0]; let dist = vector_to_box_distance(&point, &box_).unwrap();
assert!((dist - 1.0).abs() < 1e-5);
}
#[test]
fn test_boundary_distance_contained() {
let outer = NdarrayBox::new(array![0.0, 0.0], array![1.0, 1.0], 1.0).unwrap();
let inner = NdarrayBox::new(array![0.2, 0.2], array![0.8, 0.8], 1.0).unwrap();
let dist = boundary_distance(&outer, &inner).unwrap();
assert!(dist.is_some());
let dist_val = dist.unwrap();
assert!(dist_val >= 0.0);
assert!(dist_val <= 0.2); }
#[test]
fn test_boundary_distance_not_contained() {
let outer = NdarrayBox::new(array![0.0, 0.0], array![1.0, 1.0], 1.0).unwrap();
let inner = NdarrayBox::new(array![0.5, 0.5], array![1.5, 1.5], 1.0).unwrap();
let dist = boundary_distance(&outer, &inner).unwrap();
assert!(dist.is_none()); }
#[test]
fn test_query2box_inside() {
let q = NdarrayBox::new(array![0.0, 0.0], array![2.0, 2.0], 1.0).unwrap();
let e = array![1.0, 1.0]; let d = query2box_distance(&q, &e, 0.02).unwrap();
assert_eq!(d, 0.0, "entity at center: distance should be 0");
}
#[test]
fn test_query2box_outside() {
let q = NdarrayBox::new(array![0.0, 0.0], array![2.0, 2.0], 1.0).unwrap();
let e = array![5.0, 5.0];
let d = query2box_distance(&q, &e, 0.02).unwrap();
assert!((d - 6.0).abs() < 1e-5, "expected 6.0, got {d}");
}
#[test]
fn test_query2box_dim_mismatch() {
let q = NdarrayBox::new(array![0.0, 0.0], array![2.0, 2.0], 1.0).unwrap();
let e = array![1.0];
assert!(query2box_distance(&q, &e, 0.02).is_err());
}
}