use oxinum_core::{OxiNumError, OxiNumResult};
use super::constants::pi;
use super::float::{BigFloat, RoundingMode};
fn sin_taylor(u: &BigFloat, prec: u32, mode: RoundingMode) -> OxiNumResult<BigFloat> {
if u.is_zero() {
return Ok(BigFloat::zero(prec));
}
let u_sq = u.mul_ref_with_mode(u, mode).with_precision(prec, mode);
let mut term = u.clone().with_precision(prec, mode);
let mut result = term.clone();
let n_terms: u64 = (prec as u64) / 2 + 10;
for k in 1..=n_terms {
let denom_val = (2 * k) * (2 * k + 1);
let denom_i64 = denom_val.min(i64::MAX as u64) as i64;
let denom_f = BigFloat::from_i64(denom_i64, prec, mode);
let numer = term
.mul_ref_with_mode(&u_sq, mode)
.with_precision(prec, mode);
term = numer
.div_ref_with_mode(&denom_f, mode)
.map_err(|e| OxiNumError::Precision(format!("sin_taylor denom zero: {e}").into()))?;
term = term.neg().with_precision(prec, mode);
result = result
.add_ref_with_mode(&term, mode)
.with_precision(prec, mode);
}
Ok(result.with_precision(prec, mode))
}
fn cos_taylor(u: &BigFloat, prec: u32, mode: RoundingMode) -> OxiNumResult<BigFloat> {
let u_sq = u.mul_ref_with_mode(u, mode).with_precision(prec, mode);
let one = BigFloat::from_i64(1, prec, mode);
let mut term = one.clone();
let mut result = one;
let n_terms: u64 = (prec as u64) / 2 + 10;
for k in 1..=n_terms {
let denom_val = (2 * k - 1) * (2 * k);
let denom_i64 = denom_val.min(i64::MAX as u64) as i64;
let denom_f = BigFloat::from_i64(denom_i64, prec, mode);
let numer = term
.mul_ref_with_mode(&u_sq, mode)
.with_precision(prec, mode);
term = numer
.div_ref_with_mode(&denom_f, mode)
.map_err(|e| OxiNumError::Precision(format!("cos_taylor denom zero: {e}").into()))?;
term = term.neg().with_precision(prec, mode);
result = result
.add_ref_with_mode(&term, mode)
.with_precision(prec, mode);
}
Ok(result.with_precision(prec, mode))
}
fn bigfloat_to_i64_round(x: &BigFloat) -> Option<i64> {
if x.is_zero() {
return Some(0);
}
let e = x.exponent();
let bits = x.mantissa().bit_length() as i64;
let top_pos = e.saturating_add(bits - 1);
if top_pos < 0 {
return Some(0);
}
if top_pos >= 63 {
return None;
}
let int_biguint = if e >= 0 {
x.mantissa().shl_bits(e as u64)
} else {
let shift = (-e) as u64;
x.mantissa().shr_bits(shift)
};
let int_mag = int_biguint.to_u64()?;
if x.signum() < 0 {
Some(-(int_mag as i64))
} else {
Some(int_mag as i64)
}
}
fn round_div_pi_over_2(
x: &BigFloat,
pi_over_2: &BigFloat,
work_prec: u32,
mode: RoundingMode,
) -> Option<i64> {
let ratio = x
.clone()
.with_precision(work_prec, mode)
.div_ref_with_mode(pi_over_2, mode)
.ok()?;
let half = BigFloat::from_f64(0.5, work_prec).ok()?;
let shifted = if ratio.signum() >= 0 {
ratio.add_ref_with_mode(&half, mode)
} else {
ratio.sub_ref_with_mode(&half, mode)
};
bigfloat_to_i64_round(&shifted)
}
pub(crate) fn sincos_impl(
x: &BigFloat,
prec: u32,
mode: RoundingMode,
) -> OxiNumResult<(BigFloat, BigFloat)> {
let exp_guard = if x.exponent() > 0 {
(x.exponent() as u32).min(256)
} else {
0u32
};
let work_prec = prec.saturating_add(exp_guard).saturating_add(32);
let pi_val = pi(work_prec)?;
let two = BigFloat::from_i64(2, work_prec, mode);
let pi_over_2 = pi_val.div_ref_with_mode(&two, mode)?;
let q = match round_div_pi_over_2(x, &pi_over_2, work_prec, mode) {
Some(v) => v,
None => {
return Err(OxiNumError::Precision(
"sin/cos: argument magnitude too large (|x| > 2^62 * π/2); \
use argument reduction before calling"
.into(),
));
}
};
let quadrant = q.rem_euclid(4) as u32;
let q_bf = BigFloat::from_i64(q, work_prec, mode);
let u = x
.clone()
.with_precision(work_prec, mode)
.sub_ref_with_mode(&q_bf.mul_ref_with_mode(&pi_over_2, mode), mode);
let sin_u = sin_taylor(&u, work_prec, mode)?;
let cos_u = cos_taylor(&u, work_prec, mode)?;
let (sin_x, cos_x) = match quadrant {
0 => (sin_u, cos_u),
1 => (cos_u, sin_u.neg()),
2 => (sin_u.neg(), cos_u.neg()),
3 => (cos_u.neg(), sin_u),
_ => unreachable!("rem_euclid(4) always in 0..=3"),
};
Ok((
sin_x.with_precision(prec, mode),
cos_x.with_precision(prec, mode),
))
}
impl BigFloat {
pub fn sin(&self, prec: u32, mode: RoundingMode) -> OxiNumResult<BigFloat> {
if self.is_nan() || self.is_infinite() {
return Ok(BigFloat::nan(prec));
}
if self.is_zero() {
return Ok(BigFloat::zero(prec));
}
let (s, _c) = sincos_impl(self, prec, mode)?;
Ok(s)
}
pub fn cos(&self, prec: u32, mode: RoundingMode) -> OxiNumResult<BigFloat> {
if self.is_nan() || self.is_infinite() {
return Ok(BigFloat::nan(prec));
}
let (_s, c) = sincos_impl(self, prec, mode)?;
Ok(c)
}
pub fn tan(&self, prec: u32, mode: RoundingMode) -> OxiNumResult<BigFloat> {
if self.is_nan() || self.is_infinite() {
return Ok(BigFloat::nan(prec));
}
let (s, c) = sincos_impl(self, prec, mode)?;
if c.is_zero() {
return Err(OxiNumError::Domain("tan undefined at π/2 + k·π".into()));
}
let result = s.div_ref_with_mode(&c, mode)?;
Ok(result.with_precision(prec, mode))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn mk(n: i64, prec: u32) -> BigFloat {
BigFloat::from_i64(n, prec, RoundingMode::HalfEven)
}
#[test]
fn sin_zero_is_zero() {
let z = mk(0, 64);
let s = z.sin(64, RoundingMode::HalfEven).expect("sin(0)");
assert!(s.is_zero());
}
#[test]
fn cos_zero_is_one() {
let z = mk(0, 64);
let c = z.cos(64, RoundingMode::HalfEven).expect("cos(0)");
assert!((c.to_f64() - 1.0).abs() < 1e-14, "cos(0) = {}", c.to_f64());
}
#[test]
fn sin_pi_over_2_is_one() {
let prec = 64u32;
let p = pi(prec).expect("pi");
let two = mk(2, prec);
let pi_over_2 = p
.div_ref_with_mode(&two, RoundingMode::HalfEven)
.expect("pi/2");
let s = pi_over_2
.sin(prec, RoundingMode::HalfEven)
.expect("sin(pi/2)");
assert!(
(s.to_f64() - 1.0).abs() < 1e-14,
"sin(π/2) = {}",
s.to_f64()
);
}
#[test]
fn pythagorean_identity_at_f64_angle() {
let prec = 100u32;
for x_f64 in [0.3f64, 1.0, 2.7, -1.5, 5.0, -2.9] {
let x = BigFloat::from_f64(x_f64, prec).expect("from_f64");
let s = x.sin(prec, RoundingMode::HalfEven).expect("sin");
let c = x.cos(prec, RoundingMode::HalfEven).expect("cos");
let sum = s
.mul_ref_with_mode(&s, RoundingMode::HalfEven)
.add_ref_with_mode(
&c.mul_ref_with_mode(&c, RoundingMode::HalfEven),
RoundingMode::HalfEven,
);
let err = (sum.to_f64() - 1.0).abs();
assert!(
err < 1e-25,
"sin²+cos² error = {:.2e} for x = {}",
err,
x_f64
);
}
}
}