use pounce_common::types::Number;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Interval {
pub lo: Number,
pub hi: Number,
}
impl Interval {
pub const ENTIRE: Interval = Interval {
lo: Number::NEG_INFINITY,
hi: Number::INFINITY,
};
pub const EMPTY: Interval = Interval {
lo: Number::INFINITY,
hi: Number::NEG_INFINITY,
};
pub fn new(lo: Number, hi: Number) -> Self {
if lo.is_nan() || hi.is_nan() || lo > hi {
return Self::EMPTY;
}
Self { lo, hi }
}
pub fn point(x: Number) -> Self {
if x.is_nan() {
return Self::EMPTY;
}
Self { lo: x, hi: x }
}
pub fn is_empty(&self) -> bool {
self.lo > self.hi || self.lo.is_nan() || self.hi.is_nan()
}
pub fn is_entire(&self) -> bool {
self.lo == Number::NEG_INFINITY && self.hi == Number::INFINITY
}
pub fn contains(&self, x: Number) -> bool {
!self.is_empty() && self.lo <= x && x <= self.hi
}
pub fn contains_zero(&self) -> bool {
self.contains(0.0)
}
pub fn width(&self) -> Number {
if self.is_empty() {
0.0
} else {
self.hi - self.lo
}
}
pub fn intersect(self, other: Self) -> Self {
if self.is_empty() || other.is_empty() {
return Self::EMPTY;
}
Self::new(self.lo.max(other.lo), self.hi.min(other.hi))
}
pub fn hull(self, other: Self) -> Self {
if self.is_empty() {
return other;
}
if other.is_empty() {
return self;
}
Self::new(self.lo.min(other.lo), self.hi.max(other.hi))
}
pub fn add(self, rhs: Self) -> Self {
if self.is_empty() || rhs.is_empty() {
return Self::EMPTY;
}
Self {
lo: round_down(self.lo + rhs.lo),
hi: round_up(self.hi + rhs.hi),
}
}
pub fn sub(self, rhs: Self) -> Self {
if self.is_empty() || rhs.is_empty() {
return Self::EMPTY;
}
Self {
lo: round_down(self.lo - rhs.hi),
hi: round_up(self.hi - rhs.lo),
}
}
pub fn neg(self) -> Self {
if self.is_empty() {
return Self::EMPTY;
}
Self {
lo: -self.hi,
hi: -self.lo,
}
}
pub fn mul(self, rhs: Self) -> Self {
if self.is_empty() || rhs.is_empty() {
return Self::EMPTY;
}
let p1 = self.lo * rhs.lo;
let p2 = self.lo * rhs.hi;
let p3 = self.hi * rhs.lo;
let p4 = self.hi * rhs.hi;
let lo = round_down(p1.min(p2).min(p3.min(p4)));
let hi = round_up(p1.max(p2).max(p3.max(p4)));
Self { lo, hi }
}
pub fn div(self, rhs: Self) -> Self {
if self.is_empty() || rhs.is_empty() {
return Self::EMPTY;
}
if rhs.contains_zero() {
return Self::ENTIRE;
}
self.mul(Self {
lo: round_down(1.0 / rhs.hi),
hi: round_up(1.0 / rhs.lo),
})
}
pub fn pow_uint(self, n: u32) -> Self {
if self.is_empty() {
return Self::EMPTY;
}
if n == 0 {
return Self::point(1.0);
}
if n == 1 {
return self;
}
let (a, b) = (self.lo, self.hi);
if n % 2 == 1 {
Self {
lo: round_down(powi(a, n as i32)),
hi: round_up(powi(b, n as i32)),
}
} else if a >= 0.0 {
Self {
lo: round_down(powi(a, n as i32)),
hi: round_up(powi(b, n as i32)),
}
} else if b <= 0.0 {
Self {
lo: round_down(powi(b, n as i32)),
hi: round_up(powi(a, n as i32)),
}
} else {
let ha = powi(a, n as i32);
let hb = powi(b, n as i32);
Self {
lo: 0.0,
hi: round_up(ha.max(hb)),
}
}
}
pub fn sqrt(self) -> Self {
if self.is_empty() || self.hi < 0.0 {
return Self::EMPTY;
}
let lo = self.lo.max(0.0).sqrt();
let hi = self.hi.sqrt();
Self {
lo: round_down(lo),
hi: round_up(hi),
}
}
pub fn exp(self) -> Self {
if self.is_empty() {
return Self::EMPTY;
}
Self {
lo: round_down(self.lo.exp()),
hi: round_up(self.hi.exp()),
}
}
pub fn ln(self) -> Self {
if self.is_empty() || self.hi <= 0.0 {
return Self::EMPTY;
}
let lo = if self.lo <= 0.0 {
Number::NEG_INFINITY
} else {
round_down(self.lo.ln())
};
Self {
lo,
hi: round_up(self.hi.ln()),
}
}
pub fn abs(self) -> Self {
if self.is_empty() {
return Self::EMPTY;
}
if self.lo >= 0.0 {
self
} else if self.hi <= 0.0 {
self.neg()
} else {
Self {
lo: 0.0,
hi: self.lo.abs().max(self.hi),
}
}
}
pub fn sin(self) -> Self {
if self.is_empty() {
return Self::EMPTY;
}
trig_image(self.lo, self.hi, |x| x.sin(), SIN_PEAKS, SIN_TROUGHS)
}
pub fn cos(self) -> Self {
if self.is_empty() {
return Self::EMPTY;
}
trig_image(self.lo, self.hi, |x| x.cos(), COS_PEAKS, COS_TROUGHS)
}
pub fn min(self, other: Self) -> Self {
if self.is_empty() {
return Self::EMPTY;
}
if other.is_empty() {
return Self::EMPTY;
}
Self::new(self.lo.min(other.lo), self.hi.min(other.hi))
}
pub fn max(self, other: Self) -> Self {
if self.is_empty() {
return Self::EMPTY;
}
if other.is_empty() {
return Self::EMPTY;
}
Self::new(self.lo.max(other.lo), self.hi.max(other.hi))
}
}
fn round_down(x: Number) -> Number {
if x.is_finite() {
x.next_down()
} else {
x
}
}
fn round_up(x: Number) -> Number {
if x.is_finite() {
x.next_up()
} else {
x
}
}
fn powi(x: Number, n: i32) -> Number {
x.powi(n)
}
const TWO_PI: Number = 2.0 * std::f64::consts::PI;
const SIN_PEAKS: Number = std::f64::consts::FRAC_PI_2; const SIN_TROUGHS: Number = -std::f64::consts::FRAC_PI_2;
const COS_PEAKS: Number = 0.0; const COS_TROUGHS: Number = std::f64::consts::PI;
fn trig_image<F>(
lo: Number,
hi: Number,
f: F,
peak_offset: Number,
trough_offset: Number,
) -> Interval
where
F: Fn(Number) -> Number,
{
if !lo.is_finite() || !hi.is_finite() {
return Interval::new(-1.0, 1.0);
}
if hi - lo >= TWO_PI {
return Interval::new(-1.0, 1.0);
}
let crosses = |offset: Number| -> bool {
let k = ((lo - offset) / TWO_PI).ceil();
let x = offset + TWO_PI * k;
x <= hi
};
let endpoint_lo = f(lo);
let endpoint_hi = f(hi);
let mut local_min = endpoint_lo.min(endpoint_hi);
let mut local_max = endpoint_lo.max(endpoint_hi);
if crosses(peak_offset) {
local_max = 1.0;
}
if crosses(trough_offset) {
local_min = -1.0;
}
Interval {
lo: round_down(local_min),
hi: round_up(local_max),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: Number, b: Number, eps: Number) -> bool {
(a - b).abs() <= eps + eps * b.abs()
}
#[test]
fn empty_propagates() {
let e = Interval::EMPTY;
let a = Interval::new(0.0, 1.0);
assert!(e.add(a).is_empty());
assert!(a.add(e).is_empty());
assert!(e.mul(a).is_empty());
assert!(e.sqrt().is_empty());
assert!(e.exp().is_empty());
}
#[test]
fn new_normalizes_malformed() {
assert!(Interval::new(1.0, 0.0).is_empty());
assert!(Interval::new(Number::NAN, 1.0).is_empty());
assert!(Interval::new(1.0, Number::NAN).is_empty());
}
#[test]
fn entire_is_entire() {
assert!(Interval::ENTIRE.is_entire());
assert!(Interval::ENTIRE.contains_zero());
assert!(!Interval::EMPTY.is_entire());
}
#[test]
fn add_widens_outward() {
let r = Interval::new(1.0, 2.0).add(Interval::new(3.0, 4.0));
assert!(r.lo <= 4.0 && 4.0 - r.lo < 1e-15);
assert!(r.hi >= 6.0 && r.hi - 6.0 < 1e-15);
}
#[test]
fn sub_uses_cross_endpoints() {
let r = Interval::new(1.0, 2.0).sub(Interval::new(3.0, 4.0));
assert!(r.lo <= -3.0 && -3.0 - r.lo < 1e-15);
assert!(r.hi >= -1.0 && r.hi - (-1.0) < 1e-15);
}
#[test]
fn mul_handles_sign_crossings() {
let r = Interval::new(-2.0, 3.0).mul(Interval::new(-1.0, 4.0));
assert!(r.contains(-8.0));
assert!(r.contains(12.0));
assert!(r.lo <= -8.0);
assert!(r.hi >= 12.0);
}
#[test]
fn div_by_zero_crossing_yields_entire() {
let r = Interval::new(1.0, 2.0).div(Interval::new(-1.0, 1.0));
assert!(r.is_entire());
}
#[test]
fn div_disjoint_from_zero_inverts_correctly() {
let r = Interval::new(1.0, 4.0).div(Interval::new(2.0, 4.0));
assert!(r.contains(0.25));
assert!(r.contains(2.0));
assert!(r.lo <= 0.25);
assert!(r.hi >= 2.0);
}
#[test]
fn pow_uint_even_straddles_zero() {
let r = Interval::new(-2.0, 3.0).pow_uint(2);
assert_eq!(r.lo, 0.0);
assert!(r.hi >= 9.0);
}
#[test]
fn pow_uint_even_negative() {
let r = Interval::new(-4.0, -2.0).pow_uint(2);
assert!(r.lo <= 4.0);
assert!(r.hi >= 16.0);
}
#[test]
fn pow_uint_odd() {
let r = Interval::new(-2.0, 3.0).pow_uint(3);
assert!(r.lo <= -8.0);
assert!(r.hi >= 27.0);
}
#[test]
fn pow_zero_and_one() {
let i = Interval::new(2.0, 5.0);
let z = i.pow_uint(0);
assert_eq!(z.lo, 1.0);
assert_eq!(z.hi, 1.0);
let o = i.pow_uint(1);
assert_eq!(o, i);
}
#[test]
fn sqrt_clips_negative_lo() {
let r = Interval::new(-1.0, 4.0).sqrt();
assert!(r.lo <= 0.0);
assert!(r.lo >= -1e-300, "outward bump should be at most ~1 ULP");
assert!(r.hi >= 2.0);
}
#[test]
fn sqrt_of_fully_negative_is_empty() {
assert!(Interval::new(-4.0, -1.0).sqrt().is_empty());
}
#[test]
fn exp_is_monotone() {
let r = Interval::new(0.0, 1.0).exp();
assert!(r.contains(1.0));
assert!(r.contains(std::f64::consts::E));
}
#[test]
fn ln_of_non_positive_is_empty() {
assert!(Interval::new(-2.0, -1.0).ln().is_empty());
assert!(Interval::new(-2.0, 0.0).ln().is_empty());
}
#[test]
fn ln_with_zero_lower_yields_neg_inf() {
let r = Interval::new(0.0, 1.0).ln();
assert_eq!(r.lo, Number::NEG_INFINITY);
assert!(r.contains(0.0));
}
#[test]
fn ln_strict_positive() {
let r = Interval::new(1.0, std::f64::consts::E).ln();
assert!(r.contains(0.0));
assert!(r.contains(1.0));
}
#[test]
fn abs_negative_interval() {
let r = Interval::new(-3.0, -1.0).abs();
assert!(r.contains(1.0));
assert!(r.contains(3.0));
}
#[test]
fn abs_straddling_interval() {
let r = Interval::new(-2.0, 3.0).abs();
assert_eq!(r.lo, 0.0);
assert!(r.hi >= 3.0);
}
#[test]
fn sin_full_range() {
let r = Interval::new(0.0, TWO_PI).sin();
assert!(approx_eq(r.lo, -1.0, 1e-15));
assert!(approx_eq(r.hi, 1.0, 1e-15));
}
#[test]
fn sin_within_one_branch() {
let r = Interval::new(0.0, std::f64::consts::FRAC_PI_2).sin();
assert!(r.contains(0.0));
assert!(r.contains(1.0));
}
#[test]
fn cos_at_zero() {
let r = Interval::new(-0.1, 0.1).cos();
assert!(r.contains(1.0));
assert!(r.lo < 1.0);
}
#[test]
fn intersect_disjoint_is_empty() {
assert!(Interval::new(0.0, 1.0)
.intersect(Interval::new(2.0, 3.0))
.is_empty());
}
#[test]
fn intersect_overlap() {
let r = Interval::new(0.0, 5.0).intersect(Interval::new(3.0, 10.0));
assert_eq!(r, Interval::new(3.0, 5.0));
}
#[test]
fn hull_combines() {
let r = Interval::new(0.0, 1.0).hull(Interval::new(5.0, 6.0));
assert_eq!(r, Interval::new(0.0, 6.0));
}
#[test]
fn min_max_pairs() {
let a = Interval::new(1.0, 5.0);
let b = Interval::new(2.0, 7.0);
let mn = a.min(b);
assert!(mn.contains(1.0));
assert!(mn.contains(5.0));
let mx = a.max(b);
assert!(mx.contains(2.0));
assert!(mx.contains(7.0));
}
#[test]
fn fuzz_add_contains_pointwise() {
let cases = [
((0.5, 2.5), (1.0, 1.5), 1.5, 2.0),
((-3.0, 1.0), (-1.0, 4.0), 0.5, 2.5),
((1.0e-10, 1.0e10), (1.0, 1.0), 100.0, 1.0),
];
for &((a, b), (c, d), x, y) in &cases {
let i = Interval::new(a, b).add(Interval::new(c, d));
assert!(i.contains(x + y), "{a},{b} + {c},{d} ∌ {x}+{y}");
}
}
#[test]
fn fuzz_mul_contains_pointwise() {
let cases = [
((-2.0, 3.0), (-1.0, 4.0), 0.5, 2.0),
((1.0, 10.0), (0.1, 0.2), 5.0, 0.15),
((-5.0, -1.0), (-3.0, -1.0), -3.0, -2.0),
];
for &((a, b), (c, d), x, y) in &cases {
let i = Interval::new(a, b).mul(Interval::new(c, d));
assert!(i.contains(x * y), "{a},{b} × {c},{d} ∌ {x}×{y}");
}
}
#[test]
fn fuzz_div_contains_pointwise() {
let cases = [
((0.1, 0.2), (3.0, 4.0), 0.15, 3.5),
((1.0, 1.0), (3.0, 3.0), 1.0, 3.0),
((-2.0, 5.0), (1.0, 7.0), 1.5, 3.0),
];
for &((a, b), (c, d), x, y) in &cases {
let i = Interval::new(a, b).div(Interval::new(c, d));
assert!(i.contains(x / y), "{a},{b} / {c},{d} ∌ {x}/{y}");
}
}
#[test]
fn rounding_does_not_shrink_below_truth() {
let one = Interval::point(0.1);
let two = Interval::point(0.2);
let sum = one.add(two);
assert!(sum.contains(0.3));
}
}