use std::{cmp::Ordering, marker::PhantomData, ops::Add};
use num_traits::Zero;
use super::{Constrained, Constraint, ConstraintError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct NonNegative;
impl NonNegative {
pub fn new<T: PartialOrd + Zero>(
value: T,
) -> Result<Constrained<T, NonNegative>, ConstraintError> {
Constrained::<T, NonNegative>::new(value)
}
#[must_use]
pub fn zero<T: PartialOrd + Zero>() -> Constrained<T, NonNegative> {
Constrained::<T, NonNegative>::zero()
}
}
impl<T: PartialOrd + Zero> Constraint<T> for NonNegative {
fn check(value: &T) -> Result<(), ConstraintError> {
match value.partial_cmp(&T::zero()) {
Some(Ordering::Greater | Ordering::Equal) => Ok(()),
Some(Ordering::Less) => Err(ConstraintError::Negative),
None => Err(ConstraintError::NotANumber),
}
}
}
impl<T> Add for Constrained<T, NonNegative>
where
T: Add<Output = T> + PartialOrd + Zero,
{
type Output = Self;
fn add(self, rhs: Self) -> Self {
let value = self.value + rhs.value;
debug_assert!(
value >= T::zero(),
"Addition produced a negative value, violating NonNegative bound invariant"
);
Self {
value,
_marker: PhantomData,
}
}
}
impl<T> Zero for Constrained<T, NonNegative>
where
T: PartialOrd + Zero,
{
fn zero() -> Self {
Self {
value: T::zero(),
_marker: PhantomData,
}
}
fn is_zero(&self) -> bool {
self.value == T::zero()
}
}
#[cfg(test)]
mod tests {
use super::*;
use uom::si::{f64::MassRate, mass_rate::kilogram_per_second};
#[test]
fn integers() {
let one = Constrained::<i32, NonNegative>::new(1).unwrap();
assert_eq!(one.into_inner(), 1);
let two = NonNegative::new(2).unwrap();
assert_eq!(two.as_ref(), &2);
let zero = NonNegative::zero();
assert_eq!(zero.into_inner(), 0);
let sum = one + two + zero;
assert_eq!(sum.into_inner(), 3);
assert!(NonNegative::new(-1).is_err());
}
#[test]
fn floats() {
assert!(Constrained::<f64, NonNegative>::new(2.0).is_ok());
assert!(NonNegative::new(0.0).is_ok());
assert!(NonNegative::new(-2.0).is_err());
assert!(NonNegative::new(f64::NAN).is_err());
}
#[test]
fn mass_rates() {
let mass_rate = MassRate::new::<kilogram_per_second>(5.0);
assert!(NonNegative::new(mass_rate).is_ok());
let mass_rate = MassRate::new::<kilogram_per_second>(0.0);
assert!(NonNegative::new(mass_rate).is_ok());
let mass_rate = MassRate::new::<kilogram_per_second>(-2.0);
assert!(NonNegative::new(mass_rate).is_err());
}
}