use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::ops::{Add, AddAssign, Neg, Sub, SubAssign};
use crate::Currency;
#[cfg(feature = "rkyv")]
use crate::intern::AsDecimal;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub struct Amount {
#[cfg_attr(feature = "rkyv", rkyv(with = AsDecimal))]
pub number: Decimal,
pub currency: Currency,
}
impl Amount {
#[must_use]
pub fn new(number: Decimal, currency: impl Into<Currency>) -> Self {
Self {
number,
currency: currency.into(),
}
}
#[must_use]
pub fn zero(currency: impl Into<Currency>) -> Self {
Self {
number: Decimal::ZERO,
currency: currency.into(),
}
}
#[must_use]
pub const fn is_zero(&self) -> bool {
self.number.is_zero()
}
#[must_use]
pub const fn is_positive(&self) -> bool {
self.number.is_sign_positive() && !self.number.is_zero()
}
#[must_use]
pub const fn is_negative(&self) -> bool {
self.number.is_sign_negative()
}
#[must_use]
pub fn abs(&self) -> Self {
Self {
number: self.number.abs(),
currency: self.currency.clone(),
}
}
#[must_use]
pub const fn scale(&self) -> u32 {
self.number.scale()
}
#[must_use]
pub fn inferred_tolerance(&self) -> Decimal {
Decimal::new(5, self.number.scale() + 1)
}
#[must_use]
pub fn is_near_zero(&self, tolerance: Decimal) -> bool {
self.number.abs() <= tolerance
}
#[must_use]
pub fn is_near(&self, other: &Self, tolerance: Decimal) -> bool {
self.currency == other.currency && (self.number - other.number).abs() <= tolerance
}
#[must_use]
pub fn eq_with_tolerance(&self, other: &Self, tolerance: Decimal) -> bool {
self.is_near(other, tolerance)
}
#[must_use]
pub fn eq_auto_tolerance(&self, other: &Self) -> bool {
if self.currency != other.currency {
return false;
}
let tolerance = self.inferred_tolerance().max(other.inferred_tolerance());
(self.number - other.number).abs() <= tolerance
}
#[must_use]
pub fn round_dp(&self, dp: u32) -> Self {
Self {
number: self.number.round_dp(dp),
currency: self.currency.clone(),
}
}
}
impl fmt::Display for Amount {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} {}", self.number, self.currency)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AmountParseError {
pub input: String,
pub reason: AmountParseErrorReason,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AmountParseErrorReason {
NotTwoTokens,
InvalidNumber(String),
InvalidCurrency(String),
}
impl fmt::Display for AmountParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.reason {
AmountParseErrorReason::NotTwoTokens => write!(
f,
"invalid amount literal {:?}: expected `<number> <currency>` (e.g. \"100 USD\")",
self.input,
),
AmountParseErrorReason::InvalidNumber(tok) => write!(
f,
"invalid amount literal {:?}: {:?} doesn't parse as a decimal number",
self.input, tok,
),
AmountParseErrorReason::InvalidCurrency(tok) => write!(
f,
"invalid amount literal {:?}: {:?} isn't a valid commodity \
(uppercase ASCII, may contain digits/'./_/-, max 24 chars)",
self.input, tok,
),
}
}
}
impl std::error::Error for AmountParseError {}
impl std::str::FromStr for Amount {
type Err = AmountParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut iter = s.split_whitespace();
let (Some(num_tok), Some(cur_tok), None) = (iter.next(), iter.next(), iter.next()) else {
return Err(AmountParseError {
input: s.to_string(),
reason: AmountParseErrorReason::NotTwoTokens,
});
};
let number = Decimal::from_str_exact(num_tok).map_err(|_| AmountParseError {
input: s.to_string(),
reason: AmountParseErrorReason::InvalidNumber(num_tok.to_string()),
})?;
if !is_valid_commodity(cur_tok) {
return Err(AmountParseError {
input: s.to_string(),
reason: AmountParseErrorReason::InvalidCurrency(cur_tok.to_string()),
});
}
Ok(Self::new(number, cur_tok))
}
}
fn is_valid_commodity(s: &str) -> bool {
if s.is_empty() || s.len() > 24 {
return false;
}
let mut chars = s.chars();
let Some(first) = chars.next() else {
return false;
};
if !first.is_ascii_uppercase() {
return false;
}
chars.all(|c| {
c.is_ascii_uppercase() || c.is_ascii_digit() || matches!(c, '\'' | '.' | '_' | '-')
})
}
impl Add for &Amount {
type Output = Amount;
fn add(self, other: &Amount) -> Amount {
debug_assert_eq!(
self.currency, other.currency,
"Cannot add amounts with different currencies"
);
Amount {
number: self.number + other.number,
currency: self.currency.clone(),
}
}
}
impl Sub for &Amount {
type Output = Amount;
fn sub(self, other: &Amount) -> Amount {
debug_assert_eq!(
self.currency, other.currency,
"Cannot subtract amounts with different currencies"
);
Amount {
number: self.number - other.number,
currency: self.currency.clone(),
}
}
}
impl Neg for &Amount {
type Output = Amount;
fn neg(self) -> Amount {
Amount {
number: -self.number,
currency: self.currency.clone(),
}
}
}
impl Add for Amount {
type Output = Self;
fn add(self, other: Self) -> Self {
&self + &other
}
}
impl Sub for Amount {
type Output = Self;
fn sub(self, other: Self) -> Self {
&self - &other
}
}
impl Neg for Amount {
type Output = Self;
fn neg(self) -> Self {
-&self
}
}
impl AddAssign<&Self> for Amount {
fn add_assign(&mut self, other: &Self) {
debug_assert_eq!(
self.currency, other.currency,
"Cannot add amounts with different currencies"
);
self.number += other.number;
}
}
impl SubAssign<&Self> for Amount {
fn sub_assign(&mut self, other: &Self) {
debug_assert_eq!(
self.currency, other.currency,
"Cannot subtract amounts with different currencies"
);
self.number -= other.number;
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub enum IncompleteAmount {
Complete(Amount),
NumberOnly(#[cfg_attr(feature = "rkyv", rkyv(with = AsDecimal))] Decimal),
CurrencyOnly(Currency),
}
impl IncompleteAmount {
#[must_use]
pub fn complete(number: Decimal, currency: impl Into<Currency>) -> Self {
Self::Complete(Amount::new(number, currency))
}
#[must_use]
pub const fn number_only(number: Decimal) -> Self {
Self::NumberOnly(number)
}
#[must_use]
pub fn currency_only(currency: impl Into<Currency>) -> Self {
Self::CurrencyOnly(currency.into())
}
#[must_use]
pub const fn number(&self) -> Option<Decimal> {
match self {
Self::Complete(a) => Some(a.number),
Self::NumberOnly(n) => Some(*n),
Self::CurrencyOnly(_) => None,
}
}
#[must_use]
pub fn currency(&self) -> Option<&str> {
match self {
Self::Complete(a) => Some(&a.currency),
Self::NumberOnly(_) => None,
Self::CurrencyOnly(c) => Some(c),
}
}
#[must_use]
pub const fn is_complete(&self) -> bool {
matches!(self, Self::Complete(_))
}
#[must_use]
pub const fn as_amount(&self) -> Option<&Amount> {
match self {
Self::Complete(a) => Some(a),
_ => None,
}
}
#[must_use]
pub fn into_amount(self) -> Option<Amount> {
match self {
Self::Complete(a) => Some(a),
_ => None,
}
}
}
impl From<Amount> for IncompleteAmount {
fn from(amount: Amount) -> Self {
Self::Complete(amount)
}
}
impl fmt::Display for IncompleteAmount {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Complete(a) => write!(f, "{a}"),
Self::NumberOnly(n) => write!(f, "{n}"),
Self::CurrencyOnly(c) => write!(f, "{c}"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rust_decimal_macros::dec;
#[test]
fn test_new() {
let amount = Amount::new(dec!(100.00), "USD");
assert_eq!(amount.number, dec!(100.00));
assert_eq!(amount.currency, "USD");
}
#[test]
fn test_zero() {
let amount = Amount::zero("EUR");
assert!(amount.is_zero());
assert_eq!(amount.currency, "EUR");
}
#[test]
fn test_is_positive_negative() {
let pos = Amount::new(dec!(100), "USD");
let neg = Amount::new(dec!(-100), "USD");
let zero = Amount::zero("USD");
assert!(pos.is_positive());
assert!(!pos.is_negative());
assert!(!neg.is_positive());
assert!(neg.is_negative());
assert!(!zero.is_positive());
assert!(!zero.is_negative());
}
#[test]
fn test_add() {
let a = Amount::new(dec!(100.00), "USD");
let b = Amount::new(dec!(50.00), "USD");
let sum = &a + &b;
assert_eq!(sum.number, dec!(150.00));
assert_eq!(sum.currency, "USD");
}
#[test]
fn test_sub() {
let a = Amount::new(dec!(100.00), "USD");
let b = Amount::new(dec!(50.00), "USD");
let diff = &a - &b;
assert_eq!(diff.number, dec!(50.00));
}
#[test]
fn test_neg() {
let a = Amount::new(dec!(100.00), "USD");
let neg_a = -&a;
assert_eq!(neg_a.number, dec!(-100.00));
}
#[test]
fn test_add_assign() {
let mut a = Amount::new(dec!(100.00), "USD");
let b = Amount::new(dec!(50.00), "USD");
a += &b;
assert_eq!(a.number, dec!(150.00));
}
#[test]
fn test_inferred_tolerance() {
let a = Amount::new(dec!(100), "USD");
assert_eq!(a.inferred_tolerance(), dec!(0.5));
let b = Amount::new(dec!(100.00), "USD");
assert_eq!(b.inferred_tolerance(), dec!(0.005));
let c = Amount::new(dec!(100.000), "USD");
assert_eq!(c.inferred_tolerance(), dec!(0.0005));
}
#[test]
fn test_is_near_zero() {
let a = Amount::new(dec!(0.004), "USD");
assert!(a.is_near_zero(dec!(0.005)));
assert!(!a.is_near_zero(dec!(0.003)));
}
#[test]
fn test_is_near() {
let a = Amount::new(dec!(100.00), "USD");
let b = Amount::new(dec!(100.004), "USD");
assert!(a.is_near(&b, dec!(0.005)));
assert!(!a.is_near(&b, dec!(0.003)));
let c = Amount::new(dec!(100.00), "EUR");
assert!(!a.is_near(&c, dec!(1.0)));
}
#[test]
fn test_display() {
let a = Amount::new(dec!(1234.56), "USD");
assert_eq!(format!("{a}"), "1234.56 USD");
}
#[test]
fn test_abs() {
let neg = Amount::new(dec!(-100.00), "USD");
let abs = neg.abs();
assert_eq!(abs.number, dec!(100.00));
}
#[test]
fn test_eq_with_tolerance() {
let a = Amount::new(dec!(100.00), "USD");
let b = Amount::new(dec!(100.004), "USD");
assert!(a.eq_with_tolerance(&b, dec!(0.005)));
assert!(b.eq_with_tolerance(&a, dec!(0.005)));
assert!(!a.eq_with_tolerance(&b, dec!(0.003)));
let c = Amount::new(dec!(100.00), "EUR");
assert!(!a.eq_with_tolerance(&c, dec!(1.0)));
let d = Amount::new(dec!(100.00), "USD");
assert!(a.eq_with_tolerance(&d, dec!(0.0)));
}
#[test]
#[allow(clippy::many_single_char_names)]
fn test_eq_auto_tolerance() {
let a = Amount::new(dec!(100.00), "USD");
let b = Amount::new(dec!(100.004), "USD");
assert!(a.eq_auto_tolerance(&b));
let c = Amount::new(dec!(100.000), "USD");
let d = Amount::new(dec!(100.001), "USD");
assert!(!c.eq_auto_tolerance(&d));
let e = Amount::new(dec!(100.0004), "USD");
assert!(c.eq_auto_tolerance(&e));
let f = Amount::new(dec!(100.00), "EUR");
assert!(!a.eq_auto_tolerance(&f));
}
use std::str::FromStr;
#[test]
fn amount_from_str_round_trips_display() {
for amt in [
Amount::new(dec!(100), "USD"),
Amount::new(dec!(-50.25), "EUR"),
Amount::new(dec!(0), "GBP"),
Amount::new(dec!(1234567.89), "JPY"),
Amount::new(dec!(0.0001), "USD"),
] {
let displayed = amt.to_string();
assert_eq!(
Amount::from_str(&displayed),
Ok(amt.clone()),
"round-trip lost data: Display produced {displayed:?}"
);
}
}
#[test]
fn amount_from_str_accepts_canonical_forms() {
assert_eq!(
Amount::from_str("100 USD"),
Ok(Amount::new(dec!(100), "USD"))
);
assert_eq!(
Amount::from_str("-50.25 EUR"),
Ok(Amount::new(dec!(-50.25), "EUR"))
);
assert_eq!(
Amount::from_str(" 100 USD "),
Ok(Amount::new(dec!(100), "USD"))
);
assert_eq!(Amount::from_str("1 X"), Ok(Amount::new(dec!(1), "X")));
assert_eq!(
Amount::from_str("100 RY-2024"),
Ok(Amount::new(dec!(100), "RY-2024"))
);
}
#[test]
fn amount_from_str_rejects_currency_first() {
let err = Amount::from_str("USD 100").expect_err("currency-first must reject");
assert!(matches!(
err.reason,
AmountParseErrorReason::InvalidNumber(_)
));
}
#[test]
fn amount_from_str_rejects_single_token() {
for s in ["", " ", "100", "USD"] {
let err = Amount::from_str(s).expect_err("single token must reject");
assert!(
matches!(err.reason, AmountParseErrorReason::NotTwoTokens),
"expected NotTwoTokens for {s:?}, got {:?}",
err.reason
);
}
}
#[test]
fn amount_from_str_rejects_extra_tokens() {
let err = Amount::from_str("100 USD extra").expect_err("trailing token must reject");
assert!(matches!(err.reason, AmountParseErrorReason::NotTwoTokens));
}
#[test]
fn amount_from_str_rejects_scientific_notation() {
let err = Amount::from_str("1e2 USD").expect_err("scientific must reject");
assert!(matches!(
err.reason,
AmountParseErrorReason::InvalidNumber(_)
));
}
#[test]
fn amount_from_str_rejects_thousands_separator() {
let err = Amount::from_str("1,000 USD").expect_err("thousands sep must reject");
assert!(matches!(
err.reason,
AmountParseErrorReason::InvalidNumber(_)
));
}
#[test]
fn amount_from_str_rejects_lowercase_currency() {
let err = Amount::from_str("100 usd").expect_err("lowercase commodity must reject");
assert!(matches!(
err.reason,
AmountParseErrorReason::InvalidCurrency(_)
));
}
#[test]
fn amount_from_str_rejects_currency_starting_with_digit() {
let err = Amount::from_str("100 1USD").expect_err("digit-first commodity must reject");
assert!(matches!(
err.reason,
AmountParseErrorReason::InvalidCurrency(_)
));
}
#[test]
fn amount_from_str_error_message_names_input() {
let err = Amount::from_str("oopsie daisy").unwrap_err();
let msg = err.to_string();
assert!(msg.contains("oopsie daisy"), "error must echo input: {msg}");
assert!(msg.contains("doesn't parse"), "error must explain: {msg}");
}
}