use core::{marker::PhantomData, ptr::NonNull, task::Poll};
use portable_atomic::{AtomicBool, Ordering};
use procmacros::{handler, ram};
use crate::{
Async,
Blocking,
DriverMode,
asynch::AtomicWaker,
interrupt::InterruptHandler,
pac,
peripherals::{Interrupt, RSA},
system::{Cpu, GenericPeripheralGuard, Peripheral as PeripheralEnable},
trm_markdown_link,
work_queue::{self, Status, VTable, WorkQueue, WorkQueueDriver, WorkQueueFrontend},
};
pub struct Rsa<'d, Dm: DriverMode> {
rsa: RSA<'d>,
phantom: PhantomData<Dm>,
#[cfg(not(esp32))]
_memory_guard: RsaMemoryPowerGuard,
_guard: GenericPeripheralGuard<{ PeripheralEnable::Rsa as u8 }>,
}
const WORDS_PER_INCREMENT: u32 = property!("rsa.size_increment") / 32;
#[cfg(not(esp32))]
struct RsaMemoryPowerGuard;
#[cfg(not(esp32))]
impl RsaMemoryPowerGuard {
fn new() -> Self {
crate::peripherals::SYSTEM::regs()
.rsa_pd_ctrl()
.modify(|_, w| {
w.rsa_mem_force_pd().clear_bit();
w.rsa_mem_force_pu().set_bit();
w.rsa_mem_pd().clear_bit()
});
Self
}
}
#[cfg(not(esp32))]
impl Drop for RsaMemoryPowerGuard {
fn drop(&mut self) {
crate::peripherals::SYSTEM::regs()
.rsa_pd_ctrl()
.modify(|_, w| {
w.rsa_mem_force_pd().clear_bit();
w.rsa_mem_force_pu().clear_bit();
w.rsa_mem_pd().set_bit()
});
}
}
impl<'d> Rsa<'d, Blocking> {
pub fn new(rsa: RSA<'d>) -> Self {
let guard = GenericPeripheralGuard::new();
let this = Self {
rsa,
phantom: PhantomData,
#[cfg(not(esp32))]
_memory_guard: RsaMemoryPowerGuard::new(),
_guard: guard,
};
while !this.ready() {}
this
}
pub fn into_async(mut self) -> Rsa<'d, Async> {
self.set_interrupt_handler(rsa_interrupt_handler);
self.enable_disable_interrupt(true);
Rsa {
rsa: self.rsa,
phantom: PhantomData,
#[cfg(not(esp32))]
_memory_guard: self._memory_guard,
_guard: self._guard,
}
}
pub fn enable_disable_interrupt(&mut self, enable: bool) {
self.internal_enable_disable_interrupt(enable);
}
#[instability::unstable]
pub fn set_interrupt_handler(&mut self, handler: InterruptHandler) {
self.rsa.disable_peri_interrupt_on_all_cores();
self.rsa.bind_peri_interrupt(handler);
}
}
impl crate::private::Sealed for Rsa<'_, Blocking> {}
#[instability::unstable]
impl crate::interrupt::InterruptConfigurable for Rsa<'_, Blocking> {
fn set_interrupt_handler(&mut self, handler: InterruptHandler) {
self.set_interrupt_handler(handler);
}
}
impl<'d> Rsa<'d, Async> {
pub fn into_blocking(self) -> Rsa<'d, Blocking> {
self.internal_enable_disable_interrupt(false);
self.rsa.disable_peri_interrupt_on_all_cores();
crate::interrupt::disable(Cpu::current(), Interrupt::RSA);
Rsa {
rsa: self.rsa,
phantom: PhantomData,
#[cfg(not(esp32))]
_memory_guard: self._memory_guard,
_guard: self._guard,
}
}
}
impl<'d, Dm: DriverMode> Rsa<'d, Dm> {
fn internal_enable_disable_interrupt(&self, enable: bool) {
cfg_if::cfg_if! {
if #[cfg(esp32)] {
self.regs().interrupt().write(|w| w.interrupt().bit(enable));
} else {
self.regs().int_ena().write(|w| w.int_ena().bit(enable));
}
}
}
fn regs(&self) -> &pac::rsa::RegisterBlock {
self.rsa.register_block()
}
fn ready(&self) -> bool {
cfg_if::cfg_if! {
if #[cfg(any(esp32, esp32s2, esp32s3))] {
self.regs().clean().read().clean().bit_is_set()
} else {
self.regs().query_clean().read().query_clean().bit_is_set()
}
}
}
fn start_modexp(&self) {
cfg_if::cfg_if! {
if #[cfg(any(esp32, esp32s2, esp32s3))] {
self.regs()
.modexp_start()
.write(|w| w.modexp_start().set_bit());
} else {
self.regs()
.set_start_modexp()
.write(|w| w.set_start_modexp().set_bit());
}
}
}
fn start_multi(&self) {
cfg_if::cfg_if! {
if #[cfg(any(esp32, esp32s2, esp32s3))] {
self.regs().mult_start().write(|w| w.mult_start().set_bit());
} else {
self.regs()
.set_start_mult()
.write(|w| w.set_start_mult().set_bit());
}
}
}
fn start_modmulti(&self) {
cfg_if::cfg_if! {
if #[cfg(esp32)] {
self.start_multi();
} else if #[cfg(any(esp32s2, esp32s3))] {
self.regs()
.modmult_start()
.write(|w| w.modmult_start().set_bit());
} else {
self.regs()
.set_start_modmult()
.write(|w| w.set_start_modmult().set_bit());
}
}
}
fn clear_interrupt(&mut self) {
cfg_if::cfg_if! {
if #[cfg(esp32)] {
self.regs().interrupt().write(|w| w.interrupt().set_bit());
} else {
self.regs().int_clr().write(|w| w.int_clr().set_bit());
}
}
}
fn is_idle(&self) -> bool {
cfg_if::cfg_if! {
if #[cfg(esp32)] {
self.regs().interrupt().read().interrupt().bit_is_set()
} else if #[cfg(any(esp32s2, esp32s3))] {
self.regs().idle().read().idle().bit_is_set()
} else {
self.regs().query_idle().read().query_idle().bit_is_set()
}
}
}
fn wait_for_idle(&mut self) {
while !self.is_idle() {}
self.clear_interrupt();
}
fn write_multi_mode(&mut self, mode: u32, modular: bool) {
let mode = if cfg!(esp32) && !modular {
const NON_MODULAR: u32 = 8;
mode | NON_MODULAR
} else {
mode
};
cfg_if::cfg_if! {
if #[cfg(esp32)] {
self.regs().mult_mode().write(|w| unsafe { w.bits(mode) });
} else {
self.regs().mode().write(|w| unsafe { w.bits(mode) });
}
}
}
fn write_modexp_mode(&mut self, mode: u32) {
cfg_if::cfg_if! {
if #[cfg(esp32)] {
self.regs().modexp_mode().write(|w| unsafe { w.bits(mode) });
} else {
self.regs().mode().write(|w| unsafe { w.bits(mode) });
}
}
}
fn write_operand_b(&mut self, operand: &[u32]) {
for (reg, op) in self.regs().y_mem_iter().zip(operand.iter().copied()) {
reg.write(|w| unsafe { w.bits(op) });
}
}
fn write_modulus(&mut self, modulus: &[u32]) {
for (reg, op) in self.regs().m_mem_iter().zip(modulus.iter().copied()) {
reg.write(|w| unsafe { w.bits(op) });
}
}
fn write_mprime(&mut self, m_prime: u32) {
self.regs().m_prime().write(|w| unsafe { w.bits(m_prime) });
}
fn write_operand_a(&mut self, operand: &[u32]) {
for (reg, op) in self.regs().x_mem_iter().zip(operand.iter().copied()) {
reg.write(|w| unsafe { w.bits(op) });
}
}
fn write_multi_operand_b(&mut self, operand: &[u32]) {
for (reg, op) in self
.regs()
.z_mem_iter()
.skip(operand.len())
.zip(operand.iter().copied())
{
reg.write(|w| unsafe { w.bits(op) });
}
}
fn write_r(&mut self, r: &[u32]) {
for (reg, op) in self.regs().z_mem_iter().zip(r.iter().copied()) {
reg.write(|w| unsafe { w.bits(op) });
}
}
fn read_out(&self, outbuf: &mut [u32]) {
for (reg, op) in self.regs().z_mem_iter().zip(outbuf.iter_mut()) {
*op = reg.read().bits();
}
}
fn read_results(&mut self, outbuf: &mut [u32]) {
self.wait_for_idle();
self.read_out(outbuf);
}
#[doc = trm_markdown_link!("rsa")]
#[cfg(not(esp32))]
pub fn disable_constant_time(&mut self, disable: bool) {
self.regs()
.constant_time()
.write(|w| w.constant_time().bit(disable));
}
#[doc = trm_markdown_link!("rsa")]
#[cfg(not(esp32))]
pub fn search_acceleration(&mut self, enable: bool) {
self.regs()
.search_enable()
.write(|w| w.search_enable().bit(enable));
}
#[cfg(not(esp32))]
fn is_search_enabled(&mut self) -> bool {
self.regs()
.search_enable()
.read()
.search_enable()
.bit_is_set()
}
#[cfg(not(esp32))]
fn write_search_position(&mut self, search_position: u32) {
self.regs()
.search_pos()
.write(|w| unsafe { w.bits(search_position) });
}
}
pub trait RsaMode: crate::private::Sealed {
type InputType: AsRef<[u32]> + AsMut<[u32]>;
}
pub trait Multi: RsaMode {
type OutputType: AsRef<[u32]> + AsMut<[u32]>;
}
pub mod operand_sizes {
for_each_rsa_exponentiation!(
($x:literal) => {
paste::paste! {
#[doc = concat!(stringify!($x), "-bit RSA operation.")]
pub struct [<Op $x>];
impl crate::private::Sealed for [<Op $x>] {}
impl crate::rsa::RsaMode for [<Op $x>] {
type InputType = [u32; $x / 32];
}
}
};
);
for_each_rsa_multiplication!(
($x:literal) => {
impl crate::rsa::Multi for paste::paste!( [<Op $x>] ) {
type OutputType = [u32; $x * 2 / 32];
}
};
);
}
pub struct RsaModularExponentiation<'a, 'd, T: RsaMode, Dm: DriverMode> {
rsa: &'a mut Rsa<'d, Dm>,
phantom: PhantomData<T>,
}
impl<'a, 'd, T: RsaMode, Dm: DriverMode, const N: usize> RsaModularExponentiation<'a, 'd, T, Dm>
where
T: RsaMode<InputType = [u32; N]>,
{
#[doc = trm_markdown_link!("rsa")]
pub fn new(
rsa: &'a mut Rsa<'d, Dm>,
exponent: &T::InputType,
modulus: &T::InputType,
m_prime: u32,
) -> Self {
Self::write_mode(rsa);
rsa.write_operand_b(exponent);
rsa.write_modulus(modulus);
rsa.write_mprime(m_prime);
#[cfg(not(esp32))]
if rsa.is_search_enabled() {
rsa.write_search_position(Self::find_search_pos(exponent));
}
Self {
rsa,
phantom: PhantomData,
}
}
fn set_up_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) {
self.rsa.write_operand_a(base);
self.rsa.write_r(r);
}
#[doc = trm_markdown_link!("rsa")]
pub fn start_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) {
self.set_up_exponentiation(base, r);
self.rsa.start_modexp();
}
pub fn read_results(&mut self, outbuf: &mut T::InputType) {
self.rsa.read_results(outbuf);
}
#[cfg(not(esp32))]
fn find_search_pos(exponent: &T::InputType) -> u32 {
for (i, byte) in exponent.iter().rev().enumerate() {
if *byte == 0 {
continue;
}
return (exponent.len() * 32) as u32 - (byte.leading_zeros() + i as u32 * 32) - 1;
}
0
}
fn write_mode(rsa: &mut Rsa<'d, Dm>) {
rsa.write_modexp_mode(N as u32 / WORDS_PER_INCREMENT - 1);
}
}
pub struct RsaModularMultiplication<'a, 'd, T, Dm>
where
T: RsaMode,
Dm: DriverMode,
{
rsa: &'a mut Rsa<'d, Dm>,
phantom: PhantomData<T>,
}
impl<'a, 'd, T, Dm, const N: usize> RsaModularMultiplication<'a, 'd, T, Dm>
where
T: RsaMode<InputType = [u32; N]>,
Dm: DriverMode,
{
#[doc = trm_markdown_link!("rsa")]
pub fn new(
rsa: &'a mut Rsa<'d, Dm>,
operand_a: &T::InputType,
modulus: &T::InputType,
r: &T::InputType,
m_prime: u32,
) -> Self {
rsa.write_multi_mode(N as u32 / WORDS_PER_INCREMENT - 1, true);
rsa.write_mprime(m_prime);
rsa.write_modulus(modulus);
rsa.write_operand_a(operand_a);
rsa.write_r(r);
Self {
rsa,
phantom: PhantomData,
}
}
#[doc = trm_markdown_link!("rsa")]
pub fn start_modular_multiplication(&mut self, operand_b: &T::InputType) {
self.set_up_modular_multiplication(operand_b);
self.rsa.start_modmulti();
}
pub fn read_results(&mut self, outbuf: &mut T::InputType) {
self.rsa.read_results(outbuf);
}
fn set_up_modular_multiplication(&mut self, operand_b: &T::InputType) {
if cfg!(esp32) {
self.rsa.start_multi();
self.rsa.wait_for_idle();
self.rsa.write_operand_a(operand_b);
} else {
self.rsa.write_operand_b(operand_b);
}
}
}
pub struct RsaMultiplication<'a, 'd, T, Dm>
where
T: RsaMode + Multi,
Dm: DriverMode,
{
rsa: &'a mut Rsa<'d, Dm>,
phantom: PhantomData<T>,
}
impl<'a, 'd, T, Dm, const N: usize> RsaMultiplication<'a, 'd, T, Dm>
where
T: RsaMode<InputType = [u32; N]>,
T: Multi,
Dm: DriverMode,
{
pub fn new(rsa: &'a mut Rsa<'d, Dm>, operand_a: &T::InputType) -> Self {
rsa.write_multi_mode(2 * N as u32 / WORDS_PER_INCREMENT - 1, false);
rsa.write_operand_a(operand_a);
Self {
rsa,
phantom: PhantomData,
}
}
pub fn start_multiplication(&mut self, operand_b: &T::InputType) {
self.set_up_multiplication(operand_b);
self.rsa.start_multi();
}
pub fn read_results<const O: usize>(&mut self, outbuf: &mut T::OutputType)
where
T: Multi<OutputType = [u32; O]>,
{
self.rsa.read_results(outbuf);
}
fn set_up_multiplication(&mut self, operand_b: &T::InputType) {
self.rsa.write_multi_operand_b(operand_b);
}
}
static WAKER: AtomicWaker = AtomicWaker::new();
static SIGNALED: AtomicBool = AtomicBool::new(false);
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct RsaFuture<'a, 'd> {
driver: &'a Rsa<'d, Async>,
}
impl<'a, 'd> RsaFuture<'a, 'd> {
fn new(driver: &'a Rsa<'d, Async>) -> Self {
SIGNALED.store(false, Ordering::Relaxed);
driver.internal_enable_disable_interrupt(true);
Self { driver }
}
fn is_done(&self) -> bool {
SIGNALED.load(Ordering::Acquire)
}
}
impl Drop for RsaFuture<'_, '_> {
fn drop(&mut self) {
self.driver.internal_enable_disable_interrupt(false);
}
}
impl core::future::Future for RsaFuture<'_, '_> {
type Output = ();
fn poll(
self: core::pin::Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
) -> core::task::Poll<Self::Output> {
WAKER.register(cx.waker());
if self.is_done() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
impl<T: RsaMode, const N: usize> RsaModularExponentiation<'_, '_, T, Async>
where
T: RsaMode<InputType = [u32; N]>,
{
pub async fn exponentiation(
&mut self,
base: &T::InputType,
r: &T::InputType,
outbuf: &mut T::InputType,
) {
self.set_up_exponentiation(base, r);
let fut = RsaFuture::new(self.rsa);
self.rsa.start_modexp();
fut.await;
self.rsa.read_out(outbuf);
}
}
impl<T: RsaMode, const N: usize> RsaModularMultiplication<'_, '_, T, Async>
where
T: RsaMode<InputType = [u32; N]>,
{
pub async fn modular_multiplication(
&mut self,
operand_b: &T::InputType,
outbuf: &mut T::InputType,
) {
if cfg!(esp32) {
let fut = RsaFuture::new(self.rsa);
self.rsa.start_multi();
fut.await;
self.rsa.write_operand_a(operand_b);
} else {
self.set_up_modular_multiplication(operand_b);
}
let fut = RsaFuture::new(self.rsa);
self.rsa.start_modmulti();
fut.await;
self.rsa.read_out(outbuf);
}
}
impl<T: RsaMode + Multi, const N: usize> RsaMultiplication<'_, '_, T, Async>
where
T: RsaMode<InputType = [u32; N]>,
{
pub async fn multiplication<const O: usize>(
&mut self,
operand_b: &T::InputType,
outbuf: &mut T::OutputType,
) where
T: Multi<OutputType = [u32; O]>,
{
self.set_up_multiplication(operand_b);
let fut = RsaFuture::new(self.rsa);
self.rsa.start_multi();
fut.await;
self.rsa.read_out(outbuf);
}
}
#[handler]
pub(super) fn rsa_interrupt_handler() {
let rsa = RSA::regs();
SIGNALED.store(true, Ordering::Release);
cfg_if::cfg_if! {
if #[cfg(esp32)] {
rsa.interrupt().write(|w| w.interrupt().set_bit());
} else {
rsa.int_clr().write(|w| w.int_clr().set_bit());
}
}
WAKER.wake();
}
static RSA_WORK_QUEUE: WorkQueue<RsaWorkItem> = WorkQueue::new();
const RSA_VTABLE: VTable<RsaWorkItem> = VTable {
post: |driver, item| {
let driver = unsafe { RsaBackend::from_raw(driver) };
Some(driver.process_item(item))
},
poll: |driver, item| {
let driver = unsafe { RsaBackend::from_raw(driver) };
driver.process_item(item)
},
cancel: |driver, item| {
let driver = unsafe { RsaBackend::from_raw(driver) };
driver.cancel(item)
},
stop: |driver| {
let driver = unsafe { RsaBackend::from_raw(driver) };
driver.deinitialize()
},
};
#[derive(Default)]
enum RsaBackendState<'d> {
#[default]
Idle,
Initializing(Rsa<'d, Blocking>),
Ready(Rsa<'d, Blocking>),
#[cfg(esp32)]
ModularMultiplicationRoundOne(Rsa<'d, Blocking>),
Processing(Rsa<'d, Blocking>),
}
#[procmacros::doc_replace]
pub struct RsaBackend<'d> {
peri: RSA<'d>,
state: RsaBackendState<'d>,
}
impl<'d> RsaBackend<'d> {
#[procmacros::doc_replace]
pub fn new(rsa: RSA<'d>) -> Self {
Self {
peri: rsa,
state: RsaBackendState::Idle,
}
}
#[procmacros::doc_replace]
pub fn start(&mut self) -> RsaWorkQueueDriver<'_, 'd> {
RsaWorkQueueDriver {
inner: WorkQueueDriver::new(self, RSA_VTABLE, &RSA_WORK_QUEUE),
}
}
unsafe fn from_raw<'any>(ptr: NonNull<()>) -> &'any mut Self {
unsafe { ptr.cast::<RsaBackend<'_>>().as_mut() }
}
fn process_item(&mut self, item: &mut RsaWorkItem) -> work_queue::Poll {
match core::mem::take(&mut self.state) {
RsaBackendState::Idle => {
let driver = Rsa {
rsa: unsafe { self.peri.clone_unchecked() },
phantom: PhantomData,
#[cfg(not(esp32))]
_memory_guard: RsaMemoryPowerGuard::new(),
_guard: GenericPeripheralGuard::new(),
};
self.state = RsaBackendState::Initializing(driver);
work_queue::Poll::Pending(true)
}
RsaBackendState::Initializing(mut rsa) => {
self.state = if rsa.ready() {
rsa.set_interrupt_handler(rsa_work_queue_handler);
rsa.enable_disable_interrupt(true);
RsaBackendState::Ready(rsa)
} else {
RsaBackendState::Initializing(rsa)
};
work_queue::Poll::Pending(true)
}
RsaBackendState::Ready(mut rsa) => {
#[cfg(not(esp32))]
{
rsa.disable_constant_time(!item.constant_time);
rsa.search_acceleration(item.search_acceleration);
}
match item.operation {
RsaOperation::Multiplication { x, y } => {
let n = x.len() as u32;
rsa.write_operand_a(unsafe { x.as_ref() });
rsa.write_multi_mode(2 * n / WORDS_PER_INCREMENT - 1, false);
rsa.write_multi_operand_b(unsafe { y.as_ref() });
rsa.start_multi();
}
RsaOperation::ModularMultiplication {
x,
#[cfg(not(esp32))]
y,
m,
m_prime,
r: r_inv,
..
} => {
let n = x.len() as u32;
rsa.write_operand_a(unsafe { x.as_ref() });
rsa.write_multi_mode(n / WORDS_PER_INCREMENT - 1, true);
#[cfg(not(esp32))]
rsa.write_operand_b(unsafe { y.as_ref() });
rsa.write_modulus(unsafe { m.as_ref() });
rsa.write_mprime(m_prime);
rsa.write_r(unsafe { r_inv.as_ref() });
rsa.start_modmulti();
#[cfg(esp32)]
{
self.state = RsaBackendState::ModularMultiplicationRoundOne(rsa);
return work_queue::Poll::Pending(false);
}
}
RsaOperation::ModularExponentiation {
x,
y,
m,
m_prime,
r_inv,
} => {
let n = x.len() as u32;
rsa.write_operand_a(unsafe { x.as_ref() });
rsa.write_modexp_mode(n / WORDS_PER_INCREMENT - 1);
rsa.write_operand_b(unsafe { y.as_ref() });
rsa.write_modulus(unsafe { m.as_ref() });
rsa.write_mprime(m_prime);
rsa.write_r(unsafe { r_inv.as_ref() });
#[cfg(not(esp32))]
if item.search_acceleration {
fn find_search_pos(exponent: &[u32]) -> u32 {
for (i, byte) in exponent.iter().rev().enumerate() {
if *byte == 0 {
continue;
}
return (exponent.len() * 32) as u32
- (byte.leading_zeros() + i as u32 * 32)
- 1;
}
0
}
rsa.write_search_position(find_search_pos(unsafe { y.as_ref() }));
}
rsa.start_modexp();
}
}
self.state = RsaBackendState::Processing(rsa);
work_queue::Poll::Pending(false)
}
#[cfg(esp32)]
RsaBackendState::ModularMultiplicationRoundOne(mut rsa) => {
if rsa.is_idle() {
let RsaOperation::ModularMultiplication { y, .. } = item.operation else {
unreachable!();
};
rsa.write_operand_a(unsafe { y.as_ref() });
rsa.start_modmulti();
self.state = RsaBackendState::Processing(rsa);
} else {
self.state = RsaBackendState::ModularMultiplicationRoundOne(rsa);
}
work_queue::Poll::Pending(false)
}
RsaBackendState::Processing(rsa) => {
if rsa.is_idle() {
rsa.read_out(unsafe { item.result.as_mut() });
self.state = RsaBackendState::Ready(rsa);
work_queue::Poll::Ready(Status::Completed)
} else {
self.state = RsaBackendState::Processing(rsa);
work_queue::Poll::Pending(false)
}
}
}
}
fn cancel(&mut self, _item: &mut RsaWorkItem) {
self.state = RsaBackendState::Idle;
}
fn deinitialize(&mut self) {
self.state = RsaBackendState::Idle;
}
}
pub struct RsaWorkQueueDriver<'t, 'd> {
inner: WorkQueueDriver<'t, RsaBackend<'d>, RsaWorkItem>,
}
impl<'t, 'd> RsaWorkQueueDriver<'t, 'd> {
pub fn stop(self) -> impl Future<Output = ()> {
self.inner.stop()
}
}
#[derive(Clone)]
struct RsaWorkItem {
#[cfg(not(esp32))]
search_acceleration: bool,
#[cfg(not(esp32))]
constant_time: bool,
operation: RsaOperation,
result: NonNull<[u32]>,
}
unsafe impl Sync for RsaWorkItem {}
unsafe impl Send for RsaWorkItem {}
#[derive(Clone)]
enum RsaOperation {
Multiplication {
x: NonNull<[u32]>,
y: NonNull<[u32]>,
},
ModularMultiplication {
x: NonNull<[u32]>,
y: NonNull<[u32]>,
m: NonNull<[u32]>,
r: NonNull<[u32]>,
m_prime: u32,
},
ModularExponentiation {
x: NonNull<[u32]>,
y: NonNull<[u32]>,
m: NonNull<[u32]>,
r_inv: NonNull<[u32]>,
m_prime: u32,
},
}
#[handler]
#[ram]
fn rsa_work_queue_handler() {
if !RSA_WORK_QUEUE.process() {
cfg_if::cfg_if! {
if #[cfg(esp32)] {
RSA::regs().interrupt().write(|w| w.interrupt().set_bit());
} else {
RSA::regs().int_clr().write(|w| w.int_clr().set_bit());
}
}
}
}
#[cfg_attr(
not(esp32),
doc = " \nThe context is created with a secure configuration by default. You can enable hardware acceleration
options using [enable_search_acceleration][Self::enable_search_acceleration] and
[enable_acceleration][Self::enable_acceleration] when appropriate."
)]
#[derive(Clone)]
pub struct RsaContext {
frontend: WorkQueueFrontend<RsaWorkItem>,
}
impl Default for RsaContext {
fn default() -> Self {
Self::new()
}
}
impl RsaContext {
pub fn new() -> Self {
Self {
frontend: WorkQueueFrontend::new(RsaWorkItem {
#[cfg(not(esp32))]
search_acceleration: false,
#[cfg(not(esp32))]
constant_time: true,
operation: RsaOperation::Multiplication {
x: NonNull::from(&[]),
y: NonNull::from(&[]),
},
result: NonNull::from(&mut []),
}),
}
}
#[cfg(not(esp32))]
#[doc = trm_markdown_link!("rsa")]
pub fn enable_search_acceleration(&mut self) {
self.frontend.data_mut().search_acceleration = true;
}
#[cfg(not(esp32))]
#[doc = trm_markdown_link!("rsa")]
pub fn enable_acceleration(&mut self) {
self.frontend.data_mut().constant_time = false;
}
fn post(&mut self) -> RsaHandle<'_> {
RsaHandle(self.frontend.post(&RSA_WORK_QUEUE))
}
#[procmacros::doc_replace]
pub fn modular_exponentiate<'t, OP>(
&'t mut self,
x: &'t OP::InputType,
y: &'t OP::InputType,
m: &'t OP::InputType,
r: &'t OP::InputType,
m_prime: u32,
result: &'t mut OP::InputType,
) -> RsaHandle<'t>
where
OP: RsaMode,
{
self.frontend.data_mut().operation = RsaOperation::ModularExponentiation {
x: NonNull::from(x.as_ref()),
y: NonNull::from(y.as_ref()),
m: NonNull::from(m.as_ref()),
r_inv: NonNull::from(r.as_ref()),
m_prime,
};
self.frontend.data_mut().result = NonNull::from(result.as_mut());
self.post()
}
pub fn modular_multiply<'t, OP>(
&'t mut self,
x: &'t OP::InputType,
y: &'t OP::InputType,
m: &'t OP::InputType,
r: &'t OP::InputType,
m_prime: u32,
result: &'t mut OP::InputType,
) -> RsaHandle<'t>
where
OP: RsaMode,
{
self.frontend.data_mut().operation = RsaOperation::ModularMultiplication {
x: NonNull::from(x.as_ref()),
y: NonNull::from(y.as_ref()),
m: NonNull::from(m.as_ref()),
r: NonNull::from(r.as_ref()),
m_prime,
};
self.frontend.data_mut().result = NonNull::from(result.as_mut());
self.post()
}
#[procmacros::doc_replace]
pub fn multiply<'t, OP>(
&'t mut self,
x: &'t OP::InputType,
y: &'t OP::InputType,
result: &'t mut OP::OutputType,
) -> RsaHandle<'t>
where
OP: Multi,
{
self.frontend.data_mut().operation = RsaOperation::Multiplication {
x: NonNull::from(x.as_ref()),
y: NonNull::from(y.as_ref()),
};
self.frontend.data_mut().result = NonNull::from(result.as_mut());
self.post()
}
}
pub struct RsaHandle<'t>(work_queue::Handle<'t, RsaWorkItem>);
impl RsaHandle<'_> {
#[inline]
pub fn poll(&mut self) -> bool {
self.0.poll()
}
#[inline]
pub fn wait_blocking(mut self) {
while !self.poll() {}
}
#[inline]
pub fn wait(&mut self) -> impl Future<Output = Status> {
self.0.wait()
}
}