use alloc::rc::Rc;
use core::{cell::UnsafeCell, convert::Infallible};
use std::thread_local;
use chacha20::ChaCha8Rng;
use rand_core::{SeedableRng, TryCryptoRng, TryRng};
thread_local! {
static SOURCE: Rc<UnsafeCell<ChaCha8Rng>> = {
Rc::new(UnsafeCell::new(ChaCha8Rng::try_from_rng(&mut getrandom::SysRng).expect("Unable to source entropy for initialisation")))
};
}
pub struct ThreadLocalEntropy(Rc<UnsafeCell<ChaCha8Rng>>);
impl ThreadLocalEntropy {
pub fn get() -> Result<Self, std::thread::AccessError> {
Ok(Self(SOURCE.try_with(Rc::clone)?))
}
#[inline(always)]
fn access_local_source<F, O>(&mut self, f: F) -> Result<O, Infallible>
where
F: FnOnce(&mut ChaCha8Rng) -> Result<O, Infallible>,
{
unsafe { f(&mut *self.0.get()) }
}
}
impl core::fmt::Debug for ThreadLocalEntropy {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_tuple("ThreadLocalEntropy").finish()
}
}
impl TryRng for ThreadLocalEntropy {
type Error = Infallible;
fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
self.access_local_source(TryRng::try_next_u32)
}
fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
self.access_local_source(TryRng::try_next_u64)
}
fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error> {
self.access_local_source(|source| source.try_fill_bytes(dst))
}
}
impl TryCryptoRng for ThreadLocalEntropy {}
#[cfg(test)]
mod tests {
use alloc::{format, vec, vec::Vec};
use rand_core::Rng;
use super::*;
#[test]
fn smoke_test() -> Result<(), std::thread::AccessError> {
let mut rng1 = ThreadLocalEntropy::get()?;
let mut rng2 = ThreadLocalEntropy::get()?;
rng1.next_u32();
rng2.next_u64();
let mut bytes1 = vec![0u8; 128];
let mut bytes2 = vec![0u8; 128];
rng1.fill_bytes(&mut bytes1);
rng2.fill_bytes(&mut bytes2);
assert_ne!(&bytes1, &bytes2);
Ok(())
}
#[test]
fn unique_source_per_thread() {
let mut bytes1: Vec<u8> = vec![0u8; 128];
let mut bytes2: Vec<u8> = vec![0u8; 128];
let b1 = bytes1.as_mut();
let b2 = bytes2.as_mut();
let (a, b) = std::thread::scope(|s| {
let a = s.spawn(move || {
let mut rng =
ThreadLocalEntropy::get().expect("Should not fail when accessing local source");
rng.fill_bytes(b1);
rng.access_local_source(|rng| Ok(rng.next_u64()))
});
let b = s.spawn(move || {
let mut rng =
ThreadLocalEntropy::get().expect("Should not fail when accessing local source");
rng.fill_bytes(b2);
rng.access_local_source(|rng| Ok(rng.next_u64()))
});
(a.join(), b.join())
});
let Ok(a) = a.unwrap();
let Ok(b) = b.unwrap();
assert_ne!(a, b);
assert_ne!(&bytes1, &bytes2);
}
#[test]
fn non_leaking_debug() {
assert_eq!(
"Ok(ThreadLocalEntropy)",
format!("{:?}", ThreadLocalEntropy::get())
);
}
}