use std::num::NonZeroU32;
use std::sync::Arc;
use num_bigint::BigInt;
use num_traits::{One, Signed, Zero};
use parking_lot::RwLock;
use crate::binary::{Binary, Bounds, FiniteBounds, UXBinary, XBinary};
use crate::binary_utils::bisection::{
PrefixBisectionResult, PrefixBounds, bisection_step_normalized, midpoint, normalize_bounds,
};
use crate::binary_utils::power::binary_pow;
use crate::error::ComputableError;
use crate::node::{Node, NodeOp};
pub struct NthRootOp {
pub inner: Arc<Node>,
pub degree: NonZeroU32,
pub bisection_state: RwLock<Option<BisectionState>>,
}
#[derive(Clone, Debug)]
pub struct BisectionState {
pub bounds: PrefixBounds,
pub target: Binary,
pub negate_result: bool,
pub exact_value: Option<Binary>,
}
impl NodeOp for NthRootOp {
fn compute_bounds(&self) -> Result<Bounds, ComputableError> {
let input_bounds = self.inner.get_bounds()?;
let state = self.bisection_state.read();
match &*state {
None => {
compute_initial_bounds(&input_bounds, self.degree.get())
}
Some(s) => {
let finite_bounds = {
let bounds = if let Some(exact) = &s.exact_value {
FiniteBounds::point(exact.clone())
} else {
s.bounds.to_finite_bounds()
};
if s.negate_result {
bounds.interval_neg()
} else {
bounds
}
};
Ok(Bounds::from_lower_and_width(
XBinary::Finite(finite_bounds.small().clone()),
UXBinary::Finite(finite_bounds.width().clone()),
))
}
}
}
fn refine_step(&self, _precision_bits: usize) -> Result<bool, ComputableError> {
let input_bounds = self.inner.get_bounds()?;
let mut state = self.bisection_state.write();
match &mut *state {
None => {
*state = Some(initialize_nth_root_bisection_state(
&input_bounds,
self.degree.get(),
)?);
Ok(true)
}
Some(s) => {
if s.exact_value.is_some() {
return Ok(false);
}
let degree = self.degree.get();
let target = &s.target;
let result =
bisection_step_normalized(&s.bounds, |mid| binary_pow(mid, degree).cmp(target));
match result {
PrefixBisectionResult::Narrowed(new_bounds) => {
s.bounds = new_bounds;
}
PrefixBisectionResult::Exact(mid) => {
s.exact_value = Some(mid);
}
}
Ok(true)
}
}
}
fn children(&self) -> Vec<Arc<Node>> {
vec![Arc::clone(&self.inner)]
}
fn is_refiner(&self) -> bool {
true
}
}
fn compute_initial_bounds(input_bounds: &Bounds, degree: u32) -> Result<Bounds, ComputableError> {
let lower_input = input_bounds.small();
let upper_input = &input_bounds.large();
let is_even = degree.is_multiple_of(2);
let lower_output = compute_output_lower_bound(lower_input, is_even, degree)?;
let upper_output = compute_output_upper_bound(upper_input, is_even, degree)?;
Ok(Bounds::new(lower_output, upper_output))
}
fn compute_output_lower_bound(
lower_input: &XBinary,
is_even: bool,
degree: u32,
) -> Result<XBinary, ComputableError> {
match lower_input {
XBinary::NegInf => {
Ok(if is_even {
XBinary::Finite(Binary::zero())
} else {
XBinary::NegInf
})
}
XBinary::PosInf => {
crate::detected_computable_with_infinite_value!("lower input bound is PosInf");
Ok(XBinary::PosInf)
}
XBinary::Finite(lower_bin) => {
if lower_bin.mantissa().is_negative() {
if is_even {
Ok(XBinary::Finite(Binary::zero()))
} else {
let neg_lower = lower_bin.neg();
let lower_root = compute_root_upper_bound(&neg_lower, degree).neg();
Ok(XBinary::Finite(lower_root))
}
} else {
let lower_root = compute_root_lower_bound(lower_bin, degree);
Ok(XBinary::Finite(lower_root))
}
}
}
}
fn compute_output_upper_bound(
upper_input: &XBinary,
is_even: bool,
degree: u32,
) -> Result<XBinary, ComputableError> {
match upper_input {
XBinary::PosInf => Ok(XBinary::PosInf),
XBinary::NegInf => {
crate::detected_computable_with_infinite_value!("upper input bound is NegInf");
Ok(XBinary::NegInf)
}
XBinary::Finite(upper_bin) => {
if upper_bin.mantissa().is_negative() {
if is_even {
return Err(ComputableError::DomainError);
}
let neg_upper = upper_bin.neg();
let upper_root = compute_root_lower_bound(&neg_upper, degree).neg();
Ok(XBinary::Finite(upper_root))
} else {
let upper_root = compute_root_upper_bound(upper_bin, degree);
Ok(XBinary::Finite(upper_root))
}
}
}
}
fn compute_root_upper_bound(x: &Binary, _degree: u32) -> Binary {
let one = Binary::new(BigInt::one(), BigInt::zero());
let abs_x = x.magnitude().to_binary();
if abs_x > one { abs_x } else { one }
}
fn compute_root_lower_bound(x: &Binary, _degree: u32) -> Binary {
if x.mantissa().is_zero() || x.mantissa().is_negative() {
return Binary::zero();
}
let one = Binary::new(BigInt::one(), BigInt::zero());
if x < &one { x.clone() } else { one }
}
fn initialize_nth_root_bisection_state(
input_bounds: &Bounds,
degree: u32,
) -> Result<BisectionState, ComputableError> {
let lower = input_bounds.small();
let upper = &input_bounds.large();
let target = match (lower, upper) {
(XBinary::Finite(l), XBinary::Finite(u)) => midpoint(l, u),
_ => return Err(ComputableError::InfiniteBounds),
};
let is_even = degree.is_multiple_of(2);
if is_even && target.mantissa().is_negative() {
return Err(ComputableError::DomainError);
}
let (actual_target, negate_result) = if !is_even && target.mantissa().is_negative() {
(target.neg(), true)
} else {
(target.clone(), false)
};
let one = Binary::new(BigInt::one(), BigInt::zero());
let bisection_lower = if actual_target.mantissa().is_zero() {
Binary::zero()
} else if actual_target < one {
actual_target.clone()
} else {
one.clone()
};
let bisection_upper = if actual_target.mantissa().is_zero() {
Binary::zero()
} else if actual_target < one {
one
} else {
actual_target.clone()
};
let initial_bounds = FiniteBounds::new(bisection_lower, bisection_upper);
let normalized = normalize_bounds(&initial_bounds)?;
let exponent = normalized.width().exponent().clone();
let normalized_lower = normalized.small();
let mantissa = if normalized_lower.mantissa().is_zero() {
BigInt::zero()
} else {
normalized_lower.mantissa().clone()
};
Ok(BisectionState {
bounds: PrefixBounds::new(mantissa, exponent),
target: actual_target,
negate_result,
exact_value: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::computable::Computable;
use crate::refinement::{XUsize, bounds_width_leq};
use crate::test_utils::{bin, interval_noop_computable, unwrap_finite};
fn nz(n: u32) -> NonZeroU32 {
NonZeroU32::new(n).expect("test degree must be non-zero")
}
fn assert_bounds_compatible_with_expected(
bounds: &Bounds,
expected: &Binary,
tolerance_exp: &XUsize,
) {
let lower = unwrap_finite(bounds.small());
let upper_xb = bounds.large();
let upper = unwrap_finite(&upper_xb);
assert!(
lower <= *expected && *expected <= upper,
"Expected {} to be in bounds [{}, {}]",
expected,
lower,
upper
);
assert!(
bounds_width_leq(bounds, tolerance_exp),
"Bounds width should be <= tolerance",
);
}
#[test]
fn sqrt_of_4() {
let four = Computable::constant(bin(4, 0));
let sqrt_four = four.nth_root(nz(2));
let epsilon = XUsize::Finite(8);
let bounds = sqrt_four
.refine_to_default(epsilon)
.expect("refine_to should succeed");
let expected = bin(2, 0);
assert_bounds_compatible_with_expected(&bounds, &expected, &epsilon);
}
#[test]
fn sqrt_of_2() {
let two = Computable::constant(bin(2, 0));
let sqrt_two = two.nth_root(nz(2));
let epsilon = XUsize::Finite(8);
let bounds = sqrt_two
.refine_to_default(epsilon)
.expect("refine_to should succeed");
let expected_f64 = 2.0_f64.sqrt();
let expected_binary = XBinary::from_f64(expected_f64)
.expect("expected value should convert to extended binary");
let expected = unwrap_finite(&expected_binary);
assert_bounds_compatible_with_expected(&bounds, &expected, &epsilon);
}
#[test]
fn cbrt_of_8() {
let eight = Computable::constant(bin(8, 0));
let cbrt_eight = eight.nth_root(nz(3));
let epsilon = XUsize::Finite(8);
let bounds = cbrt_eight
.refine_to_default(epsilon)
.expect("refine_to should succeed");
let expected = bin(2, 0);
assert_bounds_compatible_with_expected(&bounds, &expected, &epsilon);
}
#[test]
fn cbrt_of_negative_8() {
let neg_eight = Computable::constant(bin(-8, 0));
let cbrt_neg_eight = neg_eight.nth_root(nz(3));
let epsilon = XUsize::Finite(8);
let bounds = cbrt_neg_eight
.refine_to_default(epsilon)
.expect("refine_to should succeed");
let expected = bin(-2, 0);
assert_bounds_compatible_with_expected(&bounds, &expected, &epsilon);
}
#[test]
fn fourth_root_of_16() {
let sixteen = Computable::constant(bin(16, 0));
let fourth_root = sixteen.nth_root(nz(4));
let epsilon = XUsize::Finite(8);
let bounds = fourth_root
.refine_to_default(epsilon)
.expect("refine_to should succeed");
let expected = bin(2, 0);
assert_bounds_compatible_with_expected(&bounds, &expected, &epsilon);
}
#[test]
fn sqrt_of_half() {
let half = Computable::constant(bin(1, -1));
let sqrt_half = half.nth_root(nz(2));
let epsilon = XUsize::Finite(8);
let bounds = sqrt_half
.refine_to_default(epsilon)
.expect("refine_to should succeed");
let expected_f64 = 0.5_f64.sqrt();
let expected_binary = XBinary::from_f64(expected_f64)
.expect("expected value should convert to extended binary");
let expected = unwrap_finite(&expected_binary);
assert_bounds_compatible_with_expected(&bounds, &expected, &epsilon);
}
#[test]
fn nth_root_in_expression() {
let sqrt_2 = Computable::constant(bin(2, 0)).nth_root(nz(2));
let cbrt_8 = Computable::constant(bin(8, 0)).nth_root(nz(3));
let sum = sqrt_2 + cbrt_8;
let epsilon = XUsize::Finite(8);
let bounds = sum
.refine_to_default(epsilon)
.expect("refine_to should succeed");
let expected_f64 = 2.0_f64.sqrt() + 2.0;
let expected_binary = XBinary::from_f64(expected_f64)
.expect("expected value should convert to extended binary");
let expected = unwrap_finite(&expected_binary);
assert_bounds_compatible_with_expected(&bounds, &expected, &epsilon);
}
#[test]
fn sqrt_of_zero() {
let zero = Computable::constant(bin(0, 0));
let sqrt_zero = zero.nth_root(nz(2));
let bounds = sqrt_zero.bounds().expect("bounds should succeed");
let expected = bin(0, 0);
let lower = unwrap_finite(bounds.small());
let upper = unwrap_finite(&bounds.large());
assert!(lower <= expected && expected <= upper);
}
#[test]
fn sqrt_of_interval_overlapping_zero() {
let interval = interval_noop_computable(-1, 4);
let sqrt_interval = interval.nth_root(nz(2));
let bounds = sqrt_interval.bounds().expect("bounds should succeed");
let lower = unwrap_finite(bounds.small());
let upper = unwrap_finite(&bounds.large());
assert_eq!(lower, bin(0, 0));
assert!(upper >= bin(2, 0));
}
#[test]
fn cbrt_of_interval_overlapping_zero() {
let interval = interval_noop_computable(-8, 27);
let cbrt_interval = interval.nth_root(nz(3));
let bounds = cbrt_interval.bounds().expect("bounds should succeed");
let lower = unwrap_finite(bounds.small());
let upper = unwrap_finite(&bounds.large());
assert!(lower <= bin(-2, 0));
assert!(upper >= bin(3, 0));
}
}