use core::future::poll_fn;
use core::marker::PhantomData;
use core::ops::Not;
use core::task::Poll;
use embassy_hal_internal::{Peri, PeripheralType};
use embassy_sync::waitqueue::AtomicWaker;
use crate::interrupt::typelevel::{Binding, Interrupt};
use crate::peripherals::TRNG;
use crate::{interrupt, pac};
trait SealedInstance {
fn regs() -> pac::trng::Trng;
fn waker() -> &'static AtomicWaker;
}
#[allow(private_bounds)]
pub trait Instance: SealedInstance + PeripheralType {
type Interrupt: Interrupt;
}
impl SealedInstance for TRNG {
fn regs() -> rp_pac::trng::Trng {
pac::TRNG
}
fn waker() -> &'static AtomicWaker {
static WAKER: AtomicWaker = AtomicWaker::new();
&WAKER
}
}
impl Instance for TRNG {
type Interrupt = interrupt::typelevel::TRNG_IRQ;
}
#[derive(Copy, Clone, Debug)]
#[allow(missing_docs)]
pub enum InverterChainLength {
None = 0,
One,
Two,
Three,
Four,
}
impl From<InverterChainLength> for u8 {
fn from(value: InverterChainLength) -> Self {
value as u8
}
}
#[non_exhaustive]
#[derive(Copy, Clone, Debug)]
pub struct Config {
pub disable_autocorrelation_test: bool,
pub disable_crngt_test: bool,
pub disable_von_neumann_balancer: bool,
pub sample_count: u32,
pub inverter_chain_length: InverterChainLength,
}
impl Default for Config {
fn default() -> Self {
Config {
disable_autocorrelation_test: false,
disable_crngt_test: false,
disable_von_neumann_balancer: false,
sample_count: 25,
inverter_chain_length: InverterChainLength::One,
}
}
}
pub struct Trng<'d, T: Instance> {
phantom: PhantomData<&'d mut T>,
config: Config,
}
const TRNG_BLOCK_SIZE_BITS: usize = 192;
const TRNG_BLOCK_SIZE_BYTES: usize = TRNG_BLOCK_SIZE_BITS / 8;
impl<'d, T: Instance> Trng<'d, T> {
pub fn new(_trng: Peri<'d, T>, _irq: impl Binding<T::Interrupt, InterruptHandler<T>> + 'd, config: Config) -> Self {
let trng = Trng {
phantom: PhantomData,
config: config,
};
trng.initialize_rng();
trng
}
fn start_rng(&self) {
let regs = T::regs();
let source_enable_register = regs.rnd_source_enable();
source_enable_register.write(|w| w.set_rnd_src_en(true));
}
fn stop_rng(&self) {
let regs = T::regs();
let source_enable_register = regs.rnd_source_enable();
source_enable_register.write(|w| w.set_rnd_src_en(false));
let reset_bits_counter_register = regs.rst_bits_counter();
reset_bits_counter_register.write(|w| w.set_rst_bits_counter(true));
}
fn initialize_rng(&self) {
let regs = T::regs();
regs.rng_imr().write(|w| w.set_ehr_valid_int_mask(false));
let trng_config_register = regs.trng_config();
trng_config_register.write(|w| {
w.set_rnd_src_sel(self.config.inverter_chain_length.clone().into());
});
let sample_count_register = regs.sample_cnt1();
sample_count_register.write(|w| {
*w = self.config.sample_count;
});
let debug_control_register = regs.trng_debug_control();
debug_control_register.write(|w| {
w.set_auto_correlate_bypass(self.config.disable_autocorrelation_test);
w.set_trng_crngt_bypass(self.config.disable_crngt_test);
w.set_vnc_bypass(self.config.disable_von_neumann_balancer);
});
}
fn enable_irq(&self) {
unsafe { T::Interrupt::enable() }
}
fn disable_irq(&self) {
T::Interrupt::disable();
}
fn blocking_wait_for_successful_generation(&self) {
let regs = T::regs();
let trng_busy_register = regs.trng_busy();
let trng_valid_register = regs.trng_valid();
let mut success = false;
while success.not() {
while trng_busy_register.read().trng_busy() {}
if trng_valid_register.read().ehr_valid().not() {
if regs.rng_isr().read().autocorr_err() {
regs.trng_sw_reset().write(|w| w.set_trng_sw_reset(true));
regs.trng_sw_reset().read();
self.initialize_rng();
self.start_rng();
} else {
panic!("RNG not busy, but ehr is not valid!")
}
} else {
success = true
}
}
}
fn read_ehr_registers_into_array(&mut self, buffer: &mut [u8; TRNG_BLOCK_SIZE_BYTES]) {
let regs = T::regs();
let ehr_data_regs = [
regs.ehr_data0(),
regs.ehr_data1(),
regs.ehr_data2(),
regs.ehr_data3(),
regs.ehr_data4(),
regs.ehr_data5(),
];
for (i, reg) in ehr_data_regs.iter().enumerate() {
buffer[i * 4..i * 4 + 4].copy_from_slice(®.read().to_ne_bytes());
}
}
fn blocking_read_ehr_registers_into_array(&mut self, buffer: &mut [u8; TRNG_BLOCK_SIZE_BYTES]) {
self.blocking_wait_for_successful_generation();
self.read_ehr_registers_into_array(buffer);
}
pub async fn fill_bytes(&mut self, destination: &mut [u8]) {
if destination.is_empty() {
return; }
self.start_rng();
self.enable_irq();
let mut bytes_transferred = 0usize;
let mut buffer = [0u8; TRNG_BLOCK_SIZE_BYTES];
let regs = T::regs();
let trng_busy_register = regs.trng_busy();
let trng_valid_register = regs.trng_valid();
let waker = T::waker();
let destination_length = destination.len();
poll_fn(|context| {
waker.register(context.waker());
if bytes_transferred == destination_length {
self.stop_rng();
self.disable_irq();
Poll::Ready(())
} else {
if trng_busy_register.read().trng_busy() {
Poll::Pending
} else {
if trng_valid_register.read().ehr_valid().not() {
self.initialize_rng();
self.start_rng();
return Poll::Pending;
}
self.read_ehr_registers_into_array(&mut buffer);
let remaining = destination_length - bytes_transferred;
if remaining > TRNG_BLOCK_SIZE_BYTES {
destination[bytes_transferred..bytes_transferred + TRNG_BLOCK_SIZE_BYTES]
.copy_from_slice(&buffer);
bytes_transferred += TRNG_BLOCK_SIZE_BYTES
} else {
destination[bytes_transferred..bytes_transferred + remaining]
.copy_from_slice(&buffer[0..remaining]);
bytes_transferred += remaining
}
if bytes_transferred == destination_length {
self.stop_rng();
self.disable_irq();
Poll::Ready(())
} else {
Poll::Pending
}
}
}
})
.await
}
pub fn blocking_fill_bytes(&mut self, destination: &mut [u8]) {
if destination.is_empty() {
return; }
self.start_rng();
let mut buffer = [0u8; TRNG_BLOCK_SIZE_BYTES];
for chunk in destination.chunks_mut(TRNG_BLOCK_SIZE_BYTES) {
self.blocking_wait_for_successful_generation();
self.blocking_read_ehr_registers_into_array(&mut buffer);
chunk.copy_from_slice(&buffer[..chunk.len()])
}
self.stop_rng()
}
pub fn blocking_next_u32(&mut self) -> u32 {
let regs = T::regs();
self.start_rng();
self.blocking_wait_for_successful_generation();
let result = regs.ehr_data5().read();
self.stop_rng();
result
}
pub fn blocking_next_u64(&mut self) -> u64 {
let regs = T::regs();
self.start_rng();
self.blocking_wait_for_successful_generation();
let low = regs.ehr_data4().read() as u64;
let result = (regs.ehr_data5().read() as u64) << 32 | low;
self.stop_rng();
result
}
}
impl<'d, T: Instance> rand_core_06::RngCore for Trng<'d, T> {
fn next_u32(&mut self) -> u32 {
self.blocking_next_u32()
}
fn next_u64(&mut self) -> u64 {
self.blocking_next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
self.blocking_fill_bytes(dest);
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core_06::Error> {
self.blocking_fill_bytes(dest);
Ok(())
}
}
impl<'d, T: Instance> rand_core_06::CryptoRng for Trng<'d, T> {}
impl<'d, T: Instance> rand_core_09::RngCore for Trng<'d, T> {
fn next_u32(&mut self) -> u32 {
self.blocking_next_u32()
}
fn next_u64(&mut self) -> u64 {
self.blocking_next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
self.blocking_fill_bytes(dest);
}
}
impl<'d, T: Instance> rand_core_09::CryptoRng for Trng<'d, T> {}
pub struct InterruptHandler<T: Instance> {
_trng: PhantomData<T>,
}
impl<T: Instance> interrupt::typelevel::Handler<T::Interrupt> for InterruptHandler<T> {
unsafe fn on_interrupt() {
let regs = T::regs();
let isr = regs.rng_isr().read();
if isr.ehr_valid() {
regs.rng_icr().write(|w| {
w.set_ehr_valid(true);
});
T::waker().wake();
} else if isr.crngt_err() {
warn!("TRNG CRNGT error! Increase sample count to reduce likelihood");
regs.rng_icr().write(|w| {
w.set_crngt_err(true);
});
} else if isr.vn_err() {
warn!("TRNG Von-Neumann balancer error! Increase sample count to reduce likelihood");
regs.rng_icr().write(|w| {
w.set_vn_err(true);
});
} else if isr.autocorr_err() {
warn!("TRNG Autocorrect error! Resetting TRNG. Increase sample count to reduce likelihood");
regs.trng_sw_reset().write(|w| {
w.set_trng_sw_reset(true);
});
regs.trng_sw_reset().read();
T::waker().wake();
}
}
}