use core::{marker::PhantomData, ptr::copy_nonoverlapping};
use crate::{
peripheral::{Peripheral, PeripheralRef},
peripherals::RSA,
system::{Peripheral as PeripheralEnable, PeripheralClockControl},
};
#[cfg_attr(esp32s2, path = "esp32sX.rs")]
#[cfg_attr(esp32s3, path = "esp32sX.rs")]
#[cfg_attr(esp32c3, path = "esp32cX.rs")]
#[cfg_attr(esp32c6, path = "esp32cX.rs")]
#[cfg_attr(esp32h2, path = "esp32cX.rs")]
#[cfg_attr(esp32, path = "esp32.rs")]
mod rsa_spec_impl;
pub use rsa_spec_impl::operand_sizes;
pub struct Rsa<'d> {
rsa: PeripheralRef<'d, RSA>,
}
impl<'d> Rsa<'d> {
pub fn new(rsa: impl Peripheral<P = RSA> + 'd) -> Self {
crate::into_ref!(rsa);
PeripheralClockControl::enable(PeripheralEnable::Rsa);
Self { rsa }
}
unsafe fn write_operand_b<const N: usize>(&mut self, operand_b: &[u32; N]) {
copy_nonoverlapping(operand_b.as_ptr(), self.rsa.y_mem(0).as_ptr(), N);
}
unsafe fn write_modulus<const N: usize>(&mut self, modulus: &[u32; N]) {
copy_nonoverlapping(modulus.as_ptr(), self.rsa.m_mem(0).as_ptr(), N);
}
fn write_mprime(&mut self, m_prime: u32) {
self.rsa.m_prime().write(|w| unsafe { w.bits(m_prime) });
}
unsafe fn write_operand_a<const N: usize>(&mut self, operand_a: &[u32; N]) {
copy_nonoverlapping(operand_a.as_ptr(), self.rsa.x_mem(0).as_ptr(), N);
}
unsafe fn write_r<const N: usize>(&mut self, r: &[u32; N]) {
copy_nonoverlapping(r.as_ptr(), self.rsa.z_mem(0).as_ptr(), N);
}
unsafe fn read_out<const N: usize>(&mut self, outbuf: &mut [u32; N]) {
copy_nonoverlapping(
self.rsa.z_mem(0).as_ptr() as *const u32,
outbuf.as_ptr() as *mut u32,
N,
);
}
}
pub trait RsaMode: crate::private::Sealed {
type InputType;
}
pub trait Multi: RsaMode {
type OutputType;
}
macro_rules! implement_op {
(($x:literal, multi)) => {
paste! {pub struct [<Op $x>];}
paste! {
impl Multi for [<Op $x>] {
type OutputType = [u32; $x*2 / 32];
}}
paste! {
impl crate::private::Sealed for [<Op $x>] {}
}
paste! {
impl RsaMode for [<Op $x>] {
type InputType = [u32; $x / 32];
}}
};
(($x:literal)) => {
paste! {pub struct [<Op $x>];}
paste! {
impl crate::private::Sealed for [<Op $x>] {}
}
paste!{
impl RsaMode for [<Op $x>] {
type InputType = [u32; $x / 32];
}}
};
($x:tt, $($y:tt),+) => {
implement_op!($x);
implement_op!($($y),+);
};
}
use implement_op;
pub struct RsaModularExponentiation<'a, 'd, T: RsaMode> {
rsa: &'a mut Rsa<'d>,
phantom: PhantomData<T>,
}
impl<'a, 'd, T: RsaMode, const N: usize> RsaModularExponentiation<'a, 'd, T>
where
T: RsaMode<InputType = [u32; N]>,
{
pub fn start_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) {
unsafe {
self.rsa.write_operand_a(base);
self.rsa.write_r(r);
}
self.set_start();
}
pub fn read_results(&mut self, outbuf: &mut T::InputType) {
loop {
if self.rsa.is_idle() {
unsafe {
self.rsa.read_out(outbuf);
}
self.rsa.clear_interrupt();
break;
}
}
}
}
pub struct RsaModularMultiplication<'a, 'd, T: RsaMode> {
rsa: &'a mut Rsa<'d>,
phantom: PhantomData<T>,
}
impl<'a, 'd, T: RsaMode, const N: usize> RsaModularMultiplication<'a, 'd, T>
where
T: RsaMode<InputType = [u32; N]>,
{
pub fn read_results(&mut self, outbuf: &mut T::InputType) {
loop {
if self.rsa.is_idle() {
unsafe {
self.rsa.read_out(outbuf);
}
self.rsa.clear_interrupt();
break;
}
}
}
}
pub struct RsaMultiplication<'a, 'd, T: RsaMode + Multi> {
rsa: &'a mut Rsa<'d>,
phantom: PhantomData<T>,
}
impl<'a, 'd, T: RsaMode + Multi, const N: usize> RsaMultiplication<'a, 'd, T>
where
T: RsaMode<InputType = [u32; N]>,
{
pub fn read_results<const O: usize>(&mut self, outbuf: &mut T::OutputType)
where
T: Multi<OutputType = [u32; O]>,
{
loop {
if self.rsa.is_idle() {
unsafe {
self.rsa.read_out(outbuf);
}
self.rsa.clear_interrupt();
break;
}
}
}
}
#[cfg(feature = "async")]
pub(crate) mod asynch {
use core::task::Poll;
use embassy_sync::waitqueue::AtomicWaker;
use procmacros::interrupt;
use crate::rsa::{
Multi,
RsaMode,
RsaModularExponentiation,
RsaModularMultiplication,
RsaMultiplication,
};
static WAKER: AtomicWaker = AtomicWaker::new();
pub(crate) struct RsaFuture<'d> {
instance: &'d crate::peripherals::RSA,
}
impl<'d> RsaFuture<'d> {
pub async fn new(instance: &'d crate::peripherals::RSA) -> Self {
#[cfg(not(any(esp32, esp32s2, esp32s3)))]
instance.int_ena().modify(|_, w| w.int_ena().set_bit());
#[cfg(any(esp32s2, esp32s3))]
instance
.interrupt_ena()
.modify(|_, w| w.interrupt_ena().set_bit());
#[cfg(esp32)]
instance.interrupt().modify(|_, w| w.interrupt().set_bit());
Self { instance }
}
fn event_bit_is_clear(&self) -> bool {
#[cfg(not(any(esp32, esp32s2, esp32s3)))]
return self.instance.int_ena().read().int_ena().bit_is_clear();
#[cfg(any(esp32s2, esp32s3))]
return self
.instance
.interrupt_ena()
.read()
.interrupt_ena()
.bit_is_clear();
#[cfg(esp32)]
return self.instance.interrupt().read().interrupt().bit_is_clear();
}
}
impl<'d> core::future::Future for RsaFuture<'d> {
type Output = ();
fn poll(
self: core::pin::Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
) -> core::task::Poll<Self::Output> {
WAKER.register(cx.waker());
if self.event_bit_is_clear() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
impl<'a, 'd, T: RsaMode, const N: usize> RsaModularExponentiation<'a, 'd, T>
where
T: RsaMode<InputType = [u32; N]>,
{
pub async fn exponentiation(
&mut self,
base: &T::InputType,
r: &T::InputType,
outbuf: &mut T::InputType,
) {
self.start_exponentiation(&base, &r);
RsaFuture::new(&self.rsa.rsa).await;
self.read_results(outbuf);
}
}
impl<'a, 'd, T: RsaMode, const N: usize> RsaModularMultiplication<'a, 'd, T>
where
T: RsaMode<InputType = [u32; N]>,
{
#[cfg(not(esp32))]
pub async fn modular_multiplication(
&mut self,
r: &T::InputType,
outbuf: &mut T::InputType,
) {
self.start_modular_multiplication(r);
RsaFuture::new(&self.rsa.rsa).await;
self.read_results(outbuf);
}
#[cfg(esp32)]
pub async fn modular_multiplication(
&mut self,
operand_a: &T::InputType,
operand_b: &T::InputType,
r: &T::InputType,
outbuf: &mut T::InputType,
) {
self.start_step1(operand_a, r);
self.start_step2(operand_b);
RsaFuture::new(&self.rsa.rsa).await;
self.read_results(outbuf);
}
}
impl<'a, 'd, T: RsaMode + Multi, const N: usize> RsaMultiplication<'a, 'd, T>
where
T: RsaMode<InputType = [u32; N]>,
{
#[cfg(not(esp32))]
pub async fn multiplication<'b, const O: usize>(
&mut self,
operand_b: &T::InputType,
outbuf: &mut T::OutputType,
) where
T: Multi<OutputType = [u32; O]>,
{
self.start_multiplication(operand_b);
RsaFuture::new(&self.rsa.rsa).await;
self.read_results(outbuf);
}
#[cfg(esp32)]
pub async fn multiplication<'b, const O: usize>(
&mut self,
operand_a: &T::InputType,
operand_b: &T::InputType,
outbuf: &mut T::OutputType,
) where
T: Multi<OutputType = [u32; O]>,
{
self.start_multiplication(operand_a, operand_b);
RsaFuture::new(&self.rsa.rsa).await;
self.read_results(outbuf);
}
}
#[interrupt]
fn RSA() {
#[cfg(not(any(esp32, esp32s2, esp32s3)))]
unsafe { &*crate::peripherals::RSA::ptr() }
.int_ena()
.modify(|_, w| w.int_ena().clear_bit());
#[cfg(esp32)]
unsafe { &*crate::peripherals::RSA::ptr() }
.interrupt()
.modify(|_, w| w.interrupt().clear_bit());
#[cfg(any(esp32s2, esp32s3))]
unsafe { &*crate::peripherals::RSA::ptr() }
.interrupt_ena()
.modify(|_, w| w.interrupt_ena().clear_bit());
WAKER.wake();
}
}