use rand_core::TryRng;
const MAX_RETURN_CHUNK_SIZE: usize = 128;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct RandJitterKernel {
rng_fd: libc::c_int,
}
impl RandJitterKernel {
pub fn new() -> Result<Self, std::io::Error> {
#[cfg(not(target_os = "linux"))]
compile_error!("Only Linux is supported");
let fam_fd = unsafe { libc::socket(libc::AF_ALG, libc::SOCK_SEQPACKET, 0) };
if fam_fd < 0 {
return Err(std::io::Error::other(
"unable to create AF_ALG socket for jitterentropy_rng",
));
}
let mut sock_addr: libc::sockaddr_alg = unsafe { std::mem::zeroed() };
sock_addr.salg_family = u16::try_from(libc::AF_ALG)
.map_err(|_| std::io::Error::other("unable to convert socket algorithm family"))?;
let rng_type = "rng";
let rng_name = "jitterentropy_rng";
sock_addr.salg_type[..rng_type.len()].copy_from_slice(rng_type.to_string().as_bytes());
sock_addr.salg_name[..rng_name.len()].copy_from_slice(rng_name.to_string().as_bytes());
let bind_ret = unsafe {
libc::bind(
fam_fd,
std::ptr::addr_of!(sock_addr).cast::<libc::sockaddr>(),
u32::try_from(std::mem::size_of_val(&sock_addr))
.map_err(|_| std::io::Error::other("unable to convert size of sock_addr"))?,
)
};
if bind_ret != 0 {
unsafe {
libc::close(fam_fd);
}
return Err(std::io::Error::other("unable to bind AF_ALG socket"));
}
let rng_fd = unsafe { libc::accept(fam_fd, std::ptr::null_mut(), std::ptr::null_mut()) };
if rng_fd < 0 {
unsafe {
libc::close(fam_fd);
}
return Err(std::io::Error::other("unable to get rng_fd from kernel"));
}
unsafe { libc::close(fam_fd) };
Ok(RandJitterKernel { rng_fd })
}
fn try_fill_bytes_max_chunk_size(&mut self, dst: &mut [u8]) -> Result<(), std::io::Error> {
if dst.len() > MAX_RETURN_CHUNK_SIZE {
return Err(std::io::Error::other(format!(
"Cannot return more than {} byte in a single call. Requested: {} byte",
MAX_RETURN_CHUNK_SIZE,
dst.len()
)));
}
if self.rng_fd < 0 {
return Err(std::io::Error::other(format!(
"Cannot get entropy from jitterentropy_rng in kernel with invalid fd {}",
self.rng_fd
)));
}
let size = unsafe {
libc::read(
self.rng_fd,
dst.as_mut_ptr().cast::<libc::c_void>(),
dst.len(),
)
};
if size >= 0
&& usize::try_from(size)
.map_err(|_| std::io::Error::other("unable to convert returned size to usize"))?
== dst.len()
{
Ok(())
} else {
Err(std::io::Error::other(
"Cannot get entropy from jitterentropy_rng in kernel",
))
}
}
}
impl Default for RandJitterKernel {
fn default() -> Self {
Self::new().unwrap()
}
}
impl Drop for RandJitterKernel {
fn drop(&mut self) {
assert!(self.rng_fd >= 0, "rng_fd already closed or never opened?");
unsafe {
libc::close(self.rng_fd);
}
self.rng_fd = -1;
}
}
impl TryRng for RandJitterKernel {
type Error = std::io::Error;
fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
Ok(u32::try_from(self.try_next_u64()? & 0xFF_FF_FF_FF).unwrap())
}
fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
let mut bytes: [u8; 8] = [0; 8];
self.try_fill_bytes(&mut bytes)?;
Ok(u64::from_ne_bytes(bytes))
}
fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error> {
let mut idx = 0;
while idx < dst.len() {
let chunk_size = if idx + MAX_RETURN_CHUNK_SIZE > dst.len() {
dst.len() - idx
} else {
MAX_RETURN_CHUNK_SIZE
};
self.try_fill_bytes_max_chunk_size(&mut dst[idx..idx + chunk_size])?;
idx += chunk_size;
}
assert_eq!(idx, dst.len());
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::RandJitterKernel;
use rand_core::TryRng;
#[test]
fn test_u32() {
let mut rng = RandJitterKernel::new().unwrap();
for _ in 0..1000 {
let u = rng.try_next_u32();
assert!(u.is_ok());
}
}
#[test]
fn test_u64() {
let mut rng = RandJitterKernel::new().unwrap();
for _ in 0..1000 {
let u = rng.try_next_u64();
assert!(u.is_ok());
}
}
#[test]
fn test_speed() {
use std::time::Instant;
let start = Instant::now();
let mut num_bytes = 0usize;
let mut rng = RandJitterKernel::new().unwrap();
loop {
let mut b = [0u8; 32];
rng.try_fill_bytes(&mut b).unwrap();
let now = Instant::now();
num_bytes += b.len();
if (now - start).as_secs() > 2 {
let datarate = f64::from(u32::try_from(num_bytes).unwrap())
/ (now - start).as_secs_f64()
/ 1024.0;
println!("datarate: {datarate} KiB/s");
break;
}
}
}
#[test]
fn test_bytes() {
let mut rng = RandJitterKernel::new().unwrap();
for buffer_size in 0..=256 {
let mut buffer = vec![0u8; buffer_size];
assert!(rng.try_fill_bytes(&mut buffer).is_ok());
println!("{buffer_size}: {buffer:#04X?}");
}
}
#[test]
fn test_large_bytes_but_ok() {
let mut rng = RandJitterKernel::new().unwrap();
let mut buffer = [0u8; 128];
assert!(rng.try_fill_bytes_max_chunk_size(&mut buffer).is_ok());
}
#[test]
fn test_too_large_bytes() {
let mut rng = RandJitterKernel::new().unwrap();
let mut buffer = [0u8; 129];
assert!(rng.try_fill_bytes_max_chunk_size(&mut buffer).is_err());
}
#[test]
fn test_multi_instantiation() {
for _ in 0..256 {
let mut rng = RandJitterKernel::new().unwrap();
let u = rng.try_next_u32().unwrap();
println!("Got {u}");
}
}
#[test]
fn test_multi_threading() {
let mut threads = vec![];
let mut rng = RandJitterKernel::new().unwrap();
let _ = rng.try_next_u64().unwrap();
println!("Got bytes (single threaded)!");
for _ in 0..6 {
threads.push(std::thread::spawn(move || {
for _ in 0..1024 {
let mut rng = RandJitterKernel::new().unwrap();
let _ = rng.try_next_u64().unwrap();
}
}));
}
for t in threads {
let _ = t.join();
}
}
}