1use super::Uint;
4use crate::{CtLt, Encoding, Limb, NonZero, Random, RandomBits, RandomBitsError, RandomMod};
5use rand_core::{Rng, TryRng};
6
7impl<const LIMBS: usize> Random for Uint<LIMBS> {
8 fn try_random_from_rng<R: TryRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
9 let mut limbs = [Limb::ZERO; LIMBS];
10
11 for limb in &mut limbs {
12 *limb = Limb::try_random_from_rng(rng)?;
13 }
14
15 Ok(limbs.into())
16 }
17}
18
19#[allow(clippy::integer_division_remainder_used, reason = "public parameter")]
29pub(crate) fn random_bits_core<T, R: TryRng + ?Sized>(
30 rng: &mut R,
31 x: &mut T,
32 n_bits: u32,
33) -> Result<(), R::Error>
34where
35 T: Encoding,
36{
37 if n_bits == 0 {
38 return Ok(());
39 }
40
41 let n_bytes = n_bits.div_ceil(u8::BITS) as usize;
42 let hi_mask = u8::MAX >> ((u8::BITS - (n_bits % u8::BITS)) % u8::BITS);
43
44 let mut buffer = x.to_le_bytes();
45 let slice = buffer.as_mut();
46 rng.try_fill_bytes(&mut slice[..n_bytes])?;
47 slice[n_bytes - 1] &= hi_mask;
48 *x = T::from_le_bytes(buffer);
49
50 Ok(())
51}
52
53impl<const LIMBS: usize> RandomBits for Uint<LIMBS> {
54 fn try_random_bits<R: TryRng + ?Sized>(
55 rng: &mut R,
56 bit_length: u32,
57 ) -> Result<Self, RandomBitsError<R::Error>> {
58 Self::try_random_bits_with_precision(rng, bit_length, Self::BITS)
59 }
60
61 fn try_random_bits_with_precision<R: TryRng + ?Sized>(
62 rng: &mut R,
63 bit_length: u32,
64 bits_precision: u32,
65 ) -> Result<Self, RandomBitsError<R::Error>> {
66 if bits_precision != Self::BITS {
67 return Err(RandomBitsError::BitsPrecisionMismatch {
68 bits_precision,
69 integer_bits: Self::BITS,
70 });
71 }
72 if bit_length > Self::BITS {
73 return Err(RandomBitsError::BitLengthTooLarge {
74 bit_length,
75 bits_precision,
76 });
77 }
78 let mut x = Self::ZERO;
79 random_bits_core(rng, &mut x, bit_length).map_err(RandomBitsError::RandCore)?;
80 Ok(x)
81 }
82}
83
84impl<const LIMBS: usize> RandomMod for Uint<LIMBS> {
85 fn random_mod_vartime<R: Rng + ?Sized>(rng: &mut R, modulus: &NonZero<Self>) -> Self {
86 let mut x = Self::ZERO;
87 let Ok(()) = random_mod_vartime_core(rng, &mut x, modulus, modulus.bits_vartime());
88 x
89 }
90
91 fn try_random_mod_vartime<R: TryRng + ?Sized>(
92 rng: &mut R,
93 modulus: &NonZero<Self>,
94 ) -> Result<Self, R::Error> {
95 let mut x = Self::ZERO;
96 random_mod_vartime_core(rng, &mut x, modulus, modulus.bits_vartime())?;
97 Ok(x)
98 }
99}
100
101pub(super) fn random_mod_vartime_core<T, R: TryRng + ?Sized>(
104 rng: &mut R,
105 x: &mut T,
106 modulus: &NonZero<T>,
107 n_bits: u32,
108) -> Result<(), R::Error>
109where
110 T: Encoding + CtLt,
111{
112 loop {
113 random_bits_core(rng, x, n_bits)?;
114 if x.ct_lt(modulus).into() {
115 return Ok(());
116 }
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use crate::uint::rand::random_bits_core;
123 use crate::{Limb, NonZero, Random, RandomBits, RandomMod, U256, U1024, Uint};
124 use chacha20::ChaCha8Rng;
125 use rand_core::{Rng, SeedableRng};
126
127 const RANDOM_OUTPUT: U1024 = Uint::from_be_hex(concat![
128 "A484C4C693EECC47C3B919AE0D16DF2259CD1A8A9B8EA8E0862878227D4B40A3",
129 "C54302F2EB1E2F69E17653A37F1BCC44277FA208E6B31E08CDC4A23A7E88E660",
130 "EF781C7DD2D368BAD438539D6A2E923C8CAE14CB947EB0BDE10D666732024679",
131 "0F6760A48F9B887CB2FB0D3281E2A6E67746A55FBAD8C037B585F767A79A3B6C"
132 ]);
133
134 fn get_four_sequential_rng() -> ChaCha8Rng {
138 ChaCha8Rng::seed_from_u64(0)
139 }
140
141 #[test]
143 fn random_platform_independence() {
144 let mut rng = get_four_sequential_rng();
145 assert_eq!(U1024::random_from_rng(&mut rng), RANDOM_OUTPUT);
146 }
147
148 #[test]
149 fn random_mod_vartime() {
150 let mut rng = ChaCha8Rng::seed_from_u64(1);
151
152 let modulus = NonZero::new(U256::from(42u8)).unwrap();
154 let res = U256::random_mod_vartime(&mut rng, &modulus);
155
156 assert!(res < U256::from(42u8));
158
159 let modulus = NonZero::new(U256::from(0x10000000000000001u128)).unwrap();
162 let res = U256::random_mod_vartime(&mut rng, &modulus);
163
164 assert!(res < U256::from(0x10000000000000001u128));
166 }
167
168 #[test]
169 fn random_bits() {
170 let mut rng = ChaCha8Rng::seed_from_u64(1);
171
172 let lower_bound = 16;
173
174 let bit_length = U256::BITS;
176 for _ in 0..10 {
177 let res = U256::random_bits(&mut rng, bit_length);
178 assert!(res > (U256::ONE << (bit_length - lower_bound)));
179 }
180
181 let bit_length = U256::BITS - Limb::BITS;
183 for _ in 0..10 {
184 let res = U256::random_bits(&mut rng, bit_length);
185 assert!(res > (U256::ONE << (bit_length - lower_bound)));
186 assert!(res < (U256::ONE << bit_length));
187 }
188
189 let bit_length = U256::BITS - Limb::BITS - 8;
191 for _ in 0..10 {
192 let res = U256::random_bits(&mut rng, bit_length);
193 assert!(res > (U256::ONE << (bit_length - lower_bound)));
194 assert!(res < (U256::ONE << bit_length));
195 }
196
197 let bit_length = U256::BITS - Limb::BITS - 8 - 3;
199 for _ in 0..10 {
200 let res = U256::random_bits(&mut rng, bit_length);
201 assert!(res > (U256::ONE << (bit_length - lower_bound)));
202 assert!(res < (U256::ONE << bit_length));
203 }
204
205 let bit_length = 7;
207 for _ in 0..10 {
208 let res = U256::random_bits(&mut rng, bit_length);
209 assert!(res < (U256::ONE << bit_length));
210 }
211
212 let bit_length = 0;
214 for _ in 0..10 {
215 let res = U256::random_bits(&mut rng, bit_length);
216 assert_eq!(res, U256::ZERO);
217 }
218 }
219
220 #[test]
222 fn random_bits_platform_independence() {
223 let mut rng = get_four_sequential_rng();
224
225 let bit_length = 989;
226 let mut val = U1024::ZERO;
227 random_bits_core(&mut rng, &mut val, bit_length).expect("safe");
228
229 assert_eq!(
230 val,
231 RANDOM_OUTPUT.bitand(&U1024::ONE.shl(bit_length).wrapping_sub(&Uint::ONE))
232 );
233
234 let mut state = [0u8; 16];
236 rng.fill_bytes(&mut state);
237
238 assert_eq!(
239 state,
240 [
241 198, 196, 132, 164, 240, 211, 223, 12, 36, 189, 139, 48, 94, 1, 123, 253
242 ]
243 );
244 }
245
246 #[test]
248 fn random_mod_vartime_platform_independence() {
249 let mut rng = get_four_sequential_rng();
250
251 let modulus = NonZero::new(U256::from_u32(8192)).unwrap();
252 let mut vals = [U256::ZERO; 5];
253 for val in &mut vals {
254 *val = U256::random_mod_vartime(&mut rng, &modulus);
255 }
256 let expected = [55, 3378, 2172, 1657, 5323];
257 for (want, got) in expected.into_iter().zip(vals.into_iter()) {
258 assert_eq!(got, U256::from_u32(want));
260 }
261
262 let modulus =
263 NonZero::new(U256::ZERO.wrapping_sub(&U256::from_u64(rng.next_u64()))).unwrap();
264 let val = U256::random_mod_vartime(&mut rng, &modulus);
265 assert_eq!(
266 val,
267 U256::from_be_hex("E17653A37F1BCC44277FA208E6B31E08CDC4A23A7E88E660EF781C7DD2D368BA")
268 );
269
270 let mut state = [0u8; 16];
271 rng.fill_bytes(&mut state);
272
273 assert_eq!(
274 state,
275 [
276 105, 47, 30, 235, 242, 2, 67, 197, 163, 64, 75, 125, 34, 120, 40, 134,
277 ],
278 );
279 }
280
281 #[test]
283 fn random_bits_4_bytes_sequential() {
284 let bit_lengths = [0, 32, 64, 128, 192, 992];
286
287 for bit_length in bit_lengths {
288 let mut rng = get_four_sequential_rng();
289 let mut first = U1024::ZERO;
290 let mut second = U1024::ZERO;
291 random_bits_core(&mut rng, &mut first, bit_length).expect("safe");
292 random_bits_core(&mut rng, &mut second, U1024::BITS - bit_length).expect("safe");
293 assert_eq!(second.shl(bit_length).bitor(&first), RANDOM_OUTPUT);
294 }
295 }
296}