use super::{const_set_bit, FixedUInt, MachineWord};
use crate::const_numtraits::{ConstIsqrt, ConstPrimInt, ConstZero};
use crate::machineword::ConstMachineWord;
c0nst::c0nst! {
impl<T: [c0nst] ConstMachineWord + MachineWord, const N: usize> c0nst ConstIsqrt for FixedUInt<T, N> {
fn isqrt(self) -> Self {
match ConstIsqrt::checked_isqrt(self) {
Some(v) => v,
None => unreachable!(),
}
}
fn checked_isqrt(self) -> Option<Self> {
if self.is_zero() {
return Some(Self::zero());
}
let mut result = Self::zero();
let bit_len = Self::BIT_SIZE - ConstPrimInt::leading_zeros(self) as usize;
let start_bit = bit_len.div_ceil(2);
let mut bit_pos = start_bit;
while bit_pos > 0 {
bit_pos -= 1;
let mut candidate = result;
const_set_bit(&mut candidate.array, bit_pos);
let square = candidate * candidate;
if square <= self {
result = candidate;
}
}
Some(result)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use num_traits::{CheckedAdd, CheckedMul};
#[test]
fn test_isqrt() {
type U16 = FixedUInt<u8, 2>;
assert_eq!(ConstIsqrt::isqrt(U16::from(0u8)), U16::from(0u8));
assert_eq!(ConstIsqrt::isqrt(U16::from(1u8)), U16::from(1u8));
assert_eq!(ConstIsqrt::isqrt(U16::from(4u8)), U16::from(2u8));
assert_eq!(ConstIsqrt::isqrt(U16::from(9u8)), U16::from(3u8));
assert_eq!(ConstIsqrt::isqrt(U16::from(16u8)), U16::from(4u8));
assert_eq!(ConstIsqrt::isqrt(U16::from(25u8)), U16::from(5u8));
assert_eq!(ConstIsqrt::isqrt(U16::from(100u8)), U16::from(10u8));
assert_eq!(ConstIsqrt::isqrt(U16::from(144u8)), U16::from(12u8));
assert_eq!(ConstIsqrt::isqrt(U16::from(2u8)), U16::from(1u8));
assert_eq!(ConstIsqrt::isqrt(U16::from(3u8)), U16::from(1u8));
assert_eq!(ConstIsqrt::isqrt(U16::from(5u8)), U16::from(2u8));
assert_eq!(ConstIsqrt::isqrt(U16::from(8u8)), U16::from(2u8));
assert_eq!(ConstIsqrt::isqrt(U16::from(10u8)), U16::from(3u8));
assert_eq!(ConstIsqrt::isqrt(U16::from(15u8)), U16::from(3u8));
assert_eq!(ConstIsqrt::isqrt(U16::from(24u8)), U16::from(4u8));
}
#[test]
fn test_isqrt_larger_values() {
type U16 = FixedUInt<u8, 2>;
assert_eq!(ConstIsqrt::isqrt(U16::from(10000u16)), U16::from(100u8));
assert_eq!(ConstIsqrt::isqrt(U16::from(65535u16)), U16::from(255u8)); assert_eq!(ConstIsqrt::isqrt(U16::from(65025u16)), U16::from(255u8)); }
#[test]
fn test_checked_isqrt() {
type U16 = FixedUInt<u8, 2>;
assert_eq!(
ConstIsqrt::checked_isqrt(U16::from(0u8)),
Some(U16::from(0u8))
);
assert_eq!(
ConstIsqrt::checked_isqrt(U16::from(16u8)),
Some(U16::from(4u8))
);
assert_eq!(
ConstIsqrt::checked_isqrt(U16::from(17u8)),
Some(U16::from(4u8))
);
}
#[test]
fn test_isqrt_correctness() {
type U16 = FixedUInt<u8, 2>;
for n in 0..=1000u16 {
let n_int = U16::from(n);
let r = ConstIsqrt::isqrt(n_int);
assert!(r * r <= n_int, "Failed: {}^2 > {}", r, n);
if let Some(r_plus_1) = r.checked_add(&U16::from(1u8)) {
if let Some(square) = r_plus_1.checked_mul(&r_plus_1) {
assert!(square > n_int, "Failed: {}^2 <= {}", r_plus_1, n);
}
}
}
}
#[test]
fn test_isqrt_wider_types() {
type U32x2 = FixedUInt<u32, 2>;
assert_eq!(ConstIsqrt::isqrt(U32x2::from(0u8)), U32x2::from(0u8));
assert_eq!(ConstIsqrt::isqrt(U32x2::from(1u8)), U32x2::from(1u8));
assert_eq!(ConstIsqrt::isqrt(U32x2::from(16u8)), U32x2::from(4u8));
assert_eq!(
ConstIsqrt::isqrt(U32x2::from(1000000u32)),
U32x2::from(1000u32)
);
assert_eq!(
ConstIsqrt::isqrt(U32x2::from(0xFFFFFFFFu32)),
U32x2::from(0xFFFFu32)
);
type U8x4 = FixedUInt<u8, 4>;
assert_eq!(ConstIsqrt::isqrt(U8x4::from(65536u32)), U8x4::from(256u32));
assert_eq!(
ConstIsqrt::isqrt(U8x4::from(1000000u32)),
U8x4::from(1000u32)
);
for n in (0..=10000u32).step_by(100) {
let n_int = U32x2::from(n);
let r = ConstIsqrt::isqrt(n_int);
assert!(r * r <= n_int, "Failed: {}^2 > {} for U32x2", r, n);
if let Some(r_plus_1) = r.checked_add(&U32x2::from(1u8)) {
if let Some(square) = r_plus_1.checked_mul(&r_plus_1) {
assert!(square > n_int, "Failed: {}^2 <= {} for U32x2", r_plus_1, n);
}
}
}
}
c0nst::c0nst! {
pub c0nst fn const_isqrt<T: [c0nst] ConstMachineWord + MachineWord, const N: usize>(
v: FixedUInt<T, N>,
) -> FixedUInt<T, N> {
ConstIsqrt::isqrt(v)
}
}
#[test]
fn test_const_isqrt() {
type U16 = FixedUInt<u8, 2>;
assert_eq!(const_isqrt(U16::from(16u8)), U16::from(4u8));
assert_eq!(const_isqrt(U16::from(100u8)), U16::from(10u8));
#[cfg(feature = "nightly")]
{
const SIXTEEN: U16 = FixedUInt { array: [16, 0] };
const RESULT: U16 = const_isqrt(SIXTEEN);
assert_eq!(RESULT, FixedUInt { array: [4, 0] });
}
}
}