use crate::prelude::*;
const TWO: Decimal = Decimal::from_parts_raw(2, 0, 0, 0);
const PI: Decimal = Decimal::from_parts_raw(1102470953, 185874565, 1703060790, 1835008);
const LN2: Decimal = Decimal::from_parts_raw(2831677809, 328455696, 3757558395, 1900544);
const EXP_TOLERANCE: Decimal = Decimal::from_parts(2, 0, 0, false, 7);
pub trait MathematicalOps {
fn exp(&self) -> Decimal;
fn exp_with_tolerance(&self, tolerance: Decimal) -> Decimal;
fn powi(&self, exp: u64) -> Decimal;
fn checked_powi(&self, exp: u64) -> Option<Decimal>;
fn sqrt(&self) -> Option<Decimal>;
fn ln(&self) -> Decimal;
fn erf(&self) -> Decimal;
fn norm_cdf(&self) -> Decimal;
fn norm_pdf(&self) -> Decimal;
}
impl MathematicalOps for Decimal {
fn exp(&self) -> Decimal {
self.exp_with_tolerance(EXP_TOLERANCE)
}
#[inline]
fn exp_with_tolerance(&self, tolerance: Decimal) -> Decimal {
if self == &Decimal::ZERO {
return Decimal::ONE;
}
let mut term = *self;
let mut result = self + Decimal::ONE;
let mut prev_result: Option<Decimal> = None;
let mut factorial = Decimal::ONE;
let mut n = 2;
while (prev_result.is_none() || (result - prev_result.unwrap()).abs() > tolerance) && n < 24 {
prev_result = Some(result);
term = self * term.round_dp(8);
factorial *= Decimal::from_parts_raw(n, 0, 0, 0);
result += (term / factorial).round_dp(8);
n += 1;
}
result
}
fn powi(&self, exp: u64) -> Decimal {
match self.checked_powi(exp) {
Some(result) => result,
None => panic!("Pow overflowed"),
}
}
fn checked_powi(&self, exp: u64) -> Option<Decimal> {
match exp {
0 => Some(Decimal::ONE),
1 => Some(*self),
2 => self.checked_mul(*self),
_ => {
let squared = match self.checked_mul(*self) {
Some(s) => s,
None => return None,
};
let iter = core::iter::repeat(squared);
let mut product = Decimal::ONE;
for x in iter.take((exp >> 1) as usize) {
match product.checked_mul(x) {
Some(r) => product = r,
None => return None,
};
}
if exp & 0x1 > 0 {
self.checked_mul(product)
} else {
Some(product)
}
}
}
}
fn sqrt(&self) -> Option<Decimal> {
if self.is_sign_negative() {
return None;
}
if self.is_zero() {
return Some(Decimal::ZERO);
}
let mut result = self / TWO;
if result.is_zero() {
result = *self;
}
let mut last = result + Decimal::ONE;
let mut circuit_breaker = 0;
while last != result {
circuit_breaker += 1;
assert!(circuit_breaker < 1000, "geo mean circuit breaker");
last = result;
result = (result + self / result) / TWO;
}
Some(result)
}
fn ln(&self) -> Decimal {
if self.is_sign_positive() {
if self == &Decimal::ONE {
Decimal::ZERO
} else {
let s = self * Decimal::new(256, 0);
let arith_geo_mean = arithmetic_geo_mean_of_2(&Decimal::ONE, &(Decimal::new(4, 0) / s));
PI / (arith_geo_mean * TWO) - (Decimal::new(8, 0) * LN2)
}
} else {
Decimal::ZERO
}
}
fn erf(&self) -> Decimal {
if self.is_sign_positive() {
let one = &Decimal::ONE;
let xa1 = self * Decimal::from_parts(705230784, 0, 0, false, 10);
let xa2 = self.powi(2) * Decimal::from_parts(422820123, 0, 0, false, 10);
let xa3 = self.powi(3) * Decimal::from_parts(92705272, 0, 0, false, 10);
let xa4 = self.powi(4) * Decimal::from_parts(1520143, 0, 0, false, 10);
let xa5 = self.powi(5) * Decimal::from_parts(2765672, 0, 0, false, 10);
let xa6 = self.powi(6) * Decimal::from_parts(430638, 0, 0, false, 10);
let sum = one + xa1 + xa2 + xa3 + xa4 + xa5 + xa6;
one - (one / sum.powi(16))
} else {
-self.abs().erf()
}
}
fn norm_cdf(&self) -> Decimal {
(Decimal::ONE + (self / Decimal::from_parts(2318911239, 3292722, 0, false, 16)).erf()) / TWO
}
fn norm_pdf(&self) -> Decimal {
let sqrt2pi = Decimal::from_parts_raw(2133383024, 2079885984, 1358845910, 1835008);
(-self.powi(2) / TWO).exp() / sqrt2pi
}
}
fn arithmetic_geo_mean_of_2(a: &Decimal, b: &Decimal) -> Decimal {
const TOLERANCE: Decimal = Decimal::from_parts(5, 0, 0, false, 7);
let diff = (a - b).abs();
if diff < TOLERANCE {
*a
} else {
arithmetic_geo_mean_of_2(&mean_of_2(a, b), &geo_mean_of_2(a, b))
}
}
fn mean_of_2(a: &Decimal, b: &Decimal) -> Decimal {
(a + b) / TWO
}
fn geo_mean_of_2(a: &Decimal, b: &Decimal) -> Decimal {
(a * b).sqrt().unwrap()
}
#[cfg(test)]
mod test {
use super::*;
use std::str::FromStr;
#[test]
fn test_geo_mean_of_2() {
let test_cases = &[
(
Decimal::from_str("2").unwrap(),
Decimal::from_str("2").unwrap(),
Decimal::from_str("2").unwrap(),
),
(
Decimal::from_str("4").unwrap(),
Decimal::from_str("3").unwrap(),
Decimal::from_str("3.4641016151377545870548926830").unwrap(),
),
(
Decimal::from_str("12").unwrap(),
Decimal::from_str("3").unwrap(),
Decimal::from_str("6.000000000000000000000000000").unwrap(),
),
];
for case in test_cases {
assert_eq!(case.2, geo_mean_of_2(&case.0, &case.1));
}
}
#[test]
fn test_mean_of_2() {
let test_cases = &[
(
Decimal::from_str("2").unwrap(),
Decimal::from_str("2").unwrap(),
Decimal::from_str("2").unwrap(),
),
(
Decimal::from_str("4").unwrap(),
Decimal::from_str("3").unwrap(),
Decimal::from_str("3.5").unwrap(),
),
(
Decimal::from_str("12").unwrap(),
Decimal::from_str("3").unwrap(),
Decimal::from_str("7.5").unwrap(),
),
];
for case in test_cases {
assert_eq!(case.2, mean_of_2(&case.0, &case.1));
}
}
}