use core::fmt;
use core::ops::{Add, Div, Mul, Neg, Sub};
use num_traits::Float;
#[inline]
pub fn prev_float(x: f64) -> f64 {
if x.is_nan() {
return x;
}
if x == f64::NEG_INFINITY {
return f64::NEG_INFINITY;
}
if x == 0.0 {
return -f64::MIN_POSITIVE * f64::EPSILON; }
let bits = x.to_bits();
if x > 0.0 {
f64::from_bits(bits - 1)
} else {
f64::from_bits(bits + 1)
}
}
#[inline]
pub fn next_float(x: f64) -> f64 {
if x.is_nan() {
return x;
}
if x == f64::INFINITY {
return f64::INFINITY;
}
if x == 0.0 {
return f64::MIN_POSITIVE * f64::EPSILON; }
let bits = x.to_bits();
if x > 0.0 {
f64::from_bits(bits + 1)
} else {
f64::from_bits(bits - 1)
}
}
#[inline]
pub fn round_down<T: Float + FloatBits>(x: T) -> T {
if x.is_nan() || x == T::neg_infinity() {
return x;
}
if x == T::zero() {
return T::zero(); }
T::from_bits(if x > T::zero() {
T::wrapping_sub_bits(x.to_bits(), T::one_bits())
} else {
T::wrapping_add_bits(x.to_bits(), T::one_bits())
})
}
#[inline]
pub fn round_up<T: Float + FloatBits>(x: T) -> T {
if x.is_nan() || x == T::infinity() {
return x;
}
if x == T::zero() {
return T::zero();
}
T::from_bits(if x > T::zero() {
T::wrapping_add_bits(x.to_bits(), T::one_bits())
} else {
T::wrapping_sub_bits(x.to_bits(), T::one_bits())
})
}
pub trait FloatBits: Sized + Copy {
type Bits: Copy
+ core::ops::Add<Output = Self::Bits>
+ core::ops::Sub<Output = Self::Bits>
+ Eq
+ PartialOrd;
fn to_bits(self) -> Self::Bits;
fn from_bits(bits: Self::Bits) -> Self;
fn one_bits() -> Self::Bits;
fn wrapping_add_bits(a: Self::Bits, b: Self::Bits) -> Self::Bits;
fn wrapping_sub_bits(a: Self::Bits, b: Self::Bits) -> Self::Bits;
}
impl FloatBits for f32 {
type Bits = u32;
#[inline]
fn to_bits(self) -> u32 {
f32::to_bits(self)
}
#[inline]
fn from_bits(bits: u32) -> f32 {
f32::from_bits(bits)
}
#[inline]
fn one_bits() -> u32 {
1u32
}
#[inline]
fn wrapping_add_bits(a: u32, b: u32) -> u32 {
a.wrapping_add(b)
}
#[inline]
fn wrapping_sub_bits(a: u32, b: u32) -> u32 {
a.wrapping_sub(b)
}
}
impl FloatBits for f64 {
type Bits = u64;
#[inline]
fn to_bits(self) -> u64 {
f64::to_bits(self)
}
#[inline]
fn from_bits(bits: u64) -> f64 {
f64::from_bits(bits)
}
#[inline]
fn one_bits() -> u64 {
1u64
}
#[inline]
fn wrapping_add_bits(a: u64, b: u64) -> u64 {
a.wrapping_add(b)
}
#[inline]
fn wrapping_sub_bits(a: u64, b: u64) -> u64 {
a.wrapping_sub(b)
}
}
#[derive(Clone, Copy, PartialEq)]
pub struct Interval<T: Float> {
pub lo: T,
pub hi: T,
}
impl Interval<f64> {
pub const ENTIRE: Self = Self {
lo: f64::NEG_INFINITY,
hi: f64::INFINITY,
};
pub const EMPTY: Self = Self {
lo: f64::INFINITY,
hi: f64::NEG_INFINITY,
};
}
impl Interval<f32> {
pub const ENTIRE: Self = Self {
lo: f32::NEG_INFINITY,
hi: f32::INFINITY,
};
pub const EMPTY: Self = Self {
lo: f32::INFINITY,
hi: f32::NEG_INFINITY,
};
}
impl<T: Float> Interval<T> {
#[inline]
pub fn new(lo: T, hi: T) -> Self {
Self { lo, hi }
}
#[inline]
pub fn point(x: T) -> Self {
Self { lo: x, hi: x }
}
#[inline]
pub fn from_midpoint_radius(midpoint: T, radius: T) -> Option<Self> {
if radius < T::zero() {
return None;
}
Some(Self {
lo: midpoint - radius,
hi: midpoint + radius,
})
}
#[inline]
pub fn from_bounds(a: T, b: T) -> Self {
if a <= b {
Self { lo: a, hi: b }
} else {
Self { lo: b, hi: a }
}
}
}
impl Interval<f64> {
#[inline]
pub fn rounded(lo: f64, hi: f64) -> Self {
Self {
lo: prev_float(lo),
hi: next_float(hi),
}
}
#[inline]
pub fn point_rounded(x: f64) -> Self {
Self {
lo: prev_float(x),
hi: next_float(x),
}
}
}
impl<T: Float> Interval<T> {
#[inline]
pub fn is_empty(&self) -> bool {
self.lo > self.hi
}
#[inline]
pub fn is_point(&self) -> bool {
self.lo == self.hi
}
#[inline]
pub fn is_unbounded_below(&self) -> bool {
self.lo == T::neg_infinity()
}
#[inline]
pub fn is_unbounded_above(&self) -> bool {
self.hi == T::infinity()
}
#[inline]
pub fn is_entire(&self) -> bool {
self.is_unbounded_below() && self.is_unbounded_above()
}
#[inline]
pub fn contains(&self, x: T) -> bool {
!self.is_empty() && self.lo <= x && x <= self.hi
}
#[inline]
pub fn contains_interval(&self, other: &Self) -> bool {
other.is_empty() || (!self.is_empty() && self.lo <= other.lo && other.hi <= self.hi)
}
#[inline]
pub fn intersects(&self, other: &Self) -> bool {
!self.is_empty() && !other.is_empty() && self.lo <= other.hi && other.lo <= self.hi
}
}
impl<T: Float> Interval<T> {
#[inline]
pub fn width(&self) -> T {
if self.is_empty() {
T::nan()
} else {
self.hi - self.lo
}
}
#[inline]
pub fn midpoint(&self) -> T {
if self.is_empty() {
T::nan()
} else {
let two = T::one() + T::one();
self.lo / two + self.hi / two
}
}
#[inline]
pub fn radius(&self) -> T {
if self.is_empty() {
T::nan()
} else {
let two = T::one() + T::one();
(self.hi - self.lo) / two
}
}
#[inline]
pub fn mig(&self) -> T {
if self.is_empty() {
return T::nan();
}
if self.lo <= T::zero() && T::zero() <= self.hi {
T::zero()
} else {
self.lo.abs().min(self.hi.abs())
}
}
#[inline]
pub fn mag(&self) -> T {
if self.is_empty() {
return T::nan();
}
self.lo.abs().max(self.hi.abs())
}
#[inline]
pub fn intersection(&self, other: &Self) -> Self {
if self.is_empty() || other.is_empty() || !self.intersects(other) {
Self {
lo: T::infinity(),
hi: T::neg_infinity(),
}
} else {
Self {
lo: self.lo.max(other.lo),
hi: self.hi.min(other.hi),
}
}
}
#[inline]
pub fn hull(&self, other: &Self) -> Self {
if self.is_empty() {
return *other;
}
if other.is_empty() {
return *self;
}
Self {
lo: self.lo.min(other.lo),
hi: self.hi.max(other.hi),
}
}
#[inline]
pub fn abs(&self) -> Self {
if self.is_empty() {
return *self;
}
if self.lo >= T::zero() {
*self
} else if self.hi <= T::zero() {
Self {
lo: (-self.hi),
hi: (-self.lo),
}
} else {
Self {
lo: T::zero(),
hi: self.lo.abs().max(self.hi.abs()),
}
}
}
}
impl Interval<f64> {
#[inline]
pub fn add_rounded(self, rhs: Self) -> Self {
if self.is_empty() || rhs.is_empty() {
return Self::EMPTY;
}
Self {
lo: prev_float(self.lo + rhs.lo),
hi: next_float(self.hi + rhs.hi),
}
}
#[inline]
pub fn sub_rounded(self, rhs: Self) -> Self {
if self.is_empty() || rhs.is_empty() {
return Self::EMPTY;
}
Self {
lo: prev_float(self.lo - rhs.hi),
hi: next_float(self.hi - rhs.lo),
}
}
#[inline]
pub fn mul_rounded(self, rhs: Self) -> Self {
if self.is_empty() || rhs.is_empty() {
return Self::EMPTY;
}
let p1 = self.lo * rhs.lo;
let p2 = self.lo * rhs.hi;
let p3 = self.hi * rhs.lo;
let p4 = self.hi * rhs.hi;
let raw_lo = p1.min(p2).min(p3).min(p4);
let raw_hi = p1.max(p2).max(p3).max(p4);
Self {
lo: prev_float(raw_lo),
hi: next_float(raw_hi),
}
}
#[inline]
pub fn div_rounded(self, rhs: Self) -> Self {
if self.is_empty() || rhs.is_empty() {
return Self::EMPTY;
}
if rhs.lo <= 0.0 && rhs.hi >= 0.0 {
return Self::ENTIRE;
}
let p1 = self.lo / rhs.lo;
let p2 = self.lo / rhs.hi;
let p3 = self.hi / rhs.lo;
let p4 = self.hi / rhs.hi;
let raw_lo = p1.min(p2).min(p3).min(p4);
let raw_hi = p1.max(p2).max(p3).max(p4);
Self {
lo: prev_float(raw_lo),
hi: next_float(raw_hi),
}
}
#[inline]
pub fn sqrt(&self) -> Self {
if self.is_empty() || self.hi < 0.0 {
return Self::EMPTY;
}
let lo_clamped = self.lo.max(0.0);
Self {
lo: prev_float(lo_clamped.sqrt()),
hi: next_float(self.hi.sqrt()),
}
}
#[inline]
pub fn exp(&self) -> Self {
if self.is_empty() {
return Self::EMPTY;
}
Self {
lo: prev_float(self.lo.exp()),
hi: next_float(self.hi.exp()),
}
}
#[inline]
pub fn ln(&self) -> Self {
if self.is_empty() || self.hi <= 0.0 {
return Self::EMPTY;
}
let lo_clamped = self.lo.max(0.0);
Self {
lo: prev_float(lo_clamped.ln()),
hi: next_float(self.hi.ln()),
}
}
#[inline]
pub fn sin(&self) -> Self {
if self.is_empty() {
return Self::EMPTY;
}
let two_pi = core::f64::consts::TAU;
if self.hi - self.lo >= two_pi {
return Self::new(-1.0, 1.0);
}
let s_lo = self.lo.sin();
let s_hi = self.hi.sin();
let mut lo = s_lo.min(s_hi);
let mut hi = s_lo.max(s_hi);
let pi_half = core::f64::consts::FRAC_PI_2;
let k_min = ((self.lo - pi_half) / two_pi).ceil() as i64;
let k_max = ((self.hi - pi_half) / two_pi).floor() as i64;
if k_min <= k_max {
hi = 1.0_f64;
}
let neg_pi_half = -pi_half;
let k_min2 = ((self.lo - neg_pi_half) / two_pi).ceil() as i64;
let k_max2 = ((self.hi - neg_pi_half) / two_pi).floor() as i64;
if k_min2 <= k_max2 {
lo = -1.0_f64;
}
Self {
lo: prev_float(lo),
hi: next_float(hi),
}
}
#[inline]
pub fn cos(&self) -> Self {
if self.is_empty() {
return Self::EMPTY;
}
let pi_half = core::f64::consts::FRAC_PI_2;
let shifted = Self {
lo: self.lo + pi_half,
hi: self.hi + pi_half,
};
let two_pi = core::f64::consts::TAU;
if self.hi - self.lo >= two_pi {
return Self::new(-1.0, 1.0);
}
let c_lo = self.lo.cos();
let c_hi = self.hi.cos();
let mut lo = c_lo.min(c_hi);
let mut hi = c_lo.max(c_hi);
let k_min = (self.lo / two_pi).ceil() as i64;
let k_max = (self.hi / two_pi).floor() as i64;
if k_min <= k_max {
hi = 1.0_f64;
}
let pi = core::f64::consts::PI;
let k_min2 = ((self.lo - pi) / two_pi).ceil() as i64;
let k_max2 = ((self.hi - pi) / two_pi).floor() as i64;
if k_min2 <= k_max2 {
lo = -1.0_f64;
}
let pi_half = core::f64::consts::FRAC_PI_2;
let zk_min = ((self.lo - pi_half) / pi).ceil() as i64;
let zk_max = ((self.hi - pi_half) / pi).floor() as i64;
if zk_min <= zk_max {
if lo > 0.0 {
lo = 0.0;
}
if hi < 0.0 {
hi = 0.0;
}
}
let _ = shifted; Self {
lo: prev_float(lo),
hi: next_float(hi),
}
}
#[inline]
pub fn atan(&self) -> Self {
if self.is_empty() {
return Self::EMPTY;
}
Self {
lo: prev_float(self.lo.atan()),
hi: next_float(self.hi.atan()),
}
}
pub fn powi(&self, n: i32) -> Self {
if self.is_empty() {
return Self::EMPTY;
}
if n == 0 {
return Self::point(1.0);
}
if n < 0 {
let pos = self.powi(-n);
return Self::point(1.0).div_rounded(pos);
}
if n % 2 == 0 {
let abs_iv = self.abs();
let lo = prev_float(abs_iv.lo.powi(n));
let hi = next_float(abs_iv.hi.powi(n));
Self { lo, hi }
} else {
let lo = prev_float(self.lo.powi(n));
let hi = next_float(self.hi.powi(n));
Self { lo, hi }
}
}
}
impl<T: Float> Neg for Interval<T> {
type Output = Self;
#[inline]
fn neg(self) -> Self {
if self.is_empty() {
Self {
lo: T::infinity(),
hi: T::neg_infinity(),
}
} else {
Self {
lo: -self.hi,
hi: -self.lo,
}
}
}
}
impl Add for Interval<f64> {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
self.add_rounded(rhs)
}
}
impl Sub for Interval<f64> {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self {
self.sub_rounded(rhs)
}
}
impl Mul for Interval<f64> {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
self.mul_rounded(rhs)
}
}
impl Div for Interval<f64> {
type Output = Self;
#[inline]
fn div(self, rhs: Self) -> Self {
self.div_rounded(rhs)
}
}
impl Add<f64> for Interval<f64> {
type Output = Self;
#[inline]
fn add(self, rhs: f64) -> Self {
self.add_rounded(Self::point(rhs))
}
}
impl Sub<f64> for Interval<f64> {
type Output = Self;
#[inline]
fn sub(self, rhs: f64) -> Self {
self.sub_rounded(Self::point(rhs))
}
}
impl Mul<f64> for Interval<f64> {
type Output = Self;
#[inline]
fn mul(self, rhs: f64) -> Self {
self.mul_rounded(Self::point(rhs))
}
}
impl Div<f64> for Interval<f64> {
type Output = Self;
#[inline]
fn div(self, rhs: f64) -> Self {
self.div_rounded(Self::point(rhs))
}
}
impl<T: Float + fmt::Display> fmt::Display for Interval<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_empty() {
write!(f, "∅")
} else {
write!(f, "[{}, {}]", self.lo, self.hi)
}
}
}
impl<T: Float + fmt::Debug> fmt::Debug for Interval<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_empty() {
write!(f, "Interval::EMPTY")
} else {
write!(f, "Interval {{ lo: {:?}, hi: {:?} }}", self.lo, self.hi)
}
}
}
impl<T: Float> PartialOrd for Interval<T> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
use core::cmp::Ordering;
if self == other {
return Some(Ordering::Equal);
}
if other.contains_interval(self) {
return Some(Ordering::Less);
}
if self.contains_interval(other) {
return Some(Ordering::Greater);
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_containment() {
let a = Interval::new(1.0_f64, 2.0);
let b = Interval::new(3.0_f64, 4.0);
let c = a + b;
assert!(c.lo <= 4.0, "lo should be ≤ 4, got {}", c.lo);
assert!(c.hi >= 6.0, "hi should be ≥ 6, got {}", c.hi);
}
#[test]
fn test_sub_containment() {
let a = Interval::new(5.0_f64, 7.0);
let b = Interval::new(1.0_f64, 3.0);
let c = a - b;
assert!(c.lo <= 2.0, "lo should be ≤ 2, got {}", c.lo);
assert!(c.hi >= 6.0, "hi should be ≥ 6, got {}", c.hi);
}
#[test]
fn test_mul_containment() {
let a = Interval::new(2.0_f64, 3.0);
let b = Interval::new(4.0_f64, 5.0);
let c = a * b;
assert!(c.lo <= 8.0, "lo should be ≤ 8, got {}", c.lo);
assert!(c.hi >= 15.0, "hi should be ≥ 15, got {}", c.hi);
}
#[test]
fn test_div_containment() {
let a = Interval::new(6.0_f64, 8.0);
let b = Interval::new(2.0_f64, 4.0);
let c = a / b;
assert!(c.lo <= 1.5, "lo should be ≤ 1.5, got {}", c.lo);
assert!(c.hi >= 4.0, "hi should be ≥ 4, got {}", c.hi);
}
#[test]
fn test_div_zero_divisor() {
let a = Interval::new(1.0_f64, 2.0);
let b = Interval::new(-1.0_f64, 1.0);
let c = a / b;
assert!(c.is_entire(), "dividing by interval containing zero should give ENTIRE");
}
#[test]
fn test_sqrt_containment() {
let a = Interval::new(4.0_f64, 9.0);
let r = a.sqrt();
assert!(r.lo <= 2.0, "lo should be ≤ 2, got {}", r.lo);
assert!(r.hi >= 3.0, "hi should be ≥ 3, got {}", r.hi);
}
#[test]
fn test_sqrt_negative() {
let a = Interval::new(-1.0_f64, -0.5);
assert!(a.sqrt().is_empty());
}
#[test]
fn test_exp_containment() {
let a = Interval::new(0.0_f64, 1.0);
let e = a.exp();
assert!(e.lo <= 1.0, "e.lo should be ≤ 1, got {}", e.lo);
assert!(e.hi >= core::f64::consts::E, "e.hi should be ≥ e, got {}", e.hi);
}
#[test]
fn test_ln_containment() {
let a = Interval::new(1.0_f64, core::f64::consts::E);
let l = a.ln();
assert!(l.lo <= 0.0, "l.lo should be ≤ 0, got {}", l.lo);
assert!(l.hi >= 1.0, "l.hi should be ≥ 1, got {}", l.hi);
}
#[test]
fn test_sin_wide_interval() {
let a = Interval::new(0.0_f64, 7.0);
let s = a.sin();
assert!(s.lo <= -1.0, "sin wide lo should be ≤ -1");
assert!(s.hi >= 1.0, "sin wide hi should be ≥ 1");
}
#[test]
fn test_contains() {
let a = Interval::new(1.0_f64, 3.0);
assert!(a.contains(2.0));
assert!(!a.contains(0.0));
assert!(!a.contains(4.0));
}
#[test]
fn test_empty_interval() {
let e = Interval::<f64>::EMPTY;
assert!(e.is_empty());
assert!(!e.contains(0.0));
}
#[test]
fn test_entire_interval() {
let e = Interval::<f64>::ENTIRE;
assert!(e.is_entire());
assert!(e.contains(1e300));
assert!(e.contains(-1e300));
}
#[test]
fn test_width_midpoint() {
let a = Interval::new(1.0_f64, 5.0);
assert_eq!(a.width(), 4.0);
assert_eq!(a.midpoint(), 3.0);
}
#[test]
fn test_hull() {
let a = Interval::new(1.0_f64, 3.0);
let b = Interval::new(5.0_f64, 7.0);
let h = a.hull(&b);
assert_eq!(h.lo, 1.0);
assert_eq!(h.hi, 7.0);
}
#[test]
fn test_intersection() {
let a = Interval::new(1.0_f64, 5.0);
let b = Interval::new(3.0_f64, 7.0);
let i = a.intersection(&b);
assert_eq!(i.lo, 3.0);
assert_eq!(i.hi, 5.0);
let c = Interval::new(6.0_f64, 8.0);
assert!(a.intersection(&c).is_empty());
}
#[test]
fn test_powi_even() {
let a = Interval::new(-2.0_f64, 3.0);
let r = a.powi(2);
assert!(r.lo <= 0.0);
assert!(r.hi >= 9.0);
}
#[test]
fn test_powi_odd() {
let a = Interval::new(-2.0_f64, 3.0);
let r = a.powi(3);
assert!(r.lo <= -8.0);
assert!(r.hi >= 27.0);
}
#[test]
fn test_neg() {
let a = Interval::new(1.0_f64, 3.0);
let b = -a;
assert_eq!(b.lo, -3.0);
assert_eq!(b.hi, -1.0);
}
#[test]
fn test_next_prev_float() {
let x = 1.0_f64;
assert!(prev_float(x) < x);
assert!(next_float(x) > x);
assert_eq!(prev_float(f64::NEG_INFINITY), f64::NEG_INFINITY);
assert_eq!(next_float(f64::INFINITY), f64::INFINITY);
}
}