use oxinum_core::{OxiNumError, OxiNumResult, Sign};
use super::constants::pi;
use super::float::{BigFloat, RoundingMode};
fn half_angle_reduce(x: BigFloat, prec: u32, mode: RoundingMode) -> OxiNumResult<(BigFloat, u32)> {
let one = BigFloat::from_i64(1, prec, mode);
let mut u = x;
let mut m = 0u32;
loop {
let abs_f64 = u.abs().to_f64();
if abs_f64 < f64::from_bits(0x3e00_0000_0000_0000u64) {
break;
}
if m >= 300 {
break;
}
let u_sq = u.mul_ref_with_mode(&u, mode).with_precision(prec, mode);
let one_plus_u_sq = one.add_ref_with_mode(&u_sq, mode);
let sqrt_val = one_plus_u_sq.sqrt(prec, mode)?;
let denom = one.add_ref_with_mode(&sqrt_val, mode);
u = u.div_ref_with_mode(&denom, mode)?;
u = u.with_precision(prec, mode);
m += 1;
}
Ok((u, m))
}
fn atan_taylor(u: &BigFloat, prec: u32, mode: RoundingMode) -> OxiNumResult<BigFloat> {
if u.is_zero() {
return Ok(BigFloat::zero(prec));
}
let u_sq = u.mul_ref_with_mode(u, mode).with_precision(prec, mode);
let mut term = u.clone().with_precision(prec, mode);
let mut result = term.clone();
let n_terms: u64 = (prec as u64) / 64 + 8;
for k in 1..=n_terms {
let numer = term
.mul_ref_with_mode(&u_sq, mode)
.with_precision(prec, mode);
term = numer.neg().with_precision(prec, mode);
let denom_val = 2 * k + 1;
let denom_i64 = denom_val.min(i64::MAX as u64) as i64;
let denom_f = BigFloat::from_i64(denom_i64, prec, mode);
let scaled = term
.div_ref_with_mode(&denom_f, mode)
.map_err(|e| OxiNumError::Precision(format!("atan_taylor denom zero: {e}").into()))?;
result = result
.add_ref_with_mode(&scaled, mode)
.with_precision(prec, mode);
}
Ok(result.with_precision(prec, mode))
}
impl BigFloat {
pub fn atan(&self, prec: u32, mode: RoundingMode) -> OxiNumResult<BigFloat> {
if self.is_nan() {
return Ok(BigFloat::nan(prec));
}
if self.is_infinite() {
let pi_val = pi(prec.saturating_add(16))?;
let two = BigFloat::from_i64(2, prec.saturating_add(16), mode);
let half_pi = pi_val.div_ref_with_mode(&two, mode)?;
return if self.sign == Sign::Negative {
Ok(half_pi.neg().with_precision(prec, mode))
} else {
Ok(half_pi.with_precision(prec, mode))
};
}
if self.is_zero() {
return Ok(BigFloat::zero(prec));
}
let work_prec = prec.saturating_add(64);
let orig_sign = self.signum();
let abs_x = self.abs().with_precision(work_prec, mode);
let one = BigFloat::from_i64(1, work_prec, mode);
let (working_x, use_complement): (BigFloat, bool) = if abs_x > one.clone() {
let inv = one.div_ref_with_mode(&abs_x, mode)?;
(inv.with_precision(work_prec, mode), true)
} else {
(abs_x, false)
};
let (reduced, m) = half_angle_reduce(working_x, work_prec, mode)?;
let mut result = atan_taylor(&reduced, work_prec, mode)?;
let two = BigFloat::from_i64(2, work_prec, mode);
for _ in 0..m {
result = result
.mul_ref_with_mode(&two, mode)
.with_precision(work_prec, mode);
}
if use_complement {
let pi_val = pi(work_prec)?;
let two_wp = BigFloat::from_i64(2, work_prec, mode);
let pi_over_2 = pi_val.div_ref_with_mode(&two_wp, mode)?;
result = pi_over_2
.sub_ref_with_mode(&result, mode)
.with_precision(work_prec, mode);
}
if orig_sign < 0 {
result = result.neg();
}
Ok(result.with_precision(prec, mode))
}
pub fn atan2(&self, x: &BigFloat, prec: u32, mode: RoundingMode) -> OxiNumResult<BigFloat> {
let y = self;
if y.is_nan() || x.is_nan() {
return Ok(BigFloat::nan(prec));
}
if y.is_infinite() && x.is_infinite() {
let pi_val = pi(prec.saturating_add(16))?;
let four = BigFloat::from_i64(4, prec.saturating_add(16), mode);
let pi_over_4 = pi_val
.div_ref_with_mode(&four, mode)?
.with_precision(prec, mode);
let three_pi_over_4 = {
let three = BigFloat::from_i64(3, prec, mode);
three
.mul_ref_with_mode(&pi_over_4, mode)
.with_precision(prec, mode)
};
let (mag, apply_neg) = if x.sign == Sign::Negative {
(three_pi_over_4, y.sign == Sign::Negative)
} else {
(pi_over_4, y.sign == Sign::Negative)
};
return Ok(if apply_neg { mag.neg() } else { mag });
}
if y.is_infinite() {
let pi_val = pi(prec.saturating_add(16))?;
let two = BigFloat::from_i64(2, prec.saturating_add(16), mode);
let half_pi = pi_val
.div_ref_with_mode(&two, mode)?
.with_precision(prec, mode);
return if y.sign == Sign::Negative {
Ok(half_pi.neg())
} else {
Ok(half_pi)
};
}
if x.is_infinite() {
if x.sign == Sign::Positive {
return Ok(BigFloat::zero(prec));
} else {
let pi_val = pi(prec.saturating_add(16))?.with_precision(prec, mode);
return if y.sign == Sign::Negative {
Ok(pi_val.neg())
} else {
Ok(pi_val)
};
}
}
let y_sign = y.signum();
let x_sign = x.signum();
if y.is_zero() && x.is_zero() {
return Ok(BigFloat::zero(prec));
}
if x.is_zero() {
let pi_val = pi(prec.saturating_add(16))?;
let two = BigFloat::from_i64(2, prec.saturating_add(16), mode);
let pi_over_2 = pi_val.div_ref_with_mode(&two, mode)?;
return if y_sign >= 0 {
Ok(pi_over_2.with_precision(prec, mode))
} else {
Ok(pi_over_2.neg().with_precision(prec, mode))
};
}
if x_sign > 0 {
let work_prec = prec.saturating_add(16);
let ratio = y
.clone()
.with_precision(work_prec, mode)
.div_ref_with_mode(&x.clone().with_precision(work_prec, mode), mode)?;
return ratio.atan(prec, mode);
}
let work_prec = prec.saturating_add(32);
let ratio = y
.clone()
.with_precision(work_prec, mode)
.div_ref_with_mode(&x.clone().with_precision(work_prec, mode), mode)?;
let atan_val = ratio.atan(work_prec, mode)?;
let pi_val = pi(work_prec)?;
if y_sign >= 0 {
let result = pi_val.add_ref_with_mode(&atan_val, mode);
Ok(result.with_precision(prec, mode))
} else {
let result = atan_val.sub_ref_with_mode(&pi_val, mode);
Ok(result.with_precision(prec, mode))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn mk(n: i64, prec: u32) -> BigFloat {
BigFloat::from_i64(n, prec, RoundingMode::HalfEven)
}
#[test]
fn atan_zero() {
let z = mk(0, 64);
let r = z.atan(64, RoundingMode::HalfEven).expect("atan(0)");
assert!(r.is_zero());
}
#[test]
fn atan_one_is_pi_over_4() {
let prec = 64u32;
let one = mk(1, prec);
let a = one.atan(prec, RoundingMode::HalfEven).expect("atan(1)");
assert!(
(a.to_f64() - std::f64::consts::FRAC_PI_4).abs() < 1e-14,
"got {}",
a.to_f64()
);
}
#[test]
fn atan_minus_one() {
let prec = 64u32;
let minus_one = mk(-1, prec);
let a = minus_one
.atan(prec, RoundingMode::HalfEven)
.expect("atan(-1)");
assert!((a.to_f64() + std::f64::consts::FRAC_PI_4).abs() < 1e-14);
}
#[test]
fn atan2_quadrant_i() {
let prec = 64u32;
let mode = RoundingMode::HalfEven;
let one = mk(1, prec);
let a = one.atan2(&one, prec, mode).expect("atan2(1,1)");
assert!((a.to_f64() - std::f64::consts::FRAC_PI_4).abs() < 1e-14);
}
#[test]
fn atan2_negative_x() {
let prec = 64u32;
let mode = RoundingMode::HalfEven;
let one = mk(1, prec);
let neg_one = mk(-1, prec);
let a = one.atan2(&neg_one, prec, mode).expect("atan2(1,-1)");
let expected = 3.0 * std::f64::consts::FRAC_PI_4;
assert!(
(a.to_f64() - expected).abs() < 1e-13,
"got {}, expected {}",
a.to_f64(),
expected
);
}
#[test]
fn atan2_zero_x_positive_y() {
let prec = 64u32;
let mode = RoundingMode::HalfEven;
let one = mk(1, prec);
let zero = mk(0, prec);
let a = one.atan2(&zero, prec, mode).expect("atan2(1,0)");
assert!((a.to_f64() - std::f64::consts::FRAC_PI_2).abs() < 1e-14);
}
#[test]
fn atan2_zero_x_negative_y() {
let prec = 64u32;
let mode = RoundingMode::HalfEven;
let neg_one = mk(-1, prec);
let zero = mk(0, prec);
let a = neg_one.atan2(&zero, prec, mode).expect("atan2(-1,0)");
assert!((a.to_f64() + std::f64::consts::FRAC_PI_2).abs() < 1e-14);
}
}