use crate::box_trait::BoxError;
use crate::HyperBox;
pub trait Region: Sized {
const DEFAULT_K: f32 = 1.0;
fn dim(&self) -> usize;
fn subsumption_score(&self, inner: &Self) -> Result<f32, BoxError>;
}
macro_rules! impl_region_via_hyperbox {
($ty:ty) => {
impl Region for $ty {
fn dim(&self) -> usize {
HyperBox::dim(self)
}
fn subsumption_score(&self, inner: &Self) -> Result<f32, BoxError> {
HyperBox::containment_prob(self, inner)
}
}
};
}
impl_region_via_hyperbox!(crate::ndarray_backend::NdarrayBox);
impl_region_via_hyperbox!(crate::ndarray_backend::NdarrayGumbelBox);
#[cfg(feature = "candle-backend")]
impl_region_via_hyperbox!(crate::candle_backend::CandleBox);
#[cfg(feature = "candle-backend")]
impl_region_via_hyperbox!(crate::candle_backend::CandleGumbelBox);
impl Region for crate::Ball {
fn dim(&self) -> usize {
self.dim()
}
fn subsumption_score(&self, inner: &Self) -> Result<f32, BoxError> {
crate::ball::containment_prob(inner, self, Self::DEFAULT_K)
}
}
impl Region for crate::Ellipsoid {
fn dim(&self) -> usize {
self.dim()
}
fn subsumption_score(&self, inner: &Self) -> Result<f32, BoxError> {
crate::ellipsoid::containment_prob(inner, self, Self::DEFAULT_K)
}
}
impl Region for crate::Subspace {
fn dim(&self) -> usize {
self.dim()
}
fn subsumption_score(&self, inner: &Self) -> Result<f32, BoxError> {
crate::subspace::containment_score(inner, self)
}
}
impl Region for crate::GaussianBox {
fn dim(&self) -> usize {
self.dim()
}
fn subsumption_score(&self, inner: &Self) -> Result<f32, BoxError> {
crate::gaussian::containment_prob(inner, self, Self::DEFAULT_K)
}
}
impl Region for crate::SphericalCap {
fn dim(&self) -> usize {
self.dim()
}
fn subsumption_score(&self, inner: &Self) -> Result<f32, BoxError> {
crate::spherical_cap::containment_prob(inner, self, Self::DEFAULT_K)
}
}
impl Region for crate::AnnularSector {
fn dim(&self) -> usize {
2
}
fn subsumption_score(&self, inner: &Self) -> Result<f32, BoxError> {
Ok(crate::annular::containment_score(inner, self))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn most_subsumed<'a, R: Region>(outer: &R, candidates: &'a [R]) -> &'a R {
candidates
.iter()
.max_by(|a, b| {
outer
.subsumption_score(a)
.unwrap()
.total_cmp(&outer.subsumption_score(b).unwrap())
})
.unwrap()
}
#[test]
fn ball_subsumption_is_monotone() {
let outer = crate::Ball::new(vec![0.0, 0.0], 2.0).unwrap();
let inside = crate::Ball::new(vec![0.0, 0.0], 0.3).unwrap(); let straddling = crate::Ball::new(vec![1.9, 0.0], 1.0).unwrap(); assert_eq!(Region::dim(&outer), 2);
let s_in = outer.subsumption_score(&inside).unwrap();
let s_out = outer.subsumption_score(&straddling).unwrap();
assert!(
s_in > s_out,
"a deeply-contained ball must score higher than a straddling one: {s_in} vs {s_out}"
);
let cands = vec![straddling, inside];
let best = most_subsumed(&outer, &cands);
assert!(
(best.radius() - 0.3).abs() < 1e-6,
"ranking should pick the inside ball"
);
}
#[test]
fn ball_region_matches_free_function() {
let outer = crate::Ball::new(vec![0.0, 0.0], 2.0).unwrap();
let inner = crate::Ball::new(vec![0.0, 0.0], 0.5).unwrap();
let via_trait = outer.subsumption_score(&inner).unwrap();
let via_free =
crate::ball::containment_prob(&inner, &outer, <crate::Ball as Region>::DEFAULT_K)
.unwrap();
assert_eq!(via_trait, via_free);
}
#[test]
fn ndarray_box_is_a_region() {
use crate::ndarray_backend::NdarrayBox;
use ndarray::array;
let outer = NdarrayBox::new(array![0.0, 0.0], array![1.0, 1.0], 1.0).unwrap();
let inner = NdarrayBox::new(array![0.2, 0.2], array![0.8, 0.8], 1.0).unwrap();
let disjoint = NdarrayBox::new(array![5.0, 5.0], array![6.0, 6.0], 1.0).unwrap();
fn score<R: Region>(outer: &R, inner: &R) -> (usize, f32) {
(outer.dim(), outer.subsumption_score(inner).unwrap())
}
let (d, p_in) = score(&outer, &inner);
let (_, p_out) = score(&outer, &disjoint);
assert_eq!(d, 2);
assert!(
p_in > p_out,
"contained box should outscore disjoint: {p_in} vs {p_out}"
);
}
fn monotone<R: Region>(outer: &R, contained: &R, loose: &R) -> (f32, f32) {
(
outer.subsumption_score(contained).unwrap(),
outer.subsumption_score(loose).unwrap(),
)
}
#[test]
fn gaussian_region_is_monotone() {
let outer = crate::GaussianBox::new(vec![0.0, 0.0], vec![2.0, 2.0]).unwrap();
let inside = crate::GaussianBox::new(vec![0.0, 0.0], vec![0.4, 0.4]).unwrap();
let far = crate::GaussianBox::new(vec![6.0, 6.0], vec![0.4, 0.4]).unwrap();
assert_eq!(Region::dim(&outer), 2);
let (s_in, s_out) = monotone(&outer, &inside, &far);
assert!(
s_in > s_out,
"concentric child should outscore the far one: {s_in} vs {s_out}"
);
}
#[test]
fn spherical_cap_region_is_monotone() {
let outer = crate::SphericalCap::new(vec![0.0, 0.0, 1.0], 1.2).unwrap();
let inside = crate::SphericalCap::new(vec![0.0, 0.0, 1.0], 0.2).unwrap();
let away = crate::SphericalCap::new(vec![0.0, 0.0, -1.0], 0.2).unwrap();
assert_eq!(Region::dim(&outer), 3);
let (s_in, s_out) = monotone(&outer, &inside, &away);
assert!(
s_in > s_out,
"aligned inner cap should outscore the opposed one: {s_in} vs {s_out}"
);
}
#[test]
fn annular_sector_region_is_monotone() {
let outer = crate::AnnularSector::new(0.0, 0.0, 1.0, 4.0, 0.0, 2.0).unwrap();
let inside = crate::AnnularSector::new(0.0, 0.0, 1.5, 3.5, 0.5, 1.5).unwrap();
let spilling = crate::AnnularSector::new(0.0, 0.0, 0.2, 6.0, 0.0, 3.0).unwrap();
assert_eq!(Region::dim(&outer), 2);
let (s_in, s_out) = monotone(&outer, &inside, &spilling);
assert!(
s_in > s_out,
"contained sector should outscore the spilling one: {s_in} vs {s_out}"
);
}
#[test]
fn ellipsoid_region_is_monotone() {
let outer = crate::Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![1.0, 1.0]).unwrap();
let inside = crate::Ellipsoid::from_log_diagonal(vec![0.0, 0.0], vec![-1.0, -1.0]).unwrap();
let far = crate::Ellipsoid::from_log_diagonal(vec![6.0, 6.0], vec![-1.0, -1.0]).unwrap();
assert_eq!(Region::dim(&outer), 2);
let (s_in, s_out) = monotone(&outer, &inside, &far);
assert!(
s_in > s_out,
"concentric child should outscore the far one: {s_in} vs {s_out}"
);
}
#[test]
fn subspace_region_is_monotone() {
let outer = crate::Subspace::new(vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]]).unwrap();
let inside = crate::Subspace::new(vec![vec![1.0, 0.0, 0.0]]).unwrap();
let orthogonal = crate::Subspace::new(vec![vec![0.0, 0.0, 1.0]]).unwrap();
assert_eq!(Region::dim(&outer), 3); let (s_in, s_out) = monotone(&outer, &inside, &orthogonal);
assert!(
s_in > s_out,
"contained axis should outscore the orthogonal one: {s_in} vs {s_out}"
);
}
}