aptos_infallible_link/
math.rs

1// Copyright (c) Aptos
2// SPDX-License-Identifier: Apache-2.0
3
4/// Utility macro for writing secure arithmetic operations in order to avoid
5/// integer overflows.
6///
7/// # Examples
8///
9/// ```
10///# use crate::aptos_infallible::checked;
11/// let a: i64 = 1;
12/// let b: i64 = 2;
13/// let c: i64 = 3;
14///
15/// assert_eq!(checked!(a + b).unwrap(), 3);
16/// assert_eq!(checked!(a + b + c).unwrap(), 6);
17///
18/// // When doing multiple different operations, it's important to use parentheses in order
19/// // to guarantee the order of operation!
20/// assert_eq!(checked!(a + ((b - c) * c)).unwrap(), -2);
21///
22/// // When using numeric literals, the compiler might not be able to infer the type properly,
23/// // so if it complains, just add the type to the number.
24/// assert_eq!(checked!(10_u32 / 2_u32).unwrap(), 5);
25/// assert_eq!(checked!(10_u32 * 2_u32).unwrap(), 20);
26/// assert_eq!(checked!(10_u32 - 2_u32).unwrap(), 8);
27/// assert_eq!(checked!(2_i32 - 10_i32).unwrap(), -8);
28/// assert_eq!(checked!(10_u32 + 2_u32).unwrap(), 12);
29///
30/// // Casts using `as` operator must appear within parenthesis
31/// assert_eq!(checked!(10_u32 + (2_u16 as u32)).unwrap(), 12);
32///
33/// assert_eq!(checked!(10_u32 / (1_u32 + 1_u32)).unwrap(), 5);
34/// assert_eq!(checked!(10_u32 * (1_u32 + 1_u32)).unwrap(), 20);
35/// assert_eq!(checked!(10_u32 - (1_u32 + 1_u32)).unwrap(), 8);
36/// assert_eq!(checked!(10_u32 + (1_u32 + 1_u32)).unwrap(), 12);
37///
38/// let max = u32::max_value();
39/// assert!(checked!(max + 1_u32).is_err());
40/// assert!(checked!(0_u32 - 1_u32).is_err());
41///
42/// # struct Foo {
43/// #    pub bar: i32
44/// # }
45/// # impl Foo {
46/// #    pub fn one() -> i32 {
47/// #         1
48/// #    }
49/// # }
50/// // When one of the operands is an associated function or member, due to limitations with the
51/// // macro syntax which disallows an `expr` to precede a `+` sign, make sure to wrap the expression
52/// // in parenthesis
53/// # let foo = Foo { bar: 1 };
54/// assert_eq!(checked!((foo.bar) + 1_i32).unwrap(), 2);
55/// assert_eq!(checked!(1_i32 + (Foo::one())).unwrap(), 2);
56/// ```
57#[macro_export]
58macro_rules! checked {
59    ($a:tt + $b:tt) => {{
60        $a.checked_add($b).ok_or_else(|| $crate::ArithmeticError(format!("Operation results in overflow/underflow: {} + {}", $a, $b)))
61    }};
62    ($a:tt - $b:tt) => {{
63        $a.checked_sub($b).ok_or_else(|| $crate::ArithmeticError(format!("Operation results in overflow/underflow: {} - {}", $a, $b)))
64    }};
65    ($a:tt * $b:tt) => {{
66        $a.checked_mul($b).ok_or_else(|| $crate::ArithmeticError(format!("Operation results in overflow/underflow: {} * {}", $a, $b)))
67    }};
68    ($a:tt / $b:tt) => {{
69        $a.checked_div($b).ok_or_else(|| $crate::ArithmeticError(format!("Operation results in overflow/underflow: {} / {}", $a, $b)))
70    }};
71    ($a:tt + $($tokens:tt)*) => {{
72        checked!( $($tokens)* ).and_then(|b| {
73            b.checked_add($a)
74                .ok_or_else(|| $crate::ArithmeticError(format!("Operation results in overflow/underflow: {} + {}", b, $a)))
75        })
76    }};
77    ($a:tt - $($tokens:tt)*) => {{
78        checked!( $($tokens)* ).and_then(|b| {
79            b.checked_sub($a)
80                .ok_or_else(|| $crate::ArithmeticError(format!("Operation results in overflow/underflow: {} - {}", b, $a)))
81        })
82    }};
83    ($a:tt * $($tokens:tt)*) => {{
84        checked!( $($tokens)* ).and_then(|b| {
85            b.checked_mul($a)
86                .ok_or_else(|| $crate::ArithmeticError(format!("Operation results in overflow/underflow: {} * {}", b, $a)))
87        })
88    }};
89    ($a:tt / $($tokens:tt)*) => {{
90        checked!( $($tokens)* ).and_then(|b| {
91            b.checked_div($a)
92                .ok_or_else(|| $crate::ArithmeticError(format!("Operation results in overflow/underflow: {} / {}", b, $a)))
93        })
94    }};
95}
96
97#[derive(Debug)]
98pub struct ArithmeticError(pub String);
99
100impl std::fmt::Display for ArithmeticError {
101    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
102        write!(f, "{:?}", self.0)
103    }
104}
105
106impl std::error::Error for ArithmeticError {
107    fn description(&self) -> &str {
108        &self.0
109    }
110}