Skip to main content

fixed_bigint/fixeduint/
isqrt_impl.rs

1// Copyright 2021 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Integer square root for FixedUInt.
16
17use super::{const_set_bit, FixedUInt, MachineWord};
18use crate::const_numtraits::{ConstIsqrt, ConstPrimInt, ConstZero};
19use crate::machineword::ConstMachineWord;
20
21c0nst::c0nst! {
22    impl<T: [c0nst] ConstMachineWord + MachineWord, const N: usize> c0nst ConstIsqrt for FixedUInt<T, N> {
23        fn isqrt(self) -> Self {
24            // For unsigned types, isqrt always succeeds
25            match ConstIsqrt::checked_isqrt(self) {
26                Some(v) => v,
27                None => unreachable!(),
28            }
29        }
30
31        fn checked_isqrt(self) -> Option<Self> {
32            // Bit-by-bit algorithm for integer square root
33            // Returns the largest r such that r * r <= self
34            if self.is_zero() {
35                return Some(Self::zero());
36            }
37
38            // Start with the highest bit position that could be set in the result
39            // For sqrt, the result has at most half the bits of the input
40            let mut result = Self::zero();
41
42            // Find starting bit position: half of the bit length of self
43            let bit_len = Self::BIT_SIZE - ConstPrimInt::leading_zeros(self) as usize;
44            let start_bit = bit_len.div_ceil(2);
45
46            let mut bit_pos = start_bit;
47            while bit_pos > 0 {
48                bit_pos -= 1;
49
50                // Try setting this bit in the result
51                let mut candidate = result;
52                const_set_bit(&mut candidate.array, bit_pos);
53
54                // Check if candidate * candidate <= self
55                // Since candidate has at most half the bits of self,
56                // candidate * candidate won't overflow.
57                let square = candidate * candidate;
58                if square <= self {
59                    result = candidate;
60                }
61            }
62
63            Some(result)
64        }
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use num_traits::{CheckedAdd, CheckedMul};
72
73    #[test]
74    fn test_isqrt() {
75        type U16 = FixedUInt<u8, 2>;
76
77        // Perfect squares
78        assert_eq!(ConstIsqrt::isqrt(U16::from(0u8)), U16::from(0u8));
79        assert_eq!(ConstIsqrt::isqrt(U16::from(1u8)), U16::from(1u8));
80        assert_eq!(ConstIsqrt::isqrt(U16::from(4u8)), U16::from(2u8));
81        assert_eq!(ConstIsqrt::isqrt(U16::from(9u8)), U16::from(3u8));
82        assert_eq!(ConstIsqrt::isqrt(U16::from(16u8)), U16::from(4u8));
83        assert_eq!(ConstIsqrt::isqrt(U16::from(25u8)), U16::from(5u8));
84        assert_eq!(ConstIsqrt::isqrt(U16::from(100u8)), U16::from(10u8));
85        assert_eq!(ConstIsqrt::isqrt(U16::from(144u8)), U16::from(12u8));
86
87        // Non-perfect squares (floor)
88        assert_eq!(ConstIsqrt::isqrt(U16::from(2u8)), U16::from(1u8));
89        assert_eq!(ConstIsqrt::isqrt(U16::from(3u8)), U16::from(1u8));
90        assert_eq!(ConstIsqrt::isqrt(U16::from(5u8)), U16::from(2u8));
91        assert_eq!(ConstIsqrt::isqrt(U16::from(8u8)), U16::from(2u8));
92        assert_eq!(ConstIsqrt::isqrt(U16::from(10u8)), U16::from(3u8));
93        assert_eq!(ConstIsqrt::isqrt(U16::from(15u8)), U16::from(3u8));
94        assert_eq!(ConstIsqrt::isqrt(U16::from(24u8)), U16::from(4u8));
95    }
96
97    #[test]
98    fn test_isqrt_larger_values() {
99        type U16 = FixedUInt<u8, 2>;
100
101        // Larger values
102        assert_eq!(ConstIsqrt::isqrt(U16::from(10000u16)), U16::from(100u8));
103        assert_eq!(ConstIsqrt::isqrt(U16::from(65535u16)), U16::from(255u8)); // sqrt(65535) = 255.998...
104        assert_eq!(ConstIsqrt::isqrt(U16::from(65025u16)), U16::from(255u8)); // 255^2 = 65025
105    }
106
107    #[test]
108    fn test_checked_isqrt() {
109        type U16 = FixedUInt<u8, 2>;
110
111        // For unsigned, checked_isqrt always returns Some
112        assert_eq!(
113            ConstIsqrt::checked_isqrt(U16::from(0u8)),
114            Some(U16::from(0u8))
115        );
116        assert_eq!(
117            ConstIsqrt::checked_isqrt(U16::from(16u8)),
118            Some(U16::from(4u8))
119        );
120        assert_eq!(
121            ConstIsqrt::checked_isqrt(U16::from(17u8)),
122            Some(U16::from(4u8))
123        );
124    }
125
126    #[test]
127    fn test_isqrt_correctness() {
128        type U16 = FixedUInt<u8, 2>;
129
130        // Verify r^2 <= n < (r+1)^2 for various values
131        for n in 0..=1000u16 {
132            let n_int = U16::from(n);
133            let r = ConstIsqrt::isqrt(n_int);
134
135            // r^2 <= n
136            assert!(r * r <= n_int, "Failed: {}^2 > {}", r, n);
137
138            // (r+1)^2 > n - use checked arithmetic to handle potential overflow
139            if let Some(r_plus_1) = r.checked_add(&U16::from(1u8)) {
140                // If (r+1)^2 overflows, it's definitely > n since n fits in U16
141                if let Some(square) = r_plus_1.checked_mul(&r_plus_1) {
142                    assert!(square > n_int, "Failed: {}^2 <= {}", r_plus_1, n);
143                }
144            }
145            // If r+1 overflows, r is MAX, so (r+1)^2 > n also holds
146        }
147    }
148
149    #[test]
150    fn test_isqrt_wider_types() {
151        // Test with wider word type to exercise cross-word bit-setting
152        type U32x2 = FixedUInt<u32, 2>;
153
154        // Perfect squares
155        assert_eq!(ConstIsqrt::isqrt(U32x2::from(0u8)), U32x2::from(0u8));
156        assert_eq!(ConstIsqrt::isqrt(U32x2::from(1u8)), U32x2::from(1u8));
157        assert_eq!(ConstIsqrt::isqrt(U32x2::from(16u8)), U32x2::from(4u8));
158
159        // Larger values that span multiple bits
160        assert_eq!(
161            ConstIsqrt::isqrt(U32x2::from(1000000u32)),
162            U32x2::from(1000u32)
163        );
164        assert_eq!(
165            ConstIsqrt::isqrt(U32x2::from(0xFFFFFFFFu32)),
166            U32x2::from(0xFFFFu32)
167        );
168
169        // Test with u8x4 for different word boundary behavior
170        type U8x4 = FixedUInt<u8, 4>;
171        assert_eq!(ConstIsqrt::isqrt(U8x4::from(65536u32)), U8x4::from(256u32));
172        assert_eq!(
173            ConstIsqrt::isqrt(U8x4::from(1000000u32)),
174            U8x4::from(1000u32)
175        );
176
177        // Verify correctness for a range
178        for n in (0..=10000u32).step_by(100) {
179            let n_int = U32x2::from(n);
180            let r = ConstIsqrt::isqrt(n_int);
181
182            // r^2 <= n
183            assert!(r * r <= n_int, "Failed: {}^2 > {} for U32x2", r, n);
184
185            // (r+1)^2 > n
186            if let Some(r_plus_1) = r.checked_add(&U32x2::from(1u8)) {
187                if let Some(square) = r_plus_1.checked_mul(&r_plus_1) {
188                    assert!(square > n_int, "Failed: {}^2 <= {} for U32x2", r_plus_1, n);
189                }
190            }
191        }
192    }
193
194    c0nst::c0nst! {
195        pub c0nst fn const_isqrt<T: [c0nst] ConstMachineWord + MachineWord, const N: usize>(
196            v: FixedUInt<T, N>,
197        ) -> FixedUInt<T, N> {
198            ConstIsqrt::isqrt(v)
199        }
200    }
201
202    #[test]
203    fn test_const_isqrt() {
204        type U16 = FixedUInt<u8, 2>;
205
206        assert_eq!(const_isqrt(U16::from(16u8)), U16::from(4u8));
207        assert_eq!(const_isqrt(U16::from(100u8)), U16::from(10u8));
208
209        #[cfg(feature = "nightly")]
210        {
211            const SIXTEEN: U16 = FixedUInt { array: [16, 0] };
212            const RESULT: U16 = const_isqrt(SIXTEEN);
213            assert_eq!(RESULT, FixedUInt { array: [4, 0] });
214        }
215    }
216}