p2panda_encryption/crypto/
rng.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3use std::sync::Mutex;
4
5use rand_chacha::rand_core::{SeedableRng, TryRngCore};
6use thiserror::Error;
7
8/// Cryptographically-secure random number generator that uses the ChaCha algorithm.
9#[derive(Debug)]
10pub struct Rng {
11    rng: Mutex<rand_chacha::ChaCha20Rng>,
12}
13
14impl Default for Rng {
15    fn default() -> Self {
16        Self {
17            rng: Mutex::new(rand_chacha::ChaCha20Rng::from_os_rng()),
18        }
19    }
20}
21
22impl Rng {
23    pub fn from_rng(rng: &Rng) -> Result<Self, RngError> {
24        Ok(Self::from_seed(rng.random_array()?))
25    }
26
27    pub fn random_array<const N: usize>(&self) -> Result<[u8; N], RngError> {
28        let mut rng = self.rng.lock().map_err(|_| RngError::LockPoisoned)?;
29        let mut out = [0u8; N];
30        rng.try_fill_bytes(&mut out)
31            .map_err(|_| RngError::NotEnoughRandomness)?;
32        Ok(out)
33    }
34
35    pub fn random_vec(&self, len: usize) -> Result<Vec<u8>, RngError> {
36        let mut rng = self.rng.lock().map_err(|_| RngError::LockPoisoned)?;
37        let mut out = vec![0u8; len];
38        rng.try_fill_bytes(&mut out)
39            .map_err(|_| RngError::NotEnoughRandomness)?;
40        Ok(out)
41    }
42
43    #[cfg(any(test, feature = "test_utils"))]
44    pub fn from_seed(seed: [u8; 32]) -> Self {
45        Self {
46            rng: Mutex::new(rand_chacha::ChaCha20Rng::from_seed(seed)),
47        }
48    }
49
50    #[cfg(not(any(test, feature = "test_utils")))]
51    fn from_seed(seed: [u8; 32]) -> Self {
52        Self {
53            rng: Mutex::new(rand_chacha::ChaCha20Rng::from_seed(seed)),
54        }
55    }
56}
57
58#[derive(Debug, Error)]
59pub enum RngError {
60    #[error("rng lock is poisoned")]
61    LockPoisoned,
62
63    #[error("unable to collect enough randomness")]
64    NotEnoughRandomness,
65}
66
67#[cfg(test)]
68mod tests {
69    use super::Rng;
70
71    #[test]
72    fn deterministic_randomness() {
73        let sample_1 = {
74            let rng = Rng::from_seed([1; 32]);
75            rng.random_vec(128).unwrap()
76        };
77
78        let sample_2 = {
79            let rng = Rng::from_seed([1; 32]);
80            rng.random_vec(128).unwrap()
81        };
82
83        assert_eq!(sample_1, sample_2);
84    }
85}