use core::cmp::Ordering;
use core::fmt;
use core::ops::{Add, Div, Mul, Neg, Sub};
use crate::error::{CoreError, CoreResult, ErrorContext};
#[inline(always)]
fn comp_err(msg: impl Into<String>) -> CoreError {
CoreError::ComputationError(ErrorContext::new(msg))
}
#[inline]
pub fn two_sum(a: f64, b: f64) -> (f64, f64) {
let s = a + b;
let v = s - a;
let e = (a - (s - v)) + (b - v);
(s, e)
}
#[inline]
pub fn two_prod(a: f64, b: f64) -> (f64, f64) {
let p = a * b;
let c = (134_217_729.0_f64) * a; let a_hi = c - (c - a);
let a_lo = a - a_hi;
let c2 = (134_217_729.0_f64) * b;
let b_hi = c2 - (c2 - b);
let b_lo = b - b_hi;
let e = ((a_hi * b_hi - p) + a_hi * b_lo + a_lo * b_hi) + a_lo * b_lo;
(p, e)
}
#[inline]
pub fn two_diff(a: f64, b: f64) -> (f64, f64) {
two_sum(a, -b)
}
#[derive(Debug, Clone, Copy, Default)]
pub struct DD {
pub hi: f64,
pub lo: f64,
}
impl DD {
#[inline]
#[must_use]
pub fn from_f64(x: f64) -> Self {
Self { hi: x, lo: 0.0 }
}
#[inline]
#[must_use]
pub fn from_i64(x: i64) -> Self {
let hi = x as f64;
let lo = (x - hi as i64) as f64;
Self::renorm(hi, lo)
}
#[inline]
#[must_use]
pub fn from_parts(hi: f64, lo: f64) -> Self {
Self::renorm(hi, lo)
}
pub const ZERO: DD = DD { hi: 0.0, lo: 0.0 };
pub const ONE: DD = DD { hi: 1.0, lo: 0.0 };
#[must_use]
pub fn pi() -> DD {
DD { hi: 3.141_592_653_589_793_1_f64, lo: 1.224_646_799_147_353_2e-16 }
}
#[must_use]
pub fn e() -> DD {
DD { hi: 2.718_281_828_459_045_f64, lo: 1.445_646_891_729_250_2e-16 }
}
#[must_use]
pub fn ln2() -> DD {
DD { hi: 0.693_147_180_559_945_3_f64, lo: 2.319_046_813_846_299_6e-17 }
}
#[must_use]
pub fn sqrt2() -> DD {
DD { hi: 1.414_213_562_373_095_f64, lo: -9.667_293_313_452_914e-17 }
}
#[inline]
#[must_use]
pub fn is_zero(&self) -> bool {
self.hi == 0.0 && self.lo == 0.0
}
#[inline]
#[must_use]
pub fn is_finite(&self) -> bool {
self.hi.is_finite()
}
#[inline]
#[must_use]
pub fn is_nan(&self) -> bool {
self.hi.is_nan()
}
#[inline]
#[must_use]
pub fn to_f64(self) -> f64 {
self.hi
}
#[inline]
#[must_use]
pub fn to_f64_round(self) -> f64 {
self.hi + self.lo
}
#[inline]
#[must_use]
pub fn negate(self) -> DD {
DD { hi: -self.hi, lo: -self.lo }
}
#[must_use]
pub fn dd_add(self, rhs: DD) -> DD {
let (s1, s2) = two_sum(self.hi, rhs.hi);
let (t1, t2) = two_sum(self.lo, rhs.lo);
let c = s2 + t1;
let (v_hi, v_lo) = two_sum(s1, c);
let w = t2 + v_lo;
DD::renorm(v_hi, w)
}
#[must_use]
pub fn dd_sub(self, rhs: DD) -> DD {
self.dd_add(rhs.negate())
}
#[must_use]
pub fn dd_mul(self, rhs: DD) -> DD {
let (p1, p2) = two_prod(self.hi, rhs.hi);
let p2 = p2 + self.hi * rhs.lo + self.lo * rhs.hi;
DD::renorm(p1, p2)
}
pub fn dd_div(self, rhs: DD) -> CoreResult<DD> {
if rhs.is_zero() {
return Err(comp_err("DD::div — division by zero"));
}
let q1 = self.hi / rhs.hi;
let r = self.dd_sub(DD::from_f64(q1).dd_mul(rhs));
let q2 = r.hi / rhs.hi;
Ok(DD::renorm(q1, q2))
}
#[inline]
#[must_use]
pub fn abs(self) -> DD {
if self.hi < 0.0 { self.negate() } else { self }
}
pub fn sqrt(self) -> CoreResult<DD> {
if self.hi < 0.0 {
return Err(comp_err("DD::sqrt of negative number"));
}
if self.is_zero() {
return Ok(DD::ZERO);
}
let x0 = DD::from_f64(self.hi.sqrt());
let half = DD::from_f64(0.5);
let x1 = x0.dd_add(self.dd_div(x0)?).dd_mul(half);
let x2 = x1.dd_add(self.dd_div(x1)?).dd_mul(half);
Ok(x2)
}
#[must_use]
pub fn square(self) -> DD {
let (p1, p2) = two_prod(self.hi, self.hi);
let p2 = p2 + 2.0 * self.hi * self.lo;
DD::renorm(p1, p2)
}
#[inline]
#[must_use]
pub fn renorm(hi: f64, lo: f64) -> DD {
let (s, e) = two_sum(hi, lo);
DD { hi: s, lo: e }
}
#[must_use]
pub fn compare(&self, rhs: &DD) -> Ordering {
match self.hi.partial_cmp(&rhs.hi) {
Some(Ordering::Equal) => self.lo.partial_cmp(&rhs.lo).unwrap_or(Ordering::Equal),
Some(ord) => ord,
None => Ordering::Equal, }
}
}
impl PartialEq for DD {
fn eq(&self, other: &Self) -> bool {
self.hi == other.hi && self.lo == other.lo
}
}
impl PartialOrd for DD {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.compare(other))
}
}
impl Neg for DD {
type Output = DD;
fn neg(self) -> DD {
self.negate()
}
}
impl Add for DD {
type Output = DD;
fn add(self, rhs: DD) -> DD {
self.dd_add(rhs)
}
}
impl Sub for DD {
type Output = DD;
fn sub(self, rhs: DD) -> DD {
self.dd_sub(rhs)
}
}
impl Mul for DD {
type Output = DD;
fn mul(self, rhs: DD) -> DD {
self.dd_mul(rhs)
}
}
impl Div for DD {
type Output = DD;
fn div(self, rhs: DD) -> DD {
self.dd_div(rhs).unwrap_or(DD { hi: f64::NAN, lo: f64::NAN })
}
}
impl fmt::Display for DD {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let v = self.hi + self.lo;
write!(f, "{:.30e}", v)
}
}
impl From<f64> for DD {
fn from(x: f64) -> DD {
DD::from_f64(x)
}
}
impl From<i64> for DD {
fn from(x: i64) -> DD {
DD::from_i64(x)
}
}
impl From<i32> for DD {
fn from(x: i32) -> DD {
DD::from_f64(x as f64)
}
}
pub fn dd_exp(x: DD) -> CoreResult<DD> {
if !x.is_finite() {
return Err(comp_err("dd_exp: non-finite input"));
}
let ln2 = DD::ln2();
let k_f = (x.hi / ln2.hi).round();
let k = k_f as i64;
let k_ln2 = DD::from_f64(k_f).dd_mul(ln2);
let r = x.dd_sub(k_ln2);
let n_terms = 30usize;
let mut sum = DD::ONE;
let mut term = DD::ONE;
for n in 1..=n_terms {
term = term.dd_mul(r).dd_div(DD::from_i64(n as i64))?;
let new_sum = sum.dd_add(term);
if term.abs().hi.abs() < sum.abs().hi * f64::EPSILON * 0.5 {
sum = new_sum;
break;
}
sum = new_sum;
}
let scale = f64::from_bits(((1023i64 + k) as u64) << 52);
Ok(DD::renorm(sum.hi * scale, sum.lo * scale))
}
pub fn dd_ln(x: DD) -> CoreResult<DD> {
if x.hi <= 0.0 {
return Err(comp_err("dd_ln: argument must be positive"));
}
if !x.is_finite() {
return Err(comp_err("dd_ln: non-finite input"));
}
let a0 = DD::from_f64(x.hi.ln());
let exp_a0 = dd_exp(a0)?;
let correction = x.dd_sub(exp_a0).dd_div(exp_a0)?;
Ok(a0.dd_add(correction))
}
pub fn dd_sincos(x: DD) -> CoreResult<(DD, DD)> {
if !x.is_finite() {
return Err(comp_err("dd_sincos: non-finite input"));
}
let pi = DD::pi();
let two_over_pi = DD::from_f64(2.0).dd_div(pi)?;
let k_f = (x.dd_mul(two_over_pi)).hi.round();
let k = k_f as i64;
let half_pi = pi.dd_mul(DD::from_f64(0.5));
let r = x.dd_sub(DD::from_i64(k).dd_mul(half_pi));
let r2 = r.square();
let n_terms = 20usize;
let mut sin_val = r;
let mut term_sin = r;
let mut cos_val = DD::ONE;
let mut term_cos = DD::ONE;
for i in 1..=n_terms {
term_sin = term_sin.dd_mul(r2.negate())
.dd_div(DD::from_i64((2 * i) as i64))?
.dd_div(DD::from_i64((2 * i + 1) as i64))?;
term_cos = term_cos.dd_mul(r2.negate())
.dd_div(DD::from_i64((2 * i - 1) as i64))?
.dd_div(DD::from_i64((2 * i) as i64))?;
let new_sin = sin_val.dd_add(term_sin);
let new_cos = cos_val.dd_add(term_cos);
let conv = term_sin.abs().hi.abs() < sin_val.abs().hi * f64::EPSILON * 0.5;
sin_val = new_sin;
cos_val = new_cos;
if conv {
break;
}
}
let km4 = ((k % 4) + 4) as usize % 4;
let (s, c) = match km4 {
0 => (sin_val, cos_val),
1 => (cos_val, sin_val.negate()),
2 => (sin_val.negate(), cos_val.negate()),
_ => (cos_val.negate(), sin_val),
};
Ok((s, c))
}
pub mod qd_functions;
pub mod quad_double;
pub mod solver_integration;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_two_sum_exact() {
let a = 1.0_f64;
let b = f64::EPSILON / 2.0;
let (s, e) = two_sum(a, b);
assert_eq!(s + e, a + b, "TwoSum roundtrip failed");
}
#[test]
fn test_two_prod_exact() {
let a = 1.0_f64 + f64::EPSILON;
let b = 1.0_f64 + f64::EPSILON;
let (p, e) = two_prod(a, b);
let exact = (a as f64) * (b as f64); let _ = exact; let reconstructed = p + e;
assert!((reconstructed - a * b).abs() <= f64::EPSILON * 4.0,
"TwoProd roundtrip: {reconstructed} vs {}", a * b);
}
#[test]
fn test_dd_add_basic() {
let a = DD::from_f64(1.0);
let b = DD::from_f64(2.0);
let c = a.dd_add(b);
assert_eq!(c.hi, 3.0, "1 + 2 should be 3");
assert_eq!(c.lo, 0.0);
}
#[test]
fn test_dd_sub_basic() {
let a = DD::from_f64(5.0);
let b = DD::from_f64(3.0);
let c = a.dd_sub(b);
assert_eq!(c.hi, 2.0);
}
#[test]
fn test_dd_mul_basic() {
let a = DD::from_f64(3.0);
let b = DD::from_f64(4.0);
let c = a.dd_mul(b);
assert_eq!(c.hi, 12.0);
assert_eq!(c.lo, 0.0);
}
#[test]
fn test_dd_div_basic() {
let a = DD::from_f64(10.0);
let b = DD::from_f64(4.0);
let c = a.dd_div(b).expect("should succeed");
let diff = (c.hi - 2.5).abs();
assert!(diff < f64::EPSILON * 4.0, "10/4 should be 2.5, got {}", c.hi);
}
#[test]
fn test_dd_div_zero() {
let a = DD::from_f64(1.0);
let b = DD::ZERO;
assert!(a.dd_div(b).is_err());
}
#[test]
fn test_dd_sqrt() {
let two = DD::from_f64(2.0);
let s = two.sqrt().expect("should succeed");
let expected = std::f64::consts::SQRT_2;
let diff = (s.hi - expected).abs();
assert!(diff < f64::EPSILON * 4.0, "sqrt(2) error: {diff}");
let reconst = s.hi + s.lo;
let better = (reconst - expected).abs();
assert!(better < 1e-31 || better < diff, "DD sqrt should be more precise than f64");
}
#[test]
fn test_dd_sqrt_negative() {
let neg = DD::from_f64(-1.0);
assert!(neg.sqrt().is_err());
}
#[test]
fn test_dd_pi_accuracy() {
let pi = DD::pi();
let diff = (pi.hi - std::f64::consts::PI).abs();
assert!(diff < f64::EPSILON * 2.0, "DD::pi hi part error: {diff}");
assert!(pi.lo.abs() > 0.0, "DD::pi lo part should be non-zero");
}
#[test]
fn test_dd_e_accuracy() {
let e = DD::e();
let diff = (e.hi - std::f64::consts::E).abs();
assert!(diff < f64::EPSILON * 2.0, "DD::e hi part error: {diff}");
}
#[test]
fn test_dd_ln2_accuracy() {
let ln2 = DD::ln2();
let diff = (ln2.hi - std::f64::consts::LN_2).abs();
assert!(diff < f64::EPSILON * 2.0, "DD::ln2 hi part error: {diff}");
}
#[test]
fn test_dd_sqrt2_accuracy() {
let sqrt2 = DD::sqrt2();
let diff = (sqrt2.hi - std::f64::consts::SQRT_2).abs();
assert!(diff < f64::EPSILON * 2.0, "DD::sqrt2 hi part error: {diff}");
}
#[test]
fn test_dd_exp() {
let one = DD::ONE;
let e_val = dd_exp(one).expect("should succeed");
let expected = std::f64::consts::E;
let diff = (e_val.hi + e_val.lo - expected).abs();
assert!(diff < 1e-30, "dd_exp(1) - e = {diff}");
}
#[test]
fn test_dd_ln() {
let e_val = DD::e();
let ln_e = dd_ln(e_val).expect("should succeed");
let diff = (ln_e.hi + ln_e.lo - 1.0).abs();
assert!(diff < 1e-28, "ln(e) - 1 = {diff}");
}
#[test]
fn test_dd_sincos() {
let x = DD::from_f64(1.0);
let (s, c) = dd_sincos(x).expect("should succeed");
let expected_sin = 1.0_f64.sin();
let expected_cos = 1.0_f64.cos();
let diff_s = (s.hi - expected_sin).abs();
let diff_c = (c.hi - expected_cos).abs();
assert!(diff_s < 1e-15, "sin(1) diff: {diff_s}");
assert!(diff_c < 1e-15, "cos(1) diff: {diff_c}");
}
#[test]
fn test_operator_overloads() {
let a = DD::from_f64(3.0);
let b = DD::from_f64(4.0);
let sum = a + b;
let diff = a - b;
let prod = a * b;
let quot = a / b;
assert_eq!(sum.hi, 7.0);
assert_eq!(diff.hi, -1.0);
assert_eq!(prod.hi, 12.0);
assert!((quot.hi - 0.75).abs() < f64::EPSILON * 4.0);
}
#[test]
fn test_partial_ord() {
let a = DD::from_f64(1.0);
let b = DD::from_f64(2.0);
assert!(a < b);
assert!(b > a);
assert!(a <= a);
}
#[test]
fn test_from_i64() {
let x = DD::from_i64(1_000_000_000_000i64);
assert_eq!(x.hi, 1_000_000_000_000.0_f64);
}
#[test]
fn test_square() {
let x = DD::from_f64(3.0);
let sq = x.square();
assert_eq!(sq.hi, 9.0);
assert_eq!(sq.lo, 0.0);
}
#[test]
fn test_display() {
let x = DD::from_f64(1.5);
let s = format!("{x}");
assert!(s.contains("1.5"), "Display: {s}");
}
}