use core::f32;
use core::fmt::Debug;
use core::fmt::Display;
use core::ops::Add;
use core::ops::Neg;
use core::ops::Sub;
use num_traits::CheckedAdd;
use num_traits::CheckedSub;
use num_traits::WrappingAdd;
use num_traits::WrappingSub;
use thiserror::Error;
pub type D1 = Decimal32<1>;
pub type D2 = Decimal32<2>;
pub type D3 = Decimal32<3>;
pub type D4 = Decimal32<4>;
pub type D5 = Decimal32<5>;
pub type D6 = Decimal32<6>;
pub type D7 = Decimal32<7>;
pub type D8 = Decimal32<8>;
pub type D9 = Decimal32<9>;
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
pub struct Decimal32<const PRECISION: u32>(i32);
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum DecimalErr {
#[error("The attempted operation would cause a loss of precision.")]
Lossy,
}
impl<const PRECISION: u32> Decimal32<PRECISION> {
const _PRECISION_CHECK: () = assert!(
PRECISION <= 9,
"PRECISION must be <= 9; 10ePRECISION would overflow u32"
);
pub const ZERO: Self = Self(0);
pub const MAX: Self = Self(i32::MAX);
pub const MIN: Self = Self(i32::MIN);
pub const MIN_UNIT: Self = Self(1);
#[must_use]
#[inline]
pub const fn cast(value: f64) -> Self {
Self((value * self::scalar(PRECISION) as f64) as i32)
}
#[must_use]
pub const fn checked(value: f64) -> Option<Self> {
let dec = Self::cast(value);
if dec.get() != value {
return None;
}
Some(dec)
}
#[must_use]
#[inline]
pub const fn get(self) -> f64 {
self.0 as f64 / self::scalar(PRECISION) as f64
}
}
#[inline]
const fn scalar(precision: u32) -> u32 {
10u32.pow(precision)
}
impl<const P: u32> Debug for Decimal32<P> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Decimal32<{}>({})", P, self.get())
}
}
impl<const P: u32> Display for Decimal32<P> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.get())
}
}
impl<const P: u32> TryFrom<f32> for Decimal32<P> {
type Error = DecimalErr;
fn try_from(value: f32) -> Result<Self, Self::Error> {
let dec = Self((value * self::scalar(P) as f32) as i32);
if dec.get() as f32 != value {
return Err(DecimalErr::Lossy);
}
Ok(dec)
}
}
impl<const P: u32> TryFrom<Decimal32<P>> for f32 {
type Error = DecimalErr;
fn try_from(value: Decimal32<P>) -> Result<Self, Self::Error> {
if value.0 as Self as i32 != value.0 {
return Err(DecimalErr::Lossy);
}
Ok(value.get() as Self)
}
}
impl<const P: u32> From<Decimal32<P>> for f64 {
fn from(val: Decimal32<P>) -> Self {
val.get()
}
}
impl<const P: u32> Add for Decimal32<P> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self(self.0.wrapping_add(rhs.0))
}
}
impl<const P: u32> CheckedAdd for Decimal32<P> {
fn checked_add(&self, v: &Self) -> Option<Self> {
self.0.checked_add(v.0).map(Self)
}
}
impl<const P: u32> WrappingAdd for Decimal32<P> {
fn wrapping_add(&self, v: &Self) -> Self {
Self(self.0.wrapping_add(v.0))
}
}
impl<const P: u32> Default for Decimal32<P> {
fn default() -> Self {
Self::ZERO
}
}
impl<const P: u32> Sub for Decimal32<P> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self(self.0.wrapping_sub(rhs.0))
}
}
impl<const P: u32> CheckedSub for Decimal32<P> {
fn checked_sub(&self, v: &Self) -> Option<Self> {
self.0.checked_sub(v.0).map(Self)
}
}
impl<const P: u32> WrappingSub for Decimal32<P> {
#[inline]
fn wrapping_sub(&self, v: &Self) -> Self {
Self(self.0.wrapping_sub(v.0))
}
}
impl<const P: u32> Neg for Decimal32<P> {
type Output = Self;
fn neg(self) -> Self::Output {
Self(self.0.wrapping_neg())
}
}
#[cfg(test)]
impl<const P: u32> From<Decimal32<P>> for num_bigint::BigInt {
fn from(val: Decimal32<P>) -> Self {
Self::from(val.0)
}
}
#[cfg(test)]
impl<const P: u32> TryFrom<&num_bigint::BigInt> for Decimal32<P> {
type Error = ();
fn try_from(val: &num_bigint::BigInt) -> Result<Self, Self::Error> {
let n: i32 = val.try_into().map_err(|_| ())?;
Ok(Self(n))
}
}
#[allow(clippy::missing_panics_doc)]
#[allow(clippy::expect_used)]
#[cfg(test)]
pub mod decimal_tests {
use num_traits::CheckedAdd;
use num_traits::CheckedSub;
use num_traits::WrappingAdd;
use num_traits::WrappingSub;
use crate::decimal::Decimal32;
use crate::decimal::DecimalErr;
use crate::decimal::D3;
pub fn assert_eq_f64(left: f64, right: f64, precision: u32) {
let tolerance = 1.0 / super::scalar(precision) as f64;
assert!(
(left - right).abs() < tolerance,
"equality failed: {left:?} != {right:?} (tolerance {tolerance})"
);
}
#[test]
fn cast_from_exact() {
assert_eq_f64(D3::cast(0.001).get(), 0.001, 3);
assert_eq_f64(D3::cast(7.120).get(), 7.120, 3);
assert_eq_f64(D3::cast(-3.500).get(), -3.500, 3);
}
#[test]
fn cast_truncates() {
let d = Decimal32::<0>::cast(0.9999_f64);
assert_eq!(d.get(), 0.0_f64);
}
#[test]
fn try_from_lossless() {
assert!(D3::try_from(0.001_f32).is_ok());
assert!(D3::try_from(7.120_f32).is_ok());
assert!(D3::try_from(0.0_f32).is_ok());
}
#[test]
fn try_from_lossy() {
assert!(matches!(
D3::try_from(1.0_f32 / 3.0_f32),
Err(DecimalErr::Lossy)
));
assert_eq_f64(D3::cast(1. / 3.).get(), 0.333, 3);
}
#[test]
fn decimal_to_f32_ok() {
let d = D3::cast(7.120);
let f: f32 = f32::try_from(d).expect("should round-trip");
assert_eq_f64(f as f64, 7.120, 3);
}
#[test]
fn add_basic() {
let sum = D3::cast(0.001) + D3::cast(7.120);
assert_eq_f64(sum.get(), 7.121, 3);
}
#[test]
fn add_negative() {
let sum = D3::cast(1.000) + D3::cast(-3.500);
assert_eq_f64(sum.get(), -2.500, 3);
}
#[test]
fn checked_add_ok() {
let a = D3::cast(1.000);
let b = D3::cast(2.000);
assert_eq!(a.checked_add(&b), Some(D3::cast(3.000)));
}
#[test]
fn checked_add_overflow() {
let max = Decimal32::<0>(i32::MAX);
let one = Decimal32::<0>(1);
assert_eq!(max.checked_add(&one), None);
}
#[test]
fn wrapping_add_overflow() {
let max = Decimal32::<0>(i32::MAX);
let one = Decimal32::<0>(1);
let expected = Decimal32::<0>(i32::MIN);
assert_eq!(max.wrapping_add(&one), expected);
}
#[test]
fn sub_basic() {
let diff = D3::cast(7.121) - D3::cast(0.001);
assert_eq_f64(diff.get(), 7.120, 3);
}
#[test]
fn sub_to_negative() {
let diff = D3::cast(1.000) - D3::cast(3.500);
assert_eq_f64(diff.get(), -2.500, 3);
}
#[test]
fn checked_sub_ok() {
let a = D3::cast(5.000);
let b = D3::cast(2.000);
assert_eq!(a.checked_sub(&b), Some(D3::cast(3.000)));
}
#[test]
fn checked_sub_underflow() {
let min = Decimal32::<0>(i32::MIN);
let one = Decimal32::<0>(1);
assert_eq!(min.checked_sub(&one), None);
}
#[test]
fn wrapping_sub_underflow() {
let min = Decimal32::<0>(i32::MIN);
let one = Decimal32::<0>(1);
let expected = Decimal32::<0>(i32::MAX);
assert_eq!(min.wrapping_sub(&one), expected);
}
#[test]
fn zero_additive_identity() {
let x = D3::cast(4.200);
assert_eq!(x + D3::ZERO, x);
assert_eq!(D3::ZERO + x, x);
}
#[test]
fn zero_subtractive_identity() {
let x = D3::cast(4.200);
assert_eq!(x - D3::ZERO, x);
}
#[test]
fn ordering() {
let neg = D3::cast(-1.000);
let zero = D3::ZERO;
let pos = D3::cast(1.000);
assert!(neg < zero);
assert!(zero < pos);
assert!(neg < pos);
assert_eq!(zero, D3::cast(0.000));
}
#[test]
fn default_is_zero() {
assert_eq!(D3::default(), D3::ZERO);
}
#[test]
fn neg_basic() {
let x = D3::cast(4.200);
assert_eq!(-x, D3::cast(-4.200));
assert_eq!(-(-x), x);
}
#[test]
fn neg_min_wraps() {
let min = Decimal32::<0>(i32::MIN);
assert_eq!(-min, min);
}
#[test]
fn decimal_to_f32_lossy() {
let d = Decimal32::<1>(16_777_217_i32);
assert!(f32::try_from(d).is_err());
let d_neg = Decimal32::<1>(-16_777_217_i32);
assert!(f32::try_from(d_neg).is_err());
}
}