1use crate::{Choice, CtOption, Int, ShrVartime, Uint, WrappingShr, primitives::u32_rem};
4use core::ops::{Shr, ShrAssign};
5
6impl<const LIMBS: usize> Int<LIMBS> {
7 #[inline(always)]
15 #[must_use]
16 pub const fn shr(&self, shift: u32) -> Self {
17 let sign_bits = Self::select(&Self::ZERO, &Self::MINUS_ONE, self.is_negative());
18 let res = Uint::shr(&self.0, shift);
19 Self::from_bits(res.bitor(&sign_bits.0.unbounded_shl(Self::BITS - shift)))
20 }
21
22 #[inline(always)]
34 #[must_use]
35 #[track_caller]
36 pub const fn shr_vartime(&self, shift: u32) -> Self {
37 self.overflowing_shr_vartime(shift)
38 .expect("`shift` within the bit size of the integer")
39 }
40
41 #[inline(always)]
48 #[must_use]
49 #[allow(clippy::integer_division_remainder_used, reason = "needs triage")]
50 pub const fn overflowing_shr(&self, shift: u32) -> CtOption<Self> {
51 let in_range = Choice::from_u32_lt(shift, Self::BITS);
52 let adj_shift = in_range.select_u32(0, shift);
53 CtOption::new(self.shr(adj_shift), in_range)
54 }
55
56 #[inline(always)]
67 #[must_use]
68 pub const fn overflowing_shr_vartime(&self, shift: u32) -> Option<Self> {
69 if shift < Self::BITS {
70 Some(self.unbounded_shr_vartime(shift))
71 } else {
72 None
73 }
74 }
75
76 #[inline(always)]
82 #[must_use]
83 pub const fn unbounded_shr(&self, shift: u32) -> Self {
84 let default = Self::select(&Self::ZERO, &Self::MINUS_ONE, self.is_negative());
85 ctutils::unwrap_or!(self.overflowing_shr(shift), default, Self::select)
86 }
87
88 #[inline(always)]
98 #[must_use]
99 pub const fn unbounded_shr_vartime(&self, shift: u32) -> Self {
100 let sign_bits = Self::select(&Self::ZERO, &Self::MINUS_ONE, self.is_negative());
101 if let Some(res) = self.0.overflowing_shr_vartime(shift) {
102 Self::from_bits(res.bitor(&sign_bits.0.unbounded_shl(Self::BITS - shift)))
103 } else {
104 sign_bits
105 }
106 }
107
108 #[inline]
114 #[must_use]
115 pub const fn wrapping_shr(&self, shift: u32) -> Self {
116 self.shr(u32_rem(shift, Self::BITS))
117 }
118
119 #[inline]
129 #[must_use]
130 #[allow(clippy::integer_division_remainder_used, reason = "needs triage")]
131 pub const fn wrapping_shr_vartime(&self, shift: u32) -> Self {
132 self.unbounded_shr_vartime(shift % Self::BITS)
133 }
134}
135
136macro_rules! impl_shr {
137 ($($shift:ty),+) => {
138 $(
139 impl<const LIMBS: usize> Shr<$shift> for Int<LIMBS> {
140 type Output = Int<LIMBS>;
141
142 #[inline]
143 fn shr(self, shift: $shift) -> Int<LIMBS> {
144 <&Self>::shr(&self, shift)
145 }
146 }
147
148 impl<const LIMBS: usize> Shr<$shift> for &Int<LIMBS> {
149 type Output = Int<LIMBS>;
150
151 #[inline]
152 fn shr(self, shift: $shift) -> Int<LIMBS> {
153 Int::<LIMBS>::shr(self, u32::try_from(shift).expect("invalid shift"))
154 }
155 }
156
157 impl<const LIMBS: usize> ShrAssign<$shift> for Int<LIMBS> {
158 fn shr_assign(&mut self, shift: $shift) {
159 *self = self.shr(shift)
160 }
161 }
162 )+
163 };
164}
165
166impl_shr!(i32, u32, usize);
167
168impl<const LIMBS: usize> WrappingShr for Int<LIMBS> {
169 fn wrapping_shr(&self, shift: u32) -> Int<LIMBS> {
170 self.wrapping_shr(shift)
171 }
172}
173
174impl<const LIMBS: usize> ShrVartime for Int<LIMBS> {
175 fn overflowing_shr_vartime(&self, shift: u32) -> Option<Self> {
176 self.overflowing_shr_vartime(shift)
177 }
178
179 fn unbounded_shr_vartime(&self, shift: u32) -> Self {
180 self.unbounded_shr_vartime(shift)
181 }
182
183 fn wrapping_shr_vartime(&self, shift: u32) -> Self {
184 self.wrapping_shr_vartime(shift)
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use core::ops::Div;
191
192 use crate::{I256, ShrVartime};
193
194 const N: I256 =
195 I256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141");
196
197 const N_2: I256 =
198 I256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF5D576E7357A4501DDFE92F46681B20A0");
199
200 #[test]
201 fn shr0() {
202 assert_eq!(I256::MAX >> 0, I256::MAX);
203 assert_eq!(I256::MIN >> 0, I256::MIN);
204 }
205
206 #[test]
207 fn shr1() {
208 assert_eq!(N >> 1, N_2);
209 assert_eq!(ShrVartime::overflowing_shr_vartime(&N, 1), Some(N_2));
210 assert_eq!(ShrVartime::wrapping_shr_vartime(&N, 1), N_2);
211 }
212
213 #[test]
214 fn shr5() {
215 assert_eq!(
216 I256::MAX >> 5,
217 I256::MAX.div(I256::from(32).to_nz().unwrap()).unwrap()
218 );
219 assert_eq!(
220 I256::MIN >> 5,
221 I256::MIN.div(I256::from(32).to_nz().unwrap()).unwrap()
222 );
223 }
224
225 #[test]
226 fn shr7_vartime() {
227 assert_eq!(
228 I256::MAX.shr_vartime(7),
229 I256::MAX.div(I256::from(128).to_nz().unwrap()).unwrap()
230 );
231 assert_eq!(
232 I256::MIN.shr_vartime(7),
233 I256::MIN.div(I256::from(128).to_nz().unwrap()).unwrap()
234 );
235 }
236
237 #[test]
238 fn shr256_const() {
239 assert!(N.overflowing_shr(256).is_none().to_bool_vartime());
240 assert!(ShrVartime::overflowing_shr_vartime(&N, 256).is_none());
241 }
242
243 #[test]
244 #[should_panic(expected = "`shift` exceeds upper bound")]
245 fn shr_bounds_panic() {
246 let _ = N >> 256;
247 }
248
249 #[test]
250 fn unbounded_shr_vartime_zero_shift() {
251 assert_eq!(I256::MAX.unbounded_shr_vartime(0), I256::MAX);
252 assert_eq!(I256::MIN.unbounded_shr_vartime(0), I256::MIN);
253 assert_eq!(I256::ONE.unbounded_shr_vartime(0), I256::ONE);
254 assert_eq!(I256::MINUS_ONE.unbounded_shr_vartime(0), I256::MINUS_ONE);
255 assert_eq!(I256::ZERO.unbounded_shr_vartime(0), I256::ZERO);
256 }
257
258 #[test]
259 fn overflowing_shr_vartime_zero_shift() {
260 let values = [I256::MAX, I256::MIN, I256::ONE, I256::MINUS_ONE, I256::ZERO];
261 for &val in &values {
262 assert_eq!(val.overflowing_shr_vartime(0), Some(val));
263 }
264 }
265
266 #[test]
267 fn shr_vartime_zero_shift() {
268 let values = [I256::MAX, I256::MIN, I256::ONE, I256::MINUS_ONE, I256::ZERO];
269 for &val in &values {
270 assert_eq!(val.shr_vartime(0), val);
271 }
272 }
273
274 #[test]
275 fn wrapping_shr_vartime_multiple_of_bits_is_identity() {
276 let values = [I256::MAX, I256::MIN, I256::ONE, I256::MINUS_ONE, I256::ZERO];
277 for &val in &values {
278 for i in 0..4 {
280 assert_eq!(val.wrapping_shr_vartime(i * I256::BITS), val);
281 }
282 }
283 }
284
285 #[test]
286 fn unbounded_shr() {
287 assert_eq!(I256::MAX.unbounded_shr(257), I256::ZERO);
288 assert_eq!(I256::MIN.unbounded_shr(257), I256::MINUS_ONE);
289 assert_eq!(
290 ShrVartime::unbounded_shr_vartime(&I256::MAX, 257),
291 I256::ZERO
292 );
293 assert_eq!(
294 ShrVartime::unbounded_shr_vartime(&I256::MIN, 257),
295 I256::MINUS_ONE
296 );
297 }
298
299 #[test]
300 fn wrapping_shr() {
301 assert_eq!(I256::MAX.wrapping_shr(257), I256::MAX.shr(1));
302 assert_eq!(I256::MIN.wrapping_shr(257), I256::MIN.shr(1));
303 assert_eq!(
304 ShrVartime::wrapping_shr_vartime(&I256::MAX, 257),
305 I256::MAX.shr(1)
306 );
307 assert_eq!(
308 ShrVartime::wrapping_shr_vartime(&I256::MIN, 257),
309 I256::MIN.shr(1)
310 );
311 }
312}