Skip to main content

crypto_bigint/uint/
sqrt.rs

1//! [`Uint`] square root operations.
2
3use ctutils::Choice;
4
5use crate::{CheckedSquareRoot, CtOption, FloorSquareRoot, NonZero, Uint};
6
7impl<const LIMBS: usize> Uint<LIMBS> {
8    /// Computes `floor(√(self))` in constant time.
9    ///
10    /// Callers can check if `self` is a square by squaring the result.
11    #[deprecated(since = "0.7.0", note = "please use `floor_sqrt` instead")]
12    #[must_use]
13    pub const fn sqrt(&self) -> Self {
14        self.floor_sqrt()
15    }
16
17    /// Computes `floor(√(self))` in constant time.
18    ///
19    /// Callers can check if `self` is a square by squaring the result, or use
20    /// `checked_sqrt`.
21    #[must_use]
22    pub const fn floor_sqrt(&self) -> Self {
23        let mut root = *self;
24        root.floor_sqrt_assign();
25        root
26    }
27
28    /// Computes `floor(√(self))`.
29    ///
30    /// Callers can check if `self` is a square by squaring the result.
31    ///
32    /// Variable time with respect to `self`.
33    #[deprecated(since = "0.7.0", note = "please use `floor_sqrt_vartime` instead")]
34    #[must_use]
35    pub const fn sqrt_vartime(&self) -> Self {
36        self.floor_sqrt_vartime()
37    }
38
39    /// Computes `floor(√(self))`.
40    ///
41    /// Callers can check if `self` is a square by squaring the result.
42    ///
43    /// Variable time with respect to `self`.
44    #[must_use]
45    pub const fn floor_sqrt_vartime(&self) -> Self {
46        let mut root = *self;
47        root.floor_sqrt_assign_vartime();
48        root
49    }
50
51    /// Wrapped sqrt is just `floor(√(self))`.
52    /// There’s no way wrapping could ever happen.
53    /// This function exists so that all operations are accounted for in the wrapping operations.
54    #[must_use]
55    pub const fn wrapping_sqrt(&self) -> Self {
56        self.floor_sqrt()
57    }
58
59    /// Wrapped sqrt is just `floor(√(self))`.
60    /// There’s no way wrapping could ever happen.
61    /// This function exists so that all operations are accounted for in the wrapping operations.
62    ///
63    /// Variable time with respect to `self`.
64    #[must_use]
65    pub const fn wrapping_sqrt_vartime(&self) -> Self {
66        self.floor_sqrt_vartime()
67    }
68
69    /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
70    /// only if the square root is exact.
71    #[must_use]
72    pub fn checked_sqrt(&self) -> CtOption<Self> {
73        let mut root = *self;
74        let exact = root.floor_sqrt_assign();
75        CtOption::new(root, exact)
76    }
77
78    /// Perform checked sqrt, returning an [`Option`] which `is_some`
79    /// only if the square root is exact.
80    ///
81    /// Variable time with respect to `self`.
82    #[must_use]
83    pub fn checked_sqrt_vartime(&self) -> Option<Self> {
84        let mut root = *self;
85        if root.floor_sqrt_assign_vartime() {
86            Some(root)
87        } else {
88            None
89        }
90    }
91
92    /// Assigns `floor(√(self))` to `self` and returns a [`Choice`] indicating
93    /// whether the square root is exact.
94    const fn floor_sqrt_assign(&mut self) -> Choice {
95        let mut buf = (Uint::<LIMBS>::ZERO, Uint::<LIMBS>::ZERO);
96        self.as_mut_uint_ref()
97            .sqrt_assign((buf.0.as_mut_uint_ref(), buf.1.as_mut_uint_ref()))
98    }
99
100    /// Assigns `floor(√(self))` to `self` and returns a [`bool`] indicating
101    /// whether the square root is exact.
102    ///
103    /// Variable time with respect to `self`.
104    const fn floor_sqrt_assign_vartime(&mut self) -> bool {
105        let mut buf = (Uint::<LIMBS>::ZERO, Uint::<LIMBS>::ZERO);
106        self.as_mut_uint_ref()
107            .sqrt_assign_vartime((buf.0.as_mut_uint_ref(), buf.1.as_mut_uint_ref()))
108    }
109}
110
111impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
112    /// Computes `floor(√(self))` in constant time.
113    ///
114    /// Callers can check if `self` is a square by squaring the result, or
115    /// use `checked_sqrt`.
116    #[must_use]
117    pub const fn floor_sqrt(&self) -> Self {
118        NonZero::new_unchecked(self.as_ref().floor_sqrt())
119    }
120
121    /// Computes `floor(√(self))`.
122    ///
123    /// Callers can check if `self` is a square by squaring the result, or
124    /// use `checked_sqrt_vartime`.
125    ///
126    /// Variable time with respect to `self`.
127    #[must_use]
128    pub const fn floor_sqrt_vartime(&self) -> Self {
129        NonZero::new_unchecked(self.as_ref().floor_sqrt_vartime())
130    }
131
132    /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
133    /// only if the square root is exact.
134    #[must_use]
135    pub fn checked_sqrt(&self) -> CtOption<Self> {
136        self.as_ref().checked_sqrt().map(NonZero::new_unchecked)
137    }
138
139    /// Perform checked sqrt, returning an [`Option`] which `is_some`
140    /// only if the square root is exact.
141    #[must_use]
142    pub fn checked_sqrt_vartime(&self) -> Option<Self> {
143        self.as_ref()
144            .checked_sqrt_vartime()
145            .map(NonZero::new_unchecked)
146    }
147}
148
149impl<const LIMBS: usize> CheckedSquareRoot for Uint<LIMBS> {
150    type Output = Self;
151
152    fn checked_sqrt(&self) -> CtOption<Self> {
153        self.checked_sqrt()
154    }
155
156    fn checked_sqrt_vartime(&self) -> Option<Self> {
157        self.checked_sqrt_vartime()
158    }
159}
160
161impl<const LIMBS: usize> FloorSquareRoot for Uint<LIMBS> {
162    fn floor_sqrt(&self) -> Self {
163        self.floor_sqrt()
164    }
165
166    fn floor_sqrt_vartime(&self) -> Self {
167        self.floor_sqrt_vartime()
168    }
169}
170
171impl<const LIMBS: usize> CheckedSquareRoot for NonZero<Uint<LIMBS>> {
172    type Output = Self;
173
174    fn checked_sqrt(&self) -> CtOption<Self> {
175        self.checked_sqrt()
176    }
177
178    fn checked_sqrt_vartime(&self) -> Option<Self> {
179        self.checked_sqrt_vartime()
180    }
181}
182
183impl<const LIMBS: usize> FloorSquareRoot for NonZero<Uint<LIMBS>> {
184    fn floor_sqrt(&self) -> Self {
185        self.floor_sqrt()
186    }
187
188    fn floor_sqrt_vartime(&self) -> Self {
189        self.floor_sqrt_vartime()
190    }
191}
192
193#[cfg(test)]
194#[allow(clippy::integer_division_remainder_used, reason = "test")]
195mod tests {
196    use crate::{Limb, U192, U256};
197
198    #[cfg(feature = "rand_core")]
199    use {
200        crate::{CheckedAdd, CheckedSquareRoot, FloorSquareRoot, Random, RandomBits, U512},
201        chacha20::ChaCha8Rng,
202        rand_core::SeedableRng,
203    };
204
205    #[test]
206    fn edge() {
207        assert_eq!(U256::ZERO.floor_sqrt(), U256::ZERO);
208        assert_eq!(U256::ONE.floor_sqrt(), U256::ONE);
209        let mut half = U256::ZERO;
210        for i in 0..half.limbs.len() / 2 {
211            half.limbs[i] = Limb::MAX;
212        }
213        assert_eq!(U256::MAX.floor_sqrt(), half);
214
215        // Test edge cases that use up the maximum number of iterations.
216
217        // `x = (r + 1)^2 - 583`, where `r` is the expected square root.
218        assert_eq!(
219            U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").floor_sqrt(),
220            U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
221        );
222        assert_eq!(
223            U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d")
224                .floor_sqrt_vartime(),
225            U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
226        );
227
228        // `x = (r + 1)^2 - 205`, where `r` is the expected square root.
229        assert_eq!(
230            U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
231                .floor_sqrt(),
232            U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
233        );
234        assert_eq!(
235            U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
236                .floor_sqrt_vartime(),
237            U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
238        );
239    }
240
241    #[test]
242    fn edge_vartime() {
243        assert_eq!(U256::ZERO.floor_sqrt_vartime(), U256::ZERO);
244        assert_eq!(U256::ONE.floor_sqrt_vartime(), U256::ONE);
245        let mut half = U256::ZERO;
246        for i in 0..half.limbs.len() / 2 {
247            half.limbs[i] = Limb::MAX;
248        }
249        assert_eq!(U256::MAX.floor_sqrt_vartime(), half);
250    }
251
252    #[test]
253    fn simple() {
254        let tests = [
255            (4u8, 2u8),
256            (9, 3),
257            (16, 4),
258            (25, 5),
259            (36, 6),
260            (49, 7),
261            (64, 8),
262            (81, 9),
263            (100, 10),
264            (121, 11),
265            (144, 12),
266            (169, 13),
267        ];
268        for (a, e) in &tests {
269            let l = U256::from(*a);
270            let r = U256::from(*e);
271            assert_eq!(l.floor_sqrt(), r);
272            assert_eq!(l.floor_sqrt_vartime(), r);
273            assert!(l.checked_sqrt().is_some().to_bool());
274            assert!(l.checked_sqrt_vartime().is_some());
275        }
276    }
277
278    #[test]
279    fn nonsquares() {
280        assert_eq!(U256::from(2u8).floor_sqrt(), U256::from(1u8));
281        assert!(!U256::from(2u8).checked_sqrt().is_some().to_bool());
282        assert_eq!(U256::from(3u8).floor_sqrt(), U256::from(1u8));
283        assert!(!U256::from(3u8).checked_sqrt().is_some().to_bool());
284        assert_eq!(U256::from(5u8).floor_sqrt(), U256::from(2u8));
285        assert_eq!(U256::from(6u8).floor_sqrt(), U256::from(2u8));
286        assert_eq!(U256::from(7u8).floor_sqrt(), U256::from(2u8));
287        assert_eq!(U256::from(8u8).floor_sqrt(), U256::from(2u8));
288        assert_eq!(U256::from(10u8).floor_sqrt(), U256::from(3u8));
289    }
290
291    #[test]
292    fn nonsquares_vartime() {
293        assert_eq!(U256::from(2u8).floor_sqrt_vartime(), U256::from(1u8));
294        assert!(U256::from(2u8).checked_sqrt_vartime().is_none());
295        assert_eq!(U256::from(3u8).floor_sqrt_vartime(), U256::from(1u8));
296        assert!(U256::from(3u8).checked_sqrt_vartime().is_none());
297        assert_eq!(U256::from(5u8).floor_sqrt_vartime(), U256::from(2u8));
298        assert_eq!(U256::from(6u8).floor_sqrt_vartime(), U256::from(2u8));
299        assert_eq!(U256::from(7u8).floor_sqrt_vartime(), U256::from(2u8));
300        assert_eq!(U256::from(8u8).floor_sqrt_vartime(), U256::from(2u8));
301        assert_eq!(U256::from(10u8).floor_sqrt_vartime(), U256::from(3u8));
302    }
303
304    #[cfg(feature = "rand_core")]
305    #[test]
306    fn fuzz() {
307        let mut rng = ChaCha8Rng::from_seed([7u8; 32]);
308        for _ in 0..50 {
309            let s = U256::random_bits(&mut rng, 128);
310            let s2 = s.checked_square().unwrap();
311            assert_eq!(FloorSquareRoot::floor_sqrt(&s2), s);
312            assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&s2), s);
313            assert!(CheckedSquareRoot::checked_sqrt(&s2).is_some().to_bool());
314            assert!(CheckedSquareRoot::checked_sqrt_vartime(&s2).is_some());
315
316            if let Some(nz) = s2.to_nz().into_option() {
317                assert_eq!(FloorSquareRoot::floor_sqrt(&nz).get(), s);
318                assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&nz).get(), s);
319                assert!(CheckedSquareRoot::checked_sqrt(&nz).is_some().to_bool());
320                assert!(CheckedSquareRoot::checked_sqrt_vartime(&nz).is_some());
321            }
322
323            if let Some(sx) = s2.checked_add(&U256::ONE).into_option() {
324                assert_eq!(FloorSquareRoot::floor_sqrt(&sx), s);
325                assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&sx), s);
326                assert!(CheckedSquareRoot::checked_sqrt(&sx).is_none().to_bool());
327                assert!(CheckedSquareRoot::checked_sqrt_vartime(&sx).is_none());
328            }
329        }
330
331        for _ in 0..50 {
332            let s = U256::random_from_rng(&mut rng);
333            let mut s2 = U512::ZERO;
334            s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
335            assert_eq!(s.concatenating_square().floor_sqrt(), s2);
336            assert_eq!(s.concatenating_square().floor_sqrt_vartime(), s2);
337        }
338    }
339}