use crate::common::cbor::{
CborDecoder, CborDeserialize, CborEncoder, CborSerializationResult, CborSerialize,
UnsignedDecimalFraction,
};
use anyhow::Context;
use rust_decimal::Error;
use std::str::FromStr;
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, serde::Deserialize, serde::Serialize)]
#[serde(try_from = "TokenAmountJson", into = "TokenAmountJson")]
pub struct TokenAmount {
value: u64,
decimals: u8,
}
impl CborSerialize for TokenAmount {
fn serialize<C: CborEncoder>(&self, encoder: C) -> CborSerializationResult<()> {
let decimal_fraction = UnsignedDecimalFraction::new(
i64::from(self.decimals)
.checked_neg()
.context("convert decimals to exponent")?,
self.value,
);
decimal_fraction.serialize(encoder)
}
}
impl CborDeserialize for TokenAmount {
fn deserialize<C: CborDecoder>(decoder: C) -> CborSerializationResult<Self>
where
Self: Sized,
{
let decimal_fraction = UnsignedDecimalFraction::deserialize(decoder)?;
let decimals = decimal_fraction
.exponent()
.checked_neg()
.and_then(|val| u8::try_from(val).ok())
.context("convert exponent to decimals")?;
let value = decimal_fraction.mantissa();
Ok(Self { decimals, value })
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum ConversionRule {
Exact,
AllowRounding,
}
impl TokenAmount {
pub fn from_raw(value: u64, decimals: u8) -> Self {
Self { value, decimals }
}
pub fn decimals(&self) -> u8 {
self.decimals
}
pub fn value(&self) -> u64 {
self.value
}
pub fn try_from_rust_decimal(
decimal: rust_decimal::Decimal,
decimals: u8,
conversion_rule: ConversionRule,
) -> TokenAmountConversionResult<Self> {
let decimals_u32 = decimals as u32;
let mut decimal_scaled = decimal;
decimal_scaled.rescale(decimals_u32);
if decimals_u32 > rust_decimal::Decimal::MAX_SCALE {
return Err(TokenAmountConversionError::RustDecimal(
rust_decimal::Error::ScaleExceedsMaximumPrecision(decimals_u32),
));
}
if decimal_scaled.scale() != decimals_u32 {
return Err(TokenAmountConversionError::ValueOverflow);
}
let this = Self::from_raw(
decimal_scaled
.mantissa()
.try_into()
.map_err(|_| TokenAmountConversionError::ValueOverflow)?,
decimals,
);
decimal_scaled.rescale(decimal.scale());
if decimal_scaled.mantissa() != decimal.mantissa()
&& conversion_rule == ConversionRule::Exact
{
return Err(TokenAmountConversionError::LossOfPrecision);
}
Ok(this)
}
pub fn try_to_rust_decimal(&self) -> TokenAmountConversionResult<rust_decimal::Decimal> {
Ok(rust_decimal::Decimal::try_from_i128_with_scale(
self.value as i128,
self.decimals as u32,
)?)
}
pub fn from_str(
decimal_str: &str,
decimals: u8,
conversion_rule: ConversionRule,
) -> TokenAmountConversionResult<Self> {
let decimal = match conversion_rule {
ConversionRule::Exact => {
rust_decimal::Decimal::from_str_exact(decimal_str).map_err(|err| match err {
Error::Underflow => TokenAmountConversionError::LossOfPrecision,
err => TokenAmountConversionError::RustDecimal(err),
})?
}
ConversionRule::AllowRounding => rust_decimal::Decimal::from_str(decimal_str)?,
};
Self::try_from_rust_decimal(decimal, decimals, conversion_rule)
}
}
pub type TokenAmountConversionResult<T> = Result<T, TokenAmountConversionError>;
#[derive(Debug, thiserror::Error)]
pub enum TokenAmountConversionError {
#[error("error converting into rust_decimal::Decimal")]
RustDecimal(#[from] rust_decimal::Error),
#[error("precision lost due to rounding")]
LossOfPrecision,
#[error("value overflow")]
ValueOverflow,
}
impl std::fmt::Display for TokenAmount {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.decimals == 0 {
self.value.fmt(f)
} else {
let mut value = format!(
"{value:0width$}",
width = self.decimals as usize + 1,
value = self.value
);
value.insert(value.len() - self.decimals() as usize, '.');
write!(f, "{}", value)
}
}
}
impl PartialOrd for TokenAmount {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
(self.decimals == other.decimals).then(|| self.value.cmp(&other.value))
}
}
#[derive(serde::Serialize, serde::Deserialize)]
struct TokenAmountJson {
value: String,
decimals: u8,
}
impl From<TokenAmount> for TokenAmountJson {
fn from(amount: TokenAmount) -> Self {
Self {
value: amount.value.to_string(),
decimals: amount.decimals,
}
}
}
impl TryFrom<TokenAmountJson> for TokenAmount {
type Error = std::num::ParseIntError;
fn try_from(json: TokenAmountJson) -> Result<Self, Self::Error> {
Ok(Self {
value: json.value.parse()?,
decimals: json.decimals,
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::common::cbor;
use assert_matches::assert_matches;
use std::cmp::Ordering;
#[test]
fn test_display() {
let amount = TokenAmount {
value: 123450,
decimals: 3,
};
assert_eq!(amount.to_string().as_str(), "123.450");
let amount = TokenAmount {
value: 10,
decimals: 5,
};
assert_eq!(amount.to_string().as_str(), "0.00010");
}
#[test]
fn test_try_from_str_exact() {
let str = "123.450";
let token_amount = TokenAmount::from_str(str, 3, ConversionRule::Exact).unwrap();
assert_eq!(token_amount, TokenAmount::from_raw(123450, 3));
let token_amount = TokenAmount::from_str(str, 2, ConversionRule::Exact).unwrap();
assert_eq!(token_amount, TokenAmount::from_raw(12345, 2));
let token_amount = TokenAmount::from_str(str, 5, ConversionRule::Exact).unwrap();
assert_eq!(token_amount, TokenAmount::from_raw(12345000, 5));
let res = TokenAmount::from_str(str, 1, ConversionRule::Exact);
assert_matches!(res, Err(TokenAmountConversionError::LossOfPrecision));
let token_amount = TokenAmount::from_str(str, 1, ConversionRule::AllowRounding).unwrap();
assert_eq!(token_amount, TokenAmount::from_raw(1235, 1));
let str = "1.000000000000000000000000000000000000001";
let res = TokenAmount::from_str(str, 1, ConversionRule::Exact);
assert_matches!(res, Err(TokenAmountConversionError::LossOfPrecision));
let token_amount = TokenAmount::from_str(str, 1, ConversionRule::AllowRounding).unwrap();
assert_eq!(token_amount, TokenAmount::from_raw(10, 1));
}
#[test]
fn test_try_from_rust_decimal() {
let decimal = rust_decimal::Decimal::new(12600, 4);
let token_amount =
TokenAmount::try_from_rust_decimal(decimal, 4, ConversionRule::Exact).unwrap();
assert_eq!(token_amount, TokenAmount::from_raw(12600, 4));
let token_amount =
TokenAmount::try_from_rust_decimal(decimal, 2, ConversionRule::Exact).unwrap();
assert_eq!(token_amount, TokenAmount::from_raw(126, 2));
let token_amount =
TokenAmount::try_from_rust_decimal(decimal, 6, ConversionRule::Exact).unwrap();
assert_eq!(token_amount, TokenAmount::from_raw(1260000, 6));
let res = TokenAmount::try_from_rust_decimal(decimal, 1, ConversionRule::Exact);
assert_matches!(res, Err(TokenAmountConversionError::LossOfPrecision));
let token_amount =
TokenAmount::try_from_rust_decimal(decimal, 1, ConversionRule::AllowRounding).unwrap();
assert_eq!(token_amount, TokenAmount::from_raw(13, 1));
let res = TokenAmount::try_from_rust_decimal(decimal, 30, ConversionRule::Exact);
assert_matches!(
res,
Err(TokenAmountConversionError::RustDecimal(
rust_decimal::Error::ScaleExceedsMaximumPrecision(_)
))
);
let decimal = rust_decimal::Decimal::new(i64::MAX, 0);
let res = TokenAmount::try_from_rust_decimal(decimal, 28, ConversionRule::Exact);
assert_matches!(res, Err(TokenAmountConversionError::ValueOverflow));
}
#[test]
fn test_try_to_rust_decimal() {
let token_amount = TokenAmount::from_raw(12640, 4);
let decimal = token_amount.try_to_rust_decimal().unwrap();
assert_eq!(decimal.mantissa(), 12640);
assert_eq!(decimal.scale(), 4);
}
#[test]
fn test_partial_cmp() {
{
let amount1 = TokenAmount::from_raw(123, 3);
let amount2 = TokenAmount::from_raw(456, 3);
assert_eq!(amount1.partial_cmp(&amount2), Some(Ordering::Less));
}
{
let amount1 = TokenAmount::from_raw(123, 3);
let amount2 = TokenAmount::from_raw(456, 2);
assert_eq!(amount1.partial_cmp(&amount2), None);
}
}
#[test]
fn test_token_amount_cbor() {
let token_amount = TokenAmount::from_raw(12300, 3);
let cbor = cbor::cbor_encode(&token_amount).unwrap();
assert_eq!(hex::encode(&cbor), "c4822219300c");
let token_amount_decoded: TokenAmount = cbor::cbor_decode(&cbor).unwrap();
assert_eq!(token_amount_decoded, token_amount);
let token_amount = TokenAmount {
value: u64::MAX,
decimals: 3,
};
let cbor = cbor::cbor_encode(&token_amount).unwrap();
assert_eq!(hex::encode(&cbor), "c482221bffffffffffffffff");
let token_amount_decoded: TokenAmount = cbor::cbor_decode(&cbor).unwrap();
assert_eq!(token_amount_decoded, token_amount);
}
}