use borsh::{BorshDeserialize, BorshSchema, BorshSerialize};
use serde::Serialize;
use std::{
convert::TryFrom,
fmt,
iter::Sum,
ops::{Add, Div, Mul, Sub},
};
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct Rational {
pub numerator: u64,
pub denominator: u64,
}
impl PartialOrd for Rational {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
if self.denominator == 0 || other.denominator == 0 {
None
} else {
let x = self.numerator as u128 * other.denominator as u128;
let y = other.numerator as u128 * self.denominator as u128;
Some(x.cmp(&y))
}
}
}
#[derive(Debug, Eq, PartialEq)]
pub struct ArithmeticError;
pub type Result<T> = std::result::Result<T, ArithmeticError>;
macro_rules! impl_token {
($TokenLamports:ident, $symbol:expr) => {
#[derive(
Copy,
Clone,
Default,
Eq,
Ord,
PartialEq,
PartialOrd,
BorshDeserialize,
BorshSerialize,
BorshSchema,
Serialize,
)]
pub struct $TokenLamports(pub u64);
impl fmt::Display for $TokenLamports {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}.{:0>9} {}",
self.0 / 1_000_000_000,
self.0 % 1_000_000_000,
$symbol
)
}
}
impl fmt::Debug for $TokenLamports {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl Mul<Rational> for $TokenLamports {
type Output = Result<$TokenLamports>;
fn mul(self, other: Rational) -> Result<$TokenLamports> {
let result_u128 = ((self.0 as u128) * (other.numerator as u128))
.checked_div(other.denominator as u128)
.ok_or(ArithmeticError)?;
u64::try_from(result_u128)
.map($TokenLamports)
.map_err(|_| ArithmeticError)
}
}
impl Mul<u64> for $TokenLamports {
type Output = Result<$TokenLamports>;
fn mul(self, other: u64) -> Result<$TokenLamports> {
self.0
.checked_mul(other)
.map($TokenLamports)
.ok_or(ArithmeticError)
}
}
impl Div<u64> for $TokenLamports {
type Output = Result<$TokenLamports>;
fn div(self, other: u64) -> Result<$TokenLamports> {
self.0
.checked_div(other)
.map($TokenLamports)
.ok_or(ArithmeticError)
}
}
impl Sub<$TokenLamports> for $TokenLamports {
type Output = Result<$TokenLamports>;
fn sub(self, other: $TokenLamports) -> Result<$TokenLamports> {
self.0
.checked_sub(other.0)
.map($TokenLamports)
.ok_or(ArithmeticError)
}
}
impl Add<$TokenLamports> for $TokenLamports {
type Output = Result<$TokenLamports>;
fn add(self, other: $TokenLamports) -> Result<$TokenLamports> {
self.0
.checked_add(other.0)
.map($TokenLamports)
.ok_or(ArithmeticError)
}
}
impl Sum<$TokenLamports> for Result<$TokenLamports> {
fn sum<I: Iterator<Item = $TokenLamports>>(iter: I) -> Self {
let mut sum = $TokenLamports(0);
for item in iter {
sum = (sum + item)?;
}
Ok(sum)
}
}
impl std::str::FromStr for $TokenLamports {
type Err = &'static str;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let mut value = 0_u64;
let mut is_after_decimal = false;
let mut exponent = 9_i32;
let mut had_digit = false;
for ch in s.as_bytes() {
match ch {
b'0'..=b'9' => {
value = value * 10 + ((ch - b'0') as u64);
if is_after_decimal {
exponent -= 1;
}
had_digit = true;
}
b'.' if !is_after_decimal => is_after_decimal = true,
b'.' => return Err("Value can contain at most one '.' (decimal point)."),
b'_' => { }
_ => return Err("Invalid value, only digits, '_', and '.' are allowed."),
}
if exponent < 0 {
return Err("Value can contain at most 9 digits after the decimal point.");
}
}
if !had_digit {
return Err("Value must contain at least one digit.");
}
while exponent > 0 {
value *= 10;
exponent -= 1;
}
Ok($TokenLamports(value))
}
}
};
}
impl_token!(Lamports, "SOL");
impl_token!(StLamports, "stSOL");
#[cfg(test)]
pub mod test {
use super::*;
use std::str::FromStr;
#[test]
fn test_lamports_from_str_roundtrip() {
let mut x = 0;
while x < u64::MAX / 17 {
let orig = Lamports(x);
let s = format!("{}", orig);
let without_suffix = &s[..s.len() - 4];
let reconstructed = Lamports::from_str(without_suffix).unwrap();
assert_eq!(reconstructed, orig);
x += 1;
x *= 17;
}
}
#[test]
fn test_lamports_from_str_handles_more_than_f64() {
let x = "9007199.254740993";
let expected = Lamports(9007199_254740993);
assert_eq!(Lamports::from_str(x), Ok(expected));
assert_eq!((f64::from_str(x).unwrap() * 1e9) as u64, expected.0 - 1);
}
#[test]
fn test_lamports_from_str_examples() {
assert_eq!(Lamports::from_str("1_000"), Ok(Lamports(1_000_000_000_000)));
assert_eq!(Lamports::from_str("1"), Ok(Lamports(1_000_000_000)));
assert_eq!(Lamports::from_str("1."), Ok(Lamports(1_000_000_000)));
assert_eq!(Lamports::from_str("1.0"), Ok(Lamports(1_000_000_000)));
assert_eq!(Lamports::from_str("1.02"), Ok(Lamports(1_020_000_000)));
assert_eq!(
Lamports::from_str("1.000_000_001"),
Ok(Lamports(1_000_000_001))
);
assert_eq!(Lamports::from_str(".1"), Ok(Lamports(100_000_000)));
assert!(Lamports::from_str("").is_err());
assert!(Lamports::from_str(".").is_err());
assert!(Lamports::from_str("_").is_err());
assert!(Lamports::from_str("_._").is_err());
assert!(Lamports::from_str("0.000_000_000_1").is_err());
assert!(Lamports::from_str("0.0.0").is_err());
assert!(Lamports::from_str("lol, sol").is_err());
}
#[test]
fn test_rational_cmp() {
let x = Rational {
numerator: 1 << 53,
denominator: 1,
};
let y = Rational {
numerator: x.numerator + 1,
denominator: x.denominator,
};
assert_eq!(x.partial_cmp(&y), Some(std::cmp::Ordering::Less));
assert_eq!(y.partial_cmp(&x), Some(std::cmp::Ordering::Greater));
}
#[test]
fn test_equal_cmp() {
let x = Rational {
numerator: 1,
denominator: 1,
};
let y = Rational {
numerator: 1,
denominator: 1,
};
assert_eq!(x.partial_cmp(&y), Some(std::cmp::Ordering::Equal));
assert_eq!(y.partial_cmp(&x), Some(std::cmp::Ordering::Equal));
}
#[test]
fn test_division_by_zero_cmp() {
let x = Rational {
numerator: 1,
denominator: 0,
};
let y = Rational {
numerator: x.numerator,
denominator: x.denominator + 1,
};
assert_eq!(x.partial_cmp(&y), None);
assert_eq!(y.partial_cmp(&x), None);
let y = Rational {
numerator: x.numerator,
denominator: x.denominator,
};
assert_eq!(x.partial_cmp(&y), None);
assert_eq!(y.partial_cmp(&x), None);
}
}