1#![allow(clippy::inconsistent_digit_grouping)]
59
60use std::{
61 fmt::{self, Display},
62 iter::Sum,
63 ops::{Add, AddAssign, Div, Mul, Sub},
64 str::FromStr,
65};
66
67use anyhow::format_err;
68use rust_decimal::{Decimal, RoundingStrategy, prelude::ToPrimitive};
69use serde::{Deserialize, Deserializer, Serialize};
70
71use crate::dec;
72
73#[macro_use]
74mod amount_macros {
75 #[macro_export]
77 macro_rules! btc {
78 ($amount:expr) => {
79 $crate::ln::amount::Amount::try_from_btc($crate::dec!($amount))
80 .unwrap()
81 };
82 }
83
84 #[macro_export]
86 macro_rules! sat {
87 ($amount:expr) => {
88 $crate::ln::amount::Amount::from_sats_u32($amount)
89 };
90 }
91
92 #[macro_export]
94 macro_rules! msat {
95 ($amount:expr) => {
96 $crate::ln::amount::Amount::from_msat($amount)
97 };
98 }
99}
100
101#[derive(Debug, thiserror::Error)]
103pub enum Error {
104 #[error("Amount is negative")]
105 Negative,
106 #[error("Amount is too large")]
107 TooLarge,
108}
109
110#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd, Serialize)]
124pub struct Amount(Decimal);
125
126impl Amount {
127 pub const MAX: Self =
133 Self(Decimal::from_parts(4294967295, 4294967295, 0, false, 3));
134
135 pub const ZERO: Self = Self(dec!(0));
137
138 pub const MAX_BITCOIN_SUPPLY: Self = Self(dec!(21_000_000_0000_0000));
142 pub const MAX_BITCOIN_SUPPLY_SATS_U64: u64 = 21_000_000_0000_0000;
143 pub const MAX_BITCOIN_SUPPLY_MSATS_U64: u64 = 21_000_000_0000_0000_000;
144
145 pub const INVOICE_MAX_AMOUNT_MSATS_U64: u64 = u64::MAX / 10;
149
150 #[inline]
154 pub fn from_msat(msats: u64) -> Self {
155 Self(Decimal::from(msats) / dec!(1000))
156 }
157
158 #[inline]
160 pub fn from_sats_u32(sats_u32: u32) -> Self {
161 Self::from_msat(u64::from(sats_u32) * 1000)
162 }
163
164 #[inline]
166 pub fn try_from_sats_u64(sats_u64: u64) -> Result<Self, Error> {
167 Self::try_from_sats(Decimal::from(sats_u64))
168 }
169
170 #[inline]
172 pub fn try_from_sats(sats: Decimal) -> Result<Self, Error> {
173 Self::try_from_inner(sats)
174 }
175
176 #[inline]
178 pub fn try_from_btc(btc: Decimal) -> Result<Self, Error> {
179 Self::try_from_inner(btc * dec!(1_0000_0000))
180 }
181
182 #[inline]
187 pub fn msat(&self) -> u64 {
188 (self.0 * dec!(1000))
189 .to_u64()
190 .expect("Amount::MAX == u64::MAX millisats")
191 }
192
193 pub fn invoice_safe_msat(&self) -> Result<u64, Error> {
196 let msat = self.msat();
197 if msat <= Self::INVOICE_MAX_AMOUNT_MSATS_U64 {
198 Ok(msat)
199 } else {
200 Err(Error::TooLarge)
201 }
202 }
203
204 #[inline]
206 pub fn sats_u64(&self) -> u64 {
207 self.sats().to_u64().expect("Msats fits => sats fits")
208 }
209
210 #[inline]
212 pub fn sats(&self) -> Decimal {
213 self.0
214 }
215
216 #[inline]
218 pub fn btc(&self) -> Decimal {
219 self.0 / dec!(1_0000_0000)
220 }
221
222 pub fn round_sat(&self) -> Self {
226 Self(self.0.round())
227 }
228
229 pub fn floor_sat(&self) -> Self {
231 Self(self.0.round_dp_with_strategy(0, RoundingStrategy::ToZero))
235 }
236
237 #[cfg(test)]
244 fn round_msat(&self) -> Self {
245 Self(self.0.round_dp(3))
246 }
247
248 #[inline]
250 pub fn abs_diff(self, other: Self) -> Amount {
251 if self >= other {
252 self - other
253 } else {
254 other - self
255 }
256 }
257
258 #[inline]
261 pub fn approx_eq(self, other: Self, epsilon: Self) -> bool {
262 self.abs_diff(other) <= epsilon
263 }
264
265 pub fn checked_add(self, rhs: Self) -> Option<Self> {
268 let inner = self.0.checked_add(rhs.0)?;
269 Self::try_from_inner(inner).ok()
270 }
271
272 pub fn checked_sub(self, rhs: Self) -> Option<Self> {
273 let inner = self.0.checked_sub(rhs.0)?;
274 Self::try_from_inner(inner).ok()
275 }
276
277 pub fn checked_mul(self, rhs: Decimal) -> Option<Self> {
279 let inner = self.0.checked_mul(rhs)?;
280 Self::try_from_inner(inner).ok()
281 }
282
283 pub fn checked_div(self, rhs: Decimal) -> Option<Self> {
285 let inner = self.0.checked_div(rhs)?;
286 Self::try_from_inner(inner).ok()
287 }
288
289 pub fn saturating_add(self, rhs: Self) -> Self {
292 Self::try_from_inner(self.0.saturating_add(rhs.0)).unwrap_or(Self::MAX)
293 }
294
295 pub fn saturating_sub(self, rhs: Self) -> Self {
296 Self::try_from_inner(self.0.saturating_sub(rhs.0)).unwrap_or(Self::ZERO)
297 }
298
299 pub fn saturating_mul(self, rhs: Decimal) -> Self {
301 Self::try_from_inner(self.0.saturating_mul(rhs)).unwrap_or(Self::MAX)
302 }
303
304 #[inline]
307 fn try_from_inner(inner: Decimal) -> Result<Self, Error> {
308 if inner.is_sign_negative() {
309 Err(Error::Negative)
310 } else if inner > Self::MAX.0 {
311 Err(Error::TooLarge)
312 } else {
313 Ok(Self(inner.round_dp(3)))
314 }
315 }
316}
317
318impl<'de> Deserialize<'de> for Amount {
319 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
320 where
321 D: Deserializer<'de>,
322 {
323 let inner: Decimal = Deserialize::deserialize(deserializer)?;
324
325 Self::try_from_inner(inner).map_err(|e| match e {
326 Error::Negative => serde::de::Error::custom("Amount was negative"),
327 Error::TooLarge => serde::de::Error::custom("Amount was too large"),
328 })
329 }
330}
331
332impl Display for Amount {
333 #[inline]
334 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
335 Decimal::fmt(&self.0, f)
337 }
338}
339
340impl FromStr for Amount {
341 type Err = anyhow::Error;
342
343 fn from_str(s: &str) -> Result<Self, Self::Err> {
344 let decimal =
345 Decimal::from_str(s).map_err(|err| format_err!("{err}"))?;
346 Ok(Amount::try_from_inner(decimal)?)
347 }
348}
349
350impl From<Amount> for bitcoin::Amount {
355 #[inline]
356 fn from(amt: Amount) -> Self {
357 Self::from_sat(amt.sats().to_u64().expect("safe by construction"))
358 }
359}
360
361impl TryFrom<bitcoin::Amount> for Amount {
362 type Error = Error;
363 #[inline]
364 fn try_from(amt: bitcoin::Amount) -> Result<Self, Self::Error> {
365 Self::try_from_sats(Decimal::from(amt.to_sat()))
366 }
367}
368
369impl Add for Amount {
372 type Output = Self;
373 fn add(self, rhs: Self) -> Self::Output {
374 Self::try_from_inner(self.0 + rhs.0).expect("Overflowed")
375 }
376}
377impl AddAssign for Amount {
378 #[inline]
379 fn add_assign(&mut self, rhs: Self) {
380 *self = *self + rhs;
381 }
382}
383
384impl Sub for Amount {
385 type Output = Self;
386 fn sub(self, rhs: Self) -> Self::Output {
387 Self::try_from_inner(self.0 - rhs.0).expect("Underflowed")
388 }
389}
390
391impl Mul<Decimal> for Amount {
393 type Output = Self;
394 fn mul(self, rhs: Decimal) -> Self::Output {
395 Self::try_from_inner(self.0 * rhs).expect("Overflowed")
396 }
397}
398impl Mul<Amount> for Decimal {
400 type Output = Amount;
401 fn mul(self, rhs: Amount) -> Self::Output {
402 Amount::try_from_inner(self * rhs.0).expect("Overflowed")
403 }
404}
405
406impl Div<Decimal> for Amount {
408 type Output = Self;
409 fn div(self, rhs: Decimal) -> Self::Output {
410 Self::try_from_inner(self.0 / rhs).expect("Overflowed")
411 }
412}
413
414impl Sum for Amount {
415 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
416 iter.fold(Amount::ZERO, Self::add)
417 }
418}
419
420#[cfg(any(test, feature = "test-utils"))]
423pub mod arb {
424 use proptest::{
425 arbitrary::Arbitrary,
426 strategy::{BoxedStrategy, Strategy},
427 };
428
429 use super::*;
430
431 impl Arbitrary for Amount {
433 type Parameters = ();
434 type Strategy = BoxedStrategy<Self>;
435 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
436 (0_u64..=Amount::MAX_BITCOIN_SUPPLY_MSATS_U64)
437 .prop_map(Amount::from_msat)
438 .boxed()
439 }
440 }
441
442 pub fn sats_amount() -> impl Strategy<Value = Amount> {
444 (0_u64..=Amount::MAX_BITCOIN_SUPPLY_SATS_U64)
445 .prop_map(|sats_u64| Amount::try_from_sats_u64(sats_u64).unwrap())
446 }
447}
448
449#[cfg(test)]
450mod test {
451 use std::str::FromStr;
452
453 use lexe_std::Apply;
454 use proptest::{
455 arbitrary::any,
456 prelude::{Strategy, TestCaseError},
457 prop_assert, prop_assert_eq, proptest,
458 };
459
460 use super::*;
461 use crate::test_utils::arbitrary;
462
463 #[test]
465 fn check_associated_constants() {
466 let max_u64_msat_in_sat = Decimal::from(u64::MAX) / dec!(1000);
468 println!("{:?}", max_u64_msat_in_sat.unpack());
469 assert_eq!(Amount::MAX, Amount(max_u64_msat_in_sat));
470
471 assert_eq!(Amount::MAX.msat(), u64::MAX);
472 assert_eq!(
473 Amount::MAX_BITCOIN_SUPPLY.sats(),
474 dec!(21_000_000) * dec!(100_000_000),
475 );
476 assert_eq!(
477 Amount::MAX_BITCOIN_SUPPLY.msat(),
478 21_000_000 * 100_000_000 * 1000,
479 );
480 }
481
482 #[test]
485 fn no_msat_u64_precision_loss() {
486 proptest!(|(msat1 in any::<u64>())| {
487 let amount = Amount::from_msat(msat1);
488 let msat2 = amount.msat();
489 prop_assert_eq!(msat1, msat2);
490 })
491 }
492
493 #[test]
495 fn sat_u32_roundtrips() {
496 proptest!(|(sat1 in any::<u32>())| {
497 let amount = Amount::from_sats_u32(sat1);
498 let sat2a = amount.sats_u64().apply(u32::try_from).unwrap();
499 let sat2b = amount.sats().to_u32().unwrap();
500 prop_assert_eq!(sat1, sat2a);
501 prop_assert_eq!(sat1, sat2b);
502 })
503 }
504
505 #[test]
512 fn no_roundtrip_inside_outside_precision_loss() {
513 proptest!(|(amount in any::<Amount>())| {
514 {
515 let roundtrip_inside =
517 Amount::try_from_sats(amount.sats()).unwrap();
518 prop_assert_eq!(amount, roundtrip_inside);
519
520 let msat_u64 = amount.msat();
523 let msat_dec = Decimal::from(msat_u64);
524 let sat_dec = msat_dec / dec!(1000);
525 let roundtrip_outside = Amount::try_from_sats(sat_dec).unwrap();
526 prop_assert_eq!(roundtrip_inside, roundtrip_outside);
527 }
528
529 {
531 let roundtrip_inside = Amount::try_from_btc(amount.btc()).unwrap();
533 prop_assert_eq!(amount, roundtrip_inside);
534
535 let msat_u64 = amount.msat();
537 let msat_dec = Decimal::from(msat_u64);
538 let btc_dec = msat_dec / dec!(100_000_000_000);
539 let roundtrip_outside = Amount::try_from_btc(btc_dec).unwrap();
540 prop_assert_eq!(roundtrip_inside, roundtrip_outside);
541 }
542 })
543 }
544
545 #[test]
547 fn amount_add_sub() {
548 proptest!(|(
549 amount1 in any::<Amount>(),
550 amount2 in any::<Amount>(),
551 )| {
552 let (greater, lesser) = if amount1 >= amount2 {
553 (amount1, amount2)
554 } else {
555 (amount2, amount1)
556 };
557
558 let diff = greater - lesser;
559 prop_assert_eq!(greater, lesser + diff);
560 prop_assert_eq!(lesser, greater - diff);
561
562 let checked_diff = greater.checked_sub(lesser).unwrap();
563 prop_assert_eq!(greater, lesser.checked_add(checked_diff).unwrap());
564 prop_assert_eq!(lesser, greater.checked_sub(checked_diff).unwrap());
565
566 if greater > lesser {
567 prop_assert!(lesser.checked_sub(greater).is_none());
568 prop_assert!(Amount::MAX.checked_add(greater).is_none());
569 }
570
571 prop_assert!(amount1.abs_diff(amount2) >= Amount::ZERO);
573 })
574 }
575
576 #[test]
578 fn amount_mul_div() {
579 proptest!(|(start in any::<Amount>())| {
580 let amount1 = Amount(start.0.round_dp(2));
583
584 let intermediate_a = amount1 / dec!(10);
585 let intermediate_b = amount1.checked_div(dec!(10)).unwrap();
586 prop_assert_eq!(intermediate_a, intermediate_b);
587
588 let amount2_a = dec!(10) * intermediate_a;
589 let amount2_b = intermediate_a * dec!(10);
590 let amount2_c = intermediate_a.checked_mul(dec!(10)).unwrap();
591 prop_assert_eq!(amount1, amount2_a);
592 prop_assert_eq!(amount1, amount2_b);
593 prop_assert_eq!(amount1, amount2_c);
594 })
595 }
596
597 fn any_bounded_decimal() -> impl Strategy<Value = Decimal> {
600 let min_nanosat: u128 = 0;
604 let max_nanosat: u128 = u128::from(Amount::MAX.msat()) * 1_000;
605 (min_nanosat..=max_nanosat)
606 .prop_map(|nanosat| Decimal::from(nanosat) / dec!(1_000_000))
608 }
609
610 #[test]
613 fn test_bounded_decimal_strategy() {
614 proptest!(|(
615 bounded_decimal in any_bounded_decimal(),
616 )| {
617 prop_assert!(bounded_decimal >= Amount::ZERO.0);
618 prop_assert!(bounded_decimal <= Amount::MAX.0);
619 });
620 }
621
622 #[test]
624 fn amount_msat_rounding() {
625 fn assert_whole_msat(amount: Amount) -> Result<(), TestCaseError> {
626 prop_assert_eq!(amount, amount.round_msat());
627 Ok(())
628 }
629
630 proptest!(|(
631 amount in any::<Amount>(),
632
633 other_amount in any::<Amount>(),
634 unbounded_dec in arbitrary::any_decimal(),
635 bounded_dec in any_bounded_decimal(),
636 )| {
637 assert_whole_msat(amount.saturating_add(other_amount))?;
639 if let Some(added) = amount.checked_add(other_amount) {
640 assert_whole_msat(added)?;
641 assert_whole_msat(amount + other_amount)?;
642 }
643
644 assert_whole_msat(amount.saturating_sub(other_amount))?;
646 if let Some(subbed) = amount.checked_sub(other_amount) {
647 assert_whole_msat(subbed)?;
648 assert_whole_msat(amount - other_amount)?;
649 }
650
651 assert_whole_msat(amount.saturating_mul(unbounded_dec))?;
654 if let Some(mulled) = amount.checked_mul(unbounded_dec) {
655 assert_whole_msat(mulled)?;
656 assert_whole_msat(amount * unbounded_dec)?;
657 assert_whole_msat(unbounded_dec * amount)?;
658 }
659
660 if let Some(dived) = amount.checked_div(unbounded_dec) {
662 assert_whole_msat(dived)?;
663 assert_whole_msat(amount / unbounded_dec)?;
664 }
665
666 assert_whole_msat(amount.abs_diff(other_amount))?;
668
669 assert_whole_msat(Amount::try_from_inner(bounded_dec).unwrap())?;
671
672 let bounded_decimal_str = bounded_dec.to_string();
674 assert_whole_msat(Amount::from_str(&bounded_decimal_str).unwrap())?;
675
676 })
677 }
678
679 #[test]
681 fn test_floor_sat() {
682 proptest!(|(amount in any::<Amount>())| {
683 let floored = amount.floor_sat();
684 prop_assert!(floored <= amount);
685 prop_assert_eq!(floored, Amount(amount.0.floor()));
688 prop_assert_eq!(
689 floored,
690 Amount::try_from_sats(Decimal::from(amount.sats_u64())).unwrap()
691 );
692 });
693 }
694
695 #[test]
697 fn amount_round_sat_btc() {
698 fn expect_no_precision_loss(amount: Amount) {
703 assert_eq!(amount.btc(), amount.round_sat().btc());
704 }
705
706 expect_no_precision_loss(Amount::from_sats_u32(0));
707 expect_no_precision_loss(Amount::from_sats_u32(10_0000));
708 expect_no_precision_loss(Amount::from_sats_u32(10_0010_0005));
709 expect_no_precision_loss(
710 Amount::try_from_sats_u64(20_999_999_9999_9999).unwrap(),
711 );
712
713 proptest!(|(amount_u64: u64)| {
714 let amount_u64 = amount_u64 % 2_100_000_000_000_000;
716 let amount = Amount::try_from_sats_u64(amount_u64).unwrap();
717 expect_no_precision_loss(amount);
718 });
719
720 assert_eq!(Amount::from_msat(1).round_sat().btc(), Amount::ZERO.btc());
725 assert_eq!(
726 Amount::from_msat(1_001).round_sat().btc(),
727 Amount::from_sats_u32(1).btc(),
728 );
729 assert_eq!(
730 Amount::from_msat(1_501).round_sat().btc(),
731 Amount::from_sats_u32(2).btc(),
732 );
733 }
734
735 #[test]
737 fn amount_btc_str() {
738 fn parse_btc_str(input: &str) -> Option<Amount> {
739 Decimal::from_str(input)
740 .ok()
741 .and_then(|btc_decimal| Amount::try_from_btc(btc_decimal).ok())
742 }
743 fn parse_eq(input: &str, expected: Amount) {
744 assert_eq!(parse_btc_str(input).unwrap(), expected);
745 }
746 fn parse_fail(input: &str) {
747 if let Some(amount) = parse_btc_str(input) {
748 panic!(
749 "Should fail to parse BTC str: '{input}', got: {amount:?}"
750 );
751 }
752 }
753
754 parse_eq("0", Amount::ZERO);
757 parse_eq("0.", Amount::ZERO);
758 parse_eq(".0", Amount::ZERO);
759 parse_eq("0.001", Amount::from_sats_u32(10_0000));
760 parse_eq("10.00", Amount::from_sats_u32(10_0000_0000));
761 parse_eq("10.", Amount::from_sats_u32(10_0000_0000));
762 parse_eq("10", Amount::from_sats_u32(10_0000_0000));
763 parse_eq("10.00000000", Amount::from_sats_u32(10_0000_0000));
764 parse_eq("10.00001230", Amount::from_sats_u32(10_0000_1230));
765 parse_eq("10.69696969", Amount::from_sats_u32(10_6969_6969));
766 parse_eq("0.00001230", Amount::from_sats_u32(1230));
767 parse_eq("0.69696969", Amount::from_sats_u32(6969_6969));
768 parse_eq(".00001230", Amount::from_sats_u32(1230));
769 parse_eq(".69696969", Amount::from_sats_u32(6969_6969));
770 parse_eq(
771 "20000000",
772 Amount::try_from_sats_u64(20_000_000_0000_0000).unwrap(),
773 );
774 parse_eq(
775 "20999999.99999999",
776 Amount::try_from_sats_u64(20_999_999_9999_9999).unwrap(),
777 );
778
779 parse_fail(".");
782 parse_fail("asdif.");
783 parse_fail("156.(6kfjaosid");
784 parse_fail("-156");
785 parse_fail("-15.4984");
786 parse_fail("-.4");
787 parse_fail(" 0.4");
788 parse_fail("0.4 ");
789
790 proptest!(|(amount: Amount)| {
793 let amount_btc_str = amount.btc().to_string();
794 let amount_round_sat_btc_str = amount.round_sat().btc().to_string();
795 let amount_btc_str_btc = parse_btc_str(&amount_btc_str).unwrap();
796 let amount_round_sat_btc_str_btc = parse_btc_str(&amount_round_sat_btc_str).unwrap();
797 prop_assert_eq!(amount, amount_btc_str_btc);
798 prop_assert_eq!(amount.btc(), amount_btc_str_btc.btc());
799 prop_assert_eq!(amount.round_sat(), amount_round_sat_btc_str_btc);
800 prop_assert_eq!(amount.round_sat().btc(), amount_round_sat_btc_str_btc.btc());
801 });
802
803 proptest!(|(s in arbitrary::any_string())| {
806 let _ = parse_btc_str(&s);
807 });
808 }
809}