Skip to main content

bitcoin_units/
result.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! Provides a monadic type returned by mathematical operations (`core::ops`).
4
5use core::convert::Infallible;
6use core::{fmt, ops};
7
8#[cfg(feature = "arbitrary")]
9use arbitrary::{Arbitrary, Unstructured};
10use NumOpResult as R;
11
12use crate::{Amount, FeeRate, SignedAmount, Weight};
13
14#[rustfmt::skip]                // Keep public re-exports separate.
15#[doc(no_inline)]
16pub use self::error::NumOpError;
17
18/// Result of a mathematical operation on two numeric types.
19///
20/// In order to prevent overflow we provide a custom result type that is similar to the normal
21/// [`core::result::Result`] but implements mathematical operations (e.g. [`core::ops::Add`]) so that
22/// math operations can be chained ergonomically. This is very similar to how `NaN` works.
23///
24/// `NumOpResult` is a monadic type that contains `Valid` and `Error` (similar to `Ok` and `Err`).
25/// It supports a subset of functions similar to `Result` (e.g. `unwrap`).
26///
27/// # Examples
28///
29/// The `NumOpResult` type provides protection against overflow and div-by-zero.
30///
31/// ### Overflow protection
32///
33/// ```
34/// # use bitcoin_units::{amount, Amount};
35/// // Example UTXO value.
36/// let a1 = Amount::from_sat(1_000_000)?;
37/// // And another value from some other UTXO.
38/// let a2 = Amount::from_sat(765_432)?;
39/// // Just an example (typically one would calculate fee using weight and fee rate).
40/// let fee = Amount::from_sat(1_00)?;
41/// // The amount we want to send.
42/// let spend = Amount::from_sat(1_200_000)?;
43///
44/// // We can error if the change calculation overflows.
45/// //
46/// // For example if the `spend` value comes from the user and the `change` value is later
47/// // used then overflow here could be an attack vector.
48/// let _change = (a1 + a2 - spend - fee).into_result().expect("handle this error");
49///
50/// // Or if we control all the values and know they are sane we can just `unwrap`.
51/// let _change = (a1 + a2 - spend - fee).unwrap();
52/// // `NumOpResult` also implements `expect`.
53/// let _change = (a1 + a2 - spend - fee).expect("we know values don't overflow");
54/// # Ok::<_, amount::OutOfRangeError>(())
55/// ```
56///
57/// ### Divide-by-zero (overflow in `Div` or `Rem`)
58///
59/// In some instances one may wish to differentiate div-by-zero from overflow.
60///
61/// ```
62/// # use bitcoin_units::{Amount, FeeRate, NumOpResult, result::NumOpError};
63/// // Two amounts that will be added to calculate the max fee.
64/// let a = Amount::from_sat(123).expect("valid amount");
65/// let b = Amount::from_sat(467).expect("valid amount");
66/// // Fee rate for transaction.
67/// let fee_rate = FeeRate::from_sat_per_vb(1);
68///
69/// // Somewhat contrived example to show addition operator chained with division.
70/// let max_fee = a + b;
71/// let _fee = match max_fee / fee_rate {
72///     NumOpResult::Valid(fee) => fee,
73///     NumOpResult::Error(e) if e.is_div_by_zero() => {
74///         // Do something when div by zero.
75///         return Err(e);
76///     },
77///     NumOpResult::Error(e) => {
78///         // We separate div-by-zero from overflow in case it needs to be handled separately.
79///         //
80///         // This branch could be hit since `max_fee` came from some previous calculation. And if
81///         // an input to that calculation was from the user then overflow could be an attack vector.
82///         return Err(e);
83///     }
84/// };
85/// # Ok::<_, NumOpError>(())
86/// ```
87#[derive(Debug, Copy, Clone, PartialEq, Eq)]
88#[must_use]
89pub enum NumOpResult<T> {
90    /// Result of a successful mathematical operation.
91    Valid(T),
92    /// Result of an unsuccessful mathematical operation.
93    Error(NumOpError),
94}
95
96impl<T> NumOpResult<T> {
97    /// Maps a `NumOpResult<T>` to `NumOpResult<U>` by applying a function to a
98    /// contained [`NumOpResult::Valid`] value, leaving a [`NumOpResult::Error`] value untouched.
99    #[inline]
100    pub fn map<U, F: FnOnce(T) -> U>(self, op: F) -> NumOpResult<U> {
101        match self {
102            Self::Valid(t) => NumOpResult::Valid(op(t)),
103            Self::Error(e) => NumOpResult::Error(e),
104        }
105    }
106}
107
108impl<T: fmt::Debug> NumOpResult<T> {
109    /// Returns the contained valid numeric type, consuming `self`.
110    ///
111    /// # Panics
112    ///
113    /// Panics with `msg` if the numeric result is an `Error`.
114    #[inline]
115    #[track_caller]
116    pub fn expect(self, msg: &str) -> T {
117        match self {
118            Self::Valid(x) => x,
119            Self::Error(_) => panic!("{}", msg),
120        }
121    }
122
123    /// Returns the contained valid numeric type, consuming `self`.
124    ///
125    /// # Panics
126    ///
127    /// Panics if the numeric result is an `Error`.
128    #[inline]
129    #[track_caller]
130    pub fn unwrap(self) -> T {
131        match self {
132            Self::Valid(x) => x,
133            Self::Error(e) => panic!("tried to unwrap an invalid numeric result: {:?}", e),
134        }
135    }
136
137    /// Returns the contained error, consuming `self`.
138    ///
139    /// # Panics
140    ///
141    /// Panics if the numeric result is valid.
142    #[inline]
143    #[track_caller]
144    pub fn unwrap_err(self) -> NumOpError {
145        match self {
146            Self::Error(e) => e,
147            Self::Valid(a) => panic!("tried to unwrap a valid numeric result: {:?}", a),
148        }
149    }
150
151    /// Returns the contained Some value or a provided default.
152    ///
153    /// Arguments passed to `unwrap_or` are eagerly evaluated; if you are passing the result of a
154    /// function call, it is recommended to use `unwrap_or_else`, which is lazily evaluated.
155    #[inline]
156    #[track_caller]
157    pub fn unwrap_or(self, default: T) -> T {
158        match self {
159            Self::Valid(x) => x,
160            Self::Error(_) => default,
161        }
162    }
163
164    /// Returns the contained `Some` value or computes it from a closure.
165    #[inline]
166    #[track_caller]
167    pub fn unwrap_or_else<F>(self, f: F) -> T
168    where
169        F: FnOnce() -> T,
170    {
171        match self {
172            Self::Valid(x) => x,
173            Self::Error(_) => f(),
174        }
175    }
176
177    /// Converts this `NumOpResult` to an `Option<T>`.
178    #[inline]
179    pub fn ok(self) -> Option<T> {
180        match self {
181            Self::Valid(x) => Some(x),
182            Self::Error(_) => None,
183        }
184    }
185
186    /// Converts this `NumOpResult` to a `Result<T, NumOpError>`.
187    #[inline]
188    #[allow(clippy::missing_errors_doc)]
189    pub fn into_result(self) -> Result<T, NumOpError> {
190        match self {
191            Self::Valid(x) => Ok(x),
192            Self::Error(e) => Err(e),
193        }
194    }
195
196    /// Calls `op` if the numeric result is `Valid`, otherwise returns the `Error` value of `self`.
197    #[inline]
198    pub fn and_then<F>(self, op: F) -> Self
199    where
200        F: FnOnce(T) -> Self,
201    {
202        match self {
203            Self::Valid(x) => op(x),
204            Self::Error(e) => Self::Error(e),
205        }
206    }
207
208    /// Returns `true` if the numeric result is valid.
209    #[inline]
210    pub fn is_valid(&self) -> bool {
211        match self {
212            Self::Valid(_) => true,
213            Self::Error(_) => false,
214        }
215    }
216
217    /// Returns `true` if the numeric result is invalid.
218    #[inline]
219    pub fn is_error(&self) -> bool { !self.is_valid() }
220}
221
222// Implement Add/Sub on NumOpResults for all wrapped types that already implement Add/Sub on themselves
223crate::internal_macros::impl_op_for_references! {
224    impl<T> ops::Add<NumOpResult<T>> for NumOpResult<T>
225    where
226        (T: Copy + ops::Add<Output = NumOpResult<T>>)
227    {
228        type Output = NumOpResult<T>;
229
230        fn add(self, rhs: Self) -> Self::Output {
231            match (self, rhs) {
232                (R::Valid(lhs), R::Valid(rhs)) => lhs + rhs,
233                (_, _) => R::Error(NumOpError::while_doing(MathOp::Add)),
234            }
235        }
236    }
237
238    impl<T> ops::Add<T> for NumOpResult<T>
239    where
240        (T: Copy + ops::Add<NumOpResult<T>, Output = NumOpResult<T>>)
241    {
242        type Output = NumOpResult<T>;
243
244        fn add(self, rhs: T) -> Self::Output { rhs + self }
245    }
246
247    impl<T> ops::Sub<NumOpResult<T>> for NumOpResult<T>
248    where
249        (T: Copy + ops::Sub<Output = NumOpResult<T>>)
250    {
251        type Output = NumOpResult<T>;
252
253        fn sub(self, rhs: Self) -> Self::Output {
254            match (self, rhs) {
255                (R::Valid(lhs), R::Valid(rhs)) => lhs - rhs,
256                (_, _) => R::Error(NumOpError::while_doing(MathOp::Sub)),
257            }
258        }
259    }
260
261    impl<T> ops::Sub<T> for NumOpResult<T>
262    where
263        (T: Copy + ops::Sub<Output = NumOpResult<T>>)
264    {
265        type Output = NumOpResult<T>;
266
267        fn sub(self, rhs: T) -> Self::Output {
268            match self {
269                R::Valid(amount) => amount - rhs,
270                R::Error(_) => self,
271            }
272        }
273    }
274}
275
276// Implement AddAssign on NumOpResults for all wrapped types that already implement AddAssign on themselves
277impl<T: ops::AddAssign> ops::AddAssign<T> for NumOpResult<T> {
278    #[inline]
279    fn add_assign(&mut self, rhs: T) {
280        if let Self::Valid(ref mut lhs) = self {
281            *lhs += rhs;
282        }
283    }
284}
285
286impl<T: ops::AddAssign + Copy> ops::AddAssign<Self> for NumOpResult<T> {
287    #[inline]
288    fn add_assign(&mut self, rhs: Self) {
289        match (&self, rhs) {
290            (Self::Valid(_), Self::Valid(rhs)) => *self += rhs,
291            (_, _) => *self = Self::Error(NumOpError::while_doing(MathOp::Add)),
292        }
293    }
294}
295
296// Implement SubAssign on NumOpResults for all wrapped types that already implement SubAssign on themselves
297impl<T: ops::SubAssign> ops::SubAssign<T> for NumOpResult<T> {
298    #[inline]
299    fn sub_assign(&mut self, rhs: T) {
300        if let Self::Valid(ref mut lhs) = self {
301            *lhs -= rhs;
302        }
303    }
304}
305
306impl<T: ops::SubAssign + Copy> ops::SubAssign<Self> for NumOpResult<T> {
307    #[inline]
308    fn sub_assign(&mut self, rhs: Self) {
309        match (&self, rhs) {
310            (Self::Valid(_), Self::Valid(rhs)) => *self -= rhs,
311            (_, _) => *self = Self::Error(NumOpError::while_doing(MathOp::Sub)),
312        }
313    }
314}
315
316pub(crate) trait OptionExt<T> {
317    fn valid_or_error(self, op: MathOp) -> NumOpResult<T>;
318}
319
320macro_rules! impl_opt_ext {
321    ($($ty:ident),* $(,)?) => {
322        $(
323            impl OptionExt<$ty> for Option<$ty> {
324                #[inline]
325                fn valid_or_error(self, op: MathOp) -> NumOpResult<$ty> {
326                    match self {
327                        Some(amount) => R::Valid(amount),
328                        None => R::Error(NumOpError(op)),
329                    }
330                }
331            }
332        )*
333    }
334}
335impl_opt_ext!(Amount, SignedAmount, u64, i64, FeeRate, Weight);
336
337/// The math operation that caused the error.
338#[derive(Debug, Copy, Clone, PartialEq, Eq)]
339#[non_exhaustive]
340pub enum MathOp {
341    /// Addition failed ([`core::ops::Add`] resulted in an invalid value).
342    Add,
343    /// Subtraction failed ([`core::ops::Sub`] resulted in an invalid value).
344    Sub,
345    /// Multiplication failed ([`core::ops::Mul`] resulted in an invalid value).
346    Mul,
347    /// Division failed ([`core::ops::Div`] attempted div-by-zero).
348    Div,
349    /// Calculating the remainder failed ([`core::ops::Rem`] attempted div-by-zero).
350    Rem,
351    /// Negation failed ([`core::ops::Neg`] resulted in an invalid value).
352    Neg,
353    /// Stops users from casting this enum to an integer.
354    // May get removed if one day Rust supports disabling casts natively.
355    #[doc(hidden)]
356    _DoNotUse(Infallible),
357}
358
359impl MathOp {
360    /// Returns `true` if this operation error'ed due to overflow.
361    pub fn is_overflow(self) -> bool {
362        matches!(self, Self::Add | Self::Sub | Self::Mul | Self::Neg)
363    }
364
365    /// Returns `true` if this operation error'ed due to division by zero.
366    #[inline]
367    pub fn is_div_by_zero(self) -> bool { !self.is_overflow() }
368
369    /// Returns `true` if this operation error'ed due to addition.
370    #[inline]
371    pub fn is_addition(self) -> bool { self == Self::Add }
372
373    /// Returns `true` if this operation error'ed due to subtraction.
374    #[inline]
375    pub fn is_subtraction(self) -> bool { self == Self::Sub }
376
377    /// Returns `true` if this operation error'ed due to multiplication.
378    #[inline]
379    pub fn is_multiplication(self) -> bool { self == Self::Mul }
380
381    /// Returns `true` if this operation error'ed due to negation.
382    #[inline]
383    pub fn is_negation(self) -> bool { self == Self::Neg }
384}
385
386impl fmt::Display for MathOp {
387    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
388        match *self {
389            Self::Add => write!(f, "add"),
390            Self::Sub => write!(f, "sub"),
391            Self::Mul => write!(f, "mul"),
392            Self::Div => write!(f, "div"),
393            Self::Rem => write!(f, "rem"),
394            Self::Neg => write!(f, "neg"),
395            Self::_DoNotUse(infallible) => match infallible {},
396        }
397    }
398}
399
400/// Error types for mathematical operations.
401pub mod error {
402    use core::convert::Infallible;
403    use core::fmt;
404
405    use super::MathOp;
406
407    /// Error returned when a mathematical operation fails.
408    #[derive(Debug, Copy, Clone, PartialEq, Eq)]
409    #[non_exhaustive]
410    pub struct NumOpError(pub(super) MathOp);
411
412    impl NumOpError {
413        /// Constructs a [`NumOpError`] caused by `op`.
414        #[inline]
415        pub(crate) const fn while_doing(op: MathOp) -> Self { Self(op) }
416
417        /// Returns `true` if this operation error'ed due to overflow.
418        #[inline]
419        pub fn is_overflow(self) -> bool { self.0.is_overflow() }
420
421        /// Returns `true` if this operation error'ed due to division by zero.
422        #[inline]
423        pub fn is_div_by_zero(self) -> bool { self.0.is_div_by_zero() }
424
425        /// Returns the [`MathOp`] that caused this error.
426        #[inline]
427        pub fn operation(self) -> MathOp { self.0 }
428    }
429
430    impl From<Infallible> for NumOpError {
431        fn from(never: Infallible) -> Self { match never {} }
432    }
433
434    impl fmt::Display for NumOpError {
435        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
436            write!(f, "math operation '{}' gave an invalid numeric result", self.operation())
437        }
438    }
439
440    #[cfg(feature = "std")]
441    impl std::error::Error for NumOpError {
442        #[inline]
443        fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { None }
444    }
445}
446
447#[cfg(feature = "arbitrary")]
448impl<'a, T: Arbitrary<'a>> Arbitrary<'a> for NumOpResult<T> {
449    fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
450        match bool::arbitrary(u)? {
451            true => Ok(Self::Valid(T::arbitrary(u)?)),
452            false => Ok(Self::Error(NumOpError(MathOp::arbitrary(u)?))),
453        }
454    }
455}
456
457#[cfg(feature = "arbitrary")]
458impl<'a> Arbitrary<'a> for MathOp {
459    fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
460        let choice = u.int_in_range(0..=5)?;
461        match choice {
462            0 => Ok(Self::Add),
463            1 => Ok(Self::Sub),
464            2 => Ok(Self::Mul),
465            3 => Ok(Self::Div),
466            4 => Ok(Self::Rem),
467            _ => Ok(Self::Neg),
468        }
469    }
470}
471
472#[cfg(test)]
473mod tests {
474    #[cfg(feature = "alloc")]
475    use alloc::string::ToString;
476    #[cfg(feature = "std")]
477    use std::error::Error;
478
479    use crate::result::{MathOp, NumOpError, NumOpResult};
480    use crate::{Amount, FeeRate, Weight};
481
482    #[test]
483    fn mathop_predicates() {
484        assert!(MathOp::Add.is_overflow());
485        assert!(MathOp::Sub.is_overflow());
486        assert!(MathOp::Mul.is_overflow());
487        assert!(MathOp::Neg.is_overflow());
488        assert!(!MathOp::Div.is_overflow());
489        assert!(!MathOp::Rem.is_overflow());
490
491        assert!(MathOp::Div.is_div_by_zero());
492        assert!(MathOp::Rem.is_div_by_zero());
493        assert!(!MathOp::Add.is_div_by_zero());
494
495        assert!(MathOp::Add.is_addition());
496        assert!(!MathOp::Sub.is_addition());
497
498        assert!(MathOp::Sub.is_subtraction());
499        assert!(!MathOp::Add.is_subtraction());
500
501        assert!(MathOp::Mul.is_multiplication());
502        assert!(!MathOp::Div.is_multiplication());
503
504        assert!(MathOp::Neg.is_negation());
505        assert!(!MathOp::Add.is_negation());
506    }
507
508    #[test]
509    fn mathop_map() {
510        // op is evaluated for valid results
511        let res = NumOpResult::Valid(Amount::from_sat_u32(100));
512        let new_value = res.map(|val| (val / FeeRate::from_sat_per_kwu(10)).unwrap());
513        assert_eq!(new_value, NumOpResult::Valid(Weight::from_wu(10_000)));
514
515        // op is not evaluated for error results
516        let res = NumOpResult::<Weight>::Error(NumOpError::while_doing(MathOp::Add));
517        let res_err = res.map(|_| {
518            panic!("map should not evaluate for wrapped error values");
519        });
520        assert_eq!(res_err, res);
521    }
522
523    #[test]
524    fn mathop_expect() {
525        let amounts = [
526            Amount::from_sat_u32(0),
527            Amount::from_sat_u32(10_000_000),
528            Amount::from_sat_u32(u32::MAX),
529        ];
530        for amount in amounts {
531            assert_eq!(
532                NumOpResult::Valid(amount).expect("unreachable"),
533                NumOpResult::Valid(amount).unwrap(),
534            );
535            assert_eq!(NumOpResult::Valid(amount).expect("unreachable"), amount);
536        }
537    }
538
539    #[test]
540    #[should_panic(expected = "test error message")]
541    fn mathop_expect_panics_on_error() {
542        NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add))
543            .expect("test error message");
544    }
545
546    #[test]
547    fn mathop_unwrap() {
548        let amounts = [
549            Amount::from_sat_u32(0),
550            Amount::from_sat_u32(10_000_000),
551            Amount::from_sat_u32(u32::MAX),
552        ];
553        for amount in amounts {
554            assert_eq!(NumOpResult::Valid(amount).unwrap(), amount);
555        }
556        let weights = [Weight::from_wu(0), Weight::from_wu(16_384_000), Weight::from_wu(u64::MAX)];
557        for weight in weights {
558            assert_eq!(NumOpResult::Valid(weight).unwrap(), weight);
559        }
560    }
561
562    #[test]
563    #[should_panic(expected = "")]
564    fn mathop_unwrap_panics_on_err() {
565        NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add)).unwrap();
566    }
567
568    #[test]
569    fn mathop_unwrap_err() {
570        let errs = [
571            NumOpError::while_doing(MathOp::Add),
572            NumOpError::while_doing(MathOp::Sub),
573            NumOpError::while_doing(MathOp::Mul),
574            NumOpError::while_doing(MathOp::Div),
575            NumOpError::while_doing(MathOp::Neg),
576            NumOpError::while_doing(MathOp::Rem),
577        ];
578        for err in errs {
579            assert_eq!(NumOpResult::<Amount>::Error(err).unwrap_err(), err);
580        }
581    }
582
583    #[test]
584    #[should_panic(expected = "")]
585    fn mathop_unwrap_err_panics_on_valid() {
586        let value = Amount::from_sat_u32(150);
587        NumOpResult::<Amount>::Valid(value).unwrap_err();
588    }
589
590    #[test]
591    fn mathop_unwrap_or() {
592        let base_amount = Amount::from_sat_u32(100);
593
594        // default is returned for error results
595        let res = NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add));
596        let res_default = res.unwrap_or(base_amount);
597        assert_eq!(res_default, base_amount);
598
599        // wrapped value is returned for valid results
600        let res = NumOpResult::Valid(base_amount);
601        let new_amount = res.unwrap_or(Amount::from_sat_u32(50));
602        assert_eq!(new_amount, base_amount);
603    }
604
605    #[test]
606    fn mathop_unwrap_or_else() {
607        let base_amount = Amount::from_sat_u32(100);
608
609        // op is evaluated for error results
610        let res = NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add));
611        let res_default = res.unwrap_or_else(|| base_amount);
612        assert_eq!(res_default, base_amount);
613
614        // op is not evaluated for valid results
615        let res = NumOpResult::<Amount>::Valid(base_amount);
616        let new_amount = res.unwrap_or_else(|| {
617            panic!("unwrap_or_else should not evaluate for wrapped valid values");
618        });
619        assert_eq!(new_amount, base_amount);
620    }
621
622    #[test]
623    fn mathop_ok() {
624        let amt = Amount::from_sat_u32(150);
625        assert_eq!(NumOpResult::Valid(amt).ok(), Some(amt));
626
627        let err = NumOpError::while_doing(MathOp::Add);
628        assert_eq!(NumOpResult::<Amount>::Error(err).ok(), None);
629    }
630
631    #[test]
632    fn mathop_and_then() {
633        // op is evaluated for valid results
634        let res = NumOpResult::Valid(Amount::from_sat_u32(100));
635        let new_value = res.and_then(|val| val + Amount::from_sat_u32(50));
636        assert_eq!(new_value, NumOpResult::Valid(Amount::from_sat_u32(150)));
637
638        // op is not evaluated for error results
639        let res = NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add));
640        let res_err = res.and_then(|_| {
641            panic!("and_then should not evaluate for wrapped error values");
642        });
643        assert_eq!(res_err, res);
644    }
645
646    #[test]
647    #[cfg(feature = "alloc")]
648    fn error_display_is_non_empty() {
649        // NumOpError - math operation error
650        let e = (Amount::MAX + Amount::MAX).unwrap_err();
651        assert!(!e.to_string().is_empty());
652        #[cfg(feature = "std")]
653        assert!(e.source().is_none());
654    }
655}