use std::{
borrow::Borrow,
ops::{Range, RangeInclusive},
};
use num_traits::PrimInt;
#[derive(Copy, Clone, Debug, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub struct Extent<N: PrimInt> {
lo: N,
hi: N,
}
impl<N: PrimInt> Default for Extent<N> {
fn default() -> Self {
Self::empty()
}
}
impl<N: PrimInt> Extent<N> {
pub fn lo(&self) -> Option<N> {
if self.is_empty() {
None
} else {
Some(self.lo)
}
}
pub fn hi(&self) -> Option<N> {
if self.is_empty() {
None
} else {
Some(self.hi)
}
}
pub unsafe fn lo_unchecked(&self) -> N {
self.lo
}
pub unsafe fn hi_unchecked(&self) -> N {
self.hi
}
pub fn len(&self) -> N {
if self.is_empty() {
N::zero()
} else {
N::one() + (self.hi - self.lo)
}
}
pub fn empty() -> Self {
Self {
lo: N::one(),
hi: N::zero(),
}
}
pub fn is_empty(&self) -> bool {
self.lo > self.hi
}
pub fn new<T: Borrow<N>, U: Borrow<N>>(lo: T, hi: U) -> Self {
let lo: N = *lo.borrow();
let hi: N = *hi.borrow();
Self {
lo: lo.min(hi),
hi: hi.max(lo),
}
}
pub unsafe fn new_unchecked<T: Borrow<N>, U: Borrow<N>>(lo: T, hi: U) -> Self {
let lo: N = *lo.borrow();
let hi: N = *hi.borrow();
if lo > hi {
Self::empty()
} else {
Self { lo, hi }
}
}
pub fn union<S: Borrow<Self>>(&self, other: S) -> Self {
let other = *other.borrow();
Self::new(self.lo.min(other.lo), self.hi.max(other.hi))
}
pub fn intersect<S: Borrow<Self>>(&self, other: S) -> Self {
let other = *other.borrow();
Self::new(&self.lo.max(other.lo), &self.hi.min(other.hi))
}
pub fn contains<T: Borrow<N>>(&self, n: T) -> bool {
let n = *n.borrow();
self.lo <= n && n <= self.hi
}
pub fn iter(&self) -> ExtentIter<N> {
ExtentIter(*self)
}
}
#[derive(Clone, Debug, Default)]
pub struct ExtentIter<N: PrimInt>(Extent<N>);
impl<N: PrimInt> Iterator for ExtentIter<N> {
type Item = N;
fn next(&mut self) -> Option<Self::Item> {
if self.0.is_empty() {
None
} else {
let v = self.0.lo;
self.0.lo = self.0.lo + N::one();
Some(v)
}
}
}
impl<N: PrimInt> ExtentIter<N> {
pub fn rev(self) -> ExtentRevIter<N> {
ExtentRevIter(self.0)
}
}
pub struct ExtentRevIter<N: PrimInt>(Extent<N>);
impl<N: PrimInt> Iterator for ExtentRevIter<N> {
type Item = N;
fn next(&mut self) -> Option<Self::Item> {
if self.0.is_empty() {
None
} else {
let v = self.0.hi;
self.0.hi = self.0.hi - N::one();
Some(v)
}
}
}
impl<N: PrimInt> From<Range<N>> for Extent<N> {
fn from(r: Range<N>) -> Self {
if r.is_empty() {
Self::empty()
} else {
Self {
lo: r.start,
hi: r.end - N::one(),
}
}
}
}
impl<N: PrimInt> TryFrom<Extent<N>> for Range<N> {
type Error = &'static str;
fn try_from(e: Extent<N>) -> Result<Self, Self::Error> {
if e.is_empty() {
Ok(Range {
start: N::zero(),
end: N::zero(),
})
} else if e.hi == N::max_value() {
Err("Extent.hi is N::max_value(), can't represent as Range")
} else {
Ok(Range {
start: e.lo,
end: e.hi + N::one(),
})
}
}
}
impl<N: PrimInt> From<RangeInclusive<N>> for Extent<N> {
fn from(r: RangeInclusive<N>) -> Self {
if r.is_empty() {
Self::empty()
} else {
Self::new(r.start(), r.end())
}
}
}
impl<N: PrimInt> From<Extent<N>> for RangeInclusive<N> {
fn from(e: Extent<N>) -> Self {
RangeInclusive::new(e.lo, e.hi)
}
}
#[cfg(test)]
mod test {
use super::*;
use core::convert::TryInto;
use num_traits::PrimInt;
use std::fmt::Debug;
fn check_sensible<N: PrimInt + Debug>(a: N, b: N) {
let e = Extent::new(a, b);
assert!(e.contains(a));
assert!(e.contains(b));
assert!(e.lo() <= e.hi());
let ri: RangeInclusive<N> = e.clone().into();
let e2: Extent<N> = ri.into();
assert_eq!(e, e2);
match e.try_into() {
Ok(r) => {
let r: Range<N> = r;
let e3: Extent<N> = r.into();
assert_eq!(e, e3);
}
Err(_) => {
assert_eq!(e.hi, N::max_value())
}
}
}
fn check_set_ops<N: PrimInt + Debug>(a: N, b: N, c: N) {
let mut v = [a, b, c];
v.sort();
let a = v[0];
let b = v[1];
let c = v[2];
let ab = Extent::from(a..=b);
let bc = Extent::from(b..=c);
let ac = Extent::from(a..=c);
let bb = Extent::from(b..=b);
assert_eq!(ab.union(bc), ac);
assert_eq!(ab.union(ac), ac);
assert_eq!(bc.union(ac), ac);
assert_eq!(ac.union(ab), ac);
assert_eq!(ac.union(bc), ac);
assert_eq!(ab.intersect(bc), bb);
assert_eq!(ab.intersect(ac), ab);
assert_eq!(bc.intersect(ac), bc);
assert_eq!(bb.intersect(ac), bb);
assert_eq!(bb.intersect(ab), bb);
assert_eq!(bb.intersect(bc), bb);
}
#[test]
fn test_basics() {
let elts = vec![
i32::MIN,
i32::MIN + 1,
i32::MIN + 2,
-2,
-1,
0,
1,
2,
i32::MAX - 2,
i32::MAX - 1,
i32::MAX,
];
for a in elts.iter() {
for b in elts.iter() {
check_sensible(*a, *b);
for c in elts.iter() {
check_set_ops(*a, *b, *c);
}
}
}
}
}