use core::num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8};
macro_rules! impl_sqrt {
($u:ident, $NZ:ident) => {
pub const fn $u(val: $NZ, frac_nbits: u32) -> $u {
let int_nbits = $u::BITS - frac_nbits;
let odd_frac_nbits = frac_nbits % 2 != 0;
let leading = val.leading_zeros();
let sig_int_pairs = if odd_frac_nbits {
((int_nbits + 1) / 2) as i32 - ((leading + 1) / 2) as i32
} else {
(int_nbits / 2) as i32 - (leading / 2) as i32
};
let mut i = 1;
let mut q_i = 1 << ($u::BITS - 2);
let mut next_bit = q_i >> 1;
let mut y_i = val.get();
let input_shl = int_nbits as i32 - sig_int_pairs * 2;
if input_shl < 0 {
debug_assert!(input_shl == -1);
y_i -= 2 * q_i;
next_bit >>= 1;
i += 1;
} else {
y_i <<= input_shl;
y_i -= q_i;
};
let iters = (frac_nbits as i32 - 1 + sig_int_pairs) as u32;
while i <= iters {
let d = next_bit >> 1;
if d == 0 {
if i == iters {
debug_assert!(int_nbits as i32 - 1 - sig_int_pairs == 0);
if q_i < y_i {
q_i += 1;
}
return q_i;
}
debug_assert!(i == iters - 1);
debug_assert!(int_nbits as i32 - 1 - sig_int_pairs == -1);
if q_i < y_i {
y_i -= q_i + 1;
y_i *= 2;
y_i += 1;
q_i += 1;
} else {
y_i *= 2;
}
if q_i < y_i {
q_i = (q_i << 1) + 1;
} else {
q_i <<= 1;
}
return q_i;
}
if q_i + d <= y_i {
y_i -= q_i + d;
q_i += next_bit;
}
y_i *= 2;
next_bit = d;
i += 1;
}
let result_shr = int_nbits as i32 - 1 - sig_int_pairs;
q_i >> result_shr
}
};
}
impl_sqrt! { u8, NonZeroU8 }
impl_sqrt! { u16, NonZeroU16 }
impl_sqrt! { u32, NonZeroU32 }
impl_sqrt! { u64, NonZeroU64 }
impl_sqrt! { u128, NonZeroU128 }
#[cfg(test)]
mod tests {
use crate::types::extra::{
U0, U1, U125, U126, U127, U128, U13, U14, U15, U16, U17, U29, U3, U30, U31, U32, U33, U4,
U5, U6, U61, U62, U63, U64, U65, U7, U8, U9,
};
use crate::{
FixedI128, FixedI16, FixedI32, FixedI64, FixedI8, FixedU128, FixedU16, FixedU32, FixedU64,
FixedU8,
};
macro_rules! check_sqrt {
($val:expr) => {{
let sqrt = $val.sqrt();
assert!(sqrt * sqrt <= $val);
let delta = $val.wrapping_neg().wrapping_sub(!$val);
if let Some(sqrt_delta) = sqrt.checked_add(delta) {
if let Some(prod) = sqrt_delta.checked_mul(sqrt_delta) {
assert!(prod >= $val);
}
}
}};
}
#[test]
fn check_max_8() {
check_sqrt!(FixedU8::<U0>::MAX);
check_sqrt!(FixedU8::<U1>::MAX);
check_sqrt!(FixedU8::<U3>::MAX);
check_sqrt!(FixedU8::<U4>::MAX);
check_sqrt!(FixedU8::<U5>::MAX);
check_sqrt!(FixedU8::<U7>::MAX);
check_sqrt!(FixedU8::<U8>::MAX);
assert_eq!(FixedU8::<U8>::MAX.sqrt(), FixedU8::<U8>::MAX);
check_sqrt!(FixedI8::<U0>::MAX);
check_sqrt!(FixedI8::<U1>::MAX);
check_sqrt!(FixedI8::<U3>::MAX);
check_sqrt!(FixedI8::<U4>::MAX);
check_sqrt!(FixedI8::<U5>::MAX);
check_sqrt!(FixedI8::<U7>::MAX);
assert!(FixedI8::<U8>::MAX.checked_sqrt().is_none());
}
#[test]
fn check_max_16() {
check_sqrt!(FixedU16::<U0>::MAX);
check_sqrt!(FixedU16::<U1>::MAX);
check_sqrt!(FixedU16::<U7>::MAX);
check_sqrt!(FixedU16::<U8>::MAX);
check_sqrt!(FixedU16::<U9>::MAX);
check_sqrt!(FixedU16::<U15>::MAX);
check_sqrt!(FixedU16::<U16>::MAX);
assert_eq!(FixedU16::<U16>::MAX.sqrt(), FixedU16::<U16>::MAX);
check_sqrt!(FixedI16::<U0>::MAX);
check_sqrt!(FixedI16::<U1>::MAX);
check_sqrt!(FixedI16::<U7>::MAX);
check_sqrt!(FixedI16::<U8>::MAX);
check_sqrt!(FixedI16::<U9>::MAX);
check_sqrt!(FixedI16::<U15>::MAX);
assert!(FixedI16::<U16>::MAX.checked_sqrt().is_none());
}
#[test]
fn check_max_32() {
check_sqrt!(FixedU32::<U0>::MAX);
check_sqrt!(FixedU32::<U1>::MAX);
check_sqrt!(FixedU32::<U15>::MAX);
check_sqrt!(FixedU32::<U16>::MAX);
check_sqrt!(FixedU32::<U17>::MAX);
check_sqrt!(FixedU32::<U31>::MAX);
check_sqrt!(FixedU32::<U32>::MAX);
assert_eq!(FixedU32::<U32>::MAX.sqrt(), FixedU32::<U32>::MAX);
check_sqrt!(FixedI32::<U0>::MAX);
check_sqrt!(FixedI32::<U1>::MAX);
check_sqrt!(FixedI32::<U15>::MAX);
check_sqrt!(FixedI32::<U16>::MAX);
check_sqrt!(FixedI32::<U17>::MAX);
check_sqrt!(FixedI32::<U31>::MAX);
assert!(FixedI32::<U32>::MAX.checked_sqrt().is_none());
}
#[test]
fn check_max_64() {
check_sqrt!(FixedU64::<U0>::MAX);
check_sqrt!(FixedU64::<U1>::MAX);
check_sqrt!(FixedU64::<U31>::MAX);
check_sqrt!(FixedU64::<U32>::MAX);
check_sqrt!(FixedU64::<U33>::MAX);
check_sqrt!(FixedU64::<U63>::MAX);
check_sqrt!(FixedU64::<U64>::MAX);
assert_eq!(FixedU64::<U64>::MAX.sqrt(), FixedU64::<U64>::MAX);
check_sqrt!(FixedI64::<U0>::MAX);
check_sqrt!(FixedI64::<U1>::MAX);
check_sqrt!(FixedI64::<U31>::MAX);
check_sqrt!(FixedI64::<U32>::MAX);
check_sqrt!(FixedI64::<U33>::MAX);
check_sqrt!(FixedI64::<U63>::MAX);
assert!(FixedI64::<U64>::MAX.checked_sqrt().is_none());
}
#[test]
fn check_max_128() {
check_sqrt!(FixedU128::<U0>::MAX);
check_sqrt!(FixedU128::<U1>::MAX);
check_sqrt!(FixedU128::<U63>::MAX);
check_sqrt!(FixedU128::<U64>::MAX);
check_sqrt!(FixedU128::<U65>::MAX);
check_sqrt!(FixedU128::<U127>::MAX);
check_sqrt!(FixedU128::<U128>::MAX);
assert_eq!(FixedU128::<U128>::MAX.sqrt(), FixedU128::<U128>::MAX);
check_sqrt!(FixedI128::<U0>::MAX);
check_sqrt!(FixedI128::<U1>::MAX);
check_sqrt!(FixedI128::<U63>::MAX);
check_sqrt!(FixedI128::<U64>::MAX);
check_sqrt!(FixedI128::<U65>::MAX);
check_sqrt!(FixedI128::<U127>::MAX);
assert!(FixedI128::<U128>::MAX.checked_sqrt().is_none());
}
#[test]
fn check_two_8() {
assert_eq!(FixedU8::<U0>::from_num(2).sqrt(), FixedU8::<U0>::SQRT_2);
assert_eq!(FixedU8::<U1>::from_num(2).sqrt(), FixedU8::<U1>::SQRT_2);
assert_eq!(FixedU8::<U3>::from_num(2).sqrt(), FixedU8::<U3>::SQRT_2);
assert_eq!(FixedU8::<U4>::from_num(2).sqrt(), FixedU8::<U4>::SQRT_2);
assert_eq!(FixedU8::<U5>::from_num(2).sqrt(), FixedU8::<U5>::SQRT_2);
assert_eq!(FixedU8::<U6>::from_num(2).sqrt(), FixedU8::<U6>::SQRT_2);
assert!(
FixedU8::<U7>::MAX.sqrt() == FixedU8::<U7>::SQRT_2 - FixedU8::<U7>::DELTA
|| FixedU8::<U7>::MAX.sqrt() == FixedU8::<U7>::SQRT_2
);
assert_eq!(FixedI8::<U0>::from_num(2).sqrt(), FixedI8::<U0>::SQRT_2);
assert_eq!(FixedI8::<U1>::from_num(2).sqrt(), FixedI8::<U1>::SQRT_2);
assert_eq!(FixedI8::<U3>::from_num(2).sqrt(), FixedI8::<U3>::SQRT_2);
assert_eq!(FixedI8::<U4>::from_num(2).sqrt(), FixedI8::<U4>::SQRT_2);
assert_eq!(FixedI8::<U5>::from_num(2).sqrt(), FixedI8::<U5>::SQRT_2);
assert!(
FixedI8::<U6>::MAX.sqrt() == FixedI8::<U6>::SQRT_2 - FixedI8::<U6>::DELTA
|| FixedI8::<U6>::MAX.sqrt() == FixedI8::<U6>::SQRT_2
);
}
#[test]
fn check_two_16() {
assert_eq!(FixedU16::<U0>::from_num(2).sqrt(), FixedU16::<U0>::SQRT_2);
assert_eq!(FixedU16::<U1>::from_num(2).sqrt(), FixedU16::<U1>::SQRT_2);
assert_eq!(FixedU16::<U7>::from_num(2).sqrt(), FixedU16::<U7>::SQRT_2);
assert_eq!(FixedU16::<U8>::from_num(2).sqrt(), FixedU16::<U8>::SQRT_2);
assert_eq!(FixedU16::<U9>::from_num(2).sqrt(), FixedU16::<U9>::SQRT_2);
assert_eq!(FixedU16::<U13>::from_num(2).sqrt(), FixedU16::<U13>::SQRT_2);
assert_eq!(FixedU16::<U14>::from_num(2).sqrt(), FixedU16::<U14>::SQRT_2);
assert!(
FixedU16::<U15>::MAX.sqrt() == FixedU16::<U15>::SQRT_2 - FixedU16::<U15>::DELTA
|| FixedU16::<U15>::MAX.sqrt() == FixedU16::<U15>::SQRT_2
);
assert_eq!(FixedI16::<U0>::from_num(2).sqrt(), FixedI16::<U0>::SQRT_2);
assert_eq!(FixedI16::<U1>::from_num(2).sqrt(), FixedI16::<U1>::SQRT_2);
assert_eq!(FixedI16::<U7>::from_num(2).sqrt(), FixedI16::<U7>::SQRT_2);
assert_eq!(FixedI16::<U8>::from_num(2).sqrt(), FixedI16::<U8>::SQRT_2);
assert_eq!(FixedI16::<U9>::from_num(2).sqrt(), FixedI16::<U9>::SQRT_2);
assert_eq!(FixedI16::<U13>::from_num(2).sqrt(), FixedI16::<U13>::SQRT_2);
assert!(
FixedI16::<U14>::MAX.sqrt() == FixedI16::<U14>::SQRT_2 - FixedI16::<U14>::DELTA
|| FixedI16::<U14>::MAX.sqrt() == FixedI16::<U14>::SQRT_2
);
}
#[test]
fn check_two_32() {
assert_eq!(FixedU32::<U0>::from_num(2).sqrt(), FixedU32::<U0>::SQRT_2);
assert_eq!(FixedU32::<U1>::from_num(2).sqrt(), FixedU32::<U1>::SQRT_2);
assert_eq!(FixedU32::<U15>::from_num(2).sqrt(), FixedU32::<U15>::SQRT_2);
assert_eq!(FixedU32::<U16>::from_num(2).sqrt(), FixedU32::<U16>::SQRT_2);
assert_eq!(FixedU32::<U17>::from_num(2).sqrt(), FixedU32::<U17>::SQRT_2);
assert_eq!(FixedU32::<U29>::from_num(2).sqrt(), FixedU32::<U29>::SQRT_2);
assert_eq!(FixedU32::<U30>::from_num(2).sqrt(), FixedU32::<U30>::SQRT_2);
assert!(
FixedU32::<U31>::MAX.sqrt() == FixedU32::<U31>::SQRT_2 - FixedU32::<U31>::DELTA
|| FixedU32::<U31>::MAX.sqrt() == FixedU32::<U31>::SQRT_2
);
assert_eq!(FixedI32::<U0>::from_num(2).sqrt(), FixedI32::<U0>::SQRT_2);
assert_eq!(FixedI32::<U1>::from_num(2).sqrt(), FixedI32::<U1>::SQRT_2);
assert_eq!(FixedI32::<U15>::from_num(2).sqrt(), FixedI32::<U15>::SQRT_2);
assert_eq!(FixedI32::<U16>::from_num(2).sqrt(), FixedI32::<U16>::SQRT_2);
assert_eq!(FixedI32::<U17>::from_num(2).sqrt(), FixedI32::<U17>::SQRT_2);
assert_eq!(FixedI32::<U29>::from_num(2).sqrt(), FixedI32::<U29>::SQRT_2);
assert!(
FixedI32::<U30>::MAX.sqrt() == FixedI32::<U30>::SQRT_2 - FixedI32::<U30>::DELTA
|| FixedI32::<U30>::MAX.sqrt() == FixedI32::<U30>::SQRT_2
);
}
#[test]
fn check_two_64() {
assert_eq!(FixedU64::<U0>::from_num(2).sqrt(), FixedU64::<U0>::SQRT_2);
assert_eq!(FixedU64::<U1>::from_num(2).sqrt(), FixedU64::<U1>::SQRT_2);
assert_eq!(FixedU64::<U31>::from_num(2).sqrt(), FixedU64::<U31>::SQRT_2);
assert_eq!(FixedU64::<U32>::from_num(2).sqrt(), FixedU64::<U32>::SQRT_2);
assert_eq!(FixedU64::<U33>::from_num(2).sqrt(), FixedU64::<U33>::SQRT_2);
assert_eq!(FixedU64::<U61>::from_num(2).sqrt(), FixedU64::<U61>::SQRT_2);
assert_eq!(FixedU64::<U62>::from_num(2).sqrt(), FixedU64::<U62>::SQRT_2);
assert!(
FixedU64::<U63>::MAX.sqrt() == FixedU64::<U63>::SQRT_2 - FixedU64::<U63>::DELTA
|| FixedU64::<U63>::MAX.sqrt() == FixedU64::<U63>::SQRT_2
);
assert_eq!(FixedI64::<U0>::from_num(2).sqrt(), FixedI64::<U0>::SQRT_2);
assert_eq!(FixedI64::<U1>::from_num(2).sqrt(), FixedI64::<U1>::SQRT_2);
assert_eq!(FixedI64::<U31>::from_num(2).sqrt(), FixedI64::<U31>::SQRT_2);
assert_eq!(FixedI64::<U32>::from_num(2).sqrt(), FixedI64::<U32>::SQRT_2);
assert_eq!(FixedI64::<U33>::from_num(2).sqrt(), FixedI64::<U33>::SQRT_2);
assert_eq!(FixedI64::<U61>::from_num(2).sqrt(), FixedI64::<U61>::SQRT_2);
assert!(
FixedI64::<U62>::MAX.sqrt() == FixedI64::<U62>::SQRT_2 - FixedI64::<U62>::DELTA
|| FixedI64::<U62>::MAX.sqrt() == FixedI64::<U62>::SQRT_2
);
}
#[test]
fn check_two_128() {
assert_eq!(FixedU128::<U0>::from_num(2).sqrt(), FixedU128::<U0>::SQRT_2);
assert_eq!(FixedU128::<U1>::from_num(2).sqrt(), FixedU128::<U1>::SQRT_2);
assert_eq!(
FixedU128::<U63>::from_num(2).sqrt(),
FixedU128::<U63>::SQRT_2
);
assert_eq!(
FixedU128::<U64>::from_num(2).sqrt(),
FixedU128::<U64>::SQRT_2
);
assert_eq!(
FixedU128::<U65>::from_num(2).sqrt(),
FixedU128::<U65>::SQRT_2
);
assert_eq!(
FixedU128::<U125>::from_num(2).sqrt(),
FixedU128::<U125>::SQRT_2
);
assert_eq!(
FixedU128::<U126>::from_num(2).sqrt(),
FixedU128::<U126>::SQRT_2
);
assert!(
FixedU128::<U127>::MAX.sqrt() == FixedU128::<U127>::SQRT_2 - FixedU128::<U127>::DELTA
|| FixedU128::<U127>::MAX.sqrt() == FixedU128::<U127>::SQRT_2
);
assert_eq!(FixedI128::<U0>::from_num(2).sqrt(), FixedI128::<U0>::SQRT_2);
assert_eq!(FixedI128::<U1>::from_num(2).sqrt(), FixedI128::<U1>::SQRT_2);
assert_eq!(
FixedI128::<U63>::from_num(2).sqrt(),
FixedI128::<U63>::SQRT_2
);
assert_eq!(
FixedI128::<U64>::from_num(2).sqrt(),
FixedI128::<U64>::SQRT_2
);
assert_eq!(
FixedI128::<U65>::from_num(2).sqrt(),
FixedI128::<U65>::SQRT_2
);
assert_eq!(
FixedI128::<U125>::from_num(2).sqrt(),
FixedI128::<U125>::SQRT_2
);
assert!(
FixedI128::<U126>::MAX.sqrt() == FixedI128::<U126>::SQRT_2 - FixedI128::<U126>::DELTA
|| FixedI128::<U126>::MAX.sqrt() == FixedI128::<U126>::SQRT_2
);
}
}