use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::ops::{Add, AddAssign, Neg, Sub, SubAssign};
use crate::intern::InternedStr;
#[cfg(feature = "rkyv")]
use crate::intern::{AsDecimal, AsInternedStr};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub struct Amount {
#[cfg_attr(feature = "rkyv", rkyv(with = AsDecimal))]
pub number: Decimal,
#[cfg_attr(feature = "rkyv", rkyv(with = AsInternedStr))]
pub currency: InternedStr,
}
impl Amount {
#[must_use]
pub fn new(number: Decimal, currency: impl Into<InternedStr>) -> Self {
Self {
number,
currency: currency.into(),
}
}
#[must_use]
pub fn zero(currency: impl Into<InternedStr>) -> Self {
Self {
number: Decimal::ZERO,
currency: currency.into(),
}
}
#[must_use]
pub const fn is_zero(&self) -> bool {
self.number.is_zero()
}
#[must_use]
pub const fn is_positive(&self) -> bool {
self.number.is_sign_positive() && !self.number.is_zero()
}
#[must_use]
pub const fn is_negative(&self) -> bool {
self.number.is_sign_negative()
}
#[must_use]
pub fn abs(&self) -> Self {
Self {
number: self.number.abs(),
currency: self.currency.clone(),
}
}
#[must_use]
pub const fn scale(&self) -> u32 {
self.number.scale()
}
#[must_use]
pub fn inferred_tolerance(&self) -> Decimal {
Decimal::new(5, self.number.scale() + 1)
}
#[must_use]
pub fn is_near_zero(&self, tolerance: Decimal) -> bool {
self.number.abs() <= tolerance
}
#[must_use]
pub fn is_near(&self, other: &Self, tolerance: Decimal) -> bool {
self.currency == other.currency && (self.number - other.number).abs() <= tolerance
}
#[must_use]
pub fn eq_with_tolerance(&self, other: &Self, tolerance: Decimal) -> bool {
self.is_near(other, tolerance)
}
#[must_use]
pub fn eq_auto_tolerance(&self, other: &Self) -> bool {
if self.currency != other.currency {
return false;
}
let tolerance = self.inferred_tolerance().max(other.inferred_tolerance());
(self.number - other.number).abs() <= tolerance
}
#[must_use]
pub fn round_dp(&self, dp: u32) -> Self {
Self {
number: self.number.round_dp(dp),
currency: self.currency.clone(),
}
}
}
impl fmt::Display for Amount {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} {}", self.number, self.currency)
}
}
impl Add for &Amount {
type Output = Amount;
fn add(self, other: &Amount) -> Amount {
debug_assert_eq!(
self.currency, other.currency,
"Cannot add amounts with different currencies"
);
Amount {
number: self.number + other.number,
currency: self.currency.clone(),
}
}
}
impl Sub for &Amount {
type Output = Amount;
fn sub(self, other: &Amount) -> Amount {
debug_assert_eq!(
self.currency, other.currency,
"Cannot subtract amounts with different currencies"
);
Amount {
number: self.number - other.number,
currency: self.currency.clone(),
}
}
}
impl Neg for &Amount {
type Output = Amount;
fn neg(self) -> Amount {
Amount {
number: -self.number,
currency: self.currency.clone(),
}
}
}
impl Add for Amount {
type Output = Self;
fn add(self, other: Self) -> Self {
&self + &other
}
}
impl Sub for Amount {
type Output = Self;
fn sub(self, other: Self) -> Self {
&self - &other
}
}
impl Neg for Amount {
type Output = Self;
fn neg(self) -> Self {
-&self
}
}
impl AddAssign<&Self> for Amount {
fn add_assign(&mut self, other: &Self) {
debug_assert_eq!(
self.currency, other.currency,
"Cannot add amounts with different currencies"
);
self.number += other.number;
}
}
impl SubAssign<&Self> for Amount {
fn sub_assign(&mut self, other: &Self) {
debug_assert_eq!(
self.currency, other.currency,
"Cannot subtract amounts with different currencies"
);
self.number -= other.number;
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub enum IncompleteAmount {
Complete(Amount),
NumberOnly(#[cfg_attr(feature = "rkyv", rkyv(with = AsDecimal))] Decimal),
CurrencyOnly(#[cfg_attr(feature = "rkyv", rkyv(with = AsInternedStr))] InternedStr),
}
impl IncompleteAmount {
#[must_use]
pub fn complete(number: Decimal, currency: impl Into<InternedStr>) -> Self {
Self::Complete(Amount::new(number, currency))
}
#[must_use]
pub const fn number_only(number: Decimal) -> Self {
Self::NumberOnly(number)
}
#[must_use]
pub fn currency_only(currency: impl Into<InternedStr>) -> Self {
Self::CurrencyOnly(currency.into())
}
#[must_use]
pub const fn number(&self) -> Option<Decimal> {
match self {
Self::Complete(a) => Some(a.number),
Self::NumberOnly(n) => Some(*n),
Self::CurrencyOnly(_) => None,
}
}
#[must_use]
pub fn currency(&self) -> Option<&str> {
match self {
Self::Complete(a) => Some(&a.currency),
Self::NumberOnly(_) => None,
Self::CurrencyOnly(c) => Some(c),
}
}
#[must_use]
pub const fn is_complete(&self) -> bool {
matches!(self, Self::Complete(_))
}
#[must_use]
pub const fn as_amount(&self) -> Option<&Amount> {
match self {
Self::Complete(a) => Some(a),
_ => None,
}
}
#[must_use]
pub fn into_amount(self) -> Option<Amount> {
match self {
Self::Complete(a) => Some(a),
_ => None,
}
}
}
impl From<Amount> for IncompleteAmount {
fn from(amount: Amount) -> Self {
Self::Complete(amount)
}
}
impl fmt::Display for IncompleteAmount {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Complete(a) => write!(f, "{a}"),
Self::NumberOnly(n) => write!(f, "{n}"),
Self::CurrencyOnly(c) => write!(f, "{c}"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rust_decimal_macros::dec;
#[test]
fn test_new() {
let amount = Amount::new(dec!(100.00), "USD");
assert_eq!(amount.number, dec!(100.00));
assert_eq!(amount.currency, "USD");
}
#[test]
fn test_zero() {
let amount = Amount::zero("EUR");
assert!(amount.is_zero());
assert_eq!(amount.currency, "EUR");
}
#[test]
fn test_is_positive_negative() {
let pos = Amount::new(dec!(100), "USD");
let neg = Amount::new(dec!(-100), "USD");
let zero = Amount::zero("USD");
assert!(pos.is_positive());
assert!(!pos.is_negative());
assert!(!neg.is_positive());
assert!(neg.is_negative());
assert!(!zero.is_positive());
assert!(!zero.is_negative());
}
#[test]
fn test_add() {
let a = Amount::new(dec!(100.00), "USD");
let b = Amount::new(dec!(50.00), "USD");
let sum = &a + &b;
assert_eq!(sum.number, dec!(150.00));
assert_eq!(sum.currency, "USD");
}
#[test]
fn test_sub() {
let a = Amount::new(dec!(100.00), "USD");
let b = Amount::new(dec!(50.00), "USD");
let diff = &a - &b;
assert_eq!(diff.number, dec!(50.00));
}
#[test]
fn test_neg() {
let a = Amount::new(dec!(100.00), "USD");
let neg_a = -&a;
assert_eq!(neg_a.number, dec!(-100.00));
}
#[test]
fn test_add_assign() {
let mut a = Amount::new(dec!(100.00), "USD");
let b = Amount::new(dec!(50.00), "USD");
a += &b;
assert_eq!(a.number, dec!(150.00));
}
#[test]
fn test_inferred_tolerance() {
let a = Amount::new(dec!(100), "USD");
assert_eq!(a.inferred_tolerance(), dec!(0.5));
let b = Amount::new(dec!(100.00), "USD");
assert_eq!(b.inferred_tolerance(), dec!(0.005));
let c = Amount::new(dec!(100.000), "USD");
assert_eq!(c.inferred_tolerance(), dec!(0.0005));
}
#[test]
fn test_is_near_zero() {
let a = Amount::new(dec!(0.004), "USD");
assert!(a.is_near_zero(dec!(0.005)));
assert!(!a.is_near_zero(dec!(0.003)));
}
#[test]
fn test_is_near() {
let a = Amount::new(dec!(100.00), "USD");
let b = Amount::new(dec!(100.004), "USD");
assert!(a.is_near(&b, dec!(0.005)));
assert!(!a.is_near(&b, dec!(0.003)));
let c = Amount::new(dec!(100.00), "EUR");
assert!(!a.is_near(&c, dec!(1.0)));
}
#[test]
fn test_display() {
let a = Amount::new(dec!(1234.56), "USD");
assert_eq!(format!("{a}"), "1234.56 USD");
}
#[test]
fn test_abs() {
let neg = Amount::new(dec!(-100.00), "USD");
let abs = neg.abs();
assert_eq!(abs.number, dec!(100.00));
}
#[test]
fn test_eq_with_tolerance() {
let a = Amount::new(dec!(100.00), "USD");
let b = Amount::new(dec!(100.004), "USD");
assert!(a.eq_with_tolerance(&b, dec!(0.005)));
assert!(b.eq_with_tolerance(&a, dec!(0.005)));
assert!(!a.eq_with_tolerance(&b, dec!(0.003)));
let c = Amount::new(dec!(100.00), "EUR");
assert!(!a.eq_with_tolerance(&c, dec!(1.0)));
let d = Amount::new(dec!(100.00), "USD");
assert!(a.eq_with_tolerance(&d, dec!(0.0)));
}
#[test]
#[allow(clippy::many_single_char_names)]
fn test_eq_auto_tolerance() {
let a = Amount::new(dec!(100.00), "USD");
let b = Amount::new(dec!(100.004), "USD");
assert!(a.eq_auto_tolerance(&b));
let c = Amount::new(dec!(100.000), "USD");
let d = Amount::new(dec!(100.001), "USD");
assert!(!c.eq_auto_tolerance(&d));
let e = Amount::new(dec!(100.0004), "USD");
assert!(c.eq_auto_tolerance(&e));
let f = Amount::new(dec!(100.00), "EUR");
assert!(!a.eq_auto_tolerance(&f));
}
}