use num_traits::Float;
use std::ops::Range;
#[derive(thiserror::Error, Debug)]
#[error("meet failure between {left:?} and {right:?}")]
pub struct MeetError<T> {
left: Interval<T>,
right: Interval<T>,
}
#[derive(thiserror::Error, Debug)]
pub enum IntervalError<T> {
#[error("invalid interval bounds: lower={lower:?}, upper={upper:?}")]
InvalidBounds { lower: T, upper: T },
#[error("invalid domain: {0:?}")]
InvalidDomain(Range<T>),
}
#[derive(Debug, Clone, Copy)]
pub struct Interval<T> {
lower: T,
upper: T,
}
impl<T> Interval<T> {
pub fn new(lower: T, upper: T) -> Result<Self, IntervalError<T>>
where
T: Float,
{
if !lower.is_finite() || !upper.is_finite() || lower > upper {
return Err(IntervalError::InvalidBounds { lower, upper });
}
Ok(Self { lower, upper })
}
pub fn from_domain(domain: Range<T>) -> Result<Self, IntervalError<T>>
where
T: Float,
{
if !domain.start.is_finite() || !domain.end.is_finite() || domain.start >= domain.end {
return Err(IntervalError::InvalidDomain(domain));
}
Ok(Self {
lower: domain.start,
upper: domain.end,
})
}
pub fn lower(&self) -> T
where
T: Copy,
{
self.lower
}
pub fn upper(&self) -> T
where
T: Copy,
{
self.upper
}
pub fn width(&self) -> T
where
T: Float,
{
self.upper - self.lower
}
pub fn meet(self, other: Self) -> Result<Self, MeetError<T>>
where
T: Float,
{
let lower = self.lower.max(other.lower);
let upper = self.upper.min(other.upper);
if lower > upper {
return Err(MeetError {
left: self,
right: other,
});
}
Ok(Self { lower, upper })
}
pub fn contains(&self, x: T) -> bool
where
T: Float,
{
self.lower <= x && x <= self.upper
}
}
#[derive(Clone, Debug)]
pub struct SequentialInterval<T> {
pub current: Interval<T>,
}
impl<T> SequentialInterval<T> {
pub(crate) fn instantiate(domain: Range<T>) -> SequentialInterval<T>
where
T: Float,
{
if domain.start.is_nan() || domain.end.is_nan() {
panic!("domain contained NaN");
}
if domain.start >= domain.end {
panic!("ill-defined domain");
}
Self {
current: Interval {
lower: domain.start,
upper: domain.end,
},
}
}
pub(crate) fn width(&self) -> T
where
T: Float,
{
self.current.upper - self.current.lower
}
pub(crate) fn lower(&self) -> T
where
T: Copy,
{
self.current.lower
}
pub(crate) fn upper(&self) -> T
where
T: Copy,
{
self.current.upper
}
pub fn try_meet(self, other: Interval<T>) -> Result<Self, MeetError<T>>
where
T: Float,
{
let lower = self.current.lower.max(other.lower);
let upper = self.current.upper.min(other.upper);
if lower > upper {
return Err(MeetError {
left: self.current,
right: other,
});
}
Ok(Self {
current: Interval { lower, upper },
})
}
pub fn meet_or_keep(self, other: Interval<T>) -> (Self, bool)
where
T: Float,
{
match self.clone().try_meet(other) {
Ok(next) => (next, true),
Err(_) => (self, false),
}
}
}