use num_traits::{AsPrimitive, FromPrimitive};
use std::cmp::Ordering;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum IntegerOrInfinity {
PositiveInfinity,
Integer(i64),
NegativeInfinity,
}
impl IntegerOrInfinity {
#[must_use]
pub fn clamp_finite<I: Ord + AsPrimitive<i64> + FromPrimitive>(self, min: I, max: I) -> I {
assert!(min <= max);
match self {
Self::Integer(i) => {
I::from_i64(i.clamp(min.as_(), max.as_())).expect("`i` should already be clamped")
}
Self::PositiveInfinity => max,
Self::NegativeInfinity => min,
}
}
#[must_use]
pub const fn as_integer(self) -> Option<i64> {
match self {
Self::Integer(i) => Some(i),
_ => None,
}
}
}
impl From<f64> for IntegerOrInfinity {
fn from(number: f64) -> Self {
if number.is_nan() || number == 0.0 {
Self::Integer(0)
} else if number == f64::INFINITY {
Self::PositiveInfinity
} else if number == f64::NEG_INFINITY {
Self::NegativeInfinity
} else {
let integer = number.abs().floor().copysign(number) as i64;
Self::Integer(integer)
}
}
}
impl PartialEq<i64> for IntegerOrInfinity {
fn eq(&self, other: &i64) -> bool {
match self {
Self::Integer(i) => i == other,
_ => false,
}
}
}
impl PartialEq<IntegerOrInfinity> for i64 {
fn eq(&self, other: &IntegerOrInfinity) -> bool {
other.eq(self)
}
}
impl PartialOrd<i64> for IntegerOrInfinity {
fn partial_cmp(&self, other: &i64) -> Option<Ordering> {
match self {
Self::PositiveInfinity => Some(Ordering::Greater),
Self::Integer(i) => i.partial_cmp(other),
Self::NegativeInfinity => Some(Ordering::Less),
}
}
}
impl PartialOrd<IntegerOrInfinity> for i64 {
fn partial_cmp(&self, other: &IntegerOrInfinity) -> Option<Ordering> {
other.partial_cmp(self).map(Ordering::reverse)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_eq() {
let int: i64 = 42;
let int_or_inf = IntegerOrInfinity::Integer(10);
assert!(int != int_or_inf);
assert!(int_or_inf != int);
let int: i64 = 10;
assert!(int == int_or_inf);
assert!(int_or_inf == int);
}
#[test]
fn test_ord() {
let int: i64 = 42;
let int_or_inf = IntegerOrInfinity::Integer(10);
assert!(int_or_inf < int);
assert!(int > int_or_inf);
let int_or_inf = IntegerOrInfinity::Integer(100);
assert!(int_or_inf > int);
assert!(int < int_or_inf);
let int_or_inf = IntegerOrInfinity::PositiveInfinity;
assert!(int_or_inf > int);
assert!(int < int_or_inf);
let int_or_inf = IntegerOrInfinity::NegativeInfinity;
assert!(int_or_inf < int);
assert!(int > int_or_inf);
}
}