use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;
use std::str::FromStr;
use super::error::{OtomlError, Result};
pub const MICRO_UNITS: u64 = 1_000_000;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct OCur {
asset: String,
amount: u64,
}
impl OCur {
pub fn new(asset: &str, micro_units: u64) -> Result<Self> {
let asset = validate_asset(asset)?;
Ok(OCur {
asset,
amount: micro_units,
})
}
pub fn from_units(asset: &str, units: f64) -> Result<Self> {
if units < 0.0 {
return Err(OtomlError::InvalidCurrency(
"amount cannot be negative".to_string(),
));
}
let micro_units = (units * MICRO_UNITS as f64).round() as u64;
Self::new(asset, micro_units)
}
pub fn asset(&self) -> &str {
&self.asset
}
pub fn micro_units(&self) -> u64 {
self.amount
}
pub fn as_units(&self) -> f64 {
self.amount as f64 / MICRO_UNITS as f64
}
pub fn zero(asset: &str) -> Result<Self> {
Self::new(asset, 0)
}
pub fn is_zero(&self) -> bool {
self.amount == 0
}
pub fn add(&self, other: &OCur) -> Result<OCur> {
if self.asset != other.asset {
return Err(OtomlError::InvalidCurrency(format!(
"cannot add {} and {} (different assets)",
self.asset, other.asset
)));
}
let sum = self
.amount
.checked_add(other.amount)
.ok_or_else(|| OtomlError::InvalidCurrency("overflow".to_string()))?;
Ok(OCur {
asset: self.asset.clone(),
amount: sum,
})
}
pub fn sub(&self, other: &OCur) -> Result<OCur> {
if self.asset != other.asset {
return Err(OtomlError::InvalidCurrency(format!(
"cannot subtract {} and {} (different assets)",
self.asset, other.asset
)));
}
let diff = self
.amount
.checked_sub(other.amount)
.ok_or_else(|| OtomlError::InvalidCurrency("underflow".to_string()))?;
Ok(OCur {
asset: self.asset.clone(),
amount: diff,
})
}
}
impl Default for OCur {
fn default() -> Self {
OCur {
asset: "usd".to_string(),
amount: 0,
}
}
}
impl fmt::Display for OCur {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[\"{}\", {}]", self.asset, self.amount)
}
}
impl FromStr for OCur {
type Err = OtomlError;
fn from_str(s: &str) -> Result<Self> {
let s = s.trim();
if !s.starts_with('[') || !s.ends_with(']') {
return Err(OtomlError::InvalidCurrency(format!(
"invalid format, expected [\"asset\", amount], got '{}'",
s
)));
}
let inner = &s[1..s.len() - 1];
let parts: Vec<&str> = inner.splitn(2, ',').collect();
if parts.len() != 2 {
return Err(OtomlError::InvalidCurrency(format!(
"invalid format, expected [\"asset\", amount], got '{}'",
s
)));
}
let asset_part = parts[0].trim();
let amount_part = parts[1].trim();
if !asset_part.starts_with('"') || !asset_part.ends_with('"') {
return Err(OtomlError::InvalidCurrency(format!(
"asset must be a quoted string, got '{}'",
asset_part
)));
}
let asset = &asset_part[1..asset_part.len() - 1];
let amount: u64 = amount_part.parse().map_err(|_| {
OtomlError::InvalidCurrency(format!("invalid amount '{}'", amount_part))
})?;
OCur::new(asset, amount)
}
}
fn validate_asset(asset: &str) -> Result<String> {
let asset = asset.trim().to_lowercase();
if asset.is_empty() {
return Err(OtomlError::InvalidCurrency(
"asset code cannot be empty".to_string(),
));
}
if asset.len() > 5 {
return Err(OtomlError::InvalidCurrency(format!(
"asset code '{}' too long (max 5 chars)",
asset
)));
}
for c in asset.chars() {
if !c.is_ascii_lowercase() && !c.is_ascii_digit() {
return Err(OtomlError::InvalidCurrency(format!(
"asset code '{}' contains invalid character '{}' (only a-z and 0-9 allowed)",
asset, c
)));
}
}
Ok(asset)
}
impl Serialize for OCur {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::SerializeTuple;
let mut tup = serializer.serialize_tuple(2)?;
tup.serialize_element(&self.asset)?;
tup.serialize_element(&self.amount)?;
tup.end()
}
}
impl<'de> Deserialize<'de> for OCur {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let (asset, amount): (String, u64) = Deserialize::deserialize(deserializer)?;
OCur::new(&asset, amount).map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let cur = OCur::new("usd", 1_000_000).unwrap();
assert_eq!(cur.asset(), "usd");
assert_eq!(cur.micro_units(), 1_000_000);
assert_eq!(cur.as_units(), 1.0);
}
#[test]
fn test_from_units() {
let cur = OCur::from_units("usd", 1.25).unwrap();
assert_eq!(cur.micro_units(), 1_250_000);
assert_eq!(cur.as_units(), 1.25);
}
#[test]
fn test_display() {
let cur = OCur::new("usdc", 1_250_000).unwrap();
assert_eq!(cur.to_string(), "[\"usdc\", 1250000]");
}
#[test]
fn test_parse() {
let cur: OCur = "[\"usd\", 1250000]".parse().unwrap();
assert_eq!(cur.asset(), "usd");
assert_eq!(cur.micro_units(), 1_250_000);
}
#[test]
fn test_asset_validation() {
assert!(OCur::new("usd", 0).is_ok());
assert!(OCur::new("usdc", 0).is_ok());
assert!(OCur::new("btc", 0).is_ok());
assert!(OCur::new("eth", 0).is_ok());
assert!(OCur::new("gold", 0).is_ok());
assert!(OCur::new("abc12", 0).is_ok());
assert!(OCur::new("", 0).is_err()); assert!(OCur::new("toolong", 0).is_err()); assert!(OCur::new("USD", 0).is_ok()); assert!(OCur::new("us-d", 0).is_err()); assert!(OCur::new("us d", 0).is_err()); assert!(OCur::new("us_d", 0).is_err()); }
#[test]
fn test_asset_normalization() {
let cur = OCur::new("USD", 100).unwrap();
assert_eq!(cur.asset(), "usd");
let cur = OCur::new("BtC", 100).unwrap();
assert_eq!(cur.asset(), "btc");
}
#[test]
fn test_arithmetic() {
let a = OCur::new("usd", 1_000_000).unwrap();
let b = OCur::new("usd", 500_000).unwrap();
let sum = a.add(&b).unwrap();
assert_eq!(sum.micro_units(), 1_500_000);
let diff = a.sub(&b).unwrap();
assert_eq!(diff.micro_units(), 500_000);
}
#[test]
fn test_arithmetic_different_assets() {
let a = OCur::new("usd", 1_000_000).unwrap();
let b = OCur::new("eur", 500_000).unwrap();
assert!(a.add(&b).is_err());
assert!(a.sub(&b).is_err());
}
#[test]
fn test_underflow() {
let a = OCur::new("usd", 100).unwrap();
let b = OCur::new("usd", 200).unwrap();
assert!(a.sub(&b).is_err());
}
#[test]
fn test_zero() {
let cur = OCur::zero("usd").unwrap();
assert!(cur.is_zero());
assert_eq!(cur.micro_units(), 0);
}
#[test]
fn test_common_amounts() {
let one_dollar = OCur::new("usd", 1_000_000).unwrap();
assert_eq!(one_dollar.as_units(), 1.0);
let one_cent = OCur::new("usd", 10_000).unwrap();
assert_eq!(one_cent.as_units(), 0.01);
let quarter = OCur::new("usd", 250_000).unwrap();
assert_eq!(quarter.as_units(), 0.25);
let tiny_btc = OCur::new("btc", 1).unwrap();
assert_eq!(tiny_btc.as_units(), 0.000001);
}
#[test]
fn test_serde_roundtrip() {
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct Invoice {
total: OCur,
}
let invoice = Invoice {
total: OCur::new("usdc", 1_250_000).unwrap(),
};
let otoml = crate::dump_otoml(&invoice).unwrap();
assert!(otoml.contains("total = [\"usdc\", 1250000]"));
let parsed: Invoice = crate::load_otoml(&otoml).unwrap();
assert_eq!(invoice, parsed);
}
#[test]
fn test_large_amounts() {
let billion = OCur::new("usd", 1_000_000_000_000_000).unwrap();
assert_eq!(billion.as_units(), 1_000_000_000.0);
let max = OCur::new("usd", u64::MAX).unwrap();
assert_eq!(max.micro_units(), u64::MAX);
}
#[test]
fn test_binary_roundtrip() {
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct Wallet {
balance: OCur,
pending: Option<OCur>,
}
let wallet = Wallet {
balance: OCur::new("usdc", 1_500_000_000).unwrap(),
pending: Some(OCur::new("usdc", 250_000_000).unwrap()),
};
let bytes = crate::dump_obin(&wallet).unwrap();
let parsed: Wallet = crate::load_obin(&bytes).unwrap();
assert_eq!(wallet, parsed);
}
#[test]
fn test_default() {
let cur = OCur::default();
assert_eq!(cur.asset(), "usd");
assert_eq!(cur.micro_units(), 0);
assert!(cur.is_zero());
}
#[test]
fn test_hash() {
use std::collections::HashSet;
let c1 = OCur::new("usd", 1_000_000).unwrap();
let c2 = OCur::new("usd", 1_000_000).unwrap();
let c3 = OCur::new("usd", 2_000_000).unwrap();
let c4 = OCur::new("eur", 1_000_000).unwrap();
let mut set = HashSet::new();
set.insert(c1.clone());
set.insert(c2); set.insert(c3);
set.insert(c4);
assert_eq!(set.len(), 3);
}
#[test]
fn test_clone() {
let c1 = OCur::new("btc", 50_000_000).unwrap();
let c2 = c1.clone();
assert_eq!(c1, c2);
assert_eq!(c1.asset(), c2.asset());
assert_eq!(c1.micro_units(), c2.micro_units());
}
#[test]
fn test_negative_from_units() {
let result = OCur::from_units("usd", -10.0);
assert!(result.is_err());
}
#[test]
fn test_overflow_add() {
let max = OCur::new("usd", u64::MAX).unwrap();
let one = OCur::new("usd", 1).unwrap();
let result = max.add(&one);
assert!(result.is_err());
}
#[test]
fn test_parse_edge_cases() {
let c: OCur = "[ \"usd\" , 1000000 ]".parse().unwrap();
assert_eq!(c.micro_units(), 1_000_000);
assert!("\"usd\", 1000000".parse::<OCur>().is_err());
assert!("[\"usd\"; 1000000]".parse::<OCur>().is_err());
assert!("[\"usd\", -1000000]".parse::<OCur>().is_err());
}
#[test]
fn test_crypto_assets() {
let assets = ["btc", "eth", "usdc", "usdt", "sol", "matic"];
for asset in assets {
let cur = OCur::new(asset, 1_000_000).unwrap();
assert_eq!(cur.asset(), asset);
}
}
#[test]
fn test_precision_preservation() {
let mut total = OCur::zero("usd").unwrap();
for _ in 0..100 {
let penny = OCur::new("usd", 10_000).unwrap(); total = total.add(&penny).unwrap();
}
assert_eq!(total.micro_units(), 1_000_000);
assert_eq!(total.as_units(), 1.0);
}
}