Skip to main content

crypto_bigint/uint/boxed/
sqrt.rs

1//! [`BoxedUint`] square root operations.
2
3use crate::{
4    BitOps, BoxedUint, CheckedSquareRoot, CtAssign, CtEq, CtGt, CtOption, FloorSquareRoot, Limb,
5};
6
7impl BoxedUint {
8    /// Computes `floor(√(self))` in constant time.
9    ///
10    /// Callers can check if `self` is a square by squaring the result.
11    #[deprecated(since = "0.7.0", note = "please use `floor_sqrt` instead")]
12    #[must_use]
13    pub fn sqrt(&self) -> Self {
14        self.floor_sqrt()
15    }
16
17    /// Computes √(`self`) in constant time.
18    ///
19    /// Callers can check if `self` is a square by squaring the result.
20    #[must_use]
21    pub fn floor_sqrt(&self) -> Self {
22        // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13.
23        //
24        // See Hast, "Note on computation of integer square roots"
25        // for the proof of the sufficiency of the bound on iterations.
26        // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf
27
28        // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
29        // Will not overflow since `b <= BITS`.
30        let mut x = Self::one_with_precision(self.bits_precision());
31        x.unbounded_shl_assign_vartime((self.bits() + 1) >> 1); // ≥ √(`self`)
32
33        let mut nz_x = x.clone();
34        let mut quo = Self::zero_with_precision(self.bits_precision());
35        let mut rem = Self::zero_with_precision(self.bits_precision());
36        let mut i = 0;
37
38        // Repeat enough times to guarantee result has stabilized.
39        // TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough.
40        while i < self.log2_bits() + 2 {
41            let x_nonzero = x.is_nonzero();
42            nz_x.ct_assign(&x, x_nonzero);
43
44            // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
45            quo.limbs.copy_from_slice(&self.limbs);
46            rem.limbs.copy_from_slice(&nz_x.limbs);
47            quo.as_mut_uint_ref().div_rem(rem.as_mut_uint_ref());
48            x.conditional_carrying_add_assign(&quo, x_nonzero);
49            x.shr1_assign();
50
51            i += 1;
52        }
53
54        // At this point `x_prev == x_{n}` and `x == x_{n+1}`
55        // where `n == i - 1 == LOG2_BITS + 1 == floor(log2(BITS)) + 1`.
56        // Thus, according to Hast, `sqrt(self) = min(x_n, x_{n+1})`.
57        x.ct_assign(&nz_x, x.ct_gt(&nz_x));
58        x
59    }
60
61    /// Computes `floor(√(self))`.
62    ///
63    /// Callers can check if `self` is a square by squaring the result.
64    ///
65    /// Variable time with respect to `self`.
66    #[deprecated(since = "0.7.0", note = "please use `floor_sqrt_vartime` instead")]
67    #[must_use]
68    pub fn sqrt_vartime(&self) -> Self {
69        self.floor_sqrt_vartime()
70    }
71
72    /// Computes √(`self`).
73    ///
74    /// Callers can check if `self` is a square by squaring the result.
75    ///
76    /// Variable time with respect to `self`.
77    #[must_use]
78    pub fn floor_sqrt_vartime(&self) -> Self {
79        // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
80
81        if self.is_zero_vartime() {
82            return Self::zero_with_precision(self.bits_precision());
83        }
84
85        // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
86        // Will not overflow since `b <= BITS`.
87        // The initial value of `x` is always greater than zero.
88        let mut x = Self::one_with_precision(self.bits_precision());
89        x.unbounded_shl_assign_vartime((self.bits() + 1) >> 1); // ≥ √(`self`)
90
91        let mut quo = Self::zero_with_precision(self.bits_precision());
92        let mut rem = Self::zero_with_precision(self.bits_precision());
93
94        loop {
95            // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
96            quo.limbs.copy_from_slice(&self.limbs);
97            rem.limbs.copy_from_slice(&x.limbs);
98            quo.as_mut_uint_ref().div_rem_vartime(rem.as_mut_uint_ref());
99            quo.carrying_add_assign(&x, Limb::ZERO);
100            quo.shr1_assign();
101
102            // If `quo` is the same as `x` or greater, we reached convergence
103            // (`x` is guaranteed to either go down or oscillate between
104            // `sqrt(self)` and `sqrt(self) + 1`)
105            if !x.cmp_vartime(&quo).is_gt() {
106                break;
107            }
108            x.limbs.copy_from_slice(&quo.limbs);
109            if x.is_zero_vartime() {
110                break;
111            }
112        }
113
114        x
115    }
116
117    /// Wrapped sqrt is just `floor(√(self))`.
118    /// There’s no way wrapping could ever happen.
119    /// This function exists so that all operations are accounted for in the wrapping operations.
120    #[must_use]
121    pub fn wrapping_sqrt(&self) -> Self {
122        self.floor_sqrt()
123    }
124
125    /// Wrapped sqrt is just `floor(√(self))`.
126    /// There’s no way wrapping could ever happen.
127    /// This function exists so that all operations are accounted for in the wrapping operations.
128    ///
129    /// Variable time with respect to `self`.
130    #[must_use]
131    pub fn wrapping_sqrt_vartime(&self) -> Self {
132        self.floor_sqrt_vartime()
133    }
134
135    /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
136    /// only if the square root is exact.
137    #[must_use]
138    pub fn checked_sqrt(&self) -> CtOption<Self> {
139        let r = self.floor_sqrt();
140        let s = r.wrapping_square();
141        CtOption::new(r, self.ct_eq(&s))
142    }
143
144    /// Perform checked sqrt, returning an [`Option`] which `is_some`
145    /// only if the square root is exact.
146    ///
147    /// Variable time with respect to `self`.
148    #[must_use]
149    pub fn checked_sqrt_vartime(&self) -> Option<Self> {
150        let r = self.floor_sqrt_vartime();
151        let s = r.wrapping_square();
152        if self.cmp_vartime(&s).is_eq() {
153            Some(r)
154        } else {
155            None
156        }
157    }
158}
159
160impl CheckedSquareRoot for BoxedUint {
161    type Output = Self;
162
163    fn checked_sqrt(&self) -> CtOption<Self::Output> {
164        self.checked_sqrt()
165    }
166
167    fn checked_sqrt_vartime(&self) -> Option<Self::Output> {
168        self.checked_sqrt_vartime()
169    }
170}
171
172impl FloorSquareRoot for BoxedUint {
173    fn floor_sqrt(&self) -> Self {
174        self.floor_sqrt()
175    }
176
177    fn floor_sqrt_vartime(&self) -> Self {
178        self.floor_sqrt_vartime()
179    }
180}
181
182#[cfg(test)]
183#[allow(clippy::integer_division_remainder_used, reason = "test")]
184mod tests {
185    use crate::{BoxedUint, Limb};
186
187    #[cfg(feature = "rand_core")]
188    use {
189        crate::RandomBits,
190        chacha20::ChaCha8Rng,
191        rand_core::{Rng, SeedableRng},
192    };
193
194    #[test]
195    fn edge() {
196        assert_eq!(
197            BoxedUint::zero_with_precision(256).floor_sqrt(),
198            BoxedUint::zero_with_precision(256)
199        );
200        assert_eq!(
201            BoxedUint::one_with_precision(256).floor_sqrt(),
202            BoxedUint::one_with_precision(256)
203        );
204        let mut half = BoxedUint::zero_with_precision(256);
205        for i in 0..half.limbs.len() / 2 {
206            half.limbs[i] = Limb::MAX;
207        }
208        let u256_max = !BoxedUint::zero_with_precision(256);
209        assert_eq!(u256_max.floor_sqrt(), half);
210
211        // Test edge cases that use up the maximum number of iterations.
212
213        // `x = (r + 1)^2 - 583`, where `r` is the expected square root.
214        assert_eq!(
215            BoxedUint::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d", 192)
216                .unwrap()
217                .floor_sqrt(),
218            BoxedUint::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21", 192)
219                .unwrap(),
220        );
221        assert_eq!(
222            BoxedUint::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d", 192)
223                .unwrap()
224                .floor_sqrt_vartime(),
225            BoxedUint::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21", 192)
226                .unwrap()
227        );
228
229        // `x = (r + 1)^2 - 205`, where `r` is the expected square root.
230        assert_eq!(
231            BoxedUint::from_be_hex(
232                "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597",
233                256
234            )
235            .unwrap()
236            .floor_sqrt(),
237            BoxedUint::from_be_hex(
238                "000000000000000000000000000000008b3956339e8315cff66eb6107b610075",
239                256
240            )
241            .unwrap()
242        );
243        assert_eq!(
244            BoxedUint::from_be_hex(
245                "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597",
246                256
247            )
248            .unwrap()
249            .floor_sqrt_vartime(),
250            BoxedUint::from_be_hex(
251                "000000000000000000000000000000008b3956339e8315cff66eb6107b610075",
252                256
253            )
254            .unwrap()
255        );
256    }
257
258    #[test]
259    fn edge_vartime() {
260        assert_eq!(
261            BoxedUint::zero_with_precision(256).floor_sqrt_vartime(),
262            BoxedUint::zero_with_precision(256)
263        );
264        assert_eq!(
265            BoxedUint::one_with_precision(256).floor_sqrt_vartime(),
266            BoxedUint::one_with_precision(256)
267        );
268        let mut half = BoxedUint::zero_with_precision(256);
269        for i in 0..half.limbs.len() / 2 {
270            half.limbs[i] = Limb::MAX;
271        }
272        let u256_max = !BoxedUint::zero_with_precision(256);
273        assert_eq!(u256_max.floor_sqrt_vartime(), half);
274    }
275
276    #[test]
277    fn simple() {
278        let tests = [
279            (4u8, 2u8),
280            (9, 3),
281            (16, 4),
282            (25, 5),
283            (36, 6),
284            (49, 7),
285            (64, 8),
286            (81, 9),
287            (100, 10),
288            (121, 11),
289            (144, 12),
290            (169, 13),
291        ];
292        for (a, e) in &tests {
293            let l = BoxedUint::from(*a);
294            let r = BoxedUint::from(*e);
295            assert_eq!(l.floor_sqrt(), r);
296            assert_eq!(l.floor_sqrt_vartime(), r);
297            assert!(l.checked_sqrt().is_some().to_bool());
298            assert!(l.checked_sqrt_vartime().is_some());
299        }
300    }
301
302    #[test]
303    fn nonsquares() {
304        assert_eq!(BoxedUint::from(2u8).floor_sqrt(), BoxedUint::from(1u8));
305        assert!(!BoxedUint::from(2u8).checked_sqrt().is_some().to_bool());
306        assert_eq!(BoxedUint::from(3u8).floor_sqrt(), BoxedUint::from(1u8));
307        assert!(!BoxedUint::from(3u8).checked_sqrt().is_some().to_bool());
308        assert_eq!(BoxedUint::from(5u8).floor_sqrt(), BoxedUint::from(2u8));
309        assert_eq!(BoxedUint::from(6u8).floor_sqrt(), BoxedUint::from(2u8));
310        assert_eq!(BoxedUint::from(7u8).floor_sqrt(), BoxedUint::from(2u8));
311        assert_eq!(BoxedUint::from(8u8).floor_sqrt(), BoxedUint::from(2u8));
312        assert_eq!(BoxedUint::from(10u8).floor_sqrt(), BoxedUint::from(3u8));
313    }
314
315    #[test]
316    fn nonsquares_vartime() {
317        assert_eq!(
318            BoxedUint::from(2u8).floor_sqrt_vartime(),
319            BoxedUint::from(1u8)
320        );
321        assert!(BoxedUint::from(2u8).checked_sqrt_vartime().is_none());
322        assert_eq!(
323            BoxedUint::from(3u8).floor_sqrt_vartime(),
324            BoxedUint::from(1u8)
325        );
326        assert!(BoxedUint::from(3u8).checked_sqrt_vartime().is_none());
327        assert_eq!(
328            BoxedUint::from(5u8).floor_sqrt_vartime(),
329            BoxedUint::from(2u8)
330        );
331        assert_eq!(
332            BoxedUint::from(6u8).floor_sqrt_vartime(),
333            BoxedUint::from(2u8)
334        );
335        assert_eq!(
336            BoxedUint::from(7u8).floor_sqrt_vartime(),
337            BoxedUint::from(2u8)
338        );
339        assert_eq!(
340            BoxedUint::from(8u8).floor_sqrt_vartime(),
341            BoxedUint::from(2u8)
342        );
343        assert_eq!(
344            BoxedUint::from(10u8).floor_sqrt_vartime(),
345            BoxedUint::from(3u8)
346        );
347    }
348
349    #[cfg(feature = "rand_core")]
350    #[test]
351    fn fuzz() {
352        use crate::{CheckedSquareRoot, ConcatenatingSquare};
353
354        let mut rng = ChaCha8Rng::from_seed([7u8; 32]);
355        let rounds = if cfg!(miri) { 10 } else { 50 };
356        for _ in 0..rounds {
357            let t = u64::from(rng.next_u32());
358            let s = BoxedUint::from(t);
359            let s2 = s.checked_mul(&s).unwrap();
360            assert_eq!(s2.floor_sqrt(), s);
361            assert_eq!(s2.floor_sqrt_vartime(), s);
362            assert!(CheckedSquareRoot::checked_sqrt(&s2).is_some().to_bool());
363            assert!(CheckedSquareRoot::checked_sqrt_vartime(&s2).is_some());
364        }
365
366        for _ in 0..rounds {
367            let s = BoxedUint::random_bits(&mut rng, 512);
368            let mut s2 = BoxedUint::zero_with_precision(512);
369            s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
370            assert_eq!(s.concatenating_square().floor_sqrt(), s2);
371            assert_eq!(s.concatenating_square().floor_sqrt_vartime(), s2);
372        }
373    }
374}