use starknet::core::types::Felt;
use std::fmt;
use crate::{
error::{Result, StarkzapError},
tokens::Token,
};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct Amount {
raw: u128,
decimals: u8,
symbol: String,
}
impl Amount {
pub fn parse(value: &str, token: &Token) -> Result<Self> {
let raw = parse_decimal(value, token.decimals)?;
Ok(Self {
raw,
decimals: token.decimals,
symbol: token.symbol.clone(),
})
}
pub fn from_raw(raw: u128, token: &Token) -> Self {
Self {
raw,
decimals: token.decimals,
symbol: token.symbol.clone(),
}
}
pub fn raw(&self) -> u128 {
self.raw
}
pub fn to_formatted(&self) -> String {
format!("{} {}", self.to_decimal_string(), self.symbol)
}
pub fn to_decimal_string(&self) -> String {
format_decimal(self.raw, self.decimals)
}
pub fn to_u256_felts(&self) -> [Felt; 2] {
[Felt::from(self.raw), Felt::ZERO]
}
pub fn checked_add(&self, other: &Amount) -> Option<Amount> {
let raw = self.raw.checked_add(other.raw)?;
Some(Amount {
raw,
decimals: self.decimals,
symbol: self.symbol.clone(),
})
}
pub fn is_zero(&self) -> bool {
self.raw == 0
}
}
impl fmt::Display for Amount {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_formatted())
}
}
fn parse_decimal(value: &str, decimals: u8) -> Result<u128> {
let value = value.trim();
if value.is_empty() {
return Err(StarkzapError::AmountParse {
input: value.to_string(),
});
}
let (integer_part, fractional_part) = match value.split_once('.') {
Some((int, frac)) => (int, frac),
None => (value, ""),
};
if !integer_part.chars().all(|c| c.is_ascii_digit())
|| !fractional_part.chars().all(|c| c.is_ascii_digit())
{
return Err(StarkzapError::AmountParse {
input: value.to_string(),
});
}
let dec = decimals as usize;
let frac_padded = if fractional_part.len() > dec {
fractional_part[..dec].to_string()
} else {
format!("{:0<width$}", fractional_part, width = dec)
};
let int_val: u128 = if integer_part.is_empty() {
0
} else {
integer_part
.parse::<u128>()
.map_err(|_| StarkzapError::AmountParse {
input: value.to_string(),
})?
};
let scale = 10u128
.checked_pow(dec as u32)
.ok_or(StarkzapError::AmountOverflow)?;
let frac_val: u128 = if frac_padded.is_empty() {
0
} else {
frac_padded
.parse::<u128>()
.map_err(|_| StarkzapError::AmountParse {
input: value.to_string(),
})?
};
int_val
.checked_mul(scale)
.and_then(|v| v.checked_add(frac_val))
.ok_or(StarkzapError::AmountOverflow)
}
fn format_decimal(raw: u128, decimals: u8) -> String {
if decimals == 0 {
return raw.to_string();
}
let scale = 10u128.pow(decimals as u32);
let integer = raw / scale;
let fraction = raw % scale;
if fraction == 0 {
integer.to_string()
} else {
let frac_str = format!("{:0>width$}", fraction, width = decimals as usize);
let frac_trimmed = frac_str.trim_end_matches('0');
format!("{}.{}", integer, frac_trimmed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokens::Token;
fn usdc() -> Token {
Token {
symbol: "USDC".to_string(),
name: "USD Coin".to_string(),
decimals: 6,
address: Felt::ZERO,
}
}
fn strk() -> Token {
Token {
symbol: "STRK".to_string(),
name: "Starknet Token".to_string(),
decimals: 18,
address: Felt::ZERO,
}
}
#[test]
fn parse_whole_number() {
let a = Amount::parse("10", &usdc()).unwrap();
assert_eq!(a.raw(), 10_000_000);
}
#[test]
fn parse_decimal() {
let a = Amount::parse("10.5", &usdc()).unwrap();
assert_eq!(a.raw(), 10_500_000);
}
#[test]
fn parse_max_precision() {
let a = Amount::parse("0.000001", &usdc()).unwrap();
assert_eq!(a.raw(), 1);
}
#[test]
fn parse_strk_18_decimals() {
let a = Amount::parse("1.5", &strk()).unwrap();
assert_eq!(a.raw(), 1_500_000_000_000_000_000u128);
}
#[test]
fn format_round_trip() {
let a = Amount::parse("10.5", &usdc()).unwrap();
assert_eq!(a.to_decimal_string(), "10.5");
assert_eq!(a.to_formatted(), "10.5 USDC");
}
#[test]
fn format_no_trailing_zeros() {
let a = Amount::parse("1.10", &usdc()).unwrap();
assert_eq!(a.to_decimal_string(), "1.1");
}
#[test]
fn to_u256_felts() {
let a = Amount::parse("1", &usdc()).unwrap();
let [low, high] = a.to_u256_felts();
assert_eq!(low, Felt::from(1_000_000u128));
assert_eq!(high, Felt::ZERO);
}
#[test]
fn invalid_input_errors() {
assert!(Amount::parse("abc", &usdc()).is_err());
assert!(Amount::parse("", &usdc()).is_err());
assert!(Amount::parse("1.2.3", &usdc()).is_err());
}
}