use oxinum_core::{OxiNumResult, Sign};
use oxinum_int::native::{BigInt, BigUint};
use super::binary_splitting::{binary_split, BSSeries, BSSplit};
use super::float::{BigFloat, RoundingMode};
pub(crate) const BS_THRESHOLD_BITS: u32 = 512;
fn bigfloat_from_bigint(n: &BigInt, prec: u32, mode: RoundingMode) -> BigFloat {
if n.is_zero() {
return BigFloat::zero(prec);
}
BigFloat::from_parts(n.sign(), n.magnitude().clone(), 0, prec, mode)
}
pub(crate) fn term_count_exp(target_bits: u32) -> u64 {
let target = (target_bits as f64) + 16.0;
let mut k: u64 = 2;
let mut log2_kfact: f64 = 1.0; while log2_kfact < target {
k += 1;
log2_kfact += (k as f64).log2();
}
(k + 1).max(2)
}
pub(crate) fn term_count_trig(target_bits: u32) -> u64 {
let target = (target_bits as f64) + 16.0;
let mut k: u64 = 1;
let mut log2_fact: f64 = 0.0;
loop {
let a = (2 * k - 1) as f64;
let b = (2 * k) as f64;
log2_fact += a.log2() + b.log2();
if log2_fact > target {
break;
}
k += 1;
if k > 100_000 {
break; }
}
(k + 2).max(2)
}
fn split_arg(y: &BigFloat) -> (BigInt, BigInt) {
let m = y.mantissa().clone();
let sign = y.sign();
let e = y.exponent();
if e >= 0 {
let shifted = m.shl_bits(e as u64);
let p = BigInt::from_parts(sign, shifted);
let q = BigInt::one();
(p, q)
} else {
let p = BigInt::from_parts(sign, m);
let neg_e = (-e) as u64;
let q_mag = BigUint::one().shl_bits(neg_e);
let q = BigInt::from_parts(Sign::Positive, q_mag);
(p, q)
}
}
struct ExpSeries {
p: BigInt,
q: BigInt,
}
impl BSSeries for ExpSeries {
fn term(&self, k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
if k == 0 {
return (BigInt::one(), BigInt::one(), BigInt::one(), BigInt::one());
}
let qk = &self.q * &BigInt::from(k as i64);
(self.p.clone(), qk, BigInt::one(), BigInt::one())
}
}
struct SinSeries {
p: BigInt,
q: BigInt,
p2: BigInt, q2: BigInt, }
impl BSSeries for SinSeries {
fn term(&self, k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
if k == 0 {
return (self.p.clone(), self.q.clone(), BigInt::one(), BigInt::one());
}
let neg_p2 = -self.p2.clone();
let denom_k = &self.q2 * &BigInt::from((2 * k) as i64) * &BigInt::from((2 * k + 1) as i64);
(neg_p2, denom_k, BigInt::one(), BigInt::one())
}
}
struct CosSeries {
p2: BigInt, q2: BigInt, }
impl BSSeries for CosSeries {
fn term(&self, k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
if k == 0 {
return (BigInt::one(), BigInt::one(), BigInt::one(), BigInt::one());
}
let neg_p2 = -self.p2.clone();
let denom_k = &self.q2 * &BigInt::from((2 * k - 1) as i64) * &BigInt::from((2 * k) as i64);
(neg_p2, denom_k, BigInt::one(), BigInt::one())
}
}
fn reconstruct(split: BSSplit, work_prec: u32, mode: RoundingMode) -> OxiNumResult<BigFloat> {
let denom_int: BigInt = &split.q * &split.b;
let numer_f = bigfloat_from_bigint(&split.t, work_prec, mode);
let denom_f = bigfloat_from_bigint(&denom_int, work_prec, mode);
numer_f.div_ref_with_mode(&denom_f, mode)
}
pub(crate) fn exp_bs(y: &BigFloat, work_prec: u32, mode: RoundingMode) -> OxiNumResult<BigFloat> {
if y.is_zero() {
return Ok(BigFloat::from_i64(1, work_prec, mode));
}
let (p, q) = split_arg(y);
let n = term_count_exp(work_prec);
let split = binary_split(&ExpSeries { p, q }, 0, n);
let s = reconstruct(split, work_prec, mode)?;
Ok(s.with_precision(work_prec, mode))
}
pub(crate) fn sincos_bs(
u: &BigFloat,
work_prec: u32,
mode: RoundingMode,
) -> OxiNumResult<(BigFloat, BigFloat)> {
if u.is_zero() {
let sin_val = BigFloat::zero(work_prec);
let cos_val = BigFloat::from_i64(1, work_prec, mode);
return Ok((sin_val, cos_val));
}
let (p, q) = split_arg(u);
let p2 = &p * &p;
let q2 = &q * &q;
let n_trig = term_count_trig(work_prec);
let sin_split = binary_split(
&SinSeries {
p: p.clone(),
q: q.clone(),
p2: p2.clone(),
q2: q2.clone(),
},
0,
n_trig,
);
let cos_split = binary_split(&CosSeries { p2, q2 }, 0, n_trig);
let sin_val = reconstruct(sin_split, work_prec, mode)?;
let cos_val = reconstruct(cos_split, work_prec, mode)?;
Ok((
sin_val.with_precision(work_prec, mode),
cos_val.with_precision(work_prec, mode),
))
}
#[cfg(test)]
mod tests {
use super::*;
const MODE: RoundingMode = RoundingMode::HalfEven;
const PREC: u32 = 600;
fn bi(n: i64) -> BigInt {
BigInt::from(n)
}
#[test]
fn exp_series_term_zero() {
let s = ExpSeries { p: bi(1), q: bi(2) };
let (p, q, b, a) = s.term(0);
assert_eq!(p, BigInt::one());
assert_eq!(q, BigInt::one());
assert_eq!(b, BigInt::one());
assert_eq!(a, BigInt::one());
}
#[test]
fn exp_series_term_three() {
let s = ExpSeries { p: bi(1), q: bi(2) };
let (p, q, b, a) = s.term(3);
assert_eq!(p, bi(1));
assert_eq!(q, bi(6));
assert_eq!(b, BigInt::one());
assert_eq!(a, BigInt::one());
}
#[test]
fn sin_series_term_zero() {
let s = SinSeries {
p: bi(1),
q: bi(1),
p2: bi(1),
q2: bi(1),
};
let (p, q, b, a) = s.term(0);
assert_eq!(p, bi(1));
assert_eq!(q, bi(1));
assert_eq!(b, BigInt::one());
assert_eq!(a, BigInt::one());
}
#[test]
fn sin_series_term_two() {
let s = SinSeries {
p: bi(1),
q: bi(1),
p2: bi(1),
q2: bi(1),
};
let (p, q, b, a) = s.term(2);
assert_eq!(p, bi(-1));
assert_eq!(q, bi(20));
assert_eq!(b, BigInt::one());
assert_eq!(a, BigInt::one());
}
#[test]
fn cos_series_term_one() {
let s = CosSeries {
p2: bi(1),
q2: bi(1),
};
let (p, q, b, a) = s.term(1);
assert_eq!(p, bi(-1));
assert_eq!(q, bi(2));
assert_eq!(b, BigInt::one());
assert_eq!(a, BigInt::one());
}
#[test]
fn exp_bs_zero_is_one() {
let y = BigFloat::zero(PREC);
let result = exp_bs(&y, PREC, MODE).expect("exp_bs(0)");
let diff = (result.to_f64() - 1.0_f64).abs();
assert!(
diff < 1e-15,
"exp_bs(0) = {}, expected 1.0",
result.to_f64()
);
}
#[test]
fn sincos_bs_zero() {
let u = BigFloat::zero(PREC);
let (sin_u, cos_u) = sincos_bs(&u, PREC, MODE).expect("sincos_bs(0)");
assert!(
sin_u.to_f64().abs() < 1e-15,
"sin_bs(0) = {}",
sin_u.to_f64()
);
assert!(
(cos_u.to_f64() - 1.0).abs() < 1e-15,
"cos_bs(0) = {}",
cos_u.to_f64()
);
}
#[test]
fn exp_bs_one_matches_e_const() {
use crate::native::constants::e_const;
let one = BigFloat::from_i64(1, PREC, MODE);
let e_val = one.exp(PREC, MODE).expect("exp(1)");
let e_const_val = e_const(PREC).expect("e_const");
let diff = (e_val.to_f64() - e_const_val.to_f64()).abs();
assert!(diff < 1e-9, "exp(1) vs e_const diff = {diff}");
}
#[test]
fn pythagorean_identity_high_prec() {
let x = BigFloat::from_f64(0.7, PREC).expect("0.7");
let sin_x = x.sin(PREC, MODE).expect("sin");
let cos_x = x.cos(PREC, MODE).expect("cos");
let s2 = sin_x.mul_ref_with_mode(&sin_x, MODE);
let c2 = cos_x.mul_ref_with_mode(&cos_x, MODE);
let sum = s2.add_ref_with_mode(&c2, MODE);
let diff = (sum.to_f64() - 1.0).abs();
assert!(diff < 1e-9, "sin²+cos² = {}", sum.to_f64());
}
#[test]
fn sin_pi_over_6_high_prec() {
use crate::native::constants::pi;
let pi_val = pi(PREC).expect("pi");
let six = BigFloat::from_i64(6, PREC, MODE);
let x = pi_val.div_ref_with_mode(&six, MODE).expect("pi/6");
let s = x.sin(PREC, MODE).expect("sin");
let diff = (s.to_f64() - 0.5).abs();
assert!(diff < 1e-9, "sin(π/6) = {}", s.to_f64());
}
}