1use core::{marker::PhantomData, ptr::NonNull, task::Poll};
17
18use portable_atomic::{AtomicBool, Ordering};
19use procmacros::{handler, ram};
20
21use crate::{
22 Async,
23 Blocking,
24 DriverMode,
25 asynch::AtomicWaker,
26 interrupt::InterruptHandler,
27 pac,
28 peripherals::{Interrupt, RSA},
29 system::{Cpu, GenericPeripheralGuard, Peripheral as PeripheralEnable},
30 trm_markdown_link,
31 work_queue::{self, Status, VTable, WorkQueue, WorkQueueDriver, WorkQueueFrontend},
32};
33
34pub struct Rsa<'d, Dm: DriverMode> {
36 rsa: RSA<'d>,
37 phantom: PhantomData<Dm>,
38 #[cfg(not(esp32))]
39 _memory_guard: RsaMemoryPowerGuard,
40 _guard: GenericPeripheralGuard<{ PeripheralEnable::Rsa as u8 }>,
41}
42
43const WORDS_PER_INCREMENT: u32 = property!("rsa.size_increment") / 32;
52
53#[cfg(not(esp32))]
54struct RsaMemoryPowerGuard;
55
56#[cfg(not(esp32))]
57impl RsaMemoryPowerGuard {
58 fn new() -> Self {
59 crate::peripherals::SYSTEM::regs()
60 .rsa_pd_ctrl()
61 .modify(|_, w| {
62 w.rsa_mem_force_pd().clear_bit();
63 w.rsa_mem_force_pu().set_bit();
64 w.rsa_mem_pd().clear_bit()
65 });
66 Self
67 }
68}
69
70#[cfg(not(esp32))]
71impl Drop for RsaMemoryPowerGuard {
72 fn drop(&mut self) {
73 unsafe {
74 crate::peripherals::RSA::steal().disable_peri_interrupt_on_all_cores();
79 }
80 crate::peripherals::SYSTEM::regs()
81 .rsa_pd_ctrl()
82 .modify(|_, w| {
83 w.rsa_mem_force_pd().clear_bit();
84 w.rsa_mem_force_pu().clear_bit();
85 w.rsa_mem_pd().set_bit()
86 });
87 }
88}
89
90impl<'d> Rsa<'d, Blocking> {
91 pub fn new(rsa: RSA<'d>) -> Self {
95 let guard = GenericPeripheralGuard::new();
96
97 let this = Self {
98 rsa,
99 phantom: PhantomData,
100 #[cfg(not(esp32))]
101 _memory_guard: RsaMemoryPowerGuard::new(),
102 _guard: guard,
103 };
104
105 while !this.ready() {}
106
107 this
108 }
109
110 pub fn into_async(mut self) -> Rsa<'d, Async> {
112 self.set_interrupt_handler(rsa_interrupt_handler);
113 self.enable_disable_interrupt(true);
114
115 Rsa {
116 rsa: self.rsa,
117 phantom: PhantomData,
118 #[cfg(not(esp32))]
119 _memory_guard: self._memory_guard,
120 _guard: self._guard,
121 }
122 }
123
124 pub fn enable_disable_interrupt(&mut self, enable: bool) {
129 self.internal_enable_disable_interrupt(enable);
130 }
131
132 #[instability::unstable]
137 pub fn set_interrupt_handler(&mut self, handler: InterruptHandler) {
138 self.rsa.disable_peri_interrupt_on_all_cores();
139 self.rsa.bind_peri_interrupt(handler);
140 }
141}
142
143impl crate::private::Sealed for Rsa<'_, Blocking> {}
144
145#[instability::unstable]
146impl crate::interrupt::InterruptConfigurable for Rsa<'_, Blocking> {
147 fn set_interrupt_handler(&mut self, handler: InterruptHandler) {
148 self.set_interrupt_handler(handler);
149 }
150}
151
152impl<'d> Rsa<'d, Async> {
153 pub fn into_blocking(self) -> Rsa<'d, Blocking> {
155 self.internal_enable_disable_interrupt(false);
156 self.rsa.disable_peri_interrupt_on_all_cores();
157
158 crate::interrupt::disable(Cpu::current(), Interrupt::RSA);
159 Rsa {
160 rsa: self.rsa,
161 phantom: PhantomData,
162 #[cfg(not(esp32))]
163 _memory_guard: self._memory_guard,
164 _guard: self._guard,
165 }
166 }
167}
168
169impl<'d, Dm: DriverMode> Rsa<'d, Dm> {
170 fn internal_enable_disable_interrupt(&self, enable: bool) {
171 cfg_if::cfg_if! {
172 if #[cfg(esp32)] {
173 self.regs().interrupt().write(|w| w.interrupt().bit(enable));
175 } else {
176 self.regs().int_ena().write(|w| w.int_ena().bit(enable));
177 }
178 }
179 }
180
181 fn regs(&self) -> &pac::rsa::RegisterBlock {
182 self.rsa.register_block()
183 }
184
185 fn ready(&self) -> bool {
190 cfg_if::cfg_if! {
191 if #[cfg(any(esp32, esp32s2, esp32s3))] {
192 self.regs().clean().read().clean().bit_is_set()
193 } else {
194 self.regs().query_clean().read().query_clean().bit_is_set()
195 }
196 }
197 }
198
199 fn start_modexp(&self) {
201 cfg_if::cfg_if! {
202 if #[cfg(any(esp32, esp32s2, esp32s3))] {
203 self.regs()
204 .modexp_start()
205 .write(|w| w.modexp_start().set_bit());
206 } else {
207 self.regs()
208 .set_start_modexp()
209 .write(|w| w.set_start_modexp().set_bit());
210 }
211 }
212 }
213
214 fn start_multi(&self) {
216 cfg_if::cfg_if! {
217 if #[cfg(any(esp32, esp32s2, esp32s3))] {
218 self.regs().mult_start().write(|w| w.mult_start().set_bit());
219 } else {
220 self.regs()
221 .set_start_mult()
222 .write(|w| w.set_start_mult().set_bit());
223 }
224 }
225 }
226
227 fn start_modmulti(&self) {
229 cfg_if::cfg_if! {
230 if #[cfg(esp32)] {
231 self.start_multi();
233 } else if #[cfg(any(esp32s2, esp32s3))] {
234 self.regs()
235 .modmult_start()
236 .write(|w| w.modmult_start().set_bit());
237 } else {
238 self.regs()
239 .set_start_modmult()
240 .write(|w| w.set_start_modmult().set_bit());
241 }
242 }
243 }
244
245 fn clear_interrupt(&mut self) {
247 cfg_if::cfg_if! {
248 if #[cfg(esp32)] {
249 self.regs().interrupt().write(|w| w.interrupt().set_bit());
250 } else {
251 self.regs().int_clr().write(|w| w.int_clr().set_bit());
252 }
253 }
254 }
255
256 fn is_idle(&self) -> bool {
258 cfg_if::cfg_if! {
259 if #[cfg(esp32)] {
260 self.regs().interrupt().read().interrupt().bit_is_set()
261 } else if #[cfg(any(esp32s2, esp32s3))] {
262 self.regs().idle().read().idle().bit_is_set()
263 } else {
264 self.regs().query_idle().read().query_idle().bit_is_set()
265 }
266 }
267 }
268
269 fn wait_for_idle(&mut self) {
270 while !self.is_idle() {}
271 self.clear_interrupt();
272 }
273
274 fn write_multi_mode(&mut self, mode: u32, modular: bool) {
276 let mode = if cfg!(esp32) && !modular {
277 const NON_MODULAR: u32 = 8;
278 mode | NON_MODULAR
279 } else {
280 mode
281 };
282
283 cfg_if::cfg_if! {
284 if #[cfg(esp32)] {
285 self.regs().mult_mode().write(|w| unsafe { w.bits(mode) });
286 } else {
287 self.regs().mode().write(|w| unsafe { w.bits(mode) });
288 }
289 }
290 }
291
292 fn write_modexp_mode(&mut self, mode: u32) {
294 cfg_if::cfg_if! {
295 if #[cfg(esp32)] {
296 self.regs().modexp_mode().write(|w| unsafe { w.bits(mode) });
297 } else {
298 self.regs().mode().write(|w| unsafe { w.bits(mode) });
299 }
300 }
301 }
302
303 fn write_operand_b(&mut self, operand: &[u32]) {
304 for (reg, op) in self.regs().y_mem_iter().zip(operand.iter().copied()) {
305 reg.write(|w| unsafe { w.bits(op) });
306 }
307 }
308
309 fn write_modulus(&mut self, modulus: &[u32]) {
310 for (reg, op) in self.regs().m_mem_iter().zip(modulus.iter().copied()) {
311 reg.write(|w| unsafe { w.bits(op) });
312 }
313 }
314
315 fn write_mprime(&mut self, m_prime: u32) {
316 self.regs().m_prime().write(|w| unsafe { w.bits(m_prime) });
317 }
318
319 fn write_operand_a(&mut self, operand: &[u32]) {
320 for (reg, op) in self.regs().x_mem_iter().zip(operand.iter().copied()) {
321 reg.write(|w| unsafe { w.bits(op) });
322 }
323 }
324
325 fn write_multi_operand_b(&mut self, operand: &[u32]) {
326 for (reg, op) in self
327 .regs()
328 .z_mem_iter()
329 .skip(operand.len())
330 .zip(operand.iter().copied())
331 {
332 reg.write(|w| unsafe { w.bits(op) });
333 }
334 }
335
336 fn write_r(&mut self, r: &[u32]) {
337 for (reg, op) in self.regs().z_mem_iter().zip(r.iter().copied()) {
338 reg.write(|w| unsafe { w.bits(op) });
339 }
340 }
341
342 fn read_out(&self, outbuf: &mut [u32]) {
343 for (reg, op) in self.regs().z_mem_iter().zip(outbuf.iter_mut()) {
344 *op = reg.read().bits();
345 }
346 }
347
348 fn read_results(&mut self, outbuf: &mut [u32]) {
349 self.wait_for_idle();
350 self.read_out(outbuf);
351 }
352
353 #[doc = trm_markdown_link!("rsa")]
364 #[cfg(not(esp32))]
365 pub fn disable_constant_time(&mut self, disable: bool) {
366 self.regs()
367 .constant_time()
368 .write(|w| w.constant_time().bit(disable));
369 }
370
371 #[doc = trm_markdown_link!("rsa")]
381 #[cfg(not(esp32))]
382 pub fn search_acceleration(&mut self, enable: bool) {
383 self.regs()
384 .search_enable()
385 .write(|w| w.search_enable().bit(enable));
386 }
387
388 #[cfg(not(esp32))]
390 fn is_search_enabled(&mut self) -> bool {
391 self.regs()
392 .search_enable()
393 .read()
394 .search_enable()
395 .bit_is_set()
396 }
397
398 #[cfg(not(esp32))]
400 fn write_search_position(&mut self, search_position: u32) {
401 self.regs()
402 .search_pos()
403 .write(|w| unsafe { w.bits(search_position) });
404 }
405}
406
407pub trait RsaMode: crate::private::Sealed {
409 type InputType: AsRef<[u32]> + AsMut<[u32]>;
411}
412
413pub trait Multi: RsaMode {
415 type OutputType: AsRef<[u32]> + AsMut<[u32]>;
417}
418
419pub mod operand_sizes {
421 for_each_rsa_exponentiation!(
422 ($x:literal) => {
423 paste::paste! {
424 #[doc = concat!(stringify!($x), "-bit RSA operation.")]
425 pub struct [<Op $x>];
426
427 impl crate::private::Sealed for [<Op $x>] {}
428 impl crate::rsa::RsaMode for [<Op $x>] {
429 type InputType = [u32; $x / 32];
430 }
431 }
432 };
433 );
434
435 for_each_rsa_multiplication!(
436 ($x:literal) => {
437 impl crate::rsa::Multi for paste::paste!( [<Op $x>] ) {
438 type OutputType = [u32; $x * 2 / 32];
439 }
440 };
441 );
442}
443
444pub struct RsaModularExponentiation<'a, 'd, T: RsaMode, Dm: DriverMode> {
449 rsa: &'a mut Rsa<'d, Dm>,
450 phantom: PhantomData<T>,
451}
452
453impl<'a, 'd, T: RsaMode, Dm: DriverMode, const N: usize> RsaModularExponentiation<'a, 'd, T, Dm>
454where
455 T: RsaMode<InputType = [u32; N]>,
456{
457 #[doc = trm_markdown_link!("rsa")]
464 pub fn new(
465 rsa: &'a mut Rsa<'d, Dm>,
466 exponent: &T::InputType,
467 modulus: &T::InputType,
468 m_prime: u32,
469 ) -> Self {
470 Self::write_mode(rsa);
471 rsa.write_operand_b(exponent);
472 rsa.write_modulus(modulus);
473 rsa.write_mprime(m_prime);
474
475 #[cfg(not(esp32))]
476 if rsa.is_search_enabled() {
477 rsa.write_search_position(Self::find_search_pos(exponent));
478 }
479
480 Self {
481 rsa,
482 phantom: PhantomData,
483 }
484 }
485
486 fn set_up_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) {
487 self.rsa.write_operand_a(base);
488 self.rsa.write_r(r);
489 }
490
491 #[doc = trm_markdown_link!("rsa")]
497 pub fn start_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) {
498 self.set_up_exponentiation(base, r);
499 self.rsa.start_modexp();
500 }
501
502 pub fn read_results(&mut self, outbuf: &mut T::InputType) {
508 self.rsa.read_results(outbuf);
509 }
510
511 #[cfg(not(esp32))]
512 fn find_search_pos(exponent: &T::InputType) -> u32 {
513 for (i, byte) in exponent.iter().rev().enumerate() {
514 if *byte == 0 {
515 continue;
516 }
517 return (exponent.len() * 32) as u32 - (byte.leading_zeros() + i as u32 * 32) - 1;
518 }
519 0
520 }
521
522 fn write_mode(rsa: &mut Rsa<'d, Dm>) {
524 rsa.write_modexp_mode(N as u32 / WORDS_PER_INCREMENT - 1);
525 }
526}
527
528pub struct RsaModularMultiplication<'a, 'd, T, Dm>
533where
534 T: RsaMode,
535 Dm: DriverMode,
536{
537 rsa: &'a mut Rsa<'d, Dm>,
538 phantom: PhantomData<T>,
539}
540
541impl<'a, 'd, T, Dm, const N: usize> RsaModularMultiplication<'a, 'd, T, Dm>
542where
543 T: RsaMode<InputType = [u32; N]>,
544 Dm: DriverMode,
545{
546 #[doc = trm_markdown_link!("rsa")]
553 pub fn new(
554 rsa: &'a mut Rsa<'d, Dm>,
555 operand_a: &T::InputType,
556 modulus: &T::InputType,
557 r: &T::InputType,
558 m_prime: u32,
559 ) -> Self {
560 rsa.write_multi_mode(N as u32 / WORDS_PER_INCREMENT - 1, true);
561
562 rsa.write_mprime(m_prime);
563 rsa.write_modulus(modulus);
564 rsa.write_operand_a(operand_a);
565 rsa.write_r(r);
566
567 Self {
568 rsa,
569 phantom: PhantomData,
570 }
571 }
572
573 #[doc = trm_markdown_link!("rsa")]
577 pub fn start_modular_multiplication(&mut self, operand_b: &T::InputType) {
578 self.set_up_modular_multiplication(operand_b);
579 self.rsa.start_modmulti();
580 }
581
582 pub fn read_results(&mut self, outbuf: &mut T::InputType) {
588 self.rsa.read_results(outbuf);
589 }
590
591 fn set_up_modular_multiplication(&mut self, operand_b: &T::InputType) {
592 if cfg!(esp32) {
593 self.rsa.start_multi();
594 self.rsa.wait_for_idle();
595
596 self.rsa.write_operand_a(operand_b);
597 } else {
598 self.rsa.write_operand_b(operand_b);
599 }
600 }
601}
602
603pub struct RsaMultiplication<'a, 'd, T, Dm>
608where
609 T: RsaMode + Multi,
610 Dm: DriverMode,
611{
612 rsa: &'a mut Rsa<'d, Dm>,
613 phantom: PhantomData<T>,
614}
615
616impl<'a, 'd, T, Dm, const N: usize> RsaMultiplication<'a, 'd, T, Dm>
617where
618 T: RsaMode<InputType = [u32; N]>,
619 T: Multi,
620 Dm: DriverMode,
621{
622 pub fn new(rsa: &'a mut Rsa<'d, Dm>, operand_a: &T::InputType) -> Self {
624 rsa.write_multi_mode(2 * N as u32 / WORDS_PER_INCREMENT - 1, false);
626 rsa.write_operand_a(operand_a);
627
628 Self {
629 rsa,
630 phantom: PhantomData,
631 }
632 }
633
634 pub fn start_multiplication(&mut self, operand_b: &T::InputType) {
636 self.set_up_multiplication(operand_b);
637 self.rsa.start_multi();
638 }
639
640 pub fn read_results<const O: usize>(&mut self, outbuf: &mut T::OutputType)
646 where
647 T: Multi<OutputType = [u32; O]>,
648 {
649 self.rsa.read_results(outbuf);
650 }
651
652 fn set_up_multiplication(&mut self, operand_b: &T::InputType) {
653 self.rsa.write_multi_operand_b(operand_b);
654 }
655}
656
657static WAKER: AtomicWaker = AtomicWaker::new();
658static SIGNALED: AtomicBool = AtomicBool::new(false);
660
661#[must_use = "futures do nothing unless you `.await` or poll them"]
663struct RsaFuture<'a, 'd> {
664 driver: &'a Rsa<'d, Async>,
665}
666
667impl<'a, 'd> RsaFuture<'a, 'd> {
668 fn new(driver: &'a Rsa<'d, Async>) -> Self {
669 SIGNALED.store(false, Ordering::Relaxed);
670
671 driver.internal_enable_disable_interrupt(true);
672
673 Self { driver }
674 }
675
676 fn is_done(&self) -> bool {
677 SIGNALED.load(Ordering::Acquire)
678 }
679}
680
681impl Drop for RsaFuture<'_, '_> {
682 fn drop(&mut self) {
683 self.driver.internal_enable_disable_interrupt(false);
684 }
685}
686
687impl core::future::Future for RsaFuture<'_, '_> {
688 type Output = ();
689
690 fn poll(
691 self: core::pin::Pin<&mut Self>,
692 cx: &mut core::task::Context<'_>,
693 ) -> core::task::Poll<Self::Output> {
694 WAKER.register(cx.waker());
695 if self.is_done() {
696 Poll::Ready(())
697 } else {
698 Poll::Pending
699 }
700 }
701}
702
703impl<T: RsaMode, const N: usize> RsaModularExponentiation<'_, '_, T, Async>
704where
705 T: RsaMode<InputType = [u32; N]>,
706{
707 pub async fn exponentiation(
709 &mut self,
710 base: &T::InputType,
711 r: &T::InputType,
712 outbuf: &mut T::InputType,
713 ) {
714 self.set_up_exponentiation(base, r);
715 let fut = RsaFuture::new(self.rsa);
716 self.rsa.start_modexp();
717 fut.await;
718 self.rsa.read_out(outbuf);
719 }
720}
721
722impl<T: RsaMode, const N: usize> RsaModularMultiplication<'_, '_, T, Async>
723where
724 T: RsaMode<InputType = [u32; N]>,
725{
726 pub async fn modular_multiplication(
728 &mut self,
729 operand_b: &T::InputType,
730 outbuf: &mut T::InputType,
731 ) {
732 if cfg!(esp32) {
733 let fut = RsaFuture::new(self.rsa);
734 self.rsa.start_multi();
735 fut.await;
736
737 self.rsa.write_operand_a(operand_b);
738 } else {
739 self.set_up_modular_multiplication(operand_b);
740 }
741
742 let fut = RsaFuture::new(self.rsa);
743 self.rsa.start_modmulti();
744 fut.await;
745 self.rsa.read_out(outbuf);
746 }
747}
748
749impl<T: RsaMode + Multi, const N: usize> RsaMultiplication<'_, '_, T, Async>
750where
751 T: RsaMode<InputType = [u32; N]>,
752{
753 pub async fn multiplication<const O: usize>(
755 &mut self,
756 operand_b: &T::InputType,
757 outbuf: &mut T::OutputType,
758 ) where
759 T: Multi<OutputType = [u32; O]>,
760 {
761 self.set_up_multiplication(operand_b);
762 let fut = RsaFuture::new(self.rsa);
763 self.rsa.start_multi();
764 fut.await;
765 self.rsa.read_out(outbuf);
766 }
767}
768
769#[handler]
770pub(super) fn rsa_interrupt_handler() {
772 let rsa = RSA::regs();
773 SIGNALED.store(true, Ordering::Release);
774 cfg_if::cfg_if! {
775 if #[cfg(esp32)] {
776 rsa.interrupt().write(|w| w.interrupt().set_bit());
777 } else {
778 rsa.int_clr().write(|w| w.int_clr().set_bit());
779 }
780 }
781
782 WAKER.wake();
783}
784
785static RSA_WORK_QUEUE: WorkQueue<RsaWorkItem> = WorkQueue::new();
786const RSA_VTABLE: VTable<RsaWorkItem> = VTable {
787 post: |driver, item| {
788 let driver = unsafe { RsaBackend::from_raw(driver) };
790 Some(driver.process_item(item))
791 },
792 poll: |driver, item| {
793 let driver = unsafe { RsaBackend::from_raw(driver) };
794 driver.process_item(item)
795 },
796 cancel: |driver, item| {
797 let driver = unsafe { RsaBackend::from_raw(driver) };
798 driver.cancel(item)
799 },
800 stop: |driver| {
801 let driver = unsafe { RsaBackend::from_raw(driver) };
802 driver.deinitialize()
803 },
804};
805
806#[derive(Default)]
807enum RsaBackendState<'d> {
808 #[default]
809 Idle,
810 Initializing(Rsa<'d, Blocking>),
811 Ready(Rsa<'d, Blocking>),
812 #[cfg(esp32)]
813 ModularMultiplicationRoundOne(Rsa<'d, Blocking>),
814 Processing(Rsa<'d, Blocking>),
815}
816
817#[procmacros::doc_replace]
818pub struct RsaBackend<'d> {
848 peri: RSA<'d>,
849 state: RsaBackendState<'d>,
850}
851
852impl<'d> RsaBackend<'d> {
853 #[procmacros::doc_replace]
854 pub fn new(rsa: RSA<'d>) -> Self {
866 Self {
867 peri: rsa,
868 state: RsaBackendState::Idle,
869 }
870 }
871
872 #[procmacros::doc_replace]
873 pub fn start(&mut self) -> RsaWorkQueueDriver<'_, 'd> {
889 RsaWorkQueueDriver {
890 inner: WorkQueueDriver::new(self, RSA_VTABLE, &RSA_WORK_QUEUE),
891 }
892 }
893
894 unsafe fn from_raw<'any>(ptr: NonNull<()>) -> &'any mut Self {
897 unsafe { ptr.cast::<RsaBackend<'_>>().as_mut() }
898 }
899
900 fn process_item(&mut self, item: &mut RsaWorkItem) -> work_queue::Poll {
901 match core::mem::take(&mut self.state) {
902 RsaBackendState::Idle => {
903 let driver = Rsa {
904 rsa: unsafe { self.peri.clone_unchecked() },
905 phantom: PhantomData,
906 #[cfg(not(esp32))]
907 _memory_guard: RsaMemoryPowerGuard::new(),
908 _guard: GenericPeripheralGuard::new(),
909 };
910 self.state = RsaBackendState::Initializing(driver);
911 work_queue::Poll::Pending(true)
912 }
913 RsaBackendState::Initializing(mut rsa) => {
914 self.state = if rsa.ready() {
917 rsa.set_interrupt_handler(rsa_work_queue_handler);
918 rsa.enable_disable_interrupt(true);
919 RsaBackendState::Ready(rsa)
920 } else {
921 RsaBackendState::Initializing(rsa)
922 };
923 work_queue::Poll::Pending(true)
924 }
925 RsaBackendState::Ready(mut rsa) => {
926 #[cfg(not(esp32))]
927 {
928 rsa.disable_constant_time(!item.constant_time);
929 rsa.search_acceleration(item.search_acceleration);
930 }
931
932 match item.operation {
933 RsaOperation::Multiplication { x, y } => {
934 let n = x.len() as u32;
935 rsa.write_operand_a(unsafe { x.as_ref() });
936
937 rsa.write_multi_mode(2 * n / WORDS_PER_INCREMENT - 1, false);
939 rsa.write_multi_operand_b(unsafe { y.as_ref() });
940 rsa.start_multi();
941 }
942
943 RsaOperation::ModularMultiplication {
944 x,
945 #[cfg(not(esp32))]
946 y,
947 m,
948 m_prime,
949 r: r_inv,
950 ..
951 } => {
952 let n = x.len() as u32;
953 rsa.write_operand_a(unsafe { x.as_ref() });
954
955 rsa.write_multi_mode(n / WORDS_PER_INCREMENT - 1, true);
956
957 #[cfg(not(esp32))]
958 rsa.write_operand_b(unsafe { y.as_ref() });
959
960 rsa.write_modulus(unsafe { m.as_ref() });
961 rsa.write_mprime(m_prime);
962 rsa.write_r(unsafe { r_inv.as_ref() });
963
964 rsa.start_modmulti();
965
966 #[cfg(esp32)]
967 {
968 self.state = RsaBackendState::ModularMultiplicationRoundOne(rsa);
971
972 return work_queue::Poll::Pending(false);
973 }
974 }
975 RsaOperation::ModularExponentiation {
976 x,
977 y,
978 m,
979 m_prime,
980 r_inv,
981 } => {
982 let n = x.len() as u32;
983 rsa.write_operand_a(unsafe { x.as_ref() });
984
985 rsa.write_modexp_mode(n / WORDS_PER_INCREMENT - 1);
986 rsa.write_operand_b(unsafe { y.as_ref() });
987 rsa.write_modulus(unsafe { m.as_ref() });
988 rsa.write_mprime(m_prime);
989 rsa.write_r(unsafe { r_inv.as_ref() });
990
991 #[cfg(not(esp32))]
992 if item.search_acceleration {
993 fn find_search_pos(exponent: &[u32]) -> u32 {
994 for (i, byte) in exponent.iter().rev().enumerate() {
995 if *byte == 0 {
996 continue;
997 }
998 return (exponent.len() * 32) as u32
999 - (byte.leading_zeros() + i as u32 * 32)
1000 - 1;
1001 }
1002 0
1003 }
1004 rsa.write_search_position(find_search_pos(unsafe { y.as_ref() }));
1005 }
1006
1007 rsa.start_modexp();
1008 }
1009 }
1010
1011 self.state = RsaBackendState::Processing(rsa);
1012
1013 work_queue::Poll::Pending(false)
1014 }
1015
1016 #[cfg(esp32)]
1017 RsaBackendState::ModularMultiplicationRoundOne(mut rsa) => {
1018 if rsa.is_idle() {
1019 let RsaOperation::ModularMultiplication { y, .. } = item.operation else {
1020 unreachable!();
1021 };
1022
1023 rsa.write_operand_a(unsafe { y.as_ref() });
1025 rsa.start_modmulti();
1026
1027 self.state = RsaBackendState::Processing(rsa);
1028 } else {
1029 self.state = RsaBackendState::ModularMultiplicationRoundOne(rsa);
1031 }
1032 work_queue::Poll::Pending(false)
1033 }
1034
1035 RsaBackendState::Processing(rsa) => {
1036 if rsa.is_idle() {
1037 rsa.read_out(unsafe { item.result.as_mut() });
1038
1039 self.state = RsaBackendState::Ready(rsa);
1040 work_queue::Poll::Ready(Status::Completed)
1041 } else {
1042 self.state = RsaBackendState::Processing(rsa);
1043 work_queue::Poll::Pending(false)
1044 }
1045 }
1046 }
1047 }
1048
1049 fn cancel(&mut self, _item: &mut RsaWorkItem) {
1050 self.state = RsaBackendState::Idle;
1053 }
1054
1055 fn deinitialize(&mut self) {
1056 self.state = RsaBackendState::Idle;
1057 }
1058}
1059
1060pub struct RsaWorkQueueDriver<'t, 'd> {
1066 inner: WorkQueueDriver<'t, RsaBackend<'d>, RsaWorkItem>,
1067}
1068
1069impl<'t, 'd> RsaWorkQueueDriver<'t, 'd> {
1070 pub fn stop(self) -> impl Future<Output = ()> {
1072 self.inner.stop()
1073 }
1074}
1075
1076#[derive(Clone)]
1077struct RsaWorkItem {
1078 #[cfg(not(esp32))]
1080 search_acceleration: bool,
1081 #[cfg(not(esp32))]
1082 constant_time: bool,
1083
1084 operation: RsaOperation,
1086 result: NonNull<[u32]>,
1087}
1088
1089unsafe impl Sync for RsaWorkItem {}
1090unsafe impl Send for RsaWorkItem {}
1091
1092#[derive(Clone)]
1093enum RsaOperation {
1094 Multiplication {
1097 x: NonNull<[u32]>,
1098 y: NonNull<[u32]>,
1099 },
1100 ModularMultiplication {
1102 x: NonNull<[u32]>,
1103 y: NonNull<[u32]>,
1104 m: NonNull<[u32]>,
1105 r: NonNull<[u32]>,
1106 m_prime: u32,
1107 },
1108 ModularExponentiation {
1110 x: NonNull<[u32]>,
1111 y: NonNull<[u32]>,
1112 m: NonNull<[u32]>,
1113 r_inv: NonNull<[u32]>,
1114 m_prime: u32,
1115 },
1116}
1117
1118#[handler]
1119#[ram]
1120fn rsa_work_queue_handler() {
1121 if !RSA_WORK_QUEUE.process() {
1122 cfg_if::cfg_if! {
1125 if #[cfg(esp32)] {
1126 RSA::regs().interrupt().write(|w| w.interrupt().set_bit());
1127 } else {
1128 RSA::regs().int_clr().write(|w| w.int_clr().set_bit());
1129 }
1130 }
1131 }
1132}
1133
1134#[cfg_attr(
1141 not(esp32),
1142 doc = " \nThe context is created with a secure configuration by default. You can enable hardware acceleration
1143 options using [enable_search_acceleration][Self::enable_search_acceleration] and
1144 [enable_acceleration][Self::enable_acceleration] when appropriate."
1145)]
1146#[derive(Clone)]
1147pub struct RsaContext {
1148 frontend: WorkQueueFrontend<RsaWorkItem>,
1149}
1150
1151impl Default for RsaContext {
1152 fn default() -> Self {
1153 Self::new()
1154 }
1155}
1156
1157impl RsaContext {
1158 pub fn new() -> Self {
1160 Self {
1161 frontend: WorkQueueFrontend::new(RsaWorkItem {
1162 #[cfg(not(esp32))]
1163 search_acceleration: false,
1164 #[cfg(not(esp32))]
1165 constant_time: true,
1166 operation: RsaOperation::Multiplication {
1167 x: NonNull::from(&[]),
1168 y: NonNull::from(&[]),
1169 },
1170 result: NonNull::from(&mut []),
1171 }),
1172 }
1173 }
1174
1175 #[cfg(not(esp32))]
1176 #[doc = trm_markdown_link!("rsa")]
1186 pub fn enable_search_acceleration(&mut self) {
1187 self.frontend.data_mut().search_acceleration = true;
1188 }
1189
1190 #[cfg(not(esp32))]
1191 #[doc = trm_markdown_link!("rsa")]
1202 pub fn enable_acceleration(&mut self) {
1203 self.frontend.data_mut().constant_time = false;
1204 }
1205
1206 fn post(&mut self) -> RsaHandle<'_> {
1207 RsaHandle(self.frontend.post(&RSA_WORK_QUEUE))
1208 }
1209
1210 #[procmacros::doc_replace]
1211 pub fn modular_exponentiate<'t, OP>(
1276 &'t mut self,
1277 x: &'t OP::InputType,
1278 y: &'t OP::InputType,
1279 m: &'t OP::InputType,
1280 r: &'t OP::InputType,
1281 m_prime: u32,
1282 result: &'t mut OP::InputType,
1283 ) -> RsaHandle<'t>
1284 where
1285 OP: RsaMode,
1286 {
1287 self.frontend.data_mut().operation = RsaOperation::ModularExponentiation {
1288 x: NonNull::from(x.as_ref()),
1289 y: NonNull::from(y.as_ref()),
1290 m: NonNull::from(m.as_ref()),
1291 r_inv: NonNull::from(r.as_ref()),
1292 m_prime,
1293 };
1294 self.frontend.data_mut().result = NonNull::from(result.as_mut());
1295 self.post()
1296 }
1297
1298 pub fn modular_multiply<'t, OP>(
1314 &'t mut self,
1315 x: &'t OP::InputType,
1316 y: &'t OP::InputType,
1317 m: &'t OP::InputType,
1318 r: &'t OP::InputType,
1319 m_prime: u32,
1320 result: &'t mut OP::InputType,
1321 ) -> RsaHandle<'t>
1322 where
1323 OP: RsaMode,
1324 {
1325 self.frontend.data_mut().operation = RsaOperation::ModularMultiplication {
1326 x: NonNull::from(x.as_ref()),
1327 y: NonNull::from(y.as_ref()),
1328 m: NonNull::from(m.as_ref()),
1329 r: NonNull::from(r.as_ref()),
1330 m_prime,
1331 };
1332 self.frontend.data_mut().result = NonNull::from(result.as_mut());
1333 self.post()
1334 }
1335
1336 #[procmacros::doc_replace]
1337 pub fn multiply<'t, OP>(
1367 &'t mut self,
1368 x: &'t OP::InputType,
1369 y: &'t OP::InputType,
1370 result: &'t mut OP::OutputType,
1371 ) -> RsaHandle<'t>
1372 where
1373 OP: Multi,
1374 {
1375 self.frontend.data_mut().operation = RsaOperation::Multiplication {
1376 x: NonNull::from(x.as_ref()),
1377 y: NonNull::from(y.as_ref()),
1378 };
1379 self.frontend.data_mut().result = NonNull::from(result.as_mut());
1380 self.post()
1381 }
1382}
1383
1384pub struct RsaHandle<'t>(work_queue::Handle<'t, RsaWorkItem>);
1386
1387impl RsaHandle<'_> {
1388 #[inline]
1390 pub fn poll(&mut self) -> bool {
1391 self.0.poll()
1392 }
1393
1394 #[inline]
1396 pub fn wait_blocking(self) {
1397 self.0.wait_blocking();
1398 }
1399
1400 #[inline]
1402 pub fn wait(&mut self) -> impl Future<Output = Status> {
1403 self.0.wait()
1404 }
1405}