1use crate::drbg::HmacDrbgSha256;
19use crate::hash::{sha1, sha256, sha384, sha512};
20use crate::internal_alloc::Vec;
21use noxtls_core::{Error, Result};
22
23use super::bignum::BigUint;
24
25const RSA_KEYGEN_MIN_BITS: usize = 1024;
26const RSA_KEYGEN_MAX_BITS: usize = 4096;
27const RSA_MIN_SECURE_BITS: usize = 2048;
28const RSA_RECOMMENDED_SECURE_BITS: usize = 3072;
29
30#[derive(Debug, Clone)]
32pub struct RsaPrivateKey {
33 pub n: BigUint,
34 pub d: BigUint,
35 crt: Option<RsaPrivateCrtComponents>,
36}
37
38#[derive(Debug, Clone)]
40pub struct RsaPublicKey {
41 pub n: BigUint,
42 pub e: BigUint,
43}
44
45#[derive(Debug, Clone)]
47struct RsaPrivateCrtComponents {
48 p: BigUint,
49 q: BigUint,
50 dp: BigUint,
51 dq: BigUint,
52 qinv: BigUint,
53}
54
55#[derive(Debug, Copy, Clone, Eq, PartialEq)]
57pub enum RsaKeySizePolicy {
58 Minimum2048,
60 Minimum3072,
62}
63
64impl RsaKeySizePolicy {
65 fn min_bits(self) -> usize {
79 match self {
80 Self::Minimum2048 => RSA_MIN_SECURE_BITS,
81 Self::Minimum3072 => RSA_RECOMMENDED_SECURE_BITS,
82 }
83 }
84}
85
86impl RsaPrivateKey {
87 pub fn from_be_bytes(n: &[u8], d: &[u8]) -> Result<Self> {
102 if n.is_empty() || d.is_empty() {
103 return Err(Error::InvalidLength(
104 "rsa private key fields must not be empty",
105 ));
106 }
107 let key = Self {
108 n: BigUint::from_be_bytes(n),
109 d: BigUint::from_be_bytes(d),
110 crt: None,
111 };
112 if !cfg!(feature = "hazardous-legacy-crypto") && key.n.bit_len() < RSA_MIN_SECURE_BITS {
113 return Err(Error::InvalidLength(
114 "rsa private key modulus must be at least 2048 bits",
115 ));
116 }
117 validate_private_components(&key.n, &key.d)?;
118 Ok(key)
119 }
120
121 #[must_use]
130 pub fn from_u128(n: u128, d: u128) -> Self {
131 Self {
132 n: BigUint::from_u128(n),
133 d: BigUint::from_u128(d),
134 crt: None,
135 }
136 }
137
138 pub fn clear(&mut self) {
143 self.n.clear();
144 self.d.clear();
145 if let Some(crt) = self.crt.as_mut() {
146 crt.p.clear();
147 crt.q.clear();
148 crt.dp.clear();
149 crt.dq.clear();
150 crt.qinv.clear();
151 }
152 self.crt = None;
153 }
154
155 pub fn with_crt_components(
167 mut self,
168 p: &[u8],
169 q: &[u8],
170 dp: &[u8],
171 dq: &[u8],
172 qinv: &[u8],
173 ) -> Result<Self> {
174 let crt = RsaPrivateCrtComponents {
175 p: BigUint::from_be_bytes(p),
176 q: BigUint::from_be_bytes(q),
177 dp: BigUint::from_be_bytes(dp),
178 dq: BigUint::from_be_bytes(dq),
179 qinv: BigUint::from_be_bytes(qinv),
180 };
181 validate_crt_components(&self.n, &crt)?;
182 self.crt = Some(crt);
183 Ok(self)
184 }
185
186 pub fn sign_digest(&self, digest: &[u8]) -> Result<Vec<u8>> {
194 if digest.is_empty() {
195 return Err(Error::InvalidLength("digest must not be empty"));
196 }
197 validate_private_components(&self.n, &self.d)?;
198 let m = BigUint::from_be_bytes(digest).modulo(&self.n);
199 let s = BigUint::mod_exp(&m, &self.d, &self.n);
200 s.to_be_bytes_padded(self.modulus_len())
201 }
202
203 pub fn sign_pkcs1_v15_sha256(&self, msg: &[u8]) -> Result<Vec<u8>> {
211 validate_private_components(&self.n, &self.d)?;
212 let hash = sha256(msg);
213 let em = emsa_pkcs1_v15_encode(
214 &hash,
215 PKCS1_V15_DIGESTINFO_SHA256_PREFIX,
216 self.modulus_len(),
217 )?;
218 let m = BigUint::from_be_bytes(&em);
219 let s = BigUint::mod_exp(&m, &self.d, &self.n);
220 s.to_be_bytes_padded(self.modulus_len())
221 }
222
223 pub fn sign_pkcs1_v15_sha1(&self, msg: &[u8]) -> Result<Vec<u8>> {
231 validate_private_components(&self.n, &self.d)?;
232 let hash = sha1(msg);
233 let em =
234 emsa_pkcs1_v15_encode(&hash, PKCS1_V15_DIGESTINFO_SHA1_PREFIX, self.modulus_len())?;
235 let m = BigUint::from_be_bytes(&em);
236 let s = BigUint::mod_exp(&m, &self.d, &self.n);
237 s.to_be_bytes_padded(self.modulus_len())
238 }
239
240 pub fn sign_pkcs1_v15_sha384(&self, msg: &[u8]) -> Result<Vec<u8>> {
248 validate_private_components(&self.n, &self.d)?;
249 let hash = sha384(msg);
250 let em = emsa_pkcs1_v15_encode(
251 &hash,
252 PKCS1_V15_DIGESTINFO_SHA384_PREFIX,
253 self.modulus_len(),
254 )?;
255 let m = BigUint::from_be_bytes(&em);
256 let s = BigUint::mod_exp(&m, &self.d, &self.n);
257 s.to_be_bytes_padded(self.modulus_len())
258 }
259
260 pub fn sign_pkcs1_v15_sha512(&self, msg: &[u8]) -> Result<Vec<u8>> {
268 validate_private_components(&self.n, &self.d)?;
269 let hash = sha512(msg);
270 let em = emsa_pkcs1_v15_encode(
271 &hash,
272 PKCS1_V15_DIGESTINFO_SHA512_PREFIX,
273 self.modulus_len(),
274 )?;
275 let m = BigUint::from_be_bytes(&em);
276 let s = BigUint::mod_exp(&m, &self.d, &self.n);
277 s.to_be_bytes_padded(self.modulus_len())
278 }
279
280 pub fn sign_pss_sha256(&self, msg: &[u8], salt: &[u8]) -> Result<Vec<u8>> {
289 validate_private_components(&self.n, &self.d)?;
290 let em_bits = self.n.bit_len().saturating_sub(1);
291 let em_len = em_bits.div_ceil(8);
292 let m_hash = sha256(msg);
293 let em = emsa_pss_encode_sha256(&m_hash, salt, em_bits, em_len)?;
294 let s = BigUint::mod_exp(&BigUint::from_be_bytes(&em), &self.d, &self.n);
295 s.to_be_bytes_padded(self.modulus_len())
296 }
297
298 pub fn sign_pss_sha384(&self, msg: &[u8], salt: &[u8]) -> Result<Vec<u8>> {
307 validate_private_components(&self.n, &self.d)?;
308 let em_bits = self.n.bit_len().saturating_sub(1);
309 let em_len = em_bits.div_ceil(8);
310 let m_hash = sha384(msg);
311 let em = emsa_pss_encode_sha384(&m_hash, salt, em_bits, em_len)?;
312 let s = BigUint::mod_exp(&BigUint::from_be_bytes(&em), &self.d, &self.n);
313 s.to_be_bytes_padded(self.modulus_len())
314 }
315
316 pub fn decrypt_pkcs1_v15(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
324 validate_private_components(&self.n, &self.d)?;
325 if ciphertext.len() != self.modulus_len() {
326 return Err(Error::CryptoFailure("rsa decryption failed"));
327 }
328 let em = BigUint::mod_exp(&BigUint::from_be_bytes(ciphertext), &self.d, &self.n)
329 .to_be_bytes_padded(self.modulus_len())?;
330 decode_pkcs1_v15_plaintext(&em)
331 }
332
333 pub fn decrypt_pkcs1_v15_crt_only(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
341 validate_private_components(&self.n, &self.d)?;
342 if ciphertext.len() != self.modulus_len() {
343 return Err(Error::CryptoFailure("rsa decryption failed"));
344 }
345 let crt = self
346 .crt
347 .as_ref()
348 .ok_or(Error::StateError("rsa crt parameters are not configured"))?;
349 let c = BigUint::from_be_bytes(ciphertext);
350 let m1 = BigUint::mod_exp(&c, &crt.dp, &crt.p);
351 let m2 = BigUint::mod_exp(&c, &crt.dq, &crt.q);
352 let diff = if m1.cmp(&m2).is_ge() {
353 m1.sub(&m2)
354 } else {
355 m1.add(&crt.p).sub(&m2)
356 };
357 let h = crt.qinv.mul(&diff).modulo(&crt.p);
358 let m = m2.add(&crt.q.mul(&h));
359 let em = m.to_be_bytes_padded(self.modulus_len())?;
360 decode_pkcs1_v15_plaintext(&em)
361 }
362
363 pub fn decrypt_oaep_sha256(&self, ciphertext: &[u8], label: &[u8]) -> Result<Vec<u8>> {
372 validate_private_components(&self.n, &self.d)?;
373 if ciphertext.len() != self.modulus_len() {
374 return Err(Error::CryptoFailure("rsa decryption failed"));
375 }
376 let em = BigUint::mod_exp(&BigUint::from_be_bytes(ciphertext), &self.d, &self.n)
377 .to_be_bytes_padded(self.modulus_len())?;
378 decode_oaep_sha256_plaintext(&em, label)
379 }
380
381 pub fn decrypt_oaep_sha256_crt_only(&self, ciphertext: &[u8], label: &[u8]) -> Result<Vec<u8>> {
390 validate_private_components(&self.n, &self.d)?;
391 if ciphertext.len() != self.modulus_len() {
392 return Err(Error::CryptoFailure("rsa decryption failed"));
393 }
394 let crt = self
395 .crt
396 .as_ref()
397 .ok_or(Error::StateError("rsa crt parameters are not configured"))?;
398 let c = BigUint::from_be_bytes(ciphertext);
399 let m1 = BigUint::mod_exp(&c, &crt.dp, &crt.p);
400 let m2 = BigUint::mod_exp(&c, &crt.dq, &crt.q);
401 let diff = if m1.cmp(&m2).is_ge() {
402 m1.sub(&m2)
403 } else {
404 m1.add(&crt.p).sub(&m2)
405 };
406 let h = crt.qinv.mul(&diff).modulo(&crt.p);
407 let m = m2.add(&crt.q.mul(&h));
408 let em = m.to_be_bytes_padded(self.modulus_len())?;
409 decode_oaep_sha256_plaintext(&em, label)
410 }
411
412 fn modulus_len(&self) -> usize {
426 self.n.to_be_bytes().len()
427 }
428}
429
430impl Drop for RsaPrivateKey {
431 fn drop(&mut self) {
432 self.clear();
433 }
434}
435
436impl RsaPublicKey {
437 pub fn from_be_bytes(n: &[u8], e: &[u8]) -> Result<Self> {
452 if n.is_empty() || e.is_empty() {
453 return Err(Error::InvalidLength(
454 "rsa public key fields must not be empty",
455 ));
456 }
457 let key = Self {
458 n: BigUint::from_be_bytes(n),
459 e: BigUint::from_be_bytes(e),
460 };
461 if !cfg!(feature = "hazardous-legacy-crypto") && key.n.bit_len() < RSA_MIN_SECURE_BITS {
462 return Err(Error::InvalidLength(
463 "rsa public key modulus must be at least 2048 bits",
464 ));
465 }
466 validate_public_components(&key.n, &key.e)?;
467 Ok(key)
468 }
469
470 #[must_use]
479 pub fn from_u128(n: u128, e: u128) -> Self {
480 Self {
481 n: BigUint::from_u128(n),
482 e: BigUint::from_u128(e),
483 }
484 }
485
486 pub fn clear(&mut self) {
491 self.n = BigUint::zero();
492 self.e = BigUint::zero();
493 }
494
495 pub fn verify_digest(&self, digest: &[u8], signature: &[u8]) -> Result<()> {
504 if digest.is_empty() {
505 return Err(Error::InvalidLength("digest must not be empty"));
506 }
507 validate_public_components(&self.n, &self.e)?;
508 let k = self.modulus_len();
509 let expected = BigUint::from_be_bytes(digest)
510 .modulo(&self.n)
511 .to_be_bytes_padded(k)?;
512 let recovered = BigUint::mod_exp(&BigUint::from_be_bytes(signature), &self.e, &self.n)
513 .to_be_bytes_padded(k)?;
514 if ct_bytes_eq(&recovered, &expected) {
515 Ok(())
516 } else {
517 Err(Error::CryptoFailure("RSA verification failed"))
518 }
519 }
520
521 pub fn verify_pkcs1_v15_sha256(&self, msg: &[u8], signature: &[u8]) -> Result<()> {
530 validate_public_components(&self.n, &self.e)?;
531 if signature.len() != self.modulus_len() {
532 return Err(Error::InvalidLength("rsa signature length mismatch"));
533 }
534 let recovered = BigUint::mod_exp(&BigUint::from_be_bytes(signature), &self.e, &self.n)
535 .to_be_bytes_padded(self.modulus_len())?;
536 let expected = emsa_pkcs1_v15_encode(
537 &sha256(msg),
538 PKCS1_V15_DIGESTINFO_SHA256_PREFIX,
539 self.modulus_len(),
540 )?;
541 if ct_bytes_eq(&recovered, &expected) {
542 Ok(())
543 } else {
544 Err(Error::CryptoFailure("RSA verification failed"))
545 }
546 }
547
548 pub fn verify_pkcs1_v15_sha1(&self, msg: &[u8], signature: &[u8]) -> Result<()> {
557 validate_public_components(&self.n, &self.e)?;
558 if signature.len() != self.modulus_len() {
559 return Err(Error::InvalidLength("rsa signature length mismatch"));
560 }
561 let recovered = BigUint::mod_exp(&BigUint::from_be_bytes(signature), &self.e, &self.n)
562 .to_be_bytes_padded(self.modulus_len())?;
563 let expected = emsa_pkcs1_v15_encode(
564 &sha1(msg),
565 PKCS1_V15_DIGESTINFO_SHA1_PREFIX,
566 self.modulus_len(),
567 )?;
568 if ct_bytes_eq(&recovered, &expected) {
569 Ok(())
570 } else {
571 Err(Error::CryptoFailure("RSA verification failed"))
572 }
573 }
574
575 pub fn verify_pkcs1_v15_sha384(&self, msg: &[u8], signature: &[u8]) -> Result<()> {
584 validate_public_components(&self.n, &self.e)?;
585 if signature.len() != self.modulus_len() {
586 return Err(Error::InvalidLength("rsa signature length mismatch"));
587 }
588 let recovered = BigUint::mod_exp(&BigUint::from_be_bytes(signature), &self.e, &self.n)
589 .to_be_bytes_padded(self.modulus_len())?;
590 let expected = emsa_pkcs1_v15_encode(
591 &sha384(msg),
592 PKCS1_V15_DIGESTINFO_SHA384_PREFIX,
593 self.modulus_len(),
594 )?;
595 if ct_bytes_eq(&recovered, &expected) {
596 Ok(())
597 } else {
598 Err(Error::CryptoFailure("RSA verification failed"))
599 }
600 }
601
602 pub fn verify_pkcs1_v15_sha512(&self, msg: &[u8], signature: &[u8]) -> Result<()> {
611 validate_public_components(&self.n, &self.e)?;
612 if signature.len() != self.modulus_len() {
613 return Err(Error::InvalidLength("rsa signature length mismatch"));
614 }
615 let recovered = BigUint::mod_exp(&BigUint::from_be_bytes(signature), &self.e, &self.n)
616 .to_be_bytes_padded(self.modulus_len())?;
617 let expected = emsa_pkcs1_v15_encode(
618 &sha512(msg),
619 PKCS1_V15_DIGESTINFO_SHA512_PREFIX,
620 self.modulus_len(),
621 )?;
622 if ct_bytes_eq(&recovered, &expected) {
623 Ok(())
624 } else {
625 Err(Error::CryptoFailure("RSA verification failed"))
626 }
627 }
628
629 pub fn verify_pss_sha256(&self, msg: &[u8], signature: &[u8], salt_len: usize) -> Result<()> {
639 validate_public_components(&self.n, &self.e)?;
640 if signature.len() != self.modulus_len() {
641 return Err(Error::InvalidLength("rsa signature length mismatch"));
642 }
643 let em_bits = self.n.bit_len().saturating_sub(1);
644 let em_len = em_bits.div_ceil(8);
645 let recovered = BigUint::mod_exp(&BigUint::from_be_bytes(signature), &self.e, &self.n)
646 .to_be_bytes_padded(self.modulus_len())?;
647 let em = &recovered[recovered.len() - em_len..];
648 emsa_pss_verify_sha256(&sha256(msg), em, em_bits, salt_len)
649 }
650
651 pub fn verify_pss_sha384(&self, msg: &[u8], signature: &[u8], salt_len: usize) -> Result<()> {
661 validate_public_components(&self.n, &self.e)?;
662 if signature.len() != self.modulus_len() {
663 return Err(Error::InvalidLength("rsa signature length mismatch"));
664 }
665 let em_bits = self.n.bit_len().saturating_sub(1);
666 let em_len = em_bits.div_ceil(8);
667 let recovered = BigUint::mod_exp(&BigUint::from_be_bytes(signature), &self.e, &self.n)
668 .to_be_bytes_padded(self.modulus_len())?;
669 let em = &recovered[recovered.len() - em_len..];
670 emsa_pss_verify_sha384(&sha384(msg), em, em_bits, salt_len)
671 }
672
673 pub fn encrypt_pkcs1_v15_auto(
682 &self,
683 plaintext: &[u8],
684 drbg: &mut HmacDrbgSha256,
685 ) -> Result<Vec<u8>> {
686 validate_public_components(&self.n, &self.e)?;
687 let k = self.modulus_len();
688 if plaintext.len() > k.saturating_sub(11) {
689 return Err(Error::InvalidLength(
690 "rsa plaintext too long for pkcs1 v1.5 encryption",
691 ));
692 }
693 let ps_len = k - plaintext.len() - 3;
694 let ps = drbg_nonzero_padding(drbg, ps_len)?;
695 let mut em = Vec::with_capacity(k);
696 em.push(0x00);
697 em.push(0x02);
698 em.extend_from_slice(&ps);
699 em.push(0x00);
700 em.extend_from_slice(plaintext);
701 let c = BigUint::mod_exp(&BigUint::from_be_bytes(&em), &self.e, &self.n);
702 c.to_be_bytes_padded(k)
703 }
704
705 pub fn encrypt_oaep_sha256_auto(
715 &self,
716 plaintext: &[u8],
717 label: &[u8],
718 drbg: &mut HmacDrbgSha256,
719 ) -> Result<Vec<u8>> {
720 validate_public_components(&self.n, &self.e)?;
721 let k = self.modulus_len();
722 let seed = drbg.generate(32, b"rsa_oaep_sha256_seed")?;
723 let em = emea_oaep_encode_sha256(plaintext, label, &seed, k)?;
724 let c = BigUint::mod_exp(&BigUint::from_be_bytes(&em), &self.e, &self.n);
725 c.to_be_bytes_padded(k)
726 }
727
728 fn modulus_len(&self) -> usize {
742 self.n.to_be_bytes().len()
743 }
744}
745
746#[cfg(feature = "hazardous-legacy-crypto")]
756pub fn rsa_generate_keypair_with_exponent_auto(
757 modulus_bits: usize,
758 public_exponent: u32,
759 drbg: &mut HmacDrbgSha256,
760) -> Result<(RsaPrivateKey, RsaPublicKey)> {
761 rsa_generate_keypair_backend_auto(modulus_bits, public_exponent, drbg)
762}
763
764fn rsa_generate_keypair_backend_auto(
784 modulus_bits: usize,
785 public_exponent: u32,
786 drbg: &mut HmacDrbgSha256,
787) -> Result<(RsaPrivateKey, RsaPublicKey)> {
788 if !(RSA_KEYGEN_MIN_BITS..=RSA_KEYGEN_MAX_BITS).contains(&modulus_bits) {
789 return Err(Error::InvalidLength(
790 "rsa modulus bits must be in supported range 1024..=4096",
791 ));
792 }
793 if public_exponent < 3 || (public_exponent & 1) == 0 {
794 return Err(Error::CryptoFailure(
795 "rsa public exponent must be odd and at least 3",
796 ));
797 }
798 let e = BigUint::from_u128(u128::from(public_exponent));
799 let one = BigUint::one();
800 let p_bits = modulus_bits / 2;
801 let q_bits = modulus_bits - p_bits;
802 let mut attempts = 0_u32;
803 while attempts < 256 {
804 let mut p = generate_rsa_prime_candidate_auto(p_bits, &e, drbg)?;
805 let mut q = generate_rsa_prime_candidate_auto(q_bits, &e, drbg)?;
806 let mut distinct_attempts = 0_u32;
807 while p.cmp(&q).is_eq() {
808 if distinct_attempts >= 32 {
809 break;
810 }
811 q = generate_rsa_prime_candidate_auto(q_bits, &e, drbg)?;
812 distinct_attempts = distinct_attempts.saturating_add(1);
813 }
814 if p.cmp(&q).is_eq() {
815 attempts = attempts.saturating_add(1);
816 continue;
817 }
818 if p.cmp(&q).is_gt() {
819 core::mem::swap(&mut p, &mut q);
820 }
821 let n = p.mul(&q);
822 if n.bit_len() != modulus_bits {
823 attempts = attempts.saturating_add(1);
824 continue;
825 }
826 let pm1 = p.sub(&one);
827 let qm1 = q.sub(&one);
828 let phi = pm1.mul(&qm1);
829 if BigUint::gcd(&e, &phi).cmp(&one).is_ne() {
830 attempts = attempts.saturating_add(1);
831 continue;
832 }
833 let Some(d) = BigUint::mod_inverse(&e, &phi) else {
834 attempts = attempts.saturating_add(1);
835 continue;
836 };
837 let dp = d.modulo(&pm1);
838 let dq = d.modulo(&qm1);
839 let Some(qinv) = BigUint::mod_inverse(&q, &p) else {
840 attempts = attempts.saturating_add(1);
841 continue;
842 };
843 let private = RsaPrivateKey {
844 n: n.clone(),
845 d,
846 crt: Some(RsaPrivateCrtComponents { p, q, dp, dq, qinv }),
847 };
848 let public = RsaPublicKey { n, e };
849 validate_private_components(&private.n, &private.d)?;
850 validate_public_components(&public.n, &public.e)?;
851 validate_crt_components(&private.n, private.crt.as_ref().expect("crt must exist"))?;
852 return Ok((private, public));
853 }
854 Err(Error::StateError(
855 "rsa key generation exhausted attempt budget",
856 ))
857}
858
859#[cfg(feature = "hazardous-legacy-crypto")]
868pub fn rsa_generate_keypair_auto(
869 modulus_bits: usize,
870 drbg: &mut HmacDrbgSha256,
871) -> Result<(RsaPrivateKey, RsaPublicKey)> {
872 rsa_generate_keypair_backend_auto(modulus_bits, 65_537, drbg)
873}
874
875pub fn rsa_generate_keypair_with_policy_auto(
886 modulus_bits: usize,
887 public_exponent: u32,
888 policy: RsaKeySizePolicy,
889 drbg: &mut HmacDrbgSha256,
890) -> Result<(RsaPrivateKey, RsaPublicKey)> {
891 if !(RSA_MIN_SECURE_BITS..=RSA_KEYGEN_MAX_BITS).contains(&modulus_bits) {
892 return Err(Error::InvalidLength(
893 "secure rsa modulus bits must be in supported range 2048..=4096",
894 ));
895 }
896 if modulus_bits < policy.min_bits() {
897 return Err(Error::InvalidLength(
898 "rsa modulus bits do not satisfy configured secure policy minimum",
899 ));
900 }
901 rsa_generate_keypair_backend_auto(modulus_bits, public_exponent, drbg)
902}
903
904pub fn rsa_generate_keypair_secure_auto(
914 modulus_bits: usize,
915 policy: RsaKeySizePolicy,
916 drbg: &mut HmacDrbgSha256,
917) -> Result<(RsaPrivateKey, RsaPublicKey)> {
918 rsa_generate_keypair_with_policy_auto(modulus_bits, 65_537, policy, drbg)
919}
920
921pub fn rsassa_sha256_sign(private: &RsaPrivateKey, msg: &[u8]) -> Result<Vec<u8>> {
930 private.sign_pkcs1_v15_sha256(msg)
931}
932
933pub fn rsassa_sha256_verify(public: &RsaPublicKey, msg: &[u8], signature: &[u8]) -> Result<()> {
943 public.verify_pkcs1_v15_sha256(msg, signature)
944}
945
946pub fn rsassa_sha1_sign(private: &RsaPrivateKey, msg: &[u8]) -> Result<Vec<u8>> {
955 private.sign_pkcs1_v15_sha1(msg)
956}
957
958pub fn rsassa_sha1_verify(public: &RsaPublicKey, msg: &[u8], signature: &[u8]) -> Result<()> {
968 public.verify_pkcs1_v15_sha1(msg, signature)
969}
970
971pub fn rsassa_sha384_sign(private: &RsaPrivateKey, msg: &[u8]) -> Result<Vec<u8>> {
980 private.sign_pkcs1_v15_sha384(msg)
981}
982
983pub fn rsassa_sha384_verify(public: &RsaPublicKey, msg: &[u8], signature: &[u8]) -> Result<()> {
993 public.verify_pkcs1_v15_sha384(msg, signature)
994}
995
996pub fn rsassa_sha512_sign(private: &RsaPrivateKey, msg: &[u8]) -> Result<Vec<u8>> {
1005 private.sign_pkcs1_v15_sha512(msg)
1006}
1007
1008pub fn rsassa_sha512_verify(public: &RsaPublicKey, msg: &[u8], signature: &[u8]) -> Result<()> {
1018 public.verify_pkcs1_v15_sha512(msg, signature)
1019}
1020
1021pub fn rsassa_pss_sha256_sign(private: &RsaPrivateKey, msg: &[u8], salt: &[u8]) -> Result<Vec<u8>> {
1031 private.sign_pss_sha256(msg, salt)
1032}
1033
1034pub fn rsassa_pss_sha256_sign_auto(
1045 private: &RsaPrivateKey,
1046 msg: &[u8],
1047 drbg: &mut HmacDrbgSha256,
1048 salt_len: usize,
1049) -> Result<Vec<u8>> {
1050 let salt = drbg.generate(salt_len, b"rsa_pss_sha256_salt")?;
1051 private.sign_pss_sha256(msg, &salt)
1052}
1053
1054pub fn rsassa_pss_sha256_verify(
1065 public: &RsaPublicKey,
1066 msg: &[u8],
1067 signature: &[u8],
1068 salt_len: usize,
1069) -> Result<()> {
1070 public.verify_pss_sha256(msg, signature, salt_len)
1071}
1072
1073pub fn rsassa_pss_sha384_sign(private: &RsaPrivateKey, msg: &[u8], salt: &[u8]) -> Result<Vec<u8>> {
1083 private.sign_pss_sha384(msg, salt)
1084}
1085
1086pub fn rsassa_pss_sha384_sign_auto(
1097 private: &RsaPrivateKey,
1098 msg: &[u8],
1099 drbg: &mut HmacDrbgSha256,
1100 salt_len: usize,
1101) -> Result<Vec<u8>> {
1102 let salt = drbg.generate(salt_len, b"rsa_pss_sha384_salt")?;
1103 private.sign_pss_sha384(msg, &salt)
1104}
1105
1106pub fn rsassa_pss_sha384_verify(
1117 public: &RsaPublicKey,
1118 msg: &[u8],
1119 signature: &[u8],
1120 salt_len: usize,
1121) -> Result<()> {
1122 public.verify_pss_sha384(msg, signature, salt_len)
1123}
1124
1125pub fn rsaes_pkcs1_v15_encrypt_auto(
1135 public: &RsaPublicKey,
1136 plaintext: &[u8],
1137 drbg: &mut HmacDrbgSha256,
1138) -> Result<Vec<u8>> {
1139 public.encrypt_pkcs1_v15_auto(plaintext, drbg)
1140}
1141
1142pub fn rsaes_pkcs1_v15_decrypt(private: &RsaPrivateKey, ciphertext: &[u8]) -> Result<Vec<u8>> {
1151 private.decrypt_pkcs1_v15(ciphertext)
1152}
1153
1154pub fn rsaes_pkcs1_v15_decrypt_crt_only(
1168 private: &RsaPrivateKey,
1169 ciphertext: &[u8],
1170) -> Result<Vec<u8>> {
1171 private.decrypt_pkcs1_v15_crt_only(ciphertext)
1172}
1173
1174pub fn rsaes_oaep_sha256_encrypt_auto(
1185 public: &RsaPublicKey,
1186 plaintext: &[u8],
1187 label: &[u8],
1188 drbg: &mut HmacDrbgSha256,
1189) -> Result<Vec<u8>> {
1190 public.encrypt_oaep_sha256_auto(plaintext, label, drbg)
1191}
1192
1193pub fn rsaes_oaep_sha256_decrypt(
1203 private: &RsaPrivateKey,
1204 ciphertext: &[u8],
1205 label: &[u8],
1206) -> Result<Vec<u8>> {
1207 private.decrypt_oaep_sha256(ciphertext, label)
1208}
1209
1210pub fn rsaes_oaep_sha256_decrypt_crt_only(
1220 private: &RsaPrivateKey,
1221 ciphertext: &[u8],
1222 label: &[u8],
1223) -> Result<Vec<u8>> {
1224 private.decrypt_oaep_sha256_crt_only(ciphertext, label)
1225}
1226
1227const PKCS1_V15_DIGESTINFO_SHA1_PREFIX: &[u8] = &[
1228 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2B, 0x0E, 0x03, 0x02, 0x1A, 0x05, 0x00, 0x04, 0x14,
1229];
1230const PKCS1_V15_DIGESTINFO_SHA256_PREFIX: &[u8] = &[
1231 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05,
1232 0x00, 0x04, 0x20,
1233];
1234const PKCS1_V15_DIGESTINFO_SHA384_PREFIX: &[u8] = &[
1235 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05,
1236 0x00, 0x04, 0x30,
1237];
1238const PKCS1_V15_DIGESTINFO_SHA512_PREFIX: &[u8] = &[
1239 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05,
1240 0x00, 0x04, 0x40,
1241];
1242
1243fn emsa_pkcs1_v15_encode(hash: &[u8], digest_info_prefix: &[u8], k: usize) -> Result<Vec<u8>> {
1263 let t_len = digest_info_prefix.len() + hash.len();
1264 if k < t_len + 11 {
1265 return Err(Error::InvalidLength("rsa modulus too short for pkcs1 v1.5"));
1266 }
1267 let ps_len = k - t_len - 3;
1268 let mut em = Vec::with_capacity(k);
1269 em.push(0x00);
1270 em.push(0x01);
1271 em.extend(core::iter::repeat_n(0xff_u8, ps_len));
1272 em.push(0x00);
1273 em.extend_from_slice(digest_info_prefix);
1274 em.extend_from_slice(hash);
1275 Ok(em)
1276}
1277
1278fn emsa_pss_encode_sha256(
1299 m_hash: &[u8; 32],
1300 salt: &[u8],
1301 em_bits: usize,
1302 em_len: usize,
1303) -> Result<Vec<u8>> {
1304 const HASH_LEN: usize = 32;
1305 if em_len < HASH_LEN + salt.len() + 2 {
1306 return Err(Error::InvalidLength("rsa modulus too short for pss"));
1307 }
1308
1309 let mut m_prime = vec![0_u8; 8];
1310 m_prime.extend_from_slice(m_hash);
1311 m_prime.extend_from_slice(salt);
1312 let h = sha256(&m_prime);
1313
1314 let ps_len = em_len - salt.len() - HASH_LEN - 2;
1315 let mut db = vec![0_u8; ps_len];
1316 db.push(0x01);
1317 db.extend_from_slice(salt);
1318
1319 let db_mask = mgf1_sha256(&h, em_len - HASH_LEN - 1)?;
1320 for (byte, mask) in db.iter_mut().zip(db_mask.iter()) {
1321 *byte ^= *mask;
1322 }
1323
1324 let unused_bits = 8 * em_len - em_bits;
1325 if unused_bits > 0 {
1326 db[0] &= 0xff_u8 >> unused_bits;
1327 }
1328
1329 let mut em = db;
1330 em.extend_from_slice(&h);
1331 em.push(0xbc);
1332 Ok(em)
1333}
1334
1335fn emsa_pss_verify_sha256(
1356 m_hash: &[u8; 32],
1357 em: &[u8],
1358 em_bits: usize,
1359 salt_len: usize,
1360) -> Result<()> {
1361 const HASH_LEN: usize = 32;
1362 if em.len() < HASH_LEN + salt_len + 2 {
1363 return Err(Error::InvalidLength("rsa modulus too short for pss"));
1364 }
1365 if em.last().copied() != Some(0xbc) {
1366 return Err(Error::CryptoFailure("RSA verification failed"));
1367 }
1368
1369 let db_len = em.len() - HASH_LEN - 1;
1370 let (masked_db, rest) = em.split_at(db_len);
1371 let h = &rest[..HASH_LEN];
1372
1373 let unused_bits = 8 * em.len() - em_bits;
1374 if unused_bits > 0 {
1375 let mask = 0xff_u8 << (8 - unused_bits);
1376 if masked_db[0] & mask != 0 {
1377 return Err(Error::CryptoFailure("RSA verification failed"));
1378 }
1379 }
1380
1381 let db_mask = mgf1_sha256(h, db_len)?;
1382 let mut db = masked_db.to_vec();
1383 for (byte, mask) in db.iter_mut().zip(db_mask.iter()) {
1384 *byte ^= *mask;
1385 }
1386 if unused_bits > 0 {
1387 db[0] &= 0xff_u8 >> unused_bits;
1388 }
1389
1390 let ps_len = em.len() - HASH_LEN - salt_len - 2;
1391 if !ct_all_zero(&db[..ps_len]) || db[ps_len] != 0x01 {
1392 return Err(Error::CryptoFailure("RSA verification failed"));
1393 }
1394 let salt = &db[db.len() - salt_len..];
1395
1396 let mut m_prime = vec![0_u8; 8];
1397 m_prime.extend_from_slice(m_hash);
1398 m_prime.extend_from_slice(salt);
1399 let expected_h = sha256(&m_prime);
1400 if ct_bytes_eq(expected_h.as_slice(), h) {
1401 Ok(())
1402 } else {
1403 Err(Error::CryptoFailure("RSA verification failed"))
1404 }
1405}
1406
1407fn emsa_pss_encode_sha384(
1428 m_hash: &[u8; 48],
1429 salt: &[u8],
1430 em_bits: usize,
1431 em_len: usize,
1432) -> Result<Vec<u8>> {
1433 const HASH_LEN: usize = 48;
1434 if em_len < HASH_LEN + salt.len() + 2 {
1435 return Err(Error::InvalidLength("rsa modulus too short for pss"));
1436 }
1437
1438 let mut m_prime = vec![0_u8; 8];
1439 m_prime.extend_from_slice(m_hash);
1440 m_prime.extend_from_slice(salt);
1441 let h = sha384(&m_prime);
1442
1443 let ps_len = em_len - salt.len() - HASH_LEN - 2;
1444 let mut db = vec![0_u8; ps_len];
1445 db.push(0x01);
1446 db.extend_from_slice(salt);
1447
1448 let db_mask = mgf1_sha384(&h, em_len - HASH_LEN - 1)?;
1449 for (byte, mask) in db.iter_mut().zip(db_mask.iter()) {
1450 *byte ^= *mask;
1451 }
1452
1453 let unused_bits = 8 * em_len - em_bits;
1454 if unused_bits > 0 {
1455 db[0] &= 0xff_u8 >> unused_bits;
1456 }
1457
1458 let mut em = db;
1459 em.extend_from_slice(&h);
1460 em.push(0xbc);
1461 Ok(em)
1462}
1463
1464fn emsa_pss_verify_sha384(
1485 m_hash: &[u8; 48],
1486 em: &[u8],
1487 em_bits: usize,
1488 salt_len: usize,
1489) -> Result<()> {
1490 const HASH_LEN: usize = 48;
1491 if em.len() < HASH_LEN + salt_len + 2 {
1492 return Err(Error::InvalidLength("rsa modulus too short for pss"));
1493 }
1494 if em.last().copied() != Some(0xbc) {
1495 return Err(Error::CryptoFailure("RSA verification failed"));
1496 }
1497
1498 let db_len = em.len() - HASH_LEN - 1;
1499 let (masked_db, rest) = em.split_at(db_len);
1500 let h = &rest[..HASH_LEN];
1501
1502 let unused_bits = 8 * em.len() - em_bits;
1503 if unused_bits > 0 {
1504 let mask = 0xff_u8 << (8 - unused_bits);
1505 if masked_db[0] & mask != 0 {
1506 return Err(Error::CryptoFailure("RSA verification failed"));
1507 }
1508 }
1509
1510 let db_mask = mgf1_sha384(h, db_len)?;
1511 let mut db = masked_db.to_vec();
1512 for (byte, mask) in db.iter_mut().zip(db_mask.iter()) {
1513 *byte ^= *mask;
1514 }
1515 if unused_bits > 0 {
1516 db[0] &= 0xff_u8 >> unused_bits;
1517 }
1518
1519 let ps_len = em.len() - HASH_LEN - salt_len - 2;
1520 if !ct_all_zero(&db[..ps_len]) || db[ps_len] != 0x01 {
1521 return Err(Error::CryptoFailure("RSA verification failed"));
1522 }
1523 let salt = &db[db.len() - salt_len..];
1524
1525 let mut m_prime = vec![0_u8; 8];
1526 m_prime.extend_from_slice(m_hash);
1527 m_prime.extend_from_slice(salt);
1528 let expected_h = sha384(&m_prime);
1529 if ct_bytes_eq(expected_h.as_slice(), h) {
1530 Ok(())
1531 } else {
1532 Err(Error::CryptoFailure("RSA verification failed"))
1533 }
1534}
1535
1536fn mgf1_sha256(seed: &[u8], out_len: usize) -> Result<Vec<u8>> {
1555 let mut out = Vec::with_capacity(out_len);
1556 let mut counter = 0_u32;
1557 while out.len() < out_len {
1558 if counter == u32::MAX {
1559 return Err(Error::InvalidLength("mgf1 output too large"));
1560 }
1561 let mut block_input = Vec::with_capacity(seed.len() + 4);
1562 block_input.extend_from_slice(seed);
1563 block_input.extend_from_slice(&counter.to_be_bytes());
1564 out.extend_from_slice(&sha256(&block_input));
1565 counter = counter.wrapping_add(1);
1566 }
1567 out.truncate(out_len);
1568 Ok(out)
1569}
1570
1571fn mgf1_sha384(seed: &[u8], out_len: usize) -> Result<Vec<u8>> {
1590 let mut out = Vec::with_capacity(out_len);
1591 let mut counter = 0_u32;
1592 while out.len() < out_len {
1593 if counter == u32::MAX {
1594 return Err(Error::InvalidLength("mgf1 output too large"));
1595 }
1596 let mut block_input = Vec::with_capacity(seed.len() + 4);
1597 block_input.extend_from_slice(seed);
1598 block_input.extend_from_slice(&counter.to_be_bytes());
1599 out.extend_from_slice(&sha384(&block_input));
1600 counter = counter.wrapping_add(1);
1601 }
1602 out.truncate(out_len);
1603 Ok(out)
1604}
1605
1606fn drbg_nonzero_padding(drbg: &mut HmacDrbgSha256, len: usize) -> Result<Vec<u8>> {
1625 let mut out = Vec::with_capacity(len);
1626 while out.len() < len {
1627 let block = drbg.generate(len.saturating_sub(out.len()), b"rsa_pkcs1_v15_ps")?;
1628 for byte in block {
1629 if byte != 0 {
1630 out.push(byte);
1631 if out.len() == len {
1632 break;
1633 }
1634 }
1635 }
1636 }
1637 Ok(out)
1638}
1639
1640fn emea_oaep_encode_sha256(
1661 plaintext: &[u8],
1662 label: &[u8],
1663 seed: &[u8],
1664 k: usize,
1665) -> Result<Vec<u8>> {
1666 const HASH_LEN: usize = 32;
1667 if seed.len() != HASH_LEN {
1668 return Err(Error::InvalidLength("rsa oaep seed must be 32 bytes"));
1669 }
1670 if k < (2 * HASH_LEN + 2) {
1671 return Err(Error::InvalidLength(
1672 "rsa modulus too short for oaep sha256",
1673 ));
1674 }
1675 if plaintext.len() > k - (2 * HASH_LEN + 2) {
1676 return Err(Error::InvalidLength(
1677 "rsa plaintext too long for oaep sha256",
1678 ));
1679 }
1680 let l_hash = sha256(label);
1681 let ps_len = k - plaintext.len() - (2 * HASH_LEN + 2);
1682 let mut db = Vec::with_capacity(k - HASH_LEN - 1);
1683 db.extend_from_slice(&l_hash);
1684 db.extend(core::iter::repeat_n(0_u8, ps_len));
1685 db.push(0x01);
1686 db.extend_from_slice(plaintext);
1687 let db_mask = mgf1_sha256(seed, k - HASH_LEN - 1)?;
1688 for (byte, mask) in db.iter_mut().zip(db_mask.iter()) {
1689 *byte ^= *mask;
1690 }
1691 let seed_mask = mgf1_sha256(&db, HASH_LEN)?;
1692 let mut masked_seed = seed.to_vec();
1693 for (byte, mask) in masked_seed.iter_mut().zip(seed_mask.iter()) {
1694 *byte ^= *mask;
1695 }
1696 let mut em = Vec::with_capacity(k);
1697 em.push(0x00);
1698 em.extend_from_slice(&masked_seed);
1699 em.extend_from_slice(&db);
1700 Ok(em)
1701}
1702
1703fn decode_oaep_sha256_plaintext(encoded: &[u8], label: &[u8]) -> Result<Vec<u8>> {
1722 const HASH_LEN: usize = 32;
1723 if encoded.len() < (2 * HASH_LEN + 2) {
1724 return Err(Error::InvalidLength(
1725 "rsa modulus too short for oaep sha256",
1726 ));
1727 }
1728 let mut invalid = 0_u8;
1729 invalid |= encoded[0];
1730 let (masked_seed, masked_db) = encoded[1..].split_at(HASH_LEN);
1731 let seed_mask = mgf1_sha256(masked_db, HASH_LEN)?;
1732 let mut seed = masked_seed.to_vec();
1733 for (byte, mask) in seed.iter_mut().zip(seed_mask.iter()) {
1734 *byte ^= *mask;
1735 }
1736 let db_mask = mgf1_sha256(&seed, masked_db.len())?;
1737 let mut db = masked_db.to_vec();
1738 for (byte, mask) in db.iter_mut().zip(db_mask.iter()) {
1739 *byte ^= *mask;
1740 }
1741 let expected_l_hash = sha256(label);
1742 invalid |= u8::from(!ct_bytes_eq(&db[..HASH_LEN], expected_l_hash.as_slice()));
1743 let rest = &db[HASH_LEN..];
1744 let mut marker_idx = 0_usize;
1745 let mut found_marker = 0_u8;
1746 let mut invalid_ps = 0_u8;
1747 for (idx, &byte) in rest.iter().enumerate() {
1748 let is_zero = u8::from(byte == 0);
1749 let is_one = u8::from(byte == 1);
1750 let before_marker = 1_u8 ^ found_marker;
1751 let should_set = before_marker & is_one;
1752 marker_idx = ct_select_usize(should_set, idx, marker_idx);
1753 invalid_ps |= before_marker & (1_u8 ^ is_zero) & (1_u8 ^ is_one);
1754 found_marker |= is_one;
1755 }
1756 invalid |= invalid_ps;
1757 invalid |= 1_u8 ^ found_marker;
1758 if invalid != 0 {
1759 return Err(Error::CryptoFailure("rsa decryption failed"));
1760 }
1761 Ok(rest[marker_idx.saturating_add(1)..].to_vec())
1762}
1763
1764fn decode_pkcs1_v15_plaintext(encoded: &[u8]) -> Result<Vec<u8>> {
1782 if encoded.len() < 11 {
1783 return Err(Error::CryptoFailure("rsa decryption failed"));
1784 }
1785 let mut invalid = 0_u8;
1786 invalid |= encoded[0];
1787 invalid |= encoded[1] ^ 0x02;
1788
1789 let mut sep_idx = 0_usize;
1790 let mut found_sep = 0_u8;
1791 for (idx, &byte) in encoded.iter().enumerate().skip(2) {
1792 let is_zero = u8::from(byte == 0);
1793 let should_set = is_zero & (1_u8 ^ found_sep);
1794 sep_idx = ct_select_usize(should_set, idx, sep_idx);
1795 found_sep |= is_zero;
1796 }
1797 if found_sep == 0 {
1798 invalid |= 1;
1799 }
1800 if sep_idx < 10 {
1801 invalid |= 1;
1802 }
1803 if invalid != 0 {
1804 return Err(Error::CryptoFailure("rsa decryption failed"));
1805 }
1806 Ok(encoded[sep_idx + 1..].to_vec())
1807}
1808
1809fn ct_bytes_eq(left: &[u8], right: &[u8]) -> bool {
1824 if left.len() != right.len() {
1825 return false;
1826 }
1827 let mut diff = 0_u8;
1828 for (&l, &r) in left.iter().zip(right.iter()) {
1829 diff |= l ^ r;
1830 }
1831 diff == 0
1832}
1833
1834fn ct_all_zero(bytes: &[u8]) -> bool {
1848 let mut acc = 0_u8;
1849 for &byte in bytes {
1850 acc |= byte;
1851 }
1852 acc == 0
1853}
1854
1855fn ct_select_usize(selector: u8, if_one: usize, if_zero: usize) -> usize {
1871 let mask = (0_usize).wrapping_sub(usize::from(selector));
1872 (if_one & mask) | (if_zero & !mask)
1873}
1874
1875fn validate_private_components(n: &BigUint, d: &BigUint) -> Result<()> {
1894 validate_modulus(n)?;
1895 if d.is_zero() {
1896 return Err(Error::CryptoFailure(
1897 "rsa private exponent must be non-zero",
1898 ));
1899 }
1900 if !d.is_odd() {
1901 return Err(Error::CryptoFailure("rsa private exponent must be odd"));
1902 }
1903 if d.cmp(n).is_ge() {
1904 return Err(Error::CryptoFailure(
1905 "rsa private exponent must be smaller than modulus",
1906 ));
1907 }
1908 Ok(())
1909}
1910
1911fn validate_public_components(n: &BigUint, e: &BigUint) -> Result<()> {
1930 validate_modulus(n)?;
1931 let three = BigUint::from_u128(3);
1932 if e.cmp(&three).is_lt() {
1933 return Err(Error::CryptoFailure(
1934 "rsa public exponent must be at least 3",
1935 ));
1936 }
1937 if !e.is_odd() {
1938 return Err(Error::CryptoFailure("rsa public exponent must be odd"));
1939 }
1940 if e.cmp(n).is_ge() {
1941 return Err(Error::CryptoFailure(
1942 "rsa public exponent must be smaller than modulus",
1943 ));
1944 }
1945 Ok(())
1946}
1947
1948fn validate_modulus(n: &BigUint) -> Result<()> {
1966 let three = BigUint::from_u128(3);
1967 if n.cmp(&three).is_lt() {
1968 return Err(Error::CryptoFailure("rsa modulus must be greater than 3"));
1969 }
1970 if !n.is_odd() {
1971 return Err(Error::CryptoFailure("rsa modulus must be odd"));
1972 }
1973 Ok(())
1974}
1975
1976fn validate_crt_components(n: &BigUint, crt: &RsaPrivateCrtComponents) -> Result<()> {
1995 if crt.p.is_zero()
1996 || crt.q.is_zero()
1997 || crt.dp.is_zero()
1998 || crt.dq.is_zero()
1999 || crt.qinv.is_zero()
2000 {
2001 return Err(Error::CryptoFailure("rsa crt parameters must be non-zero"));
2002 }
2003 if !crt.p.is_odd() || !crt.q.is_odd() {
2004 return Err(Error::CryptoFailure("rsa crt primes must be odd"));
2005 }
2006 if crt.p.mul(&crt.q).cmp(n).is_ne() {
2007 return Err(Error::CryptoFailure(
2008 "rsa crt prime product must equal modulus",
2009 ));
2010 }
2011 if crt.dp.cmp(&crt.p).is_ge() || crt.dq.cmp(&crt.q).is_ge() {
2012 return Err(Error::CryptoFailure("rsa crt exponents must be reduced"));
2013 }
2014 if crt.qinv.cmp(&crt.p).is_ge() {
2015 return Err(Error::CryptoFailure(
2016 "rsa crt coefficient must be smaller than p",
2017 ));
2018 }
2019 let one = BigUint::one();
2020 if crt.q.mul(&crt.qinv).modulo(&crt.p).cmp(&one).is_ne() {
2021 return Err(Error::CryptoFailure(
2022 "rsa crt coefficient must be inverse of q modulo p",
2023 ));
2024 }
2025 Ok(())
2026}
2027
2028fn generate_rsa_prime_candidate_auto(
2048 bits: usize,
2049 e: &BigUint,
2050 drbg: &mut HmacDrbgSha256,
2051) -> Result<BigUint> {
2052 let one = BigUint::one();
2053 let mut attempts = 0_u32;
2054 while attempts < 20_000 {
2055 let candidate = random_biguint_with_bits(bits, drbg, b"rsa_prime_candidate")?;
2056 if candidate.bit_len() != bits {
2057 attempts = attempts.saturating_add(1);
2058 continue;
2059 }
2060 if !is_probable_prime(&candidate) {
2061 attempts = attempts.saturating_add(1);
2062 continue;
2063 }
2064 let pm1 = candidate.sub(&one);
2065 if BigUint::gcd(e, &pm1).cmp(&one).is_eq() {
2066 return Ok(candidate);
2067 }
2068 attempts = attempts.saturating_add(1);
2069 }
2070 Err(Error::StateError(
2071 "rsa prime generation exhausted attempt budget",
2072 ))
2073}
2074
2075fn random_biguint_with_bits(
2095 bits: usize,
2096 drbg: &mut HmacDrbgSha256,
2097 label: &[u8],
2098) -> Result<BigUint> {
2099 if bits < 2 {
2100 return Err(Error::InvalidLength(
2101 "rsa prime candidate bits must be at least 2",
2102 ));
2103 }
2104 let byte_len = bits.div_ceil(8);
2105 let mut random = drbg.generate(byte_len, label)?;
2106 let top_bits = bits % 8;
2107 if top_bits != 0 {
2108 random[0] &= (1_u8 << top_bits) - 1;
2109 }
2110 let high_bit_index = (bits - 1) % 8;
2111 random[0] |= 1_u8 << high_bit_index;
2112 let last = random.len() - 1;
2113 random[last] |= 1;
2114 Ok(BigUint::from_be_bytes(&random))
2115}
2116
2117fn is_probable_prime(n: &BigUint) -> bool {
2131 let two = BigUint::from_u128(2);
2132 if n.cmp(&two).is_lt() {
2133 return false;
2134 }
2135 for small in [2_u32, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37] {
2136 let small_bn = BigUint::from_u128(u128::from(small));
2137 if n.cmp(&small_bn).is_eq() {
2138 return true;
2139 }
2140 if n.mod_u32(small) == 0 {
2141 return false;
2142 }
2143 }
2144 let one = BigUint::one();
2145 let n_minus_one = n.sub(&one);
2146 let mut d = n_minus_one.clone();
2147 let mut s = 0_u32;
2148 while d.is_even() {
2149 d = d.shr1();
2150 s = s.saturating_add(1);
2151 }
2152 for witness in [2_u32, 3, 5, 7, 11, 13, 17, 19, 23, 29] {
2153 if !miller_rabin_round(n, &d, s, witness) {
2154 return false;
2155 }
2156 }
2157 true
2158}
2159
2160fn miller_rabin_round(n: &BigUint, d: &BigUint, s: u32, witness: u32) -> bool {
2177 let a = BigUint::from_u128(u128::from(witness)).modulo(n);
2178 if a.is_zero() {
2179 return true;
2180 }
2181 let one = BigUint::one();
2182 let n_minus_one = n.sub(&one);
2183 let mut x = BigUint::mod_exp(&a, d, n);
2184 if x.cmp(&one).is_eq() || x.cmp(&n_minus_one).is_eq() {
2185 return true;
2186 }
2187 for _ in 1..s {
2188 x = x.mul(&x).modulo(n);
2189 if x.cmp(&n_minus_one).is_eq() {
2190 return true;
2191 }
2192 }
2193 false
2194}