b3_helper_lib/ledger/currency/amount/
token_amount.rs

1use crate::currency::ICPToken;
2use candid::CandidType;
3use candid::Nat;
4use serde::{Deserialize, Serialize};
5use std::{
6    fmt,
7    ops::{Add, Div, Mul, Sub},
8    str::FromStr,
9};
10
11use super::error::TokenAmountError;
12
13#[derive(CandidType, Deserialize, PartialEq, Eq, Serialize, Copy, Clone, Debug)]
14pub struct TokenAmount {
15    pub amount: u128,
16    pub decimals: u8,
17}
18
19impl PartialOrd for TokenAmount {
20    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
21        Some(
22            self.amount
23                .cmp(&other.amount)
24                .then_with(|| self.decimals.cmp(&other.decimals)),
25        )
26    }
27}
28
29impl TokenAmount {
30    pub fn new(amount: u128, decimals: u8) -> Self {
31        Self { amount, decimals }
32    }
33
34    pub fn from_tokens(tokens: ICPToken) -> Self {
35        Self {
36            amount: tokens.e8s as u128,
37            decimals: 8,
38        }
39    }
40
41    /// Returns the amount as a u64, if the amount has no decimals.
42    /// Otherwise returns an error.
43    /// # Example
44    /// ```
45    /// use b3_helper_lib::currency::TokenAmount;
46    ///
47    /// let amount = TokenAmount::new(100, 0);
48    ///
49    /// assert_eq!(amount.as_u64().unwrap(), 100);
50    ///
51    /// let amount = TokenAmount::new(100, 1);
52    ///
53    /// assert!(amount.as_u64().is_err());
54    /// ```
55    pub fn as_u64(&self) -> Result<u64, TokenAmountError> {
56        if self.decimals > 0 {
57            return Err(TokenAmountError::PrecisionLoss);
58        }
59
60        match self.amount.try_into() {
61            Ok(val) => Ok(val),
62            Err(_) => Err(TokenAmountError::Overflow),
63        }
64    }
65
66    /// Returns the amount as a u128, if the amount has no decimals.
67    /// Otherwise returns an error.
68    ///
69    /// # Example
70    /// ```
71    /// use b3_helper_lib::currency::TokenAmount;
72    ///
73    /// let amount = TokenAmount::new(100, 0);
74    ///
75    /// assert_eq!(amount.as_u128().unwrap(), 100);
76    /// ```
77    pub fn as_u128(&self) -> Result<u128, TokenAmountError> {
78        if self.decimals > 0 {
79            return Err(TokenAmountError::PrecisionLoss);
80        }
81
82        Ok(self.amount)
83    }
84
85    /// Returns the amount to Nat.
86    pub fn to_nat(&self) -> Nat {
87        Nat::from(self.amount)
88    }
89
90    /// Returns the amount to Satoshi.
91    pub fn to_satoshi(&self) -> Result<u64, TokenAmountError> {
92        if self.decimals != 8 {
93            return Err(TokenAmountError::PrecisionLoss);
94        }
95
96        match self.amount.try_into() {
97            Ok(val) => Ok(val),
98            Err(_) => Err(TokenAmountError::Overflow),
99        }
100    }
101
102    /// Returns the Tokens representation of this amount.
103    /// Throws an error if the decimals are not 8.
104    pub fn to_tokens(&self) -> Result<ICPToken, TokenAmountError> {
105        self.try_into()
106    }
107}
108
109impl Add for TokenAmount {
110    type Output = Result<Self, TokenAmountError>;
111
112    fn add(self, other: Self) -> Self::Output {
113        if self.decimals != other.decimals {
114            return Err(TokenAmountError::DifferentDecimals(
115                self.decimals,
116                other.decimals,
117            ));
118        }
119
120        self.amount
121            .checked_add(other.amount)
122            .map(|amount| Self {
123                amount,
124                decimals: self.decimals,
125            })
126            .ok_or(TokenAmountError::Overflow)
127    }
128}
129
130impl Sub for TokenAmount {
131    type Output = Result<Self, TokenAmountError>;
132
133    fn sub(self, other: Self) -> Self::Output {
134        if self.decimals != other.decimals {
135            return Err(TokenAmountError::DifferentDecimals(
136                self.decimals,
137                other.decimals,
138            ));
139        }
140
141        self.amount
142            .checked_sub(other.amount)
143            .map(|amount| Self {
144                amount,
145                decimals: self.decimals,
146            })
147            .ok_or(TokenAmountError::Underflow)
148    }
149}
150
151impl Mul for TokenAmount {
152    type Output = Result<Self, TokenAmountError>;
153
154    fn mul(self, other: Self) -> Self::Output {
155        match self.amount.checked_mul(other.amount) {
156            Some(amount) => {
157                let decimals = self.decimals.saturating_add(other.decimals);
158                Ok(Self { amount, decimals })
159            }
160            None => Err(TokenAmountError::Overflow),
161        }
162    }
163}
164
165impl Div for TokenAmount {
166    type Output = Result<Self, TokenAmountError>;
167
168    fn div(self, other: Self) -> Self::Output {
169        if other.amount == 0 {
170            return Err(TokenAmountError::DivisionByZero);
171        }
172
173        // Adjust the divisor and dividend to have the same decimal places
174        let max_decimals = self.decimals.max(other.decimals);
175        let self_amount = self.amount * 10u128.pow((max_decimals - self.decimals) as u32);
176        let other_amount = other.amount * 10u128.pow((max_decimals - other.decimals) as u32);
177
178        match self_amount.checked_div(other_amount) {
179            Some(amount) => Ok(Self {
180                amount,
181                decimals: max_decimals,
182            }),
183            None => Err(TokenAmountError::Underflow),
184        }
185    }
186}
187
188impl From<u128> for TokenAmount {
189    fn from(amount: u128) -> Self {
190        Self {
191            amount,
192            decimals: 0,
193        }
194    }
195}
196
197impl TryFrom<TokenAmount> for Nat {
198    type Error = TokenAmountError;
199
200    fn try_from(amount: TokenAmount) -> Result<Self, Self::Error> {
201        if amount.decimals > 0 {
202            return Err(TokenAmountError::PrecisionLoss);
203        }
204
205        match amount.amount.try_into() {
206            Ok(val) => Ok(Nat(val)),
207            Err(_) => Err(TokenAmountError::Overflow),
208        }
209    }
210}
211
212impl TryFrom<&TokenAmount> for ICPToken {
213    type Error = TokenAmountError;
214
215    fn try_from(amount: &TokenAmount) -> Result<Self, Self::Error> {
216        if amount.decimals != ICPToken::DECIMALS {
217            return Err(TokenAmountError::DecimalsMismatch);
218        }
219
220        match amount.amount.try_into() {
221            Ok(val) => Ok(ICPToken::from_e8s(val)),
222            Err(_) => Err(TokenAmountError::Overflow),
223        }
224    }
225}
226
227impl From<ICPToken> for TokenAmount {
228    fn from(tokens: ICPToken) -> Self {
229        Self {
230            amount: tokens.e8s as u128,
231            decimals: 8,
232        }
233    }
234}
235
236impl fmt::Display for TokenAmount {
237    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238        let amount = self.amount.to_string();
239        let len = amount.len();
240
241        if self.decimals > 0 && len > self.decimals as usize {
242            let (integral, fractional) = amount.split_at(len - self.decimals as usize);
243            let fractional = fractional.trim_end_matches('0');
244            if fractional.is_empty() {
245                write!(f, "{}", integral)
246            } else {
247                write!(f, "{}.{}", integral, fractional)
248            }
249        } else {
250            if self.decimals == 0 {
251                write!(f, "{}", amount)
252            } else {
253                let zeros = if len <= self.decimals as usize {
254                    "0".repeat(self.decimals as usize - len)
255                } else {
256                    String::new()
257                };
258                let result = format!("0.{}{}", zeros, amount);
259                let result = result.trim_end_matches('0');
260                if result.ends_with('.') {
261                    write!(f, "{}", result.trim_end_matches('.'))
262                } else {
263                    write!(f, "{}", result)
264                }
265            }
266        }
267    }
268}
269
270impl FromStr for TokenAmount {
271    type Err = TokenAmountError;
272
273    fn from_str(s: &str) -> Result<Self, Self::Err> {
274        let parts: Vec<&str> = s.split('.').collect();
275        let amount: u128;
276        let mut decimals: u8 = 0;
277
278        if parts.len() == 1 {
279            // If there's no decimal point
280            amount = parts[0]
281                .parse::<u128>()
282                .map_err(|e| TokenAmountError::InvalidAmount(e.to_string()))?;
283        } else if parts.len() == 2 {
284            // If there's a decimal point
285            decimals = parts[1].len() as u8;
286            let whole = parts.join("");
287            amount = whole
288                .parse::<u128>()
289                .map_err(|e| TokenAmountError::InvalidAmount(e.to_string()))?;
290        } else {
291            return Err(TokenAmountError::ToManyDecimals);
292        }
293
294        Ok(Self { amount, decimals })
295    }
296}