use super::decimal_compute::{
ComputeStorage, DECIMAL_COMPUTE_DP,
decimal_compute_zero, decimal_compute_one,
decimal_compute_add, decimal_compute_sub, decimal_compute_mul, decimal_compute_div,
decimal_compute_div_int, decimal_compute_halve,
decimal_compute_is_zero, decimal_compute_is_negative, decimal_compute_cmp,
decimal_compute_neg, pow10_compute_ct,
};
use crate::fixed_point::domains::symbolic::rational::rational_number::OverflowDetected;
use std::cell::RefCell;
thread_local! {
static LN2_FOR_EXP: RefCell<Option<ComputeStorage>> = const { RefCell::new(None) };
}
fn ln2_at_compute() -> Result<ComputeStorage, OverflowDetected> {
let cached: Option<ComputeStorage> = LN2_FOR_EXP.with(|c| c.borrow().clone());
if let Some(v) = cached {
return Ok(v);
}
let one = decimal_compute_one();
let three = super::decimal_compute::decimal_compute_from_int(3);
let s = decimal_compute_div(one, three)?;
let s_sq = decimal_compute_mul(s, s);
let mut term = s;
let mut sum = s;
let max_terms = (DECIMAL_COMPUTE_DP as u32) + 20;
for k in 1..=max_terms {
term = decimal_compute_mul(term, s_sq);
if decimal_compute_is_zero(&term) { break; }
let divisor = (2 * k as u64) + 1;
let contribution = decimal_compute_div_int(term, divisor);
if decimal_compute_is_zero(&contribution) { break; }
sum = decimal_compute_add(sum, contribution);
}
let ln2 = decimal_compute_add(sum, sum);
LN2_FOR_EXP.with(|c: &RefCell<Option<ComputeStorage>>| *c.borrow_mut() = Some(ln2));
Ok(ln2)
}
const fn max_taylor_terms() -> u32 {
(DECIMAL_COMPUTE_DP as u32) + 20
}
fn round_to_int(v: ComputeStorage) -> Result<i64, OverflowDetected> {
let half = decimal_compute_halve(decimal_compute_one());
let rounded = if decimal_compute_is_negative(&v) {
decimal_compute_sub(v, half)
} else {
decimal_compute_add(v, half)
};
let scale = decimal_compute_one();
#[cfg(table_format = "q16_16")]
{ Ok(rounded / scale) }
#[cfg(table_format = "q32_32")]
{ Ok((rounded / scale) as i64) }
#[cfg(table_format = "q64_64")]
{
let q = rounded / scale;
if !q.fits_in_i128() { return Err(OverflowDetected::Overflow); }
let q_i128 = q.as_i128();
if q_i128 > i64::MAX as i128 || q_i128 < i64::MIN as i128 {
return Err(OverflowDetected::Overflow);
}
Ok(q_i128 as i64)
}
#[cfg(table_format = "q128_128")]
{
let q = rounded / scale;
let q_i128 = q.as_i128();
if q_i128 > i64::MAX as i128 || q_i128 < i64::MIN as i128 {
return Err(OverflowDetected::Overflow);
}
Ok(q_i128 as i64)
}
#[cfg(table_format = "q256_256")]
{
let q = rounded / scale;
let q_i128 = q.as_i128();
if q_i128 > i64::MAX as i128 || q_i128 < i64::MIN as i128 {
return Err(OverflowDetected::Overflow);
}
Ok(q_i128 as i64)
}
}
fn mul_by_pow2(v: ComputeStorage, k: i64) -> Result<ComputeStorage, OverflowDetected> {
if k == 0 {
return Ok(v);
}
if k > 0 {
if k > 200 {
return Err(OverflowDetected::Overflow);
}
#[cfg(table_format = "q16_16")]
{ Ok(v << (k as u32)) }
#[cfg(table_format = "q32_32")]
{ Ok(v << (k as u32)) }
#[cfg(table_format = "q64_64")]
{ Ok(v << (k as usize)) }
#[cfg(table_format = "q128_128")]
{ Ok(v << (k as usize)) }
#[cfg(table_format = "q256_256")]
{ Ok(v << (k as usize)) }
} else {
let abs_k = (-k) as u64;
if abs_k >= 250 {
return Ok(decimal_compute_zero());
}
let one_compute = {
#[cfg(table_format = "q16_16")]
{ 1i64 }
#[cfg(table_format = "q32_32")]
{ 1i128 }
#[cfg(table_format = "q64_64")]
{ crate::fixed_point::i256::I256::from_i128(1) }
#[cfg(table_format = "q128_128")]
{ crate::fixed_point::i512::I512::from_i128(1) }
#[cfg(table_format = "q256_256")]
{ crate::fixed_point::I1024::from_i128(1) }
};
let round_bit = {
#[cfg(table_format = "q16_16")]
{ one_compute << ((abs_k - 1) as u32) }
#[cfg(table_format = "q32_32")]
{ one_compute << ((abs_k - 1) as u32) }
#[cfg(table_format = "q64_64")]
{ one_compute << ((abs_k - 1) as usize) }
#[cfg(table_format = "q128_128")]
{ one_compute << ((abs_k - 1) as usize) }
#[cfg(table_format = "q256_256")]
{ one_compute << ((abs_k - 1) as usize) }
};
let rounded = if decimal_compute_is_negative(&v) {
v - round_bit
} else {
v + round_bit
};
#[cfg(table_format = "q16_16")]
{ Ok(rounded >> (abs_k as u32)) }
#[cfg(table_format = "q32_32")]
{ Ok(rounded >> (abs_k as u32)) }
#[cfg(table_format = "q64_64")]
{ Ok(rounded >> (abs_k as u32)) }
#[cfg(table_format = "q128_128")]
{ Ok(rounded >> (abs_k as usize)) }
#[cfg(table_format = "q256_256")]
{ Ok(rounded >> (abs_k as usize)) }
}
}
struct ExpTables {
exp_int: [ComputeStorage; 31],
exp_tenths: [ComputeStorage; 10],
exp_hundredths: [ComputeStorage; 10],
exp_thousandths: [ComputeStorage; 10],
}
impl Clone for ExpTables {
fn clone(&self) -> Self {
Self {
exp_int: self.exp_int,
exp_tenths: self.exp_tenths,
exp_hundredths: self.exp_hundredths,
exp_thousandths: self.exp_thousandths,
}
}
}
thread_local! {
static EXP_TABLES: RefCell<Option<ExpTables>> = const { RefCell::new(None) };
}
fn exp_taylor_raw(x: ComputeStorage) -> Result<ComputeStorage, OverflowDetected> {
let one = decimal_compute_one();
let mut term = one;
let mut sum = one;
let max_terms = max_taylor_terms();
for n in 1..=max_terms {
term = decimal_compute_mul(term, x);
term = decimal_compute_div_int(term, n as u64);
if decimal_compute_is_zero(&term) { break; }
sum = decimal_compute_add(sum, term);
}
Ok(sum)
}
fn compute_as_i64(v: ComputeStorage) -> i64 {
#[cfg(table_format = "q16_16")]
{ v }
#[cfg(table_format = "q32_32")]
{ v as i64 }
#[cfg(table_format = "q64_64")]
{ v.as_i128() as i64 }
#[cfg(table_format = "q128_128")]
{ v.as_i128() as i64 }
#[cfg(table_format = "q256_256")]
{ v.as_i128() as i64 }
}
fn build_exp_tables() -> Result<ExpTables, OverflowDetected> {
let one = decimal_compute_one();
let zero = decimal_compute_zero();
let e1 = exp_taylor_raw(one)?;
let mut exp_int = [zero; 31];
exp_int[0] = one;
exp_int[1] = e1;
for k in 2..=30usize {
exp_int[k] = decimal_compute_mul(exp_int[k - 1], e1);
}
let mut exp_tenths = [zero; 10];
exp_tenths[0] = one;
let pt1 = pow10_compute_ct(DECIMAL_COMPUTE_DP - 1); let e_pt1 = exp_taylor_raw(pt1)?;
exp_tenths[1] = e_pt1;
for d in 2..=9usize {
exp_tenths[d] = decimal_compute_mul(exp_tenths[d - 1], e_pt1);
}
let mut exp_hundredths = [zero; 10];
exp_hundredths[0] = one;
let pt01 = pow10_compute_ct(DECIMAL_COMPUTE_DP - 2);
let e_pt01 = exp_taylor_raw(pt01)?;
exp_hundredths[1] = e_pt01;
for d in 2..=9usize {
exp_hundredths[d] = decimal_compute_mul(exp_hundredths[d - 1], e_pt01);
}
let mut exp_thousandths = [zero; 10];
exp_thousandths[0] = one;
let pt001 = pow10_compute_ct(DECIMAL_COMPUTE_DP - 3);
let e_pt001 = exp_taylor_raw(pt001)?;
exp_thousandths[1] = e_pt001;
for d in 2..=9usize {
exp_thousandths[d] = decimal_compute_mul(exp_thousandths[d - 1], e_pt001);
}
Ok(ExpTables { exp_int, exp_tenths, exp_hundredths, exp_thousandths })
}
fn decimal_exp_table_path(abs_x: ComputeStorage) -> Result<ComputeStorage, OverflowDetected> {
let one = decimal_compute_one();
let s1 = pow10_compute_ct(DECIMAL_COMPUTE_DP - 1);
let s2 = pow10_compute_ct(DECIMAL_COMPUTE_DP - 2);
let s3 = pow10_compute_ct(DECIMAL_COMPUTE_DP - 3);
let k_compute = abs_x / one;
let k = compute_as_i64(k_compute);
let frac = decimal_compute_sub(abs_x, mul_compute_by_int(one, k));
let d1_compute = frac / s1;
let d1 = compute_as_i64(d1_compute);
let frac2 = decimal_compute_sub(frac, mul_compute_by_int(s1, d1));
let d2_compute = frac2 / s2;
let d2 = compute_as_i64(d2_compute);
let frac3 = decimal_compute_sub(frac2, mul_compute_by_int(s2, d2));
let d3_compute = frac3 / s3;
let d3 = compute_as_i64(d3_compute);
let remainder = decimal_compute_sub(frac3, mul_compute_by_int(s3, d3));
if k < 0 || k > 30 || d1 < 0 || d1 > 9 || d2 < 0 || d2 > 9 || d3 < 0 || d3 > 9 {
return decimal_exp_ln2_reduction(abs_x);
}
EXP_TABLES.with(|c| {
{
let cached = c.borrow();
if cached.is_none() {
drop(cached);
let tables = build_exp_tables()?;
*c.borrow_mut() = Some(tables);
}
}
let tables = c.borrow();
let t = tables.as_ref().unwrap();
let mut result = t.exp_int[k as usize];
result = decimal_compute_mul(result, t.exp_tenths[d1 as usize]);
result = decimal_compute_mul(result, t.exp_hundredths[d2 as usize]);
result = decimal_compute_mul(result, t.exp_thousandths[d3 as usize]);
if !decimal_compute_is_zero(&remainder) {
let exp_r = exp_taylor_raw(remainder)?;
result = decimal_compute_mul(result, exp_r);
}
Ok(result)
})
}
fn decimal_exp_ln2_reduction(x: ComputeStorage) -> Result<ComputeStorage, OverflowDetected> {
let ln2 = ln2_at_compute()?;
let x_over_ln2 = decimal_compute_div(x, ln2)?;
let k = round_to_int(x_over_ln2)?;
let k_times_ln2 = mul_compute_by_int(ln2, k);
let r = decimal_compute_sub(x, k_times_ln2);
let exp_r = exp_taylor_raw(r)?;
mul_by_pow2(exp_r, k)
}
pub fn decimal_exp(x: ComputeStorage) -> Result<ComputeStorage, OverflowDetected> {
if decimal_compute_is_zero(&x) {
return Ok(decimal_compute_one());
}
let is_neg = decimal_compute_is_negative(&x);
let abs_x = if is_neg { decimal_compute_neg(x) } else { x };
let one = decimal_compute_one();
let thirty = mul_compute_by_int(one, 30);
let use_table = decimal_compute_cmp(&abs_x, &thirty) != std::cmp::Ordering::Greater;
let result = if use_table {
decimal_exp_table_path(abs_x)?
} else {
return decimal_exp_ln2_reduction(x);
};
if is_neg {
decimal_compute_div(one, result)
} else {
Ok(result)
}
}
fn mul_compute_by_int(v: ComputeStorage, n: i64) -> ComputeStorage {
if n == 0 {
return decimal_compute_zero();
}
let negative = n < 0;
let n_abs = n.unsigned_abs();
let n_compute: ComputeStorage = {
#[cfg(table_format = "q16_16")]
{ n_abs as i64 }
#[cfg(table_format = "q32_32")]
{ n_abs as i128 }
#[cfg(table_format = "q64_64")]
{ crate::fixed_point::i256::I256::from_i128(n_abs as i128) }
#[cfg(table_format = "q128_128")]
{ crate::fixed_point::i512::I512::from_i128(n_abs as i128) }
#[cfg(table_format = "q256_256")]
{ crate::fixed_point::I1024::from_i128(n_abs as i128) }
};
let result = v * n_compute;
if negative {
decimal_compute_neg(result)
} else {
result
}
}
#[allow(dead_code)]
pub fn decimal_exp_neg(x: ComputeStorage) -> Result<ComputeStorage, OverflowDetected> {
decimal_exp(decimal_compute_neg(x))
}
pub fn decimal_sinhcosh(x: ComputeStorage) -> Result<(ComputeStorage, ComputeStorage), OverflowDetected> {
if decimal_compute_is_zero(&x) {
return Ok((decimal_compute_zero(), decimal_compute_one()));
}
let ep = decimal_exp(x)?;
let en = decimal_exp(decimal_compute_neg(x))?;
let sinh_c = decimal_compute_halve(decimal_compute_sub(ep, en));
let cosh_c = decimal_compute_halve(decimal_compute_add(ep, en));
Ok((sinh_c, cosh_c))
}
#[cfg(all(test, table_format = "q64_64"))]
mod tests {
use super::*;
use super::super::decimal_compute::{decimal_compute_from_int, pow10_compute_ct};
use crate::fixed_point::i256::I256;
#[test]
fn exp_zero_is_one() {
let result = decimal_exp(decimal_compute_zero()).unwrap();
assert_eq!(result, decimal_compute_one());
}
#[test]
fn exp_one_within_1_ulp() {
let result = decimal_exp(decimal_compute_one()).unwrap();
let expected_str = "271828182845904523536028747135266249776";
let expected = parse_decimal_str_q64_64(expected_str);
let diff = if result > expected { result - expected } else { expected - result };
let tolerance = I256::from_i128(10_000);
assert!(
diff < tolerance,
"exp(1) precision check: got={:?}, expected={:?}, diff={:?}",
result, expected, diff
);
}
#[test]
fn exp_half_within_1_ulp() {
let half = pow10_compute_ct(37) * I256::from_i128(5);
let result = decimal_exp(half).unwrap();
let expected_str = "164872127070012814684865078781416357165";
let expected = parse_decimal_str_q64_64(expected_str);
let diff = if result > expected { result - expected } else { expected - result };
let tolerance = I256::from_i128(10_000);
assert!(
diff < tolerance,
"exp(0.5) precision check: got={:?}, expected={:?}, diff={:?}",
result, expected, diff
);
}
fn parse_decimal_str_q64_64(s: &str) -> I256 {
let mut result = I256::from_i128(0);
let ten = I256::from_i128(10);
for ch in s.chars() {
let digit = ch.to_digit(10).expect("non-digit in test string");
result = result * ten + I256::from_i128(digit as i128);
}
result
}
#[test]
fn exp_two_within_reasonable() {
let two = decimal_compute_from_int(2);
let result = decimal_exp(two).unwrap();
let expected_str = "738905609893065022723042746057500781318";
let expected = parse_decimal_str_q64_64(expected_str);
let diff = if result > expected { result - expected } else { expected - result };
let tolerance = I256::from_i128(100_000);
assert!(
diff < tolerance,
"exp(2) precision check: got={:?}, expected={:?}, diff={:?}",
result, expected, diff
);
}
#[test]
fn exp_neg_one() {
use crate::canonical::{gmath, evaluate};
let result = evaluate(&gmath("-1.0").exp()).unwrap();
let s = format!("{}", result);
assert!(s.starts_with("0.3678794411714"),
"exp(-1) at storage tier should match mpmath to 13+ digits, got: {}", s);
}
}