use std::{cmp::Ordering, marker::PhantomData, ops::Add};
use num_traits::Zero;
use super::{Constrained, Constraint, ConstraintError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StrictlyNegative;
impl StrictlyNegative {
pub fn new<T: PartialOrd + Zero>(
value: T,
) -> Result<Constrained<T, StrictlyNegative>, ConstraintError> {
Constrained::<T, StrictlyNegative>::new(value)
}
}
impl<T: PartialOrd + Zero> Constraint<T> for StrictlyNegative {
fn check(value: &T) -> Result<(), ConstraintError> {
match value.partial_cmp(&T::zero()) {
Some(Ordering::Less) => Ok(()),
Some(Ordering::Equal) => Err(ConstraintError::Zero),
Some(Ordering::Greater) => Err(ConstraintError::Negative),
None => Err(ConstraintError::NotANumber),
}
}
}
impl<T> Add for Constrained<T, StrictlyNegative>
where
T: Add<Output = T> + PartialOrd + Zero,
{
type Output = Self;
fn add(self, rhs: Self) -> Self {
let value = self.value + rhs.value;
debug_assert!(
value < T::zero(),
"Addition produced a non-negative value, violating StrictlyNegative bound invariant"
);
Self {
value,
_marker: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn integers() {
let x = Constrained::<i32, StrictlyNegative>::new(-1).unwrap();
assert_eq!(x.into_inner(), -1);
let y = StrictlyNegative::new(-42).unwrap();
assert_eq!(y.as_ref(), &-42);
assert!(StrictlyNegative::new(0).is_err());
assert!(StrictlyNegative::new(2).is_err());
}
#[test]
fn floats() {
assert!(Constrained::<f64, StrictlyNegative>::new(-1.0).is_ok());
assert!(StrictlyNegative::new(-0.1).is_ok());
assert!(StrictlyNegative::new(0.0).is_err());
assert!(StrictlyNegative::new(5.0).is_err());
assert!(StrictlyNegative::new(f64::NAN).is_err());
}
}