cashu/
amount.rs

1//! CDK Amount
2//!
3//! Is any unit and will be treated as the unit of the wallet
4
5use std::cmp::Ordering;
6use std::fmt;
7use std::str::FromStr;
8
9use serde::{Deserialize, Deserializer, Serialize, Serializer};
10use thiserror::Error;
11
12use crate::nuts::CurrencyUnit;
13
14/// Amount Error
15#[derive(Debug, Error)]
16pub enum Error {
17    /// Split Values must be less then or equal to amount
18    #[error("Split Values must be less then or equal to amount")]
19    SplitValuesGreater,
20    /// Amount overflow
21    #[error("Amount Overflow")]
22    AmountOverflow,
23    /// Cannot convert units
24    #[error("Cannot convert units")]
25    CannotConvertUnits,
26}
27
28/// Amount can be any unit
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
30#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
31#[serde(transparent)]
32pub struct Amount(u64);
33
34impl Amount {
35    /// Amount zero
36    pub const ZERO: Amount = Amount(0);
37
38    /// Split into parts that are powers of two
39    pub fn split(&self) -> Vec<Self> {
40        let sats = self.0;
41        (0_u64..64)
42            .rev()
43            .filter_map(|bit| {
44                let part = 1 << bit;
45                ((sats & part) == part).then_some(Self::from(part))
46            })
47            .collect()
48    }
49
50    /// Split into parts that are powers of two by target
51    pub fn split_targeted(&self, target: &SplitTarget) -> Result<Vec<Self>, Error> {
52        let mut parts = match target {
53            SplitTarget::None => self.split(),
54            SplitTarget::Value(amount) => {
55                if self.le(amount) {
56                    return Ok(self.split());
57                }
58
59                let mut parts_total = Amount::ZERO;
60                let mut parts = Vec::new();
61
62                // The powers of two that are need to create target value
63                let parts_of_value = amount.split();
64
65                while parts_total.lt(self) {
66                    for part in parts_of_value.iter().copied() {
67                        if (part + parts_total).le(self) {
68                            parts.push(part);
69                        } else {
70                            let amount_left = *self - parts_total;
71                            parts.extend(amount_left.split());
72                        }
73
74                        parts_total = Amount::try_sum(parts.clone().iter().copied())?;
75
76                        if parts_total.eq(self) {
77                            break;
78                        }
79                    }
80                }
81
82                parts
83            }
84            SplitTarget::Values(values) => {
85                let values_total: Amount = Amount::try_sum(values.clone().into_iter())?;
86
87                match self.cmp(&values_total) {
88                    Ordering::Equal => values.clone(),
89                    Ordering::Less => {
90                        return Err(Error::SplitValuesGreater);
91                    }
92                    Ordering::Greater => {
93                        let extra = *self - values_total;
94                        let mut extra_amount = extra.split();
95                        let mut values = values.clone();
96
97                        values.append(&mut extra_amount);
98                        values
99                    }
100                }
101            }
102        };
103
104        parts.sort();
105        Ok(parts)
106    }
107
108    /// Checked addition for Amount. Returns None if overflow occurs.
109    pub fn checked_add(self, other: Amount) -> Option<Amount> {
110        self.0.checked_add(other.0).map(Amount)
111    }
112
113    /// Checked subtraction for Amount. Returns None if overflow occurs.
114    pub fn checked_sub(self, other: Amount) -> Option<Amount> {
115        self.0.checked_sub(other.0).map(Amount)
116    }
117
118    /// Try sum to check for overflow
119    pub fn try_sum<I>(iter: I) -> Result<Self, Error>
120    where
121        I: IntoIterator<Item = Self>,
122    {
123        iter.into_iter().try_fold(Amount::ZERO, |acc, x| {
124            acc.checked_add(x).ok_or(Error::AmountOverflow)
125        })
126    }
127}
128
129impl Default for Amount {
130    fn default() -> Self {
131        Amount::ZERO
132    }
133}
134
135impl Default for &Amount {
136    fn default() -> Self {
137        &Amount::ZERO
138    }
139}
140
141impl fmt::Display for Amount {
142    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143        if let Some(width) = f.width() {
144            write!(f, "{:width$}", self.0, width = width)
145        } else {
146            write!(f, "{}", self.0)
147        }
148    }
149}
150
151impl From<u64> for Amount {
152    fn from(value: u64) -> Self {
153        Self(value)
154    }
155}
156
157impl From<&u64> for Amount {
158    fn from(value: &u64) -> Self {
159        Self(*value)
160    }
161}
162
163impl From<Amount> for u64 {
164    fn from(value: Amount) -> Self {
165        value.0
166    }
167}
168
169impl AsRef<u64> for Amount {
170    fn as_ref(&self) -> &u64 {
171        &self.0
172    }
173}
174
175impl std::ops::Add for Amount {
176    type Output = Amount;
177
178    fn add(self, rhs: Amount) -> Self::Output {
179        Amount(self.0.checked_add(rhs.0).expect("Addition error"))
180    }
181}
182
183impl std::ops::AddAssign for Amount {
184    fn add_assign(&mut self, rhs: Self) {
185        self.0 = self.0.checked_add(rhs.0).expect("Addition error");
186    }
187}
188
189impl std::ops::Sub for Amount {
190    type Output = Amount;
191
192    fn sub(self, rhs: Amount) -> Self::Output {
193        Amount(self.0 - rhs.0)
194    }
195}
196
197impl std::ops::SubAssign for Amount {
198    fn sub_assign(&mut self, other: Self) {
199        self.0 -= other.0;
200    }
201}
202
203impl std::ops::Mul for Amount {
204    type Output = Self;
205
206    fn mul(self, other: Self) -> Self::Output {
207        Amount(self.0 * other.0)
208    }
209}
210
211impl std::ops::Div for Amount {
212    type Output = Self;
213
214    fn div(self, other: Self) -> Self::Output {
215        Amount(self.0 / other.0)
216    }
217}
218
219/// String wrapper for an [Amount].
220///
221/// It ser-/deserializes the inner [Amount] to a string, while at the same time using the [u64]
222/// value of the [Amount] for comparison and ordering. This helps automatically sort the keys of
223/// a [BTreeMap] when [AmountStr] is used as key.
224#[derive(Debug, Clone, PartialEq, Eq)]
225pub struct AmountStr(Amount);
226
227impl AmountStr {
228    pub(crate) fn from(amt: Amount) -> Self {
229        Self(amt)
230    }
231}
232
233impl PartialOrd<Self> for AmountStr {
234    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
235        Some(self.cmp(other))
236    }
237}
238
239impl Ord for AmountStr {
240    fn cmp(&self, other: &Self) -> Ordering {
241        self.0.cmp(&other.0)
242    }
243}
244
245impl<'de> Deserialize<'de> for AmountStr {
246    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
247    where
248        D: Deserializer<'de>,
249    {
250        let s = String::deserialize(deserializer)?;
251        u64::from_str(&s)
252            .map(Amount)
253            .map(Self)
254            .map_err(serde::de::Error::custom)
255    }
256}
257
258impl Serialize for AmountStr {
259    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
260    where
261        S: Serializer,
262    {
263        serializer.serialize_str(&self.0.to_string())
264    }
265}
266
267/// Kinds of targeting that are supported
268#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)]
269pub enum SplitTarget {
270    /// Default target; least amount of proofs
271    #[default]
272    None,
273    /// Target amount for wallet to have most proofs that add up to value
274    Value(Amount),
275    /// Specific amounts to split into **MUST** equal amount being split
276    Values(Vec<Amount>),
277}
278
279/// Msats in sat
280pub const MSAT_IN_SAT: u64 = 1000;
281
282/// Helper function to convert units
283pub fn to_unit<T>(
284    amount: T,
285    current_unit: &CurrencyUnit,
286    target_unit: &CurrencyUnit,
287) -> Result<Amount, Error>
288where
289    T: Into<u64>,
290{
291    let amount = amount.into();
292    match (current_unit, target_unit) {
293        (CurrencyUnit::Sat, CurrencyUnit::Sat) => Ok(amount.into()),
294        (CurrencyUnit::Msat, CurrencyUnit::Msat) => Ok(amount.into()),
295        (CurrencyUnit::Sat, CurrencyUnit::Msat) => Ok((amount * MSAT_IN_SAT).into()),
296        (CurrencyUnit::Msat, CurrencyUnit::Sat) => Ok((amount / MSAT_IN_SAT).into()),
297        (CurrencyUnit::Usd, CurrencyUnit::Usd) => Ok(amount.into()),
298        (CurrencyUnit::Eur, CurrencyUnit::Eur) => Ok(amount.into()),
299        _ => Err(Error::CannotConvertUnits),
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_split_amount() {
309        assert_eq!(Amount::from(1).split(), vec![Amount::from(1)]);
310        assert_eq!(Amount::from(2).split(), vec![Amount::from(2)]);
311        assert_eq!(
312            Amount::from(3).split(),
313            vec![Amount::from(2), Amount::from(1)]
314        );
315        let amounts: Vec<Amount> = [8, 2, 1].iter().map(|a| Amount::from(*a)).collect();
316        assert_eq!(Amount::from(11).split(), amounts);
317        let amounts: Vec<Amount> = [128, 64, 32, 16, 8, 4, 2, 1]
318            .iter()
319            .map(|a| Amount::from(*a))
320            .collect();
321        assert_eq!(Amount::from(255).split(), amounts);
322    }
323
324    #[test]
325    fn test_split_target_amount() {
326        let amount = Amount(65);
327
328        let split = amount
329            .split_targeted(&SplitTarget::Value(Amount(32)))
330            .unwrap();
331        assert_eq!(vec![Amount(1), Amount(32), Amount(32)], split);
332
333        let amount = Amount(150);
334
335        let split = amount
336            .split_targeted(&SplitTarget::Value(Amount::from(50)))
337            .unwrap();
338        assert_eq!(
339            vec![
340                Amount(2),
341                Amount(2),
342                Amount(2),
343                Amount(16),
344                Amount(16),
345                Amount(16),
346                Amount(32),
347                Amount(32),
348                Amount(32)
349            ],
350            split
351        );
352
353        let amount = Amount::from(63);
354
355        let split = amount
356            .split_targeted(&SplitTarget::Value(Amount::from(32)))
357            .unwrap();
358        assert_eq!(
359            vec![
360                Amount(1),
361                Amount(2),
362                Amount(4),
363                Amount(8),
364                Amount(16),
365                Amount(32)
366            ],
367            split
368        );
369    }
370
371    #[test]
372    fn test_split_values() {
373        let amount = Amount(10);
374
375        let target = vec![Amount(2), Amount(4), Amount(4)];
376
377        let split_target = SplitTarget::Values(target.clone());
378
379        let values = amount.split_targeted(&split_target).unwrap();
380
381        assert_eq!(target, values);
382
383        let target = vec![Amount(2), Amount(4), Amount(4)];
384
385        let split_target = SplitTarget::Values(vec![Amount(2), Amount(4)]);
386
387        let values = amount.split_targeted(&split_target).unwrap();
388
389        assert_eq!(target, values);
390
391        let split_target = SplitTarget::Values(vec![Amount(2), Amount(10)]);
392
393        let values = amount.split_targeted(&split_target);
394
395        assert!(values.is_err())
396    }
397
398    #[test]
399    #[should_panic]
400    fn test_amount_addition() {
401        let amount_one: Amount = u64::MAX.into();
402        let amount_two: Amount = 1.into();
403
404        let amounts = vec![amount_one, amount_two];
405
406        let _total: Amount = Amount::try_sum(amounts).unwrap();
407    }
408
409    #[test]
410    fn test_try_amount_addition() {
411        let amount_one: Amount = u64::MAX.into();
412        let amount_two: Amount = 1.into();
413
414        let amounts = vec![amount_one, amount_two];
415
416        let total = Amount::try_sum(amounts);
417
418        assert!(total.is_err());
419        let amount_one: Amount = 10000.into();
420        let amount_two: Amount = 1.into();
421
422        let amounts = vec![amount_one, amount_two];
423        let total = Amount::try_sum(amounts).unwrap();
424
425        assert_eq!(total, 10001.into());
426    }
427
428    #[test]
429    fn test_amount_to_unit() {
430        let amount = Amount::from(1000);
431        let current_unit = CurrencyUnit::Sat;
432        let target_unit = CurrencyUnit::Msat;
433
434        let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
435
436        assert_eq!(converted, 1000000.into());
437
438        let amount = Amount::from(1000);
439        let current_unit = CurrencyUnit::Msat;
440        let target_unit = CurrencyUnit::Sat;
441
442        let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
443
444        assert_eq!(converted, 1.into());
445
446        let amount = Amount::from(1);
447        let current_unit = CurrencyUnit::Usd;
448        let target_unit = CurrencyUnit::Usd;
449
450        let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
451
452        assert_eq!(converted, 1.into());
453
454        let amount = Amount::from(1);
455        let current_unit = CurrencyUnit::Eur;
456        let target_unit = CurrencyUnit::Eur;
457
458        let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
459
460        assert_eq!(converted, 1.into());
461
462        let amount = Amount::from(1);
463        let current_unit = CurrencyUnit::Sat;
464        let target_unit = CurrencyUnit::Eur;
465
466        let converted = to_unit(amount, &current_unit, &target_unit);
467
468        assert!(converted.is_err());
469    }
470}