1use std::cmp::Ordering;
6use std::fmt;
7use std::str::FromStr;
8
9use serde::{Deserialize, Serialize};
10use thiserror::Error;
11
12use crate::nuts::CurrencyUnit;
13
14#[derive(Debug, Error)]
16pub enum Error {
17 #[error("Split Values must be less then or equal to amount")]
19 SplitValuesGreater,
20 #[error("Amount Overflow")]
22 AmountOverflow,
23 #[error("Cannot convert units")]
25 CannotConvertUnits,
26 #[error("Invalid Amount: {0}")]
28 InvalidAmount(String),
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
33#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
34#[serde(transparent)]
35pub struct Amount(u64);
36
37impl FromStr for Amount {
38 type Err = Error;
39
40 fn from_str(s: &str) -> Result<Self, Self::Err> {
41 let value = s
42 .parse::<u64>()
43 .map_err(|_| Error::InvalidAmount(s.to_owned()))?;
44 Ok(Amount(value))
45 }
46}
47
48impl Amount {
49 pub const ZERO: Amount = Amount(0);
51
52 pub const ONE: Amount = Amount(1);
54
55 pub fn split(&self) -> Vec<Self> {
57 let sats = self.0;
58 (0_u64..64)
59 .rev()
60 .filter_map(|bit| {
61 let part = 1 << bit;
62 ((sats & part) == part).then_some(Self::from(part))
63 })
64 .collect()
65 }
66
67 pub fn split_targeted(&self, target: &SplitTarget) -> Result<Vec<Self>, Error> {
69 let mut parts = match target {
70 SplitTarget::None => self.split(),
71 SplitTarget::Value(amount) => {
72 if self.le(amount) {
73 return Ok(self.split());
74 }
75
76 let mut parts_total = Amount::ZERO;
77 let mut parts = Vec::new();
78
79 let parts_of_value = amount.split();
81
82 while parts_total.lt(self) {
83 for part in parts_of_value.iter().copied() {
84 if (part + parts_total).le(self) {
85 parts.push(part);
86 } else {
87 let amount_left = *self - parts_total;
88 parts.extend(amount_left.split());
89 }
90
91 parts_total = Amount::try_sum(parts.clone().iter().copied())?;
92
93 if parts_total.eq(self) {
94 break;
95 }
96 }
97 }
98
99 parts
100 }
101 SplitTarget::Values(values) => {
102 let values_total: Amount = Amount::try_sum(values.clone().into_iter())?;
103
104 match self.cmp(&values_total) {
105 Ordering::Equal => values.clone(),
106 Ordering::Less => {
107 return Err(Error::SplitValuesGreater);
108 }
109 Ordering::Greater => {
110 let extra = *self - values_total;
111 let mut extra_amount = extra.split();
112 let mut values = values.clone();
113
114 values.append(&mut extra_amount);
115 values
116 }
117 }
118 }
119 };
120
121 parts.sort();
122 Ok(parts)
123 }
124
125 pub fn split_with_fee(&self, fee_ppk: u64) -> Result<Vec<Self>, Error> {
127 let without_fee_amounts = self.split();
128 let fee_ppk = fee_ppk * without_fee_amounts.len() as u64;
129 let fee = Amount::from((fee_ppk + 999) / 1000);
130 let new_amount = self.checked_add(fee).ok_or(Error::AmountOverflow)?;
131
132 let split = new_amount.split();
133 let split_fee_ppk = split.len() as u64 * fee_ppk;
134 let split_fee = Amount::from((split_fee_ppk + 999) / 1000);
135
136 if let Some(net_amount) = new_amount.checked_sub(split_fee) {
137 if net_amount >= *self {
138 return Ok(split);
139 }
140 }
141 self.checked_add(Amount::ONE)
142 .ok_or(Error::AmountOverflow)?
143 .split_with_fee(fee_ppk)
144 }
145
146 pub fn checked_add(self, other: Amount) -> Option<Amount> {
148 self.0.checked_add(other.0).map(Amount)
149 }
150
151 pub fn checked_sub(self, other: Amount) -> Option<Amount> {
153 self.0.checked_sub(other.0).map(Amount)
154 }
155
156 pub fn checked_mul(self, other: Amount) -> Option<Amount> {
158 self.0.checked_mul(other.0).map(Amount)
159 }
160
161 pub fn checked_div(self, other: Amount) -> Option<Amount> {
163 self.0.checked_div(other.0).map(Amount)
164 }
165
166 pub fn try_sum<I>(iter: I) -> Result<Self, Error>
168 where
169 I: IntoIterator<Item = Self>,
170 {
171 iter.into_iter().try_fold(Amount::ZERO, |acc, x| {
172 acc.checked_add(x).ok_or(Error::AmountOverflow)
173 })
174 }
175}
176
177impl Default for Amount {
178 fn default() -> Self {
179 Amount::ZERO
180 }
181}
182
183impl Default for &Amount {
184 fn default() -> Self {
185 &Amount::ZERO
186 }
187}
188
189impl fmt::Display for Amount {
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 if let Some(width) = f.width() {
192 write!(f, "{:width$}", self.0, width = width)
193 } else {
194 write!(f, "{}", self.0)
195 }
196 }
197}
198
199impl From<u64> for Amount {
200 fn from(value: u64) -> Self {
201 Self(value)
202 }
203}
204
205impl From<&u64> for Amount {
206 fn from(value: &u64) -> Self {
207 Self(*value)
208 }
209}
210
211impl From<Amount> for u64 {
212 fn from(value: Amount) -> Self {
213 value.0
214 }
215}
216
217impl AsRef<u64> for Amount {
218 fn as_ref(&self) -> &u64 {
219 &self.0
220 }
221}
222
223impl std::ops::Add for Amount {
224 type Output = Amount;
225
226 fn add(self, rhs: Amount) -> Self::Output {
227 Amount(self.0.checked_add(rhs.0).expect("Addition error"))
228 }
229}
230
231impl std::ops::AddAssign for Amount {
232 fn add_assign(&mut self, rhs: Self) {
233 self.0 = self.0.checked_add(rhs.0).expect("Addition error");
234 }
235}
236
237impl std::ops::Sub for Amount {
238 type Output = Amount;
239
240 fn sub(self, rhs: Amount) -> Self::Output {
241 Amount(self.0 - rhs.0)
242 }
243}
244
245impl std::ops::SubAssign for Amount {
246 fn sub_assign(&mut self, other: Self) {
247 self.0 -= other.0;
248 }
249}
250
251impl std::ops::Mul for Amount {
252 type Output = Self;
253
254 fn mul(self, other: Self) -> Self::Output {
255 Amount(self.0 * other.0)
256 }
257}
258
259impl std::ops::Div for Amount {
260 type Output = Self;
261
262 fn div(self, other: Self) -> Self::Output {
263 Amount(self.0 / other.0)
264 }
265}
266
267#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)]
269pub enum SplitTarget {
270 #[default]
272 None,
273 Value(Amount),
275 Values(Vec<Amount>),
277}
278
279pub const MSAT_IN_SAT: u64 = 1000;
281
282pub fn to_unit<T>(
284 amount: T,
285 current_unit: &CurrencyUnit,
286 target_unit: &CurrencyUnit,
287) -> Result<Amount, Error>
288where
289 T: Into<u64>,
290{
291 let amount = amount.into();
292 match (current_unit, target_unit) {
293 (CurrencyUnit::Sat, CurrencyUnit::Sat) => Ok(amount.into()),
294 (CurrencyUnit::Msat, CurrencyUnit::Msat) => Ok(amount.into()),
295 (CurrencyUnit::Sat, CurrencyUnit::Msat) => Ok((amount * MSAT_IN_SAT).into()),
296 (CurrencyUnit::Msat, CurrencyUnit::Sat) => Ok((amount / MSAT_IN_SAT).into()),
297 (CurrencyUnit::Usd, CurrencyUnit::Usd) => Ok(amount.into()),
298 (CurrencyUnit::Eur, CurrencyUnit::Eur) => Ok(amount.into()),
299 _ => Err(Error::CannotConvertUnits),
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_split_amount() {
309 assert_eq!(Amount::from(1).split(), vec![Amount::from(1)]);
310 assert_eq!(Amount::from(2).split(), vec![Amount::from(2)]);
311 assert_eq!(
312 Amount::from(3).split(),
313 vec![Amount::from(2), Amount::from(1)]
314 );
315 let amounts: Vec<Amount> = [8, 2, 1].iter().map(|a| Amount::from(*a)).collect();
316 assert_eq!(Amount::from(11).split(), amounts);
317 let amounts: Vec<Amount> = [128, 64, 32, 16, 8, 4, 2, 1]
318 .iter()
319 .map(|a| Amount::from(*a))
320 .collect();
321 assert_eq!(Amount::from(255).split(), amounts);
322 }
323
324 #[test]
325 fn test_split_target_amount() {
326 let amount = Amount(65);
327
328 let split = amount
329 .split_targeted(&SplitTarget::Value(Amount(32)))
330 .unwrap();
331 assert_eq!(vec![Amount(1), Amount(32), Amount(32)], split);
332
333 let amount = Amount(150);
334
335 let split = amount
336 .split_targeted(&SplitTarget::Value(Amount::from(50)))
337 .unwrap();
338 assert_eq!(
339 vec![
340 Amount(2),
341 Amount(2),
342 Amount(2),
343 Amount(16),
344 Amount(16),
345 Amount(16),
346 Amount(32),
347 Amount(32),
348 Amount(32)
349 ],
350 split
351 );
352
353 let amount = Amount::from(63);
354
355 let split = amount
356 .split_targeted(&SplitTarget::Value(Amount::from(32)))
357 .unwrap();
358 assert_eq!(
359 vec![
360 Amount(1),
361 Amount(2),
362 Amount(4),
363 Amount(8),
364 Amount(16),
365 Amount(32)
366 ],
367 split
368 );
369 }
370
371 #[test]
372 fn test_split_with_fee() {
373 let amount = Amount(2);
374 let fee_ppk = 1;
375
376 let split = amount.split_with_fee(fee_ppk).unwrap();
377 assert_eq!(split, vec![Amount(2), Amount(1)]);
378
379 let amount = Amount(3);
380 let fee_ppk = 1;
381
382 let split = amount.split_with_fee(fee_ppk).unwrap();
383 assert_eq!(split, vec![Amount(4)]);
384
385 let amount = Amount(3);
386 let fee_ppk = 1000;
387
388 let split = amount.split_with_fee(fee_ppk).unwrap();
389 assert_eq!(split, vec![Amount(32)]);
390 }
391
392 #[test]
393 fn test_split_values() {
394 let amount = Amount(10);
395
396 let target = vec![Amount(2), Amount(4), Amount(4)];
397
398 let split_target = SplitTarget::Values(target.clone());
399
400 let values = amount.split_targeted(&split_target).unwrap();
401
402 assert_eq!(target, values);
403
404 let target = vec![Amount(2), Amount(4), Amount(4)];
405
406 let split_target = SplitTarget::Values(vec![Amount(2), Amount(4)]);
407
408 let values = amount.split_targeted(&split_target).unwrap();
409
410 assert_eq!(target, values);
411
412 let split_target = SplitTarget::Values(vec![Amount(2), Amount(10)]);
413
414 let values = amount.split_targeted(&split_target);
415
416 assert!(values.is_err())
417 }
418
419 #[test]
420 #[should_panic]
421 fn test_amount_addition() {
422 let amount_one: Amount = u64::MAX.into();
423 let amount_two: Amount = 1.into();
424
425 let amounts = vec![amount_one, amount_two];
426
427 let _total: Amount = Amount::try_sum(amounts).unwrap();
428 }
429
430 #[test]
431 fn test_try_amount_addition() {
432 let amount_one: Amount = u64::MAX.into();
433 let amount_two: Amount = 1.into();
434
435 let amounts = vec![amount_one, amount_two];
436
437 let total = Amount::try_sum(amounts);
438
439 assert!(total.is_err());
440 let amount_one: Amount = 10000.into();
441 let amount_two: Amount = 1.into();
442
443 let amounts = vec![amount_one, amount_two];
444 let total = Amount::try_sum(amounts).unwrap();
445
446 assert_eq!(total, 10001.into());
447 }
448
449 #[test]
450 fn test_amount_to_unit() {
451 let amount = Amount::from(1000);
452 let current_unit = CurrencyUnit::Sat;
453 let target_unit = CurrencyUnit::Msat;
454
455 let converted = to_unit(amount, ¤t_unit, &target_unit).unwrap();
456
457 assert_eq!(converted, 1000000.into());
458
459 let amount = Amount::from(1000);
460 let current_unit = CurrencyUnit::Msat;
461 let target_unit = CurrencyUnit::Sat;
462
463 let converted = to_unit(amount, ¤t_unit, &target_unit).unwrap();
464
465 assert_eq!(converted, 1.into());
466
467 let amount = Amount::from(1);
468 let current_unit = CurrencyUnit::Usd;
469 let target_unit = CurrencyUnit::Usd;
470
471 let converted = to_unit(amount, ¤t_unit, &target_unit).unwrap();
472
473 assert_eq!(converted, 1.into());
474
475 let amount = Amount::from(1);
476 let current_unit = CurrencyUnit::Eur;
477 let target_unit = CurrencyUnit::Eur;
478
479 let converted = to_unit(amount, ¤t_unit, &target_unit).unwrap();
480
481 assert_eq!(converted, 1.into());
482
483 let amount = Amount::from(1);
484 let current_unit = CurrencyUnit::Sat;
485 let target_unit = CurrencyUnit::Eur;
486
487 let converted = to_unit(amount, ¤t_unit, &target_unit);
488
489 assert!(converted.is_err());
490 }
491}