Skip to main content

fixed_bigint/fixeduint/
isqrt_impl.rs

1// Copyright 2021 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Integer square root for FixedUInt.
16
17use super::{const_set_bit, FixedUInt, MachineWord};
18use crate::const_numtraits::{ConstBitPrimInt, ConstIsqrt, ConstZero};
19use crate::machineword::ConstMachineWord;
20use crate::personality::Nct;
21
22c0nst::c0nst! {
23    impl<T: [c0nst] ConstMachineWord + MachineWord, const N: usize> c0nst ConstIsqrt for FixedUInt<T, N, Nct> {
24        fn isqrt(self) -> Self {
25            // For unsigned types, isqrt always succeeds
26            match ConstIsqrt::checked_isqrt(self) {
27                Some(v) => v,
28                None => unreachable!(),
29            }
30        }
31
32        fn checked_isqrt(self) -> Option<Self> {
33            // Bit-by-bit algorithm for integer square root
34            // Returns the largest r such that r * r <= self
35            if self.is_zero() {
36                return Some(Self::zero());
37            }
38
39            // Start with the highest bit position that could be set in the result
40            // For sqrt, the result has at most half the bits of the input
41            let mut result = Self::zero();
42
43            // Find starting bit position: half of the bit length of self
44            let bit_len = Self::BIT_SIZE - ConstBitPrimInt::leading_zeros(self) as usize;
45            let start_bit = bit_len.div_ceil(2);
46
47            let mut bit_pos = start_bit;
48            while bit_pos > 0 {
49                bit_pos -= 1;
50
51                // Try setting this bit in the result
52                let mut candidate = result;
53                const_set_bit(&mut candidate.array, bit_pos);
54
55                // Check if candidate * candidate <= self
56                // Since candidate has at most half the bits of self,
57                // candidate * candidate won't overflow.
58                let square = candidate * candidate;
59                if square <= self {
60                    result = candidate;
61                }
62            }
63
64            Some(result)
65        }
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use num_traits::{CheckedAdd, CheckedMul};
73
74    #[test]
75    fn test_isqrt() {
76        type U16 = FixedUInt<u8, 2>;
77
78        // Perfect squares
79        assert_eq!(ConstIsqrt::isqrt(U16::from(0u8)), U16::from(0u8));
80        assert_eq!(ConstIsqrt::isqrt(U16::from(1u8)), U16::from(1u8));
81        assert_eq!(ConstIsqrt::isqrt(U16::from(4u8)), U16::from(2u8));
82        assert_eq!(ConstIsqrt::isqrt(U16::from(9u8)), U16::from(3u8));
83        assert_eq!(ConstIsqrt::isqrt(U16::from(16u8)), U16::from(4u8));
84        assert_eq!(ConstIsqrt::isqrt(U16::from(25u8)), U16::from(5u8));
85        assert_eq!(ConstIsqrt::isqrt(U16::from(100u8)), U16::from(10u8));
86        assert_eq!(ConstIsqrt::isqrt(U16::from(144u8)), U16::from(12u8));
87
88        // Non-perfect squares (floor)
89        assert_eq!(ConstIsqrt::isqrt(U16::from(2u8)), U16::from(1u8));
90        assert_eq!(ConstIsqrt::isqrt(U16::from(3u8)), U16::from(1u8));
91        assert_eq!(ConstIsqrt::isqrt(U16::from(5u8)), U16::from(2u8));
92        assert_eq!(ConstIsqrt::isqrt(U16::from(8u8)), U16::from(2u8));
93        assert_eq!(ConstIsqrt::isqrt(U16::from(10u8)), U16::from(3u8));
94        assert_eq!(ConstIsqrt::isqrt(U16::from(15u8)), U16::from(3u8));
95        assert_eq!(ConstIsqrt::isqrt(U16::from(24u8)), U16::from(4u8));
96    }
97
98    #[test]
99    fn test_isqrt_larger_values() {
100        type U16 = FixedUInt<u8, 2>;
101
102        // Larger values
103        assert_eq!(ConstIsqrt::isqrt(U16::from(10000u16)), U16::from(100u8));
104        assert_eq!(ConstIsqrt::isqrt(U16::from(65535u16)), U16::from(255u8)); // sqrt(65535) = 255.998...
105        assert_eq!(ConstIsqrt::isqrt(U16::from(65025u16)), U16::from(255u8)); // 255^2 = 65025
106    }
107
108    #[test]
109    fn test_checked_isqrt() {
110        type U16 = FixedUInt<u8, 2>;
111
112        // For unsigned, checked_isqrt always returns Some
113        assert_eq!(
114            ConstIsqrt::checked_isqrt(U16::from(0u8)),
115            Some(U16::from(0u8))
116        );
117        assert_eq!(
118            ConstIsqrt::checked_isqrt(U16::from(16u8)),
119            Some(U16::from(4u8))
120        );
121        assert_eq!(
122            ConstIsqrt::checked_isqrt(U16::from(17u8)),
123            Some(U16::from(4u8))
124        );
125    }
126
127    #[test]
128    fn test_isqrt_correctness() {
129        type U16 = FixedUInt<u8, 2>;
130
131        // Verify r^2 <= n < (r+1)^2 for various values
132        for n in 0..=1000u16 {
133            let n_int = U16::from(n);
134            let r = ConstIsqrt::isqrt(n_int);
135
136            // r^2 <= n
137            assert!(r * r <= n_int, "Failed: {}^2 > {}", r, n);
138
139            // (r+1)^2 > n - use checked arithmetic to handle potential overflow
140            if let Some(r_plus_1) = r.checked_add(&U16::from(1u8)) {
141                // If (r+1)^2 overflows, it's definitely > n since n fits in U16
142                if let Some(square) = r_plus_1.checked_mul(&r_plus_1) {
143                    assert!(square > n_int, "Failed: {}^2 <= {}", r_plus_1, n);
144                }
145            }
146            // If r+1 overflows, r is MAX, so (r+1)^2 > n also holds
147        }
148    }
149
150    #[test]
151    fn test_isqrt_wider_types() {
152        // Test with wider word type to exercise cross-word bit-setting
153        type U32x2 = FixedUInt<u32, 2>;
154
155        // Perfect squares
156        assert_eq!(ConstIsqrt::isqrt(U32x2::from(0u8)), U32x2::from(0u8));
157        assert_eq!(ConstIsqrt::isqrt(U32x2::from(1u8)), U32x2::from(1u8));
158        assert_eq!(ConstIsqrt::isqrt(U32x2::from(16u8)), U32x2::from(4u8));
159
160        // Larger values that span multiple bits
161        assert_eq!(
162            ConstIsqrt::isqrt(U32x2::from(1000000u32)),
163            U32x2::from(1000u32)
164        );
165        assert_eq!(
166            ConstIsqrt::isqrt(U32x2::from(0xFFFFFFFFu32)),
167            U32x2::from(0xFFFFu32)
168        );
169
170        // Test with u8x4 for different word boundary behavior
171        type U8x4 = FixedUInt<u8, 4>;
172        assert_eq!(ConstIsqrt::isqrt(U8x4::from(65536u32)), U8x4::from(256u32));
173        assert_eq!(
174            ConstIsqrt::isqrt(U8x4::from(1000000u32)),
175            U8x4::from(1000u32)
176        );
177
178        // Verify correctness for a range
179        for n in (0..=10000u32).step_by(100) {
180            let n_int = U32x2::from(n);
181            let r = ConstIsqrt::isqrt(n_int);
182
183            // r^2 <= n
184            assert!(r * r <= n_int, "Failed: {}^2 > {} for U32x2", r, n);
185
186            // (r+1)^2 > n
187            if let Some(r_plus_1) = r.checked_add(&U32x2::from(1u8)) {
188                if let Some(square) = r_plus_1.checked_mul(&r_plus_1) {
189                    assert!(square > n_int, "Failed: {}^2 <= {} for U32x2", r_plus_1, n);
190                }
191            }
192        }
193    }
194
195    c0nst::c0nst! {
196        pub c0nst fn const_isqrt<T: [c0nst] ConstMachineWord + MachineWord, const N: usize>(
197            v: FixedUInt<T, N, Nct>,
198        ) -> FixedUInt<T, N, Nct> {
199            ConstIsqrt::isqrt(v)
200        }
201    }
202
203    #[test]
204    fn test_const_isqrt() {
205        type U16 = FixedUInt<u8, 2>;
206
207        assert_eq!(const_isqrt(U16::from(16u8)), U16::from(4u8));
208        assert_eq!(const_isqrt(U16::from(100u8)), U16::from(10u8));
209
210        #[cfg(feature = "nightly")]
211        {
212            const SIXTEEN: U16 = FixedUInt::from_array([16, 0]);
213            const RESULT: U16 = const_isqrt(SIXTEEN);
214            assert_eq!(RESULT, FixedUInt::from_array([4, 0]));
215        }
216    }
217}