use crate::core::app_errors::Result as AppResult;
use crate::crypto::rng::SecureRng;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use zeroize::{Zeroize, Zeroizing};
pub struct GlobalRng {
rng: Arc<SecureRng>,
output_counter: AtomicU64,
reseed_threshold: u64,
last_reseed: AtomicU64,
}
pub trait ByteStream {
fn fill_next_block(&mut self) -> AppResult<()>;
fn remaining_bytes(&self) -> &[u8];
fn consume(&mut self, n: usize);
}
impl<T: ByteStream + ?Sized> ByteStream for &mut T {
fn fill_next_block(&mut self) -> AppResult<()> {
(**self).fill_next_block()
}
fn remaining_bytes(&self) -> &[u8] {
(**self).remaining_bytes()
}
fn consume(&mut self, n: usize) {
(**self).consume(n);
}
}
impl GlobalRng {
const DEFAULT_RESEED_THRESHOLD: u64 = 1_048_576;
const RESEED_TIME_THRESHOLD: u64 = 3600;
pub fn new() -> AppResult<Self> {
Ok(Self {
rng: Arc::new(SecureRng::new()?),
output_counter: AtomicU64::new(0),
reseed_threshold: Self::DEFAULT_RESEED_THRESHOLD,
last_reseed: AtomicU64::new(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
})
}
pub fn generate_bytes(&self, dest: &mut [u8]) -> AppResult<()> {
if self.should_reseed()? {
self.reseed()?;
}
self.output_counter
.fetch_add(dest.len() as u64, Ordering::Relaxed);
self.rng.generate_bytes(dest)
}
fn should_reseed(&self) -> AppResult<bool> {
let output_count = self.output_counter.load(Ordering::Relaxed);
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| {
crate::core::app_errors::GenerationError::IoError(std::io::Error::other(
"Time error",
))
})?
.as_secs();
let last_reseed = self.last_reseed.load(Ordering::Relaxed);
let elapsed = current_time.saturating_sub(last_reseed);
Ok(should_reseed_by(
output_count,
elapsed,
self.reseed_threshold,
Self::RESEED_TIME_THRESHOLD,
))
}
fn reseed(&self) -> AppResult<()> {
self.rng.reseed()?;
self.output_counter.store(0, Ordering::Relaxed);
self.last_reseed.store(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
Ordering::Relaxed,
);
Ok(())
}
pub fn get_statistics(&self) -> GlobalRngStatistics {
GlobalRngStatistics {
output_bytes: self.output_counter.load(Ordering::Relaxed),
last_reseed: self.last_reseed.load(Ordering::Relaxed),
reseed_threshold: self.reseed_threshold,
}
}
pub fn stream(self: &Arc<Self>) -> GlobalRngStream {
GlobalRngStream::new(self.clone())
}
}
const GLOBAL_STREAM_BLOCK_SIZE: usize = 256;
pub struct GlobalRngStream {
rng: Arc<GlobalRng>,
cache: Zeroizing<[u8; GLOBAL_STREAM_BLOCK_SIZE]>,
cursor: usize,
available: usize,
}
impl GlobalRngStream {
pub fn new(rng: Arc<GlobalRng>) -> Self {
Self {
rng,
cache: Zeroizing::new([0u8; GLOBAL_STREAM_BLOCK_SIZE]),
cursor: 0,
available: 0,
}
}
}
impl ByteStream for GlobalRngStream {
fn fill_next_block(&mut self) -> AppResult<()> {
if let Err(err) = self.rng.generate_bytes(self.cache.as_mut()) {
self.cache.as_mut().zeroize();
self.cursor = 0;
self.available = 0;
return Err(err);
}
self.cursor = 0;
self.available = self.cache.len();
Ok(())
}
fn remaining_bytes(&self) -> &[u8] {
let end = self
.cursor
.saturating_add(self.available)
.min(self.cache.len());
&self.cache[self.cursor..end]
}
fn consume(&mut self, n: usize) {
let take = n.min(self.available);
if take > 0 {
let start = self.cursor;
let end = start.saturating_add(take).min(self.cache.len());
self.cache.as_mut()[start..end].zeroize();
}
self.cursor = (self.cursor + take).min(self.cache.len());
self.available = self.available.saturating_sub(take);
if self.available == 0 {
self.cursor = 0;
}
}
}
impl Drop for GlobalRngStream {
fn drop(&mut self) {
self.cache.as_mut().zeroize();
self.cursor = 0;
self.available = 0;
}
}
#[derive(Debug, Clone)]
pub struct GlobalRngStatistics {
pub output_bytes: u64,
pub last_reseed: u64,
pub reseed_threshold: u64,
}
use std::sync::Mutex;
static GLOBAL_RNG: Mutex<Option<Arc<GlobalRng>>> = Mutex::new(None);
pub fn get_global_rng() -> AppResult<Arc<GlobalRng>> {
let mut guard = match GLOBAL_RNG.lock() {
Ok(guard) => guard,
Err(poison) => {
let mut guard = poison.into_inner();
if let Some(rng) = guard.as_ref() {
if rng.reseed().is_ok() {
return Ok(rng.clone());
}
*guard = None;
}
guard
}
};
if let Some(rng) = guard.as_ref() {
return Ok(rng.clone());
}
let rng = Arc::new(GlobalRng::new()?);
*guard = Some(rng.clone());
Ok(rng)
}
fn should_reseed_by(
output_bytes: u64,
elapsed_secs: u64,
bytes_threshold: u64,
time_threshold: u64,
) -> bool {
output_bytes >= bytes_threshold || elapsed_secs >= time_threshold
}