use oxinum_core::{OxiNumError, OxiNumResult, Sign};
use super::float::{BigFloat, RoundingMode};
impl BigFloat {
pub fn exp(&self, prec: u32, mode: RoundingMode) -> OxiNumResult<BigFloat> {
assert!(prec > 0, "BigFloat precision must be > 0");
if self.is_nan() {
return Ok(BigFloat::nan(prec));
}
if self.is_infinite() {
return if self.sign == Sign::Negative {
Ok(BigFloat::zero(prec)) } else {
Ok(BigFloat::infinity(prec)) };
}
if self.is_zero() {
return Ok(BigFloat::from_i64(1, prec, mode));
}
let x_f64 = self.to_f64();
if x_f64 > 745.0 {
return Err(OxiNumError::Overflow(
"exp: argument too large (result exceeds BigFloat range)".into(),
));
}
if x_f64 < -745.0 {
return Ok(BigFloat::zero(prec));
}
let x_abs = x_f64.abs();
let log2_x_abs = if x_abs >= 1.0 {
x_abs.log2().ceil() as u64
} else {
0u64
};
let guard = 32u32 + (prec / 64 + 4);
let work_prec = prec.saturating_add(guard);
let k = log2_x_abs + (prec as u64 / 64) + 4;
let x_reduced = if self.is_zero() {
BigFloat::zero(work_prec)
} else {
BigFloat::from_parts(
self.sign,
self.mantissa.clone(),
self.exponent.saturating_sub(k as i64),
work_prec,
mode,
)
};
let n_terms = (prec / 4 + 16).max(64) as u64;
let mut result = exp_taylor(&x_reduced, n_terms, work_prec, mode)?;
for _ in 0..k {
result = result
.mul_ref_with_mode(&result.clone(), mode)
.with_precision(work_prec, mode);
}
Ok(result.with_precision(prec, mode))
}
}
fn exp_taylor(
y: &BigFloat,
n_terms: u64,
work_prec: u32,
mode: RoundingMode,
) -> OxiNumResult<BigFloat> {
let mut term = BigFloat::from_i64(1, work_prec, mode);
let mut result = term.clone();
for k in 1..=n_terms {
term = term
.mul_ref_with_mode(y, mode)
.with_precision(work_prec, mode);
let k_float = BigFloat::from_i64(k as i64, work_prec, mode);
term = term.div_ref_with_mode(&k_float, mode)?;
term = term.with_precision(work_prec, mode);
result = result.add_ref_with_mode(&term, mode);
result = result.with_precision(work_prec, mode);
if term.is_zero() {
break;
}
if !result.is_zero() && !term.is_zero() {
let term_top = term
.exponent
.saturating_add(term.mantissa.bit_length() as i64 - 1);
let result_top = result
.exponent
.saturating_add(result.mantissa.bit_length() as i64 - 1);
if term_top < result_top.saturating_sub(work_prec as i64 + 8) {
break;
}
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::native::{e_const, RoundingMode};
fn mk(n: i64, prec: u32) -> BigFloat {
BigFloat::from_i64(n, prec, RoundingMode::HalfEven)
}
#[test]
fn exp_zero_is_one() {
let x = mk(0, 64);
let result = x.exp(64, RoundingMode::HalfEven).expect("exp(0)");
let one = mk(1, 64);
assert_eq!(result, one, "exp(0) should == 1");
}
#[test]
fn exp_one_approx_e() {
let x = mk(1, 100);
let result = x.exp(100, RoundingMode::HalfEven).expect("exp(1)");
let e = e_const(100).expect("e_const(100)");
let diff = (result.to_f64() - e.to_f64()).abs();
assert!(diff < 1e-14, "exp(1) diff from e: {diff}");
}
#[test]
fn exp_overflow() {
let x = BigFloat::from_f64(800.0, 64).expect("from_f64");
let result = x.exp(64, RoundingMode::HalfEven);
assert!(
matches!(result, Err(OxiNumError::Overflow(_))),
"Expected Overflow, got: {result:?}"
);
}
#[test]
fn exp_large_negative_returns_zero() {
let x = BigFloat::from_f64(-800.0, 64).expect("from_f64");
let result = x.exp(64, RoundingMode::HalfEven).expect("exp(-800)");
assert!(result.is_zero(), "exp(-800) should be zero");
}
#[test]
fn exp_small_value_cross_val() {
let x = BigFloat::from_f64(0.5, 64).expect("0.5");
let result = x.exp(64, RoundingMode::HalfEven).expect("exp(0.5)");
let expected = 0.5_f64.exp();
let diff = (result.to_f64() - expected).abs();
assert!(diff < 1e-14, "exp(0.5) diff: {diff}");
}
}