1use core::num::{NonZeroI64, NonZeroU64};
6use core::ops;
7
8use NumOpResult as R;
9
10use super::{Amount, SignedAmount};
11use crate::internal_macros::{
12 impl_add_assign_for_results, impl_div_assign, impl_mul_assign, impl_sub_assign_for_results,
13};
14use crate::result::{MathOp, NumOpError, NumOpResult, OptionExt};
15
16impl From<Amount> for NumOpResult<Amount> {
17 #[inline]
18 fn from(a: Amount) -> Self { Self::Valid(a) }
19}
20impl From<&Amount> for NumOpResult<Amount> {
21 #[inline]
22 fn from(a: &Amount) -> Self { Self::Valid(*a) }
23}
24
25impl From<SignedAmount> for NumOpResult<SignedAmount> {
26 #[inline]
27 fn from(a: SignedAmount) -> Self { Self::Valid(a) }
28}
29impl From<&SignedAmount> for NumOpResult<SignedAmount> {
30 #[inline]
31 fn from(a: &SignedAmount) -> Self { Self::Valid(*a) }
32}
33
34crate::internal_macros::impl_op_for_references! {
35 impl ops::Add<Amount> for Amount {
36 type Output = NumOpResult<Amount>;
37
38 fn add(self, rhs: Amount) -> Self::Output { self.checked_add(rhs).valid_or_error(MathOp::Add) }
39 }
40 impl ops::Add<NumOpResult<Amount>> for Amount {
41 type Output = NumOpResult<Amount>;
42
43 fn add(self, rhs: NumOpResult<Amount>) -> Self::Output { rhs.and_then(|a| a + self) }
44 }
45
46 impl ops::Sub<Amount> for Amount {
47 type Output = NumOpResult<Amount>;
48
49 fn sub(self, rhs: Amount) -> Self::Output { self.checked_sub(rhs).valid_or_error(MathOp::Sub) }
50 }
51 impl ops::Sub<NumOpResult<Amount>> for Amount {
52 type Output = NumOpResult<Amount>;
53
54 fn sub(self, rhs: NumOpResult<Amount>) -> Self::Output {
55 match rhs {
56 R::Valid(amount) => self - amount,
57 R::Error(_) => rhs,
58 }
59 }
60 }
61
62 impl ops::Mul<u64> for Amount {
63 type Output = NumOpResult<Amount>;
64
65 fn mul(self, rhs: u64) -> Self::Output { self.checked_mul(rhs).valid_or_error(MathOp::Mul) }
66 }
67 impl ops::Mul<u64> for NumOpResult<Amount> {
68 type Output = NumOpResult<Amount>;
69
70 fn mul(self, rhs: u64) -> Self::Output { self.and_then(|lhs| lhs * rhs) }
71 }
72 impl ops::Mul<Amount> for u64 {
73 type Output = NumOpResult<Amount>;
74
75 fn mul(self, rhs: Amount) -> Self::Output { rhs.checked_mul(self).valid_or_error(MathOp::Mul) }
76 }
77 impl ops::Mul<NumOpResult<Amount>> for u64 {
78 type Output = NumOpResult<Amount>;
79
80 fn mul(self, rhs: NumOpResult<Amount>) -> Self::Output { rhs.and_then(|rhs| self * rhs) }
81 }
82
83 impl ops::Div<u64> for Amount {
84 type Output = NumOpResult<Amount>;
85
86 fn div(self, rhs: u64) -> Self::Output { self.checked_div(rhs).valid_or_error(MathOp::Div) }
87 }
88 impl ops::Div<u64> for NumOpResult<Amount> {
89 type Output = NumOpResult<Amount>;
90
91 fn div(self, rhs: u64) -> Self::Output { self.and_then(|lhs| lhs / rhs) }
92 }
93 impl ops::Div<Amount> for Amount {
94 type Output = NumOpResult<u64>;
95
96 fn div(self, rhs: Amount) -> Self::Output {
97 self.to_sat().checked_div(rhs.to_sat()).valid_or_error(MathOp::Div)
98 }
99 }
100 impl ops::Div<NonZeroU64> for Amount {
101 type Output = Amount;
102
103 fn div(self, rhs: NonZeroU64) -> Self::Output { Self::from_sat(self.to_sat() / rhs.get()).expect("construction after division cannot fail") }
104 }
105 impl ops::Rem<u64> for Amount {
106 type Output = NumOpResult<Amount>;
107
108 fn rem(self, modulus: u64) -> Self::Output { self.checked_rem(modulus).valid_or_error(MathOp::Rem) }
109 }
110 impl ops::Rem<u64> for NumOpResult<Amount> {
111 type Output = NumOpResult<Amount>;
112
113 fn rem(self, modulus: u64) -> Self::Output { self.and_then(|lhs| lhs % modulus) }
114 }
115
116 impl ops::Add<SignedAmount> for SignedAmount {
117 type Output = NumOpResult<SignedAmount>;
118
119 fn add(self, rhs: SignedAmount) -> Self::Output { self.checked_add(rhs).valid_or_error(MathOp::Add) }
120 }
121 impl ops::Add<NumOpResult<SignedAmount>> for SignedAmount {
122 type Output = NumOpResult<SignedAmount>;
123
124 fn add(self, rhs: NumOpResult<SignedAmount>) -> Self::Output { rhs.and_then(|a| a + self) }
125 }
126
127 impl ops::Sub<SignedAmount> for SignedAmount {
128 type Output = NumOpResult<SignedAmount>;
129
130 fn sub(self, rhs: SignedAmount) -> Self::Output { self.checked_sub(rhs).valid_or_error(MathOp::Sub) }
131 }
132 impl ops::Sub<NumOpResult<SignedAmount>> for SignedAmount {
133 type Output = NumOpResult<SignedAmount>;
134
135 fn sub(self, rhs: NumOpResult<SignedAmount>) -> Self::Output {
136 match rhs {
137 R::Valid(amount) => self - amount,
138 R::Error(_) => rhs,
139 }
140 }
141 }
142
143 impl ops::Mul<i64> for SignedAmount {
144 type Output = NumOpResult<SignedAmount>;
145
146 fn mul(self, rhs: i64) -> Self::Output { self.checked_mul(rhs).valid_or_error(MathOp::Mul) }
147 }
148 impl ops::Mul<i64> for NumOpResult<SignedAmount> {
149 type Output = NumOpResult<SignedAmount>;
150
151 fn mul(self, rhs: i64) -> Self::Output { self.and_then(|lhs| lhs * rhs) }
152 }
153 impl ops::Mul<SignedAmount> for i64 {
154 type Output = NumOpResult<SignedAmount>;
155
156 fn mul(self, rhs: SignedAmount) -> Self::Output { rhs.checked_mul(self).valid_or_error(MathOp::Mul) }
157 }
158 impl ops::Mul<NumOpResult<SignedAmount>> for i64 {
159 type Output = NumOpResult<SignedAmount>;
160
161 fn mul(self, rhs: NumOpResult<SignedAmount>) -> Self::Output { rhs.and_then(|rhs| self * rhs) }
162 }
163
164 impl ops::Div<i64> for SignedAmount {
165 type Output = NumOpResult<SignedAmount>;
166
167 fn div(self, rhs: i64) -> Self::Output { self.checked_div(rhs).valid_or_error(MathOp::Div) }
168 }
169 impl ops::Div<i64> for NumOpResult<SignedAmount> {
170 type Output = NumOpResult<SignedAmount>;
171
172 fn div(self, rhs: i64) -> Self::Output { self.and_then(|lhs| lhs / rhs) }
173 }
174 impl ops::Div<SignedAmount> for SignedAmount {
175 type Output = NumOpResult<i64>;
176
177 fn div(self, rhs: SignedAmount) -> Self::Output {
178 self.to_sat().checked_div(rhs.to_sat()).valid_or_error(MathOp::Div)
179 }
180 }
181 impl ops::Div<NonZeroI64> for SignedAmount {
182 type Output = SignedAmount;
183
184 fn div(self, rhs: NonZeroI64) -> Self::Output { Self::from_sat(self.to_sat() / rhs.get()).expect("construction after division cannot fail") }
185 }
186 impl ops::Rem<i64> for SignedAmount {
187 type Output = NumOpResult<SignedAmount>;
188
189 fn rem(self, modulus: i64) -> Self::Output { self.checked_rem(modulus).valid_or_error(MathOp::Rem) }
190 }
191 impl ops::Rem<i64> for NumOpResult<SignedAmount> {
192 type Output = NumOpResult<SignedAmount>;
193
194 fn rem(self, modulus: i64) -> Self::Output { self.and_then(|lhs| lhs % modulus) }
195 }
196}
197
198impl_mul_assign!(NumOpResult<Amount>, u64);
199impl_mul_assign!(NumOpResult<SignedAmount>, i64);
200impl_div_assign!(NumOpResult<Amount>, u64);
201impl_div_assign!(NumOpResult<SignedAmount>, i64);
202
203impl_add_assign_for_results!(Amount);
204impl_add_assign_for_results!(SignedAmount);
205impl_sub_assign_for_results!(Amount);
206impl_sub_assign_for_results!(SignedAmount);
207
208impl ops::Neg for SignedAmount {
209 type Output = Self;
210
211 #[inline]
212 fn neg(self) -> Self::Output {
213 Self::from_sat(self.to_sat().neg()).expect("all +ve and -ve values are valid")
214 }
215}
216
217impl core::iter::Sum<Self> for NumOpResult<Amount> {
218 fn sum<I>(iter: I) -> Self
219 where
220 I: Iterator<Item = Self>,
221 {
222 iter.fold(Self::Valid(Amount::ZERO), |acc, amount| match (acc, amount) {
223 (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
224 (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
225 })
226 }
227}
228impl<'a> core::iter::Sum<&'a Self> for NumOpResult<Amount> {
229 fn sum<I>(iter: I) -> Self
230 where
231 I: Iterator<Item = &'a Self>,
232 {
233 iter.fold(Self::Valid(Amount::ZERO), |acc, amount| match (acc, amount) {
234 (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
235 (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
236 })
237 }
238}
239
240impl core::iter::Sum<Self> for NumOpResult<SignedAmount> {
241 fn sum<I>(iter: I) -> Self
242 where
243 I: Iterator<Item = Self>,
244 {
245 iter.fold(Self::Valid(SignedAmount::ZERO), |acc, amount| match (acc, amount) {
246 (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
247 (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
248 })
249 }
250}
251impl<'a> core::iter::Sum<&'a Self> for NumOpResult<SignedAmount> {
252 fn sum<I>(iter: I) -> Self
253 where
254 I: Iterator<Item = &'a Self>,
255 {
256 iter.fold(Self::Valid(SignedAmount::ZERO), |acc, amount| match (acc, amount) {
257 (Self::Valid(lhs), Self::Valid(rhs)) => lhs + rhs,
258 (_, _) => Self::Error(NumOpError::while_doing(MathOp::Add)),
259 })
260 }
261}
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_sum_amount_results() {
268 let amounts = [
269 NumOpResult::Valid(Amount::from_sat_u32(100)),
270 NumOpResult::Valid(Amount::from_sat_u32(200)),
271 NumOpResult::Valid(Amount::from_sat_u32(300)),
272 ];
273
274 let sum: NumOpResult<Amount> = amounts.into_iter().sum();
275 assert_eq!(sum, NumOpResult::Valid(Amount::from_sat_u32(600)));
276 }
277
278 #[test]
279 fn test_sum_amount_results_with_references() {
280 let amounts = [
281 NumOpResult::Valid(Amount::from_sat_u32(100)),
282 NumOpResult::Valid(Amount::from_sat_u32(200)),
283 NumOpResult::Valid(Amount::from_sat_u32(300)),
284 ];
285
286 let sum: NumOpResult<Amount> = amounts.iter().sum();
287 assert_eq!(sum, NumOpResult::Valid(Amount::from_sat_u32(600)));
288 }
289
290 #[test]
291 fn test_sum_amount_with_error_propagation() {
292 let amounts = [
293 NumOpResult::Valid(Amount::from_sat_u32(100)),
294 NumOpResult::Error(NumOpError::while_doing(MathOp::Add)),
295 NumOpResult::Valid(Amount::from_sat_u32(200)),
296 ];
297
298 let sum: NumOpResult<Amount> = amounts.into_iter().sum();
299 assert!(matches!(sum, NumOpResult::Error(_)));
300 }
301
302 #[test]
303 fn test_sum_signed_amount_results() {
304 let amounts = [
305 NumOpResult::Valid(SignedAmount::from_sat_i32(100)),
306 NumOpResult::Valid(SignedAmount::from_sat_i32(-50)),
307 NumOpResult::Valid(SignedAmount::from_sat_i32(200)),
308 ];
309
310 let sum: NumOpResult<SignedAmount> = amounts.into_iter().sum();
311 assert_eq!(sum, NumOpResult::Valid(SignedAmount::from_sat_i32(250)));
312 }
313
314 #[test]
315 fn test_sum_signed_amount_results_with_references() {
316 let amounts = [
317 NumOpResult::Valid(SignedAmount::from_sat_i32(100)),
318 NumOpResult::Valid(SignedAmount::from_sat_i32(-50)),
319 NumOpResult::Valid(SignedAmount::from_sat_i32(200)),
320 ];
321
322 let sum: NumOpResult<SignedAmount> = amounts.iter().sum();
323 assert_eq!(sum, NumOpResult::Valid(SignedAmount::from_sat_i32(250)));
324 }
325
326 #[test]
327 fn test_sum_signed_amount_with_error_propagation() {
328 let amounts = [
329 NumOpResult::Valid(SignedAmount::from_sat_i32(100)),
330 NumOpResult::Error(NumOpError::while_doing(MathOp::Add)),
331 NumOpResult::Valid(SignedAmount::from_sat_i32(200)),
332 ];
333
334 let sum: NumOpResult<SignedAmount> = amounts.into_iter().sum();
335 assert!(matches!(sum, NumOpResult::Error(_)));
336 }
337
338 #[test]
339 fn test_op_assign_amount() {
340 let sat = Amount::from_sat_u32(50);
341
342 let mut res = sat + sat;
343 res += Amount::from_sat_u32(50);
344 assert_eq!(res, NumOpResult::Valid(Amount::from_sat_u32(150)));
345
346 let add_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
347 res += add_err; assert_eq!(res, add_err);
349
350 let mut res = sat + sat;
351 res -= Amount::from_sat_u32(20);
352 assert_eq!(res, NumOpResult::Valid(Amount::from_sat_u32(80)));
353
354 let sub_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Sub));
355 res -= sub_err; assert_eq!(res, sub_err);
357 }
358
359 #[test]
360 fn test_op_assign_signed_amount() {
361 let ssat = SignedAmount::from_sat_i32(50);
362
363 let mut res = ssat + ssat;
364 res += SignedAmount::from_sat_i32(-30);
365 assert_eq!(res, NumOpResult::Valid(SignedAmount::from_sat_i32(70)));
366
367 let add_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
368 res += add_err; assert_eq!(res, add_err);
370
371 let mut res = ssat + ssat;
372 res -= SignedAmount::from_sat_i32(25);
373 assert_eq!(res, NumOpResult::Valid(SignedAmount::from_sat_i32(75)));
374
375 let sub_err = NumOpResult::Error(NumOpError::while_doing(MathOp::Sub));
376 res -= sub_err; assert_eq!(res, sub_err);
378 }
379
380 #[test]
381 fn test_op_assign_amount_error() {
382 let mut res: NumOpResult<Amount> = NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
383
384 res += Amount::from_sat_u32(10);
386 assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Add)));
387
388 res += NumOpResult::Error(NumOpError::while_doing(MathOp::Sub));
390 assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Add)));
391
392 res -= Amount::from_sat_u32(10);
394 assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Sub)));
395
396 res -= NumOpResult::Error(NumOpError::while_doing(MathOp::Add));
398 assert_eq!(res, NumOpResult::Error(NumOpError::while_doing(MathOp::Sub)));
399 }
400}