Skip to main content

bitcoin_units/
result.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! Provides a monadic type returned by mathematical operations (`core::ops`).
4
5use core::convert::Infallible;
6use core::{fmt, ops};
7
8#[cfg(feature = "arbitrary")]
9use arbitrary::{Arbitrary, Unstructured};
10use NumOpResult as R;
11
12use crate::{Amount, FeeRate, SignedAmount, Weight};
13
14/// Result of a mathematical operation on two numeric types.
15///
16/// In order to prevent overflow we provide a custom result type that is similar to the normal
17/// [`core::result::Result`] but implements mathematical operations (e.g. [`core::ops::Add`]) so that
18/// math operations can be chained ergonomically. This is very similar to how `NaN` works.
19///
20/// `NumOpResult` is a monadic type that contains `Valid` and `Error` (similar to `Ok` and `Err`).
21/// It supports a subset of functions similar to `Result` (e.g. `unwrap`).
22///
23/// # Examples
24///
25/// The `NumOpResult` type provides protection against overflow and div-by-zero.
26///
27/// ### Overflow protection
28///
29/// ```
30/// # use bitcoin_units::{amount, Amount};
31/// // Example UTXO value.
32/// let a1 = Amount::from_sat(1_000_000)?;
33/// // And another value from some other UTXO.
34/// let a2 = Amount::from_sat(765_432)?;
35/// // Just an example (typically one would calculate fee using weight and fee rate).
36/// let fee = Amount::from_sat(1_00)?;
37/// // The amount we want to send.
38/// let spend = Amount::from_sat(1_200_000)?;
39///
40/// // We can error if the change calculation overflows.
41/// //
42/// // For example if the `spend` value comes from the user and the `change` value is later
43/// // used then overflow here could be an attack vector.
44/// let _change = (a1 + a2 - spend - fee).into_result().expect("handle this error");
45///
46/// // Or if we control all the values and know they are sane we can just `unwrap`.
47/// let _change = (a1 + a2 - spend - fee).unwrap();
48/// // `NumOpResult` also implements `expect`.
49/// let _change = (a1 + a2 - spend - fee).expect("we know values don't overflow");
50/// # Ok::<_, amount::OutOfRangeError>(())
51/// ```
52///
53/// ### Divide-by-zero (overflow in `Div` or `Rem`)
54///
55/// In some instances one may wish to differentiate div-by-zero from overflow.
56///
57/// ```
58/// # use bitcoin_units::{Amount, FeeRate, NumOpResult, result::NumOpError};
59/// // Two amounts that will be added to calculate the max fee.
60/// let a = Amount::from_sat(123).expect("valid amount");
61/// let b = Amount::from_sat(467).expect("valid amount");
62/// // Fee rate for transaction.
63/// let fee_rate = FeeRate::from_sat_per_vb(1);
64///
65/// // Somewhat contrived example to show addition operator chained with division.
66/// let max_fee = a + b;
67/// let _fee = match max_fee / fee_rate {
68///     NumOpResult::Valid(fee) => fee,
69///     NumOpResult::Error(e) if e.is_div_by_zero() => {
70///         // Do something when div by zero.
71///         return Err(e);
72///     },
73///     NumOpResult::Error(e) => {
74///         // We separate div-by-zero from overflow in case it needs to be handled separately.
75///         //
76///         // This branch could be hit since `max_fee` came from some previous calculation. And if
77///         // an input to that calculation was from the user then overflow could be an attack vector.
78///         return Err(e);
79///     }
80/// };
81/// # Ok::<_, NumOpError>(())
82/// ```
83#[derive(Debug, Copy, Clone, PartialEq, Eq)]
84#[must_use]
85pub enum NumOpResult<T> {
86    /// Result of a successful mathematical operation.
87    Valid(T),
88    /// Result of an unsuccessful mathematical operation.
89    Error(NumOpError),
90}
91
92impl<T> NumOpResult<T> {
93    /// Maps a `NumOpResult<T>` to `NumOpResult<U>` by applying a function to a
94    /// contained [`NumOpResult::Valid`] value, leaving a [`NumOpResult::Error`] value untouched.
95    #[inline]
96    pub fn map<U, F: FnOnce(T) -> U>(self, op: F) -> NumOpResult<U> {
97        match self {
98            Self::Valid(t) => NumOpResult::Valid(op(t)),
99            Self::Error(e) => NumOpResult::Error(e),
100        }
101    }
102}
103
104impl<T: fmt::Debug> NumOpResult<T> {
105    /// Returns the contained valid numeric type, consuming `self`.
106    ///
107    /// # Panics
108    ///
109    /// Panics with `msg` if the numeric result is an `Error`.
110    #[inline]
111    #[track_caller]
112    pub fn expect(self, msg: &str) -> T {
113        match self {
114            Self::Valid(x) => x,
115            Self::Error(_) => panic!("{}", msg),
116        }
117    }
118
119    /// Returns the contained valid numeric type, consuming `self`.
120    ///
121    /// # Panics
122    ///
123    /// Panics if the numeric result is an `Error`.
124    #[inline]
125    #[track_caller]
126    pub fn unwrap(self) -> T {
127        match self {
128            Self::Valid(x) => x,
129            Self::Error(e) => panic!("tried to unwrap an invalid numeric result: {:?}", e),
130        }
131    }
132
133    /// Returns the contained error, consuming `self`.
134    ///
135    /// # Panics
136    ///
137    /// Panics if the numeric result is valid.
138    #[inline]
139    #[track_caller]
140    pub fn unwrap_err(self) -> NumOpError {
141        match self {
142            Self::Error(e) => e,
143            Self::Valid(a) => panic!("tried to unwrap a valid numeric result: {:?}", a),
144        }
145    }
146
147    /// Returns the contained Some value or a provided default.
148    ///
149    /// Arguments passed to `unwrap_or` are eagerly evaluated; if you are passing the result of a
150    /// function call, it is recommended to use `unwrap_or_else`, which is lazily evaluated.
151    #[inline]
152    #[track_caller]
153    pub fn unwrap_or(self, default: T) -> T {
154        match self {
155            Self::Valid(x) => x,
156            Self::Error(_) => default,
157        }
158    }
159
160    /// Returns the contained `Some` value or computes it from a closure.
161    #[inline]
162    #[track_caller]
163    pub fn unwrap_or_else<F>(self, f: F) -> T
164    where
165        F: FnOnce() -> T,
166    {
167        match self {
168            Self::Valid(x) => x,
169            Self::Error(_) => f(),
170        }
171    }
172
173    /// Converts this `NumOpResult` to an `Option<T>`.
174    #[inline]
175    pub fn ok(self) -> Option<T> {
176        match self {
177            Self::Valid(x) => Some(x),
178            Self::Error(_) => None,
179        }
180    }
181
182    /// Converts this `NumOpResult` to a `Result<T, NumOpError>`.
183    #[inline]
184    #[allow(clippy::missing_errors_doc)]
185    pub fn into_result(self) -> Result<T, NumOpError> {
186        match self {
187            Self::Valid(x) => Ok(x),
188            Self::Error(e) => Err(e),
189        }
190    }
191
192    /// Calls `op` if the numeric result is `Valid`, otherwise returns the `Error` value of `self`.
193    #[inline]
194    pub fn and_then<F>(self, op: F) -> Self
195    where
196        F: FnOnce(T) -> Self,
197    {
198        match self {
199            Self::Valid(x) => op(x),
200            Self::Error(e) => Self::Error(e),
201        }
202    }
203
204    /// Returns `true` if the numeric result is valid.
205    #[inline]
206    pub fn is_valid(&self) -> bool {
207        match self {
208            Self::Valid(_) => true,
209            Self::Error(_) => false,
210        }
211    }
212
213    /// Returns `true` if the numeric result is invalid.
214    #[inline]
215    pub fn is_error(&self) -> bool { !self.is_valid() }
216}
217
218// Implement Add/Sub on NumOpResults for all wrapped types that already implement Add/Sub on themselves
219crate::internal_macros::impl_op_for_references! {
220    impl<T> ops::Add<NumOpResult<T>> for NumOpResult<T>
221    where
222        (T: Copy + ops::Add<Output = NumOpResult<T>>)
223    {
224        type Output = NumOpResult<T>;
225
226        fn add(self, rhs: Self) -> Self::Output {
227            match (self, rhs) {
228                (R::Valid(lhs), R::Valid(rhs)) => lhs + rhs,
229                (_, _) => R::Error(NumOpError::while_doing(MathOp::Add)),
230            }
231        }
232    }
233
234    impl<T> ops::Add<T> for NumOpResult<T>
235    where
236        (T: Copy + ops::Add<NumOpResult<T>, Output = NumOpResult<T>>)
237    {
238        type Output = NumOpResult<T>;
239
240        fn add(self, rhs: T) -> Self::Output { rhs + self }
241    }
242
243    impl<T> ops::Sub<NumOpResult<T>> for NumOpResult<T>
244    where
245        (T: Copy + ops::Sub<Output = NumOpResult<T>>)
246    {
247        type Output = NumOpResult<T>;
248
249        fn sub(self, rhs: Self) -> Self::Output {
250            match (self, rhs) {
251                (R::Valid(lhs), R::Valid(rhs)) => lhs - rhs,
252                (_, _) => R::Error(NumOpError::while_doing(MathOp::Sub)),
253            }
254        }
255    }
256
257    impl<T> ops::Sub<T> for NumOpResult<T>
258    where
259        (T: Copy + ops::Sub<Output = NumOpResult<T>>)
260    {
261        type Output = NumOpResult<T>;
262
263        fn sub(self, rhs: T) -> Self::Output {
264            match self {
265                R::Valid(amount) => amount - rhs,
266                R::Error(_) => self,
267            }
268        }
269    }
270}
271
272// Implement AddAssign on NumOpResults for all wrapped types that already implement AddAssign on themselves
273impl<T: ops::AddAssign> ops::AddAssign<T> for NumOpResult<T> {
274    fn add_assign(&mut self, rhs: T) {
275        if let Self::Valid(ref mut lhs) = self {
276            *lhs += rhs;
277        }
278    }
279}
280
281impl<T: ops::AddAssign + Copy> ops::AddAssign<Self> for NumOpResult<T> {
282    fn add_assign(&mut self, rhs: Self) {
283        match (&self, rhs) {
284            (Self::Valid(_), Self::Valid(rhs)) => *self += rhs,
285            (_, _) => *self = Self::Error(NumOpError::while_doing(MathOp::Add)),
286        }
287    }
288}
289
290// Implement SubAssign on NumOpResults for all wrapped types that already implement SubAssign on themselves
291impl<T: ops::SubAssign> ops::SubAssign<T> for NumOpResult<T> {
292    fn sub_assign(&mut self, rhs: T) {
293        if let Self::Valid(ref mut lhs) = self {
294            *lhs -= rhs;
295        }
296    }
297}
298
299impl<T: ops::SubAssign + Copy> ops::SubAssign<Self> for NumOpResult<T> {
300    fn sub_assign(&mut self, rhs: Self) {
301        match (&self, rhs) {
302            (Self::Valid(_), Self::Valid(rhs)) => *self -= rhs,
303            (_, _) => *self = Self::Error(NumOpError::while_doing(MathOp::Sub)),
304        }
305    }
306}
307
308pub(crate) trait OptionExt<T> {
309    fn valid_or_error(self, op: MathOp) -> NumOpResult<T>;
310}
311
312macro_rules! impl_opt_ext {
313    ($($ty:ident),* $(,)?) => {
314        $(
315            impl OptionExt<$ty> for Option<$ty> {
316                #[inline]
317                fn valid_or_error(self, op: MathOp) -> NumOpResult<$ty> {
318                    match self {
319                        Some(amount) => R::Valid(amount),
320                        None => R::Error(NumOpError(op)),
321                    }
322                }
323            }
324        )*
325    }
326}
327impl_opt_ext!(Amount, SignedAmount, u64, i64, FeeRate, Weight);
328
329/// Error returned when a mathematical operation fails.
330#[derive(Debug, Copy, Clone, PartialEq, Eq)]
331#[non_exhaustive]
332pub struct NumOpError(MathOp);
333
334impl NumOpError {
335    /// Constructs a [`NumOpError`] caused by `op`.
336    pub(crate) const fn while_doing(op: MathOp) -> Self { Self(op) }
337
338    /// Returns `true` if this operation error'ed due to overflow.
339    pub fn is_overflow(self) -> bool { self.0.is_overflow() }
340
341    /// Returns `true` if this operation error'ed due to division by zero.
342    pub fn is_div_by_zero(self) -> bool { self.0.is_div_by_zero() }
343
344    /// Returns the [`MathOp`] that caused this error.
345    pub fn operation(self) -> MathOp { self.0 }
346}
347
348impl fmt::Display for NumOpError {
349    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
350        write!(f, "math operation '{}' gave an invalid numeric result", self.operation())
351    }
352}
353
354#[cfg(feature = "std")]
355impl std::error::Error for NumOpError {}
356
357/// The math operation that caused the error.
358#[derive(Debug, Copy, Clone, PartialEq, Eq)]
359#[non_exhaustive]
360pub enum MathOp {
361    /// Addition failed ([`core::ops::Add`] resulted in an invalid value).
362    Add,
363    /// Subtraction failed ([`core::ops::Sub`] resulted in an invalid value).
364    Sub,
365    /// Multiplication failed ([`core::ops::Mul`] resulted in an invalid value).
366    Mul,
367    /// Division failed ([`core::ops::Div`] attempted div-by-zero).
368    Div,
369    /// Calculating the remainder failed ([`core::ops::Rem`] attempted div-by-zero).
370    Rem,
371    /// Negation failed ([`core::ops::Neg`] resulted in an invalid value).
372    Neg,
373    /// Stops users from casting this enum to an integer.
374    // May get removed if one day Rust supports disabling casts natively.
375    #[doc(hidden)]
376    _DoNotUse(Infallible),
377}
378
379impl MathOp {
380    /// Returns `true` if this operation error'ed due to overflow.
381    pub fn is_overflow(self) -> bool {
382        matches!(self, Self::Add | Self::Sub | Self::Mul | Self::Neg)
383    }
384
385    /// Returns `true` if this operation error'ed due to division by zero.
386    pub fn is_div_by_zero(self) -> bool { !self.is_overflow() }
387
388    /// Returns `true` if this operation error'ed due to addition.
389    pub fn is_addition(self) -> bool { self == Self::Add }
390
391    /// Returns `true` if this operation error'ed due to subtraction.
392    pub fn is_subtraction(self) -> bool { self == Self::Sub }
393
394    /// Returns `true` if this operation error'ed due to multiplication.
395    pub fn is_multiplication(self) -> bool { self == Self::Mul }
396
397    /// Returns `true` if this operation error'ed due to negation.
398    pub fn is_negation(self) -> bool { self == Self::Neg }
399}
400
401impl fmt::Display for MathOp {
402    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
403        match *self {
404            Self::Add => write!(f, "add"),
405            Self::Sub => write!(f, "sub"),
406            Self::Mul => write!(f, "mul"),
407            Self::Div => write!(f, "div"),
408            Self::Rem => write!(f, "rem"),
409            Self::Neg => write!(f, "neg"),
410            Self::_DoNotUse(infallible) => match infallible {},
411        }
412    }
413}
414
415#[cfg(feature = "arbitrary")]
416impl<'a, T: Arbitrary<'a>> Arbitrary<'a> for NumOpResult<T> {
417    fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
418        match bool::arbitrary(u)? {
419            true => Ok(Self::Valid(T::arbitrary(u)?)),
420            false => Ok(Self::Error(NumOpError(MathOp::arbitrary(u)?))),
421        }
422    }
423}
424
425#[cfg(feature = "arbitrary")]
426impl<'a> Arbitrary<'a> for MathOp {
427    fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
428        let choice = u.int_in_range(0..=5)?;
429        match choice {
430            0 => Ok(Self::Add),
431            1 => Ok(Self::Sub),
432            2 => Ok(Self::Mul),
433            3 => Ok(Self::Div),
434            4 => Ok(Self::Rem),
435            _ => Ok(Self::Neg),
436        }
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::{MathOp, NumOpError, NumOpResult};
443    use crate::{Amount, FeeRate, Weight};
444
445    #[test]
446    fn mathop_predicates() {
447        assert!(MathOp::Add.is_overflow());
448        assert!(MathOp::Sub.is_overflow());
449        assert!(MathOp::Mul.is_overflow());
450        assert!(MathOp::Neg.is_overflow());
451        assert!(!MathOp::Div.is_overflow());
452        assert!(!MathOp::Rem.is_overflow());
453
454        assert!(MathOp::Div.is_div_by_zero());
455        assert!(MathOp::Rem.is_div_by_zero());
456        assert!(!MathOp::Add.is_div_by_zero());
457
458        assert!(MathOp::Add.is_addition());
459        assert!(!MathOp::Sub.is_addition());
460
461        assert!(MathOp::Sub.is_subtraction());
462        assert!(!MathOp::Add.is_subtraction());
463
464        assert!(MathOp::Mul.is_multiplication());
465        assert!(!MathOp::Div.is_multiplication());
466
467        assert!(MathOp::Neg.is_negation());
468        assert!(!MathOp::Add.is_negation());
469    }
470
471    #[test]
472    fn mathop_map() {
473        // op is evaluated for valid results
474        let res = NumOpResult::Valid(Amount::from_sat_u32(100));
475        let new_value = res.map(|val| (val / FeeRate::from_sat_per_kwu(10)).unwrap());
476        assert_eq!(new_value, NumOpResult::Valid(Weight::from_wu(10_000)));
477
478        // op is not evaluated for error results
479        let res = NumOpResult::<Weight>::Error(NumOpError::while_doing(MathOp::Add));
480        let res_err = res.map(|_| {
481            panic!("map should not evaluate for wrapped error values");
482        });
483        assert_eq!(res_err, res);
484    }
485
486    #[test]
487    fn mathop_expect() {
488        let amounts = [
489            Amount::from_sat_u32(0),
490            Amount::from_sat_u32(10_000_000),
491            Amount::from_sat_u32(u32::MAX),
492        ];
493        for amount in amounts {
494            assert_eq!(
495                NumOpResult::Valid(amount).expect("unreachable"),
496                NumOpResult::Valid(amount).unwrap(),
497            );
498            assert_eq!(NumOpResult::Valid(amount).expect("unreachable"), amount);
499        }
500    }
501
502    #[test]
503    #[should_panic(expected = "test error message")]
504    fn mathop_expect_panics_on_error() {
505        NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add))
506            .expect("test error message");
507    }
508
509    #[test]
510    fn mathop_unwrap() {
511        let amounts = [
512            Amount::from_sat_u32(0),
513            Amount::from_sat_u32(10_000_000),
514            Amount::from_sat_u32(u32::MAX),
515        ];
516        for amount in amounts {
517            assert_eq!(NumOpResult::Valid(amount).unwrap(), amount);
518        }
519        let weights = [Weight::from_wu(0), Weight::from_wu(16_384_000), Weight::from_wu(u64::MAX)];
520        for weight in weights {
521            assert_eq!(NumOpResult::Valid(weight).unwrap(), weight);
522        }
523    }
524
525    #[test]
526    #[should_panic(expected = "")]
527    fn mathop_unwrap_panics_on_err() {
528        NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add)).unwrap();
529    }
530
531    #[test]
532    fn mathop_unwrap_err() {
533        let errs = [
534            NumOpError::while_doing(MathOp::Add),
535            NumOpError::while_doing(MathOp::Sub),
536            NumOpError::while_doing(MathOp::Mul),
537            NumOpError::while_doing(MathOp::Div),
538            NumOpError::while_doing(MathOp::Neg),
539            NumOpError::while_doing(MathOp::Rem),
540        ];
541        for err in errs {
542            assert_eq!(NumOpResult::<Amount>::Error(err).unwrap_err(), err);
543        }
544    }
545
546    #[test]
547    #[should_panic(expected = "")]
548    fn mathop_unwrap_err_panics_on_valid() {
549        let value = Amount::from_sat_u32(150);
550        NumOpResult::<Amount>::Valid(value).unwrap_err();
551    }
552
553    #[test]
554    fn mathop_unwrap_or() {
555        let base_amount = Amount::from_sat_u32(100);
556
557        // default is returned for error results
558        let res = NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add));
559        let res_default = res.unwrap_or(base_amount);
560        assert_eq!(res_default, base_amount);
561
562        // wrapped value is returned for valid results
563        let res = NumOpResult::Valid(base_amount);
564        let new_amount = res.unwrap_or(Amount::from_sat_u32(50));
565        assert_eq!(new_amount, base_amount);
566    }
567
568    #[test]
569    fn mathop_unwrap_or_else() {
570        let base_amount = Amount::from_sat_u32(100);
571
572        // op is evaluated for error results
573        let res = NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add));
574        let res_default = res.unwrap_or_else(|| base_amount);
575        assert_eq!(res_default, base_amount);
576
577        // op is not evaluated for valid results
578        let res = NumOpResult::<Amount>::Valid(base_amount);
579        let new_amount = res.unwrap_or_else(|| {
580            panic!("unwrap_or_else should not evaluate for wrapped valid values");
581        });
582        assert_eq!(new_amount, base_amount);
583    }
584
585    #[test]
586    fn mathop_ok() {
587        let amt = Amount::from_sat_u32(150);
588        assert_eq!(NumOpResult::Valid(amt).ok(), Some(amt));
589
590        let err = NumOpError::while_doing(MathOp::Add);
591        assert_eq!(NumOpResult::<Amount>::Error(err).ok(), None);
592    }
593
594    #[test]
595    fn mathop_and_then() {
596        // op is evaluated for valid results
597        let res = NumOpResult::Valid(Amount::from_sat_u32(100));
598        let new_value = res.and_then(|val| val + Amount::from_sat_u32(50));
599        assert_eq!(new_value, NumOpResult::Valid(Amount::from_sat_u32(150)));
600
601        // op is not evaluated for error results
602        let res = NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add));
603        let res_err = res.and_then(|_| {
604            panic!("and_then should not evaluate for wrapped error values");
605        });
606        assert_eq!(res_err, res);
607    }
608}