1use crate::{CheckedSquareRoot, CtEq, CtOption, FloorSquareRoot, Limb, NonZero, Uint};
4
5impl<const LIMBS: usize> Uint<LIMBS> {
6 #[deprecated(since = "0.7.0", note = "please use `floor_sqrt` instead")]
10 #[must_use]
11 pub const fn sqrt(&self) -> Self {
12 self.floor_sqrt()
13 }
14
15 #[must_use]
19 pub const fn floor_sqrt(&self) -> Self {
20 let self_is_nz = self.is_nonzero();
21 let root_nz = NonZero(Self::select(&Self::ONE, self, self_is_nz))
22 .floor_sqrt()
23 .get_copy();
24 Self::select(&Self::ZERO, &root_nz, self_is_nz)
25 }
26
27 #[deprecated(since = "0.7.0", note = "please use `floor_sqrt_vartime` instead")]
33 #[must_use]
34 pub const fn sqrt_vartime(&self) -> Self {
35 self.floor_sqrt_vartime()
36 }
37
38 #[must_use]
44 pub const fn floor_sqrt_vartime(&self) -> Self {
45 if self.is_zero_vartime() {
46 Self::ZERO
47 } else {
48 NonZero(*self).floor_sqrt_vartime().get_copy()
49 }
50 }
51
52 #[must_use]
56 pub const fn wrapping_sqrt(&self) -> Self {
57 self.floor_sqrt()
58 }
59
60 #[must_use]
66 pub const fn wrapping_sqrt_vartime(&self) -> Self {
67 self.floor_sqrt_vartime()
68 }
69
70 #[must_use]
73 pub fn checked_sqrt(&self) -> CtOption<Self> {
74 let self_is_nz = self.is_nonzero();
75 NonZero(Self::select(&Self::ONE, self, self_is_nz))
76 .checked_sqrt()
77 .map(|nz| Self::select(&Self::ZERO, nz.as_ref(), self_is_nz))
78 }
79
80 pub fn checked_sqrt_vartime(&self) -> Option<Self> {
85 if self.is_zero_vartime() {
86 Some(Self::ZERO)
87 } else {
88 NonZero(*self).checked_sqrt_vartime().map(NonZero::get)
89 }
90 }
91}
92
93impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
94 #[must_use]
98 pub const fn floor_sqrt(&self) -> Self {
99 let rt_bits = self.0.bits().div_ceil(2);
106 let mut x = Uint::<LIMBS>::ZERO.set_bit_vartime(rt_bits, true);
109 let mut q = self.0.shr(rt_bits);
111 let mut i = 1;
113
114 loop {
115 x = Uint::select(&x.wrapping_add(&q).shr1(), &x, Uint::lt(&x, &q));
118
119 i += 1;
122 if i >= Uint::<LIMBS>::LOG2_BITS + 2 {
123 return x.to_nz().expect_copied("ensured non-zero");
124 }
125
126 (q, _) = self.0.div_rem(x.to_nz().expect_ref("ensured non-zero"));
127 }
128 }
129
130 #[must_use]
136 pub const fn floor_sqrt_vartime(&self) -> Self {
137 let bits = self.0.bits_vartime();
140 if bits <= Limb::BITS {
141 let rt = self.0.limbs[0].0.isqrt();
142 return Uint::from_word(rt)
143 .to_nz()
144 .expect_copied("ensured non-zero");
145 }
146 let rt_bits = bits.div_ceil(2);
147
148 let mut x = Uint::ZERO.set_bit_vartime(rt_bits, true);
151 let mut q = self.0.shr_vartime(rt_bits);
153
154 loop {
155 if q.cmp_vartime(&x).is_ge() {
157 return x.to_nz().expect_copied("ensured non-zero");
158 }
159 x = x.wrapping_add(&q).shr_vartime(1);
161 q = self
162 .0
163 .wrapping_div_vartime(x.to_nz().expect_ref("ensured non-zero"));
164 }
165 }
166
167 #[must_use]
170 pub fn checked_sqrt(&self) -> CtOption<Self> {
171 let r = self.floor_sqrt();
172 let s = r.wrapping_square();
173 CtOption::new(r, self.0.ct_eq(&s))
174 }
175
176 #[must_use]
179 pub fn checked_sqrt_vartime(&self) -> Option<Self> {
180 let r = self.floor_sqrt_vartime();
181 let s = r.wrapping_square();
182 if self.0.cmp_vartime(&s).is_eq() {
183 Some(r)
184 } else {
185 None
186 }
187 }
188}
189
190impl<const LIMBS: usize> CheckedSquareRoot for Uint<LIMBS> {
191 type Output = Self;
192
193 fn checked_sqrt(&self) -> CtOption<Self> {
194 self.checked_sqrt()
195 }
196
197 fn checked_sqrt_vartime(&self) -> Option<Self> {
198 self.checked_sqrt_vartime()
199 }
200}
201
202impl<const LIMBS: usize> FloorSquareRoot for Uint<LIMBS> {
203 fn floor_sqrt(&self) -> Self {
204 self.floor_sqrt()
205 }
206
207 fn floor_sqrt_vartime(&self) -> Self {
208 self.floor_sqrt_vartime()
209 }
210}
211
212impl<const LIMBS: usize> CheckedSquareRoot for NonZero<Uint<LIMBS>> {
213 type Output = Self;
214
215 fn checked_sqrt(&self) -> CtOption<Self> {
216 self.checked_sqrt()
217 }
218
219 fn checked_sqrt_vartime(&self) -> Option<Self> {
220 self.checked_sqrt_vartime()
221 }
222}
223
224impl<const LIMBS: usize> FloorSquareRoot for NonZero<Uint<LIMBS>> {
225 fn floor_sqrt(&self) -> Self {
226 self.floor_sqrt()
227 }
228
229 fn floor_sqrt_vartime(&self) -> Self {
230 self.floor_sqrt_vartime()
231 }
232}
233
234#[cfg(test)]
235#[allow(clippy::integer_division_remainder_used, reason = "test")]
236mod tests {
237 use crate::{Limb, U192, U256};
238
239 #[cfg(feature = "rand_core")]
240 use {
241 crate::{Random, U512},
242 chacha20::ChaCha8Rng,
243 rand_core::{Rng, SeedableRng},
244 };
245
246 #[test]
247 fn edge() {
248 assert_eq!(U256::ZERO.floor_sqrt(), U256::ZERO);
249 assert_eq!(U256::ONE.floor_sqrt(), U256::ONE);
250 let mut half = U256::ZERO;
251 for i in 0..half.limbs.len() / 2 {
252 half.limbs[i] = Limb::MAX;
253 }
254 assert_eq!(U256::MAX.floor_sqrt(), half);
255
256 assert_eq!(
260 U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").floor_sqrt(),
261 U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
262 );
263 assert_eq!(
264 U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d")
265 .floor_sqrt_vartime(),
266 U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
267 );
268
269 assert_eq!(
271 U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
272 .floor_sqrt(),
273 U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
274 );
275 assert_eq!(
276 U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
277 .floor_sqrt_vartime(),
278 U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
279 );
280 }
281
282 #[test]
283 fn edge_vartime() {
284 assert_eq!(U256::ZERO.floor_sqrt_vartime(), U256::ZERO);
285 assert_eq!(U256::ONE.floor_sqrt_vartime(), U256::ONE);
286 let mut half = U256::ZERO;
287 for i in 0..half.limbs.len() / 2 {
288 half.limbs[i] = Limb::MAX;
289 }
290 assert_eq!(U256::MAX.floor_sqrt_vartime(), half);
291 }
292
293 #[test]
294 fn simple() {
295 let tests = [
296 (4u8, 2u8),
297 (9, 3),
298 (16, 4),
299 (25, 5),
300 (36, 6),
301 (49, 7),
302 (64, 8),
303 (81, 9),
304 (100, 10),
305 (121, 11),
306 (144, 12),
307 (169, 13),
308 ];
309 for (a, e) in &tests {
310 let l = U256::from(*a);
311 let r = U256::from(*e);
312 assert_eq!(l.floor_sqrt(), r);
313 assert_eq!(l.floor_sqrt_vartime(), r);
314 assert!(l.checked_sqrt().is_some().to_bool());
315 assert!(l.checked_sqrt_vartime().is_some());
316 }
317 }
318
319 #[test]
320 fn nonsquares() {
321 assert_eq!(U256::from(2u8).floor_sqrt(), U256::from(1u8));
322 assert!(!U256::from(2u8).checked_sqrt().is_some().to_bool());
323 assert_eq!(U256::from(3u8).floor_sqrt(), U256::from(1u8));
324 assert!(!U256::from(3u8).checked_sqrt().is_some().to_bool());
325 assert_eq!(U256::from(5u8).floor_sqrt(), U256::from(2u8));
326 assert_eq!(U256::from(6u8).floor_sqrt(), U256::from(2u8));
327 assert_eq!(U256::from(7u8).floor_sqrt(), U256::from(2u8));
328 assert_eq!(U256::from(8u8).floor_sqrt(), U256::from(2u8));
329 assert_eq!(U256::from(10u8).floor_sqrt(), U256::from(3u8));
330 }
331
332 #[test]
333 fn nonsquares_vartime() {
334 assert_eq!(U256::from(2u8).floor_sqrt_vartime(), U256::from(1u8));
335 assert!(U256::from(2u8).checked_sqrt_vartime().is_none());
336 assert_eq!(U256::from(3u8).floor_sqrt_vartime(), U256::from(1u8));
337 assert!(U256::from(3u8).checked_sqrt_vartime().is_none());
338 assert_eq!(U256::from(5u8).floor_sqrt_vartime(), U256::from(2u8));
339 assert_eq!(U256::from(6u8).floor_sqrt_vartime(), U256::from(2u8));
340 assert_eq!(U256::from(7u8).floor_sqrt_vartime(), U256::from(2u8));
341 assert_eq!(U256::from(8u8).floor_sqrt_vartime(), U256::from(2u8));
342 assert_eq!(U256::from(10u8).floor_sqrt_vartime(), U256::from(3u8));
343 }
344
345 #[cfg(feature = "rand_core")]
346 #[test]
347 fn fuzz() {
348 use crate::{CheckedSquareRoot, FloorSquareRoot};
349
350 let mut rng = ChaCha8Rng::from_seed([7u8; 32]);
351 for _ in 0..50 {
352 let t = u64::from(rng.next_u32());
353 let s = U256::from(t);
354 let s2 = s.checked_square().unwrap();
355 assert_eq!(FloorSquareRoot::floor_sqrt(&s2), s);
356 assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&s2), s);
357 assert!(CheckedSquareRoot::checked_sqrt(&s2).is_some().to_bool());
358 assert!(CheckedSquareRoot::checked_sqrt_vartime(&s2).is_some());
359
360 if let Some(nz) = s2.to_nz().into_option() {
361 assert_eq!(FloorSquareRoot::floor_sqrt(&nz).get(), s);
362 assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&nz).get(), s);
363 assert!(CheckedSquareRoot::checked_sqrt(&nz).is_some().to_bool());
364 assert!(CheckedSquareRoot::checked_sqrt_vartime(&nz).is_some());
365 }
366 }
367
368 for _ in 0..50 {
369 let s = U256::random_from_rng(&mut rng);
370 let mut s2 = U512::ZERO;
371 s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
372 assert_eq!(s.concatenating_square().floor_sqrt(), s2);
373 assert_eq!(s.concatenating_square().floor_sqrt_vartime(), s2);
374 }
375 }
376}