use crate::error::Result;
use serde::{Deserialize, Serialize};
use std::{
fmt,
ops::{Add, Sub},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Currency {
USD,
USDC,
SOL,
}
impl Currency {
pub fn decimals(&self) -> u8 {
match self {
Currency::USD => 2,
Currency::USDC => 6, Currency::SOL => 9, }
}
pub fn symbol(&self) -> &'static str {
match self {
Currency::USD => "$",
Currency::USDC => "USDC",
Currency::SOL => "SOL",
}
}
}
impl fmt::Display for Currency {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.symbol())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Money {
amount: u64,
currency: Currency,
}
impl Money {
pub fn new(amount: u64, currency: Currency) -> Self {
Self { amount, currency }
}
pub fn from_decimal(decimal: f64, currency: Currency) -> Self {
let multiplier = 10_u64.pow(u32::from(currency.decimals()));
let amount = (decimal * multiplier as f64).round() as u64;
Self { amount, currency }
}
pub fn zero(currency: Currency) -> Self {
Self {
amount: 0,
currency,
}
}
pub fn usd_cents(cents: u64) -> Self {
Self::new(cents, Currency::USD)
}
pub fn usd(dollars: f64) -> Self {
Self::from_decimal(dollars, Currency::USD)
}
pub fn usdc(amount: u64) -> Self {
Self::new(amount, Currency::USDC)
}
pub fn usdc_decimal(amount: f64) -> Self {
Self::from_decimal(amount, Currency::USDC)
}
pub fn lamports(lamports: u64) -> Self {
Self::new(lamports, Currency::SOL)
}
pub fn sol(amount: f64) -> Self {
Self::from_decimal(amount, Currency::SOL)
}
pub fn amount(&self) -> u64 {
self.amount
}
pub fn currency(&self) -> Currency {
self.currency
}
pub fn as_decimal(&self) -> f64 {
let divisor = 10_u64.pow(u32::from(self.currency.decimals()));
self.amount as f64 / divisor as f64
}
pub fn is_zero(&self) -> bool {
self.amount == 0
}
pub fn is_positive(&self) -> bool {
self.amount > 0
}
pub fn at_least(&self, minimum: &Money) -> Result<()> {
if self.currency != minimum.currency {
return Err(crate::error::AllSourceError::InvalidInput(format!(
"Cannot compare {} with {}",
self.currency, minimum.currency
)));
}
if self.amount < minimum.amount {
return Err(crate::error::AllSourceError::ValidationError(format!(
"Amount {} is less than minimum {}",
self.as_decimal(),
minimum.as_decimal()
)));
}
Ok(())
}
pub fn percentage(&self, percent: u64) -> Self {
let amount = (self.amount * percent) / 100;
Self {
amount,
currency: self.currency,
}
}
pub fn subtract_percentage(&self, percent: u64) -> Self {
let fee = (self.amount * percent) / 100;
Self {
amount: self.amount.saturating_sub(fee),
currency: self.currency,
}
}
}
impl fmt::Display for Money {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.currency {
Currency::USD => write!(f, "${:.2}", self.as_decimal()),
Currency::USDC => write!(f, "{:.2} USDC", self.as_decimal()),
Currency::SOL => write!(f, "{:.4} SOL", self.as_decimal()),
}
}
}
impl Add for Money {
type Output = Result<Money>;
fn add(self, other: Money) -> Self::Output {
if self.currency != other.currency {
return Err(crate::error::AllSourceError::InvalidInput(format!(
"Cannot add {} to {}",
self.currency, other.currency
)));
}
Ok(Money {
amount: self.amount + other.amount,
currency: self.currency,
})
}
}
impl Sub for Money {
type Output = Result<Money>;
fn sub(self, other: Money) -> Self::Output {
if self.currency != other.currency {
return Err(crate::error::AllSourceError::InvalidInput(format!(
"Cannot subtract {} from {}",
other.currency, self.currency
)));
}
Ok(Money {
amount: self.amount.saturating_sub(other.amount),
currency: self.currency,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_money_from_smallest_unit() {
let money = Money::new(50, Currency::USD);
assert_eq!(money.amount(), 50);
assert_eq!(money.currency(), Currency::USD);
}
#[test]
fn test_create_money_from_decimal() {
let money = Money::from_decimal(0.50, Currency::USD);
assert_eq!(money.amount(), 50);
assert_eq!(money.currency(), Currency::USD);
let money = Money::from_decimal(1.00, Currency::USDC);
assert_eq!(money.amount(), 1_000_000); }
#[test]
fn test_usd_helpers() {
let cents = Money::usd_cents(150);
assert_eq!(cents.amount(), 150);
assert_eq!(cents.as_decimal(), 1.50);
let dollars = Money::usd(1.50);
assert_eq!(dollars.amount(), 150);
}
#[test]
fn test_usdc_helpers() {
let usdc = Money::usdc_decimal(1.0);
assert_eq!(usdc.amount(), 1_000_000);
assert_eq!(usdc.currency(), Currency::USDC);
}
#[test]
fn test_sol_helpers() {
let sol = Money::sol(1.0);
assert_eq!(sol.amount(), 1_000_000_000);
assert_eq!(sol.currency(), Currency::SOL);
let lamports = Money::lamports(1_000_000_000);
assert_eq!(lamports.as_decimal(), 1.0);
}
#[test]
fn test_zero() {
let zero = Money::zero(Currency::USD);
assert!(zero.is_zero());
assert!(!zero.is_positive());
}
#[test]
fn test_is_positive() {
let money = Money::usd_cents(100);
assert!(money.is_positive());
assert!(!money.is_zero());
}
#[test]
fn test_at_least() {
let amount = Money::usd_cents(100);
let minimum = Money::usd_cents(50);
assert!(amount.at_least(&minimum).is_ok());
let small = Money::usd_cents(25);
assert!(small.at_least(&minimum).is_err());
}
#[test]
fn test_at_least_different_currency_fails() {
let usd = Money::usd_cents(100);
let sol = Money::lamports(100);
assert!(usd.at_least(&sol).is_err());
}
#[test]
fn test_percentage() {
let money = Money::usd_cents(1000); let fee = money.percentage(7); assert_eq!(fee.amount(), 70); }
#[test]
fn test_subtract_percentage() {
let money = Money::usd_cents(1000); let after_fee = money.subtract_percentage(7); assert_eq!(after_fee.amount(), 930); }
#[test]
fn test_add_same_currency() {
let a = Money::usd_cents(100);
let b = Money::usd_cents(50);
let result = (a + b).unwrap();
assert_eq!(result.amount(), 150);
}
#[test]
fn test_add_different_currency_fails() {
let usd = Money::usd_cents(100);
let sol = Money::lamports(100);
let result = usd + sol;
assert!(result.is_err());
}
#[test]
fn test_sub_same_currency() {
let a = Money::usd_cents(100);
let b = Money::usd_cents(50);
let result = (a - b).unwrap();
assert_eq!(result.amount(), 50);
}
#[test]
fn test_sub_saturating() {
let a = Money::usd_cents(50);
let b = Money::usd_cents(100);
let result = (a - b).unwrap();
assert_eq!(result.amount(), 0); }
#[test]
fn test_sub_different_currency_fails() {
let usd = Money::usd_cents(100);
let sol = Money::lamports(100);
let result = usd - sol;
assert!(result.is_err());
}
#[test]
fn test_display_usd() {
let money = Money::usd_cents(150);
assert_eq!(format!("{money}"), "$1.50");
}
#[test]
fn test_display_usdc() {
let money = Money::usdc_decimal(1.50);
assert_eq!(format!("{money}"), "1.50 USDC");
}
#[test]
fn test_display_sol() {
let money = Money::sol(0.001);
assert_eq!(format!("{money}"), "0.0010 SOL");
}
#[test]
fn test_currency_decimals() {
assert_eq!(Currency::USD.decimals(), 2);
assert_eq!(Currency::USDC.decimals(), 6);
assert_eq!(Currency::SOL.decimals(), 9);
}
#[test]
fn test_currency_symbol() {
assert_eq!(Currency::USD.symbol(), "$");
assert_eq!(Currency::USDC.symbol(), "USDC");
assert_eq!(Currency::SOL.symbol(), "SOL");
}
#[test]
fn test_equality() {
let a = Money::usd_cents(100);
let b = Money::usd_cents(100);
let c = Money::usd_cents(200);
let d = Money::new(100, Currency::SOL);
assert_eq!(a, b);
assert_ne!(a, c);
assert_ne!(a, d); }
#[test]
fn test_cloning() {
let a = Money::usd_cents(100);
let b = a; assert_eq!(a, b);
}
#[test]
fn test_hash_consistency() {
use std::collections::HashSet;
let a = Money::usd_cents(100);
let b = Money::usd_cents(100);
let mut set = HashSet::new();
set.insert(a);
assert!(set.contains(&b));
}
#[test]
fn test_serde_serialization() {
let money = Money::usd_cents(150);
let json = serde_json::to_string(&money).unwrap();
let deserialized: Money = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, money);
}
#[test]
fn test_as_decimal_precision() {
let usd = Money::usd_cents(123);
assert!((usd.as_decimal() - 1.23).abs() < f64::EPSILON);
let usdc = Money::usdc(1_234_567);
assert!((usdc.as_decimal() - 1.234567).abs() < 0.000001);
let sol = Money::lamports(1_234_567_890);
assert!((sol.as_decimal() - 1.23456789).abs() < 0.000000001);
}
}