cashu/
amount.rs

1//! CDK Amount
2//!
3//! Is any unit and will be treated as the unit of the wallet
4
5use std::cmp::Ordering;
6use std::fmt;
7use std::str::FromStr;
8
9use serde::{Deserialize, Serialize};
10use thiserror::Error;
11
12use crate::nuts::CurrencyUnit;
13
14/// Amount Error
15#[derive(Debug, Error)]
16pub enum Error {
17    /// Split Values must be less then or equal to amount
18    #[error("Split Values must be less then or equal to amount")]
19    SplitValuesGreater,
20    /// Amount overflow
21    #[error("Amount Overflow")]
22    AmountOverflow,
23    /// Cannot convert units
24    #[error("Cannot convert units")]
25    CannotConvertUnits,
26    /// Invalid amount
27    #[error("Invalid Amount: {0}")]
28    InvalidAmount(String),
29}
30
31/// Amount can be any unit
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
33#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
34#[serde(transparent)]
35pub struct Amount(u64);
36
37impl FromStr for Amount {
38    type Err = Error;
39
40    fn from_str(s: &str) -> Result<Self, Self::Err> {
41        let value = s
42            .parse::<u64>()
43            .map_err(|_| Error::InvalidAmount(s.to_owned()))?;
44        Ok(Amount(value))
45    }
46}
47
48impl Amount {
49    /// Amount zero
50    pub const ZERO: Amount = Amount(0);
51
52    // Amount one
53    pub const ONE: Amount = Amount(1);
54
55    /// Split into parts that are powers of two
56    pub fn split(&self) -> Vec<Self> {
57        let sats = self.0;
58        (0_u64..64)
59            .rev()
60            .filter_map(|bit| {
61                let part = 1 << bit;
62                ((sats & part) == part).then_some(Self::from(part))
63            })
64            .collect()
65    }
66
67    /// Split into parts that are powers of two by target
68    pub fn split_targeted(&self, target: &SplitTarget) -> Result<Vec<Self>, Error> {
69        let mut parts = match target {
70            SplitTarget::None => self.split(),
71            SplitTarget::Value(amount) => {
72                if self.le(amount) {
73                    return Ok(self.split());
74                }
75
76                let mut parts_total = Amount::ZERO;
77                let mut parts = Vec::new();
78
79                // The powers of two that are need to create target value
80                let parts_of_value = amount.split();
81
82                while parts_total.lt(self) {
83                    for part in parts_of_value.iter().copied() {
84                        if (part + parts_total).le(self) {
85                            parts.push(part);
86                        } else {
87                            let amount_left = *self - parts_total;
88                            parts.extend(amount_left.split());
89                        }
90
91                        parts_total = Amount::try_sum(parts.clone().iter().copied())?;
92
93                        if parts_total.eq(self) {
94                            break;
95                        }
96                    }
97                }
98
99                parts
100            }
101            SplitTarget::Values(values) => {
102                let values_total: Amount = Amount::try_sum(values.clone().into_iter())?;
103
104                match self.cmp(&values_total) {
105                    Ordering::Equal => values.clone(),
106                    Ordering::Less => {
107                        return Err(Error::SplitValuesGreater);
108                    }
109                    Ordering::Greater => {
110                        let extra = *self - values_total;
111                        let mut extra_amount = extra.split();
112                        let mut values = values.clone();
113
114                        values.append(&mut extra_amount);
115                        values
116                    }
117                }
118            }
119        };
120
121        parts.sort();
122        Ok(parts)
123    }
124
125    /// Splits amount into powers of two while accounting for the swap fee
126    pub fn split_with_fee(&self, fee_ppk: u64) -> Result<Vec<Self>, Error> {
127        let without_fee_amounts = self.split();
128        let fee_ppk = fee_ppk * without_fee_amounts.len() as u64;
129        let fee = Amount::from((fee_ppk + 999) / 1000);
130        let new_amount = self.checked_add(fee).ok_or(Error::AmountOverflow)?;
131
132        let split = new_amount.split();
133        let split_fee_ppk = split.len() as u64 * fee_ppk;
134        let split_fee = Amount::from((split_fee_ppk + 999) / 1000);
135
136        if let Some(net_amount) = new_amount.checked_sub(split_fee) {
137            if net_amount >= *self {
138                return Ok(split);
139            }
140        }
141        self.checked_add(Amount::ONE)
142            .ok_or(Error::AmountOverflow)?
143            .split_with_fee(fee_ppk)
144    }
145
146    /// Checked addition for Amount. Returns None if overflow occurs.
147    pub fn checked_add(self, other: Amount) -> Option<Amount> {
148        self.0.checked_add(other.0).map(Amount)
149    }
150
151    /// Checked subtraction for Amount. Returns None if overflow occurs.
152    pub fn checked_sub(self, other: Amount) -> Option<Amount> {
153        self.0.checked_sub(other.0).map(Amount)
154    }
155
156    /// Checked multiplication for Amount. Returns None if overflow occurs.
157    pub fn checked_mul(self, other: Amount) -> Option<Amount> {
158        self.0.checked_mul(other.0).map(Amount)
159    }
160
161    /// Checked division for Amount. Returns None if overflow occurs.
162    pub fn checked_div(self, other: Amount) -> Option<Amount> {
163        self.0.checked_div(other.0).map(Amount)
164    }
165
166    /// Try sum to check for overflow
167    pub fn try_sum<I>(iter: I) -> Result<Self, Error>
168    where
169        I: IntoIterator<Item = Self>,
170    {
171        iter.into_iter().try_fold(Amount::ZERO, |acc, x| {
172            acc.checked_add(x).ok_or(Error::AmountOverflow)
173        })
174    }
175}
176
177impl Default for Amount {
178    fn default() -> Self {
179        Amount::ZERO
180    }
181}
182
183impl Default for &Amount {
184    fn default() -> Self {
185        &Amount::ZERO
186    }
187}
188
189impl fmt::Display for Amount {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        if let Some(width) = f.width() {
192            write!(f, "{:width$}", self.0, width = width)
193        } else {
194            write!(f, "{}", self.0)
195        }
196    }
197}
198
199impl From<u64> for Amount {
200    fn from(value: u64) -> Self {
201        Self(value)
202    }
203}
204
205impl From<&u64> for Amount {
206    fn from(value: &u64) -> Self {
207        Self(*value)
208    }
209}
210
211impl From<Amount> for u64 {
212    fn from(value: Amount) -> Self {
213        value.0
214    }
215}
216
217impl AsRef<u64> for Amount {
218    fn as_ref(&self) -> &u64 {
219        &self.0
220    }
221}
222
223impl std::ops::Add for Amount {
224    type Output = Amount;
225
226    fn add(self, rhs: Amount) -> Self::Output {
227        Amount(self.0.checked_add(rhs.0).expect("Addition error"))
228    }
229}
230
231impl std::ops::AddAssign for Amount {
232    fn add_assign(&mut self, rhs: Self) {
233        self.0 = self.0.checked_add(rhs.0).expect("Addition error");
234    }
235}
236
237impl std::ops::Sub for Amount {
238    type Output = Amount;
239
240    fn sub(self, rhs: Amount) -> Self::Output {
241        Amount(self.0 - rhs.0)
242    }
243}
244
245impl std::ops::SubAssign for Amount {
246    fn sub_assign(&mut self, other: Self) {
247        self.0 -= other.0;
248    }
249}
250
251impl std::ops::Mul for Amount {
252    type Output = Self;
253
254    fn mul(self, other: Self) -> Self::Output {
255        Amount(self.0 * other.0)
256    }
257}
258
259impl std::ops::Div for Amount {
260    type Output = Self;
261
262    fn div(self, other: Self) -> Self::Output {
263        Amount(self.0 / other.0)
264    }
265}
266
267/// Kinds of targeting that are supported
268#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)]
269pub enum SplitTarget {
270    /// Default target; least amount of proofs
271    #[default]
272    None,
273    /// Target amount for wallet to have most proofs that add up to value
274    Value(Amount),
275    /// Specific amounts to split into **MUST** equal amount being split
276    Values(Vec<Amount>),
277}
278
279/// Msats in sat
280pub const MSAT_IN_SAT: u64 = 1000;
281
282/// Helper function to convert units
283pub fn to_unit<T>(
284    amount: T,
285    current_unit: &CurrencyUnit,
286    target_unit: &CurrencyUnit,
287) -> Result<Amount, Error>
288where
289    T: Into<u64>,
290{
291    let amount = amount.into();
292    match (current_unit, target_unit) {
293        (CurrencyUnit::Sat, CurrencyUnit::Sat) => Ok(amount.into()),
294        (CurrencyUnit::Msat, CurrencyUnit::Msat) => Ok(amount.into()),
295        (CurrencyUnit::Sat, CurrencyUnit::Msat) => Ok((amount * MSAT_IN_SAT).into()),
296        (CurrencyUnit::Msat, CurrencyUnit::Sat) => Ok((amount / MSAT_IN_SAT).into()),
297        (CurrencyUnit::Usd, CurrencyUnit::Usd) => Ok(amount.into()),
298        (CurrencyUnit::Eur, CurrencyUnit::Eur) => Ok(amount.into()),
299        _ => Err(Error::CannotConvertUnits),
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_split_amount() {
309        assert_eq!(Amount::from(1).split(), vec![Amount::from(1)]);
310        assert_eq!(Amount::from(2).split(), vec![Amount::from(2)]);
311        assert_eq!(
312            Amount::from(3).split(),
313            vec![Amount::from(2), Amount::from(1)]
314        );
315        let amounts: Vec<Amount> = [8, 2, 1].iter().map(|a| Amount::from(*a)).collect();
316        assert_eq!(Amount::from(11).split(), amounts);
317        let amounts: Vec<Amount> = [128, 64, 32, 16, 8, 4, 2, 1]
318            .iter()
319            .map(|a| Amount::from(*a))
320            .collect();
321        assert_eq!(Amount::from(255).split(), amounts);
322    }
323
324    #[test]
325    fn test_split_target_amount() {
326        let amount = Amount(65);
327
328        let split = amount
329            .split_targeted(&SplitTarget::Value(Amount(32)))
330            .unwrap();
331        assert_eq!(vec![Amount(1), Amount(32), Amount(32)], split);
332
333        let amount = Amount(150);
334
335        let split = amount
336            .split_targeted(&SplitTarget::Value(Amount::from(50)))
337            .unwrap();
338        assert_eq!(
339            vec![
340                Amount(2),
341                Amount(2),
342                Amount(2),
343                Amount(16),
344                Amount(16),
345                Amount(16),
346                Amount(32),
347                Amount(32),
348                Amount(32)
349            ],
350            split
351        );
352
353        let amount = Amount::from(63);
354
355        let split = amount
356            .split_targeted(&SplitTarget::Value(Amount::from(32)))
357            .unwrap();
358        assert_eq!(
359            vec![
360                Amount(1),
361                Amount(2),
362                Amount(4),
363                Amount(8),
364                Amount(16),
365                Amount(32)
366            ],
367            split
368        );
369    }
370
371    #[test]
372    fn test_split_with_fee() {
373        let amount = Amount(2);
374        let fee_ppk = 1;
375
376        let split = amount.split_with_fee(fee_ppk).unwrap();
377        assert_eq!(split, vec![Amount(2), Amount(1)]);
378
379        let amount = Amount(3);
380        let fee_ppk = 1;
381
382        let split = amount.split_with_fee(fee_ppk).unwrap();
383        assert_eq!(split, vec![Amount(4)]);
384
385        let amount = Amount(3);
386        let fee_ppk = 1000;
387
388        let split = amount.split_with_fee(fee_ppk).unwrap();
389        assert_eq!(split, vec![Amount(32)]);
390    }
391
392    #[test]
393    fn test_split_values() {
394        let amount = Amount(10);
395
396        let target = vec![Amount(2), Amount(4), Amount(4)];
397
398        let split_target = SplitTarget::Values(target.clone());
399
400        let values = amount.split_targeted(&split_target).unwrap();
401
402        assert_eq!(target, values);
403
404        let target = vec![Amount(2), Amount(4), Amount(4)];
405
406        let split_target = SplitTarget::Values(vec![Amount(2), Amount(4)]);
407
408        let values = amount.split_targeted(&split_target).unwrap();
409
410        assert_eq!(target, values);
411
412        let split_target = SplitTarget::Values(vec![Amount(2), Amount(10)]);
413
414        let values = amount.split_targeted(&split_target);
415
416        assert!(values.is_err())
417    }
418
419    #[test]
420    #[should_panic]
421    fn test_amount_addition() {
422        let amount_one: Amount = u64::MAX.into();
423        let amount_two: Amount = 1.into();
424
425        let amounts = vec![amount_one, amount_two];
426
427        let _total: Amount = Amount::try_sum(amounts).unwrap();
428    }
429
430    #[test]
431    fn test_try_amount_addition() {
432        let amount_one: Amount = u64::MAX.into();
433        let amount_two: Amount = 1.into();
434
435        let amounts = vec![amount_one, amount_two];
436
437        let total = Amount::try_sum(amounts);
438
439        assert!(total.is_err());
440        let amount_one: Amount = 10000.into();
441        let amount_two: Amount = 1.into();
442
443        let amounts = vec![amount_one, amount_two];
444        let total = Amount::try_sum(amounts).unwrap();
445
446        assert_eq!(total, 10001.into());
447    }
448
449    #[test]
450    fn test_amount_to_unit() {
451        let amount = Amount::from(1000);
452        let current_unit = CurrencyUnit::Sat;
453        let target_unit = CurrencyUnit::Msat;
454
455        let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
456
457        assert_eq!(converted, 1000000.into());
458
459        let amount = Amount::from(1000);
460        let current_unit = CurrencyUnit::Msat;
461        let target_unit = CurrencyUnit::Sat;
462
463        let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
464
465        assert_eq!(converted, 1.into());
466
467        let amount = Amount::from(1);
468        let current_unit = CurrencyUnit::Usd;
469        let target_unit = CurrencyUnit::Usd;
470
471        let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
472
473        assert_eq!(converted, 1.into());
474
475        let amount = Amount::from(1);
476        let current_unit = CurrencyUnit::Eur;
477        let target_unit = CurrencyUnit::Eur;
478
479        let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
480
481        assert_eq!(converted, 1.into());
482
483        let amount = Amount::from(1);
484        let current_unit = CurrencyUnit::Sat;
485        let target_unit = CurrencyUnit::Eur;
486
487        let converted = to_unit(amount, &current_unit, &target_unit);
488
489        assert!(converted.is_err());
490    }
491}