pub trait Region {
fn dim(&self) -> usize;
fn center(&self) -> &[f32];
fn distance_to_point(&self, point: &[f32]) -> f32;
fn contains(&self, point: &[f32]) -> bool;
}
#[derive(Debug, Clone)]
pub struct AxisBox {
min: Vec<f32>,
max: Vec<f32>,
center: Vec<f32>,
}
impl AxisBox {
pub fn new(min: Vec<f32>, max: Vec<f32>) -> Self {
assert_eq!(min.len(), max.len(), "min/max dimension mismatch");
debug_assert!(
min.iter().zip(max.iter()).all(|(lo, hi)| lo <= hi),
"min must be <= max in every dimension"
);
let center: Vec<f32> = min
.iter()
.zip(max.iter())
.map(|(lo, hi)| (lo + hi) * 0.5)
.collect();
Self { min, max, center }
}
pub fn from_center_offset(center: Vec<f32>, offset: Vec<f32>) -> Self {
assert_eq!(center.len(), offset.len());
let min: Vec<f32> = center
.iter()
.zip(offset.iter())
.map(|(c, o)| c - o.abs())
.collect();
let max: Vec<f32> = center
.iter()
.zip(offset.iter())
.map(|(c, o)| c + o.abs())
.collect();
Self { min, max, center }
}
pub fn from_mu_delta(mu: Vec<f32>, delta: Vec<f32>) -> Self {
assert_eq!(mu.len(), delta.len());
let min: Vec<f32> = mu
.iter()
.zip(delta.iter())
.map(|(m, d)| m - d.exp() * 0.5)
.collect();
let max: Vec<f32> = mu
.iter()
.zip(delta.iter())
.map(|(m, d)| m + d.exp() * 0.5)
.collect();
let center = mu;
Self { min, max, center }
}
pub fn min(&self) -> &[f32] {
&self.min
}
pub fn max(&self) -> &[f32] {
&self.max
}
pub fn half_widths(&self) -> Vec<f32> {
self.min
.iter()
.zip(self.max.iter())
.map(|(lo, hi)| (hi - lo) * 0.5)
.collect()
}
pub fn log_volume(&self) -> f32 {
self.min
.iter()
.zip(self.max.iter())
.map(|(lo, hi)| (hi - lo).ln())
.sum()
}
}
impl Region for AxisBox {
fn dim(&self) -> usize {
self.min.len()
}
fn center(&self) -> &[f32] {
&self.center
}
fn distance_to_point(&self, point: &[f32]) -> f32 {
crate::distance::box_to_point_l2(&self.min, &self.max, point)
}
fn contains(&self, point: &[f32]) -> bool {
point
.iter()
.zip(self.min.iter())
.zip(self.max.iter())
.all(|((p, lo), hi)| *p >= *lo && *p <= *hi)
}
}
#[derive(Debug, Clone)]
pub struct Ball {
center: Vec<f32>,
radius: f32,
}
impl Ball {
pub fn new(center: Vec<f32>, radius: f32) -> Self {
assert!(radius >= 0.0, "radius must be non-negative");
Self { center, radius }
}
pub fn radius(&self) -> f32 {
self.radius
}
}
impl Region for Ball {
fn dim(&self) -> usize {
self.center.len()
}
fn center(&self) -> &[f32] {
&self.center
}
fn distance_to_point(&self, point: &[f32]) -> f32 {
crate::distance::ball_to_point_l2(&self.center, self.radius, point)
}
fn contains(&self, point: &[f32]) -> bool {
let dist_sq: f32 = self
.center
.iter()
.zip(point.iter())
.map(|(c, p)| (c - p).powi(2))
.sum();
dist_sq <= self.radius * self.radius
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn box_contains_interior_point() {
let b = AxisBox::new(vec![0.0, 0.0], vec![1.0, 1.0]);
assert!(b.contains(&[0.5, 0.5]));
assert!(b.contains(&[0.0, 0.0])); assert!(b.contains(&[1.0, 1.0])); assert!(!b.contains(&[1.5, 0.5]));
assert!(!b.contains(&[-0.1, 0.5]));
}
#[test]
fn box_distance_inside_is_zero() {
let b = AxisBox::new(vec![0.0, 0.0], vec![2.0, 2.0]);
assert_eq!(b.distance_to_point(&[1.0, 1.0]), 0.0);
}
#[test]
fn box_distance_outside() {
let b = AxisBox::new(vec![0.0, 0.0], vec![1.0, 1.0]);
let d = b.distance_to_point(&[2.0, 0.5]);
assert!((d - 1.0).abs() < 1e-6);
}
#[test]
fn box_distance_corner() {
let b = AxisBox::new(vec![0.0, 0.0], vec![1.0, 1.0]);
let d = b.distance_to_point(&[2.0, 2.0]);
assert!((d - std::f32::consts::SQRT_2).abs() < 1e-6);
}
#[test]
fn box_from_center_offset() {
let b = AxisBox::from_center_offset(vec![1.0, 1.0], vec![0.5, 0.5]);
assert!((b.min()[0] - 0.5).abs() < 1e-6);
assert!((b.max()[0] - 1.5).abs() < 1e-6);
assert!(b.contains(&[1.0, 1.0]));
assert!(!b.contains(&[2.0, 1.0]));
}
#[test]
fn box_from_mu_delta() {
let b = AxisBox::from_mu_delta(vec![0.0, 0.0], vec![0.0, 0.0]);
assert!((b.min()[0] - (-0.5)).abs() < 1e-6);
assert!((b.max()[0] - 0.5).abs() < 1e-6);
}
#[test]
fn ball_contains_and_distance() {
let ball = Ball::new(vec![0.0, 0.0, 0.0], 1.0);
assert!(ball.contains(&[0.5, 0.5, 0.5])); assert!(!ball.contains(&[1.0, 1.0, 0.0]));
assert_eq!(ball.distance_to_point(&[0.0, 0.0, 0.0]), 0.0);
let d = ball.distance_to_point(&[2.0, 0.0, 0.0]);
assert!((d - 1.0).abs() < 1e-6);
}
#[test]
fn ball_boundary_is_contained() {
let ball = Ball::new(vec![0.0, 0.0], 1.0);
assert!(ball.contains(&[1.0, 0.0]));
assert!(ball.contains(&[0.0, 1.0]));
}
}