Skip to main content

bitcoin_units/amount/
result.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! Provides a monadic type returned by mathematical operations (`core::ops`).
4
5use core::num::{NonZeroI64, NonZeroU64};
6use core::ops;
7
8use NumOpResult as R;
9
10use super::{Amount, SignedAmount};
11use crate::internal_macros::{
12    impl_add_assign_for_results, impl_div_assign, impl_mul_assign, impl_sub_assign_for_results,
13};
14use crate::result::{MathOp, NumOpError, NumOpResult, OptionExt};
15
16impl From<Amount> for NumOpResult<Amount> {
17    #[inline]
18    fn from(a: Amount) -> Self { Self::Valid(a) }
19}
20impl From<&Amount> for NumOpResult<Amount> {
21    #[inline]
22    fn from(a: &Amount) -> Self { Self::Valid(*a) }
23}
24
25impl From<SignedAmount> for NumOpResult<SignedAmount> {
26    #[inline]
27    fn from(a: SignedAmount) -> Self { Self::Valid(a) }
28}
29impl From<&SignedAmount> for NumOpResult<SignedAmount> {
30    #[inline]
31    fn from(a: &SignedAmount) -> Self { Self::Valid(*a) }
32}
33
34crate::internal_macros::impl_op_for_references! {
35    impl ops::Add<Amount> for Amount {
36        type Output = NumOpResult<Amount>;
37
38        fn add(self, rhs: Amount) -> Self::Output { self.checked_add(rhs).valid_or_error(MathOp::Add) }
39    }
40    impl ops::Add<NumOpResult<Amount>> for Amount {
41        type Output = NumOpResult<Amount>;
42
43        fn add(self, rhs: NumOpResult<Amount>) -> Self::Output { rhs.and_then(|a| a + self) }
44    }
45
46    impl ops::Sub<Amount> for Amount {
47        type Output = NumOpResult<Amount>;
48
49        fn sub(self, rhs: Amount) -> Self::Output { self.checked_sub(rhs).valid_or_error(MathOp::Sub) }
50    }
51    impl ops::Sub<NumOpResult<Amount>> for Amount {
52        type Output = NumOpResult<Amount>;
53
54        fn sub(self, rhs: NumOpResult<Amount>) -> Self::Output {
55            match rhs {
56                R::Valid(amount) => self - amount,
57                R::Error(_) => rhs,
58            }
59        }
60    }
61
62    impl ops::Mul<u64> for Amount {
63        type Output = NumOpResult<Amount>;
64
65        fn mul(self, rhs: u64) -> Self::Output { self.checked_mul(rhs).valid_or_error(MathOp::Mul) }
66    }
67    impl ops::Mul<u64> for NumOpResult<Amount> {
68        type Output = NumOpResult<Amount>;
69
70        fn mul(self, rhs: u64) -> Self::Output { self.and_then(|lhs| lhs * rhs) }
71    }
72    impl ops::Mul<Amount> for u64 {
73        type Output = NumOpResult<Amount>;
74
75        fn mul(self, rhs: Amount) -> Self::Output { rhs.checked_mul(self).valid_or_error(MathOp::Mul) }
76    }
77    impl ops::Mul<NumOpResult<Amount>> for u64 {
78        type Output = NumOpResult<Amount>;
79
80        fn mul(self, rhs: NumOpResult<Amount>) -> Self::Output { rhs.and_then(|rhs| self * rhs) }
81    }
82
83    impl ops::Div<u64> for Amount {
84        type Output = NumOpResult<Amount>;
85
86        fn div(self, rhs: u64) -> Self::Output { self.checked_div(rhs).valid_or_error(MathOp::Div) }
87    }
88    impl ops::Div<u64> for NumOpResult<Amount> {
89        type Output = NumOpResult<Amount>;
90
91        fn div(self, rhs: u64) -> Self::Output { self.and_then(|lhs| lhs / rhs) }
92    }
93    impl ops::Div<Amount> for Amount {
94        type Output = NumOpResult<u64>;
95
96        fn div(self, rhs: Amount) -> Self::Output {
97            self.to_sat().checked_div(rhs.to_sat()).valid_or_error(MathOp::Div)
98        }
99    }
100    impl ops::Div<NonZeroU64> for Amount {
101        type Output = Amount;
102
103        fn div(self, rhs: NonZeroU64) -> Self::Output { Self::from_sat(self.to_sat() / rhs.get()).expect("construction after division cannot fail") }
104    }
105    impl ops::Rem<u64> for Amount {
106        type Output = NumOpResult<Amount>;
107
108        fn rem(self, modulus: u64) -> Self::Output { self.checked_rem(modulus).valid_or_error(MathOp::Rem) }
109    }
110    impl ops::Rem<u64> for NumOpResult<Amount> {
111        type Output = NumOpResult<Amount>;
112
113        fn rem(self, modulus: u64) -> Self::Output { self.and_then(|lhs| lhs % modulus) }
114    }
115
116    impl ops::Add<SignedAmount> for SignedAmount {
117        type Output = NumOpResult<SignedAmount>;
118
119        fn add(self, rhs: SignedAmount) -> Self::Output { self.checked_add(rhs).valid_or_error(MathOp::Add) }
120    }
121    impl ops::Add<NumOpResult<SignedAmount>> for SignedAmount {
122        type Output = NumOpResult<SignedAmount>;
123
124        fn add(self, rhs: NumOpResult<SignedAmount>) -> Self::Output { rhs.and_then(|a| a + self) }
125    }
126
127    impl ops::Sub<SignedAmount> for SignedAmount {
128        type Output = NumOpResult<SignedAmount>;
129
130        fn sub(self, rhs: SignedAmount) -> Self::Output { self.checked_sub(rhs).valid_or_error(MathOp::Sub) }
131    }
132    impl ops::Sub<NumOpResult<SignedAmount>> for SignedAmount {
133        type Output = NumOpResult<SignedAmount>;
134
135        fn sub(self, rhs: NumOpResult<SignedAmount>) -> Self::Output {
136            match rhs {
137                R::Valid(amount) => self - amount,
138                R::Error(_) => rhs,
139            }
140        }
141    }
142
143    impl ops::Mul<i64> for SignedAmount {
144        type Output = NumOpResult<SignedAmount>;
145
146        fn mul(self, rhs: i64) -> Self::Output { self.checked_mul(rhs).valid_or_error(MathOp::Mul) }
147    }
148    impl ops::Mul<i64> for NumOpResult<SignedAmount> {
149        type Output = NumOpResult<SignedAmount>;
150
151        fn mul(self, rhs: i64) -> Self::Output { self.and_then(|lhs| lhs * rhs) }
152    }
153    impl ops::Mul<SignedAmount> for i64 {
154        type Output = NumOpResult<SignedAmount>;
155
156        fn mul(self, rhs: SignedAmount) -> Self::Output { rhs.checked_mul(self).valid_or_error(MathOp::Mul) }
157    }
158    impl ops::Mul<NumOpResult<SignedAmount>> for i64 {
159        type Output = NumOpResult<SignedAmount>;
160
161        fn mul(self, rhs: NumOpResult<SignedAmount>) -> Self::Output { rhs.and_then(|rhs| self * rhs) }
162    }
163
164    impl ops::Div<i64> for SignedAmount {
165        type Output = NumOpResult<SignedAmount>;
166
167        fn div(self, rhs: i64) -> Self::Output { self.checked_div(rhs).valid_or_error(MathOp::Div) }
168    }
169    impl ops::Div<i64> for NumOpResult<SignedAmount> {
170        type Output = NumOpResult<SignedAmount>;
171
172        fn div(self, rhs: i64) -> Self::Output { self.and_then(|lhs| lhs / rhs) }
173    }
174    impl ops::Div<SignedAmount> for SignedAmount {
175        type Output = NumOpResult<i64>;
176
177        fn div(self, rhs: SignedAmount) -> Self::Output {
178            self.to_sat().checked_div(rhs.to_sat()).valid_or_error(MathOp::Div)
179        }
180    }
181    impl ops::Div<NonZeroI64> for SignedAmount {
182        type Output = SignedAmount;
183
184        fn div(self, rhs: NonZeroI64) -> Self::Output { Self::from_sat(self.to_sat() / rhs.get()).expect("construction after division cannot fail") }
185    }
186    impl ops::Rem<i64> for SignedAmount {
187        type Output = NumOpResult<SignedAmount>;
188
189        fn rem(self, modulus: i64) -> Self::Output { self.checked_rem(modulus).valid_or_error(MathOp::Rem) }
190    }
191    impl ops::Rem<i64> for NumOpResult<SignedAmount> {
192        type Output = NumOpResult<SignedAmount>;
193
194        fn rem(self, modulus: i64) -> Self::Output { self.and_then(|lhs| lhs % modulus) }
195    }
196}
197
198impl_mul_assign!(NumOpResult<Amount>, u64);
199impl_mul_assign!(NumOpResult<SignedAmount>, i64);
200impl_div_assign!(NumOpResult<Amount>, u64);
201impl_div_assign!(NumOpResult<SignedAmount>, i64);
202
203impl_add_assign_for_results!(Amount);
204impl_add_assign_for_results!(SignedAmount);
205impl_sub_assign_for_results!(Amount);
206impl_sub_assign_for_results!(SignedAmount);
207
208impl ops::Neg for SignedAmount {
209    type Output = Self;
210
211    #[inline]
212    fn neg(self) -> Self::Output {
213        Self::from_sat(self.to_sat().neg()).expect("all +ve and -ve values are valid")
214    }
215}
216
217impl core::iter::Sum<Self> for NumOpResult<Amount> {
218    fn sum<I>(iter: I) -> Self
219    where
220        I: Iterator<Item = Self>,
221    {
222        iter.fold(Self::Valid(Amount::ZERO), |acc, amount| match (acc, amount) {
223            (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
224            (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
225        })
226    }
227}
228impl<'a> core::iter::Sum<&'a Self> for NumOpResult<Amount> {
229    fn sum<I>(iter: I) -> Self
230    where
231        I: Iterator<Item = &'a Self>,
232    {
233        iter.fold(Self::Valid(Amount::ZERO), |acc, amount| match (acc, amount) {
234            (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
235            (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
236        })
237    }
238}
239
240impl core::iter::Sum<Self> for NumOpResult<SignedAmount> {
241    fn sum<I>(iter: I) -> Self
242    where
243        I: Iterator<Item = Self>,
244    {
245        iter.fold(Self::Valid(SignedAmount::ZERO), |acc, amount| match (acc, amount) {
246            (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
247            (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
248        })
249    }
250}
251impl<'a> core::iter::Sum<&'a Self> for NumOpResult<SignedAmount> {
252    fn sum<I>(iter: I) -> Self
253    where
254        I: Iterator<Item = &'a Self>,
255    {
256        iter.fold(Self::Valid(SignedAmount::ZERO), |acc, amount| match (acc, amount) {
257            (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
258            (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
259        })
260    }
261}
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn test_sum_amount_results() {
268        let amounts = [
269            NumOpResult::Valid(Amount::from_sat_u32(100)),
270            NumOpResult::Valid(Amount::from_sat_u32(200)),
271            NumOpResult::Valid(Amount::from_sat_u32(300)),
272        ];
273
274        let sum: NumOpResult<Amount> = amounts.into_iter().sum();
275        assert_eq!(sum, NumOpResult::Valid(Amount::from_sat_u32(600)));
276    }
277
278    #[test]
279    fn test_sum_amount_results_with_references() {
280        let amounts = [
281            NumOpResult::Valid(Amount::from_sat_u32(100)),
282            NumOpResult::Valid(Amount::from_sat_u32(200)),
283            NumOpResult::Valid(Amount::from_sat_u32(300)),
284        ];
285
286        let sum: NumOpResult<Amount> = amounts.iter().sum();
287        assert_eq!(sum, NumOpResult::Valid(Amount::from_sat_u32(600)));
288    }
289
290    #[test]
291    fn test_sum_amount_with_error_propagation() {
292        let amounts = [
293            NumOpResult::Valid(Amount::from_sat_u32(100)),
294            NumOpResult::Error(NumOpError::while_doing(MathOp::Add)),
295            NumOpResult::Valid(Amount::from_sat_u32(200)),
296        ];
297
298        let sum: NumOpResult<Amount> = amounts.into_iter().sum();
299        assert!(matches!(sum, NumOpResult::Error(_)));
300    }
301
302    #[test]
303    fn test_sum_signed_amount_results() {
304        let amounts = [
305            NumOpResult::Valid(SignedAmount::from_sat_i32(100)),
306            NumOpResult::Valid(SignedAmount::from_sat_i32(-50)),
307            NumOpResult::Valid(SignedAmount::from_sat_i32(200)),
308        ];
309
310        let sum: NumOpResult<SignedAmount> = amounts.into_iter().sum();
311        assert_eq!(sum, NumOpResult::Valid(SignedAmount::from_sat_i32(250)));
312    }
313
314    #[test]
315    fn test_sum_signed_amount_results_with_references() {
316        let amounts = [
317            NumOpResult::Valid(SignedAmount::from_sat_i32(100)),
318            NumOpResult::Valid(SignedAmount::from_sat_i32(-50)),
319            NumOpResult::Valid(SignedAmount::from_sat_i32(200)),
320        ];
321
322        let sum: NumOpResult<SignedAmount> = amounts.iter().sum();
323        assert_eq!(sum, NumOpResult::Valid(SignedAmount::from_sat_i32(250)));
324    }
325
326    #[test]
327    fn test_sum_signed_amount_with_error_propagation() {
328        let amounts = [
329            NumOpResult::Valid(SignedAmount::from_sat_i32(100)),
330            NumOpResult::Error(NumOpError::while_doing(MathOp::Add)),
331            NumOpResult::Valid(SignedAmount::from_sat_i32(200)),
332        ];
333
334        let sum: NumOpResult<SignedAmount> = amounts.into_iter().sum();
335        assert!(matches!(sum, NumOpResult::Error(_)));
336    }
337
338    #[test]
339    fn test_op_assign_amount() {
340        let sat = Amount::from_sat_u32(50);
341
342        let mut res = sat + sat;
343        res += Amount::from_sat_u32(50);
344        assert_eq!(res, NumOpResult::Valid(Amount::from_sat_u32(150)));
345
346        let add_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
347        res += add_err; // Add an error result
348        assert_eq!(res, add_err);
349
350        let mut res = sat + sat;
351        res -= Amount::from_sat_u32(20);
352        assert_eq!(res, NumOpResult::Valid(Amount::from_sat_u32(80)));
353
354        let sub_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Sub));
355        res -= sub_err; // Subtract an error result
356        assert_eq!(res, sub_err);
357    }
358
359    #[test]
360    fn test_op_assign_signed_amount() {
361        let ssat = SignedAmount::from_sat_i32(50);
362
363        let mut res = ssat + ssat;
364        res += SignedAmount::from_sat_i32(-30);
365        assert_eq!(res, NumOpResult::Valid(SignedAmount::from_sat_i32(70)));
366
367        let add_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
368        res += add_err; // Add an error result
369        assert_eq!(res, add_err);
370
371        let mut res = ssat + ssat;
372        res -= SignedAmount::from_sat_i32(25);
373        assert_eq!(res, NumOpResult::Valid(SignedAmount::from_sat_i32(75)));
374
375        let sub_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Sub));
376        res -= sub_err; // Subtract an error result
377        assert_eq!(res, sub_err);
378    }
379
380    #[test]
381    fn test_op_assign_amount_error() {
382        let mut res: NumOpResult<Amount> = NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
383
384        // Adding a valid amount to an error should make an Add error
385        res += Amount::from_sat_u32(10);
386        assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Add)));
387
388        // Adding an error to an error change to an Add error
389        res += NumOpResult::Error(NumOpError::while_doing(MathOp::Sub));
390        assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Add)));
391
392        // Subtracting a valid amount from an error should make a Sub error
393        res -= Amount::from_sat_u32(10);
394        assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Sub)));
395
396        // Subtracting an error from an error change to a Sub error
397        res -= NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
398        assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Sub)));
399    }
400}