solana_maths/
rate.rs

1//! Math for preserving precision of ratios and percentages.
2//!
3//! Rates are internally scaled by a WAD (10^18) to preserve
4//! precision up to 18 decimal places. Rates are sized to support
5//! both serialization and precise math for the full range of
6//! unsigned 8-bit integers. The underlying representation is a
7//! u128 rather than u192 to reduce compute cost while losing
8//! support for arithmetic operations at the high end of u8 range.
9
10#![allow(clippy::assign_op_pattern)]
11#![allow(clippy::ptr_offset_with_cast)]
12#![allow(clippy::reversed_empty_ranges)]
13#![allow(clippy::manual_range_contains)]
14
15use crate::{
16    Decimal, MathError, TryAdd, TryDiv, TryMul, TrySub, BIPS_SCALER, HALF_WAD, PERCENT_SCALER,
17    SCALE, WAD,
18};
19use solana_program::program_error::ProgramError;
20use std::{convert::TryFrom, fmt};
21use uint::construct_uint;
22
23// U128 with 128 bits consisting of 2 x 64-bit words
24construct_uint! {
25    pub struct U128(2);
26}
27
28/// Small decimal values, precise to 18 digits
29#[derive(Clone, Copy, Debug, Default, PartialEq, PartialOrd, Eq, Ord)]
30pub struct Rate(pub U128);
31
32impl Rate {
33    /// One
34    pub fn one() -> Self {
35        Self(Self::wad())
36    }
37
38    /// Zero
39    pub fn zero() -> Self {
40        Self(U128::from(0))
41    }
42
43    // OPTIMIZE: use const slice when fixed in BPF toolchain
44    fn wad() -> U128 {
45        U128::from(WAD)
46    }
47
48    // OPTIMIZE: use const slice when fixed in BPF toolchain
49    fn half_wad() -> U128 {
50        U128::from(HALF_WAD)
51    }
52
53    /// Create scaled decimal from percent value
54    pub fn from_percent(percent: u8) -> Self {
55        Self(U128::from(percent as u64 * PERCENT_SCALER))
56    }
57
58    /// Create scaled decimal from percent value
59    pub fn from_bips(bips: u64) -> Self {
60        Self(U128::from(bips * BIPS_SCALER))
61    }
62
63    /// Return raw scaled value
64    #[allow(clippy::wrong_self_convention)]
65    pub fn to_scaled_val(&self) -> u128 {
66        self.0.as_u128()
67    }
68
69    /// Create decimal from scaled value
70    pub fn from_scaled_val(scaled_val: u64) -> Self {
71        Self(U128::from(scaled_val))
72    }
73
74    /// Round scaled decimal to u64
75    pub fn try_round_u64(&self) -> Result<u64, ProgramError> {
76        let rounded_val = Self::half_wad()
77            .checked_add(self.0)
78            .ok_or(MathError::AddOverflow)?
79            .checked_div(Self::wad())
80            .ok_or(MathError::DividedByZero)?;
81        Ok(u64::try_from(rounded_val).map_err(|_| MathError::UnableToRoundU64)?)
82    }
83
84    /// Calculates base^exp
85    pub fn try_pow(&self, mut exp: u64) -> Result<Rate, ProgramError> {
86        let mut base = *self;
87        let mut ret = if exp % 2 != 0 {
88            base
89        } else {
90            Rate(Self::wad())
91        };
92
93        while exp > 0 {
94            exp /= 2;
95            base = base.try_mul(base)?;
96
97            if exp % 2 != 0 {
98                ret = ret.try_mul(base)?;
99            }
100        }
101
102        Ok(ret)
103    }
104}
105
106impl fmt::Display for Rate {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        let mut scaled_val = self.0.to_string();
109        if scaled_val.len() <= SCALE {
110            scaled_val.insert_str(0, &vec!["0"; SCALE - scaled_val.len()].join(""));
111            scaled_val.insert_str(0, "0.");
112        } else {
113            scaled_val.insert(scaled_val.len() - SCALE, '.');
114        }
115        f.write_str(&scaled_val)
116    }
117}
118
119impl TryFrom<Decimal> for Rate {
120    type Error = ProgramError;
121    fn try_from(decimal: Decimal) -> Result<Self, Self::Error> {
122        Ok(Self(U128::from(decimal.to_scaled_val()?)))
123    }
124}
125
126impl TryAdd for Rate {
127    fn try_add(self, rhs: Self) -> Result<Self, ProgramError> {
128        Ok(Self(
129            self.0.checked_add(rhs.0).ok_or(MathError::AddOverflow)?,
130        ))
131    }
132}
133
134impl TrySub for Rate {
135    fn try_sub(self, rhs: Self) -> Result<Self, ProgramError> {
136        Ok(Self(
137            self.0.checked_sub(rhs.0).ok_or(MathError::SubUnderflow)?,
138        ))
139    }
140}
141
142impl TryDiv<u64> for Rate {
143    fn try_div(self, rhs: u64) -> Result<Self, ProgramError> {
144        Ok(Self(
145            self.0
146                .checked_div(U128::from(rhs))
147                .ok_or(MathError::DividedByZero)?,
148        ))
149    }
150}
151
152impl TryDiv<Rate> for Rate {
153    fn try_div(self, rhs: Self) -> Result<Self, ProgramError> {
154        Ok(Self(
155            self.0
156                .checked_mul(Self::wad())
157                .ok_or(MathError::MulOverflow)?
158                .checked_div(rhs.0)
159                .ok_or(MathError::DividedByZero)?,
160        ))
161    }
162}
163
164impl TryMul<u64> for Rate {
165    fn try_mul(self, rhs: u64) -> Result<Self, ProgramError> {
166        Ok(Self(
167            self.0
168                .checked_mul(U128::from(rhs))
169                .ok_or(MathError::MulOverflow)?,
170        ))
171    }
172}
173
174impl TryMul<Rate> for Rate {
175    fn try_mul(self, rhs: Self) -> Result<Self, ProgramError> {
176        Ok(Self(
177            self.0
178                .checked_mul(rhs.0)
179                .ok_or(MathError::MulOverflow)?
180                .checked_div(Self::wad())
181                .ok_or(MathError::DividedByZero)?,
182        ))
183    }
184}
185
186#[cfg(test)]
187mod test {
188    use super::*;
189
190    #[test]
191    fn checked_pow() {
192        assert_eq!(Rate::one(), Rate::one().try_pow(u64::MAX).unwrap());
193    }
194}