1use core::num::{NonZeroI64, NonZeroU64};
6use core::ops;
7
8use NumOpResult as R;
9
10use super::{Amount, SignedAmount};
11use crate::internal_macros::{
12 impl_add_assign_for_results, impl_div_assign, impl_mul_assign, impl_sub_assign_for_results,
13};
14use crate::result::{MathOp, NumOpError, NumOpResult, OptionExt};
15
16impl From<Amount> for NumOpResult<Amount> {
17 fn from(a: Amount) -> Self { Self::Valid(a) }
18}
19impl From<&Amount> for NumOpResult<Amount> {
20 fn from(a: &Amount) -> Self { Self::Valid(*a) }
21}
22
23impl From<SignedAmount> for NumOpResult<SignedAmount> {
24 fn from(a: SignedAmount) -> Self { Self::Valid(a) }
25}
26impl From<&SignedAmount> for NumOpResult<SignedAmount> {
27 fn from(a: &SignedAmount) -> Self { Self::Valid(*a) }
28}
29
30crate::internal_macros::impl_op_for_references! {
31 impl ops::Add<Amount> for Amount {
32 type Output = NumOpResult<Amount>;
33
34 fn add(self, rhs: Amount) -> Self::Output { self.checked_add(rhs).valid_or_error(MathOp::Add) }
35 }
36 impl ops::Add<NumOpResult<Amount>> for Amount {
37 type Output = NumOpResult<Amount>;
38
39 fn add(self, rhs: NumOpResult<Amount>) -> Self::Output { rhs.and_then(|a| a + self) }
40 }
41
42 impl ops::Sub<Amount> for Amount {
43 type Output = NumOpResult<Amount>;
44
45 fn sub(self, rhs: Amount) -> Self::Output { self.checked_sub(rhs).valid_or_error(MathOp::Sub) }
46 }
47 impl ops::Sub<NumOpResult<Amount>> for Amount {
48 type Output = NumOpResult<Amount>;
49
50 fn sub(self, rhs: NumOpResult<Amount>) -> Self::Output {
51 match rhs {
52 R::Valid(amount) => self - amount,
53 R::Error(_) => rhs,
54 }
55 }
56 }
57
58 impl ops::Mul<u64> for Amount {
59 type Output = NumOpResult<Amount>;
60
61 fn mul(self, rhs: u64) -> Self::Output { self.checked_mul(rhs).valid_or_error(MathOp::Mul) }
62 }
63 impl ops::Mul<u64> for NumOpResult<Amount> {
64 type Output = NumOpResult<Amount>;
65
66 fn mul(self, rhs: u64) -> Self::Output { self.and_then(|lhs| lhs * rhs) }
67 }
68 impl ops::Mul<Amount> for u64 {
69 type Output = NumOpResult<Amount>;
70
71 fn mul(self, rhs: Amount) -> Self::Output { rhs.checked_mul(self).valid_or_error(MathOp::Mul) }
72 }
73 impl ops::Mul<NumOpResult<Amount>> for u64 {
74 type Output = NumOpResult<Amount>;
75
76 fn mul(self, rhs: NumOpResult<Amount>) -> Self::Output { rhs.and_then(|rhs| self * rhs) }
77 }
78
79 impl ops::Div<u64> for Amount {
80 type Output = NumOpResult<Amount>;
81
82 fn div(self, rhs: u64) -> Self::Output { self.checked_div(rhs).valid_or_error(MathOp::Div) }
83 }
84 impl ops::Div<u64> for NumOpResult<Amount> {
85 type Output = NumOpResult<Amount>;
86
87 fn div(self, rhs: u64) -> Self::Output { self.and_then(|lhs| lhs / rhs) }
88 }
89 impl ops::Div<Amount> for Amount {
90 type Output = NumOpResult<u64>;
91
92 fn div(self, rhs: Amount) -> Self::Output {
93 self.to_sat().checked_div(rhs.to_sat()).valid_or_error(MathOp::Div)
94 }
95 }
96 impl ops::Div<NonZeroU64> for Amount {
97 type Output = Amount;
98
99 fn div(self, rhs: NonZeroU64) -> Self::Output { Self::from_sat(self.to_sat() / rhs.get()).expect("construction after division cannot fail") }
100 }
101 impl ops::Rem<u64> for Amount {
102 type Output = NumOpResult<Amount>;
103
104 fn rem(self, modulus: u64) -> Self::Output { self.checked_rem(modulus).valid_or_error(MathOp::Rem) }
105 }
106 impl ops::Rem<u64> for NumOpResult<Amount> {
107 type Output = NumOpResult<Amount>;
108
109 fn rem(self, modulus: u64) -> Self::Output { self.and_then(|lhs| lhs % modulus) }
110 }
111
112 impl ops::Add<SignedAmount> for SignedAmount {
113 type Output = NumOpResult<SignedAmount>;
114
115 fn add(self, rhs: SignedAmount) -> Self::Output { self.checked_add(rhs).valid_or_error(MathOp::Add) }
116 }
117 impl ops::Add<NumOpResult<SignedAmount>> for SignedAmount {
118 type Output = NumOpResult<SignedAmount>;
119
120 fn add(self, rhs: NumOpResult<SignedAmount>) -> Self::Output { rhs.and_then(|a| a + self) }
121 }
122
123 impl ops::Sub<SignedAmount> for SignedAmount {
124 type Output = NumOpResult<SignedAmount>;
125
126 fn sub(self, rhs: SignedAmount) -> Self::Output { self.checked_sub(rhs).valid_or_error(MathOp::Sub) }
127 }
128 impl ops::Sub<NumOpResult<SignedAmount>> for SignedAmount {
129 type Output = NumOpResult<SignedAmount>;
130
131 fn sub(self, rhs: NumOpResult<SignedAmount>) -> Self::Output {
132 match rhs {
133 R::Valid(amount) => self - amount,
134 R::Error(_) => rhs,
135 }
136 }
137 }
138
139 impl ops::Mul<i64> for SignedAmount {
140 type Output = NumOpResult<SignedAmount>;
141
142 fn mul(self, rhs: i64) -> Self::Output { self.checked_mul(rhs).valid_or_error(MathOp::Mul) }
143 }
144 impl ops::Mul<i64> for NumOpResult<SignedAmount> {
145 type Output = NumOpResult<SignedAmount>;
146
147 fn mul(self, rhs: i64) -> Self::Output { self.and_then(|lhs| lhs * rhs) }
148 }
149 impl ops::Mul<SignedAmount> for i64 {
150 type Output = NumOpResult<SignedAmount>;
151
152 fn mul(self, rhs: SignedAmount) -> Self::Output { rhs.checked_mul(self).valid_or_error(MathOp::Mul) }
153 }
154 impl ops::Mul<NumOpResult<SignedAmount>> for i64 {
155 type Output = NumOpResult<SignedAmount>;
156
157 fn mul(self, rhs: NumOpResult<SignedAmount>) -> Self::Output { rhs.and_then(|rhs| self * rhs) }
158 }
159
160 impl ops::Div<i64> for SignedAmount {
161 type Output = NumOpResult<SignedAmount>;
162
163 fn div(self, rhs: i64) -> Self::Output { self.checked_div(rhs).valid_or_error(MathOp::Div) }
164 }
165 impl ops::Div<i64> for NumOpResult<SignedAmount> {
166 type Output = NumOpResult<SignedAmount>;
167
168 fn div(self, rhs: i64) -> Self::Output { self.and_then(|lhs| lhs / rhs) }
169 }
170 impl ops::Div<SignedAmount> for SignedAmount {
171 type Output = NumOpResult<i64>;
172
173 fn div(self, rhs: SignedAmount) -> Self::Output {
174 self.to_sat().checked_div(rhs.to_sat()).valid_or_error(MathOp::Div)
175 }
176 }
177 impl ops::Div<NonZeroI64> for SignedAmount {
178 type Output = SignedAmount;
179
180 fn div(self, rhs: NonZeroI64) -> Self::Output { Self::from_sat(self.to_sat() / rhs.get()).expect("construction after division cannot fail") }
181 }
182 impl ops::Rem<i64> for SignedAmount {
183 type Output = NumOpResult<SignedAmount>;
184
185 fn rem(self, modulus: i64) -> Self::Output { self.checked_rem(modulus).valid_or_error(MathOp::Rem) }
186 }
187 impl ops::Rem<i64> for NumOpResult<SignedAmount> {
188 type Output = NumOpResult<SignedAmount>;
189
190 fn rem(self, modulus: i64) -> Self::Output { self.and_then(|lhs| lhs % modulus) }
191 }
192}
193
194impl_mul_assign!(NumOpResult<Amount>, u64);
195impl_mul_assign!(NumOpResult<SignedAmount>, i64);
196impl_div_assign!(NumOpResult<Amount>, u64);
197impl_div_assign!(NumOpResult<SignedAmount>, i64);
198
199impl_add_assign_for_results!(Amount);
200impl_add_assign_for_results!(SignedAmount);
201impl_sub_assign_for_results!(Amount);
202impl_sub_assign_for_results!(SignedAmount);
203
204impl ops::Neg for SignedAmount {
205 type Output = Self;
206
207 fn neg(self) -> Self::Output {
208 Self::from_sat(self.to_sat().neg()).expect("all +ve and -ve values are valid")
209 }
210}
211
212impl core::iter::Sum<Self> for NumOpResult<Amount> {
213 fn sum<I>(iter: I) -> Self
214 where
215 I: Iterator<Item = Self>,
216 {
217 iter.fold(Self::Valid(Amount::ZERO), |acc, amount| match (acc, amount) {
218 (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
219 (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
220 })
221 }
222}
223impl<'a> core::iter::Sum<&'a Self> for NumOpResult<Amount> {
224 fn sum<I>(iter: I) -> Self
225 where
226 I: Iterator<Item = &'a Self>,
227 {
228 iter.fold(Self::Valid(Amount::ZERO), |acc, amount| match (acc, amount) {
229 (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
230 (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
231 })
232 }
233}
234
235impl core::iter::Sum<Self> for NumOpResult<SignedAmount> {
236 fn sum<I>(iter: I) -> Self
237 where
238 I: Iterator<Item = Self>,
239 {
240 iter.fold(Self::Valid(SignedAmount::ZERO), |acc, amount| match (acc, amount) {
241 (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
242 (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
243 })
244 }
245}
246impl<'a> core::iter::Sum<&'a Self> for NumOpResult<SignedAmount> {
247 fn sum<I>(iter: I) -> Self
248 where
249 I: Iterator<Item = &'a Self>,
250 {
251 iter.fold(Self::Valid(SignedAmount::ZERO), |acc, amount| match (acc, amount) {
252 (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
253 (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
254 })
255 }
256}
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn test_sum_amount_results() {
263 let amounts = [
264 NumOpResult::Valid(Amount::from_sat_u32(100)),
265 NumOpResult::Valid(Amount::from_sat_u32(200)),
266 NumOpResult::Valid(Amount::from_sat_u32(300)),
267 ];
268
269 let sum: NumOpResult<Amount> = amounts.into_iter().sum();
270 assert_eq!(sum, NumOpResult::Valid(Amount::from_sat_u32(600)));
271 }
272
273 #[test]
274 fn test_sum_amount_results_with_references() {
275 let amounts = [
276 NumOpResult::Valid(Amount::from_sat_u32(100)),
277 NumOpResult::Valid(Amount::from_sat_u32(200)),
278 NumOpResult::Valid(Amount::from_sat_u32(300)),
279 ];
280
281 let sum: NumOpResult<Amount> = amounts.iter().sum();
282 assert_eq!(sum, NumOpResult::Valid(Amount::from_sat_u32(600)));
283 }
284
285 #[test]
286 fn test_sum_amount_with_error_propagation() {
287 let amounts = [
288 NumOpResult::Valid(Amount::from_sat_u32(100)),
289 NumOpResult::Error(NumOpError::while_doing(MathOp::Add)),
290 NumOpResult::Valid(Amount::from_sat_u32(200)),
291 ];
292
293 let sum: NumOpResult<Amount> = amounts.into_iter().sum();
294 assert!(matches!(sum, NumOpResult::Error(_)));
295 }
296
297 #[test]
298 fn test_sum_signed_amount_results() {
299 let amounts = [
300 NumOpResult::Valid(SignedAmount::from_sat_i32(100)),
301 NumOpResult::Valid(SignedAmount::from_sat_i32(-50)),
302 NumOpResult::Valid(SignedAmount::from_sat_i32(200)),
303 ];
304
305 let sum: NumOpResult<SignedAmount> = amounts.into_iter().sum();
306 assert_eq!(sum, NumOpResult::Valid(SignedAmount::from_sat_i32(250)));
307 }
308
309 #[test]
310 fn test_sum_signed_amount_results_with_references() {
311 let amounts = [
312 NumOpResult::Valid(SignedAmount::from_sat_i32(100)),
313 NumOpResult::Valid(SignedAmount::from_sat_i32(-50)),
314 NumOpResult::Valid(SignedAmount::from_sat_i32(200)),
315 ];
316
317 let sum: NumOpResult<SignedAmount> = amounts.iter().sum();
318 assert_eq!(sum, NumOpResult::Valid(SignedAmount::from_sat_i32(250)));
319 }
320
321 #[test]
322 fn test_sum_signed_amount_with_error_propagation() {
323 let amounts = [
324 NumOpResult::Valid(SignedAmount::from_sat_i32(100)),
325 NumOpResult::Error(NumOpError::while_doing(MathOp::Add)),
326 NumOpResult::Valid(SignedAmount::from_sat_i32(200)),
327 ];
328
329 let sum: NumOpResult<SignedAmount> = amounts.into_iter().sum();
330 assert!(matches!(sum, NumOpResult::Error(_)));
331 }
332
333 #[test]
334 fn test_op_assign_amount() {
335 let sat = Amount::from_sat_u32(50);
336
337 let mut res = sat + sat;
338 res += Amount::from_sat_u32(50);
339 assert_eq!(res, NumOpResult::Valid(Amount::from_sat_u32(150)));
340
341 let add_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
342 res += add_err; assert_eq!(res, add_err);
344
345 let mut res = sat + sat;
346 res -= Amount::from_sat_u32(20);
347 assert_eq!(res, NumOpResult::Valid(Amount::from_sat_u32(80)));
348
349 let sub_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Sub));
350 res -= sub_err; assert_eq!(res, sub_err);
352 }
353
354 #[test]
355 fn test_op_assign_signed_amount() {
356 let ssat = SignedAmount::from_sat_i32(50);
357
358 let mut res = ssat + ssat;
359 res += SignedAmount::from_sat_i32(-30);
360 assert_eq!(res, NumOpResult::Valid(SignedAmount::from_sat_i32(70)));
361
362 let add_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
363 res += add_err; assert_eq!(res, add_err);
365
366 let mut res = ssat + ssat;
367 res -= SignedAmount::from_sat_i32(25);
368 assert_eq!(res, NumOpResult::Valid(SignedAmount::from_sat_i32(75)));
369
370 let sub_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Sub));
371 res -= sub_err; assert_eq!(res, sub_err);
373 }
374
375 #[test]
376 fn test_op_assign_amount_error() {
377 let mut res: NumOpResult<Amount> = NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
378
379 res += Amount::from_sat_u32(10);
381 assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Add)));
382
383 res += NumOpResult::Error(NumOpError::while_doing(MathOp::Sub));
385 assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Add)));
386
387 res -= Amount::from_sat_u32(10);
389 assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Sub)));
390
391 res -= NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
393 assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Sub)));
394 }
395}