use oxinum_core::{OxiNumError, OxiNumResult, Sign};
use super::constants::{ln2, pi};
use super::float::{BigFloat, RoundingMode};
impl BigFloat {
pub fn ln_agm(&self, prec: u32, mode: RoundingMode) -> OxiNumResult<BigFloat> {
if prec == 0 {
return Err(OxiNumError::Domain("ln_agm: precision must be > 0".into()));
}
if self.is_nan() {
return Ok(BigFloat::nan(prec));
}
if self.is_infinite() {
return if self.sign == Sign::Negative {
Ok(BigFloat::nan(prec)) } else {
Ok(BigFloat::infinity(prec)) };
}
if self.is_zero() {
return Err(OxiNumError::Domain("ln_agm of zero is undefined".into()));
}
if self.sign == Sign::Negative {
return Err(OxiNumError::Domain(
"ln_agm of a negative number is undefined for real BigFloat".into(),
));
}
{
let one = BigFloat::from_i64(1, prec, mode);
if self == &one {
return Ok(BigFloat::zero(prec));
}
}
let guard = 64u32;
let work_prec = prec.saturating_add(guard);
let cur_log2: i64 = self
.exponent
.saturating_add(self.mantissa.bit_length() as i64 - 1);
let target_log2 = (work_prec / 2 + 10) as i64;
let shift_k: i64 = target_log2 - cur_log2;
let s = {
let new_exp = self.exponent.saturating_add(shift_k);
BigFloat::from_parts(
Sign::Positive,
self.mantissa.clone(),
new_exp,
work_prec,
mode,
)
};
let ln_s = agm_ln_large(&s, work_prec, mode)?;
let ln_result = if shift_k == 0 {
ln_s
} else {
let ln2_val = ln2(work_prec)?;
let shift_f = BigFloat::from_i64(shift_k, work_prec, mode);
let correction = shift_f.mul_ref_with_mode(&ln2_val, mode);
ln_s.sub_ref_with_mode(&correction, mode)
};
Ok(ln_result.with_precision(prec, mode))
}
}
fn agm_ln_large(s: &BigFloat, work_prec: u32, mode: RoundingMode) -> OxiNumResult<BigFloat> {
let a = BigFloat::from_i64(1, work_prec, mode);
let four = BigFloat::from_i64(4, work_prec, mode);
let b = four.div_ref_with_mode(s, mode)?;
let agm_val = agm_iterate(a, b, work_prec, mode)?;
let pi_val = pi(work_prec)?;
let two = BigFloat::from_i64(2, work_prec, mode);
let two_agm = two.mul_ref_with_mode(&agm_val, mode);
let result = pi_val.div_ref_with_mode(&two_agm, mode)?;
Ok(result)
}
fn agm_iterate(
mut a: BigFloat,
mut b: BigFloat,
work_prec: u32,
mode: RoundingMode,
) -> OxiNumResult<BigFloat> {
let max_iters: u32 = if work_prec <= 64 {
10
} else {
let n = (work_prec as f64).log2().ceil() as u32 + 8;
n.min(64)
};
let threshold = -((work_prec as i64) - 4);
for _ in 0..max_iters {
let sum = a.add_ref_with_mode(&b, mode);
let a_new = BigFloat::from_parts(
sum.sign(),
sum.mantissa().clone(),
sum.exponent().saturating_sub(1),
work_prec,
mode,
);
let product = a.mul_ref_with_mode(&b, mode);
let b_new = product.sqrt(work_prec, mode)?;
let diff = a_new.sub_ref_with_mode(&b_new, mode).abs();
if diff.is_zero() {
return Ok(a_new);
}
let top_bit_pos = diff
.exponent()
.saturating_add(diff.mantissa().bit_length() as i64 - 1);
if top_bit_pos < threshold {
return Ok(a_new);
}
a = a_new;
b = b_new;
}
Ok(a)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::native::RoundingMode;
fn mk(n: i64, prec: u32) -> BigFloat {
BigFloat::from_i64(n, prec, RoundingMode::HalfEven)
}
fn approx_eq_bits(a: &BigFloat, b: &BigFloat, tol_bits: u32) -> bool {
let diff = a.sub_ref_with_mode(b, RoundingMode::HalfEven).abs();
if diff.is_zero() {
return true;
}
let top_bit_pos = diff
.exponent()
.saturating_add(diff.mantissa().bit_length() as i64 - 1);
top_bit_pos < -(tol_bits as i64)
}
#[test]
fn ln_agm_one_is_zero() {
let x = mk(1, 100);
let result = x.ln_agm(100, RoundingMode::HalfEven).expect("ln_agm(1)");
assert!(result.is_zero(), "ln_agm(1) should be 0, got: {result:?}");
}
#[test]
fn ln_agm_zero_is_domain_error() {
let x = mk(0, 64);
let result = x.ln_agm(64, RoundingMode::HalfEven);
assert!(
matches!(result, Err(OxiNumError::Domain(_))),
"Expected Domain error for ln_agm(0), got: {result:?}"
);
}
#[test]
fn ln_agm_negative_is_domain_error() {
let x = mk(-1, 64);
let result = x.ln_agm(64, RoundingMode::HalfEven);
assert!(
matches!(result, Err(OxiNumError::Domain(_))),
"Expected Domain error for ln_agm(-1), got: {result:?}"
);
}
#[test]
fn ln_agm_prec_zero_is_domain_error() {
let x = mk(2, 64);
let result = x.ln_agm(0, RoundingMode::HalfEven);
assert!(
matches!(result, Err(OxiNumError::Domain(_))),
"Expected Domain error for prec=0, got: {result:?}"
);
}
#[test]
fn ln_agm_e_is_approximately_one() {
use crate::native::e_const;
let prec = 100u32;
let e = e_const(prec).expect("e_const");
let result = e.ln_agm(60, RoundingMode::HalfEven).expect("ln_agm(e)");
let expected = mk(1, 60);
assert!(
approx_eq_bits(&result, &expected, 45),
"ln_agm(e) should be ≈ 1, got: {} (diff from 1: {})",
result.to_f64(),
(result.to_f64() - 1.0).abs()
);
}
#[test]
fn ln_agm_matches_newton_ln() {
let prec = 80u32;
let tol = 60u32; let mode = RoundingMode::HalfEven;
for n in [2i64, 7, 100, 1000] {
let x = BigFloat::from_i64(n, prec + 64, mode);
let ln_newton = x.ln(prec, mode).expect("newton ln");
let ln_agm = x.ln_agm(prec, mode).expect("agm ln");
assert!(
approx_eq_bits(&ln_newton, &ln_agm, tol),
"ln_agm({n}) vs Newton mismatch: newton={}, agm={}",
ln_newton.to_f64(),
ln_agm.to_f64()
);
}
}
#[test]
fn ln_agm_small_fraction() {
use crate::native::ln2 as ln2_const;
let prec = 80u32;
let mode = RoundingMode::HalfEven;
let half = BigFloat::from_parts(
Sign::Positive,
oxinum_int::native::BigUint::one(),
-1,
prec,
mode,
);
let ln_half = half.ln_agm(prec, mode).expect("ln_agm(0.5)");
let neg_ln2 = ln2_const(prec).expect("ln2").neg();
assert!(
approx_eq_bits(&ln_half, &neg_ln2, 60),
"ln_agm(0.5) should be -ln(2), got: {}",
ln_half.to_f64()
);
}
}