use std::sync::Arc;
use num_bigint::BigInt;
use num_traits::Signed;
use parking_lot::RwLock;
use crate::binary::{Binary, Bounds, ReciprocalRounding, XBinary, reciprocal_rounded_abs_extended};
use crate::error::ComputableError;
use crate::node::{Node, NodeOp};
use crate::sane;
const MIN_SEED_PRECISION_BITS: usize = 64;
struct ReciprocalApprox {
lower: Binary,
upper: Binary,
precision: usize,
}
pub(crate) struct NewtonState {
lo: ReciprocalApprox,
hi: ReciprocalApprox,
abs_lower: Binary,
abs_upper: Binary,
negate_result: bool,
}
pub struct InvOp {
pub inner: Arc<Node>,
pub newton_state: RwLock<Option<NewtonState>>,
}
impl NodeOp for InvOp {
fn compute_bounds(&self) -> Result<Bounds, ComputableError> {
let state = self.newton_state.read();
match &*state {
None => {
let existing = self.inner.get_bounds()?;
let lower = existing.small();
let upper = existing.large();
let zero = XBinary::zero();
if lower <= &zero && upper >= zero {
Ok(Bounds::new(XBinary::NegInf, XBinary::PosInf))
} else if upper < zero {
Ok(Bounds::new(XBinary::NegInf, XBinary::zero()))
} else {
Ok(Bounds::new(XBinary::zero(), XBinary::PosInf))
}
}
Some(s) => {
let (out_lo, out_hi) = if s.negate_result {
(
XBinary::Finite(s.hi.upper.neg()),
XBinary::Finite(s.lo.lower.neg()),
)
} else {
(
XBinary::Finite(s.lo.lower.clone()),
XBinary::Finite(s.hi.upper.clone()),
)
};
Ok(Bounds::new(out_lo, out_hi))
}
}
}
fn refine_step(&self, precision_bits: usize) -> Result<bool, ComputableError> {
let input_bounds = self.inner.get_bounds()?;
let mut state = self.newton_state.write();
match &mut *state {
None => {
let seed_precision = if precision_bits <= crate::MAX_COMPUTATION_BITS {
precision_bits.max(MIN_SEED_PRECISION_BITS)
} else {
MIN_SEED_PRECISION_BITS
};
*state = try_initialize(&input_bounds, seed_precision)?;
Ok(true)
}
Some(s) => {
update_denominators(s, &input_bounds)?;
newton_step(&mut s.lo, &s.abs_upper);
newton_step(&mut s.hi, &s.abs_lower);
Ok(true)
}
}
}
fn children(&self) -> Vec<Arc<Node>> {
vec![Arc::clone(&self.inner)]
}
fn is_refiner(&self) -> bool {
true
}
}
fn try_initialize(
input_bounds: &Bounds,
seed_precision: usize,
) -> Result<Option<NewtonState>, ComputableError> {
let lower = input_bounds.small();
let upper = input_bounds.large();
let zero = XBinary::zero();
if lower <= &zero && upper >= zero {
return Ok(None);
}
let (lower_finite, upper_finite) = match (lower, &upper) {
(XBinary::Finite(lo), XBinary::Finite(hi)) => (lo.clone(), hi.clone()),
_ => return Ok(None),
};
let negate_result = upper_finite.mantissa().is_negative();
let (abs_lower, abs_upper) = if negate_result {
(upper_finite.neg(), lower_finite.neg())
} else {
(lower_finite, upper_finite)
};
let lo = seed_reciprocal(&abs_upper, seed_precision)?;
let hi = seed_reciprocal(&abs_lower, seed_precision)?;
Ok(Some(NewtonState {
lo,
hi,
abs_lower,
abs_upper,
negate_result,
}))
}
fn seed_reciprocal(
denom: &Binary,
seed_precision: usize,
) -> Result<ReciprocalApprox, ComputableError> {
let precision = BigInt::from(seed_precision);
let xb_denom = XBinary::Finite(denom.clone());
let lower_xb =
reciprocal_rounded_abs_extended(&xb_denom, &precision, ReciprocalRounding::Floor)?;
let upper_xb =
reciprocal_rounded_abs_extended(&xb_denom, &precision, ReciprocalRounding::Ceil)?;
let lower = match lower_xb {
XBinary::Finite(b) => b,
XBinary::NegInf | XBinary::PosInf => Binary::zero(),
};
let upper = match upper_xb {
XBinary::Finite(b) => b,
XBinary::NegInf | XBinary::PosInf => Binary::zero(),
};
Ok(ReciprocalApprox {
lower,
upper,
precision: seed_precision,
})
}
fn update_denominators(
state: &mut NewtonState,
input_bounds: &Bounds,
) -> Result<(), ComputableError> {
let lower = input_bounds.small();
let upper = &input_bounds.large();
let (lower_finite, upper_finite) = match (lower, upper) {
(XBinary::Finite(lo), XBinary::Finite(hi)) => (lo.clone(), hi.clone()),
_ => return Ok(()), };
let (new_abs_lower, new_abs_upper) = if state.negate_result {
(upper_finite.neg(), lower_finite.neg())
} else {
(lower_finite, upper_finite)
};
let one = Binary::new(BigInt::from(1_i32), BigInt::from(0_i32));
if new_abs_upper != state.abs_upper {
state.abs_upper = new_abs_upper.clone();
if state.lo.upper.mul(&new_abs_upper) < one {
state.lo.upper = state.hi.upper.clone();
}
}
if new_abs_lower != state.abs_lower {
state.abs_lower = new_abs_lower.clone();
if state.hi.lower.mul(&new_abs_lower) > one {
state.hi.lower = state.lo.lower.clone();
}
}
Ok(())
}
fn newton_step(approx: &mut ReciprocalApprox, denom: &Binary) {
let x = &approx.lower;
let gap = approx.upper.sub(x);
let two = Binary::new(BigInt::from(1_i32), BigInt::from(1_i32));
let ax = denom.mul(x);
let two_minus_ax = two.sub(&ax);
let x_new = x.mul(&two_minus_ax);
let gap_sq = gap.mul(&gap);
let err = denom.mul(&gap_sq);
let upper_new = x_new.add(&err);
let new_precision = approx.precision.saturating_mul(2_usize);
let x_trunc = truncate_floor(&x_new, new_precision);
let upper_trunc = truncate_ceil(&upper_new, new_precision);
if x_trunc > approx.lower {
approx.lower = x_trunc;
}
if upper_trunc < approx.upper {
approx.upper = upper_trunc;
}
approx.precision = new_precision;
}
fn truncate_floor(x: &Binary, precision_bits: usize) -> Binary {
let bit_length = sane::bits_as_usize(x.mantissa().magnitude().bits());
let Some(shift) = bit_length
.checked_sub(precision_bits)
.filter(|&s| s > 0_usize)
else {
return x.clone();
};
let shifted = x.mantissa().magnitude() >> shift;
let has_remainder = (&shifted << shift) != *x.mantissa().magnitude();
let signed = if x.mantissa().is_negative() && has_remainder {
-BigInt::from(shifted) - BigInt::from(1_i32)
} else if x.mantissa().is_negative() {
-BigInt::from(shifted)
} else {
BigInt::from(shifted)
};
Binary::new(signed, x.exponent() + BigInt::from(shift))
}
fn truncate_ceil(x: &Binary, precision_bits: usize) -> Binary {
let bit_length = sane::bits_as_usize(x.mantissa().magnitude().bits());
let Some(shift) = bit_length
.checked_sub(precision_bits)
.filter(|&s| s > 0_usize)
else {
return x.clone();
};
let shifted = x.mantissa().magnitude() >> shift;
let has_remainder = (&shifted << shift) != *x.mantissa().magnitude();
let signed = if !x.mantissa().is_negative() && has_remainder {
BigInt::from(shifted) + BigInt::from(1_i32)
} else if x.mantissa().is_negative() {
-BigInt::from(shifted)
} else {
BigInt::from(shifted)
};
Binary::new(signed, x.exponent() + BigInt::from(shift))
}
#[cfg(test)]
mod tests {
use crate::binary::{Bounds, XBinary};
use crate::refinement::{XUsize, bounds_width_leq};
use crate::test_utils::{bin, interval_midpoint_computable, unwrap_finite};
#[test]
fn inv_allows_infinite_bounds() {
let value = interval_midpoint_computable(-1, 1);
let inv = value.inv();
let bounds = inv.bounds().expect("bounds should succeed");
assert_eq!(bounds, Bounds::new(XBinary::NegInf, XBinary::PosInf));
}
#[test]
fn inv_bounds_for_positive_interval() {
let value = interval_midpoint_computable(2, 4);
let inv = value.inv();
let tolerance_exp = XUsize::Finite(8);
let bounds = inv
.refine_to_default(tolerance_exp)
.expect("refine_to should succeed");
let lower = unwrap_finite(bounds.small());
let upper = unwrap_finite(&bounds.large());
let three = bin(3, 0);
let one = bin(1, 0);
assert!(
lower.mul(&three) <= one,
"lower bound {lower} exceeds 1/3: lower * 3 = {}",
lower.mul(&three)
);
assert!(
upper.mul(&three) >= one,
"upper bound {upper} is below 1/3: upper * 3 = {}",
upper.mul(&three)
);
assert!(
bounds_width_leq(&bounds, &tolerance_exp),
"bounds width exceeds tolerance"
);
}
}