use std::num::NonZeroU32;
use std::sync::Arc;
use crate::binary::{Bounds, XBinary};
use crate::binary_utils::power::{is_negative, is_positive, xbinary_max, xbinary_pow};
use crate::error::ComputableError;
use crate::node::{Node, NodeOp};
pub struct PowOp {
pub inner: Arc<Node>,
pub exponent: NonZeroU32,
}
impl NodeOp for PowOp {
fn compute_bounds(&self) -> Result<Bounds, ComputableError> {
let input_bounds = self.inner.get_bounds()?;
let lower = input_bounds.small();
let upper = &input_bounds.large();
if self.exponent.get() == 1 {
return Ok(input_bounds);
}
let is_even = self.exponent.get().is_multiple_of(2);
let bounds = if is_even {
compute_even_power_bounds(lower, upper, self.exponent)
} else {
compute_odd_power_bounds(lower, upper, self.exponent)
};
Ok(bounds)
}
fn refine_step(&self, _precision_bits: usize) -> Result<bool, ComputableError> {
Ok(false)
}
fn children(&self) -> Vec<Arc<Node>> {
vec![Arc::clone(&self.inner)]
}
fn is_refiner(&self) -> bool {
false
}
}
fn compute_odd_power_bounds(lower: &XBinary, upper: &XBinary, n: NonZeroU32) -> Bounds {
let result_lower = xbinary_pow(lower, n);
let result_upper = xbinary_pow(upper, n);
Bounds::new(result_lower, result_upper)
}
fn compute_even_power_bounds(lower: &XBinary, upper: &XBinary, n: NonZeroU32) -> Bounds {
let lower_non_negative = !is_negative(lower);
let upper_non_positive = !is_positive(upper);
if lower_non_negative {
let result_lower = xbinary_pow(lower, n);
let result_upper = xbinary_pow(upper, n);
Bounds::new(result_lower, result_upper)
} else if upper_non_positive {
let result_lower = xbinary_pow(upper, n);
let result_upper = xbinary_pow(lower, n);
Bounds::new(result_lower, result_upper)
} else {
let lower_pow = xbinary_pow(lower, n);
let upper_pow = xbinary_pow(upper, n);
let result_upper = xbinary_max(&lower_pow, &upper_pow);
Bounds::new(XBinary::zero(), result_upper)
}
}
#[cfg(test)]
mod tests {
use crate::binary::{Binary, Bounds};
use crate::computable::Computable;
use crate::refinement::XUsize;
use crate::test_utils::{bin, interval_noop_computable, unwrap_finite};
fn assert_bounds_contain_expected(bounds: &Bounds, expected: &Binary, _tolerance_exp: &XUsize) {
let lower = unwrap_finite(bounds.small());
let upper = unwrap_finite(&bounds.large());
assert!(
lower <= *expected && *expected <= upper,
"Expected {} to be in bounds [{}, {}]",
expected,
lower,
upper
);
}
#[test]
fn pow_constant_squared() {
let three = Computable::constant(bin(3, 0));
let squared = three.pow(2);
let bounds = squared.bounds().expect("bounds should succeed");
let expected = bin(9, 0);
assert_eq!(unwrap_finite(bounds.small()), expected);
assert_eq!(unwrap_finite(&bounds.large()), expected);
}
#[test]
fn pow_constant_cubed() {
let two = Computable::constant(bin(2, 0));
let cubed = two.pow(3);
let bounds = cubed.bounds().expect("bounds should succeed");
let expected = bin(8, 0);
assert_eq!(unwrap_finite(bounds.small()), expected);
assert_eq!(unwrap_finite(&bounds.large()), expected);
}
#[test]
fn pow_negative_even() {
let neg_three = Computable::constant(bin(-3, 0));
let squared = neg_three.pow(2);
let bounds = squared.bounds().expect("bounds should succeed");
let expected = bin(9, 0);
assert_eq!(unwrap_finite(bounds.small()), expected);
assert_eq!(unwrap_finite(&bounds.large()), expected);
}
#[test]
fn pow_negative_odd() {
let neg_two = Computable::constant(bin(-2, 0));
let cubed = neg_two.pow(3);
let bounds = cubed.bounds().expect("bounds should succeed");
let expected = bin(-8, 0);
assert_eq!(unwrap_finite(bounds.small()), expected);
assert_eq!(unwrap_finite(&bounds.large()), expected);
}
#[test]
fn pow_interval_positive_even() {
let interval = interval_noop_computable(2, 4);
let squared = interval.pow(2);
let bounds = squared.bounds().expect("bounds should succeed");
assert_eq!(unwrap_finite(bounds.small()), bin(4, 0));
assert_eq!(unwrap_finite(&bounds.large()), bin(16, 0));
}
#[test]
fn pow_interval_negative_even() {
let interval = interval_noop_computable(-4, -2);
let squared = interval.pow(2);
let bounds = squared.bounds().expect("bounds should succeed");
assert_eq!(unwrap_finite(bounds.small()), bin(4, 0));
assert_eq!(unwrap_finite(&bounds.large()), bin(16, 0));
}
#[test]
fn pow_interval_spanning_zero_even() {
let interval = interval_noop_computable(-2, 3);
let squared = interval.pow(2);
let bounds = squared.bounds().expect("bounds should succeed");
assert_eq!(unwrap_finite(bounds.small()), bin(0, 0));
assert_eq!(unwrap_finite(&bounds.large()), bin(9, 0));
}
#[test]
fn pow_interval_odd() {
let interval = interval_noop_computable(2, 4);
let cubed = interval.pow(3);
let bounds = cubed.bounds().expect("bounds should succeed");
assert_eq!(unwrap_finite(bounds.small()), bin(8, 0));
assert_eq!(unwrap_finite(&bounds.large()), bin(64, 0));
}
#[test]
fn pow_interval_negative_odd() {
let interval = interval_noop_computable(-4, -2);
let cubed = interval.pow(3);
let bounds = cubed.bounds().expect("bounds should succeed");
assert_eq!(unwrap_finite(bounds.small()), bin(-64, 0));
assert_eq!(unwrap_finite(&bounds.large()), bin(-8, 0));
}
#[test]
fn pow_exponent_one() {
let three = Computable::constant(bin(3, 0));
let result = three.pow(1);
let bounds = result.bounds().expect("bounds should succeed");
let expected = bin(3, 0);
assert_eq!(unwrap_finite(bounds.small()), expected);
assert_eq!(unwrap_finite(&bounds.large()), expected);
}
#[test]
fn pow_exponent_zero() {
let three = Computable::constant(bin(3, 0));
let result = three.pow(0);
let bounds = result.bounds().expect("bounds should succeed");
let expected = bin(1, 0);
assert_eq!(unwrap_finite(bounds.small()), expected);
assert_eq!(unwrap_finite(&bounds.large()), expected);
}
#[test]
fn pow_zero_to_zero() {
let zero = Computable::constant(bin(0, 0));
let result = zero.pow(0);
let bounds = result.bounds().expect("bounds should succeed");
let expected = bin(1, 0);
assert_eq!(unwrap_finite(bounds.small()), expected);
assert_eq!(unwrap_finite(&bounds.large()), expected);
}
#[test]
fn pow_in_expression() {
let two_sq = Computable::constant(bin(2, 0)).pow(2);
let three_sq = Computable::constant(bin(3, 0)).pow(2);
let sum = two_sq + three_sq;
let tolerance_exp = XUsize::Finite(8);
let bounds = sum
.refine_to_default(tolerance_exp)
.expect("refine_to should succeed");
let expected = bin(13, 0);
assert_bounds_contain_expected(&bounds, &expected, &tolerance_exp);
}
#[test]
fn pow_with_sqrt() {
let two = Computable::constant(bin(2, 0));
let sqrt_two = two.nth_root(std::num::NonZeroU32::new(2).expect("2 is non-zero"));
let squared = sqrt_two.pow(2);
let tolerance_exp = XUsize::Finite(8);
let bounds = squared
.refine_to_default(tolerance_exp)
.expect("refine_to should succeed");
let expected = bin(2, 0);
assert_bounds_contain_expected(&bounds, &expected, &tolerance_exp);
}
#[test]
fn pow_of_zero() {
let zero = Computable::constant(bin(0, 0));
let squared = zero.pow(2);
let bounds = squared.bounds().expect("bounds should succeed");
assert!(bounds.small().is_zero());
assert!(bounds.large().is_zero());
}
}