Skip to main content

crypto_bigint/int/
shr.rs

1//! [`Int`] bitwise right shift operations.
2
3use crate::{Choice, CtOption, Int, ShrVartime, Uint, WrappingShr, primitives::u32_rem};
4use core::ops::{Shr, ShrAssign};
5
6impl<const LIMBS: usize> Int<LIMBS> {
7    /// Computes `self >> shift`.
8    ///
9    /// Note, this is _signed_ shift right, i.e., the value shifted in on the left is equal to
10    /// the most significant bit.
11    ///
12    /// # Panics
13    /// - if `shift >= Self::BITS`.
14    #[inline(always)]
15    #[must_use]
16    pub const fn shr(&self, shift: u32) -> Self {
17        let sign_bits = Self::select(&Self::ZERO, &Self::MINUS_ONE, self.is_negative());
18        let res = Uint::shr(&self.0, shift);
19        Self::from_bits(res.bitor(&sign_bits.0.unbounded_shl(Self::BITS - shift)))
20    }
21
22    /// Computes `self >> shift` in variable time.
23    ///
24    /// Note, this is _signed_ shift right, i.e., the value shifted in on the left is equal to
25    /// the most significant bit.
26    ///
27    /// NOTE: this operation is variable time with respect to `shift` *ONLY*.
28    ///
29    /// When used with a fixed `shift`, this function is constant-time with respect to `self`.
30    ///
31    /// # Panics
32    /// - if `shift >= Self::BITS`.
33    #[inline(always)]
34    #[must_use]
35    #[track_caller]
36    pub const fn shr_vartime(&self, shift: u32) -> Self {
37        self.overflowing_shr_vartime(shift)
38            .expect("`shift` within the bit size of the integer")
39    }
40
41    /// Computes `self >> shift`.
42    ///
43    /// Note, this is _signed_ shift right, i.e., the value shifted in on the left is equal to
44    /// the most significant bit.
45    ///
46    /// Returns `None` if `shift >= Self::BITS`.
47    #[inline(always)]
48    #[must_use]
49    #[allow(clippy::integer_division_remainder_used, reason = "needs triage")]
50    pub const fn overflowing_shr(&self, shift: u32) -> CtOption<Self> {
51        let in_range = Choice::from_u32_lt(shift, Self::BITS);
52        let adj_shift = in_range.select_u32(0, shift);
53        CtOption::new(self.shr(adj_shift), in_range)
54    }
55
56    /// Computes `self >> shift`.
57    ///
58    /// NOTE: this is _signed_ shift right, i.e., the value shifted in on the left is equal to
59    /// the most significant bit.
60    ///
61    /// Returns `None` if `shift >= Self::BITS`.
62    ///
63    /// NOTE: this operation is variable time with respect to `shift` *ONLY*.
64    ///
65    /// When used with a fixed `shift`, this function is constant-time with respect to `self`.
66    #[inline(always)]
67    #[must_use]
68    pub const fn overflowing_shr_vartime(&self, shift: u32) -> Option<Self> {
69        if shift < Self::BITS {
70            Some(self.unbounded_shr_vartime(shift))
71        } else {
72            None
73        }
74    }
75
76    /// Computes `self >> shift` in a panic-free manner.
77    ///
78    /// If the shift exceeds the precision, returns
79    /// - `0` when `self` is non-negative, and
80    /// - `-1` when `self` is negative.
81    #[inline(always)]
82    #[must_use]
83    pub const fn unbounded_shr(&self, shift: u32) -> Self {
84        let default = Self::select(&Self::ZERO, &Self::MINUS_ONE, self.is_negative());
85        ctutils::unwrap_or!(self.overflowing_shr(shift), default, Self::select)
86    }
87
88    /// Computes `self >> shift` in variable-time in a panic-free manner.
89    ///
90    /// If the shift exceeds the precision, returns
91    /// - `0` when `self` is non-negative, and
92    /// - `-1` when `self` is negative.
93    ///
94    /// NOTE: this operation is variable time with respect to `shift` *ONLY*.
95    ///
96    /// When used with a fixed `shift`, this function is constant-time with respect to `self`.
97    #[inline(always)]
98    #[must_use]
99    pub const fn unbounded_shr_vartime(&self, shift: u32) -> Self {
100        let sign_bits = Self::select(&Self::ZERO, &Self::MINUS_ONE, self.is_negative());
101        if let Some(res) = self.0.overflowing_shr_vartime(shift) {
102            Self::from_bits(res.bitor(&sign_bits.0.unbounded_shl(Self::BITS - shift)))
103        } else {
104            sign_bits
105        }
106    }
107
108    /// Computes `self >> shift` in a panic-free manner.
109    ///
110    /// If the shift exceeds the precision, returns
111    /// - `0` when `self` is non-negative, and
112    /// - `-1` when `self` is negative.
113    #[inline]
114    #[must_use]
115    pub const fn wrapping_shr(&self, shift: u32) -> Self {
116        self.shr(u32_rem(shift, Self::BITS))
117    }
118
119    /// Computes `self >> shift` in variable-time in a panic-free manner.
120    ///
121    /// If the shift exceeds the precision, returns
122    /// - `0` when `self` is non-negative, and
123    /// - `-1` when `self` is negative.
124    ///
125    /// NOTE: this operation is variable time with respect to `shift` *ONLY*.
126    ///
127    /// When used with a fixed `shift`, this function is constant-time with respect to `self`.
128    #[inline]
129    #[must_use]
130    #[allow(clippy::integer_division_remainder_used, reason = "needs triage")]
131    pub const fn wrapping_shr_vartime(&self, shift: u32) -> Self {
132        self.unbounded_shr_vartime(shift % Self::BITS)
133    }
134}
135
136macro_rules! impl_shr {
137    ($($shift:ty),+) => {
138        $(
139            impl<const LIMBS: usize> Shr<$shift> for Int<LIMBS> {
140                type Output = Int<LIMBS>;
141
142                #[inline]
143                fn shr(self, shift: $shift) -> Int<LIMBS> {
144                    <&Self>::shr(&self, shift)
145                }
146            }
147
148            impl<const LIMBS: usize> Shr<$shift> for &Int<LIMBS> {
149                type Output = Int<LIMBS>;
150
151                #[inline]
152                fn shr(self, shift: $shift) -> Int<LIMBS> {
153                    Int::<LIMBS>::shr(self, u32::try_from(shift).expect("invalid shift"))
154                }
155            }
156
157            impl<const LIMBS: usize> ShrAssign<$shift> for Int<LIMBS> {
158                fn shr_assign(&mut self, shift: $shift) {
159                    *self = self.shr(shift)
160                }
161            }
162        )+
163    };
164}
165
166impl_shr!(i32, u32, usize);
167
168impl<const LIMBS: usize> WrappingShr for Int<LIMBS> {
169    fn wrapping_shr(&self, shift: u32) -> Int<LIMBS> {
170        self.wrapping_shr(shift)
171    }
172}
173
174impl<const LIMBS: usize> ShrVartime for Int<LIMBS> {
175    fn overflowing_shr_vartime(&self, shift: u32) -> Option<Self> {
176        self.overflowing_shr_vartime(shift)
177    }
178
179    fn unbounded_shr_vartime(&self, shift: u32) -> Self {
180        self.unbounded_shr_vartime(shift)
181    }
182
183    fn wrapping_shr_vartime(&self, shift: u32) -> Self {
184        self.wrapping_shr_vartime(shift)
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use core::ops::Div;
191
192    use crate::{I256, ShrVartime};
193
194    const N: I256 =
195        I256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141");
196
197    const N_2: I256 =
198        I256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF5D576E7357A4501DDFE92F46681B20A0");
199
200    #[test]
201    fn shr0() {
202        assert_eq!(I256::MAX >> 0, I256::MAX);
203        assert_eq!(I256::MIN >> 0, I256::MIN);
204    }
205
206    #[test]
207    fn shr1() {
208        assert_eq!(N >> 1, N_2);
209        assert_eq!(ShrVartime::overflowing_shr_vartime(&N, 1), Some(N_2));
210        assert_eq!(ShrVartime::wrapping_shr_vartime(&N, 1), N_2);
211    }
212
213    #[test]
214    fn shr5() {
215        assert_eq!(
216            I256::MAX >> 5,
217            I256::MAX.div(I256::from(32).to_nz().unwrap()).unwrap()
218        );
219        assert_eq!(
220            I256::MIN >> 5,
221            I256::MIN.div(I256::from(32).to_nz().unwrap()).unwrap()
222        );
223    }
224
225    #[test]
226    fn shr7_vartime() {
227        assert_eq!(
228            I256::MAX.shr_vartime(7),
229            I256::MAX.div(I256::from(128).to_nz().unwrap()).unwrap()
230        );
231        assert_eq!(
232            I256::MIN.shr_vartime(7),
233            I256::MIN.div(I256::from(128).to_nz().unwrap()).unwrap()
234        );
235    }
236
237    #[test]
238    fn shr256_const() {
239        assert!(N.overflowing_shr(256).is_none().to_bool_vartime());
240        assert!(ShrVartime::overflowing_shr_vartime(&N, 256).is_none());
241    }
242
243    #[test]
244    #[should_panic(expected = "`shift` exceeds upper bound")]
245    fn shr_bounds_panic() {
246        let _ = N >> 256;
247    }
248
249    #[test]
250    fn unbounded_shr_vartime_zero_shift() {
251        assert_eq!(I256::MAX.unbounded_shr_vartime(0), I256::MAX);
252        assert_eq!(I256::MIN.unbounded_shr_vartime(0), I256::MIN);
253        assert_eq!(I256::ONE.unbounded_shr_vartime(0), I256::ONE);
254        assert_eq!(I256::MINUS_ONE.unbounded_shr_vartime(0), I256::MINUS_ONE);
255        assert_eq!(I256::ZERO.unbounded_shr_vartime(0), I256::ZERO);
256    }
257
258    #[test]
259    fn overflowing_shr_vartime_zero_shift() {
260        let values = [I256::MAX, I256::MIN, I256::ONE, I256::MINUS_ONE, I256::ZERO];
261        for &val in &values {
262            assert_eq!(val.overflowing_shr_vartime(0), Some(val));
263        }
264    }
265
266    #[test]
267    fn shr_vartime_zero_shift() {
268        let values = [I256::MAX, I256::MIN, I256::ONE, I256::MINUS_ONE, I256::ZERO];
269        for &val in &values {
270            assert_eq!(val.shr_vartime(0), val);
271        }
272    }
273
274    #[test]
275    fn wrapping_shr_vartime_multiple_of_bits_is_identity() {
276        let values = [I256::MAX, I256::MIN, I256::ONE, I256::MINUS_ONE, I256::ZERO];
277        for &val in &values {
278            // Shift by 0 and multiples of the bit size should be identity.
279            for i in 0..4 {
280                assert_eq!(val.wrapping_shr_vartime(i * I256::BITS), val);
281            }
282        }
283    }
284
285    #[test]
286    fn unbounded_shr() {
287        assert_eq!(I256::MAX.unbounded_shr(257), I256::ZERO);
288        assert_eq!(I256::MIN.unbounded_shr(257), I256::MINUS_ONE);
289        assert_eq!(
290            ShrVartime::unbounded_shr_vartime(&I256::MAX, 257),
291            I256::ZERO
292        );
293        assert_eq!(
294            ShrVartime::unbounded_shr_vartime(&I256::MIN, 257),
295            I256::MINUS_ONE
296        );
297    }
298
299    #[test]
300    fn wrapping_shr() {
301        assert_eq!(I256::MAX.wrapping_shr(257), I256::MAX.shr(1));
302        assert_eq!(I256::MIN.wrapping_shr(257), I256::MIN.shr(1));
303        assert_eq!(
304            ShrVartime::wrapping_shr_vartime(&I256::MAX, 257),
305            I256::MAX.shr(1)
306        );
307        assert_eq!(
308            ShrVartime::wrapping_shr_vartime(&I256::MIN, 257),
309            I256::MIN.shr(1)
310        );
311    }
312}