1use std::cmp::Ordering;
6use std::fmt;
7use std::str::FromStr;
8
9use serde::{Deserialize, Deserializer, Serialize, Serializer};
10use thiserror::Error;
11
12use crate::nuts::CurrencyUnit;
13
14#[derive(Debug, Error)]
16pub enum Error {
17 #[error("Split Values must be less then or equal to amount")]
19 SplitValuesGreater,
20 #[error("Amount Overflow")]
22 AmountOverflow,
23 #[error("Cannot convert units")]
25 CannotConvertUnits,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
30#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
31#[serde(transparent)]
32pub struct Amount(u64);
33
34impl Amount {
35 pub const ZERO: Amount = Amount(0);
37
38 pub fn split(&self) -> Vec<Self> {
40 let sats = self.0;
41 (0_u64..64)
42 .rev()
43 .filter_map(|bit| {
44 let part = 1 << bit;
45 ((sats & part) == part).then_some(Self::from(part))
46 })
47 .collect()
48 }
49
50 pub fn split_targeted(&self, target: &SplitTarget) -> Result<Vec<Self>, Error> {
52 let mut parts = match target {
53 SplitTarget::None => self.split(),
54 SplitTarget::Value(amount) => {
55 if self.le(amount) {
56 return Ok(self.split());
57 }
58
59 let mut parts_total = Amount::ZERO;
60 let mut parts = Vec::new();
61
62 let parts_of_value = amount.split();
64
65 while parts_total.lt(self) {
66 for part in parts_of_value.iter().copied() {
67 if (part + parts_total).le(self) {
68 parts.push(part);
69 } else {
70 let amount_left = *self - parts_total;
71 parts.extend(amount_left.split());
72 }
73
74 parts_total = Amount::try_sum(parts.clone().iter().copied())?;
75
76 if parts_total.eq(self) {
77 break;
78 }
79 }
80 }
81
82 parts
83 }
84 SplitTarget::Values(values) => {
85 let values_total: Amount = Amount::try_sum(values.clone().into_iter())?;
86
87 match self.cmp(&values_total) {
88 Ordering::Equal => values.clone(),
89 Ordering::Less => {
90 return Err(Error::SplitValuesGreater);
91 }
92 Ordering::Greater => {
93 let extra = *self - values_total;
94 let mut extra_amount = extra.split();
95 let mut values = values.clone();
96
97 values.append(&mut extra_amount);
98 values
99 }
100 }
101 }
102 };
103
104 parts.sort();
105 Ok(parts)
106 }
107
108 pub fn checked_add(self, other: Amount) -> Option<Amount> {
110 self.0.checked_add(other.0).map(Amount)
111 }
112
113 pub fn checked_sub(self, other: Amount) -> Option<Amount> {
115 self.0.checked_sub(other.0).map(Amount)
116 }
117
118 pub fn try_sum<I>(iter: I) -> Result<Self, Error>
120 where
121 I: IntoIterator<Item = Self>,
122 {
123 iter.into_iter().try_fold(Amount::ZERO, |acc, x| {
124 acc.checked_add(x).ok_or(Error::AmountOverflow)
125 })
126 }
127}
128
129impl Default for Amount {
130 fn default() -> Self {
131 Amount::ZERO
132 }
133}
134
135impl Default for &Amount {
136 fn default() -> Self {
137 &Amount::ZERO
138 }
139}
140
141impl fmt::Display for Amount {
142 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143 if let Some(width) = f.width() {
144 write!(f, "{:width$}", self.0, width = width)
145 } else {
146 write!(f, "{}", self.0)
147 }
148 }
149}
150
151impl From<u64> for Amount {
152 fn from(value: u64) -> Self {
153 Self(value)
154 }
155}
156
157impl From<&u64> for Amount {
158 fn from(value: &u64) -> Self {
159 Self(*value)
160 }
161}
162
163impl From<Amount> for u64 {
164 fn from(value: Amount) -> Self {
165 value.0
166 }
167}
168
169impl AsRef<u64> for Amount {
170 fn as_ref(&self) -> &u64 {
171 &self.0
172 }
173}
174
175impl std::ops::Add for Amount {
176 type Output = Amount;
177
178 fn add(self, rhs: Amount) -> Self::Output {
179 Amount(self.0.checked_add(rhs.0).expect("Addition error"))
180 }
181}
182
183impl std::ops::AddAssign for Amount {
184 fn add_assign(&mut self, rhs: Self) {
185 self.0 = self.0.checked_add(rhs.0).expect("Addition error");
186 }
187}
188
189impl std::ops::Sub for Amount {
190 type Output = Amount;
191
192 fn sub(self, rhs: Amount) -> Self::Output {
193 Amount(self.0 - rhs.0)
194 }
195}
196
197impl std::ops::SubAssign for Amount {
198 fn sub_assign(&mut self, other: Self) {
199 self.0 -= other.0;
200 }
201}
202
203impl std::ops::Mul for Amount {
204 type Output = Self;
205
206 fn mul(self, other: Self) -> Self::Output {
207 Amount(self.0 * other.0)
208 }
209}
210
211impl std::ops::Div for Amount {
212 type Output = Self;
213
214 fn div(self, other: Self) -> Self::Output {
215 Amount(self.0 / other.0)
216 }
217}
218
219#[derive(Debug, Clone, PartialEq, Eq)]
225pub struct AmountStr(Amount);
226
227impl AmountStr {
228 pub(crate) fn from(amt: Amount) -> Self {
229 Self(amt)
230 }
231}
232
233impl PartialOrd<Self> for AmountStr {
234 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
235 Some(self.cmp(other))
236 }
237}
238
239impl Ord for AmountStr {
240 fn cmp(&self, other: &Self) -> Ordering {
241 self.0.cmp(&other.0)
242 }
243}
244
245impl<'de> Deserialize<'de> for AmountStr {
246 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
247 where
248 D: Deserializer<'de>,
249 {
250 let s = String::deserialize(deserializer)?;
251 u64::from_str(&s)
252 .map(Amount)
253 .map(Self)
254 .map_err(serde::de::Error::custom)
255 }
256}
257
258impl Serialize for AmountStr {
259 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
260 where
261 S: Serializer,
262 {
263 serializer.serialize_str(&self.0.to_string())
264 }
265}
266
267#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)]
269pub enum SplitTarget {
270 #[default]
272 None,
273 Value(Amount),
275 Values(Vec<Amount>),
277}
278
279pub const MSAT_IN_SAT: u64 = 1000;
281
282pub fn to_unit<T>(
284 amount: T,
285 current_unit: &CurrencyUnit,
286 target_unit: &CurrencyUnit,
287) -> Result<Amount, Error>
288where
289 T: Into<u64>,
290{
291 let amount = amount.into();
292 match (current_unit, target_unit) {
293 (CurrencyUnit::Sat, CurrencyUnit::Sat) => Ok(amount.into()),
294 (CurrencyUnit::Msat, CurrencyUnit::Msat) => Ok(amount.into()),
295 (CurrencyUnit::Sat, CurrencyUnit::Msat) => Ok((amount * MSAT_IN_SAT).into()),
296 (CurrencyUnit::Msat, CurrencyUnit::Sat) => Ok((amount / MSAT_IN_SAT).into()),
297 (CurrencyUnit::Usd, CurrencyUnit::Usd) => Ok(amount.into()),
298 (CurrencyUnit::Eur, CurrencyUnit::Eur) => Ok(amount.into()),
299 _ => Err(Error::CannotConvertUnits),
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_split_amount() {
309 assert_eq!(Amount::from(1).split(), vec![Amount::from(1)]);
310 assert_eq!(Amount::from(2).split(), vec![Amount::from(2)]);
311 assert_eq!(
312 Amount::from(3).split(),
313 vec![Amount::from(2), Amount::from(1)]
314 );
315 let amounts: Vec<Amount> = [8, 2, 1].iter().map(|a| Amount::from(*a)).collect();
316 assert_eq!(Amount::from(11).split(), amounts);
317 let amounts: Vec<Amount> = [128, 64, 32, 16, 8, 4, 2, 1]
318 .iter()
319 .map(|a| Amount::from(*a))
320 .collect();
321 assert_eq!(Amount::from(255).split(), amounts);
322 }
323
324 #[test]
325 fn test_split_target_amount() {
326 let amount = Amount(65);
327
328 let split = amount
329 .split_targeted(&SplitTarget::Value(Amount(32)))
330 .unwrap();
331 assert_eq!(vec![Amount(1), Amount(32), Amount(32)], split);
332
333 let amount = Amount(150);
334
335 let split = amount
336 .split_targeted(&SplitTarget::Value(Amount::from(50)))
337 .unwrap();
338 assert_eq!(
339 vec![
340 Amount(2),
341 Amount(2),
342 Amount(2),
343 Amount(16),
344 Amount(16),
345 Amount(16),
346 Amount(32),
347 Amount(32),
348 Amount(32)
349 ],
350 split
351 );
352
353 let amount = Amount::from(63);
354
355 let split = amount
356 .split_targeted(&SplitTarget::Value(Amount::from(32)))
357 .unwrap();
358 assert_eq!(
359 vec![
360 Amount(1),
361 Amount(2),
362 Amount(4),
363 Amount(8),
364 Amount(16),
365 Amount(32)
366 ],
367 split
368 );
369 }
370
371 #[test]
372 fn test_split_values() {
373 let amount = Amount(10);
374
375 let target = vec![Amount(2), Amount(4), Amount(4)];
376
377 let split_target = SplitTarget::Values(target.clone());
378
379 let values = amount.split_targeted(&split_target).unwrap();
380
381 assert_eq!(target, values);
382
383 let target = vec![Amount(2), Amount(4), Amount(4)];
384
385 let split_target = SplitTarget::Values(vec![Amount(2), Amount(4)]);
386
387 let values = amount.split_targeted(&split_target).unwrap();
388
389 assert_eq!(target, values);
390
391 let split_target = SplitTarget::Values(vec![Amount(2), Amount(10)]);
392
393 let values = amount.split_targeted(&split_target);
394
395 assert!(values.is_err())
396 }
397
398 #[test]
399 #[should_panic]
400 fn test_amount_addition() {
401 let amount_one: Amount = u64::MAX.into();
402 let amount_two: Amount = 1.into();
403
404 let amounts = vec![amount_one, amount_two];
405
406 let _total: Amount = Amount::try_sum(amounts).unwrap();
407 }
408
409 #[test]
410 fn test_try_amount_addition() {
411 let amount_one: Amount = u64::MAX.into();
412 let amount_two: Amount = 1.into();
413
414 let amounts = vec![amount_one, amount_two];
415
416 let total = Amount::try_sum(amounts);
417
418 assert!(total.is_err());
419 let amount_one: Amount = 10000.into();
420 let amount_two: Amount = 1.into();
421
422 let amounts = vec![amount_one, amount_two];
423 let total = Amount::try_sum(amounts).unwrap();
424
425 assert_eq!(total, 10001.into());
426 }
427
428 #[test]
429 fn test_amount_to_unit() {
430 let amount = Amount::from(1000);
431 let current_unit = CurrencyUnit::Sat;
432 let target_unit = CurrencyUnit::Msat;
433
434 let converted = to_unit(amount, ¤t_unit, &target_unit).unwrap();
435
436 assert_eq!(converted, 1000000.into());
437
438 let amount = Amount::from(1000);
439 let current_unit = CurrencyUnit::Msat;
440 let target_unit = CurrencyUnit::Sat;
441
442 let converted = to_unit(amount, ¤t_unit, &target_unit).unwrap();
443
444 assert_eq!(converted, 1.into());
445
446 let amount = Amount::from(1);
447 let current_unit = CurrencyUnit::Usd;
448 let target_unit = CurrencyUnit::Usd;
449
450 let converted = to_unit(amount, ¤t_unit, &target_unit).unwrap();
451
452 assert_eq!(converted, 1.into());
453
454 let amount = Amount::from(1);
455 let current_unit = CurrencyUnit::Eur;
456 let target_unit = CurrencyUnit::Eur;
457
458 let converted = to_unit(amount, ¤t_unit, &target_unit).unwrap();
459
460 assert_eq!(converted, 1.into());
461
462 let amount = Amount::from(1);
463 let current_unit = CurrencyUnit::Sat;
464 let target_unit = CurrencyUnit::Eur;
465
466 let converted = to_unit(amount, ¤t_unit, &target_unit);
467
468 assert!(converted.is_err());
469 }
470}