Skip to main content

bitcoin_units/amount/
result.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! Provides a monadic type returned by mathematical operations (`core::ops`).
4
5use core::num::{NonZeroI64, NonZeroU64};
6use core::ops;
7
8use NumOpResult as R;
9
10use super::{Amount, SignedAmount};
11use crate::internal_macros::{
12    impl_add_assign_for_results, impl_div_assign, impl_mul_assign, impl_sub_assign_for_results,
13};
14use crate::result::{MathOp, NumOpError, NumOpResult, OptionExt};
15
16impl From<Amount> for NumOpResult<Amount> {
17    fn from(a: Amount) -> Self { Self::Valid(a) }
18}
19impl From<&Amount> for NumOpResult<Amount> {
20    fn from(a: &Amount) -> Self { Self::Valid(*a) }
21}
22
23impl From<SignedAmount> for NumOpResult<SignedAmount> {
24    fn from(a: SignedAmount) -> Self { Self::Valid(a) }
25}
26impl From<&SignedAmount> for NumOpResult<SignedAmount> {
27    fn from(a: &SignedAmount) -> Self { Self::Valid(*a) }
28}
29
30crate::internal_macros::impl_op_for_references! {
31    impl ops::Add<Amount> for Amount {
32        type Output = NumOpResult<Amount>;
33
34        fn add(self, rhs: Amount) -> Self::Output { self.checked_add(rhs).valid_or_error(MathOp::Add) }
35    }
36    impl ops::Add<NumOpResult<Amount>> for Amount {
37        type Output = NumOpResult<Amount>;
38
39        fn add(self, rhs: NumOpResult<Amount>) -> Self::Output { rhs.and_then(|a| a + self) }
40    }
41
42    impl ops::Sub<Amount> for Amount {
43        type Output = NumOpResult<Amount>;
44
45        fn sub(self, rhs: Amount) -> Self::Output { self.checked_sub(rhs).valid_or_error(MathOp::Sub) }
46    }
47    impl ops::Sub<NumOpResult<Amount>> for Amount {
48        type Output = NumOpResult<Amount>;
49
50        fn sub(self, rhs: NumOpResult<Amount>) -> Self::Output {
51            match rhs {
52                R::Valid(amount) => self - amount,
53                R::Error(_) => rhs,
54            }
55        }
56    }
57
58    impl ops::Mul<u64> for Amount {
59        type Output = NumOpResult<Amount>;
60
61        fn mul(self, rhs: u64) -> Self::Output { self.checked_mul(rhs).valid_or_error(MathOp::Mul) }
62    }
63    impl ops::Mul<u64> for NumOpResult<Amount> {
64        type Output = NumOpResult<Amount>;
65
66        fn mul(self, rhs: u64) -> Self::Output { self.and_then(|lhs| lhs * rhs) }
67    }
68    impl ops::Mul<Amount> for u64 {
69        type Output = NumOpResult<Amount>;
70
71        fn mul(self, rhs: Amount) -> Self::Output { rhs.checked_mul(self).valid_or_error(MathOp::Mul) }
72    }
73    impl ops::Mul<NumOpResult<Amount>> for u64 {
74        type Output = NumOpResult<Amount>;
75
76        fn mul(self, rhs: NumOpResult<Amount>) -> Self::Output { rhs.and_then(|rhs| self * rhs) }
77    }
78
79    impl ops::Div<u64> for Amount {
80        type Output = NumOpResult<Amount>;
81
82        fn div(self, rhs: u64) -> Self::Output { self.checked_div(rhs).valid_or_error(MathOp::Div) }
83    }
84    impl ops::Div<u64> for NumOpResult<Amount> {
85        type Output = NumOpResult<Amount>;
86
87        fn div(self, rhs: u64) -> Self::Output { self.and_then(|lhs| lhs / rhs) }
88    }
89    impl ops::Div<Amount> for Amount {
90        type Output = NumOpResult<u64>;
91
92        fn div(self, rhs: Amount) -> Self::Output {
93            self.to_sat().checked_div(rhs.to_sat()).valid_or_error(MathOp::Div)
94        }
95    }
96    impl ops::Div<NonZeroU64> for Amount {
97        type Output = Amount;
98
99        fn div(self, rhs: NonZeroU64) -> Self::Output { Self::from_sat(self.to_sat() / rhs.get()).expect("construction after division cannot fail") }
100    }
101    impl ops::Rem<u64> for Amount {
102        type Output = NumOpResult<Amount>;
103
104        fn rem(self, modulus: u64) -> Self::Output { self.checked_rem(modulus).valid_or_error(MathOp::Rem) }
105    }
106    impl ops::Rem<u64> for NumOpResult<Amount> {
107        type Output = NumOpResult<Amount>;
108
109        fn rem(self, modulus: u64) -> Self::Output { self.and_then(|lhs| lhs % modulus) }
110    }
111
112    impl ops::Add<SignedAmount> for SignedAmount {
113        type Output = NumOpResult<SignedAmount>;
114
115        fn add(self, rhs: SignedAmount) -> Self::Output { self.checked_add(rhs).valid_or_error(MathOp::Add) }
116    }
117    impl ops::Add<NumOpResult<SignedAmount>> for SignedAmount {
118        type Output = NumOpResult<SignedAmount>;
119
120        fn add(self, rhs: NumOpResult<SignedAmount>) -> Self::Output { rhs.and_then(|a| a + self) }
121    }
122
123    impl ops::Sub<SignedAmount> for SignedAmount {
124        type Output = NumOpResult<SignedAmount>;
125
126        fn sub(self, rhs: SignedAmount) -> Self::Output { self.checked_sub(rhs).valid_or_error(MathOp::Sub) }
127    }
128    impl ops::Sub<NumOpResult<SignedAmount>> for SignedAmount {
129        type Output = NumOpResult<SignedAmount>;
130
131        fn sub(self, rhs: NumOpResult<SignedAmount>) -> Self::Output {
132            match rhs {
133                R::Valid(amount) => self - amount,
134                R::Error(_) => rhs,
135            }
136        }
137    }
138
139    impl ops::Mul<i64> for SignedAmount {
140        type Output = NumOpResult<SignedAmount>;
141
142        fn mul(self, rhs: i64) -> Self::Output { self.checked_mul(rhs).valid_or_error(MathOp::Mul) }
143    }
144    impl ops::Mul<i64> for NumOpResult<SignedAmount> {
145        type Output = NumOpResult<SignedAmount>;
146
147        fn mul(self, rhs: i64) -> Self::Output { self.and_then(|lhs| lhs * rhs) }
148    }
149    impl ops::Mul<SignedAmount> for i64 {
150        type Output = NumOpResult<SignedAmount>;
151
152        fn mul(self, rhs: SignedAmount) -> Self::Output { rhs.checked_mul(self).valid_or_error(MathOp::Mul) }
153    }
154    impl ops::Mul<NumOpResult<SignedAmount>> for i64 {
155        type Output = NumOpResult<SignedAmount>;
156
157        fn mul(self, rhs: NumOpResult<SignedAmount>) -> Self::Output { rhs.and_then(|rhs| self * rhs) }
158    }
159
160    impl ops::Div<i64> for SignedAmount {
161        type Output = NumOpResult<SignedAmount>;
162
163        fn div(self, rhs: i64) -> Self::Output { self.checked_div(rhs).valid_or_error(MathOp::Div) }
164    }
165    impl ops::Div<i64> for NumOpResult<SignedAmount> {
166        type Output = NumOpResult<SignedAmount>;
167
168        fn div(self, rhs: i64) -> Self::Output { self.and_then(|lhs| lhs / rhs) }
169    }
170    impl ops::Div<SignedAmount> for SignedAmount {
171        type Output = NumOpResult<i64>;
172
173        fn div(self, rhs: SignedAmount) -> Self::Output {
174            self.to_sat().checked_div(rhs.to_sat()).valid_or_error(MathOp::Div)
175        }
176    }
177    impl ops::Div<NonZeroI64> for SignedAmount {
178        type Output = SignedAmount;
179
180        fn div(self, rhs: NonZeroI64) -> Self::Output { Self::from_sat(self.to_sat() / rhs.get()).expect("construction after division cannot fail") }
181    }
182    impl ops::Rem<i64> for SignedAmount {
183        type Output = NumOpResult<SignedAmount>;
184
185        fn rem(self, modulus: i64) -> Self::Output { self.checked_rem(modulus).valid_or_error(MathOp::Rem) }
186    }
187    impl ops::Rem<i64> for NumOpResult<SignedAmount> {
188        type Output = NumOpResult<SignedAmount>;
189
190        fn rem(self, modulus: i64) -> Self::Output { self.and_then(|lhs| lhs % modulus) }
191    }
192}
193
194impl_mul_assign!(NumOpResult<Amount>, u64);
195impl_mul_assign!(NumOpResult<SignedAmount>, i64);
196impl_div_assign!(NumOpResult<Amount>, u64);
197impl_div_assign!(NumOpResult<SignedAmount>, i64);
198
199impl_add_assign_for_results!(Amount);
200impl_add_assign_for_results!(SignedAmount);
201impl_sub_assign_for_results!(Amount);
202impl_sub_assign_for_results!(SignedAmount);
203
204impl ops::Neg for SignedAmount {
205    type Output = Self;
206
207    fn neg(self) -> Self::Output {
208        Self::from_sat(self.to_sat().neg()).expect("all +ve and -ve values are valid")
209    }
210}
211
212impl core::iter::Sum<Self> for NumOpResult<Amount> {
213    fn sum<I>(iter: I) -> Self
214    where
215        I: Iterator<Item = Self>,
216    {
217        iter.fold(Self::Valid(Amount::ZERO), |acc, amount| match (acc, amount) {
218            (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
219            (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
220        })
221    }
222}
223impl<'a> core::iter::Sum<&'a Self> for NumOpResult<Amount> {
224    fn sum<I>(iter: I) -> Self
225    where
226        I: Iterator<Item = &'a Self>,
227    {
228        iter.fold(Self::Valid(Amount::ZERO), |acc, amount| match (acc, amount) {
229            (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
230            (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
231        })
232    }
233}
234
235impl core::iter::Sum<Self> for NumOpResult<SignedAmount> {
236    fn sum<I>(iter: I) -> Self
237    where
238        I: Iterator<Item = Self>,
239    {
240        iter.fold(Self::Valid(SignedAmount::ZERO), |acc, amount| match (acc, amount) {
241            (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
242            (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
243        })
244    }
245}
246impl<'a> core::iter::Sum<&'a Self> for NumOpResult<SignedAmount> {
247    fn sum<I>(iter: I) -> Self
248    where
249        I: Iterator<Item = &'a Self>,
250    {
251        iter.fold(Self::Valid(SignedAmount::ZERO), |acc, amount| match (acc, amount) {
252            (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
253            (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
254        })
255    }
256}
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    #[test]
262    fn test_sum_amount_results() {
263        let amounts = [
264            NumOpResult::Valid(Amount::from_sat_u32(100)),
265            NumOpResult::Valid(Amount::from_sat_u32(200)),
266            NumOpResult::Valid(Amount::from_sat_u32(300)),
267        ];
268
269        let sum: NumOpResult<Amount> = amounts.into_iter().sum();
270        assert_eq!(sum, NumOpResult::Valid(Amount::from_sat_u32(600)));
271    }
272
273    #[test]
274    fn test_sum_amount_results_with_references() {
275        let amounts = [
276            NumOpResult::Valid(Amount::from_sat_u32(100)),
277            NumOpResult::Valid(Amount::from_sat_u32(200)),
278            NumOpResult::Valid(Amount::from_sat_u32(300)),
279        ];
280
281        let sum: NumOpResult<Amount> = amounts.iter().sum();
282        assert_eq!(sum, NumOpResult::Valid(Amount::from_sat_u32(600)));
283    }
284
285    #[test]
286    fn test_sum_amount_with_error_propagation() {
287        let amounts = [
288            NumOpResult::Valid(Amount::from_sat_u32(100)),
289            NumOpResult::Error(NumOpError::while_doing(MathOp::Add)),
290            NumOpResult::Valid(Amount::from_sat_u32(200)),
291        ];
292
293        let sum: NumOpResult<Amount> = amounts.into_iter().sum();
294        assert!(matches!(sum, NumOpResult::Error(_)));
295    }
296
297    #[test]
298    fn test_sum_signed_amount_results() {
299        let amounts = [
300            NumOpResult::Valid(SignedAmount::from_sat_i32(100)),
301            NumOpResult::Valid(SignedAmount::from_sat_i32(-50)),
302            NumOpResult::Valid(SignedAmount::from_sat_i32(200)),
303        ];
304
305        let sum: NumOpResult<SignedAmount> = amounts.into_iter().sum();
306        assert_eq!(sum, NumOpResult::Valid(SignedAmount::from_sat_i32(250)));
307    }
308
309    #[test]
310    fn test_sum_signed_amount_results_with_references() {
311        let amounts = [
312            NumOpResult::Valid(SignedAmount::from_sat_i32(100)),
313            NumOpResult::Valid(SignedAmount::from_sat_i32(-50)),
314            NumOpResult::Valid(SignedAmount::from_sat_i32(200)),
315        ];
316
317        let sum: NumOpResult<SignedAmount> = amounts.iter().sum();
318        assert_eq!(sum, NumOpResult::Valid(SignedAmount::from_sat_i32(250)));
319    }
320
321    #[test]
322    fn test_sum_signed_amount_with_error_propagation() {
323        let amounts = [
324            NumOpResult::Valid(SignedAmount::from_sat_i32(100)),
325            NumOpResult::Error(NumOpError::while_doing(MathOp::Add)),
326            NumOpResult::Valid(SignedAmount::from_sat_i32(200)),
327        ];
328
329        let sum: NumOpResult<SignedAmount> = amounts.into_iter().sum();
330        assert!(matches!(sum, NumOpResult::Error(_)));
331    }
332
333    #[test]
334    fn test_op_assign_amount() {
335        let sat = Amount::from_sat_u32(50);
336
337        let mut res = sat + sat;
338        res += Amount::from_sat_u32(50);
339        assert_eq!(res, NumOpResult::Valid(Amount::from_sat_u32(150)));
340
341        let add_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
342        res += add_err; // Add an error result
343        assert_eq!(res, add_err);
344
345        let mut res = sat + sat;
346        res -= Amount::from_sat_u32(20);
347        assert_eq!(res, NumOpResult::Valid(Amount::from_sat_u32(80)));
348
349        let sub_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Sub));
350        res -= sub_err; // Subtract an error result
351        assert_eq!(res, sub_err);
352    }
353
354    #[test]
355    fn test_op_assign_signed_amount() {
356        let ssat = SignedAmount::from_sat_i32(50);
357
358        let mut res = ssat + ssat;
359        res += SignedAmount::from_sat_i32(-30);
360        assert_eq!(res, NumOpResult::Valid(SignedAmount::from_sat_i32(70)));
361
362        let add_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
363        res += add_err; // Add an error result
364        assert_eq!(res, add_err);
365
366        let mut res = ssat + ssat;
367        res -= SignedAmount::from_sat_i32(25);
368        assert_eq!(res, NumOpResult::Valid(SignedAmount::from_sat_i32(75)));
369
370        let sub_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Sub));
371        res -= sub_err; // Subtract an error result
372        assert_eq!(res, sub_err);
373    }
374
375    #[test]
376    fn test_op_assign_amount_error() {
377        let mut res: NumOpResult<Amount> = NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
378
379        // Adding a valid amount to an error should make an Add error
380        res += Amount::from_sat_u32(10);
381        assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Add)));
382
383        // Adding an error to an error change to an Add error
384        res += NumOpResult::Error(NumOpError::while_doing(MathOp::Sub));
385        assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Add)));
386
387        // Subtracting a valid amount from an error should make a Sub error
388        res -= Amount::from_sat_u32(10);
389        assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Sub)));
390
391        // Subtracting an error from an error change to a Sub error
392        res -= NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
393        assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Sub)));
394    }
395}