crypto_bigint/uint/
sqrt.rs1use ctutils::Choice;
4
5use crate::{CheckedSquareRoot, CtOption, FloorSquareRoot, NonZero, Uint};
6
7impl<const LIMBS: usize> Uint<LIMBS> {
8 #[deprecated(since = "0.7.0", note = "please use `floor_sqrt` instead")]
12 #[must_use]
13 pub const fn sqrt(&self) -> Self {
14 self.floor_sqrt()
15 }
16
17 #[must_use]
22 pub const fn floor_sqrt(&self) -> Self {
23 let mut root = *self;
24 root.floor_sqrt_assign();
25 root
26 }
27
28 #[deprecated(since = "0.7.0", note = "please use `floor_sqrt_vartime` instead")]
34 #[must_use]
35 pub const fn sqrt_vartime(&self) -> Self {
36 self.floor_sqrt_vartime()
37 }
38
39 #[must_use]
45 pub const fn floor_sqrt_vartime(&self) -> Self {
46 let mut root = *self;
47 root.floor_sqrt_assign_vartime();
48 root
49 }
50
51 #[must_use]
55 pub const fn wrapping_sqrt(&self) -> Self {
56 self.floor_sqrt()
57 }
58
59 #[must_use]
65 pub const fn wrapping_sqrt_vartime(&self) -> Self {
66 self.floor_sqrt_vartime()
67 }
68
69 #[must_use]
72 pub fn checked_sqrt(&self) -> CtOption<Self> {
73 let mut root = *self;
74 let exact = root.floor_sqrt_assign();
75 CtOption::new(root, exact)
76 }
77
78 #[must_use]
83 pub fn checked_sqrt_vartime(&self) -> Option<Self> {
84 let mut root = *self;
85 if root.floor_sqrt_assign_vartime() {
86 Some(root)
87 } else {
88 None
89 }
90 }
91
92 const fn floor_sqrt_assign(&mut self) -> Choice {
95 let mut buf = (Uint::<LIMBS>::ZERO, Uint::<LIMBS>::ZERO);
96 self.as_mut_uint_ref()
97 .sqrt_assign((buf.0.as_mut_uint_ref(), buf.1.as_mut_uint_ref()))
98 }
99
100 const fn floor_sqrt_assign_vartime(&mut self) -> bool {
105 let mut buf = (Uint::<LIMBS>::ZERO, Uint::<LIMBS>::ZERO);
106 self.as_mut_uint_ref()
107 .sqrt_assign_vartime((buf.0.as_mut_uint_ref(), buf.1.as_mut_uint_ref()))
108 }
109}
110
111impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
112 #[must_use]
117 pub const fn floor_sqrt(&self) -> Self {
118 NonZero::new_unchecked(self.as_ref().floor_sqrt())
119 }
120
121 #[must_use]
128 pub const fn floor_sqrt_vartime(&self) -> Self {
129 NonZero::new_unchecked(self.as_ref().floor_sqrt_vartime())
130 }
131
132 #[must_use]
135 pub fn checked_sqrt(&self) -> CtOption<Self> {
136 self.as_ref().checked_sqrt().map(NonZero::new_unchecked)
137 }
138
139 #[must_use]
142 pub fn checked_sqrt_vartime(&self) -> Option<Self> {
143 self.as_ref()
144 .checked_sqrt_vartime()
145 .map(NonZero::new_unchecked)
146 }
147}
148
149impl<const LIMBS: usize> CheckedSquareRoot for Uint<LIMBS> {
150 type Output = Self;
151
152 fn checked_sqrt(&self) -> CtOption<Self> {
153 self.checked_sqrt()
154 }
155
156 fn checked_sqrt_vartime(&self) -> Option<Self> {
157 self.checked_sqrt_vartime()
158 }
159}
160
161impl<const LIMBS: usize> FloorSquareRoot for Uint<LIMBS> {
162 fn floor_sqrt(&self) -> Self {
163 self.floor_sqrt()
164 }
165
166 fn floor_sqrt_vartime(&self) -> Self {
167 self.floor_sqrt_vartime()
168 }
169}
170
171impl<const LIMBS: usize> CheckedSquareRoot for NonZero<Uint<LIMBS>> {
172 type Output = Self;
173
174 fn checked_sqrt(&self) -> CtOption<Self> {
175 self.checked_sqrt()
176 }
177
178 fn checked_sqrt_vartime(&self) -> Option<Self> {
179 self.checked_sqrt_vartime()
180 }
181}
182
183impl<const LIMBS: usize> FloorSquareRoot for NonZero<Uint<LIMBS>> {
184 fn floor_sqrt(&self) -> Self {
185 self.floor_sqrt()
186 }
187
188 fn floor_sqrt_vartime(&self) -> Self {
189 self.floor_sqrt_vartime()
190 }
191}
192
193#[cfg(test)]
194#[allow(clippy::integer_division_remainder_used, reason = "test")]
195mod tests {
196 use crate::{Limb, U192, U256};
197
198 #[cfg(feature = "rand_core")]
199 use {
200 crate::{CheckedAdd, CheckedSquareRoot, FloorSquareRoot, Random, RandomBits, U512},
201 chacha20::ChaCha8Rng,
202 rand_core::SeedableRng,
203 };
204
205 #[test]
206 fn edge() {
207 assert_eq!(U256::ZERO.floor_sqrt(), U256::ZERO);
208 assert_eq!(U256::ONE.floor_sqrt(), U256::ONE);
209 let mut half = U256::ZERO;
210 for i in 0..half.limbs.len() / 2 {
211 half.limbs[i] = Limb::MAX;
212 }
213 assert_eq!(U256::MAX.floor_sqrt(), half);
214
215 assert_eq!(
219 U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").floor_sqrt(),
220 U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
221 );
222 assert_eq!(
223 U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d")
224 .floor_sqrt_vartime(),
225 U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
226 );
227
228 assert_eq!(
230 U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
231 .floor_sqrt(),
232 U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
233 );
234 assert_eq!(
235 U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
236 .floor_sqrt_vartime(),
237 U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
238 );
239 }
240
241 #[test]
242 fn edge_vartime() {
243 assert_eq!(U256::ZERO.floor_sqrt_vartime(), U256::ZERO);
244 assert_eq!(U256::ONE.floor_sqrt_vartime(), U256::ONE);
245 let mut half = U256::ZERO;
246 for i in 0..half.limbs.len() / 2 {
247 half.limbs[i] = Limb::MAX;
248 }
249 assert_eq!(U256::MAX.floor_sqrt_vartime(), half);
250 }
251
252 #[test]
253 fn simple() {
254 let tests = [
255 (4u8, 2u8),
256 (9, 3),
257 (16, 4),
258 (25, 5),
259 (36, 6),
260 (49, 7),
261 (64, 8),
262 (81, 9),
263 (100, 10),
264 (121, 11),
265 (144, 12),
266 (169, 13),
267 ];
268 for (a, e) in &tests {
269 let l = U256::from(*a);
270 let r = U256::from(*e);
271 assert_eq!(l.floor_sqrt(), r);
272 assert_eq!(l.floor_sqrt_vartime(), r);
273 assert!(l.checked_sqrt().is_some().to_bool());
274 assert!(l.checked_sqrt_vartime().is_some());
275 }
276 }
277
278 #[test]
279 fn nonsquares() {
280 assert_eq!(U256::from(2u8).floor_sqrt(), U256::from(1u8));
281 assert!(!U256::from(2u8).checked_sqrt().is_some().to_bool());
282 assert_eq!(U256::from(3u8).floor_sqrt(), U256::from(1u8));
283 assert!(!U256::from(3u8).checked_sqrt().is_some().to_bool());
284 assert_eq!(U256::from(5u8).floor_sqrt(), U256::from(2u8));
285 assert_eq!(U256::from(6u8).floor_sqrt(), U256::from(2u8));
286 assert_eq!(U256::from(7u8).floor_sqrt(), U256::from(2u8));
287 assert_eq!(U256::from(8u8).floor_sqrt(), U256::from(2u8));
288 assert_eq!(U256::from(10u8).floor_sqrt(), U256::from(3u8));
289 }
290
291 #[test]
292 fn nonsquares_vartime() {
293 assert_eq!(U256::from(2u8).floor_sqrt_vartime(), U256::from(1u8));
294 assert!(U256::from(2u8).checked_sqrt_vartime().is_none());
295 assert_eq!(U256::from(3u8).floor_sqrt_vartime(), U256::from(1u8));
296 assert!(U256::from(3u8).checked_sqrt_vartime().is_none());
297 assert_eq!(U256::from(5u8).floor_sqrt_vartime(), U256::from(2u8));
298 assert_eq!(U256::from(6u8).floor_sqrt_vartime(), U256::from(2u8));
299 assert_eq!(U256::from(7u8).floor_sqrt_vartime(), U256::from(2u8));
300 assert_eq!(U256::from(8u8).floor_sqrt_vartime(), U256::from(2u8));
301 assert_eq!(U256::from(10u8).floor_sqrt_vartime(), U256::from(3u8));
302 }
303
304 #[cfg(feature = "rand_core")]
305 #[test]
306 fn fuzz() {
307 let mut rng = ChaCha8Rng::from_seed([7u8; 32]);
308 for _ in 0..50 {
309 let s = U256::random_bits(&mut rng, 128);
310 let s2 = s.checked_square().unwrap();
311 assert_eq!(FloorSquareRoot::floor_sqrt(&s2), s);
312 assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&s2), s);
313 assert!(CheckedSquareRoot::checked_sqrt(&s2).is_some().to_bool());
314 assert!(CheckedSquareRoot::checked_sqrt_vartime(&s2).is_some());
315
316 if let Some(nz) = s2.to_nz().into_option() {
317 assert_eq!(FloorSquareRoot::floor_sqrt(&nz).get(), s);
318 assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&nz).get(), s);
319 assert!(CheckedSquareRoot::checked_sqrt(&nz).is_some().to_bool());
320 assert!(CheckedSquareRoot::checked_sqrt_vartime(&nz).is_some());
321 }
322
323 if let Some(sx) = s2.checked_add(&U256::ONE).into_option() {
324 assert_eq!(FloorSquareRoot::floor_sqrt(&sx), s);
325 assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&sx), s);
326 assert!(CheckedSquareRoot::checked_sqrt(&sx).is_none().to_bool());
327 assert!(CheckedSquareRoot::checked_sqrt_vartime(&sx).is_none());
328 }
329 }
330
331 for _ in 0..50 {
332 let s = U256::random_from_rng(&mut rng);
333 let mut s2 = U512::ZERO;
334 s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
335 assert_eq!(s.concatenating_square().floor_sqrt(), s2);
336 assert_eq!(s.concatenating_square().floor_sqrt_vartime(), s2);
337 }
338 }
339}