1use core::convert::Infallible;
6use core::{fmt, ops};
7
8#[cfg(feature = "arbitrary")]
9use arbitrary::{Arbitrary, Unstructured};
10use NumOpResult as R;
11
12use crate::{Amount, FeeRate, SignedAmount, Weight};
13
14#[derive(Debug, Copy, Clone, PartialEq, Eq)]
84#[must_use]
85pub enum NumOpResult<T> {
86 Valid(T),
88 Error(NumOpError),
90}
91
92impl<T> NumOpResult<T> {
93 #[inline]
96 pub fn map<U, F: FnOnce(T) -> U>(self, op: F) -> NumOpResult<U> {
97 match self {
98 Self::Valid(t) => NumOpResult::Valid(op(t)),
99 Self::Error(e) => NumOpResult::Error(e),
100 }
101 }
102}
103
104impl<T: fmt::Debug> NumOpResult<T> {
105 #[inline]
111 #[track_caller]
112 pub fn expect(self, msg: &str) -> T {
113 match self {
114 Self::Valid(x) => x,
115 Self::Error(_) => panic!("{}", msg),
116 }
117 }
118
119 #[inline]
125 #[track_caller]
126 pub fn unwrap(self) -> T {
127 match self {
128 Self::Valid(x) => x,
129 Self::Error(e) => panic!("tried to unwrap an invalid numeric result: {:?}", e),
130 }
131 }
132
133 #[inline]
139 #[track_caller]
140 pub fn unwrap_err(self) -> NumOpError {
141 match self {
142 Self::Error(e) => e,
143 Self::Valid(a) => panic!("tried to unwrap a valid numeric result: {:?}", a),
144 }
145 }
146
147 #[inline]
152 #[track_caller]
153 pub fn unwrap_or(self, default: T) -> T {
154 match self {
155 Self::Valid(x) => x,
156 Self::Error(_) => default,
157 }
158 }
159
160 #[inline]
162 #[track_caller]
163 pub fn unwrap_or_else<F>(self, f: F) -> T
164 where
165 F: FnOnce() -> T,
166 {
167 match self {
168 Self::Valid(x) => x,
169 Self::Error(_) => f(),
170 }
171 }
172
173 #[inline]
175 pub fn ok(self) -> Option<T> {
176 match self {
177 Self::Valid(x) => Some(x),
178 Self::Error(_) => None,
179 }
180 }
181
182 #[inline]
184 #[allow(clippy::missing_errors_doc)]
185 pub fn into_result(self) -> Result<T, NumOpError> {
186 match self {
187 Self::Valid(x) => Ok(x),
188 Self::Error(e) => Err(e),
189 }
190 }
191
192 #[inline]
194 pub fn and_then<F>(self, op: F) -> Self
195 where
196 F: FnOnce(T) -> Self,
197 {
198 match self {
199 Self::Valid(x) => op(x),
200 Self::Error(e) => Self::Error(e),
201 }
202 }
203
204 #[inline]
206 pub fn is_valid(&self) -> bool {
207 match self {
208 Self::Valid(_) => true,
209 Self::Error(_) => false,
210 }
211 }
212
213 #[inline]
215 pub fn is_error(&self) -> bool { !self.is_valid() }
216}
217
218crate::internal_macros::impl_op_for_references! {
220 impl<T> ops::Add<NumOpResult<T>> for NumOpResult<T>
221 where
222 (T: Copy + ops::Add<Output = NumOpResult<T>>)
223 {
224 type Output = NumOpResult<T>;
225
226 fn add(self, rhs: Self) -> Self::Output {
227 match (self, rhs) {
228 (R::Valid(lhs), R::Valid(rhs)) => lhs + rhs,
229 (_, _) => R::Error(NumOpError::while_doing(MathOp::Add)),
230 }
231 }
232 }
233
234 impl<T> ops::Add<T> for NumOpResult<T>
235 where
236 (T: Copy + ops::Add<NumOpResult<T>, Output = NumOpResult<T>>)
237 {
238 type Output = NumOpResult<T>;
239
240 fn add(self, rhs: T) -> Self::Output { rhs + self }
241 }
242
243 impl<T> ops::Sub<NumOpResult<T>> for NumOpResult<T>
244 where
245 (T: Copy + ops::Sub<Output = NumOpResult<T>>)
246 {
247 type Output = NumOpResult<T>;
248
249 fn sub(self, rhs: Self) -> Self::Output {
250 match (self, rhs) {
251 (R::Valid(lhs), R::Valid(rhs)) => lhs - rhs,
252 (_, _) => R::Error(NumOpError::while_doing(MathOp::Sub)),
253 }
254 }
255 }
256
257 impl<T> ops::Sub<T> for NumOpResult<T>
258 where
259 (T: Copy + ops::Sub<Output = NumOpResult<T>>)
260 {
261 type Output = NumOpResult<T>;
262
263 fn sub(self, rhs: T) -> Self::Output {
264 match self {
265 R::Valid(amount) => amount - rhs,
266 R::Error(_) => self,
267 }
268 }
269 }
270}
271
272impl<T: ops::AddAssign> ops::AddAssign<T> for NumOpResult<T> {
274 fn add_assign(&mut self, rhs: T) {
275 if let Self::Valid(ref mut lhs) = self {
276 *lhs += rhs;
277 }
278 }
279}
280
281impl<T: ops::AddAssign + Copy> ops::AddAssign<Self> for NumOpResult<T> {
282 fn add_assign(&mut self, rhs: Self) {
283 match (&self, rhs) {
284 (Self::Valid(_), Self::Valid(rhs)) => *self += rhs,
285 (_, _) => *self = Self::Error(NumOpError::while_doing(MathOp::Add)),
286 }
287 }
288}
289
290impl<T: ops::SubAssign> ops::SubAssign<T> for NumOpResult<T> {
292 fn sub_assign(&mut self, rhs: T) {
293 if let Self::Valid(ref mut lhs) = self {
294 *lhs -= rhs;
295 }
296 }
297}
298
299impl<T: ops::SubAssign + Copy> ops::SubAssign<Self> for NumOpResult<T> {
300 fn sub_assign(&mut self, rhs: Self) {
301 match (&self, rhs) {
302 (Self::Valid(_), Self::Valid(rhs)) => *self -= rhs,
303 (_, _) => *self = Self::Error(NumOpError::while_doing(MathOp::Sub)),
304 }
305 }
306}
307
308pub(crate) trait OptionExt<T> {
309 fn valid_or_error(self, op: MathOp) -> NumOpResult<T>;
310}
311
312macro_rules! impl_opt_ext {
313 ($($ty:ident),* $(,)?) => {
314 $(
315 impl OptionExt<$ty> for Option<$ty> {
316 #[inline]
317 fn valid_or_error(self, op: MathOp) -> NumOpResult<$ty> {
318 match self {
319 Some(amount) => R::Valid(amount),
320 None => R::Error(NumOpError(op)),
321 }
322 }
323 }
324 )*
325 }
326}
327impl_opt_ext!(Amount, SignedAmount, u64, i64, FeeRate, Weight);
328
329#[derive(Debug, Copy, Clone, PartialEq, Eq)]
331#[non_exhaustive]
332pub struct NumOpError(MathOp);
333
334impl NumOpError {
335 pub(crate) const fn while_doing(op: MathOp) -> Self { Self(op) }
337
338 pub fn is_overflow(self) -> bool { self.0.is_overflow() }
340
341 pub fn is_div_by_zero(self) -> bool { self.0.is_div_by_zero() }
343
344 pub fn operation(self) -> MathOp { self.0 }
346}
347
348impl fmt::Display for NumOpError {
349 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
350 write!(f, "math operation '{}' gave an invalid numeric result", self.operation())
351 }
352}
353
354#[cfg(feature = "std")]
355impl std::error::Error for NumOpError {}
356
357#[derive(Debug, Copy, Clone, PartialEq, Eq)]
359#[non_exhaustive]
360pub enum MathOp {
361 Add,
363 Sub,
365 Mul,
367 Div,
369 Rem,
371 Neg,
373 #[doc(hidden)]
376 _DoNotUse(Infallible),
377}
378
379impl MathOp {
380 pub fn is_overflow(self) -> bool {
382 matches!(self, Self::Add | Self::Sub | Self::Mul | Self::Neg)
383 }
384
385 pub fn is_div_by_zero(self) -> bool { !self.is_overflow() }
387
388 pub fn is_addition(self) -> bool { self == Self::Add }
390
391 pub fn is_subtraction(self) -> bool { self == Self::Sub }
393
394 pub fn is_multiplication(self) -> bool { self == Self::Mul }
396
397 pub fn is_negation(self) -> bool { self == Self::Neg }
399}
400
401impl fmt::Display for MathOp {
402 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
403 match *self {
404 Self::Add => write!(f, "add"),
405 Self::Sub => write!(f, "sub"),
406 Self::Mul => write!(f, "mul"),
407 Self::Div => write!(f, "div"),
408 Self::Rem => write!(f, "rem"),
409 Self::Neg => write!(f, "neg"),
410 Self::_DoNotUse(infallible) => match infallible {},
411 }
412 }
413}
414
415#[cfg(feature = "arbitrary")]
416impl<'a, T: Arbitrary<'a>> Arbitrary<'a> for NumOpResult<T> {
417 fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
418 match bool::arbitrary(u)? {
419 true => Ok(Self::Valid(T::arbitrary(u)?)),
420 false => Ok(Self::Error(NumOpError(MathOp::arbitrary(u)?))),
421 }
422 }
423}
424
425#[cfg(feature = "arbitrary")]
426impl<'a> Arbitrary<'a> for MathOp {
427 fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
428 let choice = u.int_in_range(0..=5)?;
429 match choice {
430 0 => Ok(Self::Add),
431 1 => Ok(Self::Sub),
432 2 => Ok(Self::Mul),
433 3 => Ok(Self::Div),
434 4 => Ok(Self::Rem),
435 _ => Ok(Self::Neg),
436 }
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::{MathOp, NumOpError, NumOpResult};
443 use crate::{Amount, FeeRate, Weight};
444
445 #[test]
446 fn mathop_predicates() {
447 assert!(MathOp::Add.is_overflow());
448 assert!(MathOp::Sub.is_overflow());
449 assert!(MathOp::Mul.is_overflow());
450 assert!(MathOp::Neg.is_overflow());
451 assert!(!MathOp::Div.is_overflow());
452 assert!(!MathOp::Rem.is_overflow());
453
454 assert!(MathOp::Div.is_div_by_zero());
455 assert!(MathOp::Rem.is_div_by_zero());
456 assert!(!MathOp::Add.is_div_by_zero());
457
458 assert!(MathOp::Add.is_addition());
459 assert!(!MathOp::Sub.is_addition());
460
461 assert!(MathOp::Sub.is_subtraction());
462 assert!(!MathOp::Add.is_subtraction());
463
464 assert!(MathOp::Mul.is_multiplication());
465 assert!(!MathOp::Div.is_multiplication());
466
467 assert!(MathOp::Neg.is_negation());
468 assert!(!MathOp::Add.is_negation());
469 }
470
471 #[test]
472 fn mathop_map() {
473 let res = NumOpResult::Valid(Amount::from_sat_u32(100));
475 let new_value = res.map(|val| (val / FeeRate::from_sat_per_kwu(10)).unwrap());
476 assert_eq!(new_value, NumOpResult::Valid(Weight::from_wu(10_000)));
477
478 let res = NumOpResult::<Weight>::Error(NumOpError::while_doing(MathOp::Add));
480 let res_err = res.map(|_| {
481 panic!("map should not evaluate for wrapped error values");
482 });
483 assert_eq!(res_err, res);
484 }
485
486 #[test]
487 fn mathop_expect() {
488 let amounts = [
489 Amount::from_sat_u32(0),
490 Amount::from_sat_u32(10_000_000),
491 Amount::from_sat_u32(u32::MAX),
492 ];
493 for amount in amounts {
494 assert_eq!(
495 NumOpResult::Valid(amount).expect("unreachable"),
496 NumOpResult::Valid(amount).unwrap(),
497 );
498 assert_eq!(NumOpResult::Valid(amount).expect("unreachable"), amount);
499 }
500 }
501
502 #[test]
503 #[should_panic(expected = "test error message")]
504 fn mathop_expect_panics_on_error() {
505 NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add))
506 .expect("test error message");
507 }
508
509 #[test]
510 fn mathop_unwrap() {
511 let amounts = [
512 Amount::from_sat_u32(0),
513 Amount::from_sat_u32(10_000_000),
514 Amount::from_sat_u32(u32::MAX),
515 ];
516 for amount in amounts {
517 assert_eq!(NumOpResult::Valid(amount).unwrap(), amount);
518 }
519 let weights = [Weight::from_wu(0), Weight::from_wu(16_384_000), Weight::from_wu(u64::MAX)];
520 for weight in weights {
521 assert_eq!(NumOpResult::Valid(weight).unwrap(), weight);
522 }
523 }
524
525 #[test]
526 #[should_panic(expected = "")]
527 fn mathop_unwrap_panics_on_err() {
528 NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add)).unwrap();
529 }
530
531 #[test]
532 fn mathop_unwrap_err() {
533 let errs = [
534 NumOpError::while_doing(MathOp::Add),
535 NumOpError::while_doing(MathOp::Sub),
536 NumOpError::while_doing(MathOp::Mul),
537 NumOpError::while_doing(MathOp::Div),
538 NumOpError::while_doing(MathOp::Neg),
539 NumOpError::while_doing(MathOp::Rem),
540 ];
541 for err in errs {
542 assert_eq!(NumOpResult::<Amount>::Error(err).unwrap_err(), err);
543 }
544 }
545
546 #[test]
547 #[should_panic(expected = "")]
548 fn mathop_unwrap_err_panics_on_valid() {
549 let value = Amount::from_sat_u32(150);
550 NumOpResult::<Amount>::Valid(value).unwrap_err();
551 }
552
553 #[test]
554 fn mathop_unwrap_or() {
555 let base_amount = Amount::from_sat_u32(100);
556
557 let res = NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add));
559 let res_default = res.unwrap_or(base_amount);
560 assert_eq!(res_default, base_amount);
561
562 let res = NumOpResult::Valid(base_amount);
564 let new_amount = res.unwrap_or(Amount::from_sat_u32(50));
565 assert_eq!(new_amount, base_amount);
566 }
567
568 #[test]
569 fn mathop_unwrap_or_else() {
570 let base_amount = Amount::from_sat_u32(100);
571
572 let res = NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add));
574 let res_default = res.unwrap_or_else(|| base_amount);
575 assert_eq!(res_default, base_amount);
576
577 let res = NumOpResult::<Amount>::Valid(base_amount);
579 let new_amount = res.unwrap_or_else(|| {
580 panic!("unwrap_or_else should not evaluate for wrapped valid values");
581 });
582 assert_eq!(new_amount, base_amount);
583 }
584
585 #[test]
586 fn mathop_ok() {
587 let amt = Amount::from_sat_u32(150);
588 assert_eq!(NumOpResult::Valid(amt).ok(), Some(amt));
589
590 let err = NumOpError::while_doing(MathOp::Add);
591 assert_eq!(NumOpResult::<Amount>::Error(err).ok(), None);
592 }
593
594 #[test]
595 fn mathop_and_then() {
596 let res = NumOpResult::Valid(Amount::from_sat_u32(100));
598 let new_value = res.and_then(|val| val + Amount::from_sat_u32(50));
599 assert_eq!(new_value, NumOpResult::Valid(Amount::from_sat_u32(150)));
600
601 let res = NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add));
603 let res_err = res.and_then(|_| {
604 panic!("and_then should not evaluate for wrapped error values");
605 });
606 assert_eq!(res_err, res);
607 }
608}