use crate::core_type::D38;
use crate::mg_divide::mul_div_pow10;
impl<const SCALE: u32> D38<SCALE> {
#[inline]
#[must_use]
pub fn pow(self, exp: u32) -> Self {
let mut acc = Self::ONE;
let mut base = self;
let mut e = exp;
while e > 0 {
if e & 1 == 1 {
acc *= base;
}
e >>= 1;
if e > 0 {
base *= base;
}
}
acc
}
#[inline]
#[must_use]
pub fn powi(self, exp: i32) -> Self {
if exp >= 0 {
self.pow(exp as u32)
} else {
Self::ONE / self.pow(exp.unsigned_abs())
}
}
#[inline]
#[must_use]
pub fn powf_strict(self, exp: D38<SCALE>) -> Self {
use crate::d_w128_kernels::Fixed;
if self.to_bits() <= 0 {
return Self::ZERO;
}
let guard = crate::log_exp_strict::STRICT_GUARD;
let w = SCALE + guard;
let pow = 10u128.pow(guard);
let ln_x = crate::log_exp_strict::ln_fixed(
Fixed::from_u128_mag(self.to_bits() as u128, false).mul_u128(pow),
w,
);
let y_neg = exp.to_bits() < 0;
let y_w = Fixed::from_u128_mag(exp.to_bits().unsigned_abs(), false).mul_u128(pow);
let y_w = if y_neg { y_w.neg() } else { y_w };
let raw = crate::log_exp_strict::exp_fixed(y_w.mul(ln_x, w), w)
.round_to_i128(w, SCALE)
.expect("D38::powf: result overflows the representable range");
Self::from_bits(raw)
}
#[cfg(all(feature = "strict", not(feature = "fast")))]
#[inline]
#[must_use]
pub fn powf(self, exp: D38<SCALE>) -> Self {
self.powf_strict(exp)
}
#[inline]
#[must_use]
pub fn sqrt_strict(self) -> Self {
if self.to_bits() <= 0 {
return Self::ZERO;
}
let raw = self.to_bits() as u128;
let q = crate::mg_divide::sqrt_raw_correctly_rounded(raw, SCALE);
Self::from_bits(q as i128)
}
#[cfg(all(feature = "strict", not(feature = "fast")))]
#[inline]
#[must_use]
pub fn sqrt(self) -> Self {
self.sqrt_strict()
}
#[cfg(all(feature = "strict", not(feature = "fast")))]
#[inline]
#[must_use]
pub fn cbrt(self) -> Self {
self.cbrt_strict()
}
#[inline]
#[must_use]
pub fn cbrt_strict(self) -> Self {
let raw = self.to_bits();
if raw == 0 {
return Self::ZERO;
}
let negative = raw < 0;
let q = crate::mg_divide::cbrt_raw_correctly_rounded(raw.unsigned_abs(), SCALE);
let result = q as i128;
Self::from_bits(if negative { -result } else { result })
}
#[inline]
#[must_use]
pub fn hypot_strict(self, other: Self) -> Self {
let a = self.abs();
let b = other.abs();
let (large, small) = if a >= b { (a, b) } else { (b, a) };
if large == Self::ZERO {
Self::ZERO
} else {
let ratio = small / large;
let one_plus_sq = Self::ONE + ratio * ratio;
large * one_plus_sq.sqrt_strict()
}
}
#[cfg(all(feature = "strict", not(feature = "fast")))]
#[inline]
#[must_use]
pub fn hypot(self, other: Self) -> Self {
self.hypot_strict(other)
}
#[inline]
#[must_use]
pub fn checked_pow(self, exp: u32) -> Option<Self> {
let mut acc: i128 = Self::ONE.0;
let mut base: i128 = self.0;
let mut e = exp;
while e > 0 {
if e & 1 == 1 {
acc = mul_div_pow10::<SCALE>(acc, base)?;
}
e >>= 1;
if e > 0 {
base = mul_div_pow10::<SCALE>(base, base)?;
}
}
Some(Self(acc))
}
#[inline]
#[must_use]
pub fn wrapping_pow(self, exp: u32) -> Self {
let mut acc: i128 = Self::ONE.0;
let mut base: i128 = self.0;
let mut e = exp;
let mult = Self::multiplier();
while e > 0 {
if e & 1 == 1 {
acc = match mul_div_pow10::<SCALE>(acc, base) {
Some(q) => q,
None => acc.wrapping_mul(base).wrapping_div(mult),
};
}
e >>= 1;
if e > 0 {
base = match mul_div_pow10::<SCALE>(base, base) {
Some(q) => q,
None => base.wrapping_mul(base).wrapping_div(mult),
};
}
}
Self(acc)
}
#[inline]
#[must_use]
pub fn saturating_pow(self, exp: u32) -> Self {
if exp == 0 {
return Self::ONE;
}
let mut acc: i128 = Self::ONE.0;
let mut base: i128 = self.0;
let mut e = exp;
let result_negative_if_overflow = self.0 < 0 && (exp & 1) == 1;
while e > 0 {
if e & 1 == 1 {
match mul_div_pow10::<SCALE>(acc, base) {
Some(q) => acc = q,
None => {
return if result_negative_if_overflow {
Self::MIN
} else {
Self::MAX
};
}
}
}
e >>= 1;
if e > 0 {
match mul_div_pow10::<SCALE>(base, base) {
Some(q) => base = q,
None => {
return if result_negative_if_overflow {
Self::MIN
} else {
Self::MAX
};
}
}
}
}
Self(acc)
}
#[inline]
#[must_use]
pub fn overflowing_pow(self, exp: u32) -> (Self, bool) {
let mut acc: i128 = Self::ONE.0;
let mut base: i128 = self.0;
let mut e = exp;
let mut overflowed = false;
let mult = Self::multiplier();
while e > 0 {
if e & 1 == 1 {
acc = if let Some(q) = mul_div_pow10::<SCALE>(acc, base) { q } else {
overflowed = true;
acc.wrapping_mul(base).wrapping_div(mult)
};
}
e >>= 1;
if e > 0 {
base = if let Some(q) = mul_div_pow10::<SCALE>(base, base) { q } else {
overflowed = true;
base.wrapping_mul(base).wrapping_div(mult)
};
}
}
(Self(acc), overflowed)
}
}
#[cfg(test)]
mod tests {
use crate::core_type::D38s12;
#[cfg(feature = "std")]
const TWO_LSB: i128 = 2;
#[cfg(feature = "std")]
fn within_lsb(actual: D38s12, expected: D38s12, lsb: i128) -> bool {
let diff = (actual.to_bits() - expected.to_bits()).abs();
diff <= lsb
}
#[test]
fn pow_zero_is_one_for_nonzero() {
let v = D38s12::from_int(7);
assert_eq!(v.pow(0), D38s12::ONE);
}
#[test]
fn pow_one_is_self() {
let v = D38s12::from_int(7);
assert_eq!(v.pow(1), v);
}
#[test]
fn pow_two_matches_mul() {
let v = D38s12::from_int(13);
assert_eq!(v.pow(2), v * v);
}
#[test]
fn pow_two_matches_mul_fractional() {
let v = D38s12::from_bits(1_500_000_000_000);
assert_eq!(v.pow(2), v * v);
}
#[test]
fn pow_two_to_the_ten() {
let two = D38s12::from_int(2);
assert_eq!(two.pow(10), D38s12::from_int(1024));
}
#[test]
fn zero_pow_zero_is_one() {
assert_eq!(D38s12::ZERO.pow(0), D38s12::ONE);
}
#[test]
fn zero_pow_positive_is_zero() {
assert_eq!(D38s12::ZERO.pow(1), D38s12::ZERO);
assert_eq!(D38s12::ZERO.pow(5), D38s12::ZERO);
}
#[test]
fn one_pow_n_is_one() {
assert_eq!(D38s12::ONE.pow(0), D38s12::ONE);
assert_eq!(D38s12::ONE.pow(1), D38s12::ONE);
assert_eq!(D38s12::ONE.pow(100), D38s12::ONE);
}
#[test]
fn negative_one_pow_alternates() {
let neg_one = -D38s12::ONE;
assert_eq!(neg_one.pow(0), D38s12::ONE);
assert_eq!(neg_one.pow(1), neg_one);
assert_eq!(neg_one.pow(2), D38s12::ONE);
assert_eq!(neg_one.pow(3), neg_one);
}
#[test]
fn powi_zero_is_one() {
let v = D38s12::from_int(7);
assert_eq!(v.powi(0), D38s12::ONE);
}
#[test]
fn powi_one_is_self() {
let v = D38s12::from_int(7);
assert_eq!(v.powi(1), v);
}
#[test]
fn powi_minus_one_is_reciprocal() {
let v = D38s12::from_int(7);
assert_eq!(v.powi(-1), D38s12::ONE / v);
}
#[test]
fn powi_negative_matches_reciprocal_of_positive() {
let v = D38s12::from_int(2);
assert_eq!(v.powi(-3), D38s12::ONE / v.pow(3));
}
#[test]
fn powi_positive_matches_pow() {
let v = D38s12::from_int(3);
for e in 0_i32..6 {
assert_eq!(v.powi(e), v.pow(e as u32));
}
}
#[test]
fn powi_i32_min_for_one_base() {
assert_eq!(D38s12::ONE.powi(i32::MIN), D38s12::ONE);
}
#[cfg(feature = "std")]
#[test]
fn powf_half_matches_sqrt() {
let v = D38s12::from_int(4);
let half = D38s12::from_bits(500_000_000_000); let powf_result = v.powf(half);
let sqrt_result = v.sqrt();
assert!(
within_lsb(powf_result, sqrt_result, TWO_LSB),
"powf(0.5)={}, sqrt={}, diff={}",
powf_result.to_bits(),
sqrt_result.to_bits(),
(powf_result.to_bits() - sqrt_result.to_bits()).abs(),
);
}
#[cfg(all(feature = "std", any(not(feature = "strict"), feature = "fast")))]
#[test]
fn powf_two_matches_pow_two_within_lsb() {
let v = D38s12::from_int(7);
let two = D38s12::from_int(2);
assert!(within_lsb(v.powf(two), v.pow(2), TWO_LSB));
}
#[cfg(all(feature = "strict", not(feature = "fast")))]
#[test]
fn powf_two_matches_pow_two_within_lsb() {
let v = D38s12::from_int(7);
let two = D38s12::from_int(2);
assert!(within_lsb(v.powf(two), v.pow(2), 1));
for base in [2_i64, 3, 5, 11] {
let b = D38s12::from_int(base);
assert!(
within_lsb(b.powf(D38s12::from_int(3)), b.pow(3), 1),
"powf({base}, 3)"
);
}
}
#[cfg(feature = "std")]
#[test]
fn sqrt_zero_is_zero() {
assert_eq!(D38s12::ZERO.sqrt(), D38s12::ZERO);
}
#[cfg(feature = "std")]
#[test]
fn sqrt_one_is_one_bit_exact() {
assert_eq!(D38s12::ONE.sqrt(), D38s12::ONE);
}
#[cfg(feature = "std")]
#[test]
fn sqrt_four_is_two() {
let four = D38s12::from_int(4);
let two = D38s12::from_int(2);
assert_eq!(four.sqrt(), two);
}
#[test]
fn strict_sqrt_is_correctly_rounded() {
fn check<const S: u32>(raw: i128) {
let x = crate::core_type::D38::<S>::from_bits(raw);
let q = x.sqrt_strict().to_bits();
assert!(q >= 0, "sqrt result must be non-negative");
let mult = 10u128.pow(S);
let (n_hi, n_lo) = crate::mg_divide::mul2(raw as u128, mult);
let (qsq_hi, qsq_lo) = crate::mg_divide::mul2(q as u128, q as u128);
let q_u = q as u128;
let (uphi, uplo) = {
let (lo, c) = qsq_lo.overflowing_add(q_u);
(qsq_hi + c as u128, lo)
};
let n_le_upper = n_hi < uphi || (n_hi == uphi && n_lo <= uplo);
assert!(n_le_upper, "sqrt({raw} @ s{S}) = {q}: N exceeds (q+0.5)^2");
if q > 0 {
let (nphi, nplo) = {
let (lo, c) = n_lo.overflowing_add(q_u);
(n_hi + c as u128, lo)
};
let above_lower =
nphi > qsq_hi || (nphi == qsq_hi && nplo > qsq_lo);
assert!(above_lower, "sqrt({raw} @ s{S}) = {q}: N below (q-0.5)^2");
}
}
for &raw in &[
1_i128,
2,
3,
4,
5,
999_999_999_999,
1_000_000_000_000,
1_500_000_000_000,
123_456_789_012_345,
i128::MAX,
i128::MAX / 7,
] {
check::<0>(raw);
check::<6>(raw);
check::<12>(raw);
check::<19>(raw);
}
for &raw in &[1_i128, 2, 17, i128::MAX, i128::MAX / 3] {
check::<38>(raw);
}
}
#[test]
fn strict_cbrt_is_correctly_rounded() {
use i256::U256;
fn check<const S: u32>(raw: i128) {
let x = crate::core_type::D38::<S>::from_bits(raw);
let q = x.cbrt_strict().to_bits();
assert_eq!(q.signum(), raw.signum(), "cbrt sign mismatch");
let qa = q.unsigned_abs();
let ra = raw.unsigned_abs();
let m = U256::from(10u8).pow(2 * S);
let n = U256::from(ra) * m;
let eight_n = n << 3;
let two_q = U256::from(qa) * U256::from(2u8);
let upper = {
let t = two_q + U256::from(1u8);
t * t * t
};
assert!(eight_n <= upper, "cbrt({raw} @ s{S}) = {q}: 8N exceeds (2q+1)^3");
if qa > 0 {
let t = two_q - U256::from(1u8);
let lower = t * t * t;
assert!(eight_n > lower, "cbrt({raw} @ s{S}) = {q}: 8N at/below (2q-1)^3");
}
}
for &raw in &[
1_i128, 2, 7, 8, 9, 26, 27, 28,
999_999_999_999, 1_000_000_000_000, 123_456_789_012_345,
-8, -27, -1_000_000_000_000,
] {
check::<0>(raw);
check::<6>(raw);
check::<12>(raw);
}
for &raw in &[i128::MAX, i128::MIN + 1, i128::MAX / 11] {
check::<0>(raw);
check::<2>(raw);
}
}
#[cfg(feature = "std")]
#[test]
fn sqrt_of_square_recovers_abs() {
let v = D38s12::from_bits(1_234_567_890_123);
let squared = v * v;
let recovered = squared.sqrt();
let abs_v = v.abs();
assert!(
within_lsb(recovered, abs_v, TWO_LSB),
"sqrt({})={}, expected~={}, diff={}",
squared.to_bits(),
recovered.to_bits(),
abs_v.to_bits(),
(recovered.to_bits() - abs_v.to_bits()).abs(),
);
}
#[cfg(feature = "std")]
#[test]
fn sqrt_of_square_negative_recovers_abs() {
let v = -D38s12::from_bits(4_567_891_234_567);
let squared = v * v;
let recovered = squared.sqrt();
let abs_v = v.abs();
assert!(within_lsb(recovered, abs_v, TWO_LSB));
}
#[cfg(feature = "std")]
#[test]
fn sqrt_negative_saturates_to_zero() {
let v = -D38s12::from_int(4);
assert_eq!(v.sqrt(), D38s12::ZERO);
}
#[cfg(feature = "std")]
#[test]
fn cbrt_zero_is_zero() {
assert_eq!(D38s12::ZERO.cbrt(), D38s12::ZERO);
}
#[cfg(feature = "std")]
#[test]
fn cbrt_one_is_one() {
assert_eq!(D38s12::ONE.cbrt(), D38s12::ONE);
}
#[cfg(feature = "std")]
#[test]
fn cbrt_eight_is_two() {
let eight = D38s12::from_int(8);
let two = D38s12::from_int(2);
assert!(within_lsb(eight.cbrt(), two, TWO_LSB));
}
#[cfg(feature = "std")]
#[test]
fn cbrt_minus_eight_is_minus_two() {
let neg_eight = D38s12::from_int(-8);
let neg_two = D38s12::from_int(-2);
assert!(
within_lsb(neg_eight.cbrt(), neg_two, TWO_LSB),
"cbrt(-8) = {}, expected ~ {}",
neg_eight.cbrt().to_bits(),
neg_two.to_bits(),
);
}
#[test]
fn checked_pow_max_squared_is_none() {
assert!(D38s12::MAX.checked_pow(2).is_none());
}
#[test]
fn checked_pow_one_is_some_one() {
assert_eq!(D38s12::ONE.checked_pow(1_000_000), Some(D38s12::ONE));
assert_eq!(D38s12::ONE.checked_pow(0), Some(D38s12::ONE));
}
#[test]
fn checked_pow_matches_pow_when_no_overflow() {
let v = D38s12::from_int(3);
assert_eq!(v.checked_pow(0), Some(v.pow(0)));
assert_eq!(v.checked_pow(1), Some(v.pow(1)));
assert_eq!(v.checked_pow(5), Some(v.pow(5)));
}
#[test]
fn saturating_pow_max_squared_is_max() {
assert_eq!(D38s12::MAX.saturating_pow(2), D38s12::MAX);
}
#[test]
fn saturating_pow_min_cubed_is_min() {
assert_eq!(D38s12::MIN.saturating_pow(3), D38s12::MIN);
}
#[test]
fn saturating_pow_min_squared_is_max() {
assert_eq!(D38s12::MIN.saturating_pow(2), D38s12::MAX);
}
#[test]
fn saturating_pow_one_is_one() {
assert_eq!(D38s12::ONE.saturating_pow(1_000_000), D38s12::ONE);
}
#[test]
fn saturating_pow_zero_exp_is_one() {
assert_eq!(D38s12::MAX.saturating_pow(0), D38s12::ONE);
assert_eq!(D38s12::MIN.saturating_pow(0), D38s12::ONE);
}
#[test]
fn overflowing_pow_max_squared_flags_overflow() {
let (value, overflowed) = D38s12::MAX.overflowing_pow(2);
assert!(overflowed);
assert_eq!(value, D38s12::MAX.wrapping_pow(2));
}
#[test]
fn overflowing_pow_one_no_overflow() {
let (value, overflowed) = D38s12::ONE.overflowing_pow(1_000_000);
assert!(!overflowed);
assert_eq!(value, D38s12::ONE);
}
#[test]
fn overflowing_pow_zero_exp_no_overflow() {
let (value, overflowed) = D38s12::MAX.overflowing_pow(0);
assert!(!overflowed);
assert_eq!(value, D38s12::ONE);
}
#[test]
fn wrapping_pow_max_squared_matches_overflowing() {
let wrap = D38s12::MAX.wrapping_pow(2);
let (over, _flag) = D38s12::MAX.overflowing_pow(2);
assert_eq!(wrap, over);
}
#[test]
fn wrapping_pow_one_is_one() {
assert_eq!(D38s12::ONE.wrapping_pow(1_000_000), D38s12::ONE);
}
#[test]
fn wrapping_pow_matches_pow_when_no_overflow() {
let v = D38s12::from_int(3);
for e in 0..6 {
assert_eq!(v.wrapping_pow(e), v.pow(e));
}
}
#[test]
fn pow_two_property_safe_values() {
for raw in [
1_234_567_890_123_i128,
4_567_891_234_567_i128,
7_890_123_456_789_i128,
] {
let v = D38s12::from_bits(raw);
assert_eq!(v.pow(2), v * v, "raw bits {raw}");
}
}
#[test]
fn mul_add_zero_zero_zero_is_zero() {
let z = D38s12::ZERO;
assert_eq!(z.mul_add(z, z), D38s12::ZERO);
}
#[test]
fn mul_add_two_three_four_is_ten() {
let two = D38s12::from_int(2);
let three = D38s12::from_int(3);
let four = D38s12::from_int(4);
assert_eq!(two.mul_add(three, four), D38s12::from_int(10));
}
#[test]
fn mul_add_identity_collapses() {
let v = D38s12::from_int(7);
assert_eq!(v.mul_add(D38s12::ONE, D38s12::ZERO), v);
}
#[test]
fn mul_add_zero_factor_yields_addend() {
let v = D38s12::from_int(7);
let b = D38s12::from_int(13);
assert_eq!(v.mul_add(D38s12::ZERO, b), b);
}
#[test]
fn mul_add_matches_mul_then_add_safe_values() {
for (a_raw, b_raw, c_raw) in [
(1_234_567_890_123_i128, 2_345_678_901_234_i128, 3_456_789_012_345_i128),
(4_567_891_234_567_i128, 5_678_912_345_678_i128, 6_789_123_456_789_i128),
(7_890_123_456_789_i128, 8_901_234_567_891_i128, 9_012_345_678_912_i128),
] {
let a = D38s12::from_bits(a_raw);
let b = D38s12::from_bits(b_raw);
let c = D38s12::from_bits(c_raw);
assert_eq!(
a.mul_add(b, c),
a * b + c,
"raw bits ({a_raw}, {b_raw}, {c_raw})",
);
}
}
#[test]
fn mul_add_sign_propagates_through_factor() {
let a = D38s12::from_int(3);
let b = D38s12::from_int(5);
let c = D38s12::from_int(7);
assert_eq!((-a).mul_add(b, c), D38s12::from_int(-8));
}
#[cfg(feature = "std")]
const HYPOT_TOLERANCE_LSB: i128 = 32;
#[cfg(feature = "std")]
#[test]
fn hypot_three_four_is_five() {
let three = D38s12::from_int(3);
let four = D38s12::from_int(4);
let five = D38s12::from_int(5);
let result = three.hypot(four);
assert!(
within_lsb(result, five, HYPOT_TOLERANCE_LSB),
"hypot(3, 4)={}, expected~={}, diff={}",
result.to_bits(),
five.to_bits(),
(result.to_bits() - five.to_bits()).abs(),
);
}
#[cfg(feature = "std")]
#[test]
fn hypot_zero_zero_is_zero_bit_exact() {
assert_eq!(D38s12::ZERO.hypot(D38s12::ZERO), D38s12::ZERO);
}
#[cfg(feature = "std")]
#[test]
fn hypot_zero_x_is_abs_x() {
let x = D38s12::from_int(7);
let result = D38s12::ZERO.hypot(x);
assert!(
within_lsb(result, x.abs(), HYPOT_TOLERANCE_LSB),
"hypot(0, 7)={}, expected~={}",
result.to_bits(),
x.abs().to_bits(),
);
}
#[cfg(feature = "std")]
#[test]
fn hypot_x_zero_is_abs_x() {
let x = D38s12::from_int(-9);
let result = x.hypot(D38s12::ZERO);
assert!(
within_lsb(result, x.abs(), HYPOT_TOLERANCE_LSB),
"hypot(-9, 0)={}, expected~={}",
result.to_bits(),
x.abs().to_bits(),
);
}
#[cfg(feature = "std")]
#[test]
fn hypot_sign_invariant() {
let three = D38s12::from_int(3);
let four = D38s12::from_int(4);
let pos = three.hypot(four);
let neg_a = (-three).hypot(four);
let neg_b = three.hypot(-four);
let neg_both = (-three).hypot(-four);
assert_eq!(pos, neg_a);
assert_eq!(pos, neg_b);
assert_eq!(pos, neg_both);
}
#[cfg(feature = "std")]
#[test]
fn hypot_large_magnitudes_does_not_panic() {
let half_max = D38s12::from_bits(i128::MAX / 2);
let result = half_max.hypot(half_max);
assert!(result > D38s12::ZERO);
assert!(result >= half_max);
}
#[cfg(feature = "std")]
#[test]
fn hypot_matches_naive_sqrt_of_sum_of_squares() {
let a = D38s12::from_int(12);
let b = D38s12::from_int(13);
let h = a.hypot(b);
let naive = (a * a + b * b).sqrt();
assert!(
within_lsb(h, naive, HYPOT_TOLERANCE_LSB),
"hypot(12, 13)={}, naive sqrt(a^2+b^2)={}, diff={}",
h.to_bits(),
naive.to_bits(),
(h.to_bits() - naive.to_bits()).abs(),
);
}
}