1use aes::cipher::generic_array::{typenum::U16, typenum::U8, GenericArray};
63use aes::cipher::{BlockEncrypt, KeyInit};
64use aes::Aes128;
65use byteorder::{ByteOrder, LittleEndian};
66use rand::{CryptoRng, RngCore, SeedableRng};
67use std::mem;
68use std::slice;
69
70const AES_BLK_SIZE: usize = 16;
71const PIPELINES_U128: u128 = 8;
72const PIPELINES_USIZE: usize = 8;
73const STATE_SIZE: usize = PIPELINES_USIZE * AES_BLK_SIZE;
74pub const SEED_SIZE: usize = AES_BLK_SIZE;
75pub type RngSeed = [u8; SEED_SIZE];
76
77type Block128 = GenericArray<u8, U16>;
78type Block128x8 = GenericArray<Block128, U8>;
79
80#[derive(Clone)]
81pub struct AesRngState {
82 blocks: Block128x8,
83 next_index: u128,
84 used_bytes: usize,
85}
86
87impl Default for AesRngState {
88 fn default() -> Self {
89 AesRngState::init()
90 }
91}
92
93fn create_init_state() -> Block128x8 {
97 let mut state = [0_u8; STATE_SIZE];
98 Block128x8::from_exact_iter((0..PIPELINES_USIZE).map(|i| {
99 LittleEndian::write_u128(
100 &mut state[i * AES_BLK_SIZE..(i + 1) * AES_BLK_SIZE],
101 i as u128,
102 );
103 let sliced_state = &mut state[i * AES_BLK_SIZE..(i + 1) * AES_BLK_SIZE];
104 let block = GenericArray::from_mut_slice(sliced_state);
105 *block
106 }))
107 .unwrap()
108}
109
110impl AesRngState {
111 fn as_mut_bytes(&mut self) -> &mut [u8] {
113 #[allow(unsafe_code)]
114 unsafe {
115 slice::from_raw_parts_mut(&mut self.blocks as *mut Block128x8 as *mut u8, STATE_SIZE)
116 }
117 }
118
119 fn init() -> Self {
121 AesRngState {
122 blocks: create_init_state(),
123 next_index: PIPELINES_U128,
124 used_bytes: 0,
125 }
126 }
127
128 fn next(&mut self) {
131 let counter = self.next_index;
132 let blocks_bytes = self.as_mut_bytes();
133 for i in 0..PIPELINES_USIZE {
134 LittleEndian::write_u128(
135 &mut blocks_bytes[i * AES_BLK_SIZE..(i + 1) * AES_BLK_SIZE],
136 counter + i as u128,
137 );
138 }
139 self.next_index += PIPELINES_U128;
140 self.used_bytes = 0;
141 }
142}
143
144#[derive(Clone)]
146pub struct AesRng {
147 state: AesRngState,
148 cipher: Aes128,
149 n_cached_bits: usize,
150 cached_bits: u64,
151}
152
153impl SeedableRng for AesRng {
154 type Seed = RngSeed;
155 #[inline]
159 fn from_seed(seed: Self::Seed) -> Self {
160 let key = GenericArray::from(seed);
161 let mut out = AesRng {
162 state: AesRngState::default(),
163 cipher: Aes128::new(&key),
164 n_cached_bits: 0,
165 cached_bits: 0,
166 };
167 out.init();
168 out
169 }
170}
171
172impl AesRng {
173 fn init(&mut self) {
176 self.cipher.encrypt_blocks(&mut self.state.blocks);
177 }
178
179 fn next(&mut self) {
183 self.state.next();
184 self.cipher.encrypt_blocks(&mut self.state.blocks);
185 }
186
187 #[deprecated(since = "0.2.0")]
188 pub fn generate_random_key() -> [u8; SEED_SIZE] {
189 Self::generate_random_seed()
190 }
191
192 pub fn generate_random_seed() -> [u8; SEED_SIZE] {
193 let mut seed = [0u8; SEED_SIZE];
194 let mut rng = rand::rng();
195 rng.fill_bytes(&mut seed);
196 seed
197 }
198
199 pub fn from_random_seed() -> Self {
207 let seed = AesRng::generate_random_seed();
208 Self::from_seed(seed)
209 }
210
211 pub fn get_bit(&mut self) -> u8 {
213 if self.n_cached_bits == 0 {
214 self.cached_bits = self.next_u64();
215 self.n_cached_bits = 64;
216 }
217 self.n_cached_bits -= 1;
218 let result: u8 = (self.cached_bits & 1) as u8;
219 self.cached_bits >>= 1;
220 result
221 }
222}
223
224impl RngCore for AesRng {
225 fn next_u32(&mut self) -> u32 {
227 let u32_size = mem::size_of::<u32>();
228 if self.state.used_bytes >= STATE_SIZE - u32_size {
229 self.next();
230 }
231 let used_bytes = self.state.used_bytes;
232 self.state.used_bytes += u32_size; let blocks_bytes = self.state.as_mut_bytes();
234 LittleEndian::read_u32(&blocks_bytes[used_bytes..used_bytes + u32_size])
235 }
236
237 fn next_u64(&mut self) -> u64 {
239 let u64_size = mem::size_of::<u64>();
240 if self.state.used_bytes >= STATE_SIZE - u64_size {
241 self.next();
242 }
243 let used_bytes = self.state.used_bytes;
244 self.state.used_bytes += u64_size; LittleEndian::read_u64(&self.state.as_mut_bytes()[used_bytes..used_bytes + u64_size])
246 }
247
248 fn fill_bytes(&mut self, dest: &mut [u8]) {
250 let mut read_len = STATE_SIZE - self.state.used_bytes;
251 let mut dest_start = 0;
252
253 while read_len < dest.len() {
254 let src_start = self.state.used_bytes;
255 dest[dest_start..read_len]
256 .copy_from_slice(&self.state.as_mut_bytes()[src_start..STATE_SIZE]);
257 self.next();
258 dest_start = read_len;
259 read_len += STATE_SIZE;
260 }
261
262 let src_start = self.state.used_bytes;
263 let remainder = dest.len() - dest_start;
264 let dest_len = dest.len();
265
266 dest[dest_start..dest_len]
267 .copy_from_slice(&self.state.as_mut_bytes()[src_start..src_start + remainder]);
268 self.state.used_bytes += remainder;
269 }
270}
271
272impl CryptoRng for AesRng {}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_prng_match_aes() {
280 let seed = [0u8; SEED_SIZE];
281 let key: Block128 = GenericArray::clone_from_slice(&seed);
282 let cipher = Aes128::new(&key);
283
284 let block0 =
285 GenericArray::clone_from_slice(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
286 let block1 =
287 GenericArray::clone_from_slice(&[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
288 let block2 =
289 GenericArray::clone_from_slice(&[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
290 let block3 =
291 GenericArray::clone_from_slice(&[3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
292 let block4 =
293 GenericArray::clone_from_slice(&[4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
294 let block5 =
295 GenericArray::clone_from_slice(&[5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
296 let block6 =
297 GenericArray::clone_from_slice(&[6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
298 let block7 =
299 GenericArray::clone_from_slice(&[7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
300
301 let mut blocks = Block128x8::clone_from_slice(&[
302 block0, block1, block2, block3, block4, block5, block6, block7,
303 ]);
304
305 cipher.encrypt_blocks(&mut blocks);
307
308 let mut rng = AesRng::from_seed(seed);
309 let mut out = [0u8; 16 * 8];
310 rng.fill_bytes(&mut out);
311
312 assert_eq!(rng.state.blocks, blocks);
314 }
315
316 #[test]
317 fn test_prng_vector1() {
318 let seed = [0u8; SEED_SIZE];
326
327 let mut rng = AesRng::from_seed(seed);
328 let mut out = [0u8; 16];
329
330 for _ in 0..129 {
331 rng.fill_bytes(&mut out);
332 }
333
334 let expected: [u8; 16] = [
335 58, 215, 142, 114, 108, 30, 192, 43, 126, 191, 233, 43, 35, 217, 236, 52,
336 ];
337 assert_eq!(expected, out);
338 }
339
340 #[test]
341 fn test_prng_vector2() {
342 let seed = [0u8; SEED_SIZE];
350
351 let mut rng = AesRng::from_seed(seed);
352 let mut out = [0u8; 16];
353 for _ in 0..17 {
354 rng.fill_bytes(&mut out);
355 }
356
357 let expected: [u8; 16] = [
358 245, 86, 155, 58, 182, 166, 209, 30, 253, 225, 191, 10, 100, 198, 133, 74,
359 ];
360 assert_eq!(expected, out);
361 }
362
363 #[test]
364 fn test_prng_used_bytes() {
365 let mut rng: AesRng = AesRng::from_random_seed();
366 let mut out = [0u8; 16 * 8];
367 rng.fill_bytes(&mut out);
368
369 assert_eq!(rng.state.used_bytes, 16 * 8);
370
371 let _ = rng.next_u32();
372 assert_eq!(rng.state.used_bytes, 4);
375 }
376
377 #[test]
378 fn test_seeded_prng() {
379 let mut rng: AesRng = AesRng::from_random_seed();
380 let _ = rng.next_u32();
382 let _ = rng.next_u64();
383 }
384
385 #[test]
386 fn test_cloned_prng() {
387 let mut rng1: AesRng = AesRng::from_random_seed();
388 let mut rng2 = rng1.clone();
389
390 assert_eq!(rng1.next_u32(), rng2.next_u32());
391 }
392}