fixed_bigint/fixeduint/
isqrt_impl.rs1use super::{const_set_bit, FixedUInt, MachineWord};
18use crate::const_numtraits::{ConstBitPrimInt, ConstIsqrt, ConstZero};
19use crate::machineword::ConstMachineWord;
20use crate::personality::Nct;
21
22c0nst::c0nst! {
23 impl<T: [c0nst] ConstMachineWord + MachineWord, const N: usize> c0nst ConstIsqrt for FixedUInt<T, N, Nct> {
24 fn isqrt(self) -> Self {
25 match ConstIsqrt::checked_isqrt(self) {
27 Some(v) => v,
28 None => unreachable!(),
29 }
30 }
31
32 fn checked_isqrt(self) -> Option<Self> {
33 if self.is_zero() {
36 return Some(Self::zero());
37 }
38
39 let mut result = Self::zero();
42
43 let bit_len = Self::BIT_SIZE - ConstBitPrimInt::leading_zeros(self) as usize;
45 let start_bit = bit_len.div_ceil(2);
46
47 let mut bit_pos = start_bit;
48 while bit_pos > 0 {
49 bit_pos -= 1;
50
51 let mut candidate = result;
53 const_set_bit(&mut candidate.array, bit_pos);
54
55 let square = candidate * candidate;
59 if square <= self {
60 result = candidate;
61 }
62 }
63
64 Some(result)
65 }
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use num_traits::{CheckedAdd, CheckedMul};
73
74 #[test]
75 fn test_isqrt() {
76 type U16 = FixedUInt<u8, 2>;
77
78 assert_eq!(ConstIsqrt::isqrt(U16::from(0u8)), U16::from(0u8));
80 assert_eq!(ConstIsqrt::isqrt(U16::from(1u8)), U16::from(1u8));
81 assert_eq!(ConstIsqrt::isqrt(U16::from(4u8)), U16::from(2u8));
82 assert_eq!(ConstIsqrt::isqrt(U16::from(9u8)), U16::from(3u8));
83 assert_eq!(ConstIsqrt::isqrt(U16::from(16u8)), U16::from(4u8));
84 assert_eq!(ConstIsqrt::isqrt(U16::from(25u8)), U16::from(5u8));
85 assert_eq!(ConstIsqrt::isqrt(U16::from(100u8)), U16::from(10u8));
86 assert_eq!(ConstIsqrt::isqrt(U16::from(144u8)), U16::from(12u8));
87
88 assert_eq!(ConstIsqrt::isqrt(U16::from(2u8)), U16::from(1u8));
90 assert_eq!(ConstIsqrt::isqrt(U16::from(3u8)), U16::from(1u8));
91 assert_eq!(ConstIsqrt::isqrt(U16::from(5u8)), U16::from(2u8));
92 assert_eq!(ConstIsqrt::isqrt(U16::from(8u8)), U16::from(2u8));
93 assert_eq!(ConstIsqrt::isqrt(U16::from(10u8)), U16::from(3u8));
94 assert_eq!(ConstIsqrt::isqrt(U16::from(15u8)), U16::from(3u8));
95 assert_eq!(ConstIsqrt::isqrt(U16::from(24u8)), U16::from(4u8));
96 }
97
98 #[test]
99 fn test_isqrt_larger_values() {
100 type U16 = FixedUInt<u8, 2>;
101
102 assert_eq!(ConstIsqrt::isqrt(U16::from(10000u16)), U16::from(100u8));
104 assert_eq!(ConstIsqrt::isqrt(U16::from(65535u16)), U16::from(255u8)); assert_eq!(ConstIsqrt::isqrt(U16::from(65025u16)), U16::from(255u8)); }
107
108 #[test]
109 fn test_checked_isqrt() {
110 type U16 = FixedUInt<u8, 2>;
111
112 assert_eq!(
114 ConstIsqrt::checked_isqrt(U16::from(0u8)),
115 Some(U16::from(0u8))
116 );
117 assert_eq!(
118 ConstIsqrt::checked_isqrt(U16::from(16u8)),
119 Some(U16::from(4u8))
120 );
121 assert_eq!(
122 ConstIsqrt::checked_isqrt(U16::from(17u8)),
123 Some(U16::from(4u8))
124 );
125 }
126
127 #[test]
128 fn test_isqrt_correctness() {
129 type U16 = FixedUInt<u8, 2>;
130
131 for n in 0..=1000u16 {
133 let n_int = U16::from(n);
134 let r = ConstIsqrt::isqrt(n_int);
135
136 assert!(r * r <= n_int, "Failed: {}^2 > {}", r, n);
138
139 if let Some(r_plus_1) = r.checked_add(&U16::from(1u8)) {
141 if let Some(square) = r_plus_1.checked_mul(&r_plus_1) {
143 assert!(square > n_int, "Failed: {}^2 <= {}", r_plus_1, n);
144 }
145 }
146 }
148 }
149
150 #[test]
151 fn test_isqrt_wider_types() {
152 type U32x2 = FixedUInt<u32, 2>;
154
155 assert_eq!(ConstIsqrt::isqrt(U32x2::from(0u8)), U32x2::from(0u8));
157 assert_eq!(ConstIsqrt::isqrt(U32x2::from(1u8)), U32x2::from(1u8));
158 assert_eq!(ConstIsqrt::isqrt(U32x2::from(16u8)), U32x2::from(4u8));
159
160 assert_eq!(
162 ConstIsqrt::isqrt(U32x2::from(1000000u32)),
163 U32x2::from(1000u32)
164 );
165 assert_eq!(
166 ConstIsqrt::isqrt(U32x2::from(0xFFFFFFFFu32)),
167 U32x2::from(0xFFFFu32)
168 );
169
170 type U8x4 = FixedUInt<u8, 4>;
172 assert_eq!(ConstIsqrt::isqrt(U8x4::from(65536u32)), U8x4::from(256u32));
173 assert_eq!(
174 ConstIsqrt::isqrt(U8x4::from(1000000u32)),
175 U8x4::from(1000u32)
176 );
177
178 for n in (0..=10000u32).step_by(100) {
180 let n_int = U32x2::from(n);
181 let r = ConstIsqrt::isqrt(n_int);
182
183 assert!(r * r <= n_int, "Failed: {}^2 > {} for U32x2", r, n);
185
186 if let Some(r_plus_1) = r.checked_add(&U32x2::from(1u8)) {
188 if let Some(square) = r_plus_1.checked_mul(&r_plus_1) {
189 assert!(square > n_int, "Failed: {}^2 <= {} for U32x2", r_plus_1, n);
190 }
191 }
192 }
193 }
194
195 c0nst::c0nst! {
196 pub c0nst fn const_isqrt<T: [c0nst] ConstMachineWord + MachineWord, const N: usize>(
197 v: FixedUInt<T, N, Nct>,
198 ) -> FixedUInt<T, N, Nct> {
199 ConstIsqrt::isqrt(v)
200 }
201 }
202
203 #[test]
204 fn test_const_isqrt() {
205 type U16 = FixedUInt<u8, 2>;
206
207 assert_eq!(const_isqrt(U16::from(16u8)), U16::from(4u8));
208 assert_eq!(const_isqrt(U16::from(100u8)), U16::from(10u8));
209
210 #[cfg(feature = "nightly")]
211 {
212 const SIXTEEN: U16 = FixedUInt::from_array([16, 0]);
213 const RESULT: U16 = const_isqrt(SIXTEEN);
214 assert_eq!(RESULT, FixedUInt::from_array([4, 0]));
215 }
216 }
217}