1#![allow(clippy::inconsistent_digit_grouping)]
59
60use std::{
61 fmt::{self, Display},
62 iter::Sum,
63 ops::{Add, AddAssign, Div, Mul, Sub},
64 str::FromStr,
65};
66
67use anyhow::format_err;
68use rust_decimal::{Decimal, RoundingStrategy, prelude::ToPrimitive};
69use rust_decimal_macros::dec;
70use serde::{Deserialize, Deserializer, Serialize};
71
72#[macro_use]
73mod amount_macros {
74 #[macro_export]
76 macro_rules! btc {
77 ($amount:expr) => {
78 Amount::try_from_btc(dec!($amount)).unwrap()
79 };
80 }
81
82 #[macro_export]
84 macro_rules! sat {
85 ($amount:expr) => {
86 Amount::from_sats_u32($amount)
87 };
88 }
89
90 #[macro_export]
92 macro_rules! msat {
93 ($amount:expr) => {
94 Amount::from_msat($amount)
95 };
96 }
97}
98
99#[derive(Debug, thiserror::Error)]
101pub enum Error {
102 #[error("Amount is negative")]
103 Negative,
104 #[error("Amount is too large")]
105 TooLarge,
106}
107
108#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd, Serialize)]
122pub struct Amount(Decimal);
123
124impl Amount {
125 pub const MAX: Self =
131 Self(Decimal::from_parts(4294967295, 4294967295, 0, false, 3));
132
133 pub const ZERO: Self = Self(dec!(0));
135
136 pub const MAX_BITCOIN_SUPPLY: Self = Self(dec!(21_000_000_0000_0000));
140 pub const MAX_BITCOIN_SUPPLY_SATS_U64: u64 = 21_000_000_0000_0000;
141 pub const MAX_BITCOIN_SUPPLY_MSATS_U64: u64 = 21_000_000_0000_0000_000;
142
143 pub const INVOICE_MAX_AMOUNT_MSATS_U64: u64 = u64::MAX / 10;
147
148 #[inline]
152 pub fn from_msat(msats: u64) -> Self {
153 Self(Decimal::from(msats) / dec!(1000))
154 }
155
156 #[inline]
158 pub fn from_sats_u32(sats_u32: u32) -> Self {
159 Self::from_msat(u64::from(sats_u32) * 1000)
160 }
161
162 #[inline]
164 pub fn try_from_sats_u64(sats_u64: u64) -> Result<Self, Error> {
165 Self::try_from_sats(Decimal::from(sats_u64))
166 }
167
168 #[inline]
170 pub fn try_from_sats(sats: Decimal) -> Result<Self, Error> {
171 Self::try_from_inner(sats)
172 }
173
174 #[inline]
176 pub fn try_from_btc(btc: Decimal) -> Result<Self, Error> {
177 Self::try_from_inner(btc * dec!(1_0000_0000))
178 }
179
180 #[inline]
185 pub fn msat(&self) -> u64 {
186 (self.0 * dec!(1000))
187 .to_u64()
188 .expect("Amount::MAX == u64::MAX millisats")
189 }
190
191 pub fn invoice_safe_msat(&self) -> Result<u64, Error> {
194 let msat = self.msat();
195 if msat <= Self::INVOICE_MAX_AMOUNT_MSATS_U64 {
196 Ok(msat)
197 } else {
198 Err(Error::TooLarge)
199 }
200 }
201
202 #[inline]
204 pub fn sats_u64(&self) -> u64 {
205 self.sats().to_u64().expect("Msats fits => sats fits")
206 }
207
208 #[inline]
210 pub fn sats(&self) -> Decimal {
211 self.0
212 }
213
214 #[inline]
216 pub fn btc(&self) -> Decimal {
217 self.0 / dec!(1_0000_0000)
218 }
219
220 pub fn round_sat(&self) -> Self {
224 Self(self.0.round())
225 }
226
227 pub fn floor_sat(&self) -> Self {
229 Self(self.0.round_dp_with_strategy(0, RoundingStrategy::ToZero))
233 }
234
235 #[cfg(test)]
242 fn round_msat(&self) -> Self {
243 Self(self.0.round_dp(3))
244 }
245
246 #[inline]
248 pub fn abs_diff(self, other: Self) -> Amount {
249 if self >= other {
250 self - other
251 } else {
252 other - self
253 }
254 }
255
256 #[inline]
259 pub fn approx_eq(self, other: Self, epsilon: Self) -> bool {
260 self.abs_diff(other) <= epsilon
261 }
262
263 pub fn checked_add(self, rhs: Self) -> Option<Self> {
266 let inner = self.0.checked_add(rhs.0)?;
267 Self::try_from_inner(inner).ok()
268 }
269
270 pub fn checked_sub(self, rhs: Self) -> Option<Self> {
271 let inner = self.0.checked_sub(rhs.0)?;
272 Self::try_from_inner(inner).ok()
273 }
274
275 pub fn checked_mul(self, rhs: Decimal) -> Option<Self> {
277 let inner = self.0.checked_mul(rhs)?;
278 Self::try_from_inner(inner).ok()
279 }
280
281 pub fn checked_div(self, rhs: Decimal) -> Option<Self> {
283 let inner = self.0.checked_div(rhs)?;
284 Self::try_from_inner(inner).ok()
285 }
286
287 pub fn saturating_add(self, rhs: Self) -> Self {
290 Self::try_from_inner(self.0.saturating_add(rhs.0)).unwrap_or(Self::MAX)
291 }
292
293 pub fn saturating_sub(self, rhs: Self) -> Self {
294 Self::try_from_inner(self.0.saturating_sub(rhs.0)).unwrap_or(Self::ZERO)
295 }
296
297 pub fn saturating_mul(self, rhs: Decimal) -> Self {
299 Self::try_from_inner(self.0.saturating_mul(rhs)).unwrap_or(Self::MAX)
300 }
301
302 #[inline]
305 fn try_from_inner(inner: Decimal) -> Result<Self, Error> {
306 if inner.is_sign_negative() {
307 Err(Error::Negative)
308 } else if inner > Self::MAX.0 {
309 Err(Error::TooLarge)
310 } else {
311 Ok(Self(inner.round_dp(3)))
312 }
313 }
314}
315
316impl<'de> Deserialize<'de> for Amount {
317 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
318 where
319 D: Deserializer<'de>,
320 {
321 let inner: Decimal = Deserialize::deserialize(deserializer)?;
322
323 Self::try_from_inner(inner).map_err(|e| match e {
324 Error::Negative => serde::de::Error::custom("Amount was negative"),
325 Error::TooLarge => serde::de::Error::custom("Amount was too large"),
326 })
327 }
328}
329
330impl Display for Amount {
331 #[inline]
332 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
333 Decimal::fmt(&self.0, f)
335 }
336}
337
338impl FromStr for Amount {
339 type Err = anyhow::Error;
340
341 fn from_str(s: &str) -> Result<Self, Self::Err> {
342 let decimal =
343 Decimal::from_str(s).map_err(|err| format_err!("{err}"))?;
344 Ok(Amount::try_from_inner(decimal)?)
345 }
346}
347
348impl From<Amount> for bitcoin::Amount {
353 #[inline]
354 fn from(amt: Amount) -> Self {
355 Self::from_sat(amt.sats().to_u64().expect("safe by construction"))
356 }
357}
358
359impl TryFrom<bitcoin::Amount> for Amount {
360 type Error = Error;
361 #[inline]
362 fn try_from(amt: bitcoin::Amount) -> Result<Self, Self::Error> {
363 Self::try_from_sats(Decimal::from(amt.to_sat()))
364 }
365}
366
367impl Add for Amount {
370 type Output = Self;
371 fn add(self, rhs: Self) -> Self::Output {
372 Self::try_from_inner(self.0 + rhs.0).expect("Overflowed")
373 }
374}
375impl AddAssign for Amount {
376 #[inline]
377 fn add_assign(&mut self, rhs: Self) {
378 *self = *self + rhs;
379 }
380}
381
382impl Sub for Amount {
383 type Output = Self;
384 fn sub(self, rhs: Self) -> Self::Output {
385 Self::try_from_inner(self.0 - rhs.0).expect("Underflowed")
386 }
387}
388
389impl Mul<Decimal> for Amount {
391 type Output = Self;
392 fn mul(self, rhs: Decimal) -> Self::Output {
393 Self::try_from_inner(self.0 * rhs).expect("Overflowed")
394 }
395}
396impl Mul<Amount> for Decimal {
398 type Output = Amount;
399 fn mul(self, rhs: Amount) -> Self::Output {
400 Amount::try_from_inner(self * rhs.0).expect("Overflowed")
401 }
402}
403
404impl Div<Decimal> for Amount {
406 type Output = Self;
407 fn div(self, rhs: Decimal) -> Self::Output {
408 Self::try_from_inner(self.0 / rhs).expect("Overflowed")
409 }
410}
411
412impl Sum for Amount {
413 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
414 iter.fold(Amount::ZERO, Self::add)
415 }
416}
417
418#[cfg(any(test, feature = "test-utils"))]
421pub mod arb {
422 use proptest::{
423 arbitrary::Arbitrary,
424 strategy::{BoxedStrategy, Strategy},
425 };
426
427 use super::*;
428
429 impl Arbitrary for Amount {
431 type Parameters = ();
432 type Strategy = BoxedStrategy<Self>;
433 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
434 (0_u64..=Amount::MAX_BITCOIN_SUPPLY_MSATS_U64)
435 .prop_map(Amount::from_msat)
436 .boxed()
437 }
438 }
439
440 pub fn sats_amount() -> impl Strategy<Value = Amount> {
442 (0_u64..=Amount::MAX_BITCOIN_SUPPLY_SATS_U64)
443 .prop_map(|sats_u64| Amount::try_from_sats_u64(sats_u64).unwrap())
444 }
445}
446
447#[cfg(test)]
448mod test {
449 use std::str::FromStr;
450
451 use lexe_std::Apply;
452 use proptest::{
453 arbitrary::any,
454 prelude::{Strategy, TestCaseError},
455 prop_assert, prop_assert_eq, proptest,
456 };
457
458 use super::*;
459 use crate::test_utils::arbitrary;
460
461 #[test]
463 fn check_associated_constants() {
464 let max_u64_msat_in_sat = Decimal::from(u64::MAX) / dec!(1000);
466 println!("{:?}", max_u64_msat_in_sat.unpack());
467 assert_eq!(Amount::MAX, Amount(max_u64_msat_in_sat));
468
469 assert_eq!(Amount::MAX.msat(), u64::MAX);
470 assert_eq!(
471 Amount::MAX_BITCOIN_SUPPLY.sats(),
472 dec!(21_000_000) * dec!(100_000_000),
473 );
474 assert_eq!(
475 Amount::MAX_BITCOIN_SUPPLY.msat(),
476 21_000_000 * 100_000_000 * 1000,
477 );
478 }
479
480 #[test]
483 fn no_msat_u64_precision_loss() {
484 proptest!(|(msat1 in any::<u64>())| {
485 let amount = Amount::from_msat(msat1);
486 let msat2 = amount.msat();
487 prop_assert_eq!(msat1, msat2);
488 })
489 }
490
491 #[test]
493 fn sat_u32_roundtrips() {
494 proptest!(|(sat1 in any::<u32>())| {
495 let amount = Amount::from_sats_u32(sat1);
496 let sat2a = amount.sats_u64().apply(u32::try_from).unwrap();
497 let sat2b = amount.sats().to_u32().unwrap();
498 prop_assert_eq!(sat1, sat2a);
499 prop_assert_eq!(sat1, sat2b);
500 })
501 }
502
503 #[test]
510 fn no_roundtrip_inside_outside_precision_loss() {
511 proptest!(|(amount in any::<Amount>())| {
512 {
513 let roundtrip_inside =
515 Amount::try_from_sats(amount.sats()).unwrap();
516 prop_assert_eq!(amount, roundtrip_inside);
517
518 let msat_u64 = amount.msat();
521 let msat_dec = Decimal::from(msat_u64);
522 let sat_dec = msat_dec / dec!(1000);
523 let roundtrip_outside = Amount::try_from_sats(sat_dec).unwrap();
524 prop_assert_eq!(roundtrip_inside, roundtrip_outside);
525 }
526
527 {
529 let roundtrip_inside = Amount::try_from_btc(amount.btc()).unwrap();
531 prop_assert_eq!(amount, roundtrip_inside);
532
533 let msat_u64 = amount.msat();
535 let msat_dec = Decimal::from(msat_u64);
536 let btc_dec = msat_dec / dec!(100_000_000_000);
537 let roundtrip_outside = Amount::try_from_btc(btc_dec).unwrap();
538 prop_assert_eq!(roundtrip_inside, roundtrip_outside);
539 }
540 })
541 }
542
543 #[test]
545 fn amount_add_sub() {
546 proptest!(|(
547 amount1 in any::<Amount>(),
548 amount2 in any::<Amount>(),
549 )| {
550 let (greater, lesser) = if amount1 >= amount2 {
551 (amount1, amount2)
552 } else {
553 (amount2, amount1)
554 };
555
556 let diff = greater - lesser;
557 prop_assert_eq!(greater, lesser + diff);
558 prop_assert_eq!(lesser, greater - diff);
559
560 let checked_diff = greater.checked_sub(lesser).unwrap();
561 prop_assert_eq!(greater, lesser.checked_add(checked_diff).unwrap());
562 prop_assert_eq!(lesser, greater.checked_sub(checked_diff).unwrap());
563
564 if greater > lesser {
565 prop_assert!(lesser.checked_sub(greater).is_none());
566 prop_assert!(Amount::MAX.checked_add(greater).is_none());
567 }
568
569 prop_assert!(amount1.abs_diff(amount2) >= Amount::ZERO);
571 })
572 }
573
574 #[test]
576 fn amount_mul_div() {
577 proptest!(|(start in any::<Amount>())| {
578 let amount1 = Amount(start.0.round_dp(2));
581
582 let intermediate_a = amount1 / dec!(10);
583 let intermediate_b = amount1.checked_div(dec!(10)).unwrap();
584 prop_assert_eq!(intermediate_a, intermediate_b);
585
586 let amount2_a = dec!(10) * intermediate_a;
587 let amount2_b = intermediate_a * dec!(10);
588 let amount2_c = intermediate_a.checked_mul(dec!(10)).unwrap();
589 prop_assert_eq!(amount1, amount2_a);
590 prop_assert_eq!(amount1, amount2_b);
591 prop_assert_eq!(amount1, amount2_c);
592 })
593 }
594
595 fn any_bounded_decimal() -> impl Strategy<Value = Decimal> {
598 let min_nanosat: u128 = 0;
602 let max_nanosat: u128 = u128::from(Amount::MAX.msat()) * 1_000;
603 (min_nanosat..=max_nanosat)
604 .prop_map(|nanosat| Decimal::from(nanosat) / dec!(1_000_000))
606 }
607
608 #[test]
611 fn test_bounded_decimal_strategy() {
612 proptest!(|(
613 bounded_decimal in any_bounded_decimal(),
614 )| {
615 prop_assert!(bounded_decimal >= Amount::ZERO.0);
616 prop_assert!(bounded_decimal <= Amount::MAX.0);
617 });
618 }
619
620 #[test]
622 fn amount_msat_rounding() {
623 fn assert_whole_msat(amount: Amount) -> Result<(), TestCaseError> {
624 prop_assert_eq!(amount, amount.round_msat());
625 Ok(())
626 }
627
628 proptest!(|(
629 amount in any::<Amount>(),
630
631 other_amount in any::<Amount>(),
632 unbounded_dec in arbitrary::any_decimal(),
633 bounded_dec in any_bounded_decimal(),
634 )| {
635 assert_whole_msat(amount.saturating_add(other_amount))?;
637 if let Some(added) = amount.checked_add(other_amount) {
638 assert_whole_msat(added)?;
639 assert_whole_msat(amount + other_amount)?;
640 }
641
642 assert_whole_msat(amount.saturating_sub(other_amount))?;
644 if let Some(subbed) = amount.checked_sub(other_amount) {
645 assert_whole_msat(subbed)?;
646 assert_whole_msat(amount - other_amount)?;
647 }
648
649 assert_whole_msat(amount.saturating_mul(unbounded_dec))?;
652 if let Some(mulled) = amount.checked_mul(unbounded_dec) {
653 assert_whole_msat(mulled)?;
654 assert_whole_msat(amount * unbounded_dec)?;
655 assert_whole_msat(unbounded_dec * amount)?;
656 }
657
658 if let Some(dived) = amount.checked_div(unbounded_dec) {
660 assert_whole_msat(dived)?;
661 assert_whole_msat(amount / unbounded_dec)?;
662 }
663
664 assert_whole_msat(amount.abs_diff(other_amount))?;
666
667 assert_whole_msat(Amount::try_from_inner(bounded_dec).unwrap())?;
669
670 let bounded_decimal_str = bounded_dec.to_string();
672 assert_whole_msat(Amount::from_str(&bounded_decimal_str).unwrap())?;
673
674 })
675 }
676
677 #[test]
679 fn test_floor_sat() {
680 proptest!(|(amount in any::<Amount>())| {
681 let floored = amount.floor_sat();
682 prop_assert!(floored <= amount);
683 prop_assert_eq!(floored, Amount(amount.0.floor()));
686 prop_assert_eq!(
687 floored,
688 Amount::try_from_sats(Decimal::from(amount.sats_u64())).unwrap()
689 );
690 });
691 }
692
693 #[test]
695 fn amount_round_sat_btc() {
696 fn expect_no_precision_loss(amount: Amount) {
701 assert_eq!(amount.btc(), amount.round_sat().btc());
702 }
703
704 expect_no_precision_loss(Amount::from_sats_u32(0));
705 expect_no_precision_loss(Amount::from_sats_u32(10_0000));
706 expect_no_precision_loss(Amount::from_sats_u32(10_0010_0005));
707 expect_no_precision_loss(
708 Amount::try_from_sats_u64(20_999_999_9999_9999).unwrap(),
709 );
710
711 proptest!(|(amount_u64: u64)| {
712 let amount_u64 = amount_u64 % 2_100_000_000_000_000;
714 let amount = Amount::try_from_sats_u64(amount_u64).unwrap();
715 expect_no_precision_loss(amount);
716 });
717
718 assert_eq!(Amount::from_msat(1).round_sat().btc(), Amount::ZERO.btc());
723 assert_eq!(
724 Amount::from_msat(1_001).round_sat().btc(),
725 Amount::from_sats_u32(1).btc(),
726 );
727 assert_eq!(
728 Amount::from_msat(1_501).round_sat().btc(),
729 Amount::from_sats_u32(2).btc(),
730 );
731 }
732
733 #[test]
735 fn amount_btc_str() {
736 fn parse_btc_str(input: &str) -> Option<Amount> {
737 Decimal::from_str(input)
738 .ok()
739 .and_then(|btc_decimal| Amount::try_from_btc(btc_decimal).ok())
740 }
741 fn parse_eq(input: &str, expected: Amount) {
742 assert_eq!(parse_btc_str(input).unwrap(), expected);
743 }
744 fn parse_fail(input: &str) {
745 if let Some(amount) = parse_btc_str(input) {
746 panic!(
747 "Should fail to parse BTC str: '{input}', got: {amount:?}"
748 );
749 }
750 }
751
752 parse_eq("0", Amount::ZERO);
755 parse_eq("0.", Amount::ZERO);
756 parse_eq(".0", Amount::ZERO);
757 parse_eq("0.001", Amount::from_sats_u32(10_0000));
758 parse_eq("10.00", Amount::from_sats_u32(10_0000_0000));
759 parse_eq("10.", Amount::from_sats_u32(10_0000_0000));
760 parse_eq("10", Amount::from_sats_u32(10_0000_0000));
761 parse_eq("10.00000000", Amount::from_sats_u32(10_0000_0000));
762 parse_eq("10.00001230", Amount::from_sats_u32(10_0000_1230));
763 parse_eq("10.69696969", Amount::from_sats_u32(10_6969_6969));
764 parse_eq("0.00001230", Amount::from_sats_u32(1230));
765 parse_eq("0.69696969", Amount::from_sats_u32(6969_6969));
766 parse_eq(".00001230", Amount::from_sats_u32(1230));
767 parse_eq(".69696969", Amount::from_sats_u32(6969_6969));
768 parse_eq(
769 "20000000",
770 Amount::try_from_sats_u64(20_000_000_0000_0000).unwrap(),
771 );
772 parse_eq(
773 "20999999.99999999",
774 Amount::try_from_sats_u64(20_999_999_9999_9999).unwrap(),
775 );
776
777 parse_fail(".");
780 parse_fail("asdif.");
781 parse_fail("156.(6kfjaosid");
782 parse_fail("-156");
783 parse_fail("-15.4984");
784 parse_fail("-.4");
785 parse_fail(" 0.4");
786 parse_fail("0.4 ");
787
788 proptest!(|(amount: Amount)| {
791 let amount_btc_str = amount.btc().to_string();
792 let amount_round_sat_btc_str = amount.round_sat().btc().to_string();
793 let amount_btc_str_btc = parse_btc_str(&amount_btc_str).unwrap();
794 let amount_round_sat_btc_str_btc = parse_btc_str(&amount_round_sat_btc_str).unwrap();
795 prop_assert_eq!(amount, amount_btc_str_btc);
796 prop_assert_eq!(amount.btc(), amount_btc_str_btc.btc());
797 prop_assert_eq!(amount.round_sat(), amount_round_sat_btc_str_btc);
798 prop_assert_eq!(amount.round_sat().btc(), amount_round_sat_btc_str_btc.btc());
799 });
800
801 proptest!(|(s in arbitrary::any_string())| {
804 let _ = parse_btc_str(&s);
805 });
806 }
807}