use crate::{CheckedSquareRoot, CtEq, CtOption, FloorSquareRoot, Limb, NonZero, Uint};
impl<const LIMBS: usize> Uint<LIMBS> {
#[deprecated(since = "0.7.0", note = "please use `floor_sqrt` instead")]
#[must_use]
pub const fn sqrt(&self) -> Self {
self.floor_sqrt()
}
#[must_use]
pub const fn floor_sqrt(&self) -> Self {
let self_is_nz = self.is_nonzero();
let root_nz = NonZero(Self::select(&Self::ONE, self, self_is_nz))
.floor_sqrt()
.get_copy();
Self::select(&Self::ZERO, &root_nz, self_is_nz)
}
#[deprecated(since = "0.7.0", note = "please use `floor_sqrt_vartime` instead")]
#[must_use]
pub const fn sqrt_vartime(&self) -> Self {
self.floor_sqrt_vartime()
}
#[must_use]
pub const fn floor_sqrt_vartime(&self) -> Self {
if self.is_zero_vartime() {
Self::ZERO
} else {
NonZero(*self).floor_sqrt_vartime().get_copy()
}
}
#[must_use]
pub const fn wrapping_sqrt(&self) -> Self {
self.floor_sqrt()
}
#[must_use]
pub const fn wrapping_sqrt_vartime(&self) -> Self {
self.floor_sqrt_vartime()
}
#[must_use]
pub fn checked_sqrt(&self) -> CtOption<Self> {
let self_is_nz = self.is_nonzero();
NonZero(Self::select(&Self::ONE, self, self_is_nz))
.checked_sqrt()
.map(|nz| Self::select(&Self::ZERO, nz.as_ref(), self_is_nz))
}
pub fn checked_sqrt_vartime(&self) -> Option<Self> {
if self.is_zero_vartime() {
Some(Self::ZERO)
} else {
NonZero(*self).checked_sqrt_vartime().map(NonZero::get)
}
}
}
impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
#[must_use]
pub const fn floor_sqrt(&self) -> Self {
let rt_bits = self.0.bits().div_ceil(2);
let mut x = Uint::<LIMBS>::ZERO.set_bit_vartime(rt_bits, true);
let mut q = self.0.shr(rt_bits);
let mut i = 1;
loop {
x = Uint::select(&x.wrapping_add(&q).shr1(), &x, Uint::lt(&x, &q));
i += 1;
if i >= Uint::<LIMBS>::LOG2_BITS + 2 {
return x.to_nz().expect_copied("ensured non-zero");
}
(q, _) = self.0.div_rem(x.to_nz().expect_ref("ensured non-zero"));
}
}
#[must_use]
pub const fn floor_sqrt_vartime(&self) -> Self {
let bits = self.0.bits_vartime();
if bits <= Limb::BITS {
let rt = self.0.limbs[0].0.isqrt();
return Uint::from_word(rt)
.to_nz()
.expect_copied("ensured non-zero");
}
let rt_bits = bits.div_ceil(2);
let mut x = Uint::ZERO.set_bit_vartime(rt_bits, true);
let mut q = self.0.shr_vartime(rt_bits);
loop {
if q.cmp_vartime(&x).is_ge() {
return x.to_nz().expect_copied("ensured non-zero");
}
x = x.wrapping_add(&q).shr_vartime(1);
q = self
.0
.wrapping_div_vartime(x.to_nz().expect_ref("ensured non-zero"));
}
}
#[must_use]
pub fn checked_sqrt(&self) -> CtOption<Self> {
let r = self.floor_sqrt();
let s = r.wrapping_square();
CtOption::new(r, self.0.ct_eq(&s))
}
#[must_use]
pub fn checked_sqrt_vartime(&self) -> Option<Self> {
let r = self.floor_sqrt_vartime();
let s = r.wrapping_square();
if self.0.cmp_vartime(&s).is_eq() {
Some(r)
} else {
None
}
}
}
impl<const LIMBS: usize> CheckedSquareRoot for Uint<LIMBS> {
type Output = Self;
fn checked_sqrt(&self) -> CtOption<Self> {
self.checked_sqrt()
}
fn checked_sqrt_vartime(&self) -> Option<Self> {
self.checked_sqrt_vartime()
}
}
impl<const LIMBS: usize> FloorSquareRoot for Uint<LIMBS> {
fn floor_sqrt(&self) -> Self {
self.floor_sqrt()
}
fn floor_sqrt_vartime(&self) -> Self {
self.floor_sqrt_vartime()
}
}
impl<const LIMBS: usize> CheckedSquareRoot for NonZero<Uint<LIMBS>> {
type Output = Self;
fn checked_sqrt(&self) -> CtOption<Self> {
self.checked_sqrt()
}
fn checked_sqrt_vartime(&self) -> Option<Self> {
self.checked_sqrt_vartime()
}
}
impl<const LIMBS: usize> FloorSquareRoot for NonZero<Uint<LIMBS>> {
fn floor_sqrt(&self) -> Self {
self.floor_sqrt()
}
fn floor_sqrt_vartime(&self) -> Self {
self.floor_sqrt_vartime()
}
}
#[cfg(test)]
#[allow(clippy::integer_division_remainder_used, reason = "test")]
mod tests {
use crate::{Limb, U192, U256};
#[cfg(feature = "rand_core")]
use {
crate::{Random, U512},
chacha20::ChaCha8Rng,
rand_core::{Rng, SeedableRng},
};
#[test]
fn edge() {
assert_eq!(U256::ZERO.floor_sqrt(), U256::ZERO);
assert_eq!(U256::ONE.floor_sqrt(), U256::ONE);
let mut half = U256::ZERO;
for i in 0..half.limbs.len() / 2 {
half.limbs[i] = Limb::MAX;
}
assert_eq!(U256::MAX.floor_sqrt(), half);
assert_eq!(
U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").floor_sqrt(),
U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
);
assert_eq!(
U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d")
.floor_sqrt_vartime(),
U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
);
assert_eq!(
U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
.floor_sqrt(),
U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
);
assert_eq!(
U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
.floor_sqrt_vartime(),
U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
);
}
#[test]
fn edge_vartime() {
assert_eq!(U256::ZERO.floor_sqrt_vartime(), U256::ZERO);
assert_eq!(U256::ONE.floor_sqrt_vartime(), U256::ONE);
let mut half = U256::ZERO;
for i in 0..half.limbs.len() / 2 {
half.limbs[i] = Limb::MAX;
}
assert_eq!(U256::MAX.floor_sqrt_vartime(), half);
}
#[test]
fn simple() {
let tests = [
(4u8, 2u8),
(9, 3),
(16, 4),
(25, 5),
(36, 6),
(49, 7),
(64, 8),
(81, 9),
(100, 10),
(121, 11),
(144, 12),
(169, 13),
];
for (a, e) in &tests {
let l = U256::from(*a);
let r = U256::from(*e);
assert_eq!(l.floor_sqrt(), r);
assert_eq!(l.floor_sqrt_vartime(), r);
assert!(l.checked_sqrt().is_some().to_bool());
assert!(l.checked_sqrt_vartime().is_some());
}
}
#[test]
fn nonsquares() {
assert_eq!(U256::from(2u8).floor_sqrt(), U256::from(1u8));
assert!(!U256::from(2u8).checked_sqrt().is_some().to_bool());
assert_eq!(U256::from(3u8).floor_sqrt(), U256::from(1u8));
assert!(!U256::from(3u8).checked_sqrt().is_some().to_bool());
assert_eq!(U256::from(5u8).floor_sqrt(), U256::from(2u8));
assert_eq!(U256::from(6u8).floor_sqrt(), U256::from(2u8));
assert_eq!(U256::from(7u8).floor_sqrt(), U256::from(2u8));
assert_eq!(U256::from(8u8).floor_sqrt(), U256::from(2u8));
assert_eq!(U256::from(10u8).floor_sqrt(), U256::from(3u8));
}
#[test]
fn nonsquares_vartime() {
assert_eq!(U256::from(2u8).floor_sqrt_vartime(), U256::from(1u8));
assert!(U256::from(2u8).checked_sqrt_vartime().is_none());
assert_eq!(U256::from(3u8).floor_sqrt_vartime(), U256::from(1u8));
assert!(U256::from(3u8).checked_sqrt_vartime().is_none());
assert_eq!(U256::from(5u8).floor_sqrt_vartime(), U256::from(2u8));
assert_eq!(U256::from(6u8).floor_sqrt_vartime(), U256::from(2u8));
assert_eq!(U256::from(7u8).floor_sqrt_vartime(), U256::from(2u8));
assert_eq!(U256::from(8u8).floor_sqrt_vartime(), U256::from(2u8));
assert_eq!(U256::from(10u8).floor_sqrt_vartime(), U256::from(3u8));
}
#[cfg(feature = "rand_core")]
#[test]
fn fuzz() {
use crate::{CheckedSquareRoot, FloorSquareRoot};
let mut rng = ChaCha8Rng::from_seed([7u8; 32]);
for _ in 0..50 {
let t = u64::from(rng.next_u32());
let s = U256::from(t);
let s2 = s.checked_square().unwrap();
assert_eq!(FloorSquareRoot::floor_sqrt(&s2), s);
assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&s2), s);
assert!(CheckedSquareRoot::checked_sqrt(&s2).is_some().to_bool());
assert!(CheckedSquareRoot::checked_sqrt_vartime(&s2).is_some());
if let Some(nz) = s2.to_nz().into_option() {
assert_eq!(FloorSquareRoot::floor_sqrt(&nz).get(), s);
assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&nz).get(), s);
assert!(CheckedSquareRoot::checked_sqrt(&nz).is_some().to_bool());
assert!(CheckedSquareRoot::checked_sqrt_vartime(&nz).is_some());
}
}
for _ in 0..50 {
let s = U256::random_from_rng(&mut rng);
let mut s2 = U512::ZERO;
s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
assert_eq!(s.concatenating_square().floor_sqrt(), s2);
assert_eq!(s.concatenating_square().floor_sqrt_vartime(), s2);
}
}
}