b3_helper_lib/ledger/currency/amount/
token_amount.rs1use crate::currency::ICPToken;
2use candid::CandidType;
3use candid::Nat;
4use serde::{Deserialize, Serialize};
5use std::{
6 fmt,
7 ops::{Add, Div, Mul, Sub},
8 str::FromStr,
9};
10
11use super::error::TokenAmountError;
12
13#[derive(CandidType, Deserialize, PartialEq, Eq, Serialize, Copy, Clone, Debug)]
14pub struct TokenAmount {
15 pub amount: u128,
16 pub decimals: u8,
17}
18
19impl PartialOrd for TokenAmount {
20 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
21 Some(
22 self.amount
23 .cmp(&other.amount)
24 .then_with(|| self.decimals.cmp(&other.decimals)),
25 )
26 }
27}
28
29impl TokenAmount {
30 pub fn new(amount: u128, decimals: u8) -> Self {
31 Self { amount, decimals }
32 }
33
34 pub fn from_tokens(tokens: ICPToken) -> Self {
35 Self {
36 amount: tokens.e8s as u128,
37 decimals: 8,
38 }
39 }
40
41 pub fn as_u64(&self) -> Result<u64, TokenAmountError> {
56 if self.decimals > 0 {
57 return Err(TokenAmountError::PrecisionLoss);
58 }
59
60 match self.amount.try_into() {
61 Ok(val) => Ok(val),
62 Err(_) => Err(TokenAmountError::Overflow),
63 }
64 }
65
66 pub fn as_u128(&self) -> Result<u128, TokenAmountError> {
78 if self.decimals > 0 {
79 return Err(TokenAmountError::PrecisionLoss);
80 }
81
82 Ok(self.amount)
83 }
84
85 pub fn to_nat(&self) -> Nat {
87 Nat::from(self.amount)
88 }
89
90 pub fn to_satoshi(&self) -> Result<u64, TokenAmountError> {
92 if self.decimals != 8 {
93 return Err(TokenAmountError::PrecisionLoss);
94 }
95
96 match self.amount.try_into() {
97 Ok(val) => Ok(val),
98 Err(_) => Err(TokenAmountError::Overflow),
99 }
100 }
101
102 pub fn to_tokens(&self) -> Result<ICPToken, TokenAmountError> {
105 self.try_into()
106 }
107}
108
109impl Add for TokenAmount {
110 type Output = Result<Self, TokenAmountError>;
111
112 fn add(self, other: Self) -> Self::Output {
113 if self.decimals != other.decimals {
114 return Err(TokenAmountError::DifferentDecimals(
115 self.decimals,
116 other.decimals,
117 ));
118 }
119
120 self.amount
121 .checked_add(other.amount)
122 .map(|amount| Self {
123 amount,
124 decimals: self.decimals,
125 })
126 .ok_or(TokenAmountError::Overflow)
127 }
128}
129
130impl Sub for TokenAmount {
131 type Output = Result<Self, TokenAmountError>;
132
133 fn sub(self, other: Self) -> Self::Output {
134 if self.decimals != other.decimals {
135 return Err(TokenAmountError::DifferentDecimals(
136 self.decimals,
137 other.decimals,
138 ));
139 }
140
141 self.amount
142 .checked_sub(other.amount)
143 .map(|amount| Self {
144 amount,
145 decimals: self.decimals,
146 })
147 .ok_or(TokenAmountError::Underflow)
148 }
149}
150
151impl Mul for TokenAmount {
152 type Output = Result<Self, TokenAmountError>;
153
154 fn mul(self, other: Self) -> Self::Output {
155 match self.amount.checked_mul(other.amount) {
156 Some(amount) => {
157 let decimals = self.decimals.saturating_add(other.decimals);
158 Ok(Self { amount, decimals })
159 }
160 None => Err(TokenAmountError::Overflow),
161 }
162 }
163}
164
165impl Div for TokenAmount {
166 type Output = Result<Self, TokenAmountError>;
167
168 fn div(self, other: Self) -> Self::Output {
169 if other.amount == 0 {
170 return Err(TokenAmountError::DivisionByZero);
171 }
172
173 let max_decimals = self.decimals.max(other.decimals);
175 let self_amount = self.amount * 10u128.pow((max_decimals - self.decimals) as u32);
176 let other_amount = other.amount * 10u128.pow((max_decimals - other.decimals) as u32);
177
178 match self_amount.checked_div(other_amount) {
179 Some(amount) => Ok(Self {
180 amount,
181 decimals: max_decimals,
182 }),
183 None => Err(TokenAmountError::Underflow),
184 }
185 }
186}
187
188impl From<u128> for TokenAmount {
189 fn from(amount: u128) -> Self {
190 Self {
191 amount,
192 decimals: 0,
193 }
194 }
195}
196
197impl TryFrom<TokenAmount> for Nat {
198 type Error = TokenAmountError;
199
200 fn try_from(amount: TokenAmount) -> Result<Self, Self::Error> {
201 if amount.decimals > 0 {
202 return Err(TokenAmountError::PrecisionLoss);
203 }
204
205 match amount.amount.try_into() {
206 Ok(val) => Ok(Nat(val)),
207 Err(_) => Err(TokenAmountError::Overflow),
208 }
209 }
210}
211
212impl TryFrom<&TokenAmount> for ICPToken {
213 type Error = TokenAmountError;
214
215 fn try_from(amount: &TokenAmount) -> Result<Self, Self::Error> {
216 if amount.decimals != ICPToken::DECIMALS {
217 return Err(TokenAmountError::DecimalsMismatch);
218 }
219
220 match amount.amount.try_into() {
221 Ok(val) => Ok(ICPToken::from_e8s(val)),
222 Err(_) => Err(TokenAmountError::Overflow),
223 }
224 }
225}
226
227impl From<ICPToken> for TokenAmount {
228 fn from(tokens: ICPToken) -> Self {
229 Self {
230 amount: tokens.e8s as u128,
231 decimals: 8,
232 }
233 }
234}
235
236impl fmt::Display for TokenAmount {
237 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238 let amount = self.amount.to_string();
239 let len = amount.len();
240
241 if self.decimals > 0 && len > self.decimals as usize {
242 let (integral, fractional) = amount.split_at(len - self.decimals as usize);
243 let fractional = fractional.trim_end_matches('0');
244 if fractional.is_empty() {
245 write!(f, "{}", integral)
246 } else {
247 write!(f, "{}.{}", integral, fractional)
248 }
249 } else {
250 if self.decimals == 0 {
251 write!(f, "{}", amount)
252 } else {
253 let zeros = if len <= self.decimals as usize {
254 "0".repeat(self.decimals as usize - len)
255 } else {
256 String::new()
257 };
258 let result = format!("0.{}{}", zeros, amount);
259 let result = result.trim_end_matches('0');
260 if result.ends_with('.') {
261 write!(f, "{}", result.trim_end_matches('.'))
262 } else {
263 write!(f, "{}", result)
264 }
265 }
266 }
267 }
268}
269
270impl FromStr for TokenAmount {
271 type Err = TokenAmountError;
272
273 fn from_str(s: &str) -> Result<Self, Self::Err> {
274 let parts: Vec<&str> = s.split('.').collect();
275 let amount: u128;
276 let mut decimals: u8 = 0;
277
278 if parts.len() == 1 {
279 amount = parts[0]
281 .parse::<u128>()
282 .map_err(|e| TokenAmountError::InvalidAmount(e.to_string()))?;
283 } else if parts.len() == 2 {
284 decimals = parts[1].len() as u8;
286 let whole = parts.join("");
287 amount = whole
288 .parse::<u128>()
289 .map_err(|e| TokenAmountError::InvalidAmount(e.to_string()))?;
290 } else {
291 return Err(TokenAmountError::ToManyDecimals);
292 }
293
294 Ok(Self { amount, decimals })
295 }
296}