use crate::{DBig, OxiNumError, OxiNumResult};
use dashu_float::round::mode::HalfEven;
use std::str::FromStr;
pub fn exp(x: &DBig, precision: usize) -> OxiNumResult<DBig> {
if precision == 0 {
return Err(OxiNumError::Precision("precision must be > 0".into()));
}
let zero = DBig::from_str("0.0").map_err(|e| OxiNumError::Parse(format!("{e}").into()))?;
if *x == zero {
return DBig::from_str("1.0").map_err(|e| OxiNumError::Parse(format!("{e}").into()));
}
let guard_bits = precision * 4 + 20;
let fbig = convert_dbig_to_fbig(x, guard_bits);
let result = fbig.exp();
let dbig = fbig_to_dbig(&result, precision);
Ok(truncate_to_precision(dbig, precision))
}
pub fn ln(x: &DBig, precision: usize) -> OxiNumResult<DBig> {
if precision == 0 {
return Err(OxiNumError::Precision("precision must be > 0".into()));
}
let zero = DBig::from_str("0.0").map_err(|e| OxiNumError::Parse(format!("{e}").into()))?;
if *x <= zero {
return Err(OxiNumError::Precision("ln(x) requires x > 0".into()));
}
let one = DBig::from_str("1.0").map_err(|e| OxiNumError::Parse(format!("{e}").into()))?;
if *x == one {
return DBig::from_str("0.0").map_err(|e| OxiNumError::Parse(format!("{e}").into()));
}
let guard_bits = precision * 4 + 20;
let fbig = convert_dbig_to_fbig(x, guard_bits);
let result = fbig.ln();
let dbig = fbig_to_dbig(&result, precision);
Ok(truncate_to_precision(dbig, precision))
}
pub fn sqrt(x: &DBig, precision: usize) -> OxiNumResult<DBig> {
if precision == 0 {
return Err(OxiNumError::Precision("precision must be > 0".into()));
}
let zero = DBig::from_str("0.0").map_err(|e| OxiNumError::Parse(format!("{e}").into()))?;
if *x < zero {
return Err(OxiNumError::Precision("sqrt(x) requires x >= 0".into()));
}
if *x == zero {
return Ok(zero);
}
let guard_bits = precision * 4 + 20;
let fbig = convert_dbig_to_fbig(x, guard_bits);
let result = dashu_base::SquareRoot::sqrt(&fbig);
let dbig = fbig_to_dbig(&result, precision);
Ok(truncate_to_precision(dbig, precision))
}
pub fn pow(base: &DBig, exponent: &DBig, precision: usize) -> OxiNumResult<DBig> {
if precision == 0 {
return Err(OxiNumError::Precision("precision must be > 0".into()));
}
let zero = DBig::from_str("0.0").map_err(|e| OxiNumError::Parse(format!("{e}").into()))?;
if *exponent == zero {
return DBig::from_str("1.0").map_err(|e| OxiNumError::Parse(format!("{e}").into()));
}
if *base <= zero {
return Err(OxiNumError::Precision(
"pow(base, exp) requires base > 0".into(),
));
}
let guard_bits = precision * 4 + 20;
let fbig_base = convert_dbig_to_fbig(base, guard_bits);
let fbig_exp = convert_dbig_to_fbig(exponent, guard_bits);
let result = fbig_base.powf(&fbig_exp);
let dbig = fbig_to_dbig(&result, precision);
Ok(truncate_to_precision(dbig, precision))
}
pub(crate) fn convert_dbig_to_fbig(
value: &DBig,
binary_precision: usize,
) -> dashu_float::FBig<HalfEven, 2> {
let ctx = dashu_float::Context::<HalfEven>::new(binary_precision);
let repr = value
.clone()
.with_rounding::<HalfEven>()
.with_base_and_precision::<2>(binary_precision.max(10))
.value();
let result_repr = repr.repr().clone();
dashu_float::FBig::from_repr(result_repr, ctx)
}
pub(crate) fn fbig_to_dbig(
fbig: &dashu_float::FBig<HalfEven, 2>,
decimal_precision: usize,
) -> DBig {
if fbig.digits() == 0 {
return DBig::from_str("0.0").expect("valid literal");
}
let decimal_digits = decimal_precision.max(5);
fbig.clone()
.with_base_and_precision::<10>(decimal_digits)
.value()
.with_rounding::<dashu_float::round::mode::HalfAway>()
}
pub(crate) fn truncate_to_precision(value: DBig, precision: usize) -> DBig {
let s = value.to_string();
let truncated = truncate_decimal_str(&s, precision);
DBig::from_str(&truncated).unwrap_or(value)
}
pub(crate) fn truncate_decimal_str(src: &str, sig_digits: usize) -> String {
let mut result = String::with_capacity(sig_digits + 10);
let mut sig_count = 0;
let trimmed = src.trim_start_matches('-');
let integer_is_zero = trimmed.starts_with("0.") || trimmed == "0";
let mut seen_nonzero = !integer_is_zero;
for ch in src.chars() {
if ch == '-' {
result.push(ch);
continue;
}
if ch == '.' {
result.push(ch);
continue;
}
if ch == 'e' || ch == 'E' {
break;
}
if !ch.is_ascii_digit() {
continue;
}
if !seen_nonzero && ch == '0' {
result.push(ch);
continue;
}
seen_nonzero = true;
sig_count += 1;
result.push(ch);
if sig_count >= sig_digits {
break;
}
}
let content = result.trim_start_matches('-');
if content.is_empty() {
if result.starts_with('-') {
return "-0".to_string();
}
return "0".to_string();
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exp_of_zero() {
let x = DBig::from_str("0.0").expect("ok");
let result = exp(&x, 30).expect("ok");
let s = result.to_string();
assert!(s.starts_with("1.0000") || s == "1", "exp(0) = {s}");
}
#[test]
fn exp_of_one() {
let x = DBig::from_str("1.0").expect("ok");
let result = exp(&x, 30).expect("ok");
let s = result.to_string();
assert!(s.starts_with("2.71828"), "exp(1) = {s}");
}
#[test]
fn ln_of_one() {
let x = DBig::from_str("1.0").expect("ok");
let result = ln(&x, 30).expect("ok");
let s = result.to_string();
let s_clean = s.trim_start_matches('-');
assert!(
s_clean.starts_with("0") && !s_clean.starts_with("0.1"),
"ln(1) = {s}"
);
}
#[test]
fn ln_negative_errors() {
let x = DBig::from_str("-1.0").expect("ok");
assert!(ln(&x, 30).is_err());
}
#[test]
fn sqrt_of_four() {
let x = DBig::from_str("4.0").expect("ok");
let result = sqrt(&x, 30).expect("ok");
let s = result.to_string();
assert!(s.starts_with("2.0000") || s == "2", "sqrt(4) = {s}");
}
#[test]
fn sqrt_of_two() {
let x = DBig::from_str("2.0").expect("ok");
let result = sqrt(&x, 30).expect("ok");
let s = result.to_string();
assert!(s.starts_with("1.4142135"), "sqrt(2) = {s}");
}
#[test]
fn sqrt_negative_errors() {
let x = DBig::from_str("-1.0").expect("ok");
assert!(sqrt(&x, 30).is_err());
}
#[test]
fn sqrt_of_zero() {
let x = DBig::from_str("0.0").expect("ok");
let result = sqrt(&x, 30).expect("ok");
let s = result.to_string();
assert!(s.starts_with("0"), "sqrt(0) = {s}");
}
#[test]
fn pow_two_to_ten() {
let base = DBig::from_str("2.0").expect("ok");
let exponent = DBig::from_str("10.0").expect("ok");
let result = pow(&base, &exponent, 20).expect("ok");
let s = result.to_string();
assert!(s.starts_with("1024"), "2^10 = {s}");
}
#[test]
fn pow_zero_exponent() {
let base = DBig::from_str("5.0").expect("ok");
let exponent = DBig::from_str("0.0").expect("ok");
let result = pow(&base, &exponent, 20).expect("ok");
let s = result.to_string();
assert!(s.starts_with("1"), "5^0 = {s}");
}
#[test]
fn precision_zero_errors() {
let x = DBig::from_str("1.0").expect("ok");
assert!(exp(&x, 0).is_err());
assert!(ln(&x, 0).is_err());
assert!(sqrt(&x, 0).is_err());
assert!(pow(&x, &x, 0).is_err());
}
#[test]
fn truncate_leading_zeros() {
let s = truncate_decimal_str("0.00123456789", 5);
assert_eq!(s, "0.0012345");
}
#[test]
fn truncate_integer_part() {
let s = truncate_decimal_str("123.456789", 6);
assert_eq!(s, "123.456");
}
#[test]
fn truncate_negative() {
let s = truncate_decimal_str("-3.14159", 4);
assert_eq!(s, "-3.141");
}
}