Skip to main content

crypto_bigint/uint/
rand.rs

1//! Random number generator support
2
3use super::Uint;
4use crate::{CtLt, Encoding, Limb, NonZero, Random, RandomBits, RandomBitsError, RandomMod};
5use rand_core::{Rng, TryRng};
6
7impl<const LIMBS: usize> Random for Uint<LIMBS> {
8    fn try_random_from_rng<R: TryRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
9        let mut limbs = [Limb::ZERO; LIMBS];
10
11        for limb in &mut limbs {
12            *limb = Limb::try_random_from_rng(rng)?;
13        }
14
15        Ok(limbs.into())
16    }
17}
18
19/// Fill the given limbs slice with random bits.
20///
21/// NOTE: Assumes that the limbs in the given slice are zeroed!
22///
23/// When combined with a platform-independent "4-byte sequential" `rng`, this function is
24/// platform-independent. We consider an RNG "`X`-byte sequential" whenever
25/// `rng.fill_bytes(&mut bytes[..i]); rng.fill_bytes(&mut bytes[i..])` constructs the same `bytes`,
26/// as long as `i` is a multiple of `X`.
27/// Note that the `TryRng` trait does _not_ require this behaviour from `rng`.
28#[allow(clippy::integer_division_remainder_used, reason = "public parameter")]
29pub(crate) fn random_bits_core<T, R: TryRng + ?Sized>(
30    rng: &mut R,
31    x: &mut T,
32    n_bits: u32,
33) -> Result<(), R::Error>
34where
35    T: Encoding,
36{
37    if n_bits == 0 {
38        return Ok(());
39    }
40
41    let n_bytes = n_bits.div_ceil(u8::BITS) as usize;
42    let hi_mask = u8::MAX >> ((u8::BITS - (n_bits % u8::BITS)) % u8::BITS);
43
44    let mut buffer = x.to_le_bytes();
45    let slice = buffer.as_mut();
46    rng.try_fill_bytes(&mut slice[..n_bytes])?;
47    slice[n_bytes - 1] &= hi_mask;
48    *x = T::from_le_bytes(buffer);
49
50    Ok(())
51}
52
53impl<const LIMBS: usize> RandomBits for Uint<LIMBS> {
54    fn try_random_bits<R: TryRng + ?Sized>(
55        rng: &mut R,
56        bit_length: u32,
57    ) -> Result<Self, RandomBitsError<R::Error>> {
58        Self::try_random_bits_with_precision(rng, bit_length, Self::BITS)
59    }
60
61    fn try_random_bits_with_precision<R: TryRng + ?Sized>(
62        rng: &mut R,
63        bit_length: u32,
64        bits_precision: u32,
65    ) -> Result<Self, RandomBitsError<R::Error>> {
66        if bits_precision != Self::BITS {
67            return Err(RandomBitsError::BitsPrecisionMismatch {
68                bits_precision,
69                integer_bits: Self::BITS,
70            });
71        }
72        if bit_length > Self::BITS {
73            return Err(RandomBitsError::BitLengthTooLarge {
74                bit_length,
75                bits_precision,
76            });
77        }
78        let mut x = Self::ZERO;
79        random_bits_core(rng, &mut x, bit_length).map_err(RandomBitsError::RandCore)?;
80        Ok(x)
81    }
82}
83
84impl<const LIMBS: usize> RandomMod for Uint<LIMBS> {
85    fn random_mod_vartime<R: Rng + ?Sized>(rng: &mut R, modulus: &NonZero<Self>) -> Self {
86        let mut x = Self::ZERO;
87        let Ok(()) = random_mod_vartime_core(rng, &mut x, modulus, modulus.bits_vartime());
88        x
89    }
90
91    fn try_random_mod_vartime<R: TryRng + ?Sized>(
92        rng: &mut R,
93        modulus: &NonZero<Self>,
94    ) -> Result<Self, R::Error> {
95        let mut x = Self::ZERO;
96        random_mod_vartime_core(rng, &mut x, modulus, modulus.bits_vartime())?;
97        Ok(x)
98    }
99}
100
101/// Generic implementation of `random_mod_vartime` which can be shared with `BoxedUint`.
102// TODO(tarcieri): obtain `n_bits` via a trait like `Integer`
103pub(super) fn random_mod_vartime_core<T, R: TryRng + ?Sized>(
104    rng: &mut R,
105    x: &mut T,
106    modulus: &NonZero<T>,
107    n_bits: u32,
108) -> Result<(), R::Error>
109where
110    T: Encoding + CtLt,
111{
112    loop {
113        random_bits_core(rng, x, n_bits)?;
114        if x.ct_lt(modulus).into() {
115            return Ok(());
116        }
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use crate::uint::rand::random_bits_core;
123    use crate::{Limb, NonZero, Random, RandomBits, RandomMod, U256, U1024, Uint};
124    use chacha20::ChaCha8Rng;
125    use rand_core::{Rng, SeedableRng};
126
127    const RANDOM_OUTPUT: U1024 = Uint::from_be_hex(concat![
128        "A484C4C693EECC47C3B919AE0D16DF2259CD1A8A9B8EA8E0862878227D4B40A3",
129        "C54302F2EB1E2F69E17653A37F1BCC44277FA208E6B31E08CDC4A23A7E88E660",
130        "EF781C7DD2D368BAD438539D6A2E923C8CAE14CB947EB0BDE10D666732024679",
131        "0F6760A48F9B887CB2FB0D3281E2A6E67746A55FBAD8C037B585F767A79A3B6C"
132    ]);
133
134    /// Construct a 4-sequential `rng`, i.e., an `rng` such that
135    /// `rng.fill_bytes(&mut buffer[..x]); rng.fill_bytes(&mut buffer[x..])` will construct the
136    /// same `buffer`, for `x` any in `0..buffer.len()` that is `0 mod 4`.
137    fn get_four_sequential_rng() -> ChaCha8Rng {
138        ChaCha8Rng::seed_from_u64(0)
139    }
140
141    /// Make sure the random value constructed is consistent across platforms
142    #[test]
143    fn random_platform_independence() {
144        let mut rng = get_four_sequential_rng();
145        assert_eq!(U1024::random_from_rng(&mut rng), RANDOM_OUTPUT);
146    }
147
148    #[test]
149    fn random_mod_vartime() {
150        let mut rng = ChaCha8Rng::seed_from_u64(1);
151
152        // Ensure `random_mod_vartime` runs in a reasonable amount of time
153        let modulus = NonZero::new(U256::from(42u8)).unwrap();
154        let res = U256::random_mod_vartime(&mut rng, &modulus);
155
156        // Check that the value is in range
157        assert!(res < U256::from(42u8));
158
159        // Ensure `random_mod_vartime` runs in a reasonable amount of time
160        // when the modulus is larger than 1 limb
161        let modulus = NonZero::new(U256::from(0x10000000000000001u128)).unwrap();
162        let res = U256::random_mod_vartime(&mut rng, &modulus);
163
164        // Check that the value is in range
165        assert!(res < U256::from(0x10000000000000001u128));
166    }
167
168    #[test]
169    fn random_bits() {
170        let mut rng = ChaCha8Rng::seed_from_u64(1);
171
172        let lower_bound = 16;
173
174        // Full length of the integer
175        let bit_length = U256::BITS;
176        for _ in 0..10 {
177            let res = U256::random_bits(&mut rng, bit_length);
178            assert!(res > (U256::ONE << (bit_length - lower_bound)));
179        }
180
181        // A multiple of limb size
182        let bit_length = U256::BITS - Limb::BITS;
183        for _ in 0..10 {
184            let res = U256::random_bits(&mut rng, bit_length);
185            assert!(res > (U256::ONE << (bit_length - lower_bound)));
186            assert!(res < (U256::ONE << bit_length));
187        }
188
189        // A multiple of 8
190        let bit_length = U256::BITS - Limb::BITS - 8;
191        for _ in 0..10 {
192            let res = U256::random_bits(&mut rng, bit_length);
193            assert!(res > (U256::ONE << (bit_length - lower_bound)));
194            assert!(res < (U256::ONE << bit_length));
195        }
196
197        // Not a multiple of 8
198        let bit_length = U256::BITS - Limb::BITS - 8 - 3;
199        for _ in 0..10 {
200            let res = U256::random_bits(&mut rng, bit_length);
201            assert!(res > (U256::ONE << (bit_length - lower_bound)));
202            assert!(res < (U256::ONE << bit_length));
203        }
204
205        // One incomplete limb
206        let bit_length = 7;
207        for _ in 0..10 {
208            let res = U256::random_bits(&mut rng, bit_length);
209            assert!(res < (U256::ONE << bit_length));
210        }
211
212        // Zero bits
213        let bit_length = 0;
214        for _ in 0..10 {
215            let res = U256::random_bits(&mut rng, bit_length);
216            assert_eq!(res, U256::ZERO);
217        }
218    }
219
220    /// Make sure the `random_bits` output is consistent across platforms
221    #[test]
222    fn random_bits_platform_independence() {
223        let mut rng = get_four_sequential_rng();
224
225        let bit_length = 989;
226        let mut val = U1024::ZERO;
227        random_bits_core(&mut rng, &mut val, bit_length).expect("safe");
228
229        assert_eq!(
230            val,
231            RANDOM_OUTPUT.bitand(&U1024::ONE.shl(bit_length).wrapping_sub(&Uint::ONE))
232        );
233
234        // Test that the RNG is in the same state
235        let mut state = [0u8; 16];
236        rng.fill_bytes(&mut state);
237
238        assert_eq!(
239            state,
240            [
241                198, 196, 132, 164, 240, 211, 223, 12, 36, 189, 139, 48, 94, 1, 123, 253
242            ]
243        );
244    }
245
246    /// Make sure `random_mod_vartime` output is consistent across platforms
247    #[test]
248    fn random_mod_vartime_platform_independence() {
249        let mut rng = get_four_sequential_rng();
250
251        let modulus = NonZero::new(U256::from_u32(8192)).unwrap();
252        let mut vals = [U256::ZERO; 5];
253        for val in &mut vals {
254            *val = U256::random_mod_vartime(&mut rng, &modulus);
255        }
256        let expected = [55, 3378, 2172, 1657, 5323];
257        for (want, got) in expected.into_iter().zip(vals.into_iter()) {
258            // assert_eq!(got.as_words()[0], want);
259            assert_eq!(got, U256::from_u32(want));
260        }
261
262        let modulus =
263            NonZero::new(U256::ZERO.wrapping_sub(&U256::from_u64(rng.next_u64()))).unwrap();
264        let val = U256::random_mod_vartime(&mut rng, &modulus);
265        assert_eq!(
266            val,
267            U256::from_be_hex("E17653A37F1BCC44277FA208E6B31E08CDC4A23A7E88E660EF781C7DD2D368BA")
268        );
269
270        let mut state = [0u8; 16];
271        rng.fill_bytes(&mut state);
272
273        assert_eq!(
274            state,
275            [
276                105, 47, 30, 235, 242, 2, 67, 197, 163, 64, 75, 125, 34, 120, 40, 134,
277            ],
278        );
279    }
280
281    /// Test that random bytes are sampled consecutively.
282    #[test]
283    fn random_bits_4_bytes_sequential() {
284        // Test for multiples of 4 bytes, i.e., multiples of 32 bits.
285        let bit_lengths = [0, 32, 64, 128, 192, 992];
286
287        for bit_length in bit_lengths {
288            let mut rng = get_four_sequential_rng();
289            let mut first = U1024::ZERO;
290            let mut second = U1024::ZERO;
291            random_bits_core(&mut rng, &mut first, bit_length).expect("safe");
292            random_bits_core(&mut rng, &mut second, U1024::BITS - bit_length).expect("safe");
293            assert_eq!(second.shl(bit_length).bitor(&first), RANDOM_OUTPUT);
294        }
295    }
296}