Skip to main content

crypto_bigint/uint/
sqrt.rs

1//! [`Uint`] square root operations.
2
3use crate::{CheckedSquareRoot, CtEq, CtOption, FloorSquareRoot, Limb, NonZero, Uint};
4
5impl<const LIMBS: usize> Uint<LIMBS> {
6    /// Computes `floor(√(self))` in constant time.
7    ///
8    /// Callers can check if `self` is a square by squaring the result.
9    #[deprecated(since = "0.7.0", note = "please use `floor_sqrt` instead")]
10    #[must_use]
11    pub const fn sqrt(&self) -> Self {
12        self.floor_sqrt()
13    }
14
15    /// Computes `floor(√(self))` in constant time.
16    ///
17    /// Callers can check if `self` is a square by squaring the result.
18    #[must_use]
19    pub const fn floor_sqrt(&self) -> Self {
20        let self_is_nz = self.is_nonzero();
21        let root_nz = NonZero(Self::select(&Self::ONE, self, self_is_nz))
22            .floor_sqrt()
23            .get_copy();
24        Self::select(&Self::ZERO, &root_nz, self_is_nz)
25    }
26
27    /// Computes `floor(√(self))`.
28    ///
29    /// Callers can check if `self` is a square by squaring the result.
30    ///
31    /// Variable time with respect to `self`.
32    #[deprecated(since = "0.7.0", note = "please use `floor_sqrt_vartime` instead")]
33    #[must_use]
34    pub const fn sqrt_vartime(&self) -> Self {
35        self.floor_sqrt_vartime()
36    }
37
38    /// Computes `floor(√(self))`.
39    ///
40    /// Callers can check if `self` is a square by squaring the result.
41    ///
42    /// Variable time with respect to `self`.
43    #[must_use]
44    pub const fn floor_sqrt_vartime(&self) -> Self {
45        if self.is_zero_vartime() {
46            Self::ZERO
47        } else {
48            NonZero(*self).floor_sqrt_vartime().get_copy()
49        }
50    }
51
52    /// Wrapped sqrt is just `floor(√(self))`.
53    /// There’s no way wrapping could ever happen.
54    /// This function exists so that all operations are accounted for in the wrapping operations.
55    #[must_use]
56    pub const fn wrapping_sqrt(&self) -> Self {
57        self.floor_sqrt()
58    }
59
60    /// Wrapped sqrt is just `floor(√(self))`.
61    /// There’s no way wrapping could ever happen.
62    /// This function exists so that all operations are accounted for in the wrapping operations.
63    ///
64    /// Variable time with respect to `self`.
65    #[must_use]
66    pub const fn wrapping_sqrt_vartime(&self) -> Self {
67        self.floor_sqrt_vartime()
68    }
69
70    /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
71    /// only if the square root is exact.
72    #[must_use]
73    pub fn checked_sqrt(&self) -> CtOption<Self> {
74        let self_is_nz = self.is_nonzero();
75        NonZero(Self::select(&Self::ONE, self, self_is_nz))
76            .checked_sqrt()
77            .map(|nz| Self::select(&Self::ZERO, nz.as_ref(), self_is_nz))
78    }
79
80    /// Perform checked sqrt, returning an [`Option`] which `is_some`
81    /// only if the square root is exact.
82    ///
83    /// Variable time with respect to `self`.
84    pub fn checked_sqrt_vartime(&self) -> Option<Self> {
85        if self.is_zero_vartime() {
86            Some(Self::ZERO)
87        } else {
88            NonZero(*self).checked_sqrt_vartime().map(NonZero::get)
89        }
90    }
91}
92
93impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
94    /// Computes `floor(√(self))` in constant time.
95    ///
96    /// Callers can check if `self` is a square by squaring the result.
97    #[must_use]
98    pub const fn floor_sqrt(&self) -> Self {
99        // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13.
100        //
101        // See Hast, "Note on computation of integer square roots"
102        // for the proof of the sufficiency of the bound on iterations.
103        // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf
104
105        let rt_bits = self.0.bits().div_ceil(2);
106        // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < 2^b`.
107        // Will not overflow since `b <= BITS`.
108        let mut x = Uint::<LIMBS>::ZERO.set_bit_vartime(rt_bits, true);
109        // Compute `self.0 / x_0` by shifting.
110        let mut q = self.0.shr(rt_bits);
111        // The first division has been performed.
112        let mut i = 1;
113
114        loop {
115            // Calculate `x_{i+1} = floor((x_i + self_nz / x_i) / 2)`, leaving `x` unmodified
116            // if it would increase.
117            x = Uint::select(&x.wrapping_add(&q).shr1(), &x, Uint::lt(&x, &q));
118
119            // We repeat enough times to guarantee the result has stabilized.
120            // TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough.
121            i += 1;
122            if i >= Uint::<LIMBS>::LOG2_BITS + 2 {
123                return x.to_nz().expect_copied("ensured non-zero");
124            }
125
126            (q, _) = self.0.div_rem(x.to_nz().expect_ref("ensured non-zero"));
127        }
128    }
129
130    /// Computes `floor(√(self))`.
131    ///
132    /// Callers can check if `self` is a square by squaring the result.
133    ///
134    /// Variable time with respect to `self`.
135    #[must_use]
136    pub const fn floor_sqrt_vartime(&self) -> Self {
137        // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
138
139        let bits = self.0.bits_vartime();
140        if bits <= Limb::BITS {
141            let rt = self.0.limbs[0].0.isqrt();
142            return Uint::from_word(rt)
143                .to_nz()
144                .expect_copied("ensured non-zero");
145        }
146        let rt_bits = bits.div_ceil(2);
147
148        // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
149        // Will not overflow since `b <= BITS`.
150        let mut x = Uint::ZERO.set_bit_vartime(rt_bits, true);
151        // Compute `self / x_0` by shifting.
152        let mut q = self.0.shr_vartime(rt_bits);
153
154        loop {
155            // Terminate if `x_{i+1}` >= `x`.
156            if q.cmp_vartime(&x).is_ge() {
157                return x.to_nz().expect_copied("ensured non-zero");
158            }
159            // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
160            x = x.wrapping_add(&q).shr_vartime(1);
161            q = self
162                .0
163                .wrapping_div_vartime(x.to_nz().expect_ref("ensured non-zero"));
164        }
165    }
166
167    /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
168    /// only if the square root is exact.
169    #[must_use]
170    pub fn checked_sqrt(&self) -> CtOption<Self> {
171        let r = self.floor_sqrt();
172        let s = r.wrapping_square();
173        CtOption::new(r, self.0.ct_eq(&s))
174    }
175
176    /// Perform checked sqrt, returning an [`Option`] which `is_some`
177    /// only if the square root is exact.
178    #[must_use]
179    pub fn checked_sqrt_vartime(&self) -> Option<Self> {
180        let r = self.floor_sqrt_vartime();
181        let s = r.wrapping_square();
182        if self.0.cmp_vartime(&s).is_eq() {
183            Some(r)
184        } else {
185            None
186        }
187    }
188}
189
190impl<const LIMBS: usize> CheckedSquareRoot for Uint<LIMBS> {
191    type Output = Self;
192
193    fn checked_sqrt(&self) -> CtOption<Self> {
194        self.checked_sqrt()
195    }
196
197    fn checked_sqrt_vartime(&self) -> Option<Self> {
198        self.checked_sqrt_vartime()
199    }
200}
201
202impl<const LIMBS: usize> FloorSquareRoot for Uint<LIMBS> {
203    fn floor_sqrt(&self) -> Self {
204        self.floor_sqrt()
205    }
206
207    fn floor_sqrt_vartime(&self) -> Self {
208        self.floor_sqrt_vartime()
209    }
210}
211
212impl<const LIMBS: usize> CheckedSquareRoot for NonZero<Uint<LIMBS>> {
213    type Output = Self;
214
215    fn checked_sqrt(&self) -> CtOption<Self> {
216        self.checked_sqrt()
217    }
218
219    fn checked_sqrt_vartime(&self) -> Option<Self> {
220        self.checked_sqrt_vartime()
221    }
222}
223
224impl<const LIMBS: usize> FloorSquareRoot for NonZero<Uint<LIMBS>> {
225    fn floor_sqrt(&self) -> Self {
226        self.floor_sqrt()
227    }
228
229    fn floor_sqrt_vartime(&self) -> Self {
230        self.floor_sqrt_vartime()
231    }
232}
233
234#[cfg(test)]
235#[allow(clippy::integer_division_remainder_used, reason = "test")]
236mod tests {
237    use crate::{Limb, U192, U256};
238
239    #[cfg(feature = "rand_core")]
240    use {
241        crate::{Random, U512},
242        chacha20::ChaCha8Rng,
243        rand_core::{Rng, SeedableRng},
244    };
245
246    #[test]
247    fn edge() {
248        assert_eq!(U256::ZERO.floor_sqrt(), U256::ZERO);
249        assert_eq!(U256::ONE.floor_sqrt(), U256::ONE);
250        let mut half = U256::ZERO;
251        for i in 0..half.limbs.len() / 2 {
252            half.limbs[i] = Limb::MAX;
253        }
254        assert_eq!(U256::MAX.floor_sqrt(), half);
255
256        // Test edge cases that use up the maximum number of iterations.
257
258        // `x = (r + 1)^2 - 583`, where `r` is the expected square root.
259        assert_eq!(
260            U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").floor_sqrt(),
261            U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
262        );
263        assert_eq!(
264            U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d")
265                .floor_sqrt_vartime(),
266            U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
267        );
268
269        // `x = (r + 1)^2 - 205`, where `r` is the expected square root.
270        assert_eq!(
271            U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
272                .floor_sqrt(),
273            U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
274        );
275        assert_eq!(
276            U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
277                .floor_sqrt_vartime(),
278            U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
279        );
280    }
281
282    #[test]
283    fn edge_vartime() {
284        assert_eq!(U256::ZERO.floor_sqrt_vartime(), U256::ZERO);
285        assert_eq!(U256::ONE.floor_sqrt_vartime(), U256::ONE);
286        let mut half = U256::ZERO;
287        for i in 0..half.limbs.len() / 2 {
288            half.limbs[i] = Limb::MAX;
289        }
290        assert_eq!(U256::MAX.floor_sqrt_vartime(), half);
291    }
292
293    #[test]
294    fn simple() {
295        let tests = [
296            (4u8, 2u8),
297            (9, 3),
298            (16, 4),
299            (25, 5),
300            (36, 6),
301            (49, 7),
302            (64, 8),
303            (81, 9),
304            (100, 10),
305            (121, 11),
306            (144, 12),
307            (169, 13),
308        ];
309        for (a, e) in &tests {
310            let l = U256::from(*a);
311            let r = U256::from(*e);
312            assert_eq!(l.floor_sqrt(), r);
313            assert_eq!(l.floor_sqrt_vartime(), r);
314            assert!(l.checked_sqrt().is_some().to_bool());
315            assert!(l.checked_sqrt_vartime().is_some());
316        }
317    }
318
319    #[test]
320    fn nonsquares() {
321        assert_eq!(U256::from(2u8).floor_sqrt(), U256::from(1u8));
322        assert!(!U256::from(2u8).checked_sqrt().is_some().to_bool());
323        assert_eq!(U256::from(3u8).floor_sqrt(), U256::from(1u8));
324        assert!(!U256::from(3u8).checked_sqrt().is_some().to_bool());
325        assert_eq!(U256::from(5u8).floor_sqrt(), U256::from(2u8));
326        assert_eq!(U256::from(6u8).floor_sqrt(), U256::from(2u8));
327        assert_eq!(U256::from(7u8).floor_sqrt(), U256::from(2u8));
328        assert_eq!(U256::from(8u8).floor_sqrt(), U256::from(2u8));
329        assert_eq!(U256::from(10u8).floor_sqrt(), U256::from(3u8));
330    }
331
332    #[test]
333    fn nonsquares_vartime() {
334        assert_eq!(U256::from(2u8).floor_sqrt_vartime(), U256::from(1u8));
335        assert!(U256::from(2u8).checked_sqrt_vartime().is_none());
336        assert_eq!(U256::from(3u8).floor_sqrt_vartime(), U256::from(1u8));
337        assert!(U256::from(3u8).checked_sqrt_vartime().is_none());
338        assert_eq!(U256::from(5u8).floor_sqrt_vartime(), U256::from(2u8));
339        assert_eq!(U256::from(6u8).floor_sqrt_vartime(), U256::from(2u8));
340        assert_eq!(U256::from(7u8).floor_sqrt_vartime(), U256::from(2u8));
341        assert_eq!(U256::from(8u8).floor_sqrt_vartime(), U256::from(2u8));
342        assert_eq!(U256::from(10u8).floor_sqrt_vartime(), U256::from(3u8));
343    }
344
345    #[cfg(feature = "rand_core")]
346    #[test]
347    fn fuzz() {
348        use crate::{CheckedSquareRoot, FloorSquareRoot};
349
350        let mut rng = ChaCha8Rng::from_seed([7u8; 32]);
351        for _ in 0..50 {
352            let t = u64::from(rng.next_u32());
353            let s = U256::from(t);
354            let s2 = s.checked_square().unwrap();
355            assert_eq!(FloorSquareRoot::floor_sqrt(&s2), s);
356            assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&s2), s);
357            assert!(CheckedSquareRoot::checked_sqrt(&s2).is_some().to_bool());
358            assert!(CheckedSquareRoot::checked_sqrt_vartime(&s2).is_some());
359
360            if let Some(nz) = s2.to_nz().into_option() {
361                assert_eq!(FloorSquareRoot::floor_sqrt(&nz).get(), s);
362                assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&nz).get(), s);
363                assert!(CheckedSquareRoot::checked_sqrt(&nz).is_some().to_bool());
364                assert!(CheckedSquareRoot::checked_sqrt_vartime(&nz).is_some());
365            }
366        }
367
368        for _ in 0..50 {
369            let s = U256::random_from_rng(&mut rng);
370            let mut s2 = U512::ZERO;
371            s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
372            assert_eq!(s.concatenating_square().floor_sqrt(), s2);
373            assert_eq!(s.concatenating_square().floor_sqrt_vartime(), s2);
374        }
375    }
376}