use crate::ndarray_backend::ndarray_box::NdarrayBox;
use crate::utils::{bessel_log_volume, gumbel_lse_max, gumbel_lse_min};
use crate::utils::{gumbel_membership_prob, map_gumbel_to_bounds, sample_gumbel};
use crate::{Box, BoxError};
use ndarray::Array1;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Debug, Clone)]
pub struct NdarrayGumbelBox {
inner: NdarrayBox,
}
impl Serialize for NdarrayGumbelBox {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.inner.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for NdarrayGumbelBox {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let inner = NdarrayBox::deserialize(deserializer)?;
Ok(Self { inner })
}
}
impl NdarrayGumbelBox {
pub fn new(min: Array1<f32>, max: Array1<f32>, temperature: f32) -> Result<Self, BoxError> {
Ok(Self {
inner: NdarrayBox::new(min, max, temperature)?,
})
}
}
impl Box for NdarrayGumbelBox {
type Scalar = f32;
type Vector = Array1<f32>;
fn min(&self) -> &Self::Vector {
self.inner.min()
}
fn max(&self) -> &Self::Vector {
self.inner.max()
}
fn dim(&self) -> usize {
self.inner.dim()
}
fn volume(&self) -> Result<Self::Scalar, BoxError> {
let t = self.inner.temperature;
let mins = self.min().as_slice().unwrap();
let maxs = self.max().as_slice().unwrap();
let (_, vol) = bessel_log_volume(mins, maxs, t, t);
Ok(vol)
}
fn intersection(&self, other: &Self) -> Result<Self, BoxError> {
if self.dim() != other.dim() {
return Err(BoxError::DimensionMismatch {
expected: self.dim(),
actual: other.dim(),
});
}
let t = self.inner.temperature;
let n = self.dim();
let mut new_min = Vec::with_capacity(n);
let mut new_max = Vec::with_capacity(n);
for d in 0..n {
new_min.push(gumbel_lse_min(self.min()[d], other.min()[d], t));
new_max.push(gumbel_lse_max(self.max()[d], other.max()[d], t));
}
Ok(Self {
inner: NdarrayBox::new_unchecked(Array1::from(new_min), Array1::from(new_max), t),
})
}
fn containment_prob(&self, other: &Self) -> Result<Self::Scalar, BoxError> {
let inter = self.intersection(other)?;
let inter_vol = inter.volume()?;
let other_vol = other.volume()?;
if other_vol <= 1e-30 {
return Ok(0.0);
}
Ok((inter_vol / other_vol).clamp(0.0, 1.0))
}
fn overlap_prob(&self, other: &Self) -> Result<Self::Scalar, BoxError> {
let inter = self.intersection(other)?;
let inter_vol = inter.volume()?;
let self_vol = self.volume()?;
let other_vol = other.volume()?;
let union_vol = self_vol + other_vol - inter_vol;
if union_vol <= 1e-30 {
return Ok(0.0);
}
Ok((inter_vol / union_vol).clamp(0.0, 1.0))
}
fn union(&self, other: &Self) -> Result<Self, BoxError> {
Ok(Self {
inner: self.inner.union(&other.inner)?,
})
}
fn center(&self) -> Result<Self::Vector, BoxError> {
self.inner.center()
}
fn distance(&self, other: &Self) -> Result<Self::Scalar, BoxError> {
self.inner.distance(&other.inner)
}
fn truncate(&self, k: usize) -> Result<Self, BoxError> {
Ok(Self {
inner: self.inner.truncate(k)?,
})
}
}
impl NdarrayGumbelBox {
pub fn temperature(&self) -> f32 {
self.inner.temperature
}
pub fn membership_probability(&self, point: &Array1<f32>) -> Result<f32, BoxError> {
if point.len() != self.dim() {
return Err(BoxError::DimensionMismatch {
expected: self.dim(),
actual: point.len(),
});
}
let temp = self.temperature();
let mut prob = 1.0;
for (i, &coord) in point.iter().enumerate() {
let dim_prob = gumbel_membership_prob(coord, self.min()[i], self.max()[i], temp);
prob *= dim_prob;
}
Ok(prob)
}
pub fn sample(&self) -> Array1<f32> {
use ndarray::Array1;
let temp = self.temperature();
let dim = self.dim();
let mut seed = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64;
const A: u64 = 1664525;
const C: u64 = 1013904223;
const M: u64 = 1u64 << 32;
let mut sampled = Vec::with_capacity(dim);
for i in 0..dim {
seed = (A.wrapping_mul(seed).wrapping_add(C)) % M;
let u = (seed as f32) / (M as f32);
let gumbel = sample_gumbel(u, 1e-7);
let value = map_gumbel_to_bounds(gumbel, self.min()[i], self.max()[i], temp);
sampled.push(value);
}
Array1::from(sampled)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Box as BoxTrait;
use ndarray::array;
#[test]
fn membership_prob_inside_point_is_high() {
let gb = NdarrayGumbelBox::new(array![0.0, 0.0], array![2.0, 2.0], 1.0).unwrap();
let point = array![1.0, 1.0]; let p = gb.membership_probability(&point).unwrap();
assert!(
p > 0.2,
"Center point should have non-trivial membership, got {}",
p
);
assert!(p <= 1.0, "Membership must be <= 1.0");
}
#[test]
fn membership_prob_far_outside_is_low() {
let gb = NdarrayGumbelBox::new(array![0.0, 0.0], array![1.0, 1.0], 1.0).unwrap();
let point = array![10.0, 10.0];
let p = gb.membership_probability(&point).unwrap();
assert!(
p < 0.01,
"Far-outside point should have near-zero membership, got {}",
p
);
assert!(p >= 0.0, "Membership must be >= 0.0");
}
#[test]
fn membership_prob_always_in_unit_interval() {
let gb = NdarrayGumbelBox::new(array![0.0, 0.0, 0.0], array![1.0, 1.0, 1.0], 0.5).unwrap();
let test_points = vec![
array![0.5, 0.5, 0.5], array![-5.0, -5.0, -5.0], array![10.0, 10.0, 10.0], array![0.0, 0.0, 0.0], array![1.0, 1.0, 1.0], ];
for pt in &test_points {
let p = gb.membership_probability(pt).unwrap();
assert!(
(0.0..=1.0).contains(&p),
"Membership {} out of [0,1] for point {:?}",
p,
pt
);
}
}
#[test]
fn containment_monotonicity_nested_gumbel_boxes() {
let a = NdarrayGumbelBox::new(array![0.0, 0.0], array![10.0, 10.0], 1.0).unwrap();
let b = NdarrayGumbelBox::new(array![1.0, 1.0], array![9.0, 9.0], 1.0).unwrap();
let c = NdarrayGumbelBox::new(array![2.0, 2.0], array![8.0, 8.0], 1.0).unwrap();
let p_b_in_a = a.containment_prob(&b).unwrap();
let p_c_in_a = a.containment_prob(&c).unwrap();
let p_c_in_b = b.containment_prob(&c).unwrap();
assert!(p_b_in_a > 0.7, "B should be inside A, got {}", p_b_in_a);
assert!(p_c_in_a > 0.7, "C should be inside A, got {}", p_c_in_a);
assert!(p_c_in_b > 0.7, "C should be inside B, got {}", p_c_in_b);
let a_sharp = NdarrayGumbelBox::new(array![0.0, 0.0], array![10.0, 10.0], 0.01).unwrap();
let b_sharp = NdarrayGumbelBox::new(array![1.0, 1.0], array![9.0, 9.0], 0.01).unwrap();
let p_sharp = a_sharp.containment_prob(&b_sharp).unwrap();
assert!(
p_sharp > 0.99,
"At low T, containment should be ~1.0, got {}",
p_sharp
);
}
#[test]
fn low_temperature_sharpens_membership() {
let gb_sharp = NdarrayGumbelBox::new(array![0.0, 0.0], array![2.0, 2.0], 0.01).unwrap();
let gb_soft = NdarrayGumbelBox::new(array![0.0, 0.0], array![2.0, 2.0], 100.0).unwrap();
let inside = array![1.0, 1.0];
let p_sharp = gb_sharp.membership_probability(&inside).unwrap();
let p_soft = gb_soft.membership_probability(&inside).unwrap();
assert!(
p_sharp > p_soft || (p_sharp - p_soft).abs() < 0.05,
"Low temp should give sharper membership for interior: sharp={}, soft={}",
p_sharp,
p_soft
);
let outside = array![5.0, 5.0];
let p_sharp_out = gb_sharp.membership_probability(&outside).unwrap();
let p_soft_out = gb_soft.membership_probability(&outside).unwrap();
assert!(
p_sharp_out < p_soft_out || (p_sharp_out - p_soft_out).abs() < 0.05,
"Low temp should give lower membership for exterior: sharp={}, soft={}",
p_sharp_out,
p_soft_out
);
}
#[test]
fn gumbel_box_serde_round_trip() {
let original = NdarrayGumbelBox::new(array![0.1, 0.2], array![0.8, 0.9], 0.5).unwrap();
let json = serde_json::to_string(&original).expect("serialize");
let deserialized: NdarrayGumbelBox = serde_json::from_str(&json).expect("deserialize");
assert_eq!(original.dim(), deserialized.dim());
assert_eq!(original.temperature(), deserialized.temperature());
for i in 0..original.dim() {
assert!(
(original.min()[i] - deserialized.min()[i]).abs() < 1e-7,
"min mismatch at dim {}",
i
);
assert!(
(original.max()[i] - deserialized.max()[i]).abs() < 1e-7,
"max mismatch at dim {}",
i
);
}
}
#[test]
fn sample_produces_finite_values() {
let gb = NdarrayGumbelBox::new(array![0.0, 0.0, 0.0], array![1.0, 1.0, 1.0], 1.0).unwrap();
let s = gb.sample();
assert_eq!(s.len(), 3);
for &v in s.iter() {
assert!(v.is_finite(), "Sampled value must be finite, got {}", v);
}
}
#[test]
fn gumbel_box_dim_mismatch_error() {
let result = NdarrayGumbelBox::new(array![0.0, 0.0], array![1.0], 1.0);
assert!(result.is_err());
}
#[test]
fn gumbel_box_invalid_bounds_error() {
let result = NdarrayGumbelBox::new(array![5.0], array![1.0], 1.0);
assert!(result.is_err());
}
#[test]
fn membership_prob_dimension_mismatch() {
let gb = NdarrayGumbelBox::new(array![0.0, 0.0], array![1.0, 1.0], 1.0).unwrap();
let point = array![0.5]; let result = gb.membership_probability(&point);
assert!(result.is_err());
}
#[test]
fn temperature_accessor_returns_construction_value() {
let gb = NdarrayGumbelBox::new(array![0.0], array![1.0], 0.42).unwrap();
assert!((gb.temperature() - 0.42).abs() < 1e-7);
}
#[test]
fn membership_at_boundary_is_near_half_per_dim() {
let gb = NdarrayGumbelBox::new(array![0.0], array![10.0], 1.0).unwrap();
let p_at_min = gb.membership_probability(&array![0.0]).unwrap();
let p_at_max = gb.membership_probability(&array![10.0]).unwrap();
assert!(
(p_at_min - 0.5).abs() < 0.05,
"At min boundary expected ~0.5, got {}",
p_at_min
);
assert!(
(p_at_max - 0.5).abs() < 0.05,
"At max boundary expected ~0.5, got {}",
p_at_max
);
}
#[test]
fn very_low_temperature_membership_approaches_hard() {
let gb = NdarrayGumbelBox::new(array![0.0, 0.0], array![4.0, 4.0], 0.001).unwrap();
let center = array![2.0, 2.0];
let p = gb.membership_probability(¢er).unwrap();
assert!(
p > 0.99,
"At very low temp, center should have membership ~1.0, got {}",
p
);
}
#[test]
fn very_low_temperature_outside_membership_approaches_zero() {
let gb = NdarrayGumbelBox::new(array![0.0, 0.0], array![4.0, 4.0], 0.001).unwrap();
let outside = array![-1.0, -1.0];
let p = gb.membership_probability(&outside).unwrap();
assert!(
p < 0.01,
"At very low temp, outside should have membership ~0.0, got {}",
p
);
}
#[test]
fn gumbel_box_intersection_uses_lse() {
let a = NdarrayGumbelBox::new(array![0.0, 0.0], array![2.0, 2.0], 1.0).unwrap();
let b = NdarrayGumbelBox::new(array![1.0, 1.0], array![3.0, 3.0], 1.0).unwrap();
let inter = a.intersection(&b).unwrap();
let vol = inter.volume().unwrap();
assert!(
vol > 0.0,
"intersection volume should be positive, got {}",
vol
);
assert!(
vol < 1.0,
"LSE intersection volume should be < hard vol (1.0), got {}",
vol
);
let a_sharp = NdarrayGumbelBox::new(array![0.0, 0.0], array![2.0, 2.0], 0.01).unwrap();
let b_sharp = NdarrayGumbelBox::new(array![1.0, 1.0], array![3.0, 3.0], 0.01).unwrap();
let inter_sharp = a_sharp.intersection(&b_sharp).unwrap();
let vol_sharp = inter_sharp.volume().unwrap();
assert!(
(vol_sharp - 1.0).abs() < 0.1,
"At low T, intersection volume should be ~1.0, got {}",
vol_sharp
);
}
#[test]
fn gumbel_box_union_delegates() {
let a = NdarrayGumbelBox::new(array![0.0, 0.0], array![1.0, 1.0], 1.0).unwrap();
let u = a.union(&a).unwrap();
let vol_a = a.volume().unwrap();
let vol_u = u.volume().unwrap();
assert!((vol_a - vol_u).abs() < 1e-6);
}
#[test]
fn gumbel_box_center_delegates() {
let gb = NdarrayGumbelBox::new(array![0.0, 4.0], array![2.0, 8.0], 1.0).unwrap();
let c = gb.center().unwrap();
assert!((c[0] - 1.0).abs() < 1e-6);
assert!((c[1] - 6.0).abs() < 1e-6);
}
#[test]
fn gumbel_box_distance_delegates() {
let a = NdarrayGumbelBox::new(array![0.0, 0.0], array![1.0, 1.0], 1.0).unwrap();
let b = NdarrayGumbelBox::new(array![3.0, 0.0], array![4.0, 1.0], 1.0).unwrap();
let d = a.distance(&b).unwrap();
assert!(
(d - 2.0).abs() < 1e-5,
"Gap of 2 in x only, expected 2.0, got {}",
d
);
}
#[test]
fn gumbel_box_truncate_delegates() {
let gb = NdarrayGumbelBox::new(array![0.0, 1.0, 2.0], array![3.0, 4.0, 5.0], 0.5).unwrap();
let t = gb.truncate(2).unwrap();
assert_eq!(t.dim(), 2);
assert!((t.min()[0] - 0.0).abs() < 1e-7);
assert!((t.max()[1] - 4.0).abs() < 1e-7);
}
#[test]
fn membership_decreases_moving_away_from_center() {
let gb = NdarrayGumbelBox::new(array![0.0, 0.0], array![4.0, 4.0], 1.0).unwrap();
let center = array![2.0, 2.0];
let near_edge = array![3.5, 3.5];
let outside = array![6.0, 6.0];
let p_center = gb.membership_probability(¢er).unwrap();
let p_edge = gb.membership_probability(&near_edge).unwrap();
let p_out = gb.membership_probability(&outside).unwrap();
assert!(
p_center > p_edge,
"Center ({}) should have higher membership than near-edge ({})",
p_center,
p_edge
);
assert!(
p_edge > p_out,
"Near-edge ({}) should have higher membership than outside ({})",
p_edge,
p_out
);
}
#[test]
fn bessel_volume_monotone_in_side_length() {
let small = NdarrayGumbelBox::new(array![0.0], array![2.0], 1.0).unwrap();
let large = NdarrayGumbelBox::new(array![0.0], array![5.0], 1.0).unwrap();
let v_small = small.volume().unwrap();
let v_large = large.volume().unwrap();
assert!(
v_large > v_small,
"larger box should have larger vol: {v_large} vs {v_small}"
);
}
#[test]
fn bessel_volume_positive_for_nonempty_box() {
let b = NdarrayGumbelBox::new(array![0.0, 0.0, 0.0], array![1.0, 1.0, 1.0], 1.0).unwrap();
let v = b.volume().unwrap();
assert!(
v > 0.0,
"non-empty box should have positive volume, got {v}"
);
assert!(v.is_finite(), "volume should be finite");
}
#[test]
fn bessel_volume_approaches_hard_at_low_temperature() {
let b = NdarrayGumbelBox::new(array![0.0, 0.0], array![3.0, 4.0], 0.01).unwrap();
let v = b.volume().unwrap();
let hard_vol = 3.0 * 4.0;
assert!(
(v - hard_vol).abs() / hard_vol < 0.05,
"At low T, Bessel vol ({v}) should be close to hard vol ({hard_vol})"
);
}
#[test]
fn bessel_volume_smaller_than_hard_at_high_temperature() {
let b = NdarrayGumbelBox::new(array![0.0, 0.0], array![5.0, 5.0], 1.0).unwrap();
let bessel_vol = b.volume().unwrap();
let hard_vol = 5.0 * 5.0;
assert!(
bessel_vol < hard_vol,
"Bessel vol ({bessel_vol}) should be < hard vol ({hard_vol}) due to gamma offset"
);
}
#[test]
fn lse_intersection_symmetric() {
let a = NdarrayGumbelBox::new(array![0.0, 0.0], array![3.0, 3.0], 1.0).unwrap();
let b = NdarrayGumbelBox::new(array![1.0, 1.0], array![4.0, 4.0], 1.0).unwrap();
let ab = a.intersection(&b).unwrap();
let ba = b.intersection(&a).unwrap();
for d in 0..2 {
assert!(
(ab.min()[d] - ba.min()[d]).abs() < 1e-6,
"intersection min should be symmetric at dim {d}"
);
assert!(
(ab.max()[d] - ba.max()[d]).abs() < 1e-6,
"intersection max should be symmetric at dim {d}"
);
}
}
#[test]
fn lse_intersection_approaches_hard_at_low_temperature() {
let a = NdarrayGumbelBox::new(array![0.0, 0.0], array![3.0, 3.0], 0.01).unwrap();
let b = NdarrayGumbelBox::new(array![1.0, 1.0], array![4.0, 4.0], 0.01).unwrap();
let inter = a.intersection(&b).unwrap();
assert!(
(inter.min()[0] - 1.0).abs() < 0.05,
"LSE min should be ~1.0 at low T, got {}",
inter.min()[0]
);
assert!(
(inter.max()[0] - 3.0).abs() < 0.05,
"LSE max should be ~3.0 at low T, got {}",
inter.max()[0]
);
}
#[test]
fn lse_intersection_bounds_are_inside_parents() {
let a = NdarrayGumbelBox::new(array![0.0, 1.0], array![5.0, 6.0], 1.0).unwrap();
let b = NdarrayGumbelBox::new(array![2.0, 0.0], array![4.0, 7.0], 1.0).unwrap();
let inter = a.intersection(&b).unwrap();
assert!(
inter.min()[0] >= 2.0 - 1e-6,
"lse min[0] ({}) should be >= 2.0",
inter.min()[0]
);
assert!(
inter.min()[1] >= 1.0 - 1e-6,
"lse min[1] ({}) should be >= 1.0",
inter.min()[1]
);
assert!(
inter.max()[0] <= 4.0 + 1e-6,
"lse max[0] ({}) should be <= 4.0",
inter.max()[0]
);
assert!(
inter.max()[1] <= 6.0 + 1e-6,
"lse max[1] ({}) should be <= 6.0",
inter.max()[1]
);
}
#[test]
fn disjoint_boxes_have_near_zero_gumbel_containment() {
let a = NdarrayGumbelBox::new(array![0.0, 0.0], array![1.0, 1.0], 1.0).unwrap();
let b = NdarrayGumbelBox::new(array![10.0, 10.0], array![11.0, 11.0], 1.0).unwrap();
let p = a.containment_prob(&b).unwrap();
assert!(
p < 0.01,
"disjoint boxes should have near-zero containment, got {p}"
);
}
#[test]
fn gumbel_overlap_prob_reasonable() {
let a = NdarrayGumbelBox::new(array![0.0, 0.0], array![3.0, 3.0], 1.0).unwrap();
let b = NdarrayGumbelBox::new(array![1.0, 1.0], array![4.0, 4.0], 1.0).unwrap();
let p = a.overlap_prob(&b).unwrap();
assert!(
p > 0.0,
"overlapping boxes should have positive overlap, got {p}"
);
assert!(p <= 1.0, "overlap should be <= 1.0");
let c = NdarrayGumbelBox::new(array![10.0, 10.0], array![11.0, 11.0], 1.0).unwrap();
let p_disjoint = a.overlap_prob(&c).unwrap();
assert!(
p_disjoint < 0.01,
"disjoint boxes should have near-zero overlap, got {p_disjoint}"
);
}
}
#[cfg(test)]
mod proptest_tests {
use super::*;
use ndarray::Array1;
use proptest::prelude::*;
proptest! {
#[test]
fn proptest_membership_in_unit_interval(
box_pairs in proptest::collection::vec((-20.0f32..20.0f32, 1.0f32..10.0f32), 1..=5),
temp in 0.01f32..10.0f32,
point_coords in proptest::collection::vec(-30.0f32..30.0f32, 1..=5),
) {
let dim = box_pairs.len();
prop_assume!(point_coords.len() >= dim);
let mut mins = Vec::with_capacity(dim);
let mut maxs = Vec::with_capacity(dim);
for (lo, width) in &box_pairs {
mins.push(*lo);
maxs.push(*lo + *width);
}
let gb = NdarrayGumbelBox::new(
Array1::from(mins),
Array1::from(maxs),
temp,
).unwrap();
let point = Array1::from(point_coords[..dim].to_vec());
let p = gb.membership_probability(&point).unwrap();
prop_assert!(
(0.0..=1.0).contains(&p),
"membership_probability must be in [0,1], got {} (temp={}, point={:?})",
p, temp, point
);
prop_assert!(p.is_finite(), "membership must be finite, got {}", p);
}
}
proptest! {
#[test]
fn proptest_center_membership_gt_boundary(
box_pairs in proptest::collection::vec((-10.0f32..10.0f32, 2.0f32..10.0f32), 1..=4),
temp in 0.1f32..5.0f32,
) {
let dim = box_pairs.len();
let mut mins = Vec::with_capacity(dim);
let mut maxs = Vec::with_capacity(dim);
for (lo, width) in &box_pairs {
mins.push(*lo);
maxs.push(*lo + *width);
}
let gb = NdarrayGumbelBox::new(
Array1::from(mins.clone()),
Array1::from(maxs.clone()),
temp,
).unwrap();
let center: Vec<f32> = mins.iter().zip(maxs.iter())
.map(|(lo, hi)| (lo + hi) / 2.0)
.collect();
let boundary: Vec<f32> = mins.clone();
let p_center = gb.membership_probability(&Array1::from(center)).unwrap();
let p_boundary = gb.membership_probability(&Array1::from(boundary)).unwrap();
prop_assert!(
p_center >= p_boundary - 1e-6,
"center membership ({}) should be >= boundary membership ({}), temp={}",
p_center, p_boundary, temp
);
}
}
proptest! {
#[test]
fn proptest_temperature_monotonicity_outside(
box_pairs in proptest::collection::vec((-10.0f32..10.0f32, 2.0f32..10.0f32), 1..=3),
temp_lo in 0.1f32..1.0f32,
temp_delta in 1.0f32..50.0f32,
offset in 1.0f32..5.0f32,
) {
let dim = box_pairs.len();
let temp_hi = temp_lo + temp_delta;
let mut mins = Vec::with_capacity(dim);
let mut maxs = Vec::with_capacity(dim);
for (lo, width) in &box_pairs {
mins.push(*lo);
maxs.push(*lo + *width);
}
let gb_lo = NdarrayGumbelBox::new(
Array1::from(mins.clone()),
Array1::from(maxs.clone()),
temp_lo,
).unwrap();
let gb_hi = NdarrayGumbelBox::new(
Array1::from(mins.clone()),
Array1::from(maxs.clone()),
temp_hi,
).unwrap();
let outside: Vec<f32> = mins.iter().map(|lo| lo - offset).collect();
let outside_pt = Array1::from(outside);
let p_lo = gb_lo.membership_probability(&outside_pt).unwrap();
let p_hi = gb_hi.membership_probability(&outside_pt).unwrap();
prop_assert!(
p_hi >= p_lo - 1e-5,
"higher temp ({}) should give >= membership for outside point than lower temp ({}): p_hi={}, p_lo={}",
temp_hi, temp_lo, p_hi, p_lo
);
}
}
proptest! {
#[test]
fn proptest_bessel_volume_non_negative(
box_pairs in proptest::collection::vec((-10.0f32..10.0f32, 0.5f32..10.0f32), 1..=5),
temp in 0.01f32..5.0f32,
) {
let dim = box_pairs.len();
let mut mins = Vec::with_capacity(dim);
let mut maxs = Vec::with_capacity(dim);
for (lo, width) in &box_pairs {
mins.push(*lo);
maxs.push(*lo + *width);
}
let gb = NdarrayGumbelBox::new(
Array1::from(mins),
Array1::from(maxs),
temp,
).unwrap();
let vol = gb.volume().unwrap();
prop_assert!(vol >= 0.0, "Bessel volume must be >= 0, got {vol}");
prop_assert!(vol.is_finite(), "Bessel volume must be finite, got {vol}");
}
}
proptest! {
#[test]
fn proptest_gumbel_containment_in_unit_interval(
a_pairs in proptest::collection::vec((-5.0f32..5.0f32, 1.0f32..5.0f32), 1..=3),
b_pairs in proptest::collection::vec((-5.0f32..5.0f32, 1.0f32..5.0f32), 1..=3),
temp in 0.1f32..3.0f32,
) {
let dim = a_pairs.len().min(b_pairs.len());
prop_assume!(dim > 0);
let a = NdarrayGumbelBox::new(
Array1::from(a_pairs[..dim].iter().map(|(lo, _)| *lo).collect::<Vec<_>>()),
Array1::from(a_pairs[..dim].iter().map(|(lo, w)| lo + w).collect::<Vec<_>>()),
temp,
).unwrap();
let b = NdarrayGumbelBox::new(
Array1::from(b_pairs[..dim].iter().map(|(lo, _)| *lo).collect::<Vec<_>>()),
Array1::from(b_pairs[..dim].iter().map(|(lo, w)| lo + w).collect::<Vec<_>>()),
temp,
).unwrap();
let p = a.containment_prob(&b).unwrap();
prop_assert!(p >= -1e-6, "containment prob must be >= 0, got {p}");
prop_assert!(p <= 1.0 + 1e-6, "containment prob must be <= 1, got {p}");
prop_assert!(p.is_finite(), "containment prob must be finite, got {p}");
}
}
proptest! {
#[test]
fn proptest_lse_intersection_symmetric(
a_pairs in proptest::collection::vec((-5.0f32..5.0f32, 1.0f32..5.0f32), 1..=3),
b_pairs in proptest::collection::vec((-5.0f32..5.0f32, 1.0f32..5.0f32), 1..=3),
temp in 0.1f32..3.0f32,
) {
let dim = a_pairs.len().min(b_pairs.len());
prop_assume!(dim > 0);
let a = NdarrayGumbelBox::new(
Array1::from(a_pairs[..dim].iter().map(|(lo, _)| *lo).collect::<Vec<_>>()),
Array1::from(a_pairs[..dim].iter().map(|(lo, w)| lo + w).collect::<Vec<_>>()),
temp,
).unwrap();
let b = NdarrayGumbelBox::new(
Array1::from(b_pairs[..dim].iter().map(|(lo, _)| *lo).collect::<Vec<_>>()),
Array1::from(b_pairs[..dim].iter().map(|(lo, w)| lo + w).collect::<Vec<_>>()),
temp,
).unwrap();
let ab = a.intersection(&b).unwrap();
let ba = b.intersection(&a).unwrap();
let vol_ab = ab.volume().unwrap();
let vol_ba = ba.volume().unwrap();
prop_assert!(
(vol_ab - vol_ba).abs() < 1e-4,
"intersection volume should be symmetric: {vol_ab} vs {vol_ba}"
);
}
}
proptest! {
#[test]
fn proptest_lse_intersection_within_parents(
a_pairs in proptest::collection::vec((-5.0f32..5.0f32, 0.5f32..5.0f32), 1..=4),
b_pairs in proptest::collection::vec((-5.0f32..5.0f32, 0.5f32..5.0f32), 1..=4),
temp in 0.01f32..3.0f32,
) {
let dim = a_pairs.len().min(b_pairs.len());
prop_assume!(dim > 0);
let a_min: Vec<f32> = a_pairs[..dim].iter().map(|(lo, _)| *lo).collect();
let a_max: Vec<f32> = a_pairs[..dim].iter().map(|(lo, w)| lo + w).collect();
let b_min: Vec<f32> = b_pairs[..dim].iter().map(|(lo, _)| *lo).collect();
let b_max: Vec<f32> = b_pairs[..dim].iter().map(|(lo, w)| lo + w).collect();
let a = NdarrayGumbelBox::new(
Array1::from(a_min.clone()),
Array1::from(a_max.clone()),
temp,
).unwrap();
let b = NdarrayGumbelBox::new(
Array1::from(b_min.clone()),
Array1::from(b_max.clone()),
temp,
).unwrap();
let inter = a.intersection(&b).unwrap();
for d in 0..dim {
let hard_min = a_min[d].max(b_min[d]);
let hard_max = a_max[d].min(b_max[d]);
prop_assert!(
inter.min()[d] >= hard_min - 1e-5,
"dim {d}: LSE intersection min {} < hard min {hard_min}",
inter.min()[d]
);
prop_assert!(
inter.max()[d] <= hard_max + 1e-5,
"dim {d}: LSE intersection max {} > hard max {hard_max}",
inter.max()[d]
);
}
}
}
proptest! {
#[test]
fn proptest_bessel_volume_bounded_by_hard_volume_at_low_t(
pairs in proptest::collection::vec((-5.0f32..5.0f32, 2.0f32..10.0f32), 1..=4),
temp in 0.001f32..0.1f32,
) {
let mins: Vec<f32> = pairs.iter().map(|(lo, _)| *lo).collect();
let maxs: Vec<f32> = pairs.iter().map(|(lo, w)| lo + w).collect();
let hard_vol: f32 = pairs.iter().map(|(_, w)| w).product();
let gb = NdarrayGumbelBox::new(
Array1::from(mins),
Array1::from(maxs),
temp,
).unwrap();
let bessel_vol = gb.volume().unwrap();
prop_assert!(
bessel_vol <= hard_vol + 1e-3,
"Bessel vol ({bessel_vol}) should be <= hard vol ({hard_vol}) at T={temp}"
);
prop_assert!(
bessel_vol > 0.0,
"Bessel vol should be positive, got {bessel_vol}"
);
}
}
proptest! {
#[test]
fn proptest_gumbel_containment_approaches_hard_at_low_t(
b_pairs in proptest::collection::vec((-3.0f32..3.0f32, 0.5f32..2.0f32), 1..=3),
margin in 0.5f32..3.0f32,
) {
let b_min: Vec<f32> = b_pairs.iter().map(|(lo, _)| *lo).collect();
let b_max: Vec<f32> = b_pairs.iter().map(|(lo, w)| lo + w).collect();
let a_min: Vec<f32> = b_min.iter().map(|lo| lo - margin).collect();
let a_max: Vec<f32> = b_max.iter().map(|hi| hi + margin).collect();
let t_low = 0.01f32;
let a = NdarrayGumbelBox::new(
Array1::from(a_min),
Array1::from(a_max),
t_low,
).unwrap();
let b = NdarrayGumbelBox::new(
Array1::from(b_min),
Array1::from(b_max),
t_low,
).unwrap();
let containment = a.containment_prob(&b).unwrap();
prop_assert!(
containment > 0.9,
"hard containment=1.0 but Gumbel containment at T={t_low} is only {containment}"
);
}
}
proptest! {
#[test]
fn proptest_volume_non_negative_any_temperature(
box_pairs in proptest::collection::vec((-10.0f32..10.0f32, 0.1f32..10.0f32), 1..=8),
stored_temp in 0.1f32..10.0f32,
query_temp in 0.1f32..10.0f32,
) {
let dim = box_pairs.len();
let mut mins = Vec::with_capacity(dim);
let mut maxs = Vec::with_capacity(dim);
for (lo, width) in &box_pairs {
mins.push(*lo);
maxs.push(*lo + *width);
}
let gb = NdarrayGumbelBox::new(
Array1::from(mins),
Array1::from(maxs),
stored_temp,
).unwrap();
let vol = gb.volume().unwrap();
prop_assert!(vol >= 0.0,
"Gumbel volume must be >= 0 for any temp, got {vol} (stored_t={stored_temp}, query_t={query_temp})");
prop_assert!(vol.is_finite(),
"Gumbel volume must be finite, got {vol}");
}
}
}