mars_utils/
math.rs

1use std::convert::TryInto;
2
3use cosmwasm_std::{
4    CheckedFromRatioError, Decimal, Fraction, OverflowError, OverflowOperation, StdError,
5    StdResult, Uint128, Uint256,
6};
7
8pub fn uint128_checked_div_with_ceil(
9    numerator: Uint128,
10    denominator: Uint128,
11) -> StdResult<Uint128> {
12    let mut result = numerator.checked_div(denominator)?;
13
14    if !numerator.checked_rem(denominator)?.is_zero() {
15        result += Uint128::from(1_u128);
16    }
17
18    Ok(result)
19}
20
21/// Divide 'a' by 'b'.
22pub fn divide_decimal_by_decimal(a: Decimal, b: Decimal) -> StdResult<Decimal> {
23    Decimal::checked_from_ratio(a.numerator(), b.numerator()).map_err(|e| match e {
24        CheckedFromRatioError::Overflow => StdError::Overflow {
25            source: OverflowError {
26                operation: OverflowOperation::Mul,
27                operand1: a.numerator().to_string(),
28                operand2: a.denominator().to_string(),
29            },
30        },
31        CheckedFromRatioError::DivideByZero => StdError::DivideByZero {
32            source: cosmwasm_std::DivideByZeroError {
33                operand: b.to_string(),
34            },
35        },
36    })
37}
38
39/// Divide Uint128 by Decimal.
40/// (Uint128 / numerator / denominator) is equal to (Uint128 * denominator / numerator).
41pub fn divide_uint128_by_decimal(a: Uint128, b: Decimal) -> StdResult<Uint128> {
42    // (Uint128 / numerator / denominator) is equal to (Uint128 * denominator / numerator).
43    let numerator_u256 = a.full_mul(b.denominator());
44    let denominator_u256 = Uint256::from(b.numerator());
45
46    let result_u256 = numerator_u256 / denominator_u256;
47
48    let result = result_u256.try_into()?;
49    Ok(result)
50}
51
52/// Divide Uint128 by Decimal, rounding up to the nearest integer.
53pub fn divide_uint128_by_decimal_and_ceil(a: Uint128, b: Decimal) -> StdResult<Uint128> {
54    // (Uint128 / numerator / denominator) is equal to (Uint128 * denominator / numerator).
55    let numerator_u256 = a.full_mul(b.denominator());
56    let denominator_u256 = Uint256::from(b.numerator());
57
58    let mut result_u256 = numerator_u256 / denominator_u256;
59
60    if numerator_u256.checked_rem(denominator_u256)? > Uint256::zero() {
61        result_u256 += Uint256::from(1_u32);
62    }
63
64    let result = result_u256.try_into()?;
65    Ok(result)
66}
67
68/// Multiply Uint128 by Decimal, rounding up to the nearest integer.
69pub fn multiply_uint128_by_decimal_and_ceil(a: Uint128, b: Decimal) -> StdResult<Uint128> {
70    let numerator_u256 = a.full_mul(b.numerator());
71    let denominator_u256 = Uint256::from(b.denominator());
72
73    let mut result_u256 = numerator_u256 / denominator_u256;
74
75    if numerator_u256.checked_rem(denominator_u256)? > Uint256::zero() {
76        result_u256 += Uint256::from(1_u32);
77    }
78
79    let result = result_u256.try_into()?;
80    Ok(result)
81}
82
83#[cfg(test)]
84mod tests {
85    use std::str::FromStr;
86
87    use cosmwasm_std::{ConversionOverflowError, OverflowOperation};
88
89    use super::*;
90
91    const DECIMAL_FRACTIONAL: Uint128 = Uint128::new(1_000_000_000_000_000_000u128); // 1*10**18
92    const DECIMAL_FRACTIONAL_SQUARED: Uint128 =
93        Uint128::new(1_000_000_000_000_000_000_000_000_000_000_000_000u128); // (1*10**18)**2 = 1*10**36
94
95    #[test]
96    fn test_uint128_checked_div_with_ceil() {
97        let a = Uint128::new(120u128);
98        let b = Uint128::zero();
99        uint128_checked_div_with_ceil(a, b).unwrap_err();
100
101        let a = Uint128::new(120u128);
102        let b = Uint128::new(60_u128);
103        let c = uint128_checked_div_with_ceil(a, b).unwrap();
104        assert_eq!(c, Uint128::new(2u128));
105
106        let a = Uint128::new(120u128);
107        let b = Uint128::new(119_u128);
108        let c = uint128_checked_div_with_ceil(a, b).unwrap();
109        assert_eq!(c, Uint128::new(2u128));
110
111        let a = Uint128::new(120u128);
112        let b = Uint128::new(120_u128);
113        let c = uint128_checked_div_with_ceil(a, b).unwrap();
114        assert_eq!(c, Uint128::new(1u128));
115
116        let a = Uint128::new(120u128);
117        let b = Uint128::new(121_u128);
118        let c = uint128_checked_div_with_ceil(a, b).unwrap();
119        assert_eq!(c, Uint128::new(1u128));
120
121        let a = Uint128::zero();
122        let b = Uint128::new(121_u128);
123        let c = uint128_checked_div_with_ceil(a, b).unwrap();
124        assert_eq!(c, Uint128::zero());
125    }
126
127    #[test]
128    fn checked_decimal_division() {
129        let a = Decimal::from_ratio(99988u128, 100u128);
130        let b = Decimal::from_ratio(24997u128, 100u128);
131        let c = divide_decimal_by_decimal(a, b).unwrap();
132        assert_eq!(c, Decimal::from_str("4.0").unwrap());
133
134        let a = Decimal::from_ratio(123456789u128, 1000000u128);
135        let b = Decimal::from_ratio(33u128, 1u128);
136        let c = divide_decimal_by_decimal(a, b).unwrap();
137        assert_eq!(c, Decimal::from_str("3.741114818181818181").unwrap());
138
139        let a = Decimal::MAX;
140        let b = Decimal::MAX;
141        let c = divide_decimal_by_decimal(a, b).unwrap();
142        assert_eq!(c, Decimal::one());
143
144        // Note: DivideByZeroError is not public so we just check if dividing by zero returns error
145        let a = Decimal::one();
146        let b = Decimal::zero();
147        divide_decimal_by_decimal(a, b).unwrap_err();
148
149        let a = Decimal::MAX;
150        let b = Decimal::from_ratio(1u128, DECIMAL_FRACTIONAL);
151        let res_error = divide_decimal_by_decimal(a, b).unwrap_err();
152        assert_eq!(
153            res_error,
154            OverflowError::new(OverflowOperation::Mul, Uint128::MAX, DECIMAL_FRACTIONAL).into()
155        );
156    }
157
158    #[test]
159    fn test_divide_uint128_by_decimal() {
160        let a = Uint128::new(120u128);
161        let b = Decimal::from_ratio(120u128, 15u128);
162        let c = divide_uint128_by_decimal(a, b).unwrap();
163        assert_eq!(c, Uint128::new(15u128));
164
165        let a = Uint128::new(DECIMAL_FRACTIONAL.u128());
166        let b = Decimal::from_ratio(DECIMAL_FRACTIONAL.u128(), 1u128);
167        let c = divide_uint128_by_decimal(a, b).unwrap();
168        assert_eq!(c, Uint128::new(1u128));
169
170        let a = Uint128::new(DECIMAL_FRACTIONAL.u128());
171        let b = Decimal::from_ratio(1u128, DECIMAL_FRACTIONAL.u128());
172        let c = divide_uint128_by_decimal(a, b).unwrap();
173        assert_eq!(c, Uint128::new(DECIMAL_FRACTIONAL_SQUARED.u128()));
174
175        let a = Uint128::MAX;
176        let b = Decimal::one();
177        let c = divide_uint128_by_decimal(a, b).unwrap();
178        assert_eq!(c, Uint128::MAX);
179
180        let a = Uint128::new(1_000_000_000_000_000_000);
181        let b = Decimal::from_ratio(1u128, DECIMAL_FRACTIONAL);
182        let c = divide_uint128_by_decimal(a, b).unwrap();
183        assert_eq!(c, Uint128::new(1_000_000_000_000_000_000_000_000_000_000_000_000));
184
185        // Division is truncated
186        let a = Uint128::new(100);
187        let b = Decimal::from_ratio(3u128, 1u128);
188        let c = divide_uint128_by_decimal(a, b).unwrap();
189        assert_eq!(c, Uint128::new(33));
190
191        let a = Uint128::new(75);
192        let b = Decimal::from_ratio(100u128, 1u128);
193        let c = divide_uint128_by_decimal(a, b).unwrap();
194        assert_eq!(c, Uint128::new(0));
195
196        // Overflow
197        let a = Uint128::MAX;
198        let b = Decimal::from_ratio(1_u128, 10_u128);
199        let res_error = divide_uint128_by_decimal(a, b).unwrap_err();
200        assert_eq!(
201            res_error,
202            ConversionOverflowError::new(
203                "Uint256",
204                "Uint128",
205                "3402823669209384634633746074317682114550"
206            )
207            .into()
208        );
209    }
210
211    #[test]
212    fn test_divide_uint128_by_decimal_and_ceil() {
213        let a = Uint128::new(120u128);
214        let b = Decimal::from_ratio(120u128, 15u128);
215        let c = divide_uint128_by_decimal_and_ceil(a, b).unwrap();
216        assert_eq!(c, Uint128::new(15u128));
217
218        let a = Uint128::new(DECIMAL_FRACTIONAL.u128());
219        let b = Decimal::from_ratio(DECIMAL_FRACTIONAL.u128(), 1u128);
220        let c = divide_uint128_by_decimal_and_ceil(a, b).unwrap();
221        assert_eq!(c, Uint128::new(1u128));
222
223        let a = Uint128::new(DECIMAL_FRACTIONAL.u128());
224        let b = Decimal::from_ratio(1u128, DECIMAL_FRACTIONAL.u128());
225        let c = divide_uint128_by_decimal_and_ceil(a, b).unwrap();
226        assert_eq!(c, Uint128::new(DECIMAL_FRACTIONAL_SQUARED.u128()));
227
228        let a = Uint128::MAX;
229        let b = Decimal::one();
230        let c = divide_uint128_by_decimal_and_ceil(a, b).unwrap();
231        assert_eq!(c, Uint128::MAX);
232
233        let a = Uint128::new(1_000_000_000_000_000_000);
234        let b = Decimal::from_ratio(1u128, DECIMAL_FRACTIONAL);
235        let c = divide_uint128_by_decimal_and_ceil(a, b).unwrap();
236        assert_eq!(c, Uint128::new(1_000_000_000_000_000_000_000_000_000_000_000_000));
237
238        // Division is rounded up
239        let a = Uint128::new(100);
240        let b = Decimal::from_ratio(3u128, 1u128);
241        let c = divide_uint128_by_decimal_and_ceil(a, b).unwrap();
242        assert_eq!(c, Uint128::new(34));
243
244        let a = Uint128::new(75);
245        let b = Decimal::from_ratio(100u128, 1u128);
246        let c = divide_uint128_by_decimal_and_ceil(a, b).unwrap();
247        assert_eq!(c, Uint128::new(1));
248
249        // Overflow
250        let a = Uint128::MAX;
251        let b = Decimal::from_ratio(1_u128, 10_u128);
252        let res_error = divide_uint128_by_decimal_and_ceil(a, b).unwrap_err();
253        assert_eq!(
254            res_error,
255            ConversionOverflowError::new(
256                "Uint256",
257                "Uint128",
258                "3402823669209384634633746074317682114550"
259            )
260            .into()
261        );
262    }
263}