use crate::{conversions::to_u32, errors::ParclMathErrorCode, uint::U256};
use anchor_lang::prelude::*;
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
type InnerUint = U256;
pub const ONE: u128 = 1_000_000_000_000;
pub const BPS_EXPO: i32 = -4;
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct PreciseNumber {
pub val: InnerUint,
}
fn one() -> InnerUint {
InnerUint::from(ONE)
}
fn zero() -> InnerUint {
InnerUint::from(0)
}
impl PreciseNumber {
fn rounding_correction() -> InnerUint {
InnerUint::from(ONE / 2)
}
fn precision() -> InnerUint {
InnerUint::from(100)
}
pub fn zero() -> Self {
Self { val: zero() }
}
pub fn one() -> Self {
Self { val: one() }
}
const MAX_APPROXIMATION_ITERATIONS: u128 = 100;
fn min_pow_base() -> InnerUint {
InnerUint::from(1)
}
fn max_pow_base() -> InnerUint {
InnerUint::from(2 * ONE)
}
pub fn new(val: u128) -> Result<Self> {
let val = InnerUint::from(val)
.checked_mul(one())
.ok_or(ParclMathErrorCode::IntegerOverflow)?;
Ok(Self { val })
}
pub fn from(val: u128) -> Self {
let val = InnerUint::from(val);
Self { val }
}
pub fn from_bps(bps: u16) -> Result<Self> {
Self::from_decimal(bps.into(), BPS_EXPO)
}
pub fn from_decimal(decimal: u128, exponent: i32) -> Result<Self> {
let precision_expo = 12 + exponent;
let mut precision = 10u128
.checked_pow(to_u32(precision_expo.abs())?)
.ok_or(ParclMathErrorCode::IntegerOverflow)?;
if precision_expo < 0 {
precision = ONE
.checked_div(precision)
.ok_or(ParclMathErrorCode::IntegerOverflow)?;
}
let val = InnerUint::from(
decimal
.checked_mul(precision)
.ok_or(ParclMathErrorCode::IntegerOverflow)?,
);
Ok(Self { val })
}
pub fn to_imprecise(&self) -> Result<u128> {
let val = self
.val
.checked_add(Self::rounding_correction())
.ok_or(ParclMathErrorCode::IntegerOverflow)?
.checked_div(one())
.ok_or(ParclMathErrorCode::IntegerOverflow)?
.as_u128();
Ok(val)
}
pub fn to_imprecise_u64(&self) -> Result<u64> {
let val = self
.val
.checked_add(Self::rounding_correction())
.ok_or(ParclMathErrorCode::IntegerOverflow)?
.checked_div(one())
.ok_or(ParclMathErrorCode::IntegerOverflow)?
.as_u64();
Ok(val)
}
pub fn mul_up(self, rhs: Self) -> Result<Self> {
Ok(Self::from(
self.val
.as_u128()
.checked_mul(rhs.val.as_u128())
.ok_or(ParclMathErrorCode::IntegerOverflow)?
.checked_add(
ONE.checked_sub(1)
.ok_or(ParclMathErrorCode::IntegerOverflow)?,
)
.ok_or(ParclMathErrorCode::IntegerOverflow)?
.checked_div(ONE)
.ok_or(ParclMathErrorCode::IntegerOverflow)?,
))
}
pub fn div_up(self, rhs: Self) -> Result<Self> {
Ok(Self::from(
self.val
.as_u128()
.checked_mul(ONE)
.unwrap()
.checked_add(rhs.val.as_u128().checked_sub(1).unwrap())
.unwrap()
.checked_div(rhs.val.as_u128())
.unwrap(),
))
}
pub fn almost_eq(&self, rhs: &Self, precision: InnerUint) -> bool {
let (difference, _) = self.unsigned_sub(rhs);
difference.val < precision
}
pub fn less_than(&self, rhs: &Self) -> bool {
self.val < rhs.val
}
pub fn greater_than(&self, rhs: &Self) -> bool {
self.val > rhs.val
}
pub fn less_than_or_equal(&self, rhs: &Self) -> bool {
self.val <= rhs.val
}
pub fn greater_than_or_equal(&self, rhs: &Self) -> bool {
self.val >= rhs.val
}
pub fn floor(&self) -> Result<Self> {
let one = one();
let val = self
.val
.checked_div(one)
.ok_or(ParclMathErrorCode::IntegerOverflow)?
.checked_mul(one)
.ok_or(ParclMathErrorCode::IntegerOverflow)?;
Ok(Self { val })
}
pub fn ceil(&self) -> Result<Self> {
let one = one();
let val = self
.val
.checked_add(
one.checked_sub(InnerUint::from(1))
.ok_or(ParclMathErrorCode::IntegerOverflow)?,
)
.ok_or(ParclMathErrorCode::IntegerOverflow)?
.checked_div(one)
.ok_or(ParclMathErrorCode::IntegerOverflow)?
.checked_mul(one)
.ok_or(ParclMathErrorCode::IntegerOverflow)?;
Ok(Self { val })
}
pub fn checked_div(&self, rhs: &Self) -> Result<Self> {
if *rhs == Self::zero() {
return Err(error!(ParclMathErrorCode::IntegerOverflow));
}
match self.val.checked_mul(one()) {
Some(v) => {
let val = v
.checked_add(Self::rounding_correction())
.ok_or(ParclMathErrorCode::IntegerOverflow)?
.checked_div(rhs.val)
.ok_or(ParclMathErrorCode::IntegerOverflow)?;
Ok(Self { val })
}
None => {
let val = self
.val
.checked_add(Self::rounding_correction())
.ok_or(ParclMathErrorCode::IntegerOverflow)?
.checked_div(rhs.val)
.ok_or(ParclMathErrorCode::IntegerOverflow)?
.checked_mul(one())
.ok_or(ParclMathErrorCode::IntegerOverflow)?;
Ok(Self { val })
}
}
}
pub fn checked_mul(&self, rhs: &Self) -> Result<Self> {
let one = one();
match self.val.checked_mul(rhs.val) {
Some(v) => {
let val = v
.checked_add(Self::rounding_correction())
.ok_or(ParclMathErrorCode::IntegerOverflow)?
.checked_div(one)
.ok_or(ParclMathErrorCode::IntegerOverflow)?;
Ok(Self { val })
}
None => {
let val = if self.val >= rhs.val {
self.val
.checked_div(one)
.ok_or(ParclMathErrorCode::IntegerOverflow)?
.checked_mul(rhs.val)
.ok_or(ParclMathErrorCode::IntegerOverflow)?
} else {
rhs.val
.checked_div(one)
.ok_or(ParclMathErrorCode::IntegerOverflow)?
.checked_mul(self.val)
.ok_or(ParclMathErrorCode::IntegerOverflow)?
};
Ok(Self { val })
}
}
}
pub fn checked_add(&self, rhs: &Self) -> Result<Self> {
let val = self
.val
.checked_add(rhs.val)
.ok_or(ParclMathErrorCode::IntegerOverflow)?;
Ok(Self { val })
}
pub fn checked_sub(&self, rhs: &Self) -> Result<Self> {
let val = self
.val
.checked_sub(rhs.val)
.ok_or(ParclMathErrorCode::IntegerOverflow)?;
Ok(Self { val })
}
pub fn unsigned_sub(&self, rhs: &Self) -> (Self, bool) {
match self.val.checked_sub(rhs.val) {
None => {
let val = rhs.val.checked_sub(self.val).unwrap();
(Self { val }, true)
}
Some(val) => (Self { val }, false),
}
}
pub fn checked_pow(&self, exponent: u128) -> Result<Self> {
let val = if exponent
.checked_rem(2)
.ok_or(ParclMathErrorCode::IntegerOverflow)?
== 0
{
one()
} else {
self.val
};
let mut result = Self { val };
let mut squared_base = *self;
let mut current_exponent = exponent
.checked_div(2)
.ok_or(ParclMathErrorCode::IntegerOverflow)?;
while current_exponent != 0 {
squared_base = squared_base.checked_mul(&squared_base)?;
if current_exponent
.checked_rem(2)
.ok_or(ParclMathErrorCode::IntegerOverflow)?
!= 0
{
result = result.checked_mul(&squared_base)?;
}
current_exponent = current_exponent
.checked_div(2)
.ok_or(ParclMathErrorCode::IntegerOverflow)?;
}
Ok(result)
}
fn checked_pow_approximation(&self, exponent: &Self, max_iterations: u128) -> Result<Self> {
assert!(self.val >= Self::min_pow_base());
assert!(self.val <= Self::max_pow_base());
let one = Self::one();
if *exponent == Self::zero() {
return Ok(one);
}
let mut precise_guess = one;
let mut term = precise_guess;
let (x_minus_a, x_minus_a_negative) = self.unsigned_sub(&precise_guess);
let exponent_plus_one = exponent.checked_add(&one)?;
let mut negative = false;
for k in 1..max_iterations {
let k = Self::new(k)?;
let (current_exponent, current_exponent_negative) = exponent_plus_one.unsigned_sub(&k);
term = term.checked_mul(¤t_exponent)?;
term = term.checked_mul(&x_minus_a)?;
term = term.checked_div(&k)?;
if term.val < Self::precision() {
break;
}
if x_minus_a_negative {
negative = !negative;
}
if current_exponent_negative {
negative = !negative;
}
if negative {
precise_guess = precise_guess.checked_sub(&term)?;
} else {
precise_guess = precise_guess.checked_add(&term)?;
}
}
Ok(precise_guess)
}
#[allow(dead_code)]
fn checked_pow_fraction(&self, exponent: &Self) -> Result<Self> {
assert!(self.val >= Self::min_pow_base());
assert!(self.val <= Self::max_pow_base());
let whole_exponent = exponent.floor()?;
let precise_whole = self.checked_pow(whole_exponent.to_imprecise()?)?;
let (remainder_exponent, negative) = exponent.unsigned_sub(&whole_exponent);
assert!(!negative);
if remainder_exponent.val == InnerUint::from(0) {
return Ok(precise_whole);
}
let precise_remainder = self
.checked_pow_approximation(&remainder_exponent, Self::MAX_APPROXIMATION_ITERATIONS)?;
precise_whole.checked_mul(&precise_remainder)
}
fn newtonian_root_approximation(
&self,
root: &Self,
mut guess: Self,
iterations: u128,
) -> Result<Self> {
let zero = Self::zero();
if *self == zero {
return Ok(zero);
}
if *root == zero {
return Err(error!(ParclMathErrorCode::IntegerOverflow));
}
let one = Self::new(1)?;
let root_minus_one = root.checked_sub(&one)?;
let root_minus_one_whole = root_minus_one.to_imprecise()?;
let mut last_guess = guess;
let precision = Self::precision();
for _ in 0..iterations {
let first_term = root_minus_one.checked_mul(&guess)?;
let power = guess.checked_pow(root_minus_one_whole);
let second_term = match power {
Ok(num) => self.checked_div(&num)?,
Err(_) => Self::new(0)?,
};
guess = first_term.checked_add(&second_term)?.checked_div(root)?;
if last_guess.almost_eq(&guess, precision) {
break;
} else {
last_guess = guess;
}
}
Ok(guess)
}
fn minimum_sqrt_base() -> Self {
Self {
val: InnerUint::from(0),
}
}
fn maximum_sqrt_base() -> Self {
Self::new(std::u128::MAX).unwrap()
}
pub fn sqrt(&self) -> Result<Self> {
if self.less_than(&Self::minimum_sqrt_base())
|| self.greater_than(&Self::maximum_sqrt_base())
{
return Err(error!(ParclMathErrorCode::IntegerOverflow));
}
let two = PreciseNumber::new(2)?;
let one = PreciseNumber::new(1)?;
let guess = self.checked_add(&one)?.checked_div(&two)?;
self.newtonian_root_approximation(&two, guess, Self::MAX_APPROXIMATION_ITERATIONS)
}
}
impl Add<PreciseNumber> for PreciseNumber {
type Output = Self;
fn add(self, rhs: PreciseNumber) -> Self::Output {
self.checked_add(&rhs).unwrap()
}
}
impl Sub<PreciseNumber> for PreciseNumber {
type Output = Self;
fn sub(self, rhs: PreciseNumber) -> Self::Output {
self.checked_sub(&rhs).unwrap()
}
}
impl Mul<PreciseNumber> for PreciseNumber {
type Output = Self;
fn mul(self, rhs: PreciseNumber) -> Self::Output {
self.checked_mul(&rhs).unwrap()
}
}
impl Div<PreciseNumber> for PreciseNumber {
type Output = Self;
fn div(self, rhs: PreciseNumber) -> Self::Output {
self.checked_div(&rhs).unwrap()
}
}
impl AddAssign<PreciseNumber> for PreciseNumber {
fn add_assign(&mut self, rhs: PreciseNumber) {
self.val.add_assign(rhs.val)
}
}
impl SubAssign<PreciseNumber> for PreciseNumber {
fn sub_assign(&mut self, rhs: PreciseNumber) {
self.val.sub_assign(rhs.val)
}
}
impl MulAssign<PreciseNumber> for PreciseNumber {
fn mul_assign(&mut self, rhs: PreciseNumber) {
self.val.mul_assign(rhs.val);
self.val.div_assign(one());
}
}
impl DivAssign<PreciseNumber> for PreciseNumber {
fn div_assign(&mut self, rhs: PreciseNumber) {
self.val.mul_assign(one());
self.val.div_assign(rhs.val);
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
fn check_pow_approximation(base: InnerUint, exponent: InnerUint, expected: InnerUint) {
let precision = InnerUint::from(5_000_000); let base = PreciseNumber { val: base };
let exponent = PreciseNumber { val: exponent };
let root = base
.checked_pow_approximation(&exponent, PreciseNumber::MAX_APPROXIMATION_ITERATIONS)
.unwrap();
let expected = PreciseNumber { val: expected };
assert!(root.almost_eq(&expected, precision));
}
#[test]
fn test_root_approximation() {
let one = one();
check_pow_approximation(one / 4, one / 2, one / 2); check_pow_approximation(one * 11 / 10, one / 2, InnerUint::from(1_048808848161u128));
check_pow_approximation(one * 4 / 5, one * 2 / 5, InnerUint::from(914610103850u128));
check_pow_approximation(one / 2, one * 4 / 50, InnerUint::from(946057646730u128));
}
fn check_pow_fraction(
base: InnerUint,
exponent: InnerUint,
expected: InnerUint,
precision: InnerUint,
) {
let base = PreciseNumber { val: base };
let exponent = PreciseNumber { val: exponent };
let power = base.checked_pow_fraction(&exponent).unwrap();
let expected = PreciseNumber { val: expected };
assert!(power.almost_eq(&expected, precision));
}
#[test]
fn test_pow_fraction() {
let one = one();
let precision = InnerUint::from(50_000_000); let less_precision = precision * 1_000; check_pow_fraction(one, one, one, precision);
check_pow_fraction(
one * 20 / 13,
one * 50 / 3,
InnerUint::from(1312_534484739100u128),
precision,
); check_pow_fraction(one * 2 / 7, one * 49 / 4, InnerUint::from(2163), precision);
check_pow_fraction(
one * 5000 / 5100,
one / 9,
InnerUint::from(997802126900u128),
precision,
); check_pow_fraction(
one * 2,
one * 27 / 5,
InnerUint::from(42_224253144700u128),
less_precision,
); check_pow_fraction(
one * 18 / 10,
one * 11 / 3,
InnerUint::from(8_629769290500u128),
less_precision,
); }
#[test]
fn test_newtonian_approximation() {
let test = PreciseNumber::new(0).unwrap();
let nth_root = PreciseNumber::new(0).unwrap();
let guess = test.checked_div(&nth_root);
assert!(guess.is_err());
let test = PreciseNumber::new(9).unwrap();
let nth_root = PreciseNumber::new(2).unwrap();
let guess = test.checked_div(&nth_root).unwrap();
let root = test
.newtonian_root_approximation(
&nth_root,
guess,
PreciseNumber::MAX_APPROXIMATION_ITERATIONS,
)
.unwrap()
.to_imprecise()
.unwrap();
assert_eq!(root, 3);
let test = PreciseNumber::new(101).unwrap();
let nth_root = PreciseNumber::new(2).unwrap();
let guess = test.checked_div(&nth_root).unwrap();
let root = test
.newtonian_root_approximation(
&nth_root,
guess,
PreciseNumber::MAX_APPROXIMATION_ITERATIONS,
)
.unwrap()
.to_imprecise()
.unwrap();
assert_eq!(root, 10);
let test = PreciseNumber::new(1_000_000_000).unwrap();
let nth_root = PreciseNumber::new(2).unwrap();
let guess = test.checked_div(&nth_root).unwrap();
let root = test
.newtonian_root_approximation(
&nth_root,
guess,
PreciseNumber::MAX_APPROXIMATION_ITERATIONS,
)
.unwrap()
.to_imprecise()
.unwrap();
assert_eq!(root, 31_623);
let test = PreciseNumber::new(500).unwrap();
let nth_root = PreciseNumber::new(5).unwrap();
let guess = test.checked_div(&nth_root).unwrap();
let root = test
.newtonian_root_approximation(
&nth_root,
guess,
PreciseNumber::MAX_APPROXIMATION_ITERATIONS,
)
.unwrap()
.to_imprecise()
.unwrap();
assert_eq!(root, 3); }
#[test]
fn test_checked_mul() {
let number_one = PreciseNumber::new(0).unwrap();
let number_two = PreciseNumber::new(0).unwrap();
let result = number_one.checked_mul(&number_two).unwrap();
assert_eq!(result, PreciseNumber { val: U256::from(0) });
let number_one = PreciseNumber::new(2).unwrap();
let number_two = PreciseNumber::new(2).unwrap();
let result = number_one.checked_mul(&number_two).unwrap();
assert_eq!(result, PreciseNumber::new(2 * 2).unwrap());
let number_one = PreciseNumber { val: U256::MAX };
let number_two = PreciseNumber::new(1).unwrap();
let result = number_one.checked_mul(&number_two).unwrap();
assert_eq!(result.val, U256::MAX / one() * one());
let number_one = PreciseNumber { val: U256::MAX };
let mut number_two = PreciseNumber::new(1).unwrap();
number_two.val += U256::from(1);
let result = number_one.checked_mul(&number_two);
assert!(result.is_err());
}
fn check_square_root(check: &PreciseNumber) {
let epsilon = PreciseNumber {
val: InnerUint::from(10),
}; let one = PreciseNumber::one();
let one_plus_epsilon = one.checked_add(&epsilon).unwrap();
let one_minus_epsilon = one.checked_sub(&epsilon).unwrap();
let approximate_root = check.sqrt().unwrap();
let lower_bound = approximate_root
.checked_mul(&one_minus_epsilon)
.unwrap()
.checked_pow(2)
.unwrap();
let upper_bound = approximate_root
.checked_mul(&one_plus_epsilon)
.unwrap()
.checked_pow(2)
.unwrap();
assert!(check.less_than_or_equal(&upper_bound));
assert!(check.greater_than_or_equal(&lower_bound));
}
#[test]
fn test_square_root_min_max() {
let test_roots = [
PreciseNumber::minimum_sqrt_base(),
PreciseNumber::maximum_sqrt_base(),
];
for i in test_roots.iter() {
check_square_root(i);
}
}
#[test]
fn test_floor() {
let whole_number = PreciseNumber::new(2).unwrap();
let mut decimal_number = PreciseNumber::new(2).unwrap();
decimal_number.val += InnerUint::from(1);
let floor = decimal_number.floor().unwrap();
let floor_again = floor.floor().unwrap();
assert_eq!(whole_number.val, floor.val);
assert_eq!(whole_number.val, floor_again.val);
}
#[test]
fn test_ceiling() {
let whole_number = PreciseNumber::new(2).unwrap();
let mut decimal_number = PreciseNumber::new(2).unwrap();
decimal_number.val -= InnerUint::from(1);
let ceiling = decimal_number.ceil().unwrap();
let ceiling_again = ceiling.ceil().unwrap();
assert_eq!(whole_number.val, ceiling.val);
assert_eq!(whole_number.val, ceiling_again.val);
}
proptest! {
#[test]
fn test_square_root(a in 0..u128::MAX) {
let a = PreciseNumber { val: InnerUint::from(a) };
check_square_root(&a);
}
}
}